diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index 43387ec9..b6b04ed8 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -1,7 +1,6 @@ import time import types from abc import ABC -from pathlib import Path from typing import Annotated, Optional, Type, overload from dotenv import load_dotenv @@ -9,6 +8,7 @@ from typing_extensions import Self from askui.container import telemetry +from askui.data_extractor import DataExtractor from askui.locators.locators import Locator from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import OnMessageCb @@ -17,7 +17,7 @@ from askui.tools.agent_os import AgentOs from askui.tools.android.agent_os import AndroidAgentOs from askui.utils.image_utils import ImageSource -from askui.utils.source_utils import InputSource, load_image_source, load_source +from askui.utils.source_utils import InputSource, load_image_source from .logger import configure_logging, logger from .models import ModelComposition @@ -65,6 +65,9 @@ def __init__( on_exception_types=(ElementNotFoundError,), ) self._model_choice = self._init_model_choice(model) + self._data_extractor = DataExtractor( + reporter=self._reporter, models=models or {} + ) def _init_model_router( self, @@ -333,36 +336,14 @@ class LinkedListNode(ResponseSchemaBase): print(text) ``` """ - logger.debug("VisionAgent received instruction to get '%s'", query) - _source = ( - ImageSource(self._agent_os.screenshot()) - if source is None - else load_source(source) - ) - - # Prepare message content with file path if available - user_message_content = f'get: "{query}"' + ( - f" from '{source}'" if isinstance(source, (str, Path)) else "" - ) - - self._reporter.add_message( - "User", - user_message_content, - image=_source.root if isinstance(_source, ImageSource) else None, - ) - response = self._model_router.get( - source=_source, + _source = source or ImageSource(self._agent_os.screenshot()) + _model = model or self._model_choice["get"] + return self._data_extractor.get( query=query, + source=_source, + model=_model, response_schema=response_schema, - model_choice=model or self._model_choice["get"], - ) - message_content = ( - str(response) - if isinstance(response, (str, bool, int, float)) - else response.model_dump() ) - self._reporter.add_message("Agent", message_content) - return response @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def _locate( diff --git a/src/askui/chat/api/app.py b/src/askui/chat/api/app.py index 8689c0f4..dd0a29a5 100644 --- a/src/askui/chat/api/app.py +++ b/src/askui/chat/api/app.py @@ -4,6 +4,7 @@ from fastapi import APIRouter, FastAPI, HTTPException, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +from fastmcp import FastMCP from askui.chat.api.assistants.dependencies import get_assistant_service from askui.chat.api.assistants.router import router as assistants_router diff --git a/src/askui/chat/api/assistants/models.py b/src/askui/chat/api/assistants/models.py index ba741a17..91548d79 100644 --- a/src/askui/chat/api/assistants/models.py +++ b/src/askui/chat/api/assistants/models.py @@ -1,6 +1,6 @@ from typing import Literal -from pydantic import BaseModel +from pydantic import BaseModel, Field from askui.chat.api.models import AssistantId from askui.utils.api_utils import Resource @@ -15,6 +15,8 @@ class AssistantBase(BaseModel): name: str | None = None description: str | None = None avatar: str | None = None + tools: list[str] = Field(default_factory=list) + system: str | None = None class AssistantCreateParams(AssistantBase): @@ -27,6 +29,8 @@ class AssistantModifyParams(BaseModelWithNotGiven): name: str | NotGiven = NOT_GIVEN description: str | NotGiven = NOT_GIVEN avatar: str | NotGiven = NOT_GIVEN + tools: list[str] | NotGiven = NOT_GIVEN + system: str | NotGiven = NOT_GIVEN class Assistant(AssistantBase, Resource): diff --git a/src/askui/chat/api/messages/models.py b/src/askui/chat/api/messages/models.py index 53e3915d..346fff7e 100644 --- a/src/askui/chat/api/messages/models.py +++ b/src/askui/chat/api/messages/models.py @@ -18,6 +18,20 @@ from askui.utils.id_utils import generate_time_ordered_id +class BetaFileDocumentSourceParam(BaseModel): + file_id: str + type: Literal["file"] = "file" + + +Source = BetaFileDocumentSourceParam + + +class RequestDocumentBlockParam(BaseModel): + source: Source + type: Literal["document"] = "document" + cache_control: CacheControlEphemeralParam | None = None + + class FileImageSourceParam(BaseModel): """Image source that references a saved file.""" @@ -46,6 +60,7 @@ class ToolResultBlockParam(BaseModel): | ToolUseBlockParam | BetaThinkingBlock | BetaRedactedThinkingBlock + | RequestDocumentBlockParam ) diff --git a/src/askui/chat/api/messages/translator.py b/src/askui/chat/api/messages/translator.py index dd399bf7..04d550b4 100644 --- a/src/askui/chat/api/messages/translator.py +++ b/src/askui/chat/api/messages/translator.py @@ -6,8 +6,11 @@ FileImageSourceParam, ImageBlockParam, MessageParam, + RequestDocumentBlockParam, ToolResultBlockParam, ) +from askui.data_extractor import DataExtractor +from askui.models.models import ModelName from askui.models.shared.agent_message_param import ( Base64ImageSourceParam, TextBlockParam, @@ -25,7 +28,72 @@ from askui.models.shared.agent_message_param import ( ToolResultBlockParam as AnthropicToolResultBlockParam, ) -from askui.utils.image_utils import image_to_base64 +from askui.utils.excel_utils import OfficeDocumentSource +from askui.utils.image_utils import ImageSource, image_to_base64 +from askui.utils.source_utils import Source, load_source + + +class RequestDocumentBlockParamTranslator: + """Translator for RequestDocumentBlockParam to/from Anthropic format.""" + + def __init__(self, file_service: FileService) -> None: + self._file_service = file_service + self._data_extractor = DataExtractor() + + def extract_content( + self, source: Source, block: RequestDocumentBlockParam + ) -> list[AnthropicContentBlockParam]: + if isinstance(source, ImageSource): + return [ + AnthropicImageBlockParam( + source=Base64ImageSourceParam( + data=source.to_base64(), + media_type="image/png", + ), + type="image", + cache_control=block.cache_control, + ) + ] + if isinstance(source, OfficeDocumentSource): + with source.reader as r: + data = r.read() + return [ + TextBlockParam( + text=data.decode(), + type="text", + cache_control=block.cache_control, + ) + ] + text = self._data_extractor.get( + query="""Extract all the content of the PDF to Markdown format. + Preserve layout and formatting as much as possible, e.g., representing + tables as HTML tables. For all images, videos, figures, extract text + from it and describe what you are seeing, e.g., what is shown in the + image or figure, and include that description.""", + source=source, + model=ModelName.ASKUI, + ) + return [ + TextBlockParam( + text=text, + type="text", + cache_control=block.cache_control, + ) + ] + + async def to_anthropic( + self, block: RequestDocumentBlockParam + ) -> list[AnthropicContentBlockParam]: + file, path = self._file_service.retrieve_file_content(block.source.file_id) + source = load_source(path) + content = self.extract_content(source, block) + return [ + TextBlockParam( + text=file.model_dump_json(), + type="text", + cache_control=block.cache_control, + ), + ] + content class ImageBlockParamSourceTranslator: @@ -172,24 +240,29 @@ class MessageContentBlockParamTranslator: def __init__(self, file_service: FileService) -> None: self.image_translator = ImageBlockParamTranslator(file_service) self.tool_result_translator = ToolResultBlockParamTranslator(file_service) + self.request_document_translator = RequestDocumentBlockParamTranslator( + file_service + ) async def from_anthropic( self, block: AnthropicContentBlockParam - ) -> ContentBlockParam: + ) -> list[ContentBlockParam]: if block.type == "image": - return await self.image_translator.from_anthropic(block) + return [await self.image_translator.from_anthropic(block)] if block.type == "tool_result": - return await self.tool_result_translator.from_anthropic(block) - return block + return [await self.tool_result_translator.from_anthropic(block)] + return [block] async def to_anthropic( self, block: ContentBlockParam - ) -> AnthropicContentBlockParam: + ) -> list[AnthropicContentBlockParam]: if block.type == "image": - return await self.image_translator.to_anthropic(block) + return [await self.image_translator.to_anthropic(block)] if block.type == "tool_result": - return await self.tool_result_translator.to_anthropic(block) - return block + return [await self.tool_result_translator.to_anthropic(block)] + if block.type == "document": + return await self.request_document_translator.to_anthropic(block) + return [block] class MessageContentTranslator: @@ -201,18 +274,20 @@ async def from_anthropic( ) -> list[ContentBlockParam] | str: if isinstance(content, str): return content - return [ + lists_of_blocks = [ await self.block_param_translator.from_anthropic(block) for block in content ] + return [block for sublist in lists_of_blocks for block in sublist] async def to_anthropic( self, content: list[ContentBlockParam] | str ) -> list[AnthropicContentBlockParam] | str: if isinstance(content, str): return content - return [ + lists_of_blocks = [ await self.block_param_translator.to_anthropic(block) for block in content ] + return [block for sublist in lists_of_blocks for block in sublist] class MessageTranslator: diff --git a/src/askui/chat/api/runs/dependencies.py b/src/askui/chat/api/runs/dependencies.py index afb01cab..fd361cbc 100644 --- a/src/askui/chat/api/runs/dependencies.py +++ b/src/askui/chat/api/runs/dependencies.py @@ -2,6 +2,8 @@ from fastapi import Depends +from askui.chat.api.assistants.dependencies import AssistantServiceDep +from askui.chat.api.assistants.service import AssistantService from askui.chat.api.dependencies import WorkspaceDirDep from askui.chat.api.messages.dependencies import MessageServiceDep, MessageTranslatorDep from askui.chat.api.messages.service import MessageService @@ -12,11 +14,14 @@ def get_runs_service( workspace_dir: Path = WorkspaceDirDep, + assistant_service: AssistantService = AssistantServiceDep, message_service: MessageService = MessageServiceDep, message_translator: MessageTranslator = MessageTranslatorDep, ) -> RunService: """Get RunService instance.""" - return RunService(workspace_dir, message_service, message_translator) + return RunService( + workspace_dir, assistant_service, message_service, message_translator + ) RunServiceDep = Depends(get_runs_service) diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index bf250fab..003eb1ce 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Literal, Sequence +import anthropic import anyio from anyio.abc import ObjectStream from asyncer import asyncify, syncify @@ -12,6 +13,7 @@ from askui.agent import VisionAgent from askui.android_agent import AndroidVisionAgent +from askui.chat.api.assistants.models import Assistant from askui.chat.api.assistants.seeds import ( ANDROID_VISION_AGENT, ASKUI_VISION_AGENT, @@ -35,6 +37,7 @@ from askui.chat.api.runs.runner.events.events import Events from askui.chat.api.runs.runner.events.message_events import MessageEvent from askui.chat.api.runs.runner.events.run_events import RunEvent +from askui.custom_agent import CustomAgent from askui.models.shared.agent_message_param import ( Base64ImageSourceParam, ImageBlockParam, @@ -42,6 +45,7 @@ TextBlockParam, ) from askui.models.shared.agent_on_message_cb import OnMessageCbParam +from askui.models.shared.settings import ActSettings, MessageSettings from askui.models.shared.tools import ToolCollection from askui.tools.pynput_agent_os import PynputAgentOs from askui.utils.api_utils import ( @@ -94,11 +98,13 @@ def get_mcp_client( class Runner: def __init__( self, + assistant: Assistant, run: Run, base_dir: Path, message_service: MessageService, message_translator: MessageTranslator, ) -> None: + self._assistant = assistant self._run = run self._base_dir = base_dir self._message_service = message_service @@ -263,16 +269,19 @@ async def _run_askui_web_testing_agent( async def _run_agent( self, - agent_type: Literal["android", "vision", "web", "web_testing"], + agent_type: Literal["android", "vision", "web", "web_testing", "custom"], send_stream: ObjectStream[Events], mcp_client: McpClient | None, ) -> None: - tools = ToolCollection(mcp_client=mcp_client) + tools = ToolCollection( + mcp_client=mcp_client, + include=set(self._assistant.tools) if self._assistant.tools else None, + ) messages: list[MessageParam] = [ await self._message_translator.to_anthropic(msg) for msg in self._message_service.list_( thread_id=self._run.thread_id, - query=ListQuery(limit=LIST_LIMIT_MAX), + query=ListQuery(limit=LIST_LIMIT_MAX, order="asc"), ).data ] @@ -333,12 +342,31 @@ def _run_agent_inner() -> None: ) return - with VisionAgent() as agent: - agent.act( - messages, - on_message=on_message, - tools=tools, - ) + if agent_type == "vision": + with VisionAgent() as agent: + agent.act( + messages, + on_message=on_message, + tools=tools, + ) + return + + _tools = ToolCollection( + mcp_client=mcp_client, + include=set(self._assistant.tools), + ) + custom_agent = CustomAgent() + custom_agent.act( + messages, + on_message=on_message, + tools=_tools, + settings=ActSettings( + messages=MessageSettings( + system=self._assistant.system or anthropic.NOT_GIVEN, + thinking={"type": "enabled", "budget_tokens": 2048}, + ), + ), + ) await asyncify(_run_agent_inner)() @@ -377,6 +405,12 @@ async def run( send_stream, mcp_client, ) + else: + await self._run_agent( + agent_type="custom", + send_stream=send_stream, + mcp_client=mcp_client, + ) updated_run = self._retrieve() if updated_run.status == "in_progress": updated_run.completed_at = datetime.now(tz=timezone.utc) diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py index 9c5aed5f..b908f068 100644 --- a/src/askui/chat/api/runs/service.py +++ b/src/askui/chat/api/runs/service.py @@ -4,6 +4,7 @@ import anyio +from askui.chat.api.assistants.service import AssistantService from askui.chat.api.messages.service import MessageService from askui.chat.api.messages.translator import MessageTranslator from askui.chat.api.models import RunId, ThreadId @@ -28,10 +29,12 @@ class RunService: def __init__( self, base_dir: Path, + assistant_service: AssistantService, message_service: MessageService, message_translator: MessageTranslator, ) -> None: self._base_dir = base_dir + self._assistant_service = assistant_service self._message_service = message_service self._message_translator = message_translator @@ -59,10 +62,15 @@ def _create(self, thread_id: ThreadId, params: RunCreateParams) -> Run: async def create( self, thread_id: ThreadId, params: RunCreateParams ) -> tuple[Run, AsyncGenerator[Events, None]]: + assistant = self._assistant_service.retrieve(params.assistant_id) run = self._create(thread_id, params) send_stream, receive_stream = anyio.create_memory_object_stream[Events]() runner = Runner( - run, self._base_dir, self._message_service, self._message_translator + assistant, + run, + self._base_dir, + self._message_service, + self._message_translator, ) async def event_generator() -> AsyncGenerator[Events, None]: diff --git a/src/askui/custom_agent.py b/src/askui/custom_agent.py new file mode 100644 index 00000000..a93fb15a --- /dev/null +++ b/src/askui/custom_agent.py @@ -0,0 +1,57 @@ +from typing import Annotated + +from pydantic import ConfigDict, Field, validate_call + +from askui.container import telemetry +from askui.models.models import ModelName +from askui.models.shared.agent_message_param import MessageParam +from askui.models.shared.agent_on_message_cb import OnMessageCb +from askui.models.shared.settings import ActSettings +from askui.models.shared.tools import Tool, ToolCollection + +from .models.model_router import ModelRouter, initialize_default_model_registry +from .reporting import NullReporter + + +class CustomAgent: + def __init__(self) -> None: + self._model_router = self._init_model_router() + + def _init_model_router( + self, + ) -> ModelRouter: + reporter = NullReporter() + models = initialize_default_model_registry( + reporter=reporter, + ) + return ModelRouter( + reporter=reporter, + models=models, + ) + + @telemetry.record_call(exclude={"messages", "on_message", "settings", "tools"}) + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def act( + self, + messages: Annotated[list[MessageParam], Field(min_length=1)], + model: str | None = None, + on_message: OnMessageCb | None = None, + tools: list[Tool] | ToolCollection | None = None, + settings: ActSettings | None = None, + ) -> None: + _settings = settings or ActSettings() + _tools = self._build_tools(tools) + self._model_router.act( + messages=messages, + model_choice=model or ModelName.CLAUDE__SONNET__4__20250514, + on_message=on_message, + settings=_settings, + tools=_tools, + ) + + def _build_tools(self, tools: list[Tool] | ToolCollection | None) -> ToolCollection: + if isinstance(tools, list): + return ToolCollection(tools=tools) + if isinstance(tools, ToolCollection): + return tools + return ToolCollection() diff --git a/src/askui/data_extractor.py b/src/askui/data_extractor.py new file mode 100644 index 00000000..adab01d5 --- /dev/null +++ b/src/askui/data_extractor.py @@ -0,0 +1,95 @@ +from pathlib import Path +from typing import Annotated, Type, overload + +from PIL import Image as PILImage +from pydantic import Field + +from askui.models.models import ModelRegistry +from askui.reporting import NULL_REPORTER, Reporter +from askui.utils.image_utils import ImageSource +from askui.utils.source_utils import InputSource, Source, load_source + +from .logger import logger +from .models.model_router import ModelRouter, initialize_default_model_registry +from .models.types.response_schemas import ResponseSchema + + +class DataExtractor: + def __init__( + self, + reporter: Reporter = NULL_REPORTER, + models: ModelRegistry | None = None, + ) -> None: + self._reporter = reporter + self._model_router = self._init_model_router( + reporter=reporter, + models=models or {}, + ) + + def _init_model_router( + self, + reporter: Reporter, + models: ModelRegistry, + ) -> ModelRouter: + _models = initialize_default_model_registry( + reporter=reporter, + ) + _models.update(models) + return ModelRouter( + reporter=reporter, + models=_models, + ) + + @overload + def get( + self, + query: Annotated[str, Field(min_length=1)], + source: InputSource | Source, + model: str, + response_schema: None = None, + ) -> str: ... + @overload + def get( + self, + query: Annotated[str, Field(min_length=1)], + source: InputSource | Source, + model: str, + response_schema: Type[ResponseSchema], + ) -> ResponseSchema: ... + def get( + self, + query: Annotated[str, Field(min_length=1)], + source: InputSource | Source, + model: str, + response_schema: Type[ResponseSchema] | None = None, + ) -> ResponseSchema | str: + logger.debug("Received instruction to get '%s'", query) + _source = ( + load_source(source) + if isinstance(source, (str, Path, PILImage.Image)) + else source + ) + + # Prepare message content with file path if available + user_message_content = f'get: "{query}"' + ( + f" from '{source}'" if isinstance(source, (str, Path)) else "" + ) + + self._reporter.add_message( + "User", + user_message_content, + image=_source.root if isinstance(_source, ImageSource) else None, + ) + response = self._model_router.get( + source=_source, + query=query, + response_schema=response_schema, + model_choice=model, + ) + message_content = ( + str(response) + if isinstance(response, (str, bool, int, float)) + else response.model_dump() + ) + self._reporter.add_message("Agent", message_content) + return response diff --git a/src/askui/models/shared/tools.py b/src/askui/models/shared/tools.py index c8976e6b..2d2422c1 100644 --- a/src/askui/models/shared/tools.py +++ b/src/askui/models/shared/tools.py @@ -154,10 +154,12 @@ def __init__( self, tools: list[Tool] | None = None, mcp_client: Client[ClientTransportT] | None = None, + include: set[str] | None = None, ) -> None: _tools = tools or [] self._tool_map = {tool.to_params()["name"]: tool for tool in _tools} self._mcp_client = mcp_client + self._include = include def to_params(self) -> list[BetaToolUnionParam]: tool_map = { @@ -167,7 +169,12 @@ def to_params(self) -> list[BetaToolUnionParam]: for tool_name, tool in self._tool_map.items() }, } - return list(tool_map.values()) + filtered_tool_map = { + tool_name: tool + for tool_name, tool in tool_map.items() + if self._include is None or tool_name in self._include + } + return list(filtered_tool_map.values()) def _get_mcp_tool_params(self) -> dict[str, BetaToolUnionParam]: if not self._mcp_client: diff --git a/src/askui/tools/mcp/servers/stdio.py b/src/askui/tools/mcp/servers/stdio.py index 7785f5be..02d3f02d 100644 --- a/src/askui/tools/mcp/servers/stdio.py +++ b/src/askui/tools/mcp/servers/stdio.py @@ -11,5 +11,10 @@ def test_stdio_tool() -> str: return "I am a test stdio tool" +@mcp.tool +def list_values() -> list[str]: + return ["Optimism", "Creativity", "Intelligence"] + + if __name__ == "__main__": mcp.run(transport="stdio", show_banner=False) diff --git a/tests/integration/chat/api/test_assistants.py b/tests/integration/chat/api/test_assistants.py index f1b1409f..e6e23e66 100644 --- a/tests/integration/chat/api/test_assistants.py +++ b/tests/integration/chat/api/test_assistants.py @@ -185,6 +185,82 @@ def override_assistant_service() -> AssistantService: finally: app.dependency_overrides.clear() + def test_create_assistant_with_tools_and_system( + self, test_headers: dict[str, str] + ) -> None: + """Test creating a new assistant with tools and system prompt.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + + def override_assistant_service() -> AssistantService: + return AssistantService(workspace_path) + + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + response = client.post( + "/v1/assistants", + headers=test_headers, + json={ + "name": "Custom Assistant", + "description": "A custom assistant with tools", + "tools": ["tool1", "tool2", "tool3"], + "system": "You are a helpful custom assistant.", + }, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["name"] == "Custom Assistant" + assert data["description"] == "A custom assistant with tools" + assert data["tools"] == ["tool1", "tool2", "tool3"] + assert data["system"] == "You are a helpful custom assistant." + assert "id" in data + assert "created_at" in data + finally: + app.dependency_overrides.clear() + + def test_create_assistant_with_empty_tools( + self, test_headers: dict[str, str] + ) -> None: + """Test creating a new assistant with empty tools list.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + + # Create a test app with overridden dependencies + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + + def override_assistant_service() -> AssistantService: + return AssistantService(workspace_path) + + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + response = client.post( + "/v1/assistants", + headers=test_headers, + json={ + "name": "Empty Tools Assistant", + "tools": [], + }, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["name"] == "Empty Tools Assistant" + assert data["tools"] == [] + assert "id" in data + assert "created_at" in data + finally: + app.dependency_overrides.clear() + def test_retrieve_assistant(self, test_headers: dict[str, str]) -> None: """Test retrieving an existing assistant.""" temp_dir = tempfile.mkdtemp() @@ -284,6 +360,57 @@ def override_assistant_service() -> AssistantService: finally: app.dependency_overrides.clear() + def test_modify_assistant_with_tools_and_system( + self, test_headers: dict[str, str] + ) -> None: + """Test modifying an assistant with tools and system prompt.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + assistants_dir = workspace_path / "assistants" + assistants_dir.mkdir(parents=True, exist_ok=True) + + mock_assistant = Assistant( + id="asst_test123", + object="assistant", + created_at=1234567890, + name="Original Name", + description="Original description", + ) + (assistants_dir / "asst_test123.json").write_text( + mock_assistant.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + + def override_assistant_service() -> AssistantService: + return AssistantService(workspace_path) + + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + modify_data = { + "name": "Modified Name", + "tools": ["new_tool1", "new_tool2"], + "system": "You are a modified custom assistant.", + } + response = client.post( + "/v1/assistants/asst_test123", + json=modify_data, + headers=test_headers, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Modified Name" + assert data["tools"] == ["new_tool1", "new_tool2"] + assert data["system"] == "You are a modified custom assistant." + assert data["id"] == "asst_test123" + assert data["created_at"] == 1234567890 + finally: + app.dependency_overrides.clear() + def test_modify_assistant_partial(self, test_headers: dict[str, str]) -> None: """Test modifying an assistant with partial data.""" temp_dir = tempfile.mkdtemp() diff --git a/tests/integration/chat/api/test_request_document_translator.py b/tests/integration/chat/api/test_request_document_translator.py new file mode 100644 index 00000000..ba5f1597 --- /dev/null +++ b/tests/integration/chat/api/test_request_document_translator.py @@ -0,0 +1,273 @@ +"""Integration tests for RequestDocumentBlockParamTranslator.""" + +import pathlib +import shutil +import tempfile +from typing import Generator + +import pytest +from PIL import Image + +from askui.chat.api.files.service import FileService +from askui.chat.api.messages.models import RequestDocumentBlockParam +from askui.chat.api.messages.translator import RequestDocumentBlockParamTranslator +from askui.models.shared.agent_message_param import CacheControlEphemeralParam +from askui.utils.excel_utils import OfficeDocumentSource +from askui.utils.image_utils import ImageSource + + +class TestRequestDocumentBlockParamTranslator: + """Integration tests for RequestDocumentBlockParamTranslator with real files.""" + + @pytest.fixture + def temp_dir(self) -> Generator[pathlib.Path, None, None]: + """Create a temporary directory for test files.""" + temp_dir = pathlib.Path(tempfile.mkdtemp()) + yield temp_dir + # Cleanup: remove the temporary directory and all its contents + shutil.rmtree(temp_dir, ignore_errors=True) + + @pytest.fixture + def file_service(self, temp_dir: pathlib.Path) -> FileService: + """Create a FileService instance using the temporary directory.""" + return FileService(temp_dir) + + @pytest.fixture + def translator( + self, file_service: FileService + ) -> RequestDocumentBlockParamTranslator: + """Create a RequestDocumentBlockParamTranslator instance.""" + return RequestDocumentBlockParamTranslator(file_service) + + @pytest.fixture + def cache_control(self) -> CacheControlEphemeralParam: + """Sample cache control parameter.""" + return CacheControlEphemeralParam(type="ephemeral") + + def test_extract_content_from_image( + self, + translator: RequestDocumentBlockParamTranslator, + path_fixtures_github_com__icon: pathlib.Path, + temp_dir: pathlib.Path, + cache_control: CacheControlEphemeralParam, + ) -> None: + """Test extracting content from an image file.""" + # Copy the fixture image to the temporary directory + temp_image_path = temp_dir / "test_icon.png" + shutil.copy2(path_fixtures_github_com__icon, temp_image_path) + + # Create a document block with cache control + document_block = RequestDocumentBlockParam( + source={"file_id": "image123", "type": "file"}, + type="document", + cache_control=cache_control, + ) + + # Load the image source using PIL Image from the temporary file + pil_image = Image.open(temp_image_path) + image_source = ImageSource(pil_image) + + # Extract content + result = translator.extract_content(image_source, document_block) + + # Should return a list with one image block + assert isinstance(result, list) + assert len(result) == 1 + + # First element should be an image block + image_block = result[0] + assert image_block.type == "image" + assert image_block.cache_control == cache_control + + # Check the source is base64 encoded + assert image_block.source.type == "base64" + assert image_block.source.media_type == "image/png" + assert isinstance(image_block.source.data, str) + assert len(image_block.source.data) > 0 + + def test_extract_content_from_excel( + self, + translator: RequestDocumentBlockParamTranslator, + path_fixtures_dummy_excel: pathlib.Path, + temp_dir: pathlib.Path, + cache_control: CacheControlEphemeralParam, + ) -> None: + """Test extracting content from an Excel file.""" + # Copy the fixture Excel file to the temporary directory + temp_excel_path = temp_dir / "test_data.xlsx" + shutil.copy2(path_fixtures_dummy_excel, temp_excel_path) + + # Create a document block with cache control + document_block = RequestDocumentBlockParam( + source={"file_id": "excel123", "type": "file"}, + type="document", + cache_control=cache_control, + ) + + # Load the Excel source from the temporary file + excel_source = OfficeDocumentSource(root=temp_excel_path) + + # Extract content + result = translator.extract_content(excel_source, document_block) + + # Should return a list with one text block + assert isinstance(result, list) + assert len(result) == 1 + + # First element should be a text block + text_block = result[0] + assert text_block.type == "text" + assert text_block.cache_control == cache_control + + # Check the text content + assert isinstance(text_block.text, str) + assert len(text_block.text) > 0 + + def test_extract_content_from_word( + self, + translator: RequestDocumentBlockParamTranslator, + path_fixtures_dummy_doc: pathlib.Path, + temp_dir: pathlib.Path, + cache_control: CacheControlEphemeralParam, + ) -> None: + """Test extracting content from a Word document.""" + # Copy the fixture Word file to the temporary directory + temp_doc_path = temp_dir / "test_document.docx" + shutil.copy2(path_fixtures_dummy_doc, temp_doc_path) + + # Create a document block with cache control + document_block = RequestDocumentBlockParam( + source={"file_id": "word123", "type": "file"}, + type="document", + cache_control=cache_control, + ) + + # Load the Word source from the temporary file + word_source = OfficeDocumentSource(root=temp_doc_path) + + # Extract content + result = translator.extract_content(word_source, document_block) + + # Should return a list with one text block + assert isinstance(result, list) + assert len(result) == 1 + + # First element should be a text block + text_block = result[0] + assert text_block.type == "text" + assert text_block.cache_control == cache_control + + # Check the text content + assert isinstance(text_block.text, str) + assert len(text_block.text) > 0 + + def test_extract_content_from_image_no_cache_control( + self, + translator: RequestDocumentBlockParamTranslator, + path_fixtures_github_com__icon: pathlib.Path, + temp_dir: pathlib.Path, + ) -> None: + """Test extracting content from an image file without cache control.""" + # Copy the fixture image to the temporary directory + temp_image_path = temp_dir / "test_icon_no_cache.png" + shutil.copy2(path_fixtures_github_com__icon, temp_image_path) + + # Create a document block without cache control + document_block = RequestDocumentBlockParam( + source={"file_id": "image123", "type": "file"}, + type="document", + ) + + # Load the image source using PIL Image from the temporary file + pil_image = Image.open(temp_image_path) + image_source = ImageSource(pil_image) + + # Extract content + result = translator.extract_content(image_source, document_block) + + # Should return a list with one image block + assert isinstance(result, list) + assert len(result) == 1 + + # First element should be an image block + image_block = result[0] + assert image_block.type == "image" + assert image_block.cache_control is None + + # Check the source is base64 encoded + assert image_block.source.type == "base64" + assert image_block.source.media_type == "image/png" + assert isinstance(image_block.source.data, str) + assert len(image_block.source.data) > 0 + + def test_extract_content_from_excel_no_cache_control( + self, + translator: RequestDocumentBlockParamTranslator, + path_fixtures_dummy_excel: pathlib.Path, + temp_dir: pathlib.Path, + ) -> None: + """Test extracting content from an Excel file without cache control.""" + # Copy the fixture Excel file to the temporary directory + temp_excel_path = temp_dir / "test_data_no_cache.xlsx" + shutil.copy2(path_fixtures_dummy_excel, temp_excel_path) + + # Create a document block without cache control + document_block = RequestDocumentBlockParam( + source={"file_id": "excel123", "type": "file"}, + type="document", + ) + + # Load the Excel source from the temporary file + excel_source = OfficeDocumentSource(root=temp_excel_path) + + # Extract content + result = translator.extract_content(excel_source, document_block) + + # Should return a list with one text block + assert isinstance(result, list) + assert len(result) == 1 + + # First element should be a text block + text_block = result[0] + assert text_block.type == "text" + assert text_block.cache_control is None + + # Check the text content + assert isinstance(text_block.text, str) + assert len(text_block.text) > 0 + + def test_extract_content_from_word_no_cache_control( + self, + translator: RequestDocumentBlockParamTranslator, + path_fixtures_dummy_doc: pathlib.Path, + temp_dir: pathlib.Path, + ) -> None: + """Test extracting content from a Word document without cache control.""" + # Copy the fixture Word file to the temporary directory + temp_doc_path = temp_dir / "test_document_no_cache.docx" + shutil.copy2(path_fixtures_dummy_doc, temp_doc_path) + + # Create a document block without cache control + document_block = RequestDocumentBlockParam( + source={"file_id": "word123", "type": "file"}, + type="document", + ) + + # Load the Word source from the temporary file + word_source = OfficeDocumentSource(root=temp_doc_path) + + # Extract content + result = translator.extract_content(word_source, document_block) + + # Should return a list with one text block + assert isinstance(result, list) + assert len(result) == 1 + + # First element should be a text block + text_block = result[0] + assert text_block.type == "text" + assert text_block.cache_control is None + + # Check the text content + assert isinstance(text_block.text, str) + assert len(text_block.text) > 0 diff --git a/tests/integration/chat/api/test_runs.py b/tests/integration/chat/api/test_runs.py index 6c31d4c4..95304abe 100644 --- a/tests/integration/chat/api/test_runs.py +++ b/tests/integration/chat/api/test_runs.py @@ -7,6 +7,7 @@ from fastapi import status from fastapi.testclient import TestClient +from askui.chat.api.assistants.service import AssistantService from askui.chat.api.runs.models import Run from askui.chat.api.runs.service import RunService from askui.chat.api.threads.models import Thread @@ -44,10 +45,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -110,10 +115,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -177,10 +186,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -227,10 +240,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -287,10 +304,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -341,10 +362,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -384,10 +409,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -439,10 +468,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -485,10 +518,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -534,10 +571,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -596,10 +637,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -640,10 +685,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -721,10 +770,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -801,10 +854,14 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: + mock_assistant_service = Mock() mock_message_service = Mock() mock_message_translator = Mock() return RunService( - workspace_path, mock_message_service, mock_message_translator + workspace_path, + mock_assistant_service, + mock_message_service, + mock_message_translator, ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -836,3 +893,175 @@ def test_cancel_run_not_found( ) assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_create_run_with_custom_assistant( + self, test_headers: dict[str, str] + ) -> None: + """Test creating a run with a custom assistant.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + assistants_dir = workspace_path / "assistants" + assistants_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + # Create a mock custom assistant + from askui.chat.api.assistants.models import Assistant + + mock_assistant = Assistant( + id="asst_custom123", + object="assistant", + created_at=1234567890, + name="Custom Assistant", + tools=["tool1", "tool2"], + system="You are a custom assistant.", + ) + (assistants_dir / "asst_custom123.json").write_text( + mock_assistant.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + from askui.chat.api.runs.dependencies import get_runs_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_runs_service() -> RunService: + mock_message_service = Mock() + mock_message_translator = Mock() + from askui.chat.api.assistants.service import AssistantService + + return RunService( + workspace_path, + AssistantService(workspace_path), + mock_message_service, + mock_message_translator, + ) + + def override_assistant_service() -> AssistantService: + from askui.chat.api.assistants.service import AssistantService + + return AssistantService(workspace_path) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_runs_service] = override_runs_service + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + response = client.post( + "/v1/threads/thread_test123/runs", + headers=test_headers, + json={"assistant_id": "asst_custom123"}, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["assistant_id"] == "asst_custom123" + assert data["thread_id"] == "thread_test123" + assert data["status"] == "queued" + assert "id" in data + assert "created_at" in data + finally: + app.dependency_overrides.clear() + + def test_create_run_with_custom_assistant_empty_tools( + self, test_headers: dict[str, str] + ) -> None: + """Test creating a run with a custom assistant that has empty tools.""" + temp_dir = tempfile.mkdtemp() + workspace_path = Path(temp_dir) + threads_dir = workspace_path / "threads" + threads_dir.mkdir(parents=True, exist_ok=True) + assistants_dir = workspace_path / "assistants" + assistants_dir.mkdir(parents=True, exist_ok=True) + + # Create a mock thread + mock_thread = Thread( + id="thread_test123", + object="thread", + created_at=1234567890, + name="Test Thread", + ) + (threads_dir / "thread_test123.json").write_text(mock_thread.model_dump_json()) + + # Create a mock custom assistant with empty tools + from askui.chat.api.assistants.models import Assistant + + mock_assistant = Assistant( + id="asst_customempty123", + object="assistant", + created_at=1234567890, + name="Empty Tools Assistant", + tools=[], + system="You are an assistant with no tools.", + ) + (assistants_dir / "asst_customempty123.json").write_text( + mock_assistant.model_dump_json() + ) + + from askui.chat.api.app import app + from askui.chat.api.assistants.dependencies import get_assistant_service + from askui.chat.api.runs.dependencies import get_runs_service + from askui.chat.api.threads.dependencies import get_thread_service + + def override_thread_service() -> ThreadService: + from askui.chat.api.threads.service import ThreadService + + mock_message_service = Mock() + mock_run_service = Mock() + return ThreadService(workspace_path, mock_message_service, mock_run_service) + + def override_runs_service() -> RunService: + mock_message_service = Mock() + mock_message_translator = Mock() + from askui.chat.api.assistants.service import AssistantService + + return RunService( + workspace_path, + AssistantService(workspace_path), + mock_message_service, + mock_message_translator, + ) + + def override_assistant_service() -> AssistantService: + from askui.chat.api.assistants.service import AssistantService + + return AssistantService(workspace_path) + + app.dependency_overrides[get_thread_service] = override_thread_service + app.dependency_overrides[get_runs_service] = override_runs_service + app.dependency_overrides[get_assistant_service] = override_assistant_service + + try: + with TestClient(app) as client: + response = client.post( + "/v1/threads/thread_test123/runs", + headers=test_headers, + json={"assistant_id": "asst_customempty123"}, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["assistant_id"] == "asst_customempty123" + assert data["thread_id"] == "thread_test123" + assert data["status"] == "queued" + assert "id" in data + assert "created_at" in data + finally: + app.dependency_overrides.clear() diff --git a/tests/unit/test_request_document_translator.py b/tests/unit/test_request_document_translator.py new file mode 100644 index 00000000..93130a89 --- /dev/null +++ b/tests/unit/test_request_document_translator.py @@ -0,0 +1,134 @@ +from pathlib import Path +from unittest.mock import MagicMock + +import pytest +import pytest_mock + +from askui.chat.api.messages.models import RequestDocumentBlockParam +from askui.chat.api.messages.translator import RequestDocumentBlockParamTranslator +from askui.models.shared.agent_message_param import ( + CacheControlEphemeralParam, + TextBlockParam, +) +from askui.utils.pdf_utils import PdfSource + + +class TestRequestDocumentBlockParamTranslator: + """Test cases for RequestDocumentBlockParamTranslator.""" + + @pytest.fixture + def file_service(self) -> MagicMock: + """Mock file service.""" + return MagicMock() + + @pytest.fixture + def translator( + self, file_service: MagicMock + ) -> RequestDocumentBlockParamTranslator: + """Create translator instance.""" + return RequestDocumentBlockParamTranslator(file_service) + + @pytest.fixture + def cache_control(self) -> CacheControlEphemeralParam: + """Sample cache control parameter.""" + return CacheControlEphemeralParam(type="ephemeral") + + def test_init(self, file_service: MagicMock) -> None: + """Test translator initialization.""" + translator = RequestDocumentBlockParamTranslator(file_service) + assert translator._file_service == file_service + + @pytest.mark.asyncio + async def test_to_anthropic_success( + self, + translator: RequestDocumentBlockParamTranslator, + cache_control: CacheControlEphemeralParam, + mocker: pytest_mock.MockerFixture, + ) -> None: + """Test successful conversion to Anthropic format.""" + document_block = RequestDocumentBlockParam( + source={"file_id": "xyz789", "type": "file"}, + type="document", + cache_control=cache_control, + ) + + # Mock the file service response + mock_file = MagicMock() + mock_file.model_dump_json.return_value = '{"id": "xyz789", "name": "test.pdf"}' + mock_path = Path("/tmp/test.pdf") + mocker.patch.object( + translator._file_service, + "retrieve_file_content", + return_value=(mock_file, mock_path), + ) + + # Mock the load_source function to avoid filesystem access + mock_pdf_source = PdfSource(root=mock_path) + mocker.patch( + "askui.chat.api.messages.translator.load_source", + return_value=mock_pdf_source, + ) + + # Mock the extract_content method to return a simple text block + mock_text_block = TextBlockParam( + text="Extracted text content", type="text", cache_control=cache_control + ) + mocker.patch.object( + translator, "extract_content", return_value=[mock_text_block] + ) + + result = await translator.to_anthropic(document_block) + + assert isinstance(result, list) + assert len(result) == 2 # file info + extracted content + # First element should be the file info as TextBlockParam + assert isinstance(result[0], TextBlockParam) + assert result[0].type == "text" + assert result[0].cache_control == cache_control + # Second element should be the extracted content + assert result[1] == mock_text_block + + @pytest.mark.asyncio + async def test_to_anthropic_no_cache_control( + self, + translator: RequestDocumentBlockParamTranslator, + mocker: pytest_mock.MockerFixture, + ) -> None: + """Test conversion without cache control.""" + document_block = RequestDocumentBlockParam( + source={"file_id": "def456", "type": "file"}, + type="document", + ) + + # Mock the file service response + mock_file = MagicMock() + mock_file.model_dump_json.return_value = '{"id": "def456", "name": "test.pdf"}' + mock_path = Path("/tmp/test.pdf") + mocker.patch.object( + translator._file_service, + "retrieve_file_content", + return_value=(mock_file, mock_path), + ) + + # Mock the load_source function to avoid filesystem access + mock_pdf_source = PdfSource(root=mock_path) + mocker.patch( + "askui.chat.api.messages.translator.load_source", + return_value=mock_pdf_source, + ) + + # Mock the extract_content method to return a simple text block + mock_text_block = TextBlockParam(text="Extracted text content", type="text") + mocker.patch.object( + translator, "extract_content", return_value=[mock_text_block] + ) + + result = await translator.to_anthropic(document_block) + + assert isinstance(result, list) + assert len(result) == 2 # file info + extracted content + # First element should be the file info as TextBlockParam + assert isinstance(result[0], TextBlockParam) + assert result[0].cache_control is None + # Second element should be the extracted content + assert result[1] == mock_text_block