fix: share ML model instances to reduce startup time#229
Open
harsh-kumar-patwa wants to merge 1 commit intoThe-OpenROAD-Project:masterfrom
Open
fix: share ML model instances to reduce startup time#229harsh-kumar-patwa wants to merge 1 commit intoThe-OpenROAD-Project:masterfrom
harsh-kumar-patwa wants to merge 1 commit intoThe-OpenROAD-Project:masterfrom
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR reduces backend startup time by creating the embeddings and reranker models once and sharing those instances across the 6 retriever chains, instead of loading identical models repeatedly per chain.
Changes:
- Create a shared embedding model and shared cross-encoder reranker once in
RetrieverTools.initialize()and pass them down into retriever chains. - Add optional
embedding_modelplumbing throughHybridRetrieverChain→SimilarityRetrieverChain→FAISSVectorDatabase. - Update unit tests/mocks to reflect the new initialization path and constructor signature.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
backend/src/agents/retriever_tools.py |
Adds an embedding-model factory and shares embedding/reranker instances across all retriever chain instances. |
backend/src/vectorstores/faiss.py |
Allows injecting a pre-created embeddings object to avoid redundant model instantiation. |
backend/src/chains/similarity_retriever_chain.py |
Accepts embedding_model and forwards it into FAISSVectorDatabase. |
backend/src/chains/hybrid_retriever_chain.py |
Accepts shared embedding/reranker instances and uses the shared reranker when contextual rerank is enabled. |
backend/tests/test_retriever_tools.py |
Updates initialization tests to patch the new shared-model creation path. |
backend/tests/test_similarity_retriever_chain.py |
Updates constructor assertion to include the new embedding_model argument. |
Comments suppressed due to low confidence (1)
backend/src/chains/hybrid_retriever_chain.py:140
create_hybrid_retriever()can raiseUnboundLocalErrorbecauseensemble_retrieveris only assigned inside the(similarity_retriever and mmr_retriever and bm25_retriever)block, but it’s used unconditionally when buildingContextualCompressionRetriever(and in theelse). Ifprocessed_docsis empty (or BM25 isn’t created for any reason),ensemble_retriever(and evenbm25_retriever) will be undefined. Initialize these locals toNoneand either build the ensemble with available retrievers or raise a clear error before referencing them.
self.retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
else:
self.retriever = ensemble_retriever
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
The backend startup was slow because RetrieverTools.initialize() creates 6 retriever chains, and each one independently loaded its own copy of the embedding model (thenlper/gte-large) and reranker model (BAAI/bge-reranker-base). That meant 12 heavy model loads when only 2 are actually needed, since all chains use the same model config. This fix creates both models once at the top of initialize() and passes the shared instances down through HybridRetrieverChain, SimilarityRetrieverChain, and FAISSVectorDatabase. Both models are stateless (they only run encode/score inference) so sharing a single instance across all chains is safe. Each chain still builds its own independent FAISS index with its own documents. Startup model loading goes from ~34s to ~7s on a local machine (4.9x). Resolves The-OpenROAD-Project#88 Signed-off-by: Harsh Kumar <harshkumar3446@gmail.com>
f41df8e to
c85b110
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #88
The backend takes a long time to start because
RetrieverTools.initialize()creates 6 retriever chains, and each one independently loads its own copy of the embedding model (thenlper/gte-large) and reranker model (BAAI/bge-reranker-base). That is 12 heavy model loads when only 2 are needed, since all 6 chains use the exact same model configuration.This PR creates both models once at the top of
initialize()and passes the shared instances down throughHybridRetrieverChain→SimilarityRetrieverChain→FAISSVectorDatabase. Each chain still builds its own independent FAISS index with its own documents.Changes
backend/src/agents/retriever_tools.pyAdded a
_create_embedding_model()factory method that creates the embedding model once based on config. Ininitialize(), both the embedding model and reranker model are created once and passed to all 6 chain constructors.backend/src/vectorstores/faiss.pyAdded an optional
embedding_modelparameter toFAISSVectorDatabase.__init__(). When provided, it skips creating a new model and uses the shared one directly.backend/src/chains/similarity_retriever_chain.pyAdded
embedding_modelparameter and passes it through toFAISSVectorDatabaseincreate_vector_db().backend/src/chains/hybrid_retriever_chain.pyAdded
embedding_modelandreranker_modelparameters. Passes the embedding model toSimilarityRetrieverChain, and uses the shared reranker increate_hybrid_retriever()with a fallback to creating a new one if not provided.Test files
Updated mocks in
test_retriever_tools.pyandtest_similarity_retriever_chain.pyto account for the new model creation path.Why sharing is safe
Both models are stateless. The embedding model only runs
encode()to produce vectors, and the reranker only runsscore()to produce similarity scores. Neither modifies internal state during inference. Each FAISS database still maintains its own independent index. All new parameters default toNone, so existing code that doesn't pass shared models continues to work exactly as before.Benchmark
Tested on local machine (Apple M3, thenlper/gte-large + BAAI/bge-reranker-base):
Test plan
uv run pytest tests/ -v)make formatpasses