From c85b1107b852ca8550ef18ab7784a6d792194d36 Mon Sep 17 00:00:00 2001 From: Harsh Kumar Date: Sun, 1 Mar 2026 19:25:59 +0530 Subject: [PATCH] fix: share ML model instances to reduce startup time The backend startup was slow because RetrieverTools.initialize() creates 6 retriever chains, and each one independently loaded its own copy of the embedding model (thenlper/gte-large) and reranker model (BAAI/bge-reranker-base). That meant 12 heavy model loads when only 2 are actually needed, since all chains use the same model config. This fix creates both models once at the top of initialize() and passes the shared instances down through HybridRetrieverChain, SimilarityRetrieverChain, and FAISSVectorDatabase. Both models are stateless (they only run encode/score inference) so sharing a single instance across all chains is safe. Each chain still builds its own independent FAISS index with its own documents. Startup model loading goes from ~34s to ~7s on a local machine (4.9x). Resolves #88 Signed-off-by: Harsh Kumar --- backend/src/agents/retriever_tools.py | 53 +++++++++++++ backend/src/chains/hybrid_retriever_chain.py | 27 ++++++- .../src/chains/similarity_retriever_chain.py | 20 ++++- backend/src/vectorstores/faiss.py | 12 ++- backend/tests/test_retriever_tools.py | 75 +++++++++++++++++-- .../tests/test_similarity_retriever_chain.py | 30 +++++++- 6 files changed, 201 insertions(+), 16 deletions(-) diff --git a/backend/src/agents/retriever_tools.py b/backend/src/agents/retriever_tools.py index e72001b0..a1aafdcc 100644 --- a/backend/src/agents/retriever_tools.py +++ b/backend/src/agents/retriever_tools.py @@ -1,10 +1,15 @@ import os +import logging from typing import Tuple, Optional, Union from dotenv import load_dotenv from langchain_core.tools import tool from langchain.retrievers import EnsembleRetriever from langchain.retrievers import ContextualCompressionRetriever +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_google_genai import GoogleGenerativeAIEmbeddings +from langchain_google_vertexai import VertexAIEmbeddings +from langchain_community.cross_encoders import HuggingFaceCrossEncoder from ..chains.hybrid_retriever_chain import HybridRetrieverChain from ..tools.format_docs import format_docs @@ -39,6 +44,35 @@ def __init__(self) -> None: ] tool_descriptions: str = "" + @staticmethod + def _create_embedding_model( + embeddings_config: dict[str, str], + use_cuda: bool = False, + ) -> Union[HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings]: + embeddings_type = embeddings_config["type"] + embeddings_model_name = embeddings_config["name"] + + if embeddings_type == "GOOGLE_GENAI": + logging.info("Using Google GenerativeAI embeddings...") + return GoogleGenerativeAIEmbeddings( + model=embeddings_model_name, + task_type="retrieval_document", + ) + elif embeddings_type == "GOOGLE_VERTEXAI": + logging.info("Using Google VertexAI embeddings...") + return VertexAIEmbeddings(model_name=embeddings_model_name) + elif embeddings_type == "HF": + logging.info("Using HuggingFace embeddings...") + model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"} + return HuggingFaceEmbeddings( + model_name=embeddings_model_name, + multi_process=False, + encode_kwargs={"normalize_embeddings": True}, + model_kwargs=model_kwargs, + ) + else: + raise ValueError("Invalid embeddings type specified.") + def initialize( self, embeddings_config: dict[str, str], @@ -46,6 +80,13 @@ def initialize( use_cuda: bool = False, fast_mode: bool = False, ) -> None: + # Create shared model instances once + embedding_model = self._create_embedding_model(embeddings_config, use_cuda) + logging.info("Shared embedding model created.") + + reranker_model = HuggingFaceCrossEncoder(model_name=reranking_model_name) + logging.info("Shared reranker model created.") + markdown_docs_map = { "general": [ "./data/markdown/OR_docs", @@ -100,6 +141,8 @@ def initialize( contextual_rerank=True, search_k=search_k, chunk_size=chunk_size, + embedding_model=embedding_model, + reranker_model=reranker_model, ) general_retriever_chain.create_hybrid_retriever() RetrieverTools.general_retriever = general_retriever_chain.retriever @@ -115,6 +158,8 @@ def initialize( contextual_rerank=True, search_k=search_k, chunk_size=chunk_size, + embedding_model=embedding_model, + reranker_model=reranker_model, ) install_retriever_chain.create_hybrid_retriever() RetrieverTools.install_retriever = install_retriever_chain.retriever @@ -131,6 +176,8 @@ def initialize( contextual_rerank=True, search_k=search_k, chunk_size=chunk_size, + embedding_model=embedding_model, + reranker_model=reranker_model, ) commands_retriever_chain.create_hybrid_retriever() RetrieverTools.commands_retriever = commands_retriever_chain.retriever @@ -146,6 +193,8 @@ def initialize( contextual_rerank=True, search_k=search_k, chunk_size=chunk_size, + embedding_model=embedding_model, + reranker_model=reranker_model, ) yosys_rtdocs_retriever_chain.create_hybrid_retriever() RetrieverTools.yosys_rtdocs_retriever = yosys_rtdocs_retriever_chain.retriever @@ -161,6 +210,8 @@ def initialize( contextual_rerank=True, search_k=search_k, chunk_size=chunk_size, + embedding_model=embedding_model, + reranker_model=reranker_model, ) klayout_retriever_chain.create_hybrid_retriever() RetrieverTools.klayout_retriever = klayout_retriever_chain.retriever @@ -176,6 +227,8 @@ def initialize( contextual_rerank=True, search_k=search_k, chunk_size=chunk_size, + embedding_model=embedding_model, + reranker_model=reranker_model, ) errinfo_retriever_chain.create_hybrid_retriever() RetrieverTools.errinfo_retriever = errinfo_retriever_chain.retriever diff --git a/backend/src/chains/hybrid_retriever_chain.py b/backend/src/chains/hybrid_retriever_chain.py index 1b68c14f..791f694d 100644 --- a/backend/src/chains/hybrid_retriever_chain.py +++ b/backend/src/chains/hybrid_retriever_chain.py @@ -5,8 +5,9 @@ from langchain.retrievers import ContextualCompressionRetriever from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain_community.cross_encoders import HuggingFaceCrossEncoder -from langchain_google_vertexai import ChatVertexAI -from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings +from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings from langchain_ollama import ChatOllama from langchain.retrievers.document_compressors.cross_encoder_rerank import ( CrossEncoderReranker, @@ -38,6 +39,14 @@ def __init__( weights: list[float] = [0.33, 0.33, 0.33], chunk_size: int = 500, contextual_rerank: bool = False, + embedding_model: Optional[ + Union[ + HuggingFaceEmbeddings, + GoogleGenerativeAIEmbeddings, + VertexAIEmbeddings, + ] + ] = None, + reranker_model: Optional[HuggingFaceCrossEncoder] = None, ): super().__init__( llm_model=llm_model, @@ -48,6 +57,14 @@ def __init__( self.reranking_model_name: Optional[str] = reranking_model_name self.use_cuda: bool = use_cuda + self.embedding_model: Optional[ + Union[ + HuggingFaceEmbeddings, + GoogleGenerativeAIEmbeddings, + VertexAIEmbeddings, + ] + ] = embedding_model + self.reranker_model: Optional[HuggingFaceCrossEncoder] = reranker_model self.search_k: int = search_k self.weights: list[float] = weights @@ -74,6 +91,7 @@ def create_hybrid_retriever(self) -> None: html_docs_path=self.html_docs_path, chunk_size=self.chunk_size, use_cuda=self.use_cuda, + embedding_model=self.embedding_model, ) if self.vector_db is None: cur_path = os.path.abspath(__file__) @@ -121,8 +139,11 @@ def create_hybrid_retriever(self) -> None: ) if self.contextual_rerank: + reranker = self.reranker_model or HuggingFaceCrossEncoder( + model_name=self.reranking_model_name + ) compressor = CrossEncoderReranker( - model=HuggingFaceCrossEncoder(model_name=self.reranking_model_name), + model=reranker, top_n=self.search_k, ) self.retriever = ContextualCompressionRetriever( diff --git a/backend/src/chains/similarity_retriever_chain.py b/backend/src/chains/similarity_retriever_chain.py index d3a054e0..fa408056 100644 --- a/backend/src/chains/similarity_retriever_chain.py +++ b/backend/src/chains/similarity_retriever_chain.py @@ -3,8 +3,9 @@ from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain.docstore.document import Document -from langchain_google_vertexai import ChatVertexAI -from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI +from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings from langchain_ollama import ChatOllama from ..vectorstores.faiss import FAISSVectorDatabase @@ -28,6 +29,13 @@ def __init__( embeddings_config: Optional[dict[str, str]] = None, use_cuda: bool = False, chunk_size: int = 500, + embedding_model: Optional[ + Union[ + HuggingFaceEmbeddings, + GoogleGenerativeAIEmbeddings, + VertexAIEmbeddings, + ] + ] = None, ): super().__init__( llm_model=llm_model, @@ -40,6 +48,13 @@ def __init__( self.embeddings_config: Optional[dict[str, str]] = embeddings_config self.use_cuda: bool = use_cuda + self.embedding_model: Optional[ + Union[ + HuggingFaceEmbeddings, + GoogleGenerativeAIEmbeddings, + VertexAIEmbeddings, + ] + ] = embedding_model self.markdown_docs_path: Optional[list[str]] = markdown_docs_path self.other_docs_path: Optional[list[str]] = other_docs_path @@ -125,6 +140,7 @@ def create_vector_db(self) -> None: embeddings_model_name=self.embeddings_config["name"], embeddings_type=self.embeddings_config["type"], use_cuda=self.use_cuda, + embedding_model=self.embedding_model, ) else: raise ValueError("Embeddings model config not provided correctly.") diff --git a/backend/src/vectorstores/faiss.py b/backend/src/vectorstores/faiss.py index f4308dd5..d49b4ead 100644 --- a/backend/src/vectorstores/faiss.py +++ b/backend/src/vectorstores/faiss.py @@ -28,16 +28,21 @@ def __init__( distance_strategy: DistanceStrategy = DistanceStrategy.COSINE, debug: bool = False, use_cuda: bool = False, + embedding_model: Optional[ + Union[ + HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings + ] + ] = None, ): self.embeddings_model_name = embeddings_model_name - model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"} - self.embedding_model: Union[ HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings ] - if embeddings_type == "GOOGLE_GENAI": + if embedding_model is not None: + self.embedding_model = embedding_model + elif embeddings_type == "GOOGLE_GENAI": self.embedding_model = GoogleGenerativeAIEmbeddings( model=self.embeddings_model_name, task_type="retrieval_document", @@ -51,6 +56,7 @@ def __init__( logging.info("Using Google VertexAI embeddings...") elif embeddings_type == "HF": + model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"} self.embedding_model = HuggingFaceEmbeddings( model_name=self.embeddings_model_name, multi_process=False, diff --git a/backend/tests/test_retriever_tools.py b/backend/tests/test_retriever_tools.py index f6bd25d0..5b99c300 100644 --- a/backend/tests/test_retriever_tools.py +++ b/backend/tests/test_retriever_tools.py @@ -14,11 +14,18 @@ def test_init(self): # Check that it's a valid instance assert isinstance(tools, RetrieverTools) + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_initialize_success(self, mock_hybrid_chain): + def test_initialize_success( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test successful initialization of all retrievers.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range( @@ -40,6 +47,18 @@ def test_initialize_success(self, mock_hybrid_chain): fast_mode=False, ) + # Verify models are created exactly once + mock_create_embed.assert_called_once_with(embeddings_config, True) + mock_cross_encoder.assert_called_once_with(model_name=reranking_model_name) + + # Verify the same shared instances are passed to all 6 chains + shared_embedding = mock_create_embed.return_value + shared_reranker = mock_cross_encoder.return_value + for call in mock_hybrid_chain.call_args_list: + kwargs = call[1] + assert kwargs["embedding_model"] is shared_embedding + assert kwargs["reranker_model"] is shared_reranker + # Verify all retrievers are created assert mock_hybrid_chain.call_count == 6 @@ -55,11 +74,18 @@ def test_initialize_success(self, mock_hybrid_chain): assert RetrieverTools.klayout_retriever == mock_chains[4].retriever assert RetrieverTools.errinfo_retriever == mock_chains[5].retriever + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_initialize_with_fast_mode(self, mock_hybrid_chain): + def test_initialize_with_fast_mode( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test initialization with fast mode enabled.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range(6): @@ -250,11 +276,18 @@ def test_retrieve_klayout_docs_not_initialized(self): with pytest.raises(ValueError, match="KLayout Retriever not initialized"): RetrieverTools.retrieve_klayout_docs.invoke(input="test query") + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_initialize_verifies_configuration_parameters(self, mock_hybrid_chain): + def test_initialize_verifies_configuration_parameters( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test that initialize passes correct configuration parameters.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range(6): @@ -283,11 +316,18 @@ def test_initialize_verifies_configuration_parameters(self, mock_hybrid_chain): assert kwargs["weights"] == [0.6, 0.2, 0.2] assert kwargs["contextual_rerank"] is True + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_initialize_with_environment_variables(self, mock_hybrid_chain): + def test_initialize_with_environment_variables( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test initialization respects environment variables.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range(6): @@ -323,11 +363,18 @@ def test_tool_decorators_applied(self): assert hasattr(RetrieverTools.retrieve_yosys_rtdocs, "name") assert hasattr(RetrieverTools.retrieve_klayout_docs, "name") + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_different_docs_paths_for_retrievers(self, mock_hybrid_chain): + def test_different_docs_paths_for_retrievers( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test that different retrievers use different document paths.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range(6): @@ -369,11 +416,18 @@ def test_different_docs_paths_for_retrievers(self, mock_hybrid_chain): # Errinfo should have error-specific paths assert any("man3" in path for path in errinfo_paths) + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_html_docs_configuration(self, mock_hybrid_chain): + def test_html_docs_configuration( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test HTML docs configuration for specific retrievers.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range(6): @@ -426,11 +480,18 @@ def test_staticmethod_decorators(self): result = RetrieverTools.retrieve_general.invoke(input="test") assert result == ("", [], [], []) + @patch("src.agents.retriever_tools.HuggingFaceCrossEncoder") + @patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model") @patch("src.agents.retriever_tools.HybridRetrieverChain") - def test_retriever_chain_create_hybrid_retriever_called(self, mock_hybrid_chain): + def test_retriever_chain_create_hybrid_retriever_called( + self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder + ): """Test that create_hybrid_retriever is called on all chains.""" tools = RetrieverTools() + mock_create_embed.return_value = Mock() + mock_cross_encoder.return_value = Mock() + # Mock the HybridRetrieverChain instances mock_chains = [] for i in range(6): diff --git a/backend/tests/test_similarity_retriever_chain.py b/backend/tests/test_similarity_retriever_chain.py index 8e2a76f6..4cc0d1fd 100644 --- a/backend/tests/test_similarity_retriever_chain.py +++ b/backend/tests/test_similarity_retriever_chain.py @@ -238,7 +238,35 @@ def test_create_vector_db_success(self, mock_faiss_db): assert chain.vector_db == mock_db_instance mock_faiss_db.assert_called_once_with( - embeddings_model_name="test-model", embeddings_type="HF", use_cuda=True + embeddings_model_name="test-model", + embeddings_type="HF", + use_cuda=True, + embedding_model=None, + ) + + @patch("src.chains.similarity_retriever_chain.FAISSVectorDatabase") + def test_create_vector_db_uses_provided_embedding_model(self, mock_faiss_db): + """Test that a provided embedding_model is forwarded to FAISSVectorDatabase.""" + mock_db_instance = Mock() + mock_faiss_db.return_value = mock_db_instance + + embeddings_config = {"type": "HF", "name": "test-model"} + sentinel_embedding_model = object() + + chain = SimilarityRetrieverChain( + embeddings_config=embeddings_config, + use_cuda=True, + embedding_model=sentinel_embedding_model, + ) + + chain.create_vector_db() + + assert chain.vector_db == mock_db_instance + mock_faiss_db.assert_called_once_with( + embeddings_model_name="test-model", + embeddings_type="HF", + use_cuda=True, + embedding_model=sentinel_embedding_model, ) def test_create_vector_db_missing_config_raises_error(self):