Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions .flake8

This file was deleted.

9 changes: 5 additions & 4 deletions .vscode/extensions.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
{
"recommendations": [
"ms-python.python",
"ms-python.black-formatter",
"ms-python.isort",
"DavidAnson.vscode-markdownlint"
"charliermarsh.ruff",
"tamasfe.even-better-toml",
"davidanson.vscode-markdownlint",
"streetsidesoftware.code-spell-checker"
]
}
}
26 changes: 18 additions & 8 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
{
"files.trimTrailingWhitespace": true,
"[toml]": {
"editor.formatOnSave": true,
"editor.tabSize": 4,
},
"[json]": {
"editor.formatOnSave": true,
},
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"editor.tabSize": 4,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
},
"editor.tabSize": 4
"source.organizeImports.ruff": "explicit",
"source.fixAll.ruff": "explicit"
}
},
"python.testing.pytestArgs": ["."],
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.analysis.typeCheckingMode": "basic"
"files.insertFinalNewline": true,
"files.trimFinalNewlines": true,
"files.trimTrailingWhitespace": true,
}

3 changes: 1 addition & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ VENV_DIR ?= .venv
POETRY ?= poetry
POETRY_PYTHON ?= python

.PHONY: all init_env install clean lint format test spell_check
.PHONY: all init_env install clean lint format test

-include .env
export
Expand Down Expand Up @@ -51,4 +51,3 @@ help:
@echo '-- LINTING --'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'spell_check - run spell check'
40 changes: 23 additions & 17 deletions aidial_client/_auth.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from collections.abc import Awaitable, Callable
from inspect import isawaitable
from typing import Awaitable, Callable, Dict, Optional, TypeVar, Union
from typing import TypeVar

SyncAuthValue = Union[str, Callable[[], str]]
AsyncAuthValue = Union[SyncAuthValue, Callable[[], Awaitable[str]]]
SyncAuthValue = str | Callable[[], str]
AsyncAuthValue = SyncAuthValue | Callable[[], Awaitable[str]]

AuthValueT = TypeVar(
"AuthValueT",
bound=Union[SyncAuthValue, AsyncAuthValue],
bound=SyncAuthValue | AsyncAuthValue,
)


Expand All @@ -20,7 +21,8 @@ def get_auth_value(auth_value: SyncAuthValue) -> str:
if TYPE_CHECKING:
assert_never(auth_value)
raise TypeError(
f"auth_value must be a string or a callable returning a string, got {type(auth_value).__name__}"
f"auth_value must be a string or a callable returning a string, "
f"got {type(auth_value).__name__}"
)


Expand All @@ -35,16 +37,17 @@ async def aget_auth_value(auth_value: AsyncAuthValue) -> str:
if TYPE_CHECKING:
assert_never(auth_value)
raise TypeError(
f"auth_value must be a string or a callable, got {type(auth_value).__name__}"
"auth_value must be a string or a callable, "
f"got {type(auth_value).__name__}"
)


def get_combined_auth_headers(
*,
api_key: Optional[SyncAuthValue] = None,
bearer_token: Optional[SyncAuthValue] = None,
) -> Dict[str, str]:
headers: Dict[str, str] = {}
api_key: SyncAuthValue | None = None,
bearer_token: SyncAuthValue | None = None,
) -> dict[str, str]:
headers: dict[str, str] = {}

if api_key is not None:
headers["api-key"] = get_auth_value(api_key)
Expand All @@ -58,11 +61,14 @@ def get_combined_auth_headers(

async def aget_combined_auth_headers(
*,
api_key: Optional[AsyncAuthValue] = None,
bearer_token: Optional[AsyncAuthValue] = None,
) -> Dict[str, str]:
"""Get combined authentication headers from both api_key and bearer_token (async)."""
headers: Dict[str, str] = {}
api_key: AsyncAuthValue | None = None,
bearer_token: AsyncAuthValue | None = None,
) -> dict[str, str]:
"""
Get combined authentication headers from both api_key and
bearer_token (async).
"""
headers: dict[str, str] = {}

