diff --git a/backend/src/api/routers/conversations.py b/backend/src/api/routers/conversations.py index f0450628..99c74941 100644 --- a/backend/src/api/routers/conversations.py +++ b/backend/src/api/routers/conversations.py @@ -10,6 +10,7 @@ from langchain_ollama import ChatOllama from langchain_core.messages import AIMessageChunk from starlette.responses import StreamingResponse +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session from ...agents.retriever_graph import RetrieverGraph @@ -222,7 +223,11 @@ def parse_agent_output(output: list) -> tuple[str, list[ContextSource], list[str def get_history_str(db: Session | None, conversation_uuid: UUID | None) -> str: if use_db and db and conversation_uuid: - history = crud.get_conversation_history(db, conversation_uuid) + try: + history = crud.get_conversation_history(db, conversation_uuid) + except SQLAlchemyError: + logging.error("Failed to retrieve conversation history", exc_info=True) + return "" history_str = "" for i in history: user_msg = i.get("User", "") @@ -254,21 +259,27 @@ async def get_agent_response( conversation_uuid = user_input.conversation_uuid - if use_db and db: - conversation = crud.get_or_create_conversation( - db, - conversation_uuid=conversation_uuid, - title=user_question[:100] if user_question else None, - ) - conversation_uuid = conversation.uuid + db_persist = use_db and db is not None + if db_persist: + try: + conversation = crud.get_or_create_conversation( + db, + conversation_uuid=conversation_uuid, + title=user_question[:100] if user_question else None, + ) + conversation_uuid = conversation.uuid - crud.create_message( - db=db, - conversation_uuid=conversation.uuid, - role="user", - content=user_question, - ) - else: + crud.create_message( + db=db, + conversation_uuid=conversation.uuid, + role="user", + content=user_question, + ) + except SQLAlchemyError: + logging.error("Failed to persist user message", exc_info=True) + db_persist = False + + if not db_persist: if conversation_uuid is None: from uuid import uuid4 @@ -280,7 +291,7 @@ async def get_agent_response( "messages": [ ("user", user_question), ], - "chat_history": get_history_str(db, conversation_uuid), + "chat_history": get_history_str(db if db_persist else None, conversation_uuid), } if rg.graph is not None: @@ -296,15 +307,18 @@ async def get_agent_response( ] } - if use_db and db and conversation_uuid: - crud.create_message( - db=db, - conversation_uuid=conversation_uuid, - role="assistant", - content=llm_response, - context_sources=context_sources_dict, - tools=tools, - ) + if db_persist and conversation_uuid: + try: + crud.create_message( + db=db, + conversation_uuid=conversation_uuid, + role="assistant", + content=llm_response, + context_sources=context_sources_dict, + tools=tools, + ) + except SQLAlchemyError: + logging.error("Failed to persist assistant message", exc_info=True) else: if conversation_uuid: chat_history[conversation_uuid].append( @@ -349,21 +363,27 @@ async def get_response_stream(user_input: UserInput, db: Session | None) -> Any: conversation_uuid = user_input.conversation_uuid - if use_db and db: - conversation = crud.get_or_create_conversation( - db, - conversation_uuid=conversation_uuid, - title=user_question[:100] if user_question else None, - ) - conversation_uuid = conversation.uuid + db_persist = use_db and db is not None + if db_persist: + try: + conversation = crud.get_or_create_conversation( + db, + conversation_uuid=conversation_uuid, + title=user_question[:100] if user_question else None, + ) + conversation_uuid = conversation.uuid - crud.create_message( - db=db, - conversation_uuid=conversation.uuid, - role="user", - content=user_question, - ) - else: + crud.create_message( + db=db, + conversation_uuid=conversation.uuid, + role="user", + content=user_question, + ) + except SQLAlchemyError: + logging.error("Failed to persist user message", exc_info=True) + db_persist = False + + if not db_persist: if conversation_uuid is None: from uuid import uuid4 @@ -375,7 +395,7 @@ async def get_response_stream(user_input: UserInput, db: Session | None) -> Any: "messages": [ ("user", user_question), ], - "chat_history": get_history_str(db, conversation_uuid), + "chat_history": get_history_str(db if db_persist else None, conversation_uuid), } urls: list[str] = [] @@ -412,18 +432,21 @@ async def get_response_stream(user_input: UserInput, db: Session | None) -> Any: full_response = "".join(chunks) - if use_db and db and conversation_uuid: - context_sources_dict: dict[str, Any] = { - "sources": [{"source": url, "context": ""} for url in urls] - } - crud.create_message( - db=db, - conversation_uuid=conversation_uuid, - role="assistant", - content=full_response, - context_sources=context_sources_dict, - tools=[], - ) + if db_persist and conversation_uuid: + try: + context_sources_dict: dict[str, Any] = { + "sources": [{"source": url, "context": ""} for url in urls] + } + crud.create_message( + db=db, + conversation_uuid=conversation_uuid, + role="assistant", + content=full_response, + context_sources=context_sources_dict, + tools=[], + ) + except SQLAlchemyError: + logging.error("Failed to persist assistant message", exc_info=True) else: if conversation_uuid: chat_history[conversation_uuid].append( @@ -446,7 +469,13 @@ async def create_conversation(db: Session = Depends(get_db)) -> ConversationResp """Creates a new conversation with an auto-generated UUID identifier.""" if not use_db: raise HTTPException(status_code=503, detail=DB_DISABLED_MSG) - new_conversation = crud.create_conversation(db, conversation_uuid=None, title=None) + try: + new_conversation = crud.create_conversation( + db, conversation_uuid=None, title=None + ) + except SQLAlchemyError: + logging.error("Failed to create conversation", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to create conversation") return ConversationResponse.model_validate(new_conversation) @@ -457,7 +486,11 @@ async def list_conversations( """Retrieves a paginated list of all conversations without their messages.""" if not use_db: raise HTTPException(status_code=503, detail=DB_DISABLED_MSG) - conversations = crud.get_all_conversations(db, skip=skip, limit=limit) + try: + conversations = crud.get_all_conversations(db, skip=skip, limit=limit) + except SQLAlchemyError: + logging.error("Failed to list conversations", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to list conversations") return [ConversationListResponse.model_validate(conv) for conv in conversations] @@ -468,7 +501,11 @@ async def get_conversation( """Retrieves a complete conversation including all associated messages.""" if not use_db: raise HTTPException(status_code=503, detail=DB_DISABLED_MSG) - conversation = crud.get_conversation(db, id) + try: + conversation = crud.get_conversation(db, id) + except SQLAlchemyError: + logging.error("Failed to get conversation", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to get conversation") if not conversation: raise HTTPException(status_code=404, detail="Conversation not found") return ConversationResponse.model_validate(conversation) @@ -479,7 +516,13 @@ async def delete_conversation(id: UUID, db: Session = Depends(get_db)) -> None: """Permanently removes a conversation and all associated messages from the database.""" if not use_db: raise HTTPException(status_code=503, detail=DB_DISABLED_MSG) - conversation = crud.get_conversation(db, id) - if not conversation: - raise HTTPException(status_code=404, detail="Conversation not found") - crud.delete_conversation(db, id) + try: + conversation = crud.get_conversation(db, id) + if not conversation: + raise HTTPException(status_code=404, detail="Conversation not found") + crud.delete_conversation(db, id) + except HTTPException: + raise + except SQLAlchemyError: + logging.error("Failed to delete conversation", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to delete conversation") diff --git a/backend/src/database/crud.py b/backend/src/database/crud.py index 7ef66d97..48089b6f 100644 --- a/backend/src/database/crud.py +++ b/backend/src/database/crud.py @@ -1,9 +1,13 @@ +import logging from typing import Optional from sqlalchemy.orm import Session from sqlalchemy import desc +from sqlalchemy.exc import SQLAlchemyError from .models import Conversation, Message from uuid import UUID +logger = logging.getLogger(__name__) + def create_conversation( db: Session, conversation_uuid: Optional[UUID] = None, title: Optional[str] = None @@ -13,9 +17,14 @@ def create_conversation( if conversation_uuid else Conversation(title=title) ) - db.add(conversation) - db.commit() - db.refresh(conversation) + try: + db.add(conversation) + db.commit() + db.refresh(conversation) + except SQLAlchemyError: + db.rollback() + logger.error("Failed to create conversation", exc_info=True) + raise return conversation @@ -50,17 +59,27 @@ def update_conversation_title( ) -> Optional[Conversation]: conversation = get_conversation(db, conversation_uuid) if conversation: - conversation.title = title - db.commit() - db.refresh(conversation) + try: + conversation.title = title + db.commit() + db.refresh(conversation) + except SQLAlchemyError: + db.rollback() + logger.error("Failed to update conversation title", exc_info=True) + raise return conversation def delete_conversation(db: Session, conversation_uuid: UUID) -> bool: conversation = get_conversation(db, conversation_uuid) if conversation: - db.delete(conversation) - db.commit() + try: + db.delete(conversation) + db.commit() + except SQLAlchemyError: + db.rollback() + logger.error("Failed to delete conversation", exc_info=True) + raise return True return False @@ -80,9 +99,14 @@ def create_message( context_sources=context_sources, tools=tools, ) - db.add(message) - db.commit() - db.refresh(message) + try: + db.add(message) + db.commit() + db.refresh(message) + except SQLAlchemyError: + db.rollback() + logger.error("Failed to create message", exc_info=True) + raise return message @@ -106,8 +130,13 @@ def get_conversation_messages( def delete_message(db: Session, message_id: UUID) -> bool: message = get_message(db, message_id) if message: - db.delete(message) - db.commit() + try: + db.delete(message) + db.commit() + except SQLAlchemyError: + db.rollback() + logger.error("Failed to delete message", exc_info=True) + raise return True return False diff --git a/backend/tests/test_database_crud.py b/backend/tests/test_database_crud.py index 5f59f789..4eaf3ab2 100644 --- a/backend/tests/test_database_crud.py +++ b/backend/tests/test_database_crud.py @@ -1,8 +1,10 @@ """Unit tests for database CRUD operations.""" import pytest +from unittest.mock import Mock from uuid import uuid4, UUID from sqlalchemy import create_engine +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker, Session from src.database.models import Base from src.database import crud @@ -400,3 +402,69 @@ def test_get_conversation_history_only_assistant_messages( # Assistant messages without user messages should be ignored assert len(history) == 0 + + +class TestCRUDExceptionHandling: + """Test suite for CRUD exception handling and rollback behavior.""" + + def test_create_conversation_rollback_on_error(self): + """Test that create_conversation rolls back on commit failure.""" + mock_db = Mock(spec=Session) + mock_db.commit.side_effect = SQLAlchemyError("DB error") + + with pytest.raises(SQLAlchemyError): + crud.create_conversation(mock_db, title="Test") + + mock_db.rollback.assert_called_once() + + def test_update_conversation_title_rollback_on_error(self, db_session: Session): + """Test that update_conversation_title rolls back on commit failure.""" + conv = crud.create_conversation(db_session, title="Original") + + mock_db = Mock(spec=Session) + mock_db.query.return_value.filter.return_value.first.return_value = conv + mock_db.commit.side_effect = SQLAlchemyError("DB error") + + with pytest.raises(SQLAlchemyError): + crud.update_conversation_title(mock_db, conv.uuid, "New Title") + + mock_db.rollback.assert_called_once() + + def test_delete_conversation_rollback_on_error(self, db_session: Session): + """Test that delete_conversation rolls back on commit failure.""" + conv = crud.create_conversation(db_session, title="Test") + + mock_db = Mock(spec=Session) + mock_db.query.return_value.filter.return_value.first.return_value = conv + mock_db.commit.side_effect = SQLAlchemyError("DB error") + + with pytest.raises(SQLAlchemyError): + crud.delete_conversation(mock_db, conv.uuid) + + mock_db.rollback.assert_called_once() + + def test_create_message_rollback_on_error(self, db_session: Session): + """Test that create_message rolls back on commit failure.""" + conv = crud.create_conversation(db_session, title="Test") + + mock_db = Mock(spec=Session) + mock_db.commit.side_effect = SQLAlchemyError("DB error") + + with pytest.raises(SQLAlchemyError): + crud.create_message(mock_db, conv.uuid, "user", "Hello") + + mock_db.rollback.assert_called_once() + + def test_delete_message_rollback_on_error(self, db_session: Session): + """Test that delete_message rolls back on commit failure.""" + conv = crud.create_conversation(db_session, title="Test") + msg = crud.create_message(db_session, conv.uuid, "user", "Hello") + + mock_db = Mock(spec=Session) + mock_db.query.return_value.filter.return_value.first.return_value = msg + mock_db.commit.side_effect = SQLAlchemyError("DB error") + + with pytest.raises(SQLAlchemyError): + crud.delete_message(mock_db, msg.uuid) + + mock_db.rollback.assert_called_once()