diff --git a/src/application/api/mcp_tools.py b/src/application/api/mcp_tools.py index 62cc9c7..e3620d5 100644 --- a/src/application/api/mcp_tools.py +++ b/src/application/api/mcp_tools.py @@ -5,7 +5,8 @@ from fastmcp import FastMCP -from dependencies import get_query_use_case +from application.requests.query_request import MultimodalContentItem +from dependencies import get_multimodal_query_use_case, get_query_use_case mcp = FastMCP("RAGAnything") @@ -29,3 +30,40 @@ async def query_knowledge_base( return await use_case.execute( working_dir=working_dir, query=query, mode=mode, top_k=top_k ) + + +@mcp.tool() +async def query_knowledge_base_multimodal( + working_dir: str, + query: str, + multimodal_content: list[MultimodalContentItem], + mode: str = "hybrid", + top_k: int = 10, +) -> dict: + """Query the knowledge base with multimodal content (images, tables, equations). + + Use this tool when your query involves visual or structured data. + Each item in multimodal_content must have a "type" field ("image", "table", or "equation") + plus type-specific fields: + - image: img_path (file path) or image_data (base64) + - table: table_data (CSV format), optional table_caption + - equation: latex (LaTeX string), optional equation_caption + + Args: + working_dir: RAG workspace directory for this project + query: The user's question or search query + multimodal_content: List of multimodal content items + mode: Search mode - "hybrid" (recommended), "naive", "local", "global", "mix" + top_k: Number of chunks to retrieve (default 10) + + Returns: + Query response with multimodal analysis + """ + use_case = get_multimodal_query_use_case() + return await use_case.execute( + working_dir=working_dir, + query=query, + multimodal_content=multimodal_content, + mode=mode, + top_k=top_k, + ) diff --git a/src/application/api/query_routes.py b/src/application/api/query_routes.py index 93d1f90..c58f978 100644 --- a/src/application/api/query_routes.py +++ b/src/application/api/query_routes.py @@ -1,9 +1,10 @@ from fastapi import APIRouter, Depends, status -from application.requests.query_request import QueryRequest -from application.responses.query_response import QueryResponse +from application.requests.query_request import MultimodalQueryRequest, QueryRequest +from application.responses.query_response import MultimodalQueryResponse, QueryResponse +from application.use_cases.multimodal_query_use_case import MultimodalQueryUseCase from application.use_cases.query_use_case import QueryUseCase -from dependencies import get_query_use_case +from dependencies import get_multimodal_query_use_case, get_query_use_case query_router = APIRouter(tags=["RAG Query"]) @@ -22,3 +23,22 @@ async def query_knowledge_base( top_k=request.top_k, ) return QueryResponse(**result) + + +@query_router.post( + "/query/multimodal", + response_model=MultimodalQueryResponse, + status_code=status.HTTP_200_OK, +) +async def query_knowledge_base_multimodal( + request: MultimodalQueryRequest, + use_case: MultimodalQueryUseCase = Depends(get_multimodal_query_use_case), +) -> MultimodalQueryResponse: + result = await use_case.execute( + working_dir=request.working_dir, + query=request.query, + multimodal_content=request.multimodal_content, + mode=request.mode, + top_k=request.top_k, + ) + return MultimodalQueryResponse(**result) diff --git a/src/application/requests/query_request.py b/src/application/requests/query_request.py index 6e41f47..97d5a6b 100644 --- a/src/application/requests/query_request.py +++ b/src/application/requests/query_request.py @@ -2,6 +2,8 @@ from pydantic import BaseModel, Field +QueryMode = Literal["local", "global", "hybrid", "naive", "mix", "bypass"] + class QueryRequest(BaseModel): working_dir: str = Field( @@ -11,7 +13,7 @@ class QueryRequest(BaseModel): ..., description="The user's question or search query (e.g., 'What are the main findings?')", ) - mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field( + mode: QueryMode = Field( default="naive", description=( "Search mode - 'naive' (default, recommended), 'local' (context-aware), " @@ -25,3 +27,40 @@ class QueryRequest(BaseModel): "Use 10 for fast, focused results; use 20 for comprehensive search." ), ) + + +class MultimodalContentItem(BaseModel): + type: Literal["image", "table", "equation"] = Field( + ..., description="Type de contenu multimodal" + ) + img_path: str | None = Field( + default=None, description="Chemin vers un fichier image" + ) + image_data: str | None = Field( + default=None, description="Image encodée en base64 (alternative à img_path)" + ) + table_data: str | None = Field( + default=None, description="Données tabulaires au format CSV" + ) + table_caption: str | None = Field( + default=None, description="Légende décrivant la table" + ) + latex: str | None = Field(default=None, description="Equation au format LaTeX") + equation_caption: str | None = Field( + default=None, description="Légende décrivant l'équation" + ) + + +class MultimodalQueryRequest(BaseModel): + working_dir: str = Field( + ..., description="RAG workspace directory for this project" + ) + query: str = Field(..., description="The user's question or search query") + mode: QueryMode = Field( + default="hybrid", + description="Search mode - 'hybrid' recommended for multimodal queries", + ) + top_k: int = Field(default=10, description="Number of chunks to retrieve") + multimodal_content: list[MultimodalContentItem] = Field( + ..., description="Liste de contenus multimodaux à inclure dans la requête" + ) diff --git a/src/application/responses/query_response.py b/src/application/responses/query_response.py index 91d40e9..b250ab1 100644 --- a/src/application/responses/query_response.py +++ b/src/application/responses/query_response.py @@ -65,3 +65,11 @@ class QueryResponse(BaseModel): message: str = "" data: QueryDataResponse = Field(default_factory=QueryDataResponse) metadata: QueryMetadataResponse | None = None + + +class MultimodalQueryResponse(BaseModel): + status: str + message: str = "" + data: str = Field( + default="", description="Réponse textuelle de l'analyse multimodale" + ) diff --git a/src/application/use_cases/multimodal_query_use_case.py b/src/application/use_cases/multimodal_query_use_case.py new file mode 100644 index 0000000..471f243 --- /dev/null +++ b/src/application/use_cases/multimodal_query_use_case.py @@ -0,0 +1,27 @@ +from application.requests.query_request import MultimodalContentItem +from domain.ports.rag_engine import RAGEnginePort + + +class MultimodalQueryUseCase: + """Use case for querying the RAG knowledge base with multimodal content.""" + + def __init__(self, rag_engine: RAGEnginePort) -> None: + self.rag_engine = rag_engine + + async def execute( + self, + working_dir: str, + query: str, + multimodal_content: list[MultimodalContentItem], + mode: str = "hybrid", + top_k: int = 10, + ) -> dict: + self.rag_engine.init_project(working_dir) + result = await self.rag_engine.query_multimodal( + query=query, + multimodal_content=multimodal_content, + mode=mode, + top_k=top_k, + working_dir=working_dir, + ) + return {"status": "success", "data": result} diff --git a/src/dependencies.py b/src/dependencies.py index 611a6c7..baad639 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -4,6 +4,7 @@ from application.use_cases.index_file_use_case import IndexFileUseCase from application.use_cases.index_folder_use_case import IndexFolderUseCase +from application.use_cases.multimodal_query_use_case import MultimodalQueryUseCase from application.use_cases.query_use_case import QueryUseCase from config import AppConfig, LLMConfig, MinioConfig, RAGConfig from infrastructure.rag.lightrag_adapter import LightRAGAdapter @@ -45,3 +46,7 @@ def get_index_folder_use_case() -> IndexFolderUseCase: def get_query_use_case() -> QueryUseCase: return QueryUseCase(rag_adapter) + + +def get_multimodal_query_use_case() -> MultimodalQueryUseCase: + return MultimodalQueryUseCase(rag_adapter) diff --git a/src/domain/ports/rag_engine.py b/src/domain/ports/rag_engine.py index 9a53afc..a174bd3 100644 --- a/src/domain/ports/rag_engine.py +++ b/src/domain/ports/rag_engine.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from application.requests.query_request import MultimodalContentItem from domain.entities.indexing_result import FileIndexingResult, FolderIndexingResult @@ -29,5 +30,18 @@ async def index_folder( pass @abstractmethod - async def query(self, query: str, mode: str = "naive", top_k: int = 10, working_dir: str = "") -> dict: + async def query( + self, query: str, mode: str = "naive", top_k: int = 10, working_dir: str = "" + ) -> dict: + pass + + @abstractmethod + async def query_multimodal( + self, + query: str, + multimodal_content: list[MultimodalContentItem], + mode: str = "hybrid", + top_k: int = 10, + working_dir: str = "", + ) -> str: pass diff --git a/src/infrastructure/rag/lightrag_adapter.py b/src/infrastructure/rag/lightrag_adapter.py index 3fc97e8..78abbe8 100644 --- a/src/infrastructure/rag/lightrag_adapter.py +++ b/src/infrastructure/rag/lightrag_adapter.py @@ -9,6 +9,7 @@ from lightrag.utils import EmbeddingFunc from raganything import RAGAnything, RAGAnythingConfig +from application.requests.query_request import MultimodalContentItem from config import LLMConfig, RAGConfig from domain.entities.indexing_result import ( FileIndexingResult, @@ -42,7 +43,7 @@ class LightRAGAdapter(RAGEnginePort): def __init__(self, llm_config: LLMConfig, rag_config: RAGConfig) -> None: self._llm_config = llm_config self._rag_config = rag_config - self.rag: dict[str,RAGAnything] = {} + self.rag: dict[str, RAGAnything] = {} @staticmethod def _make_workspace(working_dir: str) -> str: @@ -93,6 +94,7 @@ def init_project(self, working_dir: str) -> RAGAnything: }, ) return self.rag[working_dir] + # ------------------------------------------------------------------ # LLM callables (passed directly to RAGAnything) # ------------------------------------------------------------------ @@ -217,7 +219,9 @@ async def index_folder( # Port implementation — query # ------------------------------------------------------------------ - async def query(self, query: str, mode: str = "naive", top_k: int = 10, working_dir: str = "") -> dict: + async def query( + self, query: str, mode: str = "naive", top_k: int = 10, working_dir: str = "" + ) -> dict: rag = self._ensure_initialized(working_dir) await rag._ensure_lightrag_initialized() if rag.lightrag is None: @@ -229,6 +233,26 @@ async def query(self, query: str, mode: str = "naive", top_k: int = 10, working_ param = QueryParam(mode=cast(QueryMode, mode), top_k=top_k, chunk_top_k=top_k) return await rag.lightrag.aquery_data(query=query, param=param) + async def query_multimodal( + self, + query: str, + multimodal_content: list[MultimodalContentItem], + mode: str = "hybrid", + top_k: int = 10, + working_dir: str = "", + ) -> str: + rag = self._ensure_initialized(working_dir) + await rag._ensure_lightrag_initialized() + raw_content = [ + item.model_dump(exclude_none=True) for item in multimodal_content + ] + return await rag.aquery_with_multimodal( + query=query, + multimodal_content=raw_content, + mode=mode, + top_k=top_k, + ) + # ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ diff --git a/tests/fixtures/external.py b/tests/fixtures/external.py index 7d191a9..9f696f2 100644 --- a/tests/fixtures/external.py +++ b/tests/fixtures/external.py @@ -39,6 +39,8 @@ def mock_rag_engine() -> AsyncMock: processing_time_ms=500.0, ) + mock.query_multimodal.return_value = "Multimodal analysis result" + return mock diff --git a/tests/unit/test_multimodal_query_use_case.py b/tests/unit/test_multimodal_query_use_case.py new file mode 100644 index 0000000..351f6c4 --- /dev/null +++ b/tests/unit/test_multimodal_query_use_case.py @@ -0,0 +1,91 @@ +from unittest.mock import AsyncMock + +from application.requests.query_request import MultimodalContentItem +from application.use_cases.multimodal_query_use_case import MultimodalQueryUseCase + + +class TestMultimodalQueryUseCase: + """Tests for MultimodalQueryUseCase — rag_engine is external, mocked.""" + + async def test_execute_calls_init_project( + self, + mock_rag_engine: AsyncMock, + ) -> None: + """Should call rag_engine.init_project with the working_dir.""" + use_case = MultimodalQueryUseCase(rag_engine=mock_rag_engine) + content = [MultimodalContentItem(type="image", img_path="/tmp/img.png")] + + await use_case.execute( + working_dir="/tmp/rag/project_42", + query="What does this image show?", + multimodal_content=content, + ) + + mock_rag_engine.init_project.assert_called_once_with("/tmp/rag/project_42") + + async def test_execute_calls_query_multimodal_with_correct_params( + self, + mock_rag_engine: AsyncMock, + ) -> None: + """Should call rag_engine.query_multimodal with all params.""" + use_case = MultimodalQueryUseCase(rag_engine=mock_rag_engine) + content = [ + MultimodalContentItem( + type="table", table_data="A,B\n1,2", table_caption="Test table" + ), + ] + + await use_case.execute( + working_dir="/tmp/rag/test", + query="Analyze this table", + multimodal_content=content, + mode="global", + top_k=20, + ) + + mock_rag_engine.query_multimodal.assert_called_once_with( + query="Analyze this table", + multimodal_content=content, + mode="global", + top_k=20, + working_dir="/tmp/rag/test", + ) + + async def test_execute_returns_success_with_result( + self, + mock_rag_engine: AsyncMock, + ) -> None: + """Should return dict with status='success' and data from rag_engine.""" + mock_rag_engine.query_multimodal.return_value = "Analysis of the image content" + use_case = MultimodalQueryUseCase(rag_engine=mock_rag_engine) + content = [MultimodalContentItem(type="image", img_path="/tmp/diagram.png")] + + result = await use_case.execute( + working_dir="/tmp/rag/test", + query="Describe this diagram", + multimodal_content=content, + ) + + assert result == {"status": "success", "data": "Analysis of the image content"} + + async def test_execute_uses_default_mode_and_top_k( + self, + mock_rag_engine: AsyncMock, + ) -> None: + """Should use mode='hybrid' and top_k=10 by default.""" + use_case = MultimodalQueryUseCase(rag_engine=mock_rag_engine) + content = [MultimodalContentItem(type="equation", latex="E=mc^2")] + + await use_case.execute( + working_dir="/tmp/rag/test", + query="Explain this formula", + multimodal_content=content, + ) + + mock_rag_engine.query_multimodal.assert_called_once_with( + query="Explain this formula", + multimodal_content=content, + mode="hybrid", + top_k=10, + working_dir="/tmp/rag/test", + ) diff --git a/tests/unit/test_routes.py b/tests/unit/test_routes.py index edf1dff..89dbfda 100644 --- a/tests/unit/test_routes.py +++ b/tests/unit/test_routes.py @@ -4,10 +4,17 @@ import pytest from httpx import ASGITransport +from application.requests.query_request import MultimodalContentItem from application.use_cases.index_file_use_case import IndexFileUseCase from application.use_cases.index_folder_use_case import IndexFolderUseCase +from application.use_cases.multimodal_query_use_case import MultimodalQueryUseCase from application.use_cases.query_use_case import QueryUseCase -from dependencies import get_index_file_use_case, get_index_folder_use_case, get_query_use_case +from dependencies import ( + get_index_file_use_case, + get_index_folder_use_case, + get_multimodal_query_use_case, + get_query_use_case, +) from main import app @@ -313,3 +320,145 @@ async def test_query_rejects_invalid_mode(self) -> None: ) assert response.status_code == 422 + + +class TestMultimodalQueryRoute: + @pytest.fixture + def mock_multimodal_query_use_case(self) -> AsyncMock: + mock = AsyncMock(spec=MultimodalQueryUseCase) + mock.execute.return_value = { + "status": "success", + "data": "Multimodal analysis result", + } + return mock + + async def test_multimodal_query_returns_200( + self, + mock_multimodal_query_use_case: AsyncMock, + ) -> None: + app.dependency_overrides[get_multimodal_query_use_case] = ( + lambda: mock_multimodal_query_use_case + ) + + async with httpx.AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/query/multimodal", + json={ + "working_dir": "/tmp/rag/test", + "query": "What does this image show?", + "multimodal_content": [ + {"type": "image", "img_path": "/tmp/img.png"}, + ], + }, + ) + + assert response.status_code == 200 + + async def test_multimodal_query_calls_use_case_with_correct_params( + self, + mock_multimodal_query_use_case: AsyncMock, + ) -> None: + app.dependency_overrides[get_multimodal_query_use_case] = ( + lambda: mock_multimodal_query_use_case + ) + + async with httpx.AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + await client.post( + "/api/v1/query/multimodal", + json={ + "working_dir": "/tmp/rag/project_42", + "query": "Analyze this table", + "multimodal_content": [ + {"type": "table", "table_data": "A,B\n1,2", "table_caption": "Test"}, + ], + "mode": "global", + "top_k": 20, + }, + ) + + mock_multimodal_query_use_case.execute.assert_called_once_with( + working_dir="/tmp/rag/project_42", + query="Analyze this table", + multimodal_content=[ + MultimodalContentItem(type="table", table_data="A,B\n1,2", table_caption="Test"), + ], + mode="global", + top_k=20, + ) + + async def test_multimodal_query_returns_response_body( + self, + mock_multimodal_query_use_case: AsyncMock, + ) -> None: + app.dependency_overrides[get_multimodal_query_use_case] = ( + lambda: mock_multimodal_query_use_case + ) + + async with httpx.AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/query/multimodal", + json={ + "working_dir": "/tmp/rag/test", + "query": "Describe this", + "multimodal_content": [ + {"type": "image", "img_path": "/tmp/img.png"}, + ], + }, + ) + + body = response.json() + assert body["status"] == "success" + assert body["data"] == "Multimodal analysis result" + + async def test_multimodal_query_rejects_missing_query(self) -> None: + async with httpx.AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/query/multimodal", + json={ + "working_dir": "/tmp/rag/test", + "multimodal_content": [ + {"type": "image", "img_path": "/tmp/img.png"}, + ], + }, + ) + + assert response.status_code == 422 + + async def test_multimodal_query_rejects_missing_multimodal_content(self) -> None: + async with httpx.AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/query/multimodal", + json={ + "working_dir": "/tmp/rag/test", + "query": "test", + }, + ) + + assert response.status_code == 422 + + async def test_multimodal_query_rejects_invalid_content_type(self) -> None: + async with httpx.AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/query/multimodal", + json={ + "working_dir": "/tmp/rag/test", + "query": "test", + "multimodal_content": [ + {"type": "video", "path": "/tmp/video.mp4"}, + ], + }, + ) + + assert response.status_code == 422