Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/askui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ModelName,
ModelRegistry,
OnMessageCb,
OnMessageCbParam,
Point,
TextBlockParam,
TextCitationParam,
Expand Down Expand Up @@ -64,6 +65,7 @@
"ModelRegistry",
"ModifierKey",
"OnMessageCb",
"OnMessageCbParam",
"PcKey",
"Point",
"ResponseSchema",
Expand Down
3 changes: 2 additions & 1 deletion src/askui/chat/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from askui.chat.api.assistants.dependencies import get_assistant_service
from askui.chat.api.assistants.router import router as assistants_router
from askui.chat.api.dependencies import get_settings
from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_settings
from askui.chat.api.health.router import router as health_router
from askui.chat.api.messages.router import router as messages_router
from askui.chat.api.runs.router import router as runs_router
Expand All @@ -24,6 +24,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
title="AskUI Chat API",
version="0.1.0",
lifespan=lifespan,
dependencies=[SetEnvFromHeadersDep],
)

# Add CORS middleware
Expand Down
48 changes: 47 additions & 1 deletion src/askui/chat/api/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from fastapi import Depends
import os
from typing import Annotated, Optional

from fastapi import Depends, Header
from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer
from pydantic import UUID4

from askui.chat.api.settings import Settings

Expand All @@ -9,3 +14,44 @@ def get_settings() -> Settings:


SettingsDep = Depends(get_settings)


http_bearer = HTTPBearer(scheme_name="Bearer", auto_error=False)
api_key_header = APIKeyHeader(
name="Authorization", auto_error=False, scheme_name="Basic"
)


def get_authorization(
bearer_auth: Annotated[
Optional[HTTPAuthorizationCredentials], Depends(http_bearer)
] = None,
api_key_auth: Annotated[Optional[str], Depends(api_key_header)] = None,
) -> Optional[str]:
if bearer_auth:
return f"{bearer_auth.scheme} {bearer_auth.credentials}"
if api_key_auth:
return api_key_auth
return None


def set_env_from_headers(
authorization: Annotated[Optional[str], Depends(get_authorization)] = None,
askui_workspace: Annotated[UUID4 | None, Header()] = None,
) -> None:
"""
Set environment variables from Authorization and AskUI-Workspace headers.

Args:
authorization (str | None, optional): Authorization header.
Defaults to `None`.
askui_workspace (UUID4 | None, optional): Workspace ID from AskUI-Workspace header.
Defaults to `None`.
"""
if authorization:
os.environ["ASKUI__AUTHORIZATION"] = authorization
if askui_workspace:
os.environ["ASKUI_WORKSPACE_ID"] = str(askui_workspace)


SetEnvFromHeadersDep = Depends(set_env_from_headers)
9 changes: 5 additions & 4 deletions src/askui/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
ModelDefinition,
ModelName,
ModelRegistry,
OnMessageCb,
Point,
)
from .openrouter.model import OpenRouterModel
Expand All @@ -28,11 +27,13 @@
ToolUseBlockParam,
UrlImageSourceParam,
)
from .shared.agent_on_message_cb import OnMessageCb, OnMessageCbParam

__all__ = [
"ActModel",
"Base64ImageSourceParam",
"CacheControlEphemeralParam",
"ChatCompletionsCreateSettings",
"CitationCharLocationParam",
"CitationContentBlockLocationParam",
"CitationPageLocationParam",
Expand All @@ -48,13 +49,13 @@
"ModelName",
"ModelRegistry",
"OnMessageCb",
"OnMessageCbParam",
"OpenRouterModel",
"OpenRouterSettings",
"Point",
"TextBlockParam",
"TextCitationParam",
"ToolResultBlockParam",
"ToolUseBlockParam",
"UrlImageSourceParam",
"OpenRouterModel",
"OpenRouterSettings",
"ChatCompletionsCreateSettings",
]
10 changes: 10 additions & 0 deletions src/askui/models/askui/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,21 @@ class AskUiInferenceApiSettings(BaseSettings):
validate_by_name=True,
env_prefix="ASKUI__",
env_nested_delimiter="__",
arbitrary_types_allowed=True,
)

inference_endpoint: HttpUrl = Field(
default_factory=lambda: HttpUrl("https://inference.askui.com"), # noqa: F821
validation_alias="ASKUI_INFERENCE_ENDPOINT",
)
messages: MessageSettings = Field(default_factory=MessageSettings)
authorization: str | NotGiven = Field(
default=NOT_GIVEN,
description=(
"The authorization header to use for the AskUI Inference API. "
"If not provided, the token will be used to generate the header."
),
)
token: SecretStr = Field(
default=...,
validation_alias="ASKUI_TOKEN",
Expand All @@ -61,6 +69,8 @@ class AskUiInferenceApiSettings(BaseSettings):

@property
def authorization_header(self) -> str:
if self.authorization:
return self.authorization
token_str = self.token.get_secret_value()
token_base64 = base64.b64encode(token_str.encode()).decode()
return f"Basic {token_base64}"
Expand Down