Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion src/application/api/mcp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from fastmcp import FastMCP

from dependencies import get_query_use_case
from application.requests.query_request import MultimodalContentItem
from dependencies import get_multimodal_query_use_case, get_query_use_case

mcp = FastMCP("RAGAnything")

Expand All @@ -29,3 +30,40 @@ async def query_knowledge_base(
return await use_case.execute(
working_dir=working_dir, query=query, mode=mode, top_k=top_k
)


@mcp.tool()
async def query_knowledge_base_multimodal(
working_dir: str,
query: str,
multimodal_content: list[MultimodalContentItem],
mode: str = "hybrid",
top_k: int = 10,
) -> dict:
"""Query the knowledge base with multimodal content (images, tables, equations).

Use this tool when your query involves visual or structured data.
Each item in multimodal_content must have a "type" field ("image", "table", or "equation")
plus type-specific fields:
- image: img_path (file path) or image_data (base64)
- table: table_data (CSV format), optional table_caption
- equation: latex (LaTeX string), optional equation_caption

Args:
working_dir: RAG workspace directory for this project
query: The user's question or search query
multimodal_content: List of multimodal content items
mode: Search mode - "hybrid" (recommended), "naive", "local", "global", "mix"
top_k: Number of chunks to retrieve (default 10)

Returns:
Query response with multimodal analysis
"""
use_case = get_multimodal_query_use_case()
return await use_case.execute(
working_dir=working_dir,
query=query,
multimodal_content=multimodal_content,
mode=mode,
top_k=top_k,
)
26 changes: 23 additions & 3 deletions src/application/api/query_routes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from fastapi import APIRouter, Depends, status

from application.requests.query_request import QueryRequest
from application.responses.query_response import QueryResponse
from application.requests.query_request import MultimodalQueryRequest, QueryRequest
from application.responses.query_response import MultimodalQueryResponse, QueryResponse
from application.use_cases.multimodal_query_use_case import MultimodalQueryUseCase
from application.use_cases.query_use_case import QueryUseCase
from dependencies import get_query_use_case
from dependencies import get_multimodal_query_use_case, get_query_use_case

query_router = APIRouter(tags=["RAG Query"])

Expand All @@ -22,3 +23,22 @@ async def query_knowledge_base(
top_k=request.top_k,
)
return QueryResponse(**result)


@query_router.post(
"/query/multimodal",
response_model=MultimodalQueryResponse,
status_code=status.HTTP_200_OK,
)
async def query_knowledge_base_multimodal(
request: MultimodalQueryRequest,
use_case: MultimodalQueryUseCase = Depends(get_multimodal_query_use_case),
) -> MultimodalQueryResponse:
result = await use_case.execute(
working_dir=request.working_dir,
query=request.query,
multimodal_content=request.multimodal_content,
mode=request.mode,
top_k=request.top_k,
)
return MultimodalQueryResponse(**result)
41 changes: 40 additions & 1 deletion src/application/requests/query_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from pydantic import BaseModel, Field

QueryMode = Literal["local", "global", "hybrid", "naive", "mix", "bypass"]


class QueryRequest(BaseModel):
working_dir: str = Field(
Expand All @@ -11,7 +13,7 @@ class QueryRequest(BaseModel):
...,
description="The user's question or search query (e.g., 'What are the main findings?')",
)
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
mode: QueryMode = Field(
default="naive",
description=(
"Search mode - 'naive' (default, recommended), 'local' (context-aware), "
Expand All @@ -25,3 +27,40 @@ class QueryRequest(BaseModel):
"Use 10 for fast, focused results; use 20 for comprehensive search."
),
)


class MultimodalContentItem(BaseModel):
type: Literal["image", "table", "equation"] = Field(
..., description="Type de contenu multimodal"
)
img_path: str | None = Field(
default=None, description="Chemin vers un fichier image"
)
image_data: str | None = Field(
default=None, description="Image encodée en base64 (alternative à img_path)"
)
table_data: str | None = Field(
default=None, description="Données tabulaires au format CSV"
)
table_caption: str | None = Field(
default=None, description="Légende décrivant la table"
)
latex: str | None = Field(default=None, description="Equation au format LaTeX")
equation_caption: str | None = Field(
default=None, description="Légende décrivant l'équation"
)


class MultimodalQueryRequest(BaseModel):
working_dir: str = Field(
..., description="RAG workspace directory for this project"
)
query: str = Field(..., description="The user's question or search query")
mode: QueryMode = Field(
default="hybrid",
description="Search mode - 'hybrid' recommended for multimodal queries",
)
top_k: int = Field(default=10, description="Number of chunks to retrieve")
multimodal_content: list[MultimodalContentItem] = Field(
..., description="Liste de contenus multimodaux à inclure dans la requête"
)
8 changes: 8 additions & 0 deletions src/application/responses/query_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,11 @@ class QueryResponse(BaseModel):
message: str = ""
data: QueryDataResponse = Field(default_factory=QueryDataResponse)
metadata: QueryMetadataResponse | None = None


class MultimodalQueryResponse(BaseModel):
status: str
message: str = ""
data: str = Field(
default="", description="Réponse textuelle de l'analyse multimodale"
)
27 changes: 27 additions & 0 deletions src/application/use_cases/multimodal_query_use_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from application.requests.query_request import MultimodalContentItem
from domain.ports.rag_engine import RAGEnginePort


class MultimodalQueryUseCase:
"""Use case for querying the RAG knowledge base with multimodal content."""

def __init__(self, rag_engine: RAGEnginePort) -> None:
self.rag_engine = rag_engine

async def execute(
self,
working_dir: str,
query: str,
multimodal_content: list[MultimodalContentItem],
mode: str = "hybrid",
top_k: int = 10,
) -> dict:
self.rag_engine.init_project(working_dir)
result = await self.rag_engine.query_multimodal(
query=query,
multimodal_content=multimodal_content,
mode=mode,
top_k=top_k,
working_dir=working_dir,
)
return {"status": "success", "data": result}
5 changes: 5 additions & 0 deletions src/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from application.use_cases.index_file_use_case import IndexFileUseCase
from application.use_cases.index_folder_use_case import IndexFolderUseCase
from application.use_cases.multimodal_query_use_case import MultimodalQueryUseCase
from application.use_cases.query_use_case import QueryUseCase
from config import AppConfig, LLMConfig, MinioConfig, RAGConfig
from infrastructure.rag.lightrag_adapter import LightRAGAdapter
Expand Down Expand Up @@ -45,3 +46,7 @@ def get_index_folder_use_case() -> IndexFolderUseCase:

def get_query_use_case() -> QueryUseCase:
return QueryUseCase(rag_adapter)


def get_multimodal_query_use_case() -> MultimodalQueryUseCase:
return MultimodalQueryUseCase(rag_adapter)
16 changes: 15 additions & 1 deletion src/domain/ports/rag_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod

from application.requests.query_request import MultimodalContentItem
from domain.entities.indexing_result import FileIndexingResult, FolderIndexingResult


Expand Down Expand Up @@ -29,5 +30,18 @@ async def index_folder(
pass

@abstractmethod
async def query(self, query: str, mode: str = "naive", top_k: int = 10, working_dir: str = "") -> dict:
async def query(
self, query: str, mode: str = "naive", top_k: int = 10, working_dir: str = ""
) -> dict:
pass

@abstractmethod
async def query_multimodal(
self,
query: str,
multimodal_content: list[MultimodalContentItem],
mode: str = "hybrid",
top_k: int = 10,
working_dir: str = "",
) -> str:
pass
28 changes: 26 additions & 2 deletions src/infrastructure/rag/lightrag_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from lightrag.utils import EmbeddingFunc
from raganything import RAGAnything, RAGAnythingConfig

from application.requests.query_request import MultimodalContentItem
from config import LLMConfig, RAGConfig
from domain.entities.indexing_result import (
FileIndexingResult,
Expand Down Expand Up @@ -42,7 +43,7 @@ class LightRAGAdapter(RAGEnginePort):
def __init__(self, llm_config: LLMConfig, rag_config: RAGConfig) -> None:
self._llm_config = llm_config
self._rag_config = rag_config
self.rag: dict[str,RAGAnything] = {}
self.rag: dict[str, RAGAnything] = {}

@staticmethod
def _make_workspace(working_dir: str) -> str:
Expand Down Expand Up @@ -93,6 +94,7 @@ def init_project(self, working_dir: str) -> RAGAnything:
},
)
return self.rag[working_dir]

# ------------------------------------------------------------------
# LLM callables (passed directly to RAGAnything)
# ------------------------------------------------------------------
Expand Down Expand Up @@ -217,7 +219,9 @@ async def index_folder(
# Port implementation — query
# ------------------------------------------------------------------

async def query(self, query: str, mode: str = "naive", top_k: int = 10, working_dir: str = "") -> dict:
async def query(
self, query: str, mode: str = "naive", top_k: int = 10, working_dir: str = ""
) -> dict:
rag = self._ensure_initialized(working_dir)
await rag._ensure_lightrag_initialized()
if rag.lightrag is None:
Expand All @@ -229,6 +233,26 @@ async def query(self, query: str, mode: str = "naive", top_k: int = 10, working_
param = QueryParam(mode=cast(QueryMode, mode), top_k=top_k, chunk_top_k=top_k)
return await rag.lightrag.aquery_data(query=query, param=param)

async def query_multimodal(
self,
query: str,
multimodal_content: list[MultimodalContentItem],
mode: str = "hybrid",
top_k: int = 10,
working_dir: str = "",
) -> str:
rag = self._ensure_initialized(working_dir)
await rag._ensure_lightrag_initialized()
raw_content = [
item.model_dump(exclude_none=True) for item in multimodal_content
]
return await rag.aquery_with_multimodal(
query=query,
multimodal_content=raw_content,
mode=mode,
top_k=top_k,
)

# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def mock_rag_engine() -> AsyncMock:
processing_time_ms=500.0,
)

mock.query_multimodal.return_value = "Multimodal analysis result"

return mock


Expand Down
91 changes: 91 additions & 0 deletions tests/unit/test_multimodal_query_use_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from unittest.mock import AsyncMock

from application.requests.query_request import MultimodalContentItem
from application.use_cases.multimodal_query_use_case import MultimodalQueryUseCase


class TestMultimodalQueryUseCase:
"""Tests for MultimodalQueryUseCase — rag_engine is external, mocked."""

async def test_execute_calls_init_project(
self,
mock_rag_engine: AsyncMock,
) -> None:
"""Should call rag_engine.init_project with the working_dir."""
use_case = MultimodalQueryUseCase(rag_engine=mock_rag_engine)
content = [MultimodalContentItem(type="image", img_path="/tmp/img.png")]

await use_case.execute(
working_dir="/tmp/rag/project_42",
query="What does this image show?",
multimodal_content=content,
)

mock_rag_engine.init_project.assert_called_once_with("/tmp/rag/project_42")

async def test_execute_calls_query_multimodal_with_correct_params(
self,
mock_rag_engine: AsyncMock,
) -> None:
"""Should call rag_engine.query_multimodal with all params."""
use_case = MultimodalQueryUseCase(rag_engine=mock_rag_engine)
content = [
MultimodalContentItem(
type="table", table_data="A,B\n1,2", table_caption="Test table"
),
]

await use_case.execute(
working_dir="/tmp/rag/test",
query="Analyze this table",
multimodal_content=content,
mode="global",
top_k=20,
)

mock_rag_engine.query_multimodal.assert_called_once_with(
query="Analyze this table",
multimodal_content=content,
mode="global",
top_k=20,
working_dir="/tmp/rag/test",
)

async def test_execute_returns_success_with_result(
self,
mock_rag_engine: AsyncMock,
) -> None:
"""Should return dict with status='success' and data from rag_engine."""
mock_rag_engine.query_multimodal.return_value = "Analysis of the image content"
use_case = MultimodalQueryUseCase(rag_engine=mock_rag_engine)
content = [MultimodalContentItem(type="image", img_path="/tmp/diagram.png")]

result = await use_case.execute(
working_dir="/tmp/rag/test",
query="Describe this diagram",
multimodal_content=content,
)

assert result == {"status": "success", "data": "Analysis of the image content"}

async def test_execute_uses_default_mode_and_top_k(
self,
mock_rag_engine: AsyncMock,
) -> None:
"""Should use mode='hybrid' and top_k=10 by default."""
use_case = MultimodalQueryUseCase(rag_engine=mock_rag_engine)
content = [MultimodalContentItem(type="equation", latex="E=mc^2")]

await use_case.execute(
working_dir="/tmp/rag/test",
query="Explain this formula",
multimodal_content=content,
)

mock_rag_engine.query_multimodal.assert_called_once_with(
query="Explain this formula",
multimodal_content=content,
mode="hybrid",
top_k=10,
working_dir="/tmp/rag/test",
)
Loading
Loading