Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 102 additions & 59 deletions backend/src/api/routers/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from langchain_ollama import ChatOllama
from langchain_core.messages import AIMessageChunk
from starlette.responses import StreamingResponse
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session

from ...agents.retriever_graph import RetrieverGraph
Expand Down Expand Up @@ -222,7 +223,11 @@ def parse_agent_output(output: list) -> tuple[str, list[ContextSource], list[str

def get_history_str(db: Session | None, conversation_uuid: UUID | None) -> str:
if use_db and db and conversation_uuid:
history = crud.get_conversation_history(db, conversation_uuid)
try:
history = crud.get_conversation_history(db, conversation_uuid)
except SQLAlchemyError:
logging.error("Failed to retrieve conversation history", exc_info=True)
return ""
history_str = ""
for i in history:
user_msg = i.get("User", "")
Expand Down Expand Up @@ -254,21 +259,27 @@ async def get_agent_response(

conversation_uuid = user_input.conversation_uuid

if use_db and db:
conversation = crud.get_or_create_conversation(
db,
conversation_uuid=conversation_uuid,
title=user_question[:100] if user_question else None,
)
conversation_uuid = conversation.uuid
db_persist = use_db and db is not None
if db_persist:
try:
conversation = crud.get_or_create_conversation(
db,
conversation_uuid=conversation_uuid,
title=user_question[:100] if user_question else None,
)
conversation_uuid = conversation.uuid

crud.create_message(
db=db,
conversation_uuid=conversation.uuid,
role="user",
content=user_question,
)
else:
crud.create_message(
db=db,
conversation_uuid=conversation.uuid,
role="user",
content=user_question,
)
except SQLAlchemyError:
logging.error("Failed to persist user message", exc_info=True)
db_persist = False

if not db_persist:
if conversation_uuid is None:
from uuid import uuid4

Expand All @@ -280,7 +291,7 @@ async def get_agent_response(
"messages": [
("user", user_question),
],
"chat_history": get_history_str(db, conversation_uuid),
"chat_history": get_history_str(db if db_persist else None, conversation_uuid),
}

if rg.graph is not None:
Expand All @@ -296,15 +307,18 @@ async def get_agent_response(
]
}

if use_db and db and conversation_uuid:
crud.create_message(
db=db,
conversation_uuid=conversation_uuid,
role="assistant",
content=llm_response,
context_sources=context_sources_dict,
tools=tools,
)
if db_persist and conversation_uuid:
try:
crud.create_message(
db=db,
conversation_uuid=conversation_uuid,
role="assistant",
content=llm_response,
context_sources=context_sources_dict,
tools=tools,
)
except SQLAlchemyError:
logging.error("Failed to persist assistant message", exc_info=True)
else:
if conversation_uuid:
chat_history[conversation_uuid].append(
Expand Down Expand Up @@ -349,21 +363,27 @@ async def get_response_stream(user_input: UserInput, db: Session | None) -> Any:

conversation_uuid = user_input.conversation_uuid

if use_db and db:
conversation = crud.get_or_create_conversation(
db,
conversation_uuid=conversation_uuid,
title=user_question[:100] if user_question else None,
)
conversation_uuid = conversation.uuid
db_persist = use_db and db is not None
if db_persist:
try:
conversation = crud.get_or_create_conversation(
db,
conversation_uuid=conversation_uuid,
title=user_question[:100] if user_question else None,
)
conversation_uuid = conversation.uuid

crud.create_message(
db=db,
conversation_uuid=conversation.uuid,
role="user",
content=user_question,
)
else:
crud.create_message(
db=db,
conversation_uuid=conversation.uuid,
role="user",
content=user_question,
)
except SQLAlchemyError:
logging.error("Failed to persist user message", exc_info=True)
db_persist = False

if not db_persist:
if conversation_uuid is None:
from uuid import uuid4

Expand All @@ -375,7 +395,7 @@ async def get_response_stream(user_input: UserInput, db: Session | None) -> Any:
"messages": [
("user", user_question),
],
"chat_history": get_history_str(db, conversation_uuid),
"chat_history": get_history_str(db if db_persist else None, conversation_uuid),
}

urls: list[str] = []
Expand Down Expand Up @@ -412,18 +432,21 @@ async def get_response_stream(user_input: UserInput, db: Session | None) -> Any:

full_response = "".join(chunks)

if use_db and db and conversation_uuid:
context_sources_dict: dict[str, Any] = {
"sources": [{"source": url, "context": ""} for url in urls]
}
crud.create_message(
db=db,
conversation_uuid=conversation_uuid,
role="assistant",
content=full_response,
context_sources=context_sources_dict,
tools=[],
)
if db_persist and conversation_uuid:
try:
context_sources_dict: dict[str, Any] = {
"sources": [{"source": url, "context": ""} for url in urls]
}
crud.create_message(
db=db,
conversation_uuid=conversation_uuid,
role="assistant",
content=full_response,
context_sources=context_sources_dict,
tools=[],
)
except SQLAlchemyError:
logging.error("Failed to persist assistant message", exc_info=True)
else:
if conversation_uuid:
chat_history[conversation_uuid].append(
Expand All @@ -446,7 +469,13 @@ async def create_conversation(db: Session = Depends(get_db)) -> ConversationResp
"""Creates a new conversation with an auto-generated UUID identifier."""
if not use_db:
raise HTTPException(status_code=503, detail=DB_DISABLED_MSG)
new_conversation = crud.create_conversation(db, conversation_uuid=None, title=None)
try:
new_conversation = crud.create_conversation(
db, conversation_uuid=None, title=None
)
except SQLAlchemyError:
logging.error("Failed to create conversation", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to create conversation")
return ConversationResponse.model_validate(new_conversation)


Expand All @@ -457,7 +486,11 @@ async def list_conversations(
"""Retrieves a paginated list of all conversations without their messages."""
if not use_db:
raise HTTPException(status_code=503, detail=DB_DISABLED_MSG)
conversations = crud.get_all_conversations(db, skip=skip, limit=limit)
try:
conversations = crud.get_all_conversations(db, skip=skip, limit=limit)
except SQLAlchemyError:
logging.error("Failed to list conversations", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to list conversations")
return [ConversationListResponse.model_validate(conv) for conv in conversations]


Expand All @@ -468,7 +501,11 @@ async def get_conversation(
"""Retrieves a complete conversation including all associated messages."""
if not use_db:
raise HTTPException(status_code=503, detail=DB_DISABLED_MSG)
conversation = crud.get_conversation(db, id)
try:
conversation = crud.get_conversation(db, id)
except SQLAlchemyError:
logging.error("Failed to get conversation", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to get conversation")
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
return ConversationResponse.model_validate(conversation)
Expand All @@ -479,7 +516,13 @@ async def delete_conversation(id: UUID, db: Session = Depends(get_db)) -> None:
"""Permanently removes a conversation and all associated messages from the database."""
if not use_db:
raise HTTPException(status_code=503, detail=DB_DISABLED_MSG)
conversation = crud.get_conversation(db, id)
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
crud.delete_conversation(db, id)
try:
conversation = crud.get_conversation(db, id)
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
crud.delete_conversation(db, id)
except HTTPException:
raise
except SQLAlchemyError:
logging.error("Failed to delete conversation", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to delete conversation")
55 changes: 42 additions & 13 deletions backend/src/database/crud.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging
from typing import Optional
from sqlalchemy.orm import Session
from sqlalchemy import desc
from sqlalchemy.exc import SQLAlchemyError
from .models import Conversation, Message
from uuid import UUID

logger = logging.getLogger(__name__)


def create_conversation(
db: Session, conversation_uuid: Optional[UUID] = None, title: Optional[str] = None
Expand All @@ -13,9 +17,14 @@ def create_conversation(
if conversation_uuid
else Conversation(title=title)
)
db.add(conversation)
db.commit()
db.refresh(conversation)
try:
db.add(conversation)
db.commit()
db.refresh(conversation)
except SQLAlchemyError:
db.rollback()
logger.error("Failed to create conversation", exc_info=True)
raise
return conversation


Expand Down Expand Up @@ -50,17 +59,27 @@ def update_conversation_title(
) -> Optional[Conversation]:
conversation = get_conversation(db, conversation_uuid)
if conversation:
conversation.title = title
db.commit()
db.refresh(conversation)
try:
conversation.title = title
db.commit()
db.refresh(conversation)
except SQLAlchemyError:
db.rollback()
logger.error("Failed to update conversation title", exc_info=True)
raise
return conversation


def delete_conversation(db: Session, conversation_uuid: UUID) -> bool:
conversation = get_conversation(db, conversation_uuid)
if conversation:
db.delete(conversation)
db.commit()
try:
db.delete(conversation)
db.commit()
except SQLAlchemyError:
db.rollback()
logger.error("Failed to delete conversation", exc_info=True)
raise
return True
return False

Expand All @@ -80,9 +99,14 @@ def create_message(
context_sources=context_sources,
tools=tools,
)
db.add(message)
db.commit()
db.refresh(message)
try:
db.add(message)
db.commit()
db.refresh(message)
except SQLAlchemyError:
db.rollback()
logger.error("Failed to create message", exc_info=True)
raise
return message


Expand All @@ -106,8 +130,13 @@ def get_conversation_messages(
def delete_message(db: Session, message_id: UUID) -> bool:
message = get_message(db, message_id)
if message:
db.delete(message)
db.commit()
try:
db.delete(message)
db.commit()
except SQLAlchemyError:
db.rollback()
logger.error("Failed to delete message", exc_info=True)
raise
return True
return False

Expand Down
Loading