if api_key is not None:
processed_api_key = await aget_auth_value(api_key)
Expand All @@ -77,8 +83,8 @@ async def aget_combined_auth_headers(

def validate_auth(
*,
api_key: Optional[AsyncAuthValue] = None,
bearer_token: Optional[AsyncAuthValue] = None,
api_key: AsyncAuthValue | None = None,
bearer_token: AsyncAuthValue | None = None,
) -> None:
"""Validate that at least one authentication method is provided."""
if not api_key and not bearer_token:
Expand Down
46 changes: 21 additions & 25 deletions aidial_client/_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import PurePosixPath
from typing import Dict, Generic, Optional, TypeVar, Union
from typing import Generic, TypeVar
from urllib.parse import urljoin

import openai
Expand All @@ -24,30 +24,28 @@
from aidial_client.helpers._url import enforce_trailing_slash
from aidial_client.types.bucket import AppData

_HttpClientT = TypeVar(
"_HttpClientT", bound=Union[AsyncHTTPClient, SyncHTTPClient]
)
_HttpClientT = TypeVar("_HttpClientT", bound=AsyncHTTPClient | SyncHTTPClient)


class BaseDialClient(Generic[_HttpClientT, AuthValueT], ABC):
_api_key: Optional[AuthValueT]
_bearer_token: Optional[AuthValueT]
_api_key: AuthValueT | None
_bearer_token: AuthValueT | None
_base_url: str
_http_client: _HttpClientT
_auth_headers: Dict[str, str]
_my_bucket: Optional[str]
_my_appdata: Union[AppData, None, NotGiven]
_auth_headers: dict[str, str]
_my_bucket: str | None
_my_appdata: AppData | None | NotGiven

def __init__(
self,
*,
base_url: str,
api_key: Optional[AuthValueT] = None,
bearer_token: Optional[AuthValueT] = None,
api_key: AuthValueT | None = None,
bearer_token: AuthValueT | None = None,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: Union[float, Timeout, None] = DEFAULT_TIMEOUT,
api_version: Optional[str] = None,
http_client: Optional[_HttpClientT] = None,
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
api_version: str | None = None,
http_client: _HttpClientT | None = None,
):
validate_auth(api_key=api_key, bearer_token=bearer_token)
self._api_key = api_key
Expand Down Expand Up @@ -79,12 +77,11 @@ def base_url(self) -> str:
return self._base_url

@property
def api_version(self) -> Optional[str]:
def api_version(self) -> str | None:
return self._api_version


class Dial(BaseDialClient[SyncHTTPClient, SyncAuthValue]):

def _init_resources(self) -> None:
openai_client = openai.AzureOpenAI(
api_key="-",
Expand Down Expand Up @@ -150,26 +147,25 @@ def my_conversations_home(self) -> PurePosixPath:
def my_prompts_home(self) -> PurePosixPath:
return "prompts" / PurePosixPath(self.my_bucket())

def _get_my_appdata(self) -> Optional[AppData]:
def _get_my_appdata(self) -> AppData | None:
return self.bucket.get_appdata()

def my_appdata(self) -> Optional[AppData]:
def my_appdata(self) -> AppData | None:
if isinstance(self._my_appdata, NotGiven):
self._my_appdata = self._get_my_appdata()
return self._my_appdata

def my_appdata_home(self) -> Optional[PurePosixPath]:
def my_appdata_home(self) -> PurePosixPath | None:
appdata = self.my_appdata()
if appdata:
return PurePosixPath(appdata.raw)
return None

def auth_headers(self) -> Dict[str, str]:
def auth_headers(self) -> dict[str, str]:
return self._http_client.auth_headers()


class AsyncDial(BaseDialClient[AsyncHTTPClient, AsyncAuthValue]):

def _init_resources(self) -> None:
openai_client = openai.AsyncAzureOpenAI(
api_key="-",
Expand Down Expand Up @@ -241,19 +237,19 @@ async def my_conversations_home(self) -> PurePosixPath:
async def my_prompts_home(self) -> PurePosixPath:
return "prompts" / PurePosixPath(await self.my_bucket())

async def _get_my_appdata(self) -> Optional[AppData]:
async def _get_my_appdata(self) -> AppData | None:
return await self.bucket.get_appdata()

async def my_appdata(self) -> Optional[AppData]:
async def my_appdata(self) -> AppData | None:
if isinstance(self._my_appdata, NotGiven):
self._my_appdata = await self._get_my_appdata()
return self._my_appdata

async def my_appdata_home(self) -> Optional[PurePosixPath]:
async def my_appdata_home(self) -> PurePosixPath | None:
appdata = await self.my_appdata()
if appdata:
return PurePosixPath(appdata.raw)
return None

async def auth_headers(self) -> Dict[str, str]:
async def auth_headers(self) -> dict[str, str]:
return await self._http_client.auth_headers()
14 changes: 6 additions & 8 deletions aidial_client/_client_pool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional, Union

import httpx

from aidial_client._auth import AsyncAuthValue, SyncAuthValue
Expand Down Expand Up @@ -27,10 +25,10 @@ def create_client(
self,
*,
base_url: str,
api_key: Optional[SyncAuthValue] = None,
bearer_token: Optional[SyncAuthValue] = None,
api_key: SyncAuthValue | None = None,
bearer_token: SyncAuthValue | None = None,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: Union[httpx.Timeout, float] = DEFAULT_TIMEOUT,
timeout: httpx.Timeout | float = DEFAULT_TIMEOUT,
) -> Dial:
return Dial(
base_url=base_url,
Expand Down Expand Up @@ -62,10 +60,10 @@ def create_client(
self,
*,
base_url: str,
api_key: Optional[AsyncAuthValue] = None,
bearer_token: Optional[AsyncAuthValue] = None,
api_key: AsyncAuthValue | None = None,
bearer_token: AsyncAuthValue | None = None,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: Union[httpx.Timeout, float] = DEFAULT_TIMEOUT,
timeout: httpx.Timeout | float = DEFAULT_TIMEOUT,
) -> AsyncDial:
return AsyncDial(
base_url=base_url,
Expand Down
10 changes: 5 additions & 5 deletions aidial_client/_exception.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from collections.abc import Mapping
from http import HTTPStatus
from typing import Mapping, Optional


class DialException(Exception):
def __init__(
self,
message: str,
status_code: int = 500,
type: Optional[str] = "runtime_error",
param: Optional[str] = None,
code: Optional[str] = None,
display_message: Optional[str] = None,
type: str | None = "runtime_error",
param: str | None = None,
code: str | None = None,
display_message: str | None = None,
) -> None:
self.message = message
self.status_code = status_code
Expand Down
Loading
Loading