diff --git a/README.md b/README.md index 2be0da4..f561b46 100644 --- a/README.md +++ b/README.md @@ -591,6 +591,40 @@ FileMetadata( ### Prompts +#### Save Prompt + +Use `save()` to create or update a prompt by its storage path: + +```python +from aidial_client.types.prompt import Prompt + +prompt_url = "prompts/my-bucket/my-folder/my-prompt" +prompt_payload = Prompt( + id=prompt_url, + name="my-prompt", + folder_id="my-folder", + content="You are a helpful assistant.", +) + +# Sync +saved_prompt = client.prompts.save(prompt_url, prompt=prompt_payload) +# Async +saved_prompt = await async_client.prompts.save(prompt_url, prompt=prompt_payload) +``` + +As a result, you will receive a `PromptMetadata` object: + +```python +PromptMetadata( + name="my-prompt", + parent_path="my-folder", + bucket="my-bucket", + url="prompts/my-bucket/my-folder/my-prompt", + node_type="ITEM", + resource_type="PROMPT", +) +``` + #### Get Prompt Use `get()` to fetch a single prompt by its storage path: @@ -640,6 +674,17 @@ PromptMetadata( ) ``` +#### Delete Prompt + +Use `delete()` to remove a prompt by its storage path: + +```python +# Sync +client.prompts.delete("prompts/my-bucket/my-folder/my-prompt") +# Async +await async_client.prompts.delete("prompts/my-bucket/my-folder/my-prompt") +``` + ### Applications #### List Applications diff --git a/aidial_client/resources/prompts.py b/aidial_client/resources/prompts.py index 01c763e..9690d08 100644 --- a/aidial_client/resources/prompts.py +++ b/aidial_client/resources/prompts.py @@ -1,12 +1,19 @@ from pathlib import PurePosixPath -from typing import Optional, Union +from typing import Any, Dict, Literal, Optional, Union from urllib.parse import urljoin import httpx +from aidial_client._compatibility.pydantic import PYDANTIC_V2 from aidial_client._constants import API_PREFIX -from aidial_client._exception import DialException, ResourceNotFoundError +from aidial_client._exception import ( + DialException, + EtagMismatchError, + ResourceNotFoundError, +) +from aidial_client._internal_types._generic import NoneType from aidial_client._internal_types._http_request import FinalRequestOptions +from aidial_client._utils._dict import remove_none from aidial_client.helpers.storage_resource import DialStorageResourceMixin from aidial_client.resources.base import AsyncResource, Resource from aidial_client.resources.metadata import AsyncMetadata, Metadata @@ -17,17 +24,50 @@ def _prompts_error_processor( http_status_error: httpx.HTTPStatusError, ) -> Optional[DialException]: - if http_status_error.response.status_code == 404: + if http_status_error.response.status_code == 412: + return EtagMismatchError( + message=http_status_error.response.text, + ) + elif http_status_error.response.status_code == 404: return ResourceNotFoundError( message=http_status_error.response.text, ) return None +def _prompt_to_json(prompt: Prompt) -> Dict[str, Any]: + if PYDANTIC_V2: + return prompt.model_dump(by_alias=True) # type: ignore + return prompt.dict(by_alias=True) + + class Prompts(Resource, DialStorageResourceMixin): metadata: Metadata resource_type: str = "prompts" + def save( + self, + url: Union[str, PurePosixPath], + prompt: Prompt, + etag_if_match: Optional[str] = None, + etag_if_none_match: Optional[Literal["*"]] = None, + ) -> PromptMetadata: + return self.http_client.request( + cast_to=PromptMetadata, + options=FinalRequestOptions( + method="PUT", + url=urljoin(API_PREFIX, self.get_api_path(str(url))), + json_data=_prompt_to_json(prompt), + headers=remove_none( + { + "If-Match": etag_if_match, + "If-None-Match": etag_if_none_match, + } + ), + ), + on_http_error=_prompts_error_processor, + ) + def get(self, url: Union[str, PurePosixPath]) -> Prompt: """Fetch a single prompt by its storage path.""" return self.http_client.request( @@ -39,6 +79,25 @@ def get(self, url: Union[str, PurePosixPath]) -> Prompt: on_http_error=_prompts_error_processor, ) + def delete( + self, + url: Union[str, PurePosixPath], + etag_if_match: Optional[str] = None, + ) -> None: + return self.http_client.request( + cast_to=NoneType, + options=FinalRequestOptions( + method="DELETE", + url=urljoin(API_PREFIX, self.get_api_path(str(url))), + headers=remove_none( + { + "If-Match": etag_if_match, + } + ), + ), + on_http_error=_prompts_error_processor, + ) + def get_metadata(self, url: Union[str, PurePosixPath]) -> PromptMetadata: return self.metadata.get( resource="prompts", @@ -50,6 +109,29 @@ class AsyncPrompts(AsyncResource, DialStorageResourceMixin): metadata: AsyncMetadata resource_type: str = "prompts" + async def save( + self, + url: Union[str, PurePosixPath], + prompt: Prompt, + etag_if_match: Optional[str] = None, + etag_if_none_match: Optional[Literal["*"]] = None, + ) -> PromptMetadata: + return await self.http_client.request( + cast_to=PromptMetadata, + options=FinalRequestOptions( + method="PUT", + url=urljoin(API_PREFIX, self.get_api_path(str(url))), + json_data=_prompt_to_json(prompt), + headers=remove_none( + { + "If-Match": etag_if_match, + "If-None-Match": etag_if_none_match, + } + ), + ), + on_http_error=_prompts_error_processor, + ) + async def get(self, url: Union[str, PurePosixPath]) -> Prompt: """Fetch a single prompt by its storage path.""" return await self.http_client.request( @@ -61,6 +143,25 @@ async def get(self, url: Union[str, PurePosixPath]) -> Prompt: on_http_error=_prompts_error_processor, ) + async def delete( + self, + url: Union[str, PurePosixPath], + etag_if_match: Optional[str] = None, + ) -> None: + return await self.http_client.request( + cast_to=NoneType, + options=FinalRequestOptions( + method="DELETE", + url=urljoin(API_PREFIX, self.get_api_path(str(url))), + headers=remove_none( + { + "If-Match": etag_if_match, + } + ), + ), + on_http_error=_prompts_error_processor, + ) + async def get_metadata( self, url: Union[str, PurePosixPath] ) -> PromptMetadata: diff --git a/pyproject.toml b/pyproject.toml index 8cec8b6..da90a45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ exclude = [ [tool.black] line-length = 80 +target-version = ["py310"] exclude = ''' /( \.git diff --git a/tests/integration/test_async_prompts.py b/tests/integration/test_async_prompts.py new file mode 100644 index 0000000..4b54780 --- /dev/null +++ b/tests/integration/test_async_prompts.py @@ -0,0 +1,134 @@ +import uuid + +import pytest + +from aidial_client import AsyncDial +from aidial_client._exception import EtagMismatchError, ResourceNotFoundError +from aidial_client.types.metadata import PromptMetadata +from aidial_client.types.prompt import Prompt +from tests.integration.fixtures import * # type: ignore # noqa + +PROMPT_FOLDER = "test-folder-artifacts" + + +async def _delete_if_exists(async_client: AsyncDial, url: str) -> None: + try: + await async_client.prompts.delete(url) + except ResourceNotFoundError: + pass + + +def _create_prompt( + url: str, content: str = "You are a helpful assistant." +) -> Prompt: + path_parts = url.split("/") + name = path_parts[-1] + folder_id = "/".join(path_parts[2:-1]) + return Prompt(id=url, name=name, folder_id=folder_id, content=content) + + +def _get_etag_or_skip(metadata: PromptMetadata) -> str: + etag = getattr(metadata, "etag", None) + if not etag: + pytest.skip("Prompt metadata does not include etag in this environment") + return etag + + +@pytest.mark.asyncio +async def test_save_get_delete(async_client: AsyncDial): + prompt_name = f"test-prompt-{uuid.uuid4()}" + prompt_url = str( + await async_client.my_prompts_home() / f"{PROMPT_FOLDER}/{prompt_name}" + ) + await _delete_if_exists(async_client, prompt_url) + + save_result = await async_client.prompts.save( + url=prompt_url, prompt=_create_prompt(prompt_url) + ) + assert isinstance(save_result, PromptMetadata) + assert save_result.node_type == "ITEM" + assert save_result.bucket == await async_client.my_bucket() + assert save_result.name == prompt_name + + prompt = await async_client.prompts.get(prompt_url) + assert prompt.name == prompt_name + assert prompt.content == "You are a helpful assistant." + + await async_client.prompts.delete(prompt_url) + with pytest.raises(ResourceNotFoundError): + await async_client.prompts.get(prompt_url) + + +@pytest.mark.asyncio +async def test_save_with_etag_if_match(async_client: AsyncDial): + prompt_name = f"test-prompt-{uuid.uuid4()}" + prompt_url = str( + await async_client.my_prompts_home() / f"{PROMPT_FOLDER}/{prompt_name}" + ) + await _delete_if_exists(async_client, prompt_url) + + first_save = await async_client.prompts.save( + url=prompt_url, prompt=_create_prompt(prompt_url, content="v1") + ) + first_etag = _get_etag_or_skip(first_save) + + second_save = await async_client.prompts.save( + url=prompt_url, + prompt=_create_prompt(prompt_url, content="v2"), + etag_if_match=first_etag, + ) + assert _get_etag_or_skip(second_save) != first_etag + + with pytest.raises(EtagMismatchError): + await async_client.prompts.save( + url=prompt_url, + prompt=_create_prompt(prompt_url, content="v3"), + etag_if_match="invalid_etag", + ) + + +@pytest.mark.asyncio +async def test_save_with_etag_if_none_match(async_client: AsyncDial): + prompt_name = f"test-prompt-{uuid.uuid4()}" + prompt_url = str( + await async_client.my_prompts_home() / f"{PROMPT_FOLDER}/{prompt_name}" + ) + await _delete_if_exists(async_client, prompt_url) + + await async_client.prompts.save( + url=prompt_url, + prompt=_create_prompt(prompt_url), + etag_if_none_match="*", + ) + + with pytest.raises(EtagMismatchError): + await async_client.prompts.save( + url=prompt_url, + prompt=_create_prompt(prompt_url, content="v2"), + etag_if_none_match="*", + ) + + +@pytest.mark.asyncio +async def test_delete_with_etag(async_client: AsyncDial): + prompt_name = f"test-prompt-{uuid.uuid4()}" + prompt_url = str( + await async_client.my_prompts_home() / f"{PROMPT_FOLDER}/{prompt_name}" + ) + await _delete_if_exists(async_client, prompt_url) + + save_result = await async_client.prompts.save( + url=prompt_url, prompt=_create_prompt(prompt_url) + ) + etag = _get_etag_or_skip(save_result) + + with pytest.raises(EtagMismatchError): + await async_client.prompts.delete( + url=prompt_url, + etag_if_match="invalid_etag", + ) + + await async_client.prompts.delete( + url=prompt_url, + etag_if_match=etag, + ) diff --git a/tests/integration/test_sync_prompts.py b/tests/integration/test_sync_prompts.py new file mode 100644 index 0000000..d591f23 --- /dev/null +++ b/tests/integration/test_sync_prompts.py @@ -0,0 +1,130 @@ +import uuid + +import pytest + +from aidial_client import Dial +from aidial_client._exception import EtagMismatchError, ResourceNotFoundError +from aidial_client.types.metadata import PromptMetadata +from aidial_client.types.prompt import Prompt +from tests.integration.fixtures import * # type: ignore # noqa + +PROMPT_FOLDER = "test-folder-artifacts" + + +def _delete_if_exists(sync_client: Dial, url: str) -> None: + try: + sync_client.prompts.delete(url) + except ResourceNotFoundError: + pass + + +def _create_prompt( + url: str, content: str = "You are a helpful assistant." +) -> Prompt: + path_parts = url.split("/") + name = path_parts[-1] + folder_id = "/".join(path_parts[2:-1]) + return Prompt(id=url, name=name, folder_id=folder_id, content=content) + + +def _get_etag_or_skip(metadata: PromptMetadata) -> str: + etag = getattr(metadata, "etag", None) + if not etag: + pytest.skip("Prompt metadata does not include etag in this environment") + return etag + + +def test_save_get_delete(sync_client: Dial): + prompt_name = f"test-prompt-{uuid.uuid4()}" + prompt_url = str( + sync_client.my_prompts_home() / f"{PROMPT_FOLDER}/{prompt_name}" + ) + _delete_if_exists(sync_client, prompt_url) + + save_result = sync_client.prompts.save( + url=prompt_url, prompt=_create_prompt(prompt_url) + ) + assert isinstance(save_result, PromptMetadata) + assert save_result.node_type == "ITEM" + assert save_result.bucket == sync_client.my_bucket() + assert save_result.name == prompt_name + + prompt = sync_client.prompts.get(prompt_url) + assert prompt.name == prompt_name + assert prompt.content == "You are a helpful assistant." + + sync_client.prompts.delete(prompt_url) + with pytest.raises(ResourceNotFoundError): + sync_client.prompts.get(prompt_url) + + +def test_save_with_etag_if_match(sync_client: Dial): + prompt_name = f"test-prompt-{uuid.uuid4()}" + prompt_url = str( + sync_client.my_prompts_home() / f"{PROMPT_FOLDER}/{prompt_name}" + ) + _delete_if_exists(sync_client, prompt_url) + + first_save = sync_client.prompts.save( + url=prompt_url, prompt=_create_prompt(prompt_url, content="v1") + ) + first_etag = _get_etag_or_skip(first_save) + + second_save = sync_client.prompts.save( + url=prompt_url, + prompt=_create_prompt(prompt_url, content="v2"), + etag_if_match=first_etag, + ) + assert _get_etag_or_skip(second_save) != first_etag + + with pytest.raises(EtagMismatchError): + sync_client.prompts.save( + url=prompt_url, + prompt=_create_prompt(prompt_url, content="v3"), + etag_if_match="invalid_etag", + ) + + +def test_save_with_etag_if_none_match(sync_client: Dial): + prompt_name = f"test-prompt-{uuid.uuid4()}" + prompt_url = str( + sync_client.my_prompts_home() / f"{PROMPT_FOLDER}/{prompt_name}" + ) + _delete_if_exists(sync_client, prompt_url) + + sync_client.prompts.save( + url=prompt_url, + prompt=_create_prompt(prompt_url), + etag_if_none_match="*", + ) + + with pytest.raises(EtagMismatchError): + sync_client.prompts.save( + url=prompt_url, + prompt=_create_prompt(prompt_url, content="v2"), + etag_if_none_match="*", + ) + + +def test_delete_with_etag(sync_client: Dial): + prompt_name = f"test-prompt-{uuid.uuid4()}" + prompt_url = str( + sync_client.my_prompts_home() / f"{PROMPT_FOLDER}/{prompt_name}" + ) + _delete_if_exists(sync_client, prompt_url) + + save_result = sync_client.prompts.save( + url=prompt_url, prompt=_create_prompt(prompt_url) + ) + etag = _get_etag_or_skip(save_result) + + with pytest.raises(EtagMismatchError): + sync_client.prompts.delete( + url=prompt_url, + etag_if_match="invalid_etag", + ) + + sync_client.prompts.delete( + url=prompt_url, + etag_if_match=etag, + ) diff --git a/tests/resources/test_prompts.py b/tests/resources/test_prompts.py index 2e91737..9bf6453 100644 --- a/tests/resources/test_prompts.py +++ b/tests/resources/test_prompts.py @@ -1,6 +1,18 @@ +import json +from typing import Any, Dict, List +from unittest.mock import AsyncMock, Mock + +import httpx import pytest -from aidial_client._exception import DialException, ResourceNotFoundError +from aidial_client import Dial +from aidial_client._client import AsyncDial +from aidial_client._exception import ( + DialException, + EtagMismatchError, + InvalidDialURLError, + ResourceNotFoundError, +) from aidial_client.types.metadata import PromptMetadata from aidial_client.types.prompt import Prompt from tests.client_mock import get_async_client_mock, get_client_mock @@ -29,6 +41,44 @@ } +def _make_capturing_client(captured: List[httpx.Request]) -> Dial: + client = Dial(api_key="dummy", base_url="http://dial.core") + + def send_mock(request: httpx.Request, **_: Any) -> httpx.Response: + captured.append(request) + response = httpx.Response( + status_code=200, request=request, json=PROMPT_METADATA_MOCK + ) + response.request = request + return response + + client._http_client._internal_http_client.send = send_mock + client._get_my_bucket = Mock(return_value="test-bucket") + return client + + +def _body(request: httpx.Request) -> Dict[str, Any]: + return json.loads(request.content.decode()) + + +def _make_async_capturing_client( + captured: List[httpx.Request], +) -> AsyncDial: + client = AsyncDial(api_key="dummy", base_url="http://dial.core") + + async def send_mock(request: httpx.Request, **_: Any) -> httpx.Response: + captured.append(request) + response = httpx.Response( + status_code=200, request=request, json=PROMPT_METADATA_MOCK + ) + response.request = request + return response + + client._http_client._internal_http_client.send = send_mock + client._get_my_bucket = AsyncMock(return_value="test-bucket") + return client + + # --------------------------------------------------------------------------- # prompts.get() # --------------------------------------------------------------------------- @@ -146,3 +196,263 @@ async def test_async_get_prompt_metadata(): assert isinstance(result, PromptMetadata) assert result.node_type == "ITEM" assert result.bucket == "test-bucket" + + +# --------------------------------------------------------------------------- +# prompts.save() +# --------------------------------------------------------------------------- + + +def test_save_prompt(): + client = get_client_mock(status_code=200, json_mock=PROMPT_METADATA_MOCK) + prompt = Prompt(**PROMPT_MOCK) + + result = client.prompts.save( + "prompts/test-bucket/my-folder/my-prompt", prompt=prompt + ) + + assert isinstance(result, PromptMetadata) + assert result.node_type == "ITEM" + assert result.bucket == "test-bucket" + + +@pytest.mark.asyncio +async def test_async_save_prompt(): + client = get_async_client_mock( + status_code=200, json_mock=PROMPT_METADATA_MOCK + ) + prompt = Prompt(**PROMPT_MOCK) + + result = await client.prompts.save( + "prompts/test-bucket/my-folder/my-prompt", prompt=prompt + ) + + assert isinstance(result, PromptMetadata) + assert result.node_type == "ITEM" + assert result.bucket == "test-bucket" + + +def test_save_prompt_sends_json_and_etag_headers(): + captured: List[httpx.Request] = [] + client = _make_capturing_client(captured) + prompt = Prompt(**PROMPT_MOCK) + + result = client.prompts.save( + url=client.my_prompts_home() / "my-folder/my-prompt", + prompt=prompt, + etag_if_match="etag-1", + etag_if_none_match="*", + ) + + assert isinstance(result, PromptMetadata) + assert len(captured) == 1 + request = captured[0] + assert request.method == "PUT" + assert request.url.path == "/v1/prompts/test-bucket/my-folder/my-prompt" + assert request.headers["if-match"] == "etag-1" + assert request.headers["if-none-match"] == "*" + assert _body(request) == { + "id": "prompts/test-bucket/my-folder/my-prompt", + "name": "my-prompt", + "folderId": "my-folder", + "content": "You are a helpful assistant.", + } + + +@pytest.mark.asyncio +async def test_async_save_prompt_sends_json_and_etag_headers(): + captured: List[httpx.Request] = [] + client = _make_async_capturing_client(captured) + prompt = Prompt(**PROMPT_MOCK) + + result = await client.prompts.save( + url=await client.my_prompts_home() / "my-folder/my-prompt", + prompt=prompt, + etag_if_match="etag-1", + etag_if_none_match="*", + ) + + assert isinstance(result, PromptMetadata) + assert len(captured) == 1 + request = captured[0] + assert request.method == "PUT" + assert request.url.path == "/v1/prompts/test-bucket/my-folder/my-prompt" + assert request.headers["if-match"] == "etag-1" + assert request.headers["if-none-match"] == "*" + assert _body(request) == { + "id": "prompts/test-bucket/my-folder/my-prompt", + "name": "my-prompt", + "folderId": "my-folder", + "content": "You are a helpful assistant.", + } + + +def test_save_prompt_etag_mismatch(): + client = get_client_mock( + status_code=412, + json_mock={ + "error": { + "message": "Precondition Failed", + "type": "etag_mismatch", + } + }, + ) + prompt = Prompt(**PROMPT_MOCK) + + with pytest.raises(EtagMismatchError): + client.prompts.save( + "prompts/test-bucket/my-folder/my-prompt", + prompt=prompt, + etag_if_match="invalid_etag", + ) + + +@pytest.mark.asyncio +async def test_async_save_prompt_etag_mismatch(): + client = get_async_client_mock( + status_code=412, + json_mock={ + "error": { + "message": "Precondition Failed", + "type": "etag_mismatch", + } + }, + ) + prompt = Prompt(**PROMPT_MOCK) + + with pytest.raises(EtagMismatchError): + await client.prompts.save( + "prompts/test-bucket/my-folder/my-prompt", + prompt=prompt, + etag_if_match="invalid_etag", + ) + + +def test_save_prompt_rejects_non_prompt_url(): + client = get_client_mock(status_code=200, json_mock=PROMPT_METADATA_MOCK) + prompt = Prompt(**PROMPT_MOCK) + + with pytest.raises(InvalidDialURLError, match="Invalid resource type"): + client.prompts.save("files/test-bucket/my-folder/my-prompt", prompt) + + +@pytest.mark.asyncio +async def test_async_save_prompt_rejects_non_prompt_url(): + client = get_async_client_mock( + status_code=200, json_mock=PROMPT_METADATA_MOCK + ) + prompt = Prompt(**PROMPT_MOCK) + + with pytest.raises(InvalidDialURLError, match="Invalid resource type"): + await client.prompts.save( + "files/test-bucket/my-folder/my-prompt", prompt + ) + + +# --------------------------------------------------------------------------- +# prompts.delete() +# --------------------------------------------------------------------------- + + +def test_delete_prompt(): + client = get_client_mock(status_code=200, json_mock={}) + + result = client.prompts.delete("prompts/test-bucket/my-folder/my-prompt") + assert result is None + + +@pytest.mark.asyncio +async def test_async_delete_prompt(): + client = get_async_client_mock(status_code=200, json_mock={}) + + result = await client.prompts.delete( + "prompts/test-bucket/my-folder/my-prompt" + ) + assert result is None + + +def test_delete_prompt_sends_etag_header(): + captured: List[httpx.Request] = [] + client = _make_capturing_client(captured) + + result = client.prompts.delete( + url=client.my_prompts_home() / "my-folder/my-prompt", + etag_if_match="etag-1", + ) + + assert result is None + assert len(captured) == 1 + request = captured[0] + assert request.method == "DELETE" + assert request.url.path == "/v1/prompts/test-bucket/my-folder/my-prompt" + assert request.headers["if-match"] == "etag-1" + + +@pytest.mark.asyncio +async def test_async_delete_prompt_sends_etag_header(): + captured: List[httpx.Request] = [] + client = _make_async_capturing_client(captured) + + result = await client.prompts.delete( + url=await client.my_prompts_home() / "my-folder/my-prompt", + etag_if_match="etag-1", + ) + + assert result is None + assert len(captured) == 1 + request = captured[0] + assert request.method == "DELETE" + assert request.url.path == "/v1/prompts/test-bucket/my-folder/my-prompt" + assert request.headers["if-match"] == "etag-1" + + +def test_delete_prompt_etag_mismatch(): + client = get_client_mock( + status_code=412, + json_mock={ + "error": { + "message": "Precondition Failed", + "type": "etag_mismatch", + } + }, + ) + + with pytest.raises(EtagMismatchError): + client.prompts.delete( + "prompts/test-bucket/my-folder/my-prompt", + etag_if_match="invalid_etag", + ) + + +@pytest.mark.asyncio +async def test_async_delete_prompt_etag_mismatch(): + client = get_async_client_mock( + status_code=412, + json_mock={ + "error": { + "message": "Precondition Failed", + "type": "etag_mismatch", + } + }, + ) + + with pytest.raises(EtagMismatchError): + await client.prompts.delete( + "prompts/test-bucket/my-folder/my-prompt", + etag_if_match="invalid_etag", + ) + + +def test_delete_prompt_rejects_non_prompt_url(): + client = get_client_mock(status_code=200, json_mock={}) + + with pytest.raises(InvalidDialURLError, match="Invalid resource type"): + client.prompts.delete("files/test-bucket/my-folder/my-prompt") + + +@pytest.mark.asyncio +async def test_async_delete_prompt_rejects_non_prompt_url(): + client = get_async_client_mock(status_code=200, json_mock={}) + + with pytest.raises(InvalidDialURLError, match="Invalid resource type"): + await client.prompts.delete("files/test-bucket/my-folder/my-prompt")