diff --git a/src/askui/chat/api/app.py b/src/askui/chat/api/app.py index d31e9492..81d8dcd0 100644 --- a/src/askui/chat/api/app.py +++ b/src/askui/chat/api/app.py @@ -152,6 +152,17 @@ def forbidden_error_handler( ) +@app.exception_handler(ValueError) +def value_error_handler( + request: Request, # noqa: ARG001 + exc: ValueError, +) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content={"detail": str(exc)}, + ) + + @app.exception_handler(Exception) def catch_all_exception_handler( request: Request, # noqa: ARG001 diff --git a/src/askui/chat/api/db/orm/types.py b/src/askui/chat/api/db/orm/types.py index aab81624..0501ec8f 100644 --- a/src/askui/chat/api/db/orm/types.py +++ b/src/askui/chat/api/db/orm/types.py @@ -56,3 +56,64 @@ def process_result_value( if value is None: return value return datetime.fromtimestamp(value, timezone.utc) + + +def create_sentinel_id_type( + prefix: str, sentinel_value: str +) -> type[TypeDecorator[str]]: + """Create a type decorator that converts between a sentinel value and NULL. + + This is useful for self-referential nullable foreign keys where NULL in the database + is represented by a sentinel value in the API (e.g., root nodes in a tree structure). + + Args: + prefix (str): The prefix for the ID (e.g., "msg"). + sentinel_value (str): The sentinel value representing NULL (e.g., "msg_000000000000000000000000"). + + Returns: + type[TypeDecorator[str]]: A TypeDecorator class that handles the transformation. + + Example: + ```python + ParentMessageId = create_sentinel_id_type("msg", ROOT_MESSAGE_PARENT_ID) + parent_id: Mapped[str] = mapped_column(ParentMessageId, nullable=True) + ``` + """ + + class SentinelId(TypeDecorator[str]): + """Type decorator that converts between sentinel value (API) and NULL (database). + + - When writing to DB: sentinel_value → NULL + - When reading from DB: NULL → sentinel_value + """ + + impl = String(24) + cache_ok = ( + False # Disable caching due to closure over prefix and sentinel_value + ) + + def process_bind_param( + self, + value: str | None, + dialect: Any, # noqa: ARG002 + ) -> str | None: + """Convert from API model to database storage.""" + if value is None or value == sentinel_value: + # Both None and sentinel value become NULL in database + return None + # Remove prefix for storage (like regular PrefixedObjectId) + return value.removeprefix(f"{prefix}_") + + def process_result_value( + self, + value: str | None, + dialect: Any, # noqa: ARG002 + ) -> str: + """Convert from database storage to API model.""" + if value is None: + # NULL in database becomes sentinel value in API + return sentinel_value + # Add prefix (like regular PrefixedObjectId) + return f"{prefix}_{value}" + + return SentinelId diff --git a/src/askui/chat/api/messages/chat_history_manager.py b/src/askui/chat/api/messages/chat_history_manager.py index 9120f092..640ff7a3 100644 --- a/src/askui/chat/api/messages/chat_history_manager.py +++ b/src/askui/chat/api/messages/chat_history_manager.py @@ -3,9 +3,10 @@ from askui.chat.api.messages.models import Message, MessageCreate from askui.chat.api.messages.service import MessageService from askui.chat.api.messages.translator import MessageTranslator -from askui.chat.api.models import ThreadId, WorkspaceId +from askui.chat.api.models import MessageId, ThreadId, WorkspaceId from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.truncation_strategies import TruncationStrategyFactory +from askui.utils.api_utils import NotFoundError class ChatHistoryManager: @@ -53,12 +54,26 @@ async def retrieve_message_params( ) ) for msg in self._message_service.iter( - workspace_id=workspace_id, thread_id=thread_id + workspace_id=workspace_id, + thread_id=thread_id, ): anthropic_message = await self._message_translator.to_anthropic(msg) truncation_strategy.append_message(anthropic_message) return truncation_strategy.messages + def retrieve_last_message( + self, + workspace_id: WorkspaceId, + thread_id: ThreadId, + ) -> MessageId: + last_message_id = self._message_service.retrieve_last_message_id( + workspace_id, thread_id + ) + if last_message_id is None: + error_msg = f"No messages found in thread {thread_id}" + raise NotFoundError(error_msg) + return last_message_id + async def append_message( self, workspace_id: WorkspaceId, @@ -66,11 +81,13 @@ async def append_message( assistant_id: str | None, run_id: str, message: MessageParam, + parent_id: str, ) -> Message: return self._message_service.create( workspace_id=workspace_id, thread_id=thread_id, params=MessageCreate( + parent_id=parent_id, assistant_id=assistant_id if message.role == "assistant" else None, role=message.role, content=await self._message_content_translator.from_anthropic( diff --git a/src/askui/chat/api/messages/models.py b/src/askui/chat/api/messages/models.py index 80bbe365..2f9d34aa 100644 --- a/src/askui/chat/api/messages/models.py +++ b/src/askui/chat/api/messages/models.py @@ -24,6 +24,8 @@ from askui.utils.datetime_utils import UnixDatetime, now from askui.utils.id_utils import generate_time_ordered_id +ROOT_MESSAGE_PARENT_ID = "msg_000000000000000000000000" + class BetaFileDocumentSourceParam(BaseModel): file_id: str @@ -80,6 +82,7 @@ class MessageParam(BaseModel): class MessageBase(MessageParam): assistant_id: AssistantId | None = None run_id: RunId | None = None + parent_id: MessageId | None = None class MessageCreate(MessageBase): diff --git a/src/askui/chat/api/messages/orms.py b/src/askui/chat/api/messages/orms.py index 21b12c42..f2931bd4 100644 --- a/src/askui/chat/api/messages/orms.py +++ b/src/askui/chat/api/messages/orms.py @@ -14,10 +14,12 @@ ThreadId, UnixDatetime, create_prefixed_id_type, + create_sentinel_id_type, ) -from askui.chat.api.messages.models import Message +from askui.chat.api.messages.models import ROOT_MESSAGE_PARENT_ID, Message MessageId = create_prefixed_id_type("msg") +_ParentMessageId = create_sentinel_id_type("msg", ROOT_MESSAGE_PARENT_ID) class MessageOrm(Base): @@ -43,6 +45,12 @@ class MessageOrm(Base): run_id: Mapped[str | None] = mapped_column( RunId, ForeignKey("runs.id", ondelete="SET NULL"), nullable=True ) + parent_id: Mapped[str] = mapped_column( + _ParentMessageId, + ForeignKey("messages.id", ondelete="CASCADE"), + nullable=True, + index=True, + ) @classmethod def from_model(cls, model: Message) -> "MessageOrm": diff --git a/src/askui/chat/api/messages/router.py b/src/askui/chat/api/messages/router.py index 409ac3b5..1878a260 100644 --- a/src/askui/chat/api/messages/router.py +++ b/src/askui/chat/api/messages/router.py @@ -24,6 +24,38 @@ def list_messages( ) +@router.get("/{message_id}/siblings") +def list_siblings( + askui_workspace: Annotated[WorkspaceId, Header()], + thread_id: ThreadId, + message_id: MessageId, + message_service: MessageService = MessageServiceDep, +) -> list[Message]: + """List all sibling messages for a given message. + + Sibling messages are messages that share the same `parent_id` as the specified message. + The specified message itself is included in the results. + Results are sorted by ID (chronological order, as IDs are BSON-based). + + Args: + askui_workspace (WorkspaceId): The workspace ID from header. + thread_id (ThreadId): The thread ID. + message_id (MessageId): The message ID to find siblings for. + message_service (MessageService): The message service dependency. + + Returns: + list[Message]: List of sibling messages sorted by ID. + + Raises: + NotFoundError: If the specified message does not exist. + """ + return message_service.list_siblings( + workspace_id=askui_workspace, + thread_id=thread_id, + message_id=message_id, + ) + + @router.post("", status_code=status.HTTP_201_CREATED) async def create_message( askui_workspace: Annotated[WorkspaceId, Header()], diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py index b1c08d54..45cfc06b 100644 --- a/src/askui/chat/api/messages/service.py +++ b/src/askui/chat/api/messages/service.py @@ -1,9 +1,13 @@ -from typing import Iterator +from typing import Any, Iterator -from sqlalchemy.orm import Session +from sqlalchemy import CTE, desc, select +from sqlalchemy.orm import Query, Session -from askui.chat.api.db.queries import list_all -from askui.chat.api.messages.models import Message, MessageCreate +from askui.chat.api.messages.models import ( + ROOT_MESSAGE_PARENT_ID, + Message, + MessageCreate, +) from askui.chat.api.messages.orms import MessageOrm from askui.chat.api.models import MessageId, ThreadId, WorkspaceId from askui.chat.api.threads.orms import ThreadOrm @@ -40,6 +44,173 @@ def _find_by_id( raise NotFoundError(error_msg) return message_orm + def _retrieve_latest_root( + self, workspace_id: WorkspaceId, thread_id: ThreadId + ) -> str | None: + """Retrieve the latest root message ID in a thread. + + Args: + workspace_id (WorkspaceId): The workspace ID. + thread_id (ThreadId): The thread ID. + + Returns: + str | None: The ID of the latest root message, or `None` if no root messages exist. + """ + return self._session.execute( + select(MessageOrm.id) + .filter( + MessageOrm.parent_id.is_(None), + MessageOrm.thread_id == thread_id, + MessageOrm.workspace_id == workspace_id, + ) + .order_by(desc(MessageOrm.id)) + .limit(1) + ).scalar_one_or_none() + + def _build_ancestors_cte( + self, message_id: MessageId, workspace_id: WorkspaceId, thread_id: ThreadId + ) -> CTE: + """Build a recursive CTE to traverse up the message tree from a given message. + + Args: + message_id (MessageId): The ID of the message to start traversing from. + workspace_id (WorkspaceId): The workspace ID. + thread_id (ThreadId): The thread ID. + + Returns: + CTE: A recursive common table expression that contains all ancestors of the message. + """ + # Build CTE to traverse up the tree from message_id + _ancestors_cte = ( + select(MessageOrm.id, MessageOrm.parent_id) + .filter( + MessageOrm.id == message_id, + MessageOrm.thread_id == thread_id, + MessageOrm.workspace_id == workspace_id, + ) + .cte(name="ancestors", recursive=True) + ) + + # Recursively traverse up until we hit NULL (root message) + _ancestors_recursive = select(MessageOrm.id, MessageOrm.parent_id).filter( + MessageOrm.id == _ancestors_cte.c.parent_id, + _ancestors_cte.c.parent_id.is_not(None), + ) + return _ancestors_cte.union_all(_ancestors_recursive) + + def _build_descendants_cte(self, message_id: MessageId) -> CTE: + """Build a recursive CTE to traverse down the message tree from a given message. + + Args: + message_id (MessageId): The ID of the message to start traversing from. + + Returns: + CTE: A recursive common table expression that contains all descendants of the message. + """ + # Build CTE to traverse down the tree from message_id + _descendants_cte = ( + select(MessageOrm.id, MessageOrm.parent_id) + .filter( + MessageOrm.id == message_id, + ) + .cte(name="descendants", recursive=True) + ) + + # Recursively traverse down + _descendants_recursive = select(MessageOrm.id, MessageOrm.parent_id).filter( + MessageOrm.parent_id == _descendants_cte.c.id, + ) + return _descendants_cte.union_all(_descendants_recursive) + + def _retrieve_latest_leaf(self, message_id: MessageId) -> str | None: + """Retrieve the latest leaf node in the subtree rooted at the given message. + + Args: + message_id (MessageId): The ID of the root message to start from. + + Returns: + str | None: The ID of the latest leaf node (highest ID), or `None` if no descendants exist. + """ + # Build CTE to traverse down the tree from message_id + _descendants_cte = self._build_descendants_cte(message_id) + + # Get the latest leaf (highest ID) + return self._session.execute( + select(_descendants_cte.c.id).order_by(desc(_descendants_cte.c.id)).limit(1) + ).scalar_one_or_none() + + def _retrieve_branch_root( + self, leaf_id: MessageId, workspace_id: WorkspaceId, thread_id: ThreadId + ) -> str | None: + """Retrieve the branch root node by traversing up from a leaf node. + + Args: + leaf_id (MessageId): The ID of the leaf message to start from. + workspace_id (WorkspaceId): The workspace ID. + thread_id (ThreadId): The thread ID. + + Returns: + str | None: The ID of the root node (with parent_id == NULL), or `None` if not found. + """ + # Build CTE to traverse up the tree from leaf_id + _ancestors_cte = self._build_ancestors_cte(leaf_id, workspace_id, thread_id) + + # Get the root node (the one with parent_id == NULL) + return self._session.execute( + select(MessageOrm.id).filter( + MessageOrm.id.in_(select(_ancestors_cte.c.id)), + MessageOrm.parent_id.is_(None), + ) + ).scalar_one_or_none() + + def _build_path_query(self, path_start: str, path_end: str) -> Query[MessageOrm]: + """Build a query for messages in the path from end to start. + + Args: + path_start (str): The ID of the path start message (upper node). + path_end (str): The ID of the path end message (lower node). + + Returns: + Query[MessageOrm]: A query object for fetching messages in the path. + """ + # Build path from path_end up to path_start using recursive CTE + # Start from path_end and traverse upward following parent_id until we reach path_start + _path_cte = ( + select(MessageOrm.id, MessageOrm.parent_id) + .filter( + MessageOrm.id == path_end, + ) + .cte(name="path", recursive=True) + ) + + # Recursively fetch parent nodes, stopping before we go past path_start + # No need to filter by thread_id/workspace_id - parent_id relationship ensures correct path + _path_recursive = select(MessageOrm.id, MessageOrm.parent_id).filter( + MessageOrm.id == _path_cte.c.parent_id, + # Stop recursion: don't fetch parent of path_start + _path_cte.c.id != path_start, + ) + + _path_cte = _path_cte.union_all(_path_recursive) + + return self._session.query(MessageOrm).join( + _path_cte, MessageOrm.id == _path_cte.c.id + ) + + def retrieve_last_message_id( + self, workspace_id: WorkspaceId, thread_id: ThreadId + ) -> MessageId | None: + """Get the last message ID in a thread. If no messages exist, return the root message ID.""" + return self._session.execute( + select(MessageOrm.id) + .filter( + MessageOrm.thread_id == thread_id, + MessageOrm.workspace_id == workspace_id, + ) + .order_by(desc(MessageOrm.id)) + .limit(1) + ).scalar_one_or_none() + def create( self, workspace_id: WorkspaceId, @@ -60,23 +231,169 @@ def create( error_msg = f"Thread {thread_id} not found" raise NotFoundError(error_msg) + if ( + params.parent_id is None + ): # If no parent ID is provided, use the last message in the thread + parent_id = self.retrieve_last_message_id(workspace_id, thread_id) + + # if the thread is empty, use the root message parent ID + if parent_id is None: + parent_id = ROOT_MESSAGE_PARENT_ID + params.parent_id = parent_id + + # Validate parent message exists (if not root) + if params.parent_id and params.parent_id != ROOT_MESSAGE_PARENT_ID: + parent_message_orm: MessageOrm | None = ( + self._session.query(MessageOrm) + .filter( + MessageOrm.id == params.parent_id, + MessageOrm.thread_id == thread_id, + MessageOrm.workspace_id == workspace_id, + ) + .first() + ) + if parent_message_orm is None: + error_msg = ( + f"Parent message {params.parent_id} not found in thread {thread_id}" + ) + raise NotFoundError(error_msg) + message = Message.create(workspace_id, thread_id, params) message_orm = MessageOrm.from_model(message) self._session.add(message_orm) self._session.commit() return message + def _get_path_endpoints( + self, workspace_id: WorkspaceId, thread_id: ThreadId, query: ListQuery + ) -> tuple[str, str] | None: + """Determine the path start and end node IDs for path traversal. + + Executes queries to get concrete ID values for the path start and end nodes. + + Args: + workspace_id (WorkspaceId): The workspace ID. + thread_id (ThreadId): The thread ID. + query (ListQuery): Pagination query (after/before, limit, order). + + Returns: + tuple[str, str] | None: A tuple of (path_start, path_end) where path_start is the + upper node and path_end is the lower node. Returns `None` if no messages exist + in the thread. + + Raises: + ValueError: If both `after` and `before` parameters are specified. + NotFoundError: If the specified message in `before` or `after` does not exist. + """ + if query.after and query.before: + error_msg = "Cannot specify both 'after' and 'before' parameters" + raise ValueError(error_msg) + + # Determine cursor and direction based on after/before and order + # Key insight: (after+desc) and (before+asc) both traverse UP (towards root) + # (after+asc) and (before+desc) both traverse DOWN (towards leaves) + _cursor = query.after or query.before + _should_traverse_up = (query.after and query.order == "desc") or ( + query.before and query.order == "asc" + ) + + path_start: str | None + path_end: str | None + + if _cursor: + if _should_traverse_up: + # Traverse UP: set path_end to cursor and find path_start by going to root + path_end = _cursor + path_start = self._retrieve_branch_root( + path_end, workspace_id, thread_id + ) + if path_start is None: + error_msg = f"Message with id '{path_end}' not found" + raise NotFoundError(error_msg) + else: + # Traverse DOWN: set path_start to cursor and find path_end by going to leaf + path_start = _cursor + path_end = self._retrieve_latest_leaf(path_start) + if path_end is None: + error_msg = f"Message with id '{path_start}' not found" + raise NotFoundError(error_msg) + else: + # No pagination - get the full branch from latest root to latest leaf + path_end = self.retrieve_last_message_id(workspace_id, thread_id) + if path_end is None: + return None + path_start = self._retrieve_branch_root(path_end, workspace_id, thread_id) + if path_start is None: + error_msg = f"Message with id '{path_end}' not found" + raise NotFoundError(error_msg) + + return path_start, path_end + def list_( self, workspace_id: WorkspaceId, thread_id: ThreadId, query: ListQuery ) -> ListResponse[Message]: - """List messages with pagination and filtering.""" - q = self._session.query(MessageOrm).filter( - MessageOrm.thread_id == thread_id, - MessageOrm.workspace_id == workspace_id, + """List messages in a tree path with pagination and filtering. + + Behavior: + - If `after` is provided: + - With `order=desc`: Returns path from `after` node up to root (excludes `after` itself) + - With `order=asc`: Returns path from `after` node down to latest leaf (excludes `after` itself) + - If `before` is provided: + - With `order=asc`: Returns path from `before` node up to root (excludes `before` itself) + - With `order=desc`: Returns path from `before` node down to latest leaf (excludes `before` itself) + - If neither: Returns main branch (root to latest leaf in entire thread) + + The method identifies a start_id (upper node) and end_id (leaf node), + traverses from end_id up to start_id, then applies the specified order. + + Args: + workspace_id (WorkspaceId): The workspace ID. + thread_id (ThreadId): The thread ID. + query (ListQuery): Pagination query (after/before, limit, order). + + Returns: + ListResponse[Message]: Paginated list of messages in the tree path. + + Raises: + ValueError: If both `after` and `before` parameters are specified. + NotFoundError: If the specified message in `before` or `after` does not exist. + """ + # Step 1: Get concrete path_start and path_end + _endpoints = self._get_path_endpoints(workspace_id, thread_id, query) + + # If no messages exist yet, return empty response + if _endpoints is None: + return ListResponse(data=[], has_more=False) + + _path_start, _path_end = _endpoints + + # Step 2: Build path query from path_end up to path_start + _query = self._build_path_query(_path_start, _path_end) + + # Build all filters at once for better query planning + _filters: list[Any] = [] + if query.after: + _filters.append(MessageOrm.id != query.after) + if query.before: + _filters.append(MessageOrm.id != query.before) + + if _filters: + _query = _query.filter(*_filters) + + orms = ( + _query.order_by( + MessageOrm.id if query.order == "asc" else desc(MessageOrm.id) + ) + .limit(query.limit + 1) + .all() ) - orms: list[MessageOrm] - orms, has_more = list_all(q, query, MessageOrm.id) - data = [orm.to_model() for orm in orms] + + if not orms: + return ListResponse(data=[], has_more=False) + + has_more = len(orms) > query.limit + data = [orm.to_model() for orm in orms[: query.limit]] + return ListResponse( data=data, has_more=has_more, @@ -92,6 +409,7 @@ def iter( batch_size: int = LIST_LIMIT_DEFAULT, ) -> Iterator[Message]: """Iterate through messages in batches.""" + has_more = True last_id: str | None = None while has_more: @@ -119,3 +437,55 @@ def delete( message_orm = self._find_by_id(workspace_id, thread_id, message_id) self._session.delete(message_orm) self._session.commit() + + def list_siblings( + self, + workspace_id: WorkspaceId, + thread_id: ThreadId, + message_id: MessageId, + ) -> list[Message]: + """List all sibling messages for a given message. + + Sibling messages are messages that share the same `parent_id` as the specified message. + The specified message itself is included in the results. + Results are sorted by ID (chronological order, as IDs are BSON-based). + + Args: + workspace_id (WorkspaceId): The workspace ID. + thread_id (ThreadId): The thread ID. + message_id (MessageId): The message ID to find siblings for. + + Returns: + list[Message]: List of sibling messages sorted by ID. + + Raises: + NotFoundError: If the specified message does not exist. + """ + # Query for all sibling messages using a subquery to get parent_id + _parent_id_subquery = ( + select(MessageOrm.parent_id) + .filter( + MessageOrm.id == message_id, + MessageOrm.thread_id == thread_id, + MessageOrm.workspace_id == workspace_id, + ) + .scalar_subquery() + ) + + orms = ( + self._session.query(MessageOrm) + .filter( + MessageOrm.parent_id.is_not_distinct_from(_parent_id_subquery), + MessageOrm.thread_id == thread_id, + MessageOrm.workspace_id == workspace_id, + ) + .order_by(desc(MessageOrm.id)) + .all() + ) + + # Validate that the message exists (if no results, message doesn't exist) + if not orms: + error_msg = f"Message {message_id} not found in thread {thread_id}" + raise NotFoundError(error_msg) + + return [orm.to_model() for orm in orms] diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index 066c273a..3c52c0ed 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -10,7 +10,7 @@ from askui.chat.api.assistants.models import Assistant from askui.chat.api.mcp_clients.manager import McpClientManagerManager from askui.chat.api.messages.chat_history_manager import ChatHistoryManager -from askui.chat.api.models import RunId, ThreadId, WorkspaceId +from askui.chat.api.models import MessageId, RunId, ThreadId, WorkspaceId from askui.chat.api.runs.events.done_events import DoneEvent from askui.chat.api.runs.events.error_events import ( ErrorEvent, @@ -65,6 +65,7 @@ def __init__( mcp_client_manager_manager: McpClientManagerManager, run_service: RunnerRunService, settings: Settings, + last_message_id: MessageId, model: str | None = None, ) -> None: self._run_id = run_id @@ -75,6 +76,7 @@ def __init__( self._mcp_client_manager_manager = mcp_client_manager_manager self._run_service = run_service self._settings = settings + self._last_message_id = last_message_id self._model: str | None = model def _retrieve_run(self) -> Run: @@ -139,7 +141,10 @@ async def async_on_message( assistant_id=self._assistant.id, run_id=self._run_id, message=on_message_cb_param.message, + parent_id=self._last_message_id, ) + # Update the parent_id for the next message + self._last_message_id = created_message.id await send_stream.send( MessageEvent( data=created_message, diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py index 7bbffdbc..5e785935 100644 --- a/src/askui/chat/api/runs/service.py +++ b/src/askui/chat/api/runs/service.py @@ -83,6 +83,11 @@ async def create( ) run = self._create(workspace_id, thread_id, params) send_stream, receive_stream = anyio.create_memory_object_stream[Event]() + + last_message_id = self._chat_history_manager.retrieve_last_message( + workspace_id=workspace_id, + thread_id=thread_id, + ) runner = Runner( run_id=run.id, thread_id=thread_id, @@ -92,6 +97,7 @@ async def create( mcp_client_manager_manager=self._mcp_client_manager_manager, run_service=self, settings=self._settings, + last_message_id=last_message_id, model=params.model, ) diff --git a/src/askui/chat/migrations/versions/7b8c9d0e1f2a_add_parent_id_to_messages.py b/src/askui/chat/migrations/versions/7b8c9d0e1f2a_add_parent_id_to_messages.py new file mode 100644 index 00000000..2c8a1daf --- /dev/null +++ b/src/askui/chat/migrations/versions/7b8c9d0e1f2a_add_parent_id_to_messages.py @@ -0,0 +1,92 @@ +"""add_parent_id_to_messages + +Revision ID: 7b8c9d0e1f2a +Revises: 5e6f7a8b9c0d +Create Date: 2025-11-05 12:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "7b8c9d0e1f2a" +down_revision: Union[str, None] = "5e6f7a8b9c0d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Get database connection + connection = op.get_bind() + + # Check if parent_id column already exists + inspector = sa.inspect(connection) + columns = [col["name"] for col in inspector.get_columns("messages")] + column_exists = "parent_id" in columns + + # Only run batch operation if column doesn't exist + if not column_exists: + # Add column, foreign key, and index all in one batch operation + # This ensures the table is only recreated once in SQLite + with op.batch_alter_table("messages") as batch_op: + # Add parent_id column + batch_op.add_column(sa.Column("parent_id", sa.String(24), nullable=True)) + + # Add foreign key constraint (self-referential) + # parent_id remains nullable - NULL indicates a root message + batch_op.create_foreign_key( + "fk_messages_parent_id", + "messages", + ["parent_id"], + ["id"], + ondelete="CASCADE", + ) + + # Add index for performance + batch_op.create_index("ix_messages_parent_id", ["parent_id"]) + + # NOW populate parent_id values AFTER the table structure is finalized + # Fetch all threads + threads_result = connection.execute(sa.text("SELECT id FROM threads")) + thread_ids = [row[0] for row in threads_result] + + # For each thread, set up parent-child relationships + for thread_id in thread_ids: + # Get all messages in this thread, sorted by ID (which is time-ordered) + messages_result = connection.execute( + sa.text( + "SELECT id FROM messages WHERE thread_id = :thread_id ORDER BY id ASC" + ), + {"thread_id": thread_id}, + ) + message_ids = [row[0] for row in messages_result] + + # Set parent_id for each message + for i, message_id in enumerate(message_ids): + if i == 0: + # First message in thread has NULL as parent (root message) + parent_id = None + else: + # Each subsequent message's parent is the previous message + parent_id = message_ids[i - 1] + + connection.execute( + sa.text( + "UPDATE messages SET parent_id = :parent_id WHERE id = :message_id" + ), + {"parent_id": parent_id, "message_id": message_id}, + ) + + +def downgrade() -> None: + # Use batch_alter_table for SQLite compatibility + with op.batch_alter_table("messages") as batch_op: + # Drop index + batch_op.drop_index("ix_messages_parent_id") + # Drop foreign key constraint + batch_op.drop_constraint("fk_messages_parent_id", type_="foreignkey") + # Drop column + batch_op.drop_column("parent_id") diff --git a/tests/integration/chat/api/test_message_service.py b/tests/integration/chat/api/test_message_service.py new file mode 100644 index 00000000..0075460c --- /dev/null +++ b/tests/integration/chat/api/test_message_service.py @@ -0,0 +1,464 @@ +"""Unit tests for the MessageService.""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +import pytest +from sqlalchemy.orm import Session + +from askui.chat.api.messages.models import ( + ROOT_MESSAGE_PARENT_ID, + Message, + MessageCreate, +) +from askui.chat.api.messages.service import MessageService +from askui.chat.api.threads.models import Thread +from askui.chat.api.threads.orms import ThreadOrm +from askui.utils.api_utils import ListQuery + + +class TestMessageServicePagination: + """Test pagination behavior with different order and after/before parameters.""" + + @pytest.fixture + def _workspace_id(self) -> UUID: + """Create a test workspace ID.""" + return uuid4() + + @pytest.fixture + def _thread_id(self, test_db_session: Session, _workspace_id: UUID) -> str: + """Create a test thread.""" + _thread = Thread( + id="thread_testpagination", + object="thread", + created_at=datetime.now(timezone.utc), + name="Test Thread for Pagination", + workspace_id=_workspace_id, + ) + _thread_orm = ThreadOrm.from_model(_thread) + test_db_session.add(_thread_orm) + test_db_session.commit() + return _thread.id + + @pytest.fixture + def _message_service(self, test_db_session: Session) -> MessageService: + """Create a MessageService instance.""" + return MessageService(test_db_session) + + @pytest.fixture + def _messages( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + ) -> list[Message]: + """Create two branches of messages for testing. + + Branch 1: Messages 0-9 (linear chain from ROOT) + Branch 2: Messages 10-19 (separate linear chain from ROOT) + """ + _created_messages: list[Message] = [] + + # Create first branch: messages 0-9 (linear chain) + for i in range(10): + _msg = _message_service.create( + workspace_id=_workspace_id, + thread_id=_thread_id, + params=MessageCreate( + role="user" if i % 2 == 0 else "assistant", + content=f"Test message {i}", + parent_id=( + ROOT_MESSAGE_PARENT_ID + if i == 0 + else _created_messages[i - 1].id + ), + ), + ) + _created_messages.append(_msg) + + # Create second branch: messages 10-19 (separate linear chain from ROOT) + for i in range(10, 20): + _msg = _message_service.create( + workspace_id=_workspace_id, + thread_id=_thread_id, + params=MessageCreate( + role="user" if i % 2 == 0 else "assistant", + content=f"Test message {i}", + parent_id=( + ROOT_MESSAGE_PARENT_ID + if i == 10 + else _created_messages[i - 1].id + ), + ), + ) + _created_messages.append(_msg) + + return _created_messages + + def test_list_asc_without_after( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + _messages: list[Message], + ) -> None: + """Test listing messages in ascending order without 'after' parameter.""" + # Without before/after, gets latest branch (branch 2) + _response = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=5, order="asc"), + ) + + assert len(_response.data) == 5 + # Should get the first 5 messages from branch 2 (10, 11, 12, 13, 14) + assert [_msg.content for _msg in _response.data] == [ + "Test message 10", + "Test message 11", + "Test message 12", + "Test message 13", + "Test message 14", + ] + assert _response.has_more is True + + def test_list_asc_with_after( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + _messages: list[Message], + ) -> None: + """Test listing messages in ascending order with 'after' parameter.""" + # First, get the first page from branch 2 (default) + _first_page = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=3, order="asc"), + ) + + assert len(_first_page.data) == 3 + assert [_msg.content for _msg in _first_page.data] == [ + "Test message 10", + "Test message 11", + "Test message 12", + ] + + # Now get the second page using 'after' + _second_page = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=3, order="asc", after=_first_page.last_id), + ) + + assert len(_second_page.data) == 3 + # Should get the next 3 messages (13, 14, 15) + assert [_msg.content for _msg in _second_page.data] == [ + "Test message 13", + "Test message 14", + "Test message 15", + ] + assert _second_page.has_more is True + + # Get the third page + _third_page = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=3, order="asc", after=_second_page.last_id), + ) + + assert len(_third_page.data) == 3 + # Should get the next 3 messages (16, 17, 18) + assert [_msg.content for _msg in _third_page.data] == [ + "Test message 16", + "Test message 17", + "Test message 18", + ] + + def test_list_desc_without_after( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + _messages: list[Message], + ) -> None: + """Test listing messages in descending order without 'after' parameter.""" + # Without before/after, gets latest branch (branch 2) + _response = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=5, order="desc"), + ) + + assert len(_response.data) == 5 + # Should get the last 5 messages from branch 2 (19, 18, 17, 16, 15) + assert [_msg.content for _msg in _response.data] == [ + "Test message 19", + "Test message 18", + "Test message 17", + "Test message 16", + "Test message 15", + ] + assert _response.has_more is True + + def test_list_desc_with_after( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + _messages: list[Message], + ) -> None: + """Test listing messages in descending order with 'after' parameter.""" + # First, get the first page from branch 2 (default) + _first_page = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=3, order="desc"), + ) + + assert len(_first_page.data) == 3 + assert [_msg.content for _msg in _first_page.data] == [ + "Test message 19", + "Test message 18", + "Test message 17", + ] + + # Now get the second page using 'after' + _second_page = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=3, order="desc", after=_first_page.last_id), + ) + + assert len(_second_page.data) == 3 + # Should get the previous 3 messages (16, 15, 14) + assert [_msg.content for _msg in _second_page.data] == [ + "Test message 16", + "Test message 15", + "Test message 14", + ] + + def test_iter_asc( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + _messages: list[Message], + ) -> None: + """Test iterating through messages in ascending order.""" + # Without before/after, iter returns the latest branch (branch 2) + _collected_messages: list[Message] = list( + _message_service.iter( + workspace_id=_workspace_id, + thread_id=_thread_id, + order="asc", + batch_size=3, + ) + ) + + # Should get all 10 messages from branch 2 in ascending order + assert len(_collected_messages) == 10 + assert [_msg.content for _msg in _collected_messages] == [ + f"Test message {i}" for i in range(10, 20) + ] + + def test_iter_desc( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + _messages: list[Message], + ) -> None: + """Test iterating through messages in descending order.""" + # Without before/after, iter returns the latest branch (branch 2) + _collected_messages: list[Message] = list( + _message_service.iter( + workspace_id=_workspace_id, + thread_id=_thread_id, + order="desc", + batch_size=3, + ) + ) + + # Should get all 10 messages from branch 2 in descending order + assert len(_collected_messages) == 10 + assert [_msg.content for _msg in _collected_messages] == [ + f"Test message {i}" for i in range(19, 9, -1) + ] + + def test_list_asc_with_before( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + _messages: list[Message], + ) -> None: + """Test listing messages in ascending order with 'before' parameter.""" + # Get messages before message 7 in ascending order + # Should get messages from root up to (but excluding) message 7 + _response = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=10, order="asc", before=_messages[7].id), + ) + + # Should get messages 0-6 in ascending order + assert len(_response.data) == 7 + assert [_msg.content for _msg in _response.data] == [ + f"Test message {i}" for i in range(7) + ] + assert _response.has_more is False + + def test_list_desc_with_before( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + _messages: list[Message], + ) -> None: + """Test listing messages in descending order with 'before' parameter.""" + # Get messages before (i.e., after in the tree) message 3 in descending + # order. Should get messages from message 3 down to the latest leaf + # (excluding message 3) + _response = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=10, order="desc", before=_messages[3].id), + ) + + # Should get messages 9-4 in descending order (excluding message 3) + assert len(_response.data) == 6 + assert [_msg.content for _msg in _response.data] == [ + f"Test message {i}" for i in range(9, 3, -1) + ] + assert _response.has_more is False + + def test_list_asc_with_before_paginated( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + _messages: list[Message], + ) -> None: + """Test listing messages in ascending order with 'before' and pagination.""" + # Get 3 messages before message 7 in ascending order + _response = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=3, order="asc", before=_messages[7].id), + ) + + # Should get messages 0-2 in ascending order + assert len(_response.data) == 3 + assert [_msg.content for _msg in _response.data] == [ + "Test message 0", + "Test message 1", + "Test message 2", + ] + assert _response.has_more is True + + def test_list_desc_with_before_paginated( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + _messages: list[Message], + ) -> None: + """Test listing messages in descending order with 'before' and pagination.""" + # Get 3 messages before (after in tree) message 3 in descending order + _response = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=3, order="desc", before=_messages[3].id), + ) + + # Should get messages 9-7 in descending order + assert len(_response.data) == 3 + assert [_msg.content for _msg in _response.data] == [ + "Test message 9", + "Test message 8", + "Test message 7", + ] + assert _response.has_more is True + + def test_list_branch1_with_after( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + _messages: list[Message], + ) -> None: + """Test querying branch 1 by starting from its first message.""" + # Query from the first message of branch 1 downward + _response = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=20, order="asc", after=_messages[0].id), + ) + + # Should get messages 1-9 from branch 1 (excluding message 0) + assert len(_response.data) == 9 + assert [_msg.content for _msg in _response.data] == [ + f"Test message {i}" for i in range(1, 10) + ] + assert _response.has_more is False + + def test_list_branch2_with_after( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + _messages: list[Message], + ) -> None: + """Test querying branch 2 by starting from its first message.""" + # Query from the first message of branch 2 downward + _response = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=20, order="asc", after=_messages[10].id), + ) + + # Should get messages 11-19 from branch 2 (excluding message 10) + assert len(_response.data) == 9 + assert [_msg.content for _msg in _response.data] == [ + f"Test message {i}" for i in range(11, 20) + ] + assert _response.has_more is False + + def test_list_branches_separately( + self, + _message_service: MessageService, + _workspace_id: UUID, + _thread_id: str, + _messages: list[Message], + ) -> None: + """Test that the two branches are separate by querying from each.""" + # Get branch 1: query from branch 1's last message going up + _branch1_response = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=20, order="desc", after=_messages[9].id), + ) + + # Should get messages 9-0 from branch 1 in descending order + assert len(_branch1_response.data) == 9 + assert [_msg.content for _msg in _branch1_response.data] == [ + f"Test message {i}" for i in range(8, -1, -1) + ] + + # Get branch 2: query from branch 2's last message going up + _branch2_response = _message_service.list_( + workspace_id=_workspace_id, + thread_id=_thread_id, + query=ListQuery(limit=20, order="desc", after=_messages[19].id), + ) + + # Should get messages 19-10 from branch 2 in descending order + assert len(_branch2_response.data) == 9 + assert [_msg.content for _msg in _branch2_response.data] == [ + f"Test message {i}" for i in range(18, 9, -1) + ] + + # Verify no overlap between branches + _branch1_ids = {_msg.id for _msg in _branch1_response.data} + _branch2_ids = {_msg.id for _msg in _branch2_response.data} + assert _branch1_ids.isdisjoint(_branch2_ids) diff --git a/tests/integration/chat/api/test_messages.py b/tests/integration/chat/api/test_messages.py index 15e61246..1ae347f8 100644 --- a/tests/integration/chat/api/test_messages.py +++ b/tests/integration/chat/api/test_messages.py @@ -9,7 +9,7 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from askui.chat.api.messages.models import Message +from askui.chat.api.messages.models import ROOT_MESSAGE_PARENT_ID, Message from askui.chat.api.messages.orms import MessageOrm from askui.chat.api.messages.service import MessageService from askui.chat.api.threads.models import Thread @@ -111,6 +111,7 @@ def test_list_messages_with_messages( workspace_id = UUID(test_headers["askui-workspace"]) mock_message = Message( id="msg_test123", + parent_id=ROOT_MESSAGE_PARENT_ID, object="thread.message", created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), thread_id="thread_test123", @@ -184,6 +185,7 @@ def test_list_messages_with_pagination( mock_message = Message( id=f"msg_test{i}", object="thread.message", + parent_id=ROOT_MESSAGE_PARENT_ID if i == 0 else f"msg_test{i - 1}", created_at=datetime.fromtimestamp(1234567890 + i, tz=timezone.utc), thread_id="thread_test123", role="user" if i % 2 == 0 else "assistant", @@ -380,6 +382,7 @@ def test_retrieve_message( mock_message = Message( id="msg_test123", object="thread.message", + parent_id=ROOT_MESSAGE_PARENT_ID, created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), thread_id="thread_test123", role="user", @@ -464,6 +467,7 @@ def test_delete_message( mock_message = Message( id="msg_test123", object="thread.message", + parent_id=ROOT_MESSAGE_PARENT_ID, created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), thread_id="thread_test123", role="user",