From e64b45d75a724d7532e05b2966836b7e806d0ad1 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Fri, 10 Oct 2025 10:46:10 +0200 Subject: [PATCH 01/14] feat(chat): migrate from json persistence to sqlite --- .gitignore | 1 + pdm.lock | 50 +++++- pyproject.toml | 1 + src/askui/chat/api/app.py | 7 +- src/askui/chat/api/assistants/dependencies.py | 9 +- src/askui/chat/api/assistants/models.py | 14 +- src/askui/chat/api/assistants/orms.py | 40 +++++ src/askui/chat/api/assistants/router.py | 16 +- src/askui/chat/api/assistants/service.py | 159 +++++++----------- src/askui/chat/api/db/__init__.py | 0 src/askui/chat/api/db/engine.py | 12 ++ src/askui/chat/api/db/orm/__init__.py | 0 src/askui/chat/api/db/orm/base.py | 13 ++ src/askui/chat/api/db/orm/types.py | 55 ++++++ src/askui/chat/api/db/queries.py | 34 ++++ src/askui/chat/api/db/session.py | 14 ++ src/askui/chat/api/settings.py | 14 +- 17 files changed, 310 insertions(+), 129 deletions(-) create mode 100644 src/askui/chat/api/assistants/orms.py create mode 100644 src/askui/chat/api/db/__init__.py create mode 100644 src/askui/chat/api/db/engine.py create mode 100644 src/askui/chat/api/db/orm/__init__.py create mode 100644 src/askui/chat/api/db/orm/base.py create mode 100644 src/askui/chat/api/db/orm/types.py create mode 100644 src/askui/chat/api/db/queries.py create mode 100644 src/askui/chat/api/db/session.py diff --git a/.gitignore b/.gitignore index b4a259a5..49db86ed 100644 --- a/.gitignore +++ b/.gitignore @@ -164,4 +164,5 @@ cython_debug/ reports/ .DS_Store /chat +/askui_chat.db diff --git a/pdm.lock b/pdm.lock index b97d166a..4c1ea011 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "all", "android", "bedrock", "chat", "dev", "pynput", "vertex", "web"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:b27f12e014074fafc8d1b4233d1b497bcb93a5c7f8c600a9766b4c7e591863a9" +content_hash = "sha256:a0f7f67a1fcacfb2f01865e1e523b861c4e2247dc7ec6578fe16eb60f639b956" [[metadata.targets]] requires_python = ">=3.10" @@ -3876,6 +3876,54 @@ files = [ {file = "soupsieve-2.8.tar.gz", hash = "sha256:e2dd4a40a628cb5f28f6d4b0db8800b8f581b65bb380b97de22ba5ca8d72572f"}, ] +[[package]] +name = "sqlalchemy" +version = "2.0.43" +requires_python = ">=3.7" +summary = "Database Abstraction Library" +groups = ["all", "chat"] +dependencies = [ + "greenlet>=1; (platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\") and python_version < \"3.14\"", + "importlib-metadata; python_version < \"3.8\"", + "typing-extensions>=4.6.0", +] +files = [ + {file = "sqlalchemy-2.0.43-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:70322986c0c699dca241418fcf18e637a4369e0ec50540a2b907b184c8bca069"}, + {file = "sqlalchemy-2.0.43-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:87accdbba88f33efa7b592dc2e8b2a9c2cdbca73db2f9d5c510790428c09c154"}, + {file = "sqlalchemy-2.0.43-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c00e7845d2f692ebfc7d5e4ec1a3fd87698e4337d09e58d6749a16aedfdf8612"}, + {file = "sqlalchemy-2.0.43-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:022e436a1cb39b13756cf93b48ecce7aa95382b9cfacceb80a7d263129dfd019"}, + {file = "sqlalchemy-2.0.43-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:c5e73ba0d76eefc82ec0219d2301cb33bfe5205ed7a2602523111e2e56ccbd20"}, + {file = "sqlalchemy-2.0.43-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9c2e02f06c68092b875d5cbe4824238ab93a7fa35d9c38052c033f7ca45daa18"}, + {file = "sqlalchemy-2.0.43-cp310-cp310-win32.whl", hash = "sha256:e7a903b5b45b0d9fa03ac6a331e1c1d6b7e0ab41c63b6217b3d10357b83c8b00"}, + {file = "sqlalchemy-2.0.43-cp310-cp310-win_amd64.whl", hash = "sha256:4bf0edb24c128b7be0c61cd17eef432e4bef507013292415f3fb7023f02b7d4b"}, + {file = "sqlalchemy-2.0.43-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:52d9b73b8fb3e9da34c2b31e6d99d60f5f99fd8c1225c9dad24aeb74a91e1d29"}, + {file = "sqlalchemy-2.0.43-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f42f23e152e4545157fa367b2435a1ace7571cab016ca26038867eb7df2c3631"}, + {file = "sqlalchemy-2.0.43-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4fb1a8c5438e0c5ea51afe9c6564f951525795cf432bed0c028c1cb081276685"}, + {file = "sqlalchemy-2.0.43-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db691fa174e8f7036afefe3061bc40ac2b770718be2862bfb03aabae09051aca"}, + {file = "sqlalchemy-2.0.43-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:fe2b3b4927d0bc03d02ad883f402d5de201dbc8894ac87d2e981e7d87430e60d"}, + {file = "sqlalchemy-2.0.43-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4d3d9b904ad4a6b175a2de0738248822f5ac410f52c2fd389ada0b5262d6a1e3"}, + {file = "sqlalchemy-2.0.43-cp311-cp311-win32.whl", hash = "sha256:5cda6b51faff2639296e276591808c1726c4a77929cfaa0f514f30a5f6156921"}, + {file = "sqlalchemy-2.0.43-cp311-cp311-win_amd64.whl", hash = "sha256:c5d1730b25d9a07727d20ad74bc1039bbbb0a6ca24e6769861c1aa5bf2c4c4a8"}, + {file = "sqlalchemy-2.0.43-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:20d81fc2736509d7a2bd33292e489b056cbae543661bb7de7ce9f1c0cd6e7f24"}, + {file = "sqlalchemy-2.0.43-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:25b9fc27650ff5a2c9d490c13c14906b918b0de1f8fcbb4c992712d8caf40e83"}, + {file = "sqlalchemy-2.0.43-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6772e3ca8a43a65a37c88e2f3e2adfd511b0b1da37ef11ed78dea16aeae85bd9"}, + {file = "sqlalchemy-2.0.43-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a113da919c25f7f641ffbd07fbc9077abd4b3b75097c888ab818f962707eb48"}, + {file = "sqlalchemy-2.0.43-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4286a1139f14b7d70141c67a8ae1582fc2b69105f1b09d9573494eb4bb4b2687"}, + {file = "sqlalchemy-2.0.43-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:529064085be2f4d8a6e5fab12d36ad44f1909a18848fcfbdb59cc6d4bbe48efe"}, + {file = "sqlalchemy-2.0.43-cp312-cp312-win32.whl", hash = "sha256:b535d35dea8bbb8195e7e2b40059e2253acb2b7579b73c1b432a35363694641d"}, + {file = "sqlalchemy-2.0.43-cp312-cp312-win_amd64.whl", hash = "sha256:1c6d85327ca688dbae7e2b06d7d84cfe4f3fffa5b5f9e21bb6ce9d0e1a0e0e0a"}, + {file = "sqlalchemy-2.0.43-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e7c08f57f75a2bb62d7ee80a89686a5e5669f199235c6d1dac75cd59374091c3"}, + {file = "sqlalchemy-2.0.43-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:14111d22c29efad445cd5021a70a8b42f7d9152d8ba7f73304c4d82460946aaa"}, + {file = "sqlalchemy-2.0.43-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21b27b56eb2f82653168cefe6cb8e970cdaf4f3a6cb2c5e3c3c1cf3158968ff9"}, + {file = "sqlalchemy-2.0.43-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c5a9da957c56e43d72126a3f5845603da00e0293720b03bde0aacffcf2dc04f"}, + {file = "sqlalchemy-2.0.43-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5d79f9fdc9584ec83d1b3c75e9f4595c49017f5594fee1a2217117647225d738"}, + {file = "sqlalchemy-2.0.43-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9df7126fd9db49e3a5a3999442cc67e9ee8971f3cb9644250107d7296cb2a164"}, + {file = "sqlalchemy-2.0.43-cp313-cp313-win32.whl", hash = "sha256:7f1ac7828857fcedb0361b48b9ac4821469f7694089d15550bbcf9ab22564a1d"}, + {file = "sqlalchemy-2.0.43-cp313-cp313-win_amd64.whl", hash = "sha256:971ba928fcde01869361f504fcff3b7143b47d30de188b11c6357c0505824197"}, + {file = "sqlalchemy-2.0.43-py3-none-any.whl", hash = "sha256:1681c21dd2ccee222c2fe0bef671d1aef7c504087c9c4e800371cfcc8ac966fc"}, + {file = "sqlalchemy-2.0.43.tar.gz", hash = "sha256:788bfcef6787a7764169cfe9859fe425bf44559619e1d9f56f5bddf2ebf6f417"}, +] + [[package]] name = "sse-starlette" version = "3.0.2" diff --git a/pyproject.toml b/pyproject.toml index ec838d81..7528fcc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -229,6 +229,7 @@ chat = [ "asgi-correlation-id>=4.3.4", "prometheus-fastapi-instrumentator>=7.1.0", "starlette-context>=0.4.0", + "sqlalchemy>=2.0.43", ] pynput = [ "mss>=10.0.0", diff --git a/src/askui/chat/api/app.py b/src/askui/chat/api/app.py index d5fa2080..496815fd 100644 --- a/src/askui/chat/api/app.py +++ b/src/askui/chat/api/app.py @@ -6,8 +6,9 @@ from fastapi.responses import JSONResponse from fastmcp import FastMCP -from askui.chat.api.assistants.dependencies import get_assistant_service from askui.chat.api.assistants.router import router as assistants_router +from askui.chat.api.db.orm.base import Base +from askui.chat.api.db.session import engine from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_settings from askui.chat.api.files.router import router as files_router from askui.chat.api.health.router import router as health_router @@ -36,8 +37,8 @@ @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 - assistant_service = get_assistant_service(settings=settings) - assistant_service.seed() + # TODO Move to mgiration script + Base.metadata.create_all(bind=engine) # type: ignore mcp_config_service = get_mcp_config_service(settings=settings) mcp_config_service.seed() yield diff --git a/src/askui/chat/api/assistants/dependencies.py b/src/askui/chat/api/assistants/dependencies.py index d0d99dfb..3211094b 100644 --- a/src/askui/chat/api/assistants/dependencies.py +++ b/src/askui/chat/api/assistants/dependencies.py @@ -1,13 +1,14 @@ from fastapi import Depends from askui.chat.api.assistants.service import AssistantService -from askui.chat.api.dependencies import SettingsDep -from askui.chat.api.settings import Settings +from askui.chat.api.db.session import SessionDep -def get_assistant_service(settings: Settings = SettingsDep) -> AssistantService: +def get_assistant_service( + session: SessionDep, +) -> AssistantService: """Get AssistantService instance.""" - return AssistantService(settings.data_dir) + return AssistantService(session) AssistantServiceDep = Depends(get_assistant_service) diff --git a/src/askui/chat/api/assistants/models.py b/src/askui/chat/api/assistants/models.py index 9d3a23aa..da18a7e3 100644 --- a/src/askui/chat/api/assistants/models.py +++ b/src/askui/chat/api/assistants/models.py @@ -18,11 +18,11 @@ class AssistantBase(BaseModel): system: str | None = None -class AssistantCreateParams(AssistantBase): +class AssistantCreate(AssistantBase): """Parameters for creating an assistant.""" -class AssistantModifyParams(BaseModelWithNotGiven): +class AssistantModify(BaseModelWithNotGiven): """Parameters for modifying an assistant.""" name: str | NotGiven = NOT_GIVEN @@ -41,7 +41,7 @@ class Assistant(AssistantBase, WorkspaceResource): @classmethod def create( - cls, workspace_id: WorkspaceId, params: AssistantCreateParams + cls, workspace_id: WorkspaceId | None, params: AssistantCreate ) -> "Assistant": return cls( id=generate_time_ordered_id("asst"), @@ -49,11 +49,3 @@ def create( workspace_id=workspace_id, **params.model_dump(), ) - - def modify(self, params: AssistantModifyParams) -> "Assistant": - return Assistant.model_validate( - { - **self.model_dump(), - **params.model_dump(), - } - ) diff --git a/src/askui/chat/api/assistants/orms.py b/src/askui/chat/api/assistants/orms.py new file mode 100644 index 00000000..b55e1cd1 --- /dev/null +++ b/src/askui/chat/api/assistants/orms.py @@ -0,0 +1,40 @@ +"""Assistant database model.""" + +from datetime import datetime +from uuid import UUID + +from sqlalchemy import JSON, String, Text, Uuid +from sqlalchemy.orm import Mapped, mapped_column + +from askui.chat.api.assistants.models import Assistant +from askui.chat.api.db.orm.base import Base +from askui.chat.api.db.orm.types import UnixDatetime, create_prefixed_id_type + +AssistantId = create_prefixed_id_type("asst") + + +class AssistantOrm(Base): + """Assistant database model.""" + + __tablename__ = "assistants" + + id: Mapped[str] = mapped_column(AssistantId, primary_key=True) + workspace_id: Mapped[UUID | None] = mapped_column(Uuid, nullable=True, index=True) + created_at: Mapped[datetime] = mapped_column( + UnixDatetime, nullable=False, index=True + ) + name: Mapped[str | None] = mapped_column(String, nullable=True) + description: Mapped[str | None] = mapped_column(String, nullable=True) + avatar: Mapped[str | None] = mapped_column(Text, nullable=True) + tools: Mapped[list[str]] = mapped_column(JSON, nullable=False) + system: Mapped[str | None] = mapped_column(Text, nullable=True) + + @classmethod + def from_model(cls, model: Assistant) -> "AssistantOrm": + return cls( + **model.model_dump(exclude={"object", "created_at"}), + created_at=model.created_at, + ) + + def to_model(self) -> Assistant: + return Assistant.model_validate(self, from_attributes=True) diff --git a/src/askui/chat/api/assistants/router.py b/src/askui/chat/api/assistants/router.py index 76ae8ab4..15888257 100644 --- a/src/askui/chat/api/assistants/router.py +++ b/src/askui/chat/api/assistants/router.py @@ -3,11 +3,7 @@ from fastapi import APIRouter, Header, status from askui.chat.api.assistants.dependencies import AssistantServiceDep -from askui.chat.api.assistants.models import ( - Assistant, - AssistantCreateParams, - AssistantModifyParams, -) +from askui.chat.api.assistants.models import Assistant, AssistantCreate, AssistantModify from askui.chat.api.assistants.service import AssistantService from askui.chat.api.dependencies import ListQueryDep from askui.chat.api.models import AssistantId, WorkspaceId @@ -18,7 +14,7 @@ @router.get("") def list_assistants( - askui_workspace: Annotated[WorkspaceId | None, Header()], + askui_workspace: Annotated[WorkspaceId | None, Header()] = None, query: ListQuery = ListQueryDep, assistant_service: AssistantService = AssistantServiceDep, ) -> ListResponse[Assistant]: @@ -27,7 +23,7 @@ def list_assistants( @router.post("", status_code=status.HTTP_201_CREATED) def create_assistant( - params: AssistantCreateParams, + params: AssistantCreate, askui_workspace: Annotated[WorkspaceId, Header()], assistant_service: AssistantService = AssistantServiceDep, ) -> Assistant: @@ -37,7 +33,7 @@ def create_assistant( @router.get("/{assistant_id}") def retrieve_assistant( assistant_id: AssistantId, - askui_workspace: Annotated[WorkspaceId | None, Header()], + askui_workspace: Annotated[WorkspaceId | None, Header()] = None, assistant_service: AssistantService = AssistantServiceDep, ) -> Assistant: return assistant_service.retrieve( @@ -49,7 +45,7 @@ def retrieve_assistant( def modify_assistant( assistant_id: AssistantId, askui_workspace: Annotated[WorkspaceId, Header()], - params: AssistantModifyParams, + params: AssistantModify, assistant_service: AssistantService = AssistantServiceDep, ) -> Assistant: return assistant_service.modify( @@ -60,7 +56,7 @@ def modify_assistant( @router.delete("/{assistant_id}", status_code=status.HTTP_204_NO_CONTENT) def delete_assistant( assistant_id: AssistantId, - askui_workspace: Annotated[WorkspaceId | None, Header()], + askui_workspace: Annotated[WorkspaceId, Header()], assistant_service: AssistantService = AssistantServiceDep, ) -> None: assistant_service.delete(workspace_id=askui_workspace, assistant_id=assistant_id) diff --git a/src/askui/chat/api/assistants/service.py b/src/askui/chat/api/assistants/service.py index 3c4248fc..a562687f 100644 --- a/src/askui/chat/api/assistants/service.py +++ b/src/askui/chat/api/assistants/service.py @@ -1,95 +1,84 @@ -from pathlib import Path +from sqlalchemy import or_ +from sqlalchemy.orm import Session -from askui.chat.api.assistants.models import ( - Assistant, - AssistantCreateParams, - AssistantModifyParams, -) -from askui.chat.api.assistants.seeds import SEEDS +from askui.chat.api.assistants.models import Assistant, AssistantCreate, AssistantModify +from askui.chat.api.assistants.orms import AssistantOrm +from askui.chat.api.db.queries import list_all from askui.chat.api.models import AssistantId, WorkspaceId -from askui.chat.api.utils import build_workspace_filter_fn -from askui.utils.api_utils import ( - LIST_LIMIT_MAX, - ConflictError, - ForbiddenError, - ListQuery, - ListResponse, - NotFoundError, - list_resources, -) +from askui.utils.api_utils import ForbiddenError, ListQuery, ListResponse, NotFoundError class AssistantService: - def __init__(self, base_dir: Path) -> None: - self._base_dir = base_dir - self._assistants_dir = base_dir / "assistants" - - def _get_assistant_path(self, assistant_id: AssistantId, new: bool = False) -> Path: - assistant_path = self._assistants_dir / f"{assistant_id}.json" - exists = assistant_path.exists() - if new and exists: - error_msg = f"Assistant {assistant_id} already exists" - raise ConflictError(error_msg) - if not new and not exists: - error_msg = f"Assistant {assistant_id} not found" - raise NotFoundError(error_msg) - return assistant_path + def __init__(self, session: Session) -> None: + self._session = session def list_( self, workspace_id: WorkspaceId | None, query: ListQuery ) -> ListResponse[Assistant]: - return list_resources( - self._assistants_dir, - query, - Assistant, - filter_fn=build_workspace_filter_fn(workspace_id, Assistant), + q = self._session.query(AssistantOrm).filter( + or_( + AssistantOrm.workspace_id == workspace_id, + AssistantOrm.workspace_id.is_(None), + ), ) + orms, has_more = list_all(q, query, AssistantOrm.id) + data = [orm.to_model() for orm in orms] + return ListResponse( + data=data, + has_more=has_more, + first_id=data[0].id if data else None, + last_id=data[-1].id if data else None, + ) + + def _find_by_id( + self, workspace_id: WorkspaceId | None, assistant_id: AssistantId + ) -> AssistantOrm: + assistant_orm = ( + self._session.query(AssistantOrm) + .filter( + AssistantOrm.id == assistant_id, + or_( + AssistantOrm.workspace_id == workspace_id, + AssistantOrm.workspace_id.is_(None), + ), + ) + .first() + ) + if assistant_orm is None: + error_msg = f"Assistant {assistant_id} not found" + raise NotFoundError(error_msg) + return assistant_orm def retrieve( self, workspace_id: WorkspaceId | None, assistant_id: AssistantId ) -> Assistant: - try: - assistant_path = self._get_assistant_path(assistant_id) - content = assistant_path.read_text() - if not content.strip(): - error_msg = f"Assistant {assistant_id} not found" - raise NotFoundError(error_msg) - assistant = Assistant.model_validate_json(content) - if not ( - assistant.workspace_id is None or assistant.workspace_id == workspace_id - ): - error_msg = f"Assistant {assistant_id} not found" - raise NotFoundError(error_msg) - except FileNotFoundError as e: - error_msg = f"Assistant {assistant_id} not found" - raise NotFoundError(error_msg) from e - except (ValueError, TypeError) as e: - # Handle JSON parsing errors - error_msg = f"Assistant {assistant_id} not found" - raise NotFoundError(error_msg) from e - else: - return assistant + assistant_model = self._find_by_id(workspace_id, assistant_id) + return assistant_model.to_model() def create( - self, workspace_id: WorkspaceId, params: AssistantCreateParams + self, workspace_id: WorkspaceId | None, params: AssistantCreate ) -> Assistant: assistant = Assistant.create(workspace_id, params) - self._save(assistant, new=True) + assistant_model = AssistantOrm.from_model(assistant) + self._session.add(assistant_model) + self._session.commit() return assistant def modify( self, - workspace_id: WorkspaceId, + workspace_id: WorkspaceId | None, assistant_id: AssistantId, - params: AssistantModifyParams, + params: AssistantModify, + force: bool = False, ) -> Assistant: - assistant = self.retrieve(workspace_id, assistant_id) - if assistant.workspace_id is None: + assistant_model = self._find_by_id(workspace_id, assistant_id) + if assistant_model.workspace_id is None and not force: error_msg = f"Default assistant {assistant_id} cannot be modified" raise ForbiddenError(error_msg) - modified = assistant.modify(params) - self._save(modified) - return modified + assistant_model.update(params.model_dump()) + self._session.commit() + self._session.refresh(assistant_model) + return assistant_model.to_model() def delete( self, @@ -97,35 +86,9 @@ def delete( assistant_id: AssistantId, force: bool = False, ) -> None: - try: - assistant = self.retrieve(workspace_id, assistant_id) - if assistant.workspace_id is None and not force: - error_msg = f"Default assistant {assistant_id} cannot be deleted" - raise ForbiddenError(error_msg) - try: - self._get_assistant_path(assistant_id).unlink() - except FileNotFoundError: - # File already deleted, that's fine - pass - except FileNotFoundError as e: - error_msg = f"Assistant {assistant_id} not found" - raise NotFoundError(error_msg) from e - except NotFoundError: - # If force=True and assistant doesn't exist, just ignore - if not force: - raise - # For force=True, we can ignore the NotFoundError - - def _save(self, assistant: Assistant, new: bool = False) -> None: - self._assistants_dir.mkdir(parents=True, exist_ok=True) - assistant_file = self._get_assistant_path(assistant.id, new=new) - assistant_file.write_text(assistant.model_dump_json(), encoding="utf-8") - - def seed(self) -> None: - """Seed the assistant service with default assistants.""" - for seed in SEEDS: - self.delete(None, seed.id, force=True) - try: - self._save(seed, new=True) - except ConflictError: # noqa: PERF203 - self._save(seed) + assistant_model = self._find_by_id(workspace_id, assistant_id) + if assistant_model.workspace_id is None and not force: + error_msg = f"Default assistant {assistant_id} cannot be deleted" + raise ForbiddenError(error_msg) + self._session.delete(assistant_model) + self._session.commit() diff --git a/src/askui/chat/api/db/__init__.py b/src/askui/chat/api/db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/db/engine.py b/src/askui/chat/api/db/engine.py new file mode 100644 index 00000000..4ed4fc43 --- /dev/null +++ b/src/askui/chat/api/db/engine.py @@ -0,0 +1,12 @@ +import logging + +from sqlalchemy import create_engine + +from askui.chat.api.dependencies import get_settings + +logger = logging.getLogger(__name__) + +settings = get_settings() +connect_args = {"check_same_thread": False} +echo = logger.isEnabledFor(logging.DEBUG) +engine = create_engine(settings.db.url, connect_args=connect_args, echo=echo) diff --git a/src/askui/chat/api/db/orm/__init__.py b/src/askui/chat/api/db/orm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/db/orm/base.py b/src/askui/chat/api/db/orm/base.py new file mode 100644 index 00000000..6e5ca243 --- /dev/null +++ b/src/askui/chat/api/db/orm/base.py @@ -0,0 +1,13 @@ +"""SQLAlchemy declarative base.""" + +from typing import Any + +from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass +from typing_extensions import Self + + +class Base(MappedAsDataclass, DeclarativeBase): + def update(self, values: dict[str, Any]) -> Self: + for key, value in values.items(): + setattr(self, key, value) + return self diff --git a/src/askui/chat/api/db/orm/types.py b/src/askui/chat/api/db/orm/types.py new file mode 100644 index 00000000..b6fe2cf1 --- /dev/null +++ b/src/askui/chat/api/db/orm/types.py @@ -0,0 +1,55 @@ +"""Custom SQLAlchemy types for chat API.""" + +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import Integer, String, TypeDecorator + + +def create_prefixed_id_type(prefix: str) -> type[TypeDecorator[str]]: + class PrefixedObjectId(TypeDecorator[str]): + impl = String(24) + cache_ok = True + + def process_bind_param(self, value: str | None, dialect: Any) -> str | None: + if value is None: + return value + return value[len(prefix) + 1 :] + + def process_result_value(self, value: str | None, dialect: Any) -> str | None: + if value is None: + return value + return f"{prefix}_{value}" + + return PrefixedObjectId + + +# Specialized types for each resource +# TODO Move into orms.py of the respective resource +ThreadId = create_prefixed_id_type("thread") +MessageId = create_prefixed_id_type("msg") +RunId = create_prefixed_id_type("run") +FileId = create_prefixed_id_type("file") +WorkflowId = create_prefixed_id_type("workflow") +McpConfigId = create_prefixed_id_type("mcp") + + +class UnixDatetime(TypeDecorator[datetime]): + impl = Integer + LOCAL_TIMEZONE = datetime.now().astimezone().tzinfo + + def process_bind_param( + self, value: datetime | int | None, dialect: Any + ) -> int | None: + if value is None: + return value + if isinstance(value, int): + return value + if value.tzinfo is None: + value = value.astimezone(self.LOCAL_TIMEZONE) + return int(value.astimezone(timezone.utc).timestamp()) + + def process_result_value(self, value: int | None, dialect: Any) -> datetime | None: + if value is None: + return value + return datetime.fromtimestamp(value, timezone.utc) diff --git a/src/askui/chat/api/db/queries.py b/src/askui/chat/api/db/queries.py new file mode 100644 index 00000000..ffb05550 --- /dev/null +++ b/src/askui/chat/api/db/queries.py @@ -0,0 +1,34 @@ +"""Shared query building utilities for database operations.""" + +from typing import Any, TypeVar + +from sqlalchemy import desc +from sqlalchemy.orm import InstrumentedAttribute, Query + +from askui.chat.api.db.orm.base import Base +from askui.utils.api_utils import ListQuery + +OrmT = TypeVar("OrmT", bound=Base) + + +def list_all( + db_query: Query[OrmT], + list_query: ListQuery, + id_column: InstrumentedAttribute[Any], +) -> tuple[list[OrmT], bool]: + if list_query.order == "asc": + if list_query.after: + db_query = db_query.filter(id_column > list_query.after) + if list_query.before: + db_query = db_query.filter(id_column < list_query.before) + db_query = db_query.order_by(id_column) + else: + if list_query.after: + db_query = db_query.filter(id_column < list_query.after) + if list_query.before: + db_query = db_query.filter(id_column > list_query.before) + db_query = db_query.order_by(desc(id_column)) + db_query = db_query.limit(list_query.limit + 1) + orms = db_query.all() + has_more = len(orms) > list_query.limit + return orms[: list_query.limit], has_more diff --git a/src/askui/chat/api/db/session.py b/src/askui/chat/api/db/session.py new file mode 100644 index 00000000..47118718 --- /dev/null +++ b/src/askui/chat/api/db/session.py @@ -0,0 +1,14 @@ +from typing import Annotated, Generator + +from fastapi import Depends +from sqlalchemy.orm import Session + +from askui.chat.api.db.engine import engine + + +def get_session() -> Generator[Session, None, None]: + with Session(engine) as session: + yield session + + +SessionDep = Annotated[Session, Depends(get_session)] diff --git a/src/askui/chat/api/settings.py b/src/askui/chat/api/settings.py index 18028e41..ee4388b0 100644 --- a/src/askui/chat/api/settings.py +++ b/src/askui/chat/api/settings.py @@ -1,7 +1,7 @@ from pathlib import Path from fastmcp.mcp_config import StdioMCPServer -from pydantic import Field +from pydantic import BaseModel, Field from pydantic_settings import BaseSettings, SettingsConfigDict from askui.chat.api.mcp_configs.models import McpConfig, RemoteMCPServer @@ -10,6 +10,15 @@ from askui.utils.datetime_utils import now +class DbSettings(BaseModel): + """Database configuration settings.""" + + url: str = Field( + default_factory=lambda: f"sqlite:///{(Path.cwd().absolute() / 'askui_chat.db').as_posix()}", + description="Database URL for SQLAlchemy connection", + ) + + def _get_default_mcp_configs(chat_api_host: str, chat_api_port: int) -> list[McpConfig]: return [ McpConfig( @@ -45,8 +54,9 @@ class Settings(BaseSettings): data_dir: Path = Field( default_factory=lambda: Path.cwd() / "chat", - description="Base directory for storing chat data", + description="Base directory for chat data (used during migration)", ) + db: DbSettings = Field(default_factory=DbSettings) host: str = Field( default="127.0.0.1", description="Host for the chat API", From 1eed6e20da9f52f47c50b3b534895fbcfd3aa596 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 14 Oct 2025 14:06:56 +0200 Subject: [PATCH 02/14] feat(chat): add migrations using alembic --- pdm.lock | 37 ++- pyproject.toml | 2 + src/askui/chat/api/app.py | 15 +- src/askui/chat/api/assistants/orms.py | 4 +- src/askui/chat/api/db/orm/types.py | 2 +- src/askui/chat/api/settings.py | 17 +- src/askui/chat/api/telemetry/logs/__init__.py | 10 + src/askui/chat/api/telemetry/logs/settings.py | 22 +- .../chat/api/telemetry/logs/structlog.py | 6 +- src/askui/chat/migrations/__init__.py | 0 src/askui/chat/migrations/alembic.ini | 114 +++++++ src/askui/chat/migrations/env.py | 71 ++++ src/askui/chat/migrations/runner.py | 23 ++ src/askui/chat/migrations/script.py.mako | 26 ++ src/askui/chat/migrations/shared/__init__.py | 0 .../migrations/shared/assistants/__init__.py | 0 .../migrations/shared/assistants/models.py | 35 ++ .../shared}/assistants/seeds.py | 307 ++++++++++++++++-- src/askui/chat/migrations/shared/models.py | 13 + src/askui/chat/migrations/shared/settings.py | 29 ++ src/askui/chat/migrations/shared/utils.py | 7 + .../057f82313448_import_json_assistants.py | 109 +++++++ .../37007a499ca7_remove_assistants_dir.py | 51 +++ .../4d1e043b4254_create_assistants_table.py | 40 +++ .../c35e88ea9595_seed_default_assistants.py | 61 ++++ 25 files changed, 943 insertions(+), 58 deletions(-) create mode 100644 src/askui/chat/migrations/__init__.py create mode 100644 src/askui/chat/migrations/alembic.ini create mode 100644 src/askui/chat/migrations/env.py create mode 100644 src/askui/chat/migrations/runner.py create mode 100644 src/askui/chat/migrations/script.py.mako create mode 100644 src/askui/chat/migrations/shared/__init__.py create mode 100644 src/askui/chat/migrations/shared/assistants/__init__.py create mode 100644 src/askui/chat/migrations/shared/assistants/models.py rename src/askui/chat/{api => migrations/shared}/assistants/seeds.py (63%) create mode 100644 src/askui/chat/migrations/shared/models.py create mode 100644 src/askui/chat/migrations/shared/settings.py create mode 100644 src/askui/chat/migrations/shared/utils.py create mode 100644 src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py create mode 100644 src/askui/chat/migrations/versions/37007a499ca7_remove_assistants_dir.py create mode 100644 src/askui/chat/migrations/versions/4d1e043b4254_create_assistants_table.py create mode 100644 src/askui/chat/migrations/versions/c35e88ea9595_seed_default_assistants.py diff --git a/pdm.lock b/pdm.lock index 4c1ea011..20266647 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "all", "android", "bedrock", "chat", "dev", "pynput", "vertex", "web"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:a0f7f67a1fcacfb2f01865e1e523b861c4e2247dc7ec6578fe16eb60f639b956" +content_hash = "sha256:3dd80953b4e83ea83b341f9627ac0f45b11dd4b0ba1f69a19d5405a27c90a495" [[metadata.targets]] requires_python = ">=3.10" @@ -21,6 +21,23 @@ files = [ {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, ] +[[package]] +name = "alembic" +version = "1.16.5" +requires_python = ">=3.9" +summary = "A database migration tool for SQLAlchemy." +groups = ["all", "chat"] +dependencies = [ + "Mako", + "SQLAlchemy>=1.4.0", + "tomli; python_version < \"3.11\"", + "typing-extensions>=4.12", +] +files = [ + {file = "alembic-1.16.5-py3-none-any.whl", hash = "sha256:e845dfe090c5ffa7b92593ae6687c5cb1a101e91fa53868497dbd79847f9dbe3"}, + {file = "alembic-1.16.5.tar.gz", hash = "sha256:a88bb7f6e513bd4301ecf4c7f2206fe93f9913f9b48dac3b78babde2d6fe765e"}, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -2085,6 +2102,20 @@ files = [ {file = "magika-0.6.2.tar.gz", hash = "sha256:37eb6ae8020f6e68f231bc06052c0a0cbe8e6fa27492db345e8dc867dbceb067"}, ] +[[package]] +name = "mako" +version = "1.3.10" +requires_python = ">=3.8" +summary = "A super-fast templating language that borrows the best ideas from the existing templating languages." +groups = ["all", "chat"] +dependencies = [ + "MarkupSafe>=0.9.2", +] +files = [ + {file = "mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59"}, + {file = "mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28"}, +] + [[package]] name = "mammoth" version = "1.11.0" @@ -2173,7 +2204,7 @@ name = "markupsafe" version = "3.0.2" requires_python = ">=3.9" summary = "Safely add untrusted strings to HTML/XML markup." -groups = ["default", "dev"] +groups = ["default", "all", "chat", "dev"] files = [ {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8"}, {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158"}, @@ -4012,7 +4043,7 @@ name = "tomli" version = "2.2.1" requires_python = ">=3.8" summary = "A lil' TOML parser" -groups = ["dev"] +groups = ["all", "chat", "dev"] marker = "python_version <= \"3.11\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, diff --git a/pyproject.toml b/pyproject.toml index 7528fcc9..97c45dba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ path = "src/askui/__init__.py" distribution = true [tool.pdm.scripts] +alembic = "alembic -c src/askui/chat/migrations/alembic.ini" test = "pytest -n auto" "test:cov" = "pytest -n auto --cov=src/askui --cov-report=html" "test:cov:view" = "python -m http.server --directory htmlcov" @@ -230,6 +231,7 @@ chat = [ "prometheus-fastapi-instrumentator>=7.1.0", "starlette-context>=0.4.0", "sqlalchemy>=2.0.43", + "alembic>=1.16.5", ] pynput = [ "mss>=10.0.0", diff --git a/src/askui/chat/api/app.py b/src/askui/chat/api/app.py index 496815fd..28f93c4d 100644 --- a/src/askui/chat/api/app.py +++ b/src/askui/chat/api/app.py @@ -1,3 +1,4 @@ +import logging from contextlib import asynccontextmanager from typing import AsyncGenerator @@ -7,8 +8,6 @@ from fastmcp import FastMCP from askui.chat.api.assistants.router import router as assistants_router -from askui.chat.api.db.orm.base import Base -from askui.chat.api.db.session import engine from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_settings from askui.chat.api.files.router import router as files_router from askui.chat.api.health.router import router as health_router @@ -24,6 +23,7 @@ from askui.chat.api.runs.router import router as runs_router from askui.chat.api.threads.router import router as threads_router from askui.chat.api.workflows.router import router as workflows_router +from askui.chat.migrations.runner import run_migrations from askui.utils.api_utils import ( ConflictError, FileTooLargeError, @@ -32,16 +32,23 @@ NotFoundError, ) +logger = logging.getLogger(__name__) + + settings = get_settings() @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 - # TODO Move to mgiration script - Base.metadata.create_all(bind=engine) # type: ignore + if settings.db.auto_migrate: + run_migrations() + else: + logger.info("Automatic migrations are disabled. Skipping migrations...") + logger.info("Seeding default MCP configurations...") mcp_config_service = get_mcp_config_service(settings=settings) mcp_config_service.seed() yield + logger.info("Disconnecting all MCP clients...") await get_mcp_client_manager_manager(mcp_config_service).disconnect_all(force=True) diff --git a/src/askui/chat/api/assistants/orms.py b/src/askui/chat/api/assistants/orms.py index b55e1cd1..4c84f875 100644 --- a/src/askui/chat/api/assistants/orms.py +++ b/src/askui/chat/api/assistants/orms.py @@ -20,9 +20,7 @@ class AssistantOrm(Base): id: Mapped[str] = mapped_column(AssistantId, primary_key=True) workspace_id: Mapped[UUID | None] = mapped_column(Uuid, nullable=True, index=True) - created_at: Mapped[datetime] = mapped_column( - UnixDatetime, nullable=False, index=True - ) + created_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) name: Mapped[str | None] = mapped_column(String, nullable=True) description: Mapped[str | None] = mapped_column(String, nullable=True) avatar: Mapped[str | None] = mapped_column(Text, nullable=True) diff --git a/src/askui/chat/api/db/orm/types.py b/src/askui/chat/api/db/orm/types.py index b6fe2cf1..f1fe85c2 100644 --- a/src/askui/chat/api/db/orm/types.py +++ b/src/askui/chat/api/db/orm/types.py @@ -14,7 +14,7 @@ class PrefixedObjectId(TypeDecorator[str]): def process_bind_param(self, value: str | None, dialect: Any) -> str | None: if value is None: return value - return value[len(prefix) + 1 :] + return value.removeprefix(f"{prefix}_") def process_result_value(self, value: str | None, dialect: Any) -> str | None: if value is None: diff --git a/src/askui/chat/api/settings.py b/src/askui/chat/api/settings.py index ee4388b0..72e9b6c9 100644 --- a/src/askui/chat/api/settings.py +++ b/src/askui/chat/api/settings.py @@ -1,7 +1,7 @@ from pathlib import Path from fastmcp.mcp_config import StdioMCPServer -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict from askui.chat.api.mcp_configs.models import McpConfig, RemoteMCPServer @@ -17,6 +17,21 @@ class DbSettings(BaseModel): default_factory=lambda: f"sqlite:///{(Path.cwd().absolute() / 'askui_chat.db').as_posix()}", description="Database URL for SQLAlchemy connection", ) + auto_migrate: bool = Field( + default=True, + description="Whether to run migrations automatically on startup", + ) + + @field_validator("url") + @classmethod + def validate_sqlite_url(cls, v: str) -> str: + """Ensure only synchronous SQLite URLs are allowed.""" + if not v.startswith("sqlite://"): + error_msg = ( + "Only synchronous SQLite URLs are allowed (must start with 'sqlite://')" + ) + raise ValueError(error_msg) + return v def _get_default_mcp_configs(chat_api_host: str, chat_api_port: int) -> list[McpConfig]: diff --git a/src/askui/chat/api/telemetry/logs/__init__.py b/src/askui/chat/api/telemetry/logs/__init__.py index d572fd8c..739ec2d8 100644 --- a/src/askui/chat/api/telemetry/logs/__init__.py +++ b/src/askui/chat/api/telemetry/logs/__init__.py @@ -7,6 +7,8 @@ from .settings import LogSettings from .structlog import setup_structlog +logger = logging_stdlib.getLogger(__name__) + def setup_uncaught_exception_logging(logger: logging_stdlib.Logger) -> None: def handle_uncaught_exception( @@ -44,11 +46,19 @@ def silence_logs(loggers: list[str]) -> None: logger.propagate = False +_logging_setup = False + + def setup_logging( settings: LogSettings, pre_processors: list[structlog_types.Processor] | None = None, ) -> None: + global _logging_setup + if _logging_setup: + logger.debug("Logging already setup. Skipping setup...") + return logging_stdlib.captureWarnings(True) root_logger = logging_stdlib.getLogger() setup_structlog(root_logger, settings, pre_processors) setup_uncaught_exception_logging(root_logger) + _logging_setup = True diff --git a/src/askui/chat/api/telemetry/logs/settings.py b/src/askui/chat/api/telemetry/logs/settings.py index eb663101..81b2386f 100644 --- a/src/askui/chat/api/telemetry/logs/settings.py +++ b/src/askui/chat/api/telemetry/logs/settings.py @@ -1,20 +1,12 @@ -import enum +import logging from typing import Literal -from pydantic import BaseModel +from pydantic import BaseModel, Field +logger = logging.getLogger(__name__) -class LogLevel(str, enum.Enum): - CRITICAL = "CRITICAL" - ERROR = "ERROR" - WARNING = "WARNING" - INFO = "INFO" - DEBUG = "DEBUG" - - -class LogFormat(str, enum.Enum): - JSON = "json" - LOGFMT = "logfmt" +LogFormat = Literal["JSON", "LOGFMT"] +LogLevel = Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] class EqualsLogFilter(BaseModel): @@ -27,6 +19,6 @@ class EqualsLogFilter(BaseModel): class LogSettings(BaseModel): - format: LogFormat = LogFormat.LOGFMT - level: LogLevel = LogLevel.INFO + format: LogFormat = Field("LOGFMT") + level: LogLevel = Field("INFO") filters: list[LogFilter] | None = None diff --git a/src/askui/chat/api/telemetry/logs/structlog.py b/src/askui/chat/api/telemetry/logs/structlog.py index 20ea4110..c2f61e8f 100644 --- a/src/askui/chat/api/telemetry/logs/structlog.py +++ b/src/askui/chat/api/telemetry/logs/structlog.py @@ -38,7 +38,7 @@ def configure_stdlib_logger( handler = logging.StreamHandler() handler.setFormatter(formatter) logger.addHandler(handler) - logger.setLevel(log_level.value) + logger.setLevel(log_level) EVENT_KEY = "message" @@ -67,7 +67,7 @@ def get_shared_processors(settings: LogSettings) -> list[structlog.types.Process def get_format_dependent_processors( log_format: LogFormat, ) -> list[structlog.types.Processor]: - if log_format == LogFormat.JSON: + if log_format == "JSON": return [structlog.processors.format_exc_info] return [ structlog.dev.set_exc_info, @@ -76,6 +76,6 @@ def get_format_dependent_processors( def get_renderer(log_format: LogFormat) -> structlog.types.Processor: - if log_format == LogFormat.JSON: + if log_format == "JSON": return structlog.processors.JSONRenderer() return structlog.dev.ConsoleRenderer(event_key=EVENT_KEY) diff --git a/src/askui/chat/migrations/__init__.py b/src/askui/chat/migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/migrations/alembic.ini b/src/askui/chat/migrations/alembic.ini new file mode 100644 index 00000000..a035c7ac --- /dev/null +++ b/src/askui/chat/migrations/alembic.ini @@ -0,0 +1,114 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = src/askui/chat/migrations + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version number format +version_num_format = %%04d + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses +# os.pathsep. If this key is omitted entirely, it falls back to the legacy +# behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# overridden inside env.py +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +# Overridden by env.py +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = INFO +handlers = +qualname = + +[logger_sqlalchemy] +level = INFO +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/src/askui/chat/migrations/env.py b/src/askui/chat/migrations/env.py new file mode 100644 index 00000000..5b83ccc7 --- /dev/null +++ b/src/askui/chat/migrations/env.py @@ -0,0 +1,71 @@ +"""Alembic environment configuration.""" + +import logging + +from alembic import context + +# We need to import the orms to ensure they are registered +import askui.chat.api.assistants.orms +from askui.chat.api.db.orm.base import Base +from askui.chat.api.dependencies import get_settings +from askui.chat.api.telemetry.logs import setup_logging + +config = context.config +settings = get_settings() +setup_logging(settings.telemetry.log) +sqlalchemy_logger = logging.getLogger("sqlalchemy.engine") +alembic_logger = logging.getLogger("alembic") +sqlalchemy_logger.setLevel(settings.telemetry.log.level) +alembic_logger.setLevel(settings.telemetry.log.level) +target_metadata = Base.metadata + + +def get_url() -> str: + """Get database URL from settings.""" + return settings.db.url + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = get_url() + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + from askui.chat.api.db.engine import engine + + with engine.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/src/askui/chat/migrations/runner.py b/src/askui/chat/migrations/runner.py new file mode 100644 index 00000000..014a9f74 --- /dev/null +++ b/src/askui/chat/migrations/runner.py @@ -0,0 +1,23 @@ +"""Migration runner for Alembic.""" + +import logging +from pathlib import Path + +from alembic import command +from alembic.config import Config + +logger = logging.getLogger(__name__) + + +def run_migrations() -> None: + """Run Alembic migrations to upgrade database to head.""" + migrations_dir = Path(__file__).parent + alembic_cfg = Config(str(migrations_dir / "alembic.ini")) + alembic_cfg.set_main_option("script_location", str(migrations_dir)) + logger.info("Running database migrations...") + try: + command.upgrade(alembic_cfg, "head") + logger.info("Database migrations completed successfully") + except Exception: + logger.exception("Failed to run database migrations") + raise diff --git a/src/askui/chat/migrations/script.py.mako b/src/askui/chat/migrations/script.py.mako new file mode 100644 index 00000000..fbc4b07d --- /dev/null +++ b/src/askui/chat/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/src/askui/chat/migrations/shared/__init__.py b/src/askui/chat/migrations/shared/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/migrations/shared/assistants/__init__.py b/src/askui/chat/migrations/shared/assistants/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/migrations/shared/assistants/models.py b/src/askui/chat/migrations/shared/assistants/models.py new file mode 100644 index 00000000..99d9ddf8 --- /dev/null +++ b/src/askui/chat/migrations/shared/assistants/models.py @@ -0,0 +1,35 @@ +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, BeforeValidator, Field + +from askui.chat.migrations.shared.models import UnixDatetimeV1, WorkspaceIdV1 + + +def add_prefix(id_: str) -> str: + if id_.startswith("asst_"): + return id_ + return f"asst_{id_}" + + +AssistantIdV1 = Annotated[ + str, Field(pattern=r"^asst_[a-z0-9]+$"), BeforeValidator(add_prefix) +] + + +class AssistantV1(BaseModel): + id: AssistantIdV1 + object: Literal["assistant"] = "assistant" + created_at: UnixDatetimeV1 + workspace_id: WorkspaceIdV1 | None = None + name: str | None = None + description: str | None = None + avatar: str | None = None + tools: list[str] = Field(default_factory=list) + system: str | None = None + + def to_db_dict(self) -> dict[str, Any]: + return { + **self.model_dump(exclude={"id", "object", "workspace_id"}), + "id": self.id.removeprefix("asst_"), + "workspace_id": str(self.workspace_id) if self.workspace_id else None, + } diff --git a/src/askui/chat/api/assistants/seeds.py b/src/askui/chat/migrations/shared/assistants/seeds.py similarity index 63% rename from src/askui/chat/api/assistants/seeds.py rename to src/askui/chat/migrations/shared/assistants/seeds.py index a49e0b5a..0b7fe8bc 100644 --- a/src/askui/chat/api/assistants/seeds.py +++ b/src/askui/chat/migrations/shared/assistants/seeds.py @@ -1,19 +1,28 @@ -from askui.chat.api.assistants.models import Assistant -from askui.prompts.system import ( - ANDROID_AGENT_SYSTEM_PROMPT, - COMPUTER_AGENT_SYSTEM_PROMPT, - ORCHESTRATOR_AGENT_SYSTEM_PROMPT, - TESTING_AGENT_SYSTEM_PROMPT, - WEB_AGENT_SYSTEM_PROMPT, -) -from askui.utils.datetime_utils import now +import platform +import sys + +from askui.chat.migrations.shared.assistants.models import AssistantV1 +from askui.chat.migrations.shared.utils import now_v1 -COMPUTER_AGENT = Assistant( +COMPUTER_AGENT_V1 = AssistantV1( id="asst_68ac2c4edc4b2f27faa5a253", - created_at=now(), + created_at=now_v1(), name="Computer Agent", avatar="data:image/webp;base64,UklGRswRAABXRUJQVlA4WAoAAAA4AAAAPwAAPwAASUNDUEgMAAAAAAxITGlubwIQAABtbnRyUkdCIFhZWiAHzgACAAkABgAxAABhY3NwTVNGVAAAAABJRUMgc1JHQgAAAAAAAAAAAAAAAQAA9tYAAQAAAADTLUhQICAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABFjcHJ0AAABUAAAADNkZXNjAAABhAAAAGx3dHB0AAAB8AAAABRia3B0AAACBAAAABRyWFlaAAACGAAAABRnWFlaAAACLAAAABRiWFlaAAACQAAAABRkbW5kAAACVAAAAHBkbWRkAAACxAAAAIh2dWVkAAADTAAAAIZ2aWV3AAAD1AAAACRsdW1pAAAD+AAAABRtZWFzAAAEDAAAACR0ZWNoAAAEMAAAAAxyVFJDAAAEPAAACAxnVFJDAAAEPAAACAxiVFJDAAAEPAAACAx0ZXh0AAAAAENvcHlyaWdodCAoYykgMTk5OCBIZXdsZXR0LVBhY2thcmQgQ29tcGFueQAAZGVzYwAAAAAAAAASc1JHQiBJRUM2MTk2Ni0yLjEAAAAAAAAAAAAAABJzUkdCIElFQzYxOTY2LTIuMQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWFlaIAAAAAAAAPNRAAEAAAABFsxYWVogAAAAAAAAAAAAAAAAAAAAAFhZWiAAAAAAAABvogAAOPUAAAOQWFlaIAAAAAAAAGKZAAC3hQAAGNpYWVogAAAAAAAAJKAAAA+EAAC2z2Rlc2MAAAAAAAAAFklFQyBodHRwOi8vd3d3LmllYy5jaAAAAAAAAAAAAAAAFklFQyBodHRwOi8vd3d3LmllYy5jaAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABkZXNjAAAAAAAAAC5JRUMgNjE5NjYtMi4xIERlZmF1bHQgUkdCIGNvbG91ciBzcGFjZSAtIHNSR0IAAAAAAAAAAAAAAC5JRUMgNjE5NjYtMi4xIERlZmF1bHQgUkdCIGNvbG91ciBzcGFjZSAtIHNSR0IAAAAAAAAAAAAAAAAAAAAAAAAAAAAAZGVzYwAAAAAAAAAsUmVmZXJlbmNlIFZpZXdpbmcgQ29uZGl0aW9uIGluIElFQzYxOTY2LTIuMQAAAAAAAAAAAAAALFJlZmVyZW5jZSBWaWV3aW5nIENvbmRpdGlvbiBpbiBJRUM2MTk2Ni0yLjEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHZpZXcAAAAAABOk/gAUXy4AEM8UAAPtzAAEEwsAA1yeAAAAAVhZWiAAAAAAAEwJVgBQAAAAVx/nbWVhcwAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAAAAAo8AAAACc2lnIAAAAABDUlQgY3VydgAAAAAAAAQAAAAABQAKAA8AFAAZAB4AIwAoAC0AMgA3ADsAQABFAEoATwBUAFkAXgBjAGgAbQByAHcAfACBAIYAiwCQAJUAmgCfAKQAqQCuALIAtwC8AMEAxgDLANAA1QDbAOAA5QDrAPAA9gD7AQEBBwENARMBGQEfASUBKwEyATgBPgFFAUwBUgFZAWABZwFuAXUBfAGDAYsBkgGaAaEBqQGxAbkBwQHJAdEB2QHhAekB8gH6AgMCDAIUAh0CJgIvAjgCQQJLAlQCXQJnAnECegKEAo4CmAKiAqwCtgLBAssC1QLgAusC9QMAAwsDFgMhAy0DOANDA08DWgNmA3IDfgOKA5YDogOuA7oDxwPTA+AD7AP5BAYEEwQgBC0EOwRIBFUEYwRxBH4EjASaBKgEtgTEBNME4QTwBP4FDQUcBSsFOgVJBVgFZwV3BYYFlgWmBbUFxQXVBeUF9gYGBhYGJwY3BkgGWQZqBnsGjAadBq8GwAbRBuMG9QcHBxkHKwc9B08HYQd0B4YHmQesB78H0gflB/gICwgfCDIIRghaCG4IggiWCKoIvgjSCOcI+wkQCSUJOglPCWQJeQmPCaQJugnPCeUJ+woRCicKPQpUCmoKgQqYCq4KxQrcCvMLCwsiCzkLUQtpC4ALmAuwC8gL4Qv5DBIMKgxDDFwMdQyODKcMwAzZDPMNDQ0mDUANWg10DY4NqQ3DDd4N+A4TDi4OSQ5kDn8Omw62DtIO7g8JDyUPQQ9eD3oPlg+zD88P7BAJECYQQxBhEH4QmxC5ENcQ9RETETERTxFtEYwRqhHJEegSBxImEkUSZBKEEqMSwxLjEwMTIxNDE2MTgxOkE8UT5RQGFCcUSRRqFIsUrRTOFPAVEhU0FVYVeBWbFb0V4BYDFiYWSRZsFo8WshbWFvoXHRdBF2UXiReuF9IX9xgbGEAYZRiKGK8Y1Rj6GSAZRRlrGZEZtxndGgQaKhpRGncanhrFGuwbFBs7G2MbihuyG9ocAhwqHFIcexyjHMwc9R0eHUcdcB2ZHcMd7B4WHkAeah6UHr4e6R8THz4faR+UH78f6iAVIEEgbCCYIMQg8CEcIUghdSGhIc4h+yInIlUigiKvIt0jCiM4I2YjlCPCI/AkHyRNJHwkqyTaJQklOCVoJZclxyX3JicmVyaHJrcm6CcYJ0kneierJ9woDSg/KHEooijUKQYpOClrKZ0p0CoCKjUqaCqbKs8rAis2K2krnSvRLAUsOSxuLKIs1y0MLUEtdi2rLeEuFi5MLoIuty7uLyQvWi+RL8cv/jA1MGwwpDDbMRIxSjGCMbox8jIqMmMymzLUMw0zRjN/M7gz8TQrNGU0njTYNRM1TTWHNcI1/TY3NnI2rjbpNyQ3YDecN9c4FDhQOIw4yDkFOUI5fzm8Ofk6Njp0OrI67zstO2s7qjvoPCc8ZTykPOM9Ij1hPaE94D4gPmA+oD7gPyE/YT+iP+JAI0BkQKZA50EpQWpBrEHuQjBCckK1QvdDOkN9Q8BEA0RHRIpEzkUSRVVFmkXeRiJGZ0arRvBHNUd7R8BIBUhLSJFI10kdSWNJqUnwSjdKfUrESwxLU0uaS+JMKkxyTLpNAk1KTZNN3E4lTm5Ot08AT0lPk0/dUCdQcVC7UQZRUFGbUeZSMVJ8UsdTE1NfU6pT9lRCVI9U21UoVXVVwlYPVlxWqVb3V0RXklfgWC9YfVjLWRpZaVm4WgdaVlqmWvVbRVuVW+VcNVyGXNZdJ114XcleGl5sXr1fD19hX7NgBWBXYKpg/GFPYaJh9WJJYpxi8GNDY5dj62RAZJRk6WU9ZZJl52Y9ZpJm6Gc9Z5Nn6Wg/aJZo7GlDaZpp8WpIap9q92tPa6dr/2xXbK9tCG1gbbluEm5rbsRvHm94b9FwK3CGcOBxOnGVcfByS3KmcwFzXXO4dBR0cHTMdSh1hXXhdj52m3b4d1Z3s3gReG54zHkqeYl553pGeqV7BHtje8J8IXyBfOF9QX2hfgF+Yn7CfyN/hH/lgEeAqIEKgWuBzYIwgpKC9INXg7qEHYSAhOOFR4Wrhg6GcobXhzuHn4gEiGmIzokziZmJ/opkisqLMIuWi/yMY4zKjTGNmI3/jmaOzo82j56QBpBukNaRP5GokhGSepLjk02TtpQglIqU9JVflcmWNJaflwqXdZfgmEyYuJkkmZCZ/JpomtWbQpuvnByciZz3nWSd0p5Anq6fHZ+Ln/qgaaDYoUehtqImopajBqN2o+akVqTHpTilqaYapoum/adup+CoUqjEqTepqaocqo+rAqt1q+msXKzQrUStuK4trqGvFq+LsACwdbDqsWCx1rJLssKzOLOutCW0nLUTtYq2AbZ5tvC3aLfguFm40blKucK6O7q1uy67p7whvJu9Fb2Pvgq+hL7/v3q/9cBwwOzBZ8Hjwl/C28NYw9TEUcTOxUvFyMZGxsPHQce/yD3IvMk6ybnKOMq3yzbLtsw1zLXNNc21zjbOts83z7jQOdC60TzRvtI/0sHTRNPG1EnUy9VO1dHWVdbY11zX4Nhk2OjZbNnx2nba+9uA3AXcit0Q3ZbeHN6i3ynfr+A24L3hROHM4lPi2+Nj4+vkc+T85YTmDeaW5x/nqegy6LzpRunQ6lvq5etw6/vshu0R7ZzuKO6070DvzPBY8OXxcvH/8ozzGfOn9DT0wvVQ9d72bfb794r4Gfio+Tj5x/pX+uf7d/wH/Jj9Kf26/kv+3P9t//9BTFBISgIAAAGQRFubIUmRKLRt27btXo3tmZVtGyvbtm3btm1PKf9hRNQ/2kbEBJD/XtEzJDTSW/59UqZfffjo2bUV7dx+D8uOD+CXO2N+B8eZBqA8EfEbdAT6lXboHA4zKH3Q5bxngDsh2BooLEpzbK2AeQIy13VsA5DVUdiaIyv6wPQqA5n9YRbjWEtkZBTL8SAZW+YbOkN/Gwmb3SG6k6EaAZv1AboF9hLB7nqJ6lVNtYAu/Q3N255OEsHraPOTMUA7I1AroLFsf2pbdHBFQbtXNGeKnCSC0jK2dt0lBoCLd3Sf9EC7LFojYBCb7H5qNAFX3WofgjB07Hsw46K6bXpNTDZPnRtgXiMAnPY3R41ngLGtGYIvA8ozsdwsZwHSmw15NdBjgUfJnKYB3h1+XByPI4K5Wh7xLzG9z+NR/BmTsSaPeiZM7zJ49AHMpnY8xqGCCwlsmnW44HZPZxbvm8hA2R3H4HICG8AyLR3pge9J+O/2OIxOtQbfgyA6/9v4nkXSxb3Cdz+IzvUMvqcRdGQovvuBDFV6bKbhMoP7MWzX/AlrTNdBG3WYJqiZCCGWoxQ8t5IEHsTtCJ4jToRv5VUFy3SJEwndhaUz4d7OhONdEb8en42KguCol8it5sW3egwLrPhZZLUYMGnD1TeKmbpLArefar0zWq5+qtPr9XqdTvfl86dPHz+8f/fm5dOH927duH7l4rpIgZjfMqasqrqqsry8vLQwPz8vKyMtJS4yJMDHy9PNxZL8FQVBlCRZlmWVxlIjyGqNVqsSCSGCyAJWUDggmAIAAPAPAJ0BKkAAQAA+kTyZSaWjIiEoG/tQsBIJZwDJ+FvKFv67tp+RMGNcDv9lccKdO60gqH5MeC/EtYd0zUOkfunNPYDWjWiNhmKAoXQQo+GDUJN88TbmU3UWr3+Ddvu5s2Eukjyw96p2S2tDw6aI0pTAjeRIQ6n3XnUckPx9C41nFxlceofxDAAA/v02dN7UOf8yY/rkv0FMLUFIHHa9HnubqLmPG/0PkUpWsuTIsdDzQ+GNLWeNexqbAVHtbBcjuWw7309/2TyreZ4Urf9AIlZhs2YdWHB63q32OItvwlhMbe45JBwS+01b8Kv8Kys6xPiAbY2ffwyrFZto2clrktrTUGODsZj+xsiMg2gl61QSCU6DPDV/rkv4TSw71kah9dNnmusKgQ6kYtkXRrTI2mInVps+WK5IPA2358a8ksYj5rE4xoPWHkFx9zSZhyReGaEAObqnmScxyOC4U7O6+B4L26wEbKNLu3shOv9rwGNJuEpkB37Cke7M2WB0728OvDtbU/JtjA5E11BzeNp6Ax2+fOFvC+MGMvSevWynClCPHHT2nmylDPSwOCcGMx6g/q57JifzellyZiU/g448apyvOPqLjdWJLujKS6TyCoYSsLf9FeE6vFMZIPAIZaTR4gkqNvwkooY+t8DeCa+herqzE9IGXG7duDTcPz2sJaW6kgcDWu2IJ9413Lh+puu4DoZmBclBMEyfdFEGHOpjQAnQsrjckgWfbAqWF6+FaKFVf1hTeuhWERcMd9nNw2aRrKCxWia1VJTU0FQEPUG2osZQtEKcCIlcCWB4LurhOl1hs98Rqpu5ghk+cTd/ToGo12tm/Zmd+5O+j8XtgDIwF/NTWiHsPRndZq58MnOxBkOAaYBLYHRAAABFWElGbAAAAE1NACoAAAAQRXhpZk1ldGEABQEaAAUAAAABAAAAUgEbAAUAAAABAAAAWgEoAAMAAAABAAIAAAExAAIAAAAKAAAAYgITAAMAAAABAAEAAAAAAAAAAABIAAAAAQAAAEgAAAABZXpnaWYuY29tAA==", - system=COMPUTER_AGENT_SYSTEM_PROMPT, + system=( + f""" +* You are utilising a {sys.platform} machine using {platform.machine()} architecture with internet access. +* When you cannot find something (application window, ui element etc.) on the currently selected/active displa/screen, check the other available displays by listing them and checking which one is currently active and then going through the other displays one by one until you find it or you have checked all of them. +* When asked to perform web tasks try to open the browser (firefox, chrome, safari, ...) if not already open. Often you can find the browser icons in the toolbars of the operating systems. +* When viewing a page it can be helpful to zoom out/in so that you can see everything on the page. Either that, or make sure you scroll down/up to see everything before deciding something isn't available. +* When using your function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. + + + +* When using Firefox, if a startup wizard appears, IGNORE IT. Do not even click "skip this step". Instead, click on the address bar where it says "Search or enter address", and enter the appropriate search term or URL there. +* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool. +""" + ), tools=[ "computer", "list_displays", @@ -22,12 +31,86 @@ ], ) -ANDROID_AGENT = Assistant( +ANDROID_AGENT_V1 = AssistantV1( id="asst_68ac2c4edc4b2f27faa5a255", - created_at=now(), + created_at=now_v1(), name="Android Agent", avatar="data:image/webp;base64,UklGRoIRAABXRUJQVlA4WAoAAAA4AAAAPwAAPwAASUNDUEgMAAAAAAxITGlubwIQAABtbnRyUkdCIFhZWiAHzgACAAkABgAxAABhY3NwTVNGVAAAAABJRUMgc1JHQgAAAAAAAAAAAAAAAQAA9tYAAQAAAADTLUhQICAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABFjcHJ0AAABUAAAADNkZXNjAAABhAAAAGx3dHB0AAAB8AAAABRia3B0AAACBAAAABRyWFlaAAACGAAAABRnWFlaAAACLAAAABRiWFlaAAACQAAAABRkbW5kAAACVAAAAHBkbWRkAAACxAAAAIh2dWVkAAADTAAAAIZ2aWV3AAAD1AAAACRsdW1pAAAD+AAAABRtZWFzAAAEDAAAACR0ZWNoAAAEMAAAAAxyVFJDAAAEPAAACAxnVFJDAAAEPAAACAxiVFJDAAAEPAAACAx0ZXh0AAAAAENvcHlyaWdodCAoYykgMTk5OCBIZXdsZXR0LVBhY2thcmQgQ29tcGFueQAAZGVzYwAAAAAAAAASc1JHQiBJRUM2MTk2Ni0yLjEAAAAAAAAAAAAAABJzUkdCIElFQzYxOTY2LTIuMQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWFlaIAAAAAAAAPNRAAEAAAABFsxYWVogAAAAAAAAAAAAAAAAAAAAAFhZWiAAAAAAAABvogAAOPUAAAOQWFlaIAAAAAAAAGKZAAC3hQAAGNpYWVogAAAAAAAAJKAAAA+EAAC2z2Rlc2MAAAAAAAAAFklFQyBodHRwOi8vd3d3LmllYy5jaAAAAAAAAAAAAAAAFklFQyBodHRwOi8vd3d3LmllYy5jaAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABkZXNjAAAAAAAAAC5JRUMgNjE5NjYtMi4xIERlZmF1bHQgUkdCIGNvbG91ciBzcGFjZSAtIHNSR0IAAAAAAAAAAAAAAC5JRUMgNjE5NjYtMi4xIERlZmF1bHQgUkdCIGNvbG91ciBzcGFjZSAtIHNSR0IAAAAAAAAAAAAAAAAAAAAAAAAAAAAAZGVzYwAAAAAAAAAsUmVmZXJlbmNlIFZpZXdpbmcgQ29uZGl0aW9uIGluIElFQzYxOTY2LTIuMQAAAAAAAAAAAAAALFJlZmVyZW5jZSBWaWV3aW5nIENvbmRpdGlvbiBpbiBJRUM2MTk2Ni0yLjEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHZpZXcAAAAAABOk/gAUXy4AEM8UAAPtzAAEEwsAA1yeAAAAAVhZWiAAAAAAAEwJVgBQAAAAVx/nbWVhcwAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAAAAAo8AAAACc2lnIAAAAABDUlQgY3VydgAAAAAAAAQAAAAABQAKAA8AFAAZAB4AIwAoAC0AMgA3ADsAQABFAEoATwBUAFkAXgBjAGgAbQByAHcAfACBAIYAiwCQAJUAmgCfAKQAqQCuALIAtwC8AMEAxgDLANAA1QDbAOAA5QDrAPAA9gD7AQEBBwENARMBGQEfASUBKwEyATgBPgFFAUwBUgFZAWABZwFuAXUBfAGDAYsBkgGaAaEBqQGxAbkBwQHJAdEB2QHhAekB8gH6AgMCDAIUAh0CJgIvAjgCQQJLAlQCXQJnAnECegKEAo4CmAKiAqwCtgLBAssC1QLgAusC9QMAAwsDFgMhAy0DOANDA08DWgNmA3IDfgOKA5YDogOuA7oDxwPTA+AD7AP5BAYEEwQgBC0EOwRIBFUEYwRxBH4EjASaBKgEtgTEBNME4QTwBP4FDQUcBSsFOgVJBVgFZwV3BYYFlgWmBbUFxQXVBeUF9gYGBhYGJwY3BkgGWQZqBnsGjAadBq8GwAbRBuMG9QcHBxkHKwc9B08HYQd0B4YHmQesB78H0gflB/gICwgfCDIIRghaCG4IggiWCKoIvgjSCOcI+wkQCSUJOglPCWQJeQmPCaQJugnPCeUJ+woRCicKPQpUCmoKgQqYCq4KxQrcCvMLCwsiCzkLUQtpC4ALmAuwC8gL4Qv5DBIMKgxDDFwMdQyODKcMwAzZDPMNDQ0mDUANWg10DY4NqQ3DDd4N+A4TDi4OSQ5kDn8Omw62DtIO7g8JDyUPQQ9eD3oPlg+zD88P7BAJECYQQxBhEH4QmxC5ENcQ9RETETERTxFtEYwRqhHJEegSBxImEkUSZBKEEqMSwxLjEwMTIxNDE2MTgxOkE8UT5RQGFCcUSRRqFIsUrRTOFPAVEhU0FVYVeBWbFb0V4BYDFiYWSRZsFo8WshbWFvoXHRdBF2UXiReuF9IX9xgbGEAYZRiKGK8Y1Rj6GSAZRRlrGZEZtxndGgQaKhpRGncanhrFGuwbFBs7G2MbihuyG9ocAhwqHFIcexyjHMwc9R0eHUcdcB2ZHcMd7B4WHkAeah6UHr4e6R8THz4faR+UH78f6iAVIEEgbCCYIMQg8CEcIUghdSGhIc4h+yInIlUigiKvIt0jCiM4I2YjlCPCI/AkHyRNJHwkqyTaJQklOCVoJZclxyX3JicmVyaHJrcm6CcYJ0kneierJ9woDSg/KHEooijUKQYpOClrKZ0p0CoCKjUqaCqbKs8rAis2K2krnSvRLAUsOSxuLKIs1y0MLUEtdi2rLeEuFi5MLoIuty7uLyQvWi+RL8cv/jA1MGwwpDDbMRIxSjGCMbox8jIqMmMymzLUMw0zRjN/M7gz8TQrNGU0njTYNRM1TTWHNcI1/TY3NnI2rjbpNyQ3YDecN9c4FDhQOIw4yDkFOUI5fzm8Ofk6Njp0OrI67zstO2s7qjvoPCc8ZTykPOM9Ij1hPaE94D4gPmA+oD7gPyE/YT+iP+JAI0BkQKZA50EpQWpBrEHuQjBCckK1QvdDOkN9Q8BEA0RHRIpEzkUSRVVFmkXeRiJGZ0arRvBHNUd7R8BIBUhLSJFI10kdSWNJqUnwSjdKfUrESwxLU0uaS+JMKkxyTLpNAk1KTZNN3E4lTm5Ot08AT0lPk0/dUCdQcVC7UQZRUFGbUeZSMVJ8UsdTE1NfU6pT9lRCVI9U21UoVXVVwlYPVlxWqVb3V0RXklfgWC9YfVjLWRpZaVm4WgdaVlqmWvVbRVuVW+VcNVyGXNZdJ114XcleGl5sXr1fD19hX7NgBWBXYKpg/GFPYaJh9WJJYpxi8GNDY5dj62RAZJRk6WU9ZZJl52Y9ZpJm6Gc9Z5Nn6Wg/aJZo7GlDaZpp8WpIap9q92tPa6dr/2xXbK9tCG1gbbluEm5rbsRvHm94b9FwK3CGcOBxOnGVcfByS3KmcwFzXXO4dBR0cHTMdSh1hXXhdj52m3b4d1Z3s3gReG54zHkqeYl553pGeqV7BHtje8J8IXyBfOF9QX2hfgF+Yn7CfyN/hH/lgEeAqIEKgWuBzYIwgpKC9INXg7qEHYSAhOOFR4Wrhg6GcobXhzuHn4gEiGmIzokziZmJ/opkisqLMIuWi/yMY4zKjTGNmI3/jmaOzo82j56QBpBukNaRP5GokhGSepLjk02TtpQglIqU9JVflcmWNJaflwqXdZfgmEyYuJkkmZCZ/JpomtWbQpuvnByciZz3nWSd0p5Anq6fHZ+Ln/qgaaDYoUehtqImopajBqN2o+akVqTHpTilqaYapoum/adup+CoUqjEqTepqaocqo+rAqt1q+msXKzQrUStuK4trqGvFq+LsACwdbDqsWCx1rJLssKzOLOutCW0nLUTtYq2AbZ5tvC3aLfguFm40blKucK6O7q1uy67p7whvJu9Fb2Pvgq+hL7/v3q/9cBwwOzBZ8Hjwl/C28NYw9TEUcTOxUvFyMZGxsPHQce/yD3IvMk6ybnKOMq3yzbLtsw1zLXNNc21zjbOts83z7jQOdC60TzRvtI/0sHTRNPG1EnUy9VO1dHWVdbY11zX4Nhk2OjZbNnx2nba+9uA3AXcit0Q3ZbeHN6i3ynfr+A24L3hROHM4lPi2+Nj4+vkc+T85YTmDeaW5x/nqegy6LzpRunQ6lvq5etw6/vshu0R7ZzuKO6070DvzPBY8OXxcvH/8ozzGfOn9DT0wvVQ9d72bfb794r4Gfio+Tj5x/pX+uf7d/wH/Jj9Kf26/kv+3P9t//9BTFBI/gEAAAGQBEm2aVvr2bb1bdu2bdu2bdu2bXNm23q21udZe/fnMCImgP5zmmRu0r5916pB5r+K3/SXacyc8GRljl+j3BX++e1yv0LNd2z0QR48+5Ns/IAbXNlYAQ+BG8bS2z5os0SJVcBM14q4M1jdz6L4CmDTWfwpF1izFNEBe7CAx5IH+QjcbKVkNMHXSTJ2Mwte9g/GphB+0WhDKQ1+gdJxhrY6gmWoXyS4UbKRPV4EbGrm0O5BatSLG+kGnuYh3KBRW3adT2Bhak/CDbzIKq/U8LNBMRvDSqOjHh1q6IRR7K2a7xPPdnHVl6XjEdaZfmRoz9reWjwvs/7kG50sNFhtBWBO6GuizukUBEfUUNcvGYOveavK94ZB0zupmsKwx+3UeN/ACS+gpmMaDrdR4nKWgScraZ+K9HGUpcz6EENvspblfgu135fkYxl6BcndrmPd9pYVDMdKqiMrG4/F42UZn4KdtBVRm3Cs8JIyi2VYPEVGk8HWK1gKdtFZNgksrrJsOBgfsBMNRbvtLhqPdtpWYn8GbQFJA56htRc57AeLKiaiIuEoaenMnDTFUma9DmVZ+/kjW9ewJYX+M95BHAogjXU+AbwrRTpN1+l7Wd1EC4U279p/yMhxk2fNW7xi1Zo1a9auX7d23fpNGzesWbV07oyxfVoVJWSTn5uampiQRlZQOCCaAgAAcA4AnQEqQABAAD6RRp1LpaOioaQYCSCwEglnANAf6AeQHe3XstJbbZcWloaHR/56MDDk1JzbdWf27Soj3JHVF/BoubkvX8Bp3XzyR/oAldRu5TVjCGd9xVdUWuQc0BP47rzGCtAn6zv3wXHtB4bG9CSVD+oEq2pE0c7mAAD+/K1GfqGH49UVwpvCqjsxVhXg5o+M0D1wmzLC6U8M96RzqcqHYYL64nYfkP3NlmcIqUTADl4R/8Ztsmrt/8YGH4qXr2393zDsf2khD8UfXRv/zElSaeuqFGQzG+NapW+EPCWqxeGmGrucoYt0vuhR9z08U7tVe4smQtRrffH8AMzxnDkwsia9Q7ikGBk1Oc51vten/chX3UXHCp/q6ZYPSo4Kha5Xd2ZV4gBJ85FK0+VXYMTTqGW6OiTDUVdeXHR9njW6s1h24eZbsiPlutth9j3Ou5cllsXoG3L55dNM6Q5rptkxCVpvh922fZyK2tsmHdT+TLatnLfIlRXgwM5LFyx68u4BPdZNnN2utj4GjgF0Wa+WKgmRFL3o9gVgHJQ8P6Mmr9OjAjM7lXPSvRQPIYETWU8Cy6AhzyNC1w0NORNEZv5NuiejUtHF1CgehtJTM8JVjcsuccE709lBgCa7EsZLDXV+L2EcRNLXSoFEUVwn7KlWyM9voYncouQo+DZHd4uyaqvRMi0YwBIYNjywWIIGlGJjcxzQvuvEM08A6lEavQk149HoiRFqfwVyaaLGDV95XN/JC9XugBjdw5TPMaBo5GByVDmM9SaBS4iIRwlUGiSKRerUrlX5Y/pFCVsqbIsZ7a9YX+4mpCKphVCD7+mZyvqTtNnyqH32GRoJBGV/p5+HaWMuX9k4kyc1uy4r6+/36lLEkyFwQAAARVhJRmwAAABNTQAqAAAAEEV4aWZNZXRhAAUBGgAFAAAAAQAAAFIBGwAFAAAAAQAAAFoBKAADAAAAAQACAAABMQACAAAACgAAAGICEwADAAAAAQABAAAAAAAAAAAASAAAAAEAAABIAAAAAWV6Z2lmLmNvbQA=", - system=ANDROID_AGENT_SYSTEM_PROMPT, + system=( + """ + +You are an autonomous Android device control agent operating via ADB on a test device with full system access. +Your primary goal is to execute tasks efficiently and reliably while maintaining system stability. + + + +* Autonomy: Operate independently and make informed decisions without requiring user input. +* Never ask for other tasks to be done, only do the task you are given. +* Reliability: Ensure actions are repeatable and maintain system stability. +* Efficiency: Optimize operations to minimize latency and resource usage. +* Safety: Always verify actions before execution, even with full system access. + + + +1. Tool Usage: + * Verify tool availability before starting any operation + * Use the most direct and efficient tool for each task + * Combine tools strategically for complex operations + * Prefer built-in tools over shell commands when possible + +2. Error Handling: + * Assess failures systematically: check tool availability, permissions, and device state + * Implement retry logic with exponential backoff for transient failures + * Use fallback strategies when primary approaches fail + * Provide clear, actionable error messages with diagnostic information + +3. Performance Optimization: + * Use one-liner shell commands with inline filtering (grep, cut, awk, jq) for efficiency + * Minimize screen captures and coordinate calculations + * Cache device state information when appropriate + * Batch related operations when possible + +4. Screen Interaction: + * Ensure all coordinates are integers and within screen bounds + * Implement smart scrolling for off-screen elements + * Use appropriate gestures (tap, swipe, drag) based on context + * Verify element visibility before interaction + +5. System Access: + * Leverage full system access responsibly + * Use shell commands for system-level operations + * Monitor system state and resource usage + * Maintain system stability during operations + +6. Recovery Strategies: + * If an element is not visible, try: + - Scrolling in different directions + - Adjusting view parameters + - Using alternative interaction methods + * If a tool fails: + - Check device connection and state + - Verify tool availability and permissions + - Try alternative tools or approaches + * If stuck: + - Provide clear diagnostic information + - Suggest potential solutions + - Request user intervention only if necessary + +7. Best Practices: + * Document all significant operations + * Maintain operation logs for debugging + * Implement proper cleanup after operations + * Follow Android best practices for UI interaction + + +* This is a test device with full system access - use this capability responsibly +* Always verify the success of critical operations +* Maintain system stability as the highest priority +* Provide clear, actionable feedback for all operations +* Use the most efficient method for each task + +""" + ), tools=[ "android_screenshot_tool", "android_tap_tool", @@ -47,12 +130,21 @@ ], ) -WEB_AGENT = Assistant( +WEB_AGENT_V1 = AssistantV1( id="asst_68ac2c4edc4b2f27faa5a256", - created_at=now(), + created_at=now_v1(), name="Web Agent", avatar="data:image/webp;base64,UklGRj4SAABXRUJQVlA4WAoAAAA4AAAAPwAAPwAASUNDUEgMAAAAAAxITGlubwIQAABtbnRyUkdCIFhZWiAHzgACAAkABgAxAABhY3NwTVNGVAAAAABJRUMgc1JHQgAAAAAAAAAAAAAAAQAA9tYAAQAAAADTLUhQICAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABFjcHJ0AAABUAAAADNkZXNjAAABhAAAAGx3dHB0AAAB8AAAABRia3B0AAACBAAAABRyWFlaAAACGAAAABRnWFlaAAACLAAAABRiWFlaAAACQAAAABRkbW5kAAACVAAAAHBkbWRkAAACxAAAAIh2dWVkAAADTAAAAIZ2aWV3AAAD1AAAACRsdW1pAAAD+AAAABRtZWFzAAAEDAAAACR0ZWNoAAAEMAAAAAxyVFJDAAAEPAAACAxnVFJDAAAEPAAACAxiVFJDAAAEPAAACAx0ZXh0AAAAAENvcHlyaWdodCAoYykgMTk5OCBIZXdsZXR0LVBhY2thcmQgQ29tcGFueQAAZGVzYwAAAAAAAAASc1JHQiBJRUM2MTk2Ni0yLjEAAAAAAAAAAAAAABJzUkdCIElFQzYxOTY2LTIuMQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWFlaIAAAAAAAAPNRAAEAAAABFsxYWVogAAAAAAAAAAAAAAAAAAAAAFhZWiAAAAAAAABvogAAOPUAAAOQWFlaIAAAAAAAAGKZAAC3hQAAGNpYWVogAAAAAAAAJKAAAA+EAAC2z2Rlc2MAAAAAAAAAFklFQyBodHRwOi8vd3d3LmllYy5jaAAAAAAAAAAAAAAAFklFQyBodHRwOi8vd3d3LmllYy5jaAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABkZXNjAAAAAAAAAC5JRUMgNjE5NjYtMi4xIERlZmF1bHQgUkdCIGNvbG91ciBzcGFjZSAtIHNSR0IAAAAAAAAAAAAAAC5JRUMgNjE5NjYtMi4xIERlZmF1bHQgUkdCIGNvbG91ciBzcGFjZSAtIHNSR0IAAAAAAAAAAAAAAAAAAAAAAAAAAAAAZGVzYwAAAAAAAAAsUmVmZXJlbmNlIFZpZXdpbmcgQ29uZGl0aW9uIGluIElFQzYxOTY2LTIuMQAAAAAAAAAAAAAALFJlZmVyZW5jZSBWaWV3aW5nIENvbmRpdGlvbiBpbiBJRUM2MTk2Ni0yLjEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHZpZXcAAAAAABOk/gAUXy4AEM8UAAPtzAAEEwsAA1yeAAAAAVhZWiAAAAAAAEwJVgBQAAAAVx/nbWVhcwAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAAAAAo8AAAACc2lnIAAAAABDUlQgY3VydgAAAAAAAAQAAAAABQAKAA8AFAAZAB4AIwAoAC0AMgA3ADsAQABFAEoATwBUAFkAXgBjAGgAbQByAHcAfACBAIYAiwCQAJUAmgCfAKQAqQCuALIAtwC8AMEAxgDLANAA1QDbAOAA5QDrAPAA9gD7AQEBBwENARMBGQEfASUBKwEyATgBPgFFAUwBUgFZAWABZwFuAXUBfAGDAYsBkgGaAaEBqQGxAbkBwQHJAdEB2QHhAekB8gH6AgMCDAIUAh0CJgIvAjgCQQJLAlQCXQJnAnECegKEAo4CmAKiAqwCtgLBAssC1QLgAusC9QMAAwsDFgMhAy0DOANDA08DWgNmA3IDfgOKA5YDogOuA7oDxwPTA+AD7AP5BAYEEwQgBC0EOwRIBFUEYwRxBH4EjASaBKgEtgTEBNME4QTwBP4FDQUcBSsFOgVJBVgFZwV3BYYFlgWmBbUFxQXVBeUF9gYGBhYGJwY3BkgGWQZqBnsGjAadBq8GwAbRBuMG9QcHBxkHKwc9B08HYQd0B4YHmQesB78H0gflB/gICwgfCDIIRghaCG4IggiWCKoIvgjSCOcI+wkQCSUJOglPCWQJeQmPCaQJugnPCeUJ+woRCicKPQpUCmoKgQqYCq4KxQrcCvMLCwsiCzkLUQtpC4ALmAuwC8gL4Qv5DBIMKgxDDFwMdQyODKcMwAzZDPMNDQ0mDUANWg10DY4NqQ3DDd4N+A4TDi4OSQ5kDn8Omw62DtIO7g8JDyUPQQ9eD3oPlg+zD88P7BAJECYQQxBhEH4QmxC5ENcQ9RETETERTxFtEYwRqhHJEegSBxImEkUSZBKEEqMSwxLjEwMTIxNDE2MTgxOkE8UT5RQGFCcUSRRqFIsUrRTOFPAVEhU0FVYVeBWbFb0V4BYDFiYWSRZsFo8WshbWFvoXHRdBF2UXiReuF9IX9xgbGEAYZRiKGK8Y1Rj6GSAZRRlrGZEZtxndGgQaKhpRGncanhrFGuwbFBs7G2MbihuyG9ocAhwqHFIcexyjHMwc9R0eHUcdcB2ZHcMd7B4WHkAeah6UHr4e6R8THz4faR+UH78f6iAVIEEgbCCYIMQg8CEcIUghdSGhIc4h+yInIlUigiKvIt0jCiM4I2YjlCPCI/AkHyRNJHwkqyTaJQklOCVoJZclxyX3JicmVyaHJrcm6CcYJ0kneierJ9woDSg/KHEooijUKQYpOClrKZ0p0CoCKjUqaCqbKs8rAis2K2krnSvRLAUsOSxuLKIs1y0MLUEtdi2rLeEuFi5MLoIuty7uLyQvWi+RL8cv/jA1MGwwpDDbMRIxSjGCMbox8jIqMmMymzLUMw0zRjN/M7gz8TQrNGU0njTYNRM1TTWHNcI1/TY3NnI2rjbpNyQ3YDecN9c4FDhQOIw4yDkFOUI5fzm8Ofk6Njp0OrI67zstO2s7qjvoPCc8ZTykPOM9Ij1hPaE94D4gPmA+oD7gPyE/YT+iP+JAI0BkQKZA50EpQWpBrEHuQjBCckK1QvdDOkN9Q8BEA0RHRIpEzkUSRVVFmkXeRiJGZ0arRvBHNUd7R8BIBUhLSJFI10kdSWNJqUnwSjdKfUrESwxLU0uaS+JMKkxyTLpNAk1KTZNN3E4lTm5Ot08AT0lPk0/dUCdQcVC7UQZRUFGbUeZSMVJ8UsdTE1NfU6pT9lRCVI9U21UoVXVVwlYPVlxWqVb3V0RXklfgWC9YfVjLWRpZaVm4WgdaVlqmWvVbRVuVW+VcNVyGXNZdJ114XcleGl5sXr1fD19hX7NgBWBXYKpg/GFPYaJh9WJJYpxi8GNDY5dj62RAZJRk6WU9ZZJl52Y9ZpJm6Gc9Z5Nn6Wg/aJZo7GlDaZpp8WpIap9q92tPa6dr/2xXbK9tCG1gbbluEm5rbsRvHm94b9FwK3CGcOBxOnGVcfByS3KmcwFzXXO4dBR0cHTMdSh1hXXhdj52m3b4d1Z3s3gReG54zHkqeYl553pGeqV7BHtje8J8IXyBfOF9QX2hfgF+Yn7CfyN/hH/lgEeAqIEKgWuBzYIwgpKC9INXg7qEHYSAhOOFR4Wrhg6GcobXhzuHn4gEiGmIzokziZmJ/opkisqLMIuWi/yMY4zKjTGNmI3/jmaOzo82j56QBpBukNaRP5GokhGSepLjk02TtpQglIqU9JVflcmWNJaflwqXdZfgmEyYuJkkmZCZ/JpomtWbQpuvnByciZz3nWSd0p5Anq6fHZ+Ln/qgaaDYoUehtqImopajBqN2o+akVqTHpTilqaYapoum/adup+CoUqjEqTepqaocqo+rAqt1q+msXKzQrUStuK4trqGvFq+LsACwdbDqsWCx1rJLssKzOLOutCW0nLUTtYq2AbZ5tvC3aLfguFm40blKucK6O7q1uy67p7whvJu9Fb2Pvgq+hL7/v3q/9cBwwOzBZ8Hjwl/C28NYw9TEUcTOxUvFyMZGxsPHQce/yD3IvMk6ybnKOMq3yzbLtsw1zLXNNc21zjbOts83z7jQOdC60TzRvtI/0sHTRNPG1EnUy9VO1dHWVdbY11zX4Nhk2OjZbNnx2nba+9uA3AXcit0Q3ZbeHN6i3ynfr+A24L3hROHM4lPi2+Nj4+vkc+T85YTmDeaW5x/nqegy6LzpRunQ6lvq5etw6/vshu0R7ZzuKO6070DvzPBY8OXxcvH/8ozzGfOn9DT0wvVQ9d72bfb794r4Gfio+Tj5x/pX+uf7d/wH/Jj9Kf26/kv+3P9t//9BTFBIxAIAAAGQRtvWGUkv+pKqGtu2bdu2bdu2bdu2bdu2PVN8o3z58lp/I2ICIKwrqUHJKNSh5/x1Q5uklYJI0iVf8L9PWmlBIvdZNP8x1hUE8txFXm9/jZy2Bvm/1SeX/KkFvByHWgWPFV9dag0CVnAcGVZvYL1sLqjus9SXTMHP6P+wukCXgKUuVNgq/PfdK7T6rSiVjK/+I3CWTkHS42bZhoI/ZgWCsWbuvf0uIGprBAJsMtoYqAsEu7jt+F6QQq/PduyPTEErOHrtV46Ph9w8t7IBUVZ+2JxHF3bfvLdpY4lGPg5fcyCsJIvK4idlDPoj5+FIlMzVtRzfagDtKIliawBZnnNMUf6RJCLh2l14fGd16+mX0HxfHAAAV4zokRmPLglK1PvwT7R8vboGAFFU2YgcSQaQnTGiR0qaIoqY/BdR6NdVk+dNbJEEACJmqdKwdO60aXLkSJWKKQKK3EEb/SdyRmm65ZHX+/bC0uklwZk2icslW6n8HO29c9SHptfysAjxIjmcFiKfQcKP6xssAQvPpzT9SQlf1w6fJKKTyzHiK9K+VTx1Fp2rixeJ+8flyergKfkGye+rmU/mSH4T6X9olZ2nCQbBz7UyMo6ib4PAm6LZDI6Kr4PA23KZVTN5CQbBK9VSgbk0LyjMLB+NA6p+o/erfTaNJ9IZeoHuCSSemNfJuXGwAbyJH1J78OJTPuCWenlp+Z/hUsYHjjY3vJQCb4+kB8txay3fvPMXlaVZY4BIRcnwjsaL8YkVEJz7CwX3rLQ6CK/spTBHBxtbIcEH2cHOXhSGgK1DKYyzZzwBb42gdi62PX0pRBEXNVedeS8JfCwizLHkBxL0X5uWUZhea+jkhatWLV244eYPn1+Qe8+oMvFUYf+VZVlxJStetVaLXkPGTZu/fM2G7Tu2b9+2Yc3yuRMHtisXBahLssoYY5oqSxCMA1ZQOCCQAgAAcA8AnQEqQABAAD6RPphIpaOiISwaqzCwEglpAM2MQa+QDgh8l4QY2S6dDRb5+9STFnsT53G3bI/tXlpji6Mj5+V4r8cZSRtBry/tRIinOXo63rB7Wi1GyNMYy0rV4zkAU3LN+pO75fM/xafU+nBCBpTvibygmDy4lufaLt2Cp6xDvs9QAP730H//wM///AnP//v9b9+fIdAd+JCJlHhg2L+U0BbHAXsiYjhfXh7j3RBU8mujTpmfy/iYn8qP5zA9sxPC+Wt0IcccpBvl06y7EXh+UwyO7zw/dX8rMfRA99qDkw/+DJX2N5E/ZNYaetvg9X5Uq4XOKgyyDyqvPk1PMTTW6Xfl7hWFRraHtG6ksVUrLJt0WpfmPgTcxoJMCHEBjW/nGcZtIwxA0WMdkQKqKg3CxjG7yGxJ7qGwSpn8RIZ5F9k5MdjCajiKikzYxesh9luS0ctptAe1OvnW6qCh5sQIXAfXRRgxoetREPXrGOxuAOGJEuE8IEhKedRig4t2wg118Rzh32C4c9OQzS2Kq1/08Ar9X++rO3p3/hQw7U0V9hIbdJQXR7eYdjTvJ4jr2/kx9H6cwfxVSD7+GdaelQvLzcYILjm7/zzq46MAgWWdG19VfZ2sCFDqcBkyeE+/6Ic/Hc/FZcTRUT8pdAtmtRKSZtiXUrBHcaraz14JCCH78fwTvJKYvSFjpaVKCmuWh/8D763znzj5fJdGpfnVK3dEdgY6UyLG+9rPQ4XVyWyoMehoQhZhV+bhJH5FX0ZhMh9wopvx/RND6OG5dgkwS0i4UFiIHafJi2Kr7jK4CPZiELMgRZ960JBoimrcNQn66ksmytF8pVSiB6ajkoAYMvhq0Gz/4/+Ss4kOiDkAAABFWElGbAAAAE1NACoAAAAQRXhpZk1ldGEABQEaAAUAAAABAAAAUgEbAAUAAAABAAAAWgEoAAMAAAABAAIAAAExAAIAAAAKAAAAYgITAAMAAAABAAEAAAAAAAAAAABIAAAAAQAAAEgAAAABZXpnaWYuY29tAA==", - system=WEB_AGENT_SYSTEM_PROMPT, + system=( + """ + +* You are utilizing a webbrowser in full-screen mode. So you are only seeing the content of the currently opened webpage (tab). +* It can be helpful to zoom in/out or scroll down/up so that you can see everything on the page. Make sure to that before deciding something isn't available. +* When using your tools, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. +* If a tool call returns with an error that a browser distribution is not found, stop, so that the user can install it and, then, continue the conversation. + +""" + ), tools=[ "browser_click", "browser_close", @@ -85,12 +177,107 @@ ], ) -TESTING_AGENT = Assistant( +TESTING_AGENT_V1 = AssistantV1( id="asst_68ac2c4edc4b2f27faa5a257", - created_at=now(), + created_at=now_v1(), name="Testing Agent", avatar="data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHhtbG5zOnhsaW5rPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5L3hsaW5rIiB2aWV3Qm94PSIwIDAgMjcgMjciIGFyaWEtaGlkZGVuPSJ0cnVlIiByb2xlPSJpbWciIGNsYXNzPSJpY29uaWZ5IGljb25pZnktLXR3ZW1vamkiIHByZXNlcnZlQXNwZWN0UmF0aW89InhNaWRZTWlkIG1lZXQiPjxwYXRoIGZpbGw9IiNDQ0Q2REQiIGQ9Ik0xMC45MjIgMTAuODEgMTkuMTAyIDIuNjI5bDUuMjIxIDUuMjIxIC04LjE4MSA4LjE4MXoiLz48cGF0aCBmaWxsPSIjNjhFMDkwIiBkPSJNNi4wNzcgMjUuNzk5QzEuODc1IDI1LjUgMS4xMjUgMjIuNTQ3IDEuMjI2IDIwLjk0OWMwLjI0MSAtMy44MDMgMTEuNzAxIC0xMi40MTMgMTEuNzAxIC0xMi40MTNsOS4zODggMS40NDhjMC4wMDEgMCAtMTMuMDQyIDE2LjA0NCAtMTYuMjM3IDE1LjgxNiIvPjxwYXRoIGZpbGw9IiM4ODk5QTYiIGQ9Ik0yNC4yNDUgMi43ODFDMjIuMDU0IDAuNTkgMTkuNTc4IC0wLjQ4NyAxOC43MTUgMC4zNzdjLTAuMDEgMC4wMSAtMC4wMTcgMC4wMjMgLTAuMDI2IDAuMDMzIC0wLjAwNSAwLjAwNSAtMC4wMTEgMC4wMDYgLTAuMDE2IDAuMDExTDEuNzIxIDE3LjM3M2E1LjU3MiA1LjU3MiAwIDAgMCAtMS42NDMgMy45NjZjMCAxLjQ5OCAwLjU4NCAyLjkwNiAxLjY0MyAzLjk2NWE1LjU3MiA1LjU3MiAwIDAgMCAzLjk2NiAxLjY0MyA1LjU3MiA1LjU3MiAwIDAgMCAzLjk2NSAtMS42NDJsMTYuOTUzIC0xNi45NTNjMC4wMDUgLTAuMDA1IDAuMDA3IC0wLjAxMiAwLjAxMSAtMC4wMTcgMC4wMSAtMC4wMDkgMC4wMjIgLTAuMDE1IDAuMDMyIC0wLjAyNSAwLjg2MyAtMC44NjIgLTAuMjE0IC0zLjMzOCAtMi40MDUgLTUuNTI5TTguMDYzIDIzLjcxNGMtMC42MzQgMC42MzQgLTEuNDc4IDAuOTgzIC0yLjM3NCAwLjk4M3MtMS43NDEgLTAuMzUgLTIuMzc1IC0wLjk4NGEzLjMzOCAzLjMzOCAwIDAgMSAtMC45ODQgLTIuMzc1YzAgLTAuODk3IDAuMzUgLTEuNzQgMC45ODMgLTIuMzc0TDE5LjA1OSAzLjIxOGMwLjQ2NyAwLjg1OCAxLjE3IDEuNzk2IDIuMDYyIDIuNjg4czEuODMgMS41OTUgMi42ODggMi4wNjJ6Ii8+PHBhdGggZmlsbD0iIzE3QkY2MyIgZD0iTTIxLjg5NyA5Ljg1OGMtMC4wNDQgMC4yODQgLTEuOTcgMC41NjMgLTQuMjY4IDAuMjU3cy00LjExMiAtMC45MTcgLTQuMDUyIC0xLjM2NSAxLjk3IC0wLjU2MyA0LjI2OCAtMC4yNTcgNC4xMjEgMC45MTggNC4wNTIgMS4zNjVNOC4xMyAxNy40MzVhMC41OTYgMC41OTYgMCAxIDEgLTAuODQyIC0wLjg0MyAwLjU5NiAwLjU5NiAwIDAgMSAwLjg0MiAwLjg0M20yLjQ4OCAxLjk2MWEwLjk3NCAwLjk3NCAwIDEgMSAtMS4zNzYgLTEuMzc3IDAuOTc0IDAuOTc0IDAgMCAxIDEuMzc2IDEuMzc3bTEuMjU4IC0zLjk5M2EwLjkxNiAwLjkxNiAwIDAgMSAtMS4yOTQgLTEuMjk0IDAuOTE1IDAuOTE1IDAgMSAxIDEuMjk0IDEuMjk0bS01LjE1MSA2LjY0NGExLjExNyAxLjExNyAwIDEgMSAtMS41NzkgLTEuNTc5IDEuMTE3IDEuMTE3IDAgMCAxIDEuNTc5IDEuNTc5bTguNTQ3IC02Ljg2OGEwLjc5NCAwLjc5NCAwIDEgMSAtMS4xMjIgLTEuMTIzIDAuNzk0IDAuNzk0IDAgMCAxIDEuMTIyIDEuMTIzbS0wLjkwNSAtMy4yMTZhMC41MiAwLjUyIDAgMSAxIC0wLjczNCAtMC43MzUgMC41MiAwLjUyIDAgMCAxIDAuNzM0IDAuNzM1Ii8+PHBhdGggdHJhbnNmb3JtPSJyb3RhdGUoLTQ1LjAwMSAzMC44MTcgNS4yMjMpIiBmaWxsPSIjQ0NENkREIiBjeD0iMzAuODE3IiBjeT0iNS4yMjMiIHJ4PSIxLjE4NCIgcnk9IjQuODQ3IiBkPSJNMjQuMDAxIDMuOTE3QTAuODg4IDMuNjM1IDAgMCAxIDIzLjExMyA3LjU1M0EwLjg4OCAzLjYzNSAwIDAgMSAyMi4yMjUgMy45MTdBMC44ODggMy42MzUgMCAwIDEgMjQuMDAxIDMuOTE3eiIvPjwvc3ZnPg==", - system=TESTING_AGENT_SYSTEM_PROMPT, + system=( + """ +You are an advanced AI testing agent responsible for managing and executing software tests. Your primary goal is to create, refine, and execute test scenarios based on given specifications or targets. You have access to various tools and subagents to accomplish this task. + +Available tools: +1. Feature management: retrieve, list, modify, create, delete +2. Scenario management: retrieve, list, modify, create, delete +3. Execution management: retrieve, list, modify, create, delete +4. Tools for executing tests using subagents: + - create_thread_and_run_v1_runs_post: Delegate tasks to subagents + - retrieve_run_v1_threads: Check the status of a run + - list_messages_v1_threads: Retrieve messages from a thread + - utility_wait: Wait for a specified number of seconds + +Subagents: +1. Computer control agent (ID: asst_68ac2c4edc4b2f27faa5a253) +2. Web browser control agent (ID: asst_68ac2c4edc4b2f27faa5a256) + +Main process: +1. Analyze test specification +2. Create and refine features if necessary by exploring the features (exploratory testing) +3. Create and refine scenarios if necessary by exploring the scenarios (exploratory testing) +4. Execute scenarios +5. Report results +6. Handle user feedback + +Detailed instructions: + +1. Analyze the test specification: + +{TEST_SPECIFICATION} + + +Review the provided test specification carefully. Identify the key features, functionalities, or areas that need to be tested. +Instead of a test specification, the user may also provide just the testing target (feature, url, application name etc.). Make +sure that you ask the user if it is a webapp or desktop app or where to find the app in general if not clear from the specification. + +2. Create and refine features: +a. Use the feature management tools to list existing features. +b. Create new features based on user input and if necessary exploring the features in the actual application using a subagent, ensuring no duplicates. +c. Present the features to the user and wait for feedback. +d. Refine the features based on user feedback until confirmation is received. + +3. Create and refine scenarios: +a. For each confirmed feature, use the scenario management tools to list existing scenarios. +b. Create new scenarios using Gherkin syntax, ensuring no duplicates. +c. Present the scenarios to the user and wait for feedback. +d. Refine the scenarios based on user feedback until confirmation is received. + +4. Execute scenarios: +a. Determine whether to use the computer control agent or web browser control agent (prefer web browser if possible). +b. Create and run a thread with the chosen subagent with a user message that contains the commands (scenario) to be executed. Set `stream` to `false` to wait for the agent to complete. +c. Use the retrieve_run_v1_threads tool to check the status of the task and the utility_wait tool for it to complete with an exponential backoff starting with 5 seconds increasing. +d. Collect and analyze the responses from the agent using the list_messages_v1_threads tool. Usually, you only need the last message within the thread (`limit=1`) which contains a summary of the execution results. If you need more details, you can use a higher limit and potentially multiple calls to the tool. + +5. Report results: +a. Use the execution management tools to create new execution records. +b. Update the execution records with the results (passed, failed, etc.). +c. Present a summary of the execution results to the user. + +6. Handle user feedback: +a. Review user feedback on the executions. +b. Based on feedback, determine whether to restart the process, modify existing tests, or perform other actions. + +Handling user commands: +Respond appropriately to user commands, such as: + +{USER_COMMAND} + + +- Execute existing scenarios +- List all available features +- Modify specific features or scenarios +- Delete features or scenarios + +Output format (for none tool calls): +``` +[Your detailed response, including any necessary explanations, lists, or summaries] + +**Next Actions**: +[Clearly state the next actions you will take or the next inputs you require from the user] + +``` + +Important reminders: +1. Always check for existing features and scenarios before creating new ones to avoid duplicates. +2. Use Gherkin syntax when creating or modifying scenarios. +3. Prefer the web browser control agent for test execution when possible. +4. Always wait for user confirmation before proceeding to the next major step in the process. +5. Be prepared to restart the process or modify existing tests based on user feedback. +6. Use tags for organizing the features and scenarios describing what is being tested and how it is being tested. +7. Prioritize sunny cases and critical features/scenarios first if not specified otherwise by the user. + +Your final output should only include the content within the and tags. Do not include any other tags or internal thought processes in your final output. +""" + ), tools=[ "create_feature", "retrieve_feature", @@ -114,12 +301,76 @@ ], ) -ORCHESTRATOR_AGENT = Assistant( +ORCHESTRATOR_AGENT_V1 = AssistantV1( id="asst_68ac2c4edc4b2f27faa5a258", - created_at=now(), + created_at=now_v1(), name="Orchestrator", avatar="data:image/svg+xml;base64,PHN2ZyAgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIgogIHdpZHRoPSIyNCIKICBoZWlnaHQ9IjI0IgogIHZpZXdCb3g9IjAgMCAyNCAyNCIKICBmaWxsPSJub25lIgogIHN0cm9rZT0iIzAwMCIgc3R5bGU9ImJhY2tncm91bmQtY29sb3I6ICNmZmY7IGJvcmRlci1yYWRpdXM6IDJweCIKICBzdHJva2Utd2lkdGg9IjIiCiAgc3Ryb2tlLWxpbmVjYXA9InJvdW5kIgogIHN0cm9rZS1saW5lam9pbj0icm91bmQiCj4KICA8cGF0aCBkPSJNMTIgOFY0SDgiIC8+CiAgPHJlY3Qgd2lkdGg9IjE2IiBoZWlnaHQ9IjEyIiB4PSI0IiB5PSI4IiByeD0iMiIgLz4KICA8cGF0aCBkPSJNMiAxNGgyIiAvPgogIDxwYXRoIGQ9Ik0yMCAxNGgyIiAvPgogIDxwYXRoIGQ9Ik0xNSAxM3YyIiAvPgogIDxwYXRoIGQ9Ik05IDEzdjIiIC8+Cjwvc3ZnPgo=", - system=ORCHESTRATOR_AGENT_SYSTEM_PROMPT, + system=( + """ +You are an AI agent called "Orchestrator" with the ID "asst_68ac2c4edc4b2f27faa5a258". Your primary role is to perform high-level planning and management of all tasks involved in responding to a given prompt. For simple prompts, you will respond directly. For more complex, you will delegate and route the execution of these tasks to other specialized agents. + +You have the following tools at your disposal: + +1. list_assistants_v1_assistants_get + This tool enables you to discover all available assistants (agents) for task delegation. + +2. create_thread_and_run_v1_runs_post + This tool enables you to delegate tasks to other agents by starting a conversation (thread) with initial messages containing necessary instructions, and then running (calling/executing) the agent to get a response. The "stream" parameter should always be set to "false". + +3. retrieve_run_v1_threads + This tool enables you to retrieve the details of a run by its ID and, by that, checking wether an assistant is still answering or completed its answer (`status` field). + +4. list_messages_v1_threads + This tool enables you to retrieve the messages of the assistant. Depending on the prompt, you may only need the last message within the thread (`limit=1`) or the whole thread using a higher limit and potentially multiple calls to the tool. + +5. utility_wait + This tool enables you to wait for a specified number of seconds, e.g. to wait for an agent to finish its task / complete its answer. + +Your main task is to analyze the user prompt and classify it as simple vs. complex. For simple prompts, respond directly. For complex prompts, create a comprehensive plan to address it by utilizing the available agents. + +Follow these steps to complete your task: + +1. Analyze the user prompt and identify the main components or subtasks required to provide a complete response. + +2. Use the list_assistants_v1_assistants_get tool to discover all available agents. + +3. Create a plan that outlines how you will delegate these subtasks to the most appropriate agents based on their specialties. + +4. For each subtask: + a. Prepare clear and concise instructions for the chosen agent. + b. Use the create_thread_and_run_v1_runs_post tool to delegate the task to the agent. + c. Include all necessary context and information in the initial messages. + d. Set the "stream" parameter to "true". + +5. Use the retrieve_run_v1_threads tool to check the status of the task and the utility_wait tool for it to complete with an exponential backoff starting with 5 seconds increasing. + +5. Collect and analyze the responses from each agent using the list_messages_v1_threads tool. + +6. Synthesize the information from all agents into a coherent and comprehensive response to the original user prompt. + +Present your final output should be eitehr in the format of + +[Simple answer] + +or + +[ +# Plan +[Provide a detailed plan outlining the subtasks and the agents assigned to each] + +# Report +[For each agent interaction, include: +1. The agent's ID and specialty +2. The subtask assigned +3. A summary of the instructions given +4. A brief summary of the agent's response] + +# Answer +[Synthesize all the information into a cohesive response to the original user prompt] +] +""" + ), tools=[ "list_assistants_v1_assistants_get", "create_thread_and_run_v1_runs_post", @@ -129,8 +380,8 @@ ], ) -SEEDS = [ - COMPUTER_AGENT, - ANDROID_AGENT, - WEB_AGENT, +SEEDS_V1 = [ + COMPUTER_AGENT_V1, + ANDROID_AGENT_V1, + WEB_AGENT_V1, ] diff --git a/src/askui/chat/migrations/shared/models.py b/src/askui/chat/migrations/shared/models.py new file mode 100644 index 00000000..750d9d87 --- /dev/null +++ b/src/askui/chat/migrations/shared/models.py @@ -0,0 +1,13 @@ +from typing import Annotated +from uuid import UUID + +from pydantic import AwareDatetime, PlainSerializer + +UnixDatetimeV1 = Annotated[ + AwareDatetime, + PlainSerializer( + lambda v: int(v.timestamp()), + return_type=int, + ), +] +WorkspaceIdV1 = UUID diff --git a/src/askui/chat/migrations/shared/settings.py b/src/askui/chat/migrations/shared/settings.py new file mode 100644 index 00000000..c7671a1a --- /dev/null +++ b/src/askui/chat/migrations/shared/settings.py @@ -0,0 +1,29 @@ +from pathlib import Path +from typing import Annotated +from uuid import UUID + +from pydantic import AwareDatetime, Field, PlainSerializer +from pydantic_settings import BaseSettings, SettingsConfigDict + +# Local models to avoid dependencies on askui.chat.api +UnixDatetime = Annotated[ + AwareDatetime, + PlainSerializer( + lambda v: int(v.timestamp()), + return_type=int, + ), +] + +AssistantId = Annotated[str, Field(pattern=r"^asst_[a-z0-9]+$")] +WorkspaceId = UUID + + +class SettingsV1(BaseSettings): + model_config = SettingsConfigDict( + env_prefix="ASKUI__CHAT_API__", env_nested_delimiter="__" + ) + + data_dir: Path = Field( + default_factory=lambda: Path.cwd() / "chat", + description="Base directory for chat data (used during migration)", + ) diff --git a/src/askui/chat/migrations/shared/utils.py b/src/askui/chat/migrations/shared/utils.py new file mode 100644 index 00000000..dc4c8c9e --- /dev/null +++ b/src/askui/chat/migrations/shared/utils.py @@ -0,0 +1,7 @@ +from datetime import datetime, timezone + +from pydantic import AwareDatetime + + +def now_v1() -> AwareDatetime: + return datetime.now(tz=timezone.utc) diff --git a/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py b/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py new file mode 100644 index 00000000..947767a6 --- /dev/null +++ b/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py @@ -0,0 +1,109 @@ +"""import_json_assistants + +Revision ID: 057f82313448 +Revises: 4d1e043b4254 +Create Date: 2025-10-10 11:21:55.527341 + +""" + +import json +import logging +from typing import Sequence, Union + +from alembic import op +from sqlalchemy import MetaData, Table + +from askui.chat.migrations.shared.assistants.models import AssistantV1 +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "057f82313448" +down_revision: Union[str, None] = "4d1e043b4254" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + + +BATCH_SIZE = 100 + + +def _insert_assistants_batch( + assistants_table: Table, assistants_batch: list[AssistantV1] +) -> None: + """Insert a batch of assistants into the database.""" + op.bulk_insert( + assistants_table, + [assistant.to_db_dict() for assistant in assistants_batch], + ) + + +settings = SettingsV1() +assistants_dir = settings.data_dir / "assistants" + + +def upgrade() -> None: + """Import existing assistants from JSON files.""" + + # Skip if directory doesn't exist (e.g., first-time setup) + if not assistants_dir.exists(): + return + + # Get the table from the current database schema + connection = op.get_bind() + assistants_table = Table("assistants", MetaData(), autoload_with=connection) + + # Get all JSON files in the assistants directory + json_files = list(assistants_dir.glob("*.json")) + + # Process assistants in batches + assistants_batch: list[AssistantV1] = [] + + for json_file in json_files: + try: + content = json_file.read_text(encoding="utf-8").strip() + data = json.loads(content) + assistant = AssistantV1.model_validate(data) + assistants_batch.append(assistant) + + if len(assistants_batch) >= BATCH_SIZE: + _insert_assistants_batch(assistants_table, assistants_batch) + assistants_batch.clear() + except Exception: # noqa: PERF203 + error_msg = "Failed to import" + logger.exception(error_msg, extra={"json_file": str(json_file)}) + continue + + # Insert remaining assistants in the final batch + if assistants_batch: + _insert_assistants_batch(assistants_table, assistants_batch) + + +def downgrade() -> None: + """Recreate JSON files for assistants during downgrade.""" + + assistants_dir.mkdir(parents=True, exist_ok=True) + + connection = op.get_bind() + assistants_table = Table("assistants", MetaData(), autoload_with=connection) + + # Fetch all assistants from the database + result = connection.execute(assistants_table.select()) + rows = result.fetchall() + if not rows: + return + + for row in rows: + try: + assistant: AssistantV1 = AssistantV1.model_validate( + row, from_attributes=True + ) + json_path = assistants_dir / f"{assistant.id}.json" + if json_path.exists(): + continue + with json_path.open("w", encoding="utf-8") as f: + f.write(json.dumps(assistant.model_dump())) + except Exception as e: # noqa: PERF203 + error_msg = f"Failed to export row to json: {e}" + logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) + continue diff --git a/src/askui/chat/migrations/versions/37007a499ca7_remove_assistants_dir.py b/src/askui/chat/migrations/versions/37007a499ca7_remove_assistants_dir.py new file mode 100644 index 00000000..18f004dd --- /dev/null +++ b/src/askui/chat/migrations/versions/37007a499ca7_remove_assistants_dir.py @@ -0,0 +1,51 @@ +"""remove_assistants_dir + +Revision ID: 37007a499ca7 +Revises: c35e88ea9595 +Create Date: 2025-10-10 14:01:53.410908 + +""" + +import logging +import shutil +from typing import Sequence, Union + +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "37007a499ca7" +down_revision: Union[str, None] = "c35e88ea9595" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + +settings = SettingsV1() +assistants_dir = settings.data_dir / "assistants" + + +def upgrade() -> None: + """Remove the assistants directory and all its contents.""" + + # Skip if directory doesn't exist + if not assistants_dir.exists(): + logger.info("Assistants directory does not exist, skipping removal") + return + + try: + shutil.rmtree(assistants_dir) + logger.info( + "Successfully removed assistants directory", + extra={"assistants_dir": str(assistants_dir)}, + ) + except Exception as e: + error_msg = "Failed to remove assistants directory" + logger.exception( + error_msg, + extra={"assistants_dir": str(assistants_dir)}, + ) + raise RuntimeError(error_msg) from e + + +def downgrade() -> None: + assistants_dir.mkdir(parents=True, exist_ok=True) diff --git a/src/askui/chat/migrations/versions/4d1e043b4254_create_assistants_table.py b/src/askui/chat/migrations/versions/4d1e043b4254_create_assistants_table.py new file mode 100644 index 00000000..cbbdcfd0 --- /dev/null +++ b/src/askui/chat/migrations/versions/4d1e043b4254_create_assistants_table.py @@ -0,0 +1,40 @@ +"""create_assistants_table + +Revision ID: 4d1e043b4254 +Revises: +Create Date: 2025-10-10 11:21:24.218911 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "4d1e043b4254" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "assistants", + sa.Column("id", sa.String(24), nullable=False, primary_key=True), + sa.Column("workspace_id", sa.Uuid(), nullable=True, index=True), + sa.Column("created_at", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("description", sa.String(), nullable=True), + sa.Column("avatar", sa.Text(), nullable=True), + sa.Column("tools", sa.JSON(), nullable=False), + sa.Column("system", sa.Text(), nullable=True), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("assistants") + # ### end Alembic commands ### diff --git a/src/askui/chat/migrations/versions/c35e88ea9595_seed_default_assistants.py b/src/askui/chat/migrations/versions/c35e88ea9595_seed_default_assistants.py new file mode 100644 index 00000000..9f9e0492 --- /dev/null +++ b/src/askui/chat/migrations/versions/c35e88ea9595_seed_default_assistants.py @@ -0,0 +1,61 @@ +"""seed_default_assistants + +Revision ID: c35e88ea9595 +Revises: 057f82313448 +Create Date: 2025-10-10 11:22:12.576195 + +""" + +import logging +from typing import Sequence, Union + +from alembic import op +from sqlalchemy import MetaData, Table +from sqlalchemy.exc import IntegrityError + +from askui.chat.migrations.shared.assistants.seeds import SEEDS_V1 + +# revision identifiers, used by Alembic. +revision: str = "c35e88ea9595" +down_revision: Union[str, None] = "057f82313448" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + + +def upgrade() -> None: + """Seed default assistants one by one, skipping duplicates. + + For each assistant in `SEEDS_V1`, insert a row into `assistants`. If a + row with the same `id` already exists, skip it and log on info level. + """ + connection = op.get_bind() + assistants_table: Table = Table("assistants", MetaData(), autoload_with=connection) + + for seed in SEEDS_V1: + payload: dict[str, object] = seed.to_db_dict() + try: + connection.execute(assistants_table.insert().values(**payload)) + except IntegrityError: + logger.info( + "Assistant already exists, skipping", extra={"assistant_id": seed.id} + ) + continue + except Exception as e: # noqa: PERF203 + logger.exception( + "Failed to insert assistant", + extra={"assistant": seed.model_dump_json()}, + exc_info=e, + ) + continue + + +def downgrade() -> None: + """Remove exactly those assistants that were seeded in upgrade().""" + connection = op.get_bind() + assistant_table: Table = Table("assistants", MetaData(), autoload_with=connection) + + seed_db_ids: list[str] = [seed.id for seed in SEEDS_V1] + for id_ in seed_db_ids: + connection.execute(assistant_table.delete().where(assistant_table.c.id == id_)) From 086b2ac6bf9227af75f75a7a60683e8338205ced Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 14 Oct 2025 15:48:21 +0200 Subject: [PATCH 03/14] chore: fix type and linting errors --- .cursor/worktrees.json | 6 + mypy.ini | 23 ++- pdm.lock | 134 ++++++++++++------ pyproject.toml | 3 +- src/askui/chat/api/assistants/service.py | 3 +- src/askui/chat/api/db/orm/types.py | 15 +- .../chat/migrations/versions/__init__.py | 0 src/askui/tools/android/ppadb_agent_os.py | 4 +- 8 files changed, 134 insertions(+), 54 deletions(-) create mode 100644 .cursor/worktrees.json create mode 100644 src/askui/chat/migrations/versions/__init__.py diff --git a/.cursor/worktrees.json b/.cursor/worktrees.json new file mode 100644 index 00000000..e691ad38 --- /dev/null +++ b/.cursor/worktrees.json @@ -0,0 +1,6 @@ +{ + "setup-worktree": [ + "pdm install", + "cp $ROOT_WORKTREE_PATH/.env .env" + ] +} diff --git a/mypy.ini b/mypy.ini index d3b36c12..56a05114 100644 --- a/mypy.ini +++ b/mypy.ini @@ -12,7 +12,7 @@ warn_unused_ignores = true warn_no_return = true warn_unreachable = true strict_optional = true -plugins = pydantic.mypy +plugins = pydantic.mypy,sqlalchemy.ext.mypy.plugin exclude = (?x)( ^src/askui/models/ui_tars_ep/ui_tars_api\.py$ | ^src/askui/tools/askui/askui_ui_controller_grpc/.*$ @@ -26,3 +26,24 @@ ignore_missing_imports = true [mypy-bson.*] ignore_missing_imports = true + +[mypy-alembic.*] +ignore_missing_imports = true + +[mypy-structlog.*] +ignore_missing_imports = true + +[mypy-starlette_context.*] +ignore_missing_imports = true + +[mypy-asgi_correlation_id.*] +ignore_missing_imports = true + +[mypy-prometheus_fastapi_instrumentator.*] +ignore_missing_imports = true + +[mypy-mss.*] +ignore_missing_imports = true + +[mypy-ppadb.*] +ignore_missing_imports = true diff --git a/pdm.lock b/pdm.lock index 20266647..d22f1460 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "all", "android", "bedrock", "chat", "dev", "pynput", "vertex", "web"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:3dd80953b4e83ea83b341f9627ac0f45b11dd4b0ba1f69a19d5405a27c90a495" +content_hash = "sha256:3b5c70118ce8b743db5aaf3f2c06765e3d77f7e1b8da13d081a731b24e4d98e7" [[metadata.targets]] requires_python = ">=3.10" @@ -1366,7 +1366,7 @@ name = "greenlet" version = "3.2.4" requires_python = ">=3.9" summary = "Lightweight in-process concurrent programming" -groups = ["all", "chat", "dev", "web"] +groups = ["default", "all", "chat", "dev", "web"] files = [ {file = "greenlet-3.2.4-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:8c68325b0d0acf8d91dde4e6f930967dd52a5302cd4062932a6b2e7c2969f47c"}, {file = "greenlet-3.2.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:94385f101946790ae13da500603491f04a76b6e4c059dab271b3ce2e283b2590"}, @@ -2332,7 +2332,7 @@ name = "mypy" version = "1.18.2" requires_python = ">=3.9" summary = "Optional static typing for Python" -groups = ["dev"] +groups = ["default", "dev"] dependencies = [ "mypy-extensions>=1.0.0", "pathspec>=0.9.0", @@ -2379,7 +2379,7 @@ name = "mypy-extensions" version = "1.1.0" requires_python = ">=3.8" summary = "Type system extensions for programs checked with the mypy type checker." -groups = ["dev"] +groups = ["default", "dev"] files = [ {file = "mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505"}, {file = "mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558"}, @@ -2681,7 +2681,7 @@ name = "pathspec" version = "0.12.1" requires_python = ">=3.8" summary = "Utility library for gitignore style pattern matching of file paths." -groups = ["dev"] +groups = ["default", "dev"] files = [ {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, @@ -3909,50 +3909,98 @@ files = [ [[package]] name = "sqlalchemy" -version = "2.0.43" +version = "2.0.44" requires_python = ">=3.7" summary = "Database Abstraction Library" -groups = ["all", "chat"] +groups = ["default", "all", "chat"] dependencies = [ - "greenlet>=1; (platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\") and python_version < \"3.14\"", + "greenlet>=1; platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\"", "importlib-metadata; python_version < \"3.8\"", "typing-extensions>=4.6.0", ] files = [ - {file = "sqlalchemy-2.0.43-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:70322986c0c699dca241418fcf18e637a4369e0ec50540a2b907b184c8bca069"}, - {file = "sqlalchemy-2.0.43-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:87accdbba88f33efa7b592dc2e8b2a9c2cdbca73db2f9d5c510790428c09c154"}, - {file = "sqlalchemy-2.0.43-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c00e7845d2f692ebfc7d5e4ec1a3fd87698e4337d09e58d6749a16aedfdf8612"}, - {file = "sqlalchemy-2.0.43-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:022e436a1cb39b13756cf93b48ecce7aa95382b9cfacceb80a7d263129dfd019"}, - {file = "sqlalchemy-2.0.43-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:c5e73ba0d76eefc82ec0219d2301cb33bfe5205ed7a2602523111e2e56ccbd20"}, - {file = "sqlalchemy-2.0.43-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9c2e02f06c68092b875d5cbe4824238ab93a7fa35d9c38052c033f7ca45daa18"}, - {file = "sqlalchemy-2.0.43-cp310-cp310-win32.whl", hash = "sha256:e7a903b5b45b0d9fa03ac6a331e1c1d6b7e0ab41c63b6217b3d10357b83c8b00"}, - {file = "sqlalchemy-2.0.43-cp310-cp310-win_amd64.whl", hash = "sha256:4bf0edb24c128b7be0c61cd17eef432e4bef507013292415f3fb7023f02b7d4b"}, - {file = "sqlalchemy-2.0.43-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:52d9b73b8fb3e9da34c2b31e6d99d60f5f99fd8c1225c9dad24aeb74a91e1d29"}, - {file = "sqlalchemy-2.0.43-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f42f23e152e4545157fa367b2435a1ace7571cab016ca26038867eb7df2c3631"}, - {file = "sqlalchemy-2.0.43-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4fb1a8c5438e0c5ea51afe9c6564f951525795cf432bed0c028c1cb081276685"}, - {file = "sqlalchemy-2.0.43-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db691fa174e8f7036afefe3061bc40ac2b770718be2862bfb03aabae09051aca"}, - {file = "sqlalchemy-2.0.43-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:fe2b3b4927d0bc03d02ad883f402d5de201dbc8894ac87d2e981e7d87430e60d"}, - {file = "sqlalchemy-2.0.43-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4d3d9b904ad4a6b175a2de0738248822f5ac410f52c2fd389ada0b5262d6a1e3"}, - {file = "sqlalchemy-2.0.43-cp311-cp311-win32.whl", hash = "sha256:5cda6b51faff2639296e276591808c1726c4a77929cfaa0f514f30a5f6156921"}, - {file = "sqlalchemy-2.0.43-cp311-cp311-win_amd64.whl", hash = "sha256:c5d1730b25d9a07727d20ad74bc1039bbbb0a6ca24e6769861c1aa5bf2c4c4a8"}, - {file = "sqlalchemy-2.0.43-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:20d81fc2736509d7a2bd33292e489b056cbae543661bb7de7ce9f1c0cd6e7f24"}, - {file = "sqlalchemy-2.0.43-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:25b9fc27650ff5a2c9d490c13c14906b918b0de1f8fcbb4c992712d8caf40e83"}, - {file = "sqlalchemy-2.0.43-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6772e3ca8a43a65a37c88e2f3e2adfd511b0b1da37ef11ed78dea16aeae85bd9"}, - {file = "sqlalchemy-2.0.43-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a113da919c25f7f641ffbd07fbc9077abd4b3b75097c888ab818f962707eb48"}, - {file = "sqlalchemy-2.0.43-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4286a1139f14b7d70141c67a8ae1582fc2b69105f1b09d9573494eb4bb4b2687"}, - {file = "sqlalchemy-2.0.43-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:529064085be2f4d8a6e5fab12d36ad44f1909a18848fcfbdb59cc6d4bbe48efe"}, - {file = "sqlalchemy-2.0.43-cp312-cp312-win32.whl", hash = "sha256:b535d35dea8bbb8195e7e2b40059e2253acb2b7579b73c1b432a35363694641d"}, - {file = "sqlalchemy-2.0.43-cp312-cp312-win_amd64.whl", hash = "sha256:1c6d85327ca688dbae7e2b06d7d84cfe4f3fffa5b5f9e21bb6ce9d0e1a0e0e0a"}, - {file = "sqlalchemy-2.0.43-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e7c08f57f75a2bb62d7ee80a89686a5e5669f199235c6d1dac75cd59374091c3"}, - {file = "sqlalchemy-2.0.43-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:14111d22c29efad445cd5021a70a8b42f7d9152d8ba7f73304c4d82460946aaa"}, - {file = "sqlalchemy-2.0.43-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21b27b56eb2f82653168cefe6cb8e970cdaf4f3a6cb2c5e3c3c1cf3158968ff9"}, - {file = "sqlalchemy-2.0.43-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c5a9da957c56e43d72126a3f5845603da00e0293720b03bde0aacffcf2dc04f"}, - {file = "sqlalchemy-2.0.43-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5d79f9fdc9584ec83d1b3c75e9f4595c49017f5594fee1a2217117647225d738"}, - {file = "sqlalchemy-2.0.43-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9df7126fd9db49e3a5a3999442cc67e9ee8971f3cb9644250107d7296cb2a164"}, - {file = "sqlalchemy-2.0.43-cp313-cp313-win32.whl", hash = "sha256:7f1ac7828857fcedb0361b48b9ac4821469f7694089d15550bbcf9ab22564a1d"}, - {file = "sqlalchemy-2.0.43-cp313-cp313-win_amd64.whl", hash = "sha256:971ba928fcde01869361f504fcff3b7143b47d30de188b11c6357c0505824197"}, - {file = "sqlalchemy-2.0.43-py3-none-any.whl", hash = "sha256:1681c21dd2ccee222c2fe0bef671d1aef7c504087c9c4e800371cfcc8ac966fc"}, - {file = "sqlalchemy-2.0.43.tar.gz", hash = "sha256:788bfcef6787a7764169cfe9859fe425bf44559619e1d9f56f5bddf2ebf6f417"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7c77f3080674fc529b1bd99489378c7f63fcb4ba7f8322b79732e0258f0ea3ce"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4c26ef74ba842d61635b0152763d057c8d48215d5be9bb8b7604116a059e9985"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4a172b31785e2f00780eccab00bc240ccdbfdb8345f1e6063175b3ff12ad1b0"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9480c0740aabd8cb29c329b422fb65358049840b34aba0adf63162371d2a96e"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:17835885016b9e4d0135720160db3095dc78c583e7b902b6be799fb21035e749"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cbe4f85f50c656d753890f39468fcd8190c5f08282caf19219f684225bfd5fd2"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-win32.whl", hash = "sha256:2fcc4901a86ed81dc76703f3b93ff881e08761c63263c46991081fd7f034b165"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-win_amd64.whl", hash = "sha256:9919e77403a483ab81e3423151e8ffc9dd992c20d2603bf17e4a8161111e55f5"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fe3917059c7ab2ee3f35e77757062b1bea10a0b6ca633c58391e3f3c6c488dd"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:de4387a354ff230bc979b46b2207af841dc8bf29847b6c7dbe60af186d97aefa"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3678a0fb72c8a6a29422b2732fe423db3ce119c34421b5f9955873eb9b62c1e"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3cf6872a23601672d61a68f390e44703442639a12ee9dd5a88bbce52a695e46e"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:329aa42d1be9929603f406186630135be1e7a42569540577ba2c69952b7cf399"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:70e03833faca7166e6a9927fbee7c27e6ecde436774cd0b24bbcc96353bce06b"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-win32.whl", hash = "sha256:253e2f29843fb303eca6b2fc645aca91fa7aa0aa70b38b6950da92d44ff267f3"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-win_amd64.whl", hash = "sha256:7a8694107eb4308a13b425ca8c0e67112f8134c846b6e1f722698708741215d5"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:72fea91746b5890f9e5e0997f16cbf3d53550580d76355ba2d998311b17b2250"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:585c0c852a891450edbb1eaca8648408a3cc125f18cf433941fa6babcc359e29"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b94843a102efa9ac68a7a30cd46df3ff1ed9c658100d30a725d10d9c60a2f44"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:119dc41e7a7defcefc57189cfa0e61b1bf9c228211aba432b53fb71ef367fda1"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0765e318ee9179b3718c4fd7ba35c434f4dd20332fbc6857a5e8df17719c24d7"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2e7b5b079055e02d06a4308d0481658e4f06bc7ef211567edc8f7d5dce52018d"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-win32.whl", hash = "sha256:846541e58b9a81cce7dee8329f352c318de25aa2f2bbe1e31587eb1f057448b4"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-win_amd64.whl", hash = "sha256:7cbcb47fd66ab294703e1644f78971f6f2f1126424d2b300678f419aa73c7b6e"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ff486e183d151e51b1d694c7aa1695747599bb00b9f5f604092b54b74c64a8e1"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0b1af8392eb27b372ddb783b317dea0f650241cea5bd29199b22235299ca2e45"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b61188657e3a2b9ac4e8f04d6cf8e51046e28175f79464c67f2fd35bceb0976"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b87e7b91a5d5973dda5f00cd61ef72ad75a1db73a386b62877d4875a8840959c"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:15f3326f7f0b2bfe406ee562e17f43f36e16167af99c4c0df61db668de20002d"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1e77faf6ff919aa8cd63f1c4e561cac1d9a454a191bb864d5dd5e545935e5a40"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-win32.whl", hash = "sha256:ee51625c2d51f8baadf2829fae817ad0b66b140573939dd69284d2ba3553ae73"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-win_amd64.whl", hash = "sha256:c1c80faaee1a6c3428cecf40d16a2365bcf56c424c92c2b6f0f9ad204b899e9e"}, + {file = "sqlalchemy-2.0.44-py3-none-any.whl", hash = "sha256:19de7ca1246fbef9f9d1bff8f1ab25641569df226364a0e40457dc5457c54b05"}, + {file = "sqlalchemy-2.0.44.tar.gz", hash = "sha256:0ae7454e1ab1d780aee69fd2aae7d6b8670a581d8847f2d1e0f7ddfbf47e5a22"}, +] + +[[package]] +name = "sqlalchemy" +version = "2.0.44" +extras = ["mypy"] +requires_python = ">=3.7" +summary = "Database Abstraction Library" +groups = ["default"] +dependencies = [ + "mypy>=0.910", + "sqlalchemy==2.0.44", +] +files = [ + {file = "sqlalchemy-2.0.44-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7c77f3080674fc529b1bd99489378c7f63fcb4ba7f8322b79732e0258f0ea3ce"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4c26ef74ba842d61635b0152763d057c8d48215d5be9bb8b7604116a059e9985"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4a172b31785e2f00780eccab00bc240ccdbfdb8345f1e6063175b3ff12ad1b0"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9480c0740aabd8cb29c329b422fb65358049840b34aba0adf63162371d2a96e"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:17835885016b9e4d0135720160db3095dc78c583e7b902b6be799fb21035e749"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cbe4f85f50c656d753890f39468fcd8190c5f08282caf19219f684225bfd5fd2"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-win32.whl", hash = "sha256:2fcc4901a86ed81dc76703f3b93ff881e08761c63263c46991081fd7f034b165"}, + {file = "sqlalchemy-2.0.44-cp310-cp310-win_amd64.whl", hash = "sha256:9919e77403a483ab81e3423151e8ffc9dd992c20d2603bf17e4a8161111e55f5"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fe3917059c7ab2ee3f35e77757062b1bea10a0b6ca633c58391e3f3c6c488dd"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:de4387a354ff230bc979b46b2207af841dc8bf29847b6c7dbe60af186d97aefa"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3678a0fb72c8a6a29422b2732fe423db3ce119c34421b5f9955873eb9b62c1e"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3cf6872a23601672d61a68f390e44703442639a12ee9dd5a88bbce52a695e46e"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:329aa42d1be9929603f406186630135be1e7a42569540577ba2c69952b7cf399"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:70e03833faca7166e6a9927fbee7c27e6ecde436774cd0b24bbcc96353bce06b"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-win32.whl", hash = "sha256:253e2f29843fb303eca6b2fc645aca91fa7aa0aa70b38b6950da92d44ff267f3"}, + {file = "sqlalchemy-2.0.44-cp311-cp311-win_amd64.whl", hash = "sha256:7a8694107eb4308a13b425ca8c0e67112f8134c846b6e1f722698708741215d5"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:72fea91746b5890f9e5e0997f16cbf3d53550580d76355ba2d998311b17b2250"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:585c0c852a891450edbb1eaca8648408a3cc125f18cf433941fa6babcc359e29"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b94843a102efa9ac68a7a30cd46df3ff1ed9c658100d30a725d10d9c60a2f44"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:119dc41e7a7defcefc57189cfa0e61b1bf9c228211aba432b53fb71ef367fda1"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0765e318ee9179b3718c4fd7ba35c434f4dd20332fbc6857a5e8df17719c24d7"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2e7b5b079055e02d06a4308d0481658e4f06bc7ef211567edc8f7d5dce52018d"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-win32.whl", hash = "sha256:846541e58b9a81cce7dee8329f352c318de25aa2f2bbe1e31587eb1f057448b4"}, + {file = "sqlalchemy-2.0.44-cp312-cp312-win_amd64.whl", hash = "sha256:7cbcb47fd66ab294703e1644f78971f6f2f1126424d2b300678f419aa73c7b6e"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ff486e183d151e51b1d694c7aa1695747599bb00b9f5f604092b54b74c64a8e1"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0b1af8392eb27b372ddb783b317dea0f650241cea5bd29199b22235299ca2e45"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b61188657e3a2b9ac4e8f04d6cf8e51046e28175f79464c67f2fd35bceb0976"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b87e7b91a5d5973dda5f00cd61ef72ad75a1db73a386b62877d4875a8840959c"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:15f3326f7f0b2bfe406ee562e17f43f36e16167af99c4c0df61db668de20002d"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1e77faf6ff919aa8cd63f1c4e561cac1d9a454a191bb864d5dd5e545935e5a40"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-win32.whl", hash = "sha256:ee51625c2d51f8baadf2829fae817ad0b66b140573939dd69284d2ba3553ae73"}, + {file = "sqlalchemy-2.0.44-cp313-cp313-win_amd64.whl", hash = "sha256:c1c80faaee1a6c3428cecf40d16a2365bcf56c424c92c2b6f0f9ad204b899e9e"}, + {file = "sqlalchemy-2.0.44-py3-none-any.whl", hash = "sha256:19de7ca1246fbef9f9d1bff8f1ab25641569df226364a0e40457dc5457c54b05"}, + {file = "sqlalchemy-2.0.44.tar.gz", hash = "sha256:0ae7454e1ab1d780aee69fd2aae7d6b8670a581d8847f2d1e0f7ddfbf47e5a22"}, ] [[package]] @@ -4043,7 +4091,7 @@ name = "tomli" version = "2.2.1" requires_python = ">=3.8" summary = "A lil' TOML parser" -groups = ["all", "chat", "dev"] +groups = ["default", "all", "chat", "dev"] marker = "python_version <= \"3.11\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, diff --git a/pyproject.toml b/pyproject.toml index 97c45dba..604add0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "bson>=0.5.10", "aiofiles>=24.1.0", "anyio==4.10.0", # We need to pin this version otherwise listing mcp tools using fastmcp within runner fails + "sqlalchemy[mypy]>=2.0.44", ] requires-python = ">=3.10,<=3.13" readme = "README.md" @@ -45,8 +46,6 @@ build-backend = "hatchling.build" path = "src/askui/__init__.py" - - [tool.pdm] distribution = true diff --git a/src/askui/chat/api/assistants/service.py b/src/askui/chat/api/assistants/service.py index a562687f..91154e51 100644 --- a/src/askui/chat/api/assistants/service.py +++ b/src/askui/chat/api/assistants/service.py @@ -21,6 +21,7 @@ def list_( AssistantOrm.workspace_id.is_(None), ), ) + orms: list[AssistantOrm] orms, has_more = list_all(q, query, AssistantOrm.id) data = [orm.to_model() for orm in orms] return ListResponse( @@ -33,7 +34,7 @@ def list_( def _find_by_id( self, workspace_id: WorkspaceId | None, assistant_id: AssistantId ) -> AssistantOrm: - assistant_orm = ( + assistant_orm: AssistantOrm | None = ( self._session.query(AssistantOrm) .filter( AssistantOrm.id == assistant_id, diff --git a/src/askui/chat/api/db/orm/types.py b/src/askui/chat/api/db/orm/types.py index f1fe85c2..0087eff6 100644 --- a/src/askui/chat/api/db/orm/types.py +++ b/src/askui/chat/api/db/orm/types.py @@ -11,12 +11,12 @@ class PrefixedObjectId(TypeDecorator[str]): impl = String(24) cache_ok = True - def process_bind_param(self, value: str | None, dialect: Any) -> str | None: + def process_bind_param(self, value: str | None, dialect: Any) -> str | None: # noqa: ARG002 if value is None: return value return value.removeprefix(f"{prefix}_") - def process_result_value(self, value: str | None, dialect: Any) -> str | None: + def process_result_value(self, value: str | None, dialect: Any) -> str | None: # noqa: ARG002 if value is None: return value return f"{prefix}_{value}" @@ -25,7 +25,6 @@ def process_result_value(self, value: str | None, dialect: Any) -> str | None: # Specialized types for each resource -# TODO Move into orms.py of the respective resource ThreadId = create_prefixed_id_type("thread") MessageId = create_prefixed_id_type("msg") RunId = create_prefixed_id_type("run") @@ -39,7 +38,9 @@ class UnixDatetime(TypeDecorator[datetime]): LOCAL_TIMEZONE = datetime.now().astimezone().tzinfo def process_bind_param( - self, value: datetime | int | None, dialect: Any + self, + value: datetime | int | None, + dialect: Any, # noqa: ARG002 ) -> int | None: if value is None: return value @@ -49,7 +50,11 @@ def process_bind_param( value = value.astimezone(self.LOCAL_TIMEZONE) return int(value.astimezone(timezone.utc).timestamp()) - def process_result_value(self, value: int | None, dialect: Any) -> datetime | None: + def process_result_value( + self, + value: int | None, + dialect: Any, # noqa: ARG002 + ) -> datetime | None: if value is None: return value return datetime.fromtimestamp(value, timezone.utc) diff --git a/src/askui/chat/migrations/versions/__init__.py b/src/askui/chat/migrations/versions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/tools/android/ppadb_agent_os.py b/src/askui/tools/android/ppadb_agent_os.py index 7d7ea15f..90285ead 100644 --- a/src/askui/tools/android/ppadb_agent_os.py +++ b/src/askui/tools/android/ppadb_agent_os.py @@ -5,8 +5,8 @@ from typing import List, Optional, get_args from PIL import Image -from ppadb.client import Client as AdbClient # type: ignore[import-untyped] -from ppadb.device import Device as AndroidDevice # type: ignore[import-untyped] +from ppadb.client import Client as AdbClient +from ppadb.device import Device as AndroidDevice from askui.tools.android.agent_os import ANDROID_KEY, AndroidAgentOs, AndroidDisplay From 95099a0c70e66d1e3e4259e8ee21f7b7c995dc0a Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 14 Oct 2025 16:53:11 +0200 Subject: [PATCH 04/14] test(assistants): fix tests after persistence migration from json to sqlite --- tests/integration/chat/api/conftest.py | 37 +- tests/integration/chat/api/test_assistants.py | 976 +++++++----------- tests/integration/chat/api/test_runs.py | 561 ++++------ 3 files changed, 614 insertions(+), 960 deletions(-) diff --git a/tests/integration/chat/api/conftest.py b/tests/integration/chat/api/conftest.py index a9272840..b58adec5 100644 --- a/tests/integration/chat/api/conftest.py +++ b/tests/integration/chat/api/conftest.py @@ -2,16 +2,39 @@ import tempfile import uuid +from collections.abc import Generator from pathlib import Path import pytest from fastapi import FastAPI from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker from askui.chat.api.app import app +from askui.chat.api.assistants.dependencies import get_assistant_service +from askui.chat.api.assistants.service import AssistantService +from askui.chat.api.db.orm.base import Base from askui.chat.api.files.service import FileService +@pytest.fixture +def test_db_session() -> Generator[Session, None, None]: + """Create a test database session with temporary SQLite file.""" + with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as temp_db: + # Create an engine with the temporary file + engine = create_engine(f"sqlite:///{temp_db.name}", echo=True) + # Create all tables + Base.metadata.create_all(engine) + # Create a session + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + session = SessionLocal() + try: + yield session + finally: + session.close() + + @pytest.fixture def test_app() -> FastAPI: """Get the FastAPI test application.""" @@ -19,9 +42,17 @@ def test_app() -> FastAPI: @pytest.fixture -def test_client(test_app: FastAPI) -> TestClient: - """Get a test client for the FastAPI application.""" - return TestClient(test_app) +def test_client( + test_app: FastAPI, test_db_session: Session +) -> Generator[TestClient, None, None]: + """Yield a TestClient with common overrides (assistants service uses the test DB).""" + app.dependency_overrides[get_assistant_service] = lambda: AssistantService( + test_db_session + ) + try: + yield TestClient(test_app) + finally: + app.dependency_overrides.pop(get_assistant_service, None) @pytest.fixture diff --git a/tests/integration/chat/api/test_assistants.py b/tests/integration/chat/api/test_assistants.py index 6c60b64a..46197d4a 100644 --- a/tests/integration/chat/api/test_assistants.py +++ b/tests/integration/chat/api/test_assistants.py @@ -1,309 +1,208 @@ """Integration tests for the assistants API endpoints.""" -import tempfile -from pathlib import Path +from datetime import datetime, timezone +from uuid import UUID from fastapi import status from fastapi.testclient import TestClient +from sqlalchemy.orm import Session from askui.chat.api.assistants.models import Assistant -from askui.chat.api.assistants.service import AssistantService +from askui.chat.api.assistants.orms import AssistantOrm +from askui.chat.api.models import WorkspaceId class TestAssistantsAPI: """Test suite for the assistants API endpoints.""" - def test_list_assistants_empty(self, test_headers: dict[str, str]) -> None: - """Test listing assistants when no assistants exist.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) + def _create_test_assistant( + self, + assistant_id: str, + workspace_id: WorkspaceId | None = None, + name: str = "Test Assistant", + description: str = "A test assistant", + avatar: str | None = None, + created_at: datetime | None = None, + ) -> Assistant: + """Create a test assistant model.""" + if created_at is None: + created_at = datetime.fromtimestamp(1234567890, tz=timezone.utc) + return Assistant( + id=assistant_id, + object="assistant", + created_at=created_at, + name=name, + description=description, + avatar=avatar, + workspace_id=workspace_id, + ) - app.dependency_overrides[get_assistant_service] = override_assistant_service + def _add_assistant_to_db( + self, assistant: Assistant, test_db_session: Session + ) -> None: + """Add an assistant to the test database.""" + assistant_orm = AssistantOrm.from_model(assistant) + test_db_session.add(assistant_orm) + test_db_session.commit() - try: - with TestClient(app) as client: - response = client.get("/v1/assistants", headers=test_headers) + def test_list_assistants_empty( + self, test_headers: dict[str, str], test_client: TestClient + ) -> None: + """Test listing assistants when no assistants exist.""" + response = test_client.get("/v1/assistants", headers=test_headers) - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert data["data"] == [] - assert data["has_more"] is False - finally: - app.dependency_overrides.clear() + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert data["data"] == [] + assert data["has_more"] is False def test_list_assistants_with_assistants( - self, test_headers: dict[str, str] + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, ) -> None: """Test listing assistants when assistants exist.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock assistant - workspace_id = test_headers["askui-workspace"] - mock_assistant = Assistant( - id="asst_test123", - object="assistant", - created_at=1234567890, - name="Test Assistant", - description="A test assistant", - avatar="test_avatar.png", - workspace_id=workspace_id, + workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 + mock_assistant = self._create_test_assistant( + "asst_test123", workspace_id=workspace_id, avatar="test_avatar.png" ) - (assistants_dir / "asst_test123.json").write_text( - mock_assistant.model_dump_json() - ) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service + self._add_assistant_to_db(mock_assistant, test_db_session) + response = test_client.get("/v1/assistants", headers=test_headers) - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service - - try: - with TestClient(app) as client: - response = client.get("/v1/assistants", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert len(data["data"]) == 1 - assert data["data"][0]["id"] == "asst_test123" - assert data["data"][0]["name"] == "Test Assistant" - finally: - app.dependency_overrides.clear() + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["id"] == "asst_test123" + assert data["data"][0]["name"] == "Test Assistant" + assert data["data"][0]["description"] == "A test assistant" + assert data["data"][0]["avatar"] == "test_avatar.png" + assert data["has_more"] is False def test_list_assistants_with_pagination( - self, test_headers: dict[str, str] + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, ) -> None: """Test listing assistants with pagination parameters.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) + workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 - # Create multiple mock assistants - workspace_id = test_headers["askui-workspace"] + # Create multiple mock assistants in the database for i in range(5): - mock_assistant = Assistant( - id=f"asst_test{i}", - object="assistant", - created_at=1234567890 + i, + mock_assistant = self._create_test_assistant( + f"asst_test{i}", + workspace_id=workspace_id, name=f"Test Assistant {i}", description=f"Test assistant {i}", - workspace_id=workspace_id, - ) - (assistants_dir / f"asst_test{i}.json").write_text( - mock_assistant.model_dump_json() + created_at=datetime.fromtimestamp(1234567890 + i, tz=timezone.utc), ) + self._add_assistant_to_db(mock_assistant, test_db_session) - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service + response = test_client.get("/v1/assistants?limit=3", headers=test_headers) - try: - with TestClient(app) as client: - response = client.get("/v1/assistants?limit=3", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert len(data["data"]) == 3 - assert data["has_more"] is True - finally: - app.dependency_overrides.clear() + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["data"]) == 3 + assert data["has_more"] is True - def test_create_assistant(self, test_headers: dict[str, str]) -> None: + def test_create_assistant( + self, test_headers: dict[str, str], test_client: TestClient + ) -> None: """Test creating a new assistant.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service - - try: - with TestClient(app) as client: - assistant_data = { - "name": "New Test Assistant", - "description": "A newly created test assistant", - "avatar": "new_avatar.png", - } - response = client.post( - "/v1/assistants", json=assistant_data, headers=test_headers - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["name"] == "New Test Assistant" - assert data["description"] == "A newly created test assistant" - assert data["avatar"] == "new_avatar.png" - assert data["object"] == "assistant" - assert "id" in data - assert "created_at" in data - finally: - app.dependency_overrides.clear() - - def test_create_assistant_minimal(self, test_headers: dict[str, str]) -> None: - """Test creating an assistant with minimal data.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service + assistant_data = { + "name": "New Test Assistant", + "description": "A newly created test assistant", + "avatar": "new_avatar.png", + } + response = test_client.post( + "/v1/assistants", json=assistant_data, headers=test_headers + ) - try: - with TestClient(app) as client: - response = client.post("/v1/assistants", json={}, headers=test_headers) + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["name"] == "New Test Assistant" + assert data["description"] == "A newly created test assistant" + assert data["avatar"] == "new_avatar.png" + assert data["object"] == "assistant" + assert "id" in data + assert "created_at" in data + + def test_create_assistant_minimal( + self, test_headers: dict[str, str], test_client: TestClient + ) -> None: + """Test creating an assistant with minimal data.""" + response = test_client.post("/v1/assistants", json={}, headers=test_headers) - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["object"] == "assistant" - assert data["name"] is None - assert data["description"] is None - assert data["avatar"] is None - finally: - app.dependency_overrides.clear() + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["object"] == "assistant" + assert data["name"] is None + assert data["description"] is None + assert data["avatar"] is None def test_create_assistant_with_tools_and_system( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_client: TestClient ) -> None: """Test creating a new assistant with tools and system prompt.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service - - try: - with TestClient(app) as client: - response = client.post( - "/v1/assistants", - headers=test_headers, - json={ - "name": "Custom Assistant", - "description": "A custom assistant with tools", - "tools": ["tool1", "tool2", "tool3"], - "system": "You are a helpful custom assistant.", - }, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["name"] == "Custom Assistant" - assert data["description"] == "A custom assistant with tools" - assert data["tools"] == ["tool1", "tool2", "tool3"] - assert data["system"] == "You are a helpful custom assistant." - assert "id" in data - assert "created_at" in data - finally: - app.dependency_overrides.clear() + response = test_client.post( + "/v1/assistants", + headers=test_headers, + json={ + "name": "Custom Assistant", + "description": "A custom assistant with tools", + "tools": ["tool1", "tool2", "tool3"], + "system": "You are a helpful custom assistant.", + }, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["name"] == "Custom Assistant" + assert data["description"] == "A custom assistant with tools" + assert data["tools"] == ["tool1", "tool2", "tool3"] + assert data["system"] == "You are a helpful custom assistant." + assert "id" in data + assert "created_at" in data def test_create_assistant_with_empty_tools( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_client: TestClient ) -> None: """Test creating a new assistant with empty tools list.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - - # Create a test app with overridden dependencies - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service - - try: - with TestClient(app) as client: - response = client.post( - "/v1/assistants", - headers=test_headers, - json={ - "name": "Empty Tools Assistant", - "tools": [], - }, - ) - - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["name"] == "Empty Tools Assistant" - assert data["tools"] == [] - assert "id" in data - assert "created_at" in data - finally: - app.dependency_overrides.clear() - - def test_retrieve_assistant(self, test_headers: dict[str, str]) -> None: - """Test retrieving an existing assistant.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) - - mock_assistant = Assistant( - id="asst_test123", - object="assistant", - created_at=1234567890, - name="Test Assistant", - description="A test assistant", - ) - (assistants_dir / "asst_test123.json").write_text( - mock_assistant.model_dump_json() + response = test_client.post( + "/v1/assistants", + headers=test_headers, + json={ + "name": "Empty Tools Assistant", + "tools": [], + }, ) - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service - - try: - with TestClient(app) as client: - response = client.get( - "/v1/assistants/asst_test123", headers=test_headers - ) + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["name"] == "Empty Tools Assistant" + assert data["tools"] == [] + assert "id" in data + assert "created_at" in data + + def test_retrieve_assistant( + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, + ) -> None: + """Test retrieving an existing assistant.""" + mock_assistant = self._create_test_assistant("asst_test123") + self._add_assistant_to_db(mock_assistant, test_db_session) + response = test_client.get("/v1/assistants/asst_test123", headers=test_headers) - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["id"] == "asst_test123" - assert data["name"] == "Test Assistant" - assert data["description"] == "A test assistant" - finally: - app.dependency_overrides.clear() + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == "asst_test123" + assert data["name"] == "Test Assistant" + assert data["description"] == "A test assistant" def test_retrieve_assistant_not_found( self, test_client: TestClient, test_headers: dict[str, str] @@ -317,151 +216,98 @@ def test_retrieve_assistant_not_found( data = response.json() assert "detail" in data - def test_modify_assistant(self, test_headers: dict[str, str]) -> None: + def test_modify_assistant( + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, + ) -> None: """Test modifying an existing assistant.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) - - workspace_id = test_headers["askui-workspace"] - mock_assistant = Assistant( - id="asst_test123", - object="assistant", - created_at=1234567890, + workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 + mock_assistant = self._create_test_assistant( + "asst_test123", + workspace_id=workspace_id, name="Original Name", description="Original description", - workspace_id=workspace_id, ) - (assistants_dir / "asst_test123.json").write_text( - mock_assistant.model_dump_json() + self._add_assistant_to_db(mock_assistant, test_db_session) + modify_data = { + "name": "Modified Name", + "description": "Modified description", + } + response = test_client.post( + "/v1/assistants/asst_test123", + json=modify_data, + headers=test_headers, ) - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service - - try: - with TestClient(app) as client: - modify_data = { - "name": "Modified Name", - "description": "Modified description", - } - response = client.post( - "/v1/assistants/asst_test123", - json=modify_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["name"] == "Modified Name" - assert data["description"] == "Modified description" - assert data["id"] == "asst_test123" - assert data["created_at"] == 1234567890 - finally: - app.dependency_overrides.clear() + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Modified Name" + assert data["description"] == "Modified description" + assert data["id"] == "asst_test123" + assert data["created_at"] == 1234567890 def test_modify_assistant_with_tools_and_system( - self, test_headers: dict[str, str] + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, ) -> None: """Test modifying an assistant with tools and system prompt.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) - - workspace_id = test_headers["askui-workspace"] - mock_assistant = Assistant( - id="asst_test123", - object="assistant", - created_at=1234567890, + workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 + mock_assistant = self._create_test_assistant( + "asst_test123", + workspace_id=workspace_id, name="Original Name", description="Original description", - workspace_id=workspace_id, ) - (assistants_dir / "asst_test123.json").write_text( - mock_assistant.model_dump_json() + self._add_assistant_to_db(mock_assistant, test_db_session) + modify_data = { + "name": "Modified Name", + "tools": ["new_tool1", "new_tool2"], + "system": "You are a modified custom assistant.", + } + response = test_client.post( + "/v1/assistants/asst_test123", + json=modify_data, + headers=test_headers, ) - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service - - try: - with TestClient(app) as client: - modify_data = { - "name": "Modified Name", - "tools": ["new_tool1", "new_tool2"], - "system": "You are a modified custom assistant.", - } - response = client.post( - "/v1/assistants/asst_test123", - json=modify_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["name"] == "Modified Name" - assert data["tools"] == ["new_tool1", "new_tool2"] - assert data["system"] == "You are a modified custom assistant." - assert data["id"] == "asst_test123" - assert data["created_at"] == 1234567890 - finally: - app.dependency_overrides.clear() - - def test_modify_assistant_partial(self, test_headers: dict[str, str]) -> None: + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Modified Name" + assert data["tools"] == ["new_tool1", "new_tool2"] + assert data["system"] == "You are a modified custom assistant." + assert data["id"] == "asst_test123" + assert data["created_at"] == 1234567890 + + def test_modify_assistant_partial( + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, + ) -> None: """Test modifying an assistant with partial data.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) - - workspace_id = test_headers["askui-workspace"] - mock_assistant = Assistant( - id="asst_test123", - object="assistant", - created_at=1234567890, + workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 + mock_assistant = self._create_test_assistant( + "asst_test123", + workspace_id=workspace_id, name="Original Name", description="Original description", - workspace_id=workspace_id, ) - (assistants_dir / "asst_test123.json").write_text( - mock_assistant.model_dump_json() + self._add_assistant_to_db(mock_assistant, test_db_session) + modify_data = {"name": "Only Name Modified"} + response = test_client.post( + "/v1/assistants/asst_test123", + json=modify_data, + headers=test_headers, ) - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service - - try: - with TestClient(app) as client: - modify_data = {"name": "Only Name Modified"} - response = client.post( - "/v1/assistants/asst_test123", - json=modify_data, - headers=test_headers, - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["name"] == "Only Name Modified" - assert data["description"] == "Original description" # Unchanged - finally: - app.dependency_overrides.clear() + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Only Name Modified" + assert data["description"] == "Original description" # Unchanged def test_modify_assistant_not_found( self, test_client: TestClient, test_headers: dict[str, str] @@ -474,43 +320,24 @@ def test_modify_assistant_not_found( assert response.status_code == status.HTTP_404_NOT_FOUND - def test_delete_assistant(self, test_headers: dict[str, str]) -> None: + def test_delete_assistant( + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, + ) -> None: """Test deleting an existing assistant.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) - - workspace_id = test_headers["askui-workspace"] - mock_assistant = Assistant( - id="asst_test123", - object="assistant", - created_at=1234567890, - name="Test Assistant", - workspace_id=workspace_id, + workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 + mock_assistant = self._create_test_assistant( + "asst_test123", workspace_id=workspace_id ) - (assistants_dir / "asst_test123.json").write_text( - mock_assistant.model_dump_json() + self._add_assistant_to_db(mock_assistant, test_db_session) + response = test_client.delete( + "/v1/assistants/asst_test123", headers=test_headers ) - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service - - try: - with TestClient(app) as client: - response = client.delete( - "/v1/assistants/asst_test123", headers=test_headers - ) - - assert response.status_code == status.HTTP_204_NO_CONTENT - assert response.content == b"" - finally: - app.dependency_overrides.clear() + assert response.status_code == status.HTTP_204_NO_CONTENT + assert response.content == b"" def test_delete_assistant_not_found( self, test_client: TestClient, test_headers: dict[str, str] @@ -523,263 +350,164 @@ def test_delete_assistant_not_found( assert response.status_code == status.HTTP_404_NOT_FOUND def test_modify_default_assistant_forbidden( - self, test_headers: dict[str, str] + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, ) -> None: """Test that modifying a default assistant returns 403 Forbidden.""" - # Create a default assistant (no workspace_id) - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) - - default_assistant = Assistant( - id="asst_default123", - object="assistant", - created_at=1234567890, + default_assistant = self._create_test_assistant( + "asst_default123", + workspace_id=None, # No workspace_id = default name="Default Assistant", description="This is a default assistant", - workspace_id=None, # No workspace_id = default ) - (assistants_dir / "asst_default123.json").write_text( - default_assistant.model_dump_json() + self._add_assistant_to_db(default_assistant, test_db_session) + # Try to modify the default assistant + response = test_client.post( + "/v1/assistants/asst_default123", + headers=test_headers, + json={"name": "Modified Name"}, ) - - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service - - try: - with TestClient(app) as client: - # Try to modify the default assistant - response = client.post( - "/v1/assistants/asst_default123", - headers=test_headers, - json={"name": "Modified Name"}, - ) - assert response.status_code == 403 - assert "cannot be modified" in response.json()["detail"] - finally: - app.dependency_overrides.clear() + assert response.status_code == 403 + assert "cannot be modified" in response.json()["detail"] def test_delete_default_assistant_forbidden( - self, test_headers: dict[str, str] + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, ) -> None: """Test that deleting a default assistant returns 403 Forbidden.""" - # Create a default assistant (no workspace_id) - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) - - default_assistant = Assistant( - id="asst_default456", - object="assistant", - created_at=1234567890, + default_assistant = self._create_test_assistant( + "asst_default456", + workspace_id=None, # No workspace_id = default name="Default Assistant", description="This is a default assistant", - workspace_id=None, # No workspace_id = default ) - (assistants_dir / "asst_default456.json").write_text( - default_assistant.model_dump_json() + self._add_assistant_to_db(default_assistant, test_db_session) + # Try to delete the default assistant + response = test_client.delete( + "/v1/assistants/asst_default456", + headers=test_headers, ) - - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service - - try: - with TestClient(app) as client: - # Try to delete the default assistant - response = client.delete( - "/v1/assistants/asst_default456", - headers=test_headers, - ) - assert response.status_code == 403 - assert "cannot be deleted" in response.json()["detail"] - finally: - app.dependency_overrides.clear() + assert response.status_code == 403 + assert "cannot be deleted" in response.json()["detail"] def test_list_assistants_includes_default_and_workspace( - self, test_headers: dict[str, str] + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, ) -> None: """Test that listing assistants includes both default and workspace-scoped ones. """ - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) - # Create a default assistant (no workspace_id) - default_assistant = Assistant( - id="asst_default789", - object="assistant", - created_at=1234567890, + default_assistant = self._create_test_assistant( + "asst_default789", + workspace_id=None, # No workspace_id = default name="Default Assistant", description="This is a default assistant", - workspace_id=None, # No workspace_id = default - ) - (assistants_dir / "asst_default789.json").write_text( - default_assistant.model_dump_json() ) + self._add_assistant_to_db(default_assistant, test_db_session) # Create a workspace-scoped assistant - workspace_id = test_headers["askui-workspace"] - workspace_assistant = Assistant( - id="asst_workspace123", - object="assistant", - created_at=1234567890, + workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 + workspace_assistant = self._create_test_assistant( + "asst_workspace123", + workspace_id=workspace_id, name="Workspace Assistant", description="This is a workspace assistant", - workspace_id=workspace_id, ) - (assistants_dir / "asst_workspace123.json").write_text( - workspace_assistant.model_dump_json() - ) - - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service + self._add_assistant_to_db(workspace_assistant, test_db_session) - try: - with TestClient(app) as client: - # List assistants - should include both - response = client.get("/v1/assistants", headers=test_headers) - assert response.status_code == 200 + # List assistants - should include both + response = test_client.get("/v1/assistants", headers=test_headers) + assert response.status_code == 200 - data = response.json() - assistant_ids = [assistant["id"] for assistant in data["data"]] + data = response.json() + assistant_ids = [assistant["id"] for assistant in data["data"]] - # Should include both default and workspace assistants - assert "asst_default789" in assistant_ids - assert "asst_workspace123" in assistant_ids + # Should include both default and workspace assistants + assert "asst_default789" in assistant_ids + assert "asst_workspace123" in assistant_ids - # Verify workspace_id fields - default_assistant_data = next( - a for a in data["data"] if a["id"] == "asst_default789" - ) - workspace_assistant_data = next( - a for a in data["data"] if a["id"] == "asst_workspace123" - ) + # Verify workspace_id fields + default_assistant_data = next( + a for a in data["data"] if a["id"] == "asst_default789" + ) + workspace_assistant_data = next( + a for a in data["data"] if a["id"] == "asst_workspace123" + ) - assert default_assistant_data["workspace_id"] is None - assert workspace_assistant_data["workspace_id"] == workspace_id - finally: - app.dependency_overrides.clear() + assert default_assistant_data["workspace_id"] is None + assert workspace_assistant_data["workspace_id"] == str(workspace_id) def test_retrieve_default_assistant_success( - self, test_headers: dict[str, str] + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, ) -> None: """Test that retrieving a default assistant works.""" - # Create a default assistant (no workspace_id) - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) - - default_assistant = Assistant( - id="asst_defaultretrieve", - object="assistant", - created_at=1234567890, + default_assistant = self._create_test_assistant( + "asst_defaultretrieve", + workspace_id=None, # No workspace_id = default name="Default Assistant", description="This is a default assistant", - workspace_id=None, # No workspace_id = default ) - (assistants_dir / "asst_defaultretrieve.json").write_text( - default_assistant.model_dump_json() + self._add_assistant_to_db(default_assistant, test_db_session) + # Retrieve the default assistant + response = test_client.get( + "/v1/assistants/asst_defaultretrieve", + headers=test_headers, ) + assert response.status_code == 200 - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service - - try: - with TestClient(app) as client: - # Retrieve the default assistant - response = client.get( - "/v1/assistants/asst_defaultretrieve", - headers=test_headers, - ) - assert response.status_code == 200 - - data = response.json() - assert data["id"] == "asst_defaultretrieve" - assert data["workspace_id"] is None - finally: - app.dependency_overrides.clear() + data = response.json() + assert data["id"] == "asst_defaultretrieve" + assert data["workspace_id"] is None def test_workspace_scoped_assistant_operations_success( - self, test_headers: dict[str, str] + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, ) -> None: """Test that workspace-scoped assistants can be modified and deleted.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) - - workspace_id = test_headers["askui-workspace"] - workspace_assistant = Assistant( - id="asst_workspaceops", - object="assistant", - created_at=1234567890, + workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 + workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 + workspace_id = UUID(test_headers["askui-workspace"]) # WorkspaceId is UUID4 + workspace_assistant = self._create_test_assistant( + "asst_workspaceops", + workspace_id=workspace_id, name="Workspace Assistant", description="This is a workspace assistant", - workspace_id=workspace_id, ) - (assistants_dir / "asst_workspaceops.json").write_text( - workspace_assistant.model_dump_json() + self._add_assistant_to_db(workspace_assistant, test_db_session) + # Modify the workspace assistant + response = test_client.post( + "/v1/assistants/asst_workspaceops", + headers=test_headers, + json={"name": "Modified Workspace Assistant"}, + ) + assert response.status_code == 200 + + data = response.json() + assert data["name"] == "Modified Workspace Assistant" + assert data["workspace_id"] == str(workspace_id) + + # Delete the workspace assistant + response = test_client.delete( + "/v1/assistants/asst_workspaceops", + headers=test_headers, ) + assert response.status_code == 204 - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - - def override_assistant_service() -> AssistantService: - return AssistantService(workspace_path) - - app.dependency_overrides[get_assistant_service] = override_assistant_service - - try: - with TestClient(app) as client: - # Modify the workspace assistant - response = client.post( - "/v1/assistants/asst_workspaceops", - headers=test_headers, - json={"name": "Modified Workspace Assistant"}, - ) - assert response.status_code == 200 - - data = response.json() - assert data["name"] == "Modified Workspace Assistant" - assert data["workspace_id"] == workspace_id - - # Delete the workspace assistant - response = client.delete( - "/v1/assistants/asst_workspaceops", - headers=test_headers, - ) - assert response.status_code == 204 - - # Verify it's deleted - response = client.get( - "/v1/assistants/asst_workspaceops", - headers=test_headers, - ) - assert response.status_code == 404 - finally: - app.dependency_overrides.clear() + # Verify it's deleted + response = test_client.get( + "/v1/assistants/asst_workspaceops", + headers=test_headers, + ) + assert response.status_code == 404 diff --git a/tests/integration/chat/api/test_runs.py b/tests/integration/chat/api/test_runs.py index ed06f357..b99515db 100644 --- a/tests/integration/chat/api/test_runs.py +++ b/tests/integration/chat/api/test_runs.py @@ -1,13 +1,19 @@ """Integration tests for the runs API endpoints.""" import tempfile +from datetime import datetime, timezone from pathlib import Path from unittest.mock import Mock +from uuid import UUID from fastapi import status from fastapi.testclient import TestClient +from sqlalchemy.orm import Session +from askui.chat.api.assistants.models import Assistant +from askui.chat.api.assistants.orms import AssistantOrm from askui.chat.api.assistants.service import AssistantService +from askui.chat.api.models import WorkspaceId from askui.chat.api.runs.models import Run from askui.chat.api.runs.service import RunService from askui.chat.api.threads.models import Thread @@ -25,109 +31,104 @@ def create_mock_mcp_client_manager_manager() -> Mock: class TestRunsAPI: """Test suite for the runs API endpoints.""" - def test_list_runs_empty(self, test_headers: dict[str, str]) -> None: - """Test listing runs when no runs exist.""" - # First create a thread - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=1234567890, - name="Test Thread", + def _create_test_assistant( + self, + assistant_id: str, + workspace_id: WorkspaceId | None = None, + name: str = "Test Assistant", + description: str = "A test assistant", + avatar: str | None = None, + created_at: datetime | None = None, + tools: list[str] | None = None, + system: str | None = None, + ) -> Assistant: + """Create a test assistant model.""" + if created_at is None: + created_at = datetime.fromtimestamp(1234567890, tz=timezone.utc) + if tools is None: + tools = [] + return Assistant( + id=assistant_id, + object="assistant", + created_at=created_at, + name=name, + description=description, + avatar=avatar, + workspace_id=workspace_id, + tools=tools, + system=system, ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) - - def override_runs_service() -> RunService: - mock_assistant_service = Mock() - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - return RunService( - base_dir=workspace_path, - assistant_service=mock_assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=Mock(), - ) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - try: - with TestClient(app) as client: - response = client.get( - "/v1/runs?thread=thread_test123", headers=test_headers - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert data["data"] == [] - assert data["has_more"] is False - finally: - app.dependency_overrides.clear() + def _add_assistant_to_db( + self, assistant: Assistant, test_db_session: Session + ) -> None: + """Add an assistant to the test database.""" + assistant_orm = AssistantOrm.from_model(assistant) + test_db_session.add(assistant_orm) + test_db_session.commit() - def test_list_runs_with_runs(self, test_headers: dict[str, str]) -> None: - """Test listing runs when runs exist.""" + def _create_test_workspace(self) -> Path: + """Create a temporary workspace directory for testing.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) threads_dir = workspace_path / "threads" threads_dir.mkdir(parents=True, exist_ok=True) - runs_dir = workspace_path / "runs" / "thread_test123" - runs_dir.mkdir(parents=True, exist_ok=True) + return workspace_path - # Create a mock thread + def _create_test_thread( + self, workspace_path: Path, thread_id: str = "thread_test123" + ) -> None: + """Create a test thread in the workspace.""" + threads_dir = workspace_path / "threads" mock_thread = Thread( - id="thread_test123", + id=thread_id, object="thread", created_at=1234567890, name="Test Thread", ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + (threads_dir / f"{thread_id}.json").write_text(mock_thread.model_dump_json()) + + def _create_test_run( + self, + workspace_path: Path, + thread_id: str = "thread_test123", + run_id: str = "run_test123", + ) -> None: + """Create a test run in the workspace.""" + runs_dir = workspace_path / "runs" / thread_id + runs_dir.mkdir(parents=True, exist_ok=True) - # Create a mock run mock_run = Run( - id="run_test123", + id=run_id, object="thread.run", created_at=1234567890, - thread_id="thread_test123", + thread_id=thread_id, assistant_id="asst_test123", expires_at=1755846718, # 10 minutes later started_at=1234567890, completed_at=1234567900, ) - (runs_dir / "run_test123.json").write_text(mock_run.model_dump_json()) + (runs_dir / f"{run_id}.json").write_text(mock_run.model_dump_json()) + def _setup_runs_dependencies( + self, workspace_path: Path, test_db_session: Session + ) -> None: + """Set up dependency overrides for runs and threads services.""" from askui.chat.api.app import app from askui.chat.api.runs.dependencies import get_runs_service from askui.chat.api.threads.dependencies import get_thread_service def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - mock_message_service = Mock() mock_run_service = Mock() return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: - mock_assistant_service = Mock() + assistant_service = AssistantService(test_db_session) mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() return RunService( base_dir=workspace_path, - assistant_service=mock_assistant_service, + assistant_service=assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, chat_history_manager=Mock(), settings=Mock(), @@ -136,157 +137,143 @@ def override_runs_service() -> RunService: app.dependency_overrides[get_thread_service] = override_thread_service app.dependency_overrides[get_runs_service] = override_runs_service - try: - with TestClient(app) as client: - response = client.get( - "/v1/runs?thread=thread_test123", headers=test_headers - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert len(data["data"]) == 1 - assert data["data"][0]["id"] == "run_test123" - assert data["data"][0]["status"] == "completed" - assert data["data"][0]["assistant_id"] == "asst_test123" - finally: - app.dependency_overrides.clear() - - def test_list_runs_with_pagination(self, test_headers: dict[str, str]) -> None: - """Test listing runs with pagination parameters.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - runs_dir = workspace_path / "runs" / "thread_test123" + def _create_multiple_test_runs( + self, workspace_path: Path, thread_id: str = "thread_test123", count: int = 5 + ) -> None: + """Create multiple test runs in the workspace.""" + runs_dir = workspace_path / "runs" / thread_id runs_dir.mkdir(parents=True, exist_ok=True) - # Create a mock thread - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=1234567890, - name="Test Thread", - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - # Create multiple mock runs - for i in range(5): + for i in range(count): mock_run = Run( id=f"run_test{i}", object="thread.run", created_at=1234567890 + i, - thread_id="thread_test123", + thread_id=thread_id, assistant_id=f"asst_test{i}", expires_at=1234567890 + i + 600, # 10 minutes later ) (runs_dir / f"run_test{i}.json").write_text(mock_run.model_dump_json()) + def _cleanup_dependencies(self) -> None: + """Clean up dependency overrides.""" from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService + app.dependency_overrides.clear() - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + def test_list_runs_empty( + self, + test_headers: dict[str, str], + test_client: TestClient, + test_db_session: Session, + ) -> None: + """Test listing runs when no runs exist.""" + workspace_path = self._create_test_workspace() + self._create_test_thread(workspace_path) - def override_runs_service() -> RunService: - mock_assistant_service = Mock() - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - return RunService( - base_dir=workspace_path, - assistant_service=mock_assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=Mock(), - ) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service + self._setup_runs_dependencies(workspace_path, test_db_session) try: - with TestClient(app) as client: - response = client.get( - "/v1/runs?thread=thread_test123&limit=3", headers=test_headers - ) + response = test_client.get( + "/v1/runs?thread=thread_test123", headers=test_headers + ) - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert len(data["data"]) == 3 - assert data["has_more"] is True + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert data["data"] == [] + assert data["has_more"] is False finally: - app.dependency_overrides.clear() + self._cleanup_dependencies() - def test_create_run(self, test_headers: dict[str, str]) -> None: - """Test creating a new run.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) + def test_list_runs_with_runs( + self, + test_headers: dict[str, str], + test_client: TestClient, + test_db_session: Session, + ) -> None: + """Test listing runs when runs exist.""" + workspace_path = self._create_test_workspace() + self._create_test_thread(workspace_path) + self._create_test_run(workspace_path) - # Create a mock thread - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=1234567890, - name="Test Thread", - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + self._setup_runs_dependencies(workspace_path, test_db_session) - from askui.chat.api.app import app - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service + try: + response = test_client.get( + "/v1/runs?thread=thread_test123", headers=test_headers + ) - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["id"] == "run_test123" + assert data["data"][0]["status"] == "completed" + assert data["data"][0]["assistant_id"] == "asst_test123" + finally: + self._cleanup_dependencies() - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService( - workspace_path, - mock_message_service, - mock_run_service, - ) + def test_list_runs_with_pagination( + self, + test_headers: dict[str, str], + test_client: TestClient, + test_db_session: Session, + ) -> None: + """Test listing runs with pagination parameters.""" + workspace_path = self._create_test_workspace() + self._create_test_thread(workspace_path) + self._create_multiple_test_runs(workspace_path) - def override_runs_service() -> RunService: - mock_assistant_service = Mock() - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - return RunService( - base_dir=workspace_path, - assistant_service=mock_assistant_service, - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=Mock(), + self._setup_runs_dependencies(workspace_path, test_db_session) + + try: + response = test_client.get( + "/v1/runs?thread=thread_test123&limit=3", headers=test_headers ) - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["data"]) == 3 + assert data["has_more"] is True + finally: + self._cleanup_dependencies() + + def test_create_run( + self, + test_headers: dict[str, str], + test_client: TestClient, + test_db_session: Session, + ) -> None: + """Test creating a new run.""" + workspace_path = self._create_test_workspace() + self._create_test_thread(workspace_path) + self._setup_runs_dependencies(workspace_path, test_db_session) + self._add_assistant_to_db( + self._create_test_assistant(assistant_id="asst_test123"), test_db_session + ) try: - with TestClient(app) as client: - run_data = { - "assistant_id": "asst_test123", - "stream": False, - "metadata": {"key": "value", "number": 42}, - } - response = client.post( - "/v1/threads/thread_test123/runs", - json=run_data, - headers=test_headers, - ) + run_data = { + "assistant_id": "asst_test123", + "stream": False, + "metadata": {"key": "value", "number": 42}, + } + response = test_client.post( + "/v1/threads/thread_test123/runs", + json=run_data, + headers=test_headers, + ) - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["assistant_id"] == "asst_test123" - assert data["thread_id"] == "thread_test123" - assert data["object"] == "thread.run" - assert "id" in data - assert "created_at" in data + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["assistant_id"] == "asst_test123" + assert data["thread_id"] == "thread_test123" + assert data["object"] == "thread.run" + assert "id" in data + assert "created_at" in data finally: - app.dependency_overrides.clear() + self._cleanup_dependencies() def test_create_run_minimal(self, test_headers: dict[str, str]) -> None: """Test creating a run with minimal data.""" @@ -907,173 +894,81 @@ def test_cancel_run_not_found( assert response.status_code == status.HTTP_404_NOT_FOUND def test_create_run_with_custom_assistant( - self, test_headers: dict[str, str] + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, ) -> None: """Test creating a run with a custom assistant.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock thread - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=1234567890, - name="Test Thread", - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - # Create a mock custom assistant - from askui.chat.api.assistants.models import Assistant - - mock_assistant = Assistant( - id="asst_custom123", - object="assistant", - created_at=1234567890, + workspace_path = self._create_test_workspace() + self._create_test_thread(workspace_path) + + # Create a custom assistant in the database + workspace_id = UUID(test_headers["askui-workspace"]) + custom_assistant = self._create_test_assistant( + "asst_custom123", + workspace_id=workspace_id, name="Custom Assistant", tools=["tool1", "tool2"], system="You are a custom assistant.", ) - (assistants_dir / "asst_custom123.json").write_text( - mock_assistant.model_dump_json() - ) - - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) - - def override_runs_service() -> RunService: - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - from askui.chat.api.assistants.service import AssistantService - - return RunService( - base_dir=workspace_path, - assistant_service=AssistantService(workspace_path), - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=Mock(), - ) - - def override_assistant_service() -> AssistantService: - from askui.chat.api.assistants.service import AssistantService - - return AssistantService(workspace_path) + self._add_assistant_to_db(custom_assistant, test_db_session) - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - app.dependency_overrides[get_assistant_service] = override_assistant_service + self._setup_runs_dependencies(workspace_path, test_db_session) try: - with TestClient(app) as client: - response = client.post( - "/v1/threads/thread_test123/runs", - headers=test_headers, - json={"assistant_id": "asst_custom123"}, - ) + response = test_client.post( + "/v1/threads/thread_test123/runs", + headers=test_headers, + json={"assistant_id": "asst_custom123"}, + ) - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["assistant_id"] == "asst_custom123" - assert data["thread_id"] == "thread_test123" - assert data["status"] == "queued" - assert "id" in data - assert "created_at" in data + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["assistant_id"] == "asst_custom123" + assert data["thread_id"] == "thread_test123" + assert data["status"] == "queued" + assert "id" in data + assert "created_at" in data finally: - app.dependency_overrides.clear() + self._cleanup_dependencies() def test_create_run_with_custom_assistant_empty_tools( - self, test_headers: dict[str, str] + self, + test_headers: dict[str, str], + test_db_session: Session, + test_client: TestClient, ) -> None: """Test creating a run with a custom assistant that has empty tools.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - assistants_dir = workspace_path / "assistants" - assistants_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock thread - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=1234567890, - name="Test Thread", - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - - # Create a mock custom assistant with empty tools - from askui.chat.api.assistants.models import Assistant - - mock_assistant = Assistant( - id="asst_customempty123", - object="assistant", - created_at=1234567890, + workspace_path = self._create_test_workspace() + self._create_test_thread(workspace_path) + + # Create a custom assistant with empty tools in the database + workspace_id = UUID(test_headers["askui-workspace"]) + empty_tools_assistant = self._create_test_assistant( + "asst_customempty123", + workspace_id=workspace_id, name="Empty Tools Assistant", tools=[], system="You are a assistant with no tools.", ) - (assistants_dir / "asst_customempty123.json").write_text( - mock_assistant.model_dump_json() - ) + self._add_assistant_to_db(empty_tools_assistant, test_db_session) - from askui.chat.api.app import app - from askui.chat.api.assistants.dependencies import get_assistant_service - from askui.chat.api.runs.dependencies import get_runs_service - from askui.chat.api.threads.dependencies import get_thread_service - - def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) - - def override_runs_service() -> RunService: - mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - from askui.chat.api.assistants.service import AssistantService - - return RunService( - base_dir=workspace_path, - assistant_service=AssistantService(workspace_path), - mcp_client_manager_manager=mock_mcp_client_manager_manager, - chat_history_manager=Mock(), - settings=Mock(), - ) - - def override_assistant_service() -> AssistantService: - from askui.chat.api.assistants.service import AssistantService - - return AssistantService(workspace_path) - - app.dependency_overrides[get_thread_service] = override_thread_service - app.dependency_overrides[get_runs_service] = override_runs_service - app.dependency_overrides[get_assistant_service] = override_assistant_service + self._setup_runs_dependencies(workspace_path, test_db_session) try: - with TestClient(app) as client: - response = client.post( - "/v1/threads/thread_test123/runs", - headers=test_headers, - json={"assistant_id": "asst_customempty123"}, - ) + response = test_client.post( + "/v1/threads/thread_test123/runs", + headers=test_headers, + json={"assistant_id": "asst_customempty123"}, + ) - assert response.status_code == status.HTTP_201_CREATED - data = response.json() - assert data["assistant_id"] == "asst_customempty123" - assert data["thread_id"] == "thread_test123" - assert data["status"] == "queued" - assert "id" in data - assert "created_at" in data + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["assistant_id"] == "asst_customempty123" + assert data["thread_id"] == "thread_test123" + assert data["status"] == "queued" + assert "id" in data + assert "created_at" in data finally: - app.dependency_overrides.clear() + self._cleanup_dependencies() From 8ba28da48480c58a7bc4ca6e3a45babbf6df8e1b Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 14 Oct 2025 16:57:45 +0200 Subject: [PATCH 05/14] chore: fix linting errors --- tests/integration/chat/api/conftest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integration/chat/api/conftest.py b/tests/integration/chat/api/conftest.py index b58adec5..9c845c33 100644 --- a/tests/integration/chat/api/conftest.py +++ b/tests/integration/chat/api/conftest.py @@ -45,7 +45,9 @@ def test_app() -> FastAPI: def test_client( test_app: FastAPI, test_db_session: Session ) -> Generator[TestClient, None, None]: - """Yield a TestClient with common overrides (assistants service uses the test DB).""" + """Yield a TestClient with common overrides + (assistants service uses the test DB). + """ app.dependency_overrides[get_assistant_service] = lambda: AssistantService( test_db_session ) From ed102681a923c4552763d499314fd665b3edb74a Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 14 Oct 2025 17:08:02 +0200 Subject: [PATCH 06/14] docs(migrations): add documentation for migrations --- docs/migrations.md | 202 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 docs/migrations.md diff --git a/docs/migrations.md b/docs/migrations.md new file mode 100644 index 00000000..cb694625 --- /dev/null +++ b/docs/migrations.md @@ -0,0 +1,202 @@ +# Database Migrations + +This document explains how database migrations work in the AskUI Chat system. + +## Overview + +Database migrations are used to manage changes to the database schema and data over time. They ensure that your database structure stays in sync with the application code and handle data transformations when the schema changes. + +## What Are Migrations Used For? + +Migrations in the AskUI Chat system are primarily used for: + +- **Schema Changes**: Creating, modifying, or dropping database tables and columns +- **Data Migrations**: Transforming existing data when the schema changes +- **Persistence Layer Evolution**: Migrating from one persistence format to another (e.g., JSON files to SQLite database) +- **Seed Data**: Populating the database with default data + +### Example Use Cases + +The current migration history shows several real-world examples: + +1. **`4d1e043b4254_create_assistants_table.py`**: Creates the initial `assistants` table with columns for ID, workspace, timestamps, and assistant configuration +2. **`057f82313448_import_json_assistants.py`**: Migrates existing assistant data from JSON files to the new SQLite database +3. **`c35e88ea9595_seed_default_assistants.py`**: Seeds the database with default assistant configurations +4. **`37007a499ca7_remove_assistants_dir.py`**: Cleans up the old JSON-based persistence by removing the assistants directory + +## Automatic Migrations on Startup + +By default, migrations are automatically run when the chat API starts up. This ensures that users are always upgraded to the newest database schema version without manual intervention. + +### Configuration + +The automatic migration behavior is controlled by the `auto_migrate` setting in the database configuration: + +```python +class DbSettings(BaseModel): + auto_migrate: bool = Field( + default=True, + description="Whether to run migrations automatically on startup", + ) +``` + +### Environment Variable Override + +You can disable automatic migrations for debugging purposes using the environment variable: + +```bash +export ASKUI__CHAT_API__DB__AUTO_MIGRATE=false +``` + +When disabled, the application will log: +``` +Automatic migrations are disabled. Skipping migrations... +``` + +## Manual Migration Commands + +You can run migrations manually using the Alembic command-line interface: + +```bash +# Run all pending migrations +pdm run alembic upgrade head + +# Run migrations to a specific revision +pdm run alembic upgrade + +# Downgrade to a previous revision +pdm run alembic downgrade + +# Show current migration status +pdm run alembic current + +# Show migration history +pdm run alembic history + +# Generate a new migration +pdm run alembic revision --autogenerate -m "description of changes" +``` + +## Migration Structure + +### Directory Layout + +``` +src/askui/chat/migrations/ +├── alembic.ini # Alembic configuration +├── env.py # Migration environment setup +├── runner.py # Migration runner for programmatic execution +├── script.py.mako # Template for new migration files +├── shared/ # Shared utilities and models for migrations +│ ├── assistants/ # Assistant-related migration utilities +│ ├── models.py # Shared data models +│ └── settings.py # Settings for migrations +└── versions/ # Individual migration files + ├── 4d1e043b4254_create_assistants_table.py + ├── 057f82313448_import_json_assistants.py + ├── c35e88ea9595_seed_default_assistants.py + └── 37007a499ca7_remove_assistants_dir.py +``` + +### Migration File Structure + +Each migration file follows this structure: + +```python +"""migration_description + +Revision ID: +Revises: +Create Date: + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "" +down_revision: Union[str, None] = "" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Apply the migration changes.""" + # Migration logic here + pass + + +def downgrade() -> None: + """Revert the migration changes.""" + # Rollback logic here + pass +``` + +## Migration Execution Flow + +1. **Startup Check**: When the chat API starts, it checks the `auto_migrate` setting +2. **Migration Runner**: If enabled, calls `run_migrations()` from `runner.py` +3. **Alembic Execution**: Uses Alembic's `upgrade` command to apply all pending migrations +4. **Database Connection**: Connects to the database using settings from `env.py` +5. **Schema Application**: Applies each migration in sequence until reaching the "head" revision + +## Database Configuration + +The migration system uses the same database configuration as the main application: + +- **Database URL**: Configured via `ASKUI__CHAT_API__DB__URL` (defaults to SQLite) +- **Connection**: Uses the same SQLAlchemy engine as the main application +- **Metadata**: Automatically detects schema changes from SQLAlchemy models + +## Best Practices + +### Creating New Migrations + +1. **Use Autogenerate**: Let Alembic detect schema changes automatically: + ```bash + pdm run alembic revision --autogenerate -m "add new column to table" + ``` + +2. **Review Generated Code**: Always review and test autogenerated migrations before applying + +3. **Handle Data Migrations**: For complex data transformations, write custom migration logic + +4. **Test Both Directions**: Ensure both `upgrade()` and `downgrade()` functions work correctly + +### Migration Safety + +1. **Backup First**: Always backup database before running migrations so that it can be easily rolled back if something goes wrong +2. **Test Locally**: Test migrations on a copy of production data +3. **Rollback Plan**: Have a rollback strategy for critical migrations +4. **Batch Operations**: For large data migrations, process data in batches to avoid memory issues +5. **Keep Old Code Around**: Keep old code versioned around so that migrations are independent of the version of AskUI chat + +## Troubleshooting + +### Common Issues + +1. **Migration Conflicts**: If multiple developers create migrations simultaneously, you may need to resolve conflicts manually +2. **Data Loss**: Some migrations (like dropping columns) can cause data loss - always review carefully +3. **Performance**: Large data migrations can be slow - consider running them not during startup but in the background maintaining compatibility with old code for as long as it runs or just disabling certain apis for that period of time + +### Debugging + +1. **Check Migration Status**: + ```bash + pdm run alembic current + ``` + +2. **View Migration History**: + ```bash + pdm run alembic history --verbose + ``` + +3. **Disable Auto-Migration**: Use the environment variable to disable automatic migrations during debugging + +## Related Documentation + +- [Alembic Documentation](https://alembic.sqlalchemy.org/) - Official Alembic migration tool documentation +- [SQLAlchemy Documentation](https://docs.sqlalchemy.org/) - SQLAlchemy ORM and database toolkit +- [Database Models](../src/askui/chat/api/) - Current database schema and models From bfca8a3b7d880158c6b51636068628be57ce4df9 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 21 Oct 2025 15:30:23 +0200 Subject: [PATCH 07/14] feat(chat/mcp_configs): migrate MCP configurations to database --- src/askui/chat/api/app.py | 4 +- src/askui/chat/api/db/orm/types.py | 1 - .../chat/api/mcp_configs/dependencies.py | 7 +- src/askui/chat/api/mcp_configs/models.py | 8 +- src/askui/chat/api/mcp_configs/orms.py | 36 ++++ src/askui/chat/api/mcp_configs/router.py | 8 +- src/askui/chat/api/mcp_configs/service.py | 184 +++++++++--------- .../migrations/shared/assistants/models.py | 10 +- .../migrations/shared/mcp_configs/__init__.py | 0 .../migrations/shared/mcp_configs/models.py | 41 ++++ src/askui/chat/migrations/shared/utils.py | 10 + .../5a1b2c3d4e5f_create_mcp_configs_table.py | 50 +++++ .../6b2c3d4e5f6a_import_json_mcp_configs.py | 109 +++++++++++ .../7c3d4e5f6a7b_remove_mcp_configs_dir.py | 51 +++++ 14 files changed, 406 insertions(+), 113 deletions(-) create mode 100644 src/askui/chat/api/mcp_configs/orms.py create mode 100644 src/askui/chat/migrations/shared/mcp_configs/__init__.py create mode 100644 src/askui/chat/migrations/shared/mcp_configs/models.py create mode 100644 src/askui/chat/migrations/versions/5a1b2c3d4e5f_create_mcp_configs_table.py create mode 100644 src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py create mode 100644 src/askui/chat/migrations/versions/7c3d4e5f6a7b_remove_mcp_configs_dir.py diff --git a/src/askui/chat/api/app.py b/src/askui/chat/api/app.py index 28f93c4d..d31e9492 100644 --- a/src/askui/chat/api/app.py +++ b/src/askui/chat/api/app.py @@ -8,6 +8,7 @@ from fastmcp import FastMCP from askui.chat.api.assistants.router import router as assistants_router +from askui.chat.api.db.session import get_session from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_settings from askui.chat.api.files.router import router as files_router from askui.chat.api.health.router import router as health_router @@ -45,7 +46,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 else: logger.info("Automatic migrations are disabled. Skipping migrations...") logger.info("Seeding default MCP configurations...") - mcp_config_service = get_mcp_config_service(settings=settings) + session = next(get_session()) + mcp_config_service = get_mcp_config_service(session=session, settings=settings) mcp_config_service.seed() yield logger.info("Disconnecting all MCP clients...") diff --git a/src/askui/chat/api/db/orm/types.py b/src/askui/chat/api/db/orm/types.py index 0087eff6..2c8072b2 100644 --- a/src/askui/chat/api/db/orm/types.py +++ b/src/askui/chat/api/db/orm/types.py @@ -30,7 +30,6 @@ def process_result_value(self, value: str | None, dialect: Any) -> str | None: RunId = create_prefixed_id_type("run") FileId = create_prefixed_id_type("file") WorkflowId = create_prefixed_id_type("workflow") -McpConfigId = create_prefixed_id_type("mcp") class UnixDatetime(TypeDecorator[datetime]): diff --git a/src/askui/chat/api/mcp_configs/dependencies.py b/src/askui/chat/api/mcp_configs/dependencies.py index 023b2bcb..fc807081 100644 --- a/src/askui/chat/api/mcp_configs/dependencies.py +++ b/src/askui/chat/api/mcp_configs/dependencies.py @@ -1,13 +1,16 @@ from fastapi import Depends +from askui.chat.api.db.session import SessionDep from askui.chat.api.dependencies import SettingsDep from askui.chat.api.mcp_configs.service import McpConfigService from askui.chat.api.settings import Settings -def get_mcp_config_service(settings: Settings = SettingsDep) -> McpConfigService: +def get_mcp_config_service( + session: SessionDep, settings: Settings = SettingsDep +) -> McpConfigService: """Get McpConfigService instance.""" - return McpConfigService(settings.data_dir, settings.mcp_configs) + return McpConfigService(session, settings.mcp_configs) McpConfigServiceDep = Depends(get_mcp_config_service) diff --git a/src/askui/chat/api/mcp_configs/models.py b/src/askui/chat/api/mcp_configs/models.py index 049da895..a98fcfbe 100644 --- a/src/askui/chat/api/mcp_configs/models.py +++ b/src/askui/chat/api/mcp_configs/models.py @@ -29,11 +29,11 @@ class McpConfigBase(BaseModel): mcp_server: McpServer -class McpConfigCreateParams(McpConfigBase): +class McpConfigCreate(McpConfigBase): """Parameters for creating an MCP configuration.""" -class McpConfigModifyParams(BaseModelWithNotGiven): +class McpConfigModify(BaseModelWithNotGiven): """Parameters for modifying an MCP configuration.""" name: str | NotGiven = NOT_GIVEN @@ -49,7 +49,7 @@ class McpConfig(McpConfigBase, WorkspaceResource): @classmethod def create( - cls, workspace_id: WorkspaceId, params: McpConfigCreateParams + cls, workspace_id: WorkspaceId | None, params: McpConfigCreate ) -> "McpConfig": return cls( id=generate_time_ordered_id("mcpcnf"), @@ -58,7 +58,7 @@ def create( **params.model_dump(), ) - def modify(self, params: McpConfigModifyParams) -> "McpConfig": + def modify(self, params: McpConfigModify) -> "McpConfig": return McpConfig.model_validate( { **self.model_dump(), diff --git a/src/askui/chat/api/mcp_configs/orms.py b/src/askui/chat/api/mcp_configs/orms.py new file mode 100644 index 00000000..b2bc65e1 --- /dev/null +++ b/src/askui/chat/api/mcp_configs/orms.py @@ -0,0 +1,36 @@ +"""MCP configuration database model.""" + +from datetime import datetime +from typing import Any +from uuid import UUID + +from sqlalchemy import JSON, String, Uuid +from sqlalchemy.orm import Mapped, mapped_column + +from askui.chat.api.db.orm.base import Base +from askui.chat.api.db.orm.types import UnixDatetime, create_prefixed_id_type +from askui.chat.api.mcp_configs.models import McpConfig + +McpConfigId = create_prefixed_id_type("mcpcnf") + + +class McpConfigOrm(Base): + """MCP configuration database model.""" + + __tablename__ = "mcp_configs" + + id: Mapped[str] = mapped_column(McpConfigId, primary_key=True) + workspace_id: Mapped[UUID | None] = mapped_column(Uuid, nullable=True, index=True) + created_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) + name: Mapped[str] = mapped_column(String, nullable=False) + mcp_server: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False) + + @classmethod + def from_model(cls, model: McpConfig) -> "McpConfigOrm": + return cls( + **model.model_dump(exclude={"object", "created_at"}), + created_at=model.created_at, + ) + + def to_model(self) -> McpConfig: + return McpConfig.model_validate(self, from_attributes=True) diff --git a/src/askui/chat/api/mcp_configs/router.py b/src/askui/chat/api/mcp_configs/router.py index 07f3b7ac..41b9e3c5 100644 --- a/src/askui/chat/api/mcp_configs/router.py +++ b/src/askui/chat/api/mcp_configs/router.py @@ -6,8 +6,8 @@ from askui.chat.api.mcp_configs.dependencies import McpConfigServiceDep from askui.chat.api.mcp_configs.models import ( McpConfig, - McpConfigCreateParams, - McpConfigModifyParams, + McpConfigCreate, + McpConfigModify, ) from askui.chat.api.mcp_configs.service import McpConfigService from askui.chat.api.models import McpConfigId, WorkspaceId @@ -27,7 +27,7 @@ def list_mcp_configs( @router.post("", status_code=status.HTTP_201_CREATED, response_model_exclude_none=True) def create_mcp_config( - params: McpConfigCreateParams, + params: McpConfigCreate, askui_workspace: Annotated[WorkspaceId, Header()], mcp_config_service: McpConfigService = McpConfigServiceDep, ) -> McpConfig: @@ -50,7 +50,7 @@ def retrieve_mcp_config( @router.post("/{mcp_config_id}", response_model_exclude_none=True) def modify_mcp_config( mcp_config_id: McpConfigId, - params: McpConfigModifyParams, + params: McpConfigModify, askui_workspace: Annotated[WorkspaceId, Header()], mcp_config_service: McpConfigService = McpConfigServiceDep, ) -> McpConfig: diff --git a/src/askui/chat/api/mcp_configs/service.py b/src/askui/chat/api/mcp_configs/service.py index 4c376142..72324195 100644 --- a/src/askui/chat/api/mcp_configs/service.py +++ b/src/askui/chat/api/mcp_configs/service.py @@ -1,75 +1,77 @@ -from pathlib import Path - from fastmcp.mcp_config import MCPConfig +from sqlalchemy import or_ +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session +from askui.chat.api.db.queries import list_all from askui.chat.api.mcp_configs.models import ( McpConfig, - McpConfigCreateParams, + McpConfigCreate, McpConfigId, - McpConfigModifyParams, + McpConfigModify, ) +from askui.chat.api.mcp_configs.orms import McpConfigOrm from askui.chat.api.models import WorkspaceId -from askui.chat.api.utils import build_workspace_filter_fn from askui.utils.api_utils import ( LIST_LIMIT_MAX, - ConflictError, ForbiddenError, LimitReachedError, ListQuery, ListResponse, NotFoundError, - list_resources, ) class McpConfigService: - """Service for managing McpConfig resources with filesystem persistence.""" + """Service for managing McpConfig resources with database persistence.""" - def __init__(self, base_dir: Path, seeds: list[McpConfig]) -> None: - self._base_dir = base_dir - self._mcp_configs_dir = base_dir / "mcp_configs" + def __init__(self, session: Session, seeds: list[McpConfig]) -> None: + self._session = session self._seeds = seeds - def _get_mcp_config_path( - self, mcp_config_id: McpConfigId, new: bool = False - ) -> Path: - mcp_config_path = self._mcp_configs_dir / f"{mcp_config_id}.json" - exists = mcp_config_path.exists() - if new and exists: - error_msg = f"MCP configuration {mcp_config_id} already exists" - raise ConflictError(error_msg) - if not new and not exists: - error_msg = f"MCP configuration {mcp_config_id} not found" - raise NotFoundError(error_msg) - return mcp_config_path - def list_( self, workspace_id: WorkspaceId | None, query: ListQuery ) -> ListResponse[McpConfig]: - return list_resources( - self._mcp_configs_dir, - query, - McpConfig, - filter_fn=build_workspace_filter_fn(workspace_id, McpConfig), + q = self._session.query(McpConfigOrm).filter( + or_( + McpConfigOrm.workspace_id == workspace_id, + McpConfigOrm.workspace_id.is_(None), + ), + ) + orms: list[McpConfigOrm] + orms, has_more = list_all(q, query, McpConfigOrm.id) + data = [orm.to_model() for orm in orms] + return ListResponse( + data=data, + has_more=has_more, + first_id=data[0].id if data else None, + last_id=data[-1].id if data else None, + ) + + def _find_by_id( + self, workspace_id: WorkspaceId | None, mcp_config_id: McpConfigId + ) -> McpConfigOrm: + mcp_config_orm: McpConfigOrm | None = ( + self._session.query(McpConfigOrm) + .filter( + McpConfigOrm.id == mcp_config_id, + or_( + McpConfigOrm.workspace_id == workspace_id, + McpConfigOrm.workspace_id.is_(None), + ), + ) + .first() ) + if mcp_config_orm is None: + error_msg = f"MCP configuration {mcp_config_id} not found" + raise NotFoundError(error_msg) + return mcp_config_orm def retrieve( self, workspace_id: WorkspaceId | None, mcp_config_id: McpConfigId ) -> McpConfig: - try: - mcp_config_path = self._get_mcp_config_path(mcp_config_id) - mcp_config = McpConfig.model_validate_json(mcp_config_path.read_text()) - if not ( - mcp_config.workspace_id is None - or mcp_config.workspace_id == workspace_id - ): - error_msg = f"MCP configuration {mcp_config_id} not found" - raise NotFoundError(error_msg) - except FileNotFoundError as e: - error_msg = f"MCP configuration {mcp_config_id} not found" - raise NotFoundError(error_msg) from e - else: - return mcp_config + mcp_config_model = self._find_by_id(workspace_id, mcp_config_id) + return mcp_config_model.to_model() def retrieve_fast_mcp_config( self, workspace_id: WorkspaceId | None @@ -83,38 +85,36 @@ def retrieve_fast_mcp_config( } return MCPConfig(mcpServers=mcp_servers_dict) if mcp_servers_dict else None - def _check_limit(self, workspace_id: WorkspaceId | None) -> None: - limit = LIST_LIMIT_MAX - list_result = self.list_(workspace_id, ListQuery(limit=limit)) - if len(list_result.data) >= limit: - error_msg = ( - "MCP configuration limit reached. " - f"You may only have {limit} MCP configurations. " - "You can delete some MCP configurations to create new ones. " - ) - raise LimitReachedError(error_msg) - def create( - self, workspace_id: WorkspaceId, params: McpConfigCreateParams + self, workspace_id: WorkspaceId | None, params: McpConfigCreate ) -> McpConfig: - self._check_limit(workspace_id) - mcp_config = McpConfig.create(workspace_id, params) - self._save(mcp_config, new=True) - return mcp_config + try: + mcp_config = McpConfig.create(workspace_id, params) + mcp_config_model = McpConfigOrm.from_model(mcp_config) + self._session.add(mcp_config_model) + self._session.commit() + except IntegrityError as e: + if "MCP configuration limit reached" in str(e): + raise LimitReachedError(str(e)) from e + raise + else: + return mcp_config def modify( self, workspace_id: WorkspaceId | None, mcp_config_id: McpConfigId, - params: McpConfigModifyParams, + params: McpConfigModify, + force: bool = False, ) -> McpConfig: - mcp_config = self.retrieve(workspace_id, mcp_config_id) - if mcp_config.workspace_id is None: + mcp_config_model = self._find_by_id(workspace_id, mcp_config_id) + if mcp_config_model.workspace_id is None and not force: error_msg = f"Default MCP configuration {mcp_config_id} cannot be modified" raise ForbiddenError(error_msg) - modified = mcp_config.modify(params) - self._save(modified) - return modified + mcp_config_model.update(params.model_dump()) + self._session.commit() + self._session.refresh(mcp_config_model) + return mcp_config_model.to_model() def delete( self, @@ -122,37 +122,35 @@ def delete( mcp_config_id: McpConfigId, force: bool = False, ) -> None: - try: - mcp_config = self.retrieve(workspace_id, mcp_config_id) - if mcp_config.workspace_id is None and not force: - error_msg = ( - f"Default MCP configuration {mcp_config_id} cannot be deleted" - ) - raise ForbiddenError(error_msg) - self._get_mcp_config_path(mcp_config_id).unlink() - except FileNotFoundError as e: + # Use a single query to find and delete atomically + mcp_config_model = ( + self._session.query(McpConfigOrm) + .filter( + McpConfigOrm.id == mcp_config_id, + or_( + McpConfigOrm.workspace_id == workspace_id, + McpConfigOrm.workspace_id.is_(None), + ), + ) + .first() + ) + + if mcp_config_model is None: error_msg = f"MCP configuration {mcp_config_id} not found" - if not force: - raise NotFoundError(error_msg) from e - except NotFoundError: - if not force: - raise + raise NotFoundError(error_msg) - def _save(self, mcp_config: McpConfig, new: bool = False) -> None: - self._mcp_configs_dir.mkdir(parents=True, exist_ok=True) - mcp_config_file = self._get_mcp_config_path(mcp_config.id, new=new) - mcp_config_file.write_text( - mcp_config.model_dump_json( - exclude_unset=True, exclude_none=True, exclude_defaults=True - ), - encoding="utf-8", - ) + if mcp_config_model.workspace_id is None and not force: + error_msg = f"Default MCP configuration {mcp_config_id} cannot be deleted" + raise ForbiddenError(error_msg) + + self._session.delete(mcp_config_model) + self._session.commit() def seed(self) -> None: """Seed the MCP configuration service with default MCP configurations.""" for seed in self._seeds: - try: - self.delete(None, seed.id, force=True) - self._save(seed, new=True) - except ConflictError: # noqa: PERF203 - self._save(seed) + with self._session.begin(): + self._session.query(McpConfigOrm).filter( + McpConfigOrm.id == seed.id + ).delete() + self._session.add(McpConfigOrm.from_model(seed)) diff --git a/src/askui/chat/migrations/shared/assistants/models.py b/src/askui/chat/migrations/shared/assistants/models.py index 99d9ddf8..c02d896f 100644 --- a/src/askui/chat/migrations/shared/assistants/models.py +++ b/src/askui/chat/migrations/shared/assistants/models.py @@ -3,16 +3,10 @@ from pydantic import BaseModel, BeforeValidator, Field from askui.chat.migrations.shared.models import UnixDatetimeV1, WorkspaceIdV1 - - -def add_prefix(id_: str) -> str: - if id_.startswith("asst_"): - return id_ - return f"asst_{id_}" - +from askui.chat.migrations.shared.utils import build_prefixer AssistantIdV1 = Annotated[ - str, Field(pattern=r"^asst_[a-z0-9]+$"), BeforeValidator(add_prefix) + str, Field(pattern=r"^asst_[a-z0-9]+$"), BeforeValidator(build_prefixer("asst")) ] diff --git a/src/askui/chat/migrations/shared/mcp_configs/__init__.py b/src/askui/chat/migrations/shared/mcp_configs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/migrations/shared/mcp_configs/models.py b/src/askui/chat/migrations/shared/mcp_configs/models.py new file mode 100644 index 00000000..4c960efb --- /dev/null +++ b/src/askui/chat/migrations/shared/mcp_configs/models.py @@ -0,0 +1,41 @@ +from typing import Annotated, Any, Literal + +from fastmcp.mcp_config import RemoteMCPServer as _RemoteMCPServer +from fastmcp.mcp_config import StdioMCPServer +from httpx import Auth +from pydantic import BaseModel, BeforeValidator, Field + +from askui.chat.migrations.shared.models import UnixDatetimeV1, WorkspaceIdV1 +from askui.chat.migrations.shared.utils import build_prefixer + +McpConfigIdV1 = Annotated[ + str, Field(pattern=r"^mcpcnf_[a-z0-9]+$"), BeforeValidator(build_prefixer("mcpcnf")) +] + + +class RemoteMCPServerV1(_RemoteMCPServer): + auth: Annotated[ + str | Literal["oauth"] | Auth | None, # noqa: PYI051 + Field( + description='Either a string representing a Bearer token or the literal "oauth" to use OAuth authentication.', + ), + ] = None + + +McpServerV1 = StdioMCPServer | RemoteMCPServerV1 + + +class McpConfigV1(BaseModel): + id: McpConfigIdV1 + object: Literal["mcp_config"] = "mcp_config" + created_at: UnixDatetimeV1 + workspace_id: WorkspaceIdV1 | None = None + name: str + mcp_server: McpServerV1 + + def to_db_dict(self) -> dict[str, Any]: + return { + **self.model_dump(exclude={"id", "object", "workspace_id"}), + "id": self.id.removeprefix("mcpcnf_"), + "workspace_id": str(self.workspace_id) if self.workspace_id else None, + } diff --git a/src/askui/chat/migrations/shared/utils.py b/src/askui/chat/migrations/shared/utils.py index dc4c8c9e..5345ed1c 100644 --- a/src/askui/chat/migrations/shared/utils.py +++ b/src/askui/chat/migrations/shared/utils.py @@ -1,7 +1,17 @@ from datetime import datetime, timezone +from typing import Callable from pydantic import AwareDatetime def now_v1() -> AwareDatetime: return datetime.now(tz=timezone.utc) + + +def build_prefixer(prefix: str) -> Callable[[str], str]: + def prefixer(id_: str) -> str: + if id_.startswith(prefix): + return id_ + return f"{prefix}_{id_}" + + return prefixer diff --git a/src/askui/chat/migrations/versions/5a1b2c3d4e5f_create_mcp_configs_table.py b/src/askui/chat/migrations/versions/5a1b2c3d4e5f_create_mcp_configs_table.py new file mode 100644 index 00000000..c2a2e6ba --- /dev/null +++ b/src/askui/chat/migrations/versions/5a1b2c3d4e5f_create_mcp_configs_table.py @@ -0,0 +1,50 @@ +"""create_mcp_configs_table + +Revision ID: 5a1b2c3d4e5f +Revises: 37007a499ca7 +Create Date: 2025-01-27 10:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "5a1b2c3d4e5f" +down_revision: Union[str, None] = "37007a499ca7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "mcp_configs", + sa.Column("id", sa.String(24), nullable=False, primary_key=True), + sa.Column("workspace_id", sa.Uuid(), nullable=True, index=True), + sa.Column("created_at", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("mcp_server", sa.JSON(), nullable=False), + ) + + # Add constraint to enforce MCP configuration limit + op.execute(""" + CREATE TRIGGER check_mcp_config_limit + BEFORE INSERT ON mcp_configs + WHEN ( + SELECT COUNT(*) FROM mcp_configs + WHERE workspace_id = NEW.workspace_id OR workspace_id IS NULL + ) >= 100 + BEGIN + SELECT RAISE(ABORT, 'MCP configuration limit reached. You may only have 100 MCP configurations. You can delete some MCP configurations to create new ones.'); + END; + """) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("mcp_configs") + # ### end Alembic commands ### diff --git a/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py b/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py new file mode 100644 index 00000000..a2f80526 --- /dev/null +++ b/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py @@ -0,0 +1,109 @@ +"""import_json_mcp_configs + +Revision ID: 6b2c3d4e5f6a +Revises: 5a1b2c3d4e5f +Create Date: 2025-01-27 10:01:00.000000 + +""" + +import json +import logging +from typing import Sequence, Union + +from alembic import op +from sqlalchemy import MetaData, Table + +from askui.chat.migrations.shared.mcp_configs.models import McpConfigV1 +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "6b2c3d4e5f6a" +down_revision: Union[str, None] = "5a1b2c3d4e5f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + + +BATCH_SIZE = 100 + + +def _insert_mcp_configs_batch( + mcp_configs_table: Table, mcp_configs_batch: list[McpConfigV1] +) -> None: + """Insert a batch of MCP configs into the database.""" + op.bulk_insert( + mcp_configs_table, + [mcp_config.to_db_dict() for mcp_config in mcp_configs_batch], + ) + + +settings = SettingsV1() +mcp_configs_dir = settings.data_dir / "mcp_configs" + + +def upgrade() -> None: + """Import existing MCP configs from JSON files.""" + + # Skip if directory doesn't exist (e.g., first-time setup) + if not mcp_configs_dir.exists(): + return + + # Get the table from the current database schema + connection = op.get_bind() + mcp_configs_table = Table("mcp_configs", MetaData(), autoload_with=connection) + + # Get all JSON files in the mcp_configs directory + json_files = list(mcp_configs_dir.glob("*.json")) + + # Process MCP configs in batches + mcp_configs_batch: list[McpConfigV1] = [] + + for json_file in json_files: + try: + content = json_file.read_text(encoding="utf-8").strip() + data = json.loads(content) + mcp_config = McpConfigV1.model_validate(data) + mcp_configs_batch.append(mcp_config) + + if len(mcp_configs_batch) >= BATCH_SIZE: + _insert_mcp_configs_batch(mcp_configs_table, mcp_configs_batch) + mcp_configs_batch.clear() + except Exception: # noqa: PERF203 + error_msg = "Failed to import" + logger.exception(error_msg, extra={"json_file": str(json_file)}) + continue + + # Insert remaining MCP configs in the final batch + if mcp_configs_batch: + _insert_mcp_configs_batch(mcp_configs_table, mcp_configs_batch) + + +def downgrade() -> None: + """Recreate JSON files for MCP configs during downgrade.""" + + mcp_configs_dir.mkdir(parents=True, exist_ok=True) + + connection = op.get_bind() + mcp_configs_table = Table("mcp_configs", MetaData(), autoload_with=connection) + + # Fetch all MCP configs from the database + result = connection.execute(mcp_configs_table.select()) + rows = result.fetchall() + if not rows: + return + + for row in rows: + try: + mcp_config: McpConfigV1 = McpConfigV1.model_validate( + row, from_attributes=True + ) + json_path = mcp_configs_dir / f"{mcp_config.id}.json" + if json_path.exists(): + continue + with json_path.open("w", encoding="utf-8") as f: + f.write(json.dumps(mcp_config.model_dump())) + except Exception as e: # noqa: PERF203 + error_msg = f"Failed to export row to json: {e}" + logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) + continue diff --git a/src/askui/chat/migrations/versions/7c3d4e5f6a7b_remove_mcp_configs_dir.py b/src/askui/chat/migrations/versions/7c3d4e5f6a7b_remove_mcp_configs_dir.py new file mode 100644 index 00000000..17419cee --- /dev/null +++ b/src/askui/chat/migrations/versions/7c3d4e5f6a7b_remove_mcp_configs_dir.py @@ -0,0 +1,51 @@ +"""remove_mcp_configs_dir + +Revision ID: 7c3d4e5f6a7b +Revises: 6b2c3d4e5f6a +Create Date: 2025-01-27 10:02:00.000000 + +""" + +import logging +import shutil +from typing import Sequence, Union + +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "7c3d4e5f6a7b" +down_revision: Union[str, None] = "6b2c3d4e5f6a" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + +settings = SettingsV1() +mcp_configs_dir = settings.data_dir / "mcp_configs" + + +def upgrade() -> None: + """Remove the mcp_configs directory and all its contents.""" + + # Skip if directory doesn't exist + if not mcp_configs_dir.exists(): + logger.info("MCP configs directory does not exist, skipping removal") + return + + try: + shutil.rmtree(mcp_configs_dir) + logger.info( + "Successfully removed mcp_configs directory", + extra={"mcp_configs_dir": str(mcp_configs_dir)}, + ) + except Exception as e: + error_msg = "Failed to remove mcp_configs directory" + logger.exception( + error_msg, + extra={"mcp_configs_dir": str(mcp_configs_dir)}, + ) + raise RuntimeError(error_msg) from e + + +def downgrade() -> None: + mcp_configs_dir.mkdir(parents=True, exist_ok=True) From a5e1190a56c282765ec748ccde66ee63c2dcc440 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 21 Oct 2025 16:20:33 +0200 Subject: [PATCH 08/14] feat(chat/files): migrate files to database --- src/askui/chat/api/db/orm/types.py | 1 - src/askui/chat/api/files/dependencies.py | 13 +- src/askui/chat/api/files/models.py | 10 +- src/askui/chat/api/files/orms.py | 36 +++++ src/askui/chat/api/files/router.py | 24 ++- src/askui/chat/api/files/service.py | 153 +++++++++++------- .../chat/migrations/shared/files/__init__.py | 0 .../chat/migrations/shared/files/models.py | 27 ++++ .../6b2c3d4e5f6a_import_json_mcp_configs.py | 1 - .../8d9e0f1a2b3c_create_files_table.py | 38 +++++ .../9e0f1a2b3c4d_import_json_files.py | 119 ++++++++++++++ .../a0f1a2b3c4d5_remove_files_dirs.py | 68 ++++++++ 12 files changed, 412 insertions(+), 78 deletions(-) create mode 100644 src/askui/chat/api/files/orms.py create mode 100644 src/askui/chat/migrations/shared/files/__init__.py create mode 100644 src/askui/chat/migrations/shared/files/models.py create mode 100644 src/askui/chat/migrations/versions/8d9e0f1a2b3c_create_files_table.py create mode 100644 src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py create mode 100644 src/askui/chat/migrations/versions/a0f1a2b3c4d5_remove_files_dirs.py diff --git a/src/askui/chat/api/db/orm/types.py b/src/askui/chat/api/db/orm/types.py index 2c8072b2..aab81624 100644 --- a/src/askui/chat/api/db/orm/types.py +++ b/src/askui/chat/api/db/orm/types.py @@ -28,7 +28,6 @@ def process_result_value(self, value: str | None, dialect: Any) -> str | None: ThreadId = create_prefixed_id_type("thread") MessageId = create_prefixed_id_type("msg") RunId = create_prefixed_id_type("run") -FileId = create_prefixed_id_type("file") WorkflowId = create_prefixed_id_type("workflow") diff --git a/src/askui/chat/api/files/dependencies.py b/src/askui/chat/api/files/dependencies.py index 75f2f39c..babeb7e2 100644 --- a/src/askui/chat/api/files/dependencies.py +++ b/src/askui/chat/api/files/dependencies.py @@ -1,14 +1,17 @@ -from pathlib import Path - from fastapi import Depends -from askui.chat.api.dependencies import WorkspaceDirDep +from askui.chat.api.db.session import SessionDep +from askui.chat.api.dependencies import SettingsDep from askui.chat.api.files.service import FileService +from askui.chat.api.settings import Settings -def get_file_service(workspace_dir: Path = WorkspaceDirDep) -> FileService: +def get_file_service( + session: SessionDep, + settings: Settings = SettingsDep, +) -> FileService: """Get FileService instance.""" - return FileService(workspace_dir) + return FileService(session, settings.data_dir) FileServiceDep = Depends(get_file_service) diff --git a/src/askui/chat/api/files/models.py b/src/askui/chat/api/files/models.py index cf55c127..f542f6ae 100644 --- a/src/askui/chat/api/files/models.py +++ b/src/askui/chat/api/files/models.py @@ -3,8 +3,7 @@ from pydantic import BaseModel, Field -from askui.chat.api.models import FileId -from askui.utils.api_utils import Resource +from askui.chat.api.models import FileId, WorkspaceId, WorkspaceResource from askui.utils.datetime_utils import UnixDatetime, now from askui.utils.id_utils import generate_time_ordered_id @@ -16,11 +15,11 @@ class FileBase(BaseModel): media_type: str -class FileCreateParams(FileBase): +class FileCreate(FileBase): filename: str | None = None -class File(FileBase, Resource): +class File(FileBase, WorkspaceResource): """A file that can be stored and managed.""" id: FileId @@ -29,7 +28,7 @@ class File(FileBase, Resource): filename: str = Field(min_length=1) @classmethod - def create(cls, params: FileCreateParams) -> "File": + def create(cls, workspace_id: WorkspaceId | None, params: FileCreate) -> "File": id_ = generate_time_ordered_id("file") filename = ( params.filename or f"{id_}{mimetypes.guess_extension(params.media_type)}" @@ -37,6 +36,7 @@ def create(cls, params: FileCreateParams) -> "File": return cls( id=id_, created_at=now(), + workspace_id=workspace_id, filename=filename, **params.model_dump(exclude={"filename"}), ) diff --git a/src/askui/chat/api/files/orms.py b/src/askui/chat/api/files/orms.py new file mode 100644 index 00000000..9a9ea0ad --- /dev/null +++ b/src/askui/chat/api/files/orms.py @@ -0,0 +1,36 @@ +"""File database model.""" + +from datetime import datetime +from uuid import UUID + +from sqlalchemy import Integer, String, Uuid +from sqlalchemy.orm import Mapped, mapped_column + +from askui.chat.api.db.orm.base import Base +from askui.chat.api.db.orm.types import UnixDatetime, create_prefixed_id_type +from askui.chat.api.files.models import File + +FileId = create_prefixed_id_type("file") + + +class FileOrm(Base): + """File database model.""" + + __tablename__ = "files" + + id: Mapped[str] = mapped_column(FileId, primary_key=True) + workspace_id: Mapped[UUID | None] = mapped_column(Uuid, nullable=True, index=True) + created_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) + filename: Mapped[str] = mapped_column(String, nullable=False) + size: Mapped[int] = mapped_column(Integer, nullable=False) + media_type: Mapped[str] = mapped_column(String, nullable=False) + + @classmethod + def from_model(cls, model: File) -> "FileOrm": + return cls( + **model.model_dump(exclude={"object", "created_at"}), + created_at=model.created_at, + ) + + def to_model(self) -> File: + return File.model_validate(self, from_attributes=True) diff --git a/src/askui/chat/api/files/router.py b/src/askui/chat/api/files/router.py index 3ebcf8cd..7356c85a 100644 --- a/src/askui/chat/api/files/router.py +++ b/src/askui/chat/api/files/router.py @@ -1,11 +1,14 @@ -from fastapi import APIRouter, UploadFile, status +from typing import Annotated + +from fastapi import APIRouter, Header, UploadFile, status from fastapi.responses import FileResponse from askui.chat.api.dependencies import ListQueryDep from askui.chat.api.files.dependencies import FileServiceDep from askui.chat.api.files.models import File as FileModel +from askui.chat.api.files.models import FileId from askui.chat.api.files.service import FileService -from askui.chat.api.models import FileId +from askui.chat.api.models import WorkspaceId from askui.utils.api_utils import ListQuery, ListResponse router = APIRouter(prefix="/files", tags=["files"]) @@ -13,45 +16,52 @@ @router.get("") def list_files( + askui_workspace: Annotated[WorkspaceId | None, Header()] = None, query: ListQuery = ListQueryDep, file_service: FileService = FileServiceDep, ) -> ListResponse[FileModel]: """List all files.""" - return file_service.list_(query=query) + return file_service.list_(workspace_id=askui_workspace, query=query) @router.post("", status_code=status.HTTP_201_CREATED) async def upload_file( file: UploadFile, + askui_workspace: Annotated[WorkspaceId, Header()], file_service: FileService = FileServiceDep, ) -> FileModel: """Upload a new file.""" - return await file_service.upload_file(file) + return await file_service.upload_file(workspace_id=askui_workspace, file=file) @router.get("/{file_id}") def retrieve_file( file_id: FileId, + askui_workspace: Annotated[WorkspaceId | None, Header()] = None, file_service: FileService = FileServiceDep, ) -> FileModel: """Get file metadata by ID.""" - return file_service.retrieve(file_id) + return file_service.retrieve(workspace_id=askui_workspace, file_id=file_id) @router.get("/{file_id}/content") def download_file( file_id: FileId, + askui_workspace: Annotated[WorkspaceId | None, Header()] = None, file_service: FileService = FileServiceDep, ) -> FileResponse: """Retrieve a file by ID.""" - file, file_path = file_service.retrieve_file_content(file_id) + file, file_path = file_service.retrieve_file_content( + workspace_id=askui_workspace, file_id=file_id + ) return FileResponse(file_path, media_type=file.media_type, filename=file.filename) @router.delete("/{file_id}", status_code=status.HTTP_204_NO_CONTENT) def delete_file( file_id: FileId, + askui_workspace: Annotated[WorkspaceId, Header()], file_service: FileService = FileServiceDep, ) -> None: """Delete a file by ID.""" - file_service.delete(file_id) + file_service.delete(workspace_id=askui_workspace, file_id=file_id) diff --git a/src/askui/chat/api/files/service.py b/src/askui/chat/api/files/service.py index ee5aed9a..e23a7514 100644 --- a/src/askui/chat/api/files/service.py +++ b/src/askui/chat/api/files/service.py @@ -5,16 +5,19 @@ from pathlib import Path from fastapi import UploadFile +from sqlalchemy import or_ +from sqlalchemy.orm import Session -from askui.chat.api.files.models import File, FileCreateParams -from askui.chat.api.models import FileId +from askui.chat.api.db.queries import list_all +from askui.chat.api.files.models import File, FileCreate +from askui.chat.api.files.orms import FileOrm +from askui.chat.api.models import FileId, WorkspaceId from askui.utils.api_utils import ( - ConflictError, FileTooLargeError, + ForbiddenError, ListQuery, ListResponse, NotFoundError, - list_resources, ) logger = logging.getLogger(__name__) @@ -25,24 +28,29 @@ class FileService: - """Service for managing File resources with filesystem persistence.""" - - def __init__(self, base_dir: Path) -> None: - self._base_dir = base_dir - self._files_dir = base_dir / "files" - self._static_dir = base_dir / "static" - - def _get_file_path(self, file_id: FileId, new: bool = False) -> Path: - """Get the path for file metadata.""" - file_path = self._files_dir / f"{file_id}.json" - exists = file_path.exists() - if new and exists: - error_msg = f"File {file_id} already exists" - raise ConflictError(error_msg) - if not new and not exists: + """Service for managing File resources with database persistence.""" + + def __init__(self, session: Session, data_dir: Path) -> None: + self._session = session + self._data_dir = data_dir + + def _find_by_id(self, workspace_id: WorkspaceId | None, file_id: FileId) -> FileOrm: + """Find file by ID.""" + file_orm: FileOrm | None = ( + self._session.query(FileOrm) + .filter( + FileOrm.id == file_id, + or_( + FileOrm.workspace_id == workspace_id, + FileOrm.workspace_id.is_(None), + ), + ) + .first() + ) + if file_orm is None: error_msg = f"File {file_id} not found" raise NotFoundError(error_msg) - return file_path + return file_orm def _get_static_file_path(self, file: File) -> Path: """Get the path for the static file based on extension.""" @@ -50,52 +58,77 @@ def _get_static_file_path(self, file: File) -> Path: extension = "" if file.media_type != "application/octet-stream": extension = mimetypes.guess_extension(file.media_type) or "" - return self._static_dir / f"{file.id}{extension}" - - def list_(self, query: ListQuery) -> ListResponse[File]: + base_name = f"{file.id}{extension}" + path = self._data_dir / "static" / base_name + if file.workspace_id is not None: + path = ( + self._data_dir + / "workspaces" + / str(file.workspace_id) + / "static" + / base_name + ) + path.parent.mkdir(parents=True, exist_ok=True) + return path + + def list_( + self, workspace_id: WorkspaceId | None, query: ListQuery + ) -> ListResponse[File]: """List files with pagination and filtering.""" - return list_resources(self._files_dir, query, File) + q = self._session.query(FileOrm).filter( + or_( + FileOrm.workspace_id == workspace_id, + FileOrm.workspace_id.is_(None), + ), + ) + orms: list[FileOrm] + orms, has_more = list_all(q, query, FileOrm.id) + data = [orm.to_model() for orm in orms] + return ListResponse( + data=data, + has_more=has_more, + first_id=data[0].id if data else None, + last_id=data[-1].id if data else None, + ) - def retrieve(self, file_id: FileId) -> File: + def retrieve(self, workspace_id: WorkspaceId | None, file_id: FileId) -> File: """Retrieve file metadata by ID.""" - try: - file_path = self._get_file_path(file_id) - return File.model_validate_json(file_path.read_text()) - except FileNotFoundError as e: - error_msg = f"File {file_id} not found" - raise NotFoundError(error_msg) from e + file_orm = self._find_by_id(workspace_id, file_id) + return file_orm.to_model() - def delete(self, file_id: FileId) -> None: + def delete( + self, workspace_id: WorkspaceId | None, file_id: FileId, force: bool = False + ) -> None: """Delete a file and its content. *Important*: We may be left with a static file that is not associated with any file metadata if this fails. """ - try: - file = self.retrieve(file_id) - static_path = self._get_static_file_path(file) - self._get_file_path(file_id).unlink() - if static_path.exists(): - static_path.unlink() - except FileNotFoundError as e: - error_msg = f"File {file_id} not found" - raise NotFoundError(error_msg) from e + file_orm = self._find_by_id(workspace_id, file_id) + file = file_orm.to_model() + if file.workspace_id is None and not force: + error_msg = f"Default file {file_id} cannot be deleted" + raise ForbiddenError(error_msg) + self._session.delete(file_orm) + self._session.commit() + static_path = self._get_static_file_path(file) + static_path.unlink() - def retrieve_file_content(self, file_id: FileId) -> tuple[File, Path]: + def retrieve_file_content( + self, workspace_id: WorkspaceId | None, file_id: FileId + ) -> tuple[File, Path]: """Get file metadata and path for downloading.""" - file = self.retrieve(file_id) + file = self.retrieve(workspace_id, file_id) static_path = self._get_static_file_path(file) return file, static_path async def _write_to_temp_file( self, file: UploadFile, - ) -> tuple[FileCreateParams, Path]: + ) -> tuple[FileCreate, Path]: size = 0 - self._static_dir.mkdir(parents=True, exist_ok=True) temp_file = tempfile.NamedTemporaryFile( delete=False, - dir=self._static_dir, suffix=".temp", ) temp_path = Path(temp_file.name) @@ -106,24 +139,32 @@ async def _write_to_temp_file( if size > MAX_FILE_SIZE: raise FileTooLargeError(MAX_FILE_SIZE) mime_type = file.content_type or "application/octet-stream" - params = FileCreateParams( + params = FileCreate( filename=file.filename, size=size, media_type=mime_type, ) return params, temp_path - def create(self, params: FileCreateParams, path: Path) -> File: - file_model = File.create(params) - self._static_dir.mkdir(parents=True, exist_ok=True) + def create( + self, workspace_id: WorkspaceId | None, params: FileCreate, path: Path + ) -> File: + """Create a file and its content. + + *Important*: We may be left with a static file that is not associated with any + file metadata if this fails. + """ + file_model = File.create(workspace_id, params) static_path = self._get_static_file_path(file_model) shutil.move(path, static_path) - self._save(file_model, new=True) - + file_orm = FileOrm.from_model(file_model) + self._session.add(file_orm) + self._session.commit() return file_model async def upload_file( self, + workspace_id: WorkspaceId | None, file: UploadFile, ) -> File: """Upload a file. @@ -134,7 +175,7 @@ async def upload_file( temp_path: Path | None = None try: params, temp_path = await self._write_to_temp_file(file) - file_model = self.create(params, temp_path) + file_model = self.create(workspace_id, params, temp_path) except Exception: logger.exception("Failed to upload file") raise @@ -143,9 +184,3 @@ async def upload_file( finally: if temp_path: temp_path.unlink(missing_ok=True) - - def _save(self, file: File, new: bool = False) -> None: - self._files_dir.mkdir(parents=True, exist_ok=True) - file_path = self._get_file_path(file.id, new=new) - content = file.model_dump_json() - file_path.write_text(content, encoding="utf-8") diff --git a/src/askui/chat/migrations/shared/files/__init__.py b/src/askui/chat/migrations/shared/files/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/migrations/shared/files/models.py b/src/askui/chat/migrations/shared/files/models.py new file mode 100644 index 00000000..2c1d9026 --- /dev/null +++ b/src/askui/chat/migrations/shared/files/models.py @@ -0,0 +1,27 @@ +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, BeforeValidator, Field + +from askui.chat.migrations.shared.models import UnixDatetimeV1, WorkspaceIdV1 +from askui.chat.migrations.shared.utils import build_prefixer + +FileIdV1 = Annotated[ + str, Field(pattern=r"^file_[a-z0-9]+$"), BeforeValidator(build_prefixer("file")) +] + + +class FileV1(BaseModel): + id: FileIdV1 + object: Literal["file"] = "file" + created_at: UnixDatetimeV1 + filename: str = Field(min_length=1) + size: int = Field(ge=0) + media_type: str + workspace_id: WorkspaceIdV1 | None = Field(default=None, exclude=True) + + def to_db_dict(self) -> dict[str, Any]: + return { + **self.model_dump(exclude={"id", "object", "workspace_id"}), + "id": self.id.removeprefix("file_"), + "workspace_id": str(self.workspace_id) if self.workspace_id else None, + } diff --git a/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py b/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py index a2f80526..6db8f789 100644 --- a/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py +++ b/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py @@ -65,7 +65,6 @@ def upgrade() -> None: data = json.loads(content) mcp_config = McpConfigV1.model_validate(data) mcp_configs_batch.append(mcp_config) - if len(mcp_configs_batch) >= BATCH_SIZE: _insert_mcp_configs_batch(mcp_configs_table, mcp_configs_batch) mcp_configs_batch.clear() diff --git a/src/askui/chat/migrations/versions/8d9e0f1a2b3c_create_files_table.py b/src/askui/chat/migrations/versions/8d9e0f1a2b3c_create_files_table.py new file mode 100644 index 00000000..932e1a6e --- /dev/null +++ b/src/askui/chat/migrations/versions/8d9e0f1a2b3c_create_files_table.py @@ -0,0 +1,38 @@ +"""create_files_table + +Revision ID: 8d9e0f1a2b3c +Revises: 7c3d4e5f6a7b +Create Date: 2025-01-27 11:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "8d9e0f1a2b3c" +down_revision: Union[str, None] = "7c3d4e5f6a7b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "files", + sa.Column("id", sa.String(24), nullable=False, primary_key=True), + sa.Column("workspace_id", sa.Uuid(), nullable=True, index=True), + sa.Column("created_at", sa.Integer(), nullable=False), + sa.Column("filename", sa.String(), nullable=False), + sa.Column("size", sa.Integer(), nullable=False), + sa.Column("media_type", sa.String(), nullable=False), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("files") + # ### end Alembic commands ### diff --git a/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py b/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py new file mode 100644 index 00000000..8a2efcc1 --- /dev/null +++ b/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py @@ -0,0 +1,119 @@ +"""import_json_files + +Revision ID: 9e0f1a2b3c4d +Revises: 8d9e0f1a2b3c +Create Date: 2025-01-27 11:01:00.000000 + +""" + +import json +import logging +import mimetypes +from typing import Sequence, Union + +from alembic import op +from sqlalchemy import MetaData, Table + +from askui.chat.migrations.shared.files.models import FileV1 +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "9e0f1a2b3c4d" +down_revision: Union[str, None] = "8d9e0f1a2b3c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + + +BATCH_SIZE = 100 + + +def _insert_files_batch(files_table: Table, files_batch: list[FileV1]) -> None: + """Insert a batch of files into the database.""" + op.bulk_insert( + files_table, + [file.to_db_dict() for file in files_batch], + ) + + +settings = SettingsV1() +workspaces_dir = settings.data_dir / "workspaces" + + +def upgrade() -> None: # noqa: C901 + """Import existing files from JSON files in workspace static directories.""" + + # Skip if workspaces directory doesn't exist (e.g., first-time setup) + if not workspaces_dir.exists(): + return + + # Get the table from the current database schema + connection = op.get_bind() + files_table = Table("files", MetaData(), autoload_with=connection) + + # Process files in batches + files_batch: list[FileV1] = [] + + # Iterate through all workspace directories + for workspace_dir in workspaces_dir.iterdir(): + if not workspace_dir.is_dir(): + continue + + workspace_id = workspace_dir.name + files_dir = workspace_dir / "files" + + if not files_dir.exists(): + continue + + # Get all JSON files in the static directory + json_files = list(files_dir.glob("*.json")) + + for json_file in json_files: + try: + content = json_file.read_text(encoding="utf-8").strip() + data = json.loads(content) + file = FileV1.model_validate({**data, "workspace_id": workspace_id}) + files_batch.append(file) + if len(files_batch) >= BATCH_SIZE: + _insert_files_batch(files_table, files_batch) + files_batch.clear() + except Exception: # noqa: PERF203 + error_msg = "Failed to import file" + logger.exception(error_msg, extra={"json_file": str(json_file)}) + continue + + # Insert remaining files in the final batch + if files_batch: + _insert_files_batch(files_table, files_batch) + + +def downgrade() -> None: + """Recreate JSON files for files during downgrade.""" + + connection = op.get_bind() + files_table = Table("files", MetaData(), autoload_with=connection) + + # Fetch all files from the database + result = connection.execute(files_table.select()) + rows = result.fetchall() + if not rows: + return + + for row in rows: + try: + file_model: FileV1 = FileV1.model_validate(row, from_attributes=True) + if file_model.workspace_id: + files_dir = workspaces_dir / str(file_model.workspace_id) / "files" + else: + files_dir = settings.data_dir / "files" + files_dir.mkdir(parents=True, exist_ok=True) + json_path = files_dir / f"{file_model.id}.json" + if json_path.exists(): + continue + with json_path.open("w", encoding="utf-8") as f: + f.write(json.dumps(file_model.model_dump())) + except Exception as e: # noqa: PERF203 + error_msg = f"Failed to export row to json: {e}" + logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) + continue diff --git a/src/askui/chat/migrations/versions/a0f1a2b3c4d5_remove_files_dirs.py b/src/askui/chat/migrations/versions/a0f1a2b3c4d5_remove_files_dirs.py new file mode 100644 index 00000000..7ac2f4f4 --- /dev/null +++ b/src/askui/chat/migrations/versions/a0f1a2b3c4d5_remove_files_dirs.py @@ -0,0 +1,68 @@ +"""remove_files_dirs + +Revision ID: a0f1a2b3c4d5 +Revises: 9e0f1a2b3c4d +Create Date: 2025-01-27 11:02:00.000000 + +""" + +import logging +import shutil +from typing import Sequence, Union + +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "a0f1a2b3c4d5" +down_revision: Union[str, None] = "9e0f1a2b3c4d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + +settings = SettingsV1() +workspaces_dir = settings.data_dir / "workspaces" + + +def upgrade() -> None: + """Remove JSON files from workspace static directories after successful migration.""" + + # Skip if workspaces directory doesn't exist + if not workspaces_dir.exists(): + logger.info( + "Workspaces directory does not exist, skipping removal", + extra={"workspaces_dir": str(workspaces_dir)}, + ) + return + + # Iterate through all workspace directories + for workspace_dir in workspaces_dir.iterdir(): + if not workspace_dir.is_dir(): + continue + + files_dir = workspace_dir / "files" + if not files_dir.exists(): + logger.info( + "Files directory does not exist, skipping removal", + extra={"files_dir": str(files_dir)}, + ) + continue + + try: + shutil.rmtree(files_dir) + logger.info( + "Successfully removed files directory", + extra={"files_dir": str(files_dir)}, + ) + except Exception as e: # noqa: PERF203 + error_msg = "Failed to remove files directory" + logger.exception(error_msg, extra={"files_dir": str(files_dir)}) + raise RuntimeError(error_msg) from e + + +def downgrade() -> None: + """Recreate JSON files in workspace static directories during downgrade.""" + + # This is handled by the import_json_files migration downgrade + # No need to recreate files here as they will be recreated when downgrading + # the import_json_files migration From 4dd2f972d3d3d9dd28928ea9696cd3611a1458bd Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Fri, 24 Oct 2025 23:14:36 +0200 Subject: [PATCH 09/14] feat(chat): migrate messages, runs, threads from json persitence to sqlite --- src/askui/chat/api/assistants/orms.py | 3 +- src/askui/chat/api/assistants/service.py | 24 +- src/askui/chat/api/db/engine.py | 11 +- src/askui/chat/api/files/orms.py | 3 +- src/askui/chat/api/mcp_configs/orms.py | 5 +- src/askui/chat/api/mcp_configs/service.py | 10 +- .../chat/api/messages/chat_history_manager.py | 28 +- src/askui/chat/api/messages/dependencies.py | 10 +- src/askui/chat/api/messages/models.py | 20 +- src/askui/chat/api/messages/orms.py | 55 ++++ src/askui/chat/api/messages/router.py | 36 ++- src/askui/chat/api/messages/service.py | 114 +++++--- src/askui/chat/api/runs/dependencies.py | 9 +- src/askui/chat/api/runs/events/service.py | 21 +- src/askui/chat/api/runs/models.py | 156 ++++++++-- src/askui/chat/api/runs/orms.py | 52 ++++ src/askui/chat/api/runs/router.py | 56 ++-- src/askui/chat/api/runs/runner/runner.py | 103 ++++--- src/askui/chat/api/runs/service.py | 182 ++++++------ src/askui/chat/api/settings.py | 2 +- .../chat/api/telemetry/logs/structlog.py | 6 +- src/askui/chat/api/threads/dependencies.py | 18 +- src/askui/chat/api/threads/facade.py | 53 +--- src/askui/chat/api/threads/models.py | 24 +- src/askui/chat/api/threads/orms.py | 31 ++ src/askui/chat/api/threads/router.py | 29 +- src/askui/chat/api/threads/service.py | 129 ++++----- src/askui/chat/migrations/env.py | 5 + .../migrations/shared/assistants/models.py | 4 +- .../chat/migrations/shared/files/models.py | 4 +- .../migrations/shared/mcp_configs/models.py | 4 +- .../migrations/shared/messages/__init__.py | 0 .../chat/migrations/shared/messages/models.py | 166 +++++++++++ .../chat/migrations/shared/runs/__init__.py | 0 .../chat/migrations/shared/runs/models.py | 74 +++++ .../migrations/shared/threads/__init__.py | 0 .../chat/migrations/shared/threads/models.py | 25 ++ .../057f82313448_import_json_assistants.py | 21 +- .../1a2b3c4d5e6f_create_threads_table.py | 36 +++ .../2b3c4d5e6f7a_create_messages_table.py | 60 ++++ .../37007a499ca7_remove_assistants_dir.py | 51 ---- ...37007a499ca7_soft_delete_assistants_dir.py | 87 ++++++ .../3c4d5e6f7a8b_create_runs_table.py | 72 +++++ .../4d5e6f7a8b9c_import_json_threads.py | 120 ++++++++ .../5e6f7a8b9c0d_import_json_messages.py | 266 ++++++++++++++++++ .../6b2c3d4e5f6a_import_json_mcp_configs.py | 25 +- .../versions/6f7a8b9c0d1e_import_json_runs.py | 242 ++++++++++++++++ .../7a8b9c0d1e2f_soft_delete_threads_dirs.py | 97 +++++++ .../7c3d4e5f6a7b_remove_mcp_configs_dir.py | 51 ---- ...c3d4e5f6a7b_soft_delete_mcp_configs_dir.py | 86 ++++++ .../8b9c0d1e2f3a_soft_delete_messages_dirs.py | 97 +++++++ .../9c0d1e2f3a4b_soft_delete_runs_dirs.py | 97 +++++++ .../9e0f1a2b3c4d_import_json_files.py | 24 +- .../a0f1a2b3c4d5_remove_files_dirs.py | 68 ----- .../a0f1a2b3c4d5_soft_delete_files_dirs.py | 108 +++++++ src/askui/utils/datetime_utils.py | 1 + 56 files changed, 2426 insertions(+), 655 deletions(-) create mode 100644 src/askui/chat/api/messages/orms.py create mode 100644 src/askui/chat/api/runs/orms.py create mode 100644 src/askui/chat/api/threads/orms.py create mode 100644 src/askui/chat/migrations/shared/messages/__init__.py create mode 100644 src/askui/chat/migrations/shared/messages/models.py create mode 100644 src/askui/chat/migrations/shared/runs/__init__.py create mode 100644 src/askui/chat/migrations/shared/runs/models.py create mode 100644 src/askui/chat/migrations/shared/threads/__init__.py create mode 100644 src/askui/chat/migrations/shared/threads/models.py create mode 100644 src/askui/chat/migrations/versions/1a2b3c4d5e6f_create_threads_table.py create mode 100644 src/askui/chat/migrations/versions/2b3c4d5e6f7a_create_messages_table.py delete mode 100644 src/askui/chat/migrations/versions/37007a499ca7_remove_assistants_dir.py create mode 100644 src/askui/chat/migrations/versions/37007a499ca7_soft_delete_assistants_dir.py create mode 100644 src/askui/chat/migrations/versions/3c4d5e6f7a8b_create_runs_table.py create mode 100644 src/askui/chat/migrations/versions/4d5e6f7a8b9c_import_json_threads.py create mode 100644 src/askui/chat/migrations/versions/5e6f7a8b9c0d_import_json_messages.py create mode 100644 src/askui/chat/migrations/versions/6f7a8b9c0d1e_import_json_runs.py create mode 100644 src/askui/chat/migrations/versions/7a8b9c0d1e2f_soft_delete_threads_dirs.py delete mode 100644 src/askui/chat/migrations/versions/7c3d4e5f6a7b_remove_mcp_configs_dir.py create mode 100644 src/askui/chat/migrations/versions/7c3d4e5f6a7b_soft_delete_mcp_configs_dir.py create mode 100644 src/askui/chat/migrations/versions/8b9c0d1e2f3a_soft_delete_messages_dirs.py create mode 100644 src/askui/chat/migrations/versions/9c0d1e2f3a4b_soft_delete_runs_dirs.py delete mode 100644 src/askui/chat/migrations/versions/a0f1a2b3c4d5_remove_files_dirs.py create mode 100644 src/askui/chat/migrations/versions/a0f1a2b3c4d5_soft_delete_files_dirs.py diff --git a/src/askui/chat/api/assistants/orms.py b/src/askui/chat/api/assistants/orms.py index 4c84f875..c59d4ea9 100644 --- a/src/askui/chat/api/assistants/orms.py +++ b/src/askui/chat/api/assistants/orms.py @@ -30,8 +30,7 @@ class AssistantOrm(Base): @classmethod def from_model(cls, model: Assistant) -> "AssistantOrm": return cls( - **model.model_dump(exclude={"object", "created_at"}), - created_at=model.created_at, + **model.model_dump(exclude={"object"}), ) def to_model(self) -> Assistant: diff --git a/src/askui/chat/api/assistants/service.py b/src/askui/chat/api/assistants/service.py index 91154e51..e963503d 100644 --- a/src/askui/chat/api/assistants/service.py +++ b/src/askui/chat/api/assistants/service.py @@ -53,15 +53,15 @@ def _find_by_id( def retrieve( self, workspace_id: WorkspaceId | None, assistant_id: AssistantId ) -> Assistant: - assistant_model = self._find_by_id(workspace_id, assistant_id) - return assistant_model.to_model() + assistant_orm = self._find_by_id(workspace_id, assistant_id) + return assistant_orm.to_model() def create( self, workspace_id: WorkspaceId | None, params: AssistantCreate ) -> Assistant: assistant = Assistant.create(workspace_id, params) - assistant_model = AssistantOrm.from_model(assistant) - self._session.add(assistant_model) + assistant_orm = AssistantOrm.from_model(assistant) + self._session.add(assistant_orm) self._session.commit() return assistant @@ -72,14 +72,14 @@ def modify( params: AssistantModify, force: bool = False, ) -> Assistant: - assistant_model = self._find_by_id(workspace_id, assistant_id) - if assistant_model.workspace_id is None and not force: + assistant_orm = self._find_by_id(workspace_id, assistant_id) + if assistant_orm.workspace_id is None and not force: error_msg = f"Default assistant {assistant_id} cannot be modified" raise ForbiddenError(error_msg) - assistant_model.update(params.model_dump()) + assistant_orm.update(params.model_dump()) self._session.commit() - self._session.refresh(assistant_model) - return assistant_model.to_model() + self._session.refresh(assistant_orm) + return assistant_orm.to_model() def delete( self, @@ -87,9 +87,9 @@ def delete( assistant_id: AssistantId, force: bool = False, ) -> None: - assistant_model = self._find_by_id(workspace_id, assistant_id) - if assistant_model.workspace_id is None and not force: + assistant_orm = self._find_by_id(workspace_id, assistant_id) + if assistant_orm.workspace_id is None and not force: error_msg = f"Default assistant {assistant_id} cannot be deleted" raise ForbiddenError(error_msg) - self._session.delete(assistant_model) + self._session.delete(assistant_orm) self._session.commit() diff --git a/src/askui/chat/api/db/engine.py b/src/askui/chat/api/db/engine.py index 4ed4fc43..87931455 100644 --- a/src/askui/chat/api/db/engine.py +++ b/src/askui/chat/api/db/engine.py @@ -1,6 +1,8 @@ import logging +from sqlite3 import Connection as SQLite3Connection +from typing import Any -from sqlalchemy import create_engine +from sqlalchemy import Engine, create_engine, event from askui.chat.api.dependencies import get_settings @@ -10,3 +12,10 @@ connect_args = {"check_same_thread": False} echo = logger.isEnabledFor(logging.DEBUG) engine = create_engine(settings.db.url, connect_args=connect_args, echo=echo) + + +@event.listens_for(Engine, "connect") +def set_sqlite_pragma(dbapi_conn: SQLite3Connection, connection_record: Any) -> None: # noqa: ARG001 + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() diff --git a/src/askui/chat/api/files/orms.py b/src/askui/chat/api/files/orms.py index 9a9ea0ad..ab0ffd94 100644 --- a/src/askui/chat/api/files/orms.py +++ b/src/askui/chat/api/files/orms.py @@ -28,8 +28,7 @@ class FileOrm(Base): @classmethod def from_model(cls, model: File) -> "FileOrm": return cls( - **model.model_dump(exclude={"object", "created_at"}), - created_at=model.created_at, + **model.model_dump(exclude={"object"}), ) def to_model(self) -> File: diff --git a/src/askui/chat/api/mcp_configs/orms.py b/src/askui/chat/api/mcp_configs/orms.py index b2bc65e1..757d5004 100644 --- a/src/askui/chat/api/mcp_configs/orms.py +++ b/src/askui/chat/api/mcp_configs/orms.py @@ -27,10 +27,7 @@ class McpConfigOrm(Base): @classmethod def from_model(cls, model: McpConfig) -> "McpConfigOrm": - return cls( - **model.model_dump(exclude={"object", "created_at"}), - created_at=model.created_at, - ) + return cls(**model.model_dump(exclude={"object"})) def to_model(self) -> McpConfig: return McpConfig.model_validate(self, from_attributes=True) diff --git a/src/askui/chat/api/mcp_configs/service.py b/src/askui/chat/api/mcp_configs/service.py index 72324195..d540ee93 100644 --- a/src/askui/chat/api/mcp_configs/service.py +++ b/src/askui/chat/api/mcp_configs/service.py @@ -149,8 +149,8 @@ def delete( def seed(self) -> None: """Seed the MCP configuration service with default MCP configurations.""" for seed in self._seeds: - with self._session.begin(): - self._session.query(McpConfigOrm).filter( - McpConfigOrm.id == seed.id - ).delete() - self._session.add(McpConfigOrm.from_model(seed)) + self._session.query(McpConfigOrm).filter( + McpConfigOrm.id == seed.id + ).delete() + self._session.add(McpConfigOrm.from_model(seed)) + self._session.commit() diff --git a/src/askui/chat/api/messages/chat_history_manager.py b/src/askui/chat/api/messages/chat_history_manager.py index 9642257c..9120f092 100644 --- a/src/askui/chat/api/messages/chat_history_manager.py +++ b/src/askui/chat/api/messages/chat_history_manager.py @@ -1,9 +1,9 @@ from anthropic.types.beta import BetaTextBlockParam, BetaToolUnionParam -from askui.chat.api.messages.models import Message, MessageCreateParams +from askui.chat.api.messages.models import Message, MessageCreate from askui.chat.api.messages.service import MessageService from askui.chat.api.messages.translator import MessageTranslator -from askui.chat.api.models import ThreadId +from askui.chat.api.models import ThreadId, WorkspaceId from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.truncation_strategies import TruncationStrategyFactory @@ -38,6 +38,7 @@ def __init__( async def retrieve_message_params( self, + workspace_id: WorkspaceId, thread_id: ThreadId, model: str, system: str | list[BetaTextBlockParam] | None, @@ -51,36 +52,25 @@ async def retrieve_message_params( model=model, ) ) - for msg in self._message_service.iter(thread_id=thread_id): + for msg in self._message_service.iter( + workspace_id=workspace_id, thread_id=thread_id + ): anthropic_message = await self._message_translator.to_anthropic(msg) truncation_strategy.append_message(anthropic_message) return truncation_strategy.messages async def append_message( self, + workspace_id: WorkspaceId, thread_id: ThreadId, assistant_id: str | None, run_id: str, message: MessageParam, ) -> Message: - """ - Add a message to the chat history and return both the created message and original message param. - - This method creates a message in the database and returns both the created - message object and the original message parameter for further processing. - - Args: - thread_id (ThreadId): The thread ID to add the message to. - assistant_id (str | None): The assistant ID if the message is from an assistant. - run_id (str): The run ID associated with this message. - message (MessageParam): The message to add. - - Returns: - Message: The created message object - """ return self._message_service.create( + workspace_id=workspace_id, thread_id=thread_id, - params=MessageCreateParams( + params=MessageCreate( assistant_id=assistant_id if message.role == "assistant" else None, role=message.role, content=await self._message_content_translator.from_anthropic( diff --git a/src/askui/chat/api/messages/dependencies.py b/src/askui/chat/api/messages/dependencies.py index e22ea940..965bfb6e 100644 --- a/src/askui/chat/api/messages/dependencies.py +++ b/src/askui/chat/api/messages/dependencies.py @@ -1,8 +1,6 @@ -from pathlib import Path - from fastapi import Depends -from askui.chat.api.dependencies import WorkspaceDirDep +from askui.chat.api.db.session import SessionDep from askui.chat.api.files.dependencies import FileServiceDep from askui.chat.api.files.service import FileService from askui.chat.api.messages.chat_history_manager import ChatHistoryManager @@ -15,10 +13,10 @@ def get_message_service( - workspace_dir: Path = WorkspaceDirDep, + session: SessionDep, ) -> MessageService: - """Get MessagePersistedService instance.""" - return MessageService(workspace_dir) + """Get MessageService instance.""" + return MessageService(session) MessageServiceDep = Depends(get_message_service) diff --git a/src/askui/chat/api/messages/models.py b/src/askui/chat/api/messages/models.py index 346fff7e..80bbe365 100644 --- a/src/askui/chat/api/messages/models.py +++ b/src/askui/chat/api/messages/models.py @@ -2,7 +2,15 @@ from pydantic import BaseModel -from askui.chat.api.models import AssistantId, FileId, MessageId, RunId, ThreadId +from askui.chat.api.models import ( + AssistantId, + FileId, + MessageId, + RunId, + ThreadId, + WorkspaceId, + WorkspaceResource, +) from askui.models.shared.agent_message_param import ( Base64ImageSourceParam, BetaRedactedThinkingBlock, @@ -13,7 +21,6 @@ ToolUseBlockParam, UrlImageSourceParam, ) -from askui.utils.api_utils import Resource from askui.utils.datetime_utils import UnixDatetime, now from askui.utils.id_utils import generate_time_ordered_id @@ -75,21 +82,24 @@ class MessageBase(MessageParam): run_id: RunId | None = None -class MessageCreateParams(MessageBase): +class MessageCreate(MessageBase): pass -class Message(MessageBase, Resource): +class Message(MessageBase, WorkspaceResource): id: MessageId object: Literal["thread.message"] = "thread.message" created_at: UnixDatetime thread_id: ThreadId @classmethod - def create(cls, thread_id: ThreadId, params: MessageCreateParams) -> "Message": + def create( + cls, workspace_id: WorkspaceId, thread_id: ThreadId, params: MessageCreate + ) -> "Message": return cls( id=generate_time_ordered_id("msg"), created_at=now(), + workspace_id=workspace_id, thread_id=thread_id, **params.model_dump(), ) diff --git a/src/askui/chat/api/messages/orms.py b/src/askui/chat/api/messages/orms.py new file mode 100644 index 00000000..21b12c42 --- /dev/null +++ b/src/askui/chat/api/messages/orms.py @@ -0,0 +1,55 @@ +"""Message database model.""" + +from datetime import datetime +from typing import Any +from uuid import UUID + +from sqlalchemy import JSON, ForeignKey, String, Uuid +from sqlalchemy.orm import Mapped, mapped_column + +from askui.chat.api.assistants.orms import AssistantId +from askui.chat.api.db.orm.base import Base +from askui.chat.api.db.orm.types import ( + RunId, + ThreadId, + UnixDatetime, + create_prefixed_id_type, +) +from askui.chat.api.messages.models import Message + +MessageId = create_prefixed_id_type("msg") + + +class MessageOrm(Base): + """Message database model.""" + + __tablename__ = "messages" + + id: Mapped[str] = mapped_column(MessageId, primary_key=True) + thread_id: Mapped[str] = mapped_column( + ThreadId, + ForeignKey("threads.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + workspace_id: Mapped[UUID] = mapped_column(Uuid, nullable=False, index=True) + created_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) + role: Mapped[str] = mapped_column(String, nullable=False) + content: Mapped[str | list[dict[str, Any]]] = mapped_column(JSON, nullable=False) + stop_reason: Mapped[str | None] = mapped_column(String, nullable=True) + assistant_id: Mapped[str | None] = mapped_column( + AssistantId, ForeignKey("assistants.id", ondelete="SET NULL"), nullable=True + ) + run_id: Mapped[str | None] = mapped_column( + RunId, ForeignKey("runs.id", ondelete="SET NULL"), nullable=True + ) + + @classmethod + def from_model(cls, model: Message) -> "MessageOrm": + return cls( + **model.model_dump(exclude={"object", "created_at"}), + created_at=model.created_at, + ) + + def to_model(self) -> Message: + return Message.model_validate(self, from_attributes=True) diff --git a/src/askui/chat/api/messages/router.py b/src/askui/chat/api/messages/router.py index 4276950a..409ac3b5 100644 --- a/src/askui/chat/api/messages/router.py +++ b/src/askui/chat/api/messages/router.py @@ -1,12 +1,12 @@ -from fastapi import APIRouter, status +from typing import Annotated + +from fastapi import APIRouter, Header, status from askui.chat.api.dependencies import ListQueryDep from askui.chat.api.messages.dependencies import MessageServiceDep -from askui.chat.api.messages.models import Message, MessageCreateParams +from askui.chat.api.messages.models import Message, MessageCreate from askui.chat.api.messages.service import MessageService -from askui.chat.api.models import MessageId, ThreadId -from askui.chat.api.threads.dependencies import ThreadFacadeDep -from askui.chat.api.threads.facade import ThreadFacade +from askui.chat.api.models import MessageId, ThreadId, WorkspaceId from askui.utils.api_utils import ListQuery, ListResponse router = APIRouter(prefix="/threads/{thread_id}/messages", tags=["messages"]) @@ -14,35 +14,47 @@ @router.get("") def list_messages( + askui_workspace: Annotated[WorkspaceId, Header()], thread_id: ThreadId, query: ListQuery = ListQueryDep, - thread_facade: ThreadFacade = ThreadFacadeDep, + message_service: MessageService = MessageServiceDep, ) -> ListResponse[Message]: - return thread_facade.list_messages(thread_id, query=query) + return message_service.list_( + workspace_id=askui_workspace, thread_id=thread_id, query=query + ) @router.post("", status_code=status.HTTP_201_CREATED) async def create_message( + askui_workspace: Annotated[WorkspaceId, Header()], thread_id: ThreadId, - params: MessageCreateParams, - thread_facade: ThreadFacade = ThreadFacadeDep, + params: MessageCreate, + message_service: MessageService = MessageServiceDep, ) -> Message: - return thread_facade.create_message(thread_id=thread_id, params=params) + return message_service.create( + workspace_id=askui_workspace, thread_id=thread_id, params=params + ) @router.get("/{message_id}") def retrieve_message( + askui_workspace: Annotated[WorkspaceId, Header()], thread_id: ThreadId, message_id: MessageId, message_service: MessageService = MessageServiceDep, ) -> Message: - return message_service.retrieve(thread_id, message_id) + return message_service.retrieve( + workspace_id=askui_workspace, thread_id=thread_id, message_id=message_id + ) @router.delete("/{message_id}", status_code=status.HTTP_204_NO_CONTENT) def delete_message( + askui_workspace: Annotated[WorkspaceId, Header()], thread_id: ThreadId, message_id: MessageId, message_service: MessageService = MessageServiceDep, ) -> None: - message_service.delete(thread_id, message_id) + message_service.delete( + workspace_id=askui_workspace, thread_id=thread_id, message_id=message_id + ) diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py index 1d2a4781..a33783bb 100644 --- a/src/askui/chat/api/messages/service.py +++ b/src/askui/chat/api/messages/service.py @@ -1,58 +1,88 @@ -from pathlib import Path from typing import Iterator -from askui.chat.api.messages.models import Message, MessageCreateParams -from askui.chat.api.models import MessageId, ThreadId +from sqlalchemy.orm import Session + +from askui.chat.api.db.queries import list_all +from askui.chat.api.messages.models import Message, MessageCreate +from askui.chat.api.messages.orms import MessageOrm +from askui.chat.api.models import MessageId, ThreadId, WorkspaceId from askui.utils.api_utils import ( LIST_LIMIT_DEFAULT, - ConflictError, ListOrder, ListQuery, ListResponse, NotFoundError, - list_resources, ) class MessageService: - def __init__(self, base_dir: Path) -> None: - self._base_dir = base_dir + """Service for managing Message resources with database persistence.""" - def get_messages_dir(self, thread_id: ThreadId) -> Path: - return self._base_dir / "messages" / thread_id + def __init__(self, session: Session) -> None: + self._session = session - def _get_message_path( - self, thread_id: ThreadId, message_id: MessageId, new: bool = False - ) -> Path: - message_path = self.get_messages_dir(thread_id) / f"{message_id}.json" - exists = message_path.exists() - if new and exists: - error_msg = f"Message {message_id} already exists in thread {thread_id}" - raise ConflictError(error_msg) - if not new and not exists: + def _find_by_id( + self, workspace_id: WorkspaceId, thread_id: ThreadId, message_id: MessageId + ) -> MessageOrm: + """Find message by ID.""" + message_orm: MessageOrm | None = ( + self._session.query(MessageOrm) + .filter( + MessageOrm.id == message_id, + MessageOrm.thread_id == thread_id, + MessageOrm.workspace_id == workspace_id, + ) + .first() + ) + if message_orm is None: error_msg = f"Message {message_id} not found in thread {thread_id}" raise NotFoundError(error_msg) - return message_path + return message_orm - def create(self, thread_id: ThreadId, params: MessageCreateParams) -> Message: - new_message = Message.create(thread_id, params) - self._save(new_message, new=True) - return new_message + def create( + self, + workspace_id: WorkspaceId, + thread_id: ThreadId, + params: MessageCreate, + ) -> Message: + """Create a new message.""" + message = Message.create(workspace_id, thread_id, params) + message_orm = MessageOrm.from_model(message) + self._session.add(message_orm) + self._session.commit() + return message - def list_(self, thread_id: ThreadId, query: ListQuery) -> ListResponse[Message]: - messages_dir = self.get_messages_dir(thread_id) - return list_resources(messages_dir, query, Message) + def list_( + self, workspace_id: WorkspaceId, thread_id: ThreadId, query: ListQuery + ) -> ListResponse[Message]: + """List messages with pagination and filtering.""" + q = self._session.query(MessageOrm).filter( + MessageOrm.thread_id == thread_id, + MessageOrm.workspace_id == workspace_id, + ) + orms: list[MessageOrm] + orms, has_more = list_all(q, query, MessageOrm.id) + data = [orm.to_model() for orm in orms] + return ListResponse( + data=data, + has_more=has_more, + first_id=data[0].id if data else None, + last_id=data[-1].id if data else None, + ) def iter( self, + workspace_id: WorkspaceId, thread_id: ThreadId, order: ListOrder = "asc", batch_size: int = LIST_LIMIT_DEFAULT, ) -> Iterator[Message]: + """Iterate through messages in batches.""" has_more = True last_id: str | None = None while has_more: list_messages_response = self.list_( + workspace_id=workspace_id, thread_id=thread_id, query=ListQuery(limit=batch_size, order=order, after=last_id), ) @@ -61,23 +91,17 @@ def iter( for msg in list_messages_response.data: yield msg - def retrieve(self, thread_id: ThreadId, message_id: MessageId) -> Message: - try: - message_file = self._get_message_path(thread_id, message_id) - return Message.model_validate_json(message_file.read_text(encoding="utf-8")) - except FileNotFoundError as e: - error_msg = f"Message {message_id} not found in thread {thread_id}" - raise NotFoundError(error_msg) from e - - def delete(self, thread_id: ThreadId, message_id: MessageId) -> None: - try: - self._get_message_path(thread_id, message_id).unlink() - except FileNotFoundError as e: - error_msg = f"Message {message_id} not found in thread {thread_id}" - raise NotFoundError(error_msg) from e + def retrieve( + self, workspace_id: WorkspaceId, thread_id: ThreadId, message_id: MessageId + ) -> Message: + """Retrieve message by ID.""" + message_orm = self._find_by_id(workspace_id, thread_id, message_id) + return message_orm.to_model() - def _save(self, message: Message, new: bool = False) -> None: - messages_dir = self.get_messages_dir(message.thread_id) - messages_dir.mkdir(parents=True, exist_ok=True) - message_file = self._get_message_path(message.thread_id, message.id, new=new) - message_file.write_text(message.model_dump_json(), encoding="utf-8") + def delete( + self, workspace_id: WorkspaceId, thread_id: ThreadId, message_id: MessageId + ) -> None: + """Delete a message.""" + message_orm = self._find_by_id(workspace_id, thread_id, message_id) + self._session.delete(message_orm) + self._session.commit() diff --git a/src/askui/chat/api/runs/dependencies.py b/src/askui/chat/api/runs/dependencies.py index 759fbf73..0fefdaca 100644 --- a/src/askui/chat/api/runs/dependencies.py +++ b/src/askui/chat/api/runs/dependencies.py @@ -1,10 +1,9 @@ -from pathlib import Path - from fastapi import Depends from askui.chat.api.assistants.dependencies import AssistantServiceDep from askui.chat.api.assistants.service import AssistantService -from askui.chat.api.dependencies import SettingsDep, WorkspaceDirDep +from askui.chat.api.db.session import SessionDep +from askui.chat.api.dependencies import SettingsDep from askui.chat.api.mcp_clients.dependencies import McpClientManagerManagerDep from askui.chat.api.mcp_clients.manager import McpClientManagerManager from askui.chat.api.messages.chat_history_manager import ChatHistoryManager @@ -18,14 +17,14 @@ def get_runs_service( - workspace_dir: Path = WorkspaceDirDep, + session: SessionDep, assistant_service: AssistantService = AssistantServiceDep, chat_history_manager: ChatHistoryManager = ChatHistoryManagerDep, mcp_client_manager_manager: McpClientManagerManager = McpClientManagerManagerDep, settings: Settings = SettingsDep, ) -> RunService: return RunService( - base_dir=workspace_dir, + session=session, assistant_service=assistant_service, mcp_client_manager_manager=mcp_client_manager_manager, chat_history_manager=chat_history_manager, diff --git a/src/askui/chat/api/runs/events/service.py b/src/askui/chat/api/runs/events/service.py index 3e07d8c3..4b9210f7 100644 --- a/src/askui/chat/api/runs/events/service.py +++ b/src/askui/chat/api/runs/events/service.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from aiofiles.threadpool.text import AsyncTextIOWrapper -from askui.chat.api.models import RunId, ThreadId +from askui.chat.api.models import RunId, ThreadId, WorkspaceId from askui.chat.api.runs.events.done_events import DoneEvent from askui.chat.api.runs.events.error_events import ( ErrorEvent, @@ -94,7 +94,9 @@ async def wait_for_new_event( class RetrieveRunService(ABC): @abstractmethod - def retrieve(self, thread_id: ThreadId, run_id: RunId) -> Run: + def retrieve( + self, workspace_id: WorkspaceId, thread_id: ThreadId, run_id: RunId + ) -> Run: raise NotImplementedError @@ -139,12 +141,14 @@ def __init__( manager: EventFileManager, run_service: RetrieveRunService, start_index: int, + workspace_id: WorkspaceId, thread_id: ThreadId, run_id: RunId, ): self._manager = manager self._run_service = run_service self._start_index = start_index + self._workspace_id = workspace_id self._thread_id = thread_id self._run_id = run_id @@ -189,7 +193,9 @@ async def read_events(self) -> AsyncIterator[Event]: # noqa: C901 "Timeout waiting for file %s to be created", self._manager.file_path, ) - if run := self._run_service.retrieve(self._thread_id, self._run_id): + if run := self._run_service.retrieve( + self._workspace_id, self._thread_id, self._run_id + ): if run.status not in ("queued", "in_progress"): async for event in self._iter_final_events(run): yield event @@ -225,7 +231,7 @@ async def read_events(self) -> AsyncIterator[Event]: # noqa: C901 self._manager.file_path, ) if run := self._run_service.retrieve( - self._thread_id, self._run_id + self._workspace_id, self._thread_id, self._run_id ): if run.status not in ( "queued", @@ -299,7 +305,11 @@ async def create_writer( @asynccontextmanager async def create_reader( - self, thread_id: ThreadId, run_id: RunId, start_index: int = 0 + self, + workspace_id: WorkspaceId, + thread_id: ThreadId, + run_id: RunId, + start_index: int = 0, ) -> AsyncIterator["EventReader"]: """ Create a reader context manager for reading events from a file. @@ -320,6 +330,7 @@ async def create_reader( manager=manager, run_service=self._run_service, start_index=start_index, + workspace_id=workspace_id, thread_id=thread_id, run_id=run_id, ) diff --git a/src/askui/chat/api/runs/models.py b/src/askui/chat/api/runs/models.py index 96220ffd..5699a4ef 100644 --- a/src/askui/chat/api/runs/models.py +++ b/src/askui/chat/api/runs/models.py @@ -3,11 +3,17 @@ from typing import Annotated, Literal from fastapi import Query -from pydantic import BaseModel, computed_field - -from askui.chat.api.models import AssistantId, RunId, ThreadId -from askui.chat.api.threads.models import ThreadCreateParams -from askui.utils.api_utils import ListQuery, Resource +from pydantic import BaseModel, Field, computed_field + +from askui.chat.api.models import ( + AssistantId, + RunId, + ThreadId, + WorkspaceId, + WorkspaceResource, +) +from askui.chat.api.threads.models import ThreadCreate +from askui.utils.api_utils import ListQuery from askui.utils.datetime_utils import UnixDatetime, now from askui.utils.id_utils import generate_time_ordered_id @@ -29,23 +35,92 @@ class RunError(BaseModel): code: Literal["server_error", "rate_limit_exceeded", "invalid_prompt"] -class RunBase(BaseModel): - """Base run model.""" +class RunCreate(BaseModel): + """Parameters for creating a run.""" + stream: bool = False assistant_id: AssistantId -class RunCreateParams(RunBase): - """Parameters for creating a run.""" +class RunStart(BaseModel): + """Parameters for starting a run.""" - stream: bool = False + type: Literal["start"] = "start" + status: Literal["in_progress"] = "in_progress" + started_at: UnixDatetime = Field(default_factory=now) + expires_at: UnixDatetime = Field( + default_factory=lambda: now() + timedelta(minutes=10) + ) + + +class RunPing(BaseModel): + """Parameters for pinging a run.""" + + type: Literal["ping"] = "ping" + expires_at: UnixDatetime = Field( + default_factory=lambda: now() + timedelta(minutes=10) + ) + + +class RunComplete(BaseModel): + """Parameters for completing a run.""" + + type: Literal["complete"] = "complete" + status: Literal["completed"] = "completed" + completed_at: UnixDatetime = Field(default_factory=now) + + +class RunTryCancelling(BaseModel): + """Parameters for trying to cancel a run.""" + + type: Literal["try_cancelling"] = "try_cancelling" + status: Literal["cancelling"] = "cancelling" + tried_cancelling_at: UnixDatetime = Field(default_factory=now) -class ThreadAndRunCreateParams(RunCreateParams): - thread: ThreadCreateParams +class RunCancel(BaseModel): + """Parameters for canceling a run.""" + type: Literal["cancel"] = "cancel" + status: Literal["cancelled"] = "cancelled" + cancelled_at: UnixDatetime = Field(default_factory=now) -class Run(RunBase, Resource): + +class RunFail(BaseModel): + """Parameters for failing a run.""" + + type: Literal["fail"] = "fail" + status: Literal["failed"] = "failed" + failed_at: UnixDatetime = Field(default_factory=now) + last_error: RunError + + +RunModify = RunStart | RunPing | RunComplete | RunTryCancelling | RunCancel | RunFail + + +class ThreadAndRunCreate(RunCreate): + thread: ThreadCreate + + +def map_status_to_readable_description(status: RunStatus) -> str: + match status: + case "queued": + return "Run has been queued." + case "in_progress": + return "Run is in progress." + case "completed": + return "Run has been completed." + case "cancelled": + return "Run has been cancelled." + case "failed": + return "Run has failed." + case "expired": + return "Run has expired." + case "cancelling": + return "Run is being cancelled." + + +class Run(WorkspaceResource): """A run execution within a thread.""" id: RunId @@ -59,11 +134,15 @@ class Run(RunBase, Resource): cancelled_at: UnixDatetime | None = None tried_cancelling_at: UnixDatetime | None = None last_error: RunError | None = None + assistant_id: AssistantId | None = None @classmethod - def create(cls, thread_id: ThreadId, params: RunCreateParams) -> "Run": + def create( + cls, workspace_id: WorkspaceId, thread_id: ThreadId, params: RunCreate + ) -> "Run": return cls( id=generate_time_ordered_id("run"), + workspace_id=workspace_id, thread_id=thread_id, created_at=now(), expires_at=now() + timedelta(minutes=10), @@ -87,22 +166,39 @@ def status(self) -> RunStatus: return "in_progress" return "queued" - def start(self) -> None: - self.started_at = now() - self.expires_at = now() + timedelta(minutes=10) - - def ping(self) -> None: - self.expires_at = now() + timedelta(minutes=10) - - def complete(self) -> None: - self.completed_at = now() - - def cancel(self) -> None: - self.cancelled_at = now() - - def fail(self, error: RunError) -> None: - self.failed_at = now() - self.last_error = error + def validate_modify(self, params: RunModify) -> None: # noqa: C901 + status_description = map_status_to_readable_description(self.status) + error_msg = status_description + match params.type: + case "start": + if self.status != "queued": + error_msg += " Cannot start it (again). Please create a new run." + raise ValueError(error_msg) + case "ping": + if self.status != "in_progress": + error_msg += " Cannot ping. Run is not in progress." + raise ValueError(error_msg) + case "complete": + if self.status != "in_progress": + error_msg += " Cannot complete. Run is not in progress." + raise ValueError(error_msg) + case "try_cancelling": + if self.status not in ["queued", "in_progress"]: + error_msg += " Cannot cancel (again)." + if self.status != "cancelling": + # I think this just sounds better if this is only added if it + # is not being cancelled as it is still in progress while being + # cancelled. + error_msg += " Run is neither queued nor in progress." + raise ValueError(error_msg) + case "cancel": + if self.status not in ["queued", "in_progress", "cancelling"]: + error_msg += " Cannot cancel. Run is neither queued, in progress, nor has it been tried to be cancelled." + raise ValueError(error_msg) + case "fail": + if self.status not in ["queued", "in_progress", "cancelling"]: + error_msg += " Cannot fail. Run is neither queued, in progress, nor has it been tried to be cancelled." + raise ValueError(error_msg) @dataclass(kw_only=True) diff --git a/src/askui/chat/api/runs/orms.py b/src/askui/chat/api/runs/orms.py new file mode 100644 index 00000000..7dfd84a8 --- /dev/null +++ b/src/askui/chat/api/runs/orms.py @@ -0,0 +1,52 @@ +"""Run database model.""" + +from datetime import datetime +from typing import Any +from uuid import UUID + +from sqlalchemy import JSON, ForeignKey, Uuid +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.sql.sqltypes import String + +from askui.chat.api.assistants.orms import AssistantId +from askui.chat.api.db.orm.base import Base +from askui.chat.api.db.orm.types import ThreadId, UnixDatetime, create_prefixed_id_type +from askui.chat.api.runs.models import Run + +RunId = create_prefixed_id_type("run") + + +class RunOrm(Base): + """Run database model.""" + + __tablename__ = "runs" + + id: Mapped[str] = mapped_column(RunId, primary_key=True) + thread_id: Mapped[str] = mapped_column( + ThreadId, + ForeignKey("threads.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + workspace_id: Mapped[UUID] = mapped_column(Uuid, nullable=False, index=True) + status: Mapped[str] = mapped_column(String, nullable=False, index=True) + created_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) + expires_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) + started_at: Mapped[datetime | None] = mapped_column(UnixDatetime, nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(UnixDatetime, nullable=True) + failed_at: Mapped[datetime | None] = mapped_column(UnixDatetime, nullable=True) + cancelled_at: Mapped[datetime | None] = mapped_column(UnixDatetime, nullable=True) + tried_cancelling_at: Mapped[datetime | None] = mapped_column( + UnixDatetime, nullable=True + ) + last_error: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) + assistant_id: Mapped[str | None] = mapped_column( + AssistantId, ForeignKey("assistants.id", ondelete="SET NULL"), nullable=True + ) + + @classmethod + def from_model(cls, model: Run) -> "RunOrm": + return cls(**model.model_dump(exclude={"object"})) + + def to_model(self) -> Run: + return Run.model_validate(self, from_attributes=True) diff --git a/src/askui/chat/api/runs/router.py b/src/askui/chat/api/runs/router.py index bca81eb2..b01bb284 100644 --- a/src/askui/chat/api/runs/router.py +++ b/src/askui/chat/api/runs/router.py @@ -1,28 +1,18 @@ from collections.abc import AsyncGenerator from typing import Annotated -from fastapi import ( - APIRouter, - BackgroundTasks, - Depends, - Header, - Path, - Query, - Response, - status, -) +from fastapi import APIRouter, BackgroundTasks, Header, Path, Query, Response, status from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel -from askui.chat.api.dependencies import ListQueryDep from askui.chat.api.models import RunId, ThreadId, WorkspaceId -from askui.chat.api.runs.models import RunCreateParams +from askui.chat.api.runs.models import RunCreate from askui.chat.api.threads.dependencies import ThreadFacadeDep from askui.chat.api.threads.facade import ThreadFacade -from askui.utils.api_utils import ListQuery, ListResponse +from askui.utils.api_utils import ListResponse from .dependencies import RunListQueryDep, RunServiceDep -from .models import Run, RunListQuery, ThreadAndRunCreateParams +from .models import Run, RunCancel, RunListQuery, ThreadAndRunCreate from .service import RunService router = APIRouter(tags=["runs"]) @@ -32,12 +22,12 @@ async def create_run( askui_workspace: Annotated[WorkspaceId, Header()], thread_id: Annotated[ThreadId, Path(...)], - params: RunCreateParams, + params: RunCreate, background_tasks: BackgroundTasks, - thread_facade: ThreadFacade = ThreadFacadeDep, + run_service: RunService = RunServiceDep, ) -> Response: stream = params.stream - run, async_generator = await thread_facade.create_run( + run, async_generator = await run_service.create( workspace_id=askui_workspace, thread_id=thread_id, params=params ) if stream: @@ -62,13 +52,15 @@ async def _run_async_generator() -> None: pass background_tasks.add_task(_run_async_generator) - return JSONResponse(status_code=status.HTTP_201_CREATED, content=run.model_dump()) + return JSONResponse( + status_code=status.HTTP_201_CREATED, content=run.model_dump(mode="json") + ) @router.post("/runs") async def create_thread_and_run( askui_workspace: Annotated[WorkspaceId, Header()], - params: ThreadAndRunCreateParams, + params: ThreadAndRunCreate, background_tasks: BackgroundTasks, thread_facade: ThreadFacade = ThreadFacadeDep, ) -> Response: @@ -98,11 +90,14 @@ async def _run_async_generator() -> None: pass background_tasks.add_task(_run_async_generator) - return JSONResponse(status_code=status.HTTP_201_CREATED, content=run.model_dump()) + return JSONResponse( + status_code=status.HTTP_201_CREATED, content=run.model_dump(mode="json") + ) @router.get("/threads/{thread_id}/runs/{run_id}") async def retrieve_run( + askui_workspace: Annotated[WorkspaceId, Header()], thread_id: Annotated[ThreadId, Path(...)], run_id: Annotated[RunId, Path(...)], stream: Annotated[bool, Query()] = False, @@ -110,11 +105,15 @@ async def retrieve_run( ) -> Response: if not stream: return JSONResponse( - content=run_service.retrieve(thread_id, run_id).model_dump(), + content=run_service.retrieve( + workspace_id=askui_workspace, thread_id=thread_id, run_id=run_id + ).model_dump(mode="json"), ) async def sse_event_stream() -> AsyncGenerator[str, None]: - async for event in run_service.retrieve_stream(thread_id, run_id): + async for event in run_service.retrieve_stream( + workspace_id=askui_workspace, thread_id=thread_id, run_id=run_id + ): data = ( event.data.model_dump_json() if isinstance(event.data, BaseModel) @@ -130,16 +129,23 @@ async def sse_event_stream() -> AsyncGenerator[str, None]: @router.get("/runs") async def list_runs( + askui_workspace: Annotated[WorkspaceId, Header()], query: RunListQuery = RunListQueryDep, - thread_facade: ThreadFacade = ThreadFacadeDep, + run_service: RunService = RunServiceDep, ) -> ListResponse[Run]: - return thread_facade.list_runs(query=query) + return run_service.list_(workspace_id=askui_workspace, query=query) @router.post("/threads/{thread_id}/runs/{run_id}/cancel") def cancel_run( + askui_workspace: Annotated[WorkspaceId, Header()], thread_id: Annotated[ThreadId, Path(...)], run_id: Annotated[RunId, Path(...)], run_service: RunService = RunServiceDep, ) -> Run: - return run_service.cancel(thread_id, run_id) + return run_service.modify( + workspace_id=askui_workspace, + thread_id=thread_id, + run_id=run_id, + params=RunCancel(), + ) diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index 7ec90a8e..8bacaf3d 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -2,6 +2,7 @@ import logging from abc import ABC, abstractmethod from datetime import datetime, timezone +from typing import Any from anthropic.types.beta import BetaCacheControlEphemeralParam, BetaTextBlockParam from anyio.abc import ObjectStream @@ -10,7 +11,7 @@ from askui.chat.api.assistants.models import Assistant from askui.chat.api.mcp_clients.manager import McpClientManagerManager from askui.chat.api.messages.chat_history_manager import ChatHistoryManager -from askui.chat.api.models import WorkspaceId +from askui.chat.api.models import RunId, ThreadId, WorkspaceId from askui.chat.api.runs.events.done_events import DoneEvent from askui.chat.api.runs.events.error_events import ( ErrorEvent, @@ -21,7 +22,16 @@ from askui.chat.api.runs.events.message_events import MessageEvent from askui.chat.api.runs.events.run_events import RunEvent from askui.chat.api.runs.events.service import RetrieveRunService -from askui.chat.api.runs.models import Run, RunError +from askui.chat.api.runs.models import ( + Run, + RunCancel, + RunComplete, + RunError, + RunFail, + RunModify, + RunPing, + RunStart, +) from askui.chat.api.settings import Settings from askui.custom_agent import CustomAgent from askui.models.models import ModelName @@ -36,41 +46,55 @@ class RunnerRunService(RetrieveRunService, ABC): @abstractmethod - def save(self, run: Run, new: bool = False) -> None: + def modify( + self, + workspace_id: WorkspaceId, + thread_id: ThreadId, + run_id: RunId, + params: RunModify, + ) -> Run: raise NotImplementedError class Runner: def __init__( self, + run_id: RunId, + thread_id: ThreadId, workspace_id: WorkspaceId, assistant: Assistant, - run: Run, chat_history_manager: ChatHistoryManager, mcp_client_manager_manager: McpClientManagerManager, run_service: RunnerRunService, settings: Settings, ) -> None: + self._run_id = run_id self._workspace_id = workspace_id + self._thread_id = thread_id self._assistant = assistant - self._run = run self._chat_history_manager = chat_history_manager self._mcp_client_manager_manager = mcp_client_manager_manager self._run_service = run_service self._settings = settings - def _retrieve(self) -> Run: + def _retrieve_run(self) -> Run: return self._run_service.retrieve( - thread_id=self._run.thread_id, - run_id=self._run.id, + workspace_id=self._workspace_id, + thread_id=self._thread_id, + run_id=self._run_id, + ) + + def _modify_run(self, params: RunModify) -> Run: + return self._run_service.modify( + workspace_id=self._workspace_id, + thread_id=self._thread_id, + run_id=self._run_id, + params=params, ) def _build_system(self) -> list[BetaTextBlockParam]: metadata = { - "run_id": str(self._run.id), - "thread_id": str(self._run.thread_id), - "workspace_id": str(self._workspace_id), - "assistant_id": str(self._run.assistant_id), + **self._get_run_extra_info(), "continued_by_user_at": datetime.now(timezone.utc).strftime( "%A, %B %d, %Y %H:%M:%S %z" ), @@ -110,9 +134,10 @@ async def async_on_message( on_message_cb_param: OnMessageCbParam, ) -> MessageParam | None: created_message = await self._chat_history_manager.append_message( - thread_id=self._run.thread_id, - assistant_id=self._run.assistant_id, - run_id=self._run.id, + workspace_id=self._workspace_id, + thread_id=self._thread_id, + assistant_id=self._assistant.id, + run_id=self._run_id, message=on_message_cb_param.message, ) await send_stream.send( @@ -121,11 +146,10 @@ async def async_on_message( event="thread.message.created", ) ) - updated_run = self._retrieve() + updated_run = self._retrieve_run() if self._should_abort(updated_run): return None - updated_run.ping() - self._run_service.save(updated_run) + self._modify_run(RunPing()) return on_message_cb_param.message on_message = syncify(async_on_message) @@ -142,7 +166,8 @@ def _run_agent_inner() -> None: system = self._build_system() model = self._settings.model messages = syncify(self._chat_history_manager.retrieve_message_params)( - thread_id=self._run.thread_id, + workspace_id=self._workspace_id, + thread_id=self._thread_id, tools=tools.to_params(), system=system, model=model, @@ -165,23 +190,34 @@ def _run_agent_inner() -> None: await asyncify(_run_agent_inner)() + def _get_run_extra_info(self) -> dict[str, str]: + return { + "run_id": self._run_id, + "thread_id": self._thread_id, + "workspace_id": str(self._workspace_id), + "assistant_id": self._assistant.id, + } + async def run( self, send_stream: ObjectStream[Event], ) -> None: try: - self._mark_run_as_started() + updated_run = self._modify_run(RunStart()) + logger.info( + "Run started", + extra=self._get_run_extra_info(), + ) await send_stream.send( RunEvent( - data=self._run, + data=updated_run, event="thread.run.in_progress", ) ) await self._run_agent(send_stream=send_stream) - updated_run = self._retrieve() + updated_run = self._retrieve_run() if updated_run.status == "in_progress": - updated_run.complete() - self._run_service.save(updated_run) + self._modify_run(RunComplete()) await send_stream.send( RunEvent( data=updated_run, @@ -195,8 +231,7 @@ async def run( event="thread.run.cancelling", ) ) - updated_run.cancel() - self._run_service.save(updated_run) + self._modify_run(RunCancel()) await send_stream.send( RunEvent( data=updated_run, @@ -212,10 +247,14 @@ async def run( ) await send_stream.send(DoneEvent()) except Exception as e: # noqa: BLE001 - logger.exception("Exception in runner") - updated_run = self._retrieve() - updated_run.fail(RunError(message=str(e), code="server_error")) - self._run_service.save(updated_run) + logger.exception( + "Run failed", + extra=self._get_run_extra_info(), + ) + updated_run = self._retrieve_run() + self._modify_run( + RunFail(last_error=RunError(message=str(e), code="server_error")), + ) await send_stream.send( RunEvent( data=updated_run, @@ -228,9 +267,5 @@ async def run( ) ) - def _mark_run_as_started(self) -> None: - self._run.start() - self._run_service.save(self._run) - def _should_abort(self, run: Run) -> bool: return run.status in ("cancelled", "cancelling", "expired") diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py index 76a442b1..2535f850 100644 --- a/src/askui/chat/api/runs/service.py +++ b/src/askui/chat/api/runs/service.py @@ -1,91 +1,93 @@ from collections.abc import AsyncGenerator from datetime import datetime, timezone -from pathlib import Path -from typing import Callable import anyio +from sqlalchemy import ColumnElement, or_ +from sqlalchemy.orm import Session from typing_extensions import override from askui.chat.api.assistants.service import AssistantService +from askui.chat.api.db.queries import list_all from askui.chat.api.mcp_clients.manager import McpClientManagerManager from askui.chat.api.messages.chat_history_manager import ChatHistoryManager from askui.chat.api.models import RunId, ThreadId, WorkspaceId from askui.chat.api.runs.events.events import DoneEvent, ErrorEvent, Event, RunEvent from askui.chat.api.runs.events.service import EventService -from askui.chat.api.runs.models import Run, RunCreateParams, RunListQuery +from askui.chat.api.runs.models import ( + Run, + RunCreate, + RunListQuery, + RunModify, + RunStatus, +) +from askui.chat.api.runs.orms import RunOrm from askui.chat.api.runs.runner.runner import Runner, RunnerRunService from askui.chat.api.settings import Settings -from askui.utils.api_utils import ( - ConflictError, - ListResponse, - NotFoundError, - list_resources, -) - - -def _build_run_filter_fn(query: RunListQuery) -> Callable[[Run], bool]: - def filter_fn(run: Run) -> bool: - return (query.thread is None or run.thread_id == query.thread) and ( - query.status is None or run.status in query.status - ) - - return filter_fn +from askui.utils.api_utils import ListResponse, NotFoundError class RunService(RunnerRunService): - """Service for managing Run resources with filesystem persistence.""" + """Service for managing Run resources with database persistence.""" def __init__( self, - base_dir: Path, + session: Session, assistant_service: AssistantService, mcp_client_manager_manager: McpClientManagerManager, chat_history_manager: ChatHistoryManager, settings: Settings, ) -> None: - self._base_dir = base_dir + self._session = session self._assistant_service = assistant_service self._mcp_client_manager_manager = mcp_client_manager_manager self._chat_history_manager = chat_history_manager self._settings = settings - self._event_service = EventService(base_dir, self) - - def get_runs_dir(self, thread_id: ThreadId) -> Path: - return self._base_dir / "runs" / thread_id - - def _get_run_path( - self, thread_id: ThreadId, run_id: RunId, new: bool = False - ) -> Path: - run_path = self.get_runs_dir(thread_id) / f"{run_id}.json" - exists = run_path.exists() - if new and exists: - error_msg = f"Run {run_id} already exists in thread {thread_id}" - raise ConflictError(error_msg) - if not new and not exists: + self._event_service = EventService(settings.data_dir, self) + + def _find_by_id( + self, workspace_id: WorkspaceId | None, thread_id: ThreadId, run_id: RunId + ) -> RunOrm: + """Find run by ID.""" + run_orm: RunOrm | None = ( + self._session.query(RunOrm) + .filter( + RunOrm.id == run_id, + RunOrm.thread_id == thread_id, + RunOrm.workspace_id == workspace_id, + ) + .first() + ) + if run_orm is None: error_msg = f"Run {run_id} not found in thread {thread_id}" raise NotFoundError(error_msg) - return run_path - - def _create(self, thread_id: ThreadId, params: RunCreateParams) -> Run: - run = Run.create(thread_id, params) - self.save(run, new=True) + return run_orm + + def _create( + self, workspace_id: WorkspaceId, thread_id: ThreadId, params: RunCreate + ) -> Run: + """Create a new run.""" + run = Run.create(workspace_id, thread_id, params) + run_orm = RunOrm.from_model(run) + self._session.add(run_orm) + self._session.commit() return run async def create( self, workspace_id: WorkspaceId, thread_id: ThreadId, - params: RunCreateParams, + params: RunCreate, ) -> tuple[Run, AsyncGenerator[Event, None]]: assistant = self._assistant_service.retrieve( workspace_id=workspace_id, assistant_id=params.assistant_id ) - run = self._create(thread_id, params) + run = self._create(workspace_id, thread_id, params) send_stream, receive_stream = anyio.create_memory_object_stream[Event]() runner = Runner( + run_id=run.id, + thread_id=thread_id, workspace_id=workspace_id, assistant=assistant, - run=run, chat_history_manager=self._chat_history_manager, mcp_client_manager_manager=self._mcp_client_manager_manager, run_service=self, @@ -136,47 +138,69 @@ async def run_runner() -> None: return run, event_generator() @override - def retrieve(self, thread_id: ThreadId, run_id: RunId) -> Run: - try: - run_file = self._get_run_path(thread_id, run_id) - return Run.model_validate_json(run_file.read_text()) - except FileNotFoundError as e: - error_msg = f"Run {run_id} not found in thread {thread_id}" - raise NotFoundError(error_msg) from e + def modify( + self, + workspace_id: WorkspaceId, + thread_id: ThreadId, + run_id: RunId, + params: RunModify, + ) -> Run: + run_orm = self._find_by_id(workspace_id, thread_id, run_id) + run = run_orm.to_model() + run.validate_modify(params) + run_orm.update(params.model_dump(exclude={"type"})) + self._session.commit() + self._session.refresh(run_orm) + return run_orm.to_model() + + @override + def retrieve( + self, workspace_id: WorkspaceId, thread_id: ThreadId, run_id: RunId + ) -> Run: + """Retrieve run by ID.""" + run_orm = self._find_by_id(workspace_id, thread_id, run_id) + return run_orm.to_model() async def retrieve_stream( - self, thread_id: ThreadId, run_id: RunId + self, workspace_id: WorkspaceId, thread_id: ThreadId, run_id: RunId ) -> AsyncGenerator[Event, None]: - async with self._event_service.create_reader(thread_id, run_id) as event_reader: + async with self._event_service.create_reader( + workspace_id=workspace_id, thread_id=thread_id, run_id=run_id + ) as event_reader: async for event in event_reader.read_events(): yield event - def list_(self, query: RunListQuery) -> ListResponse[Run]: + def _build_status_condition(self, status: RunStatus) -> ColumnElement[bool]: + match status: + case "expired": + return (RunOrm.status == "expired") | ( + (RunOrm.status.in_(["queued", "in_progress", "cancelling"])) + & (RunOrm.expires_at < datetime.now(tz=timezone.utc)) + ) + case _: + return RunOrm.status == status + + def list_( + self, workspace_id: WorkspaceId, query: RunListQuery + ) -> ListResponse[Run]: + """List runs with pagination and filtering.""" + q = self._session.query(RunOrm).filter(RunOrm.workspace_id == workspace_id) + if query.thread: - runs_dir = self.get_runs_dir(query.thread) - pattern = "*.json" - else: - runs_dir = self._base_dir / "runs" - pattern = "*/*.json" - return list_resources( - runs_dir, - query, - Run, - filter_fn=_build_run_filter_fn(query), - pattern=pattern, + q = q.filter(RunOrm.thread_id == query.thread) + + if query.status: + status_conditions = [ + self._build_status_condition(status) for status in query.status + ] + q = q.filter(or_(*status_conditions)) + + orms: list[RunOrm] + orms, has_more = list_all(q, query, RunOrm.id) + data = [orm.to_model() for orm in orms] + return ListResponse( + data=data, + has_more=has_more, + first_id=data[0].id if data else None, + last_id=data[-1].id if data else None, ) - - def cancel(self, thread_id: ThreadId, run_id: RunId) -> Run: - run = self.retrieve(thread_id, run_id) - if run.status in ("cancelled", "cancelling", "completed", "failed", "expired"): - return run - run.tried_cancelling_at = datetime.now(tz=timezone.utc) - self.save(run) - return run - - @override - def save(self, run: Run, new: bool = False) -> None: - runs_dir = self.get_runs_dir(run.thread_id) - runs_dir.mkdir(parents=True, exist_ok=True) - run_file = self._get_run_path(run.thread_id, run.id, new=new) - run_file.write_text(run.model_dump_json(), encoding="utf-8") diff --git a/src/askui/chat/api/settings.py b/src/askui/chat/api/settings.py index 72e9b6c9..3f6dc3a5 100644 --- a/src/askui/chat/api/settings.py +++ b/src/askui/chat/api/settings.py @@ -97,7 +97,7 @@ class Settings(BaseSettings): ) allow_origins: list[str] = Field( default_factory=lambda: [ - "http://localhost:4200", + "http://localhost", "https://app.caesr.ai", "https://app-dev.caesr.ai", "https://hub.askui.com", diff --git a/src/askui/chat/api/telemetry/logs/structlog.py b/src/askui/chat/api/telemetry/logs/structlog.py index c2f61e8f..20ddae48 100644 --- a/src/askui/chat/api/telemetry/logs/structlog.py +++ b/src/askui/chat/api/telemetry/logs/structlog.py @@ -1,6 +1,7 @@ import logging import structlog +from structlog.dev import plain_traceback from .settings import LogFormat, LogLevel, LogSettings from .structlog_processors import ( @@ -78,4 +79,7 @@ def get_format_dependent_processors( def get_renderer(log_format: LogFormat) -> structlog.types.Processor: if log_format == "JSON": return structlog.processors.JSONRenderer() - return structlog.dev.ConsoleRenderer(event_key=EVENT_KEY) + return structlog.dev.ConsoleRenderer( + event_key=EVENT_KEY, + exception_formatter=plain_traceback, + ) diff --git a/src/askui/chat/api/threads/dependencies.py b/src/askui/chat/api/threads/dependencies.py index 64ff8172..9eef2483 100644 --- a/src/askui/chat/api/threads/dependencies.py +++ b/src/askui/chat/api/threads/dependencies.py @@ -1,10 +1,6 @@ -from pathlib import Path - from fastapi import Depends -from askui.chat.api.dependencies import WorkspaceDirDep -from askui.chat.api.messages.dependencies import MessageServiceDep -from askui.chat.api.messages.service import MessageService +from askui.chat.api.db.session import SessionDep from askui.chat.api.runs.dependencies import RunServiceDep from askui.chat.api.runs.service import RunService from askui.chat.api.threads.facade import ThreadFacade @@ -12,16 +8,10 @@ def get_thread_service( - workspace_dir: Path = WorkspaceDirDep, - message_service: MessageService = MessageServiceDep, - run_service: RunService = RunServiceDep, + session: SessionDep, ) -> ThreadService: """Get ThreadService instance.""" - return ThreadService( - base_dir=workspace_dir, - message_service=message_service, - run_service=run_service, - ) + return ThreadService(session=session) ThreadServiceDep = Depends(get_thread_service) @@ -29,12 +19,10 @@ def get_thread_service( def get_thread_facade( thread_service: ThreadService = ThreadServiceDep, - message_service: MessageService = MessageServiceDep, run_service: RunService = RunServiceDep, ) -> ThreadFacade: return ThreadFacade( thread_service=thread_service, - message_service=message_service, run_service=run_service, ) diff --git a/src/askui/chat/api/threads/facade.py b/src/askui/chat/api/threads/facade.py index de836dd4..4cb6f705 100644 --- a/src/askui/chat/api/threads/facade.py +++ b/src/askui/chat/api/threads/facade.py @@ -1,18 +1,10 @@ from collections.abc import AsyncGenerator -from askui.chat.api.messages.models import Message, MessageCreateParams -from askui.chat.api.messages.service import MessageService -from askui.chat.api.models import ThreadId, WorkspaceId +from askui.chat.api.models import WorkspaceId from askui.chat.api.runs.events.events import Event -from askui.chat.api.runs.models import ( - Run, - RunCreateParams, - RunListQuery, - ThreadAndRunCreateParams, -) +from askui.chat.api.runs.models import Run, ThreadAndRunCreate from askui.chat.api.runs.service import RunService from askui.chat.api.threads.service import ThreadService -from askui.utils.api_utils import ListQuery, ListResponse class ThreadFacade: @@ -23,55 +15,18 @@ class ThreadFacade: def __init__( self, thread_service: ThreadService, - message_service: MessageService, run_service: RunService, ) -> None: self._thread_service = thread_service - self._message_service = message_service self._run_service = run_service - def _ensure_thread_exists(self, thread_id: ThreadId) -> None: - """Validate that a thread exists before allowing operations on it.""" - self._thread_service.retrieve(thread_id) - - def create_message( - self, thread_id: ThreadId, params: MessageCreateParams - ) -> Message: - """Create a message, ensuring the thread exists first.""" - self._ensure_thread_exists(thread_id) - return self._message_service.create(thread_id, params) - - async def create_run( - self, workspace_id: WorkspaceId, thread_id: ThreadId, params: RunCreateParams - ) -> tuple[Run, AsyncGenerator[Event, None]]: - """Create a run, ensuring the thread exists first.""" - self._ensure_thread_exists(thread_id) - return await self._run_service.create( - workspace_id=workspace_id, - thread_id=thread_id, - params=params, - ) - async def create_thread_and_run( - self, workspace_id: WorkspaceId, params: ThreadAndRunCreateParams + self, workspace_id: WorkspaceId, params: ThreadAndRunCreate ) -> tuple[Run, AsyncGenerator[Event, None]]: """Create a thread and a run, ensuring the thread exists first.""" - thread = self._thread_service.create(params.thread) + thread = self._thread_service.create(workspace_id, params.thread) return await self._run_service.create( workspace_id=workspace_id, thread_id=thread.id, params=params, ) - - def list_messages( - self, thread_id: ThreadId, query: ListQuery - ) -> ListResponse[Message]: - """List messages, ensuring the thread exists first.""" - self._ensure_thread_exists(thread_id) - return self._message_service.list_(thread_id, query) - - def list_runs(self, query: RunListQuery) -> ListResponse[Run]: - """List runs, ensuring the thread exists first.""" - if query.thread: - self._ensure_thread_exists(query.thread) - return self._run_service.list_(query) diff --git a/src/askui/chat/api/threads/models.py b/src/askui/chat/api/threads/models.py index 6ee1931f..fea45552 100644 --- a/src/askui/chat/api/threads/models.py +++ b/src/askui/chat/api/threads/models.py @@ -2,9 +2,8 @@ from pydantic import BaseModel -from askui.chat.api.messages.models import MessageCreateParams -from askui.chat.api.models import ThreadId -from askui.utils.api_utils import Resource +from askui.chat.api.messages.models import MessageCreate +from askui.chat.api.models import ThreadId, WorkspaceId, WorkspaceResource from askui.utils.datetime_utils import UnixDatetime, now from askui.utils.id_utils import generate_time_ordered_id from askui.utils.not_given import NOT_GIVEN, BaseModelWithNotGiven, NotGiven @@ -16,19 +15,19 @@ class ThreadBase(BaseModel): name: str | None = None -class ThreadCreateParams(ThreadBase): +class ThreadCreate(ThreadBase): """Parameters for creating a thread.""" - messages: list[MessageCreateParams] | None = None + messages: list[MessageCreate] | None = None -class ThreadModifyParams(BaseModelWithNotGiven): +class ThreadModify(BaseModelWithNotGiven): """Parameters for modifying a thread.""" name: str | None | NotGiven = NOT_GIVEN -class Thread(ThreadBase, Resource): +class Thread(ThreadBase, WorkspaceResource): """A chat thread/session.""" id: ThreadId @@ -36,17 +35,10 @@ class Thread(ThreadBase, Resource): created_at: UnixDatetime @classmethod - def create(cls, params: ThreadCreateParams) -> "Thread": + def create(cls, workspace_id: WorkspaceId, params: ThreadCreate) -> "Thread": return cls( id=generate_time_ordered_id("thread"), created_at=now(), + workspace_id=workspace_id, **params.model_dump(exclude={"messages"}), ) - - def modify(self, params: ThreadModifyParams) -> "Thread": - return Thread.model_validate( - { - **self.model_dump(), - **params.model_dump(), - } - ) diff --git a/src/askui/chat/api/threads/orms.py b/src/askui/chat/api/threads/orms.py new file mode 100644 index 00000000..9e2b3cfc --- /dev/null +++ b/src/askui/chat/api/threads/orms.py @@ -0,0 +1,31 @@ +"""Thread database model.""" + +from datetime import datetime +from uuid import UUID + +from sqlalchemy import String, Uuid +from sqlalchemy.orm import Mapped, mapped_column + +from askui.chat.api.db.orm.base import Base +from askui.chat.api.db.orm.types import UnixDatetime, create_prefixed_id_type +from askui.chat.api.threads.models import Thread + +ThreadId = create_prefixed_id_type("thread") + + +class ThreadOrm(Base): + """Thread database model.""" + + __tablename__ = "threads" + + id: Mapped[str] = mapped_column(ThreadId, primary_key=True) + workspace_id: Mapped[UUID] = mapped_column(Uuid, nullable=False, index=True) + created_at: Mapped[datetime] = mapped_column(UnixDatetime, nullable=False) + name: Mapped[str | None] = mapped_column(String, nullable=True) + + @classmethod + def from_model(cls, model: Thread) -> "ThreadOrm": + return cls(**model.model_dump(exclude={"object"})) + + def to_model(self) -> Thread: + return Thread.model_validate(self, from_attributes=True) diff --git a/src/askui/chat/api/threads/router.py b/src/askui/chat/api/threads/router.py index a9e18bf4..5ea65319 100644 --- a/src/askui/chat/api/threads/router.py +++ b/src/askui/chat/api/threads/router.py @@ -1,9 +1,11 @@ -from fastapi import APIRouter, status +from typing import Annotated + +from fastapi import APIRouter, Header, status from askui.chat.api.dependencies import ListQueryDep -from askui.chat.api.models import ThreadId +from askui.chat.api.models import ThreadId, WorkspaceId from askui.chat.api.threads.dependencies import ThreadServiceDep -from askui.chat.api.threads.models import Thread, ThreadCreateParams, ThreadModifyParams +from askui.chat.api.threads.models import Thread, ThreadCreate, ThreadModify from askui.chat.api.threads.service import ThreadService from askui.utils.api_utils import ListQuery, ListResponse @@ -12,40 +14,47 @@ @router.get("") def list_threads( + askui_workspace: Annotated[WorkspaceId, Header()], query: ListQuery = ListQueryDep, thread_service: ThreadService = ThreadServiceDep, ) -> ListResponse[Thread]: - return thread_service.list_(query=query) + return thread_service.list_(workspace_id=askui_workspace, query=query) @router.post("", status_code=status.HTTP_201_CREATED) def create_thread( - params: ThreadCreateParams, + askui_workspace: Annotated[WorkspaceId, Header()], + params: ThreadCreate, thread_service: ThreadService = ThreadServiceDep, ) -> Thread: - return thread_service.create(params) + return thread_service.create(workspace_id=askui_workspace, params=params) @router.get("/{thread_id}") def retrieve_thread( + askui_workspace: Annotated[WorkspaceId, Header()], thread_id: ThreadId, thread_service: ThreadService = ThreadServiceDep, ) -> Thread: - return thread_service.retrieve(thread_id) + return thread_service.retrieve(workspace_id=askui_workspace, thread_id=thread_id) @router.post("/{thread_id}") def modify_thread( + askui_workspace: Annotated[WorkspaceId, Header()], thread_id: ThreadId, - params: ThreadModifyParams, + params: ThreadModify, thread_service: ThreadService = ThreadServiceDep, ) -> Thread: - return thread_service.modify(thread_id, params) + return thread_service.modify( + workspace_id=askui_workspace, thread_id=thread_id, params=params + ) @router.delete("/{thread_id}", status_code=status.HTTP_204_NO_CONTENT) def delete_thread( + askui_workspace: Annotated[WorkspaceId, Header()], thread_id: ThreadId, thread_service: ThreadService = ThreadServiceDep, ) -> None: - thread_service.delete(thread_id) + thread_service.delete(workspace_id=askui_workspace, thread_id=thread_id) diff --git a/src/askui/chat/api/threads/service.py b/src/askui/chat/api/threads/service.py index ca54f203..413bea1f 100644 --- a/src/askui/chat/api/threads/service.py +++ b/src/askui/chat/api/threads/service.py @@ -1,82 +1,75 @@ -import shutil -from pathlib import Path +from sqlalchemy.orm import Session -from askui.chat.api.messages.service import MessageService -from askui.chat.api.models import ThreadId -from askui.chat.api.runs.service import RunService -from askui.chat.api.threads.models import Thread, ThreadCreateParams, ThreadModifyParams -from askui.utils.api_utils import ( - ConflictError, - ListQuery, - ListResponse, - NotFoundError, - list_resources, -) +from askui.chat.api.db.queries import list_all +from askui.chat.api.models import ThreadId, WorkspaceId +from askui.chat.api.threads.models import Thread, ThreadCreate, ThreadModify +from askui.chat.api.threads.orms import ThreadOrm +from askui.utils.api_utils import ListQuery, ListResponse, NotFoundError class ThreadService: - """Service for managing Thread resources with filesystem persistence.""" + """Service for managing Thread resources with database persistence.""" - def __init__( - self, base_dir: Path, message_service: MessageService, run_service: RunService - ) -> None: - self._base_dir = base_dir - self._threads_dir = base_dir / "threads" - self._message_service = message_service - self._run_service = run_service + def __init__(self, session: Session) -> None: + self._session = session - def _get_thread_path(self, thread_id: ThreadId, new: bool = False) -> Path: - thread_path = self._threads_dir / f"{thread_id}.json" - exists = thread_path.exists() - if new and exists: - error_msg = f"Thread {thread_id} already exists" - raise ConflictError(error_msg) - if not new and not exists: + def _find_by_id(self, workspace_id: WorkspaceId, thread_id: ThreadId) -> ThreadOrm: + """Find thread by ID.""" + thread_orm: ThreadOrm | None = ( + self._session.query(ThreadOrm) + .filter( + ThreadOrm.id == thread_id, + ThreadOrm.workspace_id == workspace_id, + ) + .first() + ) + if thread_orm is None: error_msg = f"Thread {thread_id} not found" raise NotFoundError(error_msg) - return thread_path - - def list_(self, query: ListQuery) -> ListResponse[Thread]: - return list_resources(self._threads_dir, query, Thread) + return thread_orm - def retrieve(self, thread_id: ThreadId) -> Thread: - try: - thread_path = self._get_thread_path(thread_id) - return Thread.model_validate_json(thread_path.read_text()) - except FileNotFoundError as e: - error_msg = f"Thread {thread_id} not found" - raise NotFoundError(error_msg) from e + def list_( + self, workspace_id: WorkspaceId, query: ListQuery + ) -> ListResponse[Thread]: + """List threads with pagination and filtering.""" + q = self._session.query(ThreadOrm).filter( + ThreadOrm.workspace_id == workspace_id + ) + orms: list[ThreadOrm] + orms, has_more = list_all(q, query, ThreadOrm.id) + data = [orm.to_model() for orm in orms] + return ListResponse( + data=data, + has_more=has_more, + first_id=data[0].id if data else None, + last_id=data[-1].id if data else None, + ) - def create(self, params: ThreadCreateParams) -> Thread: - thread = Thread.create(params) - self._save(thread, new=True) + def retrieve(self, workspace_id: WorkspaceId, thread_id: ThreadId) -> Thread: + """Retrieve thread by ID.""" + thread_orm = self._find_by_id(workspace_id, thread_id) + return thread_orm.to_model() - if params.messages: - for message in params.messages: - self._message_service.create( - thread_id=thread.id, - params=message, - ) + def create(self, workspace_id: WorkspaceId, params: ThreadCreate) -> Thread: + """Create a new thread.""" + thread = Thread.create(workspace_id, params) + thread_orm = ThreadOrm.from_model(thread) + self._session.add(thread_orm) + self._session.commit() return thread - def modify(self, thread_id: ThreadId, params: ThreadModifyParams) -> Thread: - thread = self.retrieve(thread_id) - modified = thread.modify(params) - self._save(modified) - return modified - - def delete(self, thread_id: ThreadId) -> None: - try: - shutil.rmtree( - self._message_service.get_messages_dir(thread_id), ignore_errors=True - ) - shutil.rmtree(self._run_service.get_runs_dir(thread_id), ignore_errors=True) - self._get_thread_path(thread_id).unlink() - except FileNotFoundError as e: - error_msg = f"Thread {thread_id} not found" - raise NotFoundError(error_msg) from e + def modify( + self, workspace_id: WorkspaceId, thread_id: ThreadId, params: ThreadModify + ) -> Thread: + """Modify an existing thread.""" + thread_orm = self._find_by_id(workspace_id, thread_id) + thread_orm.update(params.model_dump()) + self._session.commit() + self._session.refresh(thread_orm) + return thread_orm.to_model() - def _save(self, thread: Thread, new: bool = False) -> None: - self._threads_dir.mkdir(parents=True, exist_ok=True) - thread_file = self._get_thread_path(thread.id, new=new) - thread_file.write_text(thread.model_dump_json(), encoding="utf-8") + def delete(self, workspace_id: WorkspaceId, thread_id: ThreadId) -> None: + """Delete a thread and cascade to messages and runs.""" + thread_orm = self._find_by_id(workspace_id, thread_id) + self._session.delete(thread_orm) + self._session.commit() diff --git a/src/askui/chat/migrations/env.py b/src/askui/chat/migrations/env.py index 5b83ccc7..197b0513 100644 --- a/src/askui/chat/migrations/env.py +++ b/src/askui/chat/migrations/env.py @@ -6,6 +6,11 @@ # We need to import the orms to ensure they are registered import askui.chat.api.assistants.orms +import askui.chat.api.files.orms +import askui.chat.api.mcp_configs.orms +import askui.chat.api.messages.orms +import askui.chat.api.runs.orms +import askui.chat.api.threads.orms from askui.chat.api.db.orm.base import Base from askui.chat.api.dependencies import get_settings from askui.chat.api.telemetry.logs import setup_logging diff --git a/src/askui/chat/migrations/shared/assistants/models.py b/src/askui/chat/migrations/shared/assistants/models.py index c02d896f..693e1d1d 100644 --- a/src/askui/chat/migrations/shared/assistants/models.py +++ b/src/askui/chat/migrations/shared/assistants/models.py @@ -23,7 +23,7 @@ class AssistantV1(BaseModel): def to_db_dict(self) -> dict[str, Any]: return { - **self.model_dump(exclude={"id", "object", "workspace_id"}), + **self.model_dump(exclude={"id", "object"}), "id": self.id.removeprefix("asst_"), - "workspace_id": str(self.workspace_id) if self.workspace_id else None, + "workspace_id": self.workspace_id.hex if self.workspace_id else None, } diff --git a/src/askui/chat/migrations/shared/files/models.py b/src/askui/chat/migrations/shared/files/models.py index 2c1d9026..c57595d8 100644 --- a/src/askui/chat/migrations/shared/files/models.py +++ b/src/askui/chat/migrations/shared/files/models.py @@ -21,7 +21,7 @@ class FileV1(BaseModel): def to_db_dict(self) -> dict[str, Any]: return { - **self.model_dump(exclude={"id", "object", "workspace_id"}), + **self.model_dump(exclude={"id", "object"}), "id": self.id.removeprefix("file_"), - "workspace_id": str(self.workspace_id) if self.workspace_id else None, + "workspace_id": self.workspace_id.hex if self.workspace_id else None, } diff --git a/src/askui/chat/migrations/shared/mcp_configs/models.py b/src/askui/chat/migrations/shared/mcp_configs/models.py index 4c960efb..0f5be421 100644 --- a/src/askui/chat/migrations/shared/mcp_configs/models.py +++ b/src/askui/chat/migrations/shared/mcp_configs/models.py @@ -35,7 +35,7 @@ class McpConfigV1(BaseModel): def to_db_dict(self) -> dict[str, Any]: return { - **self.model_dump(exclude={"id", "object", "workspace_id"}), + **self.model_dump(exclude={"id", "object"}), "id": self.id.removeprefix("mcpcnf_"), - "workspace_id": str(self.workspace_id) if self.workspace_id else None, + "workspace_id": self.workspace_id.hex if self.workspace_id else None, } diff --git a/src/askui/chat/migrations/shared/messages/__init__.py b/src/askui/chat/migrations/shared/messages/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/migrations/shared/messages/models.py b/src/askui/chat/migrations/shared/messages/models.py new file mode 100644 index 00000000..8a900bd6 --- /dev/null +++ b/src/askui/chat/migrations/shared/messages/models.py @@ -0,0 +1,166 @@ +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, BeforeValidator, Field + +from askui.chat.migrations.shared.assistants.models import AssistantIdV1 +from askui.chat.migrations.shared.models import UnixDatetimeV1, WorkspaceIdV1 +from askui.chat.migrations.shared.runs.models import RunIdV1 +from askui.chat.migrations.shared.threads.models import ThreadIdV1 +from askui.chat.migrations.shared.utils import build_prefixer + +MessageIdV1 = Annotated[ + str, Field(pattern=r"^msg_[a-z0-9]+$"), BeforeValidator(build_prefixer("msg")) +] + + +class CacheControlEphemeralParamV1(BaseModel): + type: Literal["ephemeral"] = "ephemeral" + + +class CitationCharLocationParamV1(BaseModel): + cited_text: str + document_index: int + document_title: str | None = None + end_char_index: int + start_char_index: int + type: Literal["char_location"] = "char_location" + + +class CitationPageLocationParamV1(BaseModel): + cited_text: str + document_index: int + document_title: str | None = None + end_page_number: int + start_page_number: int + type: Literal["page_location"] = "page_location" + + +class CitationContentBlockLocationParamV1(BaseModel): + cited_text: str + document_index: int + document_title: str | None = None + end_block_index: int + start_block_index: int + type: Literal["content_block_location"] = "content_block_location" + + +TextCitationParamV1 = ( + CitationCharLocationParamV1 + | CitationPageLocationParamV1 + | CitationContentBlockLocationParamV1 +) + + +class UrlImageSourceParamV1(BaseModel): + type: Literal["url"] = "url" + url: str + + +class Base64ImageSourceParamV1(BaseModel): + data: str + media_type: Literal["image/jpeg", "image/png", "image/gif", "image/webp"] + type: Literal["base64"] = "base64" + + +class FileImageSourceParamV1(BaseModel): + """Image source that references a saved file.""" + + id: str # FileId equivalent + type: Literal["file"] = "file" + + +class ImageBlockParamV1(BaseModel): + source: Base64ImageSourceParamV1 | UrlImageSourceParamV1 | FileImageSourceParamV1 + type: Literal["image"] = "image" + cache_control: CacheControlEphemeralParamV1 | None = None + + +class TextBlockParamV1(BaseModel): + text: str + type: Literal["text"] = "text" + cache_control: CacheControlEphemeralParamV1 | None = None + citations: list[TextCitationParamV1] | None = None + + +class ToolResultBlockParamV1(BaseModel): + tool_use_id: str + type: Literal["tool_result"] = "tool_result" + cache_control: CacheControlEphemeralParamV1 | None = None + content: str | list[TextBlockParamV1 | ImageBlockParamV1] + is_error: bool = False + + +class ToolUseBlockParamV1(BaseModel): + id: str + input: object + name: str + type: Literal["tool_use"] = "tool_use" + cache_control: CacheControlEphemeralParamV1 | None = None + + +class BetaThinkingBlockV1(BaseModel): + signature: str + thinking: str + type: Literal["thinking"] + + +class BetaRedactedThinkingBlockV1(BaseModel): + data: str + type: Literal["redacted_thinking"] + + +class BetaFileDocumentSourceParamV1(BaseModel): + file_id: str + type: Literal["file"] = "file" + + +SourceV1 = BetaFileDocumentSourceParamV1 + + +class RequestDocumentBlockParamV1(BaseModel): + source: SourceV1 + type: Literal["document"] = "document" + cache_control: CacheControlEphemeralParamV1 | None = None + + +ContentBlockParamV1 = ( + ImageBlockParamV1 + | TextBlockParamV1 + | ToolResultBlockParamV1 + | ToolUseBlockParamV1 + | BetaThinkingBlockV1 + | BetaRedactedThinkingBlockV1 + | RequestDocumentBlockParamV1 +) + + +StopReasonV1 = Literal[ + "end_turn", "max_tokens", "stop_sequence", "tool_use", "pause_turn", "refusal" +] + + +class MessageV1(BaseModel): + id: MessageIdV1 + object: Literal["thread.message"] = "thread.message" + created_at: UnixDatetimeV1 + thread_id: ThreadIdV1 + role: Literal["user", "assistant"] + content: str | list[ContentBlockParamV1] + stop_reason: StopReasonV1 | None = None + assistant_id: AssistantIdV1 | None = None + run_id: RunIdV1 | None = None + workspace_id: WorkspaceIdV1 = Field(exclude=True) + + def to_db_dict(self) -> dict[str, Any]: + return { + **self.model_dump( + exclude={"id", "thread_id", "assistant_id", "run_id", "object"} + ), + "id": self.id.removeprefix("msg_"), + "thread_id": self.thread_id.removeprefix("thread_"), + "assistant_id": self.assistant_id.removeprefix("asst_") + if self.assistant_id + else None, + "run_id": self.run_id.removeprefix("run_") if self.run_id else None, + "workspace_id": self.workspace_id.hex, + } diff --git a/src/askui/chat/migrations/shared/runs/__init__.py b/src/askui/chat/migrations/shared/runs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/migrations/shared/runs/models.py b/src/askui/chat/migrations/shared/runs/models.py new file mode 100644 index 00000000..201f832f --- /dev/null +++ b/src/askui/chat/migrations/shared/runs/models.py @@ -0,0 +1,74 @@ +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, BeforeValidator, Field, computed_field + +from askui.chat.migrations.shared.assistants.models import AssistantIdV1 +from askui.chat.migrations.shared.models import UnixDatetimeV1, WorkspaceIdV1 +from askui.chat.migrations.shared.threads.models import ThreadIdV1 +from askui.chat.migrations.shared.utils import build_prefixer, now_v1 + +RunStatusV1 = Literal[ + "queued", + "in_progress", + "completed", + "cancelling", + "cancelled", + "failed", + "expired", +] + + +class RunErrorV1(BaseModel): + """Error information for a failed run.""" + + message: str + code: Literal["server_error", "rate_limit_exceeded", "invalid_prompt"] + + +RunIdV1 = Annotated[ + str, Field(pattern=r"^run_[a-z0-9]+$"), BeforeValidator(build_prefixer("run")) +] + + +class RunV1(BaseModel): + id: RunIdV1 + object: Literal["thread.run"] = "thread.run" + thread_id: ThreadIdV1 + created_at: UnixDatetimeV1 + expires_at: UnixDatetimeV1 + started_at: UnixDatetimeV1 | None = None + completed_at: UnixDatetimeV1 | None = None + failed_at: UnixDatetimeV1 | None = None + cancelled_at: UnixDatetimeV1 | None = None + tried_cancelling_at: UnixDatetimeV1 | None = None + last_error: RunErrorV1 | None = None + assistant_id: AssistantIdV1 | None = None + workspace_id: WorkspaceIdV1 = Field(exclude=True) + + def to_db_dict(self) -> dict[str, Any]: + return { + **self.model_dump(exclude={"id", "thread_id", "assistant_id", "object"}), + "id": self.id.removeprefix("run_"), + "thread_id": self.thread_id.removeprefix("thread_"), + "assistant_id": self.assistant_id.removeprefix("asst_") + if self.assistant_id + else None, + "workspace_id": self.workspace_id.hex, + } + + @computed_field # type: ignore[prop-decorator] + @property + def status(self) -> RunStatusV1: + if self.cancelled_at: + return "cancelled" + if self.failed_at: + return "failed" + if self.completed_at: + return "completed" + if self.expires_at and self.expires_at < now_v1(): + return "expired" + if self.tried_cancelling_at: + return "cancelling" + if self.started_at: + return "in_progress" + return "queued" diff --git a/src/askui/chat/migrations/shared/threads/__init__.py b/src/askui/chat/migrations/shared/threads/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/migrations/shared/threads/models.py b/src/askui/chat/migrations/shared/threads/models.py new file mode 100644 index 00000000..5c90c549 --- /dev/null +++ b/src/askui/chat/migrations/shared/threads/models.py @@ -0,0 +1,25 @@ +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, BeforeValidator, Field + +from askui.chat.migrations.shared.models import UnixDatetimeV1, WorkspaceIdV1 +from askui.chat.migrations.shared.utils import build_prefixer + +ThreadIdV1 = Annotated[ + str, Field(pattern=r"^thread_[a-z0-9]+$"), BeforeValidator(build_prefixer("thread")) +] + + +class ThreadV1(BaseModel): + id: ThreadIdV1 + object: Literal["thread"] = "thread" + created_at: UnixDatetimeV1 + name: str | None = None + workspace_id: WorkspaceIdV1 = Field(exclude=True) + + def to_db_dict(self) -> dict[str, Any]: + return { + **self.model_dump(exclude={"id", "object"}), + "id": self.id.removeprefix("thread_"), + "workspace_id": self.workspace_id.hex, + } diff --git a/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py b/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py index 947767a6..49e73660 100644 --- a/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py +++ b/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py @@ -11,7 +11,7 @@ from typing import Sequence, Union from alembic import op -from sqlalchemy import MetaData, Table +from sqlalchemy import Connection, MetaData, Table from askui.chat.migrations.shared.assistants.models import AssistantV1 from askui.chat.migrations.shared.settings import SettingsV1 @@ -25,15 +25,18 @@ logger = logging.getLogger(__name__) -BATCH_SIZE = 100 +BATCH_SIZE = 1000 def _insert_assistants_batch( - assistants_table: Table, assistants_batch: list[AssistantV1] + connection: Connection, assistants_table: Table, assistants_batch: list[AssistantV1] ) -> None: - """Insert a batch of assistants into the database.""" - op.bulk_insert( - assistants_table, + """Insert a batch of assistants into the database, ignoring conflicts.""" + if not assistants_batch: + return + + connection.execute( + assistants_table.insert().prefix_with("OR REPLACE"), [assistant.to_db_dict() for assistant in assistants_batch], ) @@ -67,7 +70,7 @@ def upgrade() -> None: assistants_batch.append(assistant) if len(assistants_batch) >= BATCH_SIZE: - _insert_assistants_batch(assistants_table, assistants_batch) + _insert_assistants_batch(connection, assistants_table, assistants_batch) assistants_batch.clear() except Exception: # noqa: PERF203 error_msg = "Failed to import" @@ -76,7 +79,7 @@ def upgrade() -> None: # Insert remaining assistants in the final batch if assistants_batch: - _insert_assistants_batch(assistants_table, assistants_batch) + _insert_assistants_batch(connection, assistants_table, assistants_batch) def downgrade() -> None: @@ -102,7 +105,7 @@ def downgrade() -> None: if json_path.exists(): continue with json_path.open("w", encoding="utf-8") as f: - f.write(json.dumps(assistant.model_dump())) + f.write(assistant.model_dump_json()) except Exception as e: # noqa: PERF203 error_msg = f"Failed to export row to json: {e}" logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) diff --git a/src/askui/chat/migrations/versions/1a2b3c4d5e6f_create_threads_table.py b/src/askui/chat/migrations/versions/1a2b3c4d5e6f_create_threads_table.py new file mode 100644 index 00000000..7bb7c9b8 --- /dev/null +++ b/src/askui/chat/migrations/versions/1a2b3c4d5e6f_create_threads_table.py @@ -0,0 +1,36 @@ +"""create_threads_table + +Revision ID: 1a2b3c4d5e6f +Revises: a0f1a2b3c4d5 +Create Date: 2025-01-27 12:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "1a2b3c4d5e6f" +down_revision: Union[str, None] = "a0f1a2b3c4d5" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "threads", + sa.Column("id", sa.String(24), nullable=False, primary_key=True), + sa.Column("workspace_id", sa.Uuid(), nullable=False, index=True), + sa.Column("created_at", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("threads") + # ### end Alembic commands ### diff --git a/src/askui/chat/migrations/versions/2b3c4d5e6f7a_create_messages_table.py b/src/askui/chat/migrations/versions/2b3c4d5e6f7a_create_messages_table.py new file mode 100644 index 00000000..68eb1edc --- /dev/null +++ b/src/askui/chat/migrations/versions/2b3c4d5e6f7a_create_messages_table.py @@ -0,0 +1,60 @@ +"""create_messages_table + +Revision ID: 2b3c4d5e6f7a +Revises: 1a2b3c4d5e6f +Create Date: 2025-01-27 12:01:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "2b3c4d5e6f7a" +down_revision: Union[str, None] = "1a2b3c4d5e6f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "messages", + sa.Column("id", sa.String(24), nullable=False, primary_key=True), + sa.Column( + "thread_id", + sa.String(24), + sa.ForeignKey( + "threads.id", ondelete="CASCADE", name="fk_messages_thread_id" + ), + nullable=False, + ), + sa.Column("workspace_id", sa.Uuid(), nullable=False, index=True), + sa.Column("created_at", sa.Integer(), nullable=False), + sa.Column("role", sa.String(), nullable=False), + sa.Column("content", sa.JSON(), nullable=False), + sa.Column("stop_reason", sa.String(), nullable=True), + sa.Column( + "assistant_id", + sa.String(24), + sa.ForeignKey( + "assistants.id", ondelete="SET NULL", name="fk_messages_assistant_id" + ), + nullable=True, + ), + sa.Column( + "run_id", + sa.String(24), + sa.ForeignKey("runs.id", ondelete="SET NULL", name="fk_messages_run_id"), + nullable=True, + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("messages") + # ### end Alembic commands ### diff --git a/src/askui/chat/migrations/versions/37007a499ca7_remove_assistants_dir.py b/src/askui/chat/migrations/versions/37007a499ca7_remove_assistants_dir.py deleted file mode 100644 index 18f004dd..00000000 --- a/src/askui/chat/migrations/versions/37007a499ca7_remove_assistants_dir.py +++ /dev/null @@ -1,51 +0,0 @@ -"""remove_assistants_dir - -Revision ID: 37007a499ca7 -Revises: c35e88ea9595 -Create Date: 2025-10-10 14:01:53.410908 - -""" - -import logging -import shutil -from typing import Sequence, Union - -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "37007a499ca7" -down_revision: Union[str, None] = "c35e88ea9595" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - -settings = SettingsV1() -assistants_dir = settings.data_dir / "assistants" - - -def upgrade() -> None: - """Remove the assistants directory and all its contents.""" - - # Skip if directory doesn't exist - if not assistants_dir.exists(): - logger.info("Assistants directory does not exist, skipping removal") - return - - try: - shutil.rmtree(assistants_dir) - logger.info( - "Successfully removed assistants directory", - extra={"assistants_dir": str(assistants_dir)}, - ) - except Exception as e: - error_msg = "Failed to remove assistants directory" - logger.exception( - error_msg, - extra={"assistants_dir": str(assistants_dir)}, - ) - raise RuntimeError(error_msg) from e - - -def downgrade() -> None: - assistants_dir.mkdir(parents=True, exist_ok=True) diff --git a/src/askui/chat/migrations/versions/37007a499ca7_soft_delete_assistants_dir.py b/src/askui/chat/migrations/versions/37007a499ca7_soft_delete_assistants_dir.py new file mode 100644 index 00000000..3a6435ff --- /dev/null +++ b/src/askui/chat/migrations/versions/37007a499ca7_soft_delete_assistants_dir.py @@ -0,0 +1,87 @@ +"""soft_delete_assistants_dir + +Revision ID: 37007a499ca7 +Revises: c35e88ea9595 +Create Date: 2025-10-10 14:01:53.410908 + +""" + +import logging +from pathlib import Path +from typing import Sequence, Union + +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "37007a499ca7" +down_revision: Union[str, None] = "c35e88ea9595" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + +settings = SettingsV1() +assistants_dir = settings.data_dir / "assistants" + + +def upgrade() -> None: + """Soft delete the assistants directory by moving it to .deleted subdirectory.""" + + # Skip if directory doesn't exist + if not assistants_dir.exists(): + logger.info("Assistants directory does not exist, skipping soft delete") + return + + try: + # Create .deleted directory if it doesn't exist + deleted_dir = settings.data_dir / ".deleted" + deleted_dir.mkdir(parents=True, exist_ok=True) + + # Move assistants directory to .deleted subdirectory + deleted_assistants_dir = deleted_dir / "assistants" + if deleted_assistants_dir.exists(): + logger.info( + "Deleted assistants directory already exists, skipping soft delete", + extra={"deleted_assistants_dir": str(deleted_assistants_dir)}, + ) + return + + assistants_dir.rename(deleted_assistants_dir) + logger.info( + "Successfully soft deleted assistants directory", + extra={ + "assistants_dir": str(assistants_dir), + "deleted_assistants_dir": str(deleted_assistants_dir), + }, + ) + except Exception as e: + error_msg = "Failed to soft delete assistants directory" + logger.exception( + error_msg, + extra={"assistants_dir": str(assistants_dir)}, + ) + raise RuntimeError(error_msg) from e + + +def downgrade() -> None: + """Restore the assistants directory from .deleted subdirectory.""" + deleted_dir = settings.data_dir / ".deleted" + deleted_assistants_dir = deleted_dir / "assistants" + + if not deleted_assistants_dir.exists(): + logger.info("No deleted assistants directory found to restore") + return + + try: + deleted_assistants_dir.rename(assistants_dir) + logger.info( + "Successfully restored assistants directory", + extra={"assistants_dir": str(assistants_dir)}, + ) + except Exception as e: + error_msg = "Failed to restore assistants directory" + logger.exception( + error_msg, + extra={"assistants_dir": str(assistants_dir)}, + ) + raise RuntimeError(error_msg) from e diff --git a/src/askui/chat/migrations/versions/3c4d5e6f7a8b_create_runs_table.py b/src/askui/chat/migrations/versions/3c4d5e6f7a8b_create_runs_table.py new file mode 100644 index 00000000..72eab784 --- /dev/null +++ b/src/askui/chat/migrations/versions/3c4d5e6f7a8b_create_runs_table.py @@ -0,0 +1,72 @@ +"""create_runs_table + +Revision ID: 3c4d5e6f7a8b +Revises: 2b3c4d5e6f7a +Create Date: 2025-01-27 12:02:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "3c4d5e6f7a8b" +down_revision: Union[str, None] = "2b3c4d5e6f7a" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "runs", + sa.Column("id", sa.String(24), nullable=False, primary_key=True), + sa.Column( + "thread_id", + sa.String(24), + sa.ForeignKey("threads.id", ondelete="CASCADE", name="fk_runs_thread_id"), + nullable=False, + index=True, + ), + sa.Column("workspace_id", sa.Uuid(), nullable=False, index=True), + sa.Column("created_at", sa.Integer(), nullable=False), + sa.Column("expires_at", sa.Integer(), nullable=False), + sa.Column("started_at", sa.Integer(), nullable=True), + sa.Column("completed_at", sa.Integer(), nullable=True), + sa.Column("failed_at", sa.Integer(), nullable=True), + sa.Column("cancelled_at", sa.Integer(), nullable=True), + sa.Column("tried_cancelling_at", sa.Integer(), nullable=True), + sa.Column("last_error", sa.JSON(), nullable=True), + sa.Column( + "assistant_id", + sa.String(24), + sa.ForeignKey( + "assistants.id", ondelete="SET NULL", name="fk_runs_assistant_id" + ), + nullable=True, + ), + sa.Column( + "status", + sa.Enum( + "queued", + "in_progress", + "completed", + "cancelled", + "failed", + "expired", + "cancelling", + ), + nullable=False, + index=True, + ), + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("runs") + # ### end Alembic commands ### diff --git a/src/askui/chat/migrations/versions/4d5e6f7a8b9c_import_json_threads.py b/src/askui/chat/migrations/versions/4d5e6f7a8b9c_import_json_threads.py new file mode 100644 index 00000000..b1b4a9fb --- /dev/null +++ b/src/askui/chat/migrations/versions/4d5e6f7a8b9c_import_json_threads.py @@ -0,0 +1,120 @@ +"""import_json_threads + +Revision ID: 4d5e6f7a8b9c +Revises: 3c4d5e6f7a8b +Create Date: 2025-01-27 12:03:00.000000 + +""" + +import json +import logging +from typing import Sequence, Union + +from alembic import op +from sqlalchemy import Connection, MetaData, Table + +from askui.chat.migrations.shared.settings import SettingsV1 +from askui.chat.migrations.shared.threads.models import ThreadV1 + +# revision identifiers, used by Alembic. +revision: str = "4d5e6f7a8b9c" +down_revision: Union[str, None] = "3c4d5e6f7a8b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + + +BATCH_SIZE = 1000 + + +def _insert_threads_batch( + connection: Connection, threads_table: Table, threads_batch: list[ThreadV1] +) -> None: + """Insert a batch of threads into the database, ignoring conflicts.""" + if not threads_batch: + return + + connection.execute( + threads_table.insert().prefix_with("OR REPLACE"), + [thread.to_db_dict() for thread in threads_batch], + ) + + +settings = SettingsV1() +workspaces_dir = settings.data_dir / "workspaces" + + +def upgrade() -> None: # noqa: C901 + """Import existing threads from JSON files in workspace directories.""" + + # Skip if workspaces directory doesn't exist (e.g., first-time setup) + if not workspaces_dir.exists(): + return + + # Get the table from the current database schema + connection = op.get_bind() + threads_table = Table("threads", MetaData(), autoload_with=connection) + + # Process threads in batches + threads_batch: list[ThreadV1] = [] + + # Iterate through all workspace directories + for workspace_dir in workspaces_dir.iterdir(): + if not workspace_dir.is_dir(): + continue + + workspace_id = workspace_dir.name + threads_dir = workspace_dir / "threads" + + if not threads_dir.exists(): + continue + + # Get all JSON files in the threads directory + json_files = list(threads_dir.glob("*.json")) + + for json_file in json_files: + try: + content = json_file.read_text(encoding="utf-8").strip() + data = json.loads(content) + thread = ThreadV1.model_validate({**data, "workspace_id": workspace_id}) + threads_batch.append(thread) + if len(threads_batch) >= BATCH_SIZE: + _insert_threads_batch(connection, threads_table, threads_batch) + threads_batch.clear() + except Exception: # noqa: PERF203 + error_msg = "Failed to import thread" + logger.exception(error_msg, extra={"json_file": str(json_file)}) + continue + + # Insert remaining threads in the final batch + if threads_batch: + _insert_threads_batch(connection, threads_table, threads_batch) + + +def downgrade() -> None: + """Recreate JSON files for threads during downgrade.""" + + connection = op.get_bind() + threads_table = Table("threads", MetaData(), autoload_with=connection) + + # Fetch all threads from the database + result = connection.execute(threads_table.select()) + rows = result.fetchall() + if not rows: + return + + for row in rows: + try: + thread_model: ThreadV1 = ThreadV1.model_validate(row, from_attributes=True) + threads_dir = workspaces_dir / str(thread_model.workspace_id) / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + json_path = threads_dir / f"{thread_model.id}.json" + if json_path.exists(): + continue + with json_path.open("w", encoding="utf-8") as f: + f.write(thread_model.model_dump_json()) + except Exception as e: # noqa: PERF203 + error_msg = f"Failed to export row to json: {e}" + logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) + continue diff --git a/src/askui/chat/migrations/versions/5e6f7a8b9c0d_import_json_messages.py b/src/askui/chat/migrations/versions/5e6f7a8b9c0d_import_json_messages.py new file mode 100644 index 00000000..d5602ae4 --- /dev/null +++ b/src/askui/chat/migrations/versions/5e6f7a8b9c0d_import_json_messages.py @@ -0,0 +1,266 @@ +"""import_json_messages + +Revision ID: 5e6f7a8b9c0d +Revises: 4d5e6f7a8b9c +Create Date: 2025-01-27 12:04:00.000000 + +""" + +import json +import logging +from typing import Sequence, Union + +from alembic import op +from sqlalchemy import Connection, MetaData, Table, text + +from askui.chat.migrations.shared.messages.models import MessageV1 +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "5e6f7a8b9c0d" +down_revision: Union[str, None] = "4d5e6f7a8b9c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + + +BATCH_SIZE = 1000 + + +def _insert_messages_batch( + connection: Connection, messages_table: Table, messages_batch: list[MessageV1] +) -> None: + """Insert a batch of messages into the database, handling foreign key violations.""" + if not messages_batch: + return + + # Validate and fix foreign key references + valid_messages = _validate_and_fix_foreign_keys(connection, messages_batch) + + if valid_messages: + connection.execute( + messages_table.insert().prefix_with("OR REPLACE"), + [message.to_db_dict() for message in valid_messages], + ) + + +def _validate_and_fix_foreign_keys( # noqa: C901 + connection: Connection, messages_batch: list[MessageV1] +) -> list[MessageV1]: + """ + Validate foreign key references and fix invalid ones. + + - If thread_id is invalid: ignore the message completely + - If assistant_id is invalid: set to None + - If run_id is invalid: set to None + """ + if not messages_batch: + return [] + + # Extract all foreign key values + thread_ids = {msg.thread_id.removeprefix("thread_") for msg in messages_batch} + assistant_ids = { + msg.assistant_id.removeprefix("asst_") + for msg in messages_batch + if msg.assistant_id + } + run_ids = {msg.run_id.removeprefix("run_") for msg in messages_batch if msg.run_id} + + # Check which foreign keys exist in the database + valid_thread_ids: set[str] = set() + if thread_ids: + # Create placeholders for SQLite IN clause + placeholders = ",".join([":id" + str(i) for i in range(len(thread_ids))]) + params = {f"id{i}": thread_id for i, thread_id in enumerate(thread_ids)} + result = connection.execute( + text(f"SELECT id FROM threads WHERE id IN ({placeholders})"), params + ) + valid_thread_ids = {row[0] for row in result} + + valid_assistant_ids: set[str] = set() + if assistant_ids: + # Create placeholders for SQLite IN clause + placeholders = ",".join([":id" + str(i) for i in range(len(assistant_ids))]) + params = { + f"id{i}": assistant_id for i, assistant_id in enumerate(assistant_ids) + } + result = connection.execute( + text(f"SELECT id FROM assistants WHERE id IN ({placeholders})"), params + ) + valid_assistant_ids = {row[0] for row in result} + + valid_run_ids: set[str] = set() + if run_ids: + # Create placeholders for SQLite IN clause + placeholders = ",".join([":id" + str(i) for i in range(len(run_ids))]) + params = {f"id{i}": run_id for i, run_id in enumerate(run_ids)} + result = connection.execute( + text(f"SELECT id FROM runs WHERE id IN ({placeholders})"), params + ) + valid_run_ids = {row[0] for row in result} + + # Process each message + valid_messages: list[MessageV1] = [] + for message in messages_batch: + thread_id = message.thread_id.removeprefix("thread_") + assistant_id = ( + message.assistant_id.removeprefix("asst_") if message.assistant_id else None + ) + run_id = message.run_id.removeprefix("run_") if message.run_id else None + + # If thread_id is invalid, ignore the message completely + if thread_id not in valid_thread_ids: + logger.warning( + "Ignoring message with invalid thread_id (thread does not exist)", + extra={ + "message_id": message.id, + "thread_id": thread_id, + "workspace_id": str(message.workspace_id), + }, + ) + continue + + # Check and fix assistant_id and run_id + fixed_assistant_id = None + fixed_run_id = None + changes_made: list[str] = [] + + if assistant_id is not None and assistant_id not in valid_assistant_ids: + fixed_assistant_id = None + changes_made.append(f"assistant_id set to None (was: {assistant_id})") + elif assistant_id is not None: + fixed_assistant_id = assistant_id + + if run_id is not None and run_id not in valid_run_ids: + fixed_run_id = None + changes_made.append(f"run_id set to None (was: {run_id})") + elif run_id is not None: + fixed_run_id = run_id + + # Create a copy of the message with fixed foreign keys + if changes_made: + logger.info( + "Fixed foreign key references for message", + extra={ + "message_id": message.id, + "thread_id": thread_id, + "changes": changes_made, + }, + ) + + # Create new message with fixed foreign keys + fixed_message = MessageV1( + id=message.id, + object=message.object, + created_at=message.created_at, + thread_id=message.thread_id, + role=message.role, + content=message.content, + stop_reason=message.stop_reason, + assistant_id=f"asst_{fixed_assistant_id}" + if fixed_assistant_id + else None, + run_id=f"run_{fixed_run_id}" if fixed_run_id else None, + workspace_id=message.workspace_id, + ) + valid_messages.append(fixed_message) + else: + # No changes needed, use original message + valid_messages.append(message) + + return valid_messages + + +settings = SettingsV1() +workspaces_dir = settings.data_dir / "workspaces" + + +def upgrade() -> None: # noqa: C901 + """Import existing messages from JSON files in workspace directories.""" + # Skip if workspaces directory doesn't exist (e.g., first-time setup) + if not workspaces_dir.exists(): + return + + # Get the table from the current database schema + connection = op.get_bind() + messages_table = Table("messages", MetaData(), autoload_with=connection) + + # Process messages in batches + messages_batch: list[MessageV1] = [] + + # Iterate through all workspace directories + for workspace_dir in workspaces_dir.iterdir(): + if not workspace_dir.is_dir(): + continue + + workspace_id = workspace_dir.name + messages_dir = workspace_dir / "messages" + + if not messages_dir.exists(): + continue + + # Iterate through thread directories + for thread_dir in messages_dir.iterdir(): + if not thread_dir.is_dir(): + continue + + # Get all JSON files in the thread directory + json_files = list(thread_dir.glob("*.json")) + + for json_file in json_files: + try: + content = json_file.read_text(encoding="utf-8").strip() + data = json.loads(content) + message = MessageV1.model_validate( + {**data, "workspace_id": workspace_id} + ) + messages_batch.append(message) + if len(messages_batch) >= BATCH_SIZE: + _insert_messages_batch( + connection, messages_table, messages_batch + ) + messages_batch.clear() + except Exception: # noqa: PERF203 + error_msg = "Failed to import message" + logger.exception(error_msg, extra={"json_file": str(json_file)}) + continue + + # Insert remaining messages in the final batch + if messages_batch: + _insert_messages_batch(connection, messages_table, messages_batch) + + +def downgrade() -> None: + """Recreate JSON files for messages during downgrade.""" + + connection = op.get_bind() + messages_table = Table("messages", MetaData(), autoload_with=connection) + + # Fetch all messages from the database + result = connection.execute(messages_table.select()) + rows = result.fetchall() + if not rows: + return + + for row in rows: + try: + message_model: MessageV1 = MessageV1.model_validate( + row, from_attributes=True + ) + messages_dir = ( + workspaces_dir + / str(message_model.workspace_id) + / "messages" + / message_model.thread_id + ) + messages_dir.mkdir(parents=True, exist_ok=True) + json_path = messages_dir / f"{message_model.id}.json" + if json_path.exists(): + continue + with json_path.open("w", encoding="utf-8") as f: + f.write(message_model.model_dump_json()) + except Exception as e: # noqa: PERF203 + error_msg = f"Failed to export row to json: {e}" + logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) + continue diff --git a/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py b/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py index 6db8f789..69dd2b6b 100644 --- a/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py +++ b/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py @@ -11,7 +11,7 @@ from typing import Sequence, Union from alembic import op -from sqlalchemy import MetaData, Table +from sqlalchemy import Connection, MetaData, Table from askui.chat.migrations.shared.mcp_configs.models import McpConfigV1 from askui.chat.migrations.shared.settings import SettingsV1 @@ -25,15 +25,20 @@ logger = logging.getLogger(__name__) -BATCH_SIZE = 100 +BATCH_SIZE = 1000 def _insert_mcp_configs_batch( - mcp_configs_table: Table, mcp_configs_batch: list[McpConfigV1] + connection: Connection, + mcp_configs_table: Table, + mcp_configs_batch: list[McpConfigV1], ) -> None: - """Insert a batch of MCP configs into the database.""" - op.bulk_insert( - mcp_configs_table, + """Insert a batch of MCP configs into the database, ignoring conflicts.""" + if not mcp_configs_batch: + return + + connection.execute( + mcp_configs_table.insert().prefix_with("OR REPLACE"), [mcp_config.to_db_dict() for mcp_config in mcp_configs_batch], ) @@ -66,7 +71,9 @@ def upgrade() -> None: mcp_config = McpConfigV1.model_validate(data) mcp_configs_batch.append(mcp_config) if len(mcp_configs_batch) >= BATCH_SIZE: - _insert_mcp_configs_batch(mcp_configs_table, mcp_configs_batch) + _insert_mcp_configs_batch( + connection, mcp_configs_table, mcp_configs_batch + ) mcp_configs_batch.clear() except Exception: # noqa: PERF203 error_msg = "Failed to import" @@ -75,7 +82,7 @@ def upgrade() -> None: # Insert remaining MCP configs in the final batch if mcp_configs_batch: - _insert_mcp_configs_batch(mcp_configs_table, mcp_configs_batch) + _insert_mcp_configs_batch(connection, mcp_configs_table, mcp_configs_batch) def downgrade() -> None: @@ -101,7 +108,7 @@ def downgrade() -> None: if json_path.exists(): continue with json_path.open("w", encoding="utf-8") as f: - f.write(json.dumps(mcp_config.model_dump())) + f.write(mcp_config.model_dump_json()) except Exception as e: # noqa: PERF203 error_msg = f"Failed to export row to json: {e}" logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) diff --git a/src/askui/chat/migrations/versions/6f7a8b9c0d1e_import_json_runs.py b/src/askui/chat/migrations/versions/6f7a8b9c0d1e_import_json_runs.py new file mode 100644 index 00000000..f0873662 --- /dev/null +++ b/src/askui/chat/migrations/versions/6f7a8b9c0d1e_import_json_runs.py @@ -0,0 +1,242 @@ +"""import_json_runs + +Revision ID: 6f7a8b9c0d1e +Revises: 5e6f7a8b9c0d +Create Date: 2025-01-27 12:05:00.000000 + +""" + +import json +import logging +from typing import Sequence, Union + +from alembic import op +from sqlalchemy import Connection, MetaData, Table, text + +from askui.chat.migrations.shared.runs.models import RunV1 +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "6f7a8b9c0d1e" +down_revision: Union[str, None] = "5e6f7a8b9c0d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + + +BATCH_SIZE = 1000 + + +def _insert_runs_batch( + connection: Connection, runs_table: Table, runs_batch: list[RunV1] +) -> None: + """Insert a batch of runs into the database, handling foreign key violations.""" + if not runs_batch: + return + + # Validate and fix foreign key references + valid_runs = _validate_and_fix_foreign_keys(connection, runs_batch) + + if valid_runs: + connection.execute( + runs_table.insert().prefix_with("OR REPLACE"), + [run.to_db_dict() for run in valid_runs], + ) + + +def _validate_and_fix_foreign_keys( # noqa: C901 + connection: Connection, runs_batch: list[RunV1] +) -> list[RunV1]: + """ + Validate foreign key references and fix invalid ones. + + - If thread_id is invalid: ignore the run completely + - If assistant_id is invalid: set to None + """ + if not runs_batch: + return [] + + # Extract all foreign key values + thread_ids = {run.thread_id.removeprefix("thread_") for run in runs_batch} + assistant_ids = { + run.assistant_id.removeprefix("asst_") for run in runs_batch if run.assistant_id + } + + # Check which foreign keys exist in the database + valid_thread_ids: set[str] = set() + if thread_ids: + # Create placeholders for SQLite IN clause + placeholders = ",".join([":id" + str(i) for i in range(len(thread_ids))]) + params = {f"id{i}": thread_id for i, thread_id in enumerate(thread_ids)} + result = connection.execute( + text(f"SELECT id FROM threads WHERE id IN ({placeholders})"), params + ) + valid_thread_ids = {row[0] for row in result} + + valid_assistant_ids: set[str] = set() + if assistant_ids: + # Create placeholders for SQLite IN clause + placeholders = ",".join([":id" + str(i) for i in range(len(assistant_ids))]) + params = { + f"id{i}": assistant_id for i, assistant_id in enumerate(assistant_ids) + } + result = connection.execute( + text(f"SELECT id FROM assistants WHERE id IN ({placeholders})"), params + ) + valid_assistant_ids = {row[0] for row in result} + + # Process each run + valid_runs: list[RunV1] = [] + for run in runs_batch: + thread_id = run.thread_id.removeprefix("thread_") + assistant_id = ( + run.assistant_id.removeprefix("asst_") if run.assistant_id else None + ) + + # If thread_id is invalid, ignore the run completely + if thread_id not in valid_thread_ids: + logger.warning( + "Ignoring run with invalid thread_id (thread does not exist)", + extra={ + "run_id": run.id, + "thread_id": thread_id, + "workspace_id": str(run.workspace_id), + }, + ) + continue + + # Check and fix assistant_id + fixed_assistant_id = None + changes_made: list[str] = [] + + if assistant_id is not None and assistant_id not in valid_assistant_ids: + fixed_assistant_id = None + changes_made.append(f"assistant_id set to None (was: {assistant_id})") + elif assistant_id is not None: + fixed_assistant_id = assistant_id + + # Create a copy of the run with fixed foreign keys + if changes_made: + logger.info( + "Fixed foreign key references for run", + extra={ + "run_id": run.id, + "thread_id": thread_id, + "changes": changes_made, + }, + ) + + # Create new run with fixed foreign keys + fixed_run = RunV1( + id=run.id, + object=run.object, + thread_id=run.thread_id, + created_at=run.created_at, + expires_at=run.expires_at, + started_at=run.started_at, + completed_at=run.completed_at, + failed_at=run.failed_at, + cancelled_at=run.cancelled_at, + tried_cancelling_at=run.tried_cancelling_at, + last_error=run.last_error, + assistant_id=f"asst_{fixed_assistant_id}" + if fixed_assistant_id + else None, + workspace_id=run.workspace_id, + ) + valid_runs.append(fixed_run) + else: + # No changes needed, use original run + valid_runs.append(run) + + return valid_runs + + +settings = SettingsV1() +workspaces_dir = settings.data_dir / "workspaces" + + +def upgrade() -> None: # noqa: C901 + """Import existing runs from JSON files in workspace directories.""" + + # Skip if workspaces directory doesn't exist (e.g., first-time setup) + if not workspaces_dir.exists(): + return + + # Get the table from the current database schema + connection = op.get_bind() + runs_table = Table("runs", MetaData(), autoload_with=connection) + + # Process runs in batches + runs_batch: list[RunV1] = [] + + # Iterate through all workspace directories + for workspace_dir in workspaces_dir.iterdir(): + if not workspace_dir.is_dir(): + continue + + workspace_id = workspace_dir.name + runs_dir = workspace_dir / "runs" + + if not runs_dir.exists(): + continue + + # Iterate through thread directories + for thread_dir in runs_dir.iterdir(): + if not thread_dir.is_dir(): + continue + + # Get all JSON files in the thread directory + json_files = list(thread_dir.glob("*.json")) + + for json_file in json_files: + try: + content = json_file.read_text(encoding="utf-8").strip() + data = json.loads(content) + run = RunV1.model_validate({**data, "workspace_id": workspace_id}) + runs_batch.append(run) + if len(runs_batch) >= BATCH_SIZE: + _insert_runs_batch(connection, runs_table, runs_batch) + runs_batch.clear() + except Exception: # noqa: PERF203 + error_msg = "Failed to import run" + logger.exception(error_msg, extra={"json_file": str(json_file)}) + continue + + # Insert remaining runs in the final batch + if runs_batch: + _insert_runs_batch(connection, runs_table, runs_batch) + + +def downgrade() -> None: + """Recreate JSON files for runs during downgrade.""" + + connection = op.get_bind() + runs_table = Table("runs", MetaData(), autoload_with=connection) + + # Fetch all runs from the database + result = connection.execute(runs_table.select()) + rows = result.fetchall() + if not rows: + return + + for row in rows: + try: + run_model: RunV1 = RunV1.model_validate(row, from_attributes=True) + runs_dir = ( + workspaces_dir + / str(run_model.workspace_id) + / "runs" + / run_model.thread_id + ) + runs_dir.mkdir(parents=True, exist_ok=True) + json_path = runs_dir / f"{run_model.id}.json" + if json_path.exists(): + continue + with json_path.open("w", encoding="utf-8") as f: + f.write(run_model.model_dump_json()) + except Exception as e: # noqa: PERF203 + error_msg = f"Failed to export row to json: {e}" + logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) + continue diff --git a/src/askui/chat/migrations/versions/7a8b9c0d1e2f_soft_delete_threads_dirs.py b/src/askui/chat/migrations/versions/7a8b9c0d1e2f_soft_delete_threads_dirs.py new file mode 100644 index 00000000..8c351eb8 --- /dev/null +++ b/src/askui/chat/migrations/versions/7a8b9c0d1e2f_soft_delete_threads_dirs.py @@ -0,0 +1,97 @@ +"""soft_delete_threads_dirs + +Revision ID: 7a8b9c0d1e2f +Revises: 6f7a8b9c0d1e +Create Date: 2025-01-27 12:06:00.000000 + +""" + +import logging +from typing import Sequence, Union + +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "7a8b9c0d1e2f" +down_revision: Union[str, None] = "6f7a8b9c0d1e" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + +settings = SettingsV1() +workspaces_dir = settings.data_dir / "workspaces" + + +def upgrade() -> None: + """Soft delete threads directories by moving them to .deleted subdirectory.""" + + # Skip if workspaces directory doesn't exist (e.g., first-time setup) + if not workspaces_dir.exists(): + return + + # Create .deleted directory if it doesn't exist + deleted_dir = settings.data_dir / ".deleted" + deleted_dir.mkdir(parents=True, exist_ok=True) + + # Soft delete threads directories from all workspaces + for workspace_dir in workspaces_dir.iterdir(): + if not workspace_dir.is_dir(): + continue + + threads_dir = workspace_dir / "threads" + if threads_dir.exists(): + try: + # Create workspace-specific deleted directory + deleted_workspace_dir = deleted_dir / workspace_dir.name + deleted_workspace_dir.mkdir(parents=True, exist_ok=True) + + # Move threads directory to .deleted subdirectory + deleted_threads_dir = deleted_workspace_dir / "threads" + if deleted_threads_dir.exists(): + logger.info( + "Deleted threads directory already exists, skipping soft delete", + extra={"deleted_threads_dir": str(deleted_threads_dir)}, + ) + continue + + threads_dir.rename(deleted_threads_dir) + logger.info( + "Soft deleted threads directory", + extra={ + "threads_dir": str(threads_dir), + "deleted_threads_dir": str(deleted_threads_dir), + }, + ) + except Exception as e: + error_msg = f"Failed to soft delete threads directory: {threads_dir}" + logger.exception(error_msg, exc_info=e) + + +def downgrade() -> None: + """Restore threads directories from .deleted subdirectory.""" + deleted_dir = settings.data_dir / ".deleted" + + if not deleted_dir.exists(): + logger.info("No .deleted directory found to restore from") + return + + # Restore threads directories for all workspaces + for workspace_dir in workspaces_dir.iterdir(): + if not workspace_dir.is_dir(): + continue + + deleted_workspace_dir = deleted_dir / workspace_dir.name + deleted_threads_dir = deleted_workspace_dir / "threads" + + if deleted_threads_dir.exists(): + threads_dir = workspace_dir / "threads" + try: + deleted_threads_dir.rename(threads_dir) + logger.info( + "Restored threads directory", + extra={"threads_dir": str(threads_dir)}, + ) + except Exception as e: + error_msg = f"Failed to restore threads directory: {threads_dir}" + logger.exception(error_msg, exc_info=e) diff --git a/src/askui/chat/migrations/versions/7c3d4e5f6a7b_remove_mcp_configs_dir.py b/src/askui/chat/migrations/versions/7c3d4e5f6a7b_remove_mcp_configs_dir.py deleted file mode 100644 index 17419cee..00000000 --- a/src/askui/chat/migrations/versions/7c3d4e5f6a7b_remove_mcp_configs_dir.py +++ /dev/null @@ -1,51 +0,0 @@ -"""remove_mcp_configs_dir - -Revision ID: 7c3d4e5f6a7b -Revises: 6b2c3d4e5f6a -Create Date: 2025-01-27 10:02:00.000000 - -""" - -import logging -import shutil -from typing import Sequence, Union - -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "7c3d4e5f6a7b" -down_revision: Union[str, None] = "6b2c3d4e5f6a" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - -settings = SettingsV1() -mcp_configs_dir = settings.data_dir / "mcp_configs" - - -def upgrade() -> None: - """Remove the mcp_configs directory and all its contents.""" - - # Skip if directory doesn't exist - if not mcp_configs_dir.exists(): - logger.info("MCP configs directory does not exist, skipping removal") - return - - try: - shutil.rmtree(mcp_configs_dir) - logger.info( - "Successfully removed mcp_configs directory", - extra={"mcp_configs_dir": str(mcp_configs_dir)}, - ) - except Exception as e: - error_msg = "Failed to remove mcp_configs directory" - logger.exception( - error_msg, - extra={"mcp_configs_dir": str(mcp_configs_dir)}, - ) - raise RuntimeError(error_msg) from e - - -def downgrade() -> None: - mcp_configs_dir.mkdir(parents=True, exist_ok=True) diff --git a/src/askui/chat/migrations/versions/7c3d4e5f6a7b_soft_delete_mcp_configs_dir.py b/src/askui/chat/migrations/versions/7c3d4e5f6a7b_soft_delete_mcp_configs_dir.py new file mode 100644 index 00000000..c683e87d --- /dev/null +++ b/src/askui/chat/migrations/versions/7c3d4e5f6a7b_soft_delete_mcp_configs_dir.py @@ -0,0 +1,86 @@ +"""soft_delete_mcp_configs_dir + +Revision ID: 7c3d4e5f6a7b +Revises: 6b2c3d4e5f6a +Create Date: 2025-01-27 10:02:00.000000 + +""" + +import logging +from typing import Sequence, Union + +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "7c3d4e5f6a7b" +down_revision: Union[str, None] = "6b2c3d4e5f6a" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + +settings = SettingsV1() +mcp_configs_dir = settings.data_dir / "mcp_configs" + + +def upgrade() -> None: + """Soft delete the mcp_configs directory by moving it to .deleted subdirectory.""" + + # Skip if directory doesn't exist + if not mcp_configs_dir.exists(): + logger.info("MCP configs directory does not exist, skipping soft delete") + return + + try: + # Create .deleted directory if it doesn't exist + deleted_dir = settings.data_dir / ".deleted" + deleted_dir.mkdir(parents=True, exist_ok=True) + + # Move mcp_configs directory to .deleted subdirectory + deleted_mcp_configs_dir = deleted_dir / "mcp_configs" + if deleted_mcp_configs_dir.exists(): + logger.info( + "Deleted mcp_configs directory already exists, skipping soft delete", + extra={"deleted_mcp_configs_dir": str(deleted_mcp_configs_dir)}, + ) + return + + mcp_configs_dir.rename(deleted_mcp_configs_dir) + logger.info( + "Successfully soft deleted mcp_configs directory", + extra={ + "mcp_configs_dir": str(mcp_configs_dir), + "deleted_mcp_configs_dir": str(deleted_mcp_configs_dir), + }, + ) + except Exception as e: + error_msg = "Failed to soft delete mcp_configs directory" + logger.exception( + error_msg, + extra={"mcp_configs_dir": str(mcp_configs_dir)}, + ) + raise RuntimeError(error_msg) from e + + +def downgrade() -> None: + """Restore the mcp_configs directory from .deleted subdirectory.""" + deleted_dir = settings.data_dir / ".deleted" + deleted_mcp_configs_dir = deleted_dir / "mcp_configs" + + if not deleted_mcp_configs_dir.exists(): + logger.info("No deleted mcp_configs directory found to restore") + return + + try: + deleted_mcp_configs_dir.rename(mcp_configs_dir) + logger.info( + "Successfully restored mcp_configs directory", + extra={"mcp_configs_dir": str(mcp_configs_dir)}, + ) + except Exception as e: + error_msg = "Failed to restore mcp_configs directory" + logger.exception( + error_msg, + extra={"mcp_configs_dir": str(mcp_configs_dir)}, + ) + raise RuntimeError(error_msg) from e diff --git a/src/askui/chat/migrations/versions/8b9c0d1e2f3a_soft_delete_messages_dirs.py b/src/askui/chat/migrations/versions/8b9c0d1e2f3a_soft_delete_messages_dirs.py new file mode 100644 index 00000000..2ce14b78 --- /dev/null +++ b/src/askui/chat/migrations/versions/8b9c0d1e2f3a_soft_delete_messages_dirs.py @@ -0,0 +1,97 @@ +"""soft_delete_messages_dirs + +Revision ID: 8b9c0d1e2f3a +Revises: 7a8b9c0d1e2f +Create Date: 2025-01-27 12:07:00.000000 + +""" + +import logging +from typing import Sequence, Union + +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "8b9c0d1e2f3a" +down_revision: Union[str, None] = "7a8b9c0d1e2f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + +settings = SettingsV1() +workspaces_dir = settings.data_dir / "workspaces" + + +def upgrade() -> None: + """Soft delete messages directories by moving them to .deleted subdirectory.""" + + # Skip if workspaces directory doesn't exist (e.g., first-time setup) + if not workspaces_dir.exists(): + return + + # Create .deleted directory if it doesn't exist + deleted_dir = settings.data_dir / ".deleted" + deleted_dir.mkdir(parents=True, exist_ok=True) + + # Soft delete messages directories from all workspaces + for workspace_dir in workspaces_dir.iterdir(): + if not workspace_dir.is_dir(): + continue + + messages_dir = workspace_dir / "messages" + if messages_dir.exists(): + try: + # Create workspace-specific deleted directory + deleted_workspace_dir = deleted_dir / workspace_dir.name + deleted_workspace_dir.mkdir(parents=True, exist_ok=True) + + # Move messages directory to .deleted subdirectory + deleted_messages_dir = deleted_workspace_dir / "messages" + if deleted_messages_dir.exists(): + logger.info( + "Deleted messages directory already exists, skipping soft delete", + extra={"deleted_messages_dir": str(deleted_messages_dir)}, + ) + continue + + messages_dir.rename(deleted_messages_dir) + logger.info( + "Soft deleted messages directory", + extra={ + "messages_dir": str(messages_dir), + "deleted_messages_dir": str(deleted_messages_dir), + }, + ) + except Exception as e: + error_msg = f"Failed to soft delete messages directory: {messages_dir}" + logger.exception(error_msg, exc_info=e) + + +def downgrade() -> None: + """Restore messages directories from .deleted subdirectory.""" + deleted_dir = settings.data_dir / ".deleted" + + if not deleted_dir.exists(): + logger.info("No .deleted directory found to restore from") + return + + # Restore messages directories for all workspaces + for workspace_dir in workspaces_dir.iterdir(): + if not workspace_dir.is_dir(): + continue + + deleted_workspace_dir = deleted_dir / workspace_dir.name + deleted_messages_dir = deleted_workspace_dir / "messages" + + if deleted_messages_dir.exists(): + try: + messages_dir = workspace_dir / "messages" + deleted_messages_dir.rename(messages_dir) + logger.info( + "Restored messages directory", + extra={"messages_dir": str(messages_dir)}, + ) + except Exception as e: + error_msg = f"Failed to restore messages directory: {messages_dir}" + logger.exception(error_msg, exc_info=e) diff --git a/src/askui/chat/migrations/versions/9c0d1e2f3a4b_soft_delete_runs_dirs.py b/src/askui/chat/migrations/versions/9c0d1e2f3a4b_soft_delete_runs_dirs.py new file mode 100644 index 00000000..3fde9bea --- /dev/null +++ b/src/askui/chat/migrations/versions/9c0d1e2f3a4b_soft_delete_runs_dirs.py @@ -0,0 +1,97 @@ +"""soft_delete_runs_dirs + +Revision ID: 9c0d1e2f3a4b +Revises: 8b9c0d1e2f3a +Create Date: 2025-01-27 12:08:00.000000 + +""" + +import logging +from typing import Sequence, Union + +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "9c0d1e2f3a4b" +down_revision: Union[str, None] = "8b9c0d1e2f3a" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + +settings = SettingsV1() +workspaces_dir = settings.data_dir / "workspaces" + + +def upgrade() -> None: + """Soft delete runs directories by moving them to .deleted subdirectory.""" + + # Skip if workspaces directory doesn't exist (e.g., first-time setup) + if not workspaces_dir.exists(): + return + + # Create .deleted directory if it doesn't exist + deleted_dir = settings.data_dir / ".deleted" + deleted_dir.mkdir(parents=True, exist_ok=True) + + # Soft delete runs directories from all workspaces + for workspace_dir in workspaces_dir.iterdir(): + if not workspace_dir.is_dir(): + continue + + runs_dir = workspace_dir / "runs" + if runs_dir.exists(): + try: + # Create workspace-specific deleted directory + deleted_workspace_dir = deleted_dir / workspace_dir.name + deleted_workspace_dir.mkdir(parents=True, exist_ok=True) + + # Move runs directory to .deleted subdirectory + deleted_runs_dir = deleted_workspace_dir / "runs" + if deleted_runs_dir.exists(): + logger.info( + "Deleted runs directory already exists, skipping soft delete", + extra={"deleted_runs_dir": str(deleted_runs_dir)}, + ) + continue + + runs_dir.rename(deleted_runs_dir) + logger.info( + "Soft deleted runs directory", + extra={ + "runs_dir": str(runs_dir), + "deleted_runs_dir": str(deleted_runs_dir), + }, + ) + except Exception as e: + error_msg = f"Failed to soft delete runs directory: {runs_dir}" + logger.exception(error_msg, exc_info=e) + + +def downgrade() -> None: + """Restore runs directories from .deleted subdirectory.""" + deleted_dir = settings.data_dir / ".deleted" + + if not deleted_dir.exists(): + logger.info("No .deleted directory found to restore from") + return + + # Restore runs directories for all workspaces + for workspace_dir in workspaces_dir.iterdir(): + if not workspace_dir.is_dir(): + continue + + deleted_workspace_dir = deleted_dir / workspace_dir.name + deleted_runs_dir = deleted_workspace_dir / "runs" + + if deleted_runs_dir.exists(): + try: + runs_dir = workspace_dir / "runs" + deleted_runs_dir.rename(runs_dir) + logger.info( + "Restored runs directory", + extra={"runs_dir": str(runs_dir)}, + ) + except Exception as e: + error_msg = f"Failed to restore runs directory: {runs_dir}" + logger.exception(error_msg, exc_info=e) diff --git a/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py b/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py index 8a2efcc1..5776ebb5 100644 --- a/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py +++ b/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py @@ -8,11 +8,10 @@ import json import logging -import mimetypes from typing import Sequence, Union from alembic import op -from sqlalchemy import MetaData, Table +from sqlalchemy import Connection, MetaData, Table from askui.chat.migrations.shared.files.models import FileV1 from askui.chat.migrations.shared.settings import SettingsV1 @@ -26,13 +25,18 @@ logger = logging.getLogger(__name__) -BATCH_SIZE = 100 +BATCH_SIZE = 1000 -def _insert_files_batch(files_table: Table, files_batch: list[FileV1]) -> None: - """Insert a batch of files into the database.""" - op.bulk_insert( - files_table, +def _insert_files_batch( + connection: Connection, files_table: Table, files_batch: list[FileV1] +) -> None: + """Insert a batch of files into the database, ignoring conflicts.""" + if not files_batch: + return + + connection.execute( + files_table.insert().prefix_with("OR REPLACE"), [file.to_db_dict() for file in files_batch], ) @@ -76,7 +80,7 @@ def upgrade() -> None: # noqa: C901 file = FileV1.model_validate({**data, "workspace_id": workspace_id}) files_batch.append(file) if len(files_batch) >= BATCH_SIZE: - _insert_files_batch(files_table, files_batch) + _insert_files_batch(connection, files_table, files_batch) files_batch.clear() except Exception: # noqa: PERF203 error_msg = "Failed to import file" @@ -85,7 +89,7 @@ def upgrade() -> None: # noqa: C901 # Insert remaining files in the final batch if files_batch: - _insert_files_batch(files_table, files_batch) + _insert_files_batch(connection, files_table, files_batch) def downgrade() -> None: @@ -112,7 +116,7 @@ def downgrade() -> None: if json_path.exists(): continue with json_path.open("w", encoding="utf-8") as f: - f.write(json.dumps(file_model.model_dump())) + f.write(file_model.model_dump_json()) except Exception as e: # noqa: PERF203 error_msg = f"Failed to export row to json: {e}" logger.exception(error_msg, extra={"row": str(row)}, exc_info=e) diff --git a/src/askui/chat/migrations/versions/a0f1a2b3c4d5_remove_files_dirs.py b/src/askui/chat/migrations/versions/a0f1a2b3c4d5_remove_files_dirs.py deleted file mode 100644 index 7ac2f4f4..00000000 --- a/src/askui/chat/migrations/versions/a0f1a2b3c4d5_remove_files_dirs.py +++ /dev/null @@ -1,68 +0,0 @@ -"""remove_files_dirs - -Revision ID: a0f1a2b3c4d5 -Revises: 9e0f1a2b3c4d -Create Date: 2025-01-27 11:02:00.000000 - -""" - -import logging -import shutil -from typing import Sequence, Union - -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "a0f1a2b3c4d5" -down_revision: Union[str, None] = "9e0f1a2b3c4d" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - -settings = SettingsV1() -workspaces_dir = settings.data_dir / "workspaces" - - -def upgrade() -> None: - """Remove JSON files from workspace static directories after successful migration.""" - - # Skip if workspaces directory doesn't exist - if not workspaces_dir.exists(): - logger.info( - "Workspaces directory does not exist, skipping removal", - extra={"workspaces_dir": str(workspaces_dir)}, - ) - return - - # Iterate through all workspace directories - for workspace_dir in workspaces_dir.iterdir(): - if not workspace_dir.is_dir(): - continue - - files_dir = workspace_dir / "files" - if not files_dir.exists(): - logger.info( - "Files directory does not exist, skipping removal", - extra={"files_dir": str(files_dir)}, - ) - continue - - try: - shutil.rmtree(files_dir) - logger.info( - "Successfully removed files directory", - extra={"files_dir": str(files_dir)}, - ) - except Exception as e: # noqa: PERF203 - error_msg = "Failed to remove files directory" - logger.exception(error_msg, extra={"files_dir": str(files_dir)}) - raise RuntimeError(error_msg) from e - - -def downgrade() -> None: - """Recreate JSON files in workspace static directories during downgrade.""" - - # This is handled by the import_json_files migration downgrade - # No need to recreate files here as they will be recreated when downgrading - # the import_json_files migration diff --git a/src/askui/chat/migrations/versions/a0f1a2b3c4d5_soft_delete_files_dirs.py b/src/askui/chat/migrations/versions/a0f1a2b3c4d5_soft_delete_files_dirs.py new file mode 100644 index 00000000..75b249ea --- /dev/null +++ b/src/askui/chat/migrations/versions/a0f1a2b3c4d5_soft_delete_files_dirs.py @@ -0,0 +1,108 @@ +"""soft_delete_files_dirs + +Revision ID: a0f1a2b3c4d5 +Revises: 9e0f1a2b3c4d +Create Date: 2025-01-27 11:02:00.000000 + +""" + +import logging +from typing import Sequence, Union + +from askui.chat.migrations.shared.settings import SettingsV1 + +# revision identifiers, used by Alembic. +revision: str = "a0f1a2b3c4d5" +down_revision: Union[str, None] = "9e0f1a2b3c4d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +logger = logging.getLogger(__name__) + +settings = SettingsV1() +workspaces_dir = settings.data_dir / "workspaces" + + +def upgrade() -> None: + """Soft delete JSON files from workspace static directories by moving them to .deleted subdirectory.""" + + # Skip if workspaces directory doesn't exist + if not workspaces_dir.exists(): + logger.info( + "Workspaces directory does not exist, skipping soft delete", + extra={"workspaces_dir": str(workspaces_dir)}, + ) + return + + # Create .deleted directory if it doesn't exist + deleted_dir = settings.data_dir / ".deleted" + deleted_dir.mkdir(parents=True, exist_ok=True) + + # Iterate through all workspace directories + for workspace_dir in workspaces_dir.iterdir(): + if not workspace_dir.is_dir(): + continue + + files_dir = workspace_dir / "files" + if not files_dir.exists(): + logger.info( + "Files directory does not exist, skipping soft delete", + extra={"files_dir": str(files_dir)}, + ) + continue + + try: + # Create workspace-specific deleted directory + deleted_workspace_dir = deleted_dir / workspace_dir.name + deleted_workspace_dir.mkdir(parents=True, exist_ok=True) + + # Move files directory to .deleted subdirectory + deleted_files_dir = deleted_workspace_dir / "files" + if deleted_files_dir.exists(): + logger.info( + "Deleted files directory already exists, skipping soft delete", + extra={"deleted_files_dir": str(deleted_files_dir)}, + ) + continue + + files_dir.rename(deleted_files_dir) + logger.info( + "Successfully soft deleted files directory", + extra={ + "files_dir": str(files_dir), + "deleted_files_dir": str(deleted_files_dir), + }, + ) + except Exception as e: # noqa: PERF203 + error_msg = "Failed to soft delete files directory" + logger.exception(error_msg, extra={"files_dir": str(files_dir)}) + raise RuntimeError(error_msg) from e + + +def downgrade() -> None: + """Restore JSON files in workspace static directories from .deleted subdirectory.""" + deleted_dir = settings.data_dir / ".deleted" + + if not deleted_dir.exists(): + logger.info("No .deleted directory found to restore from") + return + + # Restore files directories for all workspaces + for workspace_dir in workspaces_dir.iterdir(): + if not workspace_dir.is_dir(): + continue + + deleted_workspace_dir = deleted_dir / workspace_dir.name + deleted_files_dir = deleted_workspace_dir / "files" + + if deleted_files_dir.exists(): + try: + files_dir = workspace_dir / "files" + deleted_files_dir.rename(files_dir) + logger.info( + "Successfully restored files directory", + extra={"files_dir": str(files_dir)}, + ) + except Exception as e: + error_msg = f"Failed to restore files directory: {files_dir}" + logger.exception(error_msg, exc_info=e) diff --git a/src/askui/utils/datetime_utils.py b/src/askui/utils/datetime_utils.py index b34585a7..d3fd9d2e 100644 --- a/src/askui/utils/datetime_utils.py +++ b/src/askui/utils/datetime_utils.py @@ -8,6 +8,7 @@ PlainSerializer( lambda v: int(v.timestamp()), return_type=int, + when_used="json-unless-none", ), ] From df21b3b404efa9ab72008e5a98642009b3cd2d35 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Sun, 2 Nov 2025 23:11:37 +0100 Subject: [PATCH 10/14] chore(pdm): update pdm.lock --- pdm.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pdm.lock b/pdm.lock index d22f1460..3d2425d1 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,10 +5,10 @@ groups = ["default", "all", "android", "bedrock", "chat", "dev", "pynput", "vertex", "web"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:3b5c70118ce8b743db5aaf3f2c06765e3d77f7e1b8da13d081a731b24e4d98e7" +content_hash = "sha256:b0b4a3234caf6b5516e2ca5ea826496fd0abf8b08ee81cd92771f1b40e476b5e" [[metadata.targets]] -requires_python = ">=3.10" +requires_python = ">=3.10,<=3.13" [[package]] name = "aiofiles" From c51cb941a14527c6452ca92a026835808191255f Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Mon, 3 Nov 2025 00:34:44 +0100 Subject: [PATCH 11/14] fix(chat): fix issues created by migration --- src/askui/chat/api/dependencies.py | 10 + src/askui/chat/api/messages/dependencies.py | 5 +- src/askui/chat/api/messages/service.py | 14 + src/askui/chat/api/messages/translator.py | 73 +++- tests/integration/chat/api/conftest.py | 12 +- tests/integration/chat/api/test_files.py | 127 ++++-- .../chat/api/test_files_edge_cases.py | 43 +- .../chat/api/test_files_service.py | 181 +++------ .../integration/chat/api/test_mcp_configs.py | 283 ++++++------- tests/integration/chat/api/test_messages.py | 178 +++++--- .../api/test_request_document_translator.py | 9 +- tests/integration/chat/api/test_runs.py | 379 +++++++++++++----- tests/integration/chat/api/test_threads.py | 265 ++++++------ .../tools/testing/test_execution_tools.py | 14 +- .../tools/testing/test_feature_tools.py | 14 +- .../tools/testing/test_scenario_tools.py | 14 +- .../unit/test_request_document_translator.py | 4 +- 17 files changed, 944 insertions(+), 681 deletions(-) diff --git a/src/askui/chat/api/dependencies.py b/src/askui/chat/api/dependencies.py index 7264f598..ca42ece1 100644 --- a/src/askui/chat/api/dependencies.py +++ b/src/askui/chat/api/dependencies.py @@ -60,6 +60,16 @@ def set_env_from_headers( SetEnvFromHeadersDep = Depends(set_env_from_headers) +def get_workspace_id( + askui_workspace: Annotated[WorkspaceId | None, Header()] = None, +) -> WorkspaceId | None: + """Get workspace ID from header.""" + return askui_workspace + + +WorkspaceIdDep = Depends(get_workspace_id) + + def get_workspace_dir( askui_workspace: Annotated[WorkspaceId, Header()], settings: Settings = SettingsDep, diff --git a/src/askui/chat/api/messages/dependencies.py b/src/askui/chat/api/messages/dependencies.py index 965bfb6e..cfea1d91 100644 --- a/src/askui/chat/api/messages/dependencies.py +++ b/src/askui/chat/api/messages/dependencies.py @@ -1,11 +1,13 @@ from fastapi import Depends from askui.chat.api.db.session import SessionDep +from askui.chat.api.dependencies import WorkspaceIdDep from askui.chat.api.files.dependencies import FileServiceDep from askui.chat.api.files.service import FileService from askui.chat.api.messages.chat_history_manager import ChatHistoryManager from askui.chat.api.messages.service import MessageService from askui.chat.api.messages.translator import MessageTranslator +from askui.chat.api.models import WorkspaceId from askui.models.shared.truncation_strategies import ( SimpleTruncationStrategyFactory, TruncationStrategyFactory, @@ -24,8 +26,9 @@ def get_message_service( def get_message_translator( file_service: FileService = FileServiceDep, + workspace_id: WorkspaceId | None = WorkspaceIdDep, ) -> MessageTranslator: - return MessageTranslator(file_service) + return MessageTranslator(file_service, workspace_id) MessageTranslatorDep = Depends(get_message_translator) diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py index a33783bb..b1c08d54 100644 --- a/src/askui/chat/api/messages/service.py +++ b/src/askui/chat/api/messages/service.py @@ -6,6 +6,7 @@ from askui.chat.api.messages.models import Message, MessageCreate from askui.chat.api.messages.orms import MessageOrm from askui.chat.api.models import MessageId, ThreadId, WorkspaceId +from askui.chat.api.threads.orms import ThreadOrm from askui.utils.api_utils import ( LIST_LIMIT_DEFAULT, ListOrder, @@ -46,6 +47,19 @@ def create( params: MessageCreate, ) -> Message: """Create a new message.""" + # Validate thread exists + thread_orm: ThreadOrm | None = ( + self._session.query(ThreadOrm) + .filter( + ThreadOrm.id == thread_id, + ThreadOrm.workspace_id == workspace_id, + ) + .first() + ) + if thread_orm is None: + error_msg = f"Thread {thread_id} not found" + raise NotFoundError(error_msg) + message = Message.create(workspace_id, thread_id, params) message_orm = MessageOrm.from_model(message) self._session.add(message_orm) diff --git a/src/askui/chat/api/messages/translator.py b/src/askui/chat/api/messages/translator.py index 04f448ba..25de3788 100644 --- a/src/askui/chat/api/messages/translator.py +++ b/src/askui/chat/api/messages/translator.py @@ -9,6 +9,7 @@ RequestDocumentBlockParam, ToolResultBlockParam, ) +from askui.chat.api.models import WorkspaceId from askui.data_extractor import DataExtractor from askui.models.models import ModelName from askui.models.shared.agent_message_param import ( @@ -36,8 +37,11 @@ class RequestDocumentBlockParamTranslator: """Translator for RequestDocumentBlockParam to/from Anthropic format.""" - def __init__(self, file_service: FileService) -> None: + def __init__( + self, file_service: FileService, workspace_id: WorkspaceId | None + ) -> None: self._file_service = file_service + self._workspace_id = workspace_id self._data_extractor = DataExtractor() def extract_content( @@ -84,7 +88,9 @@ def extract_content( async def to_anthropic( self, block: RequestDocumentBlockParam ) -> list[AnthropicContentBlockParam]: - file, path = self._file_service.retrieve_file_content(block.source.file_id) + file, path = self._file_service.retrieve_file_content( + self._workspace_id, block.source.file_id + ) source = load_source(path) content = self.extract_content(source, block) return [ @@ -97,8 +103,11 @@ async def to_anthropic( class ImageBlockParamSourceTranslator: - def __init__(self, file_service: FileService) -> None: + def __init__( + self, file_service: FileService, workspace_id: WorkspaceId | None + ) -> None: self._file_service = file_service + self._workspace_id = workspace_id async def from_anthropic( # noqa: RET503 self, source: UrlImageSourceParam | Base64ImageSourceParam @@ -138,7 +147,9 @@ async def to_anthropic( # noqa: RET503 if source.type == "base64": return source if source.type == "file": # noqa: RET503 - file, path = self._file_service.retrieve_file_content(source.id) + file, path = self._file_service.retrieve_file_content( + self._workspace_id, source.id + ) image = Image.open(path) return Base64ImageSourceParam( data=image_to_base64(image), @@ -147,8 +158,12 @@ async def to_anthropic( # noqa: RET503 class ImageBlockParamTranslator: - def __init__(self, file_service: FileService) -> None: - self.source_translator = ImageBlockParamSourceTranslator(file_service) + def __init__( + self, file_service: FileService, workspace_id: WorkspaceId | None + ) -> None: + self.source_translator = ImageBlockParamSourceTranslator( + file_service, workspace_id + ) async def from_anthropic(self, block: AnthropicImageBlockParam) -> ImageBlockParam: return ImageBlockParam( @@ -166,8 +181,10 @@ async def to_anthropic(self, block: ImageBlockParam) -> AnthropicImageBlockParam class ToolResultContentBlockParamTranslator: - def __init__(self, file_service: FileService) -> None: - self.image_translator = ImageBlockParamTranslator(file_service) + def __init__( + self, file_service: FileService, workspace_id: WorkspaceId | None + ) -> None: + self.image_translator = ImageBlockParamTranslator(file_service, workspace_id) async def from_anthropic( self, block: AnthropicImageBlockParam | TextBlockParam @@ -185,9 +202,11 @@ async def to_anthropic( class ToolResultContentTranslator: - def __init__(self, file_service: FileService) -> None: + def __init__( + self, file_service: FileService, workspace_id: WorkspaceId | None + ) -> None: self.block_param_translator = ToolResultContentBlockParamTranslator( - file_service + file_service, workspace_id ) async def from_anthropic( @@ -210,8 +229,12 @@ async def to_anthropic( class ToolResultBlockParamTranslator: - def __init__(self, file_service: FileService) -> None: - self.content_translator = ToolResultContentTranslator(file_service) + def __init__( + self, file_service: FileService, workspace_id: WorkspaceId | None + ) -> None: + self.content_translator = ToolResultContentTranslator( + file_service, workspace_id + ) async def from_anthropic( self, block: AnthropicToolResultBlockParam @@ -237,11 +260,15 @@ async def to_anthropic( class MessageContentBlockParamTranslator: - def __init__(self, file_service: FileService) -> None: - self.image_translator = ImageBlockParamTranslator(file_service) - self.tool_result_translator = ToolResultBlockParamTranslator(file_service) + def __init__( + self, file_service: FileService, workspace_id: WorkspaceId | None + ) -> None: + self.image_translator = ImageBlockParamTranslator(file_service, workspace_id) + self.tool_result_translator = ToolResultBlockParamTranslator( + file_service, workspace_id + ) self.request_document_translator = RequestDocumentBlockParamTranslator( - file_service + file_service, workspace_id ) async def from_anthropic( @@ -266,8 +293,12 @@ async def to_anthropic( class MessageContentTranslator: - def __init__(self, file_service: FileService) -> None: - self.block_param_translator = MessageContentBlockParamTranslator(file_service) + def __init__( + self, file_service: FileService, workspace_id: WorkspaceId | None + ) -> None: + self.block_param_translator = MessageContentBlockParamTranslator( + file_service, workspace_id + ) async def from_anthropic( self, content: list[AnthropicContentBlockParam] | str @@ -291,8 +322,10 @@ async def to_anthropic( class MessageTranslator: - def __init__(self, file_service: FileService) -> None: - self.content_translator = MessageContentTranslator(file_service) + def __init__( + self, file_service: FileService, workspace_id: WorkspaceId | None + ) -> None: + self.content_translator = MessageContentTranslator(file_service, workspace_id) async def from_anthropic(self, message: AnthropicMessageParam) -> MessageParam: return MessageParam( diff --git a/tests/integration/chat/api/conftest.py b/tests/integration/chat/api/conftest.py index 9c845c33..77be8da2 100644 --- a/tests/integration/chat/api/conftest.py +++ b/tests/integration/chat/api/conftest.py @@ -77,12 +77,16 @@ def test_headers(test_workspace_id: str) -> dict[str, str]: @pytest.fixture -def mock_file_service(temp_workspace_dir: Path) -> FileService: +def mock_file_service( + test_db_session: Session, temp_workspace_dir: Path +) -> FileService: """Create a mock file service with temporary workspace.""" - return FileService(temp_workspace_dir) + return FileService(test_db_session, temp_workspace_dir) -def create_test_app_with_overrides(workspace_path: Path) -> FastAPI: +def create_test_app_with_overrides( + test_db_session: Session, workspace_path: Path +) -> FastAPI: """Create a test app with all dependencies overridden.""" from askui.chat.api.app import app from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir @@ -96,7 +100,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing diff --git a/tests/integration/chat/api/test_files.py b/tests/integration/chat/api/test_files.py index 4496794c..8eb76e2f 100644 --- a/tests/integration/chat/api/test_files.py +++ b/tests/integration/chat/api/test_files.py @@ -2,18 +2,31 @@ import io import tempfile +from datetime import datetime, timezone from pathlib import Path +from uuid import UUID +import pytest from fastapi import status from fastapi.testclient import TestClient +from sqlalchemy.orm import Session from askui.chat.api.files.models import File +from askui.chat.api.files.orms import FileOrm from askui.chat.api.files.service import FileService +from askui.chat.api.models import FileId +from askui.utils.api_utils import NotFoundError class TestFilesAPI: """Test suite for the files API endpoints.""" + def _add_file_to_db(self, file: File, test_db_session: Session) -> None: + """Add a file to the test database.""" + file_orm = FileOrm.from_model(file) + test_db_session.add(file_orm) + test_db_session.commit() + def test_list_files_empty( self, test_client: TestClient, test_headers: dict[str, str] ) -> None: @@ -29,6 +42,7 @@ def test_list_files_empty( def test_list_files_with_files( self, test_headers: dict[str, str], + test_db_session: Session, ) -> None: """Test listing files when files exist.""" # Create a mock file in the temporary workspace @@ -38,16 +52,21 @@ def test_list_files_with_files( files_dir.mkdir(parents=True, exist_ok=True) # Create a mock file + workspace_id = UUID(test_headers["askui-workspace"]) mock_file = File( id="file_test123", object="file", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), filename="test.txt", size=32, media_type="text/plain", + workspace_id=workspace_id, ) (files_dir / "file_test123.json").write_text(mock_file.model_dump_json()) + # Add file to database + self._add_file_to_db(mock_file, test_db_session) + # Create a test app with overridden dependencies from askui.chat.api.app import app from askui.chat.api.dependencies import get_workspace_dir @@ -57,7 +76,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) app.dependency_overrides[get_workspace_dir] = override_workspace_dir app.dependency_overrides[get_file_service] = override_file_service @@ -76,7 +95,9 @@ def override_file_service() -> FileService: # Clean up dependency overrides app.dependency_overrides.clear() - def test_list_files_with_pagination(self, test_headers: dict[str, str]) -> None: + def test_list_files_with_pagination( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test listing files with pagination parameters.""" # Create multiple mock files in the temporary workspace temp_dir = tempfile.mkdtemp() @@ -85,16 +106,20 @@ def test_list_files_with_pagination(self, test_headers: dict[str, str]) -> None: files_dir.mkdir(parents=True, exist_ok=True) # Create multiple mock files + workspace_id = UUID(test_headers["askui-workspace"]) for i in range(5): mock_file = File( id=f"file_test{i}", object="file", - created_at=1234567890 + i, + created_at=datetime.fromtimestamp(1234567890 + i, tz=timezone.utc), filename=f"test{i}.txt", size=32, media_type="text/plain", + workspace_id=workspace_id, ) (files_dir / f"file_test{i}.json").write_text(mock_file.model_dump_json()) + # Add file to database + self._add_file_to_db(mock_file, test_db_session) # Create a test app with overridden dependencies from askui.chat.api.app import app @@ -105,7 +130,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) app.dependency_overrides[get_workspace_dir] = override_workspace_dir app.dependency_overrides[get_file_service] = override_file_service @@ -140,7 +165,9 @@ def test_upload_file_success( assert "id" in data assert "created_at" in data - def test_upload_file_without_filename(self, test_headers: dict[str, str]) -> None: + def test_upload_file_without_filename( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test file upload with simple filename.""" file_content = b"test file content" # Test with a simple filename @@ -151,15 +178,11 @@ def test_upload_file_without_filename(self, test_headers: dict[str, str]) -> Non temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) - test_app = create_test_app_with_overrides(workspace_path) + test_app = create_test_app_with_overrides(test_db_session, workspace_path) with TestClient(test_app) as client: response = client.post("/v1/files", files=files, headers=test_headers) - if response.status_code != status.HTTP_201_CREATED: - print(f"Response status: {response.status_code}") - print(f"Response body: {response.text}") - assert response.status_code == status.HTTP_201_CREATED data = response.json() assert data["object"] == "file" @@ -182,7 +205,9 @@ def test_upload_file_large_size( data = response.json() assert "detail" in data - def test_retrieve_file_success(self, test_headers: dict[str, str]) -> None: + def test_retrieve_file_success( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test successful file retrieval.""" # Create a mock file in the temporary workspace temp_dir = tempfile.mkdtemp() @@ -191,16 +216,21 @@ def test_retrieve_file_success(self, test_headers: dict[str, str]) -> None: files_dir.mkdir(parents=True, exist_ok=True) # Create a mock file + workspace_id = UUID(test_headers["askui-workspace"]) mock_file = File( id="file_test123", object="file", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), filename="test.txt", size=32, media_type="text/plain", + workspace_id=workspace_id, ) (files_dir / "file_test123.json").write_text(mock_file.model_dump_json()) + # Add file to database + self._add_file_to_db(mock_file, test_db_session) + # Create a test app with overridden dependencies from askui.chat.api.app import app from askui.chat.api.dependencies import get_workspace_dir @@ -210,7 +240,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) app.dependency_overrides[get_workspace_dir] = override_workspace_dir app.dependency_overrides[get_file_service] = override_file_service @@ -229,7 +259,9 @@ def override_file_service() -> FileService: # Clean up dependency overrides app.dependency_overrides.clear() - def test_retrieve_file_not_found(self, test_headers: dict[str, str]) -> None: + def test_retrieve_file_not_found( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test file retrieval when file doesn't exist.""" # Create a test app with overridden dependencies from askui.chat.api.app import app @@ -243,7 +275,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing @@ -266,13 +298,16 @@ def override_set_env_from_headers() -> None: # Clean up dependency overrides app.dependency_overrides.clear() - def test_download_file_success(self, test_headers: dict[str, str]) -> None: + def test_download_file_success( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test successful file download.""" # Create a mock file in the temporary workspace temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) files_dir = workspace_path / "files" - static_dir = workspace_path / "static" + workspace_id = UUID(test_headers["askui-workspace"]) + static_dir = workspace_path / "workspaces" / str(workspace_id) / "static" files_dir.mkdir(parents=True, exist_ok=True) static_dir.mkdir(parents=True, exist_ok=True) @@ -280,10 +315,11 @@ def test_download_file_success(self, test_headers: dict[str, str]) -> None: mock_file = File( id="file_test123", object="file", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), filename="test.txt", size=32, media_type="text/plain", + workspace_id=workspace_id, ) (files_dir / "file_test123.json").write_text(mock_file.model_dump_json()) @@ -291,6 +327,9 @@ def test_download_file_success(self, test_headers: dict[str, str]) -> None: file_content = b"test file content" (static_dir / "file_test123.txt").write_bytes(file_content) + # Add file to database + self._add_file_to_db(mock_file, test_db_session) + # Create a test app with overridden dependencies from askui.chat.api.app import app from askui.chat.api.dependencies import get_workspace_dir @@ -300,7 +339,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) app.dependency_overrides[get_workspace_dir] = override_workspace_dir app.dependency_overrides[get_file_service] = override_file_service @@ -322,7 +361,9 @@ def override_file_service() -> FileService: # Clean up dependency overrides app.dependency_overrides.clear() - def test_download_file_not_found(self, test_headers: dict[str, str]) -> None: + def test_download_file_not_found( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test file download when file doesn't exist.""" # Create a test app with overridden dependencies from askui.chat.api.app import app @@ -336,7 +377,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing @@ -359,13 +400,16 @@ def override_set_env_from_headers() -> None: # Clean up dependency overrides app.dependency_overrides.clear() - def test_delete_file_success(self, test_headers: dict[str, str]) -> None: + def test_delete_file_success( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test successful file deletion.""" # Create a mock file in the temporary workspace temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) files_dir = workspace_path / "files" - static_dir = workspace_path / "static" + workspace_id = UUID(test_headers["askui-workspace"]) + static_dir = workspace_path / "workspaces" / str(workspace_id) / "static" files_dir.mkdir(parents=True, exist_ok=True) static_dir.mkdir(parents=True, exist_ok=True) @@ -373,10 +417,11 @@ def test_delete_file_success(self, test_headers: dict[str, str]) -> None: mock_file = File( id="file_test123", object="file", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), filename="test.txt", size=32, media_type="text/plain", + workspace_id=workspace_id, ) (files_dir / "file_test123.json").write_text(mock_file.model_dump_json()) @@ -384,6 +429,9 @@ def test_delete_file_success(self, test_headers: dict[str, str]) -> None: file_content = b"test file content" (static_dir / "file_test123.txt").write_bytes(file_content) + # Add file to database + self._add_file_to_db(mock_file, test_db_session) + # Create a test app with overridden dependencies from askui.chat.api.app import app from askui.chat.api.dependencies import get_workspace_dir @@ -392,8 +440,10 @@ def test_delete_file_success(self, test_headers: dict[str, str]) -> None: def override_workspace_dir() -> Path: return workspace_path + file_service_override = FileService(test_db_session, workspace_path) + def override_file_service() -> FileService: - return FileService(workspace_path) + return file_service_override app.dependency_overrides[get_workspace_dir] = override_workspace_dir app.dependency_overrides[get_file_service] = override_file_service @@ -404,14 +454,19 @@ def override_file_service() -> FileService: assert response.status_code == status.HTTP_204_NO_CONTENT - # Verify file is deleted - assert not (files_dir / "file_test123.json").exists() + # Verify static file is deleted (JSON files are no longer used) assert not (static_dir / "file_test123.txt").exists() + + # Verify file is deleted from database + with pytest.raises(NotFoundError): + file_service_override.retrieve(workspace_id, FileId("file_test123")) finally: # Clean up dependency overrides app.dependency_overrides.clear() - def test_delete_file_not_found(self, test_headers: dict[str, str]) -> None: + def test_delete_file_not_found( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test file deletion when file doesn't exist.""" # Create a test app with overridden dependencies from askui.chat.api.app import app @@ -425,7 +480,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing @@ -491,7 +546,9 @@ def test_upload_file_without_content_type( assert data["media_type"] is not None assert data["media_type"] != "" - def test_list_files_with_filtering(self, test_headers: dict[str, str]) -> None: + def test_list_files_with_filtering( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test listing files with filtering parameters.""" # Create multiple mock files in the temporary workspace temp_dir = tempfile.mkdtemp() @@ -500,16 +557,20 @@ def test_list_files_with_filtering(self, test_headers: dict[str, str]) -> None: files_dir.mkdir(parents=True, exist_ok=True) # Create multiple mock files with different timestamps + workspace_id = UUID(test_headers["askui-workspace"]) for i in range(3): mock_file = File( id=f"file_test{i}", object="file", - created_at=1234567890 + i, + created_at=datetime.fromtimestamp(1234567890 + i, tz=timezone.utc), filename=f"test{i}.txt", size=32, media_type="text/plain", + workspace_id=workspace_id, ) (files_dir / f"file_test{i}.json").write_text(mock_file.model_dump_json()) + # Add file to database + self._add_file_to_db(mock_file, test_db_session) # Create a test app with overridden dependencies from askui.chat.api.app import app @@ -520,7 +581,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) app.dependency_overrides[get_workspace_dir] = override_workspace_dir app.dependency_overrides[get_file_service] = override_file_service diff --git a/tests/integration/chat/api/test_files_edge_cases.py b/tests/integration/chat/api/test_files_edge_cases.py index fd4fc9d3..dad236af 100644 --- a/tests/integration/chat/api/test_files_edge_cases.py +++ b/tests/integration/chat/api/test_files_edge_cases.py @@ -6,12 +6,15 @@ from fastapi import status from fastapi.testclient import TestClient +from sqlalchemy.orm import Session class TestFilesAPIEdgeCases: """Test suite for edge cases and error scenarios in the files API.""" - def test_upload_empty_file(self, test_headers: dict[str, str]) -> None: + def test_upload_empty_file( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test uploading an empty file.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) @@ -26,7 +29,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing @@ -52,7 +55,7 @@ def override_set_env_from_headers() -> None: app.dependency_overrides.clear() def test_upload_file_with_special_characters_in_filename( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_db_session: Session ) -> None: """Test uploading a file with special characters in the filename.""" temp_dir = tempfile.mkdtemp() @@ -68,7 +71,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing @@ -96,7 +99,7 @@ def override_set_env_from_headers() -> None: app.dependency_overrides.clear() def test_upload_file_with_very_long_filename( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_db_session: Session ) -> None: """Test uploading a file with a very long filename.""" temp_dir = tempfile.mkdtemp() @@ -112,7 +115,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing @@ -140,7 +143,7 @@ def override_set_env_from_headers() -> None: app.dependency_overrides.clear() def test_upload_file_with_unknown_mime_type( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_db_session: Session ) -> None: """Test uploading a file with an unknown MIME type.""" temp_dir = tempfile.mkdtemp() @@ -156,7 +159,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing @@ -182,7 +185,7 @@ def override_set_env_from_headers() -> None: app.dependency_overrides.clear() def test_upload_file_with_binary_content( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_db_session: Session ) -> None: """Test uploading a file with binary content.""" temp_dir = tempfile.mkdtemp() @@ -198,7 +201,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing @@ -250,7 +253,7 @@ def test_upload_file_with_invalid_workspace_header( assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY def test_upload_file_with_malformed_file_data( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_db_session: Session ) -> None: """Test uploading with malformed file data.""" temp_dir = tempfile.mkdtemp() @@ -266,7 +269,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing @@ -288,7 +291,7 @@ def override_set_env_from_headers() -> None: app.dependency_overrides.clear() def test_upload_file_with_corrupted_content( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_db_session: Session ) -> None: """Test uploading a file with corrupted content.""" temp_dir = tempfile.mkdtemp() @@ -304,7 +307,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing @@ -333,7 +336,7 @@ def read(self, size: int) -> bytes: # noqa: ARG002 app.dependency_overrides.clear() def test_list_files_with_invalid_pagination( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_db_session: Session ) -> None: """Test listing files with invalid pagination parameters.""" temp_dir = tempfile.mkdtemp() @@ -349,7 +352,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing @@ -377,7 +380,7 @@ def override_set_env_from_headers() -> None: app.dependency_overrides.clear() def test_retrieve_file_with_invalid_id_format( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_db_session: Session ) -> None: """Test retrieving a file with an invalid ID format.""" temp_dir = tempfile.mkdtemp() @@ -393,7 +396,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing @@ -417,7 +420,7 @@ def override_set_env_from_headers() -> None: app.dependency_overrides.clear() def test_delete_file_with_invalid_id_format( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_db_session: Session ) -> None: """Test deleting a file with an invalid ID format.""" temp_dir = tempfile.mkdtemp() @@ -433,7 +436,7 @@ def override_workspace_dir() -> Path: return workspace_path def override_file_service() -> FileService: - return FileService(workspace_path) + return FileService(test_db_session, workspace_path) def override_set_env_from_headers() -> None: # No-op for testing diff --git a/tests/integration/chat/api/test_files_service.py b/tests/integration/chat/api/test_files_service.py index 49d221b1..00d29629 100644 --- a/tests/integration/chat/api/test_files_service.py +++ b/tests/integration/chat/api/test_files_service.py @@ -6,11 +6,12 @@ import pytest from fastapi import UploadFile +from sqlalchemy.orm import Session -from askui.chat.api.files.models import File, FileCreateParams +from askui.chat.api.files.models import File, FileCreate from askui.chat.api.files.service import FileService from askui.chat.api.models import FileId -from askui.utils.api_utils import ConflictError, FileTooLargeError, NotFoundError +from askui.utils.api_utils import FileTooLargeError, NotFoundError class TestFileService: @@ -23,85 +24,51 @@ def temp_workspace_dir(self) -> Path: return Path(temp_dir) @pytest.fixture - def file_service(self, temp_workspace_dir: Path) -> FileService: + def file_service( + self, test_db_session: Session, temp_workspace_dir: Path + ) -> FileService: """Create a FileService instance with temporary workspace.""" - return FileService(temp_workspace_dir) + return FileService(test_db_session, temp_workspace_dir) @pytest.fixture - def sample_file_params(self) -> FileCreateParams: + def sample_file_params(self) -> FileCreate: """Create sample file creation parameters.""" - return FileCreateParams(filename="test.txt", size=32, media_type="text/plain") - - def test_get_file_path_new_file(self, file_service: FileService) -> None: - """Test getting file path for a new file.""" - file_id = FileId("file_test123") - file_path = file_service._get_file_path(file_id, new=True) - - expected_path = file_service._files_dir / "file_test123.json" - assert file_path == expected_path - - def test_get_file_path_existing_file(self, file_service: FileService) -> None: - """Test getting file path for an existing file.""" - file_id = FileId("file_test123") - - # Create the file first - file_path = file_service._files_dir / "file_test123.json" - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_text('{"id": "file_test123"}') - - result_path = file_service._get_file_path(file_id, new=False) - assert result_path == file_path - - def test_get_file_path_new_file_conflict(self, file_service: FileService) -> None: - """Test that getting path for new file raises ConflictError if file exists.""" - file_id = FileId("file_test123") - - # Create the file first - file_path = file_service._files_dir / "file_test123.json" - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_text('{"id": "file_test123"}') - - with pytest.raises(ConflictError): - file_service._get_file_path(file_id, new=True) - - def test_get_file_path_existing_file_not_found( - self, file_service: FileService - ) -> None: - """Test that getting path for existing file raises NotFoundError if file - doesn't exist.""" - file_id = FileId("file_test123") - - with pytest.raises(NotFoundError): - file_service._get_file_path(file_id, new=False) + return FileCreate(filename="test.txt", size=32, media_type="text/plain") def test_get_static_file_path(self, file_service: FileService) -> None: """Test getting static file path based on file extension.""" + from datetime import datetime, timezone + file = File( id="file_test123", object="file", - created_at=1234567890, + created_at=datetime.now(timezone.utc), filename="test.txt", size=32, media_type="text/plain", + workspace_id=None, ) static_path = file_service._get_static_file_path(file) - expected_path = file_service._static_dir / "file_test123.txt" + expected_path = file_service._data_dir / "static" / "file_test123.txt" assert static_path == expected_path def test_get_static_file_path_no_extension(self, file_service: FileService) -> None: """Test getting static file path when MIME type has no extension.""" + from datetime import datetime, timezone + file = File( id="file_test123", object="file", - created_at=1234567890, + created_at=datetime.now(timezone.utc), filename="test", size=32, media_type="application/octet-stream", + workspace_id=None, ) static_path = file_service._get_static_file_path(file) - expected_path = file_service._static_dir / "file_test123" + expected_path = file_service._data_dir / "static" / "file_test123" assert static_path == expected_path def test_list_files_empty(self, file_service: FileService) -> None: @@ -109,14 +76,14 @@ def test_list_files_empty(self, file_service: FileService) -> None: from askui.utils.api_utils import ListQuery query = ListQuery() - result = file_service.list_(query) + result = file_service.list_(None, query) assert result.object == "list" assert result.data == [] assert result.has_more is False def test_list_files_with_files( - self, file_service: FileService, sample_file_params: FileCreateParams + self, file_service: FileService, sample_file_params: FileCreate ) -> None: """Test listing files when files exist.""" from askui.utils.api_utils import ListQuery @@ -127,17 +94,17 @@ def test_list_files_with_files( temp_file.write_bytes(file_content) # Update the size to match the actual file content - params = FileCreateParams( + params = FileCreate( filename=sample_file_params.filename, size=len(file_content), media_type=sample_file_params.media_type, ) try: - file = file_service.create(params, temp_file) + file = file_service.create(None, params, temp_file) query = ListQuery() - result = file_service.list_(query) + result = file_service.list_(None, query) assert result.object == "list" assert len(result.data) == 1 @@ -147,7 +114,7 @@ def test_list_files_with_files( temp_file.unlink(missing_ok=True) def test_retrieve_file_success( - self, file_service: FileService, sample_file_params: FileCreateParams + self, file_service: FileService, sample_file_params: FileCreate ) -> None: """Test successful file retrieval.""" # Create a file first @@ -156,16 +123,16 @@ def test_retrieve_file_success( temp_file.write_bytes(file_content) # Update the size to match the actual file content - params = FileCreateParams( + params = FileCreate( filename=sample_file_params.filename, size=len(file_content), media_type=sample_file_params.media_type, ) try: - file = file_service.create(params, temp_file) + file = file_service.create(None, params, temp_file) - retrieved_file = file_service.retrieve(file.id) + retrieved_file = file_service.retrieve(None, file.id) assert retrieved_file.id == file.id assert retrieved_file.filename == file.filename @@ -179,38 +146,43 @@ def test_retrieve_file_not_found(self, file_service: FileService) -> None: file_id = FileId("file_nonexistent123") with pytest.raises(NotFoundError): - file_service.retrieve(file_id) + file_service.retrieve(None, file_id) def test_delete_file_success( - self, file_service: FileService, sample_file_params: FileCreateParams + self, file_service: FileService, sample_file_params: FileCreate ) -> None: """Test successful file deletion.""" + from uuid import UUID + + # Create a workspace_id for the test file (non-default files can be deleted) + workspace_id = UUID("75592acb-9f48-4a10-8331-ea8faeed54a5") + # Create a file first temp_file = Path(tempfile.mktemp()) file_content = b"test content" temp_file.write_bytes(file_content) # Update the size to match the actual file content - params = FileCreateParams( + params = FileCreate( filename=sample_file_params.filename, size=len(file_content), media_type=sample_file_params.media_type, ) try: - file = file_service.create(params, temp_file) + file = file_service.create(workspace_id, params, temp_file) # Verify file exists by retrieving it - retrieved_file = file_service.retrieve(file.id) + retrieved_file = file_service.retrieve(workspace_id, file.id) assert retrieved_file.id == file.id # Delete the file - file_service.delete(file.id) + file_service.delete(workspace_id, file.id) # Verify file is deleted by trying to retrieve it # (should raise NotFoundError) with pytest.raises(NotFoundError): - file_service.retrieve(file.id) + file_service.retrieve(workspace_id, file.id) finally: temp_file.unlink(missing_ok=True) @@ -219,10 +191,10 @@ def test_delete_file_not_found(self, file_service: FileService) -> None: file_id = FileId("file_nonexistent123") with pytest.raises(NotFoundError): - file_service.delete(file_id) + file_service.delete(None, file_id) def test_retrieve_file_content_success( - self, file_service: FileService, sample_file_params: FileCreateParams + self, file_service: FileService, sample_file_params: FileCreate ) -> None: """Test successful file content retrieval.""" # Create a file first @@ -231,16 +203,18 @@ def test_retrieve_file_content_success( temp_file.write_bytes(file_content) # Update the size to match the actual file content - params = FileCreateParams( + params = FileCreate( filename=sample_file_params.filename, size=len(file_content), media_type=sample_file_params.media_type, ) try: - file = file_service.create(params, temp_file) + file = file_service.create(None, params, temp_file) - retrieved_file, file_path = file_service.retrieve_file_content(file.id) + retrieved_file, file_path = file_service.retrieve_file_content( + None, file.id + ) assert retrieved_file.id == file.id assert file_path.exists() @@ -252,10 +226,10 @@ def test_retrieve_file_content_not_found(self, file_service: FileService) -> Non file_id = FileId("file_nonexistent123") with pytest.raises(NotFoundError): - file_service.retrieve_file_content(file_id) + file_service.retrieve_file_content(None, file_id) def test_create_file_success( - self, file_service: FileService, sample_file_params: FileCreateParams + self, file_service: FileService, sample_file_params: FileCreate ) -> None: """Test successful file creation.""" temp_file = Path(tempfile.mktemp()) @@ -264,13 +238,13 @@ def test_create_file_success( try: # Update the size to match the actual file content - params = FileCreateParams( + params = FileCreate( filename=sample_file_params.filename, size=len(file_content), media_type=sample_file_params.media_type, ) - file = file_service.create(params, temp_file) + file = file_service.create(None, params, temp_file) assert file.id.startswith("file_") assert file.filename == params.filename @@ -282,10 +256,6 @@ def test_create_file_success( assert isinstance(file.created_at, datetime) assert file.created_at > datetime(2020, 1, 1, tzinfo=timezone.utc) - # Verify metadata file was created - metadata_path = file_service._get_file_path(file.id, new=False) - assert metadata_path.exists() - # Verify static file was moved static_path = file_service._get_static_file_path(file) assert static_path.exists() @@ -299,12 +269,12 @@ def test_create_file_without_filename(self, file_service: FileService) -> None: file_content = b"test content" temp_file.write_bytes(file_content) - params = FileCreateParams( + params = FileCreate( filename=None, size=len(file_content), media_type="text/plain" ) try: - file = file_service.create(params, temp_file) + file = file_service.create(None, params, temp_file) # Should auto-generate filename with extension assert file.filename.endswith(".txt") @@ -313,43 +283,6 @@ def test_create_file_without_filename(self, file_service: FileService) -> None: finally: temp_file.unlink(missing_ok=True) - def test_save_file_new( - self, file_service: FileService, sample_file_params: FileCreateParams - ) -> None: - """Test saving a new file.""" - file = File.create(sample_file_params) - - file_service._save(file, new=True) - - # Verify file was saved - saved_path = file_service._get_file_path(file.id, new=False) - assert saved_path.exists() - - # Verify content is correct - saved_content = saved_path.read_text() - saved_file = File.model_validate_json(saved_content) - assert saved_file.id == file.id - assert saved_file.filename == file.filename - - def test_save_file_existing( - self, file_service: FileService, sample_file_params: FileCreateParams - ) -> None: - """Test saving an existing file.""" - file = File.create(sample_file_params) - - # Save file first time - file_service._save(file, new=True) - - # Modify and save again - file.filename = "modified.txt" - file_service._save(file, new=False) - - # Verify changes were saved - saved_path = file_service._get_file_path(file.id, new=False) - saved_content = saved_path.read_text() - saved_file = File.model_validate_json(saved_content) - assert saved_file.filename == "modified.txt" - @pytest.mark.asyncio async def test_write_to_temp_file_success(self, file_service: FileService) -> None: """Test successful writing to temporary file.""" @@ -417,17 +350,15 @@ async def test_upload_file_success(self, file_service: FileService) -> None: mock_upload_file.content_type = "text/plain" mock_upload_file.read.side_effect = [file_content, b""] - file = await file_service.upload_file(mock_upload_file) + file = await file_service.upload_file(None, mock_upload_file) assert file.filename == "test.txt" assert file.size == len(file_content) assert file.media_type == "text/plain" assert file.id.startswith("file_") - # Verify files were created - metadata_path = file_service._get_file_path(file.id, new=False) + # Verify static file was created static_path = file_service._get_static_file_path(file) - assert metadata_path.exists() assert static_path.exists() @pytest.mark.asyncio @@ -439,4 +370,4 @@ async def test_upload_file_upload_failure(self, file_service: FileService) -> No mock_upload_file.read.side_effect = Exception("Simulated upload failure") with pytest.raises(Exception, match="Simulated upload failure"): - await file_service.upload_file(mock_upload_file) + await file_service.upload_file(None, mock_upload_file) diff --git a/tests/integration/chat/api/test_mcp_configs.py b/tests/integration/chat/api/test_mcp_configs.py index 51328550..cb32aa5d 100644 --- a/tests/integration/chat/api/test_mcp_configs.py +++ b/tests/integration/chat/api/test_mcp_configs.py @@ -1,44 +1,45 @@ """Integration tests for the MCP configs API endpoints.""" -import tempfile -from pathlib import Path +from uuid import UUID from fastapi import status from fastapi.testclient import TestClient +from fastmcp.mcp_config import StdioMCPServer +from sqlalchemy.orm import Session from askui.chat.api.mcp_configs.models import McpConfig +from askui.chat.api.mcp_configs.orms import McpConfigOrm from askui.chat.api.mcp_configs.service import McpConfigService class TestMcpConfigsAPI: """Test suite for the MCP configs API endpoints.""" - def test_list_mcp_configs_with_configs(self, test_headers: dict[str, str]) -> None: + def test_list_mcp_configs_with_configs( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test listing MCP configs when configs exist.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - mcp_configs_dir = workspace_path / "mcp_configs" - mcp_configs_dir.mkdir(parents=True, exist_ok=True) + from datetime import datetime, timezone - # Create a mock MCP config - workspace_id = test_headers["askui-workspace"] + # Create a mock MCP config in the database + workspace_id = UUID(test_headers["askui-workspace"]) mock_config = McpConfig( id="mcpcnf_test123", object="mcp_config", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, timezone.utc), name="Test MCP Config", - mcp_server={"type": "stdio", "command": "test_command"}, + mcp_server=StdioMCPServer(type="stdio", command="test_command"), workspace_id=workspace_id, ) - (mcp_configs_dir / "mcpcnf_test123.json").write_text( - mock_config.model_dump_json() - ) + mcp_config_orm = McpConfigOrm.from_model(mock_config) + test_db_session.add(mcp_config_orm) + test_db_session.commit() from askui.chat.api.app import app from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service def override_mcp_config_service() -> McpConfigService: - return McpConfigService(workspace_path, seeds=[]) + return McpConfigService(test_db_session, seeds=[]) app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service @@ -57,34 +58,31 @@ def override_mcp_config_service() -> McpConfigService: app.dependency_overrides.clear() def test_list_mcp_configs_with_pagination( - self, test_headers: dict[str, str] + self, test_db_session: Session, test_headers: dict[str, str] ) -> None: """Test listing MCP configs with pagination parameters.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - mcp_configs_dir = workspace_path / "mcp_configs" - mcp_configs_dir.mkdir(parents=True, exist_ok=True) + from datetime import datetime, timezone - # Create multiple mock MCP configs - workspace_id = test_headers["askui-workspace"] + # Create multiple mock MCP configs in the database + workspace_id = UUID(test_headers["askui-workspace"]) for i in range(5): mock_config = McpConfig( id=f"mcpcnf_test{i}", object="mcp_config", - created_at=1234567890 + i, + created_at=datetime.fromtimestamp(1234567890 + i, timezone.utc), name=f"Test MCP Config {i}", - mcp_server={"type": "stdio", "command": f"test_command_{i}"}, + mcp_server=StdioMCPServer(type="stdio", command=f"test_command_{i}"), workspace_id=workspace_id, ) - (mcp_configs_dir / f"mcpcnf_test{i}.json").write_text( - mock_config.model_dump_json() - ) + mcp_config_orm = McpConfigOrm.from_model(mock_config) + test_db_session.add(mcp_config_orm) + test_db_session.commit() from askui.chat.api.app import app from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service def override_mcp_config_service() -> McpConfigService: - return McpConfigService(workspace_path, seeds=[]) + return McpConfigService(test_db_session, seeds=[]) app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service @@ -99,16 +97,15 @@ def override_mcp_config_service() -> McpConfigService: finally: app.dependency_overrides.clear() - def test_create_mcp_config(self, test_headers: dict[str, str]) -> None: + def test_create_mcp_config( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test creating a new MCP config.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - from askui.chat.api.app import app from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service def override_mcp_config_service() -> McpConfigService: - return McpConfigService(workspace_path, seeds=[]) + return McpConfigService(test_db_session, seeds=[]) app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service @@ -130,16 +127,15 @@ def override_mcp_config_service() -> McpConfigService: finally: app.dependency_overrides.clear() - def test_create_mcp_config_minimal(self, test_headers: dict[str, str]) -> None: + def test_create_mcp_config_minimal( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test creating an MCP config with minimal data.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - from askui.chat.api.app import app from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service def override_mcp_config_service() -> McpConfigService: - return McpConfigService(workspace_path, seeds=[]) + return McpConfigService(test_db_session, seeds=[]) app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service @@ -163,29 +159,29 @@ def override_mcp_config_service() -> McpConfigService: finally: app.dependency_overrides.clear() - def test_retrieve_mcp_config(self, test_headers: dict[str, str]) -> None: + def test_retrieve_mcp_config( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test retrieving an existing MCP config.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - mcp_configs_dir = workspace_path / "mcp_configs" - mcp_configs_dir.mkdir(parents=True, exist_ok=True) + from datetime import datetime, timezone mock_config = McpConfig( id="mcpcnf_test123", object="mcp_config", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, timezone.utc), name="Test MCP Config", - mcp_server={"type": "stdio", "command": "test_command"}, - ) - (mcp_configs_dir / "mcpcnf_test123.json").write_text( - mock_config.model_dump_json() + mcp_server=StdioMCPServer(type="stdio", command="test_command"), + workspace_id=None, ) + mcp_config_orm = McpConfigOrm.from_model(mock_config) + test_db_session.add(mcp_config_orm) + test_db_session.commit() from askui.chat.api.app import app from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service def override_mcp_config_service() -> McpConfigService: - return McpConfigService(workspace_path, seeds=[]) + return McpConfigService(test_db_session, seeds=[]) app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service @@ -216,31 +212,30 @@ def test_retrieve_mcp_config_not_found( data = response.json() assert "detail" in data - def test_modify_mcp_config(self, test_headers: dict[str, str]) -> None: + def test_modify_mcp_config( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test modifying an existing MCP config.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - mcp_configs_dir = workspace_path / "mcp_configs" - mcp_configs_dir.mkdir(parents=True, exist_ok=True) + from datetime import datetime, timezone - workspace_id = test_headers["askui-workspace"] + workspace_id = UUID(test_headers["askui-workspace"]) mock_config = McpConfig( id="mcpcnf_test123", object="mcp_config", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, timezone.utc), name="Original Name", - mcp_server={"type": "stdio", "command": "original_command"}, + mcp_server=StdioMCPServer(type="stdio", command="original_command"), workspace_id=workspace_id, ) - (mcp_configs_dir / "mcpcnf_test123.json").write_text( - mock_config.model_dump_json() - ) + mcp_config_orm = McpConfigOrm.from_model(mock_config) + test_db_session.add(mcp_config_orm) + test_db_session.commit() from askui.chat.api.app import app from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service def override_mcp_config_service() -> McpConfigService: - return McpConfigService(workspace_path, seeds=[]) + return McpConfigService(test_db_session, seeds=[]) app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service @@ -264,31 +259,30 @@ def override_mcp_config_service() -> McpConfigService: finally: app.dependency_overrides.clear() - def test_modify_mcp_config_partial(self, test_headers: dict[str, str]) -> None: + def test_modify_mcp_config_partial( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test modifying an MCP config with partial data.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - mcp_configs_dir = workspace_path / "mcp_configs" - mcp_configs_dir.mkdir(parents=True, exist_ok=True) + from datetime import datetime, timezone - workspace_id = test_headers["askui-workspace"] + workspace_id = UUID(test_headers["askui-workspace"]) mock_config = McpConfig( id="mcpcnf_test123", object="mcp_config", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, timezone.utc), name="Original Name", - mcp_server={"type": "stdio", "command": "original_command"}, + mcp_server=StdioMCPServer(type="stdio", command="original_command"), workspace_id=workspace_id, ) - (mcp_configs_dir / "mcpcnf_test123.json").write_text( - mock_config.model_dump_json() - ) + mcp_config_orm = McpConfigOrm.from_model(mock_config) + test_db_session.add(mcp_config_orm) + test_db_session.commit() from askui.chat.api.app import app from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service def override_mcp_config_service() -> McpConfigService: - return McpConfigService(workspace_path, seeds=[]) + return McpConfigService(test_db_session, seeds=[]) app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service @@ -321,31 +315,30 @@ def test_modify_mcp_config_not_found( assert response.status_code == status.HTTP_404_NOT_FOUND - def test_delete_mcp_config(self, test_headers: dict[str, str]) -> None: + def test_delete_mcp_config( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test deleting an existing MCP config.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - mcp_configs_dir = workspace_path / "mcp_configs" - mcp_configs_dir.mkdir(parents=True, exist_ok=True) + from datetime import datetime, timezone - workspace_id = test_headers["askui-workspace"] + workspace_id = UUID(test_headers["askui-workspace"]) mock_config = McpConfig( id="mcpcnf_test123", object="mcp_config", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, timezone.utc), name="Test MCP Config", - mcp_server={"type": "stdio", "command": "test_command"}, + mcp_server=StdioMCPServer(type="stdio", command="test_command"), workspace_id=workspace_id, ) - (mcp_configs_dir / "mcpcnf_test123.json").write_text( - mock_config.model_dump_json() - ) + mcp_config_orm = McpConfigOrm.from_model(mock_config) + test_db_session.add(mcp_config_orm) + test_db_session.commit() from askui.chat.api.app import app from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service def override_mcp_config_service() -> McpConfigService: - return McpConfigService(workspace_path, seeds=[]) + return McpConfigService(test_db_session, seeds=[]) app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service @@ -371,32 +364,29 @@ def test_delete_mcp_config_not_found( assert response.status_code == status.HTTP_404_NOT_FOUND def test_modify_default_mcp_config_forbidden( - self, test_headers: dict[str, str] + self, test_db_session: Session, test_headers: dict[str, str] ) -> None: """Test that modifying a default MCP configuration returns 403 Forbidden.""" - # Create a default MCP config (no workspace_id) - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - mcp_configs_dir = workspace_path / "mcp_configs" - mcp_configs_dir.mkdir(parents=True, exist_ok=True) + from datetime import datetime, timezone + # Create a default MCP config (no workspace_id) in the database default_config = McpConfig( id="mcpcnf_default123", object="mcp_config", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, timezone.utc), name="Default MCP Config", - mcp_server={"type": "stdio", "command": "default_command"}, + mcp_server=StdioMCPServer(type="stdio", command="default_command"), workspace_id=None, # No workspace_id = default ) - (mcp_configs_dir / "mcpcnf_default123.json").write_text( - default_config.model_dump_json() - ) + mcp_config_orm = McpConfigOrm.from_model(default_config) + test_db_session.add(mcp_config_orm) + test_db_session.commit() from askui.chat.api.app import app from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service def override_mcp_config_service() -> McpConfigService: - return McpConfigService(workspace_path, seeds=[]) + return McpConfigService(test_db_session, seeds=[]) app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service @@ -414,32 +404,29 @@ def override_mcp_config_service() -> McpConfigService: app.dependency_overrides.clear() def test_delete_default_mcp_config_forbidden( - self, test_headers: dict[str, str] + self, test_db_session: Session, test_headers: dict[str, str] ) -> None: """Test that deleting a default MCP configuration returns 403 Forbidden.""" - # Create a default MCP config (no workspace_id) - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - mcp_configs_dir = workspace_path / "mcp_configs" - mcp_configs_dir.mkdir(parents=True, exist_ok=True) + from datetime import datetime, timezone + # Create a default MCP config (no workspace_id) in the database default_config = McpConfig( id="mcpcnf_default456", object="mcp_config", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, timezone.utc), name="Default MCP Config", - mcp_server={"type": "stdio", "command": "default_command"}, + mcp_server=StdioMCPServer(type="stdio", command="default_command"), workspace_id=None, # No workspace_id = default ) - (mcp_configs_dir / "mcpcnf_default456.json").write_text( - default_config.model_dump_json() - ) + mcp_config_orm = McpConfigOrm.from_model(default_config) + test_db_session.add(mcp_config_orm) + test_db_session.commit() from askui.chat.api.app import app from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service def override_mcp_config_service() -> McpConfigService: - return McpConfigService(workspace_path, seeds=[]) + return McpConfigService(test_db_session, seeds=[]) app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service @@ -456,47 +443,43 @@ def override_mcp_config_service() -> McpConfigService: app.dependency_overrides.clear() def test_list_mcp_configs_includes_default_and_workspace( - self, test_headers: dict[str, str] + self, test_db_session: Session, test_headers: dict[str, str] ) -> None: """Test that listing MCP configs includes both default and workspace-scoped ones.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - mcp_configs_dir = workspace_path / "mcp_configs" - mcp_configs_dir.mkdir(parents=True, exist_ok=True) + from datetime import datetime, timezone - # Create a default MCP config (no workspace_id) + # Create a default MCP config (no workspace_id) in the database default_config = McpConfig( id="mcpcnf_default789", object="mcp_config", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, timezone.utc), name="Default MCP Config", - mcp_server={"type": "stdio", "command": "default_command"}, + mcp_server=StdioMCPServer(type="stdio", command="default_command"), workspace_id=None, # No workspace_id = default ) - (mcp_configs_dir / "mcpcnf_default789.json").write_text( - default_config.model_dump_json() - ) + mcp_config_orm = McpConfigOrm.from_model(default_config) + test_db_session.add(mcp_config_orm) # Create a workspace-scoped MCP config - workspace_id = test_headers["askui-workspace"] + workspace_id = UUID(test_headers["askui-workspace"]) workspace_config = McpConfig( id="mcpcnf_workspace123", object="mcp_config", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, timezone.utc), name="Workspace MCP Config", - mcp_server={"type": "stdio", "command": "workspace_command"}, + mcp_server=StdioMCPServer(type="stdio", command="workspace_command"), workspace_id=workspace_id, ) - (mcp_configs_dir / "mcpcnf_workspace123.json").write_text( - workspace_config.model_dump_json() - ) + workspace_config_orm = McpConfigOrm.from_model(workspace_config) + test_db_session.add(workspace_config_orm) + test_db_session.commit() from askui.chat.api.app import app from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service def override_mcp_config_service() -> McpConfigService: - return McpConfigService(workspace_path, seeds=[]) + return McpConfigService(test_db_session, seeds=[]) app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service @@ -523,37 +506,34 @@ def override_mcp_config_service() -> McpConfigService: # Default config should not have workspace_id field (excluded when None) assert "workspace_id" not in default_config_data - assert workspace_config_data["workspace_id"] == workspace_id + assert workspace_config_data["workspace_id"] == str(workspace_id) finally: app.dependency_overrides.clear() def test_retrieve_default_mcp_config_success( - self, test_headers: dict[str, str] + self, test_db_session: Session, test_headers: dict[str, str] ) -> None: """Test that retrieving a default MCP configuration works.""" - # Create a default MCP config (no workspace_id) - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - mcp_configs_dir = workspace_path / "mcp_configs" - mcp_configs_dir.mkdir(parents=True, exist_ok=True) + from datetime import datetime, timezone + # Create a default MCP config (no workspace_id) in the database default_config = McpConfig( id="mcpcnf_defaultretrieve", object="mcp_config", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, timezone.utc), name="Default MCP Config", - mcp_server={"type": "stdio", "command": "default_command"}, + mcp_server=StdioMCPServer(type="stdio", command="default_command"), workspace_id=None, # No workspace_id = default ) - (mcp_configs_dir / "mcpcnf_defaultretrieve.json").write_text( - default_config.model_dump_json() - ) + mcp_config_orm = McpConfigOrm.from_model(default_config) + test_db_session.add(mcp_config_orm) + test_db_session.commit() from askui.chat.api.app import app from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service def override_mcp_config_service() -> McpConfigService: - return McpConfigService(workspace_path, seeds=[]) + return McpConfigService(test_db_session, seeds=[]) app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service @@ -574,32 +554,29 @@ def override_mcp_config_service() -> McpConfigService: app.dependency_overrides.clear() def test_workspace_scoped_mcp_config_operations_success( - self, test_headers: dict[str, str] + self, test_db_session: Session, test_headers: dict[str, str] ) -> None: """Test that workspace-scoped MCP configs can be modified and deleted.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - mcp_configs_dir = workspace_path / "mcp_configs" - mcp_configs_dir.mkdir(parents=True, exist_ok=True) + from datetime import datetime, timezone - workspace_id = test_headers["askui-workspace"] + workspace_id = UUID(test_headers["askui-workspace"]) workspace_config = McpConfig( id="mcpcnf_workspaceops", object="mcp_config", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, timezone.utc), name="Workspace MCP Config", - mcp_server={"type": "stdio", "command": "workspace_command"}, + mcp_server=StdioMCPServer(type="stdio", command="workspace_command"), workspace_id=workspace_id, ) - (mcp_configs_dir / "mcpcnf_workspaceops.json").write_text( - workspace_config.model_dump_json() - ) + mcp_config_orm = McpConfigOrm.from_model(workspace_config) + test_db_session.add(mcp_config_orm) + test_db_session.commit() from askui.chat.api.app import app from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service def override_mcp_config_service() -> McpConfigService: - return McpConfigService(workspace_path, seeds=[]) + return McpConfigService(test_db_session, seeds=[]) app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service @@ -615,7 +592,7 @@ def override_mcp_config_service() -> McpConfigService: data = response.json() assert data["name"] == "Modified Workspace MCP Config" - assert data["workspace_id"] == workspace_id + assert data["workspace_id"] == str(workspace_id) # Delete the workspace MCP config response = client.delete( diff --git a/tests/integration/chat/api/test_messages.py b/tests/integration/chat/api/test_messages.py index c46ec94c..15e61246 100644 --- a/tests/integration/chat/api/test_messages.py +++ b/tests/integration/chat/api/test_messages.py @@ -1,22 +1,40 @@ """Integration tests for the messages API endpoints.""" import tempfile +from datetime import datetime, timezone from pathlib import Path -from unittest.mock import Mock +from uuid import UUID from fastapi import status from fastapi.testclient import TestClient +from sqlalchemy.orm import Session from askui.chat.api.messages.models import Message +from askui.chat.api.messages.orms import MessageOrm from askui.chat.api.messages.service import MessageService from askui.chat.api.threads.models import Thread +from askui.chat.api.threads.orms import ThreadOrm from askui.chat.api.threads.service import ThreadService class TestMessagesAPI: """Test suite for the messages API endpoints.""" - def test_list_messages_empty(self, test_headers: dict[str, str]) -> None: + def _add_thread_to_db(self, thread: Thread, test_db_session: Session) -> None: + """Add a thread to the test database.""" + thread_orm = ThreadOrm.from_model(thread) + test_db_session.add(thread_orm) + test_db_session.commit() + + def _add_message_to_db(self, message: Message, test_db_session: Session) -> None: + """Add a message to the test database.""" + message_orm = MessageOrm.from_model(message) + test_db_session.add(message_orm) + test_db_session.commit() + + def test_list_messages_empty( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test listing messages when no messages exist.""" # First create a thread temp_dir = tempfile.mkdtemp() @@ -24,27 +42,28 @@ def test_list_messages_empty(self, test_headers: dict[str, str]) -> None: threads_dir = workspace_path / "threads" threads_dir.mkdir(parents=True, exist_ok=True) + workspace_id = UUID(test_headers["askui-workspace"]) mock_thread = Thread( id="thread_test123", object="thread", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), name="Test Thread", + workspace_id=workspace_id, ) (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + # Add thread to database + self._add_thread_to_db(mock_thread, test_db_session) + from askui.chat.api.app import app from askui.chat.api.messages.dependencies import get_message_service from askui.chat.api.threads.dependencies import get_thread_service def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) def override_message_service() -> MessageService: - return MessageService(workspace_path) + return MessageService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service app.dependency_overrides[get_message_service] = override_message_service @@ -63,7 +82,9 @@ def override_message_service() -> MessageService: finally: app.dependency_overrides.clear() - def test_list_messages_with_messages(self, test_headers: dict[str, str]) -> None: + def test_list_messages_with_messages( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test listing messages when messages exist.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) @@ -73,39 +94,45 @@ def test_list_messages_with_messages(self, test_headers: dict[str, str]) -> None messages_dir.mkdir(parents=True, exist_ok=True) # Create a mock thread + workspace_id = UUID(test_headers["askui-workspace"]) mock_thread = Thread( id="thread_test123", object="thread", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), name="Test Thread", + workspace_id=workspace_id, ) (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + # Add thread to database + self._add_thread_to_db(mock_thread, test_db_session) + # Create a mock message + workspace_id = UUID(test_headers["askui-workspace"]) mock_message = Message( id="msg_test123", object="thread.message", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), thread_id="thread_test123", role="user", content="Hello, this is a test message", metadata={"key": "value"}, + workspace_id=workspace_id, ) (messages_dir / "msg_test123.json").write_text(mock_message.model_dump_json()) + # Add message to database + self._add_message_to_db(mock_message, test_db_session) + from askui.chat.api.app import app from askui.chat.api.messages.dependencies import get_message_service from askui.chat.api.threads.dependencies import get_thread_service def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) def override_message_service() -> MessageService: - return MessageService(workspace_path) + return MessageService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service app.dependency_overrides[get_message_service] = override_message_service @@ -126,7 +153,9 @@ def override_message_service() -> MessageService: finally: app.dependency_overrides.clear() - def test_list_messages_with_pagination(self, test_headers: dict[str, str]) -> None: + def test_list_messages_with_pagination( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test listing messages with pagination parameters.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) @@ -136,41 +165,46 @@ def test_list_messages_with_pagination(self, test_headers: dict[str, str]) -> No messages_dir.mkdir(parents=True, exist_ok=True) # Create a mock thread + workspace_id = UUID(test_headers["askui-workspace"]) mock_thread = Thread( id="thread_test123", object="thread", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), name="Test Thread", + workspace_id=workspace_id, ) (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + # Add thread to database + self._add_thread_to_db(mock_thread, test_db_session) + # Create multiple mock messages + workspace_id = UUID(test_headers["askui-workspace"]) for i in range(5): mock_message = Message( id=f"msg_test{i}", object="thread.message", - created_at=1234567890 + i, + created_at=datetime.fromtimestamp(1234567890 + i, tz=timezone.utc), thread_id="thread_test123", role="user" if i % 2 == 0 else "assistant", content=f"Test message {i}", + workspace_id=workspace_id, ) (messages_dir / f"msg_test{i}.json").write_text( mock_message.model_dump_json() ) + # Add message to database + self._add_message_to_db(mock_message, test_db_session) from askui.chat.api.app import app from askui.chat.api.messages.dependencies import get_message_service from askui.chat.api.threads.dependencies import get_thread_service def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) def override_message_service() -> MessageService: - return MessageService(workspace_path) + return MessageService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service app.dependency_overrides[get_message_service] = override_message_service @@ -188,7 +222,9 @@ def override_message_service() -> MessageService: finally: app.dependency_overrides.clear() - def test_create_message(self, test_headers: dict[str, str]) -> None: + def test_create_message( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test creating a new message.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) @@ -196,27 +232,28 @@ def test_create_message(self, test_headers: dict[str, str]) -> None: threads_dir.mkdir(parents=True, exist_ok=True) # Create a mock thread + workspace_id = UUID(test_headers["askui-workspace"]) mock_thread = Thread( id="thread_test123", object="thread", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), name="Test Thread", + workspace_id=workspace_id, ) (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + # Add thread to database + self._add_thread_to_db(mock_thread, test_db_session) + from askui.chat.api.app import app from askui.chat.api.messages.dependencies import get_message_service from askui.chat.api.threads.dependencies import get_thread_service def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) def override_message_service() -> MessageService: - return MessageService(workspace_path) + return MessageService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service app.dependency_overrides[get_message_service] = override_message_service @@ -246,7 +283,9 @@ def override_message_service() -> MessageService: finally: app.dependency_overrides.clear() - def test_create_message_minimal(self, test_headers: dict[str, str]) -> None: + def test_create_message_minimal( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test creating a message with minimal data.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) @@ -254,27 +293,28 @@ def test_create_message_minimal(self, test_headers: dict[str, str]) -> None: threads_dir.mkdir(parents=True, exist_ok=True) # Create a mock thread + workspace_id = UUID(test_headers["askui-workspace"]) mock_thread = Thread( id="thread_test123", object="thread", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), name="Test Thread", + workspace_id=workspace_id, ) (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + # Add thread to database + self._add_thread_to_db(mock_thread, test_db_session) + from askui.chat.api.app import app from askui.chat.api.messages.dependencies import get_message_service from askui.chat.api.threads.dependencies import get_thread_service def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) def override_message_service() -> MessageService: - return MessageService(workspace_path) + return MessageService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service app.dependency_overrides[get_message_service] = override_message_service @@ -310,7 +350,9 @@ def test_create_message_invalid_thread( assert response.status_code == status.HTTP_404_NOT_FOUND - def test_retrieve_message(self, test_headers: dict[str, str]) -> None: + def test_retrieve_message( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test retrieving an existing message.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) @@ -320,39 +362,45 @@ def test_retrieve_message(self, test_headers: dict[str, str]) -> None: messages_dir.mkdir(parents=True, exist_ok=True) # Create a mock thread + workspace_id = UUID(test_headers["askui-workspace"]) mock_thread = Thread( id="thread_test123", object="thread", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), name="Test Thread", + workspace_id=workspace_id, ) (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + # Add thread to database + self._add_thread_to_db(mock_thread, test_db_session) + # Create a mock message + workspace_id = UUID(test_headers["askui-workspace"]) mock_message = Message( id="msg_test123", object="thread.message", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), thread_id="thread_test123", role="user", content="Test message content", metadata={"key": "value"}, + workspace_id=workspace_id, ) (messages_dir / "msg_test123.json").write_text(mock_message.model_dump_json()) + # Add message to database + self._add_message_to_db(mock_message, test_db_session) + from askui.chat.api.app import app from askui.chat.api.messages.dependencies import get_message_service from askui.chat.api.threads.dependencies import get_thread_service def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) def override_message_service() -> MessageService: - return MessageService(workspace_path) + return MessageService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service app.dependency_overrides[get_message_service] = override_message_service @@ -386,7 +434,9 @@ def test_retrieve_message_not_found( data = response.json() assert "detail" in data - def test_delete_message(self, test_headers: dict[str, str]) -> None: + def test_delete_message( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test deleting an existing message.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) @@ -396,38 +446,44 @@ def test_delete_message(self, test_headers: dict[str, str]) -> None: messages_dir.mkdir(parents=True, exist_ok=True) # Create a mock thread + workspace_id = UUID(test_headers["askui-workspace"]) mock_thread = Thread( id="thread_test123", object="thread", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), name="Test Thread", + workspace_id=workspace_id, ) (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + # Add thread to database + self._add_thread_to_db(mock_thread, test_db_session) + # Create a mock message + workspace_id = UUID(test_headers["askui-workspace"]) mock_message = Message( id="msg_test123", object="thread.message", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), thread_id="thread_test123", role="user", content="Test message to delete", + workspace_id=workspace_id, ) (messages_dir / "msg_test123.json").write_text(mock_message.model_dump_json()) + # Add message to database + self._add_message_to_db(mock_message, test_db_session) + from askui.chat.api.app import app from askui.chat.api.messages.dependencies import get_message_service from askui.chat.api.threads.dependencies import get_thread_service def override_thread_service() -> ThreadService: - from askui.chat.api.threads.service import ThreadService - - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) def override_message_service() -> MessageService: - return MessageService(workspace_path) + return MessageService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service app.dependency_overrides[get_message_service] = override_message_service diff --git a/tests/integration/chat/api/test_request_document_translator.py b/tests/integration/chat/api/test_request_document_translator.py index ba5f1597..14e4e42c 100644 --- a/tests/integration/chat/api/test_request_document_translator.py +++ b/tests/integration/chat/api/test_request_document_translator.py @@ -7,6 +7,7 @@ import pytest from PIL import Image +from sqlalchemy.orm import Session from askui.chat.api.files.service import FileService from askui.chat.api.messages.models import RequestDocumentBlockParam @@ -28,16 +29,18 @@ def temp_dir(self) -> Generator[pathlib.Path, None, None]: shutil.rmtree(temp_dir, ignore_errors=True) @pytest.fixture - def file_service(self, temp_dir: pathlib.Path) -> FileService: + def file_service( + self, test_db_session: Session, temp_dir: pathlib.Path + ) -> FileService: """Create a FileService instance using the temporary directory.""" - return FileService(temp_dir) + return FileService(test_db_session, temp_dir) @pytest.fixture def translator( self, file_service: FileService ) -> RequestDocumentBlockParamTranslator: """Create a RequestDocumentBlockParamTranslator instance.""" - return RequestDocumentBlockParamTranslator(file_service) + return RequestDocumentBlockParamTranslator(file_service, None) @pytest.fixture def cache_control(self) -> CacheControlEphemeralParam: diff --git a/tests/integration/chat/api/test_runs.py b/tests/integration/chat/api/test_runs.py index b99515db..aadaf3af 100644 --- a/tests/integration/chat/api/test_runs.py +++ b/tests/integration/chat/api/test_runs.py @@ -15,8 +15,11 @@ from askui.chat.api.assistants.service import AssistantService from askui.chat.api.models import WorkspaceId from askui.chat.api.runs.models import Run +from askui.chat.api.runs.orms import RunOrm from askui.chat.api.runs.service import RunService +from askui.chat.api.settings import Settings from askui.chat.api.threads.models import Thread +from askui.chat.api.threads.orms import ThreadOrm from askui.chat.api.threads.service import ThreadService @@ -67,6 +70,21 @@ def _add_assistant_to_db( test_db_session.add(assistant_orm) test_db_session.commit() + def _add_thread_to_db(self, thread: Thread, test_db_session: Session) -> None: + """Add a thread to the test database.""" + thread_orm = ThreadOrm.from_model(thread) + test_db_session.add(thread_orm) + test_db_session.commit() + + def _add_run_to_db(self, run: Run, test_db_session: Session) -> None: + """Add a run to the test database.""" + # Need to include status (computed field) in the model dump + run_dict = run.model_dump(exclude={"object"}) + run_dict["status"] = run.status # Add computed status field + run_orm = RunOrm(**run_dict) + test_db_session.add(run_orm) + test_db_session.commit() + def _create_test_workspace(self) -> Path: """Create a temporary workspace directory for testing.""" temp_dir = tempfile.mkdtemp() @@ -76,24 +94,38 @@ def _create_test_workspace(self) -> Path: return workspace_path def _create_test_thread( - self, workspace_path: Path, thread_id: str = "thread_test123" - ) -> None: + self, + workspace_path: Path, + thread_id: str = "thread_test123", + test_db_session: Session | None = None, + workspace_id: UUID | None = None, + ) -> Thread: """Create a test thread in the workspace.""" threads_dir = workspace_path / "threads" + if workspace_id is None and test_db_session is not None: + # Need workspace_id if adding to DB + error_msg = "workspace_id required when test_db_session is provided" + raise ValueError(error_msg) mock_thread = Thread( id=thread_id, object="thread", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), name="Test Thread", + workspace_id=workspace_id, ) (threads_dir / f"{thread_id}.json").write_text(mock_thread.model_dump_json()) + if test_db_session is not None and workspace_id is not None: + self._add_thread_to_db(mock_thread, test_db_session) + return mock_thread def _create_test_run( self, workspace_path: Path, thread_id: str = "thread_test123", run_id: str = "run_test123", - ) -> None: + test_db_session: Session | None = None, + workspace_id: UUID | None = None, + ) -> Run: """Create a test run in the workspace.""" runs_dir = workspace_path / "runs" / thread_id runs_dir.mkdir(parents=True, exist_ok=True) @@ -101,14 +133,18 @@ def _create_test_run( mock_run = Run( id=run_id, object="thread.run", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), thread_id=thread_id, assistant_id="asst_test123", - expires_at=1755846718, # 10 minutes later - started_at=1234567890, - completed_at=1234567900, + expires_at=datetime.fromtimestamp(1755846718, tz=timezone.utc), + started_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), + completed_at=datetime.fromtimestamp(1234567900, tz=timezone.utc), + workspace_id=workspace_id, ) (runs_dir / f"{run_id}.json").write_text(mock_run.model_dump_json()) + if test_db_session is not None and workspace_id is not None: + self._add_run_to_db(mock_run, test_db_session) + return mock_run def _setup_runs_dependencies( self, workspace_path: Path, test_db_session: Session @@ -119,26 +155,30 @@ def _setup_runs_dependencies( from askui.chat.api.threads.dependencies import get_thread_service def override_thread_service() -> ThreadService: - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(session=test_db_session) def override_runs_service() -> RunService: assistant_service = AssistantService(test_db_session) mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() + settings = Settings(data_dir=workspace_path) return RunService( - base_dir=workspace_path, + session=test_db_session, assistant_service=assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, chat_history_manager=Mock(), - settings=Mock(), + settings=settings, ) app.dependency_overrides[get_thread_service] = override_thread_service app.dependency_overrides[get_runs_service] = override_runs_service def _create_multiple_test_runs( - self, workspace_path: Path, thread_id: str = "thread_test123", count: int = 5 + self, + workspace_path: Path, + thread_id: str = "thread_test123", + count: int = 5, + test_db_session: Session | None = None, + workspace_id: UUID | None = None, ) -> None: """Create multiple test runs in the workspace.""" runs_dir = workspace_path / "runs" / thread_id @@ -148,12 +188,17 @@ def _create_multiple_test_runs( mock_run = Run( id=f"run_test{i}", object="thread.run", - created_at=1234567890 + i, + created_at=datetime.fromtimestamp(1234567890 + i, tz=timezone.utc), thread_id=thread_id, assistant_id=f"asst_test{i}", - expires_at=1234567890 + i + 600, # 10 minutes later + expires_at=datetime.fromtimestamp( + 1234567890 + i + 600, tz=timezone.utc + ), + workspace_id=workspace_id, ) (runs_dir / f"run_test{i}.json").write_text(mock_run.model_dump_json()) + if test_db_session is not None and workspace_id is not None: + self._add_run_to_db(mock_run, test_db_session) def _cleanup_dependencies(self) -> None: """Clean up dependency overrides.""" @@ -194,8 +239,21 @@ def test_list_runs_with_runs( ) -> None: """Test listing runs when runs exist.""" workspace_path = self._create_test_workspace() - self._create_test_thread(workspace_path) - self._create_test_run(workspace_path) + workspace_id = UUID(test_headers["askui-workspace"]) + self._create_test_thread( + workspace_path, test_db_session=test_db_session, workspace_id=workspace_id + ) + # Add assistant for foreign key + mock_assistant = self._create_test_assistant( + assistant_id="asst_test123", + workspace_id=workspace_id, + ) + self._add_assistant_to_db(mock_assistant, test_db_session) + self._create_test_run( + workspace_path, + test_db_session=test_db_session, + workspace_id=workspace_id, + ) self._setup_runs_dependencies(workspace_path, test_db_session) @@ -222,8 +280,22 @@ def test_list_runs_with_pagination( ) -> None: """Test listing runs with pagination parameters.""" workspace_path = self._create_test_workspace() - self._create_test_thread(workspace_path) - self._create_multiple_test_runs(workspace_path) + workspace_id = UUID(test_headers["askui-workspace"]) + self._create_test_thread( + workspace_path, test_db_session=test_db_session, workspace_id=workspace_id + ) + # Add assistants for foreign keys + for i in range(5): + mock_assistant = self._create_test_assistant( + assistant_id=f"asst_test{i}", + workspace_id=workspace_id, + ) + self._add_assistant_to_db(mock_assistant, test_db_session) + self._create_multiple_test_runs( + workspace_path, + test_db_session=test_db_session, + workspace_id=workspace_id, + ) self._setup_runs_dependencies(workspace_path, test_db_session) @@ -247,10 +319,16 @@ def test_create_run( ) -> None: """Test creating a new run.""" workspace_path = self._create_test_workspace() - self._create_test_thread(workspace_path) + workspace_id = UUID(test_headers["askui-workspace"]) + self._create_test_thread( + workspace_path, test_db_session=test_db_session, workspace_id=workspace_id + ) self._setup_runs_dependencies(workspace_path, test_db_session) self._add_assistant_to_db( - self._create_test_assistant(assistant_id="asst_test123"), test_db_session + self._create_test_assistant( + assistant_id="asst_test123", workspace_id=workspace_id + ), + test_db_session, ) try: @@ -275,7 +353,9 @@ def test_create_run( finally: self._cleanup_dependencies() - def test_create_run_minimal(self, test_headers: dict[str, str]) -> None: + def test_create_run_minimal( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test creating a run with minimal data.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) @@ -283,14 +363,26 @@ def test_create_run_minimal(self, test_headers: dict[str, str]) -> None: threads_dir.mkdir(parents=True, exist_ok=True) # Create a mock thread + workspace_id = UUID(test_headers["askui-workspace"]) mock_thread = Thread( id="thread_test123", object="thread", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), name="Test Thread", + workspace_id=workspace_id, ) (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + # Add thread to database + self._add_thread_to_db(mock_thread, test_db_session) + + # Add assistant to database (required for foreign key) + mock_assistant = self._create_test_assistant( + assistant_id="asst_test123", + workspace_id=workspace_id, + ) + self._add_assistant_to_db(mock_assistant, test_db_session) + from askui.chat.api.app import app from askui.chat.api.runs.dependencies import get_runs_service from askui.chat.api.threads.dependencies import get_thread_service @@ -298,19 +390,18 @@ def test_create_run_minimal(self, test_headers: dict[str, str]) -> None: def override_thread_service() -> ThreadService: from askui.chat.api.threads.service import ThreadService - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(session=test_db_session) def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() + settings = Settings(data_dir=workspace_path) return RunService( - base_dir=workspace_path, + session=test_db_session, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, chat_history_manager=Mock(), - settings=Mock(), + settings=settings, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -333,7 +424,9 @@ def override_runs_service() -> RunService: finally: app.dependency_overrides.clear() - def test_create_run_streaming(self, test_headers: dict[str, str]) -> None: + def test_create_run_streaming( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test creating a streaming run.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) @@ -341,14 +434,26 @@ def test_create_run_streaming(self, test_headers: dict[str, str]) -> None: threads_dir.mkdir(parents=True, exist_ok=True) # Create a mock thread + workspace_id = UUID(test_headers["askui-workspace"]) mock_thread = Thread( id="thread_test123", object="thread", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), name="Test Thread", + workspace_id=workspace_id, ) (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + # Add thread to database + self._add_thread_to_db(mock_thread, test_db_session) + + # Add assistant to database (required for foreign key) + mock_assistant = self._create_test_assistant( + assistant_id="asst_test123", + workspace_id=workspace_id, + ) + self._add_assistant_to_db(mock_assistant, test_db_session) + from askui.chat.api.app import app from askui.chat.api.runs.dependencies import get_runs_service from askui.chat.api.threads.dependencies import get_thread_service @@ -356,19 +461,18 @@ def test_create_run_streaming(self, test_headers: dict[str, str]) -> None: def override_thread_service() -> ThreadService: from askui.chat.api.threads.service import ThreadService - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(session=test_db_session) def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() + settings = Settings(data_dir=workspace_path) return RunService( - base_dir=workspace_path, + session=test_db_session, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, chat_history_manager=Mock(), - settings=Mock(), + settings=settings, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -391,11 +495,21 @@ def override_runs_service() -> RunService: finally: app.dependency_overrides.clear() - def test_create_thread_and_run(self, test_headers: dict[str, str]) -> None: + def test_create_thread_and_run( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test creating a thread and run in one request.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) + # Add assistant to database (required for foreign key) + workspace_id = UUID(test_headers["askui-workspace"]) + mock_assistant = self._create_test_assistant( + assistant_id="asst_test123", + workspace_id=workspace_id, + ) + self._add_assistant_to_db(mock_assistant, test_db_session) + from askui.chat.api.app import app from askui.chat.api.runs.dependencies import get_runs_service from askui.chat.api.threads.dependencies import get_thread_service @@ -403,19 +517,18 @@ def test_create_thread_and_run(self, test_headers: dict[str, str]) -> None: def override_thread_service() -> ThreadService: from askui.chat.api.threads.service import ThreadService - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(session=test_db_session) def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() + settings = Settings(data_dir=workspace_path) return RunService( - base_dir=workspace_path, + session=test_db_session, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, chat_history_manager=Mock(), - settings=Mock(), + settings=settings, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -450,11 +563,21 @@ def override_runs_service() -> RunService: finally: app.dependency_overrides.clear() - def test_create_thread_and_run_minimal(self, test_headers: dict[str, str]) -> None: + def test_create_thread_and_run_minimal( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test creating a thread and run with minimal data.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) + # Add assistant to database (required for foreign key) + workspace_id = UUID(test_headers["askui-workspace"]) + mock_assistant = self._create_test_assistant( + assistant_id="asst_test123", + workspace_id=workspace_id, + ) + self._add_assistant_to_db(mock_assistant, test_db_session) + from askui.chat.api.app import app from askui.chat.api.runs.dependencies import get_runs_service from askui.chat.api.threads.dependencies import get_thread_service @@ -462,19 +585,18 @@ def test_create_thread_and_run_minimal(self, test_headers: dict[str, str]) -> No def override_thread_service() -> ThreadService: from askui.chat.api.threads.service import ThreadService - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(session=test_db_session) def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() + settings = Settings(data_dir=workspace_path) return RunService( - base_dir=workspace_path, + session=test_db_session, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, chat_history_manager=Mock(), - settings=Mock(), + settings=settings, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -499,12 +621,20 @@ def override_runs_service() -> RunService: app.dependency_overrides.clear() def test_create_thread_and_run_streaming( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_db_session: Session ) -> None: """Test creating a streaming thread and run.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) + # Add assistant to database (required for foreign key) + workspace_id = UUID(test_headers["askui-workspace"]) + mock_assistant = self._create_test_assistant( + assistant_id="asst_test123", + workspace_id=workspace_id, + ) + self._add_assistant_to_db(mock_assistant, test_db_session) + from askui.chat.api.app import app from askui.chat.api.runs.dependencies import get_runs_service from askui.chat.api.threads.dependencies import get_thread_service @@ -512,19 +642,18 @@ def test_create_thread_and_run_streaming( def override_thread_service() -> ThreadService: from askui.chat.api.threads.service import ThreadService - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(session=test_db_session) def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() + settings = Settings(data_dir=workspace_path) return RunService( - base_dir=workspace_path, + session=test_db_session, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, chat_history_manager=Mock(), - settings=Mock(), + settings=settings, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -552,12 +681,20 @@ def override_runs_service() -> RunService: app.dependency_overrides.clear() def test_create_thread_and_run_with_messages( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_db_session: Session ) -> None: """Test creating a thread and run with initial messages.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) + # Add assistant to database (required for foreign key) + workspace_id = UUID(test_headers["askui-workspace"]) + mock_assistant = self._create_test_assistant( + assistant_id="asst_test123", + workspace_id=workspace_id, + ) + self._add_assistant_to_db(mock_assistant, test_db_session) + from askui.chat.api.app import app from askui.chat.api.runs.dependencies import get_runs_service from askui.chat.api.threads.dependencies import get_thread_service @@ -565,19 +702,18 @@ def test_create_thread_and_run_with_messages( def override_thread_service() -> ThreadService: from askui.chat.api.threads.service import ThreadService - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(session=test_db_session) def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() + settings = Settings(data_dir=workspace_path) return RunService( - base_dir=workspace_path, + session=test_db_session, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, chat_history_manager=Mock(), - settings=Mock(), + settings=settings, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -618,7 +754,7 @@ def override_runs_service() -> RunService: app.dependency_overrides.clear() def test_create_thread_and_run_validation_error( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_db_session: Session ) -> None: """Test creating thread and run with invalid data.""" temp_dir = tempfile.mkdtemp() @@ -631,19 +767,18 @@ def test_create_thread_and_run_validation_error( def override_thread_service() -> ThreadService: from askui.chat.api.threads.service import ThreadService - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(session=test_db_session) def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() + settings = Settings(data_dir=workspace_path) return RunService( - base_dir=workspace_path, + session=test_db_session, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, chat_history_manager=Mock(), - settings=Mock(), + settings=settings, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -666,12 +801,20 @@ def override_runs_service() -> RunService: app.dependency_overrides.clear() def test_create_thread_and_run_empty_thread( - self, test_headers: dict[str, str] + self, test_headers: dict[str, str], test_db_session: Session ) -> None: """Test creating thread and run with completely empty thread object.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) + # Add assistant to database (required for foreign key) + workspace_id = UUID(test_headers["askui-workspace"]) + mock_assistant = self._create_test_assistant( + assistant_id="asst_test123", + workspace_id=workspace_id, + ) + self._add_assistant_to_db(mock_assistant, test_db_session) + from askui.chat.api.app import app from askui.chat.api.runs.dependencies import get_runs_service from askui.chat.api.threads.dependencies import get_thread_service @@ -679,19 +822,18 @@ def test_create_thread_and_run_empty_thread( def override_thread_service() -> ThreadService: from askui.chat.api.threads.service import ThreadService - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(session=test_db_session) def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() + settings = Settings(data_dir=workspace_path) return RunService( - base_dir=workspace_path, + session=test_db_session, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, chat_history_manager=Mock(), - settings=Mock(), + settings=settings, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -726,7 +868,9 @@ def test_create_run_invalid_thread( assert response.status_code == status.HTTP_404_NOT_FOUND - def test_retrieve_run(self, test_headers: dict[str, str]) -> None: + def test_retrieve_run( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test retrieving an existing run.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) @@ -736,27 +880,43 @@ def test_retrieve_run(self, test_headers: dict[str, str]) -> None: runs_dir.mkdir(parents=True, exist_ok=True) # Create a mock thread + workspace_id = UUID(test_headers["askui-workspace"]) mock_thread = Thread( id="thread_test123", object="thread", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), name="Test Thread", + workspace_id=workspace_id, ) (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + # Add thread to database + self._add_thread_to_db(mock_thread, test_db_session) + + # Create and add assistant to database (required for foreign key) + mock_assistant = self._create_test_assistant( + assistant_id="asst_test123", + workspace_id=workspace_id, + ) + self._add_assistant_to_db(mock_assistant, test_db_session) + # Create a mock run mock_run = Run( id="run_test123", object="thread.run", - created_at=1234567890, + created_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), thread_id="thread_test123", assistant_id="asst_test123", - expires_at=1755846718, # 10 minutes later - started_at=1234567890, - completed_at=1234567900, + expires_at=datetime.fromtimestamp(1755846718, tz=timezone.utc), + started_at=datetime.fromtimestamp(1234567890, tz=timezone.utc), + completed_at=datetime.fromtimestamp(1234567900, tz=timezone.utc), + workspace_id=workspace_id, ) (runs_dir / "run_test123.json").write_text(mock_run.model_dump_json()) + # Add run to database + self._add_run_to_db(mock_run, test_db_session) + from askui.chat.api.app import app from askui.chat.api.runs.dependencies import get_runs_service from askui.chat.api.threads.dependencies import get_thread_service @@ -764,19 +924,18 @@ def test_retrieve_run(self, test_headers: dict[str, str]) -> None: def override_thread_service() -> ThreadService: from askui.chat.api.threads.service import ThreadService - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(session=test_db_session) def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() + settings = Settings(data_dir=workspace_path) return RunService( - base_dir=workspace_path, + session=test_db_session, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, chat_history_manager=Mock(), - settings=Mock(), + settings=settings, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -809,7 +968,9 @@ def test_retrieve_run_not_found( data = response.json() assert "detail" in data - def test_cancel_run(self, test_headers: dict[str, str]) -> None: + def test_cancel_run( + self, test_headers: dict[str, str], test_db_session: Session + ) -> None: """Test canceling an existing run.""" temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) @@ -819,27 +980,39 @@ def test_cancel_run(self, test_headers: dict[str, str]) -> None: runs_dir.mkdir(parents=True, exist_ok=True) # Create a mock thread + workspace_id = UUID(test_headers["askui-workspace"]) + import time + + current_time = int(time.time()) mock_thread = Thread( id="thread_test123", object="thread", - created_at=1234567890, + created_at=datetime.fromtimestamp(current_time, tz=timezone.utc), name="Test Thread", + workspace_id=workspace_id, ) (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + self._add_thread_to_db(mock_thread, test_db_session) - # Create a mock run - import time + # Create and add assistant to database (required for foreign key) + mock_assistant = self._create_test_assistant( + assistant_id="asst_test123", + workspace_id=workspace_id, + ) + self._add_assistant_to_db(mock_assistant, test_db_session) - current_time = int(time.time()) + # Create a mock run mock_run = Run( id="run_test123", object="thread.run", - created_at=current_time, + created_at=datetime.fromtimestamp(current_time, tz=timezone.utc), thread_id="thread_test123", assistant_id="asst_test123", - expires_at=current_time + 600, # 10 minutes later + expires_at=datetime.fromtimestamp(current_time + 600, tz=timezone.utc), + workspace_id=workspace_id, ) (runs_dir / "run_test123.json").write_text(mock_run.model_dump_json()) + self._add_run_to_db(mock_run, test_db_session) from askui.chat.api.app import app from askui.chat.api.runs.dependencies import get_runs_service @@ -848,19 +1021,18 @@ def test_cancel_run(self, test_headers: dict[str, str]) -> None: def override_thread_service() -> ThreadService: from askui.chat.api.threads.service import ThreadService - mock_message_service = Mock() - mock_run_service = Mock() - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(session=test_db_session) def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() + settings = Settings(data_dir=workspace_path) return RunService( - base_dir=workspace_path, + session=test_db_session, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, chat_history_manager=Mock(), - settings=Mock(), + settings=settings, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -876,9 +1048,8 @@ def override_runs_service() -> RunService: assert response.status_code == status.HTTP_200_OK data = response.json() assert data["id"] == "run_test123" - # The cancel operation sets tried_cancelling_at, making status - # "cancelling" - assert data["status"] == "cancelling" + # The cancel operation sets the status to "cancelled" + assert data["status"] == "cancelled" finally: app.dependency_overrides.clear() @@ -901,10 +1072,12 @@ def test_create_run_with_custom_assistant( ) -> None: """Test creating a run with a custom assistant.""" workspace_path = self._create_test_workspace() - self._create_test_thread(workspace_path) + workspace_id = UUID(test_headers["askui-workspace"]) + self._create_test_thread( + workspace_path, test_db_session=test_db_session, workspace_id=workspace_id + ) # Create a custom assistant in the database - workspace_id = UUID(test_headers["askui-workspace"]) custom_assistant = self._create_test_assistant( "asst_custom123", workspace_id=workspace_id, @@ -941,10 +1114,12 @@ def test_create_run_with_custom_assistant_empty_tools( ) -> None: """Test creating a run with a custom assistant that has empty tools.""" workspace_path = self._create_test_workspace() - self._create_test_thread(workspace_path) + workspace_id = UUID(test_headers["askui-workspace"]) + self._create_test_thread( + workspace_path, test_db_session=test_db_session, workspace_id=workspace_id + ) # Create a custom assistant with empty tools in the database - workspace_id = UUID(test_headers["askui-workspace"]) empty_tools_assistant = self._create_test_assistant( "asst_customempty123", workspace_id=workspace_id, diff --git a/tests/integration/chat/api/test_threads.py b/tests/integration/chat/api/test_threads.py index b2f525f3..e3cfbb72 100644 --- a/tests/integration/chat/api/test_threads.py +++ b/tests/integration/chat/api/test_threads.py @@ -1,15 +1,18 @@ """Integration tests for the threads API endpoints.""" -import tempfile -from pathlib import Path -from unittest.mock import Mock +from typing import TYPE_CHECKING +from uuid import UUID from fastapi import status from fastapi.testclient import TestClient +from sqlalchemy.orm import Session -from askui.chat.api.threads.models import Thread +from askui.chat.api.threads.models import ThreadCreate from askui.chat.api.threads.service import ThreadService +if TYPE_CHECKING: + from askui.chat.api.models import WorkspaceId + class TestThreadsAPI: """Test suite for the threads API endpoints.""" @@ -26,31 +29,26 @@ def test_list_threads_empty( assert data["data"] == [] assert data["has_more"] is False - def test_list_threads_with_threads(self, test_headers: dict[str, str]) -> None: + def test_list_threads_with_threads( + self, + test_db_session: Session, + test_headers: dict[str, str], + test_workspace_id: str, + ) -> None: """Test listing threads when threads exist.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - - # Create a mock thread - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=1234567890, - name="Test Thread", - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - from askui.chat.api.app import app from askui.chat.api.threads.dependencies import get_thread_service - # Mock the dependencies - mock_message_service = Mock() - mock_run_service = Mock() + thread_service = ThreadService(test_db_session) + workspace_id: WorkspaceId = UUID(test_workspace_id) + # Create a thread via the service + created_thread = thread_service.create( + workspace_id=workspace_id, + params=ThreadCreate(name="Test Thread"), + ) def override_thread_service() -> ThreadService: - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service @@ -62,39 +60,32 @@ def override_thread_service() -> ThreadService: data = response.json() assert data["object"] == "list" assert len(data["data"]) == 1 - assert data["data"][0]["id"] == "thread_test123" + assert data["data"][0]["id"] == created_thread.id assert data["data"][0]["name"] == "Test Thread" finally: app.dependency_overrides.clear() - def test_list_threads_with_pagination(self, test_headers: dict[str, str]) -> None: + def test_list_threads_with_pagination( + self, + test_db_session: Session, + test_headers: dict[str, str], + test_workspace_id: str, + ) -> None: """Test listing threads with pagination parameters.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - - # Create multiple mock threads - for i in range(5): - mock_thread = Thread( - id=f"thread_test{i}", - object="thread", - created_at=1234567890 + i, - name=f"Test Thread {i}", - ) - (threads_dir / f"thread_test{i}.json").write_text( - mock_thread.model_dump_json() - ) - from askui.chat.api.app import app from askui.chat.api.threads.dependencies import get_thread_service - # Mock the dependencies - mock_message_service = Mock() - mock_run_service = Mock() + thread_service = ThreadService(test_db_session) + workspace_id: WorkspaceId = UUID(test_workspace_id) + # Create multiple threads via the service + for i in range(5): + thread_service.create( + workspace_id=workspace_id, + params=ThreadCreate(name=f"Test Thread {i}"), + ) def override_thread_service() -> ThreadService: - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service @@ -109,20 +100,15 @@ def override_thread_service() -> ThreadService: finally: app.dependency_overrides.clear() - def test_create_thread(self, test_headers: dict[str, str]) -> None: + def test_create_thread( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test creating a new thread.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - from askui.chat.api.app import app from askui.chat.api.threads.dependencies import get_thread_service - # Mock the dependencies - mock_message_service = Mock() - mock_run_service = Mock() - def override_thread_service() -> ThreadService: - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service @@ -144,20 +130,15 @@ def override_thread_service() -> ThreadService: finally: app.dependency_overrides.clear() - def test_create_thread_minimal(self, test_headers: dict[str, str]) -> None: + def test_create_thread_minimal( + self, test_db_session: Session, test_headers: dict[str, str] + ) -> None: """Test creating a thread with minimal data.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - from askui.chat.api.app import app from askui.chat.api.threads.dependencies import get_thread_service - # Mock the dependencies - mock_message_service = Mock() - mock_run_service = Mock() - def override_thread_service() -> ThreadService: - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service @@ -172,42 +153,38 @@ def override_thread_service() -> ThreadService: finally: app.dependency_overrides.clear() - def test_retrieve_thread(self, test_headers: dict[str, str]) -> None: + def test_retrieve_thread( + self, + test_db_session: Session, + test_headers: dict[str, str], + test_workspace_id: str, + ) -> None: """Test retrieving an existing thread.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=1234567890, - name="Test Thread", - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - from askui.chat.api.app import app from askui.chat.api.threads.dependencies import get_thread_service - # Mock the dependencies - mock_message_service = Mock() - mock_run_service = Mock() + thread_service = ThreadService(test_db_session) + workspace_id: WorkspaceId = UUID(test_workspace_id) + # Create a thread via the service + created_thread = thread_service.create( + workspace_id=workspace_id, + params=ThreadCreate(name="Test Thread"), + ) def override_thread_service() -> ThreadService: - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service try: with TestClient(app) as client: response = client.get( - "/v1/threads/thread_test123", headers=test_headers + f"/v1/threads/{created_thread.id}", headers=test_headers ) assert response.status_code == status.HTTP_200_OK data = response.json() - assert data["id"] == "thread_test123" + assert data["id"] == created_thread.id assert data["name"] == "Test Thread" finally: app.dependency_overrides.clear() @@ -224,30 +201,26 @@ def test_retrieve_thread_not_found( data = response.json() assert "detail" in data - def test_modify_thread(self, test_headers: dict[str, str]) -> None: + def test_modify_thread( + self, + test_db_session: Session, + test_headers: dict[str, str], + test_workspace_id: str, + ) -> None: """Test modifying an existing thread.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=1234567890, - name="Original Name", - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - from askui.chat.api.app import app from askui.chat.api.threads.dependencies import get_thread_service - # Mock the dependencies - mock_message_service = Mock() - mock_run_service = Mock() + thread_service = ThreadService(test_db_session) + workspace_id: WorkspaceId = UUID(test_workspace_id) + # Create a thread via the service + created_thread = thread_service.create( + workspace_id=workspace_id, + params=ThreadCreate(name="Original Name"), + ) def override_thread_service() -> ThreadService: - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service @@ -257,41 +230,41 @@ def override_thread_service() -> ThreadService: "name": "Modified Name", } response = client.post( - "/v1/threads/thread_test123", json=modify_data, headers=test_headers + f"/v1/threads/{created_thread.id}", + json=modify_data, + headers=test_headers, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["name"] == "Modified Name" - assert data["id"] == "thread_test123" - assert data["created_at"] == 1234567890 + assert data["id"] == created_thread.id + # API returns Unix timestamp, convert datetime to timestamp for + # comparison + assert data["created_at"] == int(created_thread.created_at.timestamp()) finally: app.dependency_overrides.clear() - def test_modify_thread_partial(self, test_headers: dict[str, str]) -> None: + def test_modify_thread_partial( + self, + test_db_session: Session, + test_headers: dict[str, str], + test_workspace_id: str, + ) -> None: """Test modifying a thread with partial data.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=1234567890, - name="Original Name", - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - from askui.chat.api.app import app from askui.chat.api.threads.dependencies import get_thread_service - # Mock the dependencies - mock_message_service = Mock() - mock_run_service = Mock() + thread_service = ThreadService(test_db_session) + workspace_id: WorkspaceId = UUID(test_workspace_id) + # Create a thread via the service + created_thread = thread_service.create( + workspace_id=workspace_id, + params=ThreadCreate(name="Original Name"), + ) def override_thread_service() -> ThreadService: - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service @@ -299,7 +272,9 @@ def override_thread_service() -> ThreadService: with TestClient(app) as client: modify_data = {"name": "Only Name Modified"} response = client.post( - "/v1/threads/thread_test123", json=modify_data, headers=test_headers + f"/v1/threads/{created_thread.id}", + json=modify_data, + headers=test_headers, ) assert response.status_code == status.HTTP_200_OK @@ -319,45 +294,33 @@ def test_modify_thread_not_found( assert response.status_code == status.HTTP_404_NOT_FOUND - def test_delete_thread(self, test_headers: dict[str, str]) -> None: + def test_delete_thread( + self, + test_db_session: Session, + test_headers: dict[str, str], + test_workspace_id: str, + ) -> None: """Test deleting an existing thread.""" - temp_dir = tempfile.mkdtemp() - workspace_path = Path(temp_dir) - threads_dir = workspace_path / "threads" - threads_dir.mkdir(parents=True, exist_ok=True) - - # Create the directories that the delete operation will try to remove - messages_dir = workspace_path / "messages" / "thread_test123" - messages_dir.mkdir(parents=True, exist_ok=True) - runs_dir = workspace_path / "runs" / "thread_test123" - runs_dir.mkdir(parents=True, exist_ok=True) - - mock_thread = Thread( - id="thread_test123", - object="thread", - created_at=1234567890, - name="Test Thread", - ) - (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) - from askui.chat.api.app import app from askui.chat.api.threads.dependencies import get_thread_service - # Mock the dependencies with proper return values - mock_message_service = Mock() - mock_message_service.get_messages_dir.return_value = messages_dir - mock_run_service = Mock() - mock_run_service.get_runs_dir.return_value = runs_dir + thread_service = ThreadService(test_db_session) + workspace_id: WorkspaceId = UUID(test_workspace_id) + # Create a thread via the service + created_thread = thread_service.create( + workspace_id=workspace_id, + params=ThreadCreate(name="Test Thread"), + ) def override_thread_service() -> ThreadService: - return ThreadService(workspace_path, mock_message_service, mock_run_service) + return ThreadService(test_db_session) app.dependency_overrides[get_thread_service] = override_thread_service try: with TestClient(app) as client: response = client.delete( - "/v1/threads/thread_test123", headers=test_headers + f"/v1/threads/{created_thread.id}", headers=test_headers ) assert response.status_code == status.HTTP_204_NO_CONTENT diff --git a/tests/integration/tools/testing/test_execution_tools.py b/tests/integration/tools/testing/test_execution_tools.py index 791904ca..07c64da2 100644 --- a/tests/integration/tools/testing/test_execution_tools.py +++ b/tests/integration/tools/testing/test_execution_tools.py @@ -104,7 +104,12 @@ def test_retrieve_execution( execution_obj = _create_execution(scenario, feature.id) execution = create_tool(execution_obj) retrieved = retrieve_tool(execution.id) - assert retrieved.model_dump() == execution.model_dump() + # Compare excluding created_at due to potential timestamp precision differences + exec_dict = execution.model_dump(exclude={"created_at"}) + ret_dict = retrieved.model_dump(exclude={"created_at"}) + assert ret_dict == exec_dict + # Verify created_at timestamps are close (within 1 second) + assert abs((retrieved.created_at - execution.created_at).total_seconds()) < 1 @pytest.mark.parametrize( @@ -199,7 +204,12 @@ def test_modify_execution_noop( ) -> None: execution = create_tool(_create_execution(scenario, feature.id)) modified = modify_tool(execution.id, ExecutionModifyParams()) - assert modified.model_dump() == execution.model_dump() + # Compare excluding created_at due to potential timestamp precision differences + exec_dict = execution.model_dump(exclude={"created_at"}) + mod_dict = modified.model_dump(exclude={"created_at"}) + assert mod_dict == exec_dict + # Verify created_at timestamps are close (within 1 second) + assert abs((modified.created_at - execution.created_at).total_seconds()) < 1 def test_delete_execution( diff --git a/tests/integration/tools/testing/test_feature_tools.py b/tests/integration/tools/testing/test_feature_tools.py index f406d98e..203de8fb 100644 --- a/tests/integration/tools/testing/test_feature_tools.py +++ b/tests/integration/tools/testing/test_feature_tools.py @@ -98,7 +98,12 @@ def test_retrieve_feature( ) -> None: feature = create_tool(_create_params()) retrieved = retrieve_tool(feature.id) - assert retrieved.model_dump() == feature.model_dump() + # Compare excluding created_at due to potential timestamp precision differences + feat_dict = feature.model_dump(exclude={"created_at"}) + ret_dict = retrieved.model_dump(exclude={"created_at"}) + assert ret_dict == feat_dict + # Verify created_at timestamps are close (within 1 second) + assert abs((retrieved.created_at - feature.created_at).total_seconds()) < 1 @pytest.mark.parametrize( @@ -175,7 +180,12 @@ def test_modify_feature_noop( ) -> None: feature = create_tool(_create_params()) modified = modify_tool(feature.id, FeatureModifyParams()) - assert modified.model_dump() == feature.model_dump() + # Compare excluding created_at due to potential timestamp precision differences + feat_dict = feature.model_dump(exclude={"created_at"}) + mod_dict = modified.model_dump(exclude={"created_at"}) + assert mod_dict == feat_dict + # Verify created_at timestamps are close (within 1 second) + assert abs((modified.created_at - feature.created_at).total_seconds()) < 1 def test_delete_feature( diff --git a/tests/integration/tools/testing/test_scenario_tools.py b/tests/integration/tools/testing/test_scenario_tools.py index 541bcfd4..427b18db 100644 --- a/tests/integration/tools/testing/test_scenario_tools.py +++ b/tests/integration/tools/testing/test_scenario_tools.py @@ -145,7 +145,12 @@ def test_retrieve_scenario( ) -> None: scenario = create_tool(_create_params(feature.id)) retrieved = retrieve_tool(scenario.id) - assert retrieved.model_dump() == scenario.model_dump() + # Compare excluding created_at due to potential timestamp precision differences + scen_dict = scenario.model_dump(exclude={"created_at"}) + ret_dict = retrieved.model_dump(exclude={"created_at"}) + assert ret_dict == scen_dict + # Verify created_at timestamps are close (within 1 second) + assert abs((retrieved.created_at - scenario.created_at).total_seconds()) < 1 @pytest.mark.parametrize( @@ -242,7 +247,12 @@ def test_modify_scenario_noop( ) -> None: scenario = create_tool(_create_params(feature.id)) modified = modify_tool(scenario.id, ScenarioModifyParams()) - assert modified.model_dump() == scenario.model_dump() + # Compare excluding created_at due to potential timestamp precision differences + scen_dict = scenario.model_dump(exclude={"created_at"}) + mod_dict = modified.model_dump(exclude={"created_at"}) + assert mod_dict == scen_dict + # Verify created_at timestamps are close (within 1 second) + assert abs((modified.created_at - scenario.created_at).total_seconds()) < 1 def test_delete_scenario( diff --git a/tests/unit/test_request_document_translator.py b/tests/unit/test_request_document_translator.py index 93130a89..b5fea6a2 100644 --- a/tests/unit/test_request_document_translator.py +++ b/tests/unit/test_request_document_translator.py @@ -26,7 +26,7 @@ def translator( self, file_service: MagicMock ) -> RequestDocumentBlockParamTranslator: """Create translator instance.""" - return RequestDocumentBlockParamTranslator(file_service) + return RequestDocumentBlockParamTranslator(file_service, None) @pytest.fixture def cache_control(self) -> CacheControlEphemeralParam: @@ -35,7 +35,7 @@ def cache_control(self) -> CacheControlEphemeralParam: def test_init(self, file_service: MagicMock) -> None: """Test translator initialization.""" - translator = RequestDocumentBlockParamTranslator(file_service) + translator = RequestDocumentBlockParamTranslator(file_service, None) assert translator._file_service == file_service @pytest.mark.asyncio From 9b2de657a9c0e06a58b886fbfe606622535c8443 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Mon, 3 Nov 2025 11:11:58 +0100 Subject: [PATCH 12/14] fix(chat/migrations): fix migrations - remove buggy soft deletes --> we do not yet differentiate between what has been deleted and what is still used but rolling back now is as easy as just installing an older version --> we can differentiate later (maybe even doing a hard delete later instead of a soft delete) - add and switch to debug logging statements for all control structures --- .../057f82313448_import_json_assistants.py | 12 ++ .../1a2b3c4d5e6f_create_threads_table.py | 4 +- .../2b3c4d5e6f7a_create_messages_table.py | 4 +- ...37007a499ca7_soft_delete_assistants_dir.py | 87 -------------- .../3c4d5e6f7a8b_create_runs_table.py | 4 +- .../4d5e6f7a8b9c_import_json_threads.py | 24 +++- .../5a1b2c3d4e5f_create_mcp_configs_table.py | 4 +- .../5e6f7a8b9c0d_import_json_messages.py | 29 ++++- .../6b2c3d4e5f6a_import_json_mcp_configs.py | 12 ++ .../versions/6f7a8b9c0d1e_import_json_runs.py | 29 ++++- .../7a8b9c0d1e2f_soft_delete_threads_dirs.py | 97 ---------------- ...c3d4e5f6a7b_soft_delete_mcp_configs_dir.py | 86 -------------- .../8b9c0d1e2f3a_soft_delete_messages_dirs.py | 97 ---------------- .../8d9e0f1a2b3c_create_files_table.py | 4 +- .../9c0d1e2f3a4b_soft_delete_runs_dirs.py | 97 ---------------- .../9e0f1a2b3c4d_import_json_files.py | 20 ++++ .../a0f1a2b3c4d5_soft_delete_files_dirs.py | 108 ------------------ .../c35e88ea9595_seed_default_assistants.py | 2 +- 18 files changed, 131 insertions(+), 589 deletions(-) delete mode 100644 src/askui/chat/migrations/versions/37007a499ca7_soft_delete_assistants_dir.py delete mode 100644 src/askui/chat/migrations/versions/7a8b9c0d1e2f_soft_delete_threads_dirs.py delete mode 100644 src/askui/chat/migrations/versions/7c3d4e5f6a7b_soft_delete_mcp_configs_dir.py delete mode 100644 src/askui/chat/migrations/versions/8b9c0d1e2f3a_soft_delete_messages_dirs.py delete mode 100644 src/askui/chat/migrations/versions/9c0d1e2f3a4b_soft_delete_runs_dirs.py delete mode 100644 src/askui/chat/migrations/versions/a0f1a2b3c4d5_soft_delete_files_dirs.py diff --git a/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py b/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py index 49e73660..2b931f1c 100644 --- a/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py +++ b/src/askui/chat/migrations/versions/057f82313448_import_json_assistants.py @@ -33,6 +33,7 @@ def _insert_assistants_batch( ) -> None: """Insert a batch of assistants into the database, ignoring conflicts.""" if not assistants_batch: + logger.info("No assistants to insert, skipping batch") return connection.execute( @@ -50,6 +51,10 @@ def upgrade() -> None: # Skip if directory doesn't exist (e.g., first-time setup) if not assistants_dir.exists(): + logger.info( + "Assistants directory does not exist, skipping import of assistants", + extra={"assistants_dir": str(assistants_dir)}, + ) return # Get the table from the current database schema @@ -94,6 +99,9 @@ def downgrade() -> None: result = connection.execute(assistants_table.select()) rows = result.fetchall() if not rows: + logger.info( + "No assistants found in the database, skipping export of rows to json", + ) return for row in rows: @@ -103,6 +111,10 @@ def downgrade() -> None: ) json_path = assistants_dir / f"{assistant.id}.json" if json_path.exists(): + logger.info( + "Json file for assistant already exists, skipping export of row to json", + extra={"assistant_id": assistant.id, "json_path": str(json_path)}, + ) continue with json_path.open("w", encoding="utf-8") as f: f.write(assistant.model_dump_json()) diff --git a/src/askui/chat/migrations/versions/1a2b3c4d5e6f_create_threads_table.py b/src/askui/chat/migrations/versions/1a2b3c4d5e6f_create_threads_table.py index 7bb7c9b8..011676bf 100644 --- a/src/askui/chat/migrations/versions/1a2b3c4d5e6f_create_threads_table.py +++ b/src/askui/chat/migrations/versions/1a2b3c4d5e6f_create_threads_table.py @@ -1,7 +1,7 @@ """create_threads_table Revision ID: 1a2b3c4d5e6f -Revises: a0f1a2b3c4d5 +Revises: 9e0f1a2b3c4d Create Date: 2025-01-27 12:00:00.000000 """ @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. revision: str = "1a2b3c4d5e6f" -down_revision: Union[str, None] = "a0f1a2b3c4d5" +down_revision: Union[str, None] = "9e0f1a2b3c4d" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/src/askui/chat/migrations/versions/2b3c4d5e6f7a_create_messages_table.py b/src/askui/chat/migrations/versions/2b3c4d5e6f7a_create_messages_table.py index 68eb1edc..16a618c4 100644 --- a/src/askui/chat/migrations/versions/2b3c4d5e6f7a_create_messages_table.py +++ b/src/askui/chat/migrations/versions/2b3c4d5e6f7a_create_messages_table.py @@ -1,7 +1,7 @@ """create_messages_table Revision ID: 2b3c4d5e6f7a -Revises: 1a2b3c4d5e6f +Revises: 6f7a8b9c0d1e Create Date: 2025-01-27 12:01:00.000000 """ @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. revision: str = "2b3c4d5e6f7a" -down_revision: Union[str, None] = "1a2b3c4d5e6f" +down_revision: Union[str, None] = "6f7a8b9c0d1e" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/src/askui/chat/migrations/versions/37007a499ca7_soft_delete_assistants_dir.py b/src/askui/chat/migrations/versions/37007a499ca7_soft_delete_assistants_dir.py deleted file mode 100644 index 3a6435ff..00000000 --- a/src/askui/chat/migrations/versions/37007a499ca7_soft_delete_assistants_dir.py +++ /dev/null @@ -1,87 +0,0 @@ -"""soft_delete_assistants_dir - -Revision ID: 37007a499ca7 -Revises: c35e88ea9595 -Create Date: 2025-10-10 14:01:53.410908 - -""" - -import logging -from pathlib import Path -from typing import Sequence, Union - -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "37007a499ca7" -down_revision: Union[str, None] = "c35e88ea9595" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - -settings = SettingsV1() -assistants_dir = settings.data_dir / "assistants" - - -def upgrade() -> None: - """Soft delete the assistants directory by moving it to .deleted subdirectory.""" - - # Skip if directory doesn't exist - if not assistants_dir.exists(): - logger.info("Assistants directory does not exist, skipping soft delete") - return - - try: - # Create .deleted directory if it doesn't exist - deleted_dir = settings.data_dir / ".deleted" - deleted_dir.mkdir(parents=True, exist_ok=True) - - # Move assistants directory to .deleted subdirectory - deleted_assistants_dir = deleted_dir / "assistants" - if deleted_assistants_dir.exists(): - logger.info( - "Deleted assistants directory already exists, skipping soft delete", - extra={"deleted_assistants_dir": str(deleted_assistants_dir)}, - ) - return - - assistants_dir.rename(deleted_assistants_dir) - logger.info( - "Successfully soft deleted assistants directory", - extra={ - "assistants_dir": str(assistants_dir), - "deleted_assistants_dir": str(deleted_assistants_dir), - }, - ) - except Exception as e: - error_msg = "Failed to soft delete assistants directory" - logger.exception( - error_msg, - extra={"assistants_dir": str(assistants_dir)}, - ) - raise RuntimeError(error_msg) from e - - -def downgrade() -> None: - """Restore the assistants directory from .deleted subdirectory.""" - deleted_dir = settings.data_dir / ".deleted" - deleted_assistants_dir = deleted_dir / "assistants" - - if not deleted_assistants_dir.exists(): - logger.info("No deleted assistants directory found to restore") - return - - try: - deleted_assistants_dir.rename(assistants_dir) - logger.info( - "Successfully restored assistants directory", - extra={"assistants_dir": str(assistants_dir)}, - ) - except Exception as e: - error_msg = "Failed to restore assistants directory" - logger.exception( - error_msg, - extra={"assistants_dir": str(assistants_dir)}, - ) - raise RuntimeError(error_msg) from e diff --git a/src/askui/chat/migrations/versions/3c4d5e6f7a8b_create_runs_table.py b/src/askui/chat/migrations/versions/3c4d5e6f7a8b_create_runs_table.py index 72eab784..668c8e22 100644 --- a/src/askui/chat/migrations/versions/3c4d5e6f7a8b_create_runs_table.py +++ b/src/askui/chat/migrations/versions/3c4d5e6f7a8b_create_runs_table.py @@ -1,7 +1,7 @@ """create_runs_table Revision ID: 3c4d5e6f7a8b -Revises: 2b3c4d5e6f7a +Revises: 4d5e6f7a8b9c Create Date: 2025-01-27 12:02:00.000000 """ @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. revision: str = "3c4d5e6f7a8b" -down_revision: Union[str, None] = "2b3c4d5e6f7a" +down_revision: Union[str, None] = "4d5e6f7a8b9c" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/src/askui/chat/migrations/versions/4d5e6f7a8b9c_import_json_threads.py b/src/askui/chat/migrations/versions/4d5e6f7a8b9c_import_json_threads.py index b1b4a9fb..deaf3d1a 100644 --- a/src/askui/chat/migrations/versions/4d5e6f7a8b9c_import_json_threads.py +++ b/src/askui/chat/migrations/versions/4d5e6f7a8b9c_import_json_threads.py @@ -1,7 +1,7 @@ """import_json_threads Revision ID: 4d5e6f7a8b9c -Revises: 3c4d5e6f7a8b +Revises: 1a2b3c4d5e6f Create Date: 2025-01-27 12:03:00.000000 """ @@ -18,7 +18,7 @@ # revision identifiers, used by Alembic. revision: str = "4d5e6f7a8b9c" -down_revision: Union[str, None] = "3c4d5e6f7a8b" +down_revision: Union[str, None] = "1a2b3c4d5e6f" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -33,6 +33,7 @@ def _insert_threads_batch( ) -> None: """Insert a batch of threads into the database, ignoring conflicts.""" if not threads_batch: + logger.info("No threads to insert, skipping batch") return connection.execute( @@ -50,6 +51,10 @@ def upgrade() -> None: # noqa: C901 # Skip if workspaces directory doesn't exist (e.g., first-time setup) if not workspaces_dir.exists(): + logger.info( + "Workspaces directory does not exist, skipping import of threads", + extra={"workspaces_dir": str(workspaces_dir)}, + ) return # Get the table from the current database schema @@ -62,12 +67,20 @@ def upgrade() -> None: # noqa: C901 # Iterate through all workspace directories for workspace_dir in workspaces_dir.iterdir(): if not workspace_dir.is_dir(): + logger.info( + "Skipping non-directory in workspaces", + extra={"path": str(workspace_dir)}, + ) continue workspace_id = workspace_dir.name threads_dir = workspace_dir / "threads" if not threads_dir.exists(): + logger.info( + "Threads directory does not exist, skipping workspace", + extra={"workspace_id": workspace_id, "threads_dir": str(threads_dir)}, + ) continue # Get all JSON files in the threads directory @@ -102,6 +115,9 @@ def downgrade() -> None: result = connection.execute(threads_table.select()) rows = result.fetchall() if not rows: + logger.info( + "No threads found in the database, skipping export of rows to json", + ) return for row in rows: @@ -111,6 +127,10 @@ def downgrade() -> None: threads_dir.mkdir(parents=True, exist_ok=True) json_path = threads_dir / f"{thread_model.id}.json" if json_path.exists(): + logger.info( + "Json file for thread already exists, skipping export of row to json", + extra={"thread_id": thread_model.id, "json_path": str(json_path)}, + ) continue with json_path.open("w", encoding="utf-8") as f: f.write(thread_model.model_dump_json()) diff --git a/src/askui/chat/migrations/versions/5a1b2c3d4e5f_create_mcp_configs_table.py b/src/askui/chat/migrations/versions/5a1b2c3d4e5f_create_mcp_configs_table.py index c2a2e6ba..7c3c2001 100644 --- a/src/askui/chat/migrations/versions/5a1b2c3d4e5f_create_mcp_configs_table.py +++ b/src/askui/chat/migrations/versions/5a1b2c3d4e5f_create_mcp_configs_table.py @@ -1,7 +1,7 @@ """create_mcp_configs_table Revision ID: 5a1b2c3d4e5f -Revises: 37007a499ca7 +Revises: c35e88ea9595 Create Date: 2025-01-27 10:00:00.000000 """ @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. revision: str = "5a1b2c3d4e5f" -down_revision: Union[str, None] = "37007a499ca7" +down_revision: Union[str, None] = "c35e88ea9595" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/src/askui/chat/migrations/versions/5e6f7a8b9c0d_import_json_messages.py b/src/askui/chat/migrations/versions/5e6f7a8b9c0d_import_json_messages.py index d5602ae4..6312c65f 100644 --- a/src/askui/chat/migrations/versions/5e6f7a8b9c0d_import_json_messages.py +++ b/src/askui/chat/migrations/versions/5e6f7a8b9c0d_import_json_messages.py @@ -1,7 +1,7 @@ """import_json_messages Revision ID: 5e6f7a8b9c0d -Revises: 4d5e6f7a8b9c +Revises: 2b3c4d5e6f7a Create Date: 2025-01-27 12:04:00.000000 """ @@ -18,7 +18,7 @@ # revision identifiers, used by Alembic. revision: str = "5e6f7a8b9c0d" -down_revision: Union[str, None] = "4d5e6f7a8b9c" +down_revision: Union[str, None] = "2b3c4d5e6f7a" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -33,6 +33,7 @@ def _insert_messages_batch( ) -> None: """Insert a batch of messages into the database, handling foreign key violations.""" if not messages_batch: + logger.info("No messages to insert, skipping batch") return # Validate and fix foreign key references @@ -56,6 +57,7 @@ def _validate_and_fix_foreign_keys( # noqa: C901 - If run_id is invalid: set to None """ if not messages_batch: + logger.info("Empty message batch, nothing to validate") return [] # Extract all foreign key values @@ -180,6 +182,10 @@ def upgrade() -> None: # noqa: C901 """Import existing messages from JSON files in workspace directories.""" # Skip if workspaces directory doesn't exist (e.g., first-time setup) if not workspaces_dir.exists(): + logger.info( + "Workspaces directory does not exist, skipping import of messages", + extra={"workspaces_dir": str(workspaces_dir)}, + ) return # Get the table from the current database schema @@ -192,17 +198,29 @@ def upgrade() -> None: # noqa: C901 # Iterate through all workspace directories for workspace_dir in workspaces_dir.iterdir(): if not workspace_dir.is_dir(): + logger.info( + "Skipping non-directory in workspaces", + extra={"path": str(workspace_dir)}, + ) continue workspace_id = workspace_dir.name messages_dir = workspace_dir / "messages" if not messages_dir.exists(): + logger.info( + "Messages directory does not exist, skipping workspace", + extra={"workspace_id": workspace_id, "messages_dir": str(messages_dir)}, + ) continue # Iterate through thread directories for thread_dir in messages_dir.iterdir(): if not thread_dir.is_dir(): + logger.info( + "Skipping non-directory in messages", + extra={"path": str(thread_dir)}, + ) continue # Get all JSON files in the thread directory @@ -241,6 +259,9 @@ def downgrade() -> None: result = connection.execute(messages_table.select()) rows = result.fetchall() if not rows: + logger.info( + "No messages found in the database, skipping export of rows to json", + ) return for row in rows: @@ -257,6 +278,10 @@ def downgrade() -> None: messages_dir.mkdir(parents=True, exist_ok=True) json_path = messages_dir / f"{message_model.id}.json" if json_path.exists(): + logger.info( + "Json file for message already exists, skipping export of row to json", + extra={"message_id": message_model.id, "json_path": str(json_path)}, + ) continue with json_path.open("w", encoding="utf-8") as f: f.write(message_model.model_dump_json()) diff --git a/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py b/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py index 69dd2b6b..ca11ef15 100644 --- a/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py +++ b/src/askui/chat/migrations/versions/6b2c3d4e5f6a_import_json_mcp_configs.py @@ -35,6 +35,7 @@ def _insert_mcp_configs_batch( ) -> None: """Insert a batch of MCP configs into the database, ignoring conflicts.""" if not mcp_configs_batch: + logger.info("No MCP configs to insert, skipping batch") return connection.execute( @@ -52,6 +53,10 @@ def upgrade() -> None: # Skip if directory doesn't exist (e.g., first-time setup) if not mcp_configs_dir.exists(): + logger.info( + "MCP configs directory does not exist, skipping import of MCP configs", + extra={"mcp_configs_dir": str(mcp_configs_dir)}, + ) return # Get the table from the current database schema @@ -97,6 +102,9 @@ def downgrade() -> None: result = connection.execute(mcp_configs_table.select()) rows = result.fetchall() if not rows: + logger.info( + "No MCP configs found in the database, skipping export of rows to json", + ) return for row in rows: @@ -106,6 +114,10 @@ def downgrade() -> None: ) json_path = mcp_configs_dir / f"{mcp_config.id}.json" if json_path.exists(): + logger.info( + "Json file for mcp config already exists, skipping export of row to json", + extra={"mcp_config_id": mcp_config.id, "json_path": str(json_path)}, + ) continue with json_path.open("w", encoding="utf-8") as f: f.write(mcp_config.model_dump_json()) diff --git a/src/askui/chat/migrations/versions/6f7a8b9c0d1e_import_json_runs.py b/src/askui/chat/migrations/versions/6f7a8b9c0d1e_import_json_runs.py index f0873662..6c63c2fa 100644 --- a/src/askui/chat/migrations/versions/6f7a8b9c0d1e_import_json_runs.py +++ b/src/askui/chat/migrations/versions/6f7a8b9c0d1e_import_json_runs.py @@ -1,7 +1,7 @@ """import_json_runs Revision ID: 6f7a8b9c0d1e -Revises: 5e6f7a8b9c0d +Revises: 3c4d5e6f7a8b Create Date: 2025-01-27 12:05:00.000000 """ @@ -18,7 +18,7 @@ # revision identifiers, used by Alembic. revision: str = "6f7a8b9c0d1e" -down_revision: Union[str, None] = "5e6f7a8b9c0d" +down_revision: Union[str, None] = "3c4d5e6f7a8b" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -33,6 +33,7 @@ def _insert_runs_batch( ) -> None: """Insert a batch of runs into the database, handling foreign key violations.""" if not runs_batch: + logger.info("No runs to insert, skipping batch") return # Validate and fix foreign key references @@ -55,6 +56,7 @@ def _validate_and_fix_foreign_keys( # noqa: C901 - If assistant_id is invalid: set to None """ if not runs_batch: + logger.info("Empty run batch, nothing to validate") return [] # Extract all foreign key values @@ -162,6 +164,10 @@ def upgrade() -> None: # noqa: C901 # Skip if workspaces directory doesn't exist (e.g., first-time setup) if not workspaces_dir.exists(): + logger.info( + "Workspaces directory does not exist, skipping import of runs", + extra={"workspaces_dir": str(workspaces_dir)}, + ) return # Get the table from the current database schema @@ -174,17 +180,29 @@ def upgrade() -> None: # noqa: C901 # Iterate through all workspace directories for workspace_dir in workspaces_dir.iterdir(): if not workspace_dir.is_dir(): + logger.info( + "Skipping non-directory in workspaces", + extra={"path": str(workspace_dir)}, + ) continue workspace_id = workspace_dir.name runs_dir = workspace_dir / "runs" if not runs_dir.exists(): + logger.info( + "Runs directory does not exist, skipping workspace", + extra={"workspace_id": workspace_id, "runs_dir": str(runs_dir)}, + ) continue # Iterate through thread directories for thread_dir in runs_dir.iterdir(): if not thread_dir.is_dir(): + logger.info( + "Skipping non-directory in runs", + extra={"path": str(thread_dir)}, + ) continue # Get all JSON files in the thread directory @@ -219,6 +237,9 @@ def downgrade() -> None: result = connection.execute(runs_table.select()) rows = result.fetchall() if not rows: + logger.info( + "No runs found in the database, skipping export of rows to json", + ) return for row in rows: @@ -233,6 +254,10 @@ def downgrade() -> None: runs_dir.mkdir(parents=True, exist_ok=True) json_path = runs_dir / f"{run_model.id}.json" if json_path.exists(): + logger.info( + "Json file for run already exists, skipping export of row to json", + extra={"run_id": run_model.id, "json_path": str(json_path)}, + ) continue with json_path.open("w", encoding="utf-8") as f: f.write(run_model.model_dump_json()) diff --git a/src/askui/chat/migrations/versions/7a8b9c0d1e2f_soft_delete_threads_dirs.py b/src/askui/chat/migrations/versions/7a8b9c0d1e2f_soft_delete_threads_dirs.py deleted file mode 100644 index 8c351eb8..00000000 --- a/src/askui/chat/migrations/versions/7a8b9c0d1e2f_soft_delete_threads_dirs.py +++ /dev/null @@ -1,97 +0,0 @@ -"""soft_delete_threads_dirs - -Revision ID: 7a8b9c0d1e2f -Revises: 6f7a8b9c0d1e -Create Date: 2025-01-27 12:06:00.000000 - -""" - -import logging -from typing import Sequence, Union - -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "7a8b9c0d1e2f" -down_revision: Union[str, None] = "6f7a8b9c0d1e" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - -settings = SettingsV1() -workspaces_dir = settings.data_dir / "workspaces" - - -def upgrade() -> None: - """Soft delete threads directories by moving them to .deleted subdirectory.""" - - # Skip if workspaces directory doesn't exist (e.g., first-time setup) - if not workspaces_dir.exists(): - return - - # Create .deleted directory if it doesn't exist - deleted_dir = settings.data_dir / ".deleted" - deleted_dir.mkdir(parents=True, exist_ok=True) - - # Soft delete threads directories from all workspaces - for workspace_dir in workspaces_dir.iterdir(): - if not workspace_dir.is_dir(): - continue - - threads_dir = workspace_dir / "threads" - if threads_dir.exists(): - try: - # Create workspace-specific deleted directory - deleted_workspace_dir = deleted_dir / workspace_dir.name - deleted_workspace_dir.mkdir(parents=True, exist_ok=True) - - # Move threads directory to .deleted subdirectory - deleted_threads_dir = deleted_workspace_dir / "threads" - if deleted_threads_dir.exists(): - logger.info( - "Deleted threads directory already exists, skipping soft delete", - extra={"deleted_threads_dir": str(deleted_threads_dir)}, - ) - continue - - threads_dir.rename(deleted_threads_dir) - logger.info( - "Soft deleted threads directory", - extra={ - "threads_dir": str(threads_dir), - "deleted_threads_dir": str(deleted_threads_dir), - }, - ) - except Exception as e: - error_msg = f"Failed to soft delete threads directory: {threads_dir}" - logger.exception(error_msg, exc_info=e) - - -def downgrade() -> None: - """Restore threads directories from .deleted subdirectory.""" - deleted_dir = settings.data_dir / ".deleted" - - if not deleted_dir.exists(): - logger.info("No .deleted directory found to restore from") - return - - # Restore threads directories for all workspaces - for workspace_dir in workspaces_dir.iterdir(): - if not workspace_dir.is_dir(): - continue - - deleted_workspace_dir = deleted_dir / workspace_dir.name - deleted_threads_dir = deleted_workspace_dir / "threads" - - if deleted_threads_dir.exists(): - threads_dir = workspace_dir / "threads" - try: - deleted_threads_dir.rename(threads_dir) - logger.info( - "Restored threads directory", - extra={"threads_dir": str(threads_dir)}, - ) - except Exception as e: - error_msg = f"Failed to restore threads directory: {threads_dir}" - logger.exception(error_msg, exc_info=e) diff --git a/src/askui/chat/migrations/versions/7c3d4e5f6a7b_soft_delete_mcp_configs_dir.py b/src/askui/chat/migrations/versions/7c3d4e5f6a7b_soft_delete_mcp_configs_dir.py deleted file mode 100644 index c683e87d..00000000 --- a/src/askui/chat/migrations/versions/7c3d4e5f6a7b_soft_delete_mcp_configs_dir.py +++ /dev/null @@ -1,86 +0,0 @@ -"""soft_delete_mcp_configs_dir - -Revision ID: 7c3d4e5f6a7b -Revises: 6b2c3d4e5f6a -Create Date: 2025-01-27 10:02:00.000000 - -""" - -import logging -from typing import Sequence, Union - -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "7c3d4e5f6a7b" -down_revision: Union[str, None] = "6b2c3d4e5f6a" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - -settings = SettingsV1() -mcp_configs_dir = settings.data_dir / "mcp_configs" - - -def upgrade() -> None: - """Soft delete the mcp_configs directory by moving it to .deleted subdirectory.""" - - # Skip if directory doesn't exist - if not mcp_configs_dir.exists(): - logger.info("MCP configs directory does not exist, skipping soft delete") - return - - try: - # Create .deleted directory if it doesn't exist - deleted_dir = settings.data_dir / ".deleted" - deleted_dir.mkdir(parents=True, exist_ok=True) - - # Move mcp_configs directory to .deleted subdirectory - deleted_mcp_configs_dir = deleted_dir / "mcp_configs" - if deleted_mcp_configs_dir.exists(): - logger.info( - "Deleted mcp_configs directory already exists, skipping soft delete", - extra={"deleted_mcp_configs_dir": str(deleted_mcp_configs_dir)}, - ) - return - - mcp_configs_dir.rename(deleted_mcp_configs_dir) - logger.info( - "Successfully soft deleted mcp_configs directory", - extra={ - "mcp_configs_dir": str(mcp_configs_dir), - "deleted_mcp_configs_dir": str(deleted_mcp_configs_dir), - }, - ) - except Exception as e: - error_msg = "Failed to soft delete mcp_configs directory" - logger.exception( - error_msg, - extra={"mcp_configs_dir": str(mcp_configs_dir)}, - ) - raise RuntimeError(error_msg) from e - - -def downgrade() -> None: - """Restore the mcp_configs directory from .deleted subdirectory.""" - deleted_dir = settings.data_dir / ".deleted" - deleted_mcp_configs_dir = deleted_dir / "mcp_configs" - - if not deleted_mcp_configs_dir.exists(): - logger.info("No deleted mcp_configs directory found to restore") - return - - try: - deleted_mcp_configs_dir.rename(mcp_configs_dir) - logger.info( - "Successfully restored mcp_configs directory", - extra={"mcp_configs_dir": str(mcp_configs_dir)}, - ) - except Exception as e: - error_msg = "Failed to restore mcp_configs directory" - logger.exception( - error_msg, - extra={"mcp_configs_dir": str(mcp_configs_dir)}, - ) - raise RuntimeError(error_msg) from e diff --git a/src/askui/chat/migrations/versions/8b9c0d1e2f3a_soft_delete_messages_dirs.py b/src/askui/chat/migrations/versions/8b9c0d1e2f3a_soft_delete_messages_dirs.py deleted file mode 100644 index 2ce14b78..00000000 --- a/src/askui/chat/migrations/versions/8b9c0d1e2f3a_soft_delete_messages_dirs.py +++ /dev/null @@ -1,97 +0,0 @@ -"""soft_delete_messages_dirs - -Revision ID: 8b9c0d1e2f3a -Revises: 7a8b9c0d1e2f -Create Date: 2025-01-27 12:07:00.000000 - -""" - -import logging -from typing import Sequence, Union - -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "8b9c0d1e2f3a" -down_revision: Union[str, None] = "7a8b9c0d1e2f" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - -settings = SettingsV1() -workspaces_dir = settings.data_dir / "workspaces" - - -def upgrade() -> None: - """Soft delete messages directories by moving them to .deleted subdirectory.""" - - # Skip if workspaces directory doesn't exist (e.g., first-time setup) - if not workspaces_dir.exists(): - return - - # Create .deleted directory if it doesn't exist - deleted_dir = settings.data_dir / ".deleted" - deleted_dir.mkdir(parents=True, exist_ok=True) - - # Soft delete messages directories from all workspaces - for workspace_dir in workspaces_dir.iterdir(): - if not workspace_dir.is_dir(): - continue - - messages_dir = workspace_dir / "messages" - if messages_dir.exists(): - try: - # Create workspace-specific deleted directory - deleted_workspace_dir = deleted_dir / workspace_dir.name - deleted_workspace_dir.mkdir(parents=True, exist_ok=True) - - # Move messages directory to .deleted subdirectory - deleted_messages_dir = deleted_workspace_dir / "messages" - if deleted_messages_dir.exists(): - logger.info( - "Deleted messages directory already exists, skipping soft delete", - extra={"deleted_messages_dir": str(deleted_messages_dir)}, - ) - continue - - messages_dir.rename(deleted_messages_dir) - logger.info( - "Soft deleted messages directory", - extra={ - "messages_dir": str(messages_dir), - "deleted_messages_dir": str(deleted_messages_dir), - }, - ) - except Exception as e: - error_msg = f"Failed to soft delete messages directory: {messages_dir}" - logger.exception(error_msg, exc_info=e) - - -def downgrade() -> None: - """Restore messages directories from .deleted subdirectory.""" - deleted_dir = settings.data_dir / ".deleted" - - if not deleted_dir.exists(): - logger.info("No .deleted directory found to restore from") - return - - # Restore messages directories for all workspaces - for workspace_dir in workspaces_dir.iterdir(): - if not workspace_dir.is_dir(): - continue - - deleted_workspace_dir = deleted_dir / workspace_dir.name - deleted_messages_dir = deleted_workspace_dir / "messages" - - if deleted_messages_dir.exists(): - try: - messages_dir = workspace_dir / "messages" - deleted_messages_dir.rename(messages_dir) - logger.info( - "Restored messages directory", - extra={"messages_dir": str(messages_dir)}, - ) - except Exception as e: - error_msg = f"Failed to restore messages directory: {messages_dir}" - logger.exception(error_msg, exc_info=e) diff --git a/src/askui/chat/migrations/versions/8d9e0f1a2b3c_create_files_table.py b/src/askui/chat/migrations/versions/8d9e0f1a2b3c_create_files_table.py index 932e1a6e..c227e0c8 100644 --- a/src/askui/chat/migrations/versions/8d9e0f1a2b3c_create_files_table.py +++ b/src/askui/chat/migrations/versions/8d9e0f1a2b3c_create_files_table.py @@ -1,7 +1,7 @@ """create_files_table Revision ID: 8d9e0f1a2b3c -Revises: 7c3d4e5f6a7b +Revises: 6b2c3d4e5f6a Create Date: 2025-01-27 11:00:00.000000 """ @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. revision: str = "8d9e0f1a2b3c" -down_revision: Union[str, None] = "7c3d4e5f6a7b" +down_revision: Union[str, None] = "6b2c3d4e5f6a" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/src/askui/chat/migrations/versions/9c0d1e2f3a4b_soft_delete_runs_dirs.py b/src/askui/chat/migrations/versions/9c0d1e2f3a4b_soft_delete_runs_dirs.py deleted file mode 100644 index 3fde9bea..00000000 --- a/src/askui/chat/migrations/versions/9c0d1e2f3a4b_soft_delete_runs_dirs.py +++ /dev/null @@ -1,97 +0,0 @@ -"""soft_delete_runs_dirs - -Revision ID: 9c0d1e2f3a4b -Revises: 8b9c0d1e2f3a -Create Date: 2025-01-27 12:08:00.000000 - -""" - -import logging -from typing import Sequence, Union - -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "9c0d1e2f3a4b" -down_revision: Union[str, None] = "8b9c0d1e2f3a" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - -settings = SettingsV1() -workspaces_dir = settings.data_dir / "workspaces" - - -def upgrade() -> None: - """Soft delete runs directories by moving them to .deleted subdirectory.""" - - # Skip if workspaces directory doesn't exist (e.g., first-time setup) - if not workspaces_dir.exists(): - return - - # Create .deleted directory if it doesn't exist - deleted_dir = settings.data_dir / ".deleted" - deleted_dir.mkdir(parents=True, exist_ok=True) - - # Soft delete runs directories from all workspaces - for workspace_dir in workspaces_dir.iterdir(): - if not workspace_dir.is_dir(): - continue - - runs_dir = workspace_dir / "runs" - if runs_dir.exists(): - try: - # Create workspace-specific deleted directory - deleted_workspace_dir = deleted_dir / workspace_dir.name - deleted_workspace_dir.mkdir(parents=True, exist_ok=True) - - # Move runs directory to .deleted subdirectory - deleted_runs_dir = deleted_workspace_dir / "runs" - if deleted_runs_dir.exists(): - logger.info( - "Deleted runs directory already exists, skipping soft delete", - extra={"deleted_runs_dir": str(deleted_runs_dir)}, - ) - continue - - runs_dir.rename(deleted_runs_dir) - logger.info( - "Soft deleted runs directory", - extra={ - "runs_dir": str(runs_dir), - "deleted_runs_dir": str(deleted_runs_dir), - }, - ) - except Exception as e: - error_msg = f"Failed to soft delete runs directory: {runs_dir}" - logger.exception(error_msg, exc_info=e) - - -def downgrade() -> None: - """Restore runs directories from .deleted subdirectory.""" - deleted_dir = settings.data_dir / ".deleted" - - if not deleted_dir.exists(): - logger.info("No .deleted directory found to restore from") - return - - # Restore runs directories for all workspaces - for workspace_dir in workspaces_dir.iterdir(): - if not workspace_dir.is_dir(): - continue - - deleted_workspace_dir = deleted_dir / workspace_dir.name - deleted_runs_dir = deleted_workspace_dir / "runs" - - if deleted_runs_dir.exists(): - try: - runs_dir = workspace_dir / "runs" - deleted_runs_dir.rename(runs_dir) - logger.info( - "Restored runs directory", - extra={"runs_dir": str(runs_dir)}, - ) - except Exception as e: - error_msg = f"Failed to restore runs directory: {runs_dir}" - logger.exception(error_msg, exc_info=e) diff --git a/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py b/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py index 5776ebb5..750cce56 100644 --- a/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py +++ b/src/askui/chat/migrations/versions/9e0f1a2b3c4d_import_json_files.py @@ -33,6 +33,7 @@ def _insert_files_batch( ) -> None: """Insert a batch of files into the database, ignoring conflicts.""" if not files_batch: + logger.info("No files to insert, skipping batch") return connection.execute( @@ -50,6 +51,10 @@ def upgrade() -> None: # noqa: C901 # Skip if workspaces directory doesn't exist (e.g., first-time setup) if not workspaces_dir.exists(): + logger.info( + "Workspaces directory does not exist, skipping import of files", + extra={"workspaces_dir": str(workspaces_dir)}, + ) return # Get the table from the current database schema @@ -62,12 +67,20 @@ def upgrade() -> None: # noqa: C901 # Iterate through all workspace directories for workspace_dir in workspaces_dir.iterdir(): if not workspace_dir.is_dir(): + logger.info( + "Skipping non-directory in workspaces", + extra={"path": str(workspace_dir)}, + ) continue workspace_id = workspace_dir.name files_dir = workspace_dir / "files" if not files_dir.exists(): + logger.info( + "Files directory does not exist, skipping workspace", + extra={"workspace_id": workspace_id, "files_dir": str(files_dir)}, + ) continue # Get all JSON files in the static directory @@ -102,6 +115,9 @@ def downgrade() -> None: result = connection.execute(files_table.select()) rows = result.fetchall() if not rows: + logger.info( + "No files found in the database, skipping export of rows to json", + ) return for row in rows: @@ -114,6 +130,10 @@ def downgrade() -> None: files_dir.mkdir(parents=True, exist_ok=True) json_path = files_dir / f"{file_model.id}.json" if json_path.exists(): + logger.info( + "Json file for file already exists, skipping export of row to json", + extra={"file_id": file_model.id, "json_path": str(json_path)}, + ) continue with json_path.open("w", encoding="utf-8") as f: f.write(file_model.model_dump_json()) diff --git a/src/askui/chat/migrations/versions/a0f1a2b3c4d5_soft_delete_files_dirs.py b/src/askui/chat/migrations/versions/a0f1a2b3c4d5_soft_delete_files_dirs.py deleted file mode 100644 index 75b249ea..00000000 --- a/src/askui/chat/migrations/versions/a0f1a2b3c4d5_soft_delete_files_dirs.py +++ /dev/null @@ -1,108 +0,0 @@ -"""soft_delete_files_dirs - -Revision ID: a0f1a2b3c4d5 -Revises: 9e0f1a2b3c4d -Create Date: 2025-01-27 11:02:00.000000 - -""" - -import logging -from typing import Sequence, Union - -from askui.chat.migrations.shared.settings import SettingsV1 - -# revision identifiers, used by Alembic. -revision: str = "a0f1a2b3c4d5" -down_revision: Union[str, None] = "9e0f1a2b3c4d" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -logger = logging.getLogger(__name__) - -settings = SettingsV1() -workspaces_dir = settings.data_dir / "workspaces" - - -def upgrade() -> None: - """Soft delete JSON files from workspace static directories by moving them to .deleted subdirectory.""" - - # Skip if workspaces directory doesn't exist - if not workspaces_dir.exists(): - logger.info( - "Workspaces directory does not exist, skipping soft delete", - extra={"workspaces_dir": str(workspaces_dir)}, - ) - return - - # Create .deleted directory if it doesn't exist - deleted_dir = settings.data_dir / ".deleted" - deleted_dir.mkdir(parents=True, exist_ok=True) - - # Iterate through all workspace directories - for workspace_dir in workspaces_dir.iterdir(): - if not workspace_dir.is_dir(): - continue - - files_dir = workspace_dir / "files" - if not files_dir.exists(): - logger.info( - "Files directory does not exist, skipping soft delete", - extra={"files_dir": str(files_dir)}, - ) - continue - - try: - # Create workspace-specific deleted directory - deleted_workspace_dir = deleted_dir / workspace_dir.name - deleted_workspace_dir.mkdir(parents=True, exist_ok=True) - - # Move files directory to .deleted subdirectory - deleted_files_dir = deleted_workspace_dir / "files" - if deleted_files_dir.exists(): - logger.info( - "Deleted files directory already exists, skipping soft delete", - extra={"deleted_files_dir": str(deleted_files_dir)}, - ) - continue - - files_dir.rename(deleted_files_dir) - logger.info( - "Successfully soft deleted files directory", - extra={ - "files_dir": str(files_dir), - "deleted_files_dir": str(deleted_files_dir), - }, - ) - except Exception as e: # noqa: PERF203 - error_msg = "Failed to soft delete files directory" - logger.exception(error_msg, extra={"files_dir": str(files_dir)}) - raise RuntimeError(error_msg) from e - - -def downgrade() -> None: - """Restore JSON files in workspace static directories from .deleted subdirectory.""" - deleted_dir = settings.data_dir / ".deleted" - - if not deleted_dir.exists(): - logger.info("No .deleted directory found to restore from") - return - - # Restore files directories for all workspaces - for workspace_dir in workspaces_dir.iterdir(): - if not workspace_dir.is_dir(): - continue - - deleted_workspace_dir = deleted_dir / workspace_dir.name - deleted_files_dir = deleted_workspace_dir / "files" - - if deleted_files_dir.exists(): - try: - files_dir = workspace_dir / "files" - deleted_files_dir.rename(files_dir) - logger.info( - "Successfully restored files directory", - extra={"files_dir": str(files_dir)}, - ) - except Exception as e: - error_msg = f"Failed to restore files directory: {files_dir}" - logger.exception(error_msg, exc_info=e) diff --git a/src/askui/chat/migrations/versions/c35e88ea9595_seed_default_assistants.py b/src/askui/chat/migrations/versions/c35e88ea9595_seed_default_assistants.py index 9f9e0492..ad47e506 100644 --- a/src/askui/chat/migrations/versions/c35e88ea9595_seed_default_assistants.py +++ b/src/askui/chat/migrations/versions/c35e88ea9595_seed_default_assistants.py @@ -28,7 +28,7 @@ def upgrade() -> None: """Seed default assistants one by one, skipping duplicates. For each assistant in `SEEDS_V1`, insert a row into `assistants`. If a - row with the same `id` already exists, skip it and log on info level. + row with the same `id` already exists, skip it and log on debug level. """ connection = op.get_bind() assistants_table: Table = Table("assistants", MetaData(), autoload_with=connection) From 56db2dd517ee633f4c78fa764da66dcf6aa95842 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Mon, 3 Nov 2025 11:46:44 +0100 Subject: [PATCH 13/14] docs(migrations): add migration strategy documentation --- docs/migrations.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/migrations.md b/docs/migrations.md index cb694625..ffc76df4 100644 --- a/docs/migrations.md +++ b/docs/migrations.md @@ -24,6 +24,17 @@ The current migration history shows several real-world examples: 3. **`c35e88ea9595_seed_default_assistants.py`**: Seeds the database with default assistant configurations 4. **`37007a499ca7_remove_assistants_dir.py`**: Cleans up the old JSON-based persistence by removing the assistants directory +### Our current migration strategy + +#### Until `5e6f7a8b9c0d_import_json_messages.py` + +On Upgrade: +- We migrate from file system persistence to SQLite database persistence. We don't delete any of the files from the file system so rolling back is as easy as just installing an older version of the `askui` library. + +On Downgrade: +- This is mainly to be used by us for debugging and testing new migrations but not a user. +- We export data from database but already existing files take precedence so you may loose some data that was upgraded or deleted between the upgrade and downgrade. Also you may loose some of the data that was not originally available in the schema, e.g., global files (not scoped to workspace). + ## Automatic Migrations on Startup By default, migrations are automatically run when the chat API starts up. This ensures that users are always upgraded to the newest database schema version without manual intervention. From 48205f30a733d73787b303115800d01aaf67b357 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Mon, 3 Nov 2025 12:03:10 +0100 Subject: [PATCH 14/14] fix: reallow connecting from localhost:4200 localhost without port does not work --- src/askui/chat/api/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/askui/chat/api/settings.py b/src/askui/chat/api/settings.py index 3f6dc3a5..72e9b6c9 100644 --- a/src/askui/chat/api/settings.py +++ b/src/askui/chat/api/settings.py @@ -97,7 +97,7 @@ class Settings(BaseSettings): ) allow_origins: list[str] = Field( default_factory=lambda: [ - "http://localhost", + "http://localhost:4200", "https://app.caesr.ai", "https://app-dev.caesr.ai", "https://hub.askui.com",