From 5f87c9bffdbbf69ccad94c56158654000937627e Mon Sep 17 00:00:00 2001 From: Daniil Yarmalkevich Date: Thu, 28 May 2026 12:58:40 +0300 Subject: [PATCH 1/2] feat: add client_channel.interact() for JSON-RPC over SSE #87 Implements the /v1/ops/client-channel/interact endpoint as a thin transport in the DIAL Python client, so other apps no longer need to re-implement the SSE + JSON-RPC plumbing currently duplicated in ai-dial-quickapps-backend. - types/client_channel.py: JsonRpcRequest/Response/Error (pydantic v1/v2 compatible, smart_union to preserve int ids). - resources/client_channel.py: ClientChannel/AsyncClientChannel.interact() taking a single request or a batch, returning List[JsonRpcResponse]. Rejects id=None (would be a JSON-RPC notification with no response). - _http_client/_sse.py: minimal SSE data-event parser (sync + async). - _http_client/{_sync,_async}.py: stream_sse() context manager wrapping httpx streaming responses with auth, timeout sentinel, and DialException translation for transport errors. --- README.md | 59 ++++ aidial_client/__init__.py | 8 + aidial_client/_client.py | 6 + aidial_client/_http_client/_async.py | 62 +++- aidial_client/_http_client/_sse.py | 36 +++ aidial_client/_http_client/_sync.py | 53 +++- aidial_client/resources/__init__.py | 6 + aidial_client/resources/client_channel.py | 134 +++++++++ aidial_client/types/client_channel.py | 33 +++ tests/resources/test_client_channel.py | 342 ++++++++++++++++++++++ 10 files changed, 737 insertions(+), 2 deletions(-) create mode 100644 aidial_client/_http_client/_sse.py create mode 100644 aidial_client/resources/client_channel.py create mode 100644 aidial_client/types/client_channel.py create mode 100644 tests/resources/test_client_channel.py diff --git a/README.md b/README.md index f561b46..2149cf9 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,8 @@ - [Get Toolset by Id](#get-toolset-by-id) - [Resource Permissions](#resource-permissions) - [Grant Permissions](#grant-permissions) + - [Client Channel](#client-channel) + - [Send a JSON-RPC Request](#send-a-json-rpc-request) - [Client Pool](#client-pool) - [Synchronous Client Pool](#synchronous-client-pool) - [Asynchronous Client Pool](#asynchronous-client-pool) @@ -854,6 +856,63 @@ await async_client.resource_permissions.grant( The method returns `None` on success and raises `DialException` on HTTP error. +### Client Channel + +DIAL Core's client channel API (`/v1/ops/client-channel/*`) lets deployments send JSON-RPC requests to an interactive client (e.g. the chat UI) and receive their responses. The client must be subscribed to a channel; the channel id is propagated to the deployment via the `X-DIAL-CLIENT-CHANNEL-ID` forwarded header on the inbound request. + +#### Send a JSON-RPC Request + +Use `client_channel.interact()` to send a JSON-RPC request (or a batch) to the channel and wait for the response(s). The method returns a `List[JsonRpcResponse]` containing one entry per request. Server-emitted order is **not** guaranteed to match request order — correlate each response with its request via the `id` field, not by positional index. Every request must have a non-`None` `id`; JSON-RPC notifications are not supported here (the server does not respond to them, so `interact()` would block). + +```python +from aidial_client import JsonRpcRequest + +# Sync — single request +responses = client.client_channel.interact( + channel_id="", + request=JsonRpcRequest( + method="toolset/signin", + params={"toolsetId": "toolsets/public/my-toolset"}, + id="1", + ), + timeout=120.0, +) + +# Async — batched requests +responses = await async_client.client_channel.interact( + channel_id="", + request=[ + JsonRpcRequest( + method="toolset/signin", + params={"toolsetId": "toolsets/public/toolset-a"}, + id="1", + ), + JsonRpcRequest( + method="toolset/signin", + params={"toolsetId": "toolsets/public/toolset-b"}, + id="2", + ), + ], +) +``` + +Each element is a `JsonRpcResponse`: + +```python +JsonRpcResponse( + jsonrpc="2.0", + result="success", # or None, with `error` populated instead + error=None, # JsonRpcError(code=..., message=..., data=...) on failure + id="1", +) +``` + +- `channel_id` — required; the channel id received via the `X-DIAL-CLIENT-CHANNEL-ID` header on the inbound request. +- `request` — a single `JsonRpcRequest` or a list of them. The wire body is an object or an array accordingly. +- `timeout` — wall-clock timeout in seconds (or an `httpx.Timeout`); defaults to the client-wide timeout. Useful for interactive flows where the channel may take a while to respond. + +Raises `DialException` if the HTTP status is not 2xx (e.g. unauthorized, missing channel) or if the stream closes before a data event arrives. Raises `ParsingDataError` if the response payload is not a valid JSON-RPC response. + ### Client Pool When you need to create multiple DIAL clients and wish to enhance performance by reusing the HTTP connection for the same DIAL instance, consider using synchronous and asynchronous **client pools**. diff --git a/aidial_client/__init__.py b/aidial_client/__init__.py index 397016c..c023459 100644 --- a/aidial_client/__init__.py +++ b/aidial_client/__init__.py @@ -9,6 +9,11 @@ ParsingDataError, ResourceNotFoundError, ) +from aidial_client.types.client_channel import ( + JsonRpcError, + JsonRpcRequest, + JsonRpcResponse, +) from aidial_client.types.model import ModelInfo, ModelLimits, ModelPricing from aidial_client.types.toolset import ToolsetInfo @@ -30,4 +35,7 @@ "ModelInfo", "ModelPricing", "ModelLimits", + "JsonRpcRequest", + "JsonRpcResponse", + "JsonRpcError", ] diff --git a/aidial_client/_client.py b/aidial_client/_client.py index 3486c7e..5c4907d 100644 --- a/aidial_client/_client.py +++ b/aidial_client/_client.py @@ -119,6 +119,9 @@ def _init_resources(self) -> None: self.resource_permissions = resources.ResourcePermissions( http_client=self._http_client ) + self.client_channel = resources.ClientChannel( + http_client=self._http_client + ) def _create_http_client(self) -> SyncHTTPClient: return SyncHTTPClient( @@ -207,6 +210,9 @@ def _init_resources(self) -> None: self.resource_permissions = resources.AsyncResourcePermissions( http_client=self._http_client ) + self.client_channel = resources.AsyncClientChannel( + http_client=self._http_client + ) def _create_http_client(self) -> AsyncHTTPClient: return AsyncHTTPClient( diff --git a/aidial_client/_http_client/_async.py b/aidial_client/_http_client/_async.py index 8417940..78406b2 100644 --- a/aidial_client/_http_client/_async.py +++ b/aidial_client/_http_client/_async.py @@ -1,12 +1,23 @@ import asyncio +from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Callable, Dict, Optional, Type +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Mapping, + Optional, + Type, + Union, +) import httpx from aidial_client._auth import AsyncAuthValue, aget_combined_auth_headers from aidial_client._exception import DialException from aidial_client._http_client._base import BaseHTTPClient +from aidial_client._internal_types._defaults import NOT_GIVEN, NotGiven from aidial_client._internal_types._generic import ResponseT from aidial_client._internal_types._http_request import FinalRequestOptions from aidial_client._log import logger @@ -108,3 +119,52 @@ async def request( raise raised_error from err return process_block_response(cast_to=cast_to, response=response) + + @asynccontextmanager + async def stream_sse( + self, + *, + method: str, + url: str, + json_data: Any, + headers: Optional[Mapping[str, str]] = None, + timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, + ) -> AsyncIterator[httpx.Response]: + """Open an SSE streaming response. Yields the open httpx.Response. + + Auth headers are merged in. On non-2xx, reads the body and raises + a DialException; transport errors (timeouts, network failures) are + also wrapped so the caller always sees DialException. Retries are + not performed for streaming requests. + + ``timeout`` defaults to the client-wide timeout; pass an explicit + ``None`` (or ``httpx.Timeout(None)``) for no timeout. + """ + merged_headers = {**(await self.auth_headers()), **(headers or {})} + effective_timeout = ( + self._timeout if isinstance(timeout, NotGiven) else timeout + ) + try: + async with self._internal_http_client.stream( + method=method, + url=self._prepare_url(url), + headers=merged_headers, + json=json_data, + timeout=effective_timeout, + ) as response: + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: + try: + await response.aread() + except httpx.HTTPError: + pass + raise self._make_dial_error_from_response( + err.response + ) from err + yield response + except httpx.TimeoutException as err: + raise DialException( + message="Request timed out", + status_code=HTTPStatus.REQUEST_TIMEOUT, + ) from err diff --git a/aidial_client/_http_client/_sse.py b/aidial_client/_http_client/_sse.py new file mode 100644 index 0000000..ecbc4f1 --- /dev/null +++ b/aidial_client/_http_client/_sse.py @@ -0,0 +1,36 @@ +from typing import AsyncIterator, Iterator, List + + +def _strip_field(line: str, prefix: str) -> str: + """Strip a single leading U+0020 SPACE after the field colon, per the SSE spec.""" + value = line[len(prefix) :] + return value[1:] if value.startswith(" ") else value + + +def iter_data_events(lines: Iterator[str]) -> Iterator[str]: + """Yield the payload of each complete ``data:`` event from an SSE line stream. + + An event is complete when a blank line follows the ``data:`` line(s). Per + the SSE dispatch rule, a buffer that has not been terminated by a blank + line is discarded (we do NOT flush partial events at end of stream). + Comment lines (``:``) and other field names are ignored. + """ + buffer: List[str] = [] + for line in lines: + if line == "": + if buffer: + yield "\n".join(buffer) + buffer = [] + elif line.startswith("data:"): + buffer.append(_strip_field(line, "data:")) + + +async def aiter_data_events(lines: AsyncIterator[str]) -> AsyncIterator[str]: + buffer: List[str] = [] + async for line in lines: + if line == "": + if buffer: + yield "\n".join(buffer) + buffer = [] + elif line.startswith("data:"): + buffer.append(_strip_field(line, "data:")) diff --git a/aidial_client/_http_client/_sync.py b/aidial_client/_http_client/_sync.py index b330c12..9151659 100644 --- a/aidial_client/_http_client/_sync.py +++ b/aidial_client/_http_client/_sync.py @@ -1,12 +1,14 @@ import time +from contextlib import contextmanager from http import HTTPStatus -from typing import Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Type, Union import httpx from aidial_client._auth import SyncAuthValue, get_combined_auth_headers from aidial_client._exception import DialException from aidial_client._http_client._base import BaseHTTPClient +from aidial_client._internal_types._defaults import NOT_GIVEN, NotGiven from aidial_client._internal_types._generic import ResponseT from aidial_client._internal_types._http_request import FinalRequestOptions from aidial_client._log import logger @@ -108,3 +110,52 @@ def request( raise raised_error from err return process_block_response(cast_to=cast_to, response=response) + + @contextmanager + def stream_sse( + self, + *, + method: str, + url: str, + json_data: Any, + headers: Optional[Mapping[str, str]] = None, + timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, + ) -> Iterator[httpx.Response]: + """Open an SSE streaming response. Yields the open httpx.Response. + + Auth headers are merged in. On non-2xx, reads the body and raises + a DialException; transport errors (timeouts, network failures) are + also wrapped so the caller always sees DialException. Retries are + not performed for streaming requests. + + ``timeout`` defaults to the client-wide timeout; pass an explicit + ``None`` (or ``httpx.Timeout(None)``) for no timeout. + """ + merged_headers = {**self.auth_headers(), **(headers or {})} + effective_timeout = ( + self._timeout if isinstance(timeout, NotGiven) else timeout + ) + try: + with self._internal_http_client.stream( + method=method, + url=self._prepare_url(url), + headers=merged_headers, + json=json_data, + timeout=effective_timeout, + ) as response: + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: + try: + response.read() + except httpx.HTTPError: + pass + raise self._make_dial_error_from_response( + err.response + ) from err + yield response + except httpx.TimeoutException as err: + raise DialException( + message="Request timed out", + status_code=HTTPStatus.REQUEST_TIMEOUT, + ) from err diff --git a/aidial_client/resources/__init__.py b/aidial_client/resources/__init__.py index 8e587ca..ed55b9a 100644 --- a/aidial_client/resources/__init__.py +++ b/aidial_client/resources/__init__.py @@ -1,3 +1,7 @@ +from aidial_client.resources.client_channel import ( + AsyncClientChannel, + ClientChannel, +) from aidial_client.resources.deployments import AsyncDeployments, Deployments from aidial_client.resources.metadata import AsyncMetadata, Metadata from aidial_client.resources.model import AsyncModel, Model @@ -34,4 +38,6 @@ "AsyncModel", "ResourcePermissions", "AsyncResourcePermissions", + "ClientChannel", + "AsyncClientChannel", ] diff --git a/aidial_client/resources/client_channel.py b/aidial_client/resources/client_channel.py new file mode 100644 index 0000000..10713cc --- /dev/null +++ b/aidial_client/resources/client_channel.py @@ -0,0 +1,134 @@ +import json +from http import HTTPStatus +from typing import Any, List, Union + +import httpx + +from aidial_client._compatibility.pydantic_v1 import ValidationError +from aidial_client._exception import ( + DialException, + InvalidRequestError, + ParsingDataError, +) +from aidial_client._http_client._sse import aiter_data_events, iter_data_events +from aidial_client._internal_types._defaults import NOT_GIVEN, NotGiven +from aidial_client.resources.base import AsyncResource, Resource +from aidial_client.types.client_channel import JsonRpcRequest, JsonRpcResponse + +CLIENT_CHANNEL_HEADER = "X-DIAL-CLIENT-CHANNEL-ID" +_INTERACT_URL = "v1/ops/client-channel/interact" + + +def _serialize( + request: Union[JsonRpcRequest, List[JsonRpcRequest]], +) -> Any: + requests = request if isinstance(request, list) else [request] + for r in requests: + if r.id is None: + raise InvalidRequestError( + "JsonRpcRequest.id is required for client_channel.interact(): " + "a request without an id is a JSON-RPC notification, which the " + "server does not respond to — the call would block until the " + "stream closes." + ) + if isinstance(request, list): + return [r.dict(exclude_none=True) for r in request] + return request.dict(exclude_none=True) + + +def _parse_responses(payload: str) -> List[JsonRpcResponse]: + try: + data = json.loads(payload) + except json.JSONDecodeError as err: + raise ParsingDataError( + message=f"Malformed JSON in client-channel interact response: {err}" + ) from err + items = data if isinstance(data, list) else [data] + parsed: List[JsonRpcResponse] = [] + for item in items: + if not isinstance(item, dict): + raise ParsingDataError( + message=( + "Invalid JSON-RPC response in client-channel interact: " + f"expected object, got {type(item).__name__}" + ) + ) + try: + response = JsonRpcResponse(**item) + except (TypeError, ValidationError) as err: + raise ParsingDataError( + message=( + "Invalid JSON-RPC response in client-channel interact: " + f"{err}" + ) + ) from err + if response.result is None and response.error is None: + raise ParsingDataError( + message=( + "Invalid JSON-RPC response in client-channel interact: " + "must contain either 'result' or 'error'" + ) + ) + parsed.append(response) + return parsed + + +def _no_data_error() -> DialException: + return DialException( + message="Client-channel interact stream closed without a data event", + status_code=HTTPStatus.GATEWAY_TIMEOUT, + ) + + +class ClientChannel(Resource): + def interact( + self, + *, + channel_id: str, + request: Union[JsonRpcRequest, List[JsonRpcRequest]], + timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, + ) -> List[JsonRpcResponse]: + """Send a JSON-RPC request (or batch) over the client channel and + wait for the corresponding response event. + + Returns a list of ``JsonRpcResponse``. The wire response preserves the + server-emitted order, which is not guaranteed to match request order; + callers should correlate each response with its request via the + ``id`` field, not by positional index. + + Raises ``DialException`` if the HTTP status is not 2xx, the stream + closes without a response, or a transport error (timeout, network + failure) occurs. Raises ``InvalidRequestError`` if any request has + ``id=None`` (JSON-RPC notifications are not supported here — the + server does not respond to them). + """ + with self.http_client.stream_sse( + method="POST", + url=_INTERACT_URL, + json_data=_serialize(request), + headers={CLIENT_CHANNEL_HEADER: channel_id}, + timeout=timeout, + ) as response: + for payload in iter_data_events(response.iter_lines()): + return _parse_responses(payload) + raise _no_data_error() + + +class AsyncClientChannel(AsyncResource): + async def interact( + self, + *, + channel_id: str, + request: Union[JsonRpcRequest, List[JsonRpcRequest]], + timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, + ) -> List[JsonRpcResponse]: + async with self.http_client.stream_sse( + method="POST", + url=_INTERACT_URL, + json_data=_serialize(request), + headers={CLIENT_CHANNEL_HEADER: channel_id}, + timeout=timeout, + ) as response: + async for payload in aiter_data_events(response.aiter_lines()): + return _parse_responses(payload) + raise _no_data_error() diff --git a/aidial_client/types/client_channel.py b/aidial_client/types/client_channel.py new file mode 100644 index 0000000..b284823 --- /dev/null +++ b/aidial_client/types/client_channel.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, List, Literal, Optional, Union + +from aidial_client._compatibility.pydantic_v1 import BaseModel, Extra + + +class JsonRpcError(BaseModel): + code: int + message: str + data: Optional[Any] = None + + class Config: + extra = Extra.allow + + +class JsonRpcRequest(BaseModel): + jsonrpc: Literal["2.0"] = "2.0" + method: str + params: Optional[Union[List[Any], Dict[str, Any]]] = None + id: Optional[Union[int, str]] = None + + class Config: + smart_union = True + + +class JsonRpcResponse(BaseModel): + jsonrpc: Literal["2.0"] = "2.0" + result: Optional[Any] = None + error: Optional[JsonRpcError] = None + id: Optional[Union[int, str]] = None + + class Config: + smart_union = True + extra = Extra.allow diff --git a/tests/resources/test_client_channel.py b/tests/resources/test_client_channel.py new file mode 100644 index 0000000..e7ab4bf --- /dev/null +++ b/tests/resources/test_client_channel.py @@ -0,0 +1,342 @@ +import json +from http import HTTPStatus +from typing import Any, List + +import httpx +import pytest + +from aidial_client import Dial +from aidial_client._client import AsyncDial +from aidial_client._exception import ( + DialException, + InvalidRequestError, + ParsingDataError, +) +from aidial_client.types.client_channel import ( + JsonRpcError, + JsonRpcRequest, + JsonRpcResponse, +) +from tests.client_mock import ( + MockStreamIterator, + get_async_client_mock, + get_client_mock, +) + +SINGLE_SUCCESS = {"jsonrpc": "2.0", "result": "success", "id": "1"} +BATCH_PAYLOAD = [ + {"jsonrpc": "2.0", "result": "success", "id": "1"}, + {"jsonrpc": "2.0", "result": "denied", "id": "2"}, +] +ERROR_PAYLOAD = { + "jsonrpc": "2.0", + "error": {"code": -32000, "message": "boom"}, + "id": "1", +} + + +def _sse_chunks(*lines: str) -> List[bytes]: + """Encode a sequence of SSE lines as one byte stream chunk.""" + return [("\n".join(lines) + "\n").encode()] + + +def _data(payload: Any) -> str: + return f"data: {json.dumps(payload)}" + + +def _single_event(payload: Any) -> List[bytes]: + return _sse_chunks(": heartbeat", "", _data(payload), "") + + +def test_interact_single_request_sync(): + client = get_client_mock( + status_code=200, stream_chunks_mock=_single_event(SINGLE_SUCCESS) + ) + responses = client.client_channel.interact( + channel_id="abc", + request=JsonRpcRequest( + method="toolset/signin", params={"toolsetId": "X"}, id="1" + ), + ) + assert len(responses) == 1 + assert isinstance(responses[0], JsonRpcResponse) + assert responses[0].result == "success" + assert responses[0].id == "1" + assert responses[0].error is None + + +@pytest.mark.asyncio +async def test_interact_single_request_async(): + client = get_async_client_mock( + status_code=200, stream_chunks_mock=_single_event(SINGLE_SUCCESS) + ) + responses = await client.client_channel.interact( + channel_id="abc", + request=JsonRpcRequest( + method="toolset/signin", params={"toolsetId": "X"}, id="1" + ), + ) + assert len(responses) == 1 + assert responses[0].result == "success" + + +def test_interact_batch_request_sync(): + client = get_client_mock( + status_code=200, stream_chunks_mock=_single_event(BATCH_PAYLOAD) + ) + requests = [ + JsonRpcRequest( + method="toolset/signin", params={"toolsetId": "A"}, id="1" + ), + JsonRpcRequest( + method="toolset/signin", params={"toolsetId": "B"}, id="2" + ), + ] + responses = client.client_channel.interact( + channel_id="abc", request=requests + ) + assert len(responses) == 2 + assert responses[0].result == "success" + assert responses[0].id == "1" + assert responses[1].result == "denied" + assert responses[1].id == "2" + + +@pytest.mark.asyncio +async def test_interact_batch_request_async(): + client = get_async_client_mock( + status_code=200, stream_chunks_mock=_single_event(BATCH_PAYLOAD) + ) + requests = [ + JsonRpcRequest(method="m", id="1"), + JsonRpcRequest(method="m", id="2"), + ] + responses = await client.client_channel.interact( + channel_id="abc", request=requests + ) + assert len(responses) == 2 + assert [r.id for r in responses] == ["1", "2"] + + +def test_interact_error_response_sync(): + client = get_client_mock( + status_code=200, stream_chunks_mock=_single_event(ERROR_PAYLOAD) + ) + responses = client.client_channel.interact( + channel_id="abc", request=JsonRpcRequest(method="m", id="1") + ) + assert len(responses) == 1 + assert responses[0].result is None + assert isinstance(responses[0].error, JsonRpcError) + assert responses[0].error.code == -32000 + assert responses[0].error.message == "boom" + + +@pytest.mark.asyncio +async def test_interact_error_response_async(): + client = get_async_client_mock( + status_code=200, stream_chunks_mock=_single_event(ERROR_PAYLOAD) + ) + responses = await client.client_channel.interact( + channel_id="abc", request=JsonRpcRequest(method="m", id="1") + ) + assert responses[0].error is not None + assert responses[0].error.code == -32000 + + +def test_interact_heartbeats_skipped_sync(): + chunks = _sse_chunks( + ": heartbeat", + "", + ": heartbeat", + "", + ": heartbeat", + "", + _data(SINGLE_SUCCESS), + "", + ) + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + responses = client.client_channel.interact( + channel_id="abc", request=JsonRpcRequest(method="m", id="1") + ) + assert responses[0].result == "success" + + +def test_interact_malformed_json_raises_sync(): + chunks = _sse_chunks("data: not-json", "") + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(ParsingDataError): + client.client_channel.interact( + channel_id="abc", request=JsonRpcRequest(method="m", id="1") + ) + + +def test_interact_no_data_event_raises_sync(): + chunks = _sse_chunks(": heartbeat", "", ": heartbeat", "") + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(DialException) as exc_info: + client.client_channel.interact( + channel_id="abc", request=JsonRpcRequest(method="m", id="1") + ) + assert exc_info.value.status_code == HTTPStatus.GATEWAY_TIMEOUT + + +def test_interact_truncated_stream_does_not_yield_phantom_event_sync(): + # No trailing blank line — incomplete event must NOT be flushed. + chunks = _sse_chunks('data: {"jsonrpc":"2.0","resu') + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(DialException) as exc_info: + client.client_channel.interact( + channel_id="abc", request=JsonRpcRequest(method="m", id="1") + ) + assert exc_info.value.status_code == HTTPStatus.GATEWAY_TIMEOUT + + +def test_interact_http_error_raises_sync(): + body = json.dumps({"error": {"message": "Unauthorized"}}).encode() + client = get_client_mock(status_code=401, stream_chunks_mock=[body]) + with pytest.raises(DialException) as exc_info: + client.client_channel.interact( + channel_id="abc", request=JsonRpcRequest(method="m", id="1") + ) + assert exc_info.value.status_code == 401 + assert exc_info.value.message == "Unauthorized" + + +@pytest.mark.asyncio +async def test_interact_http_error_raises_async(): + body = json.dumps({"error": {"message": "Unauthorized"}}).encode() + client = get_async_client_mock(status_code=401, stream_chunks_mock=[body]) + with pytest.raises(DialException) as exc_info: + await client.client_channel.interact( + channel_id="abc", request=JsonRpcRequest(method="m", id="1") + ) + assert exc_info.value.status_code == 401 + + +def test_interact_sends_channel_header_and_body_sync(): + captured: dict = {} + + def send_mock(request: httpx.Request, **kwargs): + captured["request"] = request + return httpx.Response( + status_code=200, + request=request, + stream=MockStreamIterator( + mock_chunks=_single_event(SINGLE_SUCCESS) + ), + ) + + client = Dial(api_key="dummy", base_url="http://dial.core") + client._http_client._internal_http_client.send = send_mock + + client.client_channel.interact( + channel_id="my-channel", + request=JsonRpcRequest(method="toolset/signin", id="1"), + ) + + request = captured["request"] + assert request.headers["X-DIAL-CLIENT-CHANNEL-ID"] == "my-channel" + assert request.headers["api-key"] == "dummy" + assert request.url.path == "/v1/ops/client-channel/interact" + body = json.loads(request.content) + assert body == {"jsonrpc": "2.0", "method": "toolset/signin", "id": "1"} + + +def test_interact_rejects_notification_request_sync(): + client = get_client_mock( + status_code=200, stream_chunks_mock=_single_event(SINGLE_SUCCESS) + ) + with pytest.raises(InvalidRequestError): + client.client_channel.interact( + channel_id="abc", request=JsonRpcRequest(method="m") + ) + + +def test_interact_rejects_batch_with_notification_sync(): + client = get_client_mock( + status_code=200, stream_chunks_mock=_single_event(BATCH_PAYLOAD) + ) + with pytest.raises(InvalidRequestError): + client.client_channel.interact( + channel_id="abc", + request=[ + JsonRpcRequest(method="m", id="1"), + JsonRpcRequest(method="m"), + ], + ) + + +def test_interact_rejects_heartbeat_shaped_payload_sync(): + chunks = _single_event({"type": "heartbeat", "ts": 1234}) + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(ParsingDataError): + client.client_channel.interact( + channel_id="abc", request=JsonRpcRequest(method="m", id="1") + ) + + +def test_interact_rejects_null_payload_sync(): + chunks = _sse_chunks("data: null", "") + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(ParsingDataError): + client.client_channel.interact( + channel_id="abc", request=JsonRpcRequest(method="m", id="1") + ) + + +def test_interact_preserves_integer_id_sync(): + captured: dict = {} + + def send_mock(request: httpx.Request, **kwargs): + captured["request"] = request + return httpx.Response( + status_code=200, + request=request, + stream=MockStreamIterator( + mock_chunks=_single_event( + {"jsonrpc": "2.0", "result": "ok", "id": 42} + ) + ), + ) + + client = Dial(api_key="dummy", base_url="http://dial.core") + client._http_client._internal_http_client.send = send_mock + + responses = client.client_channel.interact( + channel_id="x", + request=JsonRpcRequest(method="m", id=42), + ) + body = json.loads(captured["request"].content) + assert body["id"] == 42 and isinstance(body["id"], int) + assert responses[0].id == 42 and isinstance(responses[0].id, int) + + +@pytest.mark.asyncio +async def test_interact_batch_body_serialized_as_array_async(): + captured: dict = {} + + async def send_mock(request: httpx.Request, **kwargs): + captured["request"] = request + return httpx.Response( + status_code=200, + request=request, + stream=MockStreamIterator(mock_chunks=_single_event(BATCH_PAYLOAD)), + ) + + client = AsyncDial(api_key="dummy", base_url="http://dial.core") + client._http_client._internal_http_client.send = send_mock + + await client.client_channel.interact( + channel_id="ch", + request=[ + JsonRpcRequest(method="m", id="1"), + JsonRpcRequest(method="m", id="2"), + ], + ) + + body = json.loads(captured["request"].content) + assert isinstance(body, list) + assert len(body) == 2 + assert body[0]["id"] == "1" + assert body[1]["id"] == "2" From d2b525e08ecfaa842b29c4a4ea1153598aca0c6f Mon Sep 17 00:00:00 2001 From: Daniil Yarmalkevich Date: Thu, 28 May 2026 20:26:26 +0300 Subject: [PATCH 2/2] refactor: hide JSON-RPC behind typed signin_toolsets API #87 Per @adubovik's review on PR #100, replaces the public `interact(JsonRpcRequest)` surface with a typed `signin_toolsets(channel_id, toolset_ids) -> dict[str, SigninResult]`. JSON-RPC plumbing moves to `_internal_types/_json_rpc.py`; only `SigninResult` is exported publicly. Review-comment fixes also applied (C1-C8): except Exception fallback in stream_sse, presence-based result/error XOR via pydantic root_validator (so `{"result": null}` parses), private `_CLIENT_CHANNEL_HEADER`, types moved out of client_channel.py, required `jsonrpc`/`id` on response, parsing via a pydantic root model, `_serialize` -> `_serialize_requests`, SSE warning on uncommitted trailing buffer. Additional correctness fixes from the post-redesign review pass: reject str/duplicate toolset_ids and materialize iterators; surface server-level JSON-RPC errors (id=null) as DialException; always emit array wire body; narrow stream_sse catch from Exception to httpx.HTTPError so user-code exceptions inside the with-body propagate. --- README.md | 64 ++- aidial_client/__init__.py | 10 +- aidial_client/_http_client/_async.py | 2 + aidial_client/_http_client/_sse.py | 11 + aidial_client/_http_client/_sync.py | 2 + aidial_client/_internal_types/_json_rpc.py | 75 ++++ aidial_client/resources/client_channel.py | 228 +++++++--- aidial_client/types/client_channel.py | 36 +- tests/resources/test_client_channel.py | 495 ++++++++++++--------- 9 files changed, 579 insertions(+), 344 deletions(-) create mode 100644 aidial_client/_internal_types/_json_rpc.py diff --git a/README.md b/README.md index 2149cf9..98ffb61 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ - [Resource Permissions](#resource-permissions) - [Grant Permissions](#grant-permissions) - [Client Channel](#client-channel) - - [Send a JSON-RPC Request](#send-a-json-rpc-request) + - [Sign In to Toolsets](#sign-in-to-toolsets) - [Client Pool](#client-pool) - [Synchronous Client Pool](#synchronous-client-pool) - [Asynchronous Client Pool](#asynchronous-client-pool) @@ -858,60 +858,52 @@ The method returns `None` on success and raises `DialException` on HTTP error. ### Client Channel -DIAL Core's client channel API (`/v1/ops/client-channel/*`) lets deployments send JSON-RPC requests to an interactive client (e.g. the chat UI) and receive their responses. The client must be subscribed to a channel; the channel id is propagated to the deployment via the `X-DIAL-CLIENT-CHANNEL-ID` forwarded header on the inbound request. +DIAL Core's [client channel API](https://dialx.ai/universal_chat_api.yaml) lets a deployment ask an interactive client (e.g. the chat UI) to take some action and report the result back. The channel id is propagated to the deployment via the `X-DIAL-CLIENT-CHANNEL-ID` forwarded header on the inbound request. -#### Send a JSON-RPC Request +#### Sign In to Toolsets -Use `client_channel.interact()` to send a JSON-RPC request (or a batch) to the channel and wait for the response(s). The method returns a `List[JsonRpcResponse]` containing one entry per request. Server-emitted order is **not** guaranteed to match request order — correlate each response with its request via the `id` field, not by positional index. Every request must have a non-`None` `id`; JSON-RPC notifications are not supported here (the server does not respond to them, so `interact()` would block). +Use `client_channel.signin_toolsets()` to request interactive sign-in for one or more toolsets on the active client channel. The method returns a `dict[str, SigninResult]` mapping each input toolset id to its outcome — responses are correlated by the client, so the caller never has to deal with the underlying JSON-RPC ids. ```python -from aidial_client import JsonRpcRequest +from aidial_client import SigninResult -# Sync — single request -responses = client.client_channel.interact( +# Sync +results = client.client_channel.signin_toolsets( channel_id="", - request=JsonRpcRequest( - method="toolset/signin", - params={"toolsetId": "toolsets/public/my-toolset"}, - id="1", - ), + toolset_ids=[ + "toolsets/public/toolset-a", + "toolsets/public/toolset-b", + ], timeout=120.0, ) -# Async — batched requests -responses = await async_client.client_channel.interact( +# Async +results = await async_client.client_channel.signin_toolsets( channel_id="", - request=[ - JsonRpcRequest( - method="toolset/signin", - params={"toolsetId": "toolsets/public/toolset-a"}, - id="1", - ), - JsonRpcRequest( - method="toolset/signin", - params={"toolsetId": "toolsets/public/toolset-b"}, - id="2", - ), - ], + toolset_ids=["toolsets/public/my-toolset"], ) ``` -Each element is a `JsonRpcResponse`: +Each value is a `SigninResult` enum: ```python -JsonRpcResponse( - jsonrpc="2.0", - result="success", # or None, with `error` populated instead - error=None, # JsonRpcError(code=..., message=..., data=...) on failure - id="1", -) +{ + "toolsets/public/toolset-a": SigninResult.SUCCESS, + "toolsets/public/toolset-b": SigninResult.DENIED, +} ``` +- `SigninResult.SUCCESS` — the user signed in. +- `SigninResult.DENIED` — the user declined. +- `SigninResult.ERROR` — the server returned a JSON-RPC error, or the response was missing/unrecognized. + +Arguments: + - `channel_id` — required; the channel id received via the `X-DIAL-CLIENT-CHANNEL-ID` header on the inbound request. -- `request` — a single `JsonRpcRequest` or a list of them. The wire body is an object or an array accordingly. -- `timeout` — wall-clock timeout in seconds (or an `httpx.Timeout`); defaults to the client-wide timeout. Useful for interactive flows where the channel may take a while to respond. +- `toolset_ids` — sequence of toolset ids to request sign-in for; an empty sequence returns `{}` without contacting the server. +- `timeout` — optional `float` seconds or `httpx.Timeout`; defaults to the client-wide timeout. Useful for interactive flows where the user may take a while to respond. -Raises `DialException` if the HTTP status is not 2xx (e.g. unauthorized, missing channel) or if the stream closes before a data event arrives. Raises `ParsingDataError` if the response payload is not a valid JSON-RPC response. +Raises `DialException` on HTTP errors (e.g. unauthorized, missing channel), transport failures (timeouts, network errors), or if the SSE stream closes without a response event. ### Client Pool diff --git a/aidial_client/__init__.py b/aidial_client/__init__.py index c023459..edf195f 100644 --- a/aidial_client/__init__.py +++ b/aidial_client/__init__.py @@ -9,11 +9,7 @@ ParsingDataError, ResourceNotFoundError, ) -from aidial_client.types.client_channel import ( - JsonRpcError, - JsonRpcRequest, - JsonRpcResponse, -) +from aidial_client.types.client_channel import SigninResult from aidial_client.types.model import ModelInfo, ModelLimits, ModelPricing from aidial_client.types.toolset import ToolsetInfo @@ -35,7 +31,5 @@ "ModelInfo", "ModelPricing", "ModelLimits", - "JsonRpcRequest", - "JsonRpcResponse", - "JsonRpcError", + "SigninResult", ] diff --git a/aidial_client/_http_client/_async.py b/aidial_client/_http_client/_async.py index 78406b2..5dcf37c 100644 --- a/aidial_client/_http_client/_async.py +++ b/aidial_client/_http_client/_async.py @@ -168,3 +168,5 @@ async def stream_sse( message="Request timed out", status_code=HTTPStatus.REQUEST_TIMEOUT, ) from err + except httpx.HTTPError as err: + raise DialException(message=f"Request failed: {err}") from err diff --git a/aidial_client/_http_client/_sse.py b/aidial_client/_http_client/_sse.py index ecbc4f1..fa776bb 100644 --- a/aidial_client/_http_client/_sse.py +++ b/aidial_client/_http_client/_sse.py @@ -1,5 +1,12 @@ from typing import AsyncIterator, Iterator, List +from aidial_client._log import logger + +_UNCOMMITTED_BUFFER_WARNING = ( + "Uncommitted data chunks in SSE stream " + "(stream ended without a terminating blank line); discarding." +) + def _strip_field(line: str, prefix: str) -> str: """Strip a single leading U+0020 SPACE after the field colon, per the SSE spec.""" @@ -23,6 +30,8 @@ def iter_data_events(lines: Iterator[str]) -> Iterator[str]: buffer = [] elif line.startswith("data:"): buffer.append(_strip_field(line, "data:")) + if buffer: + logger.warning(_UNCOMMITTED_BUFFER_WARNING) async def aiter_data_events(lines: AsyncIterator[str]) -> AsyncIterator[str]: @@ -34,3 +43,5 @@ async def aiter_data_events(lines: AsyncIterator[str]) -> AsyncIterator[str]: buffer = [] elif line.startswith("data:"): buffer.append(_strip_field(line, "data:")) + if buffer: + logger.warning(_UNCOMMITTED_BUFFER_WARNING) diff --git a/aidial_client/_http_client/_sync.py b/aidial_client/_http_client/_sync.py index 9151659..583de97 100644 --- a/aidial_client/_http_client/_sync.py +++ b/aidial_client/_http_client/_sync.py @@ -159,3 +159,5 @@ def stream_sse( message="Request timed out", status_code=HTTPStatus.REQUEST_TIMEOUT, ) from err + except httpx.HTTPError as err: + raise DialException(message=f"Request failed: {err}") from err diff --git a/aidial_client/_internal_types/_json_rpc.py b/aidial_client/_internal_types/_json_rpc.py new file mode 100644 index 0000000..90b9ffa --- /dev/null +++ b/aidial_client/_internal_types/_json_rpc.py @@ -0,0 +1,75 @@ +from typing import Any, Dict, List, Literal, Optional, Union + +from aidial_client._compatibility.pydantic_v1 import ( + BaseModel, + Extra, + Field, + root_validator, +) + + +class JsonRpcError(BaseModel): + code: int + message: str + data: Optional[Any] = None + + class Config: + extra = Extra.allow + + +class JsonRpcRequest(BaseModel): + jsonrpc: Literal["2.0"] = "2.0" + method: str + params: Optional[Union[List[Any], Dict[str, Any]]] = None + id: Optional[Union[int, str]] = None + + class Config: + smart_union = True + + +class JsonRpcResponse(BaseModel): + jsonrpc: Literal["2.0"] + result: Optional[Any] = None + error: Optional[JsonRpcError] = None + id: Optional[Union[int, str]] = Field(...) + + class Config: + smart_union = True + extra = Extra.allow + + @root_validator(pre=True) + def _validate_result_xor_error(cls, values): + """Per JSON-RPC 2.0 (https://www.jsonrpc.org/specification#response_object), + either ``result`` or ``error`` MUST be included (presence-wise — ``null`` + is a valid result value), and both MUST NOT be included. + """ + if not isinstance(values, dict): + return values + has_result = "result" in values + has_error = "error" in values + if has_result and has_error: + raise ValueError( + "JSON-RPC response must not contain both 'result' and 'error'" + ) + if not has_result and not has_error: + raise ValueError( + "JSON-RPC response must contain either 'result' or 'error'" + ) + return values + + +class JsonRpcResponses(BaseModel): + """Pydantic root model that accepts a single JSON-RPC response object or + a batch array, normalizing both to a list via the ``responses`` property. + """ + + __root__: Union[JsonRpcResponse, List[JsonRpcResponse]] + + class Config: + smart_union = True + + @property + def responses(self) -> List[JsonRpcResponse]: + if isinstance(self.__root__, list): + return self.__root__ + return [self.__root__] diff --git a/aidial_client/resources/client_channel.py b/aidial_client/resources/client_channel.py index 10713cc..9135262 100644 --- a/aidial_client/resources/client_channel.py +++ b/aidial_client/resources/client_channel.py @@ -1,6 +1,5 @@ -import json from http import HTTPStatus -from typing import Any, List, Union +from typing import Any, List, Optional, Sequence, Union import httpx @@ -12,65 +11,58 @@ ) from aidial_client._http_client._sse import aiter_data_events, iter_data_events from aidial_client._internal_types._defaults import NOT_GIVEN, NotGiven +from aidial_client._internal_types._json_rpc import ( + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponses, +) from aidial_client.resources.base import AsyncResource, Resource -from aidial_client.types.client_channel import JsonRpcRequest, JsonRpcResponse +from aidial_client.types.client_channel import SigninResult -CLIENT_CHANNEL_HEADER = "X-DIAL-CLIENT-CHANNEL-ID" +_CLIENT_CHANNEL_HEADER = "X-DIAL-CLIENT-CHANNEL-ID" _INTERACT_URL = "v1/ops/client-channel/interact" +_SIGNIN_METHOD = "toolset/signin" -def _serialize( - request: Union[JsonRpcRequest, List[JsonRpcRequest]], -) -> Any: - requests = request if isinstance(request, list) else [request] - for r in requests: - if r.id is None: - raise InvalidRequestError( - "JsonRpcRequest.id is required for client_channel.interact(): " - "a request without an id is a JSON-RPC notification, which the " - "server does not respond to — the call would block until the " - "stream closes." - ) - if isinstance(request, list): - return [r.dict(exclude_none=True) for r in request] - return request.dict(exclude_none=True) +def _normalize_toolset_ids(toolset_ids: Sequence[str]) -> List[str]: + """Validate ``toolset_ids`` and return a stable list. + + Catches three caller mistakes that would otherwise produce silent garbage: + a single string (str is itself a ``Sequence[str]``), a one-shot iterable + (consumed by the build step, leaving the mapping step with nothing), and + duplicate ids (the per-toolset result dict cannot represent two outcomes + for the same key). + """ + if isinstance(toolset_ids, str): + raise InvalidRequestError( + "toolset_ids must be a sequence of toolset ids, not a single str" + ) + materialized = list(toolset_ids) + if len(set(materialized)) != len(materialized): + raise InvalidRequestError("toolset_ids must not contain duplicates") + return materialized + + +def _serialize_requests(requests: Sequence[JsonRpcRequest]) -> Any: + """Serialize a sequence of JsonRpcRequest to the wire form. + + Always emits an array. DIAL Core accepts both an object and an array + body, but emitting a consistent shape avoids the "wire shape depends + on count" footgun and keeps the empty-input case safe. + """ + return [r.dict(exclude_none=True) for r in requests] def _parse_responses(payload: str) -> List[JsonRpcResponse]: try: - data = json.loads(payload) - except json.JSONDecodeError as err: + return JsonRpcResponses.parse_raw(payload).responses + except (ValidationError, ValueError) as err: raise ParsingDataError( - message=f"Malformed JSON in client-channel interact response: {err}" - ) from err - items = data if isinstance(data, list) else [data] - parsed: List[JsonRpcResponse] = [] - for item in items: - if not isinstance(item, dict): - raise ParsingDataError( - message=( - "Invalid JSON-RPC response in client-channel interact: " - f"expected object, got {type(item).__name__}" - ) - ) - try: - response = JsonRpcResponse(**item) - except (TypeError, ValidationError) as err: - raise ParsingDataError( - message=( - "Invalid JSON-RPC response in client-channel interact: " - f"{err}" - ) - ) from err - if response.result is None and response.error is None: - raise ParsingDataError( - message=( - "Invalid JSON-RPC response in client-channel interact: " - "must contain either 'result' or 'error'" - ) + message=( + "Invalid JSON-RPC response in client-channel interact: " + f"{err}" ) - parsed.append(response) - return parsed + ) from err def _no_data_error() -> DialException: @@ -80,33 +72,107 @@ def _no_data_error() -> DialException: ) +def _raise_if_batch_error(responses: Sequence[JsonRpcResponse]) -> None: + """Per JSON-RPC 2.0, a response with ``id=null`` indicates the server + could not associate the response with any request (parse error, invalid + batch, etc.). Surface that as a ``DialException`` instead of silently + mapping every toolset to ERROR. + """ + for r in responses: + if r.id is None and r.error is not None: + raise DialException( + message=( + f"Server-level JSON-RPC error " + f"({r.error.code}): {r.error.message}" + ), + status_code=HTTPStatus.BAD_GATEWAY, + ) + + +_RESULT_TO_OUTCOME = { + SigninResult.SUCCESS.value: SigninResult.SUCCESS, + SigninResult.DENIED.value: SigninResult.DENIED, +} + + +def _outcome_for(response: Optional[JsonRpcResponse]) -> SigninResult: + if response is None or response.error is not None: + return SigninResult.ERROR + if not isinstance(response.result, str): + return SigninResult.ERROR + return _RESULT_TO_OUTCOME.get(response.result, SigninResult.ERROR) + + +def _build_signin_requests( + toolset_ids: Sequence[str], +) -> List[JsonRpcRequest]: + return [ + JsonRpcRequest( + method=_SIGNIN_METHOD, + params={"toolsetId": tid}, + id=str(idx), + ) + for idx, tid in enumerate(toolset_ids, start=1) + ] + + +def _map_signin_results( + toolset_ids: Sequence[str], + responses: Sequence[JsonRpcResponse], +) -> "dict[str, SigninResult]": + by_id = {str(r.id): r for r in responses if r.id is not None} + return { + tid: _outcome_for(by_id.get(str(idx))) + for idx, tid in enumerate(toolset_ids, start=1) + } + + class ClientChannel(Resource): - def interact( + def signin_toolsets( self, *, channel_id: str, - request: Union[JsonRpcRequest, List[JsonRpcRequest]], + toolset_ids: Sequence[str], timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, - ) -> List[JsonRpcResponse]: - """Send a JSON-RPC request (or batch) over the client channel and - wait for the corresponding response event. - - Returns a list of ``JsonRpcResponse``. The wire response preserves the - server-emitted order, which is not guaranteed to match request order; - callers should correlate each response with its request via the - ``id`` field, not by positional index. - - Raises ``DialException`` if the HTTP status is not 2xx, the stream - closes without a response, or a transport error (timeout, network - failure) occurs. Raises ``InvalidRequestError`` if any request has - ``id=None`` (JSON-RPC notifications are not supported here — the - server does not respond to them). + ) -> "dict[str, SigninResult]": + """Request interactive sign-in for one or more toolsets on the given + client channel and return the per-toolset outcome. + + ``toolset_ids`` are typically DIAL toolset ids (e.g. + ``"toolsets/public/my-toolset"``). The returned dict has one entry + per input id; toolsets for which the server does not produce a + response are mapped to :class:`SigninResult.ERROR`. Iteration order + of the returned dict matches the order of ``toolset_ids``. + + Raises :class:`InvalidRequestError` if ``toolset_ids`` is a plain + string or contains duplicates. Raises :class:`DialException` on HTTP + errors, transport failures, server-level JSON-RPC errors (e.g. parse + error returned with ``id=null``), or if the SSE stream closes + without a response event. """ + ids = _normalize_toolset_ids(toolset_ids) + if not ids: + return {} + responses = self._interact( + channel_id=channel_id, + requests=_build_signin_requests(ids), + timeout=timeout, + ) + _raise_if_batch_error(responses) + return _map_signin_results(ids, responses) + + def _interact( + self, + *, + channel_id: str, + requests: Sequence[JsonRpcRequest], + timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, + ) -> List[JsonRpcResponse]: with self.http_client.stream_sse( method="POST", url=_INTERACT_URL, - json_data=_serialize(request), - headers={CLIENT_CHANNEL_HEADER: channel_id}, + json_data=_serialize_requests(requests), + headers={_CLIENT_CHANNEL_HEADER: channel_id}, timeout=timeout, ) as response: for payload in iter_data_events(response.iter_lines()): @@ -115,18 +181,36 @@ def interact( class AsyncClientChannel(AsyncResource): - async def interact( + async def signin_toolsets( + self, + *, + channel_id: str, + toolset_ids: Sequence[str], + timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, + ) -> "dict[str, SigninResult]": + ids = _normalize_toolset_ids(toolset_ids) + if not ids: + return {} + responses = await self._interact( + channel_id=channel_id, + requests=_build_signin_requests(ids), + timeout=timeout, + ) + _raise_if_batch_error(responses) + return _map_signin_results(ids, responses) + + async def _interact( self, *, channel_id: str, - request: Union[JsonRpcRequest, List[JsonRpcRequest]], + requests: Sequence[JsonRpcRequest], timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, ) -> List[JsonRpcResponse]: async with self.http_client.stream_sse( method="POST", url=_INTERACT_URL, - json_data=_serialize(request), - headers={CLIENT_CHANNEL_HEADER: channel_id}, + json_data=_serialize_requests(requests), + headers={_CLIENT_CHANNEL_HEADER: channel_id}, timeout=timeout, ) as response: async for payload in aiter_data_events(response.aiter_lines()): diff --git a/aidial_client/types/client_channel.py b/aidial_client/types/client_channel.py index b284823..fa322db 100644 --- a/aidial_client/types/client_channel.py +++ b/aidial_client/types/client_channel.py @@ -1,33 +1,9 @@ -from typing import Any, Dict, List, Literal, Optional, Union +from enum import Enum -from aidial_client._compatibility.pydantic_v1 import BaseModel, Extra +class SigninResult(str, Enum): + """Outcome of an interactive sign-in request for a single toolset.""" -class JsonRpcError(BaseModel): - code: int - message: str - data: Optional[Any] = None - - class Config: - extra = Extra.allow - - -class JsonRpcRequest(BaseModel): - jsonrpc: Literal["2.0"] = "2.0" - method: str - params: Optional[Union[List[Any], Dict[str, Any]]] = None - id: Optional[Union[int, str]] = None - - class Config: - smart_union = True - - -class JsonRpcResponse(BaseModel): - jsonrpc: Literal["2.0"] = "2.0" - result: Optional[Any] = None - error: Optional[JsonRpcError] = None - id: Optional[Union[int, str]] = None - - class Config: - smart_union = True - extra = Extra.allow + SUCCESS = "success" + DENIED = "denied" + ERROR = "error" diff --git a/tests/resources/test_client_channel.py b/tests/resources/test_client_channel.py index e7ab4bf..d85d911 100644 --- a/tests/resources/test_client_channel.py +++ b/tests/resources/test_client_channel.py @@ -1,39 +1,25 @@ import json +import logging from http import HTTPStatus from typing import Any, List import httpx import pytest -from aidial_client import Dial +from aidial_client import Dial, SigninResult from aidial_client._client import AsyncDial from aidial_client._exception import ( DialException, InvalidRequestError, ParsingDataError, ) -from aidial_client.types.client_channel import ( - JsonRpcError, - JsonRpcRequest, - JsonRpcResponse, -) +from aidial_client._internal_types._json_rpc import JsonRpcRequest from tests.client_mock import ( MockStreamIterator, get_async_client_mock, get_client_mock, ) -SINGLE_SUCCESS = {"jsonrpc": "2.0", "result": "success", "id": "1"} -BATCH_PAYLOAD = [ - {"jsonrpc": "2.0", "result": "success", "id": "1"}, - {"jsonrpc": "2.0", "result": "denied", "id": "2"}, -] -ERROR_PAYLOAD = { - "jsonrpc": "2.0", - "error": {"code": -32000, "message": "boom"}, - "id": "1", -} - def _sse_chunks(*lines: str) -> List[bytes]: """Encode a sequence of SSE lines as one byte stream chunk.""" @@ -45,194 +31,265 @@ def _data(payload: Any) -> str: def _single_event(payload: Any) -> List[bytes]: - return _sse_chunks(": heartbeat", "", _data(payload), "") + return _sse_chunks(_data(payload), "") + + +def _signin_response(id_: str, result: str) -> dict: + return {"jsonrpc": "2.0", "id": id_, "result": result} + +def _signin_error_response(id_: str, message: str = "boom") -> dict: + return { + "jsonrpc": "2.0", + "id": id_, + "error": {"code": -32000, "message": message}, + } -def test_interact_single_request_sync(): + +# ---------------------------------------------------------------------------- +# signin_toolsets — happy paths +# ---------------------------------------------------------------------------- + + +def test_signin_single_toolset_success_sync(): client = get_client_mock( - status_code=200, stream_chunks_mock=_single_event(SINGLE_SUCCESS) + status_code=200, + stream_chunks_mock=_single_event(_signin_response("1", "success")), ) - responses = client.client_channel.interact( - channel_id="abc", - request=JsonRpcRequest( - method="toolset/signin", params={"toolsetId": "X"}, id="1" - ), + out = client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["toolsets/public/a"] ) - assert len(responses) == 1 - assert isinstance(responses[0], JsonRpcResponse) - assert responses[0].result == "success" - assert responses[0].id == "1" - assert responses[0].error is None + assert out == {"toolsets/public/a": SigninResult.SUCCESS} @pytest.mark.asyncio -async def test_interact_single_request_async(): +async def test_signin_single_toolset_success_async(): client = get_async_client_mock( - status_code=200, stream_chunks_mock=_single_event(SINGLE_SUCCESS) + status_code=200, + stream_chunks_mock=_single_event(_signin_response("1", "success")), ) - responses = await client.client_channel.interact( - channel_id="abc", - request=JsonRpcRequest( - method="toolset/signin", params={"toolsetId": "X"}, id="1" - ), + out = await client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["toolsets/public/a"] ) - assert len(responses) == 1 - assert responses[0].result == "success" + assert out == {"toolsets/public/a": SigninResult.SUCCESS} -def test_interact_batch_request_sync(): +def test_signin_batch_mixed_outcomes_sync(): + payload = [ + _signin_response("1", "success"), + _signin_response("2", "denied"), + _signin_error_response("3"), + ] client = get_client_mock( - status_code=200, stream_chunks_mock=_single_event(BATCH_PAYLOAD) + status_code=200, stream_chunks_mock=_single_event(payload) ) - requests = [ - JsonRpcRequest( - method="toolset/signin", params={"toolsetId": "A"}, id="1" - ), - JsonRpcRequest( - method="toolset/signin", params={"toolsetId": "B"}, id="2" - ), - ] - responses = client.client_channel.interact( - channel_id="abc", request=requests + out = client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a", "b", "c"] ) - assert len(responses) == 2 - assert responses[0].result == "success" - assert responses[0].id == "1" - assert responses[1].result == "denied" - assert responses[1].id == "2" + assert out == { + "a": SigninResult.SUCCESS, + "b": SigninResult.DENIED, + "c": SigninResult.ERROR, + } @pytest.mark.asyncio -async def test_interact_batch_request_async(): +async def test_signin_batch_mixed_outcomes_async(): + payload = [ + _signin_response("1", "success"), + _signin_response("2", "denied"), + ] client = get_async_client_mock( - status_code=200, stream_chunks_mock=_single_event(BATCH_PAYLOAD) + status_code=200, stream_chunks_mock=_single_event(payload) ) - requests = [ - JsonRpcRequest(method="m", id="1"), - JsonRpcRequest(method="m", id="2"), - ] - responses = await client.client_channel.interact( - channel_id="abc", request=requests + out = await client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a", "b"] ) - assert len(responses) == 2 - assert [r.id for r in responses] == ["1", "2"] + assert out == {"a": SigninResult.SUCCESS, "b": SigninResult.DENIED} -def test_interact_error_response_sync(): +def test_signin_out_of_order_responses_matched_by_id_sync(): + # Server returns responses in arrival order, NOT request order. + payload = [ + _signin_response("2", "denied"), + _signin_response("1", "success"), + ] client = get_client_mock( - status_code=200, stream_chunks_mock=_single_event(ERROR_PAYLOAD) + status_code=200, stream_chunks_mock=_single_event(payload) ) - responses = client.client_channel.interact( - channel_id="abc", request=JsonRpcRequest(method="m", id="1") + out = client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a", "b"] ) - assert len(responses) == 1 - assert responses[0].result is None - assert isinstance(responses[0].error, JsonRpcError) - assert responses[0].error.code == -32000 - assert responses[0].error.message == "boom" + assert out == {"a": SigninResult.SUCCESS, "b": SigninResult.DENIED} -@pytest.mark.asyncio -async def test_interact_error_response_async(): - client = get_async_client_mock( - status_code=200, stream_chunks_mock=_single_event(ERROR_PAYLOAD) +def test_signin_missing_response_for_toolset_maps_to_error_sync(): + payload = [_signin_response("1", "success")] + client = get_client_mock( + status_code=200, stream_chunks_mock=_single_event(payload) ) - responses = await client.client_channel.interact( - channel_id="abc", request=JsonRpcRequest(method="m", id="1") + out = client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a", "b"] ) - assert responses[0].error is not None - assert responses[0].error.code == -32000 + assert out == {"a": SigninResult.SUCCESS, "b": SigninResult.ERROR} -def test_interact_heartbeats_skipped_sync(): - chunks = _sse_chunks( - ": heartbeat", - "", - ": heartbeat", - "", - ": heartbeat", - "", - _data(SINGLE_SUCCESS), - "", +def test_signin_unknown_result_string_maps_to_error_sync(): + payload = [{"jsonrpc": "2.0", "id": "1", "result": "weird-value"}] + client = get_client_mock( + status_code=200, stream_chunks_mock=_single_event(payload) ) - client = get_client_mock(status_code=200, stream_chunks_mock=chunks) - responses = client.client_channel.interact( - channel_id="abc", request=JsonRpcRequest(method="m", id="1") + out = client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] ) - assert responses[0].result == "success" + assert out == {"a": SigninResult.ERROR} -def test_interact_malformed_json_raises_sync(): - chunks = _sse_chunks("data: not-json", "") - client = get_client_mock(status_code=200, stream_chunks_mock=chunks) - with pytest.raises(ParsingDataError): - client.client_channel.interact( - channel_id="abc", request=JsonRpcRequest(method="m", id="1") +def test_signin_empty_toolset_list_returns_empty_dict_sync(): + client = get_client_mock(status_code=200, stream_chunks_mock=[b""]) + out = client.client_channel.signin_toolsets(channel_id="ch", toolset_ids=[]) + assert out == {} + + +def test_signin_rejects_single_string_as_toolset_ids_sync(): + # A plain str satisfies Sequence[str] at runtime; reject explicitly so + # it doesn't iterate the string and send one request per character. + client = get_client_mock(status_code=200, stream_chunks_mock=[b""]) + with pytest.raises(InvalidRequestError): + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids="toolsets/public/x" ) -def test_interact_no_data_event_raises_sync(): - chunks = _sse_chunks(": heartbeat", "", ": heartbeat", "") +def test_signin_rejects_duplicate_toolset_ids_sync(): + client = get_client_mock(status_code=200, stream_chunks_mock=[b""]) + with pytest.raises(InvalidRequestError): + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a", "a"] + ) + + +def test_signin_accepts_iterator_as_toolset_ids_sync(): + # A one-shot iterable would silently produce {} without materialization; + # the wrapper must list() the input before using it twice. + payload = [ + _signin_response("1", "success"), + _signin_response("2", "denied"), + ] + client = get_client_mock( + status_code=200, stream_chunks_mock=_single_event(payload) + ) + out = client.client_channel.signin_toolsets( + channel_id="ch", + toolset_ids=iter(["a", "b"]), # type: ignore[arg-type] + ) + assert out == {"a": SigninResult.SUCCESS, "b": SigninResult.DENIED} + + +def test_signin_batch_level_error_raises_dial_exception_sync(): + # Server-level JSON-RPC error: id=null with an error object (e.g. parse + # error -32700). Must raise instead of silently mapping all toolsets + # to SigninResult.ERROR. + chunks = _single_event( + { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32700, "message": "Parse error"}, + } + ) client = get_client_mock(status_code=200, stream_chunks_mock=chunks) with pytest.raises(DialException) as exc_info: - client.client_channel.interact( - channel_id="abc", request=JsonRpcRequest(method="m", id="1") + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] ) - assert exc_info.value.status_code == HTTPStatus.GATEWAY_TIMEOUT + assert "Parse error" in exc_info.value.message + assert "-32700" in exc_info.value.message -def test_interact_truncated_stream_does_not_yield_phantom_event_sync(): - # No trailing blank line — incomplete event must NOT be flushed. - chunks = _sse_chunks('data: {"jsonrpc":"2.0","resu') +# ---------------------------------------------------------------------------- +# signin_toolsets — transport errors +# ---------------------------------------------------------------------------- + + +def test_signin_no_data_event_raises_sync(): + chunks = _sse_chunks(": heartbeat", "", ": heartbeat", "") client = get_client_mock(status_code=200, stream_chunks_mock=chunks) with pytest.raises(DialException) as exc_info: - client.client_channel.interact( - channel_id="abc", request=JsonRpcRequest(method="m", id="1") + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] ) assert exc_info.value.status_code == HTTPStatus.GATEWAY_TIMEOUT -def test_interact_http_error_raises_sync(): +def test_signin_http_401_raises_with_message_sync(): body = json.dumps({"error": {"message": "Unauthorized"}}).encode() client = get_client_mock(status_code=401, stream_chunks_mock=[body]) with pytest.raises(DialException) as exc_info: - client.client_channel.interact( - channel_id="abc", request=JsonRpcRequest(method="m", id="1") + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] ) assert exc_info.value.status_code == 401 assert exc_info.value.message == "Unauthorized" @pytest.mark.asyncio -async def test_interact_http_error_raises_async(): +async def test_signin_http_401_raises_async(): body = json.dumps({"error": {"message": "Unauthorized"}}).encode() client = get_async_client_mock(status_code=401, stream_chunks_mock=[body]) with pytest.raises(DialException) as exc_info: - await client.client_channel.interact( - channel_id="abc", request=JsonRpcRequest(method="m", id="1") + await client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] ) assert exc_info.value.status_code == 401 + assert exc_info.value.message == "Unauthorized" + + +def test_signin_unknown_transport_error_wrapped_sync(): + client = get_client_mock( + status_code=200, exception_mock=httpx.ConnectError("boom") + ) + with pytest.raises(DialException) as exc_info: + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] + ) + assert "boom" in exc_info.value.message + assert "Request failed" in exc_info.value.message -def test_interact_sends_channel_header_and_body_sync(): +def test_signin_timeout_wrapped_sync(): + client = get_client_mock( + status_code=200, exception_mock=httpx.ReadTimeout("slow") + ) + with pytest.raises(DialException) as exc_info: + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] + ) + assert exc_info.value.status_code == HTTPStatus.REQUEST_TIMEOUT + + +# ---------------------------------------------------------------------------- +# Wire-format checks +# ---------------------------------------------------------------------------- + + +def test_signin_sends_channel_header_and_jsonrpc_body_sync(): captured: dict = {} - def send_mock(request: httpx.Request, **kwargs): + def send_mock(request: httpx.Request, **_kwargs): captured["request"] = request return httpx.Response( status_code=200, request=request, stream=MockStreamIterator( - mock_chunks=_single_event(SINGLE_SUCCESS) + mock_chunks=_single_event(_signin_response("1", "success")) ), ) client = Dial(api_key="dummy", base_url="http://dial.core") client._http_client._internal_http_client.send = send_mock - client.client_channel.interact( - channel_id="my-channel", - request=JsonRpcRequest(method="toolset/signin", id="1"), + client.client_channel.signin_toolsets( + channel_id="my-channel", toolset_ids=["toolsets/public/x"] ) request = captured["request"] @@ -240,103 +297,145 @@ def send_mock(request: httpx.Request, **kwargs): assert request.headers["api-key"] == "dummy" assert request.url.path == "/v1/ops/client-channel/interact" body = json.loads(request.content) - assert body == {"jsonrpc": "2.0", "method": "toolset/signin", "id": "1"} - - -def test_interact_rejects_notification_request_sync(): - client = get_client_mock( - status_code=200, stream_chunks_mock=_single_event(SINGLE_SUCCESS) - ) - with pytest.raises(InvalidRequestError): - client.client_channel.interact( - channel_id="abc", request=JsonRpcRequest(method="m") - ) - - -def test_interact_rejects_batch_with_notification_sync(): - client = get_client_mock( - status_code=200, stream_chunks_mock=_single_event(BATCH_PAYLOAD) - ) - with pytest.raises(InvalidRequestError): - client.client_channel.interact( - channel_id="abc", - request=[ - JsonRpcRequest(method="m", id="1"), - JsonRpcRequest(method="m"), - ], - ) - - -def test_interact_rejects_heartbeat_shaped_payload_sync(): - chunks = _single_event({"type": "heartbeat", "ts": 1234}) - client = get_client_mock(status_code=200, stream_chunks_mock=chunks) - with pytest.raises(ParsingDataError): - client.client_channel.interact( - channel_id="abc", request=JsonRpcRequest(method="m", id="1") - ) - - -def test_interact_rejects_null_payload_sync(): - chunks = _sse_chunks("data: null", "") - client = get_client_mock(status_code=200, stream_chunks_mock=chunks) - with pytest.raises(ParsingDataError): - client.client_channel.interact( - channel_id="abc", request=JsonRpcRequest(method="m", id="1") - ) + # Wire body is always an array, even for a single request. + assert body == [ + { + "jsonrpc": "2.0", + "method": "toolset/signin", + "params": {"toolsetId": "toolsets/public/x"}, + "id": "1", + } + ] -def test_interact_preserves_integer_id_sync(): +@pytest.mark.asyncio +async def test_signin_batch_body_serialized_as_array_async(): captured: dict = {} - def send_mock(request: httpx.Request, **kwargs): + async def send_mock(request: httpx.Request, **_kwargs): captured["request"] = request return httpx.Response( status_code=200, request=request, stream=MockStreamIterator( mock_chunks=_single_event( - {"jsonrpc": "2.0", "result": "ok", "id": 42} + [ + _signin_response("1", "success"), + _signin_response("2", "success"), + ] ) ), ) - client = Dial(api_key="dummy", base_url="http://dial.core") + client = AsyncDial(api_key="dummy", base_url="http://dial.core") client._http_client._internal_http_client.send = send_mock - responses = client.client_channel.interact( - channel_id="x", - request=JsonRpcRequest(method="m", id=42), + await client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a", "b"] ) + body = json.loads(captured["request"].content) - assert body["id"] == 42 and isinstance(body["id"], int) - assert responses[0].id == 42 and isinstance(responses[0].id, int) + assert isinstance(body, list) and len(body) == 2 + assert body[0]["params"] == {"toolsetId": "a"} + assert body[1]["params"] == {"toolsetId": "b"} -@pytest.mark.asyncio -async def test_interact_batch_body_serialized_as_array_async(): - captured: dict = {} +# ---------------------------------------------------------------------------- +# Internal _interact — protocol-level coverage +# ---------------------------------------------------------------------------- - async def send_mock(request: httpx.Request, **kwargs): - captured["request"] = request - return httpx.Response( - status_code=200, - request=request, - stream=MockStreamIterator(mock_chunks=_single_event(BATCH_PAYLOAD)), + +def test_interact_result_null_is_valid_sync(): + # {result: null} is a successful response per JSON-RPC spec — must not + # raise ParsingDataError as the old code did. + chunks = _single_event({"jsonrpc": "2.0", "id": "1", "result": None}) + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + responses = client.client_channel._interact( + channel_id="ch", + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], + ) + assert len(responses) == 1 + assert responses[0].result is None + assert responses[0].error is None + + +def test_interact_response_missing_id_raises_parsing_error_sync(): + chunks = _single_event({"jsonrpc": "2.0", "result": "ok"}) + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(ParsingDataError): + client.client_channel._interact( + channel_id="ch", + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], + ) + + +def test_interact_response_with_both_result_and_error_raises_sync(): + chunks = _single_event( + { + "jsonrpc": "2.0", + "id": "1", + "result": "ok", + "error": {"code": -1, "message": "x"}, + } + ) + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(ParsingDataError): + client.client_channel._interact( + channel_id="ch", + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], + ) + + +def test_interact_response_with_neither_result_nor_error_raises_sync(): + chunks = _single_event({"jsonrpc": "2.0", "id": "1"}) + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(ParsingDataError): + client.client_channel._interact( + channel_id="ch", + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], + ) + + +def test_interact_malformed_json_raises_sync(): + chunks = _sse_chunks("data: not-json", "") + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(ParsingDataError): + client.client_channel._interact( + channel_id="ch", + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], ) - client = AsyncDial(api_key="dummy", base_url="http://dial.core") - client._http_client._internal_http_client.send = send_mock - await client.client_channel.interact( +def test_interact_heartbeats_skipped_sync(): + chunks = _sse_chunks( + ": heartbeat", + "", + ": heartbeat", + "", + _data(_signin_response("1", "success")), + "", + ) + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + responses = client.client_channel._interact( channel_id="ch", - request=[ - JsonRpcRequest(method="m", id="1"), - JsonRpcRequest(method="m", id="2"), - ], + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], ) + assert responses[0].result == "success" - body = json.loads(captured["request"].content) - assert isinstance(body, list) - assert len(body) == 2 - assert body[0]["id"] == "1" - assert body[1]["id"] == "2" + +def test_interact_truncated_stream_warns_and_no_phantom_event_sync(caplog): + # No trailing blank line — incomplete event must NOT be flushed, and a + # warning must be emitted by the SSE parser. + chunks = _sse_chunks('data: {"jsonrpc":"2.0","resu') + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with caplog.at_level(logging.WARNING, logger="aidial_client"): + with pytest.raises(DialException) as exc_info: + client.client_channel._interact( + channel_id="ch", + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], + ) + assert exc_info.value.status_code == HTTPStatus.GATEWAY_TIMEOUT + assert any( + "Uncommitted data chunks in SSE stream" in rec.message + for rec in caplog.records + )