Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions src/askui/models/anthropic/messages_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from functools import cached_property
from typing import Type, cast

from anthropic import NOT_GIVEN, Anthropic, NotGiven
Expand Down Expand Up @@ -80,29 +81,23 @@ def __init__(self, message: str, content: list[ContentBlockParam]) -> None:
super().__init__(self.message)


class AnthropicMessagesApiSettingsBase(BaseSettings):
model_config = SettingsConfigDict(validate_by_name=True, env_prefix="ANTHROPIC_")

messages: MessageSettings = Field(default_factory=MessageSettings)
resolution: tuple[int, int] = Field(
default_factory=lambda: (1280, 800),
description="The resolution of the screen to use for the model",
class AnthropicMessagesApiSettings(BaseSettings):
model_config = SettingsConfigDict(
validate_by_name=True,
env_prefix="ANTHROPIC__",
env_nested_delimiter="__",
)


AnthropicMessagesApiUnauthorizedSettings = AnthropicMessagesApiSettingsBase


class AnthropicMessagesApiAuthorizedSettings(AnthropicMessagesApiSettingsBase):
api_key: SecretStr = Field(
default=...,
min_length=1,
validation_alias="ANTHROPIC_API_KEY",
)
messages: MessageSettings = Field(default_factory=MessageSettings)
resolution: tuple[int, int] = Field(
default_factory=lambda: (1280, 800),
description="The resolution of the screen to use for the model",
)


AnthropicMessagesApiSettings = (
AnthropicMessagesApiAuthorizedSettings | AnthropicMessagesApiUnauthorizedSettings
)


class AnthropicMessagesApi(LocateModel, GetModel, MessagesApi):
Expand All @@ -111,15 +106,17 @@ def __init__(
locator_serializer: VlmLocatorSerializer,
settings: AnthropicMessagesApiSettings | None = None,
) -> None:
self._settings = settings or AnthropicMessagesApiUnauthorizedSettings()
self._settings_default = settings
self._locator_serializer = locator_serializer

@property
@cached_property
def _settings(self) -> AnthropicMessagesApiSettings:
if self._settings_default is None:
return AnthropicMessagesApiSettings()
return self._settings_default

@cached_property
def _client(self) -> Anthropic:
if not isinstance(self._settings, AnthropicMessagesApiAuthorizedSettings):
self._settings = AnthropicMessagesApiAuthorizedSettings.model_validate(
self._settings.model_dump()
)
return Anthropic(api_key=self._settings.api_key.get_secret_value())

@override
Expand Down
37 changes: 18 additions & 19 deletions src/askui/models/askui/inference_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import json as json_lib
from functools import cached_property
from typing import Any, Type

import httpx
Expand Down Expand Up @@ -37,24 +38,25 @@ def _is_retryable_error(exception: BaseException) -> bool:
return False


class AskUiInferenceApiSettingsBase(BaseSettings):
model_config = SettingsConfigDict(validate_by_name=True, env_prefix="ASKUI_")
class AskUiInferenceApiSettings(BaseSettings):
model_config = SettingsConfigDict(
validate_by_name=True,
env_prefix="ASKUI__",
env_nested_delimiter="__",
)

inference_endpoint: HttpUrl = Field(
default_factory=lambda: HttpUrl("https://inference.askui.com"), # noqa: F821
validation_alias="ASKUI_INFERENCE_ENDPOINT",
)
messages: MessageSettings = Field(default_factory=MessageSettings)


AskUiInferenceApiSettingsUnauthorized = AskUiInferenceApiSettingsBase


class AskUiInferenceApiAuthorizedSettings(AskUiInferenceApiSettingsBase):
token: SecretStr = Field(
default=...,
validation_alias="ASKUI_TOKEN",
)
workspace_id: UUID4 = Field(
default=...,
validation_alias="ASKUI_WORKSPACE_ID",
)

@property
Expand All @@ -71,26 +73,23 @@ def base_url(self) -> str:
return f"{self.inference_endpoint}api/v1/workspaces/{self.workspace_id}"


AskUiInferenceApiSettings = (
AskUiInferenceApiAuthorizedSettings | AskUiInferenceApiSettingsUnauthorized
)


class AskUiInferenceApi(GetModel, LocateModel, MessagesApi):
def __init__(
self,
locator_serializer: AskUiLocatorSerializer,
settings: AskUiInferenceApiSettings | None = None,
) -> None:
self._settings = settings or AskUiInferenceApiSettingsUnauthorized()
self._settings_default = settings
self._locator_serializer = locator_serializer

@property
@cached_property
def _settings(self) -> AskUiInferenceApiSettings:
if self._settings_default is None:
return AskUiInferenceApiSettings()
return self._settings_default

@cached_property
def _client(self) -> httpx.Client:
if not isinstance(self._settings, AskUiInferenceApiAuthorizedSettings):
self._settings = AskUiInferenceApiAuthorizedSettings.model_validate(
self._settings.model_dump()
)
return httpx.Client(
base_url=f"{self._settings.base_url}",
headers={
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/agent/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from askui.models.askui.ai_element_utils import AiElementCollection
from askui.models.askui.inference_api import (
AskUiInferenceApi,
AskUiInferenceApiAuthorizedSettings,
AskUiInferenceApiSettings,
)
from askui.models.askui.model_router import AskUiModelRouter
from askui.models.models import ModelName
Expand Down Expand Up @@ -62,7 +62,7 @@ def askui_facade(
),
reporter=reporter,
),
settings=AskUiInferenceApiAuthorizedSettings(
settings=AskUiInferenceApiSettings(
messages=settings.messages,
),
)
Expand Down