diff --git a/src/infrastructure/rag/lightrag_adapter.py b/src/infrastructure/rag/lightrag_adapter.py index 6184af3..b0c004a 100644 --- a/src/infrastructure/rag/lightrag_adapter.py +++ b/src/infrastructure/rag/lightrag_adapter.py @@ -1,7 +1,9 @@ +import asyncio import hashlib import os import tempfile import time +from pathlib import Path from typing import Literal, cast from fastapi.logger import logger @@ -139,49 +141,6 @@ async def vision_call( ) return self.rag[working_dir] - # ------------------------------------------------------------------ - # LLM callables (passed directly to RAGAnything) - # ------------------------------------------------------------------ - - async def _llm_call( - self, prompt, system_prompt=None, history_messages=None, **kwargs - ): - if history_messages is None: - history_messages = [] - return await openai_complete_if_cache( - self._llm_config.CHAT_MODEL, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - api_key=self._llm_config.api_key, - base_url=self._llm_config.api_base_url, - **kwargs, - ) - - async def _vision_call( - self, - prompt, - system_prompt=None, - history_messages=None, - image_data=None, - **kwargs, - ): - if history_messages is None: - history_messages = [] - messages = _build_vision_messages( - system_prompt, history_messages, prompt, image_data - ) - return await openai_complete_if_cache( - self._llm_config.VISION_MODEL, - "Image Description Task", - system_prompt=None, - history_messages=messages, - api_key=self._llm_config.api_key, - base_url=self._llm_config.api_base_url, - messages=messages, - **kwargs, - ) - # ------------------------------------------------------------------ # Port implementation — indexing # ------------------------------------------------------------------ @@ -232,19 +191,18 @@ async def index_folder( file_extensions: list[str] | None = None, working_dir: str = "", ) -> FolderIndexingResult: - """Index a folder by processing each document sequentially. + """Index a folder by processing documents concurrently. - RAGAnything's process_folder_complete uses deepcopy internally which - fails with asyncpg/asyncio objects. We iterate files manually and - call process_document_complete for each one instead. + Uses ``asyncio.Semaphore`` bounded by ``MAX_CONCURRENT_FILES`` so + that at most *N* files are processed at the same time. When + ``MAX_CONCURRENT_FILES <= 1`` behaviour is identical to the old + sequential loop. """ start_time = time.time() rag = self._ensure_initialized(working_dir) await rag._ensure_lightrag_initialized() glob_pattern = "**/*" if recursive else "*" - from pathlib import Path - folder = Path(folder_path) all_files = [f for f in folder.glob(glob_pattern) if f.is_file()] @@ -252,54 +210,61 @@ async def index_folder( exts = set(file_extensions) all_files = [f for f in all_files if f.suffix in exts] + max_workers = max(1, self._rag_config.MAX_CONCURRENT_FILES) + semaphore = asyncio.Semaphore(max_workers) + succeeded = 0 failed = 0 file_results: list[FileProcessingDetail] = [] - for file_path_obj in all_files: - try: - await rag.process_document_complete( - file_path=str(file_path_obj), - output_dir=output_dir, - parse_method="txt", - ) - succeeded += 1 - file_results.append( - FileProcessingDetail( + async def _process_file(file_path_obj: Path) -> None: + nonlocal succeeded, failed + async with semaphore: + try: + await rag.process_document_complete( file_path=str(file_path_obj), - file_name=file_path_obj.name, - status=IndexingStatus.SUCCESS, + output_dir=output_dir, + parse_method="txt", ) - ) - logger.info( - f"Indexed {file_path_obj.name} ({succeeded}/{len(all_files)})" - ) - except Exception as e: - failed += 1 - logger.error(f"Failed to index {file_path_obj.name}: {e}") - file_results.append( - FileProcessingDetail( - file_path=str(file_path_obj), - file_name=file_path_obj.name, - status=IndexingStatus.FAILED, - error=str(e), + succeeded += 1 + file_results.append( + FileProcessingDetail( + file_path=str(file_path_obj), + file_name=file_path_obj.name, + status=IndexingStatus.SUCCESS, + ) ) - ) + logger.info( + f"Indexed {file_path_obj.name} ({succeeded}/{len(all_files)})" + ) + except Exception as e: + failed += 1 + logger.error(f"Failed to index {file_path_obj.name}: {e}") + file_results.append( + FileProcessingDetail( + file_path=str(file_path_obj), + file_name=file_path_obj.name, + status=IndexingStatus.FAILED, + error=str(e), + ) + ) + + await asyncio.gather(*[_process_file(f) for f in all_files]) processing_time_ms = (time.time() - start_time) * 1000 total = len(all_files) if total == 0: status = IndexingStatus.SUCCESS message = f"No files found in '{folder_path}'" - elif failed == 0 and succeeded > 0: + elif failed == 0: status = IndexingStatus.SUCCESS message = f"Successfully indexed {succeeded} file(s) from '{folder_path}'" - elif succeeded > 0 and failed > 0: - status = IndexingStatus.PARTIAL - message = f"Partially indexed: {succeeded} succeeded, {failed} failed" - else: + elif succeeded == 0: status = IndexingStatus.FAILED message = f"Failed to index folder '{folder_path}'" + else: + status = IndexingStatus.PARTIAL + message = f"Partially indexed: {succeeded} succeeded, {failed} failed" return FolderIndexingResult( status=status, @@ -358,47 +323,6 @@ async def query_multimodal( top_k=top_k, ) - # ------------------------------------------------------------------ - # Private helpers - # ------------------------------------------------------------------ - - @staticmethod - def _build_folder_result( - result, folder_path: str, recursive: bool, processing_time_ms: float - ) -> FolderIndexingResult: - result_dict = result if isinstance(result, dict) else {} - stats = FolderIndexingStats( - total_files=result_dict.get("total_files", 0), - files_processed=result_dict.get("successful_files", 0), - files_failed=result_dict.get("failed_files", 0), - files_skipped=result_dict.get("skipped_files", 0), - ) - - file_results = _parse_file_details(result_dict) - - if stats.files_failed == 0 and stats.files_processed > 0: - status = IndexingStatus.SUCCESS - message = f"Successfully indexed {stats.files_processed} file(s) from '{folder_path}'" - elif stats.files_processed > 0 and stats.files_failed > 0: - status = IndexingStatus.PARTIAL - message = f"Partially indexed folder '{folder_path}': {stats.files_processed} succeeded, {stats.files_failed} failed" - elif stats.files_processed == 0 and stats.total_files > 0: - status = IndexingStatus.FAILED - message = f"Failed to index any files from '{folder_path}'" - else: - status = IndexingStatus.SUCCESS - message = f"No files found to index in '{folder_path}'" - - return FolderIndexingResult( - status=status, - message=message, - folder_path=folder_path, - recursive=recursive, - stats=stats, - processing_time_ms=round(processing_time_ms, 2), - file_results=file_results, - ) - # ------------------------------------------------------------------ # Module-level helpers @@ -430,22 +354,3 @@ def _build_vision_messages( messages.append({"role": "user", "content": content}) return messages - - -def _parse_file_details(result_dict: dict) -> list[FileProcessingDetail] | None: - if "file_details" not in result_dict: - return None - file_details = result_dict["file_details"] - if not isinstance(file_details, list): - return None - return [ - FileProcessingDetail( - file_path=d.get("file_path", ""), - file_name=os.path.basename(d.get("file_path", "")), - status=IndexingStatus.SUCCESS - if d.get("success", False) - else IndexingStatus.FAILED, - error=d.get("error"), - ) - for d in file_details - ] diff --git a/tests/unit/test_lightrag_adapter.py b/tests/unit/test_lightrag_adapter.py index 21e2341..a2bea84 100644 --- a/tests/unit/test_lightrag_adapter.py +++ b/tests/unit/test_lightrag_adapter.py @@ -1,3 +1,4 @@ +import asyncio import os import tempfile from unittest.mock import AsyncMock, MagicMock, patch @@ -559,3 +560,240 @@ async def test_index_txt_with_various_encodings( # All three should be processed assert mock_rag.process_document_complete.call_count == 3 + + # ------------------------------------------------------------------ + # Concurrent index_folder tests + # ------------------------------------------------------------------ + + +class TestLightRAGAdapterConcurrentIndexFolder: + """Tests for concurrent file processing in index_folder. + + Verifies that the ``asyncio.Semaphore`` + ``asyncio.gather`` implementation + respects ``MAX_CONCURRENT_FILES`` and handles edge cases correctly. + """ + + @staticmethod + def _make_adapter( + llm_config: LLMConfig, max_concurrent_files: int + ) -> LightRAGAdapter: + """Create an adapter with a custom MAX_CONCURRENT_FILES.""" + rag_config = RAGConfig( + RAG_STORAGE_TYPE="postgres", + MAX_CONCURRENT_FILES=max_concurrent_files, + ) + return LightRAGAdapter(llm_config, rag_config) + + @staticmethod + def _make_mock_rag() -> MagicMock: + """Create a mock RAG with standard async stubs.""" + mock_rag = MagicMock() + mock_rag._ensure_lightrag_initialized = AsyncMock() + return mock_rag + + async def test_index_folder_concurrent_respects_max_concurrency( + self, + llm_config: LLMConfig, + tmp_path, + ) -> None: + """With MAX_CONCURRENT_FILES=2 and 5 files, at most 2 calls in-flight.""" + adapter = self._make_adapter(llm_config, max_concurrent_files=2) + mock_rag = self._make_mock_rag() + adapter.rag["test_dir"] = mock_rag + + # Create 5 test files + for i in range(5): + (tmp_path / f"doc_{i}.pdf").write_text(f"content_{i}") + + max_concurrent = 0 + current_concurrent = 0 + lock = asyncio.Lock() + + async def slow_process(**_kwargs): + nonlocal max_concurrent, current_concurrent + async with lock: + current_concurrent += 1 + if current_concurrent > max_concurrent: + max_concurrent = current_concurrent + # Simulate I/O delay so tasks overlap + await asyncio.sleep(0.05) + async with lock: + current_concurrent -= 1 + + mock_rag.process_document_complete = AsyncMock(side_effect=slow_process) + + result = await adapter.index_folder( + folder_path=str(tmp_path), + output_dir="/tmp/output", + working_dir="test_dir", + ) + + assert result.status == IndexingStatus.SUCCESS + assert result.stats.total_files == 5 + assert result.stats.files_processed == 5 + assert result.stats.files_failed == 0 + assert max_concurrent <= 2, ( + f"Expected max 2 concurrent calls, got {max_concurrent}" + ) + assert max_concurrent >= 2, ( + f"Expected at least 2 concurrent calls (concurrent execution), got {max_concurrent}" + ) + + async def test_index_folder_concurrent_all_succeed( + self, + llm_config: LLMConfig, + tmp_path, + ) -> None: + """With MAX_CONCURRENT_FILES=4 and 8 files, all succeed.""" + adapter = self._make_adapter(llm_config, max_concurrent_files=4) + mock_rag = self._make_mock_rag() + mock_rag.process_document_complete = AsyncMock() + adapter.rag["test_dir"] = mock_rag + + for i in range(8): + (tmp_path / f"file_{i}.pdf").write_text(f"data_{i}") + + result = await adapter.index_folder( + folder_path=str(tmp_path), + output_dir="/tmp/output", + working_dir="test_dir", + ) + + assert result.status == IndexingStatus.SUCCESS + assert result.stats.total_files == 8 + assert result.stats.files_processed == 8 + assert result.stats.files_failed == 0 + assert mock_rag.process_document_complete.call_count == 8 + + async def test_index_folder_concurrent_max_zero_treated_as_one( + self, + llm_config: LLMConfig, + tmp_path, + ) -> None: + """MAX_CONCURRENT_FILES=0 should be treated as 1 (deadlock prevention).""" + adapter = self._make_adapter(llm_config, max_concurrent_files=0) + mock_rag = self._make_mock_rag() + adapter.rag["test_dir"] = mock_rag + + for i in range(3): + (tmp_path / f"doc_{i}.pdf").write_text(f"content_{i}") + + max_concurrent = 0 + current_concurrent = 0 + lock = asyncio.Lock() + + async def tracked_process(**_kwargs): + nonlocal max_concurrent, current_concurrent + async with lock: + current_concurrent += 1 + if current_concurrent > max_concurrent: + max_concurrent = current_concurrent + await asyncio.sleep(0.05) + async with lock: + current_concurrent -= 1 + + mock_rag.process_document_complete = AsyncMock(side_effect=tracked_process) + + result = await adapter.index_folder( + folder_path=str(tmp_path), + output_dir="/tmp/output", + working_dir="test_dir", + ) + + assert result.status == IndexingStatus.SUCCESS + assert result.stats.files_processed == 3 + assert result.stats.files_failed == 0 + # With concurrency clamped to 1, never more than 1 in-flight + assert max_concurrent <= 1, ( + f"Expected max 1 concurrent call with MAX_CONCURRENT_FILES=0, got {max_concurrent}" + ) + + async def test_index_folder_concurrent_greater_than_file_count( + self, + llm_config: LLMConfig, + tmp_path, + ) -> None: + """With MAX_CONCURRENT_FILES=10 and only 3 files, all start immediately.""" + adapter = self._make_adapter(llm_config, max_concurrent_files=10) + mock_rag = self._make_mock_rag() + mock_rag.process_document_complete = AsyncMock() + adapter.rag["test_dir"] = mock_rag + + for i in range(3): + (tmp_path / f"small_{i}.pdf").write_text(f"data_{i}") + + result = await adapter.index_folder( + folder_path=str(tmp_path), + output_dir="/tmp/output", + working_dir="test_dir", + ) + + assert result.status == IndexingStatus.SUCCESS + assert result.stats.total_files == 3 + assert result.stats.files_processed == 3 + assert result.stats.files_failed == 0 + assert mock_rag.process_document_complete.call_count == 3 + + async def test_index_folder_concurrent_single_file( + self, + llm_config: LLMConfig, + tmp_path, + ) -> None: + """Single file with any concurrency setting produces identical result.""" + adapter = self._make_adapter(llm_config, max_concurrent_files=5) + mock_rag = self._make_mock_rag() + mock_rag.process_document_complete = AsyncMock() + adapter.rag["test_dir"] = mock_rag + + (tmp_path / "only.pdf").write_text("solo content") + + result = await adapter.index_folder( + folder_path=str(tmp_path), + output_dir="/tmp/output", + working_dir="test_dir", + ) + + assert result.status == IndexingStatus.SUCCESS + assert result.stats.total_files == 1 + assert result.stats.files_processed == 1 + assert result.stats.files_failed == 0 + mock_rag.process_document_complete.assert_awaited_once_with( + file_path=str(tmp_path / "only.pdf"), + output_dir="/tmp/output", + parse_method="txt", + ) + + async def test_index_folder_concurrent_mixed_success_failure( + self, + llm_config: LLMConfig, + tmp_path, + ) -> None: + """Some files succeed, some fail under concurrency → PARTIAL status.""" + adapter = self._make_adapter(llm_config, max_concurrent_files=3) + mock_rag = self._make_mock_rag() + adapter.rag["test_dir"] = mock_rag + + for i in range(4): + (tmp_path / f"doc_{i}.pdf").write_text(f"content_{i}") + + call_count = 0 + + async def flaky_process(**_kwargs): + nonlocal call_count + call_count += 1 + if call_count % 2 == 0: + raise RuntimeError("Simulated failure") + + mock_rag.process_document_complete = AsyncMock(side_effect=flaky_process) + + result = await adapter.index_folder( + folder_path=str(tmp_path), + output_dir="/tmp/output", + working_dir="test_dir", + ) + + assert result.status == IndexingStatus.PARTIAL + assert result.stats.files_processed == 2 + assert result.stats.files_failed == 2 + assert result.file_results is not None + assert len(result.file_results) == 4