Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/askui/models/askui/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
BetaThinkingConfigParam,
BetaToolChoiceParam,
)
from pydantic import UUID4, Field, HttpUrl, SecretStr
from pydantic import UUID4, Field, HttpUrl, SecretStr, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
from typing_extensions import override
from typing_extensions import Self, override

from askui.locators.locators import Locator
from askui.locators.serializers import AskUiLocatorSerializer, AskUiSerializedLocator
Expand Down Expand Up @@ -51,26 +51,37 @@ class AskUiInferenceApiSettings(BaseSettings):
validation_alias="ASKUI_INFERENCE_ENDPOINT",
)
messages: MessageSettings = Field(default_factory=MessageSettings)
authorization: str | NotGiven = Field(
authorization: SecretStr | NotGiven = Field(
default=NOT_GIVEN,
description=(
"The authorization header to use for the AskUI Inference API. "
"If not provided, the token will be used to generate the header."
),
)
token: SecretStr = Field(
default=...,
token: SecretStr | NotGiven = Field(
default=NOT_GIVEN,
validation_alias="ASKUI_TOKEN",
)
workspace_id: UUID4 = Field(
default=...,
validation_alias="ASKUI_WORKSPACE_ID",
)

@model_validator(mode="after")
def check_authorization(self) -> "Self":
if self.authorization == NOT_GIVEN and self.token == NOT_GIVEN:
error_message = (
'Either authorization ("ASKUI__AUTHORIZATION" environment variable) '
'or token ("ASKUI_TOKEN" environment variable) must be provided'
)
raise ValueError(error_message)
return self

@property
def authorization_header(self) -> str:
if self.authorization:
return self.authorization
return self.authorization.get_secret_value()
assert not isinstance(self.token, NotGiven), "Token is not set"
token_str = self.token.get_secret_value()
token_base64 = base64.b64encode(token_str.encode()).decode()
return f"Basic {token_base64}"
Expand Down