Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/opengradient/client/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""LLM chat and completion via TEE-verified execution with x402 payments."""

import base64
import json
import logging
from dataclasses import dataclass
Expand Down Expand Up @@ -35,6 +36,7 @@
_COMPLETION_ENDPOINT = "/v1/completions"
_REQUEST_TIMEOUT = 60


@dataclass(frozen=True)
class _ChatParams:
"""Bundles the common parameters for chat/completion requests."""
Expand Down Expand Up @@ -385,8 +387,13 @@ async def _request() -> TextGenerationOutput:
headers=self._headers(params.x402_settlement_mode),
timeout=_REQUEST_TIMEOUT,
)
response.raise_for_status()
raw_body = await response.aread()
if response.status_code >= 400:
raise httpx.HTTPStatusError(
_format_http_error(response, raw_body),
request=response.request,
response=response,
)
result = json.loads(raw_body.decode())

choices = result.get("choices")
Expand Down Expand Up @@ -532,3 +539,23 @@ async def _parse_sse_response(self, response, tee) -> AsyncGenerator[StreamChunk
chunk.tee_endpoint = tee.endpoint
chunk.tee_payment_address = tee.payment_address
yield chunk


def _decode_payment_required(header_value: Optional[str]) -> str:
"""Decode the base64-encoded JSON in the `payment-required` response header."""
if not header_value:
return "<missing>"
try:
decoded = base64.b64decode(header_value).decode("utf-8")
return json.dumps(json.loads(decoded), indent=2)
except Exception:
return header_value


def _format_http_error(response: httpx.Response, body: bytes) -> str:
"""Build an error message that surfaces the x402 payment-required details."""
return (
f"HTTP {response.status_code} from {response.url}\n"
f"Payment-Required: {_decode_payment_required(response.headers.get('payment-required'))}\n"
f"Body: {body.decode(errors='replace')}"
)
5 changes: 4 additions & 1 deletion tests/llm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import ssl
from contextlib import asynccontextmanager
from typing import List
from typing import Dict, List
from unittest.mock import AsyncMock, MagicMock, patch

import httpx
Expand Down Expand Up @@ -80,6 +80,9 @@ class _FakeResponse:
def __init__(self, status_code: int, body: bytes):
self.status_code = status_code
self._body = body
self.headers: Dict[str, str] = {}
self.request = MagicMock()
self.url = "https://test.tee.server/v1/chat/completions"

def raise_for_status(self):
pass
Expand Down
Loading