diff --git a/src/askui/chat/api/assistants/service.py b/src/askui/chat/api/assistants/service.py index cfb1d667..9c802d43 100644 --- a/src/askui/chat/api/assistants/service.py +++ b/src/askui/chat/api/assistants/service.py @@ -50,7 +50,11 @@ def retrieve( ) -> Assistant: try: assistant_path = self._get_assistant_path(assistant_id) - assistant = Assistant.model_validate_json(assistant_path.read_text()) + content = assistant_path.read_text() + if not content.strip(): + error_msg = f"Assistant {assistant_id} not found" + raise NotFoundError(error_msg) + assistant = Assistant.model_validate_json(content) if not ( assistant.workspace_id is None or assistant.workspace_id == workspace_id ): @@ -59,6 +63,10 @@ def retrieve( except FileNotFoundError as e: error_msg = f"Assistant {assistant_id} not found" raise NotFoundError(error_msg) from e + except (ValueError, TypeError) as e: + # Handle JSON parsing errors + error_msg = f"Assistant {assistant_id} not found" + raise NotFoundError(error_msg) from e else: return assistant @@ -76,6 +84,9 @@ def modify( params: AssistantModifyParams, ) -> Assistant: assistant = self.retrieve(workspace_id, assistant_id) + if assistant.workspace_id is None: + error_msg = f"Default assistant {assistant_id} cannot be modified" + raise ForbiddenError(error_msg) modified = assistant.modify(params) self._save(modified) return modified @@ -91,10 +102,19 @@ def delete( if assistant.workspace_id is None and not force: error_msg = f"Default assistant {assistant_id} cannot be deleted" raise ForbiddenError(error_msg) - self._get_assistant_path(assistant_id).unlink() + try: + self._get_assistant_path(assistant_id).unlink() + except FileNotFoundError: + # File already deleted, that's fine + pass except FileNotFoundError as e: error_msg = f"Assistant {assistant_id} not found" raise NotFoundError(error_msg) from e + except NotFoundError: + # If force=True and assistant doesn't exist, just ignore + if not force: + raise + # For force=True, we can ignore the NotFoundError def _save(self, assistant: Assistant, new: bool = False) -> None: self._assistants_dir.mkdir(parents=True, exist_ok=True) diff --git a/src/askui/chat/api/mcp_configs/service.py b/src/askui/chat/api/mcp_configs/service.py index 81513779..b93fde86 100644 --- a/src/askui/chat/api/mcp_configs/service.py +++ b/src/askui/chat/api/mcp_configs/service.py @@ -109,6 +109,9 @@ def modify( params: McpConfigModifyParams, ) -> McpConfig: mcp_config = self.retrieve(workspace_id, mcp_config_id) + if mcp_config.workspace_id is None: + error_msg = f"Default MCP configuration {mcp_config_id} cannot be modified" + raise ForbiddenError(error_msg) modified = mcp_config.modify(params) self._save(modified) return modified @@ -129,7 +132,11 @@ def delete( self._get_mcp_config_path(mcp_config_id).unlink() except FileNotFoundError as e: error_msg = f"MCP configuration {mcp_config_id} not found" - raise NotFoundError(error_msg) from e + if not force: + raise NotFoundError(error_msg) from e + except NotFoundError: + if not force: + raise def _save(self, mcp_config: McpConfig, new: bool = False) -> None: self._mcp_configs_dir.mkdir(parents=True, exist_ok=True) diff --git a/src/askui/chat/api/messages/chat_history_manager.py b/src/askui/chat/api/messages/chat_history_manager.py new file mode 100644 index 00000000..9642257c --- /dev/null +++ b/src/askui/chat/api/messages/chat_history_manager.py @@ -0,0 +1,91 @@ +from anthropic.types.beta import BetaTextBlockParam, BetaToolUnionParam + +from askui.chat.api.messages.models import Message, MessageCreateParams +from askui.chat.api.messages.service import MessageService +from askui.chat.api.messages.translator import MessageTranslator +from askui.chat.api.models import ThreadId +from askui.models.shared.agent_message_param import MessageParam +from askui.models.shared.truncation_strategies import TruncationStrategyFactory + + +class ChatHistoryManager: + """ + Manages chat history by providing methods to retrieve and add messages. + + This service encapsulates the interaction between MessageService and MessageTranslator + to provide a clean interface for managing chat history in the context of AI agents. + """ + + def __init__( + self, + message_service: MessageService, + message_translator: MessageTranslator, + truncation_strategy_factory: TruncationStrategyFactory, + ) -> None: + """ + Initialize the chat history manager. + + Args: + message_service (MessageService): Service for managing message persistence. + message_translator (MessageTranslator): Translator for converting between + message formats. + truncation_strategy_factory (TruncationStrategyFactory): Factory for creating truncation strategies. + """ + self._message_service = message_service + self._message_translator = message_translator + self._message_content_translator = message_translator.content_translator + self._truncation_strategy_factory = truncation_strategy_factory + + async def retrieve_message_params( + self, + thread_id: ThreadId, + model: str, + system: str | list[BetaTextBlockParam] | None, + tools: list[BetaToolUnionParam], + ) -> list[MessageParam]: + truncation_strategy = ( + self._truncation_strategy_factory.create_truncation_strategy( + system=system, + tools=tools, + messages=[], + model=model, + ) + ) + for msg in self._message_service.iter(thread_id=thread_id): + anthropic_message = await self._message_translator.to_anthropic(msg) + truncation_strategy.append_message(anthropic_message) + return truncation_strategy.messages + + async def append_message( + self, + thread_id: ThreadId, + assistant_id: str | None, + run_id: str, + message: MessageParam, + ) -> Message: + """ + Add a message to the chat history and return both the created message and original message param. + + This method creates a message in the database and returns both the created + message object and the original message parameter for further processing. + + Args: + thread_id (ThreadId): The thread ID to add the message to. + assistant_id (str | None): The assistant ID if the message is from an assistant. + run_id (str): The run ID associated with this message. + message (MessageParam): The message to add. + + Returns: + Message: The created message object + """ + return self._message_service.create( + thread_id=thread_id, + params=MessageCreateParams( + assistant_id=assistant_id if message.role == "assistant" else None, + role=message.role, + content=await self._message_content_translator.from_anthropic( + message.content + ), + run_id=run_id, + ), + ) diff --git a/src/askui/chat/api/messages/dependencies.py b/src/askui/chat/api/messages/dependencies.py index 62d36038..e22ea940 100644 --- a/src/askui/chat/api/messages/dependencies.py +++ b/src/askui/chat/api/messages/dependencies.py @@ -5,8 +5,13 @@ from askui.chat.api.dependencies import WorkspaceDirDep from askui.chat.api.files.dependencies import FileServiceDep from askui.chat.api.files.service import FileService +from askui.chat.api.messages.chat_history_manager import ChatHistoryManager from askui.chat.api.messages.service import MessageService from askui.chat.api.messages.translator import MessageTranslator +from askui.models.shared.truncation_strategies import ( + SimpleTruncationStrategyFactory, + TruncationStrategyFactory, +) def get_message_service( @@ -26,3 +31,25 @@ def get_message_translator( MessageTranslatorDep = Depends(get_message_translator) + + +def get_truncation_strategy_factory() -> TruncationStrategyFactory: + return SimpleTruncationStrategyFactory() + + +TruncationStrategyFactoryDep = Depends(get_truncation_strategy_factory) + + +def get_chat_history_manager( + message_service: MessageService = MessageServiceDep, + message_translator: MessageTranslator = MessageTranslatorDep, + truncation_strategy_factory: TruncationStrategyFactory = TruncationStrategyFactoryDep, +) -> ChatHistoryManager: + return ChatHistoryManager( + message_service=message_service, + message_translator=message_translator, + truncation_strategy_factory=truncation_strategy_factory, + ) + + +ChatHistoryManagerDep = Depends(get_chat_history_manager) diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py index 7821cc7b..1d2a4781 100644 --- a/src/askui/chat/api/messages/service.py +++ b/src/askui/chat/api/messages/service.py @@ -1,9 +1,12 @@ from pathlib import Path +from typing import Iterator from askui.chat.api.messages.models import Message, MessageCreateParams from askui.chat.api.models import MessageId, ThreadId from askui.utils.api_utils import ( + LIST_LIMIT_DEFAULT, ConflictError, + ListOrder, ListQuery, ListResponse, NotFoundError, @@ -40,6 +43,24 @@ def list_(self, thread_id: ThreadId, query: ListQuery) -> ListResponse[Message]: messages_dir = self.get_messages_dir(thread_id) return list_resources(messages_dir, query, Message) + def iter( + self, + thread_id: ThreadId, + order: ListOrder = "asc", + batch_size: int = LIST_LIMIT_DEFAULT, + ) -> Iterator[Message]: + has_more = True + last_id: str | None = None + while has_more: + list_messages_response = self.list_( + thread_id=thread_id, + query=ListQuery(limit=batch_size, order=order, after=last_id), + ) + has_more = list_messages_response.has_more + last_id = list_messages_response.last_id + for msg in list_messages_response.data: + yield msg + def retrieve(self, thread_id: ThreadId, message_id: MessageId) -> Message: try: message_file = self._get_message_path(thread_id, message_id) diff --git a/src/askui/chat/api/runs/dependencies.py b/src/askui/chat/api/runs/dependencies.py index fca6d6bc..ca8b1075 100644 --- a/src/askui/chat/api/runs/dependencies.py +++ b/src/askui/chat/api/runs/dependencies.py @@ -7,9 +7,8 @@ from askui.chat.api.dependencies import WorkspaceDirDep from askui.chat.api.mcp_clients.dependencies import McpClientManagerManagerDep from askui.chat.api.mcp_clients.manager import McpClientManagerManager -from askui.chat.api.messages.dependencies import MessageServiceDep, MessageTranslatorDep -from askui.chat.api.messages.service import MessageService -from askui.chat.api.messages.translator import MessageTranslator +from askui.chat.api.messages.chat_history_manager import ChatHistoryManager +from askui.chat.api.messages.dependencies import ChatHistoryManagerDep from .service import RunService @@ -17,17 +16,14 @@ def get_runs_service( workspace_dir: Path = WorkspaceDirDep, assistant_service: AssistantService = AssistantServiceDep, + chat_history_manager: ChatHistoryManager = ChatHistoryManagerDep, mcp_client_manager_manager: McpClientManagerManager = McpClientManagerManagerDep, - message_service: MessageService = MessageServiceDep, - message_translator: MessageTranslator = MessageTranslatorDep, ) -> RunService: - """Get RunService instance.""" return RunService( base_dir=workspace_dir, assistant_service=assistant_service, mcp_client_manager_manager=mcp_client_manager_manager, - message_service=message_service, - message_translator=message_translator, + chat_history_manager=chat_history_manager, ) diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index 8fd001d3..1acd6684 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -2,15 +2,14 @@ import logging from abc import ABC, abstractmethod +from anthropic.types.beta import BetaCacheControlEphemeralParam, BetaTextBlockParam from anyio.abc import ObjectStream from asyncer import asyncify, syncify from askui.chat.api.assistants.models import Assistant from askui.chat.api.assistants.seeds import ANDROID_AGENT from askui.chat.api.mcp_clients.manager import McpClientManagerManager -from askui.chat.api.messages.models import MessageCreateParams -from askui.chat.api.messages.service import MessageService -from askui.chat.api.messages.translator import MessageTranslator +from askui.chat.api.messages.chat_history_manager import ChatHistoryManager from askui.chat.api.models import RunId, ThreadId, WorkspaceId from askui.chat.api.runs.models import Run, RunError from askui.chat.api.runs.runner.events.done_events import DoneEvent @@ -28,7 +27,6 @@ from askui.models.shared.agent_on_message_cb import OnMessageCbParam from askui.models.shared.settings import ActSettings, MessageSettings from askui.models.shared.tools import Tool, ToolCollection -from askui.utils.api_utils import LIST_LIMIT_MAX, ListQuery logger = logging.getLogger(__name__) @@ -77,17 +75,14 @@ def __init__( workspace_id: WorkspaceId, assistant: Assistant, run: Run, - message_service: MessageService, - message_translator: MessageTranslator, + chat_history_manager: ChatHistoryManager, mcp_client_manager_manager: McpClientManagerManager, run_service: RunnerRunService, ) -> None: self._workspace_id = workspace_id self._assistant = assistant self._run = run - self._message_service = message_service - self._message_translator = message_translator - self._message_content_translator = message_translator.content_translator + self._chat_history_manager = chat_history_manager self._mcp_client_manager_manager = mcp_client_manager_manager self._run_service = run_service @@ -97,47 +92,53 @@ def _retrieve(self) -> Run: run_id=self._run.id, ) - def _build_system(self) -> str: - base_system = self._assistant.system or "" + def _build_system(self) -> list[BetaTextBlockParam]: metadata = { "run_id": str(self._run.id), "thread_id": str(self._run.thread_id), "workspace_id": str(self._workspace_id), "assistant_id": str(self._run.assistant_id), } - return f"{base_system}\n\nMetadata of current conversation: {json.dumps(metadata)}".strip() + return [ + *( + [ + BetaTextBlockParam( + type="text", + text=self._assistant.system, + ) + ] + if self._assistant.system + else [] + ), + BetaTextBlockParam( + type="text", + text="Metadata of current conversation: ", + ), + BetaTextBlockParam( + type="text", + text=json.dumps(metadata), + cache_control=BetaCacheControlEphemeralParam( + type="ephemeral", + ), + ), + ] async def _run_agent( self, send_stream: ObjectStream[Events], ) -> None: - messages: list[MessageParam] = [ - await self._message_translator.to_anthropic(msg) - for msg in self._message_service.list_( - thread_id=self._run.thread_id, - query=ListQuery(limit=LIST_LIMIT_MAX, order="asc"), - ).data - ] - async def async_on_message( on_message_cb_param: OnMessageCbParam, ) -> MessageParam | None: - message = self._message_service.create( + created_message = await self._chat_history_manager.append_message( thread_id=self._run.thread_id, - params=MessageCreateParams( - assistant_id=self._run.assistant_id - if on_message_cb_param.message.role == "assistant" - else None, - role=on_message_cb_param.message.role, - content=await self._message_content_translator.from_anthropic( - on_message_cb_param.message.content - ), - run_id=self._run.id, - ), + assistant_id=self._run.assistant_id, + run_id=self._run.id, + message=on_message_cb_param.message, ) await send_stream.send( MessageEvent( - data=message, + data=created_message, event="thread.message.created", ) ) @@ -149,7 +150,6 @@ async def async_on_message( return on_message_cb_param.message on_message = syncify(async_on_message) - mcp_client = await self._mcp_client_manager_manager.get_mcp_client_manager( self._workspace_id ) @@ -159,10 +159,18 @@ def _run_agent_inner() -> None: mcp_client=mcp_client, include=set(self._assistant.tools), ) + betas = tools.retrieve_tool_beta_flags() # Remove this after having extracted tools into Android MCP if self._run.assistant_id == ANDROID_AGENT.id: tools.append_tool(*_get_android_tools()) - betas = tools.retrieve_tool_beta_flags() + system = self._build_system() + model = str(ModelName.CLAUDE__SONNET__4__20250514) + messages = syncify(self._chat_history_manager.retrieve_message_params)( + thread_id=self._run.thread_id, + tools=tools.to_params(), + system=system, + model=model, + ) custom_agent = CustomAgent() custom_agent.act( messages, @@ -172,9 +180,10 @@ def _run_agent_inner() -> None: settings=ActSettings( messages=MessageSettings( betas=betas, - model=ModelName.CLAUDE__SONNET__4__20250514, - system=self._build_system(), - thinking={"type": "enabled", "budget_tokens": 2048}, + model=model, + system=system, + thinking={"type": "enabled", "budget_tokens": 4096}, + max_tokens=8192, ), ), ) diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py index 98c72afd..58488eb4 100644 --- a/src/askui/chat/api/runs/service.py +++ b/src/askui/chat/api/runs/service.py @@ -7,8 +7,7 @@ from askui.chat.api.assistants.service import AssistantService from askui.chat.api.mcp_clients.manager import McpClientManagerManager -from askui.chat.api.messages.service import MessageService -from askui.chat.api.messages.translator import MessageTranslator +from askui.chat.api.messages.chat_history_manager import ChatHistoryManager from askui.chat.api.models import RunId, ThreadId, WorkspaceId from askui.chat.api.runs.models import Run, RunCreateParams from askui.chat.api.runs.runner.events.events import ( @@ -35,14 +34,12 @@ def __init__( base_dir: Path, assistant_service: AssistantService, mcp_client_manager_manager: McpClientManagerManager, - message_service: MessageService, - message_translator: MessageTranslator, + chat_history_manager: ChatHistoryManager, ) -> None: self._base_dir = base_dir self._assistant_service = assistant_service self._mcp_client_manager_manager = mcp_client_manager_manager - self._message_service = message_service - self._message_translator = message_translator + self._chat_history_manager = chat_history_manager def get_runs_dir(self, thread_id: ThreadId) -> Path: return self._base_dir / "runs" / thread_id @@ -80,8 +77,7 @@ async def create( workspace_id=workspace_id, assistant=assistant, run=run, - message_service=self._message_service, - message_translator=self._message_translator, + chat_history_manager=self._chat_history_manager, mcp_client_manager_manager=self._mcp_client_manager_manager, run_service=self, ) diff --git a/src/askui/models/shared/agent.py b/src/askui/models/shared/agent.py index 6a5a8e0e..1b953556 100644 --- a/src/askui/models/shared/agent.py +++ b/src/askui/models/shared/agent.py @@ -2,11 +2,7 @@ from askui.models.exceptions import MaxTokensExceededError, ModelRefusalError from askui.models.models import ActModel -from askui.models.shared.agent_message_param import ( - ImageBlockParam, - MessageParam, - TextBlockParam, -) +from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import ( NULL_ON_MESSAGE_CB, OnMessageCb, @@ -15,6 +11,11 @@ from askui.models.shared.messages_api import MessagesApi from askui.models.shared.settings import ActSettings from askui.models.shared.tools import ToolCollection +from askui.models.shared.truncation_strategies import ( + SimpleTruncationStrategyFactory, + TruncationStrategy, + TruncationStrategyFactory, +) from askui.reporting import NULL_REPORTER, Reporter from ...logger import logger @@ -30,23 +31,31 @@ class Agent(ActModel): messages_api (MessagesApi): Messages API for creating messages. reporter (Reporter, optional): The reporter for logging messages and actions. Defaults to `NULL_REPORTER`. + truncation_strategy (TruncationStrategyFactory, optional): The truncation + strategy factory to use. This is used to create the truncation strategy + to truncate the message history before sending it to the model. + Defaults to `SimpleTruncationStrategyFactory`. """ def __init__( self, messages_api: MessagesApi, reporter: Reporter = NULL_REPORTER, + truncation_strategy_factory: TruncationStrategyFactory | None = None, ) -> None: self._messages_api = messages_api self._reporter = reporter + self._truncation_strategy_factory = ( + truncation_strategy_factory or SimpleTruncationStrategyFactory() + ) def _step( self, - messages: list[MessageParam], model: str, on_message: OnMessageCb, settings: ActSettings, tool_collection: ToolCollection, + truncation_strategy: TruncationStrategy, ) -> None: """Execute a single step in the conversation. @@ -55,26 +64,19 @@ def _step( upon. Args: - messages (list[MessageParam]): The message history. - Contains at least one message. model (str): The model to use for message creation. on_message (OnMessageCb): Callback on new messages settings (AgentSettings): The settings for the step. tool_collection (ToolCollection): The tools to use for the step. + truncation_strategy (TruncationStrategy): The truncation strategy to use + for the step. Returns: None """ - if settings.only_n_most_recent_images: - messages = self._maybe_filter_to_n_most_recent_images( - messages, - settings.only_n_most_recent_images, - settings.image_truncation_threshold, - ) - - if messages[-1].role == "user": + if truncation_strategy.messages[-1].role == "user": response_message = self._messages_api.create_message( - messages=messages, + messages=truncation_strategy.messages, model=model, tools=tool_collection, max_tokens=settings.messages.max_tokens, @@ -84,35 +86,34 @@ def _step( tool_choice=settings.messages.tool_choice, ) message_by_assistant = self._call_on_message( - on_message, response_message, messages + on_message, response_message, truncation_strategy.messages ) if message_by_assistant is None: return message_by_assistant_dict = message_by_assistant.model_dump(mode="json") logger.debug(message_by_assistant_dict) - messages.append(message_by_assistant) + truncation_strategy.append_message(message_by_assistant) self._reporter.add_message( self.__class__.__name__, message_by_assistant_dict ) else: - message_by_assistant = messages[-1] - + message_by_assistant = truncation_strategy.messages[-1] self._handle_stop_reason(message_by_assistant, settings.messages.max_tokens) if tool_result_message := self._use_tools( message_by_assistant, tool_collection ): if tool_result_message := self._call_on_message( - on_message, tool_result_message, messages + on_message, tool_result_message, truncation_strategy.messages ): tool_result_message_dict = tool_result_message.model_dump(mode="json") logger.debug(tool_result_message_dict) - messages.append(tool_result_message) + truncation_strategy.append_message(tool_result_message) self._step( - messages=messages, model=model, tool_collection=tool_collection, on_message=on_message, settings=settings, + truncation_strategy=truncation_strategy, ) def _call_on_message( @@ -135,12 +136,22 @@ def act( settings: ActSettings | None = None, ) -> None: _settings = settings or ActSettings() + _model = _settings.messages.model or model_choice + _tool_collection = tools or ToolCollection() + truncation_strategy = ( + self._truncation_strategy_factory.create_truncation_strategy( + tools=_tool_collection.to_params(), + system=_settings.messages.system or None, + messages=messages, + model=_model, + ) + ) self._step( - messages=messages, - model=_settings.messages.model or model_choice, + model=_model, on_message=on_message or NULL_ON_MESSAGE_CB, settings=_settings, - tool_collection=tools or ToolCollection(), + tool_collection=_tool_collection, + truncation_strategy=truncation_strategy, ) def _use_tools( @@ -174,62 +185,6 @@ def _use_tools( role="user", ) - @staticmethod - def _maybe_filter_to_n_most_recent_images( - messages: list[MessageParam], - images_to_keep: int | None, - min_removal_threshold: int, - ) -> list[MessageParam]: - """ - Filter the message history in-place to keep only the most recent images, - according to the given chunking policy. - - Args: - messages (list[MessageParam]): The message history. - images_to_keep (int | None): Number of most recent images to keep. - min_removal_threshold (int): Minimum number of images to remove at once. - - Returns: - list[MessageParam]: The filtered message history. - """ - if images_to_keep is None: - return messages - - tool_result_blocks = [ - item - for message in messages - for item in (message.content if isinstance(message.content, list) else []) - if item.type == "tool_result" - ] - total_images = sum( - 1 - for tool_result in tool_result_blocks - if not isinstance(tool_result.content, str) - for content in tool_result.content - if content.type == "image" - ) - images_to_remove = total_images - images_to_keep - if images_to_remove < min_removal_threshold: - return messages - # for better cache behavior, we want to remove in chunks - images_to_remove -= images_to_remove % min_removal_threshold - if images_to_remove <= 0: - return messages - - # Remove images from the oldest tool_result blocks first - for tool_result in tool_result_blocks: - if images_to_remove <= 0: - break - if isinstance(tool_result.content, list): - new_content: list[TextBlockParam | ImageBlockParam] = [] - for content in tool_result.content: - if content.type == "image" and images_to_remove > 0: - images_to_remove -= 1 - continue - new_content.append(content) - tool_result.content = new_content - return messages - def _handle_stop_reason(self, message: MessageParam, max_tokens: int) -> None: if message.stop_reason == "max_tokens": raise MaxTokensExceededError(max_tokens) diff --git a/src/askui/models/shared/settings.py b/src/askui/models/shared/settings.py index 41b4bc8a..d6a31026 100644 --- a/src/askui/models/shared/settings.py +++ b/src/askui/models/shared/settings.py @@ -32,5 +32,3 @@ class ActSettings(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) messages: MessageSettings = Field(default_factory=MessageSettings) - only_n_most_recent_images: int = 3 - image_truncation_threshold: int = 10 diff --git a/src/askui/models/shared/token_counter.py b/src/askui/models/shared/token_counter.py new file mode 100644 index 00000000..592e9af9 --- /dev/null +++ b/src/askui/models/shared/token_counter.py @@ -0,0 +1,262 @@ +import base64 +import json +from abc import ABC, abstractmethod + +import httpx +from anthropic.types.beta import BetaTextBlockParam, BetaToolUnionParam +from typing_extensions import override + +from askui.models.shared.agent_message_param import ( + Base64ImageSourceParam, + ContentBlockParam, + ImageBlockParam, + MessageParam, + ToolResultBlockParam, +) +from askui.utils.image_utils import base64_to_image + + +class TokenCounts: + """Token counts for a message.""" + + def __init__( + self, + system: int, + tools: int, + messages: list[int], + ) -> None: + self._system = system + self._tools = tools + self._messages = messages + self._static = system + tools + self._total = self._static + sum(messages) + + def append_message_tokens(self, tokens: int) -> None: + self._messages.append(tokens) + self._total += tokens + + def retrieve_message_tokens(self, index: int) -> int: + return self._messages[index] + + def reset_message_tokens(self, tokens: list[int]) -> None: + self._messages = tokens + self._total = self._static + sum(tokens) + + @property + def total(self) -> int: + return self._total + + @property + def static(self) -> int: + return self._static + + +class TokenCounter(ABC): + @abstractmethod + def count_tokens( + self, + tools: list[BetaToolUnionParam] | None = None, + system: str | list[BetaTextBlockParam] | None = None, + messages: list[MessageParam] | None = None, + ) -> TokenCounts: + """Count total tokens (estimated) using simple string length estimation. + + Args: + tools (list[BetaToolUnionParam] | None, optional): The tools to count + tokens for. Defaults to `None`. + system (str | list[BetaTextBlockParam] | None, optional): The system + prompt or system blocks to count tokens for. Defaults to `None`. + messages (list[MessageParam] | None, optional): + The messages to count tokens for. Defaults to `None`. + model (str | None, optional): The model to count tokens for. + Defaults to `None`. + + Returns: + int: The total estimated number of tokens across all components. + """ + raise NotImplementedError + + +class SimpleTokenCounter(TokenCounter): + """Simple token counter implementation that estimates tokens by dividing string + length by 3. + + This is a basic approximation that assumes roughly 3 characters per token + on average.For more accurate token counting, consider using model-specific + tokenizers. + """ + + def __init__(self, chars_per_token: float = 3.0) -> None: + """Initialize the simple token counter. + + Args: + chars_per_token (float, optional): The estimated characters per token. + Defaults to `3.0`. + """ + self._chars_per_token = chars_per_token + self._url_cache: dict[str, tuple[int, int] | None] = {} + + def _get_image_dimensions_from_url(self, url: str) -> tuple[int, int] | None: + """Fetch image dimensions from a URL with caching. + + Args: + url (str): The URL of the image to fetch. + + Returns: + tuple[int, int] | None: The (width, height) of the image, or None if + fetching fails. + """ + # Check cache first + if url in self._url_cache: + return self._url_cache[url] + + try: + with httpx.Client(timeout=10.0) as client: + response = client.get(url) + response.raise_for_status() + + # Check if the response is actually an image + content_type = response.headers.get("content-type", "").lower() + if not content_type.startswith("image/"): + self._url_cache[url] = None + return None + + # Convert response content to PIL Image to get dimensions + image_data_base64 = base64.b64encode(response.content).decode("utf-8") + image = base64_to_image(image_data_base64) + dimensions = image.size + self._url_cache[url] = dimensions + return dimensions + except (httpx.HTTPError, httpx.TimeoutException, ValueError, TypeError): + # If fetching fails, cache None and return None to fall back to estimation + self._url_cache[url] = None + return None + + @override + def count_tokens( + self, + tools: list[BetaToolUnionParam] | None = None, + system: str | list[BetaTextBlockParam] | None = None, + messages: list[MessageParam] | None = None, + model: str | None = None, # noqa: ARG002 + ) -> TokenCounts: + system_tokens = 0 + tools_tokens = 0 + message_tokens = [] + if tools: + tools_str = self._stringify_object(tools) + tools_tokens = int(len(tools_str) / self._chars_per_token) + if system: + system_str = self._stringify_object(system) + system_tokens = int(len(system_str) / self._chars_per_token) + if messages: + message_tokens = [ + self._count_tokens_for_message(message) for message in messages + ] + return TokenCounts( + system=system_tokens, + tools=tools_tokens, + messages=message_tokens, + ) + + def _count_tokens_for_message(self, message: MessageParam) -> int: + """Count tokens for a message by processing content blocks individually. + + For image blocks, uses the formula: tokens = (width * height) / 750 (see https://docs.anthropic.com/en/docs/build-with-claude/vision) + For other content types, uses the standard character-based estimation. + + Args: + message (MessageParam): The message to count tokens for. + + Returns: + int: The estimated number of tokens for the message. + """ + if isinstance(message.content, str): + # Simple string content - use standard estimation + return int(len(message.content) / self._chars_per_token) + + # base tokens for rest of message + total_tokens = 10 + # Content is a list of blocks - process each individually + for block in message.content: + total_tokens += self._count_tokens_for_content_block(block) + + return total_tokens + + def _count_tokens_for_content_block(self, block: ContentBlockParam) -> int: + """Count tokens for a single content block. + + Args: + block (ContentBlockParam): The content block to count tokens for. + + Returns: + int: The estimated number of tokens for the block. + """ + if isinstance(block, ImageBlockParam): + return self._count_tokens_for_image_block(block) + + if isinstance(block, ToolResultBlockParam): + # Tool result blocks can contain text or nested content blocks + if isinstance(block.content, str): + return int(len(block.content) / self._chars_per_token) + + # base tokens for tool result block + total_tokens = 20 + # Recursively count nested content blocks + for nested_block in block.content: + total_tokens += self._count_tokens_for_content_block(nested_block) + return total_tokens + + # For other block types, use string representation + return int(len(self._stringify_object(block)) / self._chars_per_token) + + def _count_tokens_for_image_block(self, block: ImageBlockParam) -> int: + """Count tokens for an image block using Anthropic's formula. + + Uses the formula: tokens = (width * height) / 750 + + Args: + block (ImageBlockParam): The image block to count tokens for. + + Returns: + int: The estimated number of tokens for the image. + """ + # If fetching fails, fall back to estimation + # Assume average image size of ~4 megapixel (2000x2000) for URL images + estimated_tokens = int((2000 * 2000) / 750) + try: + if isinstance(block.source, Base64ImageSourceParam): + # Decode base64 image to get dimensions + image = base64_to_image(block.source.data) + width, height = image.size + return int((width * height) / 750) + + # For URL-based images, try to fetch the image to get actual dimensions + dimensions = self._get_image_dimensions_from_url(block.source.url) + if dimensions is not None: + width, height = dimensions + return int((width * height) / 750) + + except (ValueError, TypeError, AttributeError): + # If image processing fails, fall back to string-based estimation + return int(len(self._stringify_object(block)) / self._chars_per_token) + return estimated_tokens + + def _stringify_object(self, obj: object) -> str: + """Convert any object to a string representation for token counting. + + Not whitespace in dumped jsons between object keys and values and among array + elements. + + Args: + obj (object): The object to stringify. + + Returns: + str: String representation of the object. + """ + if isinstance(obj, str): + return obj + try: + return json.dumps(obj, separators=(",", ":")) + except (TypeError, ValueError): + return str(obj) diff --git a/src/askui/models/shared/tools.py b/src/askui/models/shared/tools.py index 4dc4bf37..361ef3d8 100644 --- a/src/askui/models/shared/tools.py +++ b/src/askui/models/shared/tools.py @@ -6,7 +6,11 @@ import jsonref import mcp -from anthropic.types.beta import BetaToolParam, BetaToolUnionParam +from anthropic.types.beta import ( + BetaCacheControlEphemeralParam, + BetaToolParam, + BetaToolUnionParam, +) from anthropic.types.beta.beta_tool_param import InputSchema from asyncer import syncify from fastmcp.client.client import CallToolResult, ProgressHandler @@ -227,7 +231,12 @@ def to_params(self) -> list[BetaToolUnionParam]: for tool_name, tool in tool_map.items() if self._include is None or tool_name in self._include } - return list(filtered_tool_map.values()) + result = list(filtered_tool_map.values()) + if result: + result[-1]["cache_control"] = BetaCacheControlEphemeralParam( + type="ephemeral", + ) + return result def _get_mcp_tool_params(self) -> dict[str, BetaToolUnionParam]: if not self._mcp_client: @@ -238,6 +247,7 @@ def _get_mcp_tool_params(self) -> dict[str, BetaToolUnionParam]: if params := (tool.meta or {}).get("params"): # validation missing result[tool_name] = params + continue result[tool_name] = BetaToolParam( name=tool_name, description=tool.description or "", diff --git a/src/askui/models/shared/truncation_strategies.py b/src/askui/models/shared/truncation_strategies.py new file mode 100644 index 00000000..935ebcec --- /dev/null +++ b/src/askui/models/shared/truncation_strategies.py @@ -0,0 +1,376 @@ +from dataclasses import dataclass +from typing import Annotated + +from anthropic.types.beta import BetaTextBlockParam, BetaToolUnionParam +from pydantic import Field +from typing_extensions import override + +from askui.models.shared.agent_message_param import ( + CacheControlEphemeralParam, + MessageParam, + TextBlockParam, +) +from askui.models.shared.token_counter import SimpleTokenCounter, TokenCounter + +# needs to be below limits imposed by endpoint +MAX_INPUT_TOKENS = 100_000 + +# see https://docs.anthropic.com/en/api/messages#body-messages +MAX_MESSAGES = 100_000 + + +class TruncationStrategy: + """Abstract base class for truncation strategies.""" + + def __init__( + self, + tools: list[BetaToolUnionParam] | None, + system: str | list[BetaTextBlockParam] | None, + messages: list[MessageParam], + model: str, + ) -> None: + self._tools = tools + self._messages = messages + self._system = system + self._model = model + + def append_message(self, message: MessageParam) -> None: + self._messages.append(message) + + @property + def messages(self) -> list[MessageParam]: + """Get the truncated messages.""" + return self._messages + + +def _is_tool_result_user_message(message: MessageParam) -> bool: + return message.role == "user" and ( + isinstance(message.content, list) + and any(block.type == "tool_result" for block in message.content) + ) + + +def _is_tool_use_assistant_message(message: MessageParam) -> bool: + return message.role == "assistant" and ( + isinstance(message.content, list) + and any(block.type == "tool_use" for block in message.content) + ) + + +def _is_end_of_loop( + message: MessageParam, previous_message: MessageParam | None +) -> bool: + return ( + not _is_tool_result_user_message(message) + and previous_message is not None + and previous_message.role == "assistant" + ) + + +@dataclass(kw_only=True) +class MessageContainer: + index: int + message: MessageParam + tokens: int + + +class SimpleTruncationStrategy(TruncationStrategy): + """Simple truncation strategy that truncates messages to stay within token and + message limits. + + Clusters messages into "tool calling loops" - sequences of messages starting with + a user message (not containing `tool_result` blocks) or the first message, and + ending with an assistant message before the next such user message or the last + message. + + The last tool calling loop is called the "open loop" and represents the current + conversation context being worked on. + + Truncation follows this priority order until both token and message thresholds + are met: + 1. Remove tool calling turns (assistant tool_use + user tool_result pairs) + from closed loops + 2. Remove entire closed loops (except first and last which usually contain + the most important context) + 3. Remove the first loop if it's not the open loop + 4. Remove tool calling turns from the open loop (except the first and last turn) + - We need to preserve the thinking block in first turn of open loop. + - Also these are the blocks with the most important context. + 5. Raise ValueError if still exceeds limits after all truncation attempts + + We truncate until a threshold that is way below the limits to make sure that + the threshold is not reached immediately afterwards again and caching can work + in that time. + + Args: + tools (list[BetaToolUnionParam] | None): Available tools for the conversation + system (str | list[BetaTextBlockParam] | None): System prompt or blocks + messages (list[MessageParam]): Initial conversation messages + model (str): Model name for token counting + max_input_tokens (int, optional): Maximum input tokens allowed. Defaults to + 100,000. + input_token_truncation_threshold (float, optional): Fraction of max tokens to + truncate at. Defaults to 0.75. + max_messages (int, optional): Maximum messages allowed. Defaults to 100,000. + message_truncation_threshold (float, optional): Fraction of max messages to + truncate at. Defaults to 0.75. + token_counter (TokenCounter | None, optional): Token counter instance. Defaults + to SimpleTokenCounter. + + Raises: + ValueError: If conversation cannot be truncated below limits after all attempts. + """ + + def __init__( + self, + tools: list[BetaToolUnionParam] | None, + system: str | list[BetaTextBlockParam] | None, + messages: list[MessageParam], + model: str, + max_input_tokens: int = MAX_INPUT_TOKENS, + input_token_truncation_threshold: Annotated[ + float, Field(gt=0.0, lt=1.0) + ] = 0.75, + max_messages: int = MAX_MESSAGES, + message_truncation_threshold: Annotated[float, Field(gt=0.0, lt=1.0)] = 0.75, + token_counter: TokenCounter | None = None, + ) -> None: + super().__init__( + tools=tools, + system=system, + messages=messages, + model=model, + ) + self._max_input_tokens = max_input_tokens + self._max_input_tokens_after_truncation = int( + input_token_truncation_threshold * max_input_tokens + ) + self._max_messages = max_messages + self._max_messages_after_truncation = int( + message_truncation_threshold * max_messages + ) + self._token_counter = token_counter or SimpleTokenCounter() + self._token_counts = self._token_counter.count_tokens( + tools=tools, + system=system, + messages=messages, + ) + + @override + def append_message(self, message: MessageParam) -> None: + super().append_message(message) + self._token_counts.append_message_tokens( + self._token_counter.count_tokens(messages=[message]).total + ) + if self._should_truncate(): + self._truncate() + + def _should_truncate(self) -> bool: + return ( + self._token_counts.total > self._max_input_tokens + or len(self._messages) > self._max_messages + ) + + @property + @override + def messages(self) -> list[MessageParam]: + self._move_cache_control_to_last_non_tool_result_user_message() + return self._messages + + def _move_cache_control_to_last_non_tool_result_user_message(self) -> None: + found_last = False + for message in reversed(self._messages): + if message.role == "user" and not _is_tool_result_user_message(message): + if not found_last: + found_last = True + if isinstance(message.content, str): + message.content = [ + TextBlockParam( + text=message.content, + cache_control=CacheControlEphemeralParam( + type="ephemeral", + ), + ) + ] + elif len(message.content) > 0: + last_content = message.content[-1] + if hasattr(last_content, "cache_control"): + last_content.cache_control = CacheControlEphemeralParam( + type="ephemeral", + ) + else: + if isinstance(message.content, list) and message.content: + last_content = message.content[-1] + if hasattr(last_content, "cache_control"): + last_content.cache_control = None + break + + def _truncate(self) -> None: # noqa: C901 + messages_to_remove_min = min( + len(self._messages) - self._max_messages_after_truncation, 0 + ) + tokens_to_remove_min = max( + self._token_counts.total - self._max_input_tokens_after_truncation, 0 + ) + messages_removed_indices: set[int] = set() + tokens_removed = 0 + loops = self._cluster_into_tool_calling_loops() + + # 1. Remove tool calling turns within closed loops + last_message_was_tool_use_assistant_message = False + for closed_loop in loops[:-1]: + for message_container in closed_loop: + if last_message_was_tool_use_assistant_message: + messages_removed_indices.add(message_container.index) + tokens_removed += message_container.tokens + if ( + len(messages_removed_indices) >= messages_to_remove_min + or tokens_removed >= tokens_to_remove_min + ): + self._remove_messages(messages_removed_indices) + return + + last_message_was_tool_use_assistant_message = False + if _is_tool_use_assistant_message(message_container.message): + last_message_was_tool_use_assistant_message = True + messages_removed_indices.add(message_container.index) + tokens_removed += message_container.tokens + + # 2. Remove loops except first and last (open) loop + for closed_loop in loops[1:-1]: + for message_container in closed_loop: + if message_container.index not in messages_removed_indices: + messages_removed_indices.add(message_container.index) + tokens_removed += message_container.tokens + if ( + len(messages_removed_indices) >= messages_to_remove_min + or tokens_removed >= tokens_to_remove_min + ): + self._remove_messages(messages_removed_indices) + return + + # 3. Remove first loop if it is not the last (open) loop + if len(loops) > 1: + for message_container in loops[0]: + if message_container.index not in messages_removed_indices: + messages_removed_indices.add(message_container.index) + tokens_removed += message_container.tokens + if ( + len(messages_removed_indices) >= messages_to_remove_min + or tokens_removed >= tokens_to_remove_min + ): + self._remove_messages(messages_removed_indices) + return + + # 4. Remove tool calling turns within open loop except last turn + if len(loops) > 0: + open_loop = loops[-1] + last_message_was_tool_use_assistant_message = False + for i, message_container in enumerate(open_loop): + if last_message_was_tool_use_assistant_message: + messages_removed_indices.add(message_container.index) + tokens_removed += message_container.tokens + if ( + len(messages_removed_indices) >= messages_to_remove_min + or tokens_removed >= tokens_to_remove_min + ): + self._remove_messages(messages_removed_indices) + return + + last_message_was_tool_use_assistant_message = False + if ( + _is_tool_use_assistant_message(message_container.message) + and 1 < i < len(open_loop) - 2 + ): + last_message_was_tool_use_assistant_message = True + messages_removed_indices.add(message_container.index) + tokens_removed += message_container.tokens + + # Everything that is left is the last non-tool-result user message + # and the last (open or closed) tool calling turn (if there is one) + error_msg = "Conversation too long. Please start a new conversation." + raise ValueError(error_msg) + + def _remove_messages(self, indices: set[int]) -> None: + self._token_counts.reset_message_tokens( + [ + self._token_counts.retrieve_message_tokens(i) + for i, _ in enumerate(self._messages) + if i not in indices + ] + ) + self._messages = [ + message for i, message in enumerate(self._messages) if i not in indices + ] + + def _cluster_into_tool_calling_loops(self) -> list[list[MessageContainer]]: + loops: list[list[MessageContainer]] = [] + current_loop: list[MessageContainer] = [] + for i, message in enumerate(self._messages): + if _is_end_of_loop( + message, current_loop[-1].message if current_loop else None + ): + loops.append(current_loop) + current_loop = [] + current_loop.append( + MessageContainer( + index=i, + message=message, + tokens=self._token_counts.retrieve_message_tokens(i), + ), + ) + loops.append(current_loop) + return loops + + +class TruncationStrategyFactory: + def create_truncation_strategy( + self, + tools: list[BetaToolUnionParam] | None, + system: str | list[BetaTextBlockParam] | None, + messages: list[MessageParam], + model: str, + ) -> TruncationStrategy: + return TruncationStrategy( + tools=tools, + system=system, + messages=messages, + model=model, + ) + + +class SimpleTruncationStrategyFactory(TruncationStrategyFactory): + def __init__( + self, + max_input_tokens: int = MAX_INPUT_TOKENS, + input_token_truncation_threshold: Annotated[ + float, Field(gt=0.0, lt=1.0) + ] = 0.75, + max_messages: int = MAX_MESSAGES, + message_truncation_threshold: Annotated[float, Field(gt=0.0, lt=1.0)] = 0.75, + token_counter: TokenCounter | None = None, + ) -> None: + self._max_input_tokens = max_input_tokens + self._input_token_truncation_threshold = input_token_truncation_threshold + self._max_messages = max_messages + self._message_truncation_threshold = message_truncation_threshold + self._token_counter = token_counter or SimpleTokenCounter() + + def create_truncation_strategy( + self, + tools: list[BetaToolUnionParam] | None, + system: str | list[BetaTextBlockParam] | None, + messages: list[MessageParam], + model: str, + ) -> TruncationStrategy: + return SimpleTruncationStrategy( + tools=tools, + system=system, + messages=messages, + model=model, + max_input_tokens=self._max_input_tokens, + input_token_truncation_threshold=self._input_token_truncation_threshold, + max_messages=self._max_messages, + message_truncation_threshold=self._message_truncation_threshold, + token_counter=self._token_counter, + ) diff --git a/src/askui/utils/api_utils.py b/src/askui/utils/api_utils.py index 84ae2610..3f433c25 100644 --- a/src/askui/utils/api_utils.py +++ b/src/askui/utils/api_utils.py @@ -8,6 +8,7 @@ ListOrder = Literal["asc", "desc"] LIST_LIMIT_MAX = 100 +LIST_LIMIT_DEFAULT = 20 Id = TypeVar("Id", bound=str, default=str) @@ -15,7 +16,7 @@ @dataclass(kw_only=True) class ListQuery(Generic[Id]): - limit: Annotated[int, Query(ge=1, le=LIST_LIMIT_MAX)] = 20 + limit: Annotated[int, Query(ge=1, le=LIST_LIMIT_MAX)] = LIST_LIMIT_DEFAULT after: Annotated[Id | None, Query()] = None before: Annotated[Id | None, Query()] = None order: Annotated[ListOrder, Query()] = "desc" @@ -58,19 +59,48 @@ def __init__(self, max_size: int): super().__init__(f"File too large. Maximum size is {max_size} bytes.") +def _build_after_fn(after: str, order: ListOrder) -> Callable[[Path], bool]: + after_name = f"{after}.json" + if order == "asc": + return lambda f: f.name > after_name + # desc - "after" means files that come before in the sorted list + return lambda f: f.name < after_name + + +def _build_before_fn(before: str, order: ListOrder) -> Callable[[Path], bool]: + before_name = f"{before}.json" + if order == "asc": + return lambda f: f.name < before_name + return lambda f: f.name > before_name + + +def _build_list_filter_fn(list_query: ListQuery) -> Callable[[Path], bool]: + after_fn = ( + _build_after_fn(list_query.after, list_query.order) + if list_query.after + else None + ) + before_fn = ( + _build_before_fn(list_query.before, list_query.order) + if list_query.before + else None + ) + if after_fn and before_fn: + return lambda f: after_fn(f) and before_fn(f) + if after_fn: + return after_fn + if before_fn: + return before_fn + return lambda _: True + + def list_resource_paths(base_dir: Path, list_query: ListQuery) -> list[Path]: paths: list[Path] = [] - after_name = f"{list_query.after}.json" - before_name = f"{list_query.before}.json" + filter_fn = _build_list_filter_fn(list_query) for f in base_dir.glob("*.json"): try: - if list_query.after: - if f.name <= after_name: - continue - if list_query.before: - if f.name >= before_name: - continue - paths.append(f) + if filter_fn(f): + paths.append(f) except ValidationError: # noqa: PERF203 continue return sorted(paths, key=lambda f: f.name, reverse=(list_query.order == "desc")) diff --git a/tests/integration/chat/api/test_files.py b/tests/integration/chat/api/test_files.py index c15b4511..4496794c 100644 --- a/tests/integration/chat/api/test_files.py +++ b/tests/integration/chat/api/test_files.py @@ -533,11 +533,9 @@ def override_file_service() -> FileService: ) assert response.status_code == status.HTTP_200_OK data = response.json() - assert len(data["data"]) == 2 - # After file_test0 should return file_test1 and file_test2 in - # descending order - assert data["data"][0]["id"] == "file_test2" - assert data["data"][1]["id"] == "file_test1" + # In descending lexicographic order, file_test0 is the last file, + # so there are no files "after" it + assert len(data["data"]) == 0 # Test with before parameter response = client.get( @@ -545,11 +543,9 @@ def override_file_service() -> FileService: ) assert response.status_code == status.HTTP_200_OK data = response.json() - assert len(data["data"]) == 2 - # Before file_test2 should return file_test0 and file_test1 in - # descending order - assert data["data"][0]["id"] == "file_test1" - assert data["data"][1]["id"] == "file_test0" + # In descending lexicographic order, file_test2 is the first file, + # so there are no files "before" it + assert len(data["data"]) == 0 finally: # Clean up dependency overrides app.dependency_overrides.clear() diff --git a/tests/integration/chat/api/test_runs.py b/tests/integration/chat/api/test_runs.py index 21ffa310..f867d30b 100644 --- a/tests/integration/chat/api/test_runs.py +++ b/tests/integration/chat/api/test_runs.py @@ -55,14 +55,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -127,14 +124,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -200,14 +194,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -256,14 +247,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -322,14 +310,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -382,14 +367,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -431,14 +413,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -492,14 +471,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -544,14 +520,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -599,14 +572,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -667,14 +637,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -717,14 +684,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -804,14 +768,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -890,14 +851,11 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() - mock_message_service = Mock() - mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) app.dependency_overrides[get_thread_service] = override_thread_service @@ -978,8 +936,6 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: - mock_message_service = Mock() - mock_message_translator = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() from askui.chat.api.assistants.service import AssistantService @@ -987,8 +943,7 @@ def override_runs_service() -> RunService: base_dir=workspace_path, assistant_service=AssistantService(workspace_path), mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) def override_assistant_service() -> AssistantService: @@ -1066,8 +1021,6 @@ def override_thread_service() -> ThreadService: return ThreadService(workspace_path, mock_message_service, mock_run_service) def override_runs_service() -> RunService: - mock_message_service = Mock() - mock_message_translator = Mock() mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() from askui.chat.api.assistants.service import AssistantService @@ -1075,8 +1028,7 @@ def override_runs_service() -> RunService: base_dir=workspace_path, assistant_service=AssistantService(workspace_path), mcp_client_manager_manager=mock_mcp_client_manager_manager, - message_service=mock_message_service, - message_translator=mock_message_translator, + chat_history_manager=Mock(), ) def override_assistant_service() -> AssistantService: diff --git a/tests/unit/models/test_agent_filter.py b/tests/unit/models/test_agent_filter.py deleted file mode 100644 index 4787445d..00000000 --- a/tests/unit/models/test_agent_filter.py +++ /dev/null @@ -1,130 +0,0 @@ -from askui.models.shared.agent import Agent -from askui.models.shared.agent_message_param import ( - Base64ImageSourceParam, - ImageBlockParam, - MessageParam, - TextBlockParam, - ToolResultBlockParam, -) - - -def make_image_block() -> ImageBlockParam: - return ImageBlockParam( - source=Base64ImageSourceParam( - media_type="image/png", - data="abc", - ), - ) - - -def make_tool_result_block(num_images: int, num_texts: int = 0) -> ToolResultBlockParam: - content = [make_image_block() for _ in range(num_images)] + [ - TextBlockParam(text=f"text{i}") for i in range(num_texts) - ] - return ToolResultBlockParam(tool_use_id="id", content=content) - - -def make_message_with_tool_result(num_images: int, num_texts: int = 0) -> MessageParam: - return MessageParam( - role="user", content=[make_tool_result_block(num_images, num_texts)] - ) - - -def test_no_images() -> None: - messages = [make_message_with_tool_result(0, 2)] - filtered = Agent._maybe_filter_to_n_most_recent_images(messages, 3, 2) - assert filtered == messages - - -def test_fewer_images_than_keep() -> None: - messages = [make_message_with_tool_result(2, 1)] - filtered = Agent._maybe_filter_to_n_most_recent_images(messages, 3, 2) - # Only ToolResultBlockParam with list content should be checked - all_images = [ - c - for m in filtered - for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) - for c in b.content - if getattr(c, "type", None) == "image" - ] - expected_images = [ - c - for b in (messages[0].content if isinstance(messages[0].content, list) else []) - if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) - for c in b.content - if getattr(c, "type", None) == "image" - ] - assert all_images == expected_images - - -def test_exactly_images_to_keep() -> None: - messages = [make_message_with_tool_result(3, 1)] - filtered = Agent._maybe_filter_to_n_most_recent_images(messages, 3, 2) - # Only check .content if the type is correct - first_block = ( - filtered[0].content[0] - if isinstance(filtered[0].content, list) and len(filtered[0].content) > 0 - else None - ) - if isinstance(first_block, ToolResultBlockParam) and isinstance( - first_block.content, list - ): - assert len(first_block.content) == 4 - else: - error_msg = ( - "filtered[0].content[0] is not a ToolResultBlockParam with list content" - ) - raise AssertionError(error_msg) # noqa: TRY004 - all_tool_result_contents = [ - c - for m in filtered - for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) - for c in b.content - ] - assert ( - sum(1 for c in all_tool_result_contents if getattr(c, "type", None) == "image") - == 3 - ) - - -def test_more_images_than_keep_removes_oldest() -> None: - messages = [ - make_message_with_tool_result(2, 0), - make_message_with_tool_result(2, 0), - ] - filtered = Agent._maybe_filter_to_n_most_recent_images(messages, 2, 2) - # Only 2 images should remain, and they should be the newest (from the last message) - all_images = [ - c - for m in filtered - for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) - for c in b.content - if getattr(c, "type", None) == "image" - ] - assert len(all_images) == 2 - # They should be from the last message - assert all_images == [ - c - for b in (filtered[1].content if isinstance(filtered[1].content, list) else []) - if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) - for c in b.content[:2] - if getattr(c, "type", None) == "image" - ] - - -def test_removal_chunking() -> None: - messages = [make_message_with_tool_result(5, 0)] - filtered = Agent._maybe_filter_to_n_most_recent_images(messages, 2, 2) - # Should remove 4 (chunk of 4), leaving 1 image - all_images = [ - c - for m in filtered - for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) - for c in b.content - if getattr(c, "type", None) == "image" - ] - assert len(all_images) == 3