diff --git a/README.md b/README.md index 65418b68..1f566365 100644 --- a/README.md +++ b/README.md @@ -446,7 +446,7 @@ You can use Vision Agent with [OpenRouter](https://openrouter.ai/) to access a w ```python from askui import VisionAgent from askui.models import ( - OpenRouterGetModel, + OpenRouterModel, OpenRouterSettings, ModelRegistry, ) @@ -454,7 +454,7 @@ from askui.models import ( # Register OpenRouter model in the registry custom_models: ModelRegistry = { - "my-custom-model": OpenRouterGetModel( + "my-custom-model": OpenRouterModel( OpenRouterSettings( model="anthropic/claude-opus-4", ) diff --git a/src/askui/models/__init__.py b/src/askui/models/__init__.py index 5cdcdd61..925be473 100644 --- a/src/askui/models/__init__.py +++ b/src/askui/models/__init__.py @@ -11,7 +11,7 @@ OnMessageCb, Point, ) -from .openrouter.handler import OpenRouterGetModel +from .openrouter.model import OpenRouterModel from .openrouter.settings import OpenRouterSettings from .shared.computer_agent_message_param import ( Base64ImageSourceParam, @@ -28,6 +28,7 @@ ToolUseBlockParam, UrlImageSourceParam, ) +from .shared.settings import ChatCompletionsCreateSettings __all__ = [ "ActModel", @@ -54,6 +55,7 @@ "ToolResultBlockParam", "ToolUseBlockParam", "UrlImageSourceParam", - "OpenRouterGetModel", + "OpenRouterModel", "OpenRouterSettings", + "ChatCompletionsCreateSettings", ] diff --git a/src/askui/models/anthropic/handler.py b/src/askui/models/anthropic/model.py similarity index 75% rename from src/askui/models/anthropic/handler.py rename to src/askui/models/anthropic/model.py index 2856c0df..dcebe735 100644 --- a/src/askui/models/anthropic/handler.py +++ b/src/askui/models/anthropic/model.py @@ -21,6 +21,7 @@ ModelName, Point, ) +from askui.models.shared.prompts import SYSTEM_PROMPT_GET, build_system_prompt_locate from askui.models.types.response_schemas import ResponseSchema from askui.utils.image_utils import ( ImageSource, @@ -47,8 +48,8 @@ def _inference( ) -> list[anthropic.types.ContentBlock]: message = self._client.messages.create( model=model, - max_tokens=self._settings.max_tokens, - temperature=self._settings.temperature, + max_tokens=self._settings.chat_completions_create_settings.max_tokens, + temperature=self._settings.chat_completions_create_settings.temperature, system=system_prompt, messages=[ { @@ -87,12 +88,11 @@ def locate( prompt = f"Click on {locator_serialized}" screen_width = self._settings.resolution[0] screen_height = self._settings.resolution[1] - system_prompt = f"Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try taking another screenshot.\n* The screen's resolution is {screen_width}x{screen_height}.\n* The display number is 0\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.\n" # noqa: E501 scaled_image = scale_image_with_padding(image.root, screen_width, screen_height) response = self._inference( image_to_base64(scaled_image), prompt, - system_prompt, + build_system_prompt_locate(str(screen_width), str(screen_height)), model=ANTHROPIC_MODEL_NAME_MAPPING[ModelName(model_choice)], ) assert len(response) > 0 @@ -129,11 +129,10 @@ def get( max_width=self._settings.resolution[0], max_height=self._settings.resolution[1], ) - system_prompt = "You are an agent to process screenshots and answer questions about things on the screen or extract information from it. Answer only with the response to the question and keep it short and precise." # noqa: E501 response = self._inference( base64_image=image_to_base64(scaled_image), prompt=query, - system_prompt=system_prompt, + system_prompt=SYSTEM_PROMPT_GET, model=ANTHROPIC_MODEL_NAME_MAPPING[ModelName(model_choice)], ) if len(response) == 0: diff --git a/src/askui/models/anthropic/settings.py b/src/askui/models/anthropic/settings.py index e804495d..d58219de 100644 --- a/src/askui/models/anthropic/settings.py +++ b/src/askui/models/anthropic/settings.py @@ -2,12 +2,14 @@ from pydantic_settings import BaseSettings from askui.models.shared.computer_agent import ComputerAgentSettingsBase +from askui.models.shared.settings import ChatCompletionsCreateSettings COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" class AnthropicSettings(BaseSettings): api_key: SecretStr = Field( + default=..., min_length=1, validation_alias="ANTHROPIC_API_KEY", ) @@ -19,8 +21,10 @@ class ClaudeSettingsBase(BaseModel): class ClaudeSettings(ClaudeSettingsBase): resolution: tuple[int, int] = Field(default_factory=lambda: (1280, 800)) - max_tokens: int = 1000 - temperature: float = 0.0 + chat_completions_create_settings: ChatCompletionsCreateSettings = Field( + default_factory=ChatCompletionsCreateSettings, + description="Settings for ChatCompletions", + ) class ClaudeComputerAgentSettings(ComputerAgentSettingsBase, ClaudeSettingsBase): diff --git a/src/askui/models/model_router.py b/src/askui/models/model_router.py index 2f14c0a1..be5cbab4 100644 --- a/src/askui/models/model_router.py +++ b/src/askui/models/model_router.py @@ -37,7 +37,7 @@ from ..logger import logger from .anthropic.computer_agent import ClaudeComputerAgent -from .anthropic.handler import ClaudeHandler +from .anthropic.model import ClaudeHandler from .askui.inference_api import AskUiInferenceApi, AskUiSettings from .ui_tars_ep.ui_tars_api import UiTarsApiHandler, UiTarsApiHandlerSettings diff --git a/src/askui/models/openrouter/handler.py b/src/askui/models/openrouter/handler.py deleted file mode 100644 index f9f7f905..00000000 --- a/src/askui/models/openrouter/handler.py +++ /dev/null @@ -1,77 +0,0 @@ -import os -from typing import Type - -from openai import OpenAI -from typing_extensions import override - -from askui.models.exceptions import QueryNoResponseError -from askui.models.models import GetModel -from askui.models.types.response_schemas import ResponseSchema -from askui.utils.image_utils import ImageSource - -from .prompts import PROMPT_QA -from .settings import OpenRouterSettings - - -class OpenRouterGetModel(GetModel): - def __init__(self, settings: OpenRouterSettings): - self._settings = settings - - _open_router_key = os.getenv("OPEN_ROUTER_API_KEY") - if _open_router_key is None: - error_msg = "OPEN_ROUTER_API_KEY is not set" - raise ValueError(error_msg) - - self._client = OpenAI( - api_key=_open_router_key, - base_url="https://openrouter.ai/api/v1", - ) - - def _predict(self, image_url: str, instruction: str, prompt: str) -> str | None: - chat_completion = self._client.chat.completions.create( - model=self._settings.model, - messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url, - }, - }, - {"type": "text", "text": prompt + instruction}, - ], - } - ], - top_p=None, - temperature=None, - max_tokens=150, - stream=False, - seed=None, - stop=None, - frequency_penalty=None, - presence_penalty=None, - ) - return chat_completion.choices[0].message.content - - @override - def get( - self, - query: str, - image: ImageSource, - response_schema: Type[ResponseSchema] | None, - model_choice: str, - ) -> ResponseSchema | str: - if response_schema is not None: - error_msg = f'Response schema is not supported for model "{model_choice}"' - raise NotImplementedError(error_msg) - response = self._predict( - image_url=image.to_data_url(), - instruction=query, - prompt=PROMPT_QA, - ) - if response is None: - error_msg = f'No response from model "{model_choice}" to query: "{query}"' - raise QueryNoResponseError(error_msg, query) - return response diff --git a/src/askui/models/openrouter/model.py b/src/askui/models/openrouter/model.py new file mode 100644 index 00000000..9d671c16 --- /dev/null +++ b/src/askui/models/openrouter/model.py @@ -0,0 +1,185 @@ +import json +from typing import TYPE_CHECKING, Any, Optional, Type + +import openai +from openai import OpenAI +from typing_extensions import override + +from askui.logger import logger +from askui.models.exceptions import QueryNoResponseError +from askui.models.models import GetModel +from askui.models.shared.prompts import SYSTEM_PROMPT_GET +from askui.models.types.response_schemas import ResponseSchema, to_response_schema +from askui.utils.image_utils import ImageSource + +from .settings import OpenRouterSettings + +if TYPE_CHECKING: + from openai.types.chat.completion_create_params import ResponseFormat + + +def _clean_schema_refs(schema: dict[str, Any] | list[Any]) -> None: + """Remove title fields that are at the same level as $ref fields as they are not supported by OpenAI.""" # noqa: E501 + if isinstance(schema, dict): + if "$ref" in schema and "title" in schema: + del schema["title"] + for value in schema.values(): + if isinstance(value, (dict, list)): + _clean_schema_refs(value) + elif isinstance(schema, list): + for item in schema: + if isinstance(item, (dict, list)): + _clean_schema_refs(item) + + +class OpenRouterModel(GetModel): + """ + This class implements the GetModel interface for the OpenRouter API. + + Args: + settings (OpenRouterSettings): The settings for the OpenRouter model. + + Example: + ```python + from askui import VisionAgent + from askui.models import ( + OpenRouterModel, + OpenRouterSettings, + ModelRegistry, + ) + + + # Register OpenRouter model in the registry + custom_models: ModelRegistry = { + "my-custom-model": OpenRouterGetModel( + OpenRouterSettings( + model="anthropic/claude-opus-4", + ) + ), + } + + with VisionAgent(models=custom_models, model={"get":"my-custom-model"}) as agent: + result = agent.get("What is the main heading on the screen?") + print(result) + ``` + """ # noqa: E501 + + def __init__( + self, + settings: OpenRouterSettings | None = None, + client: Optional[OpenAI] = None, + ): + self._settings = settings or OpenRouterSettings() + + self._client = ( + client + if client is not None + else OpenAI( + api_key=self._settings.open_router_api_key.get_secret_value(), + base_url=str(self._settings.base_url), + ) + ) + + def _predict( + self, + image_url: str, + instruction: str, + prompt: str, + response_schema: type[ResponseSchema] | None, + ) -> str | None | ResponseSchema: + extra_body: dict[str, object] = {} + + if len(self._settings.models) > 0: + extra_body["models"] = self._settings.models + + _response_schema = ( + to_response_schema(response_schema) if response_schema else None + ) + + response_format: openai.NotGiven | ResponseFormat = openai.NOT_GIVEN + if _response_schema is not None: + extra_body["provider"] = {"require_parameters": True} + schema = _response_schema.model_json_schema() + _clean_schema_refs(schema) + + defs = schema.pop("$defs", None) + schema_response_wrapper = { + "type": "object", + "properties": {"response": schema}, + "additionalProperties": False, + "required": ["response"], + } + if defs: + schema_response_wrapper["$defs"] = defs + response_format = { + "type": "json_schema", + "json_schema": { + "name": "user_json_schema", + "schema": schema_response_wrapper, + "strict": True, + }, + } + + chat_completion = self._client.chat.completions.create( + model=self._settings.model, + extra_body=extra_body, + response_format=response_format, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + }, + {"type": "text", "text": prompt + instruction}, + ], + } + ], + stream=False, + top_p=self._settings.chat_completions_create_settings.top_p, + temperature=self._settings.chat_completions_create_settings.temperature, + max_tokens=self._settings.chat_completions_create_settings.max_tokens, + seed=self._settings.chat_completions_create_settings.seed, + stop=self._settings.chat_completions_create_settings.stop, + frequency_penalty=self._settings.chat_completions_create_settings.frequency_penalty, + presence_penalty=self._settings.chat_completions_create_settings.presence_penalty, + ) + + model_response = chat_completion.choices[0].message.content + + if _response_schema is not None and model_response is not None: + try: + response_json = json.loads(model_response) + except json.JSONDecodeError: + error_msg = f"Expected JSON, but model {self._settings.model} returned: {model_response}" # noqa: E501 + logger.error(error_msg) + raise ValueError(error_msg) from None + + validated_response = _response_schema.model_validate( + response_json["response"] + ) + return validated_response.root + + return model_response + + @override + def get( + self, + query: str, + image: ImageSource, + response_schema: Type[ResponseSchema] | None, + model_choice: str, + ) -> ResponseSchema | str: + response = self._predict( + image_url=image.to_data_url(), + instruction=query, + prompt=SYSTEM_PROMPT_GET, + response_schema=response_schema, + ) + if response is None: + error_msg = f'No response from model "{model_choice}" to query: "{query}"' + raise QueryNoResponseError(error_msg, query) + return response diff --git a/src/askui/models/openrouter/prompts.py b/src/askui/models/openrouter/prompts.py deleted file mode 100644 index 2ed79fd4..00000000 --- a/src/askui/models/openrouter/prompts.py +++ /dev/null @@ -1 +0,0 @@ -PROMPT_QA = "You are an agent to process screenshots and answer questions about things on the screen or extract information from it. Answer only with the response to the question and keep it short and precise." # noqa: E501 diff --git a/src/askui/models/openrouter/settings.py b/src/askui/models/openrouter/settings.py index 569a8854..a99dc350 100644 --- a/src/askui/models/openrouter/settings.py +++ b/src/askui/models/openrouter/settings.py @@ -1,5 +1,34 @@ -from pydantic import BaseModel, Field +from pydantic import Field, HttpUrl, SecretStr +from pydantic_settings import BaseSettings +from askui.models.shared.settings import ChatCompletionsCreateSettings -class OpenRouterSettings(BaseModel): - model: str = Field(..., description="OpenRouter model name") + +class OpenRouterSettings(BaseSettings): + """ + Settings for OpenRouter API configuration. + + Args: + model (str): OpenRouter model name. See https://openrouter.ai/models + models (list[str]): OpenRouter model names + base_url (HttpUrl): OpenRouter base URL. Defaults to https://openrouter.ai/api/v1 + chat_completions_create_settings (ChatCompletionsCreateSettings): Settings for ChatCompletions + """ # noqa: E501 + + model: str = Field(default="openrouter/auto", description="OpenRouter model name") + models: list[str] = Field( + default_factory=list, description="OpenRouter model names" + ) + open_router_api_key: SecretStr = Field( + default=..., + description="API key for OpenRouter authentication", + validation_alias="OPEN_ROUTER_API_KEY", + ) + base_url: HttpUrl = Field( + default_factory=lambda: HttpUrl("https://openrouter.ai/api/v1"), + description="OpenRouter base URL", + ) + chat_completions_create_settings: ChatCompletionsCreateSettings = Field( + default_factory=ChatCompletionsCreateSettings, + description="Settings for ChatCompletions", + ) diff --git a/src/askui/models/shared/prompts.py b/src/askui/models/shared/prompts.py new file mode 100644 index 00000000..4c922b22 --- /dev/null +++ b/src/askui/models/shared/prompts.py @@ -0,0 +1,5 @@ +SYSTEM_PROMPT_GET = "You are an agent to process screenshots and answer questions about things on the screen or extract information from it. Answer only with the response to the question and keep it short and precise." # noqa: E501 + + +def build_system_prompt_locate(screen_width: str, screen_height: str) -> str: + return f"Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try taking another screenshot.\n* The screen's resolution is {screen_width}x{screen_height}.\n* The display number is 0\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.\n" # noqa: E501 diff --git a/src/askui/models/shared/settings.py b/src/askui/models/shared/settings.py new file mode 100644 index 00000000..f0842764 --- /dev/null +++ b/src/askui/models/shared/settings.py @@ -0,0 +1,78 @@ +from pydantic import BaseModel, Field + + +class ChatCompletionsCreateSettings(BaseModel): + """ + Settings for creating chat completions. + + Args: + top_p (float | None, optional): An alternative to sampling with temperature, + called nucleus sampling, where the model considers the results of the tokens + with top_p probability mass. So `0.1` means only the tokens comprising + the top 10% probability mass are considered. We generally recommend + altering this or `temperature` but not both. + Defaults to `None`. + + temperature (float, optional): What sampling temperature to use, + between `0` and `2`. Higher values like `0.8` will make the output more + random, while lower values like `0.2` will make it more focused and + deterministic. We generally recommend altering this or `top_p` but not both. + Defaults to `0.0`. + + max_tokens (int, optional): The maximum number of tokens that can be generated + in the chat completion. This value can be used to control costs for text + generated via API. This value is now deprecated in favor of + `max_completion_tokens` for some models. + Defaults to `1000`. + + seed (int | None, optional): If specified, the system will make a best effort + to sample deterministically, such that repeated requests with the same seed + and parameters should return the same result. Determinism is not guaranteed. + Defaults to `None`. + + stop (str | list[str] | None, optional): Up to 4 sequences where the API + will stop generating further tokens. The returned text will not contain the + stop sequence. + Defaults to `None`. + + frequency_penalty (float | None, optional): Number between `-2.0` and `2.0`. + Positive values penalize new tokens based on their existing frequency + in the text so far, decreasing the model's likelihood to repeat the same + line verbatim. + Defaults to `None`. + + presence_penalty (float | None, optional): Number between `-2.0` and `2.0`. + Positive values penalize new tokens based on whether they appear in the text + so far, increasing the model's likelihood to talk about new topics. + Defaults to `None`. + + Returns: + ChatCompletionsCreateSettings: The settings object for chat completions. + + Example: + ```python + settings = ChatCompletionsCreateSettings(top_p=0.9, temperature=0.7) + ``` + """ + + top_p: float | None = Field( + default=None, + ) + temperature: float = Field( + default=0.0, + ) + max_tokens: int = Field( + default=1000, + ) + seed: int | None = Field( + default=None, + ) + stop: str | list[str] | None = Field( + default=None, + ) + frequency_penalty: float | None = Field( + default=None, + ) + presence_penalty: float | None = Field( + default=None, + ) diff --git a/tests/integration/models/openrouter/__init__.py b/tests/integration/models/openrouter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/models/openrouter/conftest.py b/tests/integration/models/openrouter/conftest.py new file mode 100644 index 00000000..bfd76cfc --- /dev/null +++ b/tests/integration/models/openrouter/conftest.py @@ -0,0 +1,41 @@ +from typing import cast +from unittest.mock import MagicMock + +import pytest +from PIL import Image as PILImage +from pytest_mock import MockerFixture + +from askui.models.openrouter.model import OpenRouterModel +from askui.models.openrouter.settings import OpenRouterSettings +from askui.utils.image_utils import ImageSource + + +@pytest.fixture(autouse=True) +def set_env_variable(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OPEN_ROUTER_API_KEY", "test_openrouter_api_key") + + +@pytest.fixture +def settings() -> OpenRouterSettings: + return OpenRouterSettings( + model="test-model", + ) + + +@pytest.fixture +def mock_openai_client(mocker: MockerFixture) -> MagicMock: + return cast("MagicMock", mocker.MagicMock()) + + +@pytest.fixture +def openrouter_model( + settings: OpenRouterSettings, mock_openai_client: MagicMock +) -> OpenRouterModel: + return OpenRouterModel(settings=settings, client=mock_openai_client) + + +@pytest.fixture +def image_source_github_login_screenshot( + github_login_screenshot: PILImage.Image, +) -> ImageSource: + return ImageSource(root=github_login_screenshot) diff --git a/tests/integration/models/openrouter/test_openrouter.py b/tests/integration/models/openrouter/test_openrouter.py new file mode 100644 index 00000000..25615610 --- /dev/null +++ b/tests/integration/models/openrouter/test_openrouter.py @@ -0,0 +1,113 @@ +import json +from typing import Any +from unittest.mock import Mock + +import pytest + +from askui.models.exceptions import QueryNoResponseError +from askui.models.openrouter.model import OpenRouterModel +from askui.models.types.response_schemas import ResponseSchemaBase +from askui.utils.image_utils import ImageSource + + +class TestResponse(ResponseSchemaBase): + text: str + number: int + + +def _create_mock_completion(content: str | None) -> Any: + """Create a mock object that mimics the OpenAI ChatCompletion response.""" + mock_message = Mock() + mock_message.content = content + mock_choice = Mock() + mock_choice.message = mock_message + mock_completion = Mock() + mock_completion.choices = [mock_choice] + return mock_completion + + +def test_basic_query_returns_string( + mock_openai_client: Mock, + openrouter_model: OpenRouterModel, + image_source_github_login_screenshot: ImageSource, +) -> None: + mock_openai_client.chat.completions.create.return_value = _create_mock_completion( + "Test response" + ) + + result = openrouter_model.get( + query="What is in the image?", + image=image_source_github_login_screenshot, + response_schema=None, + model_choice="test-model", + ) + + assert isinstance(result, str) + assert result == "Test response" + mock_openai_client.chat.completions.create.assert_called_once() + + +def test_query_with_response_schema_returns_validated_object( + mock_openai_client: Mock, + openrouter_model: OpenRouterModel, + image_source_github_login_screenshot: ImageSource, +) -> None: + mock_response = { + "response": { + "text": "Test text", + "number": 42, + } + } + mock_openai_client.chat.completions.create.return_value = _create_mock_completion( + json.dumps(mock_response) + ) + + result = openrouter_model.get( + query="What is in the image?", + image=image_source_github_login_screenshot, + response_schema=TestResponse, + model_choice="test-model", + ) + + assert isinstance(result, TestResponse) + assert result.text == "Test text" + assert result.number == 42 + mock_openai_client.chat.completions.create.assert_called_once() + + +def test_no_response_from_model( + mock_openai_client: Mock, + openrouter_model: OpenRouterModel, + image_source_github_login_screenshot: ImageSource, +) -> None: + mock_openai_client.chat.completions.create.return_value = _create_mock_completion( + None + ) + + with pytest.raises(QueryNoResponseError): + openrouter_model.get( + query="What is in the image?", + image=image_source_github_login_screenshot, + response_schema=None, + model_choice="test-model", + ) + mock_openai_client.chat.completions.create.assert_called_once() + + +def test_malformed_json_from_model( + mock_openai_client: Mock, + openrouter_model: OpenRouterModel, + image_source_github_login_screenshot: ImageSource, +) -> None: + mock_openai_client.chat.completions.create.return_value = _create_mock_completion( + "Invalid JSON {" + ) + + with pytest.raises(ValueError): + openrouter_model.get( + query="What is in the image?", + image=image_source_github_login_screenshot, + response_schema=TestResponse, + model_choice="test-model", + ) + mock_openai_client.chat.completions.create.assert_called_once()