Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,40 @@ FileMetadata(

### Prompts

#### Save Prompt

Use `save()` to create or update a prompt by its storage path:

```python
from aidial_client.types.prompt import Prompt

prompt_url = "prompts/my-bucket/my-folder/my-prompt"
prompt_payload = Prompt(
id=prompt_url,
name="my-prompt",
folder_id="my-folder",
content="You are a helpful assistant.",
)

# Sync
saved_prompt = client.prompts.save(prompt_url, prompt=prompt_payload)
# Async
saved_prompt = await async_client.prompts.save(prompt_url, prompt=prompt_payload)
```

As a result, you will receive a `PromptMetadata` object:

```python
PromptMetadata(
name="my-prompt",
parent_path="my-folder",
bucket="my-bucket",
url="prompts/my-bucket/my-folder/my-prompt",
node_type="ITEM",
resource_type="PROMPT",
)
```

#### Get Prompt

Use `get()` to fetch a single prompt by its storage path:
Expand Down Expand Up @@ -640,6 +674,17 @@ PromptMetadata(
)
```

#### Delete Prompt

Use `delete()` to remove a prompt by its storage path:

```python
# Sync
client.prompts.delete("prompts/my-bucket/my-folder/my-prompt")
# Async
await async_client.prompts.delete("prompts/my-bucket/my-folder/my-prompt")
```

### Applications

#### List Applications
Expand Down
107 changes: 104 additions & 3 deletions aidial_client/resources/prompts.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from pathlib import PurePosixPath
from typing import Optional, Union
from typing import Any, Dict, Literal, Optional, Union
from urllib.parse import urljoin

import httpx

from aidial_client._compatibility.pydantic import PYDANTIC_V2
from aidial_client._constants import API_PREFIX
from aidial_client._exception import DialException, ResourceNotFoundError
from aidial_client._exception import (
DialException,
EtagMismatchError,
ResourceNotFoundError,
)
from aidial_client._internal_types._generic import NoneType
from aidial_client._internal_types._http_request import FinalRequestOptions
from aidial_client._utils._dict import remove_none
from aidial_client.helpers.storage_resource import DialStorageResourceMixin
from aidial_client.resources.base import AsyncResource, Resource
from aidial_client.resources.metadata import AsyncMetadata, Metadata
Expand All @@ -17,17 +24,50 @@
def _prompts_error_processor(
http_status_error: httpx.HTTPStatusError,
) -> Optional[DialException]:
if http_status_error.response.status_code == 404:
if http_status_error.response.status_code == 412:
return EtagMismatchError(
message=http_status_error.response.text,
)
elif http_status_error.response.status_code == 404:
return ResourceNotFoundError(
message=http_status_error.response.text,
)
return None


def _prompt_to_json(prompt: Prompt) -> Dict[str, Any]:
if PYDANTIC_V2:
return prompt.model_dump(by_alias=True) # type: ignore
return prompt.dict(by_alias=True)


class Prompts(Resource, DialStorageResourceMixin):
metadata: Metadata
resource_type: str = "prompts"

def save(
self,
url: Union[str, PurePosixPath],
prompt: Prompt,
etag_if_match: Optional[str] = None,
etag_if_none_match: Optional[Literal["*"]] = None,
) -> PromptMetadata:
return self.http_client.request(
cast_to=PromptMetadata,
options=FinalRequestOptions(
method="PUT",
url=urljoin(API_PREFIX, self.get_api_path(str(url))),
json_data=_prompt_to_json(prompt),
headers=remove_none(
{
"If-Match": etag_if_match,
"If-None-Match": etag_if_none_match,
}
),
),
on_http_error=_prompts_error_processor,
)

def get(self, url: Union[str, PurePosixPath]) -> Prompt:
"""Fetch a single prompt by its storage path."""
return self.http_client.request(
Expand All @@ -39,6 +79,25 @@ def get(self, url: Union[str, PurePosixPath]) -> Prompt:
on_http_error=_prompts_error_processor,
)

def delete(
self,
url: Union[str, PurePosixPath],
etag_if_match: Optional[str] = None,
) -> None:
return self.http_client.request(
cast_to=NoneType,
options=FinalRequestOptions(
method="DELETE",
url=urljoin(API_PREFIX, self.get_api_path(str(url))),
headers=remove_none(
{
"If-Match": etag_if_match,
}
),
),
on_http_error=_prompts_error_processor,
)

def get_metadata(self, url: Union[str, PurePosixPath]) -> PromptMetadata:
return self.metadata.get(
resource="prompts",
Expand All @@ -50,6 +109,29 @@ class AsyncPrompts(AsyncResource, DialStorageResourceMixin):
metadata: AsyncMetadata
resource_type: str = "prompts"

async def save(
self,
url: Union[str, PurePosixPath],
prompt: Prompt,
etag_if_match: Optional[str] = None,
etag_if_none_match: Optional[Literal["*"]] = None,
) -> PromptMetadata:
return await self.http_client.request(
cast_to=PromptMetadata,
options=FinalRequestOptions(
method="PUT",
url=urljoin(API_PREFIX, self.get_api_path(str(url))),
json_data=_prompt_to_json(prompt),
headers=remove_none(
{
"If-Match": etag_if_match,
"If-None-Match": etag_if_none_match,
}
),
),
on_http_error=_prompts_error_processor,
)

async def get(self, url: Union[str, PurePosixPath]) -> Prompt:
"""Fetch a single prompt by its storage path."""
return await self.http_client.request(
Expand All @@ -61,6 +143,25 @@ async def get(self, url: Union[str, PurePosixPath]) -> Prompt:
on_http_error=_prompts_error_processor,
)

async def delete(
self,
url: Union[str, PurePosixPath],
etag_if_match: Optional[str] = None,
) -> None:
return await self.http_client.request(
cast_to=NoneType,
options=FinalRequestOptions(
method="DELETE",
url=urljoin(API_PREFIX, self.get_api_path(str(url))),
headers=remove_none(
{
"If-Match": etag_if_match,
}
),
),
on_http_error=_prompts_error_processor,
)

async def get_metadata(
self, url: Union[str, PurePosixPath]
) -> PromptMetadata:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ exclude = [

[tool.black]
line-length = 80
target-version = ["py310"]
exclude = '''
/(
\.git
Expand Down
134 changes: 134 additions & 0 deletions tests/integration/test_async_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import uuid

import pytest

from aidial_client import AsyncDial
from aidial_client._exception import EtagMismatchError, ResourceNotFoundError
from aidial_client.types.metadata import PromptMetadata
from aidial_client.types.prompt import Prompt
from tests.integration.fixtures import * # type: ignore # noqa

PROMPT_FOLDER = "test-folder-artifacts"


async def _delete_if_exists(async_client: AsyncDial, url: str) -> None:
try:
await async_client.prompts.delete(url)
except ResourceNotFoundError:
pass


def _create_prompt(
url: str, content: str = "You are a helpful assistant."
) -> Prompt:
path_parts = url.split("/")
name = path_parts[-1]
folder_id = "/".join(path_parts[2:-1])
return Prompt(id=url, name=name, folder_id=folder_id, content=content)


def _get_etag_or_skip(metadata: PromptMetadata) -> str:
etag = getattr(metadata, "etag", None)
if not etag:
pytest.skip("Prompt metadata does not include etag in this environment")
return etag


@pytest.mark.asyncio
async def test_save_get_delete(async_client: AsyncDial):
prompt_name = f"test-prompt-{uuid.uuid4()}"
prompt_url = str(
await async_client.my_prompts_home() / f"{PROMPT_FOLDER}/{prompt_name}"
)
await _delete_if_exists(async_client, prompt_url)

save_result = await async_client.prompts.save(
url=prompt_url, prompt=_create_prompt(prompt_url)
)
assert isinstance(save_result, PromptMetadata)
assert save_result.node_type == "ITEM"
assert save_result.bucket == await async_client.my_bucket()
assert save_result.name == prompt_name

prompt = await async_client.prompts.get(prompt_url)
assert prompt.name == prompt_name
assert prompt.content == "You are a helpful assistant."

await async_client.prompts.delete(prompt_url)
with pytest.raises(ResourceNotFoundError):
await async_client.prompts.get(prompt_url)


@pytest.mark.asyncio
async def test_save_with_etag_if_match(async_client: AsyncDial):
prompt_name = f"test-prompt-{uuid.uuid4()}"
prompt_url = str(
await async_client.my_prompts_home() / f"{PROMPT_FOLDER}/{prompt_name}"
)
await _delete_if_exists(async_client, prompt_url)

first_save = await async_client.prompts.save(
url=prompt_url, prompt=_create_prompt(prompt_url, content="v1")
)
first_etag = _get_etag_or_skip(first_save)

second_save = await async_client.prompts.save(
url=prompt_url,
prompt=_create_prompt(prompt_url, content="v2"),
etag_if_match=first_etag,
)
assert _get_etag_or_skip(second_save) != first_etag

with pytest.raises(EtagMismatchError):
await async_client.prompts.save(
url=prompt_url,
prompt=_create_prompt(prompt_url, content="v3"),
etag_if_match="invalid_etag",
)


@pytest.mark.asyncio
async def test_save_with_etag_if_none_match(async_client: AsyncDial):
prompt_name = f"test-prompt-{uuid.uuid4()}"
prompt_url = str(
await async_client.my_prompts_home() / f"{PROMPT_FOLDER}/{prompt_name}"
)
await _delete_if_exists(async_client, prompt_url)

await async_client.prompts.save(
url=prompt_url,
prompt=_create_prompt(prompt_url),
etag_if_none_match="*",
)

with pytest.raises(EtagMismatchError):
await async_client.prompts.save(
url=prompt_url,
prompt=_create_prompt(prompt_url, content="v2"),
etag_if_none_match="*",
)


@pytest.mark.asyncio
async def test_delete_with_etag(async_client: AsyncDial):
prompt_name = f"test-prompt-{uuid.uuid4()}"
prompt_url = str(
await async_client.my_prompts_home() / f"{PROMPT_FOLDER}/{prompt_name}"
)
await _delete_if_exists(async_client, prompt_url)

save_result = await async_client.prompts.save(
url=prompt_url, prompt=_create_prompt(prompt_url)
)
etag = _get_etag_or_skip(save_result)

with pytest.raises(EtagMismatchError):
await async_client.prompts.delete(
url=prompt_url,
etag_if_match="invalid_etag",
)

await async_client.prompts.delete(
url=prompt_url,
etag_if_match=etag,
)
Loading
Loading