diff --git a/pyproject.toml b/pyproject.toml index d1bfc8b..17a8731 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,8 +27,8 @@ dependencies = [ "langchain>=0.3.7", "openai>=1.58.1", "pydantic>=2.9.2", - "og-x402>=0.0.2.dev1", - "og-x402[extensions]>=0.0.2.dev1", + "og-x402>=0.0.2.dev2", + "og-x402[extensions]>=0.0.2.dev2", ] [project.optional-dependencies] diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 0c01281..c69185d 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -149,6 +149,16 @@ def _headers(self, settlement_mode: x402SettlementMode) -> Dict[str, str]: "X-SETTLEMENT-TYPE": settlement_mode.value, } + @staticmethod + def _data_settlement_transaction_hash(response: httpx.Response) -> Optional[str]: + value: Optional[str] = response.headers.get(X402_DATA_SETTLEMENT_TX_HASH_HEADER) + return value + + @staticmethod + def _data_settlement_blob_id(response: httpx.Response) -> Optional[str]: + value: Optional[str] = response.headers.get(X402_DATA_SETTLEMENT_BLOB_ID_HEADER) + return value + def _chat_payload(self, params: _ChatParams, messages: List[Dict], stream: bool = False) -> Dict: payload: Dict = { "model": params.model, @@ -287,8 +297,8 @@ async def _request() -> TextGenerationOutput: raw_body = await response.aread() result = json.loads(raw_body.decode()) return TextGenerationOutput( - data_settlement_transaction_hash=response.headers.get(X402_DATA_SETTLEMENT_TX_HASH_HEADER), - data_settlement_blob_id=response.headers.get(X402_DATA_SETTLEMENT_BLOB_ID_HEADER), + data_settlement_transaction_hash=self._data_settlement_transaction_hash(response), + data_settlement_blob_id=self._data_settlement_blob_id(response), completion_output=result.get("completion"), tee_signature=result.get("tee_signature"), tee_timestamp=result.get("tee_timestamp"), @@ -411,8 +421,8 @@ async def _request() -> TextGenerationOutput: ).strip() return TextGenerationOutput( - data_settlement_transaction_hash=response.headers.get(X402_DATA_SETTLEMENT_TX_HASH_HEADER), - data_settlement_blob_id=response.headers.get(X402_DATA_SETTLEMENT_BLOB_ID_HEADER), + data_settlement_transaction_hash=self._data_settlement_transaction_hash(response), + data_settlement_blob_id=self._data_settlement_blob_id(response), finish_reason=choices[0].get("finish_reason"), chat_output=message, usage=result.get("usage"), @@ -511,6 +521,7 @@ async def _parse_sse_response(self, response, tee) -> AsyncGenerator[StreamChunk raise RuntimeError(f"TEE LLM streaming request failed with status {status_code}: {body.decode('utf-8', errors='replace')}") buffer = b"" + pending_final_chunk: Optional[StreamChunk] = None async for raw_chunk in response.aiter_raw(): if not raw_chunk: continue @@ -532,6 +543,8 @@ async def _parse_sse_response(self, response, tee) -> AsyncGenerator[StreamChunk data_str = decoded[6:].strip() if data_str == "[DONE]": + if pending_final_chunk is not None: + yield pending_final_chunk return try: @@ -541,13 +554,23 @@ async def _parse_sse_response(self, response, tee) -> AsyncGenerator[StreamChunk chunk = StreamChunk.from_sse_data(data) if chunk.is_final: - chunk.data_settlement_transaction_hash = response.headers.get(X402_DATA_SETTLEMENT_TX_HASH_HEADER) - chunk.data_settlement_blob_id = response.headers.get(X402_DATA_SETTLEMENT_BLOB_ID_HEADER) + chunk.data_settlement_transaction_hash = ( + chunk.data_settlement_transaction_hash + or self._data_settlement_transaction_hash(response) + ) + chunk.data_settlement_blob_id = ( + chunk.data_settlement_blob_id or self._data_settlement_blob_id(response) + ) chunk.tee_id = tee.tee_id chunk.tee_endpoint = tee.endpoint chunk.tee_payment_address = tee.payment_address + pending_final_chunk = chunk + continue yield chunk + if pending_final_chunk is not None: + yield pending_final_chunk + def _decode_payment_required(header_value: Optional[str]) -> str: """Decode the base64-encoded JSON in the `payment-required` response header.""" diff --git a/uv.lock b/uv.lock index 3dc10c2..bbfb293 100644 --- a/uv.lock +++ b/uv.lock @@ -1950,8 +1950,8 @@ requires-dist = [ { name = "langgraph", marker = "extra == 'dev'" }, { name = "mypy", marker = "extra == 'dev'" }, { name = "numpy", specifier = ">=1.26.4" }, - { name = "og-x402", specifier = ">=0.0.2.dev1" }, - { name = "og-x402", extras = ["extensions"], specifier = ">=0.0.2.dev1" }, + { name = "og-x402", specifier = ">=0.0.2.dev2" }, + { name = "og-x402", extras = ["extensions"], specifier = ">=0.0.2.dev2" }, { name = "openai", specifier = ">=1.58.1" }, { name = "pdoc3", marker = "extra == 'dev'", specifier = "==0.10.0" }, { name = "pydantic", specifier = ">=2.9.2" },