diff --git a/src/askui/models/anthropic/messages_api.py b/src/askui/models/anthropic/messages_api.py index 6a8d348a..6866f7e9 100644 --- a/src/askui/models/anthropic/messages_api.py +++ b/src/askui/models/anthropic/messages_api.py @@ -1,4 +1,5 @@ import json +from functools import cached_property from typing import Type, cast from anthropic import NOT_GIVEN, Anthropic, NotGiven @@ -80,29 +81,23 @@ def __init__(self, message: str, content: list[ContentBlockParam]) -> None: super().__init__(self.message) -class AnthropicMessagesApiSettingsBase(BaseSettings): - model_config = SettingsConfigDict(validate_by_name=True, env_prefix="ANTHROPIC_") - - messages: MessageSettings = Field(default_factory=MessageSettings) - resolution: tuple[int, int] = Field( - default_factory=lambda: (1280, 800), - description="The resolution of the screen to use for the model", +class AnthropicMessagesApiSettings(BaseSettings): + model_config = SettingsConfigDict( + validate_by_name=True, + env_prefix="ANTHROPIC__", + env_nested_delimiter="__", ) - -AnthropicMessagesApiUnauthorizedSettings = AnthropicMessagesApiSettingsBase - - -class AnthropicMessagesApiAuthorizedSettings(AnthropicMessagesApiSettingsBase): api_key: SecretStr = Field( default=..., min_length=1, + validation_alias="ANTHROPIC_API_KEY", + ) + messages: MessageSettings = Field(default_factory=MessageSettings) + resolution: tuple[int, int] = Field( + default_factory=lambda: (1280, 800), + description="The resolution of the screen to use for the model", ) - - -AnthropicMessagesApiSettings = ( - AnthropicMessagesApiAuthorizedSettings | AnthropicMessagesApiUnauthorizedSettings -) class AnthropicMessagesApi(LocateModel, GetModel, MessagesApi): @@ -111,15 +106,17 @@ def __init__( locator_serializer: VlmLocatorSerializer, settings: AnthropicMessagesApiSettings | None = None, ) -> None: - self._settings = settings or AnthropicMessagesApiUnauthorizedSettings() + self._settings_default = settings self._locator_serializer = locator_serializer - @property + @cached_property + def _settings(self) -> AnthropicMessagesApiSettings: + if self._settings_default is None: + return AnthropicMessagesApiSettings() + return self._settings_default + + @cached_property def _client(self) -> Anthropic: - if not isinstance(self._settings, AnthropicMessagesApiAuthorizedSettings): - self._settings = AnthropicMessagesApiAuthorizedSettings.model_validate( - self._settings.model_dump() - ) return Anthropic(api_key=self._settings.api_key.get_secret_value()) @override diff --git a/src/askui/models/askui/inference_api.py b/src/askui/models/askui/inference_api.py index 5553876d..1880d017 100644 --- a/src/askui/models/askui/inference_api.py +++ b/src/askui/models/askui/inference_api.py @@ -1,5 +1,6 @@ import base64 import json as json_lib +from functools import cached_property from typing import Any, Type import httpx @@ -37,24 +38,25 @@ def _is_retryable_error(exception: BaseException) -> bool: return False -class AskUiInferenceApiSettingsBase(BaseSettings): - model_config = SettingsConfigDict(validate_by_name=True, env_prefix="ASKUI_") +class AskUiInferenceApiSettings(BaseSettings): + model_config = SettingsConfigDict( + validate_by_name=True, + env_prefix="ASKUI__", + env_nested_delimiter="__", + ) inference_endpoint: HttpUrl = Field( default_factory=lambda: HttpUrl("https://inference.askui.com"), # noqa: F821 + validation_alias="ASKUI_INFERENCE_ENDPOINT", ) messages: MessageSettings = Field(default_factory=MessageSettings) - - -AskUiInferenceApiSettingsUnauthorized = AskUiInferenceApiSettingsBase - - -class AskUiInferenceApiAuthorizedSettings(AskUiInferenceApiSettingsBase): token: SecretStr = Field( default=..., + validation_alias="ASKUI_TOKEN", ) workspace_id: UUID4 = Field( default=..., + validation_alias="ASKUI_WORKSPACE_ID", ) @property @@ -71,26 +73,23 @@ def base_url(self) -> str: return f"{self.inference_endpoint}api/v1/workspaces/{self.workspace_id}" -AskUiInferenceApiSettings = ( - AskUiInferenceApiAuthorizedSettings | AskUiInferenceApiSettingsUnauthorized -) - - class AskUiInferenceApi(GetModel, LocateModel, MessagesApi): def __init__( self, locator_serializer: AskUiLocatorSerializer, settings: AskUiInferenceApiSettings | None = None, ) -> None: - self._settings = settings or AskUiInferenceApiSettingsUnauthorized() + self._settings_default = settings self._locator_serializer = locator_serializer - @property + @cached_property + def _settings(self) -> AskUiInferenceApiSettings: + if self._settings_default is None: + return AskUiInferenceApiSettings() + return self._settings_default + + @cached_property def _client(self) -> httpx.Client: - if not isinstance(self._settings, AskUiInferenceApiAuthorizedSettings): - self._settings = AskUiInferenceApiAuthorizedSettings.model_validate( - self._settings.model_dump() - ) return httpx.Client( base_url=f"{self._settings.base_url}", headers={ diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index 422c3fc6..fbda3329 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -12,7 +12,7 @@ from askui.models.askui.ai_element_utils import AiElementCollection from askui.models.askui.inference_api import ( AskUiInferenceApi, - AskUiInferenceApiAuthorizedSettings, + AskUiInferenceApiSettings, ) from askui.models.askui.model_router import AskUiModelRouter from askui.models.models import ModelName @@ -62,7 +62,7 @@ def askui_facade( ), reporter=reporter, ), - settings=AskUiInferenceApiAuthorizedSettings( + settings=AskUiInferenceApiSettings( messages=settings.messages, ), )