diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f947f2a..ef80eec 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,10 +26,10 @@ jobs: run: black --check --diff app/ tests/ - name: Check import ordering with isort - run: isort --check-only --diff app/ tests/ + run: isort --check-only --diff --profile black app/ tests/ - name: Lint with flake8 - run: flake8 app/ tests/ --max-line-length 120 --ignore E501,W503,E402 + run: flake8 app/ tests/ --max-line-length 120 --ignore E501,W503,E402,E203 test: name: Test diff --git a/Makefile b/Makefile index 31a1715..676e5c2 100644 --- a/Makefile +++ b/Makefile @@ -29,12 +29,12 @@ test-cov: lint: black --check --diff app/ tests/ - isort --check-only --diff app/ tests/ - flake8 app/ tests/ --max-line-length 120 --ignore E501,W503,E402 + isort --check-only --diff --profile black app/ tests/ + flake8 app/ tests/ --max-line-length 120 --ignore E501,W503,E402,E203 format: black app/ tests/ - isort app/ tests/ + isort --profile black app/ tests/ clean: find . -type d -name "__pycache__" -exec rm -r {} + diff --git a/app/api/api.py b/app/api/api.py index c2ff5a3..b80785a 100644 --- a/app/api/api.py +++ b/app/api/api.py @@ -1,10 +1,12 @@ from fastapi import APIRouter -from app.api.endpoints import auth, users, knowledge_bases, documents, messages +from app.api.endpoints import auth, documents, knowledge_bases, messages, users api_router = APIRouter() api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) api_router.include_router(users.router, prefix="/users", tags=["users"]) -api_router.include_router(knowledge_bases.router, prefix="/knowledge-bases", tags=["knowledge-bases"]) +api_router.include_router( + knowledge_bases.router, prefix="/knowledge-bases", tags=["knowledge-bases"] +) api_router.include_router(documents.router, prefix="/documents", tags=["documents"]) -api_router.include_router(messages.router, prefix="/messages", tags=["messages"]) \ No newline at end of file +api_router.include_router(messages.router, prefix="/messages", tags=["messages"]) diff --git a/app/api/deps.py b/app/api/deps.py index 0745dc5..bf87923 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -1,34 +1,40 @@ -from fastapi import Depends, HTTPException -from fastapi.security import OAuth2PasswordBearer -from jose import JWTError, jwt from datetime import datetime, timedelta from typing import Optional +from fastapi import Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt from sqlalchemy.orm import Session + +from app.core.config import settings from app.db.database import get_db from app.db.models.user import User -from app.core.config import settings from app.schemas.user import UserResponse oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") -async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> UserResponse: + +async def get_current_user( + token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) +) -> UserResponse: try: - payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + ) user_id = payload.get("sub") if not user_id: raise HTTPException(status_code=401, detail="Invalid token") - - # get db + # get db user = db.query(User).filter(User.id == user_id).first() if not user: raise HTTPException(status_code=401, detail="User not found") - + return UserResponse.model_validate(user) except JWTError: raise HTTPException(status_code=401, detail="Invalid token") + def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str: """ Create a new JWT access token for a user @@ -39,11 +45,13 @@ def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) str: JWT access token """ to_encode = {"sub": user_id} - + if expires_delta: expire = datetime.utcnow() + expires_delta else: - expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) - + expire = datetime.utcnow() + timedelta( + minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES + ) + to_encode.update({"exp": expire}) - return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) \ No newline at end of file + return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) diff --git a/app/api/endpoints/auth.py b/app/api/endpoints/auth.py index 552e88a..a632284 100644 --- a/app/api/endpoints/auth.py +++ b/app/api/endpoints/auth.py @@ -1,19 +1,20 @@ +from datetime import timedelta + from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordRequestForm -from datetime import timedelta from sqlalchemy.orm import Session -from app.services.user_service import UserService from app.api.deps import create_access_token from app.core.config import settings from app.db.database import get_db +from app.services.user_service import UserService router = APIRouter() + @router.post("/token") async def login_for_access_token( - form_data: OAuth2PasswordRequestForm = Depends(), - db: Session = Depends(get_db) + form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db) ): """ OAuth2 compatible token login, get an access token for future requests @@ -30,10 +31,7 @@ async def login_for_access_token( # Create access token access_token = create_access_token( user_id=str(user.id), - expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES), ) - - return { - "access_token": access_token, - "token_type": "bearer" - } \ No newline at end of file + + return {"access_token": access_token, "token_type": "bearer"} diff --git a/app/api/endpoints/conversations.py b/app/api/endpoints/conversations.py index d46656a..34908d5 100644 --- a/app/api/endpoints/conversations.py +++ b/app/api/endpoints/conversations.py @@ -1,82 +1,100 @@ +import logging +from functools import lru_cache from typing import List + from fastapi import APIRouter, Body, Depends from fastapi.responses import JSONResponse -from functools import lru_cache from sqlalchemy.orm import Session -from app.schemas.conversation import ConversationCreate, ConversationUpdate, ConversationResponse from app.api.deps import get_current_user +from app.api.endpoints.knowledge_bases import get_knowledge_base_service +from app.db.database import get_db +from app.repositories.conversation_repository import ConversationRepository +from app.schemas.conversation import ( + ConversationCreate, + ConversationResponse, + ConversationUpdate, +) from app.schemas.user import UserResponse from app.services.conversation_service import ConversationService -from app.repositories.conversation_repository import ConversationRepository from app.services.knowledge_base_service import KnowledgeBaseService -from app.api.endpoints.knowledge_bases import get_knowledge_base_service -from app.db.database import get_db -import logging router = APIRouter() logger = logging.getLogger(__name__) + @lru_cache() def get_conversation_repository() -> ConversationRepository: """Get conversation repository instance""" return ConversationRepository() + def get_conversation_service( - conversation_repository: ConversationRepository = Depends(get_conversation_repository), + conversation_repository: ConversationRepository = Depends( + get_conversation_repository + ), knowledge_base_service: KnowledgeBaseService = Depends(get_knowledge_base_service), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ) -> ConversationService: """Get conversation service instance""" return ConversationService( conversation_repository=conversation_repository, knowledge_base_service=knowledge_base_service, - db=db + db=db, ) + @router.post("", response_model=ConversationResponse) async def create_conversation( payload: ConversationCreate = Body(..., description="Conversation details"), current_user: UserResponse = Depends(get_current_user), - conversation_service: ConversationService = Depends(get_conversation_service) + conversation_service: ConversationService = Depends(get_conversation_service), ): """Create a new conversation""" return await conversation_service.create_conversation(payload, current_user) + @router.get("", response_model=List[ConversationResponse]) async def list_conversations( current_user: UserResponse = Depends(get_current_user), - conversation_service: ConversationService = Depends(get_conversation_service) + conversation_service: ConversationService = Depends(get_conversation_service), ): """List all conversations for the current user""" logger.info(f"Listing conversations for user {current_user.id}") return await conversation_service.list_conversations(current_user) + @router.get("/{conversation_id}", response_model=ConversationResponse) async def get_conversation( conversation_id: str, current_user: UserResponse = Depends(get_current_user), - conversation_service: ConversationService = Depends(get_conversation_service) + conversation_service: ConversationService = Depends(get_conversation_service), ): """Get conversation details including messages""" return await conversation_service.get_conversation(conversation_id, current_user) + @router.put("/{conversation_id}", response_model=ConversationResponse) async def update_conversation( conversation_id: str, - conversation_update: ConversationUpdate = Body(..., description="Conversation details"), + conversation_update: ConversationUpdate = Body( + ..., description="Conversation details" + ), current_user: UserResponse = Depends(get_current_user), - conversation_service: ConversationService = Depends(get_conversation_service) + conversation_service: ConversationService = Depends(get_conversation_service), ): """Update conversation details""" - return await conversation_service.update_conversation(conversation_id, conversation_update, current_user) + return await conversation_service.update_conversation( + conversation_id, conversation_update, current_user + ) + @router.delete("/{conversation_id}") async def delete_conversation( conversation_id: str, current_user: UserResponse = Depends(get_current_user), - conversation_service: ConversationService = Depends(get_conversation_service) + conversation_service: ConversationService = Depends(get_conversation_service), ): """Delete a conversation and all its messages""" await conversation_service.delete_conversation(conversation_id, current_user) - return JSONResponse(content={"message": "Conversation deleted successfully"}) \ No newline at end of file + return JSONResponse(content={"message": "Conversation deleted successfully"}) diff --git a/app/api/endpoints/documents.py b/app/api/endpoints/documents.py index 242530e..f02de82 100644 --- a/app/api/endpoints/documents.py +++ b/app/api/endpoints/documents.py @@ -1,40 +1,45 @@ -from fastapi import APIRouter, Body, Depends, UploadFile, File, Form import logging + +from fastapi import APIRouter, Body, Depends from sqlalchemy.orm import Session +from app.api.deps import get_current_user +from app.core.config import settings from app.db.database import get_db +from app.repositories.document_repository import DocumentRepository +from app.repositories.knowledge_base_repository import KnowledgeBaseRepository +from app.schemas.document import DocumentResponse, DocumentUpdate from app.schemas.user import UserResponse from app.services.document_service import DocumentService from app.services.knowledge_base_service import KnowledgeBaseService, LocalFileStorage -from app.repositories.document_repository import DocumentRepository from app.services.rag.vector_store import VectorStore, get_vector_store -from app.repositories.knowledge_base_repository import KnowledgeBaseRepository from app.worker.celery import celery_app -from app.api.deps import get_current_user -from app.schemas.document import DocumentResponse, DocumentUpdate, DocumentUpload -from app.core.config import settings router = APIRouter() logger = logging.getLogger(__name__) + # Dependencies def get_file_storage() -> LocalFileStorage: """Get file storage instance""" return LocalFileStorage(upload_dir=settings.UPLOAD_DIR) + def get_knowledge_base_repository() -> KnowledgeBaseRepository: """Get knowledge base repository instance""" return KnowledgeBaseRepository() + def get_document_repository() -> DocumentRepository: """Get document repository instance""" return DocumentRepository() + def get_knowledge_base_service( repository: KnowledgeBaseRepository = Depends(get_knowledge_base_repository), vector_store: VectorStore = Depends(lambda: get_vector_store()), file_storage: LocalFileStorage = Depends(get_file_storage), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ) -> KnowledgeBaseService: """Dependency for KnowledgeBaseService""" return KnowledgeBaseService( @@ -42,15 +47,16 @@ def get_knowledge_base_service( vector_store=vector_store, file_storage=file_storage, celery_app=celery_app, - db=db + db=db, ) + def get_document_service( kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service), document_repository: DocumentRepository = Depends(get_document_repository), vector_store: VectorStore = Depends(lambda: get_vector_store()), file_storage: LocalFileStorage = Depends(get_file_storage), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ) -> DocumentService: """Dependency for DocumentService""" return DocumentService( @@ -59,20 +65,22 @@ def get_document_service( knowledge_base_service=kb_service, file_storage=file_storage, celery_app=celery_app, - db=db + db=db, ) + @router.get("/{document_id}", response_model=DocumentResponse) async def get_document( document_id: str, current_user: UserResponse = Depends(get_current_user), - doc_service: DocumentService = Depends(get_document_service) + doc_service: DocumentService = Depends(get_document_service), ): """ Get document details by ID. """ return await doc_service.get_document(document_id, current_user) + # @router.get("/", response_model=List[DocumentResponse]) # async def list_documents( # knowledge_base_id: str, @@ -84,26 +92,28 @@ async def get_document( # """ # return await doc_service.list_documents(knowledge_base_id, current_user) + @router.put("/{document_id}", response_model=DocumentResponse) async def update_document( document_id: str, doc_update: DocumentUpdate = Body(..., description="Document details"), current_user: UserResponse = Depends(get_current_user), - doc_service: DocumentService = Depends(get_document_service) + doc_service: DocumentService = Depends(get_document_service), ): """ Update document details. """ return await doc_service.update_document(document_id, doc_update, current_user) + @router.delete("/{document_id}") async def delete_document( document_id: str, current_user: UserResponse = Depends(get_current_user), - doc_service: DocumentService = Depends(get_document_service) + doc_service: DocumentService = Depends(get_document_service), ): """ Delete a document. """ await doc_service.delete_document(document_id, current_user) - return {"message": "Document deleted successfully"} \ No newline at end of file + return {"message": "Document deleted successfully"} diff --git a/app/api/endpoints/knowledge_bases.py b/app/api/endpoints/knowledge_bases.py index 73e395e..6b82950 100644 --- a/app/api/endpoints/knowledge_bases.py +++ b/app/api/endpoints/knowledge_bases.py @@ -1,62 +1,67 @@ -from typing import Annotated, List -from fastapi import APIRouter, Body, Depends, UploadFile, File, Path, HTTPException -from fastapi.responses import JSONResponse -from functools import lru_cache -from sqlalchemy.orm import Session import csv import io +import logging +from functools import lru_cache +from typing import List +from fastapi import APIRouter, Body, Depends, File, HTTPException, Path, UploadFile +from fastapi.responses import JSONResponse +from sqlalchemy.orm import Session + +from app.api.deps import get_current_user +from app.core.config import settings +from app.core.permissions import Permission, check_permission +from app.db.database import get_db +from app.repositories.document_repository import DocumentRepository +from app.repositories.knowledge_base_repository import KnowledgeBaseRepository +from app.repositories.question_repository import QuestionRepository +from app.schemas.document import DocumentResponse, DocumentUpdate, DocumentUpload from app.schemas.knowledge_base import ( KnowledgeBaseCreate, - KnowledgeBaseUpdate, KnowledgeBaseResponse, KnowledgeBaseShareRequest, - KnowledgeBaseUnshareRequest, KnowledgeBaseSharingResponse, - SharedUserInfo + KnowledgeBaseUnshareRequest, + KnowledgeBaseUpdate, + SharedUserInfo, ) -from app.schemas.document import DocumentUpdate, DocumentResponse, DocumentUpload -from app.schemas.question import QuestionResponse, QuestionCreate, QuestionUpdate -from app.api.deps import get_current_user +from app.schemas.question import QuestionCreate, QuestionResponse, QuestionUpdate from app.schemas.user import UserResponse -from app.services.knowledge_base_service import KnowledgeBaseService, LocalFileStorage from app.services.document_service import DocumentService +from app.services.knowledge_base_service import KnowledgeBaseService, LocalFileStorage from app.services.question_service import QuestionService -from app.repositories.document_repository import DocumentRepository -from app.repositories.question_repository import QuestionRepository from app.services.rag.vector_store import VectorStore, get_vector_store -from app.repositories.knowledge_base_repository import KnowledgeBaseRepository -from app.core.config import settings from app.worker.celery import celery_app -from app.db.database import get_db -from app.core.permissions import check_permission, Permission - -import logging logger = logging.getLogger(__name__) router = APIRouter() + @lru_cache() def get_file_storage() -> LocalFileStorage: """Get file storage instance""" return LocalFileStorage(upload_dir=settings.UPLOAD_DIR) + def get_knowledge_base_repository() -> KnowledgeBaseRepository: """Get knowledge base repository instance""" return KnowledgeBaseRepository() + def get_document_repository() -> DocumentRepository: """Get document repository instance""" return DocumentRepository() + def get_question_repository() -> QuestionRepository: return QuestionRepository() + def get_knowledge_base_service( repository: KnowledgeBaseRepository = Depends(get_knowledge_base_repository), vector_store: VectorStore = Depends(lambda: get_vector_store()), file_storage: LocalFileStorage = Depends(get_file_storage), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ) -> KnowledgeBaseService: """Dependency for KnowledgeBaseService""" return KnowledgeBaseService( @@ -64,15 +69,16 @@ def get_knowledge_base_service( vector_store=vector_store, file_storage=file_storage, celery_app=celery_app, - db=db + db=db, ) + def get_document_service( kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service), document_repository: DocumentRepository = Depends(get_document_repository), vector_store: VectorStore = Depends(lambda: get_vector_store()), file_storage: LocalFileStorage = Depends(get_file_storage), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ) -> DocumentService: """Dependency for DocumentService""" return DocumentService( @@ -81,37 +87,40 @@ def get_document_service( knowledge_base_service=kb_service, file_storage=file_storage, celery_app=celery_app, - db=db + db=db, ) + def get_question_service( kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service), question_repository: QuestionRepository = Depends(get_question_repository), vector_store: VectorStore = Depends(lambda: get_vector_store()), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ) -> QuestionService: return QuestionService( question_repository=question_repository, vector_store=vector_store, knowledge_base_service=kb_service, celery_app=celery_app, - db=db + db=db, ) + @router.post("", response_model=KnowledgeBaseResponse) async def create_knowledge_base( kb: KnowledgeBaseCreate = Body(..., description="Knowledge base details"), current_user: UserResponse = Depends(get_current_user), - kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service) + kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service), ): """Create a new knowledge base""" return await kb_service.create_knowledge_base(kb, current_user) + @router.get("", response_model=List[KnowledgeBaseResponse]) async def list_knowledge_bases( current_user: UserResponse = Depends(get_current_user), kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service), - _: UserResponse = Depends(check_permission(Permission.VIEW_KNOWLEDGE_BASES)) + _: UserResponse = Depends(check_permission(Permission.VIEW_KNOWLEDGE_BASES)), ): """ List knowledge bases based on role: @@ -121,310 +130,368 @@ async def list_knowledge_bases( """ return await kb_service.list_knowledge_bases(current_user) + @router.get("/shared-with-me", response_model=List[KnowledgeBaseResponse]) async def get_shared_knowledge_bases( kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service), current_user: UserResponse = Depends(get_current_user), - _: UserResponse = Depends(check_permission(Permission.VIEW_KNOWLEDGE_BASES)) + _: UserResponse = Depends(check_permission(Permission.VIEW_KNOWLEDGE_BASES)), ): """Get all knowledge bases shared with the current user""" return await kb_service.list_shared_knowledge_bases(current_user) + @router.get("/{kb_id}", response_model=KnowledgeBaseResponse) async def get_knowledge_base( kb_id: str, current_user: UserResponse = Depends(get_current_user), - kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service) + kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service), ): """Get knowledge base details""" return await kb_service.get_knowledge_base(kb_id, current_user) + @router.put("/{kb_id}", response_model=KnowledgeBaseResponse) async def update_knowledge_base( kb_id: str, kb_update: KnowledgeBaseUpdate = Body(..., description="Knowledge base details"), current_user: UserResponse = Depends(get_current_user), - kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service) + kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service), ): """Update knowledge base""" return await kb_service.update_knowledge_base(kb_id, kb_update, current_user) + @router.delete("/{kb_id}") async def delete_knowledge_base( kb_id: str, current_user: UserResponse = Depends(get_current_user), - kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service) + kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service), ): """Delete knowledge base and all its documents""" await kb_service.delete_knowledge_base(kb_id, current_user) return JSONResponse(content={"message": "Knowledge base deleted successfully"}) + @router.post("/{kb_id}/share", response_model=KnowledgeBaseSharingResponse) async def share_knowledge_base( kb_id: str, share_data: KnowledgeBaseShareRequest, kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service), current_user: UserResponse = Depends(get_current_user), - _: UserResponse = Depends(check_permission(Permission.UPDATE_KNOWLEDGE_BASE)) + _: UserResponse = Depends(check_permission(Permission.UPDATE_KNOWLEDGE_BASE)), ): """Share a knowledge base with another user""" - success = await kb_service.share_knowledge_base(kb_id, share_data.user_id, current_user) + success = await kb_service.share_knowledge_base( + kb_id, share_data.user_id, current_user + ) return KnowledgeBaseSharingResponse( success=success, - message="Knowledge base shared successfully" if success else "Failed to share knowledge base" + message=( + "Knowledge base shared successfully" + if success + else "Failed to share knowledge base" + ), ) + @router.post("/{kb_id}/unshare", response_model=KnowledgeBaseSharingResponse) async def unshare_knowledge_base( kb_id: str, unshare_data: KnowledgeBaseUnshareRequest, kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service), current_user: UserResponse = Depends(get_current_user), - _: UserResponse = Depends(check_permission(Permission.UPDATE_KNOWLEDGE_BASE)) + _: UserResponse = Depends(check_permission(Permission.UPDATE_KNOWLEDGE_BASE)), ): """Remove a user's access to a knowledge base""" - success = await kb_service.unshare_knowledge_base(kb_id, unshare_data.user_id, current_user) + success = await kb_service.unshare_knowledge_base( + kb_id, unshare_data.user_id, current_user + ) return KnowledgeBaseSharingResponse( success=success, - message="Knowledge base access removed successfully" if success else "Failed to remove knowledge base access" + message=( + "Knowledge base access removed successfully" + if success + else "Failed to remove knowledge base access" + ), ) + @router.get("/{kb_id}/shared-users", response_model=List[SharedUserInfo]) async def get_shared_users( kb_id: str, kb_service: KnowledgeBaseService = Depends(get_knowledge_base_service), current_user: UserResponse = Depends(get_current_user), - _: UserResponse = Depends(check_permission(Permission.VIEW_KNOWLEDGE_BASES)) + _: UserResponse = Depends(check_permission(Permission.VIEW_KNOWLEDGE_BASES)), ): """Get all users who have access to a knowledge base""" return await kb_service.list_shared_users(kb_id, current_user) + @router.post("/{kb_id}/documents", response_model=DocumentResponse) async def create_document( kb_id: str = Path(..., description="Knowledge base ID"), file: UploadFile = File(..., description="Document to upload"), current_user: UserResponse = Depends(get_current_user), - doc_service: DocumentService = Depends(get_document_service) + doc_service: DocumentService = Depends(get_document_service), ): """Upload a new document to a knowledge base""" logger.info(f"Uploading document {file.filename} to knowledge base {kb_id}") - payload = DocumentUpload(title=file.filename, content=file.file.read(), knowledge_base_id=kb_id, content_type=file.content_type) + payload = DocumentUpload( + title=file.filename, + content=file.file.read(), + knowledge_base_id=kb_id, + content_type=file.content_type, + ) return await doc_service.create_document(kb_id, payload, current_user) + @router.get("/{kb_id}/documents", response_model=List[DocumentResponse]) async def list_documents( kb_id: str = Path(..., description="Knowledge base ID"), current_user: UserResponse = Depends(get_current_user), - doc_service: DocumentService = Depends(get_document_service) + doc_service: DocumentService = Depends(get_document_service), ): """List all documents in a knowledge base""" return await doc_service.list_documents(kb_id, current_user) + @router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse) async def get_document( kb_id: str, doc_id: str, current_user: UserResponse = Depends(get_current_user), - doc_service: DocumentService = Depends(get_document_service) + doc_service: DocumentService = Depends(get_document_service), ): """Get document details""" return await doc_service.get_document(doc_id, current_user) + @router.put("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse) async def update_document( doc_id: str, doc_update: DocumentUpdate = Body(..., description="Document details"), current_user: UserResponse = Depends(get_current_user), - doc_service: DocumentService = Depends(get_document_service) + doc_service: DocumentService = Depends(get_document_service), ): """Update document details""" return await doc_service.update_document(doc_id, doc_update, current_user) + @router.delete("/{kb_id}/documents/{doc_id}") async def delete_document( kb_id: str, doc_id: str, current_user: UserResponse = Depends(get_current_user), - doc_service: DocumentService = Depends(get_document_service) + doc_service: DocumentService = Depends(get_document_service), ): """Delete a document""" await doc_service.delete_document(doc_id, current_user) return JSONResponse(content={"message": "Document deleted successfully"}) + @router.post("/{kb_id}/documents/{doc_id}/retry", response_model=DocumentResponse) async def retry_document( kb_id: str = Path(..., description="Knowledge base ID"), doc_id: str = Path(..., description="Document ID"), current_user: UserResponse = Depends(get_current_user), - doc_service: DocumentService = Depends(get_document_service) + doc_service: DocumentService = Depends(get_document_service), ): """ Retry processing a failed document. - + This endpoint allows you to retry processing a document that failed during the initial ingestion. It can only be used on documents with a FAILED status. """ return await doc_service.retry_failed_document(kb_id, doc_id, current_user) + # Question endpoints + @router.get("/{kb_id}/questions", response_model=List[QuestionResponse]) async def list_questions( kb_id: str, skip: int = 0, limit: int = 100, current_user: UserResponse = Depends(get_current_user), - question_service: QuestionService = Depends(get_question_service) + question_service: QuestionService = Depends(get_question_service), ): """List all questions for a knowledge base""" return await question_service.list_questions(kb_id, current_user, skip, limit) + @router.get("/{kb_id}/questions/{question_id}", response_model=QuestionResponse) async def get_question( kb_id: str, question_id: str, current_user: UserResponse = Depends(get_current_user), - question_service: QuestionService = Depends(get_question_service) + question_service: QuestionService = Depends(get_question_service), ): """Get a specific question by ID""" return await question_service.get_question(question_id, current_user) + @router.post("/{kb_id}/questions", response_model=QuestionResponse) async def create_question( kb_id: str, question: QuestionCreate, current_user: UserResponse = Depends(get_current_user), - question_service: QuestionService = Depends(get_question_service) + question_service: QuestionService = Depends(get_question_service), ): """Create a new question in a knowledge base""" return await question_service.create_question(kb_id, question, current_user) + @router.put("/{kb_id}/questions/{question_id}", response_model=QuestionResponse) async def update_question( kb_id: str, question_id: str, question_update: QuestionUpdate, current_user: UserResponse = Depends(get_current_user), - question_service: QuestionService = Depends(get_question_service) + question_service: QuestionService = Depends(get_question_service), ): """Update a question""" - return await question_service.update_question(question_id, question_update, current_user) + return await question_service.update_question( + question_id, question_update, current_user + ) + @router.delete("/{kb_id}/questions/{question_id}") async def delete_question( kb_id: str, question_id: str, current_user: UserResponse = Depends(get_current_user), - question_service: QuestionService = Depends(get_question_service) + question_service: QuestionService = Depends(get_question_service), ): """Delete a question""" await question_service.delete_question(question_id, current_user) return {"message": "Question deleted successfully"} + @router.get("/{kb_id}/questions/{question_id}/status") async def get_question_status( kb_id: str, question_id: str, current_user: UserResponse = Depends(get_current_user), - question_service: QuestionService = Depends(get_question_service) + question_service: QuestionService = Depends(get_question_service), ): """Get the status of a question""" return await question_service.get_question_status(question_id, current_user) + @router.post("/{kb_id}/questions/bulk-upload") async def bulk_upload_questions( kb_id: str, file: UploadFile = File(...), current_user: UserResponse = Depends(get_current_user), - question_service: QuestionService = Depends(get_question_service) + question_service: QuestionService = Depends(get_question_service), ): """Bulk upload questions from a CSV file""" # Check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise HTTPException(status_code=400, detail="Only CSV files are allowed") - + try: # Read CSV file contents = await file.read() - + # Try to decode with different encodings if UTF-8 fails try: - csv_data = contents.decode('utf-8') + csv_data = contents.decode("utf-8") except UnicodeDecodeError: try: - csv_data = contents.decode('latin-1') - except: - raise HTTPException(status_code=400, detail="Unable to decode CSV file. Please ensure it's properly encoded (UTF-8 or Latin-1).") - + csv_data = contents.decode("latin-1") + except Exception: + raise HTTPException( + status_code=400, + detail="Unable to decode CSV file. Please ensure it's properly encoded (UTF-8 or Latin-1).", + ) + # Handle empty file if not csv_data.strip(): - raise HTTPException(status_code=400, detail="The uploaded CSV file is empty.") - + raise HTTPException( + status_code=400, detail="The uploaded CSV file is empty." + ) + # Parse CSV try: csv_reader = csv.DictReader(io.StringIO(csv_data)) - + # Check if we got any fieldnames if not csv_reader.fieldnames: - raise HTTPException(status_code=400, detail="Could not parse CSV headers. Please ensure the file is properly formatted.") - + raise HTTPException( + status_code=400, + detail="Could not parse CSV headers. Please ensure the file is properly formatted.", + ) + # Validate required fields - required_fields = ['question', 'answer', 'answer_type'] - missing_fields = [field for field in required_fields if field not in csv_reader.fieldnames] - + required_fields = ["question", "answer", "answer_type"] + missing_fields = [ + field for field in required_fields if field not in csv_reader.fieldnames + ] + if missing_fields: raise HTTPException( - status_code=400, - detail=f"CSV is missing required columns: {', '.join(missing_fields)}. Required columns are: {', '.join(required_fields)}" + status_code=400, + detail=f"CSV is missing required columns: {', '.join(missing_fields)}. Required columns are: {', '.join(required_fields)}", ) - + # Process questions - results = { - "success": 0, - "failed": 0, - "errors": [] - } - + results = {"success": 0, "failed": 0, "errors": []} + rows = list(csv_reader) if not rows: - return {"success": 0, "failed": 0, "errors": ["The CSV file contains no data rows."]} - - for row_idx, row in enumerate(rows, start=2): # Start at 2 to account for header row + return { + "success": 0, + "failed": 0, + "errors": ["The CSV file contains no data rows."], + } + + for row_idx, row in enumerate( + rows, start=2 + ): # Start at 2 to account for header row try: # Validate answer_type - answer_type = row['answer_type'].strip().upper() + answer_type = row["answer_type"].strip().upper() if answer_type not in ["DIRECT", "SQL_QUERY"]: - raise ValueError(f"Invalid answer_type '{answer_type}'. Must be one of: DIRECT, SQL_QUERY") - + raise ValueError( + f"Invalid answer_type '{answer_type}'. Must be one of: DIRECT, SQL_QUERY" + ) + # Validate required values - if not row['question'].strip(): + if not row["question"].strip(): raise ValueError("Question cannot be empty") - if not row['answer'].strip(): + if not row["answer"].strip(): raise ValueError("Answer cannot be empty") - + # Create question model question_data = QuestionCreate( - question=row['question'].strip(), - answer=row['answer'].strip(), - answer_type=answer_type + question=row["question"].strip(), + answer=row["answer"].strip(), + answer_type=answer_type, ) - + # Create question - await question_service.create_question(kb_id, question_data, current_user) + await question_service.create_question( + kb_id, question_data, current_user + ) results["success"] += 1 - + except Exception as e: results["failed"] += 1 results["errors"].append(f"Row {row_idx}: {str(e)}") - + return results - + except csv.Error as e: raise HTTPException(status_code=400, detail=f"CSV parsing error: {str(e)}") - + except HTTPException: # Re-raise HTTP exceptions raise except Exception as e: # Catch all other exceptions - raise HTTPException(status_code=500, detail=f"Error processing CSV file: {str(e)}") \ No newline at end of file + raise HTTPException( + status_code=500, detail=f"Error processing CSV file: {str(e)}" + ) diff --git a/app/api/endpoints/messages.py b/app/api/endpoints/messages.py index f3187a3..ac99b58 100644 --- a/app/api/endpoints/messages.py +++ b/app/api/endpoints/messages.py @@ -1,42 +1,47 @@ -from typing import List, Dict, Any -from fastapi import APIRouter, Body, Depends, Path from functools import lru_cache +from typing import List + +from fastapi import APIRouter, Body, Depends, Path from sqlalchemy.orm import Session -from app.schemas.message import MessageCreate, MessageResponse from app.api.deps import get_current_user -from app.schemas.user import UserResponse -from app.services.message_service import MessageService -from app.repositories.message_repository import MessageRepository from app.api.endpoints.conversations import get_conversation_service from app.db.database import get_db -from app.services.query_router import get_query_router +from app.repositories.message_repository import MessageRepository +from app.schemas.message import MessageCreate, MessageResponse +from app.schemas.user import UserResponse +from app.services.message_service import MessageService router = APIRouter() + @lru_cache() def get_message_repository() -> MessageRepository: """Get message repository instance""" return MessageRepository() + def get_message_service( message_repository: MessageRepository = Depends(get_message_repository), - conversation_service = Depends(get_conversation_service), - db: Session = Depends(get_db) + conversation_service=Depends(get_conversation_service), + db: Session = Depends(get_db), ) -> MessageService: """Get message service instance""" return MessageService( message_repository=message_repository, conversation_service=conversation_service, - db=db + db=db, ) + @router.post("/{conversation_id}/messages", response_model=MessageResponse) async def create_message( - conversation_id: str = Path(..., description="ID of the conversation this message belongs to"), + conversation_id: str = Path( + ..., description="ID of the conversation this message belongs to" + ), payload: MessageCreate = Body(..., description="Message details"), current_user: UserResponse = Depends(get_current_user), - message_service: MessageService = Depends(get_message_service) + message_service: MessageService = Depends(get_message_service), ): """Create a new message in a conversation""" return await message_service.create_message( @@ -45,21 +50,23 @@ async def create_message( current_user, ) + @router.get("/{conversation_id}/messages", response_model=List[MessageResponse]) async def list_messages( conversation_id: str, current_user: UserResponse = Depends(get_current_user), - message_service: MessageService = Depends(get_message_service) + message_service: MessageService = Depends(get_message_service), ): """List all messages in a conversation""" return await message_service.list_messages(conversation_id, current_user) + @router.get("/{conversation_id}/messages/{message_id}", response_model=MessageResponse) async def get_message( conversation_id: str, message_id: str, current_user: UserResponse = Depends(get_current_user), - message_service: MessageService = Depends(get_message_service) + message_service: MessageService = Depends(get_message_service), ): """Get a message by ID""" return await message_service.get_message(conversation_id, message_id, current_user) diff --git a/app/api/endpoints/users.py b/app/api/endpoints/users.py index cee9bea..2aae902 100644 --- a/app/api/endpoints/users.py +++ b/app/api/endpoints/users.py @@ -1,32 +1,32 @@ from typing import List + from fastapi import APIRouter, Depends from sqlalchemy.orm import Session -from app.db.models.user import User -from app.schemas.user import UserCreate, UserUpdate, UserResponse, UserWithPermissions -from app.services.user_service import UserService from app.api.deps import get_current_user +from app.core.permissions import Permission, check_permission, get_permissions_for_role from app.db.database import get_db -from app.core.permissions import get_permissions_for_role, Permission, check_permission +from app.db.models.user import User +from app.schemas.user import UserCreate, UserResponse, UserUpdate, UserWithPermissions +from app.services.user_service import UserService router = APIRouter() + @router.post("", response_model=UserResponse) -async def create_user( - user_data: UserCreate, - db: Session = Depends(get_db) -): +async def create_user(user_data: UserCreate, db: Session = Depends(get_db)): """ Create a new user. """ user_service = UserService(db) return await user_service.create_user(user_data) + @router.get("", response_model=List[UserResponse]) async def list_users( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), - _: User = Depends(check_permission(Permission.VIEW_USERS)) + _: User = Depends(check_permission(Permission.VIEW_USERS)), ): """ List all users (admin and owner roles). @@ -34,10 +34,10 @@ async def list_users( user_service = UserService(db) return await user_service.list_users(current_user) + @router.get("/me", response_model=UserWithPermissions) async def get_current_user_info( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """ Get current user information with permissions. @@ -45,21 +45,19 @@ async def get_current_user_info( """ user_service = UserService(db) user = await user_service.get_user(str(current_user.id)) - + # Get permissions for the user's role permissions = [perm.value for perm in get_permissions_for_role(user.role)] - + # Create a UserWithPermissions response - return { - **user.dict(), - "permissions": permissions - } + return {**user.dict(), "permissions": permissions} + @router.get("/{user_id}", response_model=UserResponse) async def get_user( user_id: str, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Get user by ID. @@ -67,12 +65,13 @@ async def get_user( user_service = UserService(db) return await user_service.get_user(user_id) + @router.put("/{user_id}", response_model=UserResponse) async def update_user( user_id: str, user_update: UserUpdate, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Update user. @@ -80,15 +79,16 @@ async def update_user( user_service = UserService(db) return await user_service.update_user(user_id, user_update, current_user) + @router.delete("/{user_id}") async def delete_user( user_id: str, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Delete user (admin only). """ user_service = UserService(db) await user_service.delete_user(user_id, current_user) - return {"message": "User deleted successfully"} \ No newline at end of file + return {"message": "User deleted successfully"} diff --git a/app/core/config.py b/app/core/config.py index 2ca5199..d30d374 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -1,7 +1,9 @@ -from typing import List, Any, Dict, Optional, Union -from pydantic_settings import BaseSettings -from pydantic import EmailStr, AnyHttpUrl, PostgresDsn, field_validator import os +from typing import List, Optional, Union + +from pydantic import AnyHttpUrl, field_validator +from pydantic_settings import BaseSettings + class Settings(BaseSettings): # App Settings @@ -12,7 +14,9 @@ class Settings(BaseSettings): # Security SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key") ALGORITHM: str = "HS256" - ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")) # 24 hours + ACCESS_TOKEN_EXPIRE_MINUTES: int = int( + os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440") + ) # 24 hours # Email Settings SENDGRID_API_KEY: str = os.getenv("SENDGRID_API_KEY", "") @@ -22,8 +26,12 @@ class Settings(BaseSettings): PINECONE_API_KEY: str = os.getenv("PINECONE_API_KEY", "") PINECONE_ENVIRONMENT: str = os.getenv("PINECONE_ENVIRONMENT", "") PINECONE_INDEX_NAME: str = os.getenv("PINECONE_INDEX_NAME", "docbrain") - PINECONE_SUMMARY_INDEX_NAME: str = os.getenv("PINECONE_SUMMARY_INDEX_NAME", "summary") - PINECONE_QUESTIONS_INDEX_NAME: str = os.getenv("PINECONE_QUESTIONS_INDEX_NAME", "questions") + PINECONE_SUMMARY_INDEX_NAME: str = os.getenv( + "PINECONE_SUMMARY_INDEX_NAME", "summary" + ) + PINECONE_QUESTIONS_INDEX_NAME: str = os.getenv( + "PINECONE_QUESTIONS_INDEX_NAME", "questions" + ) RETRIEVER_TYPE: str = os.getenv("RETRIEVER_TYPE", "pinecone") # LLM @@ -39,7 +47,7 @@ class Settings(BaseSettings): RAG_TOP_K: int = 3 RAG_SIMILARITY_THRESHOLD: float = 0.3 RERANKER_TYPE: str = os.getenv("RERANKER_TYPE", "flag") - + @property def WHITELISTED_EMAIL_LIST(self) -> List[str]: return [email.strip() for email in self.WHITELISTED_EMAILS.split(",")] @@ -57,7 +65,7 @@ def WHITELISTED_EMAIL_LIST(self) -> List[str]: MYSQL_USER: str = os.getenv("MYSQL_USER", "docbrain") MYSQL_PASSWORD: str = os.getenv("MYSQL_PASSWORD", "docbrain") MYSQL_DATABASE: str = os.getenv("MYSQL_DATABASE", "docbrain") - + @property def DATABASE_URL(self) -> str: """Get SQLAlchemy database URL""" @@ -65,15 +73,21 @@ def DATABASE_URL(self) -> str: # Celery CELERY_BROKER_URL: str = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0") - CELERY_RESULT_BACKEND: str = os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/0") + CELERY_RESULT_BACKEND: str = os.getenv( + "CELERY_RESULT_BACKEND", "redis://localhost:6379/0" + ) # CORS - CORS_ORIGINS: str = os.getenv("CORS_ORIGINS", "http://localhost:5173,http://127.0.0.1:5173") + CORS_ORIGINS: str = os.getenv( + "CORS_ORIGINS", "http://localhost:5173,http://127.0.0.1:5173" + ) @property def CORS_ORIGIN_LIST(self) -> List[str]: """Parse comma-separated CORS origins.""" - return [origin.strip() for origin in self.CORS_ORIGINS.split(",") if origin.strip()] + return [ + origin.strip() for origin in self.CORS_ORIGINS.split(",") if origin.strip() + ] BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = [] @@ -88,7 +102,7 @@ def CORS_ORIGIN_LIST(self) -> List[str]: def STORAGE_URL(self) -> str: """Get SQLAlchemy database URL""" return f"mysql://{self.STORAGE_USER}:{self.STORAGE_PASSWORD}@{self.STORAGE_HOST}:{self.STORAGE_PORT}/{self.STORAGE_DATABASE}" - + @field_validator("BACKEND_CORS_ORIGINS") def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]: if isinstance(v, str) and not v.startswith("["): @@ -100,11 +114,16 @@ def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str LLM_PROVIDER: str = os.getenv("LLM_PROVIDER", "gemini") # Default to gemini OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") ANTHROPIC_API_KEY: str = os.getenv("ANTHROPIC_API_KEY", "") - DEFAULT_LLM_MODEL: Optional[str] = os.getenv("DEFAULT_LLM_MODEL", None) # Default model based on provider - EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "text-embedding-004") # Default embedding model + DEFAULT_LLM_MODEL: Optional[str] = os.getenv( + "DEFAULT_LLM_MODEL", None + ) # Default model based on provider + EMBEDDING_MODEL: str = os.getenv( + "EMBEDDING_MODEL", "text-embedding-004" + ) # Default embedding model class Config: env_file = ".env" case_sensitive = True -settings = Settings() \ No newline at end of file + +settings = Settings() diff --git a/app/core/middleware.py b/app/core/middleware.py index 53826d5..779ef59 100644 --- a/app/core/middleware.py +++ b/app/core/middleware.py @@ -1,14 +1,13 @@ -import time import logging +import time from collections import defaultdict +from typing import Callable, Dict, List, Optional -from fastapi import Request, HTTPException, status +from fastapi import HTTPException, Request, status from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse -from typing import Dict, List, Callable, Optional from app.core.permissions import Permission, get_permissions_for_role -from app.db.models.user import UserRole logger = logging.getLogger(__name__) @@ -19,16 +18,16 @@ class PermissionsMiddleware(BaseHTTPMiddleware): This middleware is automatically applied to all routes and checks if the user has the required permissions based on path and method. """ - + def __init__( - self, + self, app, path_permissions: Dict[str, Dict[str, List[Permission]]] = None, public_paths: List[str] = None, ): """ Initialize the middleware - + Args: app: The FastAPI application path_permissions: Dict mapping paths to methods to required permissions @@ -38,19 +37,19 @@ def __init__( super().__init__(app) self.path_permissions = path_permissions or {} self.public_paths = public_paths or [ - "/docs", - "/redoc", + "/docs", + "/redoc", "/openapi.json", "/auth/login", "/auth/register", "/auth/password-reset", "/auth/password-reset-confirm", ] - + async def dispatch(self, request: Request, call_next: Callable): """ Process the request and check permissions - + Args: request: The FastAPI request call_next: The next middleware or endpoint to call @@ -58,42 +57,42 @@ async def dispatch(self, request: Request, call_next: Callable): # Skip permission check for public paths if any(request.url.path.startswith(path) for path in self.public_paths): return await call_next(request) - + # Get user from request state (set by authentication middleware) user = getattr(request.state, "user", None) - + # If no user is authenticated and path is not public, return 401 if not user: return await call_next(request) # Let the endpoint handle authentication - + # Check if path has permission requirements path_match = None for path in self.path_permissions: if request.url.path.startswith(path): path_match = path break - + if not path_match: # No permissions defined for this path, allow access return await call_next(request) - + # Get method-specific permissions method_permissions = self.path_permissions[path_match].get(request.method, []) - + if not method_permissions: # No permissions defined for this method, allow access return await call_next(request) - + # Get user permissions based on role user_permissions = get_permissions_for_role(user.role) - + # Check if user has all required permissions if not all(perm in user_permissions for perm in method_permissions): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have permission to access this resource", ) - + # User has permissions, proceed with request return await call_next(request) @@ -110,7 +109,6 @@ async def dispatch(self, request: Request, call_next: Callable): "PUT": [Permission.UPDATE_KNOWLEDGE_BASE], "DELETE": [Permission.DELETE_KNOWLEDGE_BASE], }, - # Document routes "/api/documents": { "GET": [Permission.VIEW_DOCUMENTS], @@ -120,7 +118,6 @@ async def dispatch(self, request: Request, call_next: Callable): "GET": [Permission.VIEW_DOCUMENTS], "DELETE": [Permission.DELETE_DOCUMENT], }, - # Conversation routes "/api/conversations": { "GET": [Permission.CONVERSE_WITH_KNOWLEDGE_BASE], @@ -129,7 +126,6 @@ async def dispatch(self, request: Request, call_next: Callable): "/api/conversations/": { "GET": [Permission.CONVERSE_WITH_KNOWLEDGE_BASE], }, - # User management routes "/api/users": { "GET": [Permission.VIEW_USERS], @@ -143,7 +139,6 @@ async def dispatch(self, request: Request, call_next: Callable): "PUT": [Permission.UPDATE_USER], "DELETE": [Permission.DELETE_USER], }, - # System routes "/api/system": { "GET": [Permission.MANAGE_SYSTEM], @@ -171,7 +166,12 @@ def __init__( ): super().__init__(app) self.requests_per_minute = requests_per_minute - self.exempt_paths = exempt_paths or ["/health", "/docs", "/openapi.json", "/redoc"] + self.exempt_paths = exempt_paths or [ + "/health", + "/docs", + "/openapi.json", + "/redoc", + ] # {client_ip: [timestamp, ...]} self._requests: Dict[str, List[float]] = defaultdict(list) @@ -205,4 +205,4 @@ async def dispatch(self, request: Request, call_next: Callable): ) self._requests[client_ip].append(now) - return await call_next(request) \ No newline at end of file + return await call_next(request) diff --git a/app/core/permissions.py b/app/core/permissions.py index 652817d..916b2b2 100644 --- a/app/core/permissions.py +++ b/app/core/permissions.py @@ -1,7 +1,7 @@ -from enum import Enum, auto -from typing import Dict, List, Set +from enum import Enum +from typing import Dict, List + from fastapi import Depends, HTTPException, status -from sqlalchemy.orm import Session from app.api.deps import get_current_user from app.db.models.user import UserRole @@ -14,21 +14,21 @@ class Permission(str, Enum): CREATE_KNOWLEDGE_BASE = "create_knowledge_base" UPDATE_KNOWLEDGE_BASE = "update_knowledge_base" DELETE_KNOWLEDGE_BASE = "delete_knowledge_base" - + # Document permissions VIEW_DOCUMENTS = "view_documents" UPLOAD_DOCUMENT = "upload_document" DELETE_DOCUMENT = "delete_document" - + # Conversation permissions CONVERSE_WITH_KNOWLEDGE_BASE = "converse_with_knowledge_base" - + # User management permissions VIEW_USERS = "view_users" CREATE_USER = "create_user" UPDATE_USER = "update_user" DELETE_USER = "delete_user" - + # System management permissions MANAGE_SYSTEM = "manage_system" @@ -82,26 +82,29 @@ def endpoint(_, permission=Depends(check_permission(Permission.SOME_PERMISSION)) # This will only execute if the user has the required permission pass """ - async def permission_dependency(current_user: UserResponse = Depends(get_current_user)): + + async def permission_dependency( + current_user: UserResponse = Depends(get_current_user), + ): if current_user.role not in ROLE_PERMISSIONS: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Role {current_user.role} has no defined permissions", ) - + # Get permissions for the user's role user_permissions = get_permissions_for_role(current_user.role) - + # Check if the user has the required permission if required_permission not in user_permissions: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Permission denied: {required_permission} required", ) - + # User has the permission, return the current user return current_user - + return permission_dependency @@ -114,28 +117,31 @@ def endpoint(_, permission=Depends(require_permissions([Permission.PERM1, Permis # This will only execute if the user has ALL the required permissions pass """ - async def permissions_dependency(current_user: UserResponse = Depends(get_current_user)): + + async def permissions_dependency( + current_user: UserResponse = Depends(get_current_user), + ): if current_user.role not in ROLE_PERMISSIONS: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Role {current_user.role} has no defined permissions", ) - + # Get permissions for the user's role user_permissions = get_permissions_for_role(current_user.role) - + # Check if the user has all required permissions missing_permissions = [ perm for perm in required_permissions if perm not in user_permissions ] - + if missing_permissions: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Permission denied: missing {', '.join(missing_permissions)}", ) - + # User has all permissions, return the current user return current_user - - return permissions_dependency \ No newline at end of file + + return permissions_dependency diff --git a/app/core/prompts.py b/app/core/prompts.py index 63040ef..63493ad 100644 --- a/app/core/prompts.py +++ b/app/core/prompts.py @@ -5,87 +5,81 @@ Prompts are organized by domain and purpose, and can be parameterized with variables. """ -from typing import Dict, Any, Optional -import jinja2 import logging +import jinja2 + logger = logging.getLogger(__name__) # Configure Jinja2 environment for template rendering _template_env = jinja2.Environment( autoescape=False, # We don't need HTML escaping for prompts trim_blocks=True, - lstrip_blocks=True + lstrip_blocks=True, ) + class PromptRegistry: """Registry for managing and accessing prompts throughout the application.""" - + # Organized by domain/module and then by purpose PROMPTS = { - "query_router": { - "analyze_query": """ + "query_router": {"analyze_query": """ You are a query router for a hybrid retrieval system. Your job is to determine whether to route a user query to: - - 1. TAG (Table Augmented Generation) - for queries that need access to structured data and would be best answered with SQL + + 1. TAG (Table Augmented Generation) - for queries that need access to structured data and would be best answered with SQL 2. RAG (Retrieval Augmented Generation) - for queries about unstructured text/content - + Routes to TAG when: - The query asks about statistical information (averages, counts, sums) - The query explicitly asks for database information - The query requests tabular data, spreadsheets, or data analysis - The query involves filtering, sorting, or comparing quantitative data - The query is clearly asking for information that would be stored in a structured format - + Routes to RAG when: - The query asks about concepts, explanations, or general information - The query is looking for specific text content - The query seems to be related to documents, reports, or unstructured content - The query is asking about procedures, policies, or general knowledge - + For the following query, determine the appropriate service (tag or rag), provide a confidence score (0-1), and explain your reasoning. Return your answer as a JSON object with the keys: service, confidence, reasoning. - + User Query: {{ query }} - """ - }, - "tag_service": { - "generate_sql": """ + """}, + "tag_service": {"generate_sql": """ You are an AI assistant that converts natural language questions into SQL queries. - + I have the following database tables: {{ schema_text }} - + Given this schema, please generate a SQL query that answers the following question: "{{ query }}" - + Return ONLY a valid SQL query without any explanations or markdown formatting. Make sure the query is compatible with common SQL dialects. Do not use features specific to one SQL dialect unless necessary. - """ - }, - "rag_service": { - "generate_answer": """ + """}, + "rag_service": {"generate_answer": """ You are an assistant that provides accurate, helpful answers based on the given context. - + CONTEXT: {{ context }} - + USER QUERY: {{ query }} - + Based only on the context provided, answer the user query. If the context doesn't contain enough information to provide a complete answer, say so. Cite relevant parts of the context as part of your answer using [Document: Title] format. - """ - }, - "ingestor": { - "generate_table_schema": """ - Generate a SQL database create table query for the given table name and headers. + """}, + "ingestor": {"generate_table_schema": """ + Generate a SQL database create table query for the given table name and headers. Make sure to use headers as column names and rows as sample data. - Rows contain sample data for the table. + Rows contain sample data for the table. Use your understanding to extrapolate scenario where datatype is not obvious, or might be different from the sample data. - + Example: CREATE TABLE IF NOT EXISTS {{ table_name }} ( {{ headers[0] }} VARCHAR(255) NOT NULL, {{ headers[1] }} VARCHAR(255) NULL, @@ -95,21 +89,20 @@ class PromptRegistry: Table name: {{ table_name }} Headers: {{ headers }} Rows: {{ rows }} - """ - } + """}, # Add more domains and prompts as needed } - + @classmethod def get_prompt(cls, domain: str, prompt_name: str, **kwargs) -> str: """ Get a prompt by domain and name, with optional variable substitution. - + Args: domain: The domain or module the prompt belongs to prompt_name: The specific prompt identifier **kwargs: Variables to substitute in the prompt template - + Returns: The rendered prompt as a string """ @@ -118,31 +111,33 @@ def get_prompt(cls, domain: str, prompt_name: str, **kwargs) -> str: if domain not in cls.PROMPTS: logger.warning(f"Domain '{domain}' not found in prompt registry") return "" - + if prompt_name not in cls.PROMPTS[domain]: logger.warning(f"Prompt '{prompt_name}' not found in domain '{domain}'") return "" - + raw_prompt = cls.PROMPTS[domain][prompt_name] - + # If no variables to substitute, return the raw prompt if not kwargs: return raw_prompt - + # Render the template with the provided variables template = _template_env.from_string(raw_prompt) return template.render(**kwargs) - + except Exception as e: logger.error(f"Error rendering prompt '{domain}.{prompt_name}': {e}") # Return an empty string or the raw template in case of error return cls.PROMPTS.get(domain, {}).get(prompt_name, "") @classmethod - def register_prompt(cls, domain: str, prompt_name: str, prompt_template: str) -> None: + def register_prompt( + cls, domain: str, prompt_name: str, prompt_template: str + ) -> None: """ Register a new prompt or update an existing one. - + Args: domain: The domain or module the prompt belongs to prompt_name: The specific prompt identifier @@ -151,11 +146,12 @@ def register_prompt(cls, domain: str, prompt_name: str, prompt_template: str) -> # Create domain if it doesn't exist if domain not in cls.PROMPTS: cls.PROMPTS[domain] = {} - + # Register or update the prompt cls.PROMPTS[domain][prompt_name] = prompt_template logger.info(f"Registered prompt '{domain}.{prompt_name}'") + # Simple alias for brevity in imports get_prompt = PromptRegistry.get_prompt -register_prompt = PromptRegistry.register_prompt \ No newline at end of file +register_prompt = PromptRegistry.register_prompt diff --git a/app/core/security.py b/app/core/security.py index a33e789..a97d274 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -3,14 +3,16 @@ # Password hashing context pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + def verify_password(plain_password: str, hashed_password: str) -> bool: """ Verify a plain password against a hashed password """ return pwd_context.verify(plain_password, hashed_password) + def get_password_hash(password: str) -> str: """ Hash a password for storing """ - return pwd_context.hash(password) \ No newline at end of file + return pwd_context.hash(password) diff --git a/app/db/base_class.py b/app/db/base_class.py index 7f3dc00..1f675d0 100644 --- a/app/db/base_class.py +++ b/app/db/base_class.py @@ -1,16 +1,19 @@ +import uuid + +from sqlalchemy import Column, DateTime, String from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import Column, String, DateTime from sqlalchemy.sql import func -import uuid # Create base class for SQLAlchemy models Base = declarative_base() + # Define a base model class with common fields class BaseModel(Base): """Base model class with common fields for all SQLAlchemy models""" + __abstract__ = True - + id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) created_at = Column(DateTime, default=func.now()) - updated_at = Column(DateTime, onupdate=func.now()) \ No newline at end of file + updated_at = Column(DateTime, onupdate=func.now()) diff --git a/app/db/database.py b/app/db/database.py index e84178b..ff046b0 100644 --- a/app/db/database.py +++ b/app/db/database.py @@ -1,11 +1,12 @@ -from typing import Generator import logging +from typing import Generator + from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.orm import Session, sessionmaker + from app.core.config import settings # Import the base class -from app.db.base_class import Base logger = logging.getLogger(__name__) @@ -22,10 +23,11 @@ # Create session factory SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + def get_db() -> Generator[Session, None, None]: """ Get a database session. - + This function is used as a dependency in FastAPI endpoints to get a database session. It yields a session and ensures it's closed after use. """ @@ -33,4 +35,4 @@ def get_db() -> Generator[Session, None, None]: try: yield db finally: - db.close() \ No newline at end of file + db.close() diff --git a/app/db/models/__init__.py b/app/db/models/__init__.py index a8edd5f..ef30b69 100644 --- a/app/db/models/__init__.py +++ b/app/db/models/__init__.py @@ -1,9 +1,14 @@ from app.db.models.base import DBModel -from app.db.models.knowledge_base import Document, DocumentType, DocumentStatus, KnowledgeBase -from app.db.models.user import User, UserRole from app.db.models.conversation import Conversation -from app.db.models.message import Message, MessageStatus, MessageContentType -from app.db.models.question import Question, AnswerType, QuestionStatus +from app.db.models.knowledge_base import ( + Document, + DocumentStatus, + DocumentType, + KnowledgeBase, +) +from app.db.models.message import Message, MessageContentType, MessageStatus +from app.db.models.question import AnswerType, Question, QuestionStatus +from app.db.models.user import User, UserRole __all__ = [ "DBModel", @@ -19,5 +24,5 @@ "MessageContentType", "Question", "AnswerType", - "QuestionStatus" + "QuestionStatus", ] diff --git a/app/db/models/base.py b/app/db/models/base.py index e478acb..2960c40 100644 --- a/app/db/models/base.py +++ b/app/db/models/base.py @@ -1,12 +1,17 @@ +import uuid from datetime import datetime from typing import Optional + from pydantic import BaseModel, Field -import uuid + class DBModel(BaseModel): """Base model class for all database models""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) - created_at: Optional[str] = Field(default_factory=lambda: datetime.utcnow().isoformat()) + created_at: Optional[str] = Field( + default_factory=lambda: datetime.utcnow().isoformat() + ) updated_at: Optional[str] = None def update_timestamp(self): @@ -14,4 +19,4 @@ def update_timestamp(self): self.updated_at = datetime.utcnow().isoformat() class Config: - from_attributes = True \ No newline at end of file + from_attributes = True diff --git a/app/db/models/conversation.py b/app/db/models/conversation.py index 00acf07..68ecce5 100644 --- a/app/db/models/conversation.py +++ b/app/db/models/conversation.py @@ -1,17 +1,17 @@ -from sqlalchemy import Column, String, ForeignKey, Boolean, DateTime +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String from sqlalchemy.sql import func from app.db.base_class import BaseModel + class Conversation(BaseModel): """Conversation SQLAlchemy model""" + __tablename__ = "conversations" - + title = Column(String, nullable=False) user_id = Column(String, ForeignKey("users.id")) knowledge_base_id = Column(String, ForeignKey("knowledge_bases.id")) is_active = Column(Boolean, default=True) created_at = Column(DateTime, default=func.now()) updated_at = Column(DateTime, default=func.now(), onupdate=func.now()) - - \ No newline at end of file diff --git a/app/db/models/knowledge_base.py b/app/db/models/knowledge_base.py index 609b444..4c96ab3 100644 --- a/app/db/models/knowledge_base.py +++ b/app/db/models/knowledge_base.py @@ -1,19 +1,33 @@ -from sqlalchemy import Column, LargeBinary, String, Text, ForeignKey, DateTime, Integer, Table +import enum + +from sqlalchemy import ( + Column, + DateTime, + ForeignKey, + Integer, + LargeBinary, + String, + Table, + Text, +) from sqlalchemy.orm import relationship from sqlalchemy.sql import func -import enum from app.db.base_class import BaseModel + class DocumentStatus(str, enum.Enum): """Document processing status""" + PENDING = "PENDING" PROCESSING = "PROCESSING" PROCESSED = "PROCESSED" FAILED = "FAILED" + class DocumentType(str, enum.Enum): """Document type""" + PDF = "application/pdf" JPG = "image/jpeg" PNG = "image/png" @@ -28,10 +42,12 @@ class DocumentType(str, enum.Enum): TXT = "text/plain" HTML = "text/html" + class Document(BaseModel): """Document model""" + __tablename__ = "documents" - + title = Column(String, nullable=False) knowledge_base_id = Column(String, ForeignKey("knowledge_bases.id"), nullable=False) content = Column(LargeBinary, nullable=False) # Base64 encoded content @@ -44,7 +60,7 @@ class Document(BaseModel): summary = Column(Text, nullable=True) created_at = Column(DateTime, default=func.now()) updated_at = Column(DateTime, default=func.now(), onupdate=func.now()) - + class Config: json_schema_extra = { "example": { @@ -57,27 +73,41 @@ class Config: "status": "completed", "summary": "This is a summary of the document", } - } + } + # Knowledge base sharing association table knowledge_base_sharing = Table( "knowledge_base_sharing", BaseModel.metadata, - Column("knowledge_base_id", String, ForeignKey("knowledge_bases.id", ondelete="CASCADE"), primary_key=True), - Column("user_id", String, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True), + Column( + "knowledge_base_id", + String, + ForeignKey("knowledge_bases.id", ondelete="CASCADE"), + primary_key=True, + ), + Column( + "user_id", String, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True + ), Column("created_at", DateTime, default=func.now()), ) + class KnowledgeBase(BaseModel): """Knowledge base model""" + __tablename__ = "knowledge_bases" - + name = Column(String, nullable=False) description = Column(Text, nullable=True) user_id = Column(String, ForeignKey("users.id"), nullable=False) created_at = Column(DateTime, default=func.now()) updated_at = Column(DateTime, default=func.now(), onupdate=func.now()) - + # Define relationship to users that this knowledge base is shared with - shared_with = relationship("User", secondary=knowledge_base_sharing, lazy="joined", backref="shared_knowledge_bases") - \ No newline at end of file + shared_with = relationship( + "User", + secondary=knowledge_base_sharing, + lazy="joined", + backref="shared_knowledge_bases", + ) diff --git a/app/db/models/message.py b/app/db/models/message.py index 5f56e3b..bc67519 100644 --- a/app/db/models/message.py +++ b/app/db/models/message.py @@ -1,14 +1,17 @@ -from sqlalchemy import Column, String, Text, ForeignKey, DateTime, JSON -from sqlalchemy.sql import func import enum +from sqlalchemy import JSON, Column, DateTime, ForeignKey, String, Text +from sqlalchemy.sql import func + from app.db.base_class import BaseModel + class MessageKind(str, enum.Enum): USER = "USER" ASSISTANT = "ASSISTANT" SYSTEM = "SYSTEM" + class MessageContentType(str, enum.Enum): TEXT = "TEXT" IMAGE = "IMAGE" @@ -16,27 +19,31 @@ class MessageContentType(str, enum.Enum): VIDEO = "VIDEO" DOCUMENT = "DOCUMENT" + class MessageStatus(str, enum.Enum): RECEIVED = "RECEIVED" PROCESSING = "PROCESSING" PROCESSED = "PROCESSED" FAILED = "FAILED" + class Message(BaseModel): """Message SQLAlchemy model""" + __tablename__ = "messages" - + content = Column(Text, nullable=False) content_type = Column(String, nullable=False, default=MessageContentType.TEXT.value) kind = Column(String, nullable=False, default=MessageKind.USER.value) conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False) user_id = Column(String, ForeignKey("users.id"), nullable=False) sources = Column(JSON, nullable=True) - message_metadata = Column(JSON, nullable=True) # For storing query routing info and other metadata + message_metadata = Column( + JSON, nullable=True + ) # For storing query routing info and other metadata status = Column(String, nullable=False) created_at = Column(DateTime, default=func.now()) updated_at = Column(DateTime, default=func.now(), onupdate=func.now()) - + # Optional fields for tracking sources knowledge_base_id = Column(String, ForeignKey("knowledge_bases.id")) - \ No newline at end of file diff --git a/app/db/models/question.py b/app/db/models/question.py index a8595ee..60f0b79 100644 --- a/app/db/models/question.py +++ b/app/db/models/question.py @@ -1,27 +1,30 @@ from enum import Enum -from typing import Optional + from pydantic import Field -from sqlalchemy import Column, String, Text, ForeignKey, DateTime -from sqlalchemy.sql import func +from sqlalchemy import Column, ForeignKey, String, Text from app.db.base_class import BaseModel from app.db.models.base import DBModel + class AnswerType(str, Enum): DIRECT = "DIRECT" SQL_QUERY = "SQL_QUERY" + class QuestionStatus(str, Enum): PENDING = "PENDING" INGESTING = "INGESTING" COMPLETED = "COMPLETED" FAILED = "FAILED" + # SQLAlchemy model for database operations class Question(BaseModel): """SQLAlchemy model for questions in a knowledge base""" + __tablename__ = "questions" - + question = Column(Text, nullable=False) answer = Column(Text, nullable=False) answer_type = Column(String, nullable=False) @@ -29,9 +32,11 @@ class Question(BaseModel): knowledge_base_id = Column(String, ForeignKey("knowledge_bases.id"), nullable=False) user_id = Column(String, ForeignKey("users.id"), nullable=False) + # Pydantic model for API validation and serialization class QuestionModel(DBModel): """Pydantic model for questions in a knowledge base""" + question: str answer: str answer_type: AnswerType @@ -40,4 +45,4 @@ class QuestionModel(DBModel): user_id: str class Config: - from_attributes = True \ No newline at end of file + from_attributes = True diff --git a/app/db/models/user.py b/app/db/models/user.py index 0a4ecf9..f9f77d2 100644 --- a/app/db/models/user.py +++ b/app/db/models/user.py @@ -1,21 +1,24 @@ from enum import Enum -from typing import Optional, List -from sqlalchemy import Column, String, Boolean, Enum as SQLAlchemyEnum -from sqlalchemy.orm import relationship +from typing import List, Optional + from pydantic import EmailStr, Field +from sqlalchemy import Boolean, Column, String from app.db.base_class import BaseModel + class UserRole(str, Enum): ADMIN = "admin" OWNER = "owner" USER = "user" + # SQLAlchemy User model class User(BaseModel): """User SQLAlchemy model""" + __tablename__ = "users" - + email = Column(String, unique=True, nullable=False) hashed_password = Column(String, nullable=False) full_name = Column(String, nullable=False) @@ -24,10 +27,12 @@ class User(BaseModel): verification_token = Column(String, nullable=True) reset_token = Column(String, nullable=True) is_active = Column(Boolean, default=True) - + + # Pydantic User model for API class UserModel: """Pydantic User model for API""" + email: EmailStr hashed_password: str full_name: str @@ -43,6 +48,6 @@ class Config: "email": "user@example.com", "full_name": "John Doe", "role": "user", - "is_verified": True + "is_verified": True, } - } \ No newline at end of file + } diff --git a/app/db/storage.py b/app/db/storage.py index 394d389..8587811 100644 --- a/app/db/storage.py +++ b/app/db/storage.py @@ -1,25 +1,26 @@ +from typing import Generator + from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker + from app.core.config import settings -from typing import Generator -from sqlalchemy.orm import Session engine = create_engine( - settings.STORAGE_URL, - pool_size=3, - max_overflow=6, - pool_timeout=30, - pool_recycle=1800, - pool_pre_ping=True + settings.STORAGE_URL, + pool_size=3, + max_overflow=6, + pool_timeout=30, + pool_recycle=1800, + pool_pre_ping=True, ) # Create session factory SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + def get_storage_db() -> Generator[Session, None, None]: db = SessionLocal() try: yield db finally: db.close() - diff --git a/app/main.py b/app/main.py index c6339b2..f1e8d26 100644 --- a/app/main.py +++ b/app/main.py @@ -1,9 +1,13 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from app.api.endpoints import auth, conversations, knowledge_bases, messages, users from app.core.config import settings -from app.api.endpoints import auth, knowledge_bases, conversations, messages, users -from app.core.middleware import PermissionsMiddleware, RateLimitMiddleware, DEFAULT_PATH_PERMISSIONS +from app.core.middleware import ( + DEFAULT_PATH_PERMISSIONS, + PermissionsMiddleware, + RateLimitMiddleware, +) app = FastAPI( title=settings.APP_NAME, @@ -33,15 +37,21 @@ # Include routers app.include_router(auth.router, prefix="/auth", tags=["Authentication"]) -app.include_router(knowledge_bases.router, prefix="/knowledge-bases", tags=["Knowledge Bases"]) -app.include_router(conversations.router, prefix="/conversations", tags=["Conversations"]) +app.include_router( + knowledge_bases.router, prefix="/knowledge-bases", tags=["Knowledge Bases"] +) +app.include_router( + conversations.router, prefix="/conversations", tags=["Conversations"] +) app.include_router(messages.router, prefix="/conversations", tags=["Messages"]) app.include_router(users.router, prefix="/users", tags=["Users"]) + @app.get("/") async def root(): return {"message": "Welcome to DocBrain API"} + @app.get("/health") async def health(): """Health check endpoint for monitoring and orchestration.""" @@ -49,4 +59,4 @@ async def health(): "status": "healthy", "service": settings.APP_NAME, "version": app.version, - } \ No newline at end of file + } diff --git a/app/repositories/conversation_repository.py b/app/repositories/conversation_repository.py index d1fb24f..5f41920 100644 --- a/app/repositories/conversation_repository.py +++ b/app/repositories/conversation_repository.py @@ -1,12 +1,18 @@ +import logging from typing import List, Optional -from datetime import datetime + +from sqlalchemy.orm import Session + from app.db.models.conversation import Conversation from app.db.models.user import User -from app.schemas.conversation import ConversationCreate, ConversationResponse, ConversationUpdate -import logging -from sqlalchemy.orm import Session +from app.schemas.conversation import ( + ConversationResponse, + ConversationUpdate, +) + logger = logging.getLogger(__name__) + class ConversationRepository: @staticmethod async def create(conversation: Conversation, db: Session) -> ConversationResponse: @@ -21,12 +27,17 @@ async def create(conversation: Conversation, db: Session) -> ConversationRespons logger.error(f"Failed to create conversation: {e}") raise - @staticmethod - async def get_by_id(conversation_id: str, user: User, db: Session) -> Optional[ConversationResponse]: + async def get_by_id( + conversation_id: str, user: User, db: Session + ) -> Optional[ConversationResponse]: """Get conversation by ID for a specific user""" try: - conversation = db.query(Conversation).filter(Conversation.id == conversation_id).first() + conversation = ( + db.query(Conversation) + .filter(Conversation.id == conversation_id) + .first() + ) if conversation and conversation.user_id == str(user.id): return ConversationResponse.model_validate(conversation) return None @@ -34,21 +45,34 @@ async def get_by_id(conversation_id: str, user: User, db: Session) -> Optional[C logger.error(f"Failed to get conversation by ID: {e}") raise - @staticmethod async def list_by_user(user: User, db: Session) -> List[ConversationResponse]: """List all conversations for a user""" logger.info(f"Listing conversations for user {user.id}") - conversations = db.query(Conversation).filter(Conversation.user_id == user.id).all() + conversations = ( + db.query(Conversation).filter(Conversation.user_id == user.id).all() + ) logger.info(f"Found {len(conversations)} conversations for user {user.id}") - return [ConversationResponse.model_validate(conversation) for conversation in conversations] + return [ + ConversationResponse.model_validate(conversation) + for conversation in conversations + ] @staticmethod - async def update(conversation_id: str, conversation_update: ConversationUpdate, user: User, db: Session) -> Optional[ConversationResponse]: + async def update( + conversation_id: str, + conversation_update: ConversationUpdate, + user: User, + db: Session, + ) -> Optional[ConversationResponse]: """Update conversation details""" try: # First verify ownership - conversation = db.query(Conversation).filter(Conversation.id == conversation_id).first() + conversation = ( + db.query(Conversation) + .filter(Conversation.id == conversation_id) + .first() + ) if not conversation: return None @@ -68,7 +92,11 @@ async def delete(conversation_id: str, user: User, db: Session) -> bool: """Delete a conversation and all its messages""" try: # First verify ownership - conversation = db.query(Conversation).filter(Conversation.id == conversation_id).first() + conversation = ( + db.query(Conversation) + .filter(Conversation.id == conversation_id) + .first() + ) if not conversation: return False @@ -77,15 +105,19 @@ async def delete(conversation_id: str, user: User, db: Session) -> bool: try: # Delete all messages first logger.debug(f"Deleting messages for conversation {conversation_id}") - db.execute("DELETE FROM messages WHERE conversation_id = ?", [conversation_id]) - + db.execute( + "DELETE FROM messages WHERE conversation_id = ?", [conversation_id] + ) + # Then delete the conversation logger.debug(f"Deleting conversation {conversation_id}") db.execute("DELETE FROM conversations WHERE id = ?", [conversation_id]) - + # Commit transaction db.execute("COMMIT") - logger.info(f"Successfully deleted conversation {conversation_id} and its messages") + logger.info( + f"Successfully deleted conversation {conversation_id} and its messages" + ) return True except Exception as e: # Rollback on error @@ -94,4 +126,4 @@ async def delete(conversation_id: str, user: User, db: Session) -> bool: raise except Exception as e: logger.error(f"Error in delete operation: {e}") - raise \ No newline at end of file + raise diff --git a/app/repositories/document_repository.py b/app/repositories/document_repository.py index 7a07ed1..c45f56c 100644 --- a/app/repositories/document_repository.py +++ b/app/repositories/document_repository.py @@ -1,25 +1,26 @@ -from typing import List, Optional, Dict, Any import logging +from typing import Any, Dict, List, Optional + from sqlalchemy.orm import Session -from sqlalchemy import select from app.db.models.knowledge_base import Document, DocumentStatus from app.schemas.document import DocumentResponse logger = logging.getLogger(__name__) + class DocumentRepository: """Repository for document operations""" - + @staticmethod async def create(document: Document, db: Session) -> DocumentResponse: """ Create a new document. - + Args: document: Document instance db: Database session - + Returns: Created document """ @@ -32,16 +33,16 @@ async def create(document: Document, db: Session) -> DocumentResponse: db.rollback() logger.error(f"Failed to create document: {e}") raise - + @staticmethod async def get_by_id(document_id: str, db: Session) -> Optional[DocumentResponse]: """ Get a document by ID. - + Args: document_id: Document ID db: Database session - + Returns: Document if found, None otherwise """ @@ -53,17 +54,19 @@ async def get_by_id(document_id: str, db: Session) -> Optional[DocumentResponse] except Exception as e: logger.error(f"Failed to get document by ID {document_id}: {e}") raise - + @staticmethod - async def list_all(db: Session, skip: int = 0, limit: int = 100) -> List[DocumentResponse]: + async def list_all( + db: Session, skip: int = 0, limit: int = 100 + ) -> List[DocumentResponse]: """ Get all documents with pagination. - + Args: db: Database session skip: Number of records to skip limit: Maximum number of records to return - + Returns: List of documents """ @@ -73,41 +76,50 @@ async def list_all(db: Session, skip: int = 0, limit: int = 100) -> List[Documen except Exception as e: logger.error(f"Failed to list documents: {e}") raise - + @staticmethod async def list_by_knowledge_base( knowledge_base_id: str, db: Session, - skip: int = 0, + skip: int = 0, limit: int = 100, - status: Optional[str] = None + status: Optional[str] = None, ) -> List[DocumentResponse]: """ Get documents by knowledge base ID with optional status filter. - + Args: knowledge_base_id: Knowledge base ID db: Database session skip: Number of records to skip limit: Maximum number of records to return status: Optional status filter - + Returns: List of documents """ try: - query = db.query(Document).filter(Document.knowledge_base_id == knowledge_base_id) - + query = db.query(Document).filter( + Document.knowledge_base_id == knowledge_base_id + ) + if status: query = query.filter(Document.status == status) - - return [DocumentResponse.model_validate(doc) for doc in query.offset(skip).limit(limit).all()] + + return [ + DocumentResponse.model_validate(doc) + for doc in query.offset(skip).limit(limit).all() + ] except Exception as e: - logger.error(f"Failed to list documents for knowledge base {knowledge_base_id}: {e}") + logger.error( + f"Failed to list documents for knowledge base {knowledge_base_id}: {e}" + ) raise @staticmethod - async def set_processing(document_id: str, db: Session) -> Optional[DocumentResponse]: + async def set_processing( + document_id: str, db: Session + ) -> Optional[DocumentResponse]: """ Set a document as processing. """ @@ -125,7 +137,9 @@ async def set_processing(document_id: str, db: Session) -> Optional[DocumentResp raise @staticmethod - async def set_processed(document_id: str, summary: Optional[str], processed_chunks: int, db: Session) -> Optional[DocumentResponse]: + async def set_processed( + document_id: str, summary: Optional[str], processed_chunks: int, db: Session + ) -> Optional[DocumentResponse]: """ Set a document as processed. """ @@ -145,7 +159,9 @@ async def set_processed(document_id: str, summary: Optional[str], processed_chun raise @staticmethod - async def set_failed(document_id: str, error_message: str, db: Session) -> Optional[DocumentResponse]: + async def set_failed( + document_id: str, error_message: str, db: Session + ) -> Optional[DocumentResponse]: """ Set a document as failed. """ @@ -164,15 +180,17 @@ async def set_failed(document_id: str, error_message: str, db: Session) -> Optio raise @staticmethod - async def update(document_id: str, update_data: Dict[str, Any], db: Session) -> Optional[DocumentResponse]: + async def update( + document_id: str, update_data: Dict[str, Any], db: Session + ) -> Optional[DocumentResponse]: """ Update a document. - + Args: document_id: Document ID update_data: Data to update db: Database session - + Returns: Updated document if found, None otherwise """ @@ -181,11 +199,11 @@ async def update(document_id: str, update_data: Dict[str, Any], db: Session) -> document = db.query(Document).filter(Document.id == document_id).first() if not document: return None - + # Update attributes for key, value in update_data.items(): setattr(document, key, value) - + db.commit() db.refresh(document) return DocumentResponse.model_validate(document) @@ -193,16 +211,16 @@ async def update(document_id: str, update_data: Dict[str, Any], db: Session) -> db.rollback() logger.error(f"Failed to update document {document_id}: {e}") raise - + @staticmethod async def delete(document_id: str, db: Session) -> bool: """ Delete a document. - + Args: document_id: Document ID db: Database session - + Returns: True if document was deleted, False otherwise """ @@ -210,7 +228,7 @@ async def delete(document_id: str, db: Session) -> bool: document = db.query(Document).filter(Document.id == document_id).first() if not document: return False - + db.delete(document) db.commit() return True @@ -218,19 +236,3 @@ async def delete(document_id: str, db: Session) -> bool: db.rollback() logger.error(f"Failed to delete document {document_id}: {e}") raise - - @staticmethod - async def get_by_id(document_id: str, db: Session) -> Optional[DocumentResponse]: - """ - Class method to get a document by ID. - - Args: - document_id: Document ID - - Returns: - Document if found, None otherwise - """ - document = db.query(Document).filter(Document.id == document_id).first() - if not document: - return None - return DocumentResponse.model_validate(document) \ No newline at end of file diff --git a/app/repositories/knowledge_base_repository.py b/app/repositories/knowledge_base_repository.py index f0eb7a9..ccd354f 100644 --- a/app/repositories/knowledge_base_repository.py +++ b/app/repositories/knowledge_base_repository.py @@ -1,19 +1,22 @@ +import logging from typing import List, Optional -from sqlalchemy.orm import Session + from sqlalchemy import text +from sqlalchemy.orm import Session -from app.db.models.knowledge_base import KnowledgeBase, Document from app.db.models.conversation import Conversation -import logging - -from app.schemas.knowledge_base import KnowledgeBaseResponse +from app.db.models.knowledge_base import Document, KnowledgeBase from app.db.models.message import Message +from app.schemas.knowledge_base import KnowledgeBaseResponse logger = logging.getLogger(__name__) + class KnowledgeBaseRepository: @staticmethod - async def create(knowledge_base: KnowledgeBase, db: Session) -> KnowledgeBaseResponse: + async def create( + knowledge_base: KnowledgeBase, db: Session + ) -> KnowledgeBaseResponse: """Create a new knowledge base""" try: db.add(knowledge_base) @@ -24,7 +27,7 @@ async def create(knowledge_base: KnowledgeBase, db: Session) -> KnowledgeBaseRes db.rollback() logger.error(f"Failed to create knowledge base: {e}") raise - + @staticmethod async def get_by_id(kb_id: str, db: Session) -> Optional[KnowledgeBaseResponse]: """Get knowledge base by ID""" @@ -36,7 +39,7 @@ async def get_by_id(kb_id: str, db: Session) -> Optional[KnowledgeBaseResponse]: except Exception as e: logger.error(f"Failed to get knowledge base by ID {kb_id}: {e}") raise - + @staticmethod async def list_all(db: Session) -> List[KnowledgeBaseResponse]: """List all knowledge bases""" @@ -46,29 +49,33 @@ async def list_all(db: Session) -> List[KnowledgeBaseResponse]: except Exception as e: logger.error(f"Failed to list knowledge bases: {e}") raise - + @staticmethod async def list_by_owner(owner_id: str, db: Session) -> List[KnowledgeBaseResponse]: """List all knowledge bases owned by a user""" try: - knowledge_bases = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == owner_id).all() + knowledge_bases = ( + db.query(KnowledgeBase).filter(KnowledgeBase.user_id == owner_id).all() + ) return [KnowledgeBaseResponse.model_validate(kb) for kb in knowledge_bases] except Exception as e: logger.error(f"Failed to list knowledge bases by owner {owner_id}: {e}") raise - + @staticmethod - async def update(kb_id: str, update_data: dict, db: Session) -> Optional[KnowledgeBaseResponse]: + async def update( + kb_id: str, update_data: dict, db: Session + ) -> Optional[KnowledgeBaseResponse]: """Update knowledge base""" try: kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() if not kb: return None - + # Update attributes for key, value in update_data.items(): setattr(kb, key, value) - + db.commit() db.refresh(kb) return KnowledgeBaseResponse.model_validate(kb) @@ -76,49 +83,55 @@ async def update(kb_id: str, update_data: dict, db: Session) -> Optional[Knowled db.rollback() logger.error(f"Failed to update knowledge base {kb_id}: {e}") raise - + @staticmethod async def delete(kb_id: str, db: Session) -> bool: """Delete knowledge base and all related data in cascade""" try: # First get all conversations related to the knowledge base - conversations = db.query(Conversation).filter(Conversation.knowledge_base_id == kb_id).all() - + conversations = ( + db.query(Conversation) + .filter(Conversation.knowledge_base_id == kb_id) + .all() + ) + # For each conversation, delete all its messages first for conv in conversations: # Delete messages associated with this conversation db.query(Message).filter(Message.conversation_id == conv.id).delete() - + # Commit the message deletions db.commit() - + # Then delete all conversations for conv in conversations: db.delete(conv) - + # Commit the conversation deletions db.commit() - + # Delete all documents - documents = db.query(Document).filter(Document.knowledge_base_id == kb_id).all() + documents = ( + db.query(Document).filter(Document.knowledge_base_id == kb_id).all() + ) for doc in documents: db.delete(doc) db.commit() - + # Finally delete the knowledge base itself kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() if not kb: return False - + db.delete(kb) db.commit() return True - + except Exception as e: db.rollback() logger.error(f"Failed to cascade delete knowledge base {kb_id}: {e}") raise - + @staticmethod async def get_documents(kb_id: str, db: Session) -> List[Document]: """Get all documents in a knowledge base""" @@ -127,7 +140,7 @@ async def get_documents(kb_id: str, db: Session) -> List[Document]: except Exception as e: logger.error(f"Failed to get documents for knowledge base {kb_id}: {e}") raise - + @staticmethod async def list_documents_by_kb(kb_id: str, db: Session) -> List[Document]: """List all documents in a knowledge base""" @@ -136,29 +149,33 @@ async def list_documents_by_kb(kb_id: str, db: Session) -> List[Document]: except Exception as e: logger.error(f"Failed to list documents for knowledge base {kb_id}: {e}") raise - + @staticmethod async def is_shared_with_user(kb_id: str, user_id: str, db: Session) -> bool: """Check if a knowledge base is shared with a specific user""" try: from sqlalchemy import text + query = text(""" - SELECT COUNT(*) FROM knowledge_base_sharing + SELECT COUNT(*) FROM knowledge_base_sharing WHERE knowledge_base_id = :kb_id AND user_id = :user_id """) result = db.execute(query, {"kb_id": kb_id, "user_id": user_id}).scalar() return result > 0 except Exception as e: - logger.error(f"Failed to check if knowledge base {kb_id} is shared with user {user_id}: {e}") + logger.error( + f"Failed to check if knowledge base {kb_id} is shared with user {user_id}: {e}" + ) raise - + @staticmethod async def add_user_access(kb_id: str, user_id: str, db: Session) -> bool: """Share a knowledge base with a user""" try: from sqlalchemy import text + query = text(""" - INSERT INTO knowledge_base_sharing (knowledge_base_id, user_id) + INSERT INTO knowledge_base_sharing (knowledge_base_id, user_id) VALUES (:kb_id, :user_id) ON DUPLICATE KEY UPDATE knowledge_base_id = knowledge_base_id """) @@ -167,16 +184,19 @@ async def add_user_access(kb_id: str, user_id: str, db: Session) -> bool: return True except Exception as e: db.rollback() - logger.error(f"Failed to share knowledge base {kb_id} with user {user_id}: {e}") + logger.error( + f"Failed to share knowledge base {kb_id} with user {user_id}: {e}" + ) raise - + @staticmethod async def remove_user_access(kb_id: str, user_id: str, db: Session) -> bool: """Remove a user's access to a knowledge base""" try: from sqlalchemy import text + query = text(""" - DELETE FROM knowledge_base_sharing + DELETE FROM knowledge_base_sharing WHERE knowledge_base_id = :kb_id AND user_id = :user_id """) db.execute(query, {"kb_id": kb_id, "user_id": user_id}) @@ -184,34 +204,38 @@ async def remove_user_access(kb_id: str, user_id: str, db: Session) -> bool: return True except Exception as e: db.rollback() - logger.error(f"Failed to unshare knowledge base {kb_id} from user {user_id}: {e}") + logger.error( + f"Failed to unshare knowledge base {kb_id} from user {user_id}: {e}" + ) raise - + @staticmethod async def get_shared_users(kb_id: str, db: Session) -> List: """Get all users who have access to a knowledge base""" try: - from app.db.models.user import User from app.schemas.user import UserResponse - + # Get the knowledge base to ensure it exists kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() if not kb: return [] - + # Get all users who have access through the shared_with relationship - shared_users = [UserResponse.model_validate(user) for user in kb.shared_with] + shared_users = [ + UserResponse.model_validate(user) for user in kb.shared_with + ] return shared_users except Exception as e: logger.error(f"Failed to get shared users for knowledge base {kb_id}: {e}") raise - + @staticmethod - async def list_shared_with_user(user_id: str, db: Session) -> List[KnowledgeBaseResponse]: + async def list_shared_with_user( + user_id: str, db: Session + ) -> List[KnowledgeBaseResponse]: """List all knowledge bases shared with a specific user""" try: - - + query = text(f""" SELECT kb.* FROM knowledge_bases kb JOIN knowledge_base_sharing kbs ON kb.id = kbs.knowledge_base_id @@ -222,8 +246,10 @@ async def list_shared_with_user(user_id: str, db: Session) -> List[KnowledgeBase shared_kbs = [] for row in result: shared_kbs.append(KnowledgeBaseResponse.model_validate(row)) - + return shared_kbs # This will be an empty list if no shared knowledge bases exist except Exception as e: - logger.error(f"Failed to list knowledge bases shared with user {user_id}: {e}") - raise \ No newline at end of file + logger.error( + f"Failed to list knowledge bases shared with user {user_id}: {e}" + ) + raise diff --git a/app/repositories/message_repository.py b/app/repositories/message_repository.py index 9005562..855afb6 100644 --- a/app/repositories/message_repository.py +++ b/app/repositories/message_repository.py @@ -1,15 +1,15 @@ -from typing import List, Optional import json +import logging +from typing import List, Optional + +from sqlalchemy.orm import Session from app.db.models.message import Message, MessageContentType, MessageStatus from app.schemas.message import MessageResponse -from sqlalchemy.orm import Session -import logging logger = logging.getLogger(__name__) - class MessageRepository: @staticmethod async def create(message: Message, db: Session) -> Message: @@ -35,36 +35,42 @@ async def get_by_id(message_id: str, db: Session) -> Optional[MessageResponse]: return MessageResponse.model_validate(message) @staticmethod - async def list_by_conversation(conversation_id: str, db: Session) -> List[MessageResponse]: + async def list_by_conversation( + conversation_id: str, db: Session + ) -> List[MessageResponse]: """List all messages in a conversation""" - messages = db.query(Message).filter(Message.conversation_id == conversation_id).all() + messages = ( + db.query(Message).filter(Message.conversation_id == conversation_id).all() + ) return [MessageResponse.model_validate(message) for message in messages] @staticmethod - async def update_with_sources(message_id: str, content: str, sources: List[dict], db: Session) -> Optional[MessageResponse]: + async def update_with_sources( + message_id: str, content: str, sources: List[dict], db: Session + ) -> Optional[MessageResponse]: """Update message with response and sources""" update_data = { "content": content, "sources": json.dumps(sources) if sources is not None else None, - "status": "completed" + "status": "completed", } db.query(Message).filter(Message.id == message_id).update(update_data) db.commit() db.refresh(message_id) - return MessageResponse.model_validate(message_id) - + return MessageResponse.model_validate(message_id) + @staticmethod async def set_processed( - message_id: str, - content: str, - content_type: MessageContentType, - sources: List[dict], + message_id: str, + content: str, + content_type: MessageContentType, + sources: List[dict], db: Session, - metadata: Optional[dict] = None + metadata: Optional[dict] = None, ) -> Optional[MessageResponse]: """ Set message as processed - + Args: message_id: The ID of the message to update content: The content of the message @@ -81,12 +87,12 @@ async def set_processed( message.content_type = content_type message.sources = sources message.status = MessageStatus.PROCESSED - + # Add metadata if provided - store directly as a dictionary # SQLAlchemy will handle the JSON serialization if metadata: message.message_metadata = metadata - + db.commit() db.refresh(message) return MessageResponse.model_validate(message) @@ -94,9 +100,11 @@ async def set_processed( db.rollback() logger.error(f"Failed to set message as processed: {e}") raise - + @staticmethod - async def set_failed(message_id: str, error_message: str, db: Session) -> Optional[MessageResponse]: + async def set_failed( + message_id: str, error_message: str, db: Session + ) -> Optional[MessageResponse]: """Set message as failed""" try: message = db.query(Message).filter(Message.id == message_id).first() @@ -112,4 +120,4 @@ async def set_failed(message_id: str, error_message: str, db: Session) -> Option except Exception as e: db.rollback() logger.error(f"Failed to set message as failed: {e}") - raise \ No newline at end of file + raise diff --git a/app/repositories/question_repository.py b/app/repositories/question_repository.py index a7f14fe..48c8001 100644 --- a/app/repositories/question_repository.py +++ b/app/repositories/question_repository.py @@ -1,25 +1,26 @@ -from typing import List, Optional, Dict, Any import logging +from typing import Any, Dict, List, Optional + from sqlalchemy.orm import Session -from sqlalchemy import select from app.db.models.question import Question, QuestionStatus from app.schemas.question import QuestionResponse logger = logging.getLogger(__name__) + class QuestionRepository: """Repository for question operations""" - + @staticmethod async def create(question: Question, db: Session) -> QuestionResponse: """ Create a new question. - + Args: question: Question instance db: Database session - + Returns: Created question """ @@ -32,16 +33,16 @@ async def create(question: Question, db: Session) -> QuestionResponse: db.rollback() logger.error(f"Failed to create question: {e}") raise - + @staticmethod async def get_by_id(question_id: str, db: Session) -> Optional[QuestionResponse]: """ Get a question by ID. - + Args: question_id: Question ID db: Database session - + Returns: Question if found, None otherwise """ @@ -53,17 +54,19 @@ async def get_by_id(question_id: str, db: Session) -> Optional[QuestionResponse] except Exception as e: logger.error(f"Failed to get question by ID: {e}") raise - + @staticmethod - async def list_all(db: Session, skip: int = 0, limit: int = 100) -> List[QuestionResponse]: + async def list_all( + db: Session, skip: int = 0, limit: int = 100 + ) -> List[QuestionResponse]: """ List all questions. - + Args: db: Database session skip: Number of records to skip limit: Maximum number of records to return - + Returns: List of questions """ @@ -73,76 +76,82 @@ async def list_all(db: Session, skip: int = 0, limit: int = 100) -> List[Questio except Exception as e: logger.error(f"Failed to list all questions: {e}") raise - + @staticmethod async def list_by_knowledge_base( knowledge_base_id: str, db: Session, - skip: int = 0, + skip: int = 0, limit: int = 100, - status: Optional[str] = None + status: Optional[str] = None, ) -> List[QuestionResponse]: """ List questions by knowledge base ID. - + Args: knowledge_base_id: Knowledge base ID db: Database session skip: Number of records to skip limit: Maximum number of records to return status: Filter by status - + Returns: List of questions """ try: - query = db.query(Question).filter(Question.knowledge_base_id == knowledge_base_id) - + query = db.query(Question).filter( + Question.knowledge_base_id == knowledge_base_id + ) + if status: query = query.filter(Question.status == status) - + questions = query.offset(skip).limit(limit).all() return [QuestionResponse.model_validate(q) for q in questions] except Exception as e: logger.error(f"Failed to list questions by knowledge base: {e}") raise - + @staticmethod - async def set_ingesting(question_id: str, db: Session) -> Optional[QuestionResponse]: + async def set_ingesting( + question_id: str, db: Session + ) -> Optional[QuestionResponse]: """Set question status to INGESTING""" try: question = db.query(Question).filter(Question.id == question_id).first() if not question: return None - + question.status = QuestionStatus.INGESTING.value db.commit() db.refresh(question) - + return QuestionResponse.model_validate(question) except Exception as e: db.rollback() logger.error(f"Failed to set question to INGESTING: {e}") raise - + @staticmethod - async def set_completed(question_id: str, db: Session) -> Optional[QuestionResponse]: + async def set_completed( + question_id: str, db: Session + ) -> Optional[QuestionResponse]: """Set question status to COMPLETED""" try: question = db.query(Question).filter(Question.id == question_id).first() if not question: return None - + question.status = QuestionStatus.COMPLETED.value db.commit() db.refresh(question) - + return QuestionResponse.model_validate(question) except Exception as e: db.rollback() logger.error(f"Failed to set question to COMPLETED: {e}") raise - + @staticmethod async def set_failed(question_id: str, db: Session) -> Optional[QuestionResponse]: """Set question status to FAILED""" @@ -150,27 +159,29 @@ async def set_failed(question_id: str, db: Session) -> Optional[QuestionResponse question = db.query(Question).filter(Question.id == question_id).first() if not question: return None - + question.status = QuestionStatus.FAILED.value db.commit() db.refresh(question) - + return QuestionResponse.model_validate(question) except Exception as e: db.rollback() logger.error(f"Failed to set question to FAILED: {e}") raise - + @staticmethod - async def update(question_id: str, update_data: Dict[str, Any], db: Session) -> Optional[QuestionResponse]: + async def update( + question_id: str, update_data: Dict[str, Any], db: Session + ) -> Optional[QuestionResponse]: """ Update a question. - + Args: question_id: Question ID update_data: Data to update db: Database session - + Returns: Updated question """ @@ -178,30 +189,30 @@ async def update(question_id: str, update_data: Dict[str, Any], db: Session) -> question = db.query(Question).filter(Question.id == question_id).first() if not question: return None - + for key, value in update_data.items(): if hasattr(question, key): setattr(question, key, value) - + question.update_timestamp() db.commit() db.refresh(question) - + return QuestionResponse.model_validate(question) except Exception as e: db.rollback() logger.error(f"Failed to update question: {e}") raise - + @staticmethod async def delete(question_id: str, db: Session) -> bool: """ Delete a question. - + Args: question_id: Question ID db: Database session - + Returns: True if successful, False otherwise """ @@ -209,12 +220,12 @@ async def delete(question_id: str, db: Session) -> bool: question = db.query(Question).filter(Question.id == question_id).first() if not question: return False - + db.delete(question) db.commit() - + return True except Exception as e: db.rollback() logger.error(f"Failed to delete question: {e}") - raise \ No newline at end of file + raise diff --git a/app/repositories/storage_repository.py b/app/repositories/storage_repository.py index 936386b..d6b17ac 100644 --- a/app/repositories/storage_repository.py +++ b/app/repositories/storage_repository.py @@ -1,29 +1,42 @@ -from sqlalchemy.orm import Session -from sqlalchemy import text import logging +from sqlalchemy import text +from sqlalchemy.orm import Session + logger = logging.getLogger(__name__) + class StorageRepository: """ Repository for storing and retrieving data in the storage database. """ + @staticmethod - async def insert_csv(db: Session, table_name: str, create_table_query: str, columns: list[str], data: list[dict]): + async def insert_csv( + db: Session, + table_name: str, + create_table_query: str, + columns: list[str], + data: list[dict], + ): """ Insert CSV data into the storage database. """ - logger.info(f"Inserting CSV data into {table_name} with {len(data)} rows and {len(create_table_query)} columns") + logger.info( + f"Inserting CSV data into {table_name} with {len(data)} rows and {len(create_table_query)} columns" + ) try: # create a table if it doesn't exist logger.info(f"Create Table Query: {create_table_query}") db.execute(text(create_table_query)) logger.info(f"Successfully created table {table_name}") - # insert the data one by one + # insert the data one by one for row in data: - values = ', '.join(["'{}'".format(str(cell)) for cell in row]) - INSERT_ROW_QUERY = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({values})" + values = ", ".join(["'{}'".format(str(cell)) for cell in row]) + INSERT_ROW_QUERY = ( + f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({values})" + ) logger.info(f"Insert Row Query: {INSERT_ROW_QUERY}") db.execute(text(INSERT_ROW_QUERY)) db.commit() @@ -41,5 +54,3 @@ async def query(db: Session, query: str): except Exception as e: logger.error(f"Failed to query data: {e}") raise - - \ No newline at end of file diff --git a/app/repositories/user_repository.py b/app/repositories/user_repository.py index 46364d0..b91b871 100644 --- a/app/repositories/user_repository.py +++ b/app/repositories/user_repository.py @@ -1,11 +1,14 @@ -from typing import List, Optional import logging +from typing import List, Optional + from sqlalchemy.orm import Session + from app.db.models.user import User -from app.schemas.user import UserCreate, UserUpdate, UserResponse +from app.schemas.user import UserResponse, UserUpdate logger = logging.getLogger(__name__) + class UserRepository: @staticmethod async def create(user_data: User, db: Session) -> UserResponse: @@ -15,14 +18,14 @@ async def create(user_data: User, db: Session) -> UserResponse: db.add(user_data) db.commit() db.refresh(user_data) - + # Convert to response model return UserResponse.model_validate(user_data) except Exception as e: db.rollback() logger.error(f"Failed to create user: {e}") raise - + @staticmethod async def get_by_id(user_id: str, db: Session) -> Optional[UserResponse]: """Get user by ID""" @@ -34,7 +37,7 @@ async def get_by_id(user_id: str, db: Session) -> Optional[UserResponse]: except Exception as e: logger.error(f"Failed to get user by ID {user_id}: {e}") raise - + @staticmethod async def get_by_email(email: str, db: Session) -> Optional[UserResponse]: """Get user by email""" @@ -46,7 +49,7 @@ async def get_by_email(email: str, db: Session) -> Optional[UserResponse]: except Exception as e: logger.error(f"Failed to get user by email {email}: {e}") raise - + @staticmethod async def list_all(db: Session) -> List[UserResponse]: """List all users""" @@ -56,31 +59,33 @@ async def list_all(db: Session) -> List[UserResponse]: except Exception as e: logger.error(f"Failed to list users: {e}") raise - + @staticmethod - async def update(user_id: str, update_data: UserUpdate, db: Session) -> Optional[UserResponse]: + async def update( + user_id: str, update_data: UserUpdate, db: Session + ) -> Optional[UserResponse]: """Update user""" try: # Get the user db_user = db.query(User).filter(User.id == user_id).first() if db_user is None: return None - + # Update user attributes update_dict = update_data.model_dump(exclude_unset=True) for key, value in update_dict.items(): setattr(db_user, key, value) - + # Commit changes db.commit() db.refresh(db_user) - + return UserResponse.model_validate(db_user) except Exception as e: db.rollback() logger.error(f"Failed to update user {user_id}: {e}") raise - + @staticmethod async def delete(user_id: str, db: Session) -> bool: """Delete user""" @@ -88,11 +93,11 @@ async def delete(user_id: str, db: Session) -> bool: db_user = db.query(User).filter(User.id == user_id).first() if db_user is None: return False - + db.delete(db_user) db.commit() return True except Exception as e: db.rollback() logger.error(f"Failed to delete user {user_id}: {e}") - raise \ No newline at end of file + raise diff --git a/app/schemas/conversation.py b/app/schemas/conversation.py index 6f1e62e..a99f017 100644 --- a/app/schemas/conversation.py +++ b/app/schemas/conversation.py @@ -1,27 +1,38 @@ -from typing import Optional from datetime import datetime +from typing import Optional + from pydantic import BaseModel, Field + class ConversationBase(BaseModel): """Base conversation attributes""" + title: str = Field(..., description="Title of the conversation") - knowledge_base_id: str = Field(..., description="ID of the knowledge base this conversation is linked to") + knowledge_base_id: str = Field( + ..., description="ID of the knowledge base this conversation is linked to" + ) + class ConversationCreate(ConversationBase): """Attributes for creating a new conversation""" - pass + class ConversationUpdate(BaseModel): """Attributes that can be updated""" + title: Optional[str] = Field(None, description="New title for the conversation") + class ConversationResponse(ConversationBase): """Response model for conversations""" + id: str = Field(..., description="Unique identifier for the conversation") user_id: str = Field(..., description="ID of the user who created the conversation") is_active: bool = Field(..., description="Whether the conversation is active") created_at: datetime = Field(..., description="When the conversation was created") - updated_at: datetime = Field(..., description="When the conversation was last updated") + updated_at: datetime = Field( + ..., description="When the conversation was last updated" + ) class Config: - from_attributes = True \ No newline at end of file + from_attributes = True diff --git a/app/schemas/document.py b/app/schemas/document.py index a247373..79ca4e3 100644 --- a/app/schemas/document.py +++ b/app/schemas/document.py @@ -1,39 +1,52 @@ +from datetime import datetime from typing import Optional + from pydantic import BaseModel, Field -from datetime import datetime from app.db.models.knowledge_base import DocumentStatus, DocumentType + class DocumentBase(BaseModel): """Base document schema""" + title: str = Field(..., description="Title of the document") content_type: DocumentType = Field(..., description="Content type of the document") knowledge_base_id: str = Field(..., description="Knowledge base ID of the document") + class DocumentUpload(DocumentBase): """Schema for uploading a document""" + content: bytes = Field(..., description="Content of the document") + class DocumentUpdate(BaseModel): """Schema for updating a document""" + title: Optional[str] = None status: Optional[str] = None error_message: Optional[str] = None processed_chunks: Optional[int] = None summary: Optional[str] = None + class DocumentResponse(DocumentBase): """Schema for document response""" + id: str = Field(..., description="ID of the document") user_id: str = Field(..., description="User ID of the document") content: bytes = Field(..., description="Content of the document", exclude=True) size_bytes: int = Field(..., description="Size of the document in bytes") status: DocumentStatus = Field(..., description="Status of the document") - error_message: Optional[str] = Field(default=None, description="Error message if the document processing failed") - processed_chunks: Optional[int] = Field(default=None, description="Number of chunks processed") + error_message: Optional[str] = Field( + default=None, description="Error message if the document processing failed" + ) + processed_chunks: Optional[int] = Field( + default=None, description="Number of chunks processed" + ) summary: Optional[str] = Field(default=None, description="Summary of the document") created_at: datetime = Field(..., description="Created timestamp") updated_at: datetime = Field(..., description="Last updated timestamp") - + class Config: from_attributes = True diff --git a/app/schemas/knowledge_base.py b/app/schemas/knowledge_base.py index 33fa072..b0c8f58 100644 --- a/app/schemas/knowledge_base.py +++ b/app/schemas/knowledge_base.py @@ -1,21 +1,26 @@ -from typing import Annotated, Any, List, Optional -from pydantic import BaseModel, Field from datetime import datetime +from typing import Annotated, Optional + +from pydantic import BaseModel, Field from app.db.models.knowledge_base import DocumentStatus from app.schemas.user import UserResponse + class KnowledgeBaseBase(BaseModel): name: str = Field(..., alias="name") description: str = Field(..., alias="description") + class KnowledgeBaseCreate(KnowledgeBaseBase): pass + class KnowledgeBaseUpdate(BaseModel): name: Optional[str] = None description: Optional[str] = None + class SharedUserInfo(BaseModel): id: str email: str @@ -24,6 +29,7 @@ class SharedUserInfo(BaseModel): class Config: from_attributes = True + class KnowledgeBaseResponse(KnowledgeBaseBase): id: str user_id: str @@ -34,6 +40,7 @@ class KnowledgeBaseResponse(KnowledgeBaseBase): class Config: from_attributes = True + class DocumentResponse(BaseModel): id: str filename: str @@ -46,17 +53,23 @@ class DocumentResponse(BaseModel): created_at: datetime updated_at: Optional[datetime] = None + class DocumentUploadResponse(BaseModel): document_id: str status: DocumentStatus message: str = "Document upload initiated" + class KnowledgeBaseShareRequest(BaseModel): - user_id: str = Field(..., description="ID of the user to share the knowledge base with") + user_id: str = Field( + ..., description="ID of the user to share the knowledge base with" + ) + class KnowledgeBaseUnshareRequest(BaseModel): user_id: str = Field(..., description="ID of the user to remove access for") + class KnowledgeBaseSharingResponse(BaseModel): success: bool - message: str \ No newline at end of file + message: str diff --git a/app/schemas/message.py b/app/schemas/message.py index 0179466..d8f6611 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -1,11 +1,13 @@ -from typing import List, Optional, Dict, Any +import json from datetime import datetime -from pydantic import BaseModel, Field, model_validator from enum import Enum -import json +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field, model_validator from app.db.models.message import MessageContentType, MessageKind, MessageStatus + class MessageType(str, Enum): USER = "user" ASSISTANT = "assistant" @@ -13,44 +15,71 @@ class MessageType(str, Enum): class MessageSource(BaseModel): """Source document or question information""" + score: float = Field(..., description="Relevance score of the source") - content: str = Field(..., description="Relevant content from the document or question") - + content: str = Field( + ..., description="Relevant content from the document or question" + ) + # Document-specific fields (optional for questions) document_id: Optional[str] = Field(None, description="ID of the source document") title: Optional[str] = Field(None, description="Title of the source document") - chunk_index: Optional[int] = Field(None, description="Index of the chunk in the document") - + chunk_index: Optional[int] = Field( + None, description="Index of the chunk in the document" + ) + # Question-specific fields (optional for documents) question_id: Optional[str] = Field(None, description="ID of the source question") question: Optional[str] = Field(None, description="The question that was matched") - answer: Optional[str] = Field(None, description="The answer for the matched question") - answer_type: Optional[str] = Field(None, description="Type of answer (DIRECT, SQL_QUERY, etc.)") + answer: Optional[str] = Field( + None, description="The answer for the matched question" + ) + answer_type: Optional[str] = Field( + None, description="Type of answer (DIRECT, SQL_QUERY, etc.)" + ) + class MessageBase(BaseModel): """Base message attributes""" + content: str = Field(..., description="Content of the message") - content_type: MessageContentType = Field(..., description="Type of message (TEXT/IMAGE/AUDIO/VIDEO/DOCUMENT)") + content_type: MessageContentType = Field( + ..., description="Type of message (TEXT/IMAGE/AUDIO/VIDEO/DOCUMENT)" + ) + class MessageCreate(MessageBase): """Attributes for creating a new message""" - pass + class MessageResponse(MessageBase): """Response model for messages""" + id: str = Field(..., description="Unique identifier for the message") - kind: MessageKind = Field(..., description="Kind of message (USER/ASSISTANT/SYSTEM)") + kind: MessageKind = Field( + ..., description="Kind of message (USER/ASSISTANT/SYSTEM)" + ) user_id: str = Field(..., description="ID of the user who created the message") - conversation_id: str = Field(..., description="ID of the conversation this message belongs to") - knowledge_base_id: str = Field(..., description="ID of the knowledge base this message belongs to") - sources: Optional[List[MessageSource]] = Field(None, description="Source documents used for assistant's response") - message_metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata including routing information") - status: MessageStatus = Field(..., description="Status of the message (RECEIVED/PROCESSING/SENT/FAILED)") + conversation_id: str = Field( + ..., description="ID of the conversation this message belongs to" + ) + knowledge_base_id: str = Field( + ..., description="ID of the knowledge base this message belongs to" + ) + sources: Optional[List[MessageSource]] = Field( + None, description="Source documents used for assistant's response" + ) + message_metadata: Optional[Dict[str, Any]] = Field( + None, description="Additional metadata including routing information" + ) + status: MessageStatus = Field( + ..., description="Status of the message (RECEIVED/PROCESSING/SENT/FAILED)" + ) created_at: datetime = Field(..., description="When the message was created") updated_at: datetime = Field(..., description="When the message was last updated") - @model_validator(mode='after') - def parse_message_metadata(self) -> 'MessageResponse': + @model_validator(mode="after") + def parse_message_metadata(self) -> "MessageResponse": """Parse message_metadata if it's a string""" if self.message_metadata and isinstance(self.message_metadata, str): try: @@ -63,8 +92,12 @@ def parse_message_metadata(self) -> 'MessageResponse': class Config: from_attributes = True + class MessageProcessingResponse(BaseModel): """Response for asynchronous message processing""" + request_id: str = Field(..., description="Unique identifier for the request") - status: str = Field(..., description="Status of the request (processing/completed/failed)") - message: str = Field(..., description="Status message or error description") \ No newline at end of file + status: str = Field( + ..., description="Status of the request (processing/completed/failed)" + ) + message: str = Field(..., description="Status message or error description") diff --git a/app/schemas/question.py b/app/schemas/question.py index 8d747e8..2ceea83 100644 --- a/app/schemas/question.py +++ b/app/schemas/question.py @@ -1,36 +1,45 @@ -from typing import Optional -from pydantic import BaseModel, Field from datetime import datetime from enum import Enum +from typing import Optional + +from pydantic import BaseModel + class AnswerType(str, Enum): DIRECT = "DIRECT" SQL_QUERY = "SQL_QUERY" + class QuestionStatus(str, Enum): PENDING = "PENDING" INGESTING = "INGESTING" COMPLETED = "COMPLETED" FAILED = "FAILED" + class QuestionBase(BaseModel): """Base schema for question data""" + question: str answer: str answer_type: AnswerType + class QuestionCreate(QuestionBase): """Schema for creating a new question""" - pass + class QuestionUpdate(BaseModel): """Schema for updating an existing question""" + question: Optional[str] = None answer: Optional[str] = None answer_type: Optional[AnswerType] = None + class QuestionResponse(QuestionBase): """Schema for question response""" + id: str status: QuestionStatus knowledge_base_id: str @@ -45,13 +54,13 @@ class Config: def model_validate(cls, obj): """Custom validation to handle SQLAlchemy model to Pydantic conversion""" # Convert string status to enum - if hasattr(obj, 'status') and isinstance(obj.status, str): + if hasattr(obj, "status") and isinstance(obj.status, str): status = obj.status obj.status = QuestionStatus(status) - + # Convert string answer_type to enum - if hasattr(obj, 'answer_type') and isinstance(obj.answer_type, str): + if hasattr(obj, "answer_type") and isinstance(obj.answer_type, str): answer_type = obj.answer_type obj.answer_type = AnswerType(answer_type) - - return super().model_validate(obj) \ No newline at end of file + + return super().model_validate(obj) diff --git a/app/schemas/user.py b/app/schemas/user.py index ac91903..932538c 100644 --- a/app/schemas/user.py +++ b/app/schemas/user.py @@ -1,52 +1,66 @@ -from typing import Annotated, Optional, List +from typing import Annotated, List, Optional + from pydantic import BaseModel, EmailStr, Field + from app.db.models.user import UserRole + class UserBase(BaseModel): email: EmailStr = Field(..., alias="email") full_name: str = Field(..., alias="full_name") role: UserRole = Field(default=UserRole.USER) + class UserCreate(UserBase): password: str = Field(..., alias="password") + class UserLogin(BaseModel): email: EmailStr = Field(..., alias="email") password: str = Field(..., alias="password") + class UserUpdate(BaseModel): email: Optional[EmailStr] = Field(default=None, alias="email") full_name: Optional[str] = Field(default=None, alias="full_name") password: Optional[str] = Field(default=None, alias="password") role: UserRole = Field(default=UserRole.USER) + class Token(BaseModel): access_token: str = Field(..., alias="access_token") token_type: str = Field(default="bearer", alias="token_type") + class TokenData(BaseModel): email: Optional[str] = Field(default=None, alias="email") role: Optional[str] = Field(default=None, alias="role") + class PasswordReset(BaseModel): email: EmailStr = Field(..., alias="email") + class PasswordResetConfirm(BaseModel): token: str = Field(..., alias="token") new_password: str = Field(..., min_length=8) + class UserResponse(UserBase): id: str = Field(..., alias="id") is_active: bool = Field(..., alias="is_active") hashed_password: str = Annotated[str, Field(exclude=True)] class Config: - from_attributes = True - + from_attributes = True + + class UserWithPermissions(UserResponse): """User response with permissions information""" - permissions: List[str] = Field(..., description="List of permissions the user has based on their role") - + + permissions: List[str] = Field( + ..., description="List of permissions the user has based on their role" + ) + class Config: - from_attributes = True - \ No newline at end of file + from_attributes = True diff --git a/app/services/conversation_service.py b/app/services/conversation_service.py index ab224f7..91492f9 100644 --- a/app/services/conversation_service.py +++ b/app/services/conversation_service.py @@ -1,73 +1,87 @@ -from typing import List, Optional -from fastapi import HTTPException import logging +from typing import List, Optional +from fastapi import HTTPException from sqlalchemy.orm import Session + from app.db.models.conversation import Conversation -from app.db.models.user import User from app.repositories.conversation_repository import ConversationRepository +from app.schemas.conversation import ( + ConversationCreate, + ConversationResponse, + ConversationUpdate, +) from app.schemas.user import UserResponse from app.services.knowledge_base_service import KnowledgeBaseService -from app.schemas.conversation import ConversationCreate, ConversationResponse, ConversationUpdate # Set up logging logger = logging.getLogger(__name__) + class ConversationService: def __init__( self, conversation_repository: ConversationRepository, knowledge_base_service: KnowledgeBaseService, - db: Session + db: Session, ): self.repository = conversation_repository self.kb_service = knowledge_base_service self.db = db async def create_conversation( - self, - payload: ConversationCreate, - current_user: UserResponse + self, payload: ConversationCreate, current_user: UserResponse ) -> ConversationResponse: """Create a new conversation""" try: # Verify knowledge base access await self.kb_service.get_knowledge_base( - payload.knowledge_base_id, - current_user + payload.knowledge_base_id, current_user ) conversation = Conversation( title=payload.title, knowledge_base_id=payload.knowledge_base_id, - user_id=current_user.id + user_id=current_user.id, ) logger.info(f"Creating conversation for user {current_user.id}") - conversation: ConversationResponse = await self.repository.create(conversation, self.db) - logger.info(f"Conversation {conversation.id} created by user {current_user.id}") + conversation: ConversationResponse = await self.repository.create( + conversation, self.db + ) + logger.info( + f"Conversation {conversation.id} created by user {current_user.id}" + ) return conversation except Exception as e: logger.error(f"Failed to create conversation: {e}") raise HTTPException(status_code=500, detail=str(e)) - async def list_conversations(self, current_user: UserResponse) -> List[ConversationResponse]: + async def list_conversations( + self, current_user: UserResponse + ) -> List[ConversationResponse]: """List all conversations for the current user""" try: logger.info(f"Listing conversations for user {current_user.id}") - conversations: List[ConversationResponse] = await self.repository.list_by_user(current_user, self.db) - logger.info(f"Retrieved {len(conversations)} conversations for user {current_user.id}") + conversations: List[ConversationResponse] = ( + await self.repository.list_by_user(current_user, self.db) + ) + logger.info( + f"Retrieved {len(conversations)} conversations for user {current_user.id}" + ) return conversations except Exception as e: - logger.error(f"Failed to list conversations for user {current_user.id}: {e}") + logger.error( + f"Failed to list conversations for user {current_user.id}: {e}" + ) raise HTTPException(status_code=500, detail=str(e)) async def get_conversation( - self, - conversation_id: str, - current_user: UserResponse + self, conversation_id: str, current_user: UserResponse ) -> ConversationResponse: """Get conversation details""" try: - conversation: Optional[ConversationResponse] = await self.repository.get_by_id(conversation_id, current_user, self.db) + conversation: Optional[ConversationResponse] = ( + await self.repository.get_by_id(conversation_id, current_user, self.db) + ) if not conversation: logger.warning(f"Conversation {conversation_id} not found") raise HTTPException(status_code=404, detail="Conversation not found") @@ -82,23 +96,25 @@ async def update_conversation( self, conversation_id: str, conversation_update: ConversationUpdate, - current_user: UserResponse + current_user: UserResponse, ) -> ConversationResponse: """Update conversation details""" try: # Verify ownership - conversation: ConversationResponse = await self.get_conversation(conversation_id, current_user) + conversation: ConversationResponse = await self.get_conversation( + conversation_id, current_user + ) if not conversation: raise HTTPException(status_code=404, detail="Conversation not found") # Update conversation updated_conversation = await self.repository.update( - conversation_id, - conversation_update, - current_user + conversation_id, conversation_update, current_user ) if updated_conversation: - logger.info(f"Conversation {conversation_id} updated by user {current_user.id}") + logger.info( + f"Conversation {conversation_id} updated by user {current_user.id}" + ) return updated_conversation raise HTTPException(status_code=404, detail="Conversation not found") except HTTPException: @@ -108,24 +124,26 @@ async def update_conversation( raise HTTPException(status_code=500, detail=str(e)) async def delete_conversation( - self, - conversation_id: str, - current_user: UserResponse + self, conversation_id: str, current_user: UserResponse ) -> None: """Delete conversation and all its messages""" try: # Verify ownership - conversation: ConversationResponse = await self.get_conversation(conversation_id, current_user) + conversation: ConversationResponse = await self.get_conversation( + conversation_id, current_user + ) if not conversation: raise HTTPException(status_code=404, detail="Conversation not found") # Delete conversation (cascade will handle messages) if await self.repository.delete(conversation_id, current_user): - logger.info(f"Conversation {conversation_id} deleted by user {current_user.id}") + logger.info( + f"Conversation {conversation_id} deleted by user {current_user.id}" + ) else: raise HTTPException(status_code=404, detail="Conversation not found") except HTTPException: raise except Exception as e: logger.error(f"Failed to delete conversation {conversation_id}: {e}") - raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file + raise HTTPException(status_code=500, detail=str(e)) diff --git a/app/services/document_service.py b/app/services/document_service.py index 2b39e00..9c16300 100644 --- a/app/services/document_service.py +++ b/app/services/document_service.py @@ -1,22 +1,22 @@ -from typing import List -from fastapi import HTTPException -import base64 import logging -import aiofiles +from typing import List + from celery import Celery +from fastapi import HTTPException from sqlalchemy.orm import Session from app.db.models.knowledge_base import Document, DocumentStatus, DocumentType from app.db.models.user import UserRole from app.repositories.document_repository import DocumentRepository +from app.schemas.document import DocumentResponse, DocumentUpdate, DocumentUpload from app.schemas.user import UserResponse +from app.services.knowledge_base_service import FileStorage, KnowledgeBaseService from app.services.rag.vector_store import VectorStore -from app.services.knowledge_base_service import KnowledgeBaseService, FileStorage -from app.schemas.document import DocumentResponse, DocumentUpdate, DocumentUpload # Set up logging logger = logging.getLogger(__name__) + class DocumentService: def __init__( self, @@ -25,7 +25,7 @@ def __init__( knowledge_base_service: KnowledgeBaseService, file_storage: FileStorage, celery_app: Celery, - db: Session + db: Session, ): self.document_repository = document_repository self.vector_store = vector_store @@ -35,47 +35,45 @@ def __init__( self.db = db async def create_document( - self, - kb_id: str, - payload: DocumentUpload, - current_user: UserResponse + self, kb_id: str, payload: DocumentUpload, current_user: UserResponse ) -> DocumentResponse: """Create a new document in a knowledge base""" # file_path = None try: # Check knowledge base access await self.kb_service.get_knowledge_base(kb_id, current_user) - + # Check the number of documents in the knowledge base - existing_docs = await self.document_repository.list_by_knowledge_base(kb_id, self.db) + existing_docs = await self.document_repository.list_by_knowledge_base( + kb_id, self.db + ) if len(existing_docs) >= 20: raise HTTPException( status_code=400, - detail="Maximum number of documents (20) reached for this knowledge base" + detail="Maximum number of documents (20) reached for this knowledge base", ) - + # # Save file temporarily and get content # file_path = await self.file_storage.save_file(payload.title) # async with aiofiles.open(file_path, 'rb') as f: # content = await f.read() - + # Create document record document = await self._create_document_record( kb_id=kb_id, title=payload.title, content_type=self._detect_document_type(payload.content_type), content=payload.content, - user_id=str(current_user.id) + user_id=str(current_user.id), ) - + # Queue document processing task self.celery_app.send_task( - 'app.worker.tasks.initiate_document_ingestion', - args=[document.id] + "app.worker.tasks.initiate_document_ingestion", args=[document.id] ) - + return document - + except Exception as e: logger.error(f"Failed to create document in service: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -86,37 +84,35 @@ async def create_document( def _detect_document_type(self, content_type: str) -> DocumentType: """Detect document type from content type""" content_type = content_type.lower() - if 'pdf' in content_type: + if "pdf" in content_type: return DocumentType.PDF - if 'image/jpg' in content_type or 'image/jpeg' in content_type: + if "image/jpg" in content_type or "image/jpeg" in content_type: return DocumentType.JPG - if 'image/png' in content_type: + if "image/png" in content_type: return DocumentType.PNG - if 'image/gif' in content_type: + if "image/gif" in content_type: return DocumentType.GIF - if 'image/tiff' in content_type: + if "image/tiff" in content_type: return DocumentType.TIFF - if 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' in content_type: + if ( + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + in content_type + ): return DocumentType.DOCX - if 'application/msword' in content_type: + if "application/msword" in content_type: return DocumentType.DOC - if 'text/csv' in content_type: + if "text/csv" in content_type: return DocumentType.CSV - if 'text/plain' in content_type: + if "text/plain" in content_type: return DocumentType.TXT return DocumentType.TXT async def _create_document_record( - self, - kb_id: str, - title: str, - content_type: str, - content: bytes, - user_id: str + self, kb_id: str, title: str, content_type: str, content: bytes, user_id: str ) -> DocumentResponse: """Create document record with encoded content""" # content_base64 = base64.b64encode(content).decode('utf-8') - + document = Document( title=title, knowledge_base_id=kb_id, @@ -124,25 +120,23 @@ async def _create_document_record( content=content, size_bytes=len(content), user_id=user_id, - status=DocumentStatus.PENDING + status=DocumentStatus.PENDING, ) logger.info(f"Creating document record: {document}") return await self.document_repository.create(document, self.db) - async def get_document( - self, - doc_id: str, - current_user: UserResponse - ) -> Document: + async def get_document(self, doc_id: str, current_user: UserResponse) -> Document: """Get document details""" try: doc = await self.document_repository.get_by_id(doc_id, self.db) if not doc: raise HTTPException(status_code=404, detail="Document not found") - + # Check access through knowledge base - await self.kb_service.get_knowledge_base(doc.knowledge_base_id, current_user) - + await self.kb_service.get_knowledge_base( + doc.knowledge_base_id, current_user + ) + return doc except HTTPException: raise @@ -151,9 +145,7 @@ async def get_document( raise HTTPException(status_code=500, detail=str(e)) async def list_documents( - self, - kb_id: str, - current_user: UserResponse + self, kb_id: str, current_user: UserResponse ) -> List[DocumentResponse]: """List all documents in a knowledge base""" try: @@ -165,113 +157,110 @@ async def list_documents( raise HTTPException(status_code=500, detail=str(e)) async def update_document( - self, - doc_id: str, - doc_update: DocumentUpdate, - current_user: UserResponse + self, doc_id: str, doc_update: DocumentUpdate, current_user: UserResponse ) -> Document: """Update document metadata""" try: # Get document and check access doc = await self.get_document(doc_id, current_user) - kb = await self.kb_service.get_knowledge_base(doc.knowledge_base_id, current_user) - + kb = await self.kb_service.get_knowledge_base( + doc.knowledge_base_id, current_user + ) + # Only owner or admin can update - if current_user.role != UserRole.ADMIN and str(kb.user_id) != str(current_user.id): + if current_user.role != UserRole.ADMIN and str(kb.user_id) != str( + current_user.id + ): raise HTTPException(status_code=403, detail="Not enough privileges") - + # Update document update_data = doc_update.model_dump(exclude_unset=True) - updated_doc = await self.document_repository.update(doc_id, update_data, self.db) + updated_doc = await self.document_repository.update( + doc_id, update_data, self.db + ) if not updated_doc: raise HTTPException(status_code=404, detail="Document not found") - + logger.info(f"Document {doc_id} updated") return updated_doc - + except HTTPException: raise except Exception as e: logger.error(f"Failed to update document {doc_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) - async def delete_document( - self, - doc_id: str, - current_user: UserResponse - ) -> None: + async def delete_document(self, doc_id: str, current_user: UserResponse) -> None: """Delete a document and its vectors""" try: # Get document and check access doc = await self.get_document(doc_id, current_user) - kb = await self.kb_service.get_knowledge_base(doc.knowledge_base_id, current_user) - + kb = await self.kb_service.get_knowledge_base( + doc.knowledge_base_id, current_user + ) + # Only owner or admin can delete - if current_user.role != UserRole.ADMIN and str(kb.user_id) != str(current_user.id): + if current_user.role != UserRole.ADMIN and str(kb.user_id) != str( + current_user.id + ): raise HTTPException(status_code=403, detail="Not enough privileges") - + # Queue vector deletion task self.celery_app.send_task( - 'app.worker.tasks.initiate_document_vector_deletion', - args=[doc_id] + "app.worker.tasks.initiate_document_vector_deletion", args=[doc_id] ) - + # Delete document success = await self.document_repository.delete(doc_id, self.db) if not success: raise HTTPException(status_code=404, detail="Document not found") - + logger.info(f"Document {doc_id} deleted") - + except HTTPException: raise except Exception as e: logger.error(f"Failed to delete document {doc_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) - + async def retry_failed_document( - self, - kb_id: str, - doc_id: str, - current_user: UserResponse + self, kb_id: str, doc_id: str, current_user: UserResponse ) -> DocumentResponse: """Retry processing a failed document""" try: # Get document and check access doc = await self.get_document(doc_id, current_user) - + # Verify this document belongs to the specified knowledge base if doc.knowledge_base_id != kb_id: raise HTTPException( - status_code=400, - detail="Document does not belong to the specified knowledge base" + status_code=400, + detail="Document does not belong to the specified knowledge base", ) - + # Check if document is in a failed state if doc.status != DocumentStatus.FAILED: raise HTTPException( status_code=400, - detail=f"Only failed documents can be retried. Current status: {doc.status}" + detail=f"Only failed documents can be retried. Current status: {doc.status}", ) - + # Update document status back to pending - update_data = { - "status": DocumentStatus.PENDING, - "error_message": None - } - updated_doc = await self.document_repository.update(doc_id, update_data, self.db) - + update_data = {"status": DocumentStatus.PENDING, "error_message": None} + updated_doc = await self.document_repository.update( + doc_id, update_data, self.db + ) + # Queue document processing task self.celery_app.send_task( - 'app.worker.tasks.initiate_document_ingestion', - args=[doc_id] + "app.worker.tasks.initiate_document_ingestion", args=[doc_id] ) - + logger.info(f"Retrying processing for document {doc_id}") return updated_doc - + except HTTPException: raise except Exception as e: logger.error(f"Failed to retry document processing for {doc_id}: {e}") - raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file + raise HTTPException(status_code=500, detail=str(e)) diff --git a/app/services/knowledge_base_service.py b/app/services/knowledge_base_service.py index 7e1ea4e..951b797 100644 --- a/app/services/knowledge_base_service.py +++ b/app/services/knowledge_base_service.py @@ -1,48 +1,57 @@ -from typing import List, Protocol -from fastapi import HTTPException, UploadFile import logging -from datetime import datetime import os +from datetime import datetime +from typing import List, Protocol + import aiofiles from celery import Celery +from fastapi import HTTPException, UploadFile from sqlalchemy.orm import Session from app.db.models.knowledge_base import KnowledgeBase from app.db.models.user import UserRole from app.repositories.knowledge_base_repository import KnowledgeBaseRepository +from app.schemas.knowledge_base import ( + KnowledgeBaseCreate, + KnowledgeBaseResponse, + KnowledgeBaseUpdate, +) from app.schemas.user import UserResponse from app.services.rag.vector_store import VectorStore -from app.schemas.knowledge_base import KnowledgeBaseCreate, KnowledgeBaseResponse, KnowledgeBaseUpdate # Set up logging logger = logging.getLogger(__name__) + class FileStorage(Protocol): """Protocol for file storage operations""" + async def save_file(self, file: UploadFile) -> str: """Save a file and return its path""" ... - + def cleanup_file(self, file_path: str) -> None: """Clean up a saved file""" ... + class LocalFileStorage: """Local filesystem implementation of FileStorage""" + def __init__(self, upload_dir: str): self.upload_dir = upload_dir - + async def save_file(self, file: UploadFile) -> str: timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") safe_filename = f"{timestamp}_{file.filename}" file_path = os.path.join(self.upload_dir, safe_filename) - - async with aiofiles.open(file_path, 'wb') as out_file: + + async with aiofiles.open(file_path, "wb") as out_file: content = await file.read() await out_file.write(content) - + return file_path - + def cleanup_file(self, file_path: str) -> None: try: if os.path.exists(file_path): @@ -50,15 +59,17 @@ def cleanup_file(self, file_path: str) -> None: except Exception as e: logger.error(f"Failed to cleanup file {file_path}: {e}") + class KnowledgeBaseService: """Service for knowledge base operations""" + def __init__( self, repository: KnowledgeBaseRepository, vector_store: VectorStore, file_storage: LocalFileStorage, celery_app: Celery, - db: Session + db: Session, ): self.repository = repository self.vector_store = vector_store @@ -67,56 +78,58 @@ def __init__( self.db = db async def create_knowledge_base( - self, - kb_data: KnowledgeBaseCreate, - current_user: UserResponse + self, kb_data: KnowledgeBaseCreate, current_user: UserResponse ) -> KnowledgeBase: """Create a new knowledge base""" try: knowledge_base = KnowledgeBase( name=kb_data.name, description=kb_data.description, - user_id=current_user.id - ) - kb = await self.repository.create( - knowledge_base, - self.db + user_id=current_user.id, ) + kb = await self.repository.create(knowledge_base, self.db) return kb except Exception as e: logger.error(f"Error creating knowledge base: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to create knowledge base: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to create knowledge base: {str(e)}" + ) async def get_knowledge_base( - self, - kb_id: str, - current_user: UserResponse + self, kb_id: str, current_user: UserResponse ) -> KnowledgeBaseResponse: """Get a knowledge base by ID""" kb = await self.repository.get_by_id(kb_id, self.db) if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") - + # Check if user has access if str(kb.user_id) != str(current_user.id): # Check if user is an admin or owner (role-based access) - if current_user.role != UserRole.ADMIN and current_user.role != UserRole.OWNER: + if ( + current_user.role != UserRole.ADMIN + and current_user.role != UserRole.OWNER + ): # Check if the KB is explicitly shared with this user - is_shared = await self.repository.is_shared_with_user(kb_id, current_user.id, self.db) + is_shared = await self.repository.is_shared_with_user( + kb_id, current_user.id, self.db + ) if not is_shared: - raise HTTPException(status_code=403, detail="You don't have access to this knowledge base") - + raise HTTPException( + status_code=403, + detail="You don't have access to this knowledge base", + ) + return kb async def list_knowledge_bases( - self, - current_user: UserResponse + self, current_user: UserResponse ) -> List[KnowledgeBase]: """List all knowledge bases accessible to the user""" # For admin, return all knowledge bases if current_user.role == UserRole.ADMIN: return await self.repository.list_all(self.db) - # For owner, return only knowledge bases they own + # For owner, return only knowledge bases they own elif current_user.role == UserRole.OWNER: return await self.repository.list_by_owner(current_user.id, self.db) # For regular users, return an empty list @@ -124,149 +137,181 @@ async def list_knowledge_bases( return [] # Regular users see an empty list, but don't get an error async def update_knowledge_base( - self, - kb_id: str, - kb_data: KnowledgeBaseUpdate, - current_user: UserResponse + self, kb_id: str, kb_data: KnowledgeBaseUpdate, current_user: UserResponse ) -> KnowledgeBase: """Update a knowledge base""" kb = await self.repository.get_by_id(kb_id, self.db) if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") - + # Check if user has permission to update - if str(kb.user_id) != str(current_user.id) and current_user.role != UserRole.ADMIN: - raise HTTPException(status_code=403, detail="You don't have permission to update this knowledge base") - + if ( + str(kb.user_id) != str(current_user.id) + and current_user.role != UserRole.ADMIN + ): + raise HTTPException( + status_code=403, + detail="You don't have permission to update this knowledge base", + ) + try: updated_kb = await self.repository.update( - self.db, - kb_id=kb_id, - name=kb_data.name, - description=kb_data.description + self.db, kb_id=kb_id, name=kb_data.name, description=kb_data.description ) return updated_kb except Exception as e: logger.error(f"Error updating knowledge base: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to update knowledge base: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to update knowledge base: {str(e)}" + ) async def delete_knowledge_base( - self, - kb_id: str, - current_user: UserResponse + self, kb_id: str, current_user: UserResponse ) -> None: """Delete a knowledge base and all its documents""" kb = await self.repository.get_by_id(kb_id, self.db) if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") - + # Check if user has permission to delete - if str(kb.user_id) != str(current_user.id) and current_user.role != UserRole.ADMIN: - raise HTTPException(status_code=403, detail="You don't have permission to delete this knowledge base") - + if ( + str(kb.user_id) != str(current_user.id) + and current_user.role != UserRole.ADMIN + ): + raise HTTPException( + status_code=403, + detail="You don't have permission to delete this knowledge base", + ) + try: # Delete all documents in the knowledge base documents = await self.repository.get_documents(kb_id, self.db) for doc in documents: # Delete document vectors await self.vector_store.delete_document_chunks(doc.id, kb_id) - + # Delete the knowledge base (this will cascade delete documents) await self.repository.delete(kb_id, self.db) except Exception as e: logger.error(f"Error deleting knowledge base: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to delete knowledge base: {str(e)}") - + raise HTTPException( + status_code=500, detail=f"Failed to delete knowledge base: {str(e)}" + ) + async def share_knowledge_base( - self, - kb_id: str, - user_id: str, - current_user: UserResponse + self, kb_id: str, user_id: str, current_user: UserResponse ) -> bool: """Share a knowledge base with another user""" # Get the knowledge base kb = await self.repository.get_by_id(kb_id, self.db) if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") - + # Check if the current user has permission to share - if str(kb.user_id) != str(current_user.id) and current_user.role != UserRole.ADMIN and current_user.role != UserRole.OWNER: - raise HTTPException(status_code=403, detail="You don't have permission to share this knowledge base") - + if ( + str(kb.user_id) != str(current_user.id) + and current_user.role != UserRole.ADMIN + and current_user.role != UserRole.OWNER + ): + raise HTTPException( + status_code=403, + detail="You don't have permission to share this knowledge base", + ) + # Get the user to share with from app.repositories.user_repository import UserRepository + user_repository = UserRepository() user = await user_repository.get_by_id(user_id, self.db) if not user: raise HTTPException(status_code=404, detail="User not found") - + # Check if already shared with this user - is_already_shared = await self.repository.is_shared_with_user(kb_id, user_id, self.db) + is_already_shared = await self.repository.is_shared_with_user( + kb_id, user_id, self.db + ) if is_already_shared: return True # Already shared, consider it a success - + try: # Add the sharing relationship await self.repository.add_user_access(kb_id, user_id, self.db) return True except Exception as e: logger.error(f"Error sharing knowledge base: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to share knowledge base: {str(e)}") - + raise HTTPException( + status_code=500, detail=f"Failed to share knowledge base: {str(e)}" + ) + async def unshare_knowledge_base( - self, - kb_id: str, - user_id: str, - current_user: UserResponse + self, kb_id: str, user_id: str, current_user: UserResponse ) -> bool: """Remove a user's access to a knowledge base""" # Get the knowledge base kb = await self.repository.get_by_id(kb_id, self.db) if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") - + # Check if the current user has permission to unshare - if str(kb.user_id) != str(current_user.id) and current_user.role != UserRole.ADMIN and current_user.role != UserRole.OWNER: - raise HTTPException(status_code=403, detail="You don't have permission to modify sharing for this knowledge base") - + if ( + str(kb.user_id) != str(current_user.id) + and current_user.role != UserRole.ADMIN + and current_user.role != UserRole.OWNER + ): + raise HTTPException( + status_code=403, + detail="You don't have permission to modify sharing for this knowledge base", + ) + try: # Remove the sharing relationship await self.repository.remove_user_access(kb_id, user_id, self.db) return True except Exception as e: logger.error(f"Error unsharing knowledge base: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to unshare knowledge base: {str(e)}") - - async def list_shared_users( - self, - kb_id: str, - current_user: UserResponse - ) -> List: + raise HTTPException( + status_code=500, detail=f"Failed to unshare knowledge base: {str(e)}" + ) + + async def list_shared_users(self, kb_id: str, current_user: UserResponse) -> List: """List all users who have access to a knowledge base""" # Get the knowledge base kb = await self.repository.get_by_id(kb_id, self.db) if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") - + # Check if the current user has permission to view sharing info - if str(kb.user_id) != str(current_user.id) and current_user.role != UserRole.ADMIN and current_user.role != UserRole.OWNER: - raise HTTPException(status_code=403, detail="You don't have permission to view sharing information for this knowledge base") - + if ( + str(kb.user_id) != str(current_user.id) + and current_user.role != UserRole.ADMIN + and current_user.role != UserRole.OWNER + ): + raise HTTPException( + status_code=403, + detail="You don't have permission to view sharing information for this knowledge base", + ) + try: return await self.repository.get_shared_users(kb_id, self.db) except Exception as e: logger.error(f"Error listing shared users: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to list shared users: {str(e)}") - + raise HTTPException( + status_code=500, detail=f"Failed to list shared users: {str(e)}" + ) + async def list_shared_knowledge_bases( - self, - current_user: UserResponse + self, current_user: UserResponse ) -> List[KnowledgeBaseResponse]: """List all knowledge bases shared with the current user""" try: - shared_kbs = await self.repository.list_shared_with_user(current_user.id, self.db) + shared_kbs = await self.repository.list_shared_with_user( + current_user.id, self.db + ) # Always return a list, even if empty return shared_kbs if shared_kbs else [] except Exception as e: logger.error(f"Error listing shared knowledge bases: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to list shared knowledge bases: {str(e)}") \ No newline at end of file + raise HTTPException( + status_code=500, + detail=f"Failed to list shared knowledge bases: {str(e)}", + ) diff --git a/app/services/llm/factory.py b/app/services/llm/factory.py index cc6c9e3..fe32383 100644 --- a/app/services/llm/factory.py +++ b/app/services/llm/factory.py @@ -5,12 +5,11 @@ It abstracts away the specific LLM provider implementations to allow easy switching between providers. """ -from typing import Dict, List, Any, Optional, Union, Literal import logging -import os -import json -from enum import Enum from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + from pydantic import BaseModel, Field from app.core.config import settings @@ -19,102 +18,115 @@ # ----- Type Definitions ----- + class Role(str, Enum): """Message roles in the OpenAI Chat API format.""" + SYSTEM = "system" USER = "user" ASSISTANT = "assistant" FUNCTION = "function" + class Message(BaseModel): """Message in the OpenAI Chat API format.""" + role: Role content: str name: Optional[str] = None - + + class CompletionOptions(BaseModel): """Common options for completion requests across providers.""" + temperature: float = Field(default=0.7, ge=0, le=1) max_tokens: Optional[int] = None top_p: float = Field(default=1.0, ge=0, le=1) stream: bool = False stop: Optional[Union[str, List[str]]] = None - + + class CompletionResult(BaseModel): """Standardized result from any LLM provider.""" + content: str role: str = "assistant" finish_reason: Optional[str] = None model: str - usage: Dict[str, int] = Field(default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}) + usage: Dict[str, int] = Field( + default_factory=lambda: { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + ) raw_response: Optional[Any] = None # Provider-specific raw response + # ----- LLM Provider Interfaces ----- + class LLMProvider(ABC): """Abstract base class for all LLM providers.""" - + @abstractmethod async def complete( - self, - messages: List[Message], - options: Optional[CompletionOptions] = None + self, messages: List[Message], options: Optional[CompletionOptions] = None ) -> CompletionResult: """ Send a completion request to the LLM provider. - + Args: messages: List of messages in the conversation options: Completion options - + Returns: CompletionResult with the LLM's response """ - pass + # ----- Provider Implementations ----- + class GeminiProvider(LLMProvider): """Google Gemini implementation.""" - + def __init__(self, api_key: Optional[str] = None, model: str = "gemini-2.0-flash"): """ Initialize the Gemini provider. - + Args: api_key: Gemini API key (defaults to settings) model: Gemini model to use """ import google.generativeai as genai - + self.api_key = api_key or settings.GEMINI_API_KEY genai.configure(api_key=self.api_key) self.model_name = model self.model = genai.GenerativeModel(model) logger.info(f"Initialized GeminiProvider with model: {model}") - + async def complete( - self, - messages: List[Message], - options: Optional[CompletionOptions] = None + self, messages: List[Message], options: Optional[CompletionOptions] = None ) -> CompletionResult: """ Send a completion request to Gemini. - + Args: messages: List of messages in the conversation options: Completion options - + Returns: CompletionResult with the LLM's response """ if options is None: options = CompletionOptions() - + try: # Convert OpenAI-style messages to Gemini format gemini_messages = [] - + for msg in messages: # Gemini only supports user and model roles directly # System messages need to be injected into the first user message @@ -126,91 +138,109 @@ async def complete( elif msg.role == Role.ASSISTANT: gemini_messages.append({"role": "model", "parts": [msg.content]}) # Function messages aren't directly supported, could be converted to text - + # Handle system message by prepending to the first user message if present system_messages = [msg for msg in messages if msg.role == Role.SYSTEM] - if system_messages and gemini_messages and gemini_messages[0]["role"] == "user": + if ( + system_messages + and gemini_messages + and gemini_messages[0]["role"] == "user" + ): system_content = "\n".join([msg.content for msg in system_messages]) - gemini_messages[0]["parts"][0] = f"{system_content}\n\n{gemini_messages[0]['parts'][0]}" - + gemini_messages[0]["parts"][ + 0 + ] = f"{system_content}\n\n{gemini_messages[0]['parts'][0]}" + # Set up generation config from options generation_config = { "temperature": options.temperature, "top_p": options.top_p, "max_output_tokens": options.max_tokens, - "stop_sequences": options.stop if isinstance(options.stop, list) else [options.stop] if options.stop else None + "stop_sequences": ( + options.stop + if isinstance(options.stop, list) + else [options.stop] if options.stop else None + ), } - + # Remove None values - generation_config = {k: v for k, v in generation_config.items() if v is not None} - + generation_config = { + k: v for k, v in generation_config.items() if v is not None + } + # For empty conversations or just system message, create a simple content generation - if not gemini_messages or (len(gemini_messages) == 1 and "role" in gemini_messages[0] and gemini_messages[0]["role"] == "model"): + if not gemini_messages or ( + len(gemini_messages) == 1 + and "role" in gemini_messages[0] + and gemini_messages[0]["role"] == "model" + ): # Get the system message if any prompt = system_messages[0].content if system_messages else "" response = self.model.generate_content( - prompt, - generation_config=generation_config + prompt, generation_config=generation_config ) else: # For chat-style conversations - chat = self.model.start_chat(history=gemini_messages[:-1] if gemini_messages else []) + chat = self.model.start_chat( + history=gemini_messages[:-1] if gemini_messages else [] + ) response = chat.send_message( gemini_messages[-1]["parts"][0] if gemini_messages else "", - generation_config=generation_config + generation_config=generation_config, ) - + return CompletionResult( content=response.text, model=self.model_name, finish_reason="stop", # Gemini doesn't provide this explicitly - raw_response=response + raw_response=response, ) - + except Exception as e: logger.error(f"Error completing with Gemini: {e}", exc_info=True) raise + class OpenAIProvider(LLMProvider): """OpenAI ChatGPT implementation.""" - + def __init__(self, api_key: Optional[str] = None, model: str = "gpt-3.5-turbo"): """ Initialize the OpenAI provider. - + Args: api_key: OpenAI API key (defaults to settings) model: OpenAI model to use """ from openai import AsyncOpenAI - + self.api_key = api_key or settings.OPENAI_API_KEY self.model_name = model self.client = AsyncOpenAI(api_key=self.api_key) logger.info(f"Initialized OpenAIProvider with model: {model}") - + async def complete( - self, - messages: List[Message], - options: Optional[CompletionOptions] = None + self, messages: List[Message], options: Optional[CompletionOptions] = None ) -> CompletionResult: """ Send a completion request to OpenAI. - + Args: messages: List of messages in the conversation options: Completion options - + Returns: CompletionResult with the LLM's response """ if options is None: options = CompletionOptions() - + try: # Convert to OpenAI format (already mostly compatible) - openai_messages = [{"role": msg.role.value, "content": msg.content} for msg in messages] - + openai_messages = [ + {"role": msg.role.value, "content": msg.content} for msg in messages + ] + # Create completion response = await self.client.chat.completions.create( model=self.model_name, @@ -219,12 +249,12 @@ async def complete( max_tokens=options.max_tokens, top_p=options.top_p, stream=options.stream, - stop=options.stop + stop=options.stop, ) - + # Handle the response choice = response.choices[0] - + return CompletionResult( content=choice.message.content, model=self.model_name, @@ -232,55 +262,56 @@ async def complete( usage={ "prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens + "total_tokens": response.usage.total_tokens, }, - raw_response=response + raw_response=response, ) - + except Exception as e: logger.error(f"Error completing with OpenAI: {e}", exc_info=True) raise + class AnthropicProvider(LLMProvider): """Anthropic Claude implementation.""" - - def __init__(self, api_key: Optional[str] = None, model: str = "claude-3-sonnet-20240229"): + + def __init__( + self, api_key: Optional[str] = None, model: str = "claude-3-sonnet-20240229" + ): """ Initialize the Anthropic provider. - + Args: api_key: Anthropic API key (defaults to settings) model: Anthropic model to use """ import anthropic - + self.api_key = api_key or settings.ANTHROPIC_API_KEY self.model_name = model self.client = anthropic.AsyncAnthropic(api_key=self.api_key) logger.info(f"Initialized AnthropicProvider with model: {model}") - + async def complete( - self, - messages: List[Message], - options: Optional[CompletionOptions] = None + self, messages: List[Message], options: Optional[CompletionOptions] = None ) -> CompletionResult: """ Send a completion request to Anthropic. - + Args: messages: List of messages in the conversation options: Completion options - + Returns: CompletionResult with the LLM's response """ if options is None: options = CompletionOptions() - + try: # Convert to Anthropic format anthropic_messages = [] - + for msg in messages: if msg.role == Role.SYSTEM: # Anthropic has a separate system parameter @@ -288,83 +319,89 @@ async def complete( elif msg.role == Role.USER: anthropic_messages.append({"role": "user", "content": msg.content}) elif msg.role == Role.ASSISTANT: - anthropic_messages.append({"role": "assistant", "content": msg.content}) + anthropic_messages.append( + {"role": "assistant", "content": msg.content} + ) # Function messages aren't directly supported - + # Make the completion request response = await self.client.messages.create( model=self.model_name, messages=anthropic_messages, - system=system_content if 'system_content' in locals() else None, + system=system_content if "system_content" in locals() else None, temperature=options.temperature, max_tokens=options.max_tokens or 1024, top_p=options.top_p, - stream=options.stream + stream=options.stream, # Anthropic doesn't support stop sequences directly in the same way ) - + # Create the completion result return CompletionResult( content=response.content[0].text, model=self.model_name, usage={}, # Anthropic doesn't provide token usage in the same format - raw_response=response + raw_response=response, ) - + except Exception as e: logger.error(f"Error completing with Anthropic: {e}", exc_info=True) raise + # ----- Factory Implementation ----- + class LLMFactory: """Factory for creating LLM provider instances.""" - + @staticmethod def create( provider: Literal["openai", "gemini", "anthropic"] = None, model: Optional[str] = None, - api_key: Optional[str] = None + api_key: Optional[str] = None, ) -> LLMProvider: """ Create an LLM provider instance. - + Args: provider: The LLM provider to use (defaults to settings.LLM_PROVIDER) model: The specific model to use (defaults to provider-specific default) api_key: API key for the provider (defaults to settings) - + Returns: An instance of the requested LLM provider """ # Default to configuration provider = provider or settings.LLM_PROVIDER - + if provider == "openai": return OpenAIProvider(api_key=api_key, model=model or "gpt-3.5-turbo") elif provider == "gemini": return GeminiProvider(api_key=api_key, model=model or "gemini-2.0-flash") elif provider == "anthropic": - return AnthropicProvider(api_key=api_key, model=model or "claude-3-sonnet-20240229") + return AnthropicProvider( + api_key=api_key, model=model or "claude-3-sonnet-20240229" + ) else: raise ValueError(f"Unsupported LLM provider: {provider}") - + @staticmethod async def complete( messages: List[Union[Dict[str, str], Message]], provider: Optional[str] = None, model: Optional[str] = None, - options: Optional[CompletionOptions] = None + options: Optional[CompletionOptions] = None, ) -> CompletionResult: """ Convenience method to create a provider and complete in one step. - + Args: messages: List of messages in the conversation provider: The LLM provider to use (defaults to settings.LLM_PROVIDER) model: The specific model to use (defaults to provider-specific default) options: Completion options - + Returns: CompletionResult with the LLM's response """ @@ -372,32 +409,34 @@ async def complete( processed_messages = [] for msg in messages: if isinstance(msg, dict): - processed_messages.append(Message( - role=msg.get("role", "user"), - content=msg.get("content", ""), - name=msg.get("name") - )) + processed_messages.append( + Message( + role=msg.get("role", "user"), + content=msg.get("content", ""), + name=msg.get("name"), + ) + ) else: processed_messages.append(msg) - + # Create provider and complete llm = LLMFactory.create(provider=provider, model=model) return await llm.complete(processed_messages, options=options) - + @staticmethod async def embed_text( text: str, provider: Optional[str] = None, - model: Optional[str] = None # Default will come from settings + model: Optional[str] = None, # Default will come from settings ) -> List[float]: """ Generate an embedding for the provided text. - + Args: text: The text to embed provider: The provider to use for embeddings (defaults to Google) model: The embedding model to use (defaults to settings.EMBEDDING_MODEL) - + Returns: List of floats representing the embedding """ @@ -405,26 +444,25 @@ async def embed_text( # Use the model specified in settings if not provided embedding_model = model or settings.EMBEDDING_MODEL logger.info(f"Generating embedding using model: {embedding_model}") - + # Currently only Google's text-embedding-004 is supported # In the future, we can extend this to support other providers - + # Use Google's embedding API directly for now from google import genai from google.genai.types import ContentEmbedding - + # Initialize client with API key client = genai.Client(api_key=settings.GEMINI_API_KEY) - + # Get embedding result: ContentEmbedding = client.models.embed_content( - model=embedding_model, - contents=text + model=embedding_model, contents=text ) - + # Return the embedding values return result.embeddings[0].values - + except Exception as e: logger.error(f"Error generating embedding: {e}", exc_info=True) # If embedding fails, return a zero vector of default dimension @@ -432,23 +470,24 @@ async def embed_text( # Consider adding appropriate error handling in the calling code return [0.0] * 768 + # Convenience function async def complete( messages: List[Union[Dict[str, str], Message]], provider: Optional[str] = None, model: Optional[str] = None, - options: Optional[CompletionOptions] = None + options: Optional[CompletionOptions] = None, ) -> CompletionResult: """ Complete a conversation with the configured LLM provider. - + Args: messages: List of messages in the conversation provider: The LLM provider to use (defaults to settings.LLM_PROVIDER) model: The specific model to use (defaults to provider-specific default) options: Completion options - + Returns: CompletionResult with the LLM's response """ - return await LLMFactory.complete(messages, provider, model, options) \ No newline at end of file + return await LLMFactory.complete(messages, provider, model, options) diff --git a/app/services/message_service.py b/app/services/message_service.py index a982970..6c2ab97 100644 --- a/app/services/message_service.py +++ b/app/services/message_service.py @@ -1,24 +1,31 @@ +import logging from typing import List + from fastapi import HTTPException -import logging from sqlalchemy.orm import Session -from app.db.models.message import Message, MessageContentType, MessageKind, MessageStatus +from app.db.models.message import ( + Message, + MessageContentType, + MessageKind, + MessageStatus, +) from app.repositories.message_repository import MessageRepository +from app.schemas.message import MessageCreate, MessageResponse from app.schemas.user import UserResponse from app.services.conversation_service import ConversationService from app.worker.celery import celery_app -from app.schemas.message import MessageCreate, MessageResponse, MessageType # Set up logging logger = logging.getLogger(__name__) + class MessageService: def __init__( self, message_repository: MessageRepository, conversation_service: ConversationService, - db: Session + db: Session, ): self.repository = message_repository self.conversation_service = conversation_service @@ -34,8 +41,7 @@ async def create_message( try: # Check conversation access conversation = await self.conversation_service.get_conversation( - conversation_id, - current_user + conversation_id, current_user ) # Create user message @@ -46,13 +52,12 @@ async def create_message( content_type=payload.content_type, kind=MessageKind.USER, status=MessageStatus.RECEIVED, - user_id=current_user.id + user_id=current_user.id, ) - user_message = await self.repository.create( - message, - self.db + user_message = await self.repository.create(message, self.db) + logger.info( + f"User message {user_message.id} created in conversation {conversation_id}" ) - logger.info(f"User message {user_message.id} created in conversation {conversation_id}") # Create assistant message message = Message( @@ -62,43 +67,42 @@ async def create_message( content_type=MessageContentType.TEXT, kind=MessageKind.ASSISTANT, status=MessageStatus.PROCESSING, - user_id=current_user.id + user_id=current_user.id, ) - assistant_message = await self.repository.create( - message, - self.db + assistant_message = await self.repository.create(message, self.db) + logger.info( + f"Assistant message {assistant_message.id} created in conversation {conversation_id}" ) - logger.info(f"Assistant message {assistant_message.id} created in conversation {conversation_id}") # Queue RAG processing task celery_app.send_task( - 'app.worker.tasks.initiate_rag_retrieval', - args=[ - user_message.id, - assistant_message.id - ] + "app.worker.tasks.initiate_rag_retrieval", + args=[user_message.id, assistant_message.id], ) return user_message except Exception as e: - logger.error(f"Failed to create message in conversation {conversation_id}: {e}") + logger.error( + f"Failed to create message in conversation {conversation_id}: {e}" + ) raise HTTPException(status_code=500, detail=str(e)) async def get_message( - self, - conversation_id: str, - message_id: str, - current_user: UserResponse + self, conversation_id: str, message_id: str, current_user: UserResponse ) -> MessageResponse: """Get message details""" try: # Check conversation access first - await self.conversation_service.get_conversation(conversation_id, current_user) + await self.conversation_service.get_conversation( + conversation_id, current_user + ) message = await self.repository.get_by_id(message_id, self.db) if not message: - logger.warning(f"Message {message_id} not found in conversation {conversation_id}") + logger.warning( + f"Message {message_id} not found in conversation {conversation_id}" + ) raise HTTPException(status_code=404, detail="Message not found") return message @@ -109,18 +113,24 @@ async def get_message( raise HTTPException(status_code=500, detail=str(e)) async def list_messages( - self, - conversation_id: str, - current_user: UserResponse + self, conversation_id: str, current_user: UserResponse ) -> List[MessageResponse]: """List all messages in a conversation""" try: # Check conversation access first - await self.conversation_service.get_conversation(conversation_id, current_user) + await self.conversation_service.get_conversation( + conversation_id, current_user + ) - messages = await self.repository.list_by_conversation(conversation_id, self.db) - logger.info(f"Retrieved {len(messages)} messages from conversation {conversation_id}") + messages = await self.repository.list_by_conversation( + conversation_id, self.db + ) + logger.info( + f"Retrieved {len(messages)} messages from conversation {conversation_id}" + ) return messages except Exception as e: - logger.error(f"Failed to list messages for conversation {conversation_id}: {e}") - raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file + logger.error( + f"Failed to list messages for conversation {conversation_id}: {e}" + ) + raise HTTPException(status_code=500, detail=str(e)) diff --git a/app/services/query_router.py b/app/services/query_router.py index d1f7d66..aaeeb22 100644 --- a/app/services/query_router.py +++ b/app/services/query_router.py @@ -1,25 +1,28 @@ -from typing import List, Dict, Any, Set, Optional -from app.services.rag.vector_store import get_vector_store -from app.core.config import settings -from app.db.models.knowledge_base import DocumentType -from app.services.rag_service import get_rag_service -from app.services.tag_service import get_tag_service +import json import logging from functools import lru_cache -import json +from typing import Any, Dict, List, Optional -# Replace direct Gemini import with our LLM factory -from app.services.llm.factory import LLMFactory, Message, Role, CompletionOptions +from app.core.config import settings from app.core.prompts import get_prompt, register_prompt -from app.db.models.question import AnswerType +from app.db.models.knowledge_base import DocumentType + +# Replace direct Gemini import with our LLM factory +from app.services.llm.factory import CompletionOptions, LLMFactory, Message, Role +from app.services.rag.vector_store import get_vector_store +from app.services.rag_service import get_rag_service +from app.services.tag_service import get_tag_service logger = logging.getLogger(__name__) # Register query router prompts -register_prompt("query_router", "analyze_query", """ +register_prompt( + "query_router", + "analyze_query", + """ You are a query router for a hybrid retrieval system. Your job is to determine whether to route a user query to: -1. TAG (Table Augmented Generation) - for queries that need access to structured data and would be best answered with SQL +1. TAG (Table Augmented Generation) - for queries that need access to structured data and would be best answered with SQL 2. RAG (Retrieval Augmented Generation) - for queries about unstructured text/content Routes to TAG when: @@ -39,120 +42,134 @@ and explain your reasoning. Return your answer as a JSON object with the keys: service, confidence, reasoning. User Query: {{ query }} -""") +""", +) # Document types categorization UNSTRUCTURED_DOCUMENT_TYPES = { DocumentType.PDF, DocumentType.TXT, DocumentType.MARKDOWN, - DocumentType.HTML + DocumentType.HTML, } STRUCTURED_DOCUMENT_TYPES = { DocumentType.CSV, DocumentType.EXCEL, DocumentType.DOC, - DocumentType.DOCX + DocumentType.DOCX, } + class QueryRouter: """ QueryRouter handles both routing and dispatching of queries. It analyzes document types to decide which service to use, then dispatches the query to the appropriate service. """ - + def __init__(self): """Initialize the query router with service instances""" # Get singleton instances of services self.rag_service = get_rag_service() self.tag_service = get_tag_service() logger.info("Initialized QueryRouter") - + async def route_and_dispatch( self, query: str, metadata_filter: Optional[Dict[str, Any]] = None, top_k: int = 5, similarity_threshold: float = 0.3, - force_service: Optional[str] = None + force_service: Optional[str] = None, ) -> Dict[str, Any]: """ Route a query to the appropriate service and dispatch it. - + Args: query: The query to route and dispatch metadata_filter: Optional filter for the query top_k: Number of results to retrieve (for RAG service) similarity_threshold: Similarity threshold (for RAG service) force_service: Optional service to force routing to - + Returns: Response from the service that processed the query """ try: logger.info(f"Routing and dispatching query: '{query}'") - + # First, check if we have a direct answer in the questions index - kb_id = metadata_filter.get("knowledge_base_id") if metadata_filter else None - + kb_id = ( + metadata_filter.get("knowledge_base_id") if metadata_filter else None + ) + # Only search questions if we're not forcing a specific service if not force_service: direct_answer = await self._check_questions_index(query, kb_id) if direct_answer: logger.info("Found direct answer in questions index") return direct_answer - + # Continue with the regular routing flow if no direct answer was found # Use the specified service if provided if force_service: service = force_service.lower() if service not in ["rag", "tag"]: - logger.warning(f"Invalid force_service value: {force_service}. Defaulting to RAG.") + logger.warning( + f"Invalid force_service value: {force_service}. Defaulting to RAG." + ) service = "rag" - + # Create a routing info dict routing_info = { "service": service, "confidence": 1.0, # High confidence since it's explicitly forced "reasoning": f"Service was explicitly forced to {service}", - "fallback": False + "fallback": False, } else: # Get routing decision from LLM-based analysis routing_info = await self.analyze_query(query) service = routing_info.get("service", "rag") - + # Log which service was selected logger.info(f"Selected service: {service}") - + # Dispatch to the appropriate service if service == "tag": logger.info(f"Dispatching to TAG service: '{query}'") # TAG service no longer requires top_k and similarity_threshold response = await self.tag_service.retrieve( - knowledge_base_id=metadata_filter.get("knowledge_base_id") if metadata_filter else None, + knowledge_base_id=( + metadata_filter.get("knowledge_base_id") + if metadata_filter + else None + ), query=query, - metadata_filter=metadata_filter + metadata_filter=metadata_filter, ) response["service"] = "tag" else: # Default to RAG logger.info(f"Dispatching to RAG service: '{query}'") response = await self.rag_service.retrieve( - knowledge_base_id=metadata_filter.get("knowledge_base_id") if metadata_filter else None, + knowledge_base_id=( + metadata_filter.get("knowledge_base_id") + if metadata_filter + else None + ), query=query, top_k=top_k, similarity_threshold=similarity_threshold, - metadata_filter=metadata_filter + metadata_filter=metadata_filter, ) response["service"] = "rag" - + # Add routing information to the response response["routing_info"] = routing_info - + return response - + except Exception as e: logger.error(f"Error in route_and_dispatch: {e}", exc_info=True) # Return a basic error response @@ -165,90 +182,105 @@ async def route_and_dispatch( "service": "unknown", "confidence": 0, "reasoning": f"Error: {str(e)}", - "fallback": True + "fallback": True, }, - "error": str(e) + "error": str(e), } - - async def _check_questions_index(self, query: str, knowledge_base_id: Optional[str] = None) -> Optional[Dict[str, Any]]: + + async def _check_questions_index( + self, query: str, knowledge_base_id: Optional[str] = None + ) -> Optional[Dict[str, Any]]: """ Check if there's a direct answer in the questions index. - + Args: query: The user query knowledge_base_id: Optional ID of the knowledge base to search in - + Returns: Response with the direct answer found in questions index, None otherwise """ try: - logger.info(f"=== QUESTIONS INDEX CHECK ===") + logger.info("=== QUESTIONS INDEX CHECK ===") logger.info(f"Checking questions index for direct answer to: '{query}'") - + # Higher threshold for questions to ensure high quality matches QUESTIONS_SIMILARITY_THRESHOLD = 0.5 logger.info(f"Using similarity threshold: {QUESTIONS_SIMILARITY_THRESHOLD}") - + if not knowledge_base_id: - logger.info("No knowledge base ID provided, cannot search questions index") + logger.info( + "No knowledge base ID provided, cannot search questions index" + ) return None - + # Get the questions vector store - logger.info(f"Getting vector store for questions index: {settings.PINECONE_QUESTIONS_INDEX_NAME}") + logger.info( + f"Getting vector store for questions index: {settings.PINECONE_QUESTIONS_INDEX_NAME}" + ) questions_vector_store = get_vector_store( - store_type="pinecone", - index_name=settings.PINECONE_QUESTIONS_INDEX_NAME + store_type="pinecone", index_name=settings.PINECONE_QUESTIONS_INDEX_NAME ) - + # Log the search parameters - logger.info(f"Searching questions with: query='{query}', knowledge_base_id='{knowledge_base_id}', threshold={QUESTIONS_SIMILARITY_THRESHOLD}") - + logger.info( + f"Searching questions with: query='{query}', knowledge_base_id='{knowledge_base_id}', threshold={QUESTIONS_SIMILARITY_THRESHOLD}" + ) + # Search for similar questions in the knowledge base namespace try: results = await questions_vector_store.search_similar( query=query, knowledge_base_id=knowledge_base_id, # Use knowledge_base_id as namespace limit=1, # We only need the top result - similarity_threshold=QUESTIONS_SIMILARITY_THRESHOLD + similarity_threshold=QUESTIONS_SIMILARITY_THRESHOLD, ) except Exception as e: logger.error(f"Error searching questions index: {e}", exc_info=True) return None - + # If we have a match if results and len(results) > 0: top_match = results[0] - match_score = top_match.get('score', 0) - - logger.info(f"Found question match with score {match_score} (threshold: {QUESTIONS_SIMILARITY_THRESHOLD})") - + match_score = top_match.get("score", 0) + + logger.info( + f"Found question match with score {match_score} (threshold: {QUESTIONS_SIMILARITY_THRESHOLD})" + ) + # Only use if the match is above our threshold if match_score >= QUESTIONS_SIMILARITY_THRESHOLD: - content = top_match.get('content', '') - metadata = top_match.get('metadata', {}) + content = top_match.get("content", "") + metadata = top_match.get("metadata", {}) logger.info(f"Match content: {content[:100]}...") - + # Get question and answer from metadata if available - matched_question = metadata.get('question', '') - matched_answer = metadata.get('answer', '') - + matched_question = metadata.get("question", "") + matched_answer = metadata.get("answer", "") + # If not in metadata, try to parse from content as fallback if not matched_question or not matched_answer: - logger.info("Question/answer not found in metadata, trying to parse from content") - parts = content.split('\nAnswer: ') + logger.info( + "Question/answer not found in metadata, trying to parse from content" + ) + parts = content.split("\nAnswer: ") if len(parts) == 2: - matched_question = parts[0].replace('Question: ', '') + matched_question = parts[0].replace("Question: ", "") matched_answer = parts[1] else: - logger.warning(f"Could not parse question/answer from content: {content}") + logger.warning( + f"Could not parse question/answer from content: {content}" + ) return None - + logger.info(f"Matched question: {matched_question}") logger.info(f"Matched answer preview: {matched_answer[:100]}...") - + # Use LLM to generate the final answer based on the matched question and answer - logger.info("Generating answer with LLM using matched question and answer") - + logger.info( + "Generating answer with LLM using matched question and answer" + ) + # Create prompt for the LLM prompt = f"""You are answering a user question based on an existing similar question and answer pair. @@ -261,83 +293,86 @@ async def _check_questions_index(self, query: str, knowledge_base_id: Optional[s Please provide an accurate, helpful answer to the user's question based on this information. If the user's question is substantially different, just use the provided answer as context to help frame your response. """ - + # Create message for LLM - messages = [ - Message(role=Role.USER, content=prompt) - ] - + messages = [Message(role=Role.USER, content=prompt)] + # Set completion options options = CompletionOptions( temperature=0.3, # Lower temperature for more accurate response - max_tokens=500 + max_tokens=500, ) - + # Get LLM response llm_response = await LLMFactory.complete( - messages=messages, - options=options + messages=messages, options=options ) - + final_answer = llm_response.content.strip() logger.info(f"LLM generated answer: {final_answer[:100]}...") - + # Prepare response in the same format as other services response = { "query": query, "answer": final_answer, - "sources": [{ - # Content is the matched question for display - "content": matched_question, - # Add question-specific metadata using the updated schema - "question_id": metadata.get('question_id', ''), - "question": matched_question, - "answer": matched_answer, # Include the original answer - "answer_type": metadata.get('answer_type', 'DIRECT'), - "score": match_score, - # We no longer need these document-specific fields - # but include dummy values for backward compatibility if needed - "metadata": { - "question_id": metadata.get('question_id', ''), - "knowledge_base_id": knowledge_base_id, + "sources": [ + { + # Content is the matched question for display + "content": matched_question, + # Add question-specific metadata using the updated schema + "question_id": metadata.get("question_id", ""), + "question": matched_question, + "answer": matched_answer, # Include the original answer + "answer_type": metadata.get("answer_type", "DIRECT"), "score": match_score, - "answer": matched_answer # Include answer in metadata too + # We no longer need these document-specific fields + # but include dummy values for backward compatibility if needed + "metadata": { + "question_id": metadata.get("question_id", ""), + "knowledge_base_id": knowledge_base_id, + "score": match_score, + "answer": matched_answer, # Include answer in metadata too + }, } - }], + ], "service": "questions", "routing_info": { "service": "questions", "confidence": match_score, "reasoning": f"Found direct answer in questions index with confidence {match_score:.2f}", - "fallback": False - } + "fallback": False, + }, } - logger.info("Returning LLM-generated answer from questions index match") - logger.info(f"=== END QUESTIONS INDEX CHECK ===") + logger.info( + "Returning LLM-generated answer from questions index match" + ) + logger.info("=== END QUESTIONS INDEX CHECK ===") return response else: - logger.info(f"Match score {match_score} below threshold {QUESTIONS_SIMILARITY_THRESHOLD}, not using") + logger.info( + f"Match score {match_score} below threshold {QUESTIONS_SIMILARITY_THRESHOLD}, not using" + ) else: logger.info("No matches found in questions index") - + logger.info("No suitable direct answer found in questions index") - logger.info(f"=== END QUESTIONS INDEX CHECK ===") + logger.info("=== END QUESTIONS INDEX CHECK ===") return None - + except Exception as e: logger.error(f"Error checking questions index: {e}", exc_info=True) - logger.info(f"=== END QUESTIONS INDEX CHECK (ERROR) ===") + logger.info("=== END QUESTIONS INDEX CHECK (ERROR) ===") return None - + async def analyze_query(self, query: str) -> Dict[str, Any]: """ Analyze a query to determine which service to route it to. Uses LLM to analyze query semantics and determine if it requires structured data analysis (TAG) or unstructured text search (RAG). - + Args: query: The query to analyze - + Returns: Dictionary containing routing information: - service: "tag" or "rag" @@ -347,35 +382,31 @@ async def analyze_query(self, query: str) -> Dict[str, Any]: """ try: logger.info(f"Analyzing query for routing: '{query}'") - + # Get the prompt from the registry prompt = get_prompt("query_router", "analyze_query", query=query) - + # Create a message for the LLM - messages = [ - Message(role=Role.USER, content=prompt) - ] - + messages = [Message(role=Role.USER, content=prompt)] + # Set completion options options = CompletionOptions( temperature=0.3, # Lower temperature for more deterministic results - max_tokens=500 + max_tokens=500, ) - + # Make the LLM call using our factory - response = await LLMFactory.complete( - messages=messages, - options=options - ) - + response = await LLMFactory.complete(messages=messages, options=options) + # Parse the response as JSON response_text = response.content.strip() - + # Try to find JSON content, handling various ways the LLM might format it import re + json_pattern = r"\{[\s\S]*\}" json_matches = re.search(json_pattern, response_text) - + if json_matches: json_str = json_matches.group(0) try: @@ -383,18 +414,20 @@ async def analyze_query(self, query: str) -> Dict[str, Any]: except json.JSONDecodeError: # If direct parsing fails, try to clean up the JSON string # Remove any trailing commas before closing brackets - json_str = re.sub(r',\s*}', '}', json_str) - json_str = re.sub(r',\s*]', ']', json_str) + json_str = re.sub(r",\s*}", "}", json_str) + json_str = re.sub(r",\s*]", "]", json_str) try: routing_info = json.loads(json_str) except json.JSONDecodeError: # If still failing, default to RAG - logger.warning(f"Failed to parse LLM response as JSON: {response_text}") + logger.warning( + f"Failed to parse LLM response as JSON: {response_text}" + ) return { "service": "rag", "confidence": 0.6, "reasoning": "Failed to parse routing decision, defaulting to RAG", - "fallback": True + "fallback": True, } else: # If no JSON-like structure was found @@ -403,16 +436,16 @@ async def analyze_query(self, query: str) -> Dict[str, Any]: "service": "rag", "confidence": 0.6, "reasoning": "Could not extract routing information, defaulting to RAG", - "fallback": True + "fallback": True, } - + # Validate and normalize the routing info service = routing_info.get("service", "").lower() if service not in ["tag", "rag"]: # Default to RAG for invalid service service = "rag" routing_info["fallback"] = True - + # Ensure confidence is a float between 0 and 1 confidence = routing_info.get("confidence", 0.5) try: @@ -421,23 +454,25 @@ async def analyze_query(self, query: str) -> Dict[str, Any]: confidence = 0.5 except (ValueError, TypeError): confidence = 0.5 - + # Apply a lower threshold for TAG service to prefer RAG in unclear cases TAG_CONFIDENCE_THRESHOLD = 0.7 if service == "tag" and confidence < TAG_CONFIDENCE_THRESHOLD: service = "rag" - routing_info["reasoning"] = f"Original choice was TAG with confidence {confidence}, but this is below threshold {TAG_CONFIDENCE_THRESHOLD}. Defaulting to RAG." + routing_info["reasoning"] = ( + f"Original choice was TAG with confidence {confidence}, but this is below threshold {TAG_CONFIDENCE_THRESHOLD}. Defaulting to RAG." + ) routing_info["fallback"] = True confidence = 0.6 # Set a moderate confidence for the fallback - + # Return the normalized routing info return { "service": service, "confidence": confidence, "reasoning": routing_info.get("reasoning", "No reasoning provided"), - "fallback": routing_info.get("fallback", False) + "fallback": routing_info.get("fallback", False), } - + except Exception as e: logger.error(f"Error analyzing query: {e}", exc_info=True) # In case of any error, default to RAG service @@ -445,34 +480,34 @@ async def analyze_query(self, query: str) -> Dict[str, Any]: "service": "rag", "confidence": 0.5, "reasoning": f"Error during analysis: {str(e)}. Defaulting to RAG.", - "fallback": True + "fallback": True, } - - - async def get_relevant_knowledge_bases(self, query: str, knowledge_base_id: str) -> List[str]: + + async def get_relevant_knowledge_bases( + self, query: str, knowledge_base_id: str + ) -> List[str]: """ Get a list of knowledge base IDs that are relevant to the query. - + Args: query: The user query - + Returns: List of knowledge base IDs sorted by relevance """ # Get the summary vector store summary_vector_store = get_vector_store( - store_type="pinecone", - index_name=settings.PINECONE_SUMMARY_INDEX_NAME + store_type="pinecone", index_name=settings.PINECONE_SUMMARY_INDEX_NAME ) - + # Search for relevant document summaries results = await summary_vector_store.search_similar( query=query, knowledge_base_id=knowledge_base_id, limit=10, - similarity_threshold=0.3 + similarity_threshold=0.3, ) - + # Extract knowledge base IDs from matches kb_ids = [] for match in results: @@ -480,13 +515,13 @@ async def get_relevant_knowledge_bases(self, query: str, knowledge_base_id: str) kb_id = match["metadata"]["knowledge_base_id"] if kb_id and kb_id not in kb_ids: kb_ids.append(kb_id) - + logger.info(f"Relevant knowledge bases for query: {kb_ids}") return kb_ids + # Create a singleton instance of QueryRouter @lru_cache() def get_query_router() -> QueryRouter: """Get a singleton instance of QueryRouter""" return QueryRouter() - diff --git a/app/services/question_service.py b/app/services/question_service.py index bcef0ac..898b656 100644 --- a/app/services/question_service.py +++ b/app/services/question_service.py @@ -1,18 +1,20 @@ -from typing import List, Optional import logging +from typing import List + +from celery import Celery from fastapi import HTTPException from sqlalchemy.orm import Session -from celery import Celery -from app.db.models.question import Question, QuestionStatus, AnswerType +from app.db.models.question import Question, QuestionStatus from app.repositories.question_repository import QuestionRepository -from app.services.rag.vector_store import VectorStore -from app.services.knowledge_base_service import KnowledgeBaseService -from app.schemas.question import QuestionResponse, QuestionCreate, QuestionUpdate +from app.schemas.question import QuestionCreate, QuestionResponse, QuestionUpdate from app.schemas.user import UserResponse +from app.services.knowledge_base_service import KnowledgeBaseService +from app.services.rag.vector_store import VectorStore logger = logging.getLogger(__name__) + class QuestionService: def __init__( self, @@ -20,7 +22,7 @@ def __init__( vector_store: VectorStore, knowledge_base_service: KnowledgeBaseService, celery_app: Celery, - db: Session + db: Session, ): self.question_repository = question_repository self.vector_store = vector_store @@ -29,67 +31,60 @@ def __init__( self.db = db async def create_question( - self, - kb_id: str, - payload: QuestionCreate, - current_user: UserResponse + self, kb_id: str, payload: QuestionCreate, current_user: UserResponse ) -> QuestionResponse: """Create a new question in a knowledge base""" try: # Check knowledge base access await self.kb_service.get_knowledge_base(kb_id, current_user) - + # Create question record question = Question( question=payload.question, answer=payload.answer, answer_type=payload.answer_type.value, # Use .value for enum - status=QuestionStatus.PENDING.value, # Use .value for enum + status=QuestionStatus.PENDING.value, # Use .value for enum knowledge_base_id=kb_id, - user_id=str(current_user.id) + user_id=str(current_user.id), ) - + # Save to database created_question = await self.question_repository.create(question, self.db) - + # Queue question ingestion task self.celery_app.send_task( - 'app.worker.tasks.initiate_question_ingestion', - args=[created_question.id] + "app.worker.tasks.initiate_question_ingestion", + args=[created_question.id], ) - + return created_question - + except Exception as e: logger.error(f"Failed to create question in service: {e}") raise HTTPException(status_code=500, detail=str(e)) async def get_question( - self, - question_id: str, - current_user: UserResponse + self, question_id: str, current_user: UserResponse ) -> QuestionResponse: """Get a question by ID""" question = await self.question_repository.get_by_id(question_id, self.db) if not question: raise HTTPException(status_code=404, detail="Question not found") - + # Check knowledge base access - await self.kb_service.get_knowledge_base(question.knowledge_base_id, current_user) - + await self.kb_service.get_knowledge_base( + question.knowledge_base_id, current_user + ) + return question async def list_questions( - self, - kb_id: str, - current_user: UserResponse, - skip: int = 0, - limit: int = 100 + self, kb_id: str, current_user: UserResponse, skip: int = 0, limit: int = 100 ) -> List[QuestionResponse]: """List all questions for a knowledge base""" # Check knowledge base access await self.kb_service.get_knowledge_base(kb_id, current_user) - + questions = await self.question_repository.list_by_knowledge_base( kb_id, self.db, skip, limit ) @@ -99,71 +94,72 @@ async def update_question( self, question_id: str, question_update: QuestionUpdate, - current_user: UserResponse + current_user: UserResponse, ) -> QuestionResponse: """Update a question""" # Get existing question question = await self.question_repository.get_by_id(question_id, self.db) if not question: raise HTTPException(status_code=404, detail="Question not found") - + # Check knowledge base access - await self.kb_service.get_knowledge_base(question.knowledge_base_id, current_user) - + await self.kb_service.get_knowledge_base( + question.knowledge_base_id, current_user + ) + # Update question update_data = question_update.model_dump(exclude_unset=True) - + # If content changes, set status back to PENDING for re-ingestion if "question" in update_data or "answer" in update_data: update_data["status"] = QuestionStatus.PENDING.value - + updated_question = await self.question_repository.update( question_id, update_data, self.db ) - + # If content changed, queue re-ingestion if "question" in update_data or "answer" in update_data: self.celery_app.send_task( - 'app.worker.tasks.initiate_question_ingestion', - args=[question_id] + "app.worker.tasks.initiate_question_ingestion", args=[question_id] ) - + return updated_question async def delete_question( - self, - question_id: str, - current_user: UserResponse + self, question_id: str, current_user: UserResponse ) -> None: """Delete a question""" # Get existing question question = await self.question_repository.get_by_id(question_id, self.db) if not question: raise HTTPException(status_code=404, detail="Question not found") - + # Check knowledge base access - await self.kb_service.get_knowledge_base(question.knowledge_base_id, current_user) - + await self.kb_service.get_knowledge_base( + question.knowledge_base_id, current_user + ) + # Delete question from vector store self.celery_app.send_task( - 'app.worker.tasks.initiate_question_vector_deletion', - args=[question_id, question.knowledge_base_id] + "app.worker.tasks.initiate_question_vector_deletion", + args=[question_id, question.knowledge_base_id], ) - + # Delete question from database await self.question_repository.delete(question_id, self.db) async def get_question_status( - self, - question_id: str, - current_user: UserResponse + self, question_id: str, current_user: UserResponse ) -> QuestionStatus: """Get the status of a question""" question = await self.question_repository.get_by_id(question_id, self.db) if not question: raise HTTPException(status_code=404, detail="Question not found") - + # Check knowledge base access - await self.kb_service.get_knowledge_base(question.knowledge_base_id, current_user) - - return question.status \ No newline at end of file + await self.kb_service.get_knowledge_base( + question.knowledge_base_id, current_user + ) + + return question.status diff --git a/app/services/rag/chunker/chunker.py b/app/services/rag/chunker/chunker.py index 8592a2c..68a5e0e 100644 --- a/app/services/rag/chunker/chunker.py +++ b/app/services/rag/chunker/chunker.py @@ -1,63 +1,66 @@ -from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional import logging import re +from abc import ABC, abstractmethod from enum import Enum +from typing import Any, Dict, List logger = logging.getLogger(__name__) + class ChunkSize(str, Enum): """Enum for chunk sizes""" + SMALL = "small" MEDIUM = "medium" LARGE = "large" + class Chunker(ABC): """ Abstract base class for chunkers that split documents into chunks. """ - + @abstractmethod async def chunk( self, text: str, metadata: Dict[str, Any], - chunk_size: ChunkSize = ChunkSize.MEDIUM + chunk_size: ChunkSize = ChunkSize.MEDIUM, ) -> List[Dict[str, Any]]: """ Split text into chunks with metadata. - + Args: text: The text to chunk metadata: Metadata about the document chunk_size: Size of chunks to create - + Returns: List of dictionaries containing: - content: The chunk text - metadata: Enhanced metadata for the chunk """ - pass + class SingleChunker(Chunker): """ Simple chunker that splits text into chunks of roughly equal size. """ - + async def chunk( self, text: str, metadata: Dict[str, Any], - chunk_size: ChunkSize = ChunkSize.MEDIUM + chunk_size: ChunkSize = ChunkSize.MEDIUM, ) -> List[Dict[str, Any]]: """ Split text into chunks of roughly equal size. - + Args: text: The text to chunk metadata: Metadata about the document chunk_size: Size of chunks to create - + Returns: List of dictionaries containing: - content: The chunk text @@ -65,29 +68,29 @@ async def chunk( """ try: logger.info(f"Chunking text with SingleChunker, chunk_size={chunk_size}") - + # Determine chunk size in characters size_map = { ChunkSize.SMALL: 1000, ChunkSize.MEDIUM: 2000, - ChunkSize.LARGE: 4000 + ChunkSize.LARGE: 4000, } target_size = size_map.get(chunk_size, 2000) - + # Split text into paragraphs - paragraphs = re.split(r'\n\s*\n', text) - + paragraphs = re.split(r"\n\s*\n", text) + # Create chunks chunks = [] current_chunk = "" current_size = 0 chunk_index = 0 - + for paragraph in paragraphs: paragraph = paragraph.strip() if not paragraph: continue - + # If adding this paragraph would exceed the target size, start a new chunk if current_size + len(paragraph) > target_size and current_chunk: # Create chunk @@ -96,23 +99,22 @@ async def chunk( "chunk_index": chunk_index, "chunk_size": chunk_size, "nearest_header": "", - "section_path": [] + "section_path": [], } - - chunks.append({ - "content": current_chunk.strip(), - "metadata": chunk_metadata - }) - + + chunks.append( + {"content": current_chunk.strip(), "metadata": chunk_metadata} + ) + # Reset for next chunk current_chunk = "" current_size = 0 chunk_index += 1 - + # Add paragraph to current chunk current_chunk += paragraph + "\n\n" current_size += len(paragraph) - + # Add the last chunk if it's not empty if current_chunk: chunk_metadata = { @@ -120,57 +122,59 @@ async def chunk( "chunk_index": chunk_index, "chunk_size": chunk_size, "nearest_header": "", - "section_path": [] + "section_path": [], } - - chunks.append({ - "content": current_chunk.strip(), - "metadata": chunk_metadata - }) - + + chunks.append( + {"content": current_chunk.strip(), "metadata": chunk_metadata} + ) + logger.info(f"Created {len(chunks)} chunks with SingleChunker") - + return chunks - + except Exception as e: logger.error(f"Failed to chunk text with SingleChunker: {e}", exc_info=True) raise + class MultiLevelChunker(Chunker): """ Advanced chunker that splits text into chunks based on document structure. Preserves section hierarchy and headers in metadata. """ - + async def chunk( self, text: str, metadata: Dict[str, Any], - chunk_size: ChunkSize = ChunkSize.MEDIUM + chunk_size: ChunkSize = ChunkSize.MEDIUM, ) -> List[Dict[str, Any]]: """ Split text into chunks based on document structure. - + Args: text: The text to chunk metadata: Metadata about the document chunk_size: Size of chunks to create - + Returns: List of dictionaries containing: - content: The chunk text - metadata: Enhanced metadata for the chunk """ try: - logger.info(f"Chunking text with MultiLevelChunker, chunk_size={chunk_size}") - + logger.info( + f"Chunking text with MultiLevelChunker, chunk_size={chunk_size}" + ) + # Determine chunk size in characters size_map = { ChunkSize.SMALL: 128, ChunkSize.MEDIUM: 256, - ChunkSize.LARGE: 512 + ChunkSize.LARGE: 512, } - + # Extract headers and sections sections = self._extract_sections(text) @@ -182,48 +186,53 @@ async def chunk( section_text = section["text"] section_path = section["path"] section_header = section["header"] - + # Skip empty sections if not section_text.strip(): continue - + # Split section into paragraphs - paragraphs = re.split(r'\n\s*\n', section_text) - + paragraphs = re.split(r"\n\s*\n", section_text) + # Create chunks from paragraphs current_chunk = "" current_size = 0 - + for paragraph in paragraphs: paragraph = paragraph.strip() if not paragraph: continue - + # If adding this paragraph would exceed the target size, start a new chunk - if current_size + len(paragraph) > target_size and current_chunk: + if ( + current_size + len(paragraph) > target_size + and current_chunk + ): # Create chunk chunk_metadata = { **metadata, "chunk_index": chunk_index, "chunk_size": chunk_size, "nearest_header": section_header, - "section_path": section_path + "section_path": section_path, } - - chunks.append({ - "content": current_chunk.strip(), - "metadata": chunk_metadata - }) - + + chunks.append( + { + "content": current_chunk.strip(), + "metadata": chunk_metadata, + } + ) + # Reset for next chunk current_chunk = "" current_size = 0 chunk_index += 1 - + # Add paragraph to current chunk current_chunk += paragraph + "\n\n" current_size += len(paragraph) - + # Add the last chunk if it's not empty if current_chunk: chunk_metadata = { @@ -231,31 +240,35 @@ async def chunk( "chunk_index": chunk_index, "chunk_size": chunk_size, "nearest_header": section_header, - "section_path": section_path + "section_path": section_path, } - - chunks.append({ - "content": current_chunk.strip(), - "metadata": chunk_metadata - }) - + + chunks.append( + { + "content": current_chunk.strip(), + "metadata": chunk_metadata, + } + ) + chunk_index += 1 - + chunk_index = 0 - + return chunks - + except Exception as e: - logger.error(f"Failed to chunk text with MultiLevelChunker: {e}", exc_info=True) + logger.error( + f"Failed to chunk text with MultiLevelChunker: {e}", exc_info=True + ) raise - + def _extract_sections(self, text: str) -> List[Dict[str, Any]]: """ Extract sections and headers from text. - + Args: text: The text to extract sections from - + Returns: List of dictionaries containing: - header: The section header @@ -263,51 +276,41 @@ def _extract_sections(self, text: str) -> List[Dict[str, Any]]: - path: The section path (list of parent headers) """ # Split text into lines - lines = text.split('\n') - + lines = text.split("\n") + # Find headers (lines starting with #) headers = [] for i, line in enumerate(lines): - if re.match(r'^#+\s', line): - level = len(re.match(r'^#+', line).group(0)) - header_text = line.lstrip('#').strip() - headers.append({ - "level": level, - "text": header_text, - "line": i - }) - + if re.match(r"^#+\s", line): + level = len(re.match(r"^#+", line).group(0)) + header_text = line.lstrip("#").strip() + headers.append({"level": level, "text": header_text, "line": i}) + # If no headers found, return the entire text as one section if not headers: - return [{ - "header": "", - "text": text, - "path": [] - }] - + return [{"header": "", "text": text, "path": []}] + # Extract sections sections = [] for i, header in enumerate(headers): # Determine section start and end start_line = header["line"] + 1 end_line = len(lines) - + if i < len(headers) - 1: end_line = headers[i + 1]["line"] - + # Extract section text - section_text = '\n'.join(lines[start_line:end_line]) - + section_text = "\n".join(lines[start_line:end_line]) + # Determine section path path = [] for prev_header in headers[:i]: if prev_header["level"] < header["level"]: path.append(prev_header["text"]) - - sections.append({ - "header": header["text"], - "text": section_text, - "path": path - }) - - return sections \ No newline at end of file + + sections.append( + {"header": header["text"], "text": section_text, "path": path} + ) + + return sections diff --git a/app/services/rag/chunker/chunker_factory.py b/app/services/rag/chunker/chunker_factory.py index b748133..deaea7c 100644 --- a/app/services/rag/chunker/chunker_factory.py +++ b/app/services/rag/chunker/chunker_factory.py @@ -1,24 +1,25 @@ -from typing import Dict, Any import logging +from typing import Any, Dict from app.db.models.knowledge_base import DocumentType -from app.services.rag.chunker.chunker import Chunker, MultiLevelChunker, SingleChunker +from app.services.rag.chunker.chunker import Chunker, MultiLevelChunker logger = logging.getLogger(__name__) + class ChunkerFactory: """ Factory for creating chunkers based on document type. """ - + @staticmethod def create_chunker(document_type: DocumentType) -> Chunker: """ Create a chunker based on document type. - + Args: document_type: Type of document to chunk - + Returns: Chunker instance """ @@ -29,27 +30,27 @@ def create_chunker(document_type: DocumentType) -> Chunker: except Exception as e: logger.error(f"Failed to create chunker: {e}", exc_info=True) raise - + @staticmethod def create_chunker_from_metadata(metadata: Dict[str, Any]) -> Chunker: """ Create a chunker based on document metadata. - + Args: metadata: Document metadata - + Returns: Chunker instance """ try: logger.info("Creating chunker from metadata") - + # Extract document type from metadata - document_type = metadata.get('document_type', DocumentType.TXT) - + document_type = metadata.get("document_type", DocumentType.TXT) + # Create chunker based on document type return ChunkerFactory.create_chunker(document_type) - + except Exception as e: logger.error(f"Failed to create chunker from metadata: {e}", exc_info=True) - raise \ No newline at end of file + raise diff --git a/app/services/rag/ingestor/ingestor.py b/app/services/rag/ingestor/ingestor.py index 9668d93..587f8d3 100644 --- a/app/services/rag/ingestor/ingestor.py +++ b/app/services/rag/ingestor/ingestor.py @@ -1,37 +1,38 @@ +import csv +import io +import logging +import re from abc import ABC, abstractmethod from datetime import datetime -import re -from typing import Dict, Any, List -import logging -import io -import os # Added for environment variables -import PyPDF2 -import csv +from typing import Any, Dict, List + import markdown -from PIL import Image +import PyPDF2 import pytesseract - -from app.core.config import settings - -from docling.document_converter import DocumentConverter from docling.datamodel.base_models import InputFormat -from docling.document_converter import PdfFormatOption from docling.datamodel.pipeline_options import PdfPipelineOptions -from docling.document_converter import DocumentStream +from docling.document_converter import ( + DocumentConverter, + DocumentStream, + PdfFormatOption, +) +from PIL import Image +from app.core.prompts import get_prompt, register_prompt from app.db.storage import get_storage_db from app.repositories.storage_repository import StorageRepository - -from app.services.llm.factory import LLMFactory, Message, Role, CompletionOptions -from app.core.prompts import get_prompt, register_prompt +from app.services.llm.factory import CompletionOptions, LLMFactory, Message, Role logger = logging.getLogger(__name__) # Register the prompt for SQL table generation -register_prompt("ingestor", "generate_table_schema", """ -Generate a SQL database create table query for the given table name and headers. +register_prompt( + "ingestor", + "generate_table_schema", + """ +Generate a SQL database create table query for the given table name and headers. Make sure to use headers as column names and rows as sample data. -Rows contain sample data for the table. +Rows contain sample data for the table. Use your understanding to extrapolate scenario where datatype is not obvious, or might be different from the sample data. @@ -44,42 +45,44 @@ Table name: {{ table_name }} Headers: {{ headers }} Rows: {{ rows }} -""") +""", +) + class Ingestor(ABC): """ Abstract base class for document ingestors that extract text and metadata from documents. """ - + @abstractmethod async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, Any]: """ Extract text and metadata from a document. - + Args: content: Raw document content as bytes metadata: Additional metadata about the document - + Returns: Dictionary containing: - text: Extracted text - metadata: Enhanced metadata """ - pass + class PDFIngestor(Ingestor): """ Ingestor for PDF documents using docling for better text extraction and structure preservation. """ - + async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, Any]: """ Extract text and metadata from a PDF document using docling. - + Args: content: Raw PDF content as bytes metadata: Additional metadata about the document - + Returns: Dictionary containing: - text: Extracted text (in markdown format) @@ -87,70 +90,77 @@ async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, An """ try: logger.info("Ingesting PDF document with docling v2") - + # Configure pipeline options pipeline_options = PdfPipelineOptions() pipeline_options.do_ocr = True pipeline_options.do_table_structure = True pipeline_options.table_structure_options.do_cell_matching = True - + # Set up the document converter with PDF format options doc_converter = DocumentConverter( allowed_formats=[InputFormat.PDF], format_options={ InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options) - } + }, ) # Convert bytes to DocumentStream content_stream = DocumentStream( - name=metadata.get("title", "temp.pdf"), - stream=io.BytesIO(content) + name=metadata.get("title", "temp.pdf"), stream=io.BytesIO(content) ) # Convert the PDF content conv_result = doc_converter.convert(source=content_stream) - + # Get markdown representation markdown_text = conv_result.document.export_to_markdown() - + # Extract metadata from docling document - docling_metadata = conv_result.document.metadata.model_dump() if hasattr(conv_result.document, 'metadata') else {} - + docling_metadata = ( + conv_result.document.metadata.model_dump() + if hasattr(conv_result.document, "metadata") + else {} + ) + # Extract text from each page for additional processing if needed page_texts = [] for page in conv_result.document.pages: - page_text = page.export_to_text() if hasattr(page, 'export_to_text') else "" + page_text = ( + page.export_to_text() if hasattr(page, "export_to_text") else "" + ) page_texts.append(page_text) - + # Combine with provided metadata enhanced_metadata = { **metadata, "page_count": len(conv_result.document.pages), "pdf_metadata": docling_metadata, - "document_type": "pdf" + "document_type": "pdf", } - - logger.info(f"Successfully ingested PDF with {len(conv_result.document.pages)} pages using docling v2") - + + logger.info( + f"Successfully ingested PDF with {len(conv_result.document.pages)} pages using docling v2" + ) + return { "text": markdown_text, "metadata": enhanced_metadata, - "page_texts": page_texts + "page_texts": page_texts, } - + except Exception as e: logger.error(f"Failed to ingest PDF with docling v2: {e}", exc_info=True) - + # Fallback to PyPDF2 if docling fails logger.info("Falling back to PyPDF2 for PDF ingestion") try: # Create a file-like object from bytes pdf_file = io.BytesIO(content) - + # Open the PDF file pdf_reader = PyPDF2.PdfReader(pdf_file) - + # Extract text from each page text = "" page_texts = [] @@ -159,47 +169,53 @@ async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, An if page_text: page_texts.append(page_text) text += f"\n\n--- Page {i+1} ---\n\n{page_text}" - + # Extract metadata pdf_metadata = {} if pdf_reader.metadata: for key, value in pdf_reader.metadata.items(): - if key.startswith('/'): + if key.startswith("/"): pdf_metadata[key[1:]] = value - + # Combine with provided metadata enhanced_metadata = { **metadata, "page_count": len(pdf_reader.pages), "pdf_metadata": pdf_metadata, - "document_type": "pdf" + "document_type": "pdf", } - - logger.info(f"Successfully ingested PDF with {len(pdf_reader.pages)} pages using PyPDF2 fallback") - + + logger.info( + f"Successfully ingested PDF with {len(pdf_reader.pages)} pages using PyPDF2 fallback" + ) + return { "text": text, "metadata": enhanced_metadata, - "page_texts": page_texts + "page_texts": page_texts, } - + except Exception as fallback_error: - logger.error(f"Fallback PDF ingestion also failed: {fallback_error}", exc_info=True) + logger.error( + f"Fallback PDF ingestion also failed: {fallback_error}", + exc_info=True, + ) raise + class CSVIngestor(Ingestor): """ Ingestor for CSV documents. """ - + async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, Any]: """ Extract text and metadata from a CSV document. - + Args: content: Raw CSV content as bytes metadata: Additional metadata about the document - + Returns: Dictionary containing: - text: Extracted text @@ -207,27 +223,36 @@ async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, An """ try: logger.info("Ingesting CSV document") - + # Create a file-like object from bytes - csv_file = io.StringIO(content.decode('utf-8')) - + csv_file = io.StringIO(content.decode("utf-8")) + # Read CSV csv_reader = csv.reader(csv_file) rows = list(csv_reader) # Extract headers headers = rows[0] if rows else [] - + # Generate create table query - table_name = metadata.get("document_title", f"csv_data_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}").replace(".csv", "") - create_table_query = await CSVIngestor.generate_create_table_query(table_name, headers, rows[1:10]) - + table_name = metadata.get( + "document_title", + f"csv_data_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}", + ).replace(".csv", "") + create_table_query = await CSVIngestor.generate_create_table_query( + table_name, headers, rows[1:10] + ) + # insert into storage database as new table storage_session = get_storage_db().__next__() - await StorageRepository.insert_csv(storage_session, table_name, create_table_query, headers, rows[1:]) + await StorageRepository.insert_csv( + storage_session, table_name, create_table_query, headers, rows[1:] + ) - logger.info(f"Successfully ingested CSV with {len(rows) - 1 if rows else 0} rows and {len(headers)} columns") + logger.info( + f"Successfully ingested CSV with {len(rows) - 1 if rows else 0} rows and {len(headers)} columns" + ) return { "text": create_table_query, @@ -239,44 +264,44 @@ async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, An "row_count": len(rows) - 1 if rows else 0, "column_count": len(headers), }, - } + } except Exception as e: logger.error(f"Failed to ingest CSV: {e}", exc_info=True) raise @staticmethod - async def generate_create_table_query(table_name: str, headers: List[str], rows: List[List[str]]) -> str: + async def generate_create_table_query( + table_name: str, headers: List[str], rows: List[List[str]] + ) -> str: """ Generate a create table query for the given table name and rows. """ try: # Get the prompt from the registry - prompt = get_prompt("ingestor", "generate_table_schema", - table_name=table_name, - headers=headers, - rows=rows) - + prompt = get_prompt( + "ingestor", + "generate_table_schema", + table_name=table_name, + headers=headers, + rows=rows, + ) + # Create a message for the LLM - messages = [ - Message(role=Role.USER, content=prompt) - ] - + messages = [Message(role=Role.USER, content=prompt)] + # Set completion options options = CompletionOptions( temperature=0.3, # Low temperature for more deterministic SQL generation - max_tokens=1000 + max_tokens=1000, ) - + # Generate SQL using LLM Factory - response = await LLMFactory.complete( - messages=messages, - options=options - ) - + response = await LLMFactory.complete(messages=messages, options=options) + # Extract SQL from response response_text = response.content.strip() - + # Try to extract SQL from markdown code blocks if present sql_match = re.search(r"```sql(.*)```", response_text, re.DOTALL) if sql_match: @@ -284,10 +309,10 @@ async def generate_create_table_query(table_name: str, headers: List[str], rows: else: # If no code block, just use the full response sql_query = response_text - + logger.info(f"Generated SQL table creation query for {table_name}") return sql_query - + except Exception as e: logger.error(f"Error generating SQL table schema: {e}", exc_info=True) # Fallback to a simple CREATE TABLE statement @@ -297,23 +322,24 @@ async def generate_create_table_query(table_name: str, headers: List[str], rows: if i < len(headers) - 1: sql_query += "," sql_query += "\n);" - + logger.warning(f"Using fallback SQL schema: {sql_query}") return sql_query + class MarkdownIngestor(Ingestor): """ Ingestor for Markdown documents. """ - + async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, Any]: """ Extract text and metadata from a Markdown document. - + Args: content: Raw Markdown content as bytes metadata: Additional metadata about the document - + Returns: Dictionary containing: - text: Extracted text @@ -321,63 +347,57 @@ async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, An """ try: logger.info("Ingesting Markdown document") - + # Decode bytes to string - md_text = content.decode('utf-8') - + md_text = content.decode("utf-8") + # Convert to HTML (for metadata extraction) html = markdown.markdown(md_text) - + # Extract headers headers = [] - for line in md_text.split('\n'): - if line.startswith('#'): + for line in md_text.split("\n"): + if line.startswith("#"): # Count the number of # to determine header level level = 0 for char in line: - if char == '#': + if char == "#": level += 1 else: break - + header_text = line[level:].strip() - headers.append({ - "level": level, - "text": header_text - }) - + headers.append({"level": level, "text": header_text}) + # Enhance metadata enhanced_metadata = { **metadata, "headers": headers, - "document_type": "markdown" + "document_type": "markdown", } - + logger.info(f"Successfully ingested Markdown with {len(headers)} headers") - - return { - "text": md_text, - "metadata": enhanced_metadata, - "html": html - } - + + return {"text": md_text, "metadata": enhanced_metadata, "html": html} + except Exception as e: logger.error(f"Failed to ingest Markdown: {e}", exc_info=True) raise + class ImageIngestor(Ingestor): """ Ingestor for image documents using docling v2 for better OCR. """ - + async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, Any]: """ Extract text and metadata from an image using docling OCR. - + Args: content: Raw image content as bytes metadata: Additional metadata about the document - + Returns: Dictionary containing: - text: Extracted text @@ -385,99 +405,97 @@ async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, An """ try: logger.info("Ingesting image document with docling v2 OCR") - + # Convert to document stream content_stream = DocumentStream( - name=metadata.get("title", "temp.png"), - stream=io.BytesIO(content) + name=metadata.get("title", "temp.png"), stream=io.BytesIO(content) ) - + # Set up the document converter with IMAGE format options - doc_converter = DocumentConverter( - allowed_formats=[InputFormat.IMAGE] - ) - + doc_converter = DocumentConverter(allowed_formats=[InputFormat.IMAGE]) + # Convert the image file conv_result = doc_converter.convert(source=content_stream) - + # Get text from docling document text = conv_result.document.export_to_text() - + # Extract image metadata image = Image.open(io.BytesIO(content)) image_metadata = { "format": image.format, "size": image.size, - "mode": image.mode + "mode": image.mode, } - + # Enhance metadata enhanced_metadata = { **metadata, "image_metadata": image_metadata, - "document_type": "image" + "document_type": "image", } - - logger.info(f"Successfully ingested image with docling v2 OCR, extracted {len(text)} characters") - - return { - "text": text, - "metadata": enhanced_metadata - } - + + logger.info( + f"Successfully ingested image with docling v2 OCR, extracted {len(text)} characters" + ) + + return {"text": text, "metadata": enhanced_metadata} + except Exception as e: logger.error(f"Failed to ingest image with docling v2: {e}", exc_info=True) - + # Fallback to pytesseract if docling fails logger.info("Falling back to pytesseract for image OCR") try: # Create a file-like object from bytes image_file = io.BytesIO(content) - + # Open the image image = Image.open(image_file) - + # Extract image metadata image_metadata = { "format": image.format, "size": image.size, - "mode": image.mode + "mode": image.mode, } - + # Perform OCR text = pytesseract.image_to_string(image) - + # Enhance metadata enhanced_metadata = { **metadata, "image_metadata": image_metadata, - "document_type": "image" - } - - logger.info(f"Successfully ingested image with pytesseract fallback, extracted {len(text)} characters") - - return { - "text": text, - "metadata": enhanced_metadata + "document_type": "image", } - + + logger.info( + f"Successfully ingested image with pytesseract fallback, extracted {len(text)} characters" + ) + + return {"text": text, "metadata": enhanced_metadata} + except Exception as fallback_error: - logger.error(f"Fallback image OCR also failed: {fallback_error}", exc_info=True) + logger.error( + f"Fallback image OCR also failed: {fallback_error}", exc_info=True + ) raise + class TextIngestor(Ingestor): """ Ingestor for plain text documents. """ - + async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, Any]: """ Extract text and metadata from a plain text document. - + Args: content: Raw text content as bytes metadata: Additional metadata about the document - + Returns: Dictionary containing: - text: Extracted text @@ -485,46 +503,46 @@ async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, An """ try: logger.info("Ingesting plain text document") - + # Decode bytes to string - text = content.decode('utf-8') - + text = content.decode("utf-8") + # Count lines and words - lines = text.split('\n') + lines = text.split("\n") words = text.split() - + # Enhance metadata enhanced_metadata = { **metadata, "line_count": len(lines), "word_count": len(words), - "document_type": "text" + "document_type": "text", } - - logger.info(f"Successfully ingested text with {enhanced_metadata['line_count']} lines and {enhanced_metadata['word_count']} words") - - return { - "text": text, - "metadata": enhanced_metadata - } - + + logger.info( + f"Successfully ingested text with {enhanced_metadata['line_count']} lines and {enhanced_metadata['word_count']} words" + ) + + return {"text": text, "metadata": enhanced_metadata} + except Exception as e: logger.error(f"Failed to ingest text: {e}", exc_info=True) raise + class DocxIngestor(Ingestor): """ Ingestor for DOCX documents using docling v2. """ - + async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, Any]: """ Extract text and metadata from a DOCX document using docling. - + Args: content: Raw DOCX content as bytes metadata: Additional metadata about the document - + Returns: Dictionary containing: - text: Extracted text @@ -532,57 +550,60 @@ async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, An """ try: logger.info("Ingesting DOCX document with docling v2") - + # Create a file-like object from bytes docx_file = io.BytesIO(content) - + # Set up the document converter with DOCX format options - doc_converter = DocumentConverter( - allowed_formats=[InputFormat.DOCX] - ) - + doc_converter = DocumentConverter(allowed_formats=[InputFormat.DOCX]) + # Convert the DOCX file conv_result = doc_converter.convert(docx_file) - + # Get text and markdown representation text = conv_result.document.export_to_text() markdown_text = conv_result.document.export_to_markdown() - + # Extract metadata from docling document - docling_metadata = conv_result.document.metadata.model_dump() if hasattr(conv_result.document, 'metadata') else {} - + docling_metadata = ( + conv_result.document.metadata.model_dump() + if hasattr(conv_result.document, "metadata") + else {} + ) + # Enhance metadata enhanced_metadata = { **metadata, "docx_metadata": docling_metadata, - "document_type": "docx" + "document_type": "docx", } - - logger.info(f"Successfully ingested DOCX document using docling v2") - + + logger.info("Successfully ingested DOCX document using docling v2") + return { "text": text, "markdown": markdown_text, - "metadata": enhanced_metadata + "metadata": enhanced_metadata, } - + except Exception as e: logger.error(f"Failed to ingest DOCX with docling v2: {e}", exc_info=True) raise + class PptxIngestor(Ingestor): """ Ingestor for PPTX documents using docling v2. """ - + async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, Any]: """ Extract text and metadata from a PPTX document using docling. - + Args: content: Raw PPTX content as bytes metadata: Additional metadata about the document - + Returns: Dictionary containing: - text: Extracted text @@ -590,61 +611,70 @@ async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, An """ try: logger.info("Ingesting PPTX document with docling v2") - + # Create a file-like object from bytes pptx_file = io.BytesIO(content) - + # Set up the document converter with PPTX format options - doc_converter = DocumentConverter( - allowed_formats=[InputFormat.PPTX] - ) - + doc_converter = DocumentConverter(allowed_formats=[InputFormat.PPTX]) + # Convert the PPTX file conv_result = doc_converter.convert(pptx_file) - + # Get text and markdown representation text = conv_result.document.export_to_text() markdown_text = conv_result.document.export_to_markdown() - + # Extract metadata from docling document - docling_metadata = conv_result.document.metadata.model_dump() if hasattr(conv_result.document, 'metadata') else {} - + docling_metadata = ( + conv_result.document.metadata.model_dump() + if hasattr(conv_result.document, "metadata") + else {} + ) + # Extract slide count if available - slide_count = len(conv_result.document.pages) if hasattr(conv_result.document, 'pages') else 0 - + slide_count = ( + len(conv_result.document.pages) + if hasattr(conv_result.document, "pages") + else 0 + ) + # Enhance metadata enhanced_metadata = { **metadata, "slide_count": slide_count, "pptx_metadata": docling_metadata, - "document_type": "pptx" + "document_type": "pptx", } - - logger.info(f"Successfully ingested PPTX document with {slide_count} slides using docling v2") - + + logger.info( + f"Successfully ingested PPTX document with {slide_count} slides using docling v2" + ) + return { "text": text, "markdown": markdown_text, - "metadata": enhanced_metadata + "metadata": enhanced_metadata, } - + except Exception as e: logger.error(f"Failed to ingest PPTX with docling v2: {e}", exc_info=True) raise + class HTMLIngestor(Ingestor): """ Ingestor for HTML documents using docling v2. """ - + async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, Any]: """ Extract text and metadata from an HTML document using docling. - + Args: content: Raw HTML content as bytes metadata: Additional metadata about the document - + Returns: Dictionary containing: - text: Extracted text @@ -652,70 +682,73 @@ async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, An """ try: logger.info("Ingesting HTML document with docling v2") - + # Create a file-like object from bytes html_file = io.BytesIO(content) - + # Set up the document converter with HTML format options - doc_converter = DocumentConverter( - allowed_formats=[InputFormat.HTML] - ) - + doc_converter = DocumentConverter(allowed_formats=[InputFormat.HTML]) + # Convert the HTML file conv_result = doc_converter.convert(html_file) - + # Get text and markdown representation text = conv_result.document.export_to_text() markdown_text = conv_result.document.export_to_markdown() - + # Extract metadata from docling document - docling_metadata = conv_result.document.metadata.model_dump() if hasattr(conv_result.document, 'metadata') else {} - + docling_metadata = ( + conv_result.document.metadata.model_dump() + if hasattr(conv_result.document, "metadata") + else {} + ) + # Enhance metadata enhanced_metadata = { **metadata, "html_metadata": docling_metadata, - "document_type": "html" + "document_type": "html", } - - logger.info(f"Successfully ingested HTML document using docling v2") - + + logger.info("Successfully ingested HTML document using docling v2") + return { "text": text, "markdown": markdown_text, - "metadata": enhanced_metadata + "metadata": enhanced_metadata, } - + except Exception as e: logger.error(f"Failed to ingest HTML with docling v2: {e}", exc_info=True) - + # Fallback to simple HTML parsing if docling fails logger.info("Falling back to simple HTML parsing") try: # Decode bytes to string - html_text = content.decode('utf-8') - + html_text = content.decode("utf-8") + # Convert to markdown md_text = html_to_markdown(html_text) - + # Enhance metadata - enhanced_metadata = { - **metadata, - "document_type": "html" - } - - logger.info(f"Successfully ingested HTML with fallback method") - + enhanced_metadata = {**metadata, "document_type": "html"} + + logger.info("Successfully ingested HTML with fallback method") + return { "text": md_text, "markdown": md_text, - "metadata": enhanced_metadata + "metadata": enhanced_metadata, } - + except Exception as fallback_error: - logger.error(f"Fallback HTML ingestion also failed: {fallback_error}", exc_info=True) + logger.error( + f"Fallback HTML ingestion also failed: {fallback_error}", + exc_info=True, + ) raise + def html_to_markdown(html_text): """ Simple function to convert HTML to markdown. @@ -725,74 +758,74 @@ def html_to_markdown(html_text): # In a real-world scenario, you might want to use a more robust library # like html2text or markdownify from bs4 import BeautifulSoup - - soup = BeautifulSoup(html_text, 'html.parser') - + + soup = BeautifulSoup(html_text, "html.parser") + # Extract text - text = soup.get_text(separator='\n\n') - + text = soup.get_text(separator="\n\n") + # Try to preserve some structure - for heading in soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6']): + for heading in soup.find_all(["h1", "h2", "h3", "h4", "h5", "h6"]): level = int(heading.name[1]) heading_text = heading.get_text().strip() - heading_md = '#' * level + ' ' + heading_text + heading_md = "#" * level + " " + heading_text text = text.replace(heading_text, heading_md) - + return text """ Unified ingestor that can handle multiple document formats using docling v2. Supports PDF, DOCX, PPTX, HTML, and images. """ - + async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, Any]: """ Extract text and metadata from a document using docling v2. - + Args: content: Raw document content as bytes metadata: Additional metadata about the document, including 'file_extension' - + Returns: Dictionary containing: - text: Extracted text - metadata: Enhanced metadata """ # Determine the input format based on file extension - file_extension = metadata.get('file_extension', '').lower() - + file_extension = metadata.get("file_extension", "").lower() + # Map file extensions to docling InputFormat format_mapping = { - 'pdf': InputFormat.PDF, - 'docx': InputFormat.DOCX, - 'doc': InputFormat.DOCX, # Treat .doc as .docx (may not work perfectly) - 'pptx': InputFormat.PPTX, - 'ppt': InputFormat.PPTX, # Treat .ppt as .pptx (may not work perfectly) - 'html': InputFormat.HTML, - 'htm': InputFormat.HTML, - 'png': InputFormat.IMAGE, - 'jpg': InputFormat.IMAGE, - 'jpeg': InputFormat.IMAGE, - 'gif': InputFormat.IMAGE, - 'bmp': InputFormat.IMAGE, - 'tiff': InputFormat.IMAGE, - 'tif': InputFormat.IMAGE, + "pdf": InputFormat.PDF, + "docx": InputFormat.DOCX, + "doc": InputFormat.DOCX, # Treat .doc as .docx (may not work perfectly) + "pptx": InputFormat.PPTX, + "ppt": InputFormat.PPTX, # Treat .ppt as .pptx (may not work perfectly) + "html": InputFormat.HTML, + "htm": InputFormat.HTML, + "png": InputFormat.IMAGE, + "jpg": InputFormat.IMAGE, + "jpeg": InputFormat.IMAGE, + "gif": InputFormat.IMAGE, + "bmp": InputFormat.IMAGE, + "tiff": InputFormat.IMAGE, + "tif": InputFormat.IMAGE, } - + input_format = format_mapping.get(file_extension) - + if not input_format: logger.warning(f"Unsupported file extension: {file_extension}") # Fall back to text ingestor for unsupported formats text_ingestor = TextIngestor() return await text_ingestor.ingest(content, metadata) - + try: logger.info(f"Ingesting {file_extension.upper()} document with docling v2") - + # Create a file-like object from bytes file_obj = io.BytesIO(content) - + # Configure pipeline options for PDF pipeline_options = None if input_format == InputFormat.PDF: @@ -800,47 +833,58 @@ async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, An pipeline_options.do_ocr = True pipeline_options.do_table_structure = True pipeline_options.table_structure_options.do_cell_matching = True - + # Set up the document converter with appropriate format options format_options = {} if input_format == InputFormat.PDF and pipeline_options: format_options = { InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options) } - + # Create the document converter doc_converter = DocumentConverter( - allowed_formats=[input_format], - format_options=format_options + allowed_formats=[input_format], format_options=format_options ) - + # Convert the file conv_result = doc_converter.convert(file_obj) - + # Get text and markdown representation text = conv_result.document.export_to_text() markdown_text = conv_result.document.export_to_markdown() - + # Extract metadata from docling document - docling_metadata = conv_result.document.metadata.model_dump() if hasattr(conv_result.document, 'metadata') else {} - + docling_metadata = ( + conv_result.document.metadata.model_dump() + if hasattr(conv_result.document, "metadata") + else {} + ) + # Extract page/slide count if available - page_count = len(conv_result.document.pages) if hasattr(conv_result.document, 'pages') else 0 - + page_count = ( + len(conv_result.document.pages) + if hasattr(conv_result.document, "pages") + else 0 + ) + # Extract page texts for PDF page_texts = [] - if input_format == InputFormat.PDF and hasattr(conv_result.document, 'pages'): + if input_format == InputFormat.PDF and hasattr( + conv_result.document, "pages" + ): for page in conv_result.document.pages: - page_text = page.export_to_text() if hasattr(page, 'export_to_text') else "" + page_text = ( + page.export_to_text() if hasattr(page, "export_to_text") else "" + ) page_texts.append(page_text) - + # Combine with provided metadata enhanced_metadata = { **metadata, "docling_metadata": docling_metadata, "document_type": file_extension, } - + # Add format-specific metadata if input_format == InputFormat.PDF: enhanced_metadata["page_count"] = page_count @@ -859,32 +903,36 @@ async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, An enhanced_metadata["image_metadata"] = { "format": image.format, "size": image.size, - "mode": image.mode + "mode": image.mode, } except Exception as img_error: logger.warning(f"Failed to extract image metadata: {img_error}") - - logger.info(f"Successfully ingested {file_extension.upper()} document using docling v2") - + + logger.info( + f"Successfully ingested {file_extension.upper()} document using docling v2" + ) + result = { "text": text, "markdown": markdown_text, - "metadata": enhanced_metadata + "metadata": enhanced_metadata, } - + # Add page_texts for PDF if input_format == InputFormat.PDF: result["page_texts"] = page_texts - + return result - + except Exception as e: - logger.error(f"Failed to ingest document with docling v2: {e}", exc_info=True) - + logger.error( + f"Failed to ingest document with docling v2: {e}", exc_info=True + ) + # Fall back to specific ingestors based on format try: logger.info(f"Falling back to specific ingestor for {file_extension}") - + if input_format == InputFormat.PDF: pdf_ingestor = PDFIngestor() return await pdf_ingestor.ingest(content, metadata) @@ -899,28 +947,36 @@ async def ingest(self, content: bytes, metadata: Dict[str, Any]) -> Dict[str, An return await csv_ingestor.ingest(content, metadata) elif input_format == InputFormat.HTML: # Use the fallback method directly - html_text = content.decode('utf-8') + html_text = content.decode("utf-8") md_text = html_to_markdown(html_text) return { "text": md_text, "markdown": md_text, - "metadata": {**metadata, "document_type": "html"} + "metadata": {**metadata, "document_type": "html"}, } else: # For other formats, fall back to text ingestor text_ingestor = TextIngestor() return await text_ingestor.ingest(content, metadata) - + except Exception as fallback_error: - logger.error(f"Fallback ingestion also failed: {fallback_error}", exc_info=True) - + logger.error( + f"Fallback ingestion also failed: {fallback_error}", exc_info=True + ) + # Last resort: try to extract as plain text try: - text = content.decode('utf-8', errors='replace') + text = content.decode("utf-8", errors="replace") return { "text": text, - "metadata": {**metadata, "document_type": "text", "extraction_method": "fallback_text"} + "metadata": { + **metadata, + "document_type": "text", + "extraction_method": "fallback_text", + }, } - except: + except Exception: # If all else fails, return an error - raise Exception(f"Failed to ingest document with any available method") \ No newline at end of file + raise Exception( + "Failed to ingest document with any available method" + ) diff --git a/app/services/rag/ingestor/ingestor_factory.py b/app/services/rag/ingestor/ingestor_factory.py index 513ae88..c200077 100644 --- a/app/services/rag/ingestor/ingestor_factory.py +++ b/app/services/rag/ingestor/ingestor_factory.py @@ -1,48 +1,55 @@ -from typing import Dict, Any, Optional import logging import mimetypes +from typing import Optional from app.db.models.knowledge_base import DocumentType -from app.services.rag.ingestor.ingestor import CSVIngestor, ImageIngestor, Ingestor, MarkdownIngestor, PDFIngestor, TextIngestor - +from app.services.rag.ingestor.ingestor import ( + CSVIngestor, + ImageIngestor, + Ingestor, + MarkdownIngestor, + PDFIngestor, + TextIngestor, +) logger = logging.getLogger(__name__) + class IngestorFactory: """ Factory for creating ingestors based on document type. - + Uses a singleton pattern to ensure ingestors are only initialized once, which helps prevent issues in multiprocessing environments. """ - + # Singleton instances for each ingestor type _pdf_ingestor: Optional[Ingestor] = None _csv_ingestor: Optional[Ingestor] = None _markdown_ingestor: Optional[Ingestor] = None _image_ingestor: Optional[Ingestor] = None _text_ingestor: Optional[Ingestor] = None - + @staticmethod def create_ingestor(content_type: DocumentType) -> Ingestor: """ Create or return a singleton ingestor based on content type. - + Args: content_type: MIME type of the document - + Returns: Ingestor instance - + Raises: ValueError: If no ingestor is available for the content type """ try: logger.info(f"Getting ingestor for content type: {content_type}") - + # Normalize content type content_type = content_type.lower() - + # Map content types to ingestors if content_type == DocumentType.PDF: if IngestorFactory._pdf_ingestor is None: @@ -51,7 +58,7 @@ def create_ingestor(content_type: DocumentType) -> Ingestor: else: logger.info("Returning existing PDFIngestor") return IngestorFactory._pdf_ingestor - + elif content_type in [DocumentType.CSV, DocumentType.EXCEL]: if IngestorFactory._csv_ingestor is None: logger.info("Creating new CSVIngestor") @@ -59,7 +66,7 @@ def create_ingestor(content_type: DocumentType) -> Ingestor: else: logger.info("Returning existing CSVIngestor") return IngestorFactory._csv_ingestor - + elif content_type in [DocumentType.MARKDOWN, DocumentType.MD]: if IngestorFactory._markdown_ingestor is None: logger.info("Creating new MarkdownIngestor") @@ -67,95 +74,108 @@ def create_ingestor(content_type: DocumentType) -> Ingestor: else: logger.info("Returning existing MarkdownIngestor") return IngestorFactory._markdown_ingestor - - elif content_type in [DocumentType.JPG, DocumentType.PNG, DocumentType.GIF, DocumentType.TIFF]: + + elif content_type in [ + DocumentType.JPG, + DocumentType.PNG, + DocumentType.GIF, + DocumentType.TIFF, + ]: if IngestorFactory._image_ingestor is None: logger.info("Creating new ImageIngestor") IngestorFactory._image_ingestor = ImageIngestor() else: logger.info("Returning existing ImageIngestor") return IngestorFactory._image_ingestor - - elif content_type in [DocumentType.TXT, DocumentType.DOC, DocumentType.DOCX]: + + elif content_type in [ + DocumentType.TXT, + DocumentType.DOC, + DocumentType.DOCX, + ]: if IngestorFactory._text_ingestor is None: logger.info("Creating new TextIngestor") IngestorFactory._text_ingestor = TextIngestor() else: logger.info("Returning existing TextIngestor") return IngestorFactory._text_ingestor - + else: # Default to text ingestor if IngestorFactory._text_ingestor is None: - logger.warning(f"No specific ingestor for content type {content_type}, creating new TextIngestor") + logger.warning( + f"No specific ingestor for content type {content_type}, creating new TextIngestor" + ) IngestorFactory._text_ingestor = TextIngestor() else: - logger.warning(f"No specific ingestor for content type {content_type}, returning existing TextIngestor") + logger.warning( + f"No specific ingestor for content type {content_type}, returning existing TextIngestor" + ) return IngestorFactory._text_ingestor - + except Exception as e: logger.error(f"Failed to create ingestor: {e}", exc_info=True) raise - + @staticmethod def create_ingestor_from_filename(filename: str) -> Ingestor: """ Create or return a singleton ingestor based on filename extension. - + Args: filename: Name of the file - + Returns: Ingestor instance """ try: logger.info(f"Getting ingestor for file: {filename}") - + # Guess content type from filename content_type, _ = mimetypes.guess_type(filename) - + if not content_type: # Try to determine from extension - extension = filename.split('.')[-1].lower() if '.' in filename else '' - - if extension == 'pdf': - content_type = 'application/pdf' - elif extension == 'csv': - content_type = 'text/csv' - elif extension in ['md', 'markdown']: - content_type = 'text/markdown' - elif extension in ['jpg', 'jpeg', 'png', 'gif', 'bmp']: - content_type = f'image/{extension}' - elif extension in ['txt', 'text']: - content_type = 'text/plain' + extension = filename.split(".")[-1].lower() if "." in filename else "" + + if extension == "pdf": + content_type = "application/pdf" + elif extension == "csv": + content_type = "text/csv" + elif extension in ["md", "markdown"]: + content_type = "text/markdown" + elif extension in ["jpg", "jpeg", "png", "gif", "bmp"]: + content_type = f"image/{extension}" + elif extension in ["txt", "text"]: + content_type = "text/plain" else: - content_type = 'text/plain' # Default - + content_type = "text/plain" # Default + logger.info(f"Determined content type: {content_type}") - + # Create ingestor based on content type return IngestorFactory.create_ingestor(content_type) - + except Exception as e: logger.error(f"Failed to create ingestor from filename: {e}", exc_info=True) raise - + @staticmethod def initialize_ingestors() -> None: """ Pre-initialize all ingestor types at startup. - + This method should be called during application or worker startup to ensure ingestors are initialized in the main process before any forking occurs. """ logger.info("Pre-initializing ingestors...") - + # Initialize all ingestor types IngestorFactory.create_ingestor(DocumentType.PDF) IngestorFactory.create_ingestor(DocumentType.CSV) IngestorFactory.create_ingestor(DocumentType.MARKDOWN) IngestorFactory.create_ingestor(DocumentType.JPG) IngestorFactory.create_ingestor(DocumentType.TXT) - - logger.info("Ingestors pre-initialized successfully") \ No newline at end of file + + logger.info("Ingestors pre-initialized successfully") diff --git a/app/services/rag/reranker/__init__.py b/app/services/rag/reranker/__init__.py index c965466..f4157ea 100644 --- a/app/services/rag/reranker/__init__.py +++ b/app/services/rag/reranker/__init__.py @@ -5,28 +5,20 @@ by reordering them based on their relevance to the query. """ -from app.services.rag.reranker.reranker import ( - Reranker, - CrossEncoderReranker, -) -from app.services.rag.reranker.reranker_factory import ( - RerankerFactory, - create_reranker -) +from app.services.rag.reranker.flag_reranker import FlagEmbeddingReranker from app.services.rag.reranker.pinecone_reranker import ( + PINECONE_AVAILABLE, PineconeReranker, - PINECONE_AVAILABLE -) -from app.services.rag.reranker.flag_reranker import ( - FlagEmbeddingReranker, ) +from app.services.rag.reranker.reranker import CrossEncoderReranker, Reranker +from app.services.rag.reranker.reranker_factory import RerankerFactory, create_reranker __all__ = [ - 'Reranker', - 'CrossEncoderReranker', - 'PineconeReranker', - 'RerankerFactory', - 'create_reranker', - 'PINECONE_AVAILABLE', - 'FlagEmbeddingReranker' -] \ No newline at end of file + "Reranker", + "CrossEncoderReranker", + "PineconeReranker", + "RerankerFactory", + "create_reranker", + "PINECONE_AVAILABLE", + "FlagEmbeddingReranker", +] diff --git a/app/services/rag/reranker/flag_reranker.py b/app/services/rag/reranker/flag_reranker.py index 734f5f6..ab98530 100644 --- a/app/services/rag/reranker/flag_reranker.py +++ b/app/services/rag/reranker/flag_reranker.py @@ -1,17 +1,24 @@ +import logging from typing import Any, Dict, List + from FlagEmbedding import FlagReranker + from app.services.rag.reranker.reranker import Reranker -import logging logger = logging.getLogger(__name__) + class FlagEmbeddingReranker(Reranker): def __init__(self, model_name: str = "BAAI/bge-reranker-v2-gemma"): # Force CPU usage for the model to avoid MPS issues self.reranker = FlagReranker(model_name, use_fp16=False) - logger.info(f"Initialized FlagEmbeddingReranker with model: {model_name} (using CPU only)") + logger.info( + f"Initialized FlagEmbeddingReranker with model: {model_name} (using CPU only)" + ) - async def rerank(self, query: str, chunks: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]: + async def rerank( + self, query: str, chunks: List[Dict[str, Any]], top_k: int = 5 + ) -> List[Dict[str, Any]]: try: logger.info(f"Reranking {len(chunks)} chunks with query: {query}") pairs = [(query, chunk["content"]) for chunk in chunks] @@ -27,16 +34,22 @@ async def rerank(self, query: str, chunks: List[Dict[str, Any]], top_k: int = 5) chunks[i]["score"] = float(score) # Sort by rerank score - reranked_chunks = sorted(chunks, key=lambda x: x["rerank_score"], reverse=True) - + reranked_chunks = sorted( + chunks, key=lambda x: x["rerank_score"], reverse=True + ) + # Limit to top_k result = reranked_chunks[:top_k] - - logger.info(f"Reranking complete. Top score: {result[0]['rerank_score'] if result else 0}") - + + logger.info( + f"Reranking complete. Top score: {result[0]['rerank_score'] if result else 0}" + ) + return result except Exception as e: logger.error(f"Failed to rerank chunks: {e}", exc_info=True) # Return the original chunks sorted by their original scores as fallback - sorted_chunks = sorted(chunks, key=lambda x: x.get("score", 0), reverse=True) - return sorted_chunks[:top_k] \ No newline at end of file + sorted_chunks = sorted( + chunks, key=lambda x: x.get("score", 0), reverse=True + ) + return sorted_chunks[:top_k] diff --git a/app/services/rag/reranker/pinecone_reranker.py b/app/services/rag/reranker/pinecone_reranker.py index 98c6d99..bf8eb66 100644 --- a/app/services/rag/reranker/pinecone_reranker.py +++ b/app/services/rag/reranker/pinecone_reranker.py @@ -1,13 +1,15 @@ -from typing import Any, Dict, List, Optional -from app.services.rag.reranker.reranker import Reranker import logging +from typing import Any, Dict, List, Optional + from app.core.config import settings +from app.services.rag.reranker.reranker import Reranker logger = logging.getLogger(__name__) # Try importing pinecone for the PineconeReranker try: from pinecone import Pinecone + PINECONE_AVAILABLE = True except ImportError: PINECONE_AVAILABLE = False @@ -17,18 +19,18 @@ class PineconeReranker(Reranker): """ Reranker implementation that uses Pinecone's Cohere rerank-3.5 model. - + This reranker offers: - Improved performance for complex queries with constraints - Multilingual support for over 100+ languages - SOTA performance in domains like finance, hospitality, and more - Context length of 4096 tokens """ - + def __init__(self, model_name: str = "cohere-rerank-3.5"): """ Initialize the PineconeReranker with an API key. - + Args: api_key: Pinecone API key model_name: Name of the reranking model to use (default: "cohere-rerank-3.5") @@ -38,7 +40,7 @@ def __init__(self, model_name: str = "cohere-rerank-3.5"): "Pinecone package is not installed. " "Please install it with 'pip install -U pinecone'" ) - + try: logger.info(f"Initializing PineconeReranker with model {model_name}") self.pc = Pinecone(api_key=settings.PINECONE_API_KEY) @@ -47,21 +49,18 @@ def __init__(self, model_name: str = "cohere-rerank-3.5"): except Exception as e: logger.error(f"Failed to initialize PineconeReranker: {e}", exc_info=True) raise - + async def rerank( - self, - query: str, - chunks: List[Dict[str, Any]], - top_k: Optional[int] = None + self, query: str, chunks: List[Dict[str, Any]], top_k: Optional[int] = None ) -> List[Dict[str, Any]]: """ Rerank chunks using Pinecone's Cohere rerank model. - + Args: query: The search query chunks: List of chunks to rerank top_k: Maximum number of chunks to return after reranking - + Returns: Reranked list of chunks """ @@ -69,43 +68,47 @@ async def rerank( if not chunks: logger.warning("No chunks provided for reranking") return [] - + # If top_k is not specified, use the number of chunks if top_k is None: top_k = len(chunks) - - logger.info(f"Reranking {len(chunks)} chunks with Pinecone using query: {query}") - + + logger.info( + f"Reranking {len(chunks)} chunks with Pinecone using query: {query}" + ) + # Extract content from chunks documents = [chunk["content"] for chunk in chunks] - + # Call Pinecone's rerank API results = self.pc.inference.rerank( model=self.model_name, query=query, documents=documents, top_n=top_k, - return_documents=True + return_documents=True, ) - + # Map the reranked results back to the original chunks reranked_chunks = [] for result in results.data: # Find the original chunk that corresponds to this result original_idx = documents.index(result.document.text) original_chunk = chunks[original_idx] - + # Update the scores original_chunk["rerank_score"] = float(result.score) original_chunk["similarity_score"] = original_chunk.get("score", 0.0) original_chunk["score"] = float(result.score) - + reranked_chunks.append(original_chunk) - - logger.info(f"Reranking complete. Top score: {reranked_chunks[0]['rerank_score'] if reranked_chunks else 0}") - + + logger.info( + f"Reranking complete. Top score: {reranked_chunks[0]['rerank_score'] if reranked_chunks else 0}" + ) + return reranked_chunks - + except Exception as e: logger.error(f"Failed to rerank chunks with Pinecone: {e}", exc_info=True) # Return the original chunks if reranking fails diff --git a/app/services/rag/reranker/reranker.py b/app/services/rag/reranker/reranker.py index ef081e3..4fbfd95 100644 --- a/app/services/rag/reranker/reranker.py +++ b/app/services/rag/reranker/reranker.py @@ -1,9 +1,10 @@ -from abc import ABC, abstractmethod -import math -from typing import List, Dict, Any, Optional import logging -import os +import math import multiprocessing +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + import torch from sentence_transformers import CrossEncoder @@ -12,7 +13,7 @@ # Configure multiprocessing to use 'spawn' instead of 'fork' # This can help prevent segmentation faults in multiprocessing environments try: - multiprocessing.set_start_method('spawn', force=True) + multiprocessing.set_start_method("spawn", force=True) except RuntimeError: # If already set, this will raise a RuntimeError pass @@ -24,76 +25,75 @@ os.environ["MPS_VISIBLE_DEVICES"] = "" # Disable MPS for PyTorch os.environ["CUDA_VISIBLE_DEVICES"] = "" # Disable CUDA if present + class Reranker(ABC): """ Abstract base class for rerankers that reorder retrieved chunks based on their relevance to the query. """ - + @abstractmethod async def rerank( - self, - query: str, - chunks: List[Dict[str, Any]], - top_k: Optional[int] = None + self, query: str, chunks: List[Dict[str, Any]], top_k: Optional[int] = None ) -> List[Dict[str, Any]]: """ Rerank chunks based on their relevance to the query. - + Args: query: The search query chunks: List of chunks to rerank top_k: Maximum number of chunks to return after reranking - + Returns: Reranked list of chunks """ - pass + class CrossEncoderReranker(Reranker): """ Reranker implementation that uses a cross-encoder model to rerank chunks. """ - + def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): """ Initialize the CrossEncoderReranker with a model. - + Args: model_name: Name of the cross-encoder model to use """ try: logger.info(f"Initializing CrossEncoderReranker with model {model_name}") - + # Force CPU usage to avoid segmentation faults if torch.cuda.is_available(): - logger.info("CUDA is available but forcing CPU usage to avoid segmentation faults") - + logger.info( + "CUDA is available but forcing CPU usage to avoid segmentation faults" + ) + # Explicitly set device to CPU self.device = "cpu" - + # Initialize the model with device="cpu" to force CPU usage self.model = CrossEncoder(model_name, device=self.device) - + logger.info(f"CrossEncoderReranker initialized on device: {self.device}") except Exception as e: - logger.error(f"Failed to initialize CrossEncoderReranker: {e}", exc_info=True) + logger.error( + f"Failed to initialize CrossEncoderReranker: {e}", exc_info=True + ) raise - + async def rerank( - self, - query: str, - chunks: List[Dict[str, Any]], - top_k: Optional[int] = None + self, query: str, chunks: List[Dict[str, Any]], top_k: Optional[int] = None ) -> List[Dict[str, Any]]: """ Rerank chunks using the cross-encoder model. - + Args: query: The search query chunks: List of chunks to rerank top_k: Maximum number of chunks to return after reranking - + Returns: Reranked list of chunks """ @@ -101,16 +101,16 @@ async def rerank( if not chunks: logger.warning("No chunks provided for reranking") return [] - + # If top_k is not specified, use the number of chunks if top_k is None: top_k = len(chunks) - + logger.info(f"Reranking {len(chunks)} chunks with query: {query}") - + # Prepare pairs for cross-encoder pairs = [(query, chunk["content"]) for chunk in chunks] - + # Get scores from cross-encoder with explicit batch size to avoid memory issues # and show_progress_bar=False to avoid tqdm issues in multiprocessing logger.info(f"self.device: {self.device}") @@ -130,19 +130,23 @@ async def rerank( # Update the main score to the rerank score chunks[i]["score"] = float(score) logger.info(f"Chunk {i}: {chunks[i]}") - + # Sort by rerank score - reranked_chunks = sorted(chunks, key=lambda x: x["rerank_score"], reverse=True) - + reranked_chunks = sorted( + chunks, key=lambda x: x["rerank_score"], reverse=True + ) + # Limit to top_k result = reranked_chunks[:top_k] - - logger.info(f"Reranking complete. Top score: {result[0]['rerank_score'] if result else 0}") - + + logger.info( + f"Reranking complete. Top score: {result[0]['rerank_score'] if result else 0}" + ) + return result - + except Exception as e: logger.error(f"Failed to rerank chunks: {e}", exc_info=True) # Return the original chunks if reranking fails logger.warning("Returning original chunks due to reranking failure") - return chunks[:top_k] if top_k is not None else chunks \ No newline at end of file + return chunks[:top_k] if top_k is not None else chunks diff --git a/app/services/rag/reranker/reranker_factory.py b/app/services/rag/reranker/reranker_factory.py index 1943d65..83513fb 100644 --- a/app/services/rag/reranker/reranker_factory.py +++ b/app/services/rag/reranker/reranker_factory.py @@ -1,56 +1,52 @@ -from typing import Dict, Any, Optional import logging +from typing import Any, Dict, Optional -from app.services.rag.reranker.reranker import ( - Reranker, - CrossEncoderReranker, -) +from app.services.rag.reranker.flag_reranker import FlagEmbeddingReranker from app.services.rag.reranker.pinecone_reranker import ( + PINECONE_AVAILABLE, PineconeReranker, - PINECONE_AVAILABLE -) -from app.services.rag.reranker.flag_reranker import ( - FlagEmbeddingReranker, ) +from app.services.rag.reranker.reranker import CrossEncoderReranker, Reranker logger = logging.getLogger(__name__) + class RerankerFactory: """ Factory class for creating reranker instances based on configuration. - + This factory provides a centralized way to create different types of rerankers with appropriate configuration and fallback mechanisms. - + It uses a singleton pattern to ensure models are only initialized once, which helps prevent segmentation faults in multiprocessing environments. """ - + # Singleton instances for each reranker type _pinecone_instance: Optional[Reranker] = None _cross_encoder_instance: Optional[Reranker] = None _flag_embedding_instance: Optional[Reranker] = None - + @staticmethod def create(config: Dict[str, Any]) -> Reranker: """ Create or return a singleton reranker based on configuration. - + Args: config: Configuration dictionary with reranker settings - type: Type of reranker to create ("pinecone", "cross_encoder", etc.) - model_name: Model name for the reranker - api_key: API key for services like Pinecone - + Returns: An instance of a Reranker - + Raises: ValueError: If required configuration is missing """ reranker_type = config.get("type", "cross_encoder").lower() logger.info(f"Getting reranker of type: {reranker_type}") - + # Create the appropriate reranker based on type if reranker_type == "pinecone": return RerankerFactory._create_pinecone_reranker(config) @@ -59,17 +55,19 @@ def create(config: Dict[str, Any]) -> Reranker: elif reranker_type == "flag": return RerankerFactory._create_flag_embedding_reranker(config) else: - logger.warning(f"Unknown reranker type: {reranker_type}. Falling back to CrossEncoderReranker.") + logger.warning( + f"Unknown reranker type: {reranker_type}. Falling back to CrossEncoderReranker." + ) return RerankerFactory._create_cross_encoder_reranker(config) - + @staticmethod def _create_pinecone_reranker(config: Dict[str, Any]) -> Reranker: """ Create a PineconeReranker instance or return existing singleton. - + Args: config: Configuration for the PineconeReranker - + Returns: A PineconeReranker instance or fallback to CrossEncoderReranker if Pinecone is not available """ @@ -77,26 +75,28 @@ def _create_pinecone_reranker(config: Dict[str, Any]) -> Reranker: if RerankerFactory._pinecone_instance is not None: logger.info("Returning existing PineconeReranker instance") return RerankerFactory._pinecone_instance - + if not PINECONE_AVAILABLE: - logger.warning("Pinecone not installed. Falling back to CrossEncoderReranker.") + logger.warning( + "Pinecone not installed. Falling back to CrossEncoderReranker." + ) return RerankerFactory._create_cross_encoder_reranker(config) - + model_name = config.get("model_name", "cohere-rerank-3.5") logger.info(f"Creating new PineconeReranker with model: {model_name}") - + # Create and store the instance RerankerFactory._pinecone_instance = PineconeReranker(model_name=model_name) return RerankerFactory._pinecone_instance - + @staticmethod def _create_cross_encoder_reranker(config: Dict[str, Any]) -> Reranker: """ Create a CrossEncoderReranker instance or return existing singleton. - + Args: config: Configuration for the CrossEncoderReranker - models: + models: - cross-encoder/ms-marco-MiniLM-L-6-v2 - BAAI/bge-reranker-v2-m3 Returns: @@ -109,16 +109,18 @@ def _create_cross_encoder_reranker(config: Dict[str, Any]) -> Reranker: model_name = config.get("model_name", "BAAI/bge-reranker-v2-m3") logger.info(f"Creating new CrossEncoderReranker with model: {model_name}") - + # Create and store the instance - RerankerFactory._cross_encoder_instance = CrossEncoderReranker(model_name=model_name) + RerankerFactory._cross_encoder_instance = CrossEncoderReranker( + model_name=model_name + ) return RerankerFactory._cross_encoder_instance - + @staticmethod def _create_flag_embedding_reranker(config: Dict[str, Any]) -> Reranker: """ Create a FlagEmbeddingReranker instance or return existing singleton. - + Args: config: Configuration for the FlagEmbeddingReranker models: @@ -131,53 +133,58 @@ def _create_flag_embedding_reranker(config: Dict[str, Any]) -> Reranker: if RerankerFactory._flag_embedding_instance is not None: logger.info("Returning existing FlagEmbeddingReranker instance") return RerankerFactory._flag_embedding_instance - + model_name = config.get("model_name", "BAAI/bge-reranker-v2-m3") logger.info(f"Creating new FlagEmbeddingReranker with model: {model_name}") - + # Create and store the instance - RerankerFactory._flag_embedding_instance = FlagEmbeddingReranker(model_name=model_name) + RerankerFactory._flag_embedding_instance = FlagEmbeddingReranker( + model_name=model_name + ) return RerankerFactory._flag_embedding_instance @staticmethod def initialize_models(config: Dict[str, Any] = None) -> None: """ Pre-initialize all reranker models at startup. - + This method should be called during application or worker startup to ensure models are initialized in the main process before any forking occurs. - + Args: config: Optional configuration dictionary """ if config is None: config = {} - + logger.info("Pre-initializing reranker models...") - + # Initialize the default reranker (usually flag or cross-encoder) default_type = config.get("type", "flag") if default_type == "flag": RerankerFactory._create_flag_embedding_reranker(config) else: RerankerFactory._create_cross_encoder_reranker(config) - + logger.info("Reranker models pre-initialized successfully") + # Function for backward compatibility def create_reranker(config: Dict[str, Any]) -> Reranker: """ Factory function to create a reranker based on configuration. - + This function is maintained for backward compatibility. New code should use RerankerFactory.create() instead. - + Args: config: Configuration dictionary with reranker settings - + Returns: An instance of a Reranker """ - logger.info("Using create_reranker (deprecated). Consider using RerankerFactory.create() instead.") - return RerankerFactory.create(config) \ No newline at end of file + logger.info( + "Using create_reranker (deprecated). Consider using RerankerFactory.create() instead." + ) + return RerankerFactory.create(config) diff --git a/app/services/rag/retriever/pinecone_retriever.py b/app/services/rag/retriever/pinecone_retriever.py index 7239b76..dc54bfe 100644 --- a/app/services/rag/retriever/pinecone_retriever.py +++ b/app/services/rag/retriever/pinecone_retriever.py @@ -1,86 +1,101 @@ -from typing import List, Dict, Any, Optional import logging import random +from typing import Any, Dict, List, Optional + from pinecone import Pinecone + from app.core.config import settings -from app.services.rag.retriever.retriever import Retriever from app.services.llm.factory import LLMFactory +from app.services.rag.retriever.retriever import Retriever logger = logging.getLogger(__name__) + class PineconeRetriever(Retriever): """ Retriever implementation that uses Pinecone as the vector store. Uses the knowledge base ID as the namespace in Pinecone. """ - + def __init__(self, knowledge_base_id: str): """ Initialize the PineconeRetriever with a knowledge base ID. - + Args: knowledge_base_id: The ID of the knowledge base this retriever will work with """ super().__init__(knowledge_base_id) - + # Initialize Pinecone self.pc = Pinecone(api_key=settings.PINECONE_API_KEY) self.index_name = settings.PINECONE_INDEX_NAME self.index = self.pc.Index(self.index_name) - + # Vector dimension for embeddings self.dimension = 768 # Dimension for text-embedding-004 - + async def add_chunks(self, chunks: List[Dict[str, Any]]) -> None: """ Add document chunks to Pinecone. - + Args: chunks: List of dictionaries containing: - content: str - metadata: Dict containing document_id, chunk_index, metadata, etc. """ try: - logger.info(f"Starting to process {len(chunks)} chunks for knowledge base {self.knowledge_base_id}") - + logger.info( + f"Starting to process {len(chunks)} chunks for knowledge base {self.knowledge_base_id}" + ) + # Generate embeddings for chunks vectors = [] for i, chunk in enumerate(chunks): # Get embedding - document_id = str(chunk['metadata']['document_id']) - logger.info(f"Generating embedding for chunk {i+1}/{len(chunks)} (doc_id: {document_id}) using LLM Factory") - embedding = await self._get_embedding(chunk['content']) + document_id = str(chunk["metadata"]["document_id"]) + logger.info( + f"Generating embedding for chunk {i+1}/{len(chunks)} (doc_id: {document_id}) using LLM Factory" + ) + embedding = await self._get_embedding(chunk["content"]) logger.info(f"Generated embedding with dimension {len(embedding)}") - + # Store content and metadata separately for Pinecone metadata = { - 'document_id': document_id, - 'chunk_index': int(chunk['metadata']['chunk_index']), - 'chunk_size': str(chunk['metadata']['chunk_size']), - 'doc_title': str(chunk['metadata']['document_title']), - 'doc_type': str(chunk['metadata']['document_type']), - 'section': str(chunk['metadata']['nearest_header']), - 'path': ','.join(str(x) for x in chunk['metadata']['section_path']), - 'content': str(chunk['content']) + "document_id": document_id, + "chunk_index": int(chunk["metadata"]["chunk_index"]), + "chunk_size": str(chunk["metadata"]["chunk_size"]), + "doc_title": str(chunk["metadata"]["document_title"]), + "doc_type": str(chunk["metadata"]["document_type"]), + "section": str(chunk["metadata"]["nearest_header"]), + "path": ",".join(str(x) for x in chunk["metadata"]["section_path"]), + "content": str(chunk["content"]), } - + # Create vector record with unique ID - vector_id = f"{document_id}_{metadata['chunk_index']}_{metadata['chunk_size']}" - vectors.append({ - 'id': vector_id, - 'values': [float(x) for x in embedding], - 'metadata': metadata - }) - logger.info(f"Created vector record for chunk {i+1} with id {vector_id}") - + vector_id = ( + f"{document_id}_{metadata['chunk_index']}_{metadata['chunk_size']}" + ) + vectors.append( + { + "id": vector_id, + "values": [float(x) for x in embedding], + "metadata": metadata, + } + ) + logger.info( + f"Created vector record for chunk {i+1} with id {vector_id}" + ) + # Upsert vectors in batches of 100 batch_size = 100 total_batches = (len(vectors) + batch_size - 1) // batch_size - + for i in range(0, len(vectors), batch_size): - batch = vectors[i:i + batch_size] + batch = vectors[i : i + batch_size] batch_num = (i // batch_size) + 1 - logger.info(f"Upserting batch {batch_num}/{total_batches} ({len(batch)} vectors)") + logger.info( + f"Upserting batch {batch_num}/{total_batches} ({len(batch)} vectors)" + ) try: # Use knowledge_base_id as namespace self.index.upsert(vectors=batch, namespace=self.knowledge_base_id) @@ -91,90 +106,110 @@ async def add_chunks(self, chunks: List[Dict[str, Any]]) -> None: if batch: logger.info(f"Sample vector from failing batch: {batch[0]}") raise - - logger.info(f"Successfully added {len(chunks)} chunks to Pinecone for knowledge base {self.knowledge_base_id}") - + + logger.info( + f"Successfully added {len(chunks)} chunks to Pinecone for knowledge base {self.knowledge_base_id}" + ) + except Exception as e: logger.error(f"Failed to add chunks to Pinecone: {e}", exc_info=True) raise - + async def delete_document_chunks(self, document_id: str) -> None: """ Delete all chunks for a document from Pinecone. - + Args: document_id: ID of the document to delete """ try: - logger.info(f"Attempting to delete chunks for document {document_id} in knowledge base {self.knowledge_base_id}") - + logger.info( + f"Attempting to delete chunks for document {document_id} in knowledge base {self.knowledge_base_id}" + ) + # First, try to use metadata filtering (works on Standard and Enterprise tiers) try: self.index.delete( filter={"document_id": {"$eq": str(document_id)}}, - namespace=self.knowledge_base_id + namespace=self.knowledge_base_id, + ) + logger.info( + f"Successfully deleted chunks using metadata filter for document {document_id}" ) - logger.info(f"Successfully deleted chunks using metadata filter for document {document_id}") return except Exception as e: # If metadata filtering fails (Serverless and Starter tiers), use vector IDs - if "Serverless and Starter indexes do not support deleting with metadata filtering" in str(e): - logger.info("Pinecone Serverless/Starter tier detected, switching to ID-based deletion") - + if ( + "Serverless and Starter indexes do not support deleting with metadata filtering" + in str(e) + ): + logger.info( + "Pinecone Serverless/Starter tier detected, switching to ID-based deletion" + ) + # Query to get all vectors for this document # We need to use a dummy vector for the query dummy_vector = [0.0] * 768 # Dimension for text-embedding-004 - + # Query with a high top_k to get all vectors for this document results = self.index.query( vector=dummy_vector, top_k=10000, # Set a high limit to get all vectors include_metadata=True, - namespace=self.knowledge_base_id + namespace=self.knowledge_base_id, ) - + # Filter results to only include vectors for this document vector_ids = [] for match in results.matches: - if match.metadata and match.metadata.get('document_id') == str(document_id): + if match.metadata and match.metadata.get("document_id") == str( + document_id + ): vector_ids.append(match.id) - + if vector_ids: - logger.info(f"Found {len(vector_ids)} vectors to delete for document {document_id}") + logger.info( + f"Found {len(vector_ids)} vectors to delete for document {document_id}" + ) # Delete vectors by ID self.index.delete( - ids=vector_ids, - namespace=self.knowledge_base_id + ids=vector_ids, namespace=self.knowledge_base_id + ) + logger.info( + f"Successfully deleted {len(vector_ids)} vectors by ID for document {document_id}" ) - logger.info(f"Successfully deleted {len(vector_ids)} vectors by ID for document {document_id}") else: - logger.info(f"No vectors found for document {document_id} in knowledge base {self.knowledge_base_id}") + logger.info( + f"No vectors found for document {document_id} in knowledge base {self.knowledge_base_id}" + ) else: # If it's a different error, re-raise it raise - - logger.info(f"Successfully deleted chunks for document {document_id} in knowledge base {self.knowledge_base_id}") - + + logger.info( + f"Successfully deleted chunks for document {document_id} in knowledge base {self.knowledge_base_id}" + ) + except Exception as e: logger.error(f"Failed to delete document chunks: {e}", exc_info=True) raise - + async def search( self, query: str, top_k: int = 5, similarity_threshold: float = 0.3, - metadata_filter: Optional[Dict[str, Any]] = None + metadata_filter: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """ Search for similar chunks in Pinecone. - + Args: query: The search query top_k: Maximum number of results to return similarity_threshold: Minimum similarity score for results metadata_filter: Optional filter to apply to the search - + Returns: List of dictionaries containing chunk data """ @@ -183,19 +218,19 @@ async def search( logger.info(f"Query: {query}") logger.info(f"Knowledge Base: {self.knowledge_base_id}") logger.info(f"Limit: {top_k}, Threshold: {similarity_threshold}") - + # Get query embedding logger.info("Generating query embedding using LLM Factory") query_vector = await self._get_embedding(query) logger.info(f"Generated embedding with dimension {len(query_vector)}") - + # Prepare filter filter_dict = {} - + # Handle metadata filters if metadata_filter: logger.info(f"Original metadata filter: {metadata_filter}") - + # All metadata fields are already flattened in Pinecone storage # Just add to filter, ensuring all values are strings for key, value in metadata_filter.items(): @@ -204,24 +239,24 @@ async def search( filter_dict[key] = value elif isinstance(value, (list, tuple)): # Handle list values (like section_path) by joining with commas - filter_dict[key] = ','.join(str(x) for x in value) + filter_dict[key] = ",".join(str(x) for x in value) else: # Convert all other values to strings filter_dict[key] = str(value) - + logger.info(f"Final Pinecone filter: {filter_dict}") - + # Search in Pinecone with filter and namespace results = self.index.query( vector=query_vector, filter=filter_dict if filter_dict else None, top_k=top_k, include_metadata=True, - namespace=self.knowledge_base_id + namespace=self.knowledge_base_id, ) - + logger.info(f"Found {len(results.matches)} initial matches") - + # Process matches chunks = [] filtered_out = 0 @@ -231,51 +266,59 @@ async def search( # Get metadata safely metadata = match.metadata or {} logger.info(f"Metadata: {metadata}") - + # Build chunk with required fields chunk = { - 'id': match.id, - 'score': float(match.score), - 'document_id': str(metadata.get('document_id', '')), - 'content': str(metadata.get('content', '')), - 'chunk_index': int(metadata.get('chunk_index', 0)), - 'title': str(metadata.get('doc_title', 'Untitled')), - 'metadata': { - 'document_id': str(metadata.get('document_id', '')), - 'chunk_index': int(metadata.get('chunk_index', 0)), - 'chunk_size': str(metadata.get('chunk_size', '')), - 'doc_title': str(metadata.get('doc_title', '')), - 'doc_type': str(metadata.get('doc_type', '')), - 'section': str(metadata.get('section', '')), - 'path': metadata.get('path', '').split(',') if metadata.get('path') else [] - } + "id": match.id, + "score": float(match.score), + "document_id": str(metadata.get("document_id", "")), + "content": str(metadata.get("content", "")), + "chunk_index": int(metadata.get("chunk_index", 0)), + "title": str(metadata.get("doc_title", "Untitled")), + "metadata": { + "document_id": str(metadata.get("document_id", "")), + "chunk_index": int(metadata.get("chunk_index", 0)), + "chunk_size": str(metadata.get("chunk_size", "")), + "doc_title": str(metadata.get("doc_title", "")), + "doc_type": str(metadata.get("doc_type", "")), + "section": str(metadata.get("section", "")), + "path": ( + metadata.get("path", "").split(",") + if metadata.get("path") + else [] + ), + }, } logger.info(f"Chunk: {chunk}") - + # Only skip if absolutely necessary - if not chunk['content']: - logger.warning(f"Skipping chunk with empty content") + if not chunk["content"]: + logger.warning("Skipping chunk with empty content") continue - + chunks.append(chunk) logger.info(f"Included chunk with score {match.score:.3f}") - + except Exception as chunk_error: logger.error(f"Error processing chunk: {chunk_error}") logger.info(f"Problematic metadata: {match.metadata}") continue else: filtered_out += 1 - logger.info(f"Filtered out chunk with score {match.score:.3f} (below threshold)") - + logger.info( + f"Filtered out chunk with score {match.score:.3f} (below threshold)" + ) + # Sort chunks by score - chunks.sort(key=lambda x: x['score'], reverse=True) + chunks.sort(key=lambda x: x["score"], reverse=True) final_chunks = chunks[:top_k] - + logger.info(f"Returning {len(final_chunks)} total chunks") if final_chunks: - logger.info(f"Final score range: {min(c['score'] for c in final_chunks):.3f} - {max(c['score'] for c in final_chunks):.3f}") - + logger.info( + f"Final score range: {min(c['score'] for c in final_chunks):.3f} - {max(c['score'] for c in final_chunks):.3f}" + ) + # Log sample content from top chunk top_chunk = final_chunks[0] logger.info("Top chunk preview:") @@ -283,102 +326,112 @@ async def search( logger.info(f"Document ID: {top_chunk['document_id']}") logger.info(f"Title: {top_chunk['title']}") logger.info(f"Content: {top_chunk['content'][:200]}...") - + return final_chunks - + except Exception as e: logger.error(f"Failed to search in Pinecone: {e}") - logger.error("Search parameters:", extra={ - 'query': query, - 'knowledge_base_id': self.knowledge_base_id, - 'limit': top_k, - 'threshold': similarity_threshold, - 'filter': metadata_filter - }) + logger.error( + "Search parameters:", + extra={ + "query": query, + "knowledge_base_id": self.knowledge_base_id, + "limit": top_k, + "threshold": similarity_threshold, + "filter": metadata_filter, + }, + ) raise - + async def get_random_chunks(self, limit: int = 5) -> List[Dict[str, Any]]: """ Get random chunks from Pinecone. - + Args: limit: Maximum number of chunks to return - + Returns: List of dictionaries containing chunk data """ try: - logger.info(f"Fetching random chunks from knowledge base {self.knowledge_base_id}") - + logger.info( + f"Fetching random chunks from knowledge base {self.knowledge_base_id}" + ) + # Fetch a larger sample to select random chunks from sample_size = min(limit * 5, 100) # Get more than we need, but not too many - + # Create a random vector to fetch diverse results random_vector = [random.uniform(-1, 1) for _ in range(self.dimension)] - + # Query with the random vector results = self.index.query( vector=random_vector, top_k=sample_size, include_metadata=True, - namespace=self.knowledge_base_id + namespace=self.knowledge_base_id, ) - + if not results.matches: - logger.warning(f"No chunks found in knowledge base {self.knowledge_base_id}") + logger.warning( + f"No chunks found in knowledge base {self.knowledge_base_id}" + ) return [] - + # Shuffle the results to randomize further matches = list(results.matches) random.shuffle(matches) - + # Take only the requested number of chunks selected_matches = matches[:limit] - + # Process the selected matches chunks = [] for match in selected_matches: metadata = match.metadata or {} - + chunk = { - 'id': match.id, - 'document_id': str(metadata.get('document_id', '')), - 'title': str(metadata.get('doc_title', 'Untitled')), - 'content': str(metadata.get('content', '')), - 'chunk_index': int(metadata.get('chunk_index', 0)), - 'metadata': metadata + "id": match.id, + "document_id": str(metadata.get("document_id", "")), + "title": str(metadata.get("doc_title", "Untitled")), + "content": str(metadata.get("content", "")), + "chunk_index": int(metadata.get("chunk_index", 0)), + "metadata": metadata, } - + chunks.append(chunk) - - logger.info(f"Retrieved {len(chunks)} random chunks from knowledge base {self.knowledge_base_id}") + + logger.info( + f"Retrieved {len(chunks)} random chunks from knowledge base {self.knowledge_base_id}" + ) return chunks - + except Exception as e: logger.error(f"Failed to get random chunks: {e}", exc_info=True) return [] - + async def _get_embedding(self, text: str) -> List[float]: """ Get embedding for text using the LLM Factory. - + Args: text: The text to embed - + Returns: List of floats representing the embedding """ try: - logger.info(f"Generating embedding using LLM Factory with model: {settings.EMBEDDING_MODEL}") - + logger.info( + f"Generating embedding using LLM Factory with model: {settings.EMBEDDING_MODEL}" + ) + # Use the LLM Factory for text embeddings embedding = await LLMFactory.embed_text( - text=text, - model=settings.EMBEDDING_MODEL + text=text, model=settings.EMBEDDING_MODEL ) - + return embedding - + except Exception as e: logger.error(f"Failed to get embedding: {e}", exc_info=True) - raise \ No newline at end of file + raise diff --git a/app/services/rag/retriever/retriever.py b/app/services/rag/retriever/retriever.py index cbdd8a0..75b59c2 100644 --- a/app/services/rag/retriever/retriever.py +++ b/app/services/rag/retriever/retriever.py @@ -1,60 +1,59 @@ from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional +from typing import Any, Dict, List, Optional + class Retriever(ABC): """ Abstract base class for retrievers that handle vector storage and retrieval. Each retriever is associated with a specific knowledge base. """ - + def __init__(self, knowledge_base_id: str): """ Initialize the retriever with a knowledge base ID. - + Args: knowledge_base_id: The ID of the knowledge base this retriever will work with """ self.knowledge_base_id = knowledge_base_id - + @abstractmethod async def add_chunks(self, chunks: List[Dict[str, Any]]) -> None: """ Add document chunks to the vector store. - + Args: chunks: List of dictionaries containing: - content: str - metadata: Dict containing document_id, chunk_index, metadata, etc. """ - pass - + @abstractmethod async def delete_document_chunks(self, document_id: str) -> None: """ Delete all chunks for a document from the vector store. - + Args: document_id: ID of the document to delete """ - pass - + @abstractmethod async def search( self, query: str, top_k: int = 5, similarity_threshold: float = 0.3, - metadata_filter: Optional[Dict[str, Any]] = None + metadata_filter: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """ Search for similar chunks in the vector store. - + Args: query: The search query top_k: Maximum number of results to return similarity_threshold: Minimum similarity score for results metadata_filter: Optional filter to apply to the search - + Returns: List of dictionaries containing: - id: str @@ -63,17 +62,15 @@ async def search( - document_id: str - metadata: Dict """ - pass - + @abstractmethod async def get_random_chunks(self, limit: int = 5) -> List[Dict[str, Any]]: """ Get random chunks from the vector store. - + Args: limit: Maximum number of chunks to return - + Returns: List of dictionaries containing chunk data """ - pass \ No newline at end of file diff --git a/app/services/rag/retriever/retriever_factory.py b/app/services/rag/retriever/retriever_factory.py index b4ed158..b173cec 100644 --- a/app/services/rag/retriever/retriever_factory.py +++ b/app/services/rag/retriever/retriever_factory.py @@ -1,49 +1,63 @@ -from typing import Optional import logging -from app.services.rag.retriever.retriever import Retriever -from app.services.rag.retriever.pinecone_retriever import PineconeRetriever +from typing import Optional + from app.core.config import settings +from app.services.rag.retriever.pinecone_retriever import PineconeRetriever +from app.services.rag.retriever.retriever import Retriever logger = logging.getLogger(__name__) + class RetrieverFactory: """ Factory for creating retrievers based on configuration. """ - + @staticmethod - def create_retriever(knowledge_base_id: str, retriever_type: Optional[str] = None) -> Retriever: + def create_retriever( + knowledge_base_id: str, retriever_type: Optional[str] = None + ) -> Retriever: """ Create a retriever based on configuration. - + Args: knowledge_base_id: ID of the knowledge base to retrieve from retriever_type: Type of retriever to create (defaults to configured type) - + Returns: Retriever instance - + Raises: ValueError: If no retriever is available for the specified type """ try: # If no type specified, use the configured type if retriever_type is None: - retriever_type = settings.RETRIEVER_TYPE if hasattr(settings, 'RETRIEVER_TYPE') else "pinecone" - - logger.info(f"Creating retriever of type '{retriever_type}' for knowledge base {knowledge_base_id}") - + retriever_type = ( + settings.RETRIEVER_TYPE + if hasattr(settings, "RETRIEVER_TYPE") + else "pinecone" + ) + + logger.info( + f"Creating retriever of type '{retriever_type}' for knowledge base {knowledge_base_id}" + ) + # Create retriever based on type if retriever_type.lower() == "pinecone": - logger.info(f"Creating PineconeRetriever for knowledge base {knowledge_base_id}") + logger.info( + f"Creating PineconeRetriever for knowledge base {knowledge_base_id}" + ) return PineconeRetriever(knowledge_base_id) - + # Add more retriever types here as needed - + else: - logger.warning(f"Unknown retriever type '{retriever_type}', falling back to PineconeRetriever") + logger.warning( + f"Unknown retriever type '{retriever_type}', falling back to PineconeRetriever" + ) return PineconeRetriever(knowledge_base_id) - + except Exception as e: logger.error(f"Failed to create retriever: {e}", exc_info=True) - raise \ No newline at end of file + raise diff --git a/app/services/rag/vector_store.py b/app/services/rag/vector_store.py index 7d207bb..f1470a9 100644 --- a/app/services/rag/vector_store.py +++ b/app/services/rag/vector_store.py @@ -1,32 +1,42 @@ -from typing import List, Dict, Any, Optional -from pinecone import Pinecone -from app.core.config import settings import logging import random -from abc import ABC, abstractmethod import threading +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +from pinecone import Pinecone + +from app.core.config import settings from app.services.llm.factory import LLMFactory logger = logging.getLogger(__name__) + class VectorStore(ABC): """Abstract base class for vector stores""" - + @abstractmethod - async def add_chunks(self, chunks: List[Dict[str, Any]], knowledge_base_id: str) -> None: + async def add_chunks( + self, chunks: List[Dict[str, Any]], knowledge_base_id: str + ) -> None: """Add document chunks to the vector store""" - pass - + @abstractmethod - async def add_texts(self, texts: List[str], metadatas: List[Dict], ids: List[str], collection_name: str) -> None: + async def add_texts( + self, + texts: List[str], + metadatas: List[Dict], + ids: List[str], + collection_name: str, + ) -> None: """Add texts with metadata to the vector store""" - pass - + @abstractmethod - async def delete_document_chunks(self, document_id: str, knowledge_base_id: str) -> None: + async def delete_document_chunks( + self, document_id: str, knowledge_base_id: str + ) -> None: """Delete all chunks for a document from the vector store""" - pass - + @abstractmethod async def search_similar( self, @@ -34,44 +44,43 @@ async def search_similar( knowledge_base_id: str, limit: int = 5, similarity_threshold: float = 0.3, - metadata_filter: Dict[str, Any] = None + metadata_filter: Dict[str, Any] = None, ) -> List[Dict[str, Any]]: """Search for similar chunks in the vector store""" - pass - + @abstractmethod - async def get_random_chunks(self, knowledge_base_id: str, limit: int = 5) -> List[Dict]: + async def get_random_chunks( + self, knowledge_base_id: str, limit: int = 5 + ) -> List[Dict]: """Get random chunks from a knowledge base""" - pass - + @abstractmethod async def search_chunks( - self, - query: str, - knowledge_base_id: str, - top_k: int = 5, + self, + query: str, + knowledge_base_id: str, + top_k: int = 5, metadata_filter: Optional[Dict] = None, - similarity_threshold: float = 0.3 + similarity_threshold: float = 0.3, ) -> List[Dict]: """Search for chunks based on a query with metadata filtering""" - pass class PineconeVectorStore(VectorStore): """Pinecone implementation of the VectorStore interface""" - + def __init__(self, index_name: str = "docbrain"): """Initialize PineconeVectorStore with specific index name - + Args: index_name: Name of the Pinecone index to use ('docbrain' or 'summary') """ # Initialize Pinecone client self.pc = Pinecone(api_key=settings.PINECONE_API_KEY) - + # Set the index name self.index_name = index_name - + # Get the index try: self.index = self.pc.Index(self.index_name) @@ -79,10 +88,10 @@ def __init__(self, index_name: str = "docbrain"): except Exception as e: logger.error(f"Failed to initialize Pinecone index {self.index_name}: {e}") raise - + # Vector dimension from text-embedding-004 self.dimension = 768 - + # Connection status self._connected = True @@ -92,15 +101,17 @@ def cleanup(self): # cleanup logic here if needed in the future self._connected = False logger.info(f"Cleaned up PineconeVectorStore for index {self.index_name}") - + def __del__(self): """Destructor to ensure resources are cleaned up""" - if hasattr(self, '_connected') and self._connected: + if hasattr(self, "_connected") and self._connected: self.cleanup() - async def add_chunks(self, chunks: List[Dict[str, Any]], knowledge_base_id: str) -> None: + async def add_chunks( + self, chunks: List[Dict[str, Any]], knowledge_base_id: str + ) -> None: """Add document chunks to Pinecone - + Args: chunks: List of dictionaries containing: - content: str @@ -108,50 +119,62 @@ async def add_chunks(self, chunks: List[Dict[str, Any]], knowledge_base_id: str) knowledge_base_id: ID of the knowledge base (used as namespace) """ try: - logger.info(f"Starting to process {len(chunks)} chunks for knowledge base {knowledge_base_id} in index {self.index_name}") - + logger.info( + f"Starting to process {len(chunks)} chunks for knowledge base {knowledge_base_id} in index {self.index_name}" + ) + # Generate embeddings for chunks vectors = [] for i, chunk in enumerate(chunks): # Get embedding - document_id = str(chunk['metadata']['document_id']) - logger.info(f"Generating embedding for chunk {i+1}/{len(chunks)} (doc_id: {document_id}) using LLM Factory") - embedding = await self._get_embedding(chunk['content']) + document_id = str(chunk["metadata"]["document_id"]) + logger.info( + f"Generating embedding for chunk {i+1}/{len(chunks)} (doc_id: {document_id}) using LLM Factory" + ) + embedding = await self._get_embedding(chunk["content"]) logger.info(f"Generated embedding with dimension {len(embedding)}") - + # Store content and metadata separately for Pinecone metadata = { - 'document_id': document_id, - 'chunk_index': int(chunk['metadata']['chunk_index']), - 'chunk_size': str(chunk['metadata']['chunk_size']), - 'doc_title': str(chunk['metadata']['document_title']), - 'doc_type': str(chunk['metadata']['document_type']), - 'section': str(chunk['metadata']['nearest_header']), - 'path': ','.join(str(x) for x in chunk['metadata']['section_path']), - 'content': str(chunk['content']) + "document_id": document_id, + "chunk_index": int(chunk["metadata"]["chunk_index"]), + "chunk_size": str(chunk["metadata"]["chunk_size"]), + "doc_title": str(chunk["metadata"]["document_title"]), + "doc_type": str(chunk["metadata"]["document_type"]), + "section": str(chunk["metadata"]["nearest_header"]), + "path": ",".join(str(x) for x in chunk["metadata"]["section_path"]), + "content": str(chunk["content"]), } - + # Log metadata structure for infoging logger.info(f"Input chunk metadata structure: {chunk['metadata']}") logger.info(f"Processed metadata for Pinecone: {metadata}") - + # Create vector record with unique ID - vector_id = f"{document_id}_{metadata['chunk_index']}_{metadata['chunk_size']}" - vectors.append({ - 'id': vector_id, - 'values': [float(x) for x in embedding], - 'metadata': metadata - }) - logger.info(f"Created vector record for chunk {i+1} with id {vector_id}") - + vector_id = ( + f"{document_id}_{metadata['chunk_index']}_{metadata['chunk_size']}" + ) + vectors.append( + { + "id": vector_id, + "values": [float(x) for x in embedding], + "metadata": metadata, + } + ) + logger.info( + f"Created vector record for chunk {i+1} with id {vector_id}" + ) + # Upsert vectors in batches of 100 batch_size = 100 total_batches = (len(vectors) + batch_size - 1) // batch_size - + for i in range(0, len(vectors), batch_size): - batch = vectors[i:i + batch_size] + batch = vectors[i : i + batch_size] batch_num = (i // batch_size) + 1 - logger.info(f"Upserting batch {batch_num}/{total_batches} ({len(batch)} vectors)") + logger.info( + f"Upserting batch {batch_num}/{total_batches} ({len(batch)} vectors)" + ) try: # Use knowledge_base_id as namespace self.index.upsert(vectors=batch, namespace=knowledge_base_id) @@ -162,16 +185,24 @@ async def add_chunks(self, chunks: List[Dict[str, Any]], knowledge_base_id: str) if batch: logger.info(f"Sample vector from failing batch: {batch[0]}") raise - - logger.info(f"Successfully added {len(chunks)} chunks to Pinecone for knowledge base {knowledge_base_id}") - + + logger.info( + f"Successfully added {len(chunks)} chunks to Pinecone for knowledge base {knowledge_base_id}" + ) + except Exception as e: logger.error(f"Failed to add chunks to Pinecone: {e}", exc_info=True) raise - async def add_texts(self, texts: List[str], metadatas: List[Dict], ids: List[str], collection_name: str) -> None: + async def add_texts( + self, + texts: List[str], + metadatas: List[Dict], + ids: List[str], + collection_name: str, + ) -> None: """Add texts with metadata to Pinecone - + Args: texts: List of text content metadatas: List of metadata dictionaries @@ -180,90 +211,112 @@ async def add_texts(self, texts: List[str], metadatas: List[Dict], ids: List[str """ try: logger.info(f"Adding {len(texts)} texts to collection {collection_name}") - + # Convert to the format expected by add_chunks chunks = [] for i, (text, metadata, id) in enumerate(zip(texts, metadatas, ids)): chunk = { - 'content': text, - 'metadata': { - 'document_id': id, - 'chunk_index': i, - 'chunk_size': len(text), - 'document_title': metadata.get('title', ''), - 'document_type': metadata.get('type', ''), - 'nearest_header': metadata.get('section', ''), - 'section_path': metadata.get('path', '').split(',') if 'path' in metadata else [], - **metadata # Include other metadata fields - } + "content": text, + "metadata": { + "document_id": id, + "chunk_index": i, + "chunk_size": len(text), + "document_title": metadata.get("title", ""), + "document_type": metadata.get("type", ""), + "nearest_header": metadata.get("section", ""), + "section_path": ( + metadata.get("path", "").split(",") + if "path" in metadata + else [] + ), + **metadata, # Include other metadata fields + }, } chunks.append(chunk) - + # Use collection_name as knowledge_base_id await self.add_chunks(chunks, collection_name) - + except Exception as e: logger.error(f"Failed to add texts to Pinecone: {e}", exc_info=True) raise - async def delete_document_chunks(self, document_id: str, knowledge_base_id: str) -> None: + async def delete_document_chunks( + self, document_id: str, knowledge_base_id: str + ) -> None: """Delete all chunks for a document from Pinecone - + Args: document_id: ID of the document to delete knowledge_base_id: ID of the knowledge base (used as namespace) """ try: - logger.info(f"Attempting to delete chunks for document {document_id} in knowledge base {knowledge_base_id} in index {self.index_name}") - + logger.info( + f"Attempting to delete chunks for document {document_id} in knowledge base {knowledge_base_id} in index {self.index_name}" + ) + # First, fetch all vector IDs for this document by querying with the document_id try: # Try to use metadata filtering first (works on Standard and Enterprise tiers) self.index.delete( filter={"document_id": {"$eq": str(document_id)}}, - namespace=knowledge_base_id + namespace=knowledge_base_id, + ) + logger.info( + f"Successfully deleted chunks using metadata filter for document {document_id}" ) - logger.info(f"Successfully deleted chunks using metadata filter for document {document_id}") return except Exception as e: # If metadata filtering fails (Serverless and Starter tiers), use vector IDs - if "Serverless and Starter indexes do not support deleting with metadata filtering" in str(e): - logger.info("Pinecone Serverless/Starter tier detected, switching to ID-based deletion") - + if ( + "Serverless and Starter indexes do not support deleting with metadata filtering" + in str(e) + ): + logger.info( + "Pinecone Serverless/Starter tier detected, switching to ID-based deletion" + ) + # Query to get all vectors for this document # We need to use a dummy vector for the query dummy_vector = [0.0] * self.dimension - + # Query with a high top_k to get all vectors for this document results = self.index.query( vector=dummy_vector, top_k=10000, # Set a high limit to get all vectors include_metadata=True, - namespace=knowledge_base_id + namespace=knowledge_base_id, ) - + # Filter results to only include vectors for this document vector_ids = [] for match in results.matches: - if match.metadata and match.metadata.get('document_id') == str(document_id): + if match.metadata and match.metadata.get("document_id") == str( + document_id + ): vector_ids.append(match.id) - + if vector_ids: - logger.info(f"Found {len(vector_ids)} vectors to delete for document {document_id}") + logger.info( + f"Found {len(vector_ids)} vectors to delete for document {document_id}" + ) # Delete vectors by ID - self.index.delete( - ids=vector_ids, - namespace=knowledge_base_id + self.index.delete(ids=vector_ids, namespace=knowledge_base_id) + logger.info( + f"Successfully deleted {len(vector_ids)} vectors by ID for document {document_id}" ) - logger.info(f"Successfully deleted {len(vector_ids)} vectors by ID for document {document_id}") else: - logger.info(f"No vectors found for document {document_id} in knowledge base {knowledge_base_id}") + logger.info( + f"No vectors found for document {document_id} in knowledge base {knowledge_base_id}" + ) else: # If it's a different error, re-raise it raise - - logger.info(f"Successfully deleted chunks for document {document_id} in knowledge base {knowledge_base_id}") - + + logger.info( + f"Successfully deleted chunks for document {document_id} in knowledge base {knowledge_base_id}" + ) + except Exception as e: logger.error(f"Failed to delete document chunks: {e}", exc_info=True) raise @@ -274,7 +327,7 @@ async def search_similar( knowledge_base_id: str, limit: int = 5, similarity_threshold: float = 0.3, - metadata_filter: Dict[str, Any] = None + metadata_filter: Dict[str, Any] = None, ) -> List[Dict[str, Any]]: """Search for similar chunks in Pinecone""" try: @@ -283,44 +336,44 @@ async def search_similar( logger.info(f"Knowledge Base: {knowledge_base_id}") logger.info(f"Index: {self.index_name}") logger.info(f"Limit: {limit}, Threshold: {similarity_threshold}") - + # Get query embedding using LLM Factory logger.info("Generating query embedding using LLM Factory") query_vector = await self._get_embedding(query) logger.info(f"Generated embedding with dimension {len(query_vector)}") - + # Prepare filter (no need to include knowledge_base_id as it's now a namespace) filter_dict = {} - + # Handle metadata filters if metadata_filter: logger.info(f"Original metadata filter: {metadata_filter}") - + # All metadata fields are already flattened in Pinecone storage # Just add to filter, ensuring all values are strings for key, value in metadata_filter.items(): if isinstance(value, (list, tuple)): # Handle list values (like section_path) by joining with commas - filter_dict[key] = ','.join(str(x) for x in value) + filter_dict[key] = ",".join(str(x) for x in value) else: # Convert all other values to strings filter_dict[key] = str(value) - + logger.info(f"Final Pinecone filter: {filter_dict}") - + logger.info(f"Final Pinecone filter: {filter_dict}") - + # Search in Pinecone with filter and namespace results = self.index.query( vector=query_vector, filter=filter_dict if filter_dict else None, top_k=limit, include_metadata=True, - namespace=knowledge_base_id + namespace=knowledge_base_id, ) - + logger.info(f"Found {len(results.matches)} initial matches") - + # Process matches chunks = [] filtered_out = 0 @@ -329,77 +382,93 @@ async def search_similar( try: # Get metadata safely metadata = match.metadata or {} - + # Check if this is a question result (based on metadata fields) - is_question = 'question_id' in metadata and 'question' in metadata - + is_question = ( + "question_id" in metadata and "question" in metadata + ) + if is_question: # This is a question result, use question-specific fields chunk = { - 'score': float(match.score), - 'content': str(metadata.get('content', '')), - 'question_id': str(metadata.get('question_id', '')), - 'question': str(metadata.get('question', '')), - 'answer': str(metadata.get('answer', '')), - 'answer_type': str(metadata.get('answer_type', 'DIRECT')), - 'metadata': { - 'question_id': str(metadata.get('question_id', '')), - 'knowledge_base_id': str(metadata.get('knowledge_base_id', '')), - 'answer': str(metadata.get('answer', '')), - 'answer_type': str(metadata.get('answer_type', 'DIRECT')), - 'score': float(match.score) - } + "score": float(match.score), + "content": str(metadata.get("content", "")), + "question_id": str(metadata.get("question_id", "")), + "question": str(metadata.get("question", "")), + "answer": str(metadata.get("answer", "")), + "answer_type": str( + metadata.get("answer_type", "DIRECT") + ), + "metadata": { + "question_id": str(metadata.get("question_id", "")), + "knowledge_base_id": str( + metadata.get("knowledge_base_id", "") + ), + "answer": str(metadata.get("answer", "")), + "answer_type": str( + metadata.get("answer_type", "DIRECT") + ), + "score": float(match.score), + }, } else: # This is a document chunk, use document-specific fields chunk = { - 'score': float(match.score), - 'document_id': str(metadata.get('document_id', '')), - 'content': str(metadata.get('content', '')), - 'chunk_index': int(metadata.get('chunk_index', 0)), - 'title': str(metadata.get('doc_title', 'Untitled')), - 'metadata': { - 'document_id': str(metadata.get('document_id', '')), - 'chunk_index': int(metadata.get('chunk_index', 0)), - 'chunk_size': str(metadata.get('chunk_size', '')), - 'doc_title': str(metadata.get('doc_title', '')), - 'doc_type': str(metadata.get('doc_type', '')), - 'section': str(metadata.get('section', '')), - 'path': metadata.get('path', '').split(',') if metadata.get('path') else [] - } + "score": float(match.score), + "document_id": str(metadata.get("document_id", "")), + "content": str(metadata.get("content", "")), + "chunk_index": int(metadata.get("chunk_index", 0)), + "title": str(metadata.get("doc_title", "Untitled")), + "metadata": { + "document_id": str(metadata.get("document_id", "")), + "chunk_index": int(metadata.get("chunk_index", 0)), + "chunk_size": str(metadata.get("chunk_size", "")), + "doc_title": str(metadata.get("doc_title", "")), + "doc_type": str(metadata.get("doc_type", "")), + "section": str(metadata.get("section", "")), + "path": ( + metadata.get("path", "").split(",") + if metadata.get("path") + else [] + ), + }, } - + # Only skip if absolutely necessary - if not chunk['content']: - logger.warning(f"Skipping chunk with empty content") + if not chunk["content"]: + logger.warning("Skipping chunk with empty content") continue - + chunks.append(chunk) logger.info(f"Included chunk with score {match.score:.3f}") - + except Exception as chunk_error: logger.error(f"Error processing chunk: {chunk_error}") logger.info(f"Problematic metadata: {match.metadata}") continue else: filtered_out += 1 - logger.info(f"Filtered out chunk with score {match.score:.3f} (below threshold)") - + logger.info( + f"Filtered out chunk with score {match.score:.3f} (below threshold)" + ) + # Sort chunks by score - chunks.sort(key=lambda x: x['score'], reverse=True) + chunks.sort(key=lambda x: x["score"], reverse=True) final_chunks = chunks[:limit] - + logger.info(f"Returning {len(final_chunks)} total chunks") if final_chunks: - logger.info(f"Final score range: {min(c['score'] for c in final_chunks):.3f} - {max(c['score'] for c in final_chunks):.3f}") - + logger.info( + f"Final score range: {min(c['score'] for c in final_chunks):.3f} - {max(c['score'] for c in final_chunks):.3f}" + ) + # Log sample content from top chunk top_chunk = final_chunks[0] logger.info("Top chunk preview:") logger.info(f"Score: {top_chunk['score']:.3f}") - + # Check if this is a question result - if 'question_id' in top_chunk: + if "question_id" in top_chunk: # Question result logger.info(f"Question ID: {top_chunk['question_id']}") logger.info(f"Question: {top_chunk['question']}") @@ -409,191 +478,209 @@ async def search_similar( logger.info(f"Document ID: {top_chunk['document_id']}") logger.info(f"Title: {top_chunk['title']}") logger.info(f"Content: {top_chunk['content'][:200]}...") - + return final_chunks - + except Exception as e: logger.error(f"Failed to search in Pinecone: {e}") - logger.error("Search parameters:", extra={ - 'query': query, - 'knowledge_base_id': knowledge_base_id, - 'limit': limit, - 'threshold': similarity_threshold, - 'filter': metadata_filter - }) + logger.error( + "Search parameters:", + extra={ + "query": query, + "knowledge_base_id": knowledge_base_id, + "limit": limit, + "threshold": similarity_threshold, + "filter": metadata_filter, + }, + ) raise - + async def _get_embedding(self, text: str) -> List[float]: """ Get embedding for text using the centralized LLM Factory. - + Args: text: The text to embed - + Returns: Embedding as a list of floats """ try: # Use the LLM Factory for text embeddings (using model from settings) embedding = await LLMFactory.embed_text( - text=text, - model=settings.EMBEDDING_MODEL # Use the model from settings + text=text, model=settings.EMBEDDING_MODEL # Use the model from settings ) - + # Verify dimension if len(embedding) != self.dimension: - raise ValueError(f"Expected embedding dimension {self.dimension}, got {len(embedding)}") - + raise ValueError( + f"Expected embedding dimension {self.dimension}, got {len(embedding)}" + ) + logger.info(f"Generated embedding with dimension {len(embedding)}") return embedding - + except Exception as e: logger.error(f"Failed to get embedding: {e}", exc_info=True) raise - async def get_random_chunks(self, knowledge_base_id: str, limit: int = 5) -> List[Dict]: + async def get_random_chunks( + self, knowledge_base_id: str, limit: int = 5 + ) -> List[Dict]: """ Get random chunks from a knowledge base. - + Args: knowledge_base_id: The ID of the knowledge base to fetch from (used as namespace) limit: The number of chunks to retrieve - + Returns: List of random chunks with content and metadata """ try: - logger.info(f"Fetching random chunks from knowledge base {knowledge_base_id} in index {self.index_name}") - + logger.info( + f"Fetching random chunks from knowledge base {knowledge_base_id} in index {self.index_name}" + ) + # Fetch all document IDs in this knowledge base # We need to use a dummy query to get all vectors in the namespace dummy_vector = [0.0] * self.dimension - + # Query with a high top_k to get a good sample response = self.index.query( vector=dummy_vector, top_k=1000, # Get a large sample to choose from include_metadata=True, - namespace=knowledge_base_id + namespace=knowledge_base_id, ) - + if not response.matches: logger.info(f"No chunks found in knowledge base {knowledge_base_id}") return [] - + # Get all unique document IDs doc_ids = set() for match in response.matches: - if match.metadata and 'document_id' in match.metadata: - doc_ids.add(match.metadata['document_id']) - - logger.info(f"Found {len(doc_ids)} unique documents in knowledge base {knowledge_base_id}") - + if match.metadata and "document_id" in match.metadata: + doc_ids.add(match.metadata["document_id"]) + + logger.info( + f"Found {len(doc_ids)} unique documents in knowledge base {knowledge_base_id}" + ) + if not doc_ids: return [] - + # Select random document IDs (up to 5) selected_doc_ids = random.sample(list(doc_ids), min(5, len(doc_ids))) logger.info(f"Selected {len(selected_doc_ids)} random documents") - + # For each selected document, get a random chunk chunks = [] for doc_id in selected_doc_ids: # Get all chunks for this document doc_chunks = [ - match for match in response.matches - if match.metadata and match.metadata.get('document_id') == doc_id + match + for match in response.matches + if match.metadata and match.metadata.get("document_id") == doc_id ] - + if doc_chunks: # Select a random chunk random_chunk = random.choice(doc_chunks) - + # Format the chunk metadata = random_chunk.metadata or {} chunk = { - 'document_id': str(metadata.get('document_id', '')), - 'content': str(metadata.get('content', '')), - 'chunk_index': int(metadata.get('chunk_index', 0)), - 'title': str(metadata.get('doc_title', 'Untitled')), - 'metadata': { - 'document_id': str(metadata.get('document_id', '')), - 'chunk_index': int(metadata.get('chunk_index', 0)), - 'doc_title': str(metadata.get('doc_title', '')), - 'doc_type': str(metadata.get('doc_type', '')), - 'section': str(metadata.get('section', '')), - } + "document_id": str(metadata.get("document_id", "")), + "content": str(metadata.get("content", "")), + "chunk_index": int(metadata.get("chunk_index", 0)), + "title": str(metadata.get("doc_title", "Untitled")), + "metadata": { + "document_id": str(metadata.get("document_id", "")), + "chunk_index": int(metadata.get("chunk_index", 0)), + "doc_title": str(metadata.get("doc_title", "")), + "doc_type": str(metadata.get("doc_type", "")), + "section": str(metadata.get("section", "")), + }, } chunks.append(chunk) - + # If we don't have enough chunks, get more random ones if len(chunks) < limit and response.matches: remaining = limit - len(chunks) - random_matches = random.sample(response.matches, min(remaining, len(response.matches))) - + random_matches = random.sample( + response.matches, min(remaining, len(response.matches)) + ) + for match in random_matches: # Skip if already included - if any(c['document_id'] == match.metadata.get('document_id', '') and - c['chunk_index'] == int(match.metadata.get('chunk_index', 0)) - for c in chunks): + if any( + c["document_id"] == match.metadata.get("document_id", "") + and c["chunk_index"] + == int(match.metadata.get("chunk_index", 0)) + for c in chunks + ): continue - + metadata = match.metadata or {} chunk = { - 'document_id': str(metadata.get('document_id', '')), - 'content': str(metadata.get('content', '')), - 'chunk_index': int(metadata.get('chunk_index', 0)), - 'title': str(metadata.get('doc_title', 'Untitled')), - 'metadata': { - 'document_id': str(metadata.get('document_id', '')), - 'chunk_index': int(metadata.get('chunk_index', 0)), - 'doc_title': str(metadata.get('doc_title', '')), - 'doc_type': str(metadata.get('doc_type', '')), - 'section': str(metadata.get('section', '')), - } + "document_id": str(metadata.get("document_id", "")), + "content": str(metadata.get("content", "")), + "chunk_index": int(metadata.get("chunk_index", 0)), + "title": str(metadata.get("doc_title", "Untitled")), + "metadata": { + "document_id": str(metadata.get("document_id", "")), + "chunk_index": int(metadata.get("chunk_index", 0)), + "doc_title": str(metadata.get("doc_title", "")), + "doc_type": str(metadata.get("doc_type", "")), + "section": str(metadata.get("section", "")), + }, } chunks.append(chunk) - + if len(chunks) >= limit: break - + logger.info(f"Returning {len(chunks)} random chunks") return chunks[:limit] - + except Exception as e: logger.error(f"Error getting random chunks: {e}", exc_info=True) return [] async def search_chunks( - self, - query: str, - knowledge_base_id: str, - top_k: int = 5, + self, + query: str, + knowledge_base_id: str, + top_k: int = 5, metadata_filter: Optional[Dict] = None, - similarity_threshold: float = 0.3 + similarity_threshold: float = 0.3, ) -> List[Dict]: """ Search for chunks based on a query with metadata filtering. - + Args: query: The query to process knowledge_base_id: The ID of the knowledge base to search (used as namespace) top_k: The number of chunks to retrieve metadata_filter: Optional additional metadata filter to apply similarity_threshold: Minimum similarity score for chunks to be included (default: 0.3) - + Returns: List of chunks with content and metadata """ try: # Get embedding for the query using LLM Factory - logger.info(f"Generating embedding for query using LLM Factory: '{query[:50]}...' (truncated)") + logger.info( + f"Generating embedding for query using LLM Factory: '{query[:50]}...' (truncated)" + ) embedding = await self._get_embedding(query) logger.info(f"Generated embedding with dimension {len(embedding)}") - + # Create filter (no need to include knowledge_base_id as it's now a namespace) filter = {} - + # Add additional metadata filters if provided if metadata_filter: logger.info(f"Applying metadata filter: {metadata_filter}") @@ -601,35 +688,37 @@ async def search_chunks( if key == "similarity_threshold": # Skip this key as it's not a metadata filter continue - + if isinstance(value, dict): # If value is a dict, it's already in Pinecone filter format # Special handling for $in operator with document_id if key == "document_id" and "$in" in value: # Ensure all document IDs are strings value["$in"] = [str(doc_id) for doc_id in value["$in"]] - logger.info(f"Formatted document_id $in filter with {len(value['$in'])} IDs") - + logger.info( + f"Formatted document_id $in filter with {len(value['$in'])} IDs" + ) + filter[key] = value else: # Otherwise, create an equality filter filter[key] = {"$eq": str(value)} - + logger.info(f"Final Pinecone filter: {filter}") - + # Query Pinecone with namespace response = self.index.query( vector=embedding, filter=filter if filter else None, top_k=top_k * 2, # Get more results than needed to allow for filtering include_metadata=True, - namespace=knowledge_base_id + namespace=knowledge_base_id, ) - + if not response.matches: logger.info(f"No chunks found for query: '{query}'") return [] - + # Convert to list of dictionaries with content and metadata chunks = [] filtered_out = 0 @@ -639,30 +728,42 @@ async def search_chunks( chunk = { "id": match.id, "content": match.metadata.get("content", ""), - "document_id": match.metadata.get("document_id", "") or match.metadata.get("doc_id", ""), + "document_id": match.metadata.get("document_id", "") + or match.metadata.get("doc_id", ""), "metadata": { "doc_title": match.metadata.get("doc_title", ""), "doc_id": match.metadata.get("doc_id", ""), - "document_id": match.metadata.get("document_id", "") or match.metadata.get("doc_id", ""), + "document_id": match.metadata.get("document_id", "") + or match.metadata.get("doc_id", ""), "chunk_id": match.metadata.get("chunk_id", ""), - "knowledge_base_id": match.metadata.get("knowledge_base_id", "") + "knowledge_base_id": match.metadata.get( + "knowledge_base_id", "" + ), }, - "score": match.score + "score": match.score, } chunks.append(chunk) else: filtered_out += 1 - - logger.info(f"Found {len(chunks)} chunks above similarity threshold {similarity_threshold} (filtered out {filtered_out})") + + logger.info( + f"Found {len(chunks)} chunks above similarity threshold {similarity_threshold} (filtered out {filtered_out})" + ) return chunks - + except Exception as e: logger.error(f"Error searching chunks: {e}", exc_info=True) return [] - async def add_questions(self, texts: List[str], metadatas: List[Dict], ids: List[str], collection_name: str) -> None: + async def add_questions( + self, + texts: List[str], + metadatas: List[Dict], + ids: List[str], + collection_name: str, + ) -> None: """Add questions with metadata to Pinecone - specialized for question ingestion - + Args: texts: List of question + answer text content metadatas: List of metadata dictionaries with question-specific fields @@ -670,46 +771,56 @@ async def add_questions(self, texts: List[str], metadatas: List[Dict], ids: List collection_name: Name of collection (used as namespace) """ try: - logger.info(f"Adding {len(texts)} questions to collection {collection_name}") - + logger.info( + f"Adding {len(texts)} questions to collection {collection_name}" + ) + # Process each question for storage vectors = [] for i, (text, metadata, id) in enumerate(zip(texts, metadatas, ids)): # Get embedding - logger.info(f"Generating embedding for question {i+1}/{len(texts)} (id: {id})") + logger.info( + f"Generating embedding for question {i+1}/{len(texts)} (id: {id})" + ) embedding = await self._get_embedding(text) logger.info(f"Generated embedding with dimension {len(embedding)}") - + # Clean up metadata - only include question-specific fields # No document-specific fields like chunk_index, doc_title, etc. pinecone_metadata = { - 'content': text, - 'question_id': metadata.get('question_id', ''), - 'knowledge_base_id': metadata.get('knowledge_base_id', ''), - 'answer_type': metadata.get('answer_type', ''), - 'question': metadata.get('question', ''), - 'answer': metadata.get('answer', ''), - 'user_id': metadata.get('user_id', '') + "content": text, + "question_id": metadata.get("question_id", ""), + "knowledge_base_id": metadata.get("knowledge_base_id", ""), + "answer_type": metadata.get("answer_type", ""), + "question": metadata.get("question", ""), + "answer": metadata.get("answer", ""), + "user_id": metadata.get("user_id", ""), } - - logger.info(f"Prepared metadata for question: {pinecone_metadata.keys()}") - + + logger.info( + f"Prepared metadata for question: {pinecone_metadata.keys()}" + ) + # Create vector record with unique ID - vectors.append({ - 'id': id, - 'values': [float(x) for x in embedding], - 'metadata': pinecone_metadata - }) - + vectors.append( + { + "id": id, + "values": [float(x) for x in embedding], + "metadata": pinecone_metadata, + } + ) + # Upsert vectors in batches of 100 batch_size = 100 total_batches = (len(vectors) + batch_size - 1) // batch_size - + for i in range(0, len(vectors), batch_size): - batch = vectors[i:i + batch_size] + batch = vectors[i : i + batch_size] batch_num = (i // batch_size) + 1 - logger.info(f"Upserting batch {batch_num}/{total_batches} ({len(batch)} vectors)") - + logger.info( + f"Upserting batch {batch_num}/{total_batches} ({len(batch)} vectors)" + ) + try: # Use collection_name as namespace self.index.upsert(vectors=batch, namespace=collection_name) @@ -720,9 +831,11 @@ async def add_questions(self, texts: List[str], metadatas: List[Dict], ids: List if batch: logger.info(f"Sample vector from failing batch: {batch[0]}") raise - - logger.info(f"Successfully added {len(texts)} questions to Pinecone for collection {collection_name}") - + + logger.info( + f"Successfully added {len(texts)} questions to Pinecone for collection {collection_name}" + ) + except Exception as e: logger.error(f"Failed to add questions to Pinecone: {e}", exc_info=True) raise @@ -730,38 +843,44 @@ async def add_questions(self, texts: List[str], metadatas: List[Dict], ids: List class VectorStoreFactory: """Factory class for creating VectorStore instances""" - + # Class-level registry to store instances _instances = {} # Lock for thread safety _lock = threading.RLock() - + @classmethod - def create(cls, store_type: str = "pinecone", index_name: str = "docbrain") -> VectorStore: + def create( + cls, store_type: str = "pinecone", index_name: str = "docbrain" + ) -> VectorStore: """Create or retrieve a VectorStore instance - + Args: store_type: Type of vector store to create ('pinecone', 'weaviate', 'chroma', etc.) index_name: Name of the Pinecone index to use ('docbrain' or 'summary') - + Returns: VectorStore: An instance of a VectorStore implementation """ # Create a key for the instance registry instance_key = f"{store_type}_{index_name}" - + # First check without lock for performance if instance_key in cls._instances: - logger.debug(f"Returning existing {store_type} instance for index {index_name}") + logger.debug( + f"Returning existing {store_type} instance for index {index_name}" + ) return cls._instances[instance_key] - + # If not found, acquire lock and check again (double-checked locking pattern) with cls._lock: # Check again with lock held if instance_key in cls._instances: - logger.debug(f"Returning existing {store_type} instance for index {index_name} (after lock)") + logger.debug( + f"Returning existing {store_type} instance for index {index_name} (after lock)" + ) return cls._instances[instance_key] - + # Create a new instance logger.info(f"Creating new {store_type} instance for index {index_name}") if store_type == "pinecone": @@ -773,12 +892,12 @@ def create(cls, store_type: str = "pinecone", index_name: str = "docbrain") -> V # For future implementations like weaviate or chroma # We can add them here when needed raise ValueError(f"Unsupported vector store type: {store_type}") - + @classmethod def cleanup(cls, store_type: str = None, index_name: str = None): """ Clean up vector store instances - + Args: store_type: Optional type to clean up only instances of this type index_name: Optional index name to clean up only instances for this index @@ -789,59 +908,63 @@ def cleanup(cls, store_type: str = None, index_name: str = None): instance_key = f"{store_type}_{index_name}" if instance_key in cls._instances: instance = cls._instances[instance_key] - if hasattr(instance, 'cleanup'): + if hasattr(instance, "cleanup"): instance.cleanup() del cls._instances[instance_key] - logger.info(f"Cleaned up {store_type} instance for index {index_name}") + logger.info( + f"Cleaned up {store_type} instance for index {index_name}" + ) elif store_type: # Clean up all instances of this type keys_to_remove = [] for key, instance in cls._instances.items(): if key.startswith(f"{store_type}_"): - if hasattr(instance, 'cleanup'): + if hasattr(instance, "cleanup"): instance.cleanup() keys_to_remove.append(key) - + for key in keys_to_remove: del cls._instances[key] - + logger.info(f"Cleaned up all {store_type} instances") elif index_name: # Clean up all instances with this index name keys_to_remove = [] for key, instance in cls._instances.items(): if key.endswith(f"_{index_name}"): - if hasattr(instance, 'cleanup'): + if hasattr(instance, "cleanup"): instance.cleanup() keys_to_remove.append(key) - + for key in keys_to_remove: del cls._instances[key] - + logger.info(f"Cleaned up all instances for index {index_name}") else: # Clean up all instances for instance in cls._instances.values(): - if hasattr(instance, 'cleanup'): + if hasattr(instance, "cleanup"): instance.cleanup() - + cls._instances.clear() logger.info("Cleaned up all vector store instances") # Helper functions for working with vector stores -def get_vector_store(store_type: str = "pinecone", index_name: str = "docbrain") -> VectorStore: +def get_vector_store( + store_type: str = "pinecone", index_name: str = "docbrain" +) -> VectorStore: """Get a vector store instance - uses singleton pattern - + This function returns a singleton instance for each unique combination of store_type and index_name. Subsequent calls with the same parameters will return the same instance. - + Args: store_type: Type of vector store ('pinecone', 'weaviate', 'chroma', etc.) index_name: Name of the Pinecone index to use ('docbrain' or 'summary') - + Returns: VectorStore instance (singleton per unique combination of parameters) """ - return VectorStoreFactory.create(store_type=store_type, index_name=index_name) \ No newline at end of file + return VectorStoreFactory.create(store_type=store_type, index_name=index_name) diff --git a/app/services/rag_service.py b/app/services/rag_service.py index 63caaa8..bdc38bf 100644 --- a/app/services/rag_service.py +++ b/app/services/rag_service.py @@ -1,20 +1,24 @@ -from typing import Dict, Any, Optional, List import logging +from functools import lru_cache +from typing import Any, Dict, List, Optional + +from app.core.config import settings +from app.core.prompts import get_prompt, register_prompt from app.db.models.knowledge_base import DocumentType +from app.services.llm.factory import CompletionOptions, LLMFactory, Message, Role from app.services.rag.chunker.chunker_factory import ChunkerFactory from app.services.rag.ingestor.ingestor_factory import IngestorFactory from app.services.rag.reranker.reranker_factory import RerankerFactory from app.services.rag.retriever.retriever_factory import RetrieverFactory -from app.services.llm.factory import LLMFactory, Message, Role, CompletionOptions -from app.core.prompts import get_prompt, register_prompt -from app.core.config import settings from app.services.rag.vector_store import get_vector_store -from functools import lru_cache logger = logging.getLogger(__name__) # Register rag_service prompts -register_prompt("rag_service", "generate_answer", """ +register_prompt( + "rag_service", + "generate_answer", + """ Based on the given sources, please answer the question. If you cannot find a relevant answer in the sources, please say so. @@ -28,18 +32,26 @@ 2. Uses [Source X] notation to cite the sources 3. Only uses information from the provided sources 4. Maintains a professional and helpful tone -""") +""", +) + class RAGService: """ Retrieval-Augmented Generation (RAG) service that combines document ingestion, chunking, retrieval, reranking, and answer generation. """ - - def __init__(self, use_reranker: bool = True, reranker_model: str = "Cohere/rerank-v3.5", llm_model: str = "gemini-2.0-flash", llm_provider: Optional[str] = None): + + def __init__( + self, + use_reranker: bool = True, + reranker_model: str = "Cohere/rerank-v3.5", + llm_model: str = "gemini-2.0-flash", + llm_provider: Optional[str] = None, + ): """ Initialize the RAG service. - + Args: use_reranker: Whether to use reranking reranker_model: Model to use for reranking @@ -47,7 +59,9 @@ def __init__(self, use_reranker: bool = True, reranker_model: str = "Cohere/rera llm_provider: Provider to use for answer generation (defaults to settings.LLM_PROVIDER) """ try: - logger.info(f"Initializing RAG service with model {llm_model} from provider {llm_provider or settings.LLM_PROVIDER}") + logger.info( + f"Initializing RAG service with model {llm_model} from provider {llm_provider or settings.LLM_PROVIDER}" + ) self.llm_model = llm_model self.llm_provider = llm_provider self.use_reranker = use_reranker @@ -56,7 +70,7 @@ def __init__(self, use_reranker: bool = True, reranker_model: str = "Cohere/rera except Exception as e: logger.error(f"Failed to initialize RAG service: {e}", exc_info=True) raise - + async def ingest_document( self, content: bytes, @@ -65,12 +79,12 @@ async def ingest_document( ) -> Dict[str, Any]: """ Ingest a document into the knowledge base. - + Args: content: Document content as string or bytes metadata: Document metadata content_type: MIME type of the document - + Returns: Dictionary containing: - document_id: ID of the ingested document @@ -78,67 +92,67 @@ async def ingest_document( """ try: logger.info(f"Ingesting document of type {content_type}") - + # Create ingestor ingestor = IngestorFactory.create_ingestor(content_type) - + # Ingest document ingestion_result = await ingestor.ingest(content, metadata) - + # Extract text and enhanced metadata text = ingestion_result["text"] enhanced_metadata = ingestion_result["metadata"] - + # Create chunker chunker = ChunkerFactory.create_chunker_from_metadata(enhanced_metadata) - + # Chunk document chunks = await chunker.chunk(text, enhanced_metadata) - + # Use knowledge_base_id from metadata to create retriever kb_id = metadata.get("knowledge_base_id") retriever = RetrieverFactory.create_retriever(kb_id) await retriever.add_chunks(chunks) - + logger.info(f"Successfully ingested document with {len(chunks)} chunks") - + return { "document_id": enhanced_metadata.get("document_id", ""), - "chunk_count": len(chunks) + "chunk_count": len(chunks), } - + except Exception as e: logger.error(f"Failed to ingest document: {e}", exc_info=True) raise - + async def delete_document(self, document_id: str, knowledge_base_id: str) -> bool: """ Delete a document from the knowledge base. - + Args: document_id: ID of the document to delete knowledge_base_id: ID of the knowledge base to delete the document from - + Returns: True if successful, False otherwise """ try: logger.info(f"Deleting document {document_id}") - + # Create retriever using the provided knowledge_base_id retriever = RetrieverFactory.create_retriever(knowledge_base_id) - + # Delete document chunks from vector store await retriever.delete_document_chunks(document_id) - + logger.info(f"Successfully deleted document {document_id}") - + return True - + except Exception as e: logger.error(f"Failed to delete document: {e}", exc_info=True) return False - + async def retrieve_from_storage( self, knowledge_base_id: str, @@ -151,7 +165,7 @@ async def retrieve_from_storage( """ try: logger.info(f"Retrieving from storage for query: '{query}'") - + # Create retriever using the provided knowledge_base_id retriever = RetrieverFactory.create_retriever(knowledge_base_id) @@ -162,7 +176,6 @@ async def retrieve_from_storage( except Exception as e: logger.error(f"Failed to retrieve from storage: {e}", exc_info=True) raise - async def retrieve( self, @@ -174,14 +187,14 @@ async def retrieve( ) -> Dict[str, Any]: """ Retrieve relevant chunks and generate an answer. - + Args: knowledge_base_id: ID of the knowledge base to use query: The query to process top_k: Number of chunks to retrieve similarity_threshold: Minimum similarity score for chunks metadata_filter: Optional filter for retrieval - + Returns: Dictionary containing: - answer: Generated answer @@ -189,16 +202,16 @@ async def retrieve( """ try: logger.info(f"Processing query: '{query}'") - + # Create retriever using the provided knowledge_base_id retriever = RetrieverFactory.create_retriever(knowledge_base_id) # Retrieve chunks chunks = await retriever.search( query=query, - top_k=top_k * 2, # always retrive twice to decide whether to rerank + top_k=top_k * 2, # always retrive twice to decide whether to rerank similarity_threshold=similarity_threshold, - metadata_filter=metadata_filter + metadata_filter=metadata_filter, ) logger.info(f"Retrieved {len(chunks)} chunks") @@ -214,15 +227,17 @@ async def retrieve( elif len(chunks) > top_k: # Limit to top_k if not reranking chunks = chunks[:top_k] - + # Generate answer using the LLMFactory directly if chunks: logger.info("Generating answer") answer = await self._generate_answer(query, chunks) else: logger.warning("No chunks found, returning empty answer") - answer = "I couldn't find any relevant information to answer your question." - + answer = ( + "I couldn't find any relevant information to answer your question." + ) + # Format sources sources = [] for chunk in chunks: @@ -231,93 +246,88 @@ async def retrieve( "title": chunk.get("title", "Untitled"), "content": chunk.get("content", ""), "chunk_index": chunk.get("chunk_index", 0), - "score": chunk.get("score", 0.0) + "score": chunk.get("score", 0.0), } sources.append(source) - + logger.info(f"Successfully processed query with {len(sources)} sources") - - return { - "answer": answer, - "sources": sources - } - + + return {"answer": answer, "sources": sources} + except Exception as e: logger.error(f"Failed to process query: {e}", exc_info=True) return { "answer": f"I encountered an error while processing your query: {str(e)}", - "sources": [] + "sources": [], } - + async def _generate_answer(self, query: str, context: List[Dict[str, Any]]) -> str: """ Generate an answer using the LLM Factory. - + Args: query: The user's query context: List of context chunks to use for answering - + Returns: Generated answer as a string """ try: logger.info(f"Generating answer for query: {query}") logger.info(f"Using {len(context)} context chunks") - + # Format context chunks formatted_context = self._format_context(context) - + # Get the prompt from the registry - prompt = get_prompt("rag_service", "generate_answer", - query=query, - context=formatted_context) - + prompt = get_prompt( + "rag_service", "generate_answer", query=query, context=formatted_context + ) + # Create the message for the LLM - messages = [ - Message(role=Role.USER, content=prompt) - ] - + messages = [Message(role=Role.USER, content=prompt)] + # Set completion options options = CompletionOptions( temperature=0.3, # Low temperature for more factual answers - max_tokens=1000 + max_tokens=1000, ) - + # Generate response using LLM Factory response = await LLMFactory.complete( messages=messages, provider=self.llm_provider, model=self.llm_model, - options=options + options=options, ) - + logger.info(f"Generated answer with {len(response.content)} characters") - + return response.content - + except Exception as e: logger.error(f"Failed to generate answer: {e}", exc_info=True) return f"I apologize, but I encountered an error while generating an answer: {str(e)}" - + def _format_context(self, context: List[Dict[str, Any]]) -> str: """ Format context chunks for the prompt. - + Args: context: List of context chunks - + Returns: Formatted context as a string """ formatted_chunks = [] - + for i, chunk in enumerate(context, 1): # Extract metadata - document_id = chunk.get('document_id', 'unknown') - title = chunk.get('title', 'Untitled') - content = chunk.get('content', '') - score = chunk.get('score', 0.0) - + document_id = chunk.get("document_id", "unknown") + title = chunk.get("title", "Untitled") + content = chunk.get("content", "") + score = chunk.get("score", 0.0) + # Format chunk formatted_chunk = ( f"[Source {i}]\n" @@ -325,9 +335,9 @@ def _format_context(self, context: List[Dict[str, Any]]) -> str: f"Relevance: {score:.3f}\n" f"Content: {content}\n" ) - + formatted_chunks.append(formatted_chunk) - + return "\n\n".join(formatted_chunks) async def add_document_summary( @@ -336,30 +346,29 @@ async def add_document_summary( knowledge_base_id: str, document_title: str, document_type: str, - summary: str + summary: str, ) -> bool: """ Add a document summary to the summary index for semantic routing. - + Args: document_id: ID of the document knowledge_base_id: ID of the knowledge base containing the document document_title: Title of the document document_type: Type of the document summary: Generated summary of the document - + Returns: True if successful, False otherwise """ try: logger.info(f"Adding summary for document {document_id} to summary index") - + # Get the summary vector store summary_vector_store = get_vector_store( - store_type="pinecone", - index_name=settings.PINECONE_SUMMARY_INDEX_NAME + store_type="pinecone", index_name=settings.PINECONE_SUMMARY_INDEX_NAME ) - + # Create a single chunk with the summary chunk = { "content": summary, @@ -372,39 +381,43 @@ async def add_document_summary( "chunk_size": len(summary), "nearest_header": "Document Summary", "section_path": ["Document Summary"], - "is_summary": True - } + "is_summary": True, + }, } - + # Add the summary to the summary index # We use "summaries" as a special namespace for all document summaries - await summary_vector_store.add_chunks(chunks=[chunk], knowledge_base_id=knowledge_base_id) - - logger.info(f"Successfully added summary for document {document_id} to summary index") + await summary_vector_store.add_chunks( + chunks=[chunk], knowledge_base_id=knowledge_base_id + ) + + logger.info( + f"Successfully added summary for document {document_id} to summary index" + ) return True - + except Exception as e: logger.error(f"Failed to add document summary to index: {e}", exc_info=True) - return False + return False + # Create a singleton instance of RAGService @lru_cache() def get_rag_service( - llm_model: Optional[str] = None, - llm_provider: Optional[str] = None + llm_model: Optional[str] = None, llm_provider: Optional[str] = None ) -> RAGService: """ Get a singleton instance of RAGService. - + Args: llm_model: Model to use for answer generation (defaults to settings.DEFAULT_LLM_MODEL or provider default) llm_provider: Provider to use for answer generation (defaults to settings.LLM_PROVIDER) - + Returns: RAGService instance """ model = llm_model or settings.DEFAULT_LLM_MODEL or "gemini-2.0-flash" provider = llm_provider or settings.LLM_PROVIDER - + logger.info(f"Creating RAGService with model={model}, provider={provider}") - return RAGService(llm_model=model, llm_provider=provider) \ No newline at end of file + return RAGService(llm_model=model, llm_provider=provider) diff --git a/app/services/tag_service.py b/app/services/tag_service.py index 69e420a..f12f667 100644 --- a/app/services/tag_service.py +++ b/app/services/tag_service.py @@ -1,22 +1,23 @@ -from typing import Dict, Any, List, Optional -import logging import json -from app.core.config import settings -from app.db.database import get_db -from functools import lru_cache +import logging import re -from app.db.models.knowledge_base import DocumentType +from functools import lru_cache +from typing import Any, Dict, List, Optional + +from app.core.prompts import get_prompt, register_prompt +from app.db.database import get_db +from app.db.models.knowledge_base import Document, DocumentType from app.db.storage import get_storage_db from app.repositories.storage_repository import StorageRepository -from app.db.models.knowledge_base import Document -from app.services.llm.factory import LLMFactory, Message, Role, CompletionOptions -from app.core.prompts import get_prompt, register_prompt - +from app.services.llm.factory import CompletionOptions, LLMFactory, Message, Role logger = logging.getLogger(__name__) # Register prompts used by the TAG service -register_prompt("tag_service", "generate_sql", """ +register_prompt( + "tag_service", + "generate_sql", + """ You are an AI assistant that converts natural language questions into SQL queries. I have the following database tables: @@ -28,9 +29,13 @@ Return ONLY a valid SQL query without any explanations or markdown formatting. Make sure the query is compatible with common SQL dialects. Do not use features specific to one SQL dialect unless necessary. -""") +""", +) -register_prompt("tag_service", "generate_answer", """ +register_prompt( + "tag_service", + "generate_answer", + """ You are an AI assistant that explains SQL query results. Original question: "{{ query }}" @@ -49,7 +54,9 @@ 2. Summarize the key findings from the data 3. Be easy to understand for someone without technical SQL knowledge 4. Include specific numbers/values from the results when relevant -""") +""", +) + class TagService: """ @@ -57,11 +64,11 @@ class TagService: This service converts natural language queries to SQL and executes them against structured data stored in a database. """ - + def __init__(self): """Initialize the Tag Service""" logger.info("Initializing TagService") - + async def retrieve( self, knowledge_base_id: str, @@ -70,12 +77,12 @@ async def retrieve( ) -> Dict[str, Any]: """ Process a natural language query against structured data using text-to-SQL. - + Args: knowledge_base_id: The ID of the knowledge base to search query: The natural language query to process metadata_filter: Additional filtering criteria - + Returns: Dictionary containing: - query: The original query @@ -87,10 +94,10 @@ async def retrieve( """ try: logger.info(f"TagService processing query: '{query}'") - + # Get all table schemas from the database table_schemas = await self._get_all_table_schemas() - + if not table_schemas: logger.warning(f"No table schemas found to process query: {query}") return { @@ -99,33 +106,35 @@ async def retrieve( "sql": None, "results": [], "sources": [], - "service": "tag" + "service": "tag", } - + # Get documents from the knowledge base for sources documents = await self._get_knowledge_base_documents(knowledge_base_id) - + # Extract document IDs and titles for sources sources = [] for i, doc in enumerate(documents): if doc.content_type in [DocumentType.CSV, DocumentType.EXCEL]: # Create a source with all required fields for MessageResponse validation - sources.append({ - "document_id": doc.id, - "title": doc.title, - "score": 1.0, # Default high score for TAG sources - "content": f"Table data from {doc.title}", # Provide a description as content - "chunk_index": i, # Use document index as chunk_index - "metadata": { + sources.append( + { "document_id": doc.id, - "knowledge_base_id": knowledge_base_id, - "content_type": doc.content_type + "title": doc.title, + "score": 1.0, # Default high score for TAG sources + "content": f"Table data from {doc.title}", # Provide a description as content + "chunk_index": i, # Use document index as chunk_index + "metadata": { + "document_id": doc.id, + "knowledge_base_id": knowledge_base_id, + "content_type": doc.content_type, + }, } - }) - + ) + # Generate SQL from the natural language query sql_query = await self._generate_sql(query, table_schemas) - + if not sql_query: return { "query": query, @@ -133,24 +142,24 @@ async def retrieve( "sql": None, "results": [], "sources": sources, - "service": "tag" + "service": "tag", } - + # Execute the SQL query results = await self._execute_sql(sql_query) - + # Generate a natural language answer answer = await self._generate_answer(query, sql_query, results) - + return { "query": query, "answer": answer, "sql": sql_query, "results": results, "sources": sources, - "service": "tag" + "service": "tag", } - + except Exception as e: logger.error(f"Error in TagService.retrieve: {e}", exc_info=True) return { @@ -160,19 +169,19 @@ async def retrieve( "results": [], "sources": [], "service": "tag", - "error": str(e) + "error": str(e), } - + async def _get_all_table_schemas(self) -> Dict[str, Any]: """ Get schemas for all tables in the storage database using direct SQL queries. - + Returns: Dictionary containing table schemas """ try: schemas = {} - + # Use StorageRepository to execute SQL queries # First get all tables using SHOW TABLES logger.info("Fetching all tables from storage database") @@ -183,109 +192,121 @@ async def _get_all_table_schemas(self) -> Dict[str, Any]: except Exception as e: logger.error(f"All table query methods failed: {e}") table_names = [] - + logger.info(f"Found tables: {table_names}") - + # For each table, get its schema information for table_name in table_names: # Skip system tables - if (table_name.startswith('sqlite_') or table_name.startswith('pg_') or - table_name.startswith('alembic_') or table_name == 'spatial_ref_sys'): + if ( + table_name.startswith("sqlite_") + or table_name.startswith("pg_") + or table_name.startswith("alembic_") + or table_name == "spatial_ref_sys" + ): continue - + # Get column information try: # Try DESCRIBE command (MySQL/MariaDB) describe_query = f"DESCRIBE {table_name}" db = get_storage_db().__next__() describe_result = await StorageRepository.query(db, describe_query) - + columns = [] for row in describe_result: # DESCRIBE typically returns: Field, Type, Null, Key, Default, Extra col_name = row[0] col_type = row[1] - is_nullable = row[2].upper() == 'YES' if len(row) > 2 else True - key_type = row[3] if len(row) > 3 else '' - - columns.append({ - "name": col_name, - "type": col_type, - "nullable": is_nullable, - "key": key_type - }) - + is_nullable = row[2].upper() == "YES" if len(row) > 2 else True + key_type = row[3] if len(row) > 3 else "" + + columns.append( + { + "name": col_name, + "type": col_type, + "nullable": is_nullable, + "key": key_type, + } + ) + except Exception as e: - logger.error(f"All schema query methods failed for {table_name}: {e}") + logger.error( + f"All schema query methods failed for {table_name}: {e}" + ) # Add a minimal entry columns = [] - + # Get sample data (first few rows) to help LLM understand the data try: sample_query = f"SELECT * FROM {table_name} LIMIT 3" db = get_storage_db().__next__() sample_result = await StorageRepository.query(db, sample_query) - + # Convert sample data to list of dicts sample_data = [] if sample_result and len(sample_result) > 0: - if hasattr(sample_result[0], '_fields'): + if hasattr(sample_result[0], "_fields"): fields = sample_result[0]._fields for row in sample_result: - sample_data.append({field: getattr(row, field) for field in fields}) + sample_data.append( + {field: getattr(row, field) for field in fields} + ) else: # Fallback - sample_data = [dict(zip([c["name"] for c in columns], row)) for row in sample_result] + sample_data = [ + dict(zip([c["name"] for c in columns], row)) + for row in sample_result + ] except Exception as e: logger.warning(f"Failed to get sample data for {table_name}: {e}") sample_data = [] - + # Store schema information - schemas[table_name] = { - "columns": columns, - "sample_data": sample_data - } - + schemas[table_name] = {"columns": columns, "sample_data": sample_data} + if not schemas: logger.warning("No table schemas found in the storage database") - + return schemas - + except Exception as e: logger.error(f"Error getting table schemas: {e}", exc_info=True) return {} - + async def _get_knowledge_base_documents(self, knowledge_base_id: str) -> List[Any]: """ Get all documents from a knowledge base. - + Args: knowledge_base_id: The ID of the knowledge base - + Returns: List of documents """ try: - + db = get_db().__next__() - documents = db.query(Document).filter( - Document.knowledge_base_id == knowledge_base_id - ).all() - + documents = ( + db.query(Document) + .filter(Document.knowledge_base_id == knowledge_base_id) + .all() + ) + return documents - + except Exception as e: logger.error(f"Error getting knowledge base documents: {e}", exc_info=True) return [] - + async def _generate_sql(self, query: str, table_schemas: Dict[str, Any]) -> str: """ Generate SQL from a natural language query using an LLM. - + Args: query: The natural language query table_schemas: Dictionary of table schemas - + Returns: Generated SQL query as a string """ @@ -295,67 +316,62 @@ async def _generate_sql(self, query: str, table_schemas: Dict[str, Any]) -> str: for table_name, schema in table_schemas.items(): schema_text += f"Table: {table_name}\n" schema_text += "Columns:\n" - + for column in schema["columns"]: key_info = "" if column.get("key") == "PRI": key_info = " (PRIMARY KEY)" nullable = "NULL" if column.get("nullable", True) else "NOT NULL" schema_text += f" - {column['name']} ({column['type']}) {nullable}{key_info}\n" - + # Include sample data if available if schema.get("sample_data") and len(schema["sample_data"]) > 0: schema_text += "\nSample data (first 3 rows):\n" for i, row in enumerate(schema["sample_data"][:3]): schema_text += f"Row {i+1}: {json.dumps(row)}\n" - + schema_text += "\n" - + # Get the prompt from the registry - prompt = get_prompt("tag_service", "generate_sql", - query=query, - schema_text=schema_text) - + prompt = get_prompt( + "tag_service", "generate_sql", query=query, schema_text=schema_text + ) + # Create a message for the LLM - messages = [ - Message(role=Role.USER, content=prompt) - ] - + messages = [Message(role=Role.USER, content=prompt)] + # Set completion options options = CompletionOptions( temperature=0.2, # Lower temperature for more deterministic SQL generation - max_tokens=1000 + max_tokens=1000, ) - + # Generate SQL using LLM Factory - response = await LLMFactory.complete( - messages=messages, - options=options - ) - + response = await LLMFactory.complete(messages=messages, options=options) + # Extract SQL from response sql_query = response.content.strip() - + # Remove any markdown code block formatting if present - sql_query = re.sub(r'```sql\s*', '', sql_query) - sql_query = re.sub(r'```', '', sql_query) - + sql_query = re.sub(r"```sql\s*", "", sql_query) + sql_query = re.sub(r"```", "", sql_query) + # Log the generated SQL logger.info(f"Generated SQL query: {sql_query}") - + return sql_query - + except Exception as e: logger.error(f"Error generating SQL: {e}", exc_info=True) return "" - + async def _execute_sql(self, sql_query: str) -> List[Dict[str, Any]]: """ Execute the SQL query against the database using the storage repository. - + Args: sql_query: The SQL query to execute - + Returns: List of results as dictionaries """ @@ -364,35 +380,37 @@ async def _execute_sql(self, sql_query: str) -> List[Dict[str, Any]]: if not sql_query.strip().upper().startswith("SELECT"): logger.error(f"Attempted to execute non-SELECT query: {sql_query}") return [] - + logger.info(f"Executing SQL query via Storage Repository: {sql_query}") - + # Execute the query using the storage repository # Use the instance method through self.storage_repository db = get_storage_db().__next__() result_proxy = await StorageRepository.query(db, sql_query) results = [] - + # Convert result to list of dictionaries if result_proxy: # Get column names from the first result - if hasattr(result_proxy[0], '_fields'): + if hasattr(result_proxy[0], "_fields"): columns = result_proxy[0]._fields - elif hasattr(result_proxy[0], '_mapping'): + elif hasattr(result_proxy[0], "_mapping"): # For SQLAlchemy 2.0+ compatibility columns = [col for col in result_proxy[0]._mapping.keys()] else: # Try to infer column names from the result - logger.warning("Could not determine column names from result, using position") + logger.warning( + "Could not determine column names from result, using position" + ) columns = [f"column_{i}" for i in range(len(result_proxy[0]))] - + # Convert each row to a dictionary for row in result_proxy: # For named tuples - if hasattr(row, '_asdict'): + if hasattr(row, "_asdict"): row_dict = row._asdict() # For SQLAlchemy 2.0+ Row objects - elif hasattr(row, '_mapping'): + elif hasattr(row, "_mapping"): row_dict = dict(row._mapping) # Fallback else: @@ -402,69 +420,73 @@ async def _execute_sql(self, sql_query: str) -> List[Dict[str, Any]]: else: # Last resort row_dict = {f"value_{i}": val for i, val in enumerate(row)} - + # Handle special data types for key, value in list(row_dict.items()): # Convert date/time objects to ISO strings - if hasattr(value, 'isoformat'): + if hasattr(value, "isoformat"): row_dict[key] = value.isoformat() # Handle None values elif value is None: row_dict[key] = None # Convert non-JSON serializable types to strings - elif not isinstance(value, (str, int, float, bool, type(None), list, dict)): + elif not isinstance( + value, (str, int, float, bool, type(None), list, dict) + ): row_dict[key] = str(value) - + results.append(row_dict) - + logger.info(f"Query returned {len(results)} results") return results - + except Exception as e: logger.error(f"Error executing SQL: {e}", exc_info=True) return [] - - async def _generate_answer(self, query: str, sql: str, results: List[Dict[str, Any]]) -> str: + + async def _generate_answer( + self, query: str, sql: str, results: List[Dict[str, Any]] + ) -> str: """ Generate a natural language answer based on query, SQL, and results. - + Args: query: The original natural language query sql: The SQL query that was executed results: The results from the SQL query - + Returns: Natural language answer as a string """ try: # Format results for the prompt - results_text = json.dumps(results[:10], indent=2) # Limit to 10 rows for prompt size - + results_text = json.dumps( + results[:10], indent=2 + ) # Limit to 10 rows for prompt size + # Get the prompt from the registry - prompt = get_prompt("tag_service", "generate_answer", - query=query, - sql=sql, - results=results_text) - + prompt = get_prompt( + "tag_service", + "generate_answer", + query=query, + sql=sql, + results=results_text, + ) + # Create a message for the LLM - messages = [ - Message(role=Role.USER, content=prompt) - ] - + messages = [Message(role=Role.USER, content=prompt)] + # Set completion options options = CompletionOptions( temperature=0.5, # Moderate temperature for natural language generation - max_tokens=1000 + max_tokens=1000, ) - + # Generate answer using LLM Factory - response = await LLMFactory.complete( - messages=messages, - options=options - ) - + response = await LLMFactory.complete(messages=messages, options=options) + return response.content.strip() - + except Exception as e: logger.error(f"Error generating answer: {e}", exc_info=True) return f"I found results, but couldn't generate a natural language answer due to an error: {str(e)}" @@ -473,4 +495,4 @@ async def _generate_answer(self, query: str, sql: str, results: List[Dict[str, A @lru_cache() def get_tag_service() -> TagService: """Get a singleton instance of TagService""" - return TagService() \ No newline at end of file + return TagService() diff --git a/app/services/user_service.py b/app/services/user_service.py index 31aab34..7bec935 100644 --- a/app/services/user_service.py +++ b/app/services/user_service.py @@ -1,11 +1,14 @@ from typing import List, Optional -from fastapi import HTTPException, Depends + +from fastapi import Depends, HTTPException from sqlalchemy.orm import Session + from app.core.security import get_password_hash, verify_password +from app.db.database import get_db from app.db.models.user import User, UserRole from app.repositories.user_repository import UserRepository -from app.schemas.user import UserCreate, UserUpdate, UserResponse -from app.db.database import get_db +from app.schemas.user import UserCreate, UserResponse, UserUpdate + class UserService: def __init__(self, db: Session = Depends(get_db)): @@ -17,14 +20,14 @@ async def create_user(self, user_data: UserCreate) -> UserResponse: # Check if email already exists if await self.repository.get_by_email(user_data.email, self.db): raise HTTPException(status_code=400, detail="Email already registered") - + # Create user with hashed password user = User( email=user_data.email, hashed_password=get_password_hash(user_data.password), full_name=user_data.full_name, role=user_data.role or UserRole.USER, - is_verified=True + is_verified=True, ) return await self.repository.create(user, self.db) @@ -46,50 +49,53 @@ async def list_users(self, current_user: User) -> List[UserResponse]: return await self.repository.list_all(self.db) async def update_user( - self, - user_id: str, - user_update: UserUpdate, - current_user: User + self, user_id: str, user_update: UserUpdate, current_user: User ) -> UserResponse: """Update user""" # Check permissions if current_user.role != UserRole.ADMIN and current_user.id != user_id: raise HTTPException(status_code=403, detail="Not enough privileges") - + # Get existing user - user = await self.get_user(user_id) - + await self.get_user(user_id) + # Prepare update data update_data = user_update.model_dump(exclude_unset=True) - + # Hash new password if provided if "password" in update_data: - update_data["hashed_password"] = get_password_hash(update_data.pop("password")) - + update_data["hashed_password"] = get_password_hash( + update_data.pop("password") + ) + # Only admin can update role if "role" in update_data and current_user.role != UserRole.ADMIN: - raise HTTPException(status_code=403, detail="Not enough privileges to change role") - + raise HTTPException( + status_code=403, detail="Not enough privileges to change role" + ) + # Update user updated_user = await self.repository.update(user_id, update_data, self.db) if not updated_user: raise HTTPException(status_code=404, detail="User not found") - + return updated_user async def delete_user(self, user_id: str, current_user: User) -> None: """Delete user (admin only)""" if current_user.role != UserRole.ADMIN: raise HTTPException(status_code=403, detail="Not enough privileges") - + if not await self.repository.delete(user_id, self.db): raise HTTPException(status_code=404, detail="User not found") - async def authenticate_user(self, email: str, password: str) -> Optional[UserResponse]: + async def authenticate_user( + self, email: str, password: str + ) -> Optional[UserResponse]: """Authenticate user by email and password""" user = await self.repository.get_by_email(email, self.db) if not user: return None if not verify_password(password, user.hashed_password): return None - return user \ No newline at end of file + return user diff --git a/app/worker/celery.py b/app/worker/celery.py index 635dc43..d9606e2 100644 --- a/app/worker/celery.py +++ b/app/worker/celery.py @@ -1,8 +1,9 @@ -import os -import sys import logging -import platform import multiprocessing +import os +import platform +import sys + from celery import Celery # ===================================================================== @@ -24,25 +25,26 @@ # Set multiprocessing start method to 'spawn' on macOS # This prevents issues with fork() and MPS - if multiprocessing.get_start_method(allow_none=True) != 'spawn': + if multiprocessing.get_start_method(allow_none=True) != "spawn": try: - multiprocessing.set_start_method('spawn', force=True) + multiprocessing.set_start_method("spawn", force=True) except RuntimeError: # If already set, this will raise a RuntimeError pass # Add the parent directory to the path so we can import from app -parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +parent_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) if parent_dir not in sys.path: sys.path.append(parent_dir) # Set the default Python path -os.environ.setdefault('PYTHONPATH', '.') +os.environ.setdefault("PYTHONPATH", ".") # Configure logging logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) @@ -52,23 +54,23 @@ # Initialize Celery celery_app = Celery( - 'docbrain', + "docbrain", broker=settings.REDIS_URL or settings.CELERY_BROKER_URL, backend=settings.REDIS_URL or settings.CELERY_RESULT_BACKEND, - include=["app.worker.tasks"] + include=["app.worker.tasks"], ) # Configure Celery celery_app.conf.update( - task_serializer='json', - accept_content=['json'], - result_serializer='json', - timezone='UTC', + task_serializer="json", + accept_content=["json"], + result_serializer="json", + timezone="UTC", enable_utc=True, task_track_started=True, task_time_limit=3600, # 1 hour worker_max_tasks_per_child=100, - worker_prefetch_multiplier=1 + worker_prefetch_multiplier=1, ) # Configure Celery for macOS to avoid MPS issues @@ -79,6 +81,7 @@ worker_max_tasks_per_child=10, # Restart workers periodically to prevent memory leaks ) + # Pre-initialize models to prevent segmentation faults def pre_initialize_models(): """ @@ -87,43 +90,47 @@ def pre_initialize_models(): """ try: logger.info("Pre-initializing models...") - + # Import factories - from app.services.rag.reranker.reranker_factory import RerankerFactory from app.services.rag.ingestor.ingestor_factory import IngestorFactory - + from app.services.rag.reranker.reranker_factory import RerankerFactory + # Initialize rerankers with default configuration - RerankerFactory.initialize_models({"type": "flag", "model_name": "BAAI/bge-reranker-v2-m3"}) - + RerankerFactory.initialize_models( + {"type": "flag", "model_name": "BAAI/bge-reranker-v2-m3"} + ) + # Initialize ingestors IngestorFactory.initialize_ingestors() - + logger.info("Model pre-initialization complete") except Exception as e: logger.error(f"Failed to pre-initialize models: {e}", exc_info=True) logger.warning("Continuing without model pre-initialization") + # Function to run the worker (used by restart_worker.sh) def run_worker(): """Run the Celery worker with appropriate configuration""" logger.info("Starting Celery worker") - + # Log platform and multiprocessing information logger.info(f"Platform: {platform.system()} {platform.release()}") logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}") - + # Pre-initialize models to prevent segmentation faults # This must be done before worker starts and forks processes pre_initialize_models() - + # Use --pool=solo on macOS to avoid fork-related issues worker_args = ["worker", "--purge", "--loglevel=info", "-E", "--concurrency=50"] if platform.system() == "Darwin": worker_args.append("--pool=solo") logger.info("Using solo pool for macOS to avoid fork() issues") - + celery_app.worker_main(worker_args) + # This allows the file to be used both as a module and as a script -if __name__ == '__main__': - run_worker() \ No newline at end of file +if __name__ == "__main__": + run_worker() diff --git a/app/worker/tasks.py b/app/worker/tasks.py index 0ddf16d..1fcee7a 100644 --- a/app/worker/tasks.py +++ b/app/worker/tasks.py @@ -1,29 +1,35 @@ -from typing import Optional +import asyncio +import base64 import logging +from typing import Optional + from celery import shared_task -import base64 -import asyncio -from sqlalchemy.orm import Session from celery.exceptions import MaxRetriesExceededError +from sqlalchemy.orm import Session +from app.core.config import settings +from app.core.prompts import get_prompt, register_prompt from app.db.database import get_db from app.db.models.message import MessageContentType -from app.services.rag.vector_store import get_vector_store -from app.services.query_router import get_query_router -from app.core.config import settings from app.repositories.document_repository import DocumentRepository from app.repositories.message_repository import MessageRepository from app.repositories.question_repository import QuestionRepository from app.schemas.document import DocumentResponse from app.schemas.question import QuestionResponse +from app.services.llm.factory import CompletionOptions, LLMFactory +from app.services.llm.factory import Message as LLMMessage +from app.services.llm.factory import Role +from app.services.query_router import get_query_router +from app.services.rag.vector_store import get_vector_store from app.services.rag_service import get_rag_service -from app.services.llm.factory import LLMFactory, Message as LLMMessage, Role, CompletionOptions -from app.core.prompts import get_prompt, register_prompt logger = logging.getLogger(__name__) # Register prompts -register_prompt("worker", "document_summary", """Create a comprehensive summary of the following document. +register_prompt( + "worker", + "document_summary", + """Create a comprehensive summary of the following document. The summary should capture the main topics, key points, and important details. It should be detailed enough to understand what information is contained in the document. Focus on factual information rather than opinions. @@ -33,25 +39,27 @@ Document Content: {{ content }} -Summary:""") +Summary:""", +) DOCUMENT_REPO = DocumentRepository() MESSAGE_REPO = MessageRepository() QUESTION_REPO = QuestionRepository() RAG_SERVICE = get_rag_service() + @shared_task( bind=True, max_retries=3, default_retry_delay=60, autoretry_for=(Exception,), retry_backoff=True, - retry_jitter=True + retry_jitter=True, ) def initiate_document_ingestion(self, document_id: str) -> None: """ Process document content and create chunks. - + This task is triggered when a document is uploaded to the system. It performs the following steps: 1. Retrieve the document from the database @@ -64,30 +72,31 @@ def initiate_document_ingestion(self, document_id: str) -> None: 8. Update document status to COMPLETED """ logger.info(f"Starting document processing task for document_id: {document_id}") - + async def _ingest(db: Session): try: # Get document logger.info(f"Fetching document {document_id} from repository") - document: Optional[DocumentResponse] = await DOCUMENT_REPO.get_by_id(document_id, db) + document: Optional[DocumentResponse] = await DOCUMENT_REPO.get_by_id( + document_id, db + ) if not document: logger.error(f"Document {document_id} not found") raise ValueError(f"Document {document_id} not found") - - logger.info(f"Processing document: {document.title} (type: {document.content_type})") - + + logger.info( + f"Processing document: {document.title} (type: {document.content_type})" + ) + # Update status to processing logger.info(f"Updating document {document_id} status to PROCESSING") await DOCUMENT_REPO.set_processing(document_id, db) - + # Generate document summary logger.info(f"Generating summary for document {document_id}") - summary = await _generate_document_summary( - document.content, - document.title - ) + summary = await _generate_document_summary(document.content, document.title) logger.info(f"Summary generated for document {document_id}") - + # Prepare metadata metadata = { "document_id": document_id, @@ -96,7 +105,7 @@ async def _ingest(db: Session): "knowledge_base_id": document.knowledge_base_id, "document_type": document.content_type, } - + # Method 1: Use RAG service for end-to-end processing # This is simpler but provides less control over individual steps result = await RAG_SERVICE.ingest_document( @@ -105,185 +114,216 @@ async def _ingest(db: Session): content_type=document.content_type, ) chunk_count = result["chunk_count"] - + # Add the summary to the summary index for semantic routing - logger.info(f"Adding summary to the summary index for document {document_id}") + logger.info( + f"Adding summary to the summary index for document {document_id}" + ) await RAG_SERVICE.add_document_summary( document_id=document_id, knowledge_base_id=document.knowledge_base_id, document_title=document.title, document_type=document.content_type, - summary=summary + summary=summary, ) - - logger.info(f"Document {document_id} processed successfully with {chunk_count} chunks") - - # Update document with summary and status - await DOCUMENT_REPO.set_processed( - document_id, - summary, - chunk_count, - db + + logger.info( + f"Document {document_id} processed successfully with {chunk_count} chunks" ) + + # Update document with summary and status + await DOCUMENT_REPO.set_processed(document_id, summary, chunk_count, db) except MaxRetriesExceededError: logger.error(f"Max retries exceeded for document {document_id}") await DOCUMENT_REPO.set_failed( - document_id, - "Processing failed after maximum retries", - db + document_id, "Processing failed after maximum retries", db ) except Exception as e: - logger.error(f"Failed to process document {document_id}: {e}", exc_info=True) - await DOCUMENT_REPO.set_failed( - document_id, - str(e), - db + logger.error( + f"Failed to process document {document_id}: {e}", exc_info=True ) + await DOCUMENT_REPO.set_failed(document_id, str(e), db) raise # Let Celery handle the retry # Run the async function using asyncio.run() try: return asyncio.run(_ingest(get_db().__next__())) except Exception as e: - logger.error(f"Failed to run async process for document {document_id}: {e}", exc_info=True) + logger.error( + f"Failed to run async process for document {document_id}: {e}", + exc_info=True, + ) raise + async def _generate_document_summary(content: bytes, title: str) -> str: """Generate a summary of the document content using the LLM factory""" try: # Convert bytes to base64 string - content_str = base64.b64encode(content).decode('utf-8') + content_str = base64.b64encode(content).decode("utf-8") # Truncate content if it's too long max_content_length = 10000 # Adjust based on model limits - truncated_content = content_str[:max_content_length] + "..." if len(content_str) > max_content_length else content_str - + truncated_content = ( + content_str[:max_content_length] + "..." + if len(content_str) > max_content_length + else content_str + ) + # Get the prompt from the registry - prompt = get_prompt("worker", "document_summary", - title=title, - content=truncated_content) - + prompt = get_prompt( + "worker", "document_summary", title=title, content=truncated_content + ) + # Create a message for the LLM - messages = [ - LLMMessage(role=Role.USER, content=prompt) - ] - + messages = [LLMMessage(role=Role.USER, content=prompt)] + # Set completion options options = CompletionOptions( temperature=0.3, # Lower temperature for more factual summarization - max_tokens=1000 + max_tokens=1000, ) - + # Generate summary using LLM Factory - response = await LLMFactory.complete( - messages=messages, - options=options - ) - + response = await LLMFactory.complete(messages=messages, options=options) + summary = response.content.strip() - + # Ensure summary isn't too long for database storage max_summary_length = 5000 # Adjust based on database field size if len(summary) > max_summary_length: - summary = summary[:max_summary_length - 3] + "..." - + summary = summary[: max_summary_length - 3] + "..." + return summary except Exception as e: logger.error(f"Error generating document summary: {e}", exc_info=True) return f"Summary generation failed: {str(e)}" + @shared_task( bind=True, max_retries=3, default_retry_delay=30, autoretry_for=(Exception,), - retry_backoff=True + retry_backoff=True, ) def initiate_document_vector_deletion(self, document_id: str) -> None: """Delete document vectors from vector store""" + async def _delete_vectors(db: Session): - + try: logger.info(f"Starting vector deletion for document {document_id}") - + # Get the document to find its knowledge base ID document = await DOCUMENT_REPO.get_by_id(document_id, db) if not document: - logger.warning(f"Document {document_id} not found, cannot delete vectors") + logger.warning( + f"Document {document_id} not found, cannot delete vectors" + ) return True - + knowledge_base_id = document.knowledge_base_id - logger.info(f"Found document {document_id} in knowledge base {knowledge_base_id}") - + logger.info( + f"Found document {document_id} in knowledge base {knowledge_base_id}" + ) + # Delete document success = await RAG_SERVICE.delete_document(document_id, knowledge_base_id) - + # Also delete document summary from the summary index - logger.info(f"Deleting document summary from summary index for document {document_id}") + logger.info( + f"Deleting document summary from summary index for document {document_id}" + ) summary_vector_store = get_vector_store( - store_type="pinecone", - index_name=settings.PINECONE_SUMMARY_INDEX_NAME + store_type="pinecone", index_name=settings.PINECONE_SUMMARY_INDEX_NAME ) await summary_vector_store.delete_document_chunks(document_id, "summaries") - + if success: - logger.info(f"Successfully deleted vectors for document {document_id} from knowledge base {knowledge_base_id}") + logger.info( + f"Successfully deleted vectors for document {document_id} from knowledge base {knowledge_base_id}" + ) else: - logger.warning(f"Failed to delete vectors for document {document_id}, but continuing") - + logger.warning( + f"Failed to delete vectors for document {document_id}, but continuing" + ) + return True - + except Exception as e: error_msg = str(e) - logger.error(f"Failed to delete vectors for document {document_id}: {error_msg}", exc_info=True) - + logger.error( + f"Failed to delete vectors for document {document_id}: {error_msg}", + exc_info=True, + ) + # Check if this is a Pinecone API error if "400" in error_msg and "Bad Request" in error_msg: - logger.warning(f"Pinecone API error (400 Bad Request) for document {document_id}") - + logger.warning( + f"Pinecone API error (400 Bad Request) for document {document_id}" + ) + # Check specifically for the Serverless/Starter tier error - if "Serverless and Starter indexes do not support deleting with metadata filtering" in error_msg: - logger.warning("This error is expected with Pinecone Serverless/Starter tiers") - logger.warning("The vector_store.py and pinecone_retriever.py have been updated to handle this case") + if ( + "Serverless and Starter indexes do not support deleting with metadata filtering" + in error_msg + ): + logger.warning( + "This error is expected with Pinecone Serverless/Starter tiers" + ) + logger.warning( + "The vector_store.py and pinecone_retriever.py have been updated to handle this case" + ) logger.warning("Please retry the document deletion") # We'll raise the exception to trigger a retry, as our updated code should handle it raise else: - logger.warning("This may be due to an invalid filter format or the document not existing in the index") + logger.warning( + "This may be due to an invalid filter format or the document not existing in the index" + ) # We'll consider this a "success" since we can't do anything about it - logger.info(f"Marking document {document_id} vector deletion as complete despite error") + logger.info( + f"Marking document {document_id} vector deletion as complete despite error" + ) return True - + raise try: return asyncio.run(_delete_vectors(get_db().__next__())) except Exception as e: - logger.error(f"Failed to run async process for document vector deletion {document_id}: {e}", exc_info=True) + logger.error( + f"Failed to run async process for document vector deletion {document_id}: {e}", + exc_info=True, + ) # Check retry count and provide more context retry_count = self.request.retries max_retries = self.max_retries logger.info(f"Current retry count: {retry_count}/{max_retries}") - + if retry_count >= max_retries: - logger.warning(f"Max retries ({max_retries}) exceeded for document {document_id} vector deletion") + logger.warning( + f"Max retries ({max_retries}) exceeded for document {document_id} vector deletion" + ) # We'll consider this a "success" to prevent the task from being stuck in the queue - logger.info(f"Marking document {document_id} vector deletion as complete despite errors") + logger.info( + f"Marking document {document_id} vector deletion as complete despite errors" + ) return True - + raise + @shared_task( bind=True, max_retries=2, default_retry_delay=10, autoretry_for=(Exception,), - retry_backoff=True + retry_backoff=True, ) def initiate_rag_retrieval( - self, - user_message_id: str, - assistant_message_id: str + self, user_message_id: str, assistant_message_id: str ) -> None: """Initiate query processing for a given user message and update the corresponding assistant message. This implementation first uses the QueryRouter to determine which service to use (RAG or TAG), @@ -307,85 +347,95 @@ async def _retrieve(db: Session): # Use the content from the user message as the query query = user_msg.content knowledge_base_id = user_msg.knowledge_base_id - + # Get the query router - this is a singleton query_router = get_query_router() - + # Use the full route_and_dispatch method to handle questions index and proper routing logger.info(f"Calling query router to route and dispatch query: '{query}'") - - metadata_filter = {"knowledge_base_id": knowledge_base_id} if knowledge_base_id else {} - + + metadata_filter = ( + {"knowledge_base_id": knowledge_base_id} if knowledge_base_id else {} + ) + # This will check questions index first, then route if needed response = await query_router.route_and_dispatch( query=query, metadata_filter=metadata_filter, top_k=top_k, - similarity_threshold=similarity_threshold + similarity_threshold=similarity_threshold, ) - + # Log which service was used service = response.get("service", "unknown") routing_info = response.get("routing_info", {}) logger.info(f"Successfully processed query using {service} service") - + # Add routing metadata to the response sources and ensure all required fields are present sources = response.get("sources", []) for i, source in enumerate(sources): # Ensure all required fields are present in each source if "score" not in source: - source["score"] = 1.0 if service == "tag" else source.get("similarity", 0.8) - + source["score"] = ( + 1.0 if service == "tag" else source.get("similarity", 0.8) + ) + if "content" not in source: if service == "tag": - source["content"] = f"Table data from {source.get('title', 'database')}" + source["content"] = ( + f"Table data from {source.get('title', 'database')}" + ) else: source["content"] = source.get("text", "No content available") - + # Handle the source differently based on the service if service == "questions": # For questions service, use question-specific fields if "question_id" not in source: metadata = source.get("metadata", {}) - source["question_id"] = metadata.get("question_id", f"question_{i}") - + source["question_id"] = metadata.get( + "question_id", f"question_{i}" + ) + if "question" not in source: source["question"] = source.get("content", "") - + # Add answer field from metadata if not present if "answer" not in source: metadata = source.get("metadata", {}) source["answer"] = metadata.get("answer", "") - + if "answer_type" not in source: metadata = source.get("metadata", {}) source["answer_type"] = metadata.get("answer_type", "DIRECT") - + # Make document_id optional if "document_id" not in source: - source["document_id"] = source.get("question_id", f"question_{i}") + source["document_id"] = source.get( + "question_id", f"question_{i}" + ) else: # For document-based services (RAG/TAG) if "chunk_index" not in source: source["chunk_index"] = i - + # Ensure document_id field is present if "document_id" not in source: metadata = source.get("metadata", {}) source["document_id"] = metadata.get("document_id", f"doc_{i}") - + # Ensure title field is present if "title" not in source: metadata = source.get("metadata", {}) source["title"] = metadata.get("doc_title", "Untitled Document") - + # Add routing information to each source source["routing"] = { "service": service, "confidence": routing_info.get("confidence", 0), - "reasoning": routing_info.get("reasoning", "No reasoning provided") + "reasoning": routing_info.get("reasoning", "No reasoning provided"), } - + # If TAG service was used, add SQL information to the source if service == "tag" and response.get("sql"): source["sql_query"] = response.get("sql") @@ -396,18 +446,20 @@ async def _retrieve(db: Session): if service == "tag" and response.get("sql"): metadata["sql_query"] = response.get("sql") if response.get("results"): - metadata["sql_results"] = response.get("results")[:5] # Limit to first 5 results to avoid huge payload - + metadata["sql_results"] = response.get("results")[ + :5 + ] # Limit to first 5 results to avoid huge payload + # Check if sources have all required fields before calling set_processed logger.info(f"Checking sources for required fields: {sources}") - + await MESSAGE_REPO.set_processed( message_id=assistant_message_id, content=response.get("answer", ""), content_type=MessageContentType.TEXT, sources=sources, metadata=metadata, - db=db + db=db, ) except MaxRetriesExceededError: @@ -415,24 +467,28 @@ async def _retrieve(db: Session): logger.error(err_msg) await MESSAGE_REPO.set_failed(assistant_message_id, err_msg, db) except Exception as e: - logger.error(f"Failed to process response for message {assistant_message_id}: {e}", exc_info=True) + logger.error( + f"Failed to process response for message {assistant_message_id}: {e}", + exc_info=True, + ) await MESSAGE_REPO.set_failed(assistant_message_id, str(e), db) raise return asyncio.run(_retrieve(get_db().__next__())) + @shared_task( bind=True, max_retries=3, default_retry_delay=60, autoretry_for=(Exception,), retry_backoff=True, - retry_jitter=True + retry_jitter=True, ) def initiate_question_ingestion(self, question_id: str) -> None: """ Process question and add it to the vector store. - + This task is triggered when a question is created or updated. It performs the following steps: 1. Retrieve the question from the database @@ -442,60 +498,70 @@ def initiate_question_ingestion(self, question_id: str) -> None: 5. Update question status to COMPLETED """ logger.info(f"Starting question ingestion task for question_id: {question_id}") - + async def _ingest(db: Session): try: # Get question logger.info(f"Fetching question {question_id} from repository") - question: Optional[QuestionResponse] = await QUESTION_REPO.get_by_id(question_id, db) + question: Optional[QuestionResponse] = await QUESTION_REPO.get_by_id( + question_id, db + ) if not question: logger.error(f"Question {question_id} not found") raise ValueError(f"Question {question_id} not found") - - logger.info(f"Processing question: {question.question} (type: {question.answer_type})") - + + logger.info( + f"Processing question: {question.question} (type: {question.answer_type})" + ) + # Update status to ingesting logger.info(f"Updating question {question_id} status to INGESTING") await QUESTION_REPO.set_ingesting(question_id, db) - + try: # Get vector store instance vector_store = get_vector_store( store_type="pinecone", - index_name=settings.PINECONE_QUESTIONS_INDEX_NAME + index_name=settings.PINECONE_QUESTIONS_INDEX_NAME, ) - + # Create metadata with question-specific fields only metadata = { "question_id": question_id, "knowledge_base_id": question.knowledge_base_id, "answer_type": str(question.answer_type), "question": question.question, # Store the actual question for retrieval - "answer": question.answer, # Store the answer in metadata as well + "answer": question.answer, # Store the answer in metadata as well "user_id": str(question.user_id), } - + # Format content for vector store - include both question and answer in content - formatted_content = f"Question: {question.question}\nAnswer: {question.answer}" - + formatted_content = ( + f"Question: {question.question}\nAnswer: {question.answer}" + ) + # Store in vector store using knowledge_base_id as namespace # Use the specialized add_questions method for questions await vector_store.add_questions( texts=[formatted_content], metadatas=[metadata], ids=[f"question:{question_id}"], - collection_name=question.knowledge_base_id + collection_name=question.knowledge_base_id, + ) + + logger.info( + f"Question {question_id} successfully ingested into questions index" ) - - logger.info(f"Question {question_id} successfully ingested into questions index") await QUESTION_REPO.set_completed(question_id, db) - + except Exception as e: # Update status to failed - logger.error(f"Failed to ingest question {question_id}: {e}", exc_info=True) + logger.error( + f"Failed to ingest question {question_id}: {e}", exc_info=True + ) await QUESTION_REPO.set_failed(question_id, db) raise - + except Exception as e: logger.error(f"Failed to ingest question: {e}", exc_info=True) raise @@ -503,7 +569,7 @@ async def _ingest(db: Session): # Create a new event loop for this task loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - + try: # Get database session db = next(get_db()) @@ -522,41 +588,47 @@ async def _ingest(db: Session): except Exception as e: logger.error(f"Error cleaning up event loop: {e}", exc_info=True) + @shared_task( bind=True, max_retries=3, default_retry_delay=30, autoretry_for=(Exception,), - retry_backoff=True + retry_backoff=True, ) -def initiate_question_vector_deletion(self, question_id: str, knowledge_base_id: str) -> None: +def initiate_question_vector_deletion( + self, question_id: str, knowledge_base_id: str +) -> None: """ Delete question vectors from vector store. - + This task is triggered when a question is deleted. """ - logger.info(f"Starting question vector deletion task for question_id: {question_id}") - + logger.info( + f"Starting question vector deletion task for question_id: {question_id}" + ) + async def _delete_vectors(db: Session): try: # Get vector store instance for questions index vector_store = get_vector_store( - store_type="pinecone", - index_name=settings.PINECONE_QUESTIONS_INDEX_NAME + store_type="pinecone", index_name=settings.PINECONE_QUESTIONS_INDEX_NAME ) - + # Delete from vector store using knowledge_base_id as namespace await vector_store.delete_document_chunks( document_id=f"question:{question_id}", - knowledge_base_id=knowledge_base_id + knowledge_base_id=knowledge_base_id, + ) + + logger.info( + f"Successfully deleted question {question_id} vectors from questions index" ) - - logger.info(f"Successfully deleted question {question_id} vectors from questions index") - + except Exception as e: logger.error(f"Failed to delete question vectors: {e}", exc_info=True) raise - + # Create or get the event loop explicitly try: # Try to get existing event loop @@ -565,7 +637,7 @@ async def _delete_vectors(db: Session): # If there's no event loop, create a new one and set it loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - + try: for db in get_db(): loop.run_until_complete(_delete_vectors(db)) diff --git a/requirements-test.txt b/requirements-test.txt index da224ea..787221d 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -23,6 +23,9 @@ email-validator>=2.0.0 # Config python-dotenv>=1.0.0 +# Required for FastAPI form/file upload endpoints +python-multipart>=0.0.6 + # Testing pytest>=8.0.0 pytest-asyncio>=0.25.0 diff --git a/tests/conftest.py b/tests/conftest.py index 8b709bc..053b20e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ Environment variables and module mocks are set up before any app imports to avoid database connection errors during testing. """ + import os import sys from unittest.mock import MagicMock @@ -37,6 +38,7 @@ # Now safe to import app modules # --------------------------------------------------------------------------- import pytest + from app.db.models.user import UserRole from app.schemas.user import UserResponse diff --git a/tests/test_auth.py b/tests/test_auth.py index fa991bd..3dace62 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,7 +1,7 @@ """Tests for authentication utilities.""" -import pytest + from datetime import timedelta -from unittest.mock import patch + from jose import jwt @@ -10,6 +10,7 @@ class TestAccessToken: def test_create_token_returns_string(self): from app.api.deps import create_access_token + token = create_access_token("user-123") assert isinstance(token, str) assert len(token) > 0 @@ -17,26 +18,36 @@ def test_create_token_returns_string(self): def test_token_contains_user_id(self): from app.api.deps import create_access_token from app.core.config import settings + token = create_access_token("user-abc") - payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + ) assert payload["sub"] == "user-abc" def test_token_has_expiry(self): from app.api.deps import create_access_token from app.core.config import settings + token = create_access_token("user-123") - payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + ) assert "exp" in payload def test_custom_expiry_delta(self): from app.api.deps import create_access_token from app.core.config import settings + token = create_access_token("user-123", expires_delta=timedelta(minutes=5)) - payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + ) assert "exp" in payload def test_different_users_get_different_tokens(self): from app.api.deps import create_access_token + token1 = create_access_token("user-1") token2 = create_access_token("user-2") assert token1 != token2 diff --git a/tests/test_config.py b/tests/test_config.py index 09179ef..370649b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,4 @@ """Tests for configuration and settings.""" -import os -import pytest class TestCORSConfig: @@ -8,6 +6,7 @@ class TestCORSConfig: def test_default_cors_origins(self): from app.core.config import settings + origins = settings.CORS_ORIGIN_LIST assert isinstance(origins, list) assert len(origins) >= 1 @@ -16,6 +15,7 @@ def test_default_cors_origins(self): def test_cors_origins_are_strings(self): from app.core.config import settings + for origin in settings.CORS_ORIGIN_LIST: assert isinstance(origin, str) assert origin.startswith("http") @@ -26,11 +26,13 @@ class TestSecurityConfig: def test_token_expiry_is_reasonable(self): from app.core.config import settings + # Should be at most 24 hours (1440 minutes) assert settings.ACCESS_TOKEN_EXPIRE_MINUTES <= 1440 def test_algorithm_is_set(self): from app.core.config import settings + assert settings.ALGORITHM == "HS256" @@ -39,9 +41,11 @@ class TestRateLimitConfig: def test_rate_limit_has_default(self): from app.core.config import settings + assert settings.RATE_LIMIT_PER_MINUTE > 0 def test_rate_limit_is_reasonable(self): from app.core.config import settings + # Should be between 10 and 10000 assert 10 <= settings.RATE_LIMIT_PER_MINUTE <= 10000 diff --git a/tests/test_factories.py b/tests/test_factories.py index a87df31..21007e2 100644 --- a/tests/test_factories.py +++ b/tests/test_factories.py @@ -4,9 +4,11 @@ sentence-transformers). These tests mock those dependencies so they run in any environment, including CI without GPU or ML packages. """ + import sys +from unittest.mock import MagicMock + import pytest -from unittest.mock import MagicMock, patch from app.db.models.knowledge_base import DocumentType @@ -16,21 +18,24 @@ # --------------------------------------------------------------------------- class TestChunkerFactory: def test_create_chunker_returns_multi_level(self): - from app.services.rag.chunker.chunker_factory import ChunkerFactory from app.services.rag.chunker.chunker import MultiLevelChunker + from app.services.rag.chunker.chunker_factory import ChunkerFactory + chunker = ChunkerFactory.create_chunker(DocumentType.PDF) assert isinstance(chunker, MultiLevelChunker) def test_create_from_metadata_uses_document_type(self): - from app.services.rag.chunker.chunker_factory import ChunkerFactory from app.services.rag.chunker.chunker import MultiLevelChunker + from app.services.rag.chunker.chunker_factory import ChunkerFactory + metadata = {"document_type": DocumentType.CSV} chunker = ChunkerFactory.create_chunker_from_metadata(metadata) assert isinstance(chunker, MultiLevelChunker) def test_create_from_metadata_defaults_to_txt(self): - from app.services.rag.chunker.chunker_factory import ChunkerFactory from app.services.rag.chunker.chunker import MultiLevelChunker + from app.services.rag.chunker.chunker_factory import ChunkerFactory + chunker = ChunkerFactory.create_chunker_from_metadata({}) assert isinstance(chunker, MultiLevelChunker) @@ -71,24 +76,28 @@ def _mock_ingestors(self): def test_pdf_type(self, _mock_ingestors): from app.services.rag.ingestor.ingestor_factory import IngestorFactory + IngestorFactory._pdf_ingestor = None ingestor = IngestorFactory.create_ingestor(DocumentType.PDF) assert type(ingestor).__name__ == "PDFIngestor" def test_csv_type(self, _mock_ingestors): from app.services.rag.ingestor.ingestor_factory import IngestorFactory + IngestorFactory._csv_ingestor = None ingestor = IngestorFactory.create_ingestor(DocumentType.CSV) assert type(ingestor).__name__ == "CSVIngestor" def test_txt_type(self, _mock_ingestors): from app.services.rag.ingestor.ingestor_factory import IngestorFactory + IngestorFactory._text_ingestor = None ingestor = IngestorFactory.create_ingestor(DocumentType.TXT) assert type(ingestor).__name__ == "TextIngestor" def test_singleton_returns_same_instance(self, _mock_ingestors): from app.services.rag.ingestor.ingestor_factory import IngestorFactory + IngestorFactory._pdf_ingestor = None first = IngestorFactory.create_ingestor(DocumentType.PDF) second = IngestorFactory.create_ingestor(DocumentType.PDF) @@ -106,7 +115,9 @@ def _mock_retriever(self): mock_retriever_mod.PineconeRetriever = mock_pinecone_retriever orig = sys.modules.get("app.services.rag.retriever.pinecone_retriever") - sys.modules["app.services.rag.retriever.pinecone_retriever"] = mock_retriever_mod + sys.modules["app.services.rag.retriever.pinecone_retriever"] = ( + mock_retriever_mod + ) if "app.services.rag.retriever.retriever_factory" in sys.modules: del sys.modules["app.services.rag.retriever.retriever_factory"] @@ -122,10 +133,12 @@ def _mock_retriever(self): def test_default_creates_pinecone(self, _mock_retriever): from app.services.rag.retriever.retriever_factory import RetrieverFactory + RetrieverFactory.create_retriever("kb-123") _mock_retriever.assert_called_with("kb-123") def test_unknown_type_falls_back(self, _mock_retriever): from app.services.rag.retriever.retriever_factory import RetrieverFactory + RetrieverFactory.create_retriever("kb-789", retriever_type="unknown") _mock_retriever.assert_called_with("kb-789") diff --git a/tests/test_health.py b/tests/test_health.py index eefe9f2..bae51cf 100644 --- a/tests/test_health.py +++ b/tests/test_health.py @@ -1,22 +1,35 @@ """Tests for the health and root endpoints using a real TestClient.""" + import sys from unittest.mock import MagicMock -import pytest - # Mock heavy dependencies so we can import app.main without ML libraries. _MOCKED_MODULES = [ - "aiofiles", "celery", "celery.result", + "aiofiles", + "celery", + "celery.result", "pinecone", - "PyPDF2", "markdown", - "PIL", "PIL.Image", "pytesseract", - "docling", "docling.document_converter", - "docling.datamodel", "docling.datamodel.base_models", + "PyPDF2", + "markdown", + "PIL", + "PIL.Image", + "pytesseract", + "docling", + "docling.document_converter", + "docling.datamodel", + "docling.datamodel.base_models", "docling.datamodel.pipeline_options", - "torch", "sentence_transformers", "FlagEmbedding", - "sendgrid", "sendgrid.helpers", "sendgrid.helpers.mail", - "google.generativeai", "google.genai", - "openai", "anthropic", "dirtyjson", + "torch", + "sentence_transformers", + "FlagEmbedding", + "sendgrid", + "sendgrid.helpers", + "sendgrid.helpers.mail", + "google.generativeai", + "google.genai", + "openai", + "anthropic", + "dirtyjson", ] for _mod in _MOCKED_MODULES: @@ -25,11 +38,13 @@ # pymysql shim so SQLAlchemy can resolve the mysql dialect try: import pymysql + pymysql.install_as_MySQLdb() except ImportError: sys.modules.setdefault("MySQLdb", MagicMock()) from fastapi.testclient import TestClient + from app.main import app # noqa: E402 — must come after mocks client = TestClient(app) diff --git a/tests/test_models.py b/tests/test_models.py index d0b1a22..08c8f7a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,5 +1,5 @@ """Tests for database models and enums.""" -import pytest + from app.db.models.knowledge_base import DocumentStatus, DocumentType from app.db.models.user import UserRole diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 8ec7133..123dbe1 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -1,10 +1,6 @@ """Tests for the RBAC permission system.""" -import pytest -from app.core.permissions import ( - Permission, - ROLE_PERMISSIONS, - get_permissions_for_role, -) + +from app.core.permissions import Permission, get_permissions_for_role from app.db.models.user import UserRole diff --git a/tests/test_rate_limiter.py b/tests/test_rate_limiter.py index eadf7a0..90b3c91 100644 --- a/tests/test_rate_limiter.py +++ b/tests/test_rate_limiter.py @@ -1,8 +1,9 @@ """Tests for the rate limiting middleware.""" + import time +from unittest.mock import AsyncMock, MagicMock + import pytest -from collections import defaultdict -from unittest.mock import MagicMock, AsyncMock from app.core.middleware import RateLimitMiddleware @@ -54,7 +55,7 @@ async def test_exempt_paths_are_not_limited(self): # Should pass through even with rpm=1 for _ in range(5): - response = await mw.dispatch(request, call_next) + await mw.dispatch(request, call_next) assert call_next.call_count == 5 @pytest.mark.asyncio diff --git a/tests/test_schemas.py b/tests/test_schemas.py index a730a13..0974f6d 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -1,9 +1,10 @@ """Tests for Pydantic schemas validation.""" + import pytest from pydantic import ValidationError -from app.schemas.user import UserCreate, UserResponse, UserUpdate from app.db.models.user import UserRole +from app.schemas.user import UserCreate, UserResponse, UserUpdate class TestUserCreate: