diff --git a/backend/src/agents/retriever_tools.py b/backend/src/agents/retriever_tools.py index e72001b0..a1aafdcc 100644 --- a/backend/src/agents/retriever_tools.py +++ b/backend/src/agents/retriever_tools.py @@ -1,10 +1,15 @@ import os +import logging from typing import Tuple, Optional, Union from dotenv import load_dotenv from langchain_core.tools import tool from langchain.retrievers import EnsembleRetriever from langchain.retrievers import ContextualCompressionRetriever +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_google_genai import GoogleGenerativeAIEmbeddings +from langchain_google_vertexai import VertexAIEmbeddings +from langchain_community.cross_encoders import HuggingFaceCrossEncoder from ..chains.hybrid_retriever_chain import HybridRetrieverChain from ..tools.format_docs import format_docs @@ -39,6 +44,35 @@ def __init__(self) -> None: ] tool_descriptions: str = "" + @staticmethod + def _create_embedding_model( + embeddings_config: dict[str, str], + use_cuda: bool = False, + ) -> Union[HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings]: + embeddings_type = embeddings_config["type"] + embeddings_model_name = embeddings_config["name"] + + if embeddings_type == "GOOGLE_GENAI": + logging.info("Using Google GenerativeAI embeddings...") + return GoogleGenerativeAIEmbeddings( + model=embeddings_model_name, + task_type="retrieval_document", + ) + elif embeddings_type == "GOOGLE_VERTEXAI": + logging.info("Using Google VertexAI embeddings...") + return VertexAIEmbeddings(model_name=embeddings_model_name) + elif embeddings_type == "HF": + logging.info("Using HuggingFace embeddings...") + model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"} + return HuggingFaceEmbeddings( + model_name=embeddings_model_name, + multi_process=False, + encode_kwargs={"normalize_embeddings": True}, + model_kwargs=model_kwargs, + ) + else: + raise ValueError("Invalid embeddings type specified.") + def initialize( self, embeddings_config: dict[str, str], @@ -46,6 +80,13 @@ def initialize( use_cuda: bool = False, fast_mode: bool = False, ) -> None: + # Create shared model instances once + embedding_model = self._create_embedding_model(embeddings_config, use_cuda) + logging.info("Shared embedding model created.") + + reranker_model = HuggingFaceCrossEncoder(model_name=reranking_model_name) + logging.info("Shared reranker model created.") + markdown_docs_map = { "general": [ "./data/markdown/OR_docs", @@ -100,6 +141,8 @@ def initialize( contextual_rerank=True, search_k=search_k, chunk_size=chunk_size, + embedding_model=embedding_model, + reranker_model=reranker_model, ) general_retriever_chain.create_hybrid_retriever() RetrieverTools.general_retriever = general_retriever_chain.retriever @@ -115,6 +158,8 @@ def initialize( contextual_rerank=True, search_k=search_k, chunk_size=chunk_size, + embedding_model=embedding_model, + reranker_model=reranker_model, ) install_retriever_chain.create_hybrid_retriever() RetrieverTools.install_retriever = install_retriever_chain.retriever @@ -131,6 +176,8 @@ def initialize( contextual_rerank=True, search_k=search_k, chunk_size=chunk_size, + embedding_model=embedding_model, + reranker_model=reranker_model, ) commands_retriever_chain.create_hybrid_retriever() RetrieverTools.commands_retriever = commands_retriever_chain.retriever @@ -146,6 +193,8 @@ def initialize( contextual_rerank=True, search_k=search_k, chunk_size=chunk_size, + embedding_model=embedding_model, + reranker_model=reranker_model, ) yosys_rtdocs_retriever_chain.create_hybrid_retriever() RetrieverTools.yosys_rtdocs_retriever = yosys_rtdocs_retriever_chain.retriever @@ -161,6 +210,8 @@ def initialize( contextual_rerank=True, search_k=search_k, chunk_size=chunk_size, + embedding_model=embedding_model, + reranker_model=reranker_model, ) klayout_retriever_chain.create_hybrid_retriever() RetrieverTools.klayout_retriever = klayout_retriever_chain.retriever @@ -176,6 +227,8 @@ def initialize( contextual_rerank=True, search_k=search_k, chunk_size=chunk_size, + embedding_model=embedding_model, + reranker_model=reranker_model, ) errinfo_retriever_chain.create_hybrid_retriever() RetrieverTools.errinfo_retriever = errinfo_retriever_chain.retriever diff --git a/backend/src/chains/hybrid_retriever_chain.py b/backend/src/chains/hybrid_retriever_chain.py index 1b68c14f..791f694d 100644 --- a/backend/src/chains/hybrid_retriever_chain.py +++ b/backend/src/chains/hybrid_retriever_chain.py @@ -5,8 +5,9 @@ from langchain.retrievers import ContextualCompressionRetriever from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain_community.cross_encoders import HuggingFaceCrossEncoder -from langchain_google_vertexai import ChatVertexAI -from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings +from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings from langchain_ollama import ChatOllama from langchain.retrievers.document_compressors.cross_encoder_rerank import ( CrossEncoderReranker, @@ -38,6 +39,14 @@ def __init__( weights: list[float] = [0.33, 0.33, 0.33], chunk_size: int = 500, contextual_rerank: bool = False, + embedding_model: Optional[ + Union[ + HuggingFaceEmbeddings, + GoogleGenerativeAIEmbeddings, + VertexAIEmbeddings, + ] + ] = None, + reranker_model: Optional[HuggingFaceCrossEncoder] = None, ): super().__init__( llm_model=llm_model, @@ -48,6 +57,14 @@ def __init__( self.reranking_model_name: Optional[str] = reranking_model_name self.use_cuda: bool = use_cuda + self.embedding_model: Optional[ + Union[ + HuggingFaceEmbeddings, + GoogleGenerativeAIEmbeddings, + VertexAIEmbeddings, + ] + ] = embedding_model + self.reranker_model: Optional[HuggingFaceCrossEncoder] = reranker_model self.search_k: int = search_k self.weights: list[float] = weights @@ -74,6 +91,7 @@ def create_hybrid_retriever(self) -> None: html_docs_path=self.html_docs_path, chunk_size=self.chunk_size, use_cuda=self.use_cuda, + embedding_model=self.embedding_model, ) if self.vector_db is None: cur_path = os.path.abspath(__file__) @@ -121,8 +139,11 @@ def create_hybrid_retriever(self) -> None: ) if self.contextual_rerank: + reranker = self.reranker_model or HuggingFaceCrossEncoder( + model_name=self.reranking_model_name + ) compressor = CrossEncoderReranker( - model=HuggingFaceCrossEncoder(model_name=self.reranking_model_name), + model=reranker, top_n=self.search_k, ) self.retriever = ContextualCompressionRetriever( diff --git a/backend/src/chains/similarity_retriever_chain.py b/backend/src/chains/similarity_retriever_chain.py index d3a054e0..fa408056 100644 --- a/backend/src/chains/similarity_retriever_chain.py +++ b/backend/src/chains/similarity_retriever_chain.py @@ -3,8 +3,9 @@ from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain.docstore.document import Document -from langchain_google_vertexai import ChatVertexAI -from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI +from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings from langchain_ollama import ChatOllama from ..vectorstores.faiss import FAISSVectorDatabase @@ -28,6 +29,13 @@ def __init__( embeddings_config: Optional[dict[str, str]] = None, use_cuda: bool = False, chunk_size: int = 500, + embedding_model: Optional[ + Union[ + HuggingFaceEmbeddings, + GoogleGenerativeAIEmbeddings, + VertexAIEmbeddings, + ] + ] = None, ): super().__init__( llm_model=llm_model, @@ -40,6 +48,13 @@ def __init__( self.embeddings_config: Optional[dict[str, str]] = embeddings_config self.use_cuda: bool = use_cuda + self.embedding_model: Optional[ + Union[ + HuggingFaceEmbeddings, + GoogleGenerativeAIEmbeddings, + VertexAIEmbeddings, + ] + ] = embedding_model self.markdown_docs_path: Optional[list[str]] = markdown_docs_path self.other_docs_path: Optional[list[str]] = other_docs_path @@ -125,6 +140,7 @@ def create_vector_db(self) -> None: embeddings_model_name=self.embeddings_config["name"], embeddings_type=self.embeddings_config["type"], use_cuda=self.use_cuda, + embedding_model=self.embedding_model, ) else: raise ValueError("Embeddings model config not provided correctly.") diff --git a/backend/src/vectorstores/faiss.py b/backend/src/vectorstores/faiss.py index f4308dd5..d49b4ead 100644 --- a/backend/src/vectorstores/faiss.py +++ b/backend/src/vectorstores/faiss.py @@ -28,16 +28,21 @@ def __init__( distance_strategy: DistanceStrategy = DistanceStrategy.COSINE, debug: bool = False, use_cuda: bool = False, + embedding_model: Optional[ + Union[ + HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings + ] + ] = None, ): self.embeddings_model_name = embeddings_model_name - model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"} - self.embedding_model: Union[ HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings ] - if embeddings_type == "GOOGLE_GENAI": + if embedding_model is not None: + self.embedding_model = embedding_model + elif embeddings_type == "GOOGLE_GENAI": self.embedding_model = GoogleGenerativeAIEmbeddings( model=self.embeddings_model_name, task_type="retrieval_document", @@ -51,6 +56,7 @@ def __init__( logging.info("Using Google VertexAI embeddings...") elif embeddings_type == "HF": + model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"} self.embedding_model = HuggingFaceEmbeddings( model_name=self.embeddings_model_name, multi_process=False, diff --git a/backend/tests/test_retriever_tools.py b/backend/tests/test_retriever_tools.py index f6bd25d0..5b99c300 100644 --- a/backend/tests/test_retriever_tools.py +++ b/backend/tests/test_retriever_tools.py @@ -14,11 +14,18 @@ def test_init(self): # Check that it's a valid instance assert isinstance(tools, RetrieverTools) + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_initialize_success(self, mock_hybrid_chain): + def test_initialize_success( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test successful initialization of all retrievers.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range( @@ -40,6 +47,18 @@ def test_initialize_success(self, mock_hybrid_chain): fast_mode=False, ) + # Verify models are created exactly once + mock_create_embed.assert_called_once_with(embeddings_config, True) + mock_cross_encoder.assert_called_once_with(model_name=reranking_model_name) + + # Verify the same shared instances are passed to all 6 chains + shared_embedding = mock_create_embed.return_value + shared_reranker = mock_cross_encoder.return_value + for call in mock_hybrid_chain.call_args_list: + kwargs = call[1] + assert kwargs["embedding_model"] is shared_embedding + assert kwargs["reranker_model"] is shared_reranker + # Verify all retrievers are created assert mock_hybrid_chain.call_count == 6 @@ -55,11 +74,18 @@ def test_initialize_success(self, mock_hybrid_chain): assert RetrieverTools.klayout_retriever == mock_chains[4].retriever assert RetrieverTools.errinfo_retriever == mock_chains[5].retriever + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_initialize_with_fast_mode(self, mock_hybrid_chain): + def test_initialize_with_fast_mode( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test initialization with fast mode enabled.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range(6): @@ -250,11 +276,18 @@ def test_retrieve_klayout_docs_not_initialized(self): with pytest.raises(ValueError, match="KLayout Retriever not initialized"): RetrieverTools.retrieve_klayout_docs.invoke(input="test query") + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_initialize_verifies_configuration_parameters(self, mock_hybrid_chain): + def test_initialize_verifies_configuration_parameters( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test that initialize passes correct configuration parameters.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range(6): @@ -283,11 +316,18 @@ def test_initialize_verifies_configuration_parameters(self, mock_hybrid_chain): assert kwargs["weights"] == [0.6, 0.2, 0.2] assert kwargs["contextual_rerank"] is True + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_initialize_with_environment_variables(self, mock_hybrid_chain): + def test_initialize_with_environment_variables( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test initialization respects environment variables.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range(6): @@ -323,11 +363,18 @@ def test_tool_decorators_applied(self): assert hasattr(RetrieverTools.retrieve_yosys_rtdocs, "name") assert hasattr(RetrieverTools.retrieve_klayout_docs, "name") + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_different_docs_paths_for_retrievers(self, mock_hybrid_chain): + def test_different_docs_paths_for_retrievers( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test that different retrievers use different document paths.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range(6): @@ -369,11 +416,18 @@ def test_different_docs_paths_for_retrievers(self, mock_hybrid_chain): # Errinfo should have error-specific paths assert any("man3" in path for path in errinfo_paths) + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_html_docs_configuration(self, mock_hybrid_chain): + def test_html_docs_configuration( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test HTML docs configuration for specific retrievers.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range(6): @@ -426,11 +480,18 @@ def test_staticmethod_decorators(self): result = RetrieverTools.retrieve_general.invoke(input="test") assert result == ("", [], [], []) + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_retriever_chain_create_hybrid_retriever_called(self, mock_hybrid_chain): + def test_retriever_chain_create_hybrid_retriever_called( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test that create_hybrid_retriever is called on all chains.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range(6): diff --git a/backend/tests/test_similarity_retriever_chain.py b/backend/tests/test_similarity_retriever_chain.py index 8e2a76f6..4cc0d1fd 100644 --- a/backend/tests/test_similarity_retriever_chain.py +++ b/backend/tests/test_similarity_retriever_chain.py @@ -238,7 +238,35 @@ def test_create_vector_db_success(self, mock_faiss_db): assert chain.vector_db == mock_db_instance mock_faiss_db.assert_called_once_with( - embeddings_model_name="test-model", embeddings_type="HF", use_cuda=True + embeddings_model_name="test-model", + embeddings_type="HF", + use_cuda=True, + embedding_model=None, + ) + + @patch("src.chains.similarity_retriever_chain.FAISSVectorDatabase") + def test_create_vector_db_uses_provided_embedding_model(self, mock_faiss_db): + """Test that a provided embedding_model is forwarded to FAISSVectorDatabase.""" + mock_db_instance = Mock() + mock_faiss_db.return_value = mock_db_instance + + embeddings_config = {"type": "HF", "name": "test-model"} + sentinel_embedding_model = object() + + chain = SimilarityRetrieverChain( + embeddings_config=embeddings_config, + use_cuda=True, + embedding_model=sentinel_embedding_model, + ) + + chain.create_vector_db() + + assert chain.vector_db == mock_db_instance + mock_faiss_db.assert_called_once_with( + embeddings_model_name="test-model", + embeddings_type="HF", + use_cuda=True, + embedding_model=sentinel_embedding_model, ) def test_create_vector_db_missing_config_raises_error(self):