From 58ef65b948c0257c6a37a54157e69f6b6d3d1602 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 19 Aug 2025 11:30:51 +0200 Subject: [PATCH 01/11] feat(chat): add mcp configs api --- src/askui/chat/api/mcp_configs/router.py | 26 ++++++------------------ 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/src/askui/chat/api/mcp_configs/router.py b/src/askui/chat/api/mcp_configs/router.py index 62386785..ed30f571 100644 --- a/src/askui/chat/api/mcp_configs/router.py +++ b/src/askui/chat/api/mcp_configs/router.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, status from askui.chat.api.mcp_configs.dependencies import McpConfigServiceDep from askui.chat.api.mcp_configs.models import ( @@ -8,7 +8,7 @@ ) from askui.chat.api.mcp_configs.service import McpConfigService from askui.chat.api.models import ListQueryDep, McpConfigId -from askui.utils.api_utils import LimitReachedError, ListQuery, ListResponse +from askui.utils.api_utils import ListQuery, ListResponse router = APIRouter(prefix="/mcp-configs", tags=["mcp-configs"]) @@ -28,12 +28,7 @@ def create_mcp_config( mcp_config_service: McpConfigService = McpConfigServiceDep, ) -> McpConfig: """Create a new MCP configuration.""" - try: - return mcp_config_service.create(params) - except LimitReachedError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) - ) from e + return mcp_config_service.create(params) @router.get("/{mcp_config_id}", response_model_exclude_none=True) @@ -42,10 +37,7 @@ def retrieve_mcp_config( mcp_config_service: McpConfigService = McpConfigServiceDep, ) -> McpConfig: """Get an MCP configuration by ID.""" - try: - return mcp_config_service.retrieve(mcp_config_id) - except FileNotFoundError as e: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + return mcp_config_service.retrieve(mcp_config_id) @router.patch("/{mcp_config_id}", response_model_exclude_none=True) @@ -55,10 +47,7 @@ def modify_mcp_config( mcp_config_service: McpConfigService = McpConfigServiceDep, ) -> McpConfig: """Update an MCP configuration.""" - try: - return mcp_config_service.modify(mcp_config_id, params) - except FileNotFoundError as e: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + return mcp_config_service.modify(mcp_config_id, params) @router.delete("/{mcp_config_id}", status_code=status.HTTP_204_NO_CONTENT) @@ -67,7 +56,4 @@ def delete_mcp_config( mcp_config_service: McpConfigService = McpConfigServiceDep, ) -> None: """Delete an MCP configuration.""" - try: - mcp_config_service.delete(mcp_config_id) - except FileNotFoundError as e: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + mcp_config_service.delete(mcp_config_id) From 7a7c7fa8063d0ba5a73a2ded614059841acaeaf6 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 20 Aug 2025 09:27:29 +0200 Subject: [PATCH 02/11] refactor: make saving of messages more scalable/efficient - save them into individual files instead of a single JSONL file --- src/askui/chat/api/messages/service.py | 85 ++++++------- src/askui/chat/api/threads/service.py | 11 +- tests/integration/chat/api/__init__.py | 1 + .../chat/api/test_messages_service.py | 118 +++++++++++++++++ .../chat/api/test_threads_service.py | 119 ++++++++++++++++++ 5 files changed, 282 insertions(+), 52 deletions(-) create mode 100644 tests/integration/chat/api/__init__.py create mode 100644 tests/integration/chat/api/test_messages_service.py create mode 100644 tests/integration/chat/api/test_threads_service.py diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py index 6eca288c..1af2c778 100644 --- a/src/askui/chat/api/messages/service.py +++ b/src/askui/chat/api/messages/service.py @@ -2,11 +2,11 @@ from pathlib import Path from typing import Literal -from pydantic import Field +from pydantic import Field, ValidationError from askui.chat.api.models import AssistantId, MessageId, RunId, ThreadId from askui.models.shared.agent_message_param import MessageParam -from askui.utils.api_utils import LIST_LIMIT_MAX, ListQuery +from askui.utils.api_utils import ListQuery, ListResponse, list_resource_paths from askui.utils.datetime_utils import UnixDatetime from askui.utils.id_utils import generate_time_ordered_id @@ -38,65 +38,56 @@ def __init__(self, base_dir: Path) -> None: base_dir: Base directory to store message data """ self._base_dir = base_dir - self._threads_dir = base_dir / "threads" + self._base_messages_dir = base_dir / "messages" def create(self, thread_id: ThreadId, request: MessageCreateRequest) -> Message: - messages = self.list_(thread_id, ListQuery(limit=LIST_LIMIT_MAX, order="asc")) new_message = Message( **request.model_dump(), thread_id=thread_id, ) - self.save(thread_id, messages + [new_message]) + self._save(new_message) return new_message def delete(self, thread_id: ThreadId, message_id: MessageId) -> None: - messages = self.list_(thread_id, ListQuery(limit=LIST_LIMIT_MAX, order="asc")) - filtered_messages = [m for m in messages if m.id != message_id] - if len(filtered_messages) == len(messages): + message_file = self._get_message_path(thread_id, message_id) + if not message_file.exists(): error_msg = f"Message {message_id} not found in thread {thread_id}" raise ValueError(error_msg) - self.save(thread_id, filtered_messages) + message_file.unlink() - def list_(self, thread_id: ThreadId, query: ListQuery) -> list[Message]: - thread_file = self._threads_dir / f"{thread_id}.jsonl" - if not thread_file.exists(): - error_msg = f"Thread {thread_id} not found" - raise FileNotFoundError(error_msg) + def list_(self, thread_id: ThreadId, query: ListQuery) -> ListResponse[Message]: + messages_dir = self.get_thread_messages_dir(thread_id) + if not messages_dir.exists(): + return ListResponse(data=[]) + message_paths = list_resource_paths(messages_dir, query) messages: list[Message] = [] - with thread_file.open("r", encoding="utf-8") as f: - for line in f: - msg = Message.model_validate_json(line) + for message_file in message_paths: + try: + msg = Message.model_validate_json(message_file.read_text()) messages.append(msg) - - # Sort by creation date - messages = sorted( - messages, key=lambda m: m.created_at, reverse=(query.order == "desc") + except ValidationError: # noqa: PERF203 + continue + has_more = len(messages) > query.limit + messages = messages[: query.limit] + return ListResponse( + data=messages, + first_id=messages[0].id if messages else None, + last_id=messages[-1].id if messages else None, + has_more=has_more, ) - # Apply before/after filters - if query.after: - messages = [m for m in messages if m.id > query.after] - if query.before: - messages = [m for m in messages if m.id < query.before] - - # Apply limit - return messages[: query.limit] - - def _get_thread_path(self, thread_id: ThreadId) -> Path: - thread_path = self._threads_dir / f"{thread_id}.jsonl" - if not thread_path.exists(): - error_msg = f"Thread {thread_id} not found" - raise FileNotFoundError(error_msg) - return thread_path - - def save(self, thread_id: ThreadId, messages: list[Message]) -> None: - if len(messages) > LIST_LIMIT_MAX: - error_msg = f"Thread {thread_id} has too many messages" - raise ValueError(error_msg) - messages = sorted(messages, key=lambda m: m.created_at) - thread_path = self._get_thread_path(thread_id) - with thread_path.open("w", encoding="utf-8") as f: - for msg in messages: - f.write(msg.model_dump_json()) - f.write("\n") + def get_thread_messages_dir(self, thread_id: ThreadId) -> Path: + """Get the directory path for a specific message.""" + return self._base_messages_dir / thread_id + + def _get_message_path(self, thread_id: ThreadId, message_id: MessageId) -> Path: + """Get the file path for a specific message.""" + return self.get_thread_messages_dir(thread_id) / f"{message_id}.json" + + def _save(self, message: Message) -> None: + """Save a single message to its own JSON file.""" + messages_dir = self.get_thread_messages_dir(message.thread_id) + messages_dir.mkdir(parents=True, exist_ok=True) + message_file = self._get_message_path(message.thread_id, message.id) + message_file.write_text(message.model_dump_json(), encoding="utf-8") diff --git a/src/askui/chat/api/threads/service.py b/src/askui/chat/api/threads/service.py index 6a2c879b..f00f8ecd 100644 --- a/src/askui/chat/api/threads/service.py +++ b/src/askui/chat/api/threads/service.py @@ -1,3 +1,4 @@ +import shutil from datetime import datetime, timezone from pathlib import Path from typing import Literal @@ -58,8 +59,6 @@ def create(self, request: ThreadCreateRequest) -> Thread: self._threads_dir.mkdir(parents=True, exist_ok=True) thread_file = self._threads_dir / f"{thread.id}.json" thread_file.write_text(thread.model_dump_json()) - thread_messages_file = self._threads_dir / f"{thread.id}.jsonl" - thread_messages_file.touch() if request.messages: for message in request.messages: self._message_service.create( @@ -139,9 +138,11 @@ def delete(self, thread_id: ThreadId) -> None: if not thread_file.exists(): error_msg = f"Thread {thread_id} not found" raise FileNotFoundError(error_msg) - thread_messages_file = self._threads_dir / f"{thread_id}.jsonl" - if thread_messages_file.exists(): - thread_messages_file.unlink() + + messages_dir = self._message_service.get_thread_messages_dir(thread_id) + if messages_dir.exists(): + shutil.rmtree(messages_dir) + thread_file.unlink() def modify(self, thread_id: ThreadId, request: ThreadModifyRequest) -> Thread: diff --git a/tests/integration/chat/api/__init__.py b/tests/integration/chat/api/__init__.py new file mode 100644 index 00000000..13477e6b --- /dev/null +++ b/tests/integration/chat/api/__init__.py @@ -0,0 +1 @@ +# Chat API integration tests diff --git a/tests/integration/chat/api/test_messages_service.py b/tests/integration/chat/api/test_messages_service.py new file mode 100644 index 00000000..2a62ef76 --- /dev/null +++ b/tests/integration/chat/api/test_messages_service.py @@ -0,0 +1,118 @@ +"""Integration tests for the MessageService with JSON file persistence.""" + +import json +import tempfile +from pathlib import Path +from typing import Generator + +import pytest + +from askui.chat.api.messages.service import MessageCreateRequest, MessageService +from askui.chat.api.models import ThreadId + + +@pytest.fixture +def temp_base_dir() -> Generator[Path, None, None]: + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +@pytest.fixture +def message_service(temp_base_dir: Path) -> MessageService: + """Create a MessageService instance with temporary storage.""" + return MessageService(temp_base_dir) + + +@pytest.fixture +def thread_id() -> ThreadId: + """Create a test thread ID.""" + return "thread_test123" + + +class TestMessageServiceJSONPersistence: + """Test MessageService with JSON file persistence.""" + + def test_create_message_creates_individual_json_file( + self, message_service: MessageService, thread_id: ThreadId + ) -> None: + """Test that creating a message creates an individual JSON file.""" + request = MessageCreateRequest(role="user", content="Hello, world!") + + message = message_service.create(thread_id, request) + + # Check that the message directory was created + messages_dir = message_service.get_thread_messages_dir(thread_id) + assert messages_dir.exists() + + # Check that the message file was created + message_file = message_service._get_message_path(thread_id, message.id) + assert message_file.exists() + + # Verify the file contains the correct JSON data + with message_file.open("r") as f: + data = json.load(f) + assert data["role"] == "user" + assert data["content"] == "Hello, world!" + assert data["id"] == message.id + assert data["thread_id"] == thread_id + + def test_list_messages_reads_from_json_files( + self, message_service: MessageService, thread_id: ThreadId + ) -> None: + """Test that listing messages reads from individual JSON files.""" + # Create multiple messages + messages = [] + for i in range(3): + request = MessageCreateRequest( + role="user" if i % 2 == 0 else "assistant", content=f"Message {i}" + ) + message = message_service.create(thread_id, request) + messages.append(message) + + # List messages + from askui.utils.api_utils import ListQuery + + query = ListQuery(limit=10, order="asc") + response = message_service.list_(thread_id, query) + + # Verify all messages were found + assert len(response.data) == 3 + + # Verify messages are sorted by creation time + assert response.data[0].created_at <= response.data[1].created_at + assert response.data[1].created_at <= response.data[2].created_at + + def test_delete_message_removes_json_file( + self, message_service: MessageService, thread_id: ThreadId + ) -> None: + """Test that deleting a message removes its JSON file.""" + request = MessageCreateRequest(role="user", content="Delete me") + + message = message_service.create(thread_id, request) + message_file = message_service._get_message_path(thread_id, message.id) + assert message_file.exists() + + # Delete the message + message_service.delete(thread_id, message.id) + + # Verify the file was removed + assert not message_file.exists() + + def test_directory_structure_is_correct( + self, message_service: MessageService, thread_id: ThreadId + ) -> None: + """Test that the directory structure follows the expected pattern.""" + request = MessageCreateRequest(role="user", content="Test message") + + message_service.create(thread_id, request) + + # Check directory structure - messages are stored in base_dir/messages/thread_id/ + messages_dir = message_service.get_thread_messages_dir(thread_id) + + assert messages_dir.exists() + + # Check that there's a JSON file in the messages directory + json_files = list(messages_dir.glob("*.json")) + assert len(json_files) == 1 + assert json_files[0].suffix == ".json" diff --git a/tests/integration/chat/api/test_threads_service.py b/tests/integration/chat/api/test_threads_service.py new file mode 100644 index 00000000..08692f34 --- /dev/null +++ b/tests/integration/chat/api/test_threads_service.py @@ -0,0 +1,119 @@ +"""Integration tests for the ThreadService with JSON file persistence.""" + +import tempfile +from pathlib import Path +from typing import Generator + +import pytest + +from askui.chat.api.messages.service import MessageCreateRequest, MessageService +from askui.chat.api.threads.service import ThreadCreateRequest, ThreadService + + +@pytest.fixture +def temp_base_dir() -> Generator[Path, None, None]: + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +@pytest.fixture +def message_service(temp_base_dir: Path) -> MessageService: + """Create a MessageService instance with temporary storage.""" + return MessageService(temp_base_dir) + + +@pytest.fixture +def thread_service( + temp_base_dir: Path, message_service: MessageService +) -> ThreadService: + """Create a ThreadService instance with temporary storage.""" + return ThreadService(temp_base_dir, message_service) + + +class TestThreadServiceJSONPersistence: + """Test ThreadService with JSON file persistence.""" + + def test_create_thread_creates_directory_structure( + self, thread_service: ThreadService + ) -> None: + """Test that creating a thread creates the proper directory structure.""" + request = ThreadCreateRequest(name="Test Thread") + + thread = thread_service.create(request) + + # Check that thread metadata file was created + thread_file = thread_service._base_dir / "threads" / f"{thread.id}.json" + assert thread_file.exists() + + # Check that messages directory was created (by creating a message) + # The ThreadService doesn't create the messages directory until a message is added + message_request = MessageCreateRequest(role="user", content="Test message") + thread_service._message_service.create(thread.id, message_request) + + thread_messages_dir = thread_service._message_service.get_thread_messages_dir( + thread.id + ) + assert thread_messages_dir.exists() + + # Verify thread metadata content + with thread_file.open("r") as f: + import json + + data = json.load(f) + assert data["name"] == "Test Thread" + assert data["id"] == thread.id + + def test_create_thread_with_messages(self, thread_service: ThreadService) -> None: + """Test that creating a thread with messages works correctly.""" + messages = [ + MessageCreateRequest(role="user", content="Hello"), + MessageCreateRequest(role="assistant", content="Hi there!"), + ] + request = ThreadCreateRequest(name="Thread with Messages", messages=messages) + + thread = thread_service.create(request) + + # Check that messages were created + thread_messages_dir = thread_service._message_service.get_thread_messages_dir( + thread.id + ) + json_files = list(thread_messages_dir.glob("*.json")) + assert len(json_files) == 2 + + # Verify message content + for json_file in json_files: + with json_file.open("r") as f: + import json + + data = json.load(f) + assert data["thread_id"] == thread.id + assert data["role"] in ["user", "assistant"] + + def test_delete_thread_removes_all_files( + self, thread_service: ThreadService + ) -> None: + """Test that deleting a thread removes all associated files.""" + request = ThreadCreateRequest(name="Thread to Delete") + thread = thread_service.create(request) + + # Add a message + message_request = MessageCreateRequest(role="user", content="Test message") + thread_service._message_service.create(thread.id, message_request) + + # Verify files exist + thread_file = thread_service._base_dir / "threads" / f"{thread.id}.json" + assert thread_file.exists() + + # The thread directory itself doesn't exist, only the messages directory + messages_dir = thread_service._message_service.get_thread_messages_dir( + thread.id + ) + assert messages_dir.exists() + + # Delete thread + thread_service.delete(thread.id) + + # Verify all files were removed + assert not thread_file.exists() + assert not messages_dir.exists() From 829b26e09e2c75549af383e3c71083e06683dfc8 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 20 Aug 2025 09:38:34 +0200 Subject: [PATCH 03/11] fix(chat): utf-8 encoding in all files --- src/askui/chat/api/assistants/service.py | 8 ++++---- src/askui/chat/api/mcp_configs/service.py | 8 ++++++-- src/askui/chat/api/messages/service.py | 4 +++- src/askui/chat/api/runs/runner/runner.py | 4 ++-- src/askui/chat/api/runs/service.py | 8 ++++---- src/askui/chat/api/threads/service.py | 8 ++++---- 6 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/askui/chat/api/assistants/service.py b/src/askui/chat/api/assistants/service.py index 52c2cce3..9a62cc20 100644 --- a/src/askui/chat/api/assistants/service.py +++ b/src/askui/chat/api/assistants/service.py @@ -54,7 +54,7 @@ def list_(self, query: ListQuery) -> ListResponse[Assistant]: assistant_files = list(self._assistants_dir.glob("*.json")) assistants: list[Assistant] = [] for f in assistant_files: - with f.open("r") as file: + with f.open("r", encoding="utf-8") as file: assistants.append(Assistant.model_validate_json(file.read())) # Sort by creation date @@ -95,7 +95,7 @@ def retrieve(self, assistant_id: str) -> Assistant: error_msg = f"Assistant {assistant_id} not found" raise FileNotFoundError(error_msg) - with assistant_file.open("r") as f: + with assistant_file.open("r", encoding="utf-8") as f: return Assistant.model_validate_json(f.read()) def create(self, request: CreateAssistantRequest) -> Assistant: @@ -118,7 +118,7 @@ def _save(self, assistant: Assistant) -> None: """Save an assistant to the file system.""" self._assistants_dir.mkdir(parents=True, exist_ok=True) assistant_file = self._assistants_dir / f"{assistant.id}.json" - with assistant_file.open("w") as f: + with assistant_file.open("w", encoding="utf-8") as f: f.write(assistant.model_dump_json()) def modify(self, assistant_id: str, request: AssistantModifyRequest) -> Assistant: @@ -142,7 +142,7 @@ def modify(self, assistant_id: str, request: AssistantModifyRequest) -> Assistan if not isinstance(request.avatar, DoNotPatch): assistant.avatar = request.avatar assistant_file = self._assistants_dir / f"{assistant_id}.json" - with assistant_file.open("w") as f: + with assistant_file.open("w", encoding="utf-8") as f: f.write(assistant.model_dump_json()) return assistant diff --git a/src/askui/chat/api/mcp_configs/service.py b/src/askui/chat/api/mcp_configs/service.py index abff9d55..6b3651a3 100644 --- a/src/askui/chat/api/mcp_configs/service.py +++ b/src/askui/chat/api/mcp_configs/service.py @@ -36,7 +36,9 @@ def list_( mcp_configs: list[McpConfig] = [] for f in mcp_config_paths: try: - mcp_config = McpConfig.model_validate_json(f.read_text()) + mcp_config = McpConfig.model_validate_json( + f.read_text(encoding="utf-8") + ) mcp_configs.append(mcp_config) except ValidationError: # noqa: PERF203 continue @@ -54,7 +56,9 @@ def retrieve(self, mcp_config_id: McpConfigId) -> McpConfig: if not mcp_config_file.exists(): error_msg = f"MCP configuration {mcp_config_id} not found" raise NotFoundError(error_msg) - return McpConfig.model_validate_json(mcp_config_file.read_text()) + return McpConfig.model_validate_json( + mcp_config_file.read_text(encoding="utf-8") + ) def _check_limit(self) -> None: limit = LIST_LIMIT_MAX diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py index 1af2c778..01b8b338 100644 --- a/src/askui/chat/api/messages/service.py +++ b/src/askui/chat/api/messages/service.py @@ -64,7 +64,9 @@ def list_(self, thread_id: ThreadId, query: ListQuery) -> ListResponse[Message]: messages: list[Message] = [] for message_file in message_paths: try: - msg = Message.model_validate_json(message_file.read_text()) + msg = Message.model_validate_json( + message_file.read_text(encoding="utf-8") + ) messages.append(msg) except ValidationError: # noqa: PERF203 continue diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index f8dd63b4..e2f76dbd 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -397,10 +397,10 @@ def _should_abort(self, run: Run) -> bool: def _update_run_file(self, run: Run) -> None: run_file = self._runs_dir / f"{run.thread_id}__{run.id}.json" - with run_file.open("w") as f: + with run_file.open("w", encoding="utf-8") as f: f.write(run.model_dump_json()) def _retrieve_run(self) -> Run: run_file = self._runs_dir / f"{self._run.thread_id}__{self._run.id}.json" - with run_file.open("r") as f: + with run_file.open("r", encoding="utf-8") as f: return Run.model_validate_json(f.read()) diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py index e5c4a647..6c66a143 100644 --- a/src/askui/chat/api/runs/service.py +++ b/src/askui/chat/api/runs/service.py @@ -91,13 +91,13 @@ async def run_runner() -> None: def _update_run_file(self, run: Run) -> None: run_file = self._run_path(run.thread_id, run.id) - with run_file.open("w") as f: + with run_file.open("w", encoding="utf-8") as f: f.write(run.model_dump_json()) def retrieve(self, run_id: RunId) -> Run: # Find the file by run_id for f in self._runs_dir.glob(f"*__{run_id}.json"): - with f.open("r") as file: + with f.open("r", encoding="utf-8") as file: return Run.model_validate_json(file.read()) error_msg = f"Run {run_id} not found" raise FileNotFoundError(error_msg) @@ -119,7 +119,7 @@ def list_(self, thread_id: ThreadId, query: ListQuery) -> ListResponse[Run]: runs: list[Run] = [] for f in run_files: - with f.open("r") as file: + with f.open("r", encoding="utf-8") as file: runs.append(Run.model_validate_json(file.read())) # Sort by creation date @@ -151,7 +151,7 @@ def cancel(self, run_id: RunId) -> Run: return run run.tried_cancelling_at = datetime.now(tz=timezone.utc) for f in self._runs_dir.glob(f"*__{run_id}.json"): - with f.open("w") as file: + with f.open("w", encoding="utf-8") as file: file.write(run.model_dump_json()) return run # Find the file by run_id diff --git a/src/askui/chat/api/threads/service.py b/src/askui/chat/api/threads/service.py index f00f8ecd..542bd405 100644 --- a/src/askui/chat/api/threads/service.py +++ b/src/askui/chat/api/threads/service.py @@ -58,7 +58,7 @@ def create(self, request: ThreadCreateRequest) -> Thread: thread = Thread(name=request.name) self._threads_dir.mkdir(parents=True, exist_ok=True) thread_file = self._threads_dir / f"{thread.id}.json" - thread_file.write_text(thread.model_dump_json()) + thread_file.write_text(thread.model_dump_json(), encoding="utf-8") if request.messages: for message in request.messages: self._message_service.create( @@ -83,7 +83,7 @@ def list_(self, query: ListQuery) -> ListResponse[Thread]: thread_files = list(self._threads_dir.glob("*.json")) threads: list[Thread] = [] for f in thread_files: - thread = Thread.model_validate_json(f.read_text()) + thread = Thread.model_validate_json(f.read_text(encoding="utf-8")) threads.append(thread) # Sort by creation date @@ -123,7 +123,7 @@ def retrieve(self, thread_id: ThreadId) -> Thread: if not thread_file.exists(): error_msg = f"Thread {thread_id} not found" raise FileNotFoundError(error_msg) - return Thread.model_validate_json(thread_file.read_text()) + return Thread.model_validate_json(thread_file.read_text(encoding="utf-8")) def delete(self, thread_id: ThreadId) -> None: """Delete a thread and all its associated files. @@ -156,5 +156,5 @@ def modify(self, thread_id: ThreadId, request: ThreadModifyRequest) -> Thread: if not isinstance(request.name, DoNotPatch): thread.name = request.name thread_file = self._threads_dir / f"{thread_id}.json" - thread_file.write_text(thread.model_dump_json()) + thread_file.write_text(thread.model_dump_json(), encoding="utf-8") return thread From dfeb4d04927b996185d64c6c007c7f6729fa8115 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 20 Aug 2025 12:30:10 +0200 Subject: [PATCH 04/11] feat(chat): scope resources to workspaces with exception of "assistants" --- src/askui/chat/api/dependencies.py | 17 ++++++++++++++++- src/askui/chat/api/mcp_configs/dependencies.py | 9 +++++---- src/askui/chat/api/messages/dependencies.py | 9 +++++---- src/askui/chat/api/runs/dependencies.py | 9 +++++---- src/askui/chat/api/threads/dependencies.py | 9 +++++---- 5 files changed, 36 insertions(+), 17 deletions(-) diff --git a/src/askui/chat/api/dependencies.py b/src/askui/chat/api/dependencies.py index ea540592..2005cfe4 100644 --- a/src/askui/chat/api/dependencies.py +++ b/src/askui/chat/api/dependencies.py @@ -1,7 +1,8 @@ import os +from pathlib import Path from typing import Annotated, Optional -from fastapi import Depends, Header +from fastapi import Depends, Header, HTTPException from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer from pydantic import UUID4 @@ -55,3 +56,17 @@ def set_env_from_headers( SetEnvFromHeadersDep = Depends(set_env_from_headers) + + +def get_workspace_dir( + settings: Settings = SettingsDep, + askui_workspace: Annotated[str | None, Header()] = None, +) -> Path: + if not askui_workspace: + raise HTTPException( + status_code=400, detail="AskUI-Workspace header is required" + ) + return settings.data_dir / "workspaces" / askui_workspace + + +WorkspaceDirDep = Depends(get_workspace_dir) diff --git a/src/askui/chat/api/mcp_configs/dependencies.py b/src/askui/chat/api/mcp_configs/dependencies.py index 3ed0a11b..11e6b4ef 100644 --- a/src/askui/chat/api/mcp_configs/dependencies.py +++ b/src/askui/chat/api/mcp_configs/dependencies.py @@ -1,13 +1,14 @@ +from pathlib import Path + from fastapi import Depends -from askui.chat.api.dependencies import SettingsDep +from askui.chat.api.dependencies import WorkspaceDirDep from askui.chat.api.mcp_configs.service import McpConfigService -from askui.chat.api.settings import Settings -def get_mcp_config_service(settings: Settings = SettingsDep) -> McpConfigService: +def get_mcp_config_service(workspace_dir: Path = WorkspaceDirDep) -> McpConfigService: """Get McpConfigService instance.""" - return McpConfigService(settings.data_dir) + return McpConfigService(workspace_dir) McpConfigServiceDep = Depends(get_mcp_config_service) diff --git a/src/askui/chat/api/messages/dependencies.py b/src/askui/chat/api/messages/dependencies.py index 51bff5af..9db192a3 100644 --- a/src/askui/chat/api/messages/dependencies.py +++ b/src/askui/chat/api/messages/dependencies.py @@ -1,15 +1,16 @@ +from pathlib import Path + from fastapi import Depends -from askui.chat.api.dependencies import SettingsDep +from askui.chat.api.dependencies import WorkspaceDirDep from askui.chat.api.messages.service import MessageService -from askui.chat.api.settings import Settings def get_message_service( - settings: Settings = SettingsDep, + workspace_dir: Path = WorkspaceDirDep, ) -> MessageService: """Get MessagePersistedService instance.""" - return MessageService(settings.data_dir) + return MessageService(workspace_dir) MessageServiceDep = Depends(get_message_service) diff --git a/src/askui/chat/api/runs/dependencies.py b/src/askui/chat/api/runs/dependencies.py index 772c2545..dd37c09a 100644 --- a/src/askui/chat/api/runs/dependencies.py +++ b/src/askui/chat/api/runs/dependencies.py @@ -1,14 +1,15 @@ +from pathlib import Path + from fastapi import Depends -from askui.chat.api.dependencies import SettingsDep -from askui.chat.api.settings import Settings +from askui.chat.api.dependencies import WorkspaceDirDep from .service import RunService -def get_runs_service(settings: Settings = SettingsDep) -> RunService: +def get_runs_service(workspace_dir: Path = WorkspaceDirDep) -> RunService: """Get RunService instance.""" - return RunService(settings.data_dir) + return RunService(workspace_dir) RunServiceDep = Depends(get_runs_service) diff --git a/src/askui/chat/api/threads/dependencies.py b/src/askui/chat/api/threads/dependencies.py index 46f0840d..b52df3ca 100644 --- a/src/askui/chat/api/threads/dependencies.py +++ b/src/askui/chat/api/threads/dependencies.py @@ -1,19 +1,20 @@ +from pathlib import Path + from fastapi import Depends -from askui.chat.api.dependencies import SettingsDep +from askui.chat.api.dependencies import WorkspaceDirDep from askui.chat.api.messages.dependencies import MessageServiceDep from askui.chat.api.messages.service import MessageService -from askui.chat.api.settings import Settings from askui.chat.api.threads.service import ThreadService def get_thread_service( - settings: Settings = SettingsDep, + workspace_dir: Path = WorkspaceDirDep, message_service: MessageService = MessageServiceDep, ) -> ThreadService: """Get ThreadService instance.""" return ThreadService( - base_dir=settings.data_dir, + base_dir=workspace_dir, message_service=message_service, ) From 11cb9be611419ce6a40def9fd3fc31490bd46922 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 20 Aug 2025 13:51:30 +0200 Subject: [PATCH 05/11] refactor(chat): simplify messages service - more consistent with other services --- src/askui/chat/api/messages/router.py | 21 ++++----------- src/askui/chat/api/messages/service.py | 34 +++++++++++++++++------- src/askui/chat/api/runs/runner/runner.py | 5 ++-- 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/src/askui/chat/api/messages/router.py b/src/askui/chat/api/messages/router.py index 65b8153e..51e07d8b 100644 --- a/src/askui/chat/api/messages/router.py +++ b/src/askui/chat/api/messages/router.py @@ -7,7 +7,7 @@ MessageService, ) from askui.chat.api.models import ListQueryDep, MessageId, ThreadId -from askui.utils.api_utils import ListQuery, ListResponse +from askui.utils.api_utils import ListQuery, ListResponse, NotFoundError router = APIRouter(prefix="/threads/{thread_id}/messages", tags=["messages"]) @@ -20,13 +20,7 @@ def list_messages( ) -> ListResponse[Message]: """List all messages in a thread.""" try: - messages = message_service.list_(thread_id, query=query) - return ListResponse( - data=messages, - first_id=messages[0].id if messages else None, - last_id=messages[-1].id if messages else None, - has_more=len(messages) > query.limit, - ) + return message_service.list_(thread_id, query=query) except FileNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) from e @@ -52,13 +46,8 @@ def retrieve_message( ) -> Message: """Get a specific message from a thread.""" try: - messages = message_service.list_(thread_id=thread_id, query=ListQuery(limit=1)) - for msg in messages: - if msg.id == message_id: - return msg - error_msg = f"Message {message_id} not found in thread {thread_id}" - raise HTTPException(status_code=404, detail=error_msg) - except FileNotFoundError as e: + return message_service.retrieve(thread_id, message_id) + except NotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) from e @@ -71,5 +60,5 @@ def delete_message( """Delete a message from a thread.""" try: message_service.delete(thread_id, message_id) - except FileNotFoundError as e: + except NotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) from e diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py index 01b8b338..269f6c92 100644 --- a/src/askui/chat/api/messages/service.py +++ b/src/askui/chat/api/messages/service.py @@ -6,7 +6,13 @@ from askui.chat.api.models import AssistantId, MessageId, RunId, ThreadId from askui.models.shared.agent_message_param import MessageParam -from askui.utils.api_utils import ListQuery, ListResponse, list_resource_paths +from askui.utils.api_utils import ( + ConflictError, + ListQuery, + ListResponse, + NotFoundError, + list_resource_paths, +) from askui.utils.datetime_utils import UnixDatetime from askui.utils.id_utils import generate_time_ordered_id @@ -45,14 +51,11 @@ def create(self, thread_id: ThreadId, request: MessageCreateRequest) -> Message: **request.model_dump(), thread_id=thread_id, ) - self._save(new_message) + self._save(new_message, new=True) return new_message def delete(self, thread_id: ThreadId, message_id: MessageId) -> None: message_file = self._get_message_path(thread_id, message_id) - if not message_file.exists(): - error_msg = f"Message {message_id} not found in thread {thread_id}" - raise ValueError(error_msg) message_file.unlink() def list_(self, thread_id: ThreadId, query: ListQuery) -> ListResponse[Message]: @@ -79,17 +82,30 @@ def list_(self, thread_id: ThreadId, query: ListQuery) -> ListResponse[Message]: has_more=has_more, ) + def retrieve(self, thread_id: ThreadId, message_id: MessageId) -> Message: + message_file = self._get_message_path(thread_id, message_id) + return Message.model_validate_json(message_file.read_text(encoding="utf-8")) + def get_thread_messages_dir(self, thread_id: ThreadId) -> Path: """Get the directory path for a specific message.""" return self._base_messages_dir / thread_id - def _get_message_path(self, thread_id: ThreadId, message_id: MessageId) -> Path: + def _get_message_path( + self, thread_id: ThreadId, message_id: MessageId, new: bool = False + ) -> Path: """Get the file path for a specific message.""" - return self.get_thread_messages_dir(thread_id) / f"{message_id}.json" + message_path = self.get_thread_messages_dir(thread_id) / f"{message_id}.json" + if new and message_path.exists(): + error_msg = f"Message {message_id} already exists in thread {thread_id}" + raise ConflictError(error_msg) + if not new and not message_path.exists(): + error_msg = f"Message {message_id} not found in thread {thread_id}" + raise NotFoundError(error_msg) + return message_path - def _save(self, message: Message) -> None: + def _save(self, message: Message, new: bool = False) -> None: """Save a single message to its own JSON file.""" messages_dir = self.get_thread_messages_dir(message.thread_id) messages_dir.mkdir(parents=True, exist_ok=True) - message_file = self._get_message_path(message.thread_id, message.id) + message_file = self._get_message_path(message.thread_id, message.id, new=new) message_file.write_text(message.model_dump_json(), encoding="utf-8") diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index e2f76dbd..82206a60 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -1,5 +1,4 @@ import logging -import time from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Literal, Sequence @@ -234,8 +233,8 @@ async def _run_agent( ) for msg in self._msg_service.list_( thread_id=self._run.thread_id, - query=ListQuery(limit=LIST_LIMIT_MAX, order="asc"), - ) + query=ListQuery(limit=LIST_LIMIT_MAX), + ).data ] async def async_on_message( From e1c6cf7bb7d5e957f3cfbcd59bd02881528f5c38 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 20 Aug 2025 13:52:07 +0200 Subject: [PATCH 06/11] fix(chat): handle case where no MCP configs are available - `fastmcp.Client` raises error if no MCP configs are available --- src/askui/chat/api/runs/runner/runner.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index 82206a60..38354ebf 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -64,7 +64,7 @@ def build_fast_mcp_config(mcp_configs: Sequence[McpConfig]) -> MCPConfig: def get_mcp_client( base_dir: Path, -) -> McpClient: +) -> McpClient | None: """Get an MCP client from all available MCP configs. *Important*: This function can only handle up to 100 MCP server configs. Tool names @@ -80,7 +80,7 @@ def get_mcp_client( mcp_config_service = McpConfigService(base_dir) mcp_configs = mcp_config_service.list_(ListQuery(limit=LIST_LIMIT_MAX, order="asc")) fast_mcp_config = build_fast_mcp_config(mcp_configs.data) - return Client(fast_mcp_config) + return Client(fast_mcp_config) if fast_mcp_config.mcpServers else None class Runner: @@ -184,7 +184,7 @@ async def _run_human_agent(self, send_stream: ObjectStream[Events]) -> None: ) async def _run_askui_android_agent( - self, send_stream: ObjectStream[Events], mcp_client: McpClient + self, send_stream: ObjectStream[Events], mcp_client: McpClient | None ) -> None: await self._run_agent( agent_type="android", @@ -193,7 +193,7 @@ async def _run_askui_android_agent( ) async def _run_askui_vision_agent( - self, send_stream: ObjectStream[Events], mcp_client: McpClient + self, send_stream: ObjectStream[Events], mcp_client: McpClient | None ) -> None: await self._run_agent( agent_type="vision", @@ -202,7 +202,7 @@ async def _run_askui_vision_agent( ) async def _run_askui_web_agent( - self, send_stream: ObjectStream[Events], mcp_client: McpClient + self, send_stream: ObjectStream[Events], mcp_client: McpClient | None ) -> None: await self._run_agent( agent_type="web", @@ -211,7 +211,7 @@ async def _run_askui_web_agent( ) async def _run_askui_web_testing_agent( - self, send_stream: ObjectStream[Events], mcp_client: McpClient + self, send_stream: ObjectStream[Events], mcp_client: McpClient | None ) -> None: await self._run_agent( agent_type="web_testing", @@ -223,7 +223,7 @@ async def _run_agent( self, agent_type: Literal["android", "vision", "web", "web_testing"], send_stream: ObjectStream[Events], - mcp_client: McpClient, + mcp_client: McpClient | None, ) -> None: tools = ToolCollection(mcp_client=mcp_client) messages: list[MessageParam] = [ From 670f66525485a9b9beb0227775c6c9166d128c66 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 20 Aug 2025 13:53:11 +0200 Subject: [PATCH 07/11] fix: fix ordering of ids - already move to bson.ObjectId to prepare move to MongoDB --- pdm.lock | 15 ++++++++++++++- pyproject.toml | 1 + src/askui/utils/id_utils.py | 13 +++---------- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/pdm.lock b/pdm.lock index b7b75bde..c5315d63 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "all", "android", "chat", "dev", "pynput", "test", "web"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:0ca715599b1575a162ce51176bd474c59544297145f8d4ecd862bd7c14223775" +content_hash = "sha256:d3a2688df580a571a134c1a518396bcf1106a9e725e9466b90a3edc64473bcdf" [[metadata.targets]] requires_python = ">=3.10" @@ -186,6 +186,19 @@ files = [ {file = "black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666"}, ] +[[package]] +name = "bson" +version = "0.5.10" +summary = "BSON codec for Python" +groups = ["default"] +dependencies = [ + "python-dateutil>=2.4.0", + "six>=1.9.0", +] +files = [ + {file = "bson-0.5.10.tar.gz", hash = "sha256:d6511b2ab051139a9123c184de1a04227262173ad593429d21e443d6462d6590"}, +] + [[package]] name = "cachetools" version = "5.5.2" diff --git a/pyproject.toml b/pyproject.toml index 943037bd..f08e3937 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "filetype>=1.2.0", "markitdown[xls,xlsx,docx]>=0.1.2", "asyncer==0.0.8", + "bson>=0.5.10", ] requires-python = ">=3.10" readme = "README.md" diff --git a/src/askui/utils/id_utils.py b/src/askui/utils/id_utils.py index ecec3427..91681c8f 100644 --- a/src/askui/utils/id_utils.py +++ b/src/askui/utils/id_utils.py @@ -1,8 +1,6 @@ -import base64 -import os -import time from typing import Any +import bson from pydantic import Field @@ -15,13 +13,8 @@ def generate_time_ordered_id(prefix: str) -> str: Returns: str: Time-ordered ID string """ - timestamp = int(time.time() * 1000) - timestamp_b32 = ( - base64.b32encode(str(timestamp).encode()).decode().rstrip("=").lower() - ) - random_bytes = os.urandom(12) - random_b32 = base64.b32encode(random_bytes).decode().rstrip("=").lower() - return f"{prefix}_{timestamp_b32}{random_b32}" + + return f"{prefix}_{str(bson.ObjectId())}" def IdField(prefix: str) -> Any: From 785667b3aa7986d1443edd30a2e38e36fbff5e1c Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 21 Aug 2025 09:29:25 +0200 Subject: [PATCH 08/11] fix: CRUD resource operations - listing was not working - made all services, routers, models as consistent as possible - readd assistant endpoints --- mypy.ini | 3 + pyproject.toml | 1 + src/askui/chat/api/app.py | 35 ++- src/askui/chat/api/assistants/models.py | 48 ++++- src/askui/chat/api/assistants/router.py | 69 +++--- src/askui/chat/api/assistants/seeds.py | 6 + src/askui/chat/api/assistants/service.py | 200 +++++------------- src/askui/chat/api/dependencies.py | 6 +- src/askui/chat/api/mcp_configs/models.py | 23 +- src/askui/chat/api/mcp_configs/router.py | 3 +- src/askui/chat/api/mcp_configs/service.py | 84 +++----- src/askui/chat/api/messages/models.py | 33 +++ src/askui/chat/api/messages/router.py | 37 +--- src/askui/chat/api/messages/service.py | 117 +++------- src/askui/chat/api/models.py | 7 - src/askui/chat/api/runs/models.py | 50 +++-- src/askui/chat/api/runs/router.py | 42 +--- .../api/runs/runner/events/message_events.py | 2 +- src/askui/chat/api/runs/runner/runner.py | 71 ++++--- src/askui/chat/api/runs/service.py | 145 +++++-------- src/askui/chat/api/threads/models.py | 52 +++++ src/askui/chat/api/threads/router.py | 45 ++-- src/askui/chat/api/threads/service.py | 190 +++++------------ src/askui/tools/testing/execution_models.py | 4 +- src/askui/tools/testing/execution_service.py | 106 +++++----- src/askui/tools/testing/feature_models.py | 4 +- src/askui/tools/testing/feature_service.py | 85 ++++---- src/askui/tools/testing/scenario_models.py | 4 +- src/askui/tools/testing/scenario_service.py | 109 +++++----- src/askui/utils/api_utils.py | 56 ++++- .../chat/api/test_messages_service.py | 16 +- .../chat/api/test_threads_service.py | 35 ++- 32 files changed, 783 insertions(+), 905 deletions(-) create mode 100644 src/askui/chat/api/messages/models.py create mode 100644 src/askui/chat/api/threads/models.py diff --git a/mypy.ini b/mypy.ini index 3e2989af..3e06fd5c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -24,3 +24,6 @@ namespace_packages = true [mypy-jsonref.*] ignore_missing_imports = true + +[mypy-bson.*] +ignore_missing_imports = true diff --git a/pyproject.toml b/pyproject.toml index f08e3937..0f35433c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,7 @@ python_files = ["test_*.py"] python_functions = ["test_*"] testpaths = ["tests"] timeout = 60 +asyncio_default_fixture_loop_scope = "session" [tool.ruff] # Exclude a variety of commonly ignored directories. diff --git a/src/askui/chat/api/app.py b/src/askui/chat/api/app.py index 8acfaf14..f5586a2c 100644 --- a/src/askui/chat/api/app.py +++ b/src/askui/chat/api/app.py @@ -1,8 +1,9 @@ from contextlib import asynccontextmanager from typing import AsyncGenerator -from fastapi import APIRouter, FastAPI +from fastapi import APIRouter, FastAPI, Request, status from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse from askui.chat.api.assistants.dependencies import get_assistant_service from askui.chat.api.assistants.router import router as assistants_router @@ -12,6 +13,7 @@ from askui.chat.api.messages.router import router as messages_router from askui.chat.api.runs.router import router as runs_router from askui.chat.api.threads.router import router as threads_router +from askui.utils.api_utils import ConflictError, LimitReachedError, NotFoundError @asynccontextmanager @@ -38,6 +40,37 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 allow_headers=["*"], ) + +@app.exception_handler(NotFoundError) +def not_found_error_handler( + request: Request, # noqa: ARG001 + exc: NotFoundError, +) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, content={"detail": str(exc)} + ) + + +@app.exception_handler(ConflictError) +def conflict_error_handler( + request: Request, # noqa: ARG001 + exc: ConflictError, +) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_409_CONFLICT, content={"detail": str(exc)} + ) + + +@app.exception_handler(LimitReachedError) +def limit_reached_error_handler( + request: Request, # noqa: ARG001 + exc: LimitReachedError, +) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, content={"detail": str(exc)} + ) + + # Include routers v1_router = APIRouter(prefix="/v1") v1_router.include_router(assistants_router) diff --git a/src/askui/chat/api/assistants/models.py b/src/askui/chat/api/assistants/models.py index ce16e8f7..ba741a17 100644 --- a/src/askui/chat/api/assistants/models.py +++ b/src/askui/chat/api/assistants/models.py @@ -1,17 +1,53 @@ from typing import Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel +from askui.chat.api.models import AssistantId +from askui.utils.api_utils import Resource from askui.utils.datetime_utils import UnixDatetime, now from askui.utils.id_utils import generate_time_ordered_id +from askui.utils.not_given import NOT_GIVEN, BaseModelWithNotGiven, NotGiven -class Assistant(BaseModel): - """An assistant that can be used in a thread.""" +class AssistantBase(BaseModel): + """Base assistant model.""" - id: str = Field(default_factory=lambda: generate_time_ordered_id("asst")) - created_at: UnixDatetime = Field(default_factory=now) name: str | None = None description: str | None = None + avatar: str | None = None + + +class AssistantCreateParams(AssistantBase): + """Parameters for creating an assistant.""" + + +class AssistantModifyParams(BaseModelWithNotGiven): + """Parameters for modifying an assistant.""" + + name: str | NotGiven = NOT_GIVEN + description: str | NotGiven = NOT_GIVEN + avatar: str | NotGiven = NOT_GIVEN + + +class Assistant(AssistantBase, Resource): + """An assistant that can be used in a thread.""" + + id: AssistantId object: Literal["assistant"] = "assistant" - avatar: str | None = Field(default=None, description="URL of the avatar image") + created_at: UnixDatetime + + @classmethod + def create(cls, params: AssistantCreateParams) -> "Assistant": + return cls( + id=generate_time_ordered_id("asst"), + created_at=now(), + **params.model_dump(), + ) + + def modify(self, params: AssistantModifyParams) -> "Assistant": + return Assistant.model_validate( + { + **self.model_dump(), + **params.model_dump(), + } + ) diff --git a/src/askui/chat/api/assistants/router.py b/src/askui/chat/api/assistants/router.py index b8e73d77..515ae6d7 100644 --- a/src/askui/chat/api/assistants/router.py +++ b/src/askui/chat/api/assistants/router.py @@ -1,12 +1,13 @@ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, status -# from fastapi import status from askui.chat.api.assistants.dependencies import AssistantServiceDep -from askui.chat.api.assistants.models import Assistant -from askui.chat.api.assistants.service import ( - AssistantService, # AssistantModifyRequest, CreateAssistantRequest, +from askui.chat.api.assistants.models import ( + Assistant, + AssistantCreateParams, + AssistantModifyParams, ) -from askui.chat.api.models import ListQueryDep +from askui.chat.api.assistants.service import AssistantService +from askui.chat.api.models import AssistantId, ListQueryDep from askui.utils.api_utils import ListQuery, ListResponse router = APIRouter(prefix="/assistants", tags=["assistants"]) @@ -17,51 +18,37 @@ def list_assistants( query: ListQuery = ListQueryDep, assistant_service: AssistantService = AssistantServiceDep, ) -> ListResponse[Assistant]: - """List all assistants.""" return assistant_service.list_(query=query) -# @router.post("", status_code=status.HTTP_201_CREATED) -# def create_assistant( -# request: CreateAssistantRequest, -# assistant_service: AssistantService = AssistantServiceDep, -# ) -> Assistant: -# """Create a new assistant.""" -# return assistant_service.create(request) +@router.post("", status_code=status.HTTP_201_CREATED) +def create_assistant( + params: AssistantCreateParams, + assistant_service: AssistantService = AssistantServiceDep, +) -> Assistant: + return assistant_service.create(params) @router.get("/{assistant_id}") def retrieve_assistant( - assistant_id: str, + assistant_id: AssistantId, assistant_service: AssistantService = AssistantServiceDep, ) -> Assistant: - """Get an assistant by ID.""" - try: - return assistant_service.retrieve(assistant_id) - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e + return assistant_service.retrieve(assistant_id) -# @router.post("/{assistant_id}") -# def modify_assistant( -# assistant_id: str, -# request: AssistantModifyRequest, -# assistant_service: AssistantService = AssistantServiceDep, -# ) -> Assistant: -# """Update an assistant.""" -# try: -# return assistant_service.modify(assistant_id, request) -# except FileNotFoundError as e: -# raise HTTPException(status_code=404, detail=str(e)) from e +@router.post("/{assistant_id}") +def modify_assistant( + assistant_id: AssistantId, + params: AssistantModifyParams, + assistant_service: AssistantService = AssistantServiceDep, +) -> Assistant: + return assistant_service.modify(assistant_id, params) -# @router.delete("/{assistant_id}", status_code=status.HTTP_204_NO_CONTENT) -# def delete_assistant( -# assistant_id: str, -# assistant_service: AssistantService = AssistantServiceDep, -# ) -> None: -# """Delete an assistant.""" -# try: -# assistant_service.delete(assistant_id) -# except FileNotFoundError as e: -# raise HTTPException(status_code=404, detail=str(e)) from e +@router.delete("/{assistant_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_assistant( + assistant_id: AssistantId, + assistant_service: AssistantService = AssistantServiceDep, +) -> None: + assistant_service.delete(assistant_id) diff --git a/src/askui/chat/api/assistants/seeds.py b/src/askui/chat/api/assistants/seeds.py index 8f098a4d..7a873315 100644 --- a/src/askui/chat/api/assistants/seeds.py +++ b/src/askui/chat/api/assistants/seeds.py @@ -1,31 +1,37 @@ from askui.chat.api.assistants.models import Assistant +from askui.utils.datetime_utils import now ASKUI_VISION_AGENT = Assistant( id="asst_ge3tiojsga3dgnruge3di2u5ov36shedkcslxnmca", + created_at=now(), name="AskUI Vision Agent", avatar="data:image/svg+xml;base64,PHN2ZyAgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIgogIHdpZHRoPSIyNCIKICBoZWlnaHQ9IjI0IgogIHZpZXdCb3g9IjAgMCAyNCAyNCIKICBmaWxsPSJub25lIgogIHN0cm9rZT0iIzAwMCIgc3R5bGU9ImJhY2tncm91bmQtY29sb3I6ICNmZmY7IGJvcmRlci1yYWRpdXM6IDJweCIKICBzdHJva2Utd2lkdGg9IjIiCiAgc3Ryb2tlLWxpbmVjYXA9InJvdW5kIgogIHN0cm9rZS1saW5lam9pbj0icm91bmQiCj4KICA8cGF0aCBkPSJNMTIgOFY0SDgiIC8+CiAgPHJlY3Qgd2lkdGg9IjE2IiBoZWlnaHQ9IjEyIiB4PSI0IiB5PSI4IiByeD0iMiIgLz4KICA8cGF0aCBkPSJNMiAxNGgyIiAvPgogIDxwYXRoIGQ9Ik0yMCAxNGgyIiAvPgogIDxwYXRoIGQ9Ik0xNSAxM3YyIiAvPgogIDxwYXRoIGQ9Ik05IDEzdjIiIC8+Cjwvc3ZnPgo=", ) HUMAN_DEMONSTRATION_AGENT = Assistant( id="asst_ge3tiojsga3dgnruge3di2u5ov36shedkcslxnmcb", + created_at=now(), name="Human DemonstrationAgent", avatar="data:image/svg+xml;base64,PHN2ZyAgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIgogIHdpZHRoPSIyNCIKICBoZWlnaHQ9IjI0IgogIHZpZXdCb3g9IjAgMCAyNCAyNCIKICBmaWxsPSJub25lIgogIHN0cm9rZT0iIzAwMCIgc3R5bGU9ImJhY2tncm91bmQtY29sb3I6ICNmZmY7IGJvcmRlci1yYWRpdXM6IDJweCIKICBzdHJva2Utd2lkdGg9IjIiCiAgc3Ryb2tlLWxpbmVjYXA9InJvdW5kIgogIHN0cm9rZS1saW5lam9pbj0icm91bmQiCj4KICA8cGF0aCBkPSJNMTkgMjF2LTJhNCA0IDAgMCAwLTQtNEg5YTQgNCAwIDAgMC00IDR2MiIgLz4KICA8Y2lyY2xlIGN4PSIxMiIgY3k9IjciIHI9IjQiIC8+Cjwvc3ZnPgo=", ) ANDROID_VISION_AGENT = Assistant( id="asst_78da09fbf1ed43c7826fb1686f89f541", + created_at=now(), name="AskUI Android Vision Agent", avatar="data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciICB2aWV3Qm94PSIwIDAgNDggNDgiIHdpZHRoPSIyNXB4IiBoZWlnaHQ9IjI1cHgiPjxwYXRoIGQ9Ik0gMzIuNTE5NTMxIDAuOTgyNDIxODggQSAxLjUwMDE1IDEuNTAwMTUgMCAwIDAgMzEuMjc5Mjk3IDEuNjI4OTA2MiBMIDI5LjQzNzUgNC4yMDg5ODQ0IEMgMjcuNzgwMjA3IDMuNDQwNTAwNiAyNS45NDE5MSAzIDI0IDMgQyAyMi4wNTgwOSAzIDIwLjIxOTc5MyAzLjQ0MDUwMDYgMTguNTYyNSA0LjIwODk4NDQgTCAxNi43MjA3MDMgMS42Mjg5MDYyIEEgMS41MDAxNSAxLjUwMDE1IDAgMCAwIDE1LjQzNTU0NyAwLjk4NDM3NSBBIDEuNTAwMTUgMS41MDAxNSAwIDAgMCAxNC4yNzkyOTcgMy4zNzEwOTM4IEwgMTYgNS43NzkyOTY5IEMgMTMuMTM4ODk2IDguMDI0NzU4MiAxMS4yNDUxODggMTEuNDM2MDIgMTEuMDM1MTU2IDE1LjI5MTAxNiBDIDEwLjU1MzI2IDE1LjExMjgxOCAxMC4wNDA0MDggMTUgOS41IDE1IEMgNy4wMzI0OTkxIDE1IDUgMTcuMDMyNDk5IDUgMTkuNSBMIDUgMzAuNSBDIDUgMzIuOTY3NTAxIDcuMDMyNDk5MSAzNSA5LjUgMzUgQyAxMC4wOTAzMTMgMzUgMTAuNjUzMjI5IDM0Ljg3ODc0OSAxMS4xNzE4NzUgMzQuNjY3OTY5IEMgMTEuNTY0MzM2IDM2LjA3MjEwNSAxMi42MzEzMzMgMzcuMTk2OTk0IDE0IDM3LjY5MzM1OSBMIDE0IDQxLjUgQyAxNCA0My45Njc1MDEgMTYuMDMyNDk5IDQ2IDE4LjUgNDYgQyAyMC45Njc1MDEgNDYgMjMgNDMuOTY3NTAxIDIzIDQxLjUgTCAyMyAzOCBMIDI1IDM4IEwgMjUgNDEuNSBDIDI1IDQzLjk2NzUwMSAyNy4wMzI0OTkgNDYgMjkuNSA0NiBDIDMxLjk2NzUwMSA0NiAzNCA0My45Njc1MDEgMzQgNDEuNSBMIDM0IDM3LjY5MzM1OSBDIDM1LjM2ODY2NyAzNy4xOTY5OTQgMzYuNDM1NjY0IDM2LjA3MjEwNSAzNi44MjgxMjUgMzQuNjY3OTY5IEMgMzcuMzQ2NzcxIDM0Ljg3ODc0OSAzNy45MDk2ODcgMzUgMzguNSAzNSBDIDQwLjk2NzUwMSAzNSA0MyAzMi45Njc1MDEgNDMgMzAuNSBMIDQzIDE5LjUgQyA0MyAxNy4wMzI0OTkgNDAuOTY3NTAxIDE1IDM4LjUgMTUgQyAzNy45NTk1OTIgMTUgMzcuNDQ2NzQgMTUuMTEyODE4IDM2Ljk2NDg0NCAxNS4yOTEwMTYgQyAzNi43NTQ4MTIgMTEuNDM2MDIgMzQuODYxMTA0IDguMDI0NzU4MiAzMiA1Ljc3OTI5NjkgTCAzMy43MjA3MDMgMy4zNzEwOTM4IEEgMS41MDAxNSAxLjUwMDE1IDAgMCAwIDMyLjUxOTUzMSAwLjk4MjQyMTg4IHogTSAyNCA2IEMgMjkuMTg1MTI3IDYgMzMuMjc2NzI3IDkuOTU3NTEzMiAzMy43OTg4MjggMTUgTCAxNC4yMDExNzIgMTUgQyAxNC43MjMyNzMgOS45NTc1MTMyIDE4LjgxNDg3MyA2IDI0IDYgeiBNIDE5LjUgMTAgQSAxLjUgMS41IDAgMCAwIDE5LjUgMTMgQSAxLjUgMS41IDAgMCAwIDE5LjUgMTAgeiBNIDI4LjUgMTAgQSAxLjUgMS41IDAgMCAwIDI4LjUgMTMgQSAxLjUgMS41IDAgMCAwIDI4LjUgMTAgeiBNIDkuNSAxOCBDIDEwLjM0NjQ5OSAxOCAxMSAxOC42NTM1MDEgMTEgMTkuNSBMIDExIDMwLjUgQyAxMSAzMS4zNDY0OTkgMTAuMzQ2NDk5IDMyIDkuNSAzMiBDIDguNjUzNTAwOSAzMiA4IDMxLjM0NjQ5OSA4IDMwLjUgTCA4IDE5LjUgQyA4IDE4LjY1MzUwMSA4LjY1MzUwMDkgMTggOS41IDE4IHogTSAxNCAxOCBMIDM0IDE4IEwgMzQgMTkuNSBMIDM0IDMwLjUgTCAzNCAzMy41IEMgMzQgMzQuMzQ2NDk5IDMzLjM0NjQ5OSAzNSAzMi41IDM1IEwgMjUgMzUgTCAyMyAzNSBMIDE1LjUgMzUgQyAxNC42NTM1MDEgMzUgMTQgMzQuMzQ2NDk5IDE0IDMzLjUgTCAxNCAzMC41IEwgMTQgMTkuNSBMIDE0IDE4IHogTSAzOC41IDE4IEMgMzkuMzQ2NDk5IDE4IDQwIDE4LjY1MzUwMSA0MCAxOS41IEwgNDAgMzAuNSBDIDQwIDMxLjM0NjQ5OSAzOS4zNDY0OTkgMzIgMzguNSAzMiBDIDM3LjY1MzUwMSAzMiAzNyAzMS4zNDY0OTkgMzcgMzAuNSBMIDM3IDE5LjUgQyAzNyAxOC42NTM1MDEgMzcuNjUzNTAxIDE4IDM4LjUgMTggeiBNIDE3IDM4IEwgMjAgMzggTCAyMCA0MS41IEMgMjAgNDIuMzQ2NDk5IDE5LjM0NjQ5OSA0MyAxOC41IDQzIEMgMTcuNjUzNTAxIDQzIDE3IDQyLjM0NjQ5OSAxNyA0MS41IEwgMTcgMzggeiBNIDI4IDM4IEwgMzEgMzggTCAzMSA0MS41IEMgMzEgNDIuMzQ2NDk5IDMwLjM0NjQ5OSA0MyAyOS41IDQzIEMgMjguNjUzNTAxIDQzIDI4IDQyLjM0NjQ5OSAyOCA0MS41IEwgMjggMzggeiIvPjwvc3ZnPg==", ) ASKUI_WEB_AGENT = Assistant( id="asst_ge3tiojsga3dgnruge3di2u5ov36shedkcslxnmcc", + created_at=now(), name="AskUI Web Vision Agent", avatar="data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHdpZHRoPSI0MDAiIGhlaWdodD0iNDAwIiB2aWV3Qm94PSIwIDAgNDAwIDQwMCIgZmlsbD0ibm9uZSI+CjxwYXRoIGQ9Ik0xMzYuNDQ0IDIyMS41NTZDMTIzLjU1OCAyMjUuMjEzIDExNS4xMDQgMjMxLjYyNSAxMDkuNTM1IDIzOC4wMzJDMTE0Ljg2OSAyMzMuMzY0IDEyMi4wMTQgMjI5LjA4IDEzMS42NTIgMjI2LjM0OEMxNDEuNTEgMjIzLjU1NCAxNDkuOTIgMjIzLjU3NCAxNTYuODY5IDIyNC45MTVWMjE5LjQ4MUMxNTAuOTQxIDIxOC45MzkgMTQ0LjE0NSAyMTkuMzcxIDEzNi40NDQgMjIxLjU1NlpNMTA4Ljk0NiAxNzUuODc2TDYxLjA4OTUgMTg4LjQ4NEM2MS4wODk1IDE4OC40ODQgNjEuOTYxNyAxODkuNzE2IDYzLjU3NjcgMTkxLjM2TDEwNC4xNTMgMTgwLjY2OEMxMDQuMTUzIDE4MC42NjggMTAzLjU3OCAxODguMDc3IDk4LjU4NDcgMTk0LjcwNUMxMDguMDMgMTg3LjU1OSAxMDguOTQ2IDE3NS44NzYgMTA4Ljk0NiAxNzUuODc2Wk0xNDkuMDA1IDI4OC4zNDdDODEuNjU4MiAzMDYuNDg2IDQ2LjAyNzIgMjI4LjQzOCAzNS4yMzk2IDE4Ny45MjhDMzAuMjU1NiAxNjkuMjI5IDI4LjA3OTkgMTU1LjA2NyAyNy41IDE0NS45MjhDMjcuNDM3NyAxNDQuOTc5IDI3LjQ2NjUgMTQ0LjE3OSAyNy41MzM2IDE0My40NDZDMjQuMDQgMTQzLjY1NyAyMi4zNjc0IDE0NS40NzMgMjIuNzA3NyAxNTAuNzIxQzIzLjI4NzYgMTU5Ljg1NSAyNS40NjMzIDE3NC4wMTYgMzAuNDQ3MyAxOTIuNzIxQzQxLjIzMDEgMjMzLjIyNSA3Ni44NjU5IDMxMS4yNzMgMTQ0LjIxMyAyOTMuMTM0QzE1OC44NzIgMjg5LjE4NSAxNjkuODg1IDI4MS45OTIgMTc4LjE1MiAyNzIuODFDMTcwLjUzMiAyNzkuNjkyIDE2MC45OTUgMjg1LjExMiAxNDkuMDA1IDI4OC4zNDdaTTE2MS42NjEgMTI4LjExVjEzMi45MDNIMTg4LjA3N0MxODcuNTM1IDEzMS4yMDYgMTg2Ljk4OSAxMjkuNjc3IDE4Ni40NDcgMTI4LjExSDE2MS42NjFaIiBmaWxsPSIjMkQ0NTUyIi8+CjxwYXRoIGQ9Ik0xOTMuOTgxIDE2Ny41ODRDMjA1Ljg2MSAxNzAuOTU4IDIxMi4xNDQgMTc5LjI4NyAyMTUuNDY1IDE4Ni42NThMMjI4LjcxMSAxOTAuNDJDMjI4LjcxMSAxOTAuNDIgMjI2LjkwNCAxNjQuNjIzIDIwMy41NyAxNTcuOTk1QzE4MS43NDEgMTUxLjc5MyAxNjguMzA4IDE3MC4xMjQgMTY2LjY3NCAxNzIuNDk2QzE3My4wMjQgMTY3Ljk3MiAxODIuMjk3IDE2NC4yNjggMTkzLjk4MSAxNjcuNTg0Wk0yOTkuNDIyIDE4Ni43NzdDMjc3LjU3MyAxODAuNTQ3IDI2NC4xNDUgMTk4LjkxNiAyNjIuNTM1IDIwMS4yNTVDMjY4Ljg5IDE5Ni43MzYgMjc4LjE1OCAxOTMuMDMxIDI4OS44MzcgMTk2LjM2MkMzMDEuNjk4IDE5OS43NDEgMzA3Ljk3NiAyMDguMDYgMzExLjMwNyAyMTUuNDM2TDMyNC41NzIgMjE5LjIxMkMzMjQuNTcyIDIxOS4yMTIgMzIyLjczNiAxOTMuNDEgMjk5LjQyMiAxODYuNzc3Wk0yODYuMjYyIDI1NC43OTVMMTc2LjA3MiAyMjMuOTlDMTc2LjA3MiAyMjMuOTkgMTc3LjI2NSAyMzAuMDM4IDE4MS44NDIgMjM3Ljg2OUwyNzQuNjE3IDI2My44MDVDMjgyLjI1NSAyNTkuMzg2IDI4Ni4yNjIgMjU0Ljc5NSAyODYuMjYyIDI1NC43OTVaTTIwOS44NjcgMzIxLjEwMkMxMjIuNjE4IDI5Ny43MSAxMzMuMTY2IDE4Ni41NDMgMTQ3LjI4NCAxMzMuODY1QzE1My4wOTcgMTEyLjE1NiAxNTkuMDczIDk2LjAyMDMgMTY0LjAyOSA4NS4yMDRDMTYxLjA3MiA4NC41OTUzIDE1OC42MjMgODYuMTUyOSAxNTYuMjAzIDkxLjA3NDZDMTUwLjk0MSAxMDEuNzQ3IDE0NC4yMTIgMTE5LjEyNCAxMzcuNyAxNDMuNDVDMTIzLjU4NiAxOTYuMTI3IDExMy4wMzggMzA3LjI5IDIwMC4yODMgMzMwLjY4MkMyNDEuNDA2IDM0MS42OTkgMjczLjQ0MiAzMjQuOTU1IDI5Ny4zMjMgMjk4LjY1OUMyNzQuNjU1IDMxOS4xOSAyNDUuNzE0IDMzMC43MDEgMjA5Ljg2NyAzMjEuMTAyWiIgZmlsbD0iIzJENDU1MiIvPgo8cGF0aCBkPSJNMTYxLjY2MSAyNjIuMjk2VjIzOS44NjNMOTkuMzMyNCAyNTcuNTM3Qzk5LjMzMjQgMjU3LjUzNyAxMDMuOTM4IDIzMC43NzcgMTM2LjQ0NCAyMjEuNTU2QzE0Ni4zMDIgMjE4Ljc2MiAxNTQuNzEzIDIxOC43ODEgMTYxLjY2MSAyMjAuMTIzVjEyOC4xMUgxOTIuODY5QzE4OS40NzEgMTE3LjYxIDE4Ni4xODQgMTA5LjUyNiAxODMuNDIzIDEwMy45MDlDMTc4Ljg1NiA5NC42MTIgMTc0LjE3NCAxMDAuNzc1IDE2My41NDUgMTA5LjY2NUMxNTYuMDU5IDExNS45MTkgMTM3LjEzOSAxMjkuMjYxIDEwOC42NjggMTM2LjkzM0M4MC4xOTY2IDE0NC42MSA1Ny4xNzkgMTQyLjU3NCA0Ny41NzUyIDE0MC45MTFDMzMuOTYwMSAxMzguNTYyIDI2LjgzODcgMTM1LjU3MiAyNy41MDQ5IDE0NS45MjhDMjguMDg0NyAxNTUuMDYyIDMwLjI2MDUgMTY5LjIyNCAzNS4yNDQ1IDE4Ny45MjhDNDYuMDI3MiAyMjguNDMzIDgxLjY2MyAzMDYuNDgxIDE0OS4wMSAyODguMzQyQzE2Ni42MDIgMjgzLjYwMiAxNzkuMDE5IDI3NC4yMzMgMTg3LjYyNiAyNjIuMjkxSDE2MS42NjFWMjYyLjI5NlpNNjEuMDg0OCAxODguNDg0TDEwOC45NDYgMTc1Ljg3NkMxMDguOTQ2IDE3NS44NzYgMTA3LjU1MSAxOTQuMjg4IDg5LjYwODcgMTk5LjAxOEM3MS42NjE0IDIwMy43NDMgNjEuMDg0OCAxODguNDg0IDYxLjA4NDggMTg4LjQ4NFoiIGZpbGw9IiNFMjU3NEMiLz4KPHBhdGggZD0iTTM0MS43ODYgMTI5LjE3NEMzMjkuMzQ1IDEzMS4zNTUgMjk5LjQ5OCAxMzQuMDcyIDI2Mi42MTIgMTI0LjE4NUMyMjUuNzE2IDExNC4zMDQgMjAxLjIzNiA5Ny4wMjI0IDE5MS41MzcgODguODk5NEMxNzcuNzg4IDc3LjM4MzQgMTcxLjc0IDY5LjM4MDIgMTY1Ljc4OCA4MS40ODU3QzE2MC41MjYgOTIuMTYzIDE1My43OTcgMTA5LjU0IDE0Ny4yODQgMTMzLjg2NkMxMzMuMTcxIDE4Ni41NDMgMTIyLjYyMyAyOTcuNzA2IDIwOS44NjcgMzIxLjA5OEMyOTcuMDkzIDM0NC40NyAzNDMuNTMgMjQyLjkyIDM1Ny42NDQgMTkwLjIzOEMzNjQuMTU3IDE2NS45MTcgMzY3LjAxMyAxNDcuNSAzNjcuNzk5IDEzNS42MjVDMzY4LjY5NSAxMjIuMTczIDM1OS40NTUgMTI2LjA3OCAzNDEuNzg2IDEyOS4xNzRaTTE2Ni40OTcgMTcyLjc1NkMxNjYuNDk3IDE3Mi43NTYgMTgwLjI0NiAxNTEuMzcyIDIwMy41NjUgMTU4QzIyNi44OTkgMTY0LjYyOCAyMjguNzA2IDE5MC40MjUgMjI4LjcwNiAxOTAuNDI1TDE2Ni40OTcgMTcyLjc1NlpNMjIzLjQyIDI2OC43MTNDMTgyLjQwMyAyNTYuNjk4IDE3Ni4wNzcgMjIzLjk5IDE3Ni4wNzcgMjIzLjk5TDI4Ni4yNjIgMjU0Ljc5NkMyODYuMjYyIDI1NC43OTEgMjY0LjAyMSAyODAuNTc4IDIyMy40MiAyNjguNzEzWk0yNjIuMzc3IDIwMS40OTVDMjYyLjM3NyAyMDEuNDk1IDI3Ni4xMDcgMTgwLjEyNiAyOTkuNDIyIDE4Ni43NzNDMzIyLjczNiAxOTMuNDExIDMyNC41NzIgMjE5LjIwOCAzMjQuNTcyIDIxOS4yMDhMMjYyLjM3NyAyMDEuNDk1WiIgZmlsbD0iIzJFQUQzMyIvPgo8cGF0aCBkPSJNMTM5Ljg4IDI0Ni4wNEw5OS4zMzI0IDI1Ny41MzJDOTkuMzMyNCAyNTcuNTMyIDEwMy43MzcgMjMyLjQ0IDEzMy42MDcgMjIyLjQ5NkwxMTAuNjQ3IDEzNi4zM0wxMDguNjYzIDEzNi45MzNDODAuMTkxOCAxNDQuNjExIDU3LjE3NDIgMTQyLjU3NCA0Ny41NzA0IDE0MC45MTFDMzMuOTU1NCAxMzguNTYzIDI2LjgzNCAxMzUuNTcyIDI3LjUwMDEgMTQ1LjkyOUMyOC4wOCAxNTUuMDYzIDMwLjI1NTcgMTY5LjIyNCAzNS4yMzk3IDE4Ny45MjlDNDYuMDIyNSAyMjguNDMzIDgxLjY1ODMgMzA2LjQ4MSAxNDkuMDA1IDI4OC4zNDJMMTUwLjk4OSAyODcuNzE5TDEzOS44OCAyNDYuMDRaTTYxLjA4NDggMTg4LjQ4NUwxMDguOTQ2IDE3NS44NzZDMTA4Ljk0NiAxNzUuODc2IDEwNy41NTEgMTk0LjI4OCA4OS42MDg3IDE5OS4wMThDNzEuNjYxNSAyMDMuNzQzIDYxLjA4NDggMTg4LjQ4NSA2MS4wODQ4IDE4OC40ODVaIiBmaWxsPSIjRDY1MzQ4Ii8+CjxwYXRoIGQ9Ik0yMjUuMjcgMjY5LjE2M0wyMjMuNDE1IDI2OC43MTJDMTgyLjM5OCAyNTYuNjk4IDE3Ni4wNzIgMjIzLjk5IDE3Ni4wNzIgMjIzLjk5TDIzMi44OSAyMzkuODcyTDI2Mi45NzEgMTI0LjI4MUwyNjIuNjA3IDEyNC4xODVDMjI1LjcxMSAxMTQuMzA0IDIwMS4yMzIgOTcuMDIyNCAxOTEuNTMyIDg4Ljg5OTRDMTc3Ljc4MyA3Ny4zODM0IDE3MS43MzUgNjkuMzgwMiAxNjUuNzgzIDgxLjQ4NTdDMTYwLjUyNiA5Mi4xNjMgMTUzLjc5NyAxMDkuNTQgMTQ3LjI4NCAxMzMuODY2QzEzMy4xNzEgMTg2LjU0MyAxMjIuNjIzIDI5Ny43MDYgMjA5Ljg2NyAzMjEuMDk3TDIxMS42NTUgMzIxLjVMMjI1LjI3IDI2OS4xNjNaTTE2Ni40OTcgMTcyLjc1NkMxNjYuNDk3IDE3Mi43NTYgMTgwLjI0NiAxNTEuMzcyIDIwMy41NjUgMTU4QzIyNi44OTkgMTY0LjYyOCAyMjguNzA2IDE5MC40MjUgMjI4LjcwNiAxOTAuNDI1TDE2Ni40OTcgMTcyLjc1NloiIGZpbGw9IiMxRDhEMjIiLz4KPHBhdGggZD0iTTE0MS45NDYgMjQ1LjQ1MUwxMzEuMDcyIDI0OC41MzdDMTMzLjY0MSAyNjMuMDE5IDEzOC4xNjkgMjc2LjkxNyAxNDUuMjc2IDI4OS4xOTVDMTQ2LjUxMyAyODguOTIyIDE0Ny43NCAyODguNjg3IDE0OSAyODguMzQyQzE1Mi4zMDIgMjg3LjQ1MSAxNTUuMzY0IDI4Ni4zNDggMTU4LjMxMiAyODUuMTQ1QzE1MC4zNzEgMjczLjM2MSAxNDUuMTE4IDI1OS43ODkgMTQxLjk0NiAyNDUuNDUxWk0xMzcuNyAxNDMuNDUxQzEzMi4xMTIgMTY0LjMwNyAxMjcuMTEzIDE5NC4zMjYgMTI4LjQ4OSAyMjQuNDM2QzEzMC45NTIgMjIzLjM2NyAxMzMuNTU0IDIyMi4zNzEgMTM2LjQ0NCAyMjEuNTUxTDEzOC40NTcgMjIxLjEwMUMxMzYuMDAzIDE4OC45MzkgMTQxLjMwOCAxNTYuMTY1IDE0Ny4yODQgMTMzLjg2NkMxNDguNzk5IDEyOC4yMjUgMTUwLjMxOCAxMjIuOTc4IDE1MS44MzIgMTE4LjA4NUMxNDkuMzkzIDExOS42MzcgMTQ2Ljc2NyAxMjEuMjI4IDE0My43NzYgMTIyLjg2N0MxNDEuNzU5IDEyOS4wOTMgMTM5LjcyMiAxMzUuODk4IDEzNy43IDE0My40NTFaIiBmaWxsPSIjQzA0QjQxIi8+Cjwvc3ZnPg==", ) ASKUI_WEB_TESTING_AGENT = Assistant( id="asst_ge3tiojsga3dgnruge3di2u5ov36shedkcslxnmcd", + created_at=now(), name="AskUI Web Testing Agent", avatar="data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHhtbG5zOnhsaW5rPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5L3hsaW5rIiB2aWV3Qm94PSIwIDAgMjcgMjciIGFyaWEtaGlkZGVuPSJ0cnVlIiByb2xlPSJpbWciIGNsYXNzPSJpY29uaWZ5IGljb25pZnktLXR3ZW1vamkiIHByZXNlcnZlQXNwZWN0UmF0aW89InhNaWRZTWlkIG1lZXQiPjxwYXRoIGZpbGw9IiNDQ0Q2REQiIGQ9Ik0xMC45MjIgMTAuODEgMTkuMTAyIDIuNjI5bDUuMjIxIDUuMjIxIC04LjE4MSA4LjE4MXoiLz48cGF0aCBmaWxsPSIjNjhFMDkwIiBkPSJNNi4wNzcgMjUuNzk5QzEuODc1IDI1LjUgMS4xMjUgMjIuNTQ3IDEuMjI2IDIwLjk0OWMwLjI0MSAtMy44MDMgMTEuNzAxIC0xMi40MTMgMTEuNzAxIC0xMi40MTNsOS4zODggMS40NDhjMC4wMDEgMCAtMTMuMDQyIDE2LjA0NCAtMTYuMjM3IDE1LjgxNiIvPjxwYXRoIGZpbGw9IiM4ODk5QTYiIGQ9Ik0yNC4yNDUgMi43ODFDMjIuMDU0IDAuNTkgMTkuNTc4IC0wLjQ4NyAxOC43MTUgMC4zNzdjLTAuMDEgMC4wMSAtMC4wMTcgMC4wMjMgLTAuMDI2IDAuMDMzIC0wLjAwNSAwLjAwNSAtMC4wMTEgMC4wMDYgLTAuMDE2IDAuMDExTDEuNzIxIDE3LjM3M2E1LjU3MiA1LjU3MiAwIDAgMCAtMS42NDMgMy45NjZjMCAxLjQ5OCAwLjU4NCAyLjkwNiAxLjY0MyAzLjk2NWE1LjU3MiA1LjU3MiAwIDAgMCAzLjk2NiAxLjY0MyA1LjU3MiA1LjU3MiAwIDAgMCAzLjk2NSAtMS42NDJsMTYuOTUzIC0xNi45NTNjMC4wMDUgLTAuMDA1IDAuMDA3IC0wLjAxMiAwLjAxMSAtMC4wMTcgMC4wMSAtMC4wMDkgMC4wMjIgLTAuMDE1IDAuMDMyIC0wLjAyNSAwLjg2MyAtMC44NjIgLTAuMjE0IC0zLjMzOCAtMi40MDUgLTUuNTI5TTguMDYzIDIzLjcxNGMtMC42MzQgMC42MzQgLTEuNDc4IDAuOTgzIC0yLjM3NCAwLjk4M3MtMS43NDEgLTAuMzUgLTIuMzc1IC0wLjk4NGEzLjMzOCAzLjMzOCAwIDAgMSAtMC45ODQgLTIuMzc1YzAgLTAuODk3IDAuMzUgLTEuNzQgMC45ODMgLTIuMzc0TDE5LjA1OSAzLjIxOGMwLjQ2NyAwLjg1OCAxLjE3IDEuNzk2IDIuMDYyIDIuNjg4czEuODMgMS41OTUgMi42ODggMi4wNjJ6Ii8+PHBhdGggZmlsbD0iIzE3QkY2MyIgZD0iTTIxLjg5NyA5Ljg1OGMtMC4wNDQgMC4yODQgLTEuOTcgMC41NjMgLTQuMjY4IDAuMjU3cy00LjExMiAtMC45MTcgLTQuMDUyIC0xLjM2NSAxLjk3IC0wLjU2MyA0LjI2OCAtMC4yNTcgNC4xMjEgMC45MTggNC4wNTIgMS4zNjVNOC4xMyAxNy40MzVhMC41OTYgMC41OTYgMCAxIDEgLTAuODQyIC0wLjg0MyAwLjU5NiAwLjU5NiAwIDAgMSAwLjg0MiAwLjg0M20yLjQ4OCAxLjk2MWEwLjk3NCAwLjk3NCAwIDEgMSAtMS4zNzYgLTEuMzc3IDAuOTc0IDAuOTc0IDAgMCAxIDEuMzc2IDEuMzc3bTEuMjU4IC0zLjk5M2EwLjkxNiAwLjkxNiAwIDAgMSAtMS4yOTQgLTEuMjk0IDAuOTE1IDAuOTE1IDAgMSAxIDEuMjk0IDEuMjk0bS01LjE1MSA2LjY0NGExLjExNyAxLjExNyAwIDEgMSAtMS41NzkgLTEuNTc5IDEuMTE3IDEuMTE3IDAgMCAxIDEuNTc5IDEuNTc5bTguNTQ3IC02Ljg2OGEwLjc5NCAwLjc5NCAwIDEgMSAtMS4xMjIgLTEuMTIzIDAuNzk0IDAuNzk0IDAgMCAxIDEuMTIyIDEuMTIzbS0wLjkwNSAtMy4yMTZhMC41MiAwLjUyIDAgMSAxIC0wLjczNCAtMC43MzUgMC41MiAwLjUyIDAgMCAxIDAuNzM0IDAuNzM1Ii8+PHBhdGggdHJhbnNmb3JtPSJyb3RhdGUoLTQ1LjAwMSAzMC44MTcgNS4yMjMpIiBmaWxsPSIjQ0NENkREIiBjeD0iMzAuODE3IiBjeT0iNS4yMjMiIHJ4PSIxLjE4NCIgcnk9IjQuODQ3IiBkPSJNMjQuMDAxIDMuOTE3QTAuODg4IDMuNjM1IDAgMCAxIDIzLjExMyA3LjU1M0EwLjg4OCAzLjYzNSAwIDAgMSAyMi4yMjUgMy45MTdBMC44ODggMy42MzUgMCAwIDEgMjQuMDAxIDMuOTE3eiIvPjwvc3ZnPg==", ) diff --git a/src/askui/chat/api/assistants/service.py b/src/askui/chat/api/assistants/service.py index 9a62cc20..cd5393a3 100644 --- a/src/askui/chat/api/assistants/service.py +++ b/src/askui/chat/api/assistants/service.py @@ -1,167 +1,77 @@ from pathlib import Path -from pydantic import BaseModel, Field - -from askui.chat.api.assistants.models import Assistant +from askui.chat.api.assistants.models import ( + Assistant, + AssistantCreateParams, + AssistantModifyParams, +) from askui.chat.api.assistants.seeds import SEEDS -from askui.chat.api.models import DO_NOT_PATCH, DoNotPatch -from askui.utils.api_utils import ListQuery, ListResponse - - -class CreateAssistantRequest(BaseModel): - """Request model for creating an assistant.""" - - name: str | None = None - description: str | None = None - avatar: str | None = Field(default=None, description="URL of the avatar image") - - -class AssistantModifyRequest(BaseModel): - """Request model for updating an assistant.""" - - name: str | None | DoNotPatch = DO_NOT_PATCH - description: str | None | DoNotPatch = DO_NOT_PATCH - avatar: str | None | DoNotPatch = Field( - default=DO_NOT_PATCH, description="URL of the avatar image" - ) +from askui.chat.api.models import AssistantId +from askui.utils.api_utils import ( + ConflictError, + ListQuery, + ListResponse, + NotFoundError, + list_resources, +) class AssistantService: - """Service for managing assistants.""" - def __init__(self, base_dir: Path) -> None: - """Initialize assistant service. - - Args: - base_dir: Base directory to store assistant data - """ self._base_dir = base_dir self._assistants_dir = base_dir / "assistants" - def list_(self, query: ListQuery) -> ListResponse[Assistant]: - """List all available assistants. - - Args: - query (ListQuery): Query parameters for listing assistants - - Returns: - ListResponse[Assistant]: ListResponse containing assistants sorted by - creation date - """ - if not self._assistants_dir.exists(): - return ListResponse(data=[]) - - assistant_files = list(self._assistants_dir.glob("*.json")) - assistants: list[Assistant] = [] - for f in assistant_files: - with f.open("r", encoding="utf-8") as file: - assistants.append(Assistant.model_validate_json(file.read())) - - # Sort by creation date - assistants = sorted( - assistants, key=lambda a: a.created_at, reverse=(query.order == "desc") - ) - - # Apply before/after filters - if query.after: - assistants = [a for a in assistants if a.id > query.after] - if query.before: - assistants = [a for a in assistants if a.id < query.before] - - # Apply limit - assistants = assistants[: query.limit] - - return ListResponse( - data=assistants, - first_id=assistants[0].id if assistants else None, - last_id=assistants[-1].id if assistants else None, - has_more=len(assistant_files) > query.limit, - ) - - def retrieve(self, assistant_id: str) -> Assistant: - """Retrieve an assistant by ID. - - Args: - assistant_id: ID of assistant to retrieve - - Returns: - Assistant object - - Raises: - FileNotFoundError: If assistant doesn't exist - """ - assistant_file = self._assistants_dir / f"{assistant_id}.json" - if not assistant_file.exists(): + def _get_assistant_path(self, assistant_id: AssistantId, new: bool = False) -> Path: + assistant_path = self._assistants_dir / f"{assistant_id}.json" + exists = assistant_path.exists() + if new and exists: + error_msg = f"Assistant {assistant_id} already exists" + raise ConflictError(error_msg) + if not new and not exists: error_msg = f"Assistant {assistant_id} not found" - raise FileNotFoundError(error_msg) + raise NotFoundError(error_msg) + return assistant_path - with assistant_file.open("r", encoding="utf-8") as f: - return Assistant.model_validate_json(f.read()) - - def create(self, request: CreateAssistantRequest) -> Assistant: - """Create a new assistant. + def list_(self, query: ListQuery) -> ListResponse[Assistant]: + return list_resources(self._assistants_dir, query, Assistant) - Args: - request: Assistant creation request + def retrieve(self, assistant_id: AssistantId) -> Assistant: + try: + assistant_path = self._get_assistant_path(assistant_id) + return Assistant.model_validate_json(assistant_path.read_text()) + except FileNotFoundError as e: + error_msg = f"Assistant {assistant_id} not found" + raise NotFoundError(error_msg) from e - Returns: - Created assistant object - """ - assistant = Assistant( - name=request.name, - description=request.description, - ) - self._save(assistant) + def create(self, params: AssistantCreateParams) -> Assistant: + assistant = Assistant.create(params) + self._save(assistant, new=True) return assistant - def _save(self, assistant: Assistant) -> None: - """Save an assistant to the file system.""" - self._assistants_dir.mkdir(parents=True, exist_ok=True) - assistant_file = self._assistants_dir / f"{assistant.id}.json" - with assistant_file.open("w", encoding="utf-8") as f: - f.write(assistant.model_dump_json()) - - def modify(self, assistant_id: str, request: AssistantModifyRequest) -> Assistant: - """Update an existing assistant. - - Args: - assistant_id: ID of assistant to modify - request: Assistant modify request - - Returns: - Updated assistant object - - Raises: - FileNotFoundError: If assistant doesn't exist - """ + def modify( + self, assistant_id: AssistantId, params: AssistantModifyParams + ) -> Assistant: assistant = self.retrieve(assistant_id) - if not isinstance(request.name, DoNotPatch): - assistant.name = request.name - if not isinstance(request.description, DoNotPatch): - assistant.description = request.description - if not isinstance(request.avatar, DoNotPatch): - assistant.avatar = request.avatar - assistant_file = self._assistants_dir / f"{assistant_id}.json" - with assistant_file.open("w", encoding="utf-8") as f: - f.write(assistant.model_dump_json()) - return assistant - - def delete(self, assistant_id: str) -> None: - """Delete an assistant. - - Args: - assistant_id: ID of assistant to delete - - Raises: - FileNotFoundError: If assistant doesn't exist - """ - assistant_file = self._assistants_dir / f"{assistant_id}.json" - if not assistant_file.exists(): + modified = assistant.modify(params) + self._save(modified) + return modified + + def delete(self, assistant_id: AssistantId) -> None: + try: + self._get_assistant_path(assistant_id).unlink() + except FileNotFoundError as e: error_msg = f"Assistant {assistant_id} not found" - raise FileNotFoundError(error_msg) - assistant_file.unlink() + raise NotFoundError(error_msg) from e + + def _save(self, assistant: Assistant, new: bool = False) -> None: + self._assistants_dir.mkdir(parents=True, exist_ok=True) + assistant_file = self._get_assistant_path(assistant.id, new=new) + assistant_file.write_text(assistant.model_dump_json(), encoding="utf-8") def seed(self) -> None: """Seed the assistant service with default assistants.""" for seed in SEEDS: - self._save(seed) + try: + self._save(seed, new=True) + except ConflictError: # noqa: PERF203 + self._save(seed) diff --git a/src/askui/chat/api/dependencies.py b/src/askui/chat/api/dependencies.py index 2005cfe4..b67f627c 100644 --- a/src/askui/chat/api/dependencies.py +++ b/src/askui/chat/api/dependencies.py @@ -59,13 +59,9 @@ def set_env_from_headers( def get_workspace_dir( + askui_workspace: Annotated[str, Header()], settings: Settings = SettingsDep, - askui_workspace: Annotated[str | None, Header()] = None, ) -> Path: - if not askui_workspace: - raise HTTPException( - status_code=400, detail="AskUI-Workspace header is required" - ) return settings.data_dir / "workspaces" / askui_workspace diff --git a/src/askui/chat/api/mcp_configs/models.py b/src/askui/chat/api/mcp_configs/models.py index b74cc3f7..a219437c 100644 --- a/src/askui/chat/api/mcp_configs/models.py +++ b/src/askui/chat/api/mcp_configs/models.py @@ -1,9 +1,10 @@ from typing import Literal from fastmcp.mcp_config import RemoteMCPServer, StdioMCPServer -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel from askui.chat.api.models import McpConfigId +from askui.utils.api_utils import Resource from askui.utils.datetime_utils import UnixDatetime, now from askui.utils.id_utils import generate_time_ordered_id from askui.utils.not_given import NOT_GIVEN, BaseModelWithNotGiven, NotGiven @@ -11,30 +12,30 @@ McpServer = StdioMCPServer | RemoteMCPServer -class McpConfigCreateParams(BaseModel): - """Parameters for creating an MCP configuration.""" +class McpConfigBase(BaseModel): + """Base MCP configuration model.""" name: str mcp_server: McpServer +class McpConfigCreateParams(McpConfigBase): + """Parameters for creating an MCP configuration.""" + + class McpConfigModifyParams(BaseModelWithNotGiven): """Parameters for modifying an MCP configuration.""" name: str | NotGiven = NOT_GIVEN - mcp_server: McpServer | NotGiven = Field(default=NOT_GIVEN) + mcp_server: McpServer | NotGiven = NOT_GIVEN -class McpConfig(BaseModel): +class McpConfig(McpConfigBase, Resource): """An MCP configuration that can be stored and managed.""" - id: McpConfigId = Field( - default_factory=lambda: generate_time_ordered_id("mcp_config") - ) - created_at: UnixDatetime = Field(default_factory=now) - name: str + id: McpConfigId object: Literal["mcp_config"] = "mcp_config" - mcp_server: McpServer = Field(description="The MCP server configuration") + created_at: UnixDatetime @classmethod def create(cls, params: McpConfigCreateParams) -> "McpConfig": diff --git a/src/askui/chat/api/mcp_configs/router.py b/src/askui/chat/api/mcp_configs/router.py index ed30f571..2bd96609 100644 --- a/src/askui/chat/api/mcp_configs/router.py +++ b/src/askui/chat/api/mcp_configs/router.py @@ -18,7 +18,6 @@ def list_mcp_configs( query: ListQuery = ListQueryDep, mcp_config_service: McpConfigService = McpConfigServiceDep, ) -> ListResponse[McpConfig]: - """List all MCP configurations.""" return mcp_config_service.list_(query=query) @@ -40,7 +39,7 @@ def retrieve_mcp_config( return mcp_config_service.retrieve(mcp_config_id) -@router.patch("/{mcp_config_id}", response_model_exclude_none=True) +@router.post("/{mcp_config_id}", response_model_exclude_none=True) def modify_mcp_config( mcp_config_id: McpConfigId, params: McpConfigModifyParams, diff --git a/src/askui/chat/api/mcp_configs/service.py b/src/askui/chat/api/mcp_configs/service.py index 6b3651a3..957ad26c 100644 --- a/src/askui/chat/api/mcp_configs/service.py +++ b/src/askui/chat/api/mcp_configs/service.py @@ -9,56 +9,42 @@ ListQuery, ListResponse, NotFoundError, - list_resource_paths, + list_resources, ) from .models import McpConfig, McpConfigCreateParams, McpConfigId, McpConfigModifyParams class McpConfigService: - """ - Service for managing McpConfig resources with filesystem persistence. - - Args: - base_dir (Path): Base directory for storing MCP configuration data. - """ + """Service for managing McpConfig resources with filesystem persistence.""" def __init__(self, base_dir: Path) -> None: self._base_dir = base_dir self._mcp_configs_dir = base_dir / "mcp_configs" - self._mcp_configs_dir.mkdir(parents=True, exist_ok=True) - def list_( - self, - query: ListQuery, - ) -> ListResponse[McpConfig]: - mcp_config_paths = list_resource_paths(self._mcp_configs_dir, query) - mcp_configs: list[McpConfig] = [] - for f in mcp_config_paths: - try: - mcp_config = McpConfig.model_validate_json( - f.read_text(encoding="utf-8") - ) - mcp_configs.append(mcp_config) - except ValidationError: # noqa: PERF203 - continue - has_more = len(mcp_configs) > query.limit - mcp_configs = mcp_configs[: query.limit] - return ListResponse( - data=mcp_configs, - first_id=mcp_configs[0].id if mcp_configs else None, - last_id=mcp_configs[-1].id if mcp_configs else None, - has_more=has_more, - ) + def _get_mcp_config_path( + self, mcp_config_id: McpConfigId, new: bool = False + ) -> Path: + mcp_config_path = self._mcp_configs_dir / f"{mcp_config_id}.json" + exists = mcp_config_path.exists() + if new and exists: + error_msg = f"MCP configuration {mcp_config_id} already exists" + raise ConflictError(error_msg) + if not new and not exists: + error_msg = f"MCP configuration {mcp_config_id} not found" + raise NotFoundError(error_msg) + return mcp_config_path + + def list_(self, query: ListQuery) -> ListResponse[McpConfig]: + return list_resources(self._mcp_configs_dir, query, McpConfig) def retrieve(self, mcp_config_id: McpConfigId) -> McpConfig: - mcp_config_file = self._mcp_configs_dir / f"{mcp_config_id}.json" - if not mcp_config_file.exists(): + try: + mcp_config_path = self._get_mcp_config_path(mcp_config_id) + return McpConfig.model_validate_json(mcp_config_path.read_text()) + except FileNotFoundError as e: error_msg = f"MCP configuration {mcp_config_id} not found" - raise NotFoundError(error_msg) - return McpConfig.model_validate_json( - mcp_config_file.read_text(encoding="utf-8") - ) + raise NotFoundError(error_msg) from e def _check_limit(self) -> None: limit = LIST_LIMIT_MAX @@ -86,22 +72,18 @@ def modify( return modified def delete(self, mcp_config_id: McpConfigId) -> None: - mcp_config_file = self._mcp_configs_dir / f"{mcp_config_id}.json" - if not mcp_config_file.exists(): + try: + self._get_mcp_config_path(mcp_config_id).unlink() + except FileNotFoundError as e: error_msg = f"MCP configuration {mcp_config_id} not found" - raise NotFoundError(error_msg) - mcp_config_file.unlink() + raise NotFoundError(error_msg) from e def _save(self, mcp_config: McpConfig, new: bool = False) -> None: - """Save an MCP configuration to the file system.""" self._mcp_configs_dir.mkdir(parents=True, exist_ok=True) - mcp_config_file = self._mcp_configs_dir / f"{mcp_config.id}.json" - if new and mcp_config_file.exists(): - error_msg = f"MCP configuration {mcp_config.id} already exists" - raise ConflictError(error_msg) - with mcp_config_file.open("w", encoding="utf-8") as f: - f.write( - mcp_config.model_dump_json( - exclude_unset=True, exclude_none=True, exclude_defaults=True - ) - ) + mcp_config_file = self._get_mcp_config_path(mcp_config.id, new=new) + mcp_config_file.write_text( + mcp_config.model_dump_json( + exclude_unset=True, exclude_none=True, exclude_defaults=True + ), + encoding="utf-8", + ) diff --git a/src/askui/chat/api/messages/models.py b/src/askui/chat/api/messages/models.py new file mode 100644 index 00000000..8dd9cc67 --- /dev/null +++ b/src/askui/chat/api/messages/models.py @@ -0,0 +1,33 @@ +from typing import Literal + +from askui.chat.api.models import AssistantId, MessageId, RunId, ThreadId +from askui.models.shared.agent_message_param import MessageParam +from askui.utils.api_utils import Resource +from askui.utils.datetime_utils import UnixDatetime, now +from askui.utils.id_utils import generate_time_ordered_id + + +class MessageBase(MessageParam): + assistant_id: AssistantId | None = None + object: Literal["thread.message"] = "thread.message" + role: Literal["user", "assistant"] + run_id: RunId | None = None + + +class MessageCreateParams(MessageBase): + pass + + +class Message(MessageBase, Resource): + id: MessageId + created_at: UnixDatetime + thread_id: ThreadId + + @classmethod + def create(cls, thread_id: ThreadId, params: MessageCreateParams) -> "Message": + return cls( + id=generate_time_ordered_id("msg"), + created_at=now(), + thread_id=thread_id, + **params.model_dump(), + ) diff --git a/src/askui/chat/api/messages/router.py b/src/askui/chat/api/messages/router.py index 51e07d8b..317acab6 100644 --- a/src/askui/chat/api/messages/router.py +++ b/src/askui/chat/api/messages/router.py @@ -1,13 +1,10 @@ -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, status from askui.chat.api.messages.dependencies import MessageServiceDep -from askui.chat.api.messages.service import ( - Message, - MessageCreateRequest, - MessageService, -) +from askui.chat.api.messages.models import Message, MessageCreateParams +from askui.chat.api.messages.service import MessageService from askui.chat.api.models import ListQueryDep, MessageId, ThreadId -from askui.utils.api_utils import ListQuery, ListResponse, NotFoundError +from askui.utils.api_utils import ListQuery, ListResponse router = APIRouter(prefix="/threads/{thread_id}/messages", tags=["messages"]) @@ -18,24 +15,16 @@ def list_messages( query: ListQuery = ListQueryDep, message_service: MessageService = MessageServiceDep, ) -> ListResponse[Message]: - """List all messages in a thread.""" - try: - return message_service.list_(thread_id, query=query) - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e + return message_service.list_(thread_id, query=query) @router.post("", status_code=status.HTTP_201_CREATED) async def create_message( thread_id: ThreadId, - request: MessageCreateRequest, + params: MessageCreateParams, message_service: MessageService = MessageServiceDep, ) -> Message: - """Create a new message in a thread.""" - try: - return message_service.create(thread_id=thread_id, request=request) - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e + return message_service.create(thread_id=thread_id, params=params) @router.get("/{message_id}") @@ -44,11 +33,7 @@ def retrieve_message( message_id: MessageId, message_service: MessageService = MessageServiceDep, ) -> Message: - """Get a specific message from a thread.""" - try: - return message_service.retrieve(thread_id, message_id) - except NotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e + return message_service.retrieve(thread_id, message_id) @router.delete("/{message_id}", status_code=status.HTTP_204_NO_CONTENT) @@ -57,8 +42,4 @@ def delete_message( message_id: MessageId, message_service: MessageService = MessageServiceDep, ) -> None: - """Delete a message from a thread.""" - try: - message_service.delete(thread_id, message_id) - except NotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e + message_service.delete(thread_id, message_id) diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py index 269f6c92..12c59339 100644 --- a/src/askui/chat/api/messages/service.py +++ b/src/askui/chat/api/messages/service.py @@ -1,111 +1,62 @@ -from datetime import datetime, timezone from pathlib import Path -from typing import Literal -from pydantic import Field, ValidationError - -from askui.chat.api.models import AssistantId, MessageId, RunId, ThreadId -from askui.models.shared.agent_message_param import MessageParam +from askui.chat.api.messages.models import Message, MessageCreateParams +from askui.chat.api.models import MessageId, ThreadId from askui.utils.api_utils import ( ConflictError, ListQuery, ListResponse, NotFoundError, - list_resource_paths, + list_resources, ) -from askui.utils.datetime_utils import UnixDatetime -from askui.utils.id_utils import generate_time_ordered_id - - -class MessageBase(MessageParam): - assistant_id: AssistantId | None = None - object: Literal["thread.message"] = "thread.message" - role: Literal["user", "assistant"] - run_id: RunId | None = None - - -class Message(MessageBase): - id: MessageId = Field(default_factory=lambda: generate_time_ordered_id("msg")) - thread_id: ThreadId - created_at: UnixDatetime = Field( - default_factory=lambda: datetime.now(tz=timezone.utc) - ) - - -class MessageCreateRequest(MessageBase): - pass class MessageService: def __init__(self, base_dir: Path) -> None: - """Initialize message service. - - Args: - base_dir: Base directory to store message data - """ self._base_dir = base_dir - self._base_messages_dir = base_dir / "messages" - - def create(self, thread_id: ThreadId, request: MessageCreateRequest) -> Message: - new_message = Message( - **request.model_dump(), - thread_id=thread_id, - ) - self._save(new_message, new=True) - return new_message - - def delete(self, thread_id: ThreadId, message_id: MessageId) -> None: - message_file = self._get_message_path(thread_id, message_id) - message_file.unlink() - - def list_(self, thread_id: ThreadId, query: ListQuery) -> ListResponse[Message]: - messages_dir = self.get_thread_messages_dir(thread_id) - if not messages_dir.exists(): - return ListResponse(data=[]) - message_paths = list_resource_paths(messages_dir, query) - messages: list[Message] = [] - for message_file in message_paths: - try: - msg = Message.model_validate_json( - message_file.read_text(encoding="utf-8") - ) - messages.append(msg) - except ValidationError: # noqa: PERF203 - continue - has_more = len(messages) > query.limit - messages = messages[: query.limit] - return ListResponse( - data=messages, - first_id=messages[0].id if messages else None, - last_id=messages[-1].id if messages else None, - has_more=has_more, - ) - - def retrieve(self, thread_id: ThreadId, message_id: MessageId) -> Message: - message_file = self._get_message_path(thread_id, message_id) - return Message.model_validate_json(message_file.read_text(encoding="utf-8")) - - def get_thread_messages_dir(self, thread_id: ThreadId) -> Path: - """Get the directory path for a specific message.""" - return self._base_messages_dir / thread_id + def get_messages_dir(self, thread_id: ThreadId) -> Path: + return self._base_dir / "threads" / thread_id / "messages" def _get_message_path( self, thread_id: ThreadId, message_id: MessageId, new: bool = False ) -> Path: - """Get the file path for a specific message.""" - message_path = self.get_thread_messages_dir(thread_id) / f"{message_id}.json" - if new and message_path.exists(): + message_path = self.get_messages_dir(thread_id) / f"{message_id}.json" + exists = message_path.exists() + if new and exists: error_msg = f"Message {message_id} already exists in thread {thread_id}" raise ConflictError(error_msg) - if not new and not message_path.exists(): + if not new and not exists: error_msg = f"Message {message_id} not found in thread {thread_id}" raise NotFoundError(error_msg) return message_path + def create(self, thread_id: ThreadId, params: MessageCreateParams) -> Message: + new_message = Message.create(thread_id, params) + self._save(new_message, new=True) + return new_message + + def list_(self, thread_id: ThreadId, query: ListQuery) -> ListResponse[Message]: + messages_dir = self.get_messages_dir(thread_id) + return list_resources(messages_dir, query, Message) + + def retrieve(self, thread_id: ThreadId, message_id: MessageId) -> Message: + try: + message_file = self._get_message_path(thread_id, message_id) + return Message.model_validate_json(message_file.read_text(encoding="utf-8")) + except FileNotFoundError as e: + error_msg = f"Message {message_id} not found in thread {thread_id}" + raise NotFoundError(error_msg) from e + + def delete(self, thread_id: ThreadId, message_id: MessageId) -> None: + try: + self._get_message_path(thread_id, message_id).unlink() + except FileNotFoundError as e: + error_msg = f"Message {message_id} not found in thread {thread_id}" + raise NotFoundError(error_msg) from e + def _save(self, message: Message, new: bool = False) -> None: - """Save a single message to its own JSON file.""" - messages_dir = self.get_thread_messages_dir(message.thread_id) + messages_dir = self.get_messages_dir(message.thread_id) messages_dir.mkdir(parents=True, exist_ok=True) message_file = self._get_message_path(message.thread_id, message.id, new=new) message_file.write_text(message.model_dump_json(), encoding="utf-8") diff --git a/src/askui/chat/api/models.py b/src/askui/chat/api/models.py index 2e222c2d..85235550 100644 --- a/src/askui/chat/api/models.py +++ b/src/askui/chat/api/models.py @@ -12,10 +12,3 @@ ListQueryDep = Depends(ListQuery) - - -class DoNotPatch(BaseModel): - pass - - -DO_NOT_PATCH = DoNotPatch() diff --git a/src/askui/chat/api/runs/models.py b/src/askui/chat/api/runs/models.py index 9d851547..bbecab0c 100644 --- a/src/askui/chat/api/runs/models.py +++ b/src/askui/chat/api/runs/models.py @@ -4,7 +4,8 @@ from pydantic import BaseModel, Field, computed_field from askui.chat.api.models import AssistantId, RunId, ThreadId -from askui.utils.datetime_utils import UnixDatetime +from askui.utils.api_utils import Resource +from askui.utils.datetime_utils import UnixDatetime, now from askui.utils.id_utils import generate_time_ordered_id RunStatus = Literal[ @@ -19,27 +20,48 @@ class RunError(BaseModel): + """Error information for a failed run.""" + message: str code: Literal["server_error", "rate_limit_exceeded", "invalid_prompt"] -class Run(BaseModel): +class RunBase(BaseModel): + """Base run model.""" + assistant_id: AssistantId - cancelled_at: UnixDatetime | None = None - completed_at: UnixDatetime | None = None - created_at: UnixDatetime = Field( - default_factory=lambda: datetime.now(tz=timezone.utc) - ) - expires_at: UnixDatetime = Field( - default_factory=lambda: datetime.now(tz=timezone.utc) + timedelta(minutes=10) - ) - failed_at: UnixDatetime | None = None - id: RunId = Field(default_factory=lambda: generate_time_ordered_id("run")) - last_error: RunError | None = None + + +class RunCreateParams(RunBase): + """Parameters for creating a run.""" + + stream: bool = False + + +class Run(RunBase, Resource): + """A run execution within a thread.""" + + id: RunId object: Literal["thread.run"] = "thread.run" - started_at: UnixDatetime | None = None thread_id: ThreadId + created_at: UnixDatetime + expires_at: UnixDatetime + started_at: UnixDatetime | None = None + completed_at: UnixDatetime | None = None + failed_at: UnixDatetime | None = None + cancelled_at: UnixDatetime | None = None tried_cancelling_at: UnixDatetime | None = None + last_error: RunError | None = None + + @classmethod + def create(cls, thread_id: ThreadId, params: RunCreateParams) -> "Run": + return cls( + id=generate_time_ordered_id("run"), + thread_id=thread_id, + created_at=now(), + expires_at=datetime.now(tz=timezone.utc) + timedelta(minutes=10), + **params.model_dump(exclude={"stream"}), + ) @computed_field # type: ignore[prop-decorator] @property diff --git a/src/askui/chat/api/runs/router.py b/src/askui/chat/api/runs/router.py index 74db267e..e7aa687f 100644 --- a/src/askui/chat/api/runs/router.py +++ b/src/askui/chat/api/runs/router.py @@ -1,20 +1,12 @@ from collections.abc import AsyncGenerator from typing import Annotated -from fastapi import ( - APIRouter, - BackgroundTasks, - Body, - HTTPException, - Path, - Response, - status, -) +from fastapi import APIRouter, BackgroundTasks, Body, Path, Response, status from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel from askui.chat.api.models import ListQueryDep, RunId, ThreadId -from askui.chat.api.runs.service import CreateRunRequest +from askui.chat.api.runs.models import RunCreateParams from askui.utils.api_utils import ListQuery, ListResponse from .dependencies import RunServiceDep @@ -27,15 +19,12 @@ @router.post("") async def create_run( thread_id: Annotated[ThreadId, Path(...)], - request: Annotated[CreateRunRequest, Body(...)], + params: RunCreateParams, background_tasks: BackgroundTasks, run_service: RunService = RunServiceDep, ) -> Response: - """ - Create a new run for a given thread. - """ - stream = request.stream - run, async_generator = await run_service.create(thread_id, request) + stream = params.stream + run, async_generator = await run_service.create(thread_id, params) if stream: async def sse_event_stream() -> AsyncGenerator[str, None]: @@ -63,16 +52,11 @@ async def _run_async_generator() -> None: @router.get("/{run_id}") def retrieve_run( + thread_id: Annotated[ThreadId, Path(...)], run_id: Annotated[RunId, Path(...)], run_service: RunService = RunServiceDep, ) -> Run: - """ - Retrieve a run by its ID. - """ - try: - return run_service.retrieve(run_id) - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e + return run_service.retrieve(thread_id, run_id) @router.get("") @@ -81,21 +65,13 @@ def list_runs( query: ListQuery = ListQueryDep, run_service: RunService = RunServiceDep, ) -> ListResponse[Run]: - """ - List runs, optionally filtered by thread. - """ return run_service.list_(thread_id, query=query) @router.post("/{run_id}/cancel") def cancel_run( + thread_id: Annotated[ThreadId, Path(...)], run_id: Annotated[RunId, Path(...)], run_service: RunService = RunServiceDep, ) -> Run: - """ - Cancel a run by its ID. - """ - try: - return run_service.cancel(run_id) - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e + return run_service.cancel(thread_id, run_id) diff --git a/src/askui/chat/api/runs/runner/events/message_events.py b/src/askui/chat/api/runs/runner/events/message_events.py index 51807d84..54a5c802 100644 --- a/src/askui/chat/api/runs/runner/events/message_events.py +++ b/src/askui/chat/api/runs/runner/events/message_events.py @@ -1,6 +1,6 @@ from typing import Literal -from askui.chat.api.messages.service import Message +from askui.chat.api.messages.models import Message from askui.chat.api.runs.runner.events.event_base import EventBase diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index 38354ebf..7c82b291 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -21,7 +21,9 @@ ) from askui.chat.api.mcp_configs.models import McpConfig from askui.chat.api.mcp_configs.service import McpConfigService -from askui.chat.api.messages.service import MessageCreateRequest, MessageService +from askui.chat.api.messages.models import MessageCreateParams +from askui.chat.api.messages.service import MessageService +from askui.chat.api.models import RunId, ThreadId from askui.chat.api.runs.models import Run, RunError from askui.chat.api.runs.runner.events.done_events import DoneEvent from askui.chat.api.runs.runner.events.error_events import ( @@ -41,7 +43,12 @@ from askui.models.shared.agent_on_message_cb import OnMessageCbParam from askui.models.shared.tools import ToolCollection from askui.tools.pynput_agent_os import PynputAgentOs -from askui.utils.api_utils import LIST_LIMIT_MAX, ListQuery +from askui.utils.api_utils import ( + LIST_LIMIT_MAX, + ConflictError, + ListQuery, + NotFoundError, +) from askui.utils.image_utils import ImageSource from askui.web_agent import WebVisionAgent from askui.web_testing_agent import WebTestingAgent @@ -87,14 +94,38 @@ class Runner: def __init__(self, run: Run, base_dir: Path) -> None: self._run = run self._base_dir = base_dir - self._runs_dir = base_dir / "runs" self._msg_service = MessageService(self._base_dir) self._agent_os = PynputAgentOs() + def get_runs_dir(self, thread_id: ThreadId) -> Path: + return self._base_dir / "threads" / thread_id / "runs" + + def _get_run_path( + self, thread_id: ThreadId, run_id: RunId, new: bool = False + ) -> Path: + run_path = self.get_runs_dir(thread_id) / f"{run_id}.json" + if new and run_path.exists(): + error_msg = f"Run {run_id} already exists in thread {thread_id}" + raise ConflictError(error_msg) + if not new and not run_path.exists(): + error_msg = f"Run {run_id} not found in thread {thread_id}" + raise NotFoundError(error_msg) + return run_path + + def _save(self, run: Run, new: bool = False) -> None: + runs_dir = self.get_runs_dir(run.thread_id) + runs_dir.mkdir(parents=True, exist_ok=True) + run_file = self._get_run_path(run.thread_id, run.id, new=new) + run_file.write_text(run.model_dump_json(), encoding="utf-8") + + def _retrieve(self) -> Run: + run_file = self._get_run_path(self._run.thread_id, self._run.id) + return Run.model_validate_json(run_file.read_text(encoding="utf-8")) + async def _run_human_agent(self, send_stream: ObjectStream[Events]) -> None: message = self._msg_service.create( thread_id=self._run.thread_id, - request=MessageCreateRequest( + params=MessageCreateParams( role="user", content=[ TextBlockParam( @@ -116,7 +147,7 @@ async def _run_human_agent(self, send_stream: ObjectStream[Events]) -> None: await anyio.sleep(0.1) recorded_events: list[InputEvent] = [] while True: - updated_run = self._retrieve_run() + updated_run = self._retrieve() if self._should_abort(updated_run): break while event := self._agent_os.poll_event(): @@ -131,7 +162,7 @@ async def _run_human_agent(self, send_stream: ObjectStream[Events]) -> None: ) message = self._msg_service.create( thread_id=self._run.thread_id, - request=MessageCreateRequest( + params=MessageCreateParams( role="user", content=[ ImageBlockParam( @@ -165,7 +196,7 @@ async def _run_human_agent(self, send_stream: ObjectStream[Events]) -> None: text = "Nevermind, I didn't do anything." message = self._msg_service.create( thread_id=self._run.thread_id, - request=MessageCreateRequest( + params=MessageCreateParams( role="user", content=[ TextBlockParam( @@ -242,7 +273,7 @@ async def async_on_message( ) -> MessageParam | None: message = self._msg_service.create( thread_id=self._run.thread_id, - request=MessageCreateRequest( + params=MessageCreateParams( assistant_id=self._run.assistant_id if on_message_cb_param.message.role == "assistant" else None, @@ -257,7 +288,7 @@ async def async_on_message( event="thread.message.created", ) ) - updated_run = self._retrieve_run() + updated_run = self._retrieve() if self._should_abort(updated_run): return None return on_message_cb_param.message @@ -336,10 +367,10 @@ async def run( send_stream, mcp_client, ) - updated_run = self._retrieve_run() + updated_run = self._retrieve() if updated_run.status == "in_progress": updated_run.completed_at = datetime.now(tz=timezone.utc) - self._update_run_file(updated_run) + self._save(updated_run) await send_stream.send( RunEvent( data=updated_run, @@ -354,7 +385,7 @@ async def run( ) ) updated_run.cancelled_at = datetime.now(tz=timezone.utc) - self._update_run_file(updated_run) + self._save(updated_run) await send_stream.send( RunEvent( data=updated_run, @@ -371,10 +402,10 @@ async def run( await send_stream.send(DoneEvent()) except Exception as e: # noqa: BLE001 logger.exception("Exception in runner") - updated_run = self._retrieve_run() + updated_run = self._retrieve() updated_run.failed_at = datetime.now(tz=timezone.utc) updated_run.last_error = RunError(message=str(e), code="server_error") - self._update_run_file(updated_run) + self._save(updated_run) await send_stream.send( RunEvent( data=updated_run, @@ -389,17 +420,7 @@ async def run( def _mark_run_as_started(self) -> None: self._run.started_at = datetime.now(tz=timezone.utc) - self._update_run_file(self._run) + self._save(self._run) def _should_abort(self, run: Run) -> bool: return run.status in ("cancelled", "cancelling", "expired") - - def _update_run_file(self, run: Run) -> None: - run_file = self._runs_dir / f"{run.thread_id}__{run.id}.json" - with run_file.open("w", encoding="utf-8") as f: - f.write(run.model_dump_json()) - - def _retrieve_run(self) -> Run: - run_file = self._runs_dir / f"{self._run.thread_id}__{self._run.id}.json" - with run_file.open("r", encoding="utf-8") as f: - return Run.model_validate_json(f.read()) diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py index 6c66a143..3f5f0be2 100644 --- a/src/askui/chat/api/runs/service.py +++ b/src/askui/chat/api/runs/service.py @@ -3,77 +3,77 @@ from pathlib import Path import anyio -from pydantic import BaseModel -from askui.chat.api.models import AssistantId, RunId, ThreadId -from askui.chat.api.runs.models import Run +from askui.chat.api.models import RunId, ThreadId +from askui.chat.api.runs.models import Run, RunCreateParams from askui.chat.api.runs.runner.events import Events from askui.chat.api.runs.runner.events.done_events import DoneEvent from askui.chat.api.runs.runner.events.error_events import ErrorEvent from askui.chat.api.runs.runner.events.run_events import RunEvent from askui.chat.api.runs.runner.runner import Runner -from askui.utils.api_utils import ListQuery, ListResponse - - -class CreateRunRequest(BaseModel): - assistant_id: AssistantId - stream: bool = True +from askui.utils.api_utils import ( + ConflictError, + ListQuery, + ListResponse, + NotFoundError, + list_resources, +) class RunService: - """ - Service for managing runs. Handles creation, retrieval, listing, and - cancellation of runs. - """ + """Service for managing Run resources with filesystem persistence.""" def __init__(self, base_dir: Path) -> None: self._base_dir = base_dir - self._runs_dir = base_dir / "runs" - def _run_path(self, thread_id: ThreadId, run_id: RunId) -> Path: - return self._runs_dir / f"{thread_id}__{run_id}.json" - - def _create_run(self, thread_id: ThreadId, request: CreateRunRequest) -> Run: - run = Run(thread_id=thread_id, assistant_id=request.assistant_id) - self._runs_dir.mkdir(parents=True, exist_ok=True) - self._update_run_file(run) + def get_runs_dir(self, thread_id: ThreadId) -> Path: + return self._base_dir / "threads" / thread_id / "runs" + + def _get_run_path( + self, thread_id: ThreadId, run_id: RunId, new: bool = False + ) -> Path: + run_path = self.get_runs_dir(thread_id) / f"{run_id}.json" + exists = run_path.exists() + if new and exists: + error_msg = f"Run {run_id} already exists in thread {thread_id}" + raise ConflictError(error_msg) + if not new and not exists: + error_msg = f"Run {run_id} not found in thread {thread_id}" + raise NotFoundError(error_msg) + return run_path + + def _create(self, thread_id: ThreadId, params: RunCreateParams) -> Run: + run = Run.create(thread_id, params) + self._save(run, new=True) return run async def create( - self, thread_id: ThreadId, request: CreateRunRequest + self, thread_id: ThreadId, params: RunCreateParams ) -> tuple[Run, AsyncGenerator[Events, None]]: - run = self._create_run(thread_id, request) + run = self._create(thread_id, params) send_stream, receive_stream = anyio.create_memory_object_stream[Events]() runner = Runner(run, self._base_dir) async def event_generator() -> AsyncGenerator[Events, None]: try: yield RunEvent( - # run already in progress instead of queued which is - # different from OpenAI data=run, event="thread.run.created", ) yield RunEvent( - # run already in progress instead of queued which is - # different from OpenAI data=run, event="thread.run.queued", ) - # Start the runner in a background task async def run_runner() -> None: try: await runner.run(send_stream) # type: ignore[arg-type] finally: await send_stream.aclose() - # Create a task group to manage the runner and event processing async with anyio.create_task_group() as tg: - # Start the runner in the background tg.start_soon(run_runner) - # Process events from the stream while True: try: event = await receive_stream.receive() @@ -89,71 +89,28 @@ async def run_runner() -> None: return run, event_generator() - def _update_run_file(self, run: Run) -> None: - run_file = self._run_path(run.thread_id, run.id) - with run_file.open("w", encoding="utf-8") as f: - f.write(run.model_dump_json()) - - def retrieve(self, run_id: RunId) -> Run: - # Find the file by run_id - for f in self._runs_dir.glob(f"*__{run_id}.json"): - with f.open("r", encoding="utf-8") as file: - return Run.model_validate_json(file.read()) - error_msg = f"Run {run_id} not found" - raise FileNotFoundError(error_msg) + def retrieve(self, thread_id: ThreadId, run_id: RunId) -> Run: + try: + run_file = self._get_run_path(thread_id, run_id) + return Run.model_validate_json(run_file.read_text()) + except FileNotFoundError as e: + error_msg = f"Run {run_id} not found in thread {thread_id}" + raise NotFoundError(error_msg) from e def list_(self, thread_id: ThreadId, query: ListQuery) -> ListResponse[Run]: - """List runs, optionally filtered by thread. - - Args: - thread_id (ThreadId): ID of thread to filter runs by - query (ListQuery): Query parameters for listing runs - - Returns: - ListResponse[Run]: ListResponse containing runs sorted by creation date - """ - if not self._runs_dir.exists(): - return ListResponse(data=[]) - - run_files = list(self._runs_dir.glob(f"{thread_id}__*.json")) - - runs: list[Run] = [] - for f in run_files: - with f.open("r", encoding="utf-8") as file: - runs.append(Run.model_validate_json(file.read())) - - # Sort by creation date - runs = sorted( - runs, - key=lambda r: r.created_at, - reverse=(query.order == "desc"), - ) - - # Apply before/after filters - if query.after: - runs = [r for r in runs if r.id > query.after] - if query.before: - runs = [r for r in runs if r.id < query.before] - - # Apply limit - runs = runs[: query.limit] - - return ListResponse( - data=runs, - first_id=runs[0].id if runs else None, - last_id=runs[-1].id if runs else None, - has_more=len(run_files) > query.limit, - ) - - def cancel(self, run_id: RunId) -> Run: - run = self.retrieve(run_id) + runs_dir = self.get_runs_dir(thread_id) + return list_resources(runs_dir, query, Run) + + def cancel(self, thread_id: ThreadId, run_id: RunId) -> Run: + run = self.retrieve(thread_id, run_id) if run.status in ("cancelled", "cancelling", "completed", "failed", "expired"): return run run.tried_cancelling_at = datetime.now(tz=timezone.utc) - for f in self._runs_dir.glob(f"*__{run_id}.json"): - with f.open("w", encoding="utf-8") as file: - file.write(run.model_dump_json()) - return run - # Find the file by run_id - error_msg = f"Run {run_id} not found" - raise FileNotFoundError(error_msg) + self._save(run) + return run + + def _save(self, run: Run, new: bool = False) -> None: + runs_dir = self.get_runs_dir(run.thread_id) + runs_dir.mkdir(parents=True, exist_ok=True) + run_file = self._get_run_path(run.thread_id, run.id, new=new) + run_file.write_text(run.model_dump_json(), encoding="utf-8") diff --git a/src/askui/chat/api/threads/models.py b/src/askui/chat/api/threads/models.py new file mode 100644 index 00000000..6ee1931f --- /dev/null +++ b/src/askui/chat/api/threads/models.py @@ -0,0 +1,52 @@ +from typing import Literal + +from pydantic import BaseModel + +from askui.chat.api.messages.models import MessageCreateParams +from askui.chat.api.models import ThreadId +from askui.utils.api_utils import Resource +from askui.utils.datetime_utils import UnixDatetime, now +from askui.utils.id_utils import generate_time_ordered_id +from askui.utils.not_given import NOT_GIVEN, BaseModelWithNotGiven, NotGiven + + +class ThreadBase(BaseModel): + """Base thread model.""" + + name: str | None = None + + +class ThreadCreateParams(ThreadBase): + """Parameters for creating a thread.""" + + messages: list[MessageCreateParams] | None = None + + +class ThreadModifyParams(BaseModelWithNotGiven): + """Parameters for modifying a thread.""" + + name: str | None | NotGiven = NOT_GIVEN + + +class Thread(ThreadBase, Resource): + """A chat thread/session.""" + + id: ThreadId + object: Literal["thread"] = "thread" + created_at: UnixDatetime + + @classmethod + def create(cls, params: ThreadCreateParams) -> "Thread": + return cls( + id=generate_time_ordered_id("thread"), + created_at=now(), + **params.model_dump(exclude={"messages"}), + ) + + def modify(self, params: ThreadModifyParams) -> "Thread": + return Thread.model_validate( + { + **self.model_dump(), + **params.model_dump(), + } + ) diff --git a/src/askui/chat/api/threads/router.py b/src/askui/chat/api/threads/router.py index fd3fcfce..5808bac9 100644 --- a/src/askui/chat/api/threads/router.py +++ b/src/askui/chat/api/threads/router.py @@ -1,13 +1,9 @@ -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, status from askui.chat.api.models import ListQueryDep, ThreadId from askui.chat.api.threads.dependencies import ThreadServiceDep -from askui.chat.api.threads.service import ( - Thread, - ThreadCreateRequest, - ThreadModifyRequest, - ThreadService, -) +from askui.chat.api.threads.models import Thread, ThreadCreateParams, ThreadModifyParams +from askui.chat.api.threads.service import ThreadService from askui.utils.api_utils import ListQuery, ListResponse router = APIRouter(prefix="/threads", tags=["threads"]) @@ -18,17 +14,15 @@ def list_threads( query: ListQuery = ListQueryDep, thread_service: ThreadService = ThreadServiceDep, ) -> ListResponse[Thread]: - """List all threads.""" return thread_service.list_(query=query) @router.post("", status_code=status.HTTP_201_CREATED) def create_thread( - request: ThreadCreateRequest, + params: ThreadCreateParams, thread_service: ThreadService = ThreadServiceDep, ) -> Thread: - """Create a new thread.""" - return thread_service.create(request=request) + return thread_service.create(params) @router.get("/{thread_id}") @@ -36,30 +30,21 @@ def retrieve_thread( thread_id: ThreadId, thread_service: ThreadService = ThreadServiceDep, ) -> Thread: - """Get a thread by ID.""" - try: - return thread_service.retrieve(thread_id) - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e + return thread_service.retrieve(thread_id) -@router.delete("/{thread_id}", status_code=status.HTTP_204_NO_CONTENT) -def delete_thread( +@router.post("/{thread_id}") +def modify_thread( thread_id: ThreadId, + params: ThreadModifyParams, thread_service: ThreadService = ThreadServiceDep, -) -> None: - """Delete a thread.""" - try: - thread_service.delete(thread_id) - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e +) -> Thread: + return thread_service.modify(thread_id, params) -@router.post("/{thread_id}") -def modify_thread( +@router.delete("/{thread_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_thread( thread_id: ThreadId, - request: ThreadModifyRequest, thread_service: ThreadService = ThreadServiceDep, -) -> Thread: - """Modify a thread.""" - return thread_service.modify(thread_id, request) +) -> None: + thread_service.delete(thread_id) diff --git a/src/askui/chat/api/threads/service.py b/src/askui/chat/api/threads/service.py index 542bd405..66fde497 100644 --- a/src/askui/chat/api/threads/service.py +++ b/src/askui/chat/api/threads/service.py @@ -1,160 +1,76 @@ import shutil -from datetime import datetime, timezone from pathlib import Path -from typing import Literal -from pydantic import BaseModel, Field - -from askui.chat.api.messages.service import MessageCreateRequest, MessageService -from askui.chat.api.models import DoNotPatch, ThreadId -from askui.utils.api_utils import ListQuery, ListResponse -from askui.utils.datetime_utils import UnixDatetime -from askui.utils.id_utils import generate_time_ordered_id - - -class Thread(BaseModel): - """A chat thread/session.""" - - id: ThreadId = Field(default_factory=lambda: generate_time_ordered_id("thread")) - created_at: UnixDatetime = Field( - default_factory=lambda: datetime.now(tz=timezone.utc) - ) - name: str | None = None - object: Literal["thread"] = "thread" - - -class ThreadCreateRequest(BaseModel): - name: str | None = None - messages: list[MessageCreateRequest] | None = None - - -class ThreadModifyRequest(BaseModel): - name: str | None | DoNotPatch = DoNotPatch() +from askui.chat.api.messages.models import MessageCreateParams +from askui.chat.api.messages.service import MessageService +from askui.chat.api.models import ThreadId +from askui.chat.api.threads.models import Thread, ThreadCreateParams, ThreadModifyParams +from askui.utils.api_utils import ( + ConflictError, + ListQuery, + ListResponse, + NotFoundError, + list_resources, +) class ThreadService: - """Service for managing chat threads/sessions.""" + """Service for managing Thread resources with filesystem persistence.""" - def __init__( - self, - base_dir: Path, - message_service: MessageService, - ) -> None: - """Initialize thread service. - - Args: - base_dir: Base directory to store thread data - """ + def __init__(self, base_dir: Path, message_service: MessageService) -> None: self._base_dir = base_dir self._threads_dir = base_dir / "threads" self._message_service = message_service - def create(self, request: ThreadCreateRequest) -> Thread: - """Create a new thread. - - Returns: - Created thread object - """ - thread = Thread(name=request.name) - self._threads_dir.mkdir(parents=True, exist_ok=True) - thread_file = self._threads_dir / f"{thread.id}.json" - thread_file.write_text(thread.model_dump_json(), encoding="utf-8") - if request.messages: - for message in request.messages: - self._message_service.create( - thread_id=thread.id, - request=message, - ) - return thread + def _get_thread_path(self, thread_id: ThreadId, new: bool = False) -> Path: + thread_path = self._threads_dir / f"{thread_id}.json" + exists = thread_path.exists() + if new and exists: + error_msg = f"Thread {thread_id} already exists" + raise ConflictError(error_msg) + if not new and not exists: + error_msg = f"Thread {thread_id} not found" + raise NotFoundError(error_msg) + return thread_path def list_(self, query: ListQuery) -> ListResponse[Thread]: - """List all available threads. - - Args: - query (ListQuery): Query parameters for listing threads - - Returns: - ListResponse[Thread]: ListResponse containing threads sorted by creation - date - """ - if not self._threads_dir.exists(): - return ListResponse(data=[]) - - thread_files = list(self._threads_dir.glob("*.json")) - threads: list[Thread] = [] - for f in thread_files: - thread = Thread.model_validate_json(f.read_text(encoding="utf-8")) - threads.append(thread) - - # Sort by creation date - threads = sorted( - threads, key=lambda t: t.created_at, reverse=(query.order == "desc") - ) - - # Apply before/after filters - if query.after: - threads = [t for t in threads if t.id > query.after] - if query.before: - threads = [t for t in threads if t.id < query.before] - - # Apply limit - threads = threads[: query.limit] - - return ListResponse( - data=threads, - first_id=threads[0].id if threads else None, - last_id=threads[-1].id if threads else None, - has_more=len(thread_files) > query.limit, - ) + return list_resources(self._threads_dir, query, Thread) def retrieve(self, thread_id: ThreadId) -> Thread: - """Retrieve a thread by ID. + try: + thread_path = self._get_thread_path(thread_id) + return Thread.model_validate_json(thread_path.read_text()) + except FileNotFoundError as e: + error_msg = f"Thread {thread_id} not found" + raise NotFoundError(error_msg) from e - Args: - thread_id: ID of thread to retrieve + def create(self, params: ThreadCreateParams) -> Thread: + thread = Thread.create(params) + self._save(thread, new=True) - Returns: - Thread object + if params.messages: + for message in params.messages: + self._message_service.create( + thread_id=thread.id, + params=message, + ) + return thread - Raises: - FileNotFoundError: If thread doesn't exist - """ - thread_file = self._threads_dir / f"{thread_id}.json" - if not thread_file.exists(): - error_msg = f"Thread {thread_id} not found" - raise FileNotFoundError(error_msg) - return Thread.model_validate_json(thread_file.read_text(encoding="utf-8")) + def modify(self, thread_id: ThreadId, params: ThreadModifyParams) -> Thread: + thread = self.retrieve(thread_id) + modified = thread.modify(params) + self._save(modified) + return modified def delete(self, thread_id: ThreadId) -> None: - """Delete a thread and all its associated files. - - Args: - thread_id (ThreadId): ID of thread to delete - - Raises: - FileNotFoundError: If thread doesn't exist - """ - thread_file = self._threads_dir / f"{thread_id}.json" - if not thread_file.exists(): + try: + shutil.rmtree(self._threads_dir / thread_id) + self._get_thread_path(thread_id).unlink() + except FileNotFoundError as e: error_msg = f"Thread {thread_id} not found" - raise FileNotFoundError(error_msg) - - messages_dir = self._message_service.get_thread_messages_dir(thread_id) - if messages_dir.exists(): - shutil.rmtree(messages_dir) - - thread_file.unlink() + raise NotFoundError(error_msg) from e - def modify(self, thread_id: ThreadId, request: ThreadModifyRequest) -> Thread: - """Modify a thread. - - Args: - thread_id (ThreadId): ID of thread to modify - request (ThreadModifyRequest): Request containing the new name - """ - thread = self.retrieve(thread_id) - if not isinstance(request.name, DoNotPatch): - thread.name = request.name - thread_file = self._threads_dir / f"{thread_id}.json" + def _save(self, thread: Thread, new: bool = False) -> None: + self._threads_dir.mkdir(parents=True, exist_ok=True) + thread_file = self._get_thread_path(thread.id, new=new) thread_file.write_text(thread.model_dump_json(), encoding="utf-8") - return thread diff --git a/src/askui/tools/testing/execution_models.py b/src/askui/tools/testing/execution_models.py index 021a32c7..bc5f2f5c 100644 --- a/src/askui/tools/testing/execution_models.py +++ b/src/askui/tools/testing/execution_models.py @@ -6,7 +6,7 @@ from askui.tools.testing.feature_models import FeatureId from askui.tools.testing.scenario_models import ExampleIndex, ScenarioId -from askui.utils.api_utils import ListQuery +from askui.utils.api_utils import ListQuery, Resource from askui.utils.datetime_utils import UnixDatetime, now from askui.utils.id_utils import IdField, generate_time_ordered_id from askui.utils.not_given import NOT_GIVEN, BaseModelWithNotGiven, NotGiven @@ -72,7 +72,7 @@ class ModifyExecutionStepParams(BaseModelWithNotGiven): status: ExecutionStatus | NotGiven = NOT_GIVEN -class Execution(BaseModel): +class Execution(Resource): """ A structured representation of an execution result for a scenario or scenario outline example. diff --git a/src/askui/tools/testing/execution_service.py b/src/askui/tools/testing/execution_service.py index b9bf27ba..22639b2e 100644 --- a/src/askui/tools/testing/execution_service.py +++ b/src/askui/tools/testing/execution_service.py @@ -1,12 +1,11 @@ from pathlib import Path - -from pydantic import ValidationError +from typing import Callable from askui.utils.api_utils import ( ConflictError, ListResponse, NotFoundError, - list_resource_paths, + list_resources, ) from askui.utils.not_given import NOT_GIVEN @@ -18,63 +17,55 @@ ) -class ExecutionService: - """ - Service for managing Execution resources with filesystem persistence. +def _build_execution_filter_fn( + query: ExecutionListQuery, +) -> Callable[[Execution], bool]: + def filter_fn(execution: Execution) -> bool: + return ( + (query.feature == NOT_GIVEN or execution.feature == query.feature) + and (query.scenario == NOT_GIVEN or execution.scenario == query.scenario) + and (query.example == NOT_GIVEN or execution.example == query.example) + ) + + return filter_fn + - Args: - base_dir (Path): Base directory for storing execution data. - """ +class ExecutionService: + """Service for managing Execution resources with filesystem persistence.""" def __init__(self, base_dir: Path) -> None: self._base_dir = base_dir self._executions_dir = base_dir / "executions" - self._executions_dir.mkdir(parents=True, exist_ok=True) - def list_( - self, - query: ExecutionListQuery, - ) -> ListResponse[Execution]: - execution_paths = list_resource_paths(self._executions_dir, query) - executions: list[Execution] = [] - for f in execution_paths: - try: - execution = Execution.model_validate_json(f.read_text()) - if ( - (query.feature == NOT_GIVEN or execution.feature == query.feature) - and ( - query.scenario == NOT_GIVEN - or execution.scenario == query.scenario - ) - and ( - query.example == NOT_GIVEN or execution.example == query.example - ) - ): - executions.append(execution) - except ValidationError: # noqa: PERF203 - continue - has_more = len(executions) > query.limit - executions = executions[: query.limit] - return ListResponse( - data=executions, - first_id=executions[0].id if executions else None, - last_id=executions[-1].id if executions else None, - has_more=has_more, + def _get_execution_path(self, execution_id: ExecutionId, new: bool = False) -> Path: + execution_path = self._executions_dir / f"{execution_id}.json" + exists = execution_path.exists() + if new and exists: + error_msg = f"Execution {execution_id} already exists" + raise ConflictError(error_msg) + if not new and not exists: + error_msg = f"Execution {execution_id} not found" + raise NotFoundError(error_msg) + return execution_path + + def list_(self, query: ExecutionListQuery) -> ListResponse[Execution]: + return list_resources( + base_dir=self._executions_dir, + query=query, + resource_type=Execution, + filter_fn=_build_execution_filter_fn(query), ) def retrieve(self, execution_id: ExecutionId) -> Execution: - execution_file = self._executions_dir / f"{execution_id}.json" - if not execution_file.exists(): + try: + execution_path = self._get_execution_path(execution_id) + return Execution.model_validate_json(execution_path.read_text()) + except FileNotFoundError as e: error_msg = f"Execution {execution_id} not found" - raise NotFoundError(error_msg) - return Execution.model_validate_json(execution_file.read_text()) + raise NotFoundError(error_msg) from e def create(self, execution: Execution) -> Execution: - execution_file = self._executions_dir / f"{execution.id}.json" - if execution_file.exists(): - error_msg = f"Execution {execution.id} already exists" - raise ConflictError(error_msg) - execution_file.write_text(execution.model_dump_json()) + self._save(execution, new=True) return execution def modify( @@ -84,14 +75,15 @@ def modify( modified = execution.modify(params) return self._save(modified) - def _save(self, execution: Execution) -> Execution: - execution_file = self._executions_dir / f"{execution.id}.json" - execution_file.write_text(execution.model_dump_json()) - return execution - def delete(self, execution_id: ExecutionId) -> None: - execution_file = self._executions_dir / f"{execution_id}.json" - if not execution_file.exists(): + try: + self._get_execution_path(execution_id).unlink() + except FileNotFoundError as e: error_msg = f"Execution {execution_id} not found" - raise NotFoundError(error_msg) - execution_file.unlink() + raise NotFoundError(error_msg) from e + + def _save(self, execution: Execution, new: bool = False) -> Execution: + self._executions_dir.mkdir(parents=True, exist_ok=True) + execution_file = self._get_execution_path(execution.id, new=new) + execution_file.write_text(execution.model_dump_json(), encoding="utf-8") + return execution diff --git a/src/askui/tools/testing/feature_models.py b/src/askui/tools/testing/feature_models.py index 1084b0cb..7ea51d08 100644 --- a/src/askui/tools/testing/feature_models.py +++ b/src/askui/tools/testing/feature_models.py @@ -4,7 +4,7 @@ from fastapi import Query from pydantic import BaseModel, Field -from askui.utils.api_utils import ListQuery +from askui.utils.api_utils import ListQuery, Resource from askui.utils.datetime_utils import UnixDatetime, now from askui.utils.id_utils import IdField, generate_time_ordered_id from askui.utils.not_given import NOT_GIVEN, BaseModelWithNotGiven, NotGiven @@ -37,7 +37,7 @@ class FeatureListQuery(ListQuery): tags: Annotated[list[str] | NotGiven, Query()] = NOT_GIVEN -class Feature(BaseModel): +class Feature(Resource): """ A structured representation of a feature used for BDD test automation. diff --git a/src/askui/tools/testing/feature_service.py b/src/askui/tools/testing/feature_service.py index 63790322..ba9520db 100644 --- a/src/askui/tools/testing/feature_service.py +++ b/src/askui/tools/testing/feature_service.py @@ -1,12 +1,11 @@ from pathlib import Path - -from pydantic import ValidationError +from typing import Callable from askui.utils.api_utils import ( ConflictError, ListResponse, NotFoundError, - list_resource_paths, + list_resources, ) from askui.utils.not_given import NOT_GIVEN @@ -19,69 +18,69 @@ ) -class FeatureService: - """ - Service for managing Feature resources with filesystem persistence. +def _build_feature_filter_fn( + query: FeatureListQuery, +) -> Callable[[Feature], bool]: + def filter_fn(feature: Feature) -> bool: + return query.tags == NOT_GIVEN or any(tag in feature.tags for tag in query.tags) - Args: - base_dir (Path): Base directory for storing feature data. - """ + return filter_fn + +class FeatureService: def __init__(self, base_dir: Path) -> None: self._base_dir = base_dir self._features_dir = base_dir / "features" - self._features_dir.mkdir(parents=True, exist_ok=True) + + def _get_feature_path(self, feature_id: FeatureId, new: bool = False) -> Path: + feature_path = self._features_dir / f"{feature_id}.json" + exists = feature_path.exists() + if new and exists: + error_msg = f"Feature {feature_id} already exists" + raise ConflictError(error_msg) + if not new and not exists: + error_msg = f"Feature {feature_id} not found" + raise NotFoundError(error_msg) + return feature_path def list_( self, query: FeatureListQuery, ) -> ListResponse[Feature]: - feature_paths = list_resource_paths(self._features_dir, query) - features: list[Feature] = [] - for f in feature_paths: - try: - feature = Feature.model_validate_json(f.read_text()) - if query.tags == NOT_GIVEN or any( - tag in feature.tags for tag in query.tags - ): - features.append(feature) - except ValidationError: # noqa: PERF203 - continue - has_more = len(features) > query.limit - features = features[: query.limit] - return ListResponse( - data=features, - first_id=features[0].id if features else None, - last_id=features[-1].id if features else None, - has_more=has_more, + return list_resources( + base_dir=self._features_dir, + query=query, + resource_type=Feature, + filter_fn=_build_feature_filter_fn(query), ) def retrieve(self, feature_id: FeatureId) -> Feature: - feature_file = self._features_dir / f"{feature_id}.json" - if not feature_file.exists(): + try: + feature_path = self._get_feature_path(feature_id) + return Feature.model_validate_json(feature_path.read_text()) + except FileNotFoundError as e: error_msg = f"Feature {feature_id} not found" - raise NotFoundError(error_msg) - return Feature.model_validate_json(feature_file.read_text()) + raise NotFoundError(error_msg) from e def create(self, params: FeatureCreateParams) -> Feature: feature = Feature.create(params) - feature_file = self._features_dir / f"{feature.id}.json" - if feature_file.exists(): - error_msg = f"Feature {feature.id} already exists" - raise ConflictError(error_msg) - feature_file.write_text(feature.model_dump_json()) + self._save(feature, new=True) return feature def modify(self, feature_id: FeatureId, params: FeatureModifyParams) -> Feature: feature = self.retrieve(feature_id) modified = feature.modify(params) - feature_file = self._features_dir / f"{feature_id}.json" - feature_file.write_text(modified.model_dump_json()) + self._save(modified) return modified def delete(self, feature_id: FeatureId) -> None: - feature_file = self._features_dir / f"{feature_id}.json" - if not feature_file.exists(): + try: + self._get_feature_path(feature_id).unlink() + except FileNotFoundError as e: error_msg = f"Feature {feature_id} not found" - raise NotFoundError(error_msg) - feature_file.unlink() + raise NotFoundError(error_msg) from e + + def _save(self, feature: Feature, new: bool = False) -> None: + self._features_dir.mkdir(parents=True, exist_ok=True) + feature_file = self._get_feature_path(feature.id, new=new) + feature_file.write_text(feature.model_dump_json(), encoding="utf-8") diff --git a/src/askui/tools/testing/scenario_models.py b/src/askui/tools/testing/scenario_models.py index 1d90c619..d6f5abc4 100644 --- a/src/askui/tools/testing/scenario_models.py +++ b/src/askui/tools/testing/scenario_models.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field from askui.tools.testing.feature_models import FeatureId -from askui.utils.api_utils import ListQuery +from askui.utils.api_utils import ListQuery, Resource from askui.utils.datetime_utils import UnixDatetime, now from askui.utils.id_utils import IdField, generate_time_ordered_id from askui.utils.not_given import NOT_GIVEN, BaseModelWithNotGiven, NotGiven @@ -62,7 +62,7 @@ class ScenarioListQuery(ListQuery): tags: Annotated[list[str] | NotGiven, Query()] = NOT_GIVEN -class Scenario(BaseModel): +class Scenario(Resource): """ A structured representation of a scenario or scenario outline for BDD test automation. diff --git a/src/askui/tools/testing/scenario_service.py b/src/askui/tools/testing/scenario_service.py index 492eb0a7..e0680683 100644 --- a/src/askui/tools/testing/scenario_service.py +++ b/src/askui/tools/testing/scenario_service.py @@ -1,12 +1,11 @@ from pathlib import Path - -from pydantic import ValidationError +from typing import Callable from askui.utils.api_utils import ( ConflictError, ListResponse, NotFoundError, - list_resource_paths, + list_resources, ) from askui.utils.not_given import NOT_GIVEN @@ -19,76 +18,74 @@ ) -class ScenarioService: - """ - Service for managing Scenario resources with filesystem persistence. +def _build_scenario_filter_fn( + query: ScenarioListQuery, +) -> Callable[[Scenario], bool]: + def filter_fn(scenario: Scenario) -> bool: + tags_matched = query.tags == NOT_GIVEN or any( + tag in scenario.tags for tag in query.tags + ) + feature_matched = ( + query.feature is NOT_GIVEN or scenario.feature == query.feature + ) + return tags_matched and feature_matched + + return filter_fn - Args: - base_dir (Path): Base directory for storing scenario data. - """ + +class ScenarioService: + """Service for managing Scenario resources with filesystem persistence.""" def __init__(self, base_dir: Path) -> None: self._base_dir = base_dir self._scenarios_dir = base_dir / "scenarios" - self._scenarios_dir.mkdir(parents=True, exist_ok=True) - def list_( - self, - query: ScenarioListQuery, - ) -> ListResponse[Scenario]: - scenario_paths = list_resource_paths(self._scenarios_dir, query) - scenarios: list[Scenario] = [] - for f in scenario_paths: - try: - scenario = Scenario.model_validate_json(f.read_text()) - tags_matched = query.tags == NOT_GIVEN or any( - tag in scenario.tags for tag in query.tags - ) - feature_matched = ( - query.feature is NOT_GIVEN or scenario.feature == query.feature - ) - if tags_matched and feature_matched: - scenarios.append(scenario) - except ValidationError: # noqa: PERF203 - continue - has_more = len(scenarios) > query.limit - scenarios = scenarios[: query.limit] - return ListResponse( - data=scenarios, - first_id=scenarios[0].id if scenarios else None, - last_id=scenarios[-1].id if scenarios else None, - has_more=has_more, + def _get_scenario_path(self, scenario_id: ScenarioId, new: bool = False) -> Path: + scenario_path = self._scenarios_dir / f"{scenario_id}.json" + exists = scenario_path.exists() + if new and exists: + error_msg = f"Scenario {scenario_id} already exists" + raise ConflictError(error_msg) + if not new and not exists: + error_msg = f"Scenario {scenario_id} not found" + raise NotFoundError(error_msg) + return scenario_path + + def list_(self, query: ScenarioListQuery) -> ListResponse[Scenario]: + return list_resources( + base_dir=self._scenarios_dir, + query=query, + resource_type=Scenario, + filter_fn=_build_scenario_filter_fn(query), ) def retrieve(self, scenario_id: ScenarioId) -> Scenario: - scenario_file = self._scenarios_dir / f"{scenario_id}.json" - if not scenario_file.exists(): + try: + scenario_path = self._get_scenario_path(scenario_id) + return Scenario.model_validate_json(scenario_path.read_text()) + except FileNotFoundError as e: error_msg = f"Scenario {scenario_id} not found" - raise NotFoundError(error_msg) - return Scenario.model_validate_json(scenario_file.read_text()) + raise NotFoundError(error_msg) from e def create(self, params: ScenarioCreateParams) -> Scenario: scenario = Scenario.create(params) - scenario_file = self._scenarios_dir / f"{scenario.id}.json" - if scenario_file.exists(): - error_msg = f"Scenario {scenario.id} already exists" - raise ConflictError(error_msg) - scenario_file.write_text(scenario.model_dump_json()) + self._save(scenario, new=True) return scenario def modify(self, scenario_id: ScenarioId, params: ScenarioModifyParams) -> Scenario: scenario = self.retrieve(scenario_id) - updated = scenario.modify(params) - return self._save(updated) - - def _save(self, scenario: Scenario) -> Scenario: - scenario_file = self._scenarios_dir / f"{scenario.id}.json" - scenario_file.write_text(scenario.model_dump_json()) - return scenario + modified = scenario.modify(params) + return self._save(modified) def delete(self, scenario_id: ScenarioId) -> None: - scenario_file = self._scenarios_dir / f"{scenario_id}.json" - if not scenario_file.exists(): + try: + self._get_scenario_path(scenario_id).unlink() + except FileNotFoundError as e: error_msg = f"Scenario {scenario_id} not found" - raise NotFoundError(error_msg) - scenario_file.unlink() + raise NotFoundError(error_msg) from e + + def _save(self, scenario: Scenario, new: bool = False) -> Scenario: + self._scenarios_dir.mkdir(parents=True, exist_ok=True) + scenario_file = self._get_scenario_path(scenario.id, new=new) + scenario_file.write_text(scenario.model_dump_json(), encoding="utf-8") + return scenario diff --git a/src/askui/utils/api_utils.py b/src/askui/utils/api_utils.py index f1b994a5..8eade9bc 100644 --- a/src/askui/utils/api_utils.py +++ b/src/askui/utils/api_utils.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from pathlib import Path -from typing import Annotated, Generic, Literal, Sequence +from typing import Annotated, Callable, Generic, Literal, Sequence, Type from fastapi import Query from pydantic import BaseModel, ValidationError @@ -50,15 +50,65 @@ class NotFoundError(ApiError): def list_resource_paths(base_dir: Path, list_query: ListQuery) -> list[Path]: paths: list[Path] = [] + after_name = f"{list_query.after}.json" + before_name = f"{list_query.before}.json" for f in base_dir.glob("*.json"): try: if list_query.after: - if f.name <= list_query.after: + if f.name <= after_name: continue if list_query.before: - if f.name >= list_query.before: + if f.name >= before_name: continue paths.append(f) except ValidationError: # noqa: PERF203 continue return sorted(paths, key=lambda f: f.name, reverse=(list_query.order == "desc")) + + +class Resource(BaseModel): + id: str + + +ResourceT = TypeVar("ResourceT", bound=Resource) + + +def list_resources( + base_dir: Path, + query: ListQuery, + resource_type: Type[ResourceT], + filter_fn: Callable[[ResourceT], bool] | None = None, +) -> ListResponse[ResourceT]: + """ + List resources from a directory. + + Args: + base_dir: The base directory to list resources from. + query: The query to filter resources. + resource_type: The type of resource to list. + filter_fn: A function to filter resources. If it returns `False`, + the resource is not included in the list. + + Returns: + A list of resources. + """ + resource_paths = list_resource_paths(base_dir, query) + resources: list[ResourceT] = [] + for resource_file in resource_paths: + try: + resource = resource_type.model_validate_json( + resource_file.read_text(encoding="utf-8") + ) + if filter_fn and not filter_fn(resource): + continue + resources.append(resource) + except (ValidationError, FileNotFoundError): # noqa: PERF203 + continue + has_more = len(resources) > query.limit + resources = resources[: query.limit] + return ListResponse( + data=resources, + first_id=resources[0].id if resources else None, + last_id=resources[-1].id if resources else None, + has_more=has_more, + ) diff --git a/tests/integration/chat/api/test_messages_service.py b/tests/integration/chat/api/test_messages_service.py index 2a62ef76..e58be196 100644 --- a/tests/integration/chat/api/test_messages_service.py +++ b/tests/integration/chat/api/test_messages_service.py @@ -7,7 +7,8 @@ import pytest -from askui.chat.api.messages.service import MessageCreateRequest, MessageService +from askui.chat.api.messages.models import MessageCreateParams +from askui.chat.api.messages.service import MessageService from askui.chat.api.models import ThreadId @@ -37,12 +38,12 @@ def test_create_message_creates_individual_json_file( self, message_service: MessageService, thread_id: ThreadId ) -> None: """Test that creating a message creates an individual JSON file.""" - request = MessageCreateRequest(role="user", content="Hello, world!") + request = MessageCreateParams(role="user", content="Hello, world!") message = message_service.create(thread_id, request) # Check that the message directory was created - messages_dir = message_service.get_thread_messages_dir(thread_id) + messages_dir = message_service.get_messages_dir(thread_id) assert messages_dir.exists() # Check that the message file was created @@ -64,7 +65,7 @@ def test_list_messages_reads_from_json_files( # Create multiple messages messages = [] for i in range(3): - request = MessageCreateRequest( + request = MessageCreateParams( role="user" if i % 2 == 0 else "assistant", content=f"Message {i}" ) message = message_service.create(thread_id, request) @@ -87,7 +88,7 @@ def test_delete_message_removes_json_file( self, message_service: MessageService, thread_id: ThreadId ) -> None: """Test that deleting a message removes its JSON file.""" - request = MessageCreateRequest(role="user", content="Delete me") + request = MessageCreateParams(role="user", content="Delete me") message = message_service.create(thread_id, request) message_file = message_service._get_message_path(thread_id, message.id) @@ -103,12 +104,11 @@ def test_directory_structure_is_correct( self, message_service: MessageService, thread_id: ThreadId ) -> None: """Test that the directory structure follows the expected pattern.""" - request = MessageCreateRequest(role="user", content="Test message") + request = MessageCreateParams(role="user", content="Test message") message_service.create(thread_id, request) - # Check directory structure - messages are stored in base_dir/messages/thread_id/ - messages_dir = message_service.get_thread_messages_dir(thread_id) + messages_dir = message_service.get_messages_dir(thread_id) assert messages_dir.exists() diff --git a/tests/integration/chat/api/test_threads_service.py b/tests/integration/chat/api/test_threads_service.py index 08692f34..dde56975 100644 --- a/tests/integration/chat/api/test_threads_service.py +++ b/tests/integration/chat/api/test_threads_service.py @@ -6,8 +6,9 @@ import pytest -from askui.chat.api.messages.service import MessageCreateRequest, MessageService -from askui.chat.api.threads.service import ThreadCreateRequest, ThreadService +from askui.chat.api.messages.models import MessageCreateParams +from askui.chat.api.messages.service import MessageService +from askui.chat.api.threads.service import ThreadCreateParams, ThreadService @pytest.fixture @@ -38,7 +39,7 @@ def test_create_thread_creates_directory_structure( self, thread_service: ThreadService ) -> None: """Test that creating a thread creates the proper directory structure.""" - request = ThreadCreateRequest(name="Test Thread") + request = ThreadCreateParams(name="Test Thread") thread = thread_service.create(request) @@ -47,11 +48,12 @@ def test_create_thread_creates_directory_structure( assert thread_file.exists() # Check that messages directory was created (by creating a message) - # The ThreadService doesn't create the messages directory until a message is added - message_request = MessageCreateRequest(role="user", content="Test message") + # The ThreadService doesn't create the messages directory until a message is + # added + message_request = MessageCreateParams(role="user", content="Test message") thread_service._message_service.create(thread.id, message_request) - thread_messages_dir = thread_service._message_service.get_thread_messages_dir( + thread_messages_dir = thread_service._message_service.get_messages_dir( thread.id ) assert thread_messages_dir.exists() @@ -67,15 +69,15 @@ def test_create_thread_creates_directory_structure( def test_create_thread_with_messages(self, thread_service: ThreadService) -> None: """Test that creating a thread with messages works correctly.""" messages = [ - MessageCreateRequest(role="user", content="Hello"), - MessageCreateRequest(role="assistant", content="Hi there!"), + MessageCreateParams(role="user", content="Hello"), + MessageCreateParams(role="assistant", content="Hi there!"), ] - request = ThreadCreateRequest(name="Thread with Messages", messages=messages) + request = ThreadCreateParams(name="Thread with Messages", messages=messages) thread = thread_service.create(request) # Check that messages were created - thread_messages_dir = thread_service._message_service.get_thread_messages_dir( + thread_messages_dir = thread_service._message_service.get_messages_dir( thread.id ) json_files = list(thread_messages_dir.glob("*.json")) @@ -94,21 +96,18 @@ def test_delete_thread_removes_all_files( self, thread_service: ThreadService ) -> None: """Test that deleting a thread removes all associated files.""" - request = ThreadCreateRequest(name="Thread to Delete") + request = ThreadCreateParams( + name="Thread to Delete", + messages=[MessageCreateParams(role="user", content="Test message")], + ) thread = thread_service.create(request) - # Add a message - message_request = MessageCreateRequest(role="user", content="Test message") - thread_service._message_service.create(thread.id, message_request) - # Verify files exist thread_file = thread_service._base_dir / "threads" / f"{thread.id}.json" assert thread_file.exists() # The thread directory itself doesn't exist, only the messages directory - messages_dir = thread_service._message_service.get_thread_messages_dir( - thread.id - ) + messages_dir = thread_service._message_service.get_messages_dir(thread.id) assert messages_dir.exists() # Delete thread From 05bd4804c20733ef962918442df25321c8cb3262 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 21 Aug 2025 23:25:28 +0200 Subject: [PATCH 09/11] feat(chat): add files api --- pdm.lock | 6 +- src/askui/chat/api/app.py | 36 +- src/askui/chat/api/assistants/router.py | 3 +- src/askui/chat/api/dependencies.py | 6 +- src/askui/chat/api/files/__init__.py | 0 src/askui/chat/api/files/dependencies.py | 14 + src/askui/chat/api/files/models.py | 42 ++ src/askui/chat/api/files/router.py | 57 ++ src/askui/chat/api/files/service.py | 149 +++++ src/askui/chat/api/mcp_configs/router.py | 3 +- src/askui/chat/api/messages/dependencies.py | 12 + src/askui/chat/api/messages/models.py | 55 +- src/askui/chat/api/messages/router.py | 3 +- src/askui/chat/api/messages/service.py | 2 +- src/askui/chat/api/messages/translator.py | 238 ++++++++ src/askui/chat/api/models.py | 20 +- src/askui/chat/api/runs/dependencies.py | 11 +- src/askui/chat/api/runs/router.py | 5 +- src/askui/chat/api/runs/runner/runner.py | 64 +- src/askui/chat/api/runs/service.py | 17 +- src/askui/chat/api/threads/dependencies.py | 4 + src/askui/chat/api/threads/router.py | 3 +- src/askui/chat/api/threads/service.py | 10 +- src/askui/utils/api_utils.py | 6 + tests/integration/chat/__init__.py | 1 + tests/integration/chat/api/conftest.py | 76 +++ tests/integration/chat/api/test_files.py | 555 ++++++++++++++++++ .../chat/api/test_files_edge_cases.py | 457 ++++++++++++++ .../chat/api/test_files_service.py | 442 ++++++++++++++ .../chat/api/test_messages_service.py | 118 ---- .../chat/api/test_threads_service.py | 118 ---- 31 files changed, 2233 insertions(+), 300 deletions(-) create mode 100644 src/askui/chat/api/files/__init__.py create mode 100644 src/askui/chat/api/files/dependencies.py create mode 100644 src/askui/chat/api/files/models.py create mode 100644 src/askui/chat/api/files/router.py create mode 100644 src/askui/chat/api/files/service.py create mode 100644 src/askui/chat/api/messages/translator.py create mode 100644 tests/integration/chat/__init__.py create mode 100644 tests/integration/chat/api/conftest.py create mode 100644 tests/integration/chat/api/test_files.py create mode 100644 tests/integration/chat/api/test_files_edge_cases.py create mode 100644 tests/integration/chat/api/test_files_service.py delete mode 100644 tests/integration/chat/api/test_messages_service.py delete mode 100644 tests/integration/chat/api/test_threads_service.py diff --git a/pdm.lock b/pdm.lock index c5315d63..c8dd7c2c 100644 --- a/pdm.lock +++ b/pdm.lock @@ -26,7 +26,7 @@ files = [ [[package]] name = "anthropic" -version = "0.54.0" +version = "0.64.0" requires_python = ">=3.8" summary = "The official Python library for the anthropic API" groups = ["default"] @@ -40,8 +40,8 @@ dependencies = [ "typing-extensions<5,>=4.10", ] files = [ - {file = "anthropic-0.54.0-py3-none-any.whl", hash = "sha256:c1062a0a905daeec17ca9c06c401e4b3f24cb0495841d29d752568a1d4018d56"}, - {file = "anthropic-0.54.0.tar.gz", hash = "sha256:5e6f997d97ce8e70eac603c3ec2e7f23addeff953fbbb76b19430562bb6ba815"}, + {file = "anthropic-0.64.0-py3-none-any.whl", hash = "sha256:6f5f7d913a6a95eb7f8e1bda4e75f76670e8acd8d4cd965e02e2a256b0429dd1"}, + {file = "anthropic-0.64.0.tar.gz", hash = "sha256:3d496c91a63dff64f451b3e8e4b238a9640bf87b0c11d0b74ddc372ba5a3fe58"}, ] [[package]] diff --git a/src/askui/chat/api/app.py b/src/askui/chat/api/app.py index f5586a2c..8689c0f4 100644 --- a/src/askui/chat/api/app.py +++ b/src/askui/chat/api/app.py @@ -1,19 +1,25 @@ from contextlib import asynccontextmanager from typing import AsyncGenerator -from fastapi import APIRouter, FastAPI, Request, status +from fastapi import APIRouter, FastAPI, HTTPException, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from askui.chat.api.assistants.dependencies import get_assistant_service from askui.chat.api.assistants.router import router as assistants_router from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_settings +from askui.chat.api.files.router import router as files_router from askui.chat.api.health.router import router as health_router from askui.chat.api.mcp_configs.router import router as mcp_configs_router from askui.chat.api.messages.router import router as messages_router from askui.chat.api.runs.router import router as runs_router from askui.chat.api.threads.router import router as threads_router -from askui.utils.api_utils import ConflictError, LimitReachedError, NotFoundError +from askui.utils.api_utils import ( + ConflictError, + FileTooLargeError, + LimitReachedError, + NotFoundError, +) @asynccontextmanager @@ -71,6 +77,31 @@ def limit_reached_error_handler( ) +@app.exception_handler(FileTooLargeError) +def file_too_large_error_handler( + request: Request, # noqa: ARG001 + exc: FileTooLargeError, +) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + content={"detail": str(exc)}, + ) + + +@app.exception_handler(Exception) +def catch_all_exception_handler( + request: Request, # noqa: ARG001 + exc: Exception, +) -> JSONResponse: + if isinstance(exc, HTTPException): + raise exc + + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": "Internal server error"}, + ) + + # Include routers v1_router = APIRouter(prefix="/v1") v1_router.include_router(assistants_router) @@ -78,5 +109,6 @@ def limit_reached_error_handler( v1_router.include_router(messages_router) v1_router.include_router(runs_router) v1_router.include_router(mcp_configs_router) +v1_router.include_router(files_router) v1_router.include_router(health_router) app.include_router(v1_router) diff --git a/src/askui/chat/api/assistants/router.py b/src/askui/chat/api/assistants/router.py index 515ae6d7..4a14f3ba 100644 --- a/src/askui/chat/api/assistants/router.py +++ b/src/askui/chat/api/assistants/router.py @@ -7,7 +7,8 @@ AssistantModifyParams, ) from askui.chat.api.assistants.service import AssistantService -from askui.chat.api.models import AssistantId, ListQueryDep +from askui.chat.api.dependencies import ListQueryDep +from askui.chat.api.models import AssistantId from askui.utils.api_utils import ListQuery, ListResponse router = APIRouter(prefix="/assistants", tags=["assistants"]) diff --git a/src/askui/chat/api/dependencies.py b/src/askui/chat/api/dependencies.py index b67f627c..30d585c6 100644 --- a/src/askui/chat/api/dependencies.py +++ b/src/askui/chat/api/dependencies.py @@ -2,11 +2,12 @@ from pathlib import Path from typing import Annotated, Optional -from fastapi import Depends, Header, HTTPException +from fastapi import Depends, Header from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer from pydantic import UUID4 from askui.chat.api.settings import Settings +from askui.utils.api_utils import ListQuery def get_settings() -> Settings: @@ -66,3 +67,6 @@ def get_workspace_dir( WorkspaceDirDep = Depends(get_workspace_dir) + + +ListQueryDep = Depends(ListQuery) diff --git a/src/askui/chat/api/files/__init__.py b/src/askui/chat/api/files/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/files/dependencies.py b/src/askui/chat/api/files/dependencies.py new file mode 100644 index 00000000..75f2f39c --- /dev/null +++ b/src/askui/chat/api/files/dependencies.py @@ -0,0 +1,14 @@ +from pathlib import Path + +from fastapi import Depends + +from askui.chat.api.dependencies import WorkspaceDirDep +from askui.chat.api.files.service import FileService + + +def get_file_service(workspace_dir: Path = WorkspaceDirDep) -> FileService: + """Get FileService instance.""" + return FileService(workspace_dir) + + +FileServiceDep = Depends(get_file_service) diff --git a/src/askui/chat/api/files/models.py b/src/askui/chat/api/files/models.py new file mode 100644 index 00000000..cf55c127 --- /dev/null +++ b/src/askui/chat/api/files/models.py @@ -0,0 +1,42 @@ +import mimetypes +from typing import Literal + +from pydantic import BaseModel, Field + +from askui.chat.api.models import FileId +from askui.utils.api_utils import Resource +from askui.utils.datetime_utils import UnixDatetime, now +from askui.utils.id_utils import generate_time_ordered_id + + +class FileBase(BaseModel): + """Base file model.""" + + size: int = Field(description="In bytes", ge=0) + media_type: str + + +class FileCreateParams(FileBase): + filename: str | None = None + + +class File(FileBase, Resource): + """A file that can be stored and managed.""" + + id: FileId + object: Literal["file"] = "file" + created_at: UnixDatetime + filename: str = Field(min_length=1) + + @classmethod + def create(cls, params: FileCreateParams) -> "File": + id_ = generate_time_ordered_id("file") + filename = ( + params.filename or f"{id_}{mimetypes.guess_extension(params.media_type)}" + ) + return cls( + id=id_, + created_at=now(), + filename=filename, + **params.model_dump(exclude={"filename"}), + ) diff --git a/src/askui/chat/api/files/router.py b/src/askui/chat/api/files/router.py new file mode 100644 index 00000000..3ebcf8cd --- /dev/null +++ b/src/askui/chat/api/files/router.py @@ -0,0 +1,57 @@ +from fastapi import APIRouter, UploadFile, status +from fastapi.responses import FileResponse + +from askui.chat.api.dependencies import ListQueryDep +from askui.chat.api.files.dependencies import FileServiceDep +from askui.chat.api.files.models import File as FileModel +from askui.chat.api.files.service import FileService +from askui.chat.api.models import FileId +from askui.utils.api_utils import ListQuery, ListResponse + +router = APIRouter(prefix="/files", tags=["files"]) + + +@router.get("") +def list_files( + query: ListQuery = ListQueryDep, + file_service: FileService = FileServiceDep, +) -> ListResponse[FileModel]: + """List all files.""" + return file_service.list_(query=query) + + +@router.post("", status_code=status.HTTP_201_CREATED) +async def upload_file( + file: UploadFile, + file_service: FileService = FileServiceDep, +) -> FileModel: + """Upload a new file.""" + return await file_service.upload_file(file) + + +@router.get("/{file_id}") +def retrieve_file( + file_id: FileId, + file_service: FileService = FileServiceDep, +) -> FileModel: + """Get file metadata by ID.""" + return file_service.retrieve(file_id) + + +@router.get("/{file_id}/content") +def download_file( + file_id: FileId, + file_service: FileService = FileServiceDep, +) -> FileResponse: + """Retrieve a file by ID.""" + file, file_path = file_service.retrieve_file_content(file_id) + return FileResponse(file_path, media_type=file.media_type, filename=file.filename) + + +@router.delete("/{file_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_file( + file_id: FileId, + file_service: FileService = FileServiceDep, +) -> None: + """Delete a file by ID.""" + file_service.delete(file_id) diff --git a/src/askui/chat/api/files/service.py b/src/askui/chat/api/files/service.py new file mode 100644 index 00000000..4b94d14f --- /dev/null +++ b/src/askui/chat/api/files/service.py @@ -0,0 +1,149 @@ +import mimetypes +import shutil +import tempfile +from pathlib import Path + +from fastapi import UploadFile + +from askui.chat.api.files.models import File, FileCreateParams +from askui.chat.api.models import FileId +from askui.logger import logger +from askui.utils.api_utils import ( + ConflictError, + FileTooLargeError, + ListQuery, + ListResponse, + NotFoundError, + list_resources, +) + +# Constants +MAX_FILE_SIZE = 20 * 1024 * 1024 # 20MB supported +CHUNK_SIZE = 1024 * 1024 # 1MB for uploading and downloading + + +class FileService: + """Service for managing File resources with filesystem persistence.""" + + def __init__(self, base_dir: Path) -> None: + self._base_dir = base_dir + self._files_dir = base_dir / "files" + self._static_dir = base_dir / "static" + + def _get_file_path(self, file_id: FileId, new: bool = False) -> Path: + """Get the path for file metadata.""" + file_path = self._files_dir / f"{file_id}.json" + exists = file_path.exists() + if new and exists: + error_msg = f"File {file_id} already exists" + raise ConflictError(error_msg) + if not new and not exists: + error_msg = f"File {file_id} not found" + raise NotFoundError(error_msg) + return file_path + + def _get_static_file_path(self, file: File) -> Path: + """Get the path for the static file based on extension.""" + # For application/octet-stream, don't add .bin extension + extension = "" + if file.media_type != "application/octet-stream": + extension = mimetypes.guess_extension(file.media_type) or "" + return self._static_dir / f"{file.id}{extension}" + + def list_(self, query: ListQuery) -> ListResponse[File]: + """List files with pagination and filtering.""" + return list_resources(self._files_dir, query, File) + + def retrieve(self, file_id: FileId) -> File: + """Retrieve file metadata by ID.""" + try: + file_path = self._get_file_path(file_id) + return File.model_validate_json(file_path.read_text()) + except FileNotFoundError as e: + error_msg = f"File {file_id} not found" + raise NotFoundError(error_msg) from e + + def delete(self, file_id: FileId) -> None: + """Delete a file and its content. + + *Important*: We may be left with a static file that is not associated with any + file metadata if this fails. + """ + try: + file = self.retrieve(file_id) + static_path = self._get_static_file_path(file) + self._get_file_path(file_id).unlink() + if static_path.exists(): + static_path.unlink() + except FileNotFoundError as e: + error_msg = f"File {file_id} not found" + raise NotFoundError(error_msg) from e + + def retrieve_file_content(self, file_id: FileId) -> tuple[File, Path]: + """Get file metadata and path for downloading.""" + file = self.retrieve(file_id) + static_path = self._get_static_file_path(file) + return file, static_path + + async def _write_to_temp_file( + self, + file: UploadFile, + ) -> tuple[FileCreateParams, Path]: + size = 0 + self._static_dir.mkdir(parents=True, exist_ok=True) + temp_file = tempfile.NamedTemporaryFile( + delete=False, + dir=self._static_dir, + suffix=".temp", + ) + temp_path = Path(temp_file.name) + with temp_file: + while chunk := await file.read(CHUNK_SIZE): + temp_file.write(chunk) + size += len(chunk) + if size > MAX_FILE_SIZE: + raise FileTooLargeError(MAX_FILE_SIZE) + mime_type = file.content_type or "application/octet-stream" + params = FileCreateParams( + filename=file.filename, + size=size, + media_type=mime_type, + ) + return params, temp_path + + def create(self, params: FileCreateParams, path: Path) -> File: + file_model = File.create(params) + self._static_dir.mkdir(parents=True, exist_ok=True) + static_path = self._get_static_file_path(file_model) + shutil.move(path, static_path) + self._save(file_model, new=True) + + return file_model + + async def upload_file( + self, + file: UploadFile, + ) -> File: + """Upload a file. + + *Important*: We may be left with a static file that is not associated with any + file metadata if this fails. + """ + temp_path: Path | None = None + try: + params, temp_path = await self._write_to_temp_file(file) + file_model = self.create(params, temp_path) + except Exception as e: + logger.error(f"Failed to upload file: {e}", exc_info=True) + raise + else: + return file_model + finally: + if temp_path: + temp_path.unlink(missing_ok=True) + + def _save(self, file: File, new: bool = False) -> None: + self._files_dir.mkdir(parents=True, exist_ok=True) + file_path = self._get_file_path(file.id, new=new) + content = file.model_dump_json() + file_path.write_text(content, encoding="utf-8") diff --git a/src/askui/chat/api/mcp_configs/router.py b/src/askui/chat/api/mcp_configs/router.py index 2bd96609..e360f510 100644 --- a/src/askui/chat/api/mcp_configs/router.py +++ b/src/askui/chat/api/mcp_configs/router.py @@ -1,5 +1,6 @@ from fastapi import APIRouter, status +from askui.chat.api.dependencies import ListQueryDep from askui.chat.api.mcp_configs.dependencies import McpConfigServiceDep from askui.chat.api.mcp_configs.models import ( McpConfig, @@ -7,7 +8,7 @@ McpConfigModifyParams, ) from askui.chat.api.mcp_configs.service import McpConfigService -from askui.chat.api.models import ListQueryDep, McpConfigId +from askui.chat.api.models import McpConfigId from askui.utils.api_utils import ListQuery, ListResponse router = APIRouter(prefix="/mcp-configs", tags=["mcp-configs"]) diff --git a/src/askui/chat/api/messages/dependencies.py b/src/askui/chat/api/messages/dependencies.py index 9db192a3..62d36038 100644 --- a/src/askui/chat/api/messages/dependencies.py +++ b/src/askui/chat/api/messages/dependencies.py @@ -3,7 +3,10 @@ from fastapi import Depends from askui.chat.api.dependencies import WorkspaceDirDep +from askui.chat.api.files.dependencies import FileServiceDep +from askui.chat.api.files.service import FileService from askui.chat.api.messages.service import MessageService +from askui.chat.api.messages.translator import MessageTranslator def get_message_service( @@ -14,3 +17,12 @@ def get_message_service( MessageServiceDep = Depends(get_message_service) + + +def get_message_translator( + file_service: FileService = FileServiceDep, +) -> MessageTranslator: + return MessageTranslator(file_service) + + +MessageTranslatorDep = Depends(get_message_translator) diff --git a/src/askui/chat/api/messages/models.py b/src/askui/chat/api/messages/models.py index 8dd9cc67..53e3915d 100644 --- a/src/askui/chat/api/messages/models.py +++ b/src/askui/chat/api/messages/models.py @@ -1,16 +1,62 @@ from typing import Literal -from askui.chat.api.models import AssistantId, MessageId, RunId, ThreadId -from askui.models.shared.agent_message_param import MessageParam +from pydantic import BaseModel + +from askui.chat.api.models import AssistantId, FileId, MessageId, RunId, ThreadId +from askui.models.shared.agent_message_param import ( + Base64ImageSourceParam, + BetaRedactedThinkingBlock, + BetaThinkingBlock, + CacheControlEphemeralParam, + StopReason, + TextBlockParam, + ToolUseBlockParam, + UrlImageSourceParam, +) from askui.utils.api_utils import Resource from askui.utils.datetime_utils import UnixDatetime, now from askui.utils.id_utils import generate_time_ordered_id +class FileImageSourceParam(BaseModel): + """Image source that references a saved file.""" + + id: FileId + type: Literal["file"] = "file" + + +class ImageBlockParam(BaseModel): + source: Base64ImageSourceParam | UrlImageSourceParam | FileImageSourceParam + type: Literal["image"] = "image" + cache_control: CacheControlEphemeralParam | None = None + + +class ToolResultBlockParam(BaseModel): + tool_use_id: str + type: Literal["tool_result"] = "tool_result" + cache_control: CacheControlEphemeralParam | None = None + content: str | list[TextBlockParam | ImageBlockParam] + is_error: bool = False + + +ContentBlockParam = ( + ImageBlockParam + | TextBlockParam + | ToolResultBlockParam + | ToolUseBlockParam + | BetaThinkingBlock + | BetaRedactedThinkingBlock +) + + +class MessageParam(BaseModel): + role: Literal["user", "assistant"] + content: str | list[ContentBlockParam] + stop_reason: StopReason | None = None + + class MessageBase(MessageParam): assistant_id: AssistantId | None = None - object: Literal["thread.message"] = "thread.message" - role: Literal["user", "assistant"] run_id: RunId | None = None @@ -20,6 +66,7 @@ class MessageCreateParams(MessageBase): class Message(MessageBase, Resource): id: MessageId + object: Literal["thread.message"] = "thread.message" created_at: UnixDatetime thread_id: ThreadId diff --git a/src/askui/chat/api/messages/router.py b/src/askui/chat/api/messages/router.py index 317acab6..82e75d3a 100644 --- a/src/askui/chat/api/messages/router.py +++ b/src/askui/chat/api/messages/router.py @@ -1,9 +1,10 @@ from fastapi import APIRouter, status +from askui.chat.api.dependencies import ListQueryDep from askui.chat.api.messages.dependencies import MessageServiceDep from askui.chat.api.messages.models import Message, MessageCreateParams from askui.chat.api.messages.service import MessageService -from askui.chat.api.models import ListQueryDep, MessageId, ThreadId +from askui.chat.api.models import MessageId, ThreadId from askui.utils.api_utils import ListQuery, ListResponse router = APIRouter(prefix="/threads/{thread_id}/messages", tags=["messages"]) diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py index 12c59339..7821cc7b 100644 --- a/src/askui/chat/api/messages/service.py +++ b/src/askui/chat/api/messages/service.py @@ -16,7 +16,7 @@ def __init__(self, base_dir: Path) -> None: self._base_dir = base_dir def get_messages_dir(self, thread_id: ThreadId) -> Path: - return self._base_dir / "threads" / thread_id / "messages" + return self._base_dir / "messages" / thread_id def _get_message_path( self, thread_id: ThreadId, message_id: MessageId, new: bool = False diff --git a/src/askui/chat/api/messages/translator.py b/src/askui/chat/api/messages/translator.py new file mode 100644 index 00000000..8c24a752 --- /dev/null +++ b/src/askui/chat/api/messages/translator.py @@ -0,0 +1,238 @@ +from io import BytesIO +from pathlib import Path + +from fastapi import UploadFile +from fastapi.datastructures import Headers +from PIL import Image + +from askui.chat.api.files.service import FileService +from askui.chat.api.messages.models import ( + ContentBlockParam, + FileImageSourceParam, + ImageBlockParam, + MessageParam, + ToolResultBlockParam, +) +from askui.logger import logger +from askui.models.shared.agent_message_param import ( + Base64ImageSourceParam, + TextBlockParam, + UrlImageSourceParam, +) +from askui.models.shared.agent_message_param import ( + ContentBlockParam as AnthropicContentBlockParam, +) +from askui.models.shared.agent_message_param import ( + ImageBlockParam as AnthropicImageBlockParam, +) +from askui.models.shared.agent_message_param import ( + MessageParam as AnthropicMessageParam, +) +from askui.models.shared.agent_message_param import ( + ToolResultBlockParam as AnthropicToolResultBlockParam, +) +from askui.utils.image_utils import base64_to_image, image_to_base64 + + +class ImageBlockParamSourceTranslator: + def __init__(self, file_service: FileService) -> None: + self._file_service = file_service + + async def from_anthropic( + self, source: UrlImageSourceParam | Base64ImageSourceParam + ) -> UrlImageSourceParam | Base64ImageSourceParam | FileImageSourceParam: + if source.type == "url": + return source + if source.type == "base64": # noqa: RET503 + try: + image = base64_to_image(source.data) + bytes_io = BytesIO() + image.save(bytes_io, format="PNG") + bytes_io.seek(0) + file = await self._file_service.upload_file( + file=UploadFile( + file=bytes_io, + headers=Headers( + { + "Content-Type": "image/png", + } + ), + ) + ) + except Exception as e: # noqa: BLE001 + logger.warning(f"Failed to save image: {e}", exc_info=True) + return source + else: + return FileImageSourceParam(id=file.id, type="file") + + async def to_anthropic( + self, + source: UrlImageSourceParam | Base64ImageSourceParam | FileImageSourceParam, + ) -> UrlImageSourceParam | Base64ImageSourceParam: + if source.type == "url": + return source + if source.type == "base64": + return source + if source.type == "file": # noqa: RET503 + file, path = self._file_service.retrieve_file_content(source.id) + image = Image.open(path) + return Base64ImageSourceParam( + data=image_to_base64(image), + media_type=file.media_type, + ) + + +class ImageBlockParamTranslator: + def __init__(self, file_service: FileService) -> None: + self.source_translator = ImageBlockParamSourceTranslator(file_service) + + async def from_anthropic(self, block: AnthropicImageBlockParam) -> ImageBlockParam: + return ImageBlockParam( + source=await self.source_translator.from_anthropic(block.source), + type="image", + cache_control=block.cache_control, + ) + + async def to_anthropic(self, block: ImageBlockParam) -> AnthropicImageBlockParam: + return AnthropicImageBlockParam( + source=await self.source_translator.to_anthropic(block.source), + type="image", + cache_control=block.cache_control, + ) + + +class ToolResultContentBlockParamTranslator: + def __init__(self, file_service: FileService) -> None: + self.image_translator = ImageBlockParamTranslator(file_service) + + async def from_anthropic( + self, block: AnthropicImageBlockParam | TextBlockParam + ) -> ImageBlockParam | TextBlockParam: + if block.type == "image": + return await self.image_translator.from_anthropic(block) + return block + + async def to_anthropic( + self, block: ImageBlockParam | TextBlockParam + ) -> AnthropicImageBlockParam | TextBlockParam: + if block.type == "image": + return await self.image_translator.to_anthropic(block) + return block + + +class ToolResultContentTranslator: + def __init__(self, file_service: FileService) -> None: + self.block_param_translator = ToolResultContentBlockParamTranslator( + file_service + ) + + async def from_anthropic( + self, content: str | list[AnthropicImageBlockParam | TextBlockParam] + ) -> str | list[ImageBlockParam | TextBlockParam]: + if isinstance(content, str): + return content + return [ + await self.block_param_translator.from_anthropic(block) for block in content + ] + + async def to_anthropic( + self, content: str | list[ImageBlockParam | TextBlockParam] + ) -> str | list[AnthropicImageBlockParam | TextBlockParam]: + if isinstance(content, str): + return content + return [ + await self.block_param_translator.to_anthropic(block) for block in content + ] + + +class ToolResultBlockParamTranslator: + def __init__(self, file_service: FileService) -> None: + self.content_translator = ToolResultContentTranslator(file_service) + + async def from_anthropic( + self, block: AnthropicToolResultBlockParam + ) -> ToolResultBlockParam: + return ToolResultBlockParam( + tool_use_id=block.tool_use_id, + type="tool_result", + cache_control=block.cache_control, + content=await self.content_translator.from_anthropic(block.content), + is_error=block.is_error, + ) + + async def to_anthropic( + self, block: ToolResultBlockParam + ) -> AnthropicToolResultBlockParam: + return AnthropicToolResultBlockParam( + tool_use_id=block.tool_use_id, + type="tool_result", + cache_control=block.cache_control, + content=await self.content_translator.to_anthropic(block.content), + is_error=block.is_error, + ) + + +class MessageContentBlockParamTranslator: + def __init__(self, file_service: FileService) -> None: + self.image_translator = ImageBlockParamTranslator(file_service) + self.tool_result_translator = ToolResultBlockParamTranslator(file_service) + + async def from_anthropic( + self, block: AnthropicContentBlockParam + ) -> ContentBlockParam: + if block.type == "image": + return await self.image_translator.from_anthropic(block) + if block.type == "tool_result": + return await self.tool_result_translator.from_anthropic(block) + return block + + async def to_anthropic( + self, block: ContentBlockParam + ) -> AnthropicContentBlockParam: + if block.type == "image": + return await self.image_translator.to_anthropic(block) + if block.type == "tool_result": + return await self.tool_result_translator.to_anthropic(block) + return block + + +class MessageContentTranslator: + def __init__(self, file_service: FileService) -> None: + self.block_param_translator = MessageContentBlockParamTranslator(file_service) + + async def from_anthropic( + self, content: list[AnthropicContentBlockParam] | str + ) -> list[ContentBlockParam] | str: + if isinstance(content, str): + return content + return [ + await self.block_param_translator.from_anthropic(block) for block in content + ] + + async def to_anthropic( + self, content: list[ContentBlockParam] | str + ) -> list[AnthropicContentBlockParam] | str: + if isinstance(content, str): + return content + return [ + await self.block_param_translator.to_anthropic(block) for block in content + ] + + +class MessageTranslator: + def __init__(self, file_service: FileService) -> None: + self.content_translator = MessageContentTranslator(file_service) + + async def from_anthropic(self, message: AnthropicMessageParam) -> MessageParam: + return MessageParam( + role=message.role, + content=await self.content_translator.from_anthropic(message.content), + stop_reason=message.stop_reason, + ) + + async def to_anthropic(self, message: MessageParam) -> AnthropicMessageParam: + return AnthropicMessageParam( + role=message.role, + content=await self.content_translator.to_anthropic(message.content), + stop_reason=message.stop_reason, + ) diff --git a/src/askui/chat/api/models.py b/src/askui/chat/api/models.py index 85235550..0a38d712 100644 --- a/src/askui/chat/api/models.py +++ b/src/askui/chat/api/models.py @@ -1,14 +1,10 @@ -from fastapi import Depends -from pydantic import BaseModel +from typing import Annotated -from askui.utils.api_utils import ListQuery +from askui.utils.id_utils import IdField -AssistantId = str -McpConfigId = str -FileId = str -MessageId = str -RunId = str -ThreadId = str - - -ListQueryDep = Depends(ListQuery) +AssistantId = Annotated[str, IdField("asst")] +McpConfigId = Annotated[str, IdField("mcpcnf")] +FileId = Annotated[str, IdField("file")] +MessageId = Annotated[str, IdField("msg")] +RunId = Annotated[str, IdField("run")] +ThreadId = Annotated[str, IdField("thread")] diff --git a/src/askui/chat/api/runs/dependencies.py b/src/askui/chat/api/runs/dependencies.py index dd37c09a..afb01cab 100644 --- a/src/askui/chat/api/runs/dependencies.py +++ b/src/askui/chat/api/runs/dependencies.py @@ -3,13 +3,20 @@ from fastapi import Depends from askui.chat.api.dependencies import WorkspaceDirDep +from askui.chat.api.messages.dependencies import MessageServiceDep, MessageTranslatorDep +from askui.chat.api.messages.service import MessageService +from askui.chat.api.messages.translator import MessageTranslator from .service import RunService -def get_runs_service(workspace_dir: Path = WorkspaceDirDep) -> RunService: +def get_runs_service( + workspace_dir: Path = WorkspaceDirDep, + message_service: MessageService = MessageServiceDep, + message_translator: MessageTranslator = MessageTranslatorDep, +) -> RunService: """Get RunService instance.""" - return RunService(workspace_dir) + return RunService(workspace_dir, message_service, message_translator) RunServiceDep = Depends(get_runs_service) diff --git a/src/askui/chat/api/runs/router.py b/src/askui/chat/api/runs/router.py index e7aa687f..279cd0c4 100644 --- a/src/askui/chat/api/runs/router.py +++ b/src/askui/chat/api/runs/router.py @@ -1,11 +1,12 @@ from collections.abc import AsyncGenerator from typing import Annotated -from fastapi import APIRouter, BackgroundTasks, Body, Path, Response, status +from fastapi import APIRouter, BackgroundTasks, Path, Response, status from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel -from askui.chat.api.models import ListQueryDep, RunId, ThreadId +from askui.chat.api.dependencies import ListQueryDep +from askui.chat.api.models import RunId, ThreadId from askui.chat.api.runs.models import RunCreateParams from askui.utils.api_utils import ListQuery, ListResponse diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index 7c82b291..bf250fab 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -23,6 +23,7 @@ from askui.chat.api.mcp_configs.service import McpConfigService from askui.chat.api.messages.models import MessageCreateParams from askui.chat.api.messages.service import MessageService +from askui.chat.api.messages.translator import MessageTranslator from askui.chat.api.models import RunId, ThreadId from askui.chat.api.runs.models import Run, RunError from askui.chat.api.runs.runner.events.done_events import DoneEvent @@ -91,14 +92,22 @@ def get_mcp_client( class Runner: - def __init__(self, run: Run, base_dir: Path) -> None: + def __init__( + self, + run: Run, + base_dir: Path, + message_service: MessageService, + message_translator: MessageTranslator, + ) -> None: self._run = run self._base_dir = base_dir - self._msg_service = MessageService(self._base_dir) + self._message_service = message_service + self._message_translator = message_translator + self._message_content_translator = message_translator.content_translator self._agent_os = PynputAgentOs() def get_runs_dir(self, thread_id: ThreadId) -> Path: - return self._base_dir / "threads" / thread_id / "runs" + return self._base_dir / "runs" / thread_id def _get_run_path( self, thread_id: ThreadId, run_id: RunId, new: bool = False @@ -123,7 +132,7 @@ def _retrieve(self) -> Run: return Run.model_validate_json(run_file.read_text(encoding="utf-8")) async def _run_human_agent(self, send_stream: ObjectStream[Events]) -> None: - message = self._msg_service.create( + message = self._message_service.create( thread_id=self._run.thread_id, params=MessageCreateParams( role="user", @@ -160,26 +169,28 @@ async def _run_human_agent(self, send_stream: ObjectStream[Events]) -> None: if event.button != "unknown" else "a mouse button" ) - message = self._msg_service.create( + message = self._message_service.create( thread_id=self._run.thread_id, params=MessageCreateParams( role="user", - content=[ - ImageBlockParam( - type="image", - source=Base64ImageSourceParam( - data=ImageSource(screenshot).to_base64(), - media_type="image/png", + content=await self._message_content_translator.from_anthropic( + [ + ImageBlockParam( + type="image", + source=Base64ImageSourceParam( + data=ImageSource(screenshot).to_base64(), + media_type="image/png", + ), ), - ), - TextBlockParam( - type="text", - text=( - f"I moved the mouse to x={event.x}, " - f"y={event.y} and clicked {button}." + TextBlockParam( + type="text", + text=( + f"I moved the mouse to x={event.x}, " + f"y={event.y} and clicked {button}." + ), ), - ), - ], + ] + ), run_id=self._run.id, ), ) @@ -194,7 +205,7 @@ async def _run_human_agent(self, send_stream: ObjectStream[Events]) -> None: self._agent_os.stop_listening() if len(recorded_events) == 0: text = "Nevermind, I didn't do anything." - message = self._msg_service.create( + message = self._message_service.create( thread_id=self._run.thread_id, params=MessageCreateParams( role="user", @@ -258,11 +269,8 @@ async def _run_agent( ) -> None: tools = ToolCollection(mcp_client=mcp_client) messages: list[MessageParam] = [ - MessageParam( - role=msg.role, - content=msg.content, - ) - for msg in self._msg_service.list_( + await self._message_translator.to_anthropic(msg) + for msg in self._message_service.list_( thread_id=self._run.thread_id, query=ListQuery(limit=LIST_LIMIT_MAX), ).data @@ -271,14 +279,16 @@ async def _run_agent( async def async_on_message( on_message_cb_param: OnMessageCbParam, ) -> MessageParam | None: - message = self._msg_service.create( + message = self._message_service.create( thread_id=self._run.thread_id, params=MessageCreateParams( assistant_id=self._run.assistant_id if on_message_cb_param.message.role == "assistant" else None, role=on_message_cb_param.message.role, - content=on_message_cb_param.message.content, + content=await self._message_content_translator.from_anthropic( + on_message_cb_param.message.content + ), run_id=self._run.id, ), ) diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py index 3f5f0be2..9c5aed5f 100644 --- a/src/askui/chat/api/runs/service.py +++ b/src/askui/chat/api/runs/service.py @@ -4,6 +4,8 @@ import anyio +from askui.chat.api.messages.service import MessageService +from askui.chat.api.messages.translator import MessageTranslator from askui.chat.api.models import RunId, ThreadId from askui.chat.api.runs.models import Run, RunCreateParams from askui.chat.api.runs.runner.events import Events @@ -23,11 +25,18 @@ class RunService: """Service for managing Run resources with filesystem persistence.""" - def __init__(self, base_dir: Path) -> None: + def __init__( + self, + base_dir: Path, + message_service: MessageService, + message_translator: MessageTranslator, + ) -> None: self._base_dir = base_dir + self._message_service = message_service + self._message_translator = message_translator def get_runs_dir(self, thread_id: ThreadId) -> Path: - return self._base_dir / "threads" / thread_id / "runs" + return self._base_dir / "runs" / thread_id def _get_run_path( self, thread_id: ThreadId, run_id: RunId, new: bool = False @@ -52,7 +61,9 @@ async def create( ) -> tuple[Run, AsyncGenerator[Events, None]]: run = self._create(thread_id, params) send_stream, receive_stream = anyio.create_memory_object_stream[Events]() - runner = Runner(run, self._base_dir) + runner = Runner( + run, self._base_dir, self._message_service, self._message_translator + ) async def event_generator() -> AsyncGenerator[Events, None]: try: diff --git a/src/askui/chat/api/threads/dependencies.py b/src/askui/chat/api/threads/dependencies.py index b52df3ca..7c396bd5 100644 --- a/src/askui/chat/api/threads/dependencies.py +++ b/src/askui/chat/api/threads/dependencies.py @@ -5,17 +5,21 @@ from askui.chat.api.dependencies import WorkspaceDirDep from askui.chat.api.messages.dependencies import MessageServiceDep from askui.chat.api.messages.service import MessageService +from askui.chat.api.runs.dependencies import RunServiceDep +from askui.chat.api.runs.service import RunService from askui.chat.api.threads.service import ThreadService def get_thread_service( workspace_dir: Path = WorkspaceDirDep, message_service: MessageService = MessageServiceDep, + run_service: RunService = RunServiceDep, ) -> ThreadService: """Get ThreadService instance.""" return ThreadService( base_dir=workspace_dir, message_service=message_service, + run_service=run_service, ) diff --git a/src/askui/chat/api/threads/router.py b/src/askui/chat/api/threads/router.py index 5808bac9..a9e18bf4 100644 --- a/src/askui/chat/api/threads/router.py +++ b/src/askui/chat/api/threads/router.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, status -from askui.chat.api.models import ListQueryDep, ThreadId +from askui.chat.api.dependencies import ListQueryDep +from askui.chat.api.models import ThreadId from askui.chat.api.threads.dependencies import ThreadServiceDep from askui.chat.api.threads.models import Thread, ThreadCreateParams, ThreadModifyParams from askui.chat.api.threads.service import ThreadService diff --git a/src/askui/chat/api/threads/service.py b/src/askui/chat/api/threads/service.py index 66fde497..d58f89e6 100644 --- a/src/askui/chat/api/threads/service.py +++ b/src/askui/chat/api/threads/service.py @@ -1,9 +1,9 @@ import shutil from pathlib import Path -from askui.chat.api.messages.models import MessageCreateParams from askui.chat.api.messages.service import MessageService from askui.chat.api.models import ThreadId +from askui.chat.api.runs.service import RunService from askui.chat.api.threads.models import Thread, ThreadCreateParams, ThreadModifyParams from askui.utils.api_utils import ( ConflictError, @@ -17,10 +17,13 @@ class ThreadService: """Service for managing Thread resources with filesystem persistence.""" - def __init__(self, base_dir: Path, message_service: MessageService) -> None: + def __init__( + self, base_dir: Path, message_service: MessageService, run_service: RunService + ) -> None: self._base_dir = base_dir self._threads_dir = base_dir / "threads" self._message_service = message_service + self._run_service = run_service def _get_thread_path(self, thread_id: ThreadId, new: bool = False) -> Path: thread_path = self._threads_dir / f"{thread_id}.json" @@ -64,7 +67,8 @@ def modify(self, thread_id: ThreadId, params: ThreadModifyParams) -> Thread: def delete(self, thread_id: ThreadId) -> None: try: - shutil.rmtree(self._threads_dir / thread_id) + shutil.rmtree(self._message_service.get_messages_dir(thread_id)) + shutil.rmtree(self._run_service.get_runs_dir(thread_id)) self._get_thread_path(thread_id).unlink() except FileNotFoundError as e: error_msg = f"Thread {thread_id} not found" diff --git a/src/askui/utils/api_utils.py b/src/askui/utils/api_utils.py index 8eade9bc..a699415e 100644 --- a/src/askui/utils/api_utils.py +++ b/src/askui/utils/api_utils.py @@ -48,6 +48,12 @@ class NotFoundError(ApiError): pass +class FileTooLargeError(ApiError): + def __init__(self, max_size: int): + self.max_size = max_size + super().__init__(f"File too large. Maximum size is {max_size} bytes.") + + def list_resource_paths(base_dir: Path, list_query: ListQuery) -> list[Path]: paths: list[Path] = [] after_name = f"{list_query.after}.json" diff --git a/tests/integration/chat/__init__.py b/tests/integration/chat/__init__.py new file mode 100644 index 00000000..baa44a5b --- /dev/null +++ b/tests/integration/chat/__init__.py @@ -0,0 +1 @@ +"""Chat integration tests.""" diff --git a/tests/integration/chat/api/conftest.py b/tests/integration/chat/api/conftest.py new file mode 100644 index 00000000..a9272840 --- /dev/null +++ b/tests/integration/chat/api/conftest.py @@ -0,0 +1,76 @@ +"""Chat API integration test configuration and fixtures.""" + +import tempfile +import uuid +from pathlib import Path + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from askui.chat.api.app import app +from askui.chat.api.files.service import FileService + + +@pytest.fixture +def test_app() -> FastAPI: + """Get the FastAPI test application.""" + return app + + +@pytest.fixture +def test_client(test_app: FastAPI) -> TestClient: + """Get a test client for the FastAPI application.""" + return TestClient(test_app) + + +@pytest.fixture +def temp_workspace_dir() -> Path: + """Create a temporary workspace directory for testing.""" + temp_dir = tempfile.mkdtemp() + return Path(temp_dir) + + +@pytest.fixture +def test_workspace_id() -> str: + """Get a test workspace ID.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def test_headers(test_workspace_id: str) -> dict[str, str]: + """Get test headers with workspace ID.""" + return {"askui-workspace": test_workspace_id} + + +@pytest.fixture +def mock_file_service(temp_workspace_dir: Path) -> FileService: + """Create a mock file service with temporary workspace.""" + return FileService(temp_workspace_dir) + + +def create_test_app_with_overrides(workspace_path: Path) -> FastAPI: + """Create a test app with all dependencies overridden.""" + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + + # Create a copy of the app to avoid modifying the global one + test_app = FastAPI() + test_app.router = app.router + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + test_app.dependency_overrides[get_workspace_dir] = override_workspace_dir + test_app.dependency_overrides[get_file_service] = override_file_service + test_app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + return test_app diff --git a/tests/integration/chat/api/test_files.py b/tests/integration/chat/api/test_files.py new file mode 100644 index 00000000..99e6de6a --- /dev/null +++ b/tests/integration/chat/api/test_files.py @@ -0,0 +1,555 @@ +"""Integration tests for the files API endpoints.""" + +import io +import tempfile +from pathlib import Path + +from fastapi import status +from fastapi.testclient import TestClient + +from askui.chat.api.files.models import File +from askui.chat.api.files.service import FileService + + +class TestFilesAPI: + """Test suite for the files API endpoints.""" + + def test_list_files_empty( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test listing files when no files exist.""" + response = test_client.get("/v1/files", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert data["data"] == [] + assert data["has_more"] is False + + def test_list_files_with_files( + self, + test_headers: dict[str, str], + ) -> None: + """Test listing files when files exist.""" + # Create a mock file in the temporary workspace + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + files_dir = workspace_path / "files" + files_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock file + mock_file = File( + id="file_test123", + object="file", + created_at=1234567890, + filename="test.txt", + size=32, + media_type="text/plain", + ) + (files_dir / "file_test123.json").write_text(mock_file.model_dump_json()) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + + try: + with TestClient(app) as client: + response = client.get("/v1/files", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["id"] == "file_test123" + assert data["data"][0]["filename"] == "test.txt" + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_list_files_with_pagination(self, test_headers: dict[str, str]) -> None: + """Test listing files with pagination parameters.""" + # Create multiple mock files in the temporary workspace + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + files_dir = workspace_path / "files" + files_dir.mkdir(parents=True, exist_ok=True) + + # Create multiple mock files + for i in range(5): + mock_file = File( + id=f"file_test{i}", + object="file", + created_at=1234567890 + i, + filename=f"test{i}.txt", + size=32, + media_type="text/plain", + ) + (files_dir / f"file_test{i}.json").write_text(mock_file.model_dump_json()) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + + try: + with TestClient(app) as client: + response = client.get("/v1/files?limit=2", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["data"]) == 2 + assert data["has_more"] is True + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_upload_file_success( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test successful file upload.""" + file_content = b"test file content" + files = {"file": ("test.txt", io.BytesIO(file_content), "text/plain")} + + response = test_client.post("/v1/files", files=files, headers=test_headers) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["object"] == "file" + assert data["filename"] == "test.txt" + assert data["size"] == len(file_content) + assert data["media_type"] == "text/plain" + assert "id" in data + assert "created_at" in data + + def test_upload_file_without_filename(self, test_headers: dict[str, str]) -> None: + """Test file upload with simple filename.""" + file_content = b"test file content" + # Test with a simple filename + files = {"file": ("test", io.BytesIO(file_content), "text/plain")} + + # Create a test app with overridden dependencies + from integration.chat.api.conftest import create_test_app_with_overrides + + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + test_app = create_test_app_with_overrides(workspace_path) + + with TestClient(test_app) as client: + response = client.post("/v1/files", files=files, headers=test_headers) + + if response.status_code != status.HTTP_201_CREATED: + print(f"Response status: {response.status_code}") + print(f"Response body: {response.text}") + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["object"] == "file" + # Should use the provided filename + assert data["filename"] == "test" + assert data["size"] == len(file_content) + assert data["media_type"] == "text/plain" + + def test_upload_file_large_size( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test file upload with file exceeding size limit.""" + # Create a file larger than 20MB + large_content = b"x" * (21 * 1024 * 1024) + files = {"file": ("large.txt", io.BytesIO(large_content), "text/plain")} + + response = test_client.post("/v1/files", files=files, headers=test_headers) + + assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE + data = response.json() + assert "detail" in data + + def test_retrieve_file_success(self, test_headers: dict[str, str]) -> None: + """Test successful file retrieval.""" + # Create a mock file in the temporary workspace + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + files_dir = workspace_path / "files" + files_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock file + mock_file = File( + id="file_test123", + object="file", + created_at=1234567890, + filename="test.txt", + size=32, + media_type="text/plain", + ) + (files_dir / "file_test123.json").write_text(mock_file.model_dump_json()) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + + try: + with TestClient(app) as client: + response = client.get("/v1/files/file_test123", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == "file_test123" + assert data["filename"] == "test.txt" + assert data["size"] == 32 + assert data["media_type"] == "text/plain" + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_retrieve_file_not_found(self, test_headers: dict[str, str]) -> None: + """Test file retrieval when file doesn't exist.""" + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + try: + with TestClient(app) as client: + response = client.get( + "/v1/files/file_nonexistent123", headers=test_headers + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert "detail" in data + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_download_file_success(self, test_headers: dict[str, str]) -> None: + """Test successful file download.""" + # Create a mock file in the temporary workspace + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + files_dir = workspace_path / "files" + static_dir = workspace_path / "static" + files_dir.mkdir(parents=True, exist_ok=True) + static_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock file + mock_file = File( + id="file_test123", + object="file", + created_at=1234567890, + filename="test.txt", + size=32, + media_type="text/plain", + ) + (files_dir / "file_test123.json").write_text(mock_file.model_dump_json()) + + # Create the actual file content + file_content = b"test file content" + (static_dir / "file_test123.txt").write_bytes(file_content) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + + try: + with TestClient(app) as client: + response = client.get( + "/v1/files/file_test123/content", headers=test_headers + ) + + assert response.status_code == status.HTTP_200_OK + assert response.content == file_content + assert response.headers["content-type"].startswith("text/plain") + assert ( + response.headers["content-disposition"] + == 'attachment; filename="test.txt"' + ) + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_download_file_not_found(self, test_headers: dict[str, str]) -> None: + """Test file download when file doesn't exist.""" + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + try: + with TestClient(app) as client: + response = client.get( + "/v1/files/file_nonexistent123/content", headers=test_headers + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert "detail" in data + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_delete_file_success(self, test_headers: dict[str, str]) -> None: + """Test successful file deletion.""" + # Create a mock file in the temporary workspace + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + files_dir = workspace_path / "files" + static_dir = workspace_path / "static" + files_dir.mkdir(parents=True, exist_ok=True) + static_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock file + mock_file = File( + id="file_test123", + object="file", + created_at=1234567890, + filename="test.txt", + size=32, + media_type="text/plain", + ) + (files_dir / "file_test123.json").write_text(mock_file.model_dump_json()) + + # Create the actual file content + file_content = b"test file content" + (static_dir / "file_test123.txt").write_bytes(file_content) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + + try: + with TestClient(app) as client: + response = client.delete("/v1/files/file_test123", headers=test_headers) + + assert response.status_code == status.HTTP_204_NO_CONTENT + + # Verify file is deleted + assert not (files_dir / "file_test123.json").exists() + assert not (static_dir / "file_test123.txt").exists() + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_delete_file_not_found(self, test_headers: dict[str, str]) -> None: + """Test file deletion when file doesn't exist.""" + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + try: + with TestClient(app) as client: + response = client.delete( + "/v1/files/file_nonexistent123", headers=test_headers + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert "detail" in data + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_upload_different_file_types( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test uploading different file types.""" + # Test JSON file + json_content = b'{"key": "value"}' + json_files = { + "file": ("data.json", io.BytesIO(json_content), "application/json") + } + + response = test_client.post("/v1/files", files=json_files, headers=test_headers) + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["media_type"] == "application/json" + assert data["filename"] == "data.json" + + # Test PDF file + pdf_content = b"%PDF-1.4\ntest pdf content" + pdf_files = { + "file": ("document.pdf", io.BytesIO(pdf_content), "application/pdf") + } + + response = test_client.post("/v1/files", files=pdf_files, headers=test_headers) + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["media_type"] == "application/pdf" + assert data["filename"] == "document.pdf" + + def test_upload_file_without_content_type( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test file upload without content type.""" + file_content = b"test file content" + files = {"file": ("test.txt", io.BytesIO(file_content), None)} + + response = test_client.post("/v1/files", files=files, headers=test_headers) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + # FastAPI might infer content type from filename, so we just check it's not None + assert data["media_type"] is not None + assert data["media_type"] != "" + + def test_list_files_with_filtering(self, test_headers: dict[str, str]) -> None: + """Test listing files with filtering parameters.""" + # Create multiple mock files in the temporary workspace + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + files_dir = workspace_path / "files" + files_dir.mkdir(parents=True, exist_ok=True) + + # Create multiple mock files with different timestamps + for i in range(3): + mock_file = File( + id=f"file_test{i}", + object="file", + created_at=1234567890 + i, + filename=f"test{i}.txt", + size=32, + media_type="text/plain", + ) + (files_dir / f"file_test{i}.json").write_text(mock_file.model_dump_json()) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + + try: + with TestClient(app) as client: + # Test with after parameter + response = client.get( + "/v1/files?after=file_test0", headers=test_headers + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["data"]) == 2 + # After file_test0 should return file_test1 and file_test2 in + # descending order + assert data["data"][0]["id"] == "file_test2" + assert data["data"][1]["id"] == "file_test1" + + # Test with before parameter + response = client.get( + "/v1/files?before=file_test2", headers=test_headers + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["data"]) == 2 + # Before file_test2 should return file_test0 and file_test1 in + # descending order + assert data["data"][0]["id"] == "file_test1" + assert data["data"][1]["id"] == "file_test0" + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() diff --git a/tests/integration/chat/api/test_files_edge_cases.py b/tests/integration/chat/api/test_files_edge_cases.py new file mode 100644 index 00000000..fd4fc9d3 --- /dev/null +++ b/tests/integration/chat/api/test_files_edge_cases.py @@ -0,0 +1,457 @@ +"""Edge case and error scenario tests for the files API endpoints.""" + +import io +import tempfile +from pathlib import Path + +from fastapi import status +from fastapi.testclient import TestClient + + +class TestFilesAPIEdgeCases: + """Test suite for edge cases and error scenarios in the files API.""" + + def test_upload_empty_file(self, test_headers: dict[str, str]) -> None: + """Test uploading an empty file.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + from askui.chat.api.files.service import FileService + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + try: + with TestClient(app) as client: + empty_content = b"" + files = {"file": ("empty.txt", io.BytesIO(empty_content), "text/plain")} + + response = client.post("/v1/files", files=files, headers=test_headers) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["size"] == 0 + assert data["filename"] == "empty.txt" + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_upload_file_with_special_characters_in_filename( + self, test_headers: dict[str, str] + ) -> None: + """Test uploading a file with special characters in the filename.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + from askui.chat.api.files.service import FileService + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + try: + with TestClient(app) as client: + file_content = b"test content" + special_filename = "file with spaces & special chars!@#$%^&*().txt" + files = { + "file": (special_filename, io.BytesIO(file_content), "text/plain") + } + + response = client.post("/v1/files", files=files, headers=test_headers) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["filename"] == special_filename + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_upload_file_with_very_long_filename( + self, test_headers: dict[str, str] + ) -> None: + """Test uploading a file with a very long filename.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + from askui.chat.api.files.service import FileService + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + try: + with TestClient(app) as client: + file_content = b"test content" + long_filename = "a" * 255 + ".txt" # Very long filename + files = { + "file": (long_filename, io.BytesIO(file_content), "text/plain") + } + + response = client.post("/v1/files", files=files, headers=test_headers) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["filename"] == long_filename + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_upload_file_with_unknown_mime_type( + self, test_headers: dict[str, str] + ) -> None: + """Test uploading a file with an unknown MIME type.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + from askui.chat.api.files.service import FileService + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + try: + with TestClient(app) as client: + file_content = b"test content" + unknown_mime = "application/unknown-type" + files = {"file": ("test.xyz", io.BytesIO(file_content), unknown_mime)} + + response = client.post("/v1/files", files=files, headers=test_headers) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["media_type"] == unknown_mime + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_upload_file_with_binary_content( + self, test_headers: dict[str, str] + ) -> None: + """Test uploading a file with binary content.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + from askui.chat.api.files.service import FileService + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + try: + with TestClient(app) as client: + # Create binary content (PNG header) + binary_content = b"\x89PNG\r\n\x1a\n" + b"x" * 100 + files = {"file": ("test.png", io.BytesIO(binary_content), "image/png")} + + response = client.post("/v1/files", files=files, headers=test_headers) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["media_type"] == "image/png" + assert data["size"] == len(binary_content) + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_upload_file_without_workspace_header( + self, test_client: TestClient + ) -> None: + """Test uploading a file without workspace header.""" + file_content = b"test content" + files = {"file": ("test.txt", io.BytesIO(file_content), "text/plain")} + + response = test_client.post("/v1/files", files=files) + + # Should fail due to missing workspace header + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_upload_file_with_invalid_workspace_header( + self, test_client: TestClient + ) -> None: + """Test uploading a file with an invalid workspace header.""" + file_content = b"test content" + files = {"file": ("test.txt", io.BytesIO(file_content), "text/plain")} + invalid_headers = {"askui-workspace": "invalid-uuid"} + + response = test_client.post("/v1/files", files=files, headers=invalid_headers) + + # Should fail due to invalid workspace format + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_upload_file_with_malformed_file_data( + self, test_headers: dict[str, str] + ) -> None: + """Test uploading with malformed file data.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + from askui.chat.api.files.service import FileService + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + try: + with TestClient(app) as client: + # Send request without file data + response = client.post("/v1/files", headers=test_headers) + + # Should fail due to missing file + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_upload_file_with_corrupted_content( + self, test_headers: dict[str, str] + ) -> None: + """Test uploading a file with corrupted content.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + from askui.chat.api.files.service import FileService + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + try: + with TestClient(app) as client: + # Create a file-like object that raises an error when read + class CorruptedFile: + def read(self, size: int) -> bytes: # noqa: ARG002 + error_msg = "Simulated corruption" + raise IOError(error_msg) + + files = {"file": ("corrupted.txt", CorruptedFile(), "text/plain")} + + response = client.post("/v1/files", files=files, headers=test_headers) # type: ignore[arg-type] + + # Should fail due to corruption - FastAPI returns 400 for this case + assert response.status_code == status.HTTP_400_BAD_REQUEST + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_list_files_with_invalid_pagination( + self, test_headers: dict[str, str] + ) -> None: + """Test listing files with invalid pagination parameters.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + from askui.chat.api.files.service import FileService + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + try: + with TestClient(app) as client: + # Test with negative limit + response = client.get("/v1/files?limit=-1", headers=test_headers) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + # Test with zero limit + response = client.get("/v1/files?limit=0", headers=test_headers) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + # Test with very large limit + response = client.get("/v1/files?limit=10000", headers=test_headers) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_retrieve_file_with_invalid_id_format( + self, test_headers: dict[str, str] + ) -> None: + """Test retrieving a file with an invalid ID format.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + from askui.chat.api.files.service import FileService + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + try: + with TestClient(app) as client: + # Test with empty ID - FastAPI returns 200 for this (lists files) + response = client.get("/v1/files/", headers=test_headers) + assert response.status_code == status.HTTP_200_OK + + # Test with ID containing invalid characters - should fail validation + response = client.get("/v1/files/file@#$%", headers=test_headers) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() + + def test_delete_file_with_invalid_id_format( + self, test_headers: dict[str, str] + ) -> None: + """Test deleting a file with an invalid ID format.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_workspace_dir + from askui.chat.api.files.dependencies import get_file_service + from askui.chat.api.files.service import FileService + + def override_workspace_dir() -> Path: + return workspace_path + + def override_file_service() -> FileService: + return FileService(workspace_path) + + def override_set_env_from_headers() -> None: + # No-op for testing + pass + + app.dependency_overrides[get_workspace_dir] = override_workspace_dir + app.dependency_overrides[get_file_service] = override_file_service + app.dependency_overrides[SetEnvFromHeadersDep] = override_set_env_from_headers + + try: + with TestClient(app) as client: + # Test with empty ID - FastAPI returns 405 Method Not Allowed for this + response = client.delete("/v1/files/", headers=test_headers) + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED + + # Test with ID containing invalid characters - should fail validation + response = client.delete("/v1/files/file@#$%", headers=test_headers) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + finally: + # Clean up dependency overrides + app.dependency_overrides.clear() diff --git a/tests/integration/chat/api/test_files_service.py b/tests/integration/chat/api/test_files_service.py new file mode 100644 index 00000000..49d221b1 --- /dev/null +++ b/tests/integration/chat/api/test_files_service.py @@ -0,0 +1,442 @@ +"""Integration tests for the FileService class.""" + +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock + +import pytest +from fastapi import UploadFile + +from askui.chat.api.files.models import File, FileCreateParams +from askui.chat.api.files.service import FileService +from askui.chat.api.models import FileId +from askui.utils.api_utils import ConflictError, FileTooLargeError, NotFoundError + + +class TestFileService: + """Test suite for the FileService class.""" + + @pytest.fixture + def temp_workspace_dir(self) -> Path: + """Create a temporary workspace directory for testing.""" + temp_dir = tempfile.mkdtemp() + return Path(temp_dir) + + @pytest.fixture + def file_service(self, temp_workspace_dir: Path) -> FileService: + """Create a FileService instance with temporary workspace.""" + return FileService(temp_workspace_dir) + + @pytest.fixture + def sample_file_params(self) -> FileCreateParams: + """Create sample file creation parameters.""" + return FileCreateParams(filename="test.txt", size=32, media_type="text/plain") + + def test_get_file_path_new_file(self, file_service: FileService) -> None: + """Test getting file path for a new file.""" + file_id = FileId("file_test123") + file_path = file_service._get_file_path(file_id, new=True) + + expected_path = file_service._files_dir / "file_test123.json" + assert file_path == expected_path + + def test_get_file_path_existing_file(self, file_service: FileService) -> None: + """Test getting file path for an existing file.""" + file_id = FileId("file_test123") + + # Create the file first + file_path = file_service._files_dir / "file_test123.json" + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text('{"id": "file_test123"}') + + result_path = file_service._get_file_path(file_id, new=False) + assert result_path == file_path + + def test_get_file_path_new_file_conflict(self, file_service: FileService) -> None: + """Test that getting path for new file raises ConflictError if file exists.""" + file_id = FileId("file_test123") + + # Create the file first + file_path = file_service._files_dir / "file_test123.json" + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text('{"id": "file_test123"}') + + with pytest.raises(ConflictError): + file_service._get_file_path(file_id, new=True) + + def test_get_file_path_existing_file_not_found( + self, file_service: FileService + ) -> None: + """Test that getting path for existing file raises NotFoundError if file + doesn't exist.""" + file_id = FileId("file_test123") + + with pytest.raises(NotFoundError): + file_service._get_file_path(file_id, new=False) + + def test_get_static_file_path(self, file_service: FileService) -> None: + """Test getting static file path based on file extension.""" + file = File( + id="file_test123", + object="file", + created_at=1234567890, + filename="test.txt", + size=32, + media_type="text/plain", + ) + + static_path = file_service._get_static_file_path(file) + expected_path = file_service._static_dir / "file_test123.txt" + assert static_path == expected_path + + def test_get_static_file_path_no_extension(self, file_service: FileService) -> None: + """Test getting static file path when MIME type has no extension.""" + file = File( + id="file_test123", + object="file", + created_at=1234567890, + filename="test", + size=32, + media_type="application/octet-stream", + ) + + static_path = file_service._get_static_file_path(file) + expected_path = file_service._static_dir / "file_test123" + assert static_path == expected_path + + def test_list_files_empty(self, file_service: FileService) -> None: + """Test listing files when no files exist.""" + from askui.utils.api_utils import ListQuery + + query = ListQuery() + result = file_service.list_(query) + + assert result.object == "list" + assert result.data == [] + assert result.has_more is False + + def test_list_files_with_files( + self, file_service: FileService, sample_file_params: FileCreateParams + ) -> None: + """Test listing files when files exist.""" + from askui.utils.api_utils import ListQuery + + # Create a file first + temp_file = Path(tempfile.mktemp()) + file_content = b"test content" + temp_file.write_bytes(file_content) + + # Update the size to match the actual file content + params = FileCreateParams( + filename=sample_file_params.filename, + size=len(file_content), + media_type=sample_file_params.media_type, + ) + + try: + file = file_service.create(params, temp_file) + + query = ListQuery() + result = file_service.list_(query) + + assert result.object == "list" + assert len(result.data) == 1 + assert result.data[0].id == file.id + assert result.data[0].filename == file.filename + finally: + temp_file.unlink(missing_ok=True) + + def test_retrieve_file_success( + self, file_service: FileService, sample_file_params: FileCreateParams + ) -> None: + """Test successful file retrieval.""" + # Create a file first + temp_file = Path(tempfile.mktemp()) + file_content = b"test content" + temp_file.write_bytes(file_content) + + # Update the size to match the actual file content + params = FileCreateParams( + filename=sample_file_params.filename, + size=len(file_content), + media_type=sample_file_params.media_type, + ) + + try: + file = file_service.create(params, temp_file) + + retrieved_file = file_service.retrieve(file.id) + + assert retrieved_file.id == file.id + assert retrieved_file.filename == file.filename + assert retrieved_file.size == file.size + assert retrieved_file.media_type == file.media_type + finally: + temp_file.unlink(missing_ok=True) + + def test_retrieve_file_not_found(self, file_service: FileService) -> None: + """Test file retrieval when file doesn't exist.""" + file_id = FileId("file_nonexistent123") + + with pytest.raises(NotFoundError): + file_service.retrieve(file_id) + + def test_delete_file_success( + self, file_service: FileService, sample_file_params: FileCreateParams + ) -> None: + """Test successful file deletion.""" + # Create a file first + temp_file = Path(tempfile.mktemp()) + file_content = b"test content" + temp_file.write_bytes(file_content) + + # Update the size to match the actual file content + params = FileCreateParams( + filename=sample_file_params.filename, + size=len(file_content), + media_type=sample_file_params.media_type, + ) + + try: + file = file_service.create(params, temp_file) + + # Verify file exists by retrieving it + retrieved_file = file_service.retrieve(file.id) + assert retrieved_file.id == file.id + + # Delete the file + file_service.delete(file.id) + + # Verify file is deleted by trying to retrieve it + # (should raise NotFoundError) + with pytest.raises(NotFoundError): + file_service.retrieve(file.id) + finally: + temp_file.unlink(missing_ok=True) + + def test_delete_file_not_found(self, file_service: FileService) -> None: + """Test file deletion when file doesn't exist.""" + file_id = FileId("file_nonexistent123") + + with pytest.raises(NotFoundError): + file_service.delete(file_id) + + def test_retrieve_file_content_success( + self, file_service: FileService, sample_file_params: FileCreateParams + ) -> None: + """Test successful file content retrieval.""" + # Create a file first + temp_file = Path(tempfile.mktemp()) + file_content = b"test content" + temp_file.write_bytes(file_content) + + # Update the size to match the actual file content + params = FileCreateParams( + filename=sample_file_params.filename, + size=len(file_content), + media_type=sample_file_params.media_type, + ) + + try: + file = file_service.create(params, temp_file) + + retrieved_file, file_path = file_service.retrieve_file_content(file.id) + + assert retrieved_file.id == file.id + assert file_path.exists() + finally: + temp_file.unlink(missing_ok=True) + + def test_retrieve_file_content_not_found(self, file_service: FileService) -> None: + """Test file content retrieval when file doesn't exist.""" + file_id = FileId("file_nonexistent123") + + with pytest.raises(NotFoundError): + file_service.retrieve_file_content(file_id) + + def test_create_file_success( + self, file_service: FileService, sample_file_params: FileCreateParams + ) -> None: + """Test successful file creation.""" + temp_file = Path(tempfile.mktemp()) + file_content = b"test content" + temp_file.write_bytes(file_content) + + try: + # Update the size to match the actual file content + params = FileCreateParams( + filename=sample_file_params.filename, + size=len(file_content), + media_type=sample_file_params.media_type, + ) + + file = file_service.create(params, temp_file) + + assert file.id.startswith("file_") + assert file.filename == params.filename + assert file.size == params.size + assert file.media_type == params.media_type + # created_at is a datetime, compare with timezone-aware datetime + from datetime import datetime, timezone + + assert isinstance(file.created_at, datetime) + assert file.created_at > datetime(2020, 1, 1, tzinfo=timezone.utc) + + # Verify metadata file was created + metadata_path = file_service._get_file_path(file.id, new=False) + assert metadata_path.exists() + + # Verify static file was moved + static_path = file_service._get_static_file_path(file) + assert static_path.exists() + + finally: + temp_file.unlink(missing_ok=True) + + def test_create_file_without_filename(self, file_service: FileService) -> None: + """Test file creation without filename.""" + temp_file = Path(tempfile.mktemp()) + file_content = b"test content" + temp_file.write_bytes(file_content) + + params = FileCreateParams( + filename=None, size=len(file_content), media_type="text/plain" + ) + + try: + file = file_service.create(params, temp_file) + + # Should auto-generate filename with extension + assert file.filename.endswith(".txt") + assert file.filename.startswith("file_") + + finally: + temp_file.unlink(missing_ok=True) + + def test_save_file_new( + self, file_service: FileService, sample_file_params: FileCreateParams + ) -> None: + """Test saving a new file.""" + file = File.create(sample_file_params) + + file_service._save(file, new=True) + + # Verify file was saved + saved_path = file_service._get_file_path(file.id, new=False) + assert saved_path.exists() + + # Verify content is correct + saved_content = saved_path.read_text() + saved_file = File.model_validate_json(saved_content) + assert saved_file.id == file.id + assert saved_file.filename == file.filename + + def test_save_file_existing( + self, file_service: FileService, sample_file_params: FileCreateParams + ) -> None: + """Test saving an existing file.""" + file = File.create(sample_file_params) + + # Save file first time + file_service._save(file, new=True) + + # Modify and save again + file.filename = "modified.txt" + file_service._save(file, new=False) + + # Verify changes were saved + saved_path = file_service._get_file_path(file.id, new=False) + saved_content = saved_path.read_text() + saved_file = File.model_validate_json(saved_content) + assert saved_file.filename == "modified.txt" + + @pytest.mark.asyncio + async def test_write_to_temp_file_success(self, file_service: FileService) -> None: + """Test successful writing to temporary file.""" + file_content = b"test file content" + mock_upload_file = AsyncMock(spec=UploadFile) + mock_upload_file.content_type = "text/plain" + mock_upload_file.filename = None + mock_upload_file.read.side_effect = [ + file_content, + b"", + ] # Read content, then empty + + params, temp_path = await file_service._write_to_temp_file(mock_upload_file) + + assert params.filename is None # No filename provided + assert params.size == len(file_content) + assert params.media_type == "text/plain" + assert temp_path.exists() + assert temp_path.read_bytes() == file_content + + # Cleanup + temp_path.unlink() + + @pytest.mark.asyncio + async def test_write_to_temp_file_large_size( + self, file_service: FileService + ) -> None: + """Test writing to temporary file with size exceeding limit.""" + # Create content larger than 20MB + large_content = b"x" * (21 * 1024 * 1024) + mock_upload_file = AsyncMock(spec=UploadFile) + mock_upload_file.content_type = "text/plain" + mock_upload_file.filename = "test.txt" + mock_upload_file.read.side_effect = [ + large_content, # Read all content at once + ] + + with pytest.raises(FileTooLargeError): + await file_service._write_to_temp_file(mock_upload_file) + + @pytest.mark.asyncio + async def test_write_to_temp_file_no_content_type( + self, file_service: FileService + ) -> None: + """Test writing to temporary file without content type.""" + file_content = b"test content" + mock_upload_file = AsyncMock(spec=UploadFile) + mock_upload_file.content_type = None + mock_upload_file.filename = "test.txt" + mock_upload_file.read.side_effect = [file_content, b""] + + params, temp_path = await file_service._write_to_temp_file(mock_upload_file) + + assert params.media_type == "application/octet-stream" # Default fallback + + # Cleanup + temp_path.unlink() + + @pytest.mark.asyncio + async def test_upload_file_success(self, file_service: FileService) -> None: + """Test successful file upload.""" + file_content = b"test file content" + mock_upload_file = AsyncMock(spec=UploadFile) + mock_upload_file.filename = "test.txt" + mock_upload_file.content_type = "text/plain" + mock_upload_file.read.side_effect = [file_content, b""] + + file = await file_service.upload_file(mock_upload_file) + + assert file.filename == "test.txt" + assert file.size == len(file_content) + assert file.media_type == "text/plain" + assert file.id.startswith("file_") + + # Verify files were created + metadata_path = file_service._get_file_path(file.id, new=False) + static_path = file_service._get_static_file_path(file) + assert metadata_path.exists() + assert static_path.exists() + + @pytest.mark.asyncio + async def test_upload_file_upload_failure(self, file_service: FileService) -> None: + """Test file upload when writing fails.""" + mock_upload_file = AsyncMock(spec=UploadFile) + mock_upload_file.filename = "test.txt" + mock_upload_file.content_type = "text/plain" + mock_upload_file.read.side_effect = Exception("Simulated upload failure") + + with pytest.raises(Exception, match="Simulated upload failure"): + await file_service.upload_file(mock_upload_file) diff --git a/tests/integration/chat/api/test_messages_service.py b/tests/integration/chat/api/test_messages_service.py deleted file mode 100644 index e58be196..00000000 --- a/tests/integration/chat/api/test_messages_service.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Integration tests for the MessageService with JSON file persistence.""" - -import json -import tempfile -from pathlib import Path -from typing import Generator - -import pytest - -from askui.chat.api.messages.models import MessageCreateParams -from askui.chat.api.messages.service import MessageService -from askui.chat.api.models import ThreadId - - -@pytest.fixture -def temp_base_dir() -> Generator[Path, None, None]: - """Create a temporary directory for testing.""" - with tempfile.TemporaryDirectory() as temp_dir: - yield Path(temp_dir) - - -@pytest.fixture -def message_service(temp_base_dir: Path) -> MessageService: - """Create a MessageService instance with temporary storage.""" - return MessageService(temp_base_dir) - - -@pytest.fixture -def thread_id() -> ThreadId: - """Create a test thread ID.""" - return "thread_test123" - - -class TestMessageServiceJSONPersistence: - """Test MessageService with JSON file persistence.""" - - def test_create_message_creates_individual_json_file( - self, message_service: MessageService, thread_id: ThreadId - ) -> None: - """Test that creating a message creates an individual JSON file.""" - request = MessageCreateParams(role="user", content="Hello, world!") - - message = message_service.create(thread_id, request) - - # Check that the message directory was created - messages_dir = message_service.get_messages_dir(thread_id) - assert messages_dir.exists() - - # Check that the message file was created - message_file = message_service._get_message_path(thread_id, message.id) - assert message_file.exists() - - # Verify the file contains the correct JSON data - with message_file.open("r") as f: - data = json.load(f) - assert data["role"] == "user" - assert data["content"] == "Hello, world!" - assert data["id"] == message.id - assert data["thread_id"] == thread_id - - def test_list_messages_reads_from_json_files( - self, message_service: MessageService, thread_id: ThreadId - ) -> None: - """Test that listing messages reads from individual JSON files.""" - # Create multiple messages - messages = [] - for i in range(3): - request = MessageCreateParams( - role="user" if i % 2 == 0 else "assistant", content=f"Message {i}" - ) - message = message_service.create(thread_id, request) - messages.append(message) - - # List messages - from askui.utils.api_utils import ListQuery - - query = ListQuery(limit=10, order="asc") - response = message_service.list_(thread_id, query) - - # Verify all messages were found - assert len(response.data) == 3 - - # Verify messages are sorted by creation time - assert response.data[0].created_at <= response.data[1].created_at - assert response.data[1].created_at <= response.data[2].created_at - - def test_delete_message_removes_json_file( - self, message_service: MessageService, thread_id: ThreadId - ) -> None: - """Test that deleting a message removes its JSON file.""" - request = MessageCreateParams(role="user", content="Delete me") - - message = message_service.create(thread_id, request) - message_file = message_service._get_message_path(thread_id, message.id) - assert message_file.exists() - - # Delete the message - message_service.delete(thread_id, message.id) - - # Verify the file was removed - assert not message_file.exists() - - def test_directory_structure_is_correct( - self, message_service: MessageService, thread_id: ThreadId - ) -> None: - """Test that the directory structure follows the expected pattern.""" - request = MessageCreateParams(role="user", content="Test message") - - message_service.create(thread_id, request) - - messages_dir = message_service.get_messages_dir(thread_id) - - assert messages_dir.exists() - - # Check that there's a JSON file in the messages directory - json_files = list(messages_dir.glob("*.json")) - assert len(json_files) == 1 - assert json_files[0].suffix == ".json" diff --git a/tests/integration/chat/api/test_threads_service.py b/tests/integration/chat/api/test_threads_service.py deleted file mode 100644 index dde56975..00000000 --- a/tests/integration/chat/api/test_threads_service.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Integration tests for the ThreadService with JSON file persistence.""" - -import tempfile -from pathlib import Path -from typing import Generator - -import pytest - -from askui.chat.api.messages.models import MessageCreateParams -from askui.chat.api.messages.service import MessageService -from askui.chat.api.threads.service import ThreadCreateParams, ThreadService - - -@pytest.fixture -def temp_base_dir() -> Generator[Path, None, None]: - """Create a temporary directory for testing.""" - with tempfile.TemporaryDirectory() as temp_dir: - yield Path(temp_dir) - - -@pytest.fixture -def message_service(temp_base_dir: Path) -> MessageService: - """Create a MessageService instance with temporary storage.""" - return MessageService(temp_base_dir) - - -@pytest.fixture -def thread_service( - temp_base_dir: Path, message_service: MessageService -) -> ThreadService: - """Create a ThreadService instance with temporary storage.""" - return ThreadService(temp_base_dir, message_service) - - -class TestThreadServiceJSONPersistence: - """Test ThreadService with JSON file persistence.""" - - def test_create_thread_creates_directory_structure( - self, thread_service: ThreadService - ) -> None: - """Test that creating a thread creates the proper directory structure.""" - request = ThreadCreateParams(name="Test Thread") - - thread = thread_service.create(request) - - # Check that thread metadata file was created - thread_file = thread_service._base_dir / "threads" / f"{thread.id}.json" - assert thread_file.exists() - - # Check that messages directory was created (by creating a message) - # The ThreadService doesn't create the messages directory until a message is - # added - message_request = MessageCreateParams(role="user", content="Test message") - thread_service._message_service.create(thread.id, message_request) - - thread_messages_dir = thread_service._message_service.get_messages_dir( - thread.id - ) - assert thread_messages_dir.exists() - - # Verify thread metadata content - with thread_file.open("r") as f: - import json - - data = json.load(f) - assert data["name"] == "Test Thread" - assert data["id"] == thread.id - - def test_create_thread_with_messages(self, thread_service: ThreadService) -> None: - """Test that creating a thread with messages works correctly.""" - messages = [ - MessageCreateParams(role="user", content="Hello"), - MessageCreateParams(role="assistant", content="Hi there!"), - ] - request = ThreadCreateParams(name="Thread with Messages", messages=messages) - - thread = thread_service.create(request) - - # Check that messages were created - thread_messages_dir = thread_service._message_service.get_messages_dir( - thread.id - ) - json_files = list(thread_messages_dir.glob("*.json")) - assert len(json_files) == 2 - - # Verify message content - for json_file in json_files: - with json_file.open("r") as f: - import json - - data = json.load(f) - assert data["thread_id"] == thread.id - assert data["role"] in ["user", "assistant"] - - def test_delete_thread_removes_all_files( - self, thread_service: ThreadService - ) -> None: - """Test that deleting a thread removes all associated files.""" - request = ThreadCreateParams( - name="Thread to Delete", - messages=[MessageCreateParams(role="user", content="Test message")], - ) - thread = thread_service.create(request) - - # Verify files exist - thread_file = thread_service._base_dir / "threads" / f"{thread.id}.json" - assert thread_file.exists() - - # The thread directory itself doesn't exist, only the messages directory - messages_dir = thread_service._message_service.get_messages_dir(thread.id) - assert messages_dir.exists() - - # Delete thread - thread_service.delete(thread.id) - - # Verify all files were removed - assert not thread_file.exists() - assert not messages_dir.exists() From 3bc8fecff5af5133d48c726f5d2b578d7317f196 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 21 Aug 2025 23:30:00 +0200 Subject: [PATCH 10/11] chore(chat): disable translation until implemented in frontend --- src/askui/chat/api/messages/translator.py | 50 +++++++++++------------ 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/src/askui/chat/api/messages/translator.py b/src/askui/chat/api/messages/translator.py index 8c24a752..dd399bf7 100644 --- a/src/askui/chat/api/messages/translator.py +++ b/src/askui/chat/api/messages/translator.py @@ -1,8 +1,3 @@ -from io import BytesIO -from pathlib import Path - -from fastapi import UploadFile -from fastapi.datastructures import Headers from PIL import Image from askui.chat.api.files.service import FileService @@ -13,7 +8,6 @@ MessageParam, ToolResultBlockParam, ) -from askui.logger import logger from askui.models.shared.agent_message_param import ( Base64ImageSourceParam, TextBlockParam, @@ -31,7 +25,7 @@ from askui.models.shared.agent_message_param import ( ToolResultBlockParam as AnthropicToolResultBlockParam, ) -from askui.utils.image_utils import base64_to_image, image_to_base64 +from askui.utils.image_utils import image_to_base64 class ImageBlockParamSourceTranslator: @@ -44,26 +38,28 @@ async def from_anthropic( if source.type == "url": return source if source.type == "base64": # noqa: RET503 - try: - image = base64_to_image(source.data) - bytes_io = BytesIO() - image.save(bytes_io, format="PNG") - bytes_io.seek(0) - file = await self._file_service.upload_file( - file=UploadFile( - file=bytes_io, - headers=Headers( - { - "Content-Type": "image/png", - } - ), - ) - ) - except Exception as e: # noqa: BLE001 - logger.warning(f"Failed to save image: {e}", exc_info=True) - return source - else: - return FileImageSourceParam(id=file.id, type="file") + # Readd translation to FileImageSourceParam as soon as we support it in frontend + return source + # try: + # image = base64_to_image(source.data) + # bytes_io = BytesIO() + # image.save(bytes_io, format="PNG") + # bytes_io.seek(0) + # file = await self._file_service.upload_file( + # file=UploadFile( + # file=bytes_io, + # headers=Headers( + # { + # "Content-Type": "image/png", + # } + # ), + # ) + # ) + # except Exception as e: # noqa: BLE001 + # logger.warning(f"Failed to save image: {e}", exc_info=True) + # return source + # else: + # return FileImageSourceParam(id=file.id, type="file") async def to_anthropic( self, From 2e70aabcaf998e91fff2641e31cfffd74b631699 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Fri, 22 Aug 2025 08:56:18 +0200 Subject: [PATCH 11/11] fix(chat): fix mcpcnf id & checking thread existence when working with subresources --- src/askui/chat/api/mcp_configs/models.py | 2 +- src/askui/chat/api/messages/router.py | 10 +- src/askui/chat/api/runs/router.py | 10 +- src/askui/chat/api/threads/dependencies.py | 16 + src/askui/chat/api/threads/facade.py | 56 ++ tests/integration/chat/api/test_assistants.py | 384 +++++++++++++ tests/integration/chat/api/test_files.py | 2 +- tests/integration/chat/api/test_health.py | 49 ++ .../integration/chat/api/test_mcp_configs.py | 373 ++++++++++++ tests/integration/chat/api/test_messages.py | 456 +++++++++++++++ tests/integration/chat/api/test_runs.py | 537 ++++++++++++++++++ tests/integration/chat/api/test_threads.py | 376 ++++++++++++ 12 files changed, 2261 insertions(+), 10 deletions(-) create mode 100644 src/askui/chat/api/threads/facade.py create mode 100644 tests/integration/chat/api/test_assistants.py create mode 100644 tests/integration/chat/api/test_health.py create mode 100644 tests/integration/chat/api/test_mcp_configs.py create mode 100644 tests/integration/chat/api/test_messages.py create mode 100644 tests/integration/chat/api/test_runs.py create mode 100644 tests/integration/chat/api/test_threads.py diff --git a/src/askui/chat/api/mcp_configs/models.py b/src/askui/chat/api/mcp_configs/models.py index a219437c..c5c61fe4 100644 --- a/src/askui/chat/api/mcp_configs/models.py +++ b/src/askui/chat/api/mcp_configs/models.py @@ -40,7 +40,7 @@ class McpConfig(McpConfigBase, Resource): @classmethod def create(cls, params: McpConfigCreateParams) -> "McpConfig": return cls( - id=generate_time_ordered_id("mcp_config"), + id=generate_time_ordered_id("mcpcnf"), created_at=now(), **params.model_dump(), ) diff --git a/src/askui/chat/api/messages/router.py b/src/askui/chat/api/messages/router.py index 82e75d3a..4276950a 100644 --- a/src/askui/chat/api/messages/router.py +++ b/src/askui/chat/api/messages/router.py @@ -5,6 +5,8 @@ from askui.chat.api.messages.models import Message, MessageCreateParams from askui.chat.api.messages.service import MessageService from askui.chat.api.models import MessageId, ThreadId +from askui.chat.api.threads.dependencies import ThreadFacadeDep +from askui.chat.api.threads.facade import ThreadFacade from askui.utils.api_utils import ListQuery, ListResponse router = APIRouter(prefix="/threads/{thread_id}/messages", tags=["messages"]) @@ -14,18 +16,18 @@ def list_messages( thread_id: ThreadId, query: ListQuery = ListQueryDep, - message_service: MessageService = MessageServiceDep, + thread_facade: ThreadFacade = ThreadFacadeDep, ) -> ListResponse[Message]: - return message_service.list_(thread_id, query=query) + return thread_facade.list_messages(thread_id, query=query) @router.post("", status_code=status.HTTP_201_CREATED) async def create_message( thread_id: ThreadId, params: MessageCreateParams, - message_service: MessageService = MessageServiceDep, + thread_facade: ThreadFacade = ThreadFacadeDep, ) -> Message: - return message_service.create(thread_id=thread_id, params=params) + return thread_facade.create_message(thread_id=thread_id, params=params) @router.get("/{message_id}") diff --git a/src/askui/chat/api/runs/router.py b/src/askui/chat/api/runs/router.py index 279cd0c4..d6f8f33d 100644 --- a/src/askui/chat/api/runs/router.py +++ b/src/askui/chat/api/runs/router.py @@ -8,6 +8,8 @@ from askui.chat.api.dependencies import ListQueryDep from askui.chat.api.models import RunId, ThreadId from askui.chat.api.runs.models import RunCreateParams +from askui.chat.api.threads.dependencies import ThreadFacadeDep +from askui.chat.api.threads.facade import ThreadFacade from askui.utils.api_utils import ListQuery, ListResponse from .dependencies import RunServiceDep @@ -22,10 +24,10 @@ async def create_run( thread_id: Annotated[ThreadId, Path(...)], params: RunCreateParams, background_tasks: BackgroundTasks, - run_service: RunService = RunServiceDep, + thread_facade: ThreadFacade = ThreadFacadeDep, ) -> Response: stream = params.stream - run, async_generator = await run_service.create(thread_id, params) + run, async_generator = await thread_facade.create_run(thread_id, params) if stream: async def sse_event_stream() -> AsyncGenerator[str, None]: @@ -64,9 +66,9 @@ def retrieve_run( def list_runs( thread_id: Annotated[ThreadId, Path(...)], query: ListQuery = ListQueryDep, - run_service: RunService = RunServiceDep, + thread_facade: ThreadFacade = ThreadFacadeDep, ) -> ListResponse[Run]: - return run_service.list_(thread_id, query=query) + return thread_facade.list_runs(thread_id, query=query) @router.post("/{run_id}/cancel") diff --git a/src/askui/chat/api/threads/dependencies.py b/src/askui/chat/api/threads/dependencies.py index 7c396bd5..64ff8172 100644 --- a/src/askui/chat/api/threads/dependencies.py +++ b/src/askui/chat/api/threads/dependencies.py @@ -7,6 +7,7 @@ from askui.chat.api.messages.service import MessageService from askui.chat.api.runs.dependencies import RunServiceDep from askui.chat.api.runs.service import RunService +from askui.chat.api.threads.facade import ThreadFacade from askui.chat.api.threads.service import ThreadService @@ -24,3 +25,18 @@ def get_thread_service( ThreadServiceDep = Depends(get_thread_service) + + +def get_thread_facade( + thread_service: ThreadService = ThreadServiceDep, + message_service: MessageService = MessageServiceDep, + run_service: RunService = RunServiceDep, +) -> ThreadFacade: + return ThreadFacade( + thread_service=thread_service, + message_service=message_service, + run_service=run_service, + ) + + +ThreadFacadeDep = Depends(get_thread_facade) diff --git a/src/askui/chat/api/threads/facade.py b/src/askui/chat/api/threads/facade.py new file mode 100644 index 00000000..1d3b47cf --- /dev/null +++ b/src/askui/chat/api/threads/facade.py @@ -0,0 +1,56 @@ +from collections.abc import AsyncGenerator + +from askui.chat.api.messages.models import Message, MessageCreateParams +from askui.chat.api.messages.service import MessageService +from askui.chat.api.models import ThreadId +from askui.chat.api.runs.models import Run, RunCreateParams +from askui.chat.api.runs.runner.events.events import Events +from askui.chat.api.runs.service import RunService +from askui.chat.api.threads.service import ThreadService +from askui.utils.api_utils import ListQuery, ListResponse + + +class ThreadFacade: + """ + Facade service that coordinates operations across threads, messages, and runs. + """ + + def __init__( + self, + thread_service: ThreadService, + message_service: MessageService, + run_service: RunService, + ) -> None: + self._thread_service = thread_service + self._message_service = message_service + self._run_service = run_service + + def _ensure_thread_exists(self, thread_id: ThreadId) -> None: + """Validate that a thread exists before allowing operations on it.""" + self._thread_service.retrieve(thread_id) + + def create_message( + self, thread_id: ThreadId, params: MessageCreateParams + ) -> Message: + """Create a message, ensuring the thread exists first.""" + self._ensure_thread_exists(thread_id) + return self._message_service.create(thread_id, params) + + async def create_run( + self, thread_id: ThreadId, params: RunCreateParams + ) -> tuple[Run, AsyncGenerator[Events, None]]: + """Create a run, ensuring the thread exists first.""" + self._ensure_thread_exists(thread_id) + return await self._run_service.create(thread_id, params) + + def list_messages( + self, thread_id: ThreadId, query: ListQuery + ) -> ListResponse[Message]: + """List messages, ensuring the thread exists first.""" + self._ensure_thread_exists(thread_id) + return self._message_service.list_(thread_id, query) + + def list_runs(self, thread_id: ThreadId, query: ListQuery) -> ListResponse[Run]: + """List runs, ensuring the thread exists first.""" + self._ensure_thread_exists(thread_id) + return self._run_service.list_(thread_id, query) diff --git a/tests/integration/chat/api/test_assistants.py b/tests/integration/chat/api/test_assistants.py new file mode 100644 index 00000000..f1b1409f --- /dev/null +++ b/tests/integration/chat/api/test_assistants.py @@ -0,0 +1,384 @@ +"""Integration tests for the assistants API endpoints.""" + +import tempfile +from pathlib import Path + +from fastapi import status +from fastapi.testclient import TestClient + +from askui.chat.api.assistants.models import Assistant +from askui.chat.api.assistants.service import AssistantService + + +class TestAssistantsAPI: + """Test suite for the assistants API endpoints.""" + + def test_list_assistants_empty(self, test_headers: dict[str, str]) -> None: + """Test listing assistants when no assistants exist.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + + def override_assistant_service() -> AssistantService: + return AssistantService(workspace_path) + + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + response = client.get("/v1/assistants", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert data["data"] == [] + assert data["has_more"] is False + finally: + app.dependency_overrides.clear() + + def test_list_assistants_with_assistants( + self, test_headers: dict[str, str] + ) -> None: + """Test listing assistants when assistants exist.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + assistants_dir = workspace_path / "assistants" + assistants_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock assistant + mock_assistant = Assistant( + id="asst_test123", + object="assistant", + created_at=1234567890, + name="Test Assistant", + description="A test assistant", + avatar="test_avatar.png", + ) + (assistants_dir / "asst_test123.json").write_text( + mock_assistant.model_dump_json() + ) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + + def override_assistant_service() -> AssistantService: + return AssistantService(workspace_path) + + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + response = client.get("/v1/assistants", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["id"] == "asst_test123" + assert data["data"][0]["name"] == "Test Assistant" + finally: + app.dependency_overrides.clear() + + def test_list_assistants_with_pagination( + self, test_headers: dict[str, str] + ) -> None: + """Test listing assistants with pagination parameters.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + assistants_dir = workspace_path / "assistants" + assistants_dir.mkdir(parents=True, exist_ok=True) + + # Create multiple mock assistants + for i in range(5): + mock_assistant = Assistant( + id=f"asst_test{i}", + object="assistant", + created_at=1234567890 + i, + name=f"Test Assistant {i}", + description=f"Test assistant {i}", + ) + (assistants_dir / f"asst_test{i}.json").write_text( + mock_assistant.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + + def override_assistant_service() -> AssistantService: + return AssistantService(workspace_path) + + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + response = client.get("/v1/assistants?limit=3", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["data"]) == 3 + assert data["has_more"] is True + finally: + app.dependency_overrides.clear() + + def test_create_assistant(self, test_headers: dict[str, str]) -> None: + """Test creating a new assistant.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + + def override_assistant_service() -> AssistantService: + return AssistantService(workspace_path) + + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + assistant_data = { + "name": "New Test Assistant", + "description": "A newly created test assistant", + "avatar": "new_avatar.png", + } + response = client.post( + "/v1/assistants", json=assistant_data, headers=test_headers + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["name"] == "New Test Assistant" + assert data["description"] == "A newly created test assistant" + assert data["avatar"] == "new_avatar.png" + assert data["object"] == "assistant" + assert "id" in data + assert "created_at" in data + finally: + app.dependency_overrides.clear() + + def test_create_assistant_minimal(self, test_headers: dict[str, str]) -> None: + """Test creating an assistant with minimal data.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + + def override_assistant_service() -> AssistantService: + return AssistantService(workspace_path) + + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + response = client.post("/v1/assistants", json={}, headers=test_headers) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["object"] == "assistant" + assert data["name"] is None + assert data["description"] is None + assert data["avatar"] is None + finally: + app.dependency_overrides.clear() + + def test_retrieve_assistant(self, test_headers: dict[str, str]) -> None: + """Test retrieving an existing assistant.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + assistants_dir = workspace_path / "assistants" + assistants_dir.mkdir(parents=True, exist_ok=True) + + mock_assistant = Assistant( + id="asst_test123", + object="assistant", + created_at=1234567890, + name="Test Assistant", + description="A test assistant", + ) + (assistants_dir / "asst_test123.json").write_text( + mock_assistant.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + + def override_assistant_service() -> AssistantService: + return AssistantService(workspace_path) + + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + response = client.get( + "/v1/assistants/asst_test123", headers=test_headers + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == "asst_test123" + assert data["name"] == "Test Assistant" + assert data["description"] == "A test assistant" + finally: + app.dependency_overrides.clear() + + def test_retrieve_assistant_not_found( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test retrieving a non-existent assistant.""" + response = test_client.get( + "/v1/assistants/asst_nonexistent123", headers=test_headers + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert "detail" in data + + def test_modify_assistant(self, test_headers: dict[str, str]) -> None: + """Test modifying an existing assistant.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + assistants_dir = workspace_path / "assistants" + assistants_dir.mkdir(parents=True, exist_ok=True) + + mock_assistant = Assistant( + id="asst_test123", + object="assistant", + created_at=1234567890, + name="Original Name", + description="Original description", + ) + (assistants_dir / "asst_test123.json").write_text( + mock_assistant.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + + def override_assistant_service() -> AssistantService: + return AssistantService(workspace_path) + + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + modify_data = { + "name": "Modified Name", + "description": "Modified description", + } + response = client.post( + "/v1/assistants/asst_test123", + json=modify_data, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Modified Name" + assert data["description"] == "Modified description" + assert data["id"] == "asst_test123" + assert data["created_at"] == 1234567890 + finally: + app.dependency_overrides.clear() + + def test_modify_assistant_partial(self, test_headers: dict[str, str]) -> None: + """Test modifying an assistant with partial data.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + assistants_dir = workspace_path / "assistants" + assistants_dir.mkdir(parents=True, exist_ok=True) + + mock_assistant = Assistant( + id="asst_test123", + object="assistant", + created_at=1234567890, + name="Original Name", + description="Original description", + ) + (assistants_dir / "asst_test123.json").write_text( + mock_assistant.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + + def override_assistant_service() -> AssistantService: + return AssistantService(workspace_path) + + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + modify_data = {"name": "Only Name Modified"} + response = client.post( + "/v1/assistants/asst_test123", + json=modify_data, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Only Name Modified" + assert data["description"] == "Original description" # Unchanged + finally: + app.dependency_overrides.clear() + + def test_modify_assistant_not_found( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test modifying a non-existent assistant.""" + modify_data = {"name": "Modified Name"} + response = test_client.post( + "/v1/assistants/asst_nonexistent123", json=modify_data, headers=test_headers + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_delete_assistant(self, test_headers: dict[str, str]) -> None: + """Test deleting an existing assistant.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + assistants_dir = workspace_path / "assistants" + assistants_dir.mkdir(parents=True, exist_ok=True) + + mock_assistant = Assistant( + id="asst_test123", + object="assistant", + created_at=1234567890, + name="Test Assistant", + ) + (assistants_dir / "asst_test123.json").write_text( + mock_assistant.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + + def override_assistant_service() -> AssistantService: + return AssistantService(workspace_path) + + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + response = client.delete( + "/v1/assistants/asst_test123", headers=test_headers + ) + + assert response.status_code == status.HTTP_204_NO_CONTENT + assert response.content == b"" + finally: + app.dependency_overrides.clear() + + def test_delete_assistant_not_found( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test deleting a non-existent assistant.""" + response = test_client.delete( + "/v1/assistants/asst_nonexistent123", headers=test_headers + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/integration/chat/api/test_files.py b/tests/integration/chat/api/test_files.py index 99e6de6a..c15b4511 100644 --- a/tests/integration/chat/api/test_files.py +++ b/tests/integration/chat/api/test_files.py @@ -147,7 +147,7 @@ def test_upload_file_without_filename(self, test_headers: dict[str, str]) -> Non files = {"file": ("test", io.BytesIO(file_content), "text/plain")} # Create a test app with overridden dependencies - from integration.chat.api.conftest import create_test_app_with_overrides + from .conftest import create_test_app_with_overrides temp_dir = tempfile.mkdtemp() workspace_path = Path(temp_dir) diff --git a/tests/integration/chat/api/test_health.py b/tests/integration/chat/api/test_health.py new file mode 100644 index 00000000..74ad60d2 --- /dev/null +++ b/tests/integration/chat/api/test_health.py @@ -0,0 +1,49 @@ +"""Integration tests for the health API endpoint.""" + +from fastapi import status +from fastapi.testclient import TestClient + + +class TestHealthAPI: + """Test suite for the health API endpoint.""" + + def test_health_check( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test the health check endpoint.""" + response = test_client.get("/v1/health", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["status"] == "OK" + + def test_health_check_without_headers(self, test_client: TestClient) -> None: + """Test the health check endpoint without workspace headers.""" + response = test_client.get("/v1/health") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["status"] == "OK" + + def test_health_check_response_structure( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test that the health check response has the correct structure.""" + response = test_client.get("/v1/health", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + + # Check that only the expected fields are present + assert set(data.keys()) == {"status"} + assert isinstance(data["status"], str) + assert data["status"] == "OK" + + def test_health_check_content_type( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test that the health check response has the correct content type.""" + response = test_client.get("/v1/health", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + assert response.headers["content-type"] == "application/json" diff --git a/tests/integration/chat/api/test_mcp_configs.py b/tests/integration/chat/api/test_mcp_configs.py new file mode 100644 index 00000000..d855a1aa --- /dev/null +++ b/tests/integration/chat/api/test_mcp_configs.py @@ -0,0 +1,373 @@ +"""Integration tests for the MCP configs API endpoints.""" + +import tempfile +from pathlib import Path + +from fastapi import status +from fastapi.testclient import TestClient + +from askui.chat.api.mcp_configs.models import McpConfig +from askui.chat.api.mcp_configs.service import McpConfigService + + +class TestMcpConfigsAPI: + """Test suite for the MCP configs API endpoints.""" + + def test_list_mcp_configs_empty( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test listing MCP configs when no configs exist.""" + response = test_client.get("/v1/mcp-configs", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert data["data"] == [] + assert data["has_more"] is False + + def test_list_mcp_configs_with_configs(self, test_headers: dict[str, str]) -> None: + """Test listing MCP configs when configs exist.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + mcp_configs_dir = workspace_path / "mcp_configs" + mcp_configs_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock MCP config + mock_config = McpConfig( + id="mcpcnf_test123", + object="mcp_config", + created_at=1234567890, + name="Test MCP Config", + mcp_server={"type": "stdio", "command": "test_command"}, + ) + (mcp_configs_dir / "mcpcnf_test123.json").write_text( + mock_config.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service + + def override_mcp_config_service() -> McpConfigService: + return McpConfigService(workspace_path) + + app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service + + try: + with TestClient(app) as client: + response = client.get("/v1/mcp-configs", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["id"] == "mcpcnf_test123" + assert data["data"][0]["name"] == "Test MCP Config" + assert data["data"][0]["mcp_server"]["type"] == "stdio" + finally: + app.dependency_overrides.clear() + + def test_list_mcp_configs_with_pagination( + self, test_headers: dict[str, str] + ) -> None: + """Test listing MCP configs with pagination parameters.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + mcp_configs_dir = workspace_path / "mcp_configs" + mcp_configs_dir.mkdir(parents=True, exist_ok=True) + + # Create multiple mock MCP configs + for i in range(5): + mock_config = McpConfig( + id=f"mcpcnf_test{i}", + object="mcp_config", + created_at=1234567890 + i, + name=f"Test MCP Config {i}", + mcp_server={"type": "stdio", "command": f"test_command_{i}"}, + ) + (mcp_configs_dir / f"mcpcnf_test{i}.json").write_text( + mock_config.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service + + def override_mcp_config_service() -> McpConfigService: + return McpConfigService(workspace_path) + + app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service + + try: + with TestClient(app) as client: + response = client.get("/v1/mcp-configs?limit=3", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["data"]) == 3 + assert data["has_more"] is True + finally: + app.dependency_overrides.clear() + + def test_create_mcp_config(self, test_headers: dict[str, str]) -> None: + """Test creating a new MCP config.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + from askui.chat.api.app import app + from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service + + def override_mcp_config_service() -> McpConfigService: + return McpConfigService(workspace_path) + + app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service + + try: + with TestClient(app) as client: + config_data = { + "name": "New MCP Config", + "mcp_server": {"type": "stdio", "command": "new_command"}, + } + response = client.post( + "/v1/mcp-configs", json=config_data, headers=test_headers + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["name"] == "New MCP Config" + assert data["mcp_server"]["type"] == "stdio" + assert data["mcp_server"]["command"] == "new_command" + finally: + app.dependency_overrides.clear() + + def test_create_mcp_config_minimal(self, test_headers: dict[str, str]) -> None: + """Test creating an MCP config with minimal data.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + from askui.chat.api.app import app + from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service + + def override_mcp_config_service() -> McpConfigService: + return McpConfigService(workspace_path) + + app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service + + try: + with TestClient(app) as client: + response = client.post( + "/v1/mcp-configs", + json={ + "name": "Minimal Config", + "mcp_server": {"type": "stdio", "command": "minimal"}, + }, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["object"] == "mcp_config" + assert data["name"] == "Minimal Config" + assert data["mcp_server"]["type"] == "stdio" + assert data["mcp_server"]["command"] == "minimal" + finally: + app.dependency_overrides.clear() + + def test_retrieve_mcp_config(self, test_headers: dict[str, str]) -> None: + """Test retrieving an existing MCP config.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + mcp_configs_dir = workspace_path / "mcp_configs" + mcp_configs_dir.mkdir(parents=True, exist_ok=True) + + mock_config = McpConfig( + id="mcpcnf_test123", + object="mcp_config", + created_at=1234567890, + name="Test MCP Config", + mcp_server={"type": "stdio", "command": "test_command"}, + ) + (mcp_configs_dir / "mcpcnf_test123.json").write_text( + mock_config.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service + + def override_mcp_config_service() -> McpConfigService: + return McpConfigService(workspace_path) + + app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service + + try: + with TestClient(app) as client: + response = client.get( + "/v1/mcp-configs/mcpcnf_test123", headers=test_headers + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == "mcpcnf_test123" + assert data["name"] == "Test MCP Config" + assert data["mcp_server"]["type"] == "stdio" + assert data["mcp_server"]["command"] == "test_command" + finally: + app.dependency_overrides.clear() + + def test_retrieve_mcp_config_not_found( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test retrieving a non-existent MCP config.""" + response = test_client.get( + "/v1/mcp-configs/mcpcnf_nonexistent123", headers=test_headers + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert "detail" in data + + def test_modify_mcp_config(self, test_headers: dict[str, str]) -> None: + """Test modifying an existing MCP config.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + mcp_configs_dir = workspace_path / "mcp_configs" + mcp_configs_dir.mkdir(parents=True, exist_ok=True) + + mock_config = McpConfig( + id="mcpcnf_test123", + object="mcp_config", + created_at=1234567890, + name="Original Name", + mcp_server={"type": "stdio", "command": "original_command"}, + ) + (mcp_configs_dir / "mcpcnf_test123.json").write_text( + mock_config.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service + + def override_mcp_config_service() -> McpConfigService: + return McpConfigService(workspace_path) + + app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service + + try: + with TestClient(app) as client: + modify_data = { + "name": "Modified Name", + "mcp_server": {"type": "stdio", "command": "modified_command"}, + } + response = client.post( + "/v1/mcp-configs/mcpcnf_test123", + json=modify_data, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Modified Name" + assert data["mcp_server"]["type"] == "stdio" + assert data["mcp_server"]["command"] == "modified_command" + finally: + app.dependency_overrides.clear() + + def test_modify_mcp_config_partial(self, test_headers: dict[str, str]) -> None: + """Test modifying an MCP config with partial data.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + mcp_configs_dir = workspace_path / "mcp_configs" + mcp_configs_dir.mkdir(parents=True, exist_ok=True) + + mock_config = McpConfig( + id="mcpcnf_test123", + object="mcp_config", + created_at=1234567890, + name="Original Name", + mcp_server={"type": "stdio", "command": "original_command"}, + ) + (mcp_configs_dir / "mcpcnf_test123.json").write_text( + mock_config.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service + + def override_mcp_config_service() -> McpConfigService: + return McpConfigService(workspace_path) + + app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service + + try: + with TestClient(app) as client: + modify_data = {"name": "Only Name Modified"} + response = client.post( + "/v1/mcp-configs/mcpcnf_test123", + json=modify_data, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Only Name Modified" + + finally: + app.dependency_overrides.clear() + + def test_modify_mcp_config_not_found( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test modifying a non-existent MCP config.""" + modify_data = {"name": "Modified Name"} + response = test_client.post( + "/v1/mcp-configs/mcpcnf_nonexistent123", + json=modify_data, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_delete_mcp_config(self, test_headers: dict[str, str]) -> None: + """Test deleting an existing MCP config.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + mcp_configs_dir = workspace_path / "mcp_configs" + mcp_configs_dir.mkdir(parents=True, exist_ok=True) + + mock_config = McpConfig( + id="mcpcnf_test123", + object="mcp_config", + created_at=1234567890, + name="Test MCP Config", + mcp_server={"type": "stdio", "command": "test_command"}, + ) + (mcp_configs_dir / "mcpcnf_test123.json").write_text( + mock_config.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service + + def override_mcp_config_service() -> McpConfigService: + return McpConfigService(workspace_path) + + app.dependency_overrides[get_mcp_config_service] = override_mcp_config_service + + try: + with TestClient(app) as client: + response = client.delete( + "/v1/mcp-configs/mcpcnf_test123", headers=test_headers + ) + + assert response.status_code == status.HTTP_204_NO_CONTENT + assert response.content == b"" + finally: + app.dependency_overrides.clear() + + def test_delete_mcp_config_not_found( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test deleting a non-existent MCP config.""" + response = test_client.delete( + "/v1/mcp-configs/mcpcnf_nonexistent123", headers=test_headers + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/integration/chat/api/test_messages.py b/tests/integration/chat/api/test_messages.py new file mode 100644 index 00000000..c46ec94c --- /dev/null +++ b/tests/integration/chat/api/test_messages.py @@ -0,0 +1,456 @@ +"""Integration tests for the messages API endpoints.""" + +import tempfile +from pathlib import Path +from unittest.mock import Mock + +from fastapi import status +from fastapi.testclient import TestClient + +from askui.chat.api.messages.models import Message +from askui.chat.api.messages.service import MessageService +from askui.chat.api.threads.models import Thread +from askui.chat.api.threads.service import ThreadService + + +class TestMessagesAPI: + """Test suite for the messages API endpoints.""" + + def test_list_messages_empty(self, test_headers: dict[str, str]) -> None: + """Test listing messages when no messages exist.""" + # First create a thread + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.messages.dependencies import get_message_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_message_service() -> MessageService: + return MessageService(workspace_path) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_message_service] = override_message_service + + try: + with TestClient(app) as client: + response = client.get( + "/v1/threads/thread_test123/messages", headers=test_headers + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert data["data"] == [] + assert data["has_more"] is False + finally: + app.dependency_overrides.clear() + + def test_list_messages_with_messages(self, test_headers: dict[str, str]) -> None: + """Test listing messages when messages exist.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + messages_dir = workspace_path / "messages" / "thread_test123" + messages_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + # Create a mock message + mock_message = Message( + id="msg_test123", + object="thread.message", + created_at=1234567890, + thread_id="thread_test123", + role="user", + content="Hello, this is a test message", + metadata={"key": "value"}, + ) + (messages_dir / "msg_test123.json").write_text(mock_message.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.messages.dependencies import get_message_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_message_service() -> MessageService: + return MessageService(workspace_path) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_message_service] = override_message_service + + try: + with TestClient(app) as client: + response = client.get( + "/v1/threads/thread_test123/messages", headers=test_headers + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["id"] == "msg_test123" + assert data["data"][0]["content"] == "Hello, this is a test message" + assert data["data"][0]["role"] == "user" + finally: + app.dependency_overrides.clear() + + def test_list_messages_with_pagination(self, test_headers: dict[str, str]) -> None: + """Test listing messages with pagination parameters.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + messages_dir = workspace_path / "messages" / "thread_test123" + messages_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + # Create multiple mock messages + for i in range(5): + mock_message = Message( + id=f"msg_test{i}", + object="thread.message", + created_at=1234567890 + i, + thread_id="thread_test123", + role="user" if i % 2 == 0 else "assistant", + content=f"Test message {i}", + ) + (messages_dir / f"msg_test{i}.json").write_text( + mock_message.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.messages.dependencies import get_message_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_message_service() -> MessageService: + return MessageService(workspace_path) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_message_service] = override_message_service + + try: + with TestClient(app) as client: + response = client.get( + "/v1/threads/thread_test123/messages?limit=3", headers=test_headers + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["data"]) == 3 + assert data["has_more"] is True + finally: + app.dependency_overrides.clear() + + def test_create_message(self, test_headers: dict[str, str]) -> None: + """Test creating a new message.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.messages.dependencies import get_message_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_message_service() -> MessageService: + return MessageService(workspace_path) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_message_service] = override_message_service + + try: + with TestClient(app) as client: + message_data = { + "role": "user", + "content": "Hello, this is a new message", + "metadata": {"key": "value", "number": 42}, + } + response = client.post( + "/v1/threads/thread_test123/messages", + json=message_data, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["role"] == "user" + assert data["content"] == "Hello, this is a new message" + + assert data["object"] == "thread.message" + assert data["thread_id"] == "thread_test123" + assert "id" in data + assert "created_at" in data + finally: + app.dependency_overrides.clear() + + def test_create_message_minimal(self, test_headers: dict[str, str]) -> None: + """Test creating a message with minimal data.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.messages.dependencies import get_message_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_message_service() -> MessageService: + return MessageService(workspace_path) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_message_service] = override_message_service + + try: + with TestClient(app) as client: + message_data = {"role": "user", "content": "Minimal message"} + response = client.post( + "/v1/threads/thread_test123/messages", + json=message_data, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["object"] == "thread.message" + assert data["role"] == "user" + assert data["content"] == "Minimal message" + + finally: + app.dependency_overrides.clear() + + def test_create_message_invalid_thread( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test creating a message in a non-existent thread.""" + message_data = {"role": "user", "content": "Test message"} + response = test_client.post( + "/v1/threads/thread_nonexistent123/messages", + json=message_data, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_retrieve_message(self, test_headers: dict[str, str]) -> None: + """Test retrieving an existing message.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + messages_dir = workspace_path / "messages" / "thread_test123" + messages_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + # Create a mock message + mock_message = Message( + id="msg_test123", + object="thread.message", + created_at=1234567890, + thread_id="thread_test123", + role="user", + content="Test message content", + metadata={"key": "value"}, + ) + (messages_dir / "msg_test123.json").write_text(mock_message.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.messages.dependencies import get_message_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_message_service() -> MessageService: + return MessageService(workspace_path) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_message_service] = override_message_service + + try: + with TestClient(app) as client: + response = client.get( + "/v1/threads/thread_test123/messages/msg_test123", + headers=test_headers, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == "msg_test123" + assert data["content"] == "Test message content" + assert data["role"] == "user" + assert data["thread_id"] == "thread_test123" + finally: + app.dependency_overrides.clear() + + def test_retrieve_message_not_found( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test retrieving a non-existent message.""" + response = test_client.get( + "/v1/threads/thread_test123/messages/msg_nonexistent123", + headers=test_headers, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert "detail" in data + + def test_delete_message(self, test_headers: dict[str, str]) -> None: + """Test deleting an existing message.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + messages_dir = workspace_path / "messages" / "thread_test123" + messages_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + # Create a mock message + mock_message = Message( + id="msg_test123", + object="thread.message", + created_at=1234567890, + thread_id="thread_test123", + role="user", + content="Test message to delete", + ) + (messages_dir / "msg_test123.json").write_text(mock_message.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.messages.dependencies import get_message_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_message_service() -> MessageService: + return MessageService(workspace_path) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_message_service] = override_message_service + + try: + with TestClient(app) as client: + response = client.delete( + "/v1/threads/thread_test123/messages/msg_test123", + headers=test_headers, + ) + + assert response.status_code == status.HTTP_204_NO_CONTENT + assert response.content == b"" + finally: + app.dependency_overrides.clear() + + def test_delete_message_not_found( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test deleting a non-existent message.""" + response = test_client.delete( + "/v1/threads/thread_test123/messages/msg_nonexistent123", + headers=test_headers, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/integration/chat/api/test_runs.py b/tests/integration/chat/api/test_runs.py new file mode 100644 index 00000000..c74ab1de --- /dev/null +++ b/tests/integration/chat/api/test_runs.py @@ -0,0 +1,537 @@ +"""Integration tests for the runs API endpoints.""" + +import tempfile +from pathlib import Path +from unittest.mock import Mock + +from fastapi import status +from fastapi.testclient import TestClient + +from askui.chat.api.runs.models import Run +from askui.chat.api.runs.service import RunService +from askui.chat.api.threads.models import Thread +from askui.chat.api.threads.service import ThreadService + + +class TestRunsAPI: + """Test suite for the runs API endpoints.""" + + def test_list_runs_empty(self, test_headers: dict[str, str]) -> None: + """Test listing runs when no runs exist.""" + # First create a thread + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.runs.dependencies import get_runs_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_runs_service() -> RunService: + mock_message_service = Mock() + mock_message_translator = Mock() + return RunService( + workspace_path, mock_message_service, mock_message_translator + ) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_runs_service] = override_runs_service + + try: + with TestClient(app) as client: + response = client.get( + "/v1/threads/thread_test123/runs", headers=test_headers + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert data["data"] == [] + assert data["has_more"] is False + finally: + app.dependency_overrides.clear() + + def test_list_runs_with_runs(self, test_headers: dict[str, str]) -> None: + """Test listing runs when runs exist.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + runs_dir = workspace_path / "runs" / "thread_test123" + runs_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + # Create a mock run + mock_run = Run( + id="run_test123", + object="thread.run", + created_at=1234567890, + thread_id="thread_test123", + assistant_id="asst_test123", + expires_at=1755846718, # 10 minutes later + started_at=1234567890, + completed_at=1234567900, + ) + (runs_dir / "run_test123.json").write_text(mock_run.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.runs.dependencies import get_runs_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_runs_service() -> RunService: + mock_message_service = Mock() + mock_message_translator = Mock() + return RunService( + workspace_path, mock_message_service, mock_message_translator + ) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_runs_service] = override_runs_service + + try: + with TestClient(app) as client: + response = client.get( + "/v1/threads/thread_test123/runs", headers=test_headers + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["id"] == "run_test123" + assert data["data"][0]["status"] == "completed" + assert data["data"][0]["assistant_id"] == "asst_test123" + finally: + app.dependency_overrides.clear() + + def test_list_runs_with_pagination(self, test_headers: dict[str, str]) -> None: + """Test listing runs with pagination parameters.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + runs_dir = workspace_path / "runs" / "thread_test123" + runs_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + # Create multiple mock runs + for i in range(5): + mock_run = Run( + id=f"run_test{i}", + object="thread.run", + created_at=1234567890 + i, + thread_id="thread_test123", + assistant_id=f"asst_test{i}", + expires_at=1234567890 + i + 600, # 10 minutes later + ) + (runs_dir / f"run_test{i}.json").write_text(mock_run.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.runs.dependencies import get_runs_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_runs_service() -> RunService: + mock_message_service = Mock() + mock_message_translator = Mock() + return RunService( + workspace_path, mock_message_service, mock_message_translator + ) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_runs_service] = override_runs_service + + try: + with TestClient(app) as client: + response = client.get( + "/v1/threads/thread_test123/runs?limit=3", headers=test_headers + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["data"]) == 3 + assert data["has_more"] is True + finally: + app.dependency_overrides.clear() + + def test_create_run(self, test_headers: dict[str, str]) -> None: + """Test creating a new run.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.runs.dependencies import get_runs_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_runs_service() -> RunService: + mock_message_service = Mock() + mock_message_translator = Mock() + return RunService( + workspace_path, mock_message_service, mock_message_translator + ) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_runs_service] = override_runs_service + + try: + with TestClient(app) as client: + run_data = { + "assistant_id": "asst_test123", + "stream": False, + "metadata": {"key": "value", "number": 42}, + } + response = client.post( + "/v1/threads/thread_test123/runs", + json=run_data, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["assistant_id"] == "asst_test123" + assert data["thread_id"] == "thread_test123" + assert data["object"] == "thread.run" + assert "id" in data + assert "created_at" in data + finally: + app.dependency_overrides.clear() + + def test_create_run_minimal(self, test_headers: dict[str, str]) -> None: + """Test creating a run with minimal data.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.runs.dependencies import get_runs_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_runs_service() -> RunService: + mock_message_service = Mock() + mock_message_translator = Mock() + return RunService( + workspace_path, mock_message_service, mock_message_translator + ) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_runs_service] = override_runs_service + + try: + with TestClient(app) as client: + run_data = {"assistant_id": "asst_test123"} + response = client.post( + "/v1/threads/thread_test123/runs", + json=run_data, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["object"] == "thread.run" + assert data["assistant_id"] == "asst_test123" + # stream field is not returned in the response + finally: + app.dependency_overrides.clear() + + def test_create_run_streaming(self, test_headers: dict[str, str]) -> None: + """Test creating a streaming run.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.runs.dependencies import get_runs_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_runs_service() -> RunService: + mock_message_service = Mock() + mock_message_translator = Mock() + return RunService( + workspace_path, mock_message_service, mock_message_translator + ) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_runs_service] = override_runs_service + + try: + with TestClient(app) as client: + run_data = { + "assistant_id": "asst_test123", + "stream": True, + } + response = client.post( + "/v1/threads/thread_test123/runs", + json=run_data, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_201_CREATED + assert "text/event-stream" in response.headers["content-type"] + finally: + app.dependency_overrides.clear() + + def test_create_run_invalid_thread( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test creating a run in a non-existent thread.""" + run_data = {"assistant_id": "asst_test123"} + response = test_client.post( + "/v1/threads/thread_nonexistent123/runs", + json=run_data, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_retrieve_run(self, test_headers: dict[str, str]) -> None: + """Test retrieving an existing run.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + runs_dir = workspace_path / "runs" / "thread_test123" + runs_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + # Create a mock run + mock_run = Run( + id="run_test123", + object="thread.run", + created_at=1234567890, + thread_id="thread_test123", + assistant_id="asst_test123", + expires_at=1755846718, # 10 minutes later + started_at=1234567890, + completed_at=1234567900, + ) + (runs_dir / "run_test123.json").write_text(mock_run.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.runs.dependencies import get_runs_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_runs_service() -> RunService: + mock_message_service = Mock() + mock_message_translator = Mock() + return RunService( + workspace_path, mock_message_service, mock_message_translator + ) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_runs_service] = override_runs_service + + try: + with TestClient(app) as client: + response = client.get( + "/v1/threads/thread_test123/runs/run_test123", headers=test_headers + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == "run_test123" + assert data["status"] == "completed" + assert data["assistant_id"] == "asst_test123" + assert data["thread_id"] == "thread_test123" + finally: + app.dependency_overrides.clear() + + def test_retrieve_run_not_found( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test retrieving a non-existent run.""" + response = test_client.get( + "/v1/threads/thread_test123/runs/run_nonexistent123", headers=test_headers + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert "detail" in data + + def test_cancel_run(self, test_headers: dict[str, str]) -> None: + """Test canceling an existing run.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + runs_dir = workspace_path / "runs" / "thread_test123" + runs_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + # Create a mock run + mock_run = Run( + id="run_test123", + object="thread.run", + created_at=1234567890, + thread_id="thread_test123", + assistant_id="asst_test123", + expires_at=1755846718, # 10 minutes later + ) + (runs_dir / "run_test123.json").write_text(mock_run.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.runs.dependencies import get_runs_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_runs_service() -> RunService: + mock_message_service = Mock() + mock_message_translator = Mock() + return RunService( + workspace_path, mock_message_service, mock_message_translator + ) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_runs_service] = override_runs_service + + try: + with TestClient(app) as client: + response = client.post( + "/v1/threads/thread_test123/runs/run_test123/cancel", + headers=test_headers, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == "run_test123" + # The cancel operation sets tried_cancelling_at, making status + # "cancelling" + assert data["status"] == "cancelling" + finally: + app.dependency_overrides.clear() + + def test_cancel_run_not_found( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test canceling a non-existent run.""" + response = test_client.post( + "/v1/threads/thread_test123/runs/run_nonexistent123/cancel", + headers=test_headers, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/integration/chat/api/test_threads.py b/tests/integration/chat/api/test_threads.py new file mode 100644 index 00000000..b2f525f3 --- /dev/null +++ b/tests/integration/chat/api/test_threads.py @@ -0,0 +1,376 @@ +"""Integration tests for the threads API endpoints.""" + +import tempfile +from pathlib import Path +from unittest.mock import Mock + +from fastapi import status +from fastapi.testclient import TestClient + +from askui.chat.api.threads.models import Thread +from askui.chat.api.threads.service import ThreadService + + +class TestThreadsAPI: + """Test suite for the threads API endpoints.""" + + def test_list_threads_empty( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test listing threads when no threads exist.""" + response = test_client.get("/v1/threads", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert data["data"] == [] + assert data["has_more"] is False + + def test_list_threads_with_threads(self, test_headers: dict[str, str]) -> None: + """Test listing threads when threads exist.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.threads.dependencies import get_thread_service + + # Mock the dependencies + mock_message_service = Mock() + mock_run_service = Mock() + + def override_thread_service() -> ThreadService: + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + app.dependency_overrides[get_thread_service] = override_thread_service + + try: + with TestClient(app) as client: + response = client.get("/v1/threads", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["id"] == "thread_test123" + assert data["data"][0]["name"] == "Test Thread" + finally: + app.dependency_overrides.clear() + + def test_list_threads_with_pagination(self, test_headers: dict[str, str]) -> None: + """Test listing threads with pagination parameters.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + + # Create multiple mock threads + for i in range(5): + mock_thread = Thread( + id=f"thread_test{i}", + object="thread", + created_at=1234567890 + i, + name=f"Test Thread {i}", + ) + (threads_dir / f"thread_test{i}.json").write_text( + mock_thread.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.threads.dependencies import get_thread_service + + # Mock the dependencies + mock_message_service = Mock() + mock_run_service = Mock() + + def override_thread_service() -> ThreadService: + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + app.dependency_overrides[get_thread_service] = override_thread_service + + try: + with TestClient(app) as client: + response = client.get("/v1/threads?limit=3", headers=test_headers) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["data"]) == 3 + assert data["has_more"] is True + finally: + app.dependency_overrides.clear() + + def test_create_thread(self, test_headers: dict[str, str]) -> None: + """Test creating a new thread.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + from askui.chat.api.app import app + from askui.chat.api.threads.dependencies import get_thread_service + + # Mock the dependencies + mock_message_service = Mock() + mock_run_service = Mock() + + def override_thread_service() -> ThreadService: + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + app.dependency_overrides[get_thread_service] = override_thread_service + + try: + with TestClient(app) as client: + thread_data = { + "name": "New Test Thread", + } + response = client.post( + "/v1/threads", json=thread_data, headers=test_headers + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["name"] == "New Test Thread" + assert data["object"] == "thread" + assert "id" in data + assert "created_at" in data + finally: + app.dependency_overrides.clear() + + def test_create_thread_minimal(self, test_headers: dict[str, str]) -> None: + """Test creating a thread with minimal data.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + from askui.chat.api.app import app + from askui.chat.api.threads.dependencies import get_thread_service + + # Mock the dependencies + mock_message_service = Mock() + mock_run_service = Mock() + + def override_thread_service() -> ThreadService: + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + app.dependency_overrides[get_thread_service] = override_thread_service + + try: + with TestClient(app) as client: + response = client.post("/v1/threads", json={}, headers=test_headers) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["object"] == "thread" + assert data["name"] is None + finally: + app.dependency_overrides.clear() + + def test_retrieve_thread(self, test_headers: dict[str, str]) -> None: + """Test retrieving an existing thread.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.threads.dependencies import get_thread_service + + # Mock the dependencies + mock_message_service = Mock() + mock_run_service = Mock() + + def override_thread_service() -> ThreadService: + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + app.dependency_overrides[get_thread_service] = override_thread_service + + try: + with TestClient(app) as client: + response = client.get( + "/v1/threads/thread_test123", headers=test_headers + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == "thread_test123" + assert data["name"] == "Test Thread" + finally: + app.dependency_overrides.clear() + + def test_retrieve_thread_not_found( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test retrieving a non-existent thread.""" + response = test_client.get( + "/v1/threads/thread_nonexistent123", headers=test_headers + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert "detail" in data + + def test_modify_thread(self, test_headers: dict[str, str]) -> None: + """Test modifying an existing thread.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Original Name", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.threads.dependencies import get_thread_service + + # Mock the dependencies + mock_message_service = Mock() + mock_run_service = Mock() + + def override_thread_service() -> ThreadService: + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + app.dependency_overrides[get_thread_service] = override_thread_service + + try: + with TestClient(app) as client: + modify_data = { + "name": "Modified Name", + } + response = client.post( + "/v1/threads/thread_test123", json=modify_data, headers=test_headers + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Modified Name" + assert data["id"] == "thread_test123" + assert data["created_at"] == 1234567890 + finally: + app.dependency_overrides.clear() + + def test_modify_thread_partial(self, test_headers: dict[str, str]) -> None: + """Test modifying a thread with partial data.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Original Name", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.threads.dependencies import get_thread_service + + # Mock the dependencies + mock_message_service = Mock() + mock_run_service = Mock() + + def override_thread_service() -> ThreadService: + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + app.dependency_overrides[get_thread_service] = override_thread_service + + try: + with TestClient(app) as client: + modify_data = {"name": "Only Name Modified"} + response = client.post( + "/v1/threads/thread_test123", json=modify_data, headers=test_headers + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Only Name Modified" + finally: + app.dependency_overrides.clear() + + def test_modify_thread_not_found( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test modifying a non-existent thread.""" + modify_data = {"name": "Modified Name"} + response = test_client.post( + "/v1/threads/thread_nonexistent123", json=modify_data, headers=test_headers + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_delete_thread(self, test_headers: dict[str, str]) -> None: + """Test deleting an existing thread.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + + # Create the directories that the delete operation will try to remove + messages_dir = workspace_path / "messages" / "thread_test123" + messages_dir.mkdir(parents=True, exist_ok=True) + runs_dir = workspace_path / "runs" / "thread_test123" + runs_dir.mkdir(parents=True, exist_ok=True) + + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + from askui.chat.api.app import app + from askui.chat.api.threads.dependencies import get_thread_service + + # Mock the dependencies with proper return values + mock_message_service = Mock() + mock_message_service.get_messages_dir.return_value = messages_dir + mock_run_service = Mock() + mock_run_service.get_runs_dir.return_value = runs_dir + + def override_thread_service() -> ThreadService: + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + app.dependency_overrides[get_thread_service] = override_thread_service + + try: + with TestClient(app) as client: + response = client.delete( + "/v1/threads/thread_test123", headers=test_headers + ) + + assert response.status_code == status.HTTP_204_NO_CONTENT + assert response.content == b"" + finally: + app.dependency_overrides.clear() + + def test_delete_thread_not_found( + self, test_client: TestClient, test_headers: dict[str, str] + ) -> None: + """Test deleting a non-existent thread.""" + response = test_client.delete( + "/v1/threads/thread_nonexistent123", headers=test_headers + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND