diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 2d03512..e7e5337 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -1,5 +1,6 @@ """LLM chat and completion via TEE-verified execution with x402 payments.""" +import base64 import json import logging from dataclasses import dataclass @@ -35,6 +36,7 @@ _COMPLETION_ENDPOINT = "/v1/completions" _REQUEST_TIMEOUT = 60 + @dataclass(frozen=True) class _ChatParams: """Bundles the common parameters for chat/completion requests.""" @@ -385,8 +387,13 @@ async def _request() -> TextGenerationOutput: headers=self._headers(params.x402_settlement_mode), timeout=_REQUEST_TIMEOUT, ) - response.raise_for_status() raw_body = await response.aread() + if response.status_code >= 400: + raise httpx.HTTPStatusError( + _format_http_error(response, raw_body), + request=response.request, + response=response, + ) result = json.loads(raw_body.decode()) choices = result.get("choices") @@ -532,3 +539,23 @@ async def _parse_sse_response(self, response, tee) -> AsyncGenerator[StreamChunk chunk.tee_endpoint = tee.endpoint chunk.tee_payment_address = tee.payment_address yield chunk + + +def _decode_payment_required(header_value: Optional[str]) -> str: + """Decode the base64-encoded JSON in the `payment-required` response header.""" + if not header_value: + return "" + try: + decoded = base64.b64decode(header_value).decode("utf-8") + return json.dumps(json.loads(decoded), indent=2) + except Exception: + return header_value + + +def _format_http_error(response: httpx.Response, body: bytes) -> str: + """Build an error message that surfaces the x402 payment-required details.""" + return ( + f"HTTP {response.status_code} from {response.url}\n" + f"Payment-Required: {_decode_payment_required(response.headers.get('payment-required'))}\n" + f"Body: {body.decode(errors='replace')}" + ) diff --git a/tests/llm_test.py b/tests/llm_test.py index 8e5aba2..5309f28 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -7,7 +7,7 @@ import json import ssl from contextlib import asynccontextmanager -from typing import List +from typing import Dict, List from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -80,6 +80,9 @@ class _FakeResponse: def __init__(self, status_code: int, body: bytes): self.status_code = status_code self._body = body + self.headers: Dict[str, str] = {} + self.request = MagicMock() + self.url = "https://test.tee.server/v1/chat/completions" def raise_for_status(self): pass