Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions backend/src/agents/retriever_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os
import logging
from typing import Tuple, Optional, Union
from dotenv import load_dotenv

from langchain_core.tools import tool
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers import ContextualCompressionRetriever
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_google_vertexai import VertexAIEmbeddings
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

from ..chains.hybrid_retriever_chain import HybridRetrieverChain
from ..tools.format_docs import format_docs
Expand Down Expand Up @@ -39,13 +44,49 @@ def __init__(self) -> None:
]
tool_descriptions: str = ""

@staticmethod
def _create_embedding_model(
embeddings_config: dict[str, str],
use_cuda: bool = False,
) -> Union[HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings]:
embeddings_type = embeddings_config["type"]
embeddings_model_name = embeddings_config["name"]

if embeddings_type == "GOOGLE_GENAI":
logging.info("Using Google GenerativeAI embeddings...")
return GoogleGenerativeAIEmbeddings(
model=embeddings_model_name,
task_type="retrieval_document",
)
elif embeddings_type == "GOOGLE_VERTEXAI":
logging.info("Using Google VertexAI embeddings...")
return VertexAIEmbeddings(model_name=embeddings_model_name)
elif embeddings_type == "HF":
logging.info("Using HuggingFace embeddings...")
model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"}
return HuggingFaceEmbeddings(
model_name=embeddings_model_name,
multi_process=False,
encode_kwargs={"normalize_embeddings": True},
model_kwargs=model_kwargs,
)
else:
raise ValueError("Invalid embeddings type specified.")

def initialize(
self,
embeddings_config: dict[str, str],
reranking_model_name: str,
use_cuda: bool = False,
fast_mode: bool = False,
) -> None:
# Create shared model instances once
embedding_model = self._create_embedding_model(embeddings_config, use_cuda)
logging.info("Shared embedding model created.")

reranker_model = HuggingFaceCrossEncoder(model_name=reranking_model_name)
logging.info("Shared reranker model created.")

markdown_docs_map = {
"general": [
"./data/markdown/OR_docs",
Expand Down Expand Up @@ -100,6 +141,8 @@ def initialize(
contextual_rerank=True,
search_k=search_k,
chunk_size=chunk_size,
embedding_model=embedding_model,
reranker_model=reranker_model,
)
general_retriever_chain.create_hybrid_retriever()
RetrieverTools.general_retriever = general_retriever_chain.retriever
Expand All @@ -115,6 +158,8 @@ def initialize(
contextual_rerank=True,
search_k=search_k,
chunk_size=chunk_size,
embedding_model=embedding_model,
reranker_model=reranker_model,
)
install_retriever_chain.create_hybrid_retriever()
RetrieverTools.install_retriever = install_retriever_chain.retriever
Expand All @@ -131,6 +176,8 @@ def initialize(
contextual_rerank=True,
search_k=search_k,
chunk_size=chunk_size,
embedding_model=embedding_model,
reranker_model=reranker_model,
)
commands_retriever_chain.create_hybrid_retriever()
RetrieverTools.commands_retriever = commands_retriever_chain.retriever
Expand All @@ -146,6 +193,8 @@ def initialize(
contextual_rerank=True,
search_k=search_k,
chunk_size=chunk_size,
embedding_model=embedding_model,
reranker_model=reranker_model,
)
yosys_rtdocs_retriever_chain.create_hybrid_retriever()
RetrieverTools.yosys_rtdocs_retriever = yosys_rtdocs_retriever_chain.retriever
Expand All @@ -161,6 +210,8 @@ def initialize(
contextual_rerank=True,
search_k=search_k,
chunk_size=chunk_size,
embedding_model=embedding_model,
reranker_model=reranker_model,
)
klayout_retriever_chain.create_hybrid_retriever()
RetrieverTools.klayout_retriever = klayout_retriever_chain.retriever
Expand All @@ -176,6 +227,8 @@ def initialize(
contextual_rerank=True,
search_k=search_k,
chunk_size=chunk_size,
embedding_model=embedding_model,
reranker_model=reranker_model,
)
errinfo_retriever_chain.create_hybrid_retriever()
RetrieverTools.errinfo_retriever = errinfo_retriever_chain.retriever
Expand Down
27 changes: 24 additions & 3 deletions backend/src/chains/hybrid_retriever_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from langchain.retrievers import ContextualCompressionRetriever
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_google_vertexai import ChatVertexAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_ollama import ChatOllama
from langchain.retrievers.document_compressors.cross_encoder_rerank import (
CrossEncoderReranker,
Expand Down Expand Up @@ -38,6 +39,14 @@ def __init__(
weights: list[float] = [0.33, 0.33, 0.33],
chunk_size: int = 500,
contextual_rerank: bool = False,
embedding_model: Optional[
Union[
HuggingFaceEmbeddings,
GoogleGenerativeAIEmbeddings,
VertexAIEmbeddings,
]
] = None,
reranker_model: Optional[HuggingFaceCrossEncoder] = None,
):
super().__init__(
llm_model=llm_model,
Expand All @@ -48,6 +57,14 @@ def __init__(

self.reranking_model_name: Optional[str] = reranking_model_name
self.use_cuda: bool = use_cuda
self.embedding_model: Optional[
Union[
HuggingFaceEmbeddings,
GoogleGenerativeAIEmbeddings,
VertexAIEmbeddings,
]
] = embedding_model
self.reranker_model: Optional[HuggingFaceCrossEncoder] = reranker_model

self.search_k: int = search_k
self.weights: list[float] = weights
Expand All @@ -74,6 +91,7 @@ def create_hybrid_retriever(self) -> None:
html_docs_path=self.html_docs_path,
chunk_size=self.chunk_size,
use_cuda=self.use_cuda,
embedding_model=self.embedding_model,
)
if self.vector_db is None:
cur_path = os.path.abspath(__file__)
Expand Down Expand Up @@ -121,8 +139,11 @@ def create_hybrid_retriever(self) -> None:
)

if self.contextual_rerank:
reranker = self.reranker_model or HuggingFaceCrossEncoder(
model_name=self.reranking_model_name
)
compressor = CrossEncoderReranker(
model=HuggingFaceCrossEncoder(model_name=self.reranking_model_name),
model=reranker,
top_n=self.search_k,
)
self.retriever = ContextualCompressionRetriever(
Expand Down
20 changes: 18 additions & 2 deletions backend/src/chains/similarity_retriever_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain.docstore.document import Document
from langchain_google_vertexai import ChatVertexAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
from langchain_ollama import ChatOllama

from ..vectorstores.faiss import FAISSVectorDatabase
Expand All @@ -28,6 +29,13 @@ def __init__(
embeddings_config: Optional[dict[str, str]] = None,
use_cuda: bool = False,
chunk_size: int = 500,
embedding_model: Optional[
Union[
HuggingFaceEmbeddings,
GoogleGenerativeAIEmbeddings,
VertexAIEmbeddings,
]
] = None,
):
super().__init__(
llm_model=llm_model,
Expand All @@ -40,6 +48,13 @@ def __init__(

self.embeddings_config: Optional[dict[str, str]] = embeddings_config
self.use_cuda: bool = use_cuda
self.embedding_model: Optional[
Union[
HuggingFaceEmbeddings,
GoogleGenerativeAIEmbeddings,
VertexAIEmbeddings,
]
] = embedding_model

self.markdown_docs_path: Optional[list[str]] = markdown_docs_path
self.other_docs_path: Optional[list[str]] = other_docs_path
Expand Down Expand Up @@ -125,6 +140,7 @@ def create_vector_db(self) -> None:
embeddings_model_name=self.embeddings_config["name"],
embeddings_type=self.embeddings_config["type"],
use_cuda=self.use_cuda,
embedding_model=self.embedding_model,
)
else:
raise ValueError("Embeddings model config not provided correctly.")
Expand Down
12 changes: 9 additions & 3 deletions backend/src/vectorstores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,21 @@ def __init__(
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
debug: bool = False,
use_cuda: bool = False,
embedding_model: Optional[
Union[
HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings
]
] = None,
):
self.embeddings_model_name = embeddings_model_name

model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"}

self.embedding_model: Union[
HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings
]

if embeddings_type == "GOOGLE_GENAI":
if embedding_model is not None:
self.embedding_model = embedding_model
elif embeddings_type == "GOOGLE_GENAI":
self.embedding_model = GoogleGenerativeAIEmbeddings(
model=self.embeddings_model_name,
task_type="retrieval_document",
Expand All @@ -51,6 +56,7 @@ def __init__(
logging.info("Using Google VertexAI embeddings...")

elif embeddings_type == "HF":
model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"}
self.embedding_model = HuggingFaceEmbeddings(
model_name=self.embeddings_model_name,
multi_process=False,
Expand Down
Loading