diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 835c886..ee74830 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,4 +15,12 @@ jobs: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 - name: Run unit tests - run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tests/test_pricing.py -v --import-mode=importlib + run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tee_gateway/test/test_util.py tests/test_pricing.py -v --import-mode=importlib + + integration-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v5 + - name: Run integration tests (live CoinGecko API) + run: uv run --group test pytest tests/test_integration.py -v -m integration --import-mode=importlib diff --git a/pyproject.toml b/pyproject.toml index 19a5164..b549ccf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "setuptools>=21.0.0", "Flask>=3.0.0", "gunicorn>=23.0.0", - "og-x402[evm]==0.0.1.dev6", + "og-x402[evm]>=0.0.1.dev9", "fastapi>=0.128.0", "uvicorn[standard]>=0.40.0", "pydantic>=2.12.5", @@ -58,6 +58,12 @@ exclude = [ "**/site-packages", ] +[tool.pytest.ini_options] +pythonpath = ["."] +markers = [ + "integration: tests that require live network access (deselect with '-m not integration')", +] + [tool.uv] # Pre-release needed for og-test-v2-x402==0.0.11.dev5 prerelease = "allow" diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 18f8117..7d6b63a 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -27,19 +27,20 @@ from x402.http.types import RouteConfig from x402.mechanisms.evm.exact import ExactEvmServerScheme from x402.mechanisms.evm.upto import UptoEvmServerScheme +from x402.extensions.erc20_approval_gas_sponsoring import ( + declare_erc20_approval_gas_sponsoring_extension, +) from x402.schemas import AssetAmount from x402.server import x402ResourceServerSync from x402.session import SessionStore -import types as _types import x402.http.middleware.flask as x402_flask from .util import dynamic_session_cost_calculator from .definitions import ( - BASE_TESTNET_NETWORK, EVM_PAYMENT_ADDRESS, - BASE_OPG_ADDRESS, + BASE_MAINNET_NETWORK, + BASE_MAINNET_OPG_ADDRESS, CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, - COMPLETIONS_OPG_SESSION_MAX_SPEND, FACILITATOR_URL, ) @@ -108,8 +109,10 @@ def _shutdown_heartbeat(): server = x402ResourceServerSync(facilitator) store = SessionStore() -server.register(BASE_TESTNET_NETWORK, ExactEvmServerScheme()) -server.register(BASE_TESTNET_NETWORK, UptoEvmServerScheme()) +server.register(BASE_MAINNET_NETWORK, ExactEvmServerScheme()) + +# Upto scheme registrations (permit2-based, variable settlement) +server.register(BASE_MAINNET_NETWORK, UptoEvmServerScheme()) routes = { "POST /v1/chat/completions": RouteConfig( @@ -119,16 +122,19 @@ def _shutdown_heartbeat(): pay_to=EVM_PAYMENT_ADDRESS, price=AssetAmount( amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, - asset=BASE_OPG_ADDRESS, + asset=BASE_MAINNET_OPG_ADDRESS, extra={ - "name": "OPG", - "version": "2", + "name": "OpenGradient", + "version": "1", "assetTransferMethod": "permit2", }, ), - network=BASE_TESTNET_NETWORK, + network=BASE_MAINNET_NETWORK, ), ], + extensions={ + **declare_erc20_approval_gas_sponsoring_extension(), + }, mime_type="application/json", description="Chat completion", ), @@ -138,17 +144,20 @@ def _shutdown_heartbeat(): scheme="upto", pay_to=EVM_PAYMENT_ADDRESS, price=AssetAmount( - amount=COMPLETIONS_OPG_SESSION_MAX_SPEND, - asset=BASE_OPG_ADDRESS, + amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, + asset=BASE_MAINNET_OPG_ADDRESS, extra={ - "name": "OPG", - "version": "2", + "name": "OpenGradient", + "version": "1", "assetTransferMethod": "permit2", }, ), - network=BASE_TESTNET_NETWORK, + network=BASE_MAINNET_NETWORK, ), ], + extensions={ + **declare_erc20_approval_gas_sponsoring_extension(), + }, mime_type="application/json", description="Completion", ), @@ -374,99 +383,21 @@ def _patched_read_body_bytes(environ): ) # --------------------------------------------------------------------------- -# Strict cost-resolution patch +# Cost-resolution behaviour note # -# Why this exists -# --------------- -# The upstream x402 PaymentMiddleware._resolve_session_request_cost wraps the -# call to the session_cost_calculator in a broad try/except. If the calculator -# raises (e.g. ValueError for an unrecognised model name, KeyError for missing -# usage data), the exception is swallowed and the middleware silently falls back -# to the static session maximum (CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND / -# CHAT_COMPLETIONS_USDC_AMOUNT). That silent fallback means: -# • The client is charged the full pre-check cap instead of actual usage. -# • The server has no visible indication that pricing failed. +# dynamic_session_cost_calculator() is invoked by PaymentMiddleware via +# _accumulate_session_cost(), which is itself called from +# StreamingSessionResponse.close(). That close() wraps the call in a broad +# try/except, so any exception raised by the calculator (e.g. ValueError for +# an unrecognised model or missing usage data) is logged but otherwise +# swallowed — the session cost is simply not incremented for that request. # -# The fix -# ------- -# We replace _resolve_session_request_cost with our own implementation that is -# identical to upstream, except the cost-calculator call is NOT wrapped in a -# try/except. Any exception from dynamic_session_cost_calculator() therefore -# propagates up through the middleware and Flask, producing a proper HTTP 500 -# response to the client instead of an incorrect silent charge. +# A previous version of this file monkey-patched _resolve_session_request_cost +# to let exceptions propagate, but that method no longer exists in the library. +# If stricter error propagation is needed in future, patch _accumulate_session_cost +# or StreamingSessionResponse.close() instead. # --------------------------------------------------------------------------- - -def _strict_resolve_session_request_cost( - self, - *, - method: str, - path: str, - request_body_bytes: bytes, - response_body_bytes: bytes, - payment_payload: object, - payment_requirements: object, - status_code: int | None, - output_object: object = None, - is_streaming: bool = False, -) -> int: - """Replacement for PaymentMiddleware._resolve_session_request_cost. - - Identical to the upstream implementation except that exceptions raised by - the dynamic cost calculator are NOT caught. This means a request whose - cost cannot be determined (unknown model, missing usage data, etc.) will - result in a 500 error rather than silently falling back to the static cap - amount and charging the user an incorrect amount. - """ - from x402.http.middleware.flask import _parse_json_bytes as _x402_parse_json # noqa: PLC0415 - - default_cost = self._get_session_cost(payment_requirements) - if not self._should_charge_response(status_code): - return default_cost - if not callable(self._session_cost_calculator): - return default_cost - - request_object = _x402_parse_json(request_body_bytes) - response_object = ( - output_object - if output_object is not None - else _x402_parse_json(response_body_bytes) - ) - - callback_context = { - "method": method, - "path": path, - "status_code": status_code, - "is_streaming": is_streaming, - "request_body_bytes": request_body_bytes, - "response_body_bytes": response_body_bytes, - "request_json": request_object - if isinstance(request_object, (dict, list)) - else None, - "response_json": response_object - if isinstance(response_object, (dict, list)) - else None, - "response_object": response_object, - "payment_payload": payment_payload, - "payment_requirements": payment_requirements, - "default_cost": default_cost, - } - - # Do NOT catch exceptions here — let them propagate so the request fails - # with a 500 rather than silently charging the static fallback amount. - dynamic_cost = self._session_cost_calculator(callback_context) - if dynamic_cost is None: - raise ValueError( - f"dynamic_session_cost_calculator returned None for {method} {path}; " - "cannot determine request cost" - ) - return self._coerce_non_negative_int(dynamic_cost) - - -_payment_mw._resolve_session_request_cost = _types.MethodType( # type: ignore[method-assign] - _strict_resolve_session_request_cost, _payment_mw -) - logger.info("x402 payment middleware initialized") if __name__ == "__main__": diff --git a/tee_gateway/config.py b/tee_gateway/config.py index f6a842d..8fb25c1 100644 --- a/tee_gateway/config.py +++ b/tee_gateway/config.py @@ -9,7 +9,34 @@ from dataclasses import dataclass from typing import Optional +# --------------------------------------------------------------------------- +# OPG / token price feed +# --------------------------------------------------------------------------- + +# How long (seconds) to reuse a cached price before fetching a fresh one. +# At 120 s the gateway makes at most 30 CoinGecko calls/hour — well within +# the free-tier limit (30/min). +OPG_PRICE_CACHE_TTL_SECONDS: int = 120 + +# Number of times to retry a failed CoinGecko fetch before giving up. +# Each attempt uses the same 5-second timeout; retries are immediate (no backoff). +OPG_PRICE_FETCH_RETRIES: int = 3 + +# CoinGecko coin ID for the OPG token. +# https://www.coingecko.com/en/coins/opengradient +OPG_PRICE_COINGECKO_ID: str = "opengradient" + +# Sanity bounds for the fetched token price. +# Used in integration tests to catch obviously wrong API responses +# (wrong currency, implausibly large value). +# Update when OPG establishes a trading range. +OPG_PRICE_SANITY_MAX_USD: str = ( + "1000000" # $1 000 000 — rules out obviously corrupt data +) + +# --------------------------------------------------------------------------- # Heartbeat defaults +# --------------------------------------------------------------------------- DEFAULT_HEARTBEAT_INTERVAL = 900 # 15 minutes DEFAULT_HEARTBEAT_BUFFER = ( 300 # 5 minutes — subtracted from time.time() to compensate for enclave clock drift diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index 49fb406..24c2d3b 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -30,6 +30,7 @@ convert_messages, extract_usage, ) +from tee_gateway.util import validate_pricing_preflight logger = logging.getLogger(__name__) @@ -46,6 +47,12 @@ def create_chat_completion(body): connexion.request.get_json() ) + try: + validate_pricing_preflight(chat_request.model) + except ValueError as exc: + logger.error("Pricing preflight failed for model %r: %s", chat_request.model, exc) + return {"error": "Bad Request", "message": str(exc)}, 400 + if chat_request.stream: return _create_streaming_response(chat_request) else: diff --git a/tee_gateway/controllers/completions_controller.py b/tee_gateway/controllers/completions_controller.py index ba941fe..1b3efd7 100644 --- a/tee_gateway/controllers/completions_controller.py +++ b/tee_gateway/controllers/completions_controller.py @@ -10,6 +10,7 @@ from tee_gateway.tee_manager import get_tee_keys, compute_tee_msg_hash from tee_gateway.llm_backend import get_chat_model_cached, extract_usage +from tee_gateway.util import validate_pricing_preflight logger = logging.getLogger(__name__) @@ -21,6 +22,12 @@ def create_completion(body): else: return {"error": "Request must be application/json"}, 415 + try: + validate_pricing_preflight(body.model) + except ValueError as exc: + logger.error("Pricing preflight failed for model %r: %s", body.model, exc) + return {"error": "Bad Request", "message": str(exc)}, 400 + try: request_dict = { "model": body.model, diff --git a/tee_gateway/definitions.py b/tee_gateway/definitions.py index 83d9d8c..ca316bd 100644 --- a/tee_gateway/definitions.py +++ b/tee_gateway/definitions.py @@ -19,8 +19,9 @@ # Network IDs (EIP-155 chain identifiers) # --------------------------------------------------------------------------- -# Base Testnet — where OPG payments are accepted -BASE_TESTNET_NETWORK: str = "eip155:84532" + +# Base Mainnet — where OPG payments are accepted +BASE_MAINNET_NETWORK: str = "eip155:8453" # --------------------------------------------------------------------------- # Payment recipient @@ -31,15 +32,15 @@ # your own instance. EVM_PAYMENT_ADDRESS: str = os.getenv( "EVM_PAYMENT_ADDRESS", - "0x40eFb45552EDfB2502D90A657a8ab41F03ec460d", + "0x9deEBB5D1b22e4a6e027977CeAd13893A7E4cC1a", ) # --------------------------------------------------------------------------- # ERC-20 token contract addresses # --------------------------------------------------------------------------- -# OpenGradient token (OPG) on Base Testnet -BASE_OPG_ADDRESS: str = "0x240b09731D96979f50B2C649C9CE10FcF9C7987F" +# OpenGradient token (OPG) on Base Mainnet +BASE_MAINNET_OPG_ADDRESS: str = "0xFbC2051AE2265686a469421b2C5A2D5462FbF5eB" # --------------------------------------------------------------------------- # Token decimal places @@ -47,7 +48,7 @@ # Maps lowercase contract address → number of decimals for unit conversion. ASSET_DECIMALS_BY_ADDRESS: dict[str, int] = { - BASE_OPG_ADDRESS.lower(): 18, # OPG: 18 decimals (ERC-20 standard) + BASE_MAINNET_OPG_ADDRESS.lower(): 18, # OPG: 18 decimals (ERC-20 standard) } # Fallback for any asset not explicitly listed above diff --git a/tee_gateway/test/test_tool_forwarding.py b/tee_gateway/test/test_tool_forwarding.py index f4ca16f..ac2c250 100644 --- a/tee_gateway/test/test_tool_forwarding.py +++ b/tee_gateway/test/test_tool_forwarding.py @@ -1,4 +1,5 @@ import unittest +from decimal import Decimal from unittest.mock import patch, Mock from tee_gateway.controllers.chat_controller import ( @@ -11,6 +12,21 @@ ChatCompletionRequestFunctionMessage, ) +# Pin the token price for all integration-style tests in this file so they +# never hit the real CoinGecko API. The price feed is tested separately in +# test_util.py and test_integration.py. +_price_patcher = patch( + "tee_gateway.util.get_token_a_price_usd", return_value=Decimal("1") +) + + +def setUpModule(): + _price_patcher.start() + + +def tearDownModule(): + _price_patcher.stop() + # --------------------------------------------------------------------------- # Shared helpers diff --git a/tee_gateway/test/test_util.py b/tee_gateway/test/test_util.py new file mode 100644 index 0000000..374c4f0 --- /dev/null +++ b/tee_gateway/test/test_util.py @@ -0,0 +1,309 @@ +""" +Tests for tee_gateway.util — OPG price fetching, caching, and dynamic cost calculation. + +All tests are fully offline: urllib.request.urlopen is patched so no real +network call is ever made. +""" + +import json +import unittest +from decimal import Decimal +from unittest.mock import MagicMock, patch + +from tee_gateway.config import OPG_PRICE_COINGECKO_ID +from tee_gateway.util import ( + _fetch_opg_price_usd, + _token_price_cache, + dynamic_session_cost_calculator, + get_token_a_price_usd, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_urlopen_response(price: float) -> MagicMock: + """Return a mock context-manager that urlopen returns with a CoinGecko payload.""" + body = json.dumps({OPG_PRICE_COINGECKO_ID: {"usd": price}}).encode() + mock_resp = MagicMock() + mock_resp.read.return_value = body + mock_resp.__enter__ = lambda s: s + mock_resp.__exit__ = MagicMock(return_value=False) + return mock_resp + + +def _reset_price_cache() -> None: + """Reset module-level price cache to pristine state between tests.""" + _token_price_cache["last_good"] = None + _token_price_cache["updated_at"] = 0.0 + + +# --------------------------------------------------------------------------- +# _fetch_opg_price_usd +# --------------------------------------------------------------------------- + + +class TestFetchOPGPrice(unittest.TestCase): + """_fetch_opg_price_usd makes one HTTP call and returns a Decimal price.""" + + def setUp(self): + _reset_price_cache() + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_returns_decimal_price(self, mock_urlopen): + mock_urlopen.return_value = _make_urlopen_response(3000.50) + price = _fetch_opg_price_usd() + self.assertIsInstance(price, Decimal) + self.assertEqual(price, Decimal("3000.5")) + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_uses_configured_coingecko_id_in_url(self, mock_urlopen): + mock_urlopen.return_value = _make_urlopen_response(3000.0) + _fetch_opg_price_usd() + call_args = mock_urlopen.call_args + # First positional arg is the Request object + req = call_args[0][0] + self.assertIn(OPG_PRICE_COINGECKO_ID, req.full_url) + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_raises_on_non_positive_price(self, mock_urlopen): + mock_urlopen.return_value = _make_urlopen_response(0.0) + with self.assertRaises(ValueError): + _fetch_opg_price_usd() + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_raises_on_network_error(self, mock_urlopen): + mock_urlopen.side_effect = OSError("connection refused") + with self.assertRaises(OSError): + _fetch_opg_price_usd() + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_raises_on_malformed_json(self, mock_urlopen): + mock_resp = MagicMock() + mock_resp.read.return_value = b"not-json" + mock_resp.__enter__ = lambda s: s + mock_resp.__exit__ = MagicMock(return_value=False) + mock_urlopen.return_value = mock_resp + with self.assertRaises(Exception): + _fetch_opg_price_usd() + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_raises_when_coin_has_no_price(self, mock_urlopen): + """CoinGecko returns the coin but without a 'usd' key (no trading price yet). + + This is a deterministic failure — retrying won't help, so it must raise + immediately after the first call without consuming the retry budget. + """ + body = json.dumps({OPG_PRICE_COINGECKO_ID: {}}).encode() + mock_resp = MagicMock() + mock_resp.read.return_value = body + mock_resp.__enter__ = lambda s: s + mock_resp.__exit__ = MagicMock(return_value=False) + mock_urlopen.return_value = mock_resp + with self.assertRaises(ValueError) as ctx: + _fetch_opg_price_usd() + self.assertIn("no price", str(ctx.exception)) + self.assertEqual( + mock_urlopen.call_count, 1 + ) # no retries for deterministic error + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_retries_on_failure_then_succeeds(self, mock_urlopen): + """Succeeds on the final attempt after earlier failures.""" + from tee_gateway.config import OPG_PRICE_FETCH_RETRIES + + mock_urlopen.side_effect = [OSError("timeout")] * ( + OPG_PRICE_FETCH_RETRIES - 1 + ) + [_make_urlopen_response(3000.0)] + price = _fetch_opg_price_usd() + self.assertEqual(price, Decimal("3000.0")) + self.assertEqual(mock_urlopen.call_count, OPG_PRICE_FETCH_RETRIES) + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_raises_after_all_retries_exhausted(self, mock_urlopen): + """Raises the last exception once all retry attempts are used up.""" + from tee_gateway.config import OPG_PRICE_FETCH_RETRIES + + mock_urlopen.side_effect = OSError("connection refused") + with self.assertRaises(OSError): + _fetch_opg_price_usd() + self.assertEqual(mock_urlopen.call_count, OPG_PRICE_FETCH_RETRIES) + + +# --------------------------------------------------------------------------- +# get_token_a_price_usd — caching behaviour +# --------------------------------------------------------------------------- + + +class TestGetTokenAPriceUSD(unittest.TestCase): + """get_token_a_price_usd must respect the TTL and fallback gracefully.""" + + def setUp(self): + _reset_price_cache() + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_returns_fetched_price_on_cold_cache(self, mock_urlopen): + mock_urlopen.return_value = _make_urlopen_response(3000.0) + price = get_token_a_price_usd() + self.assertEqual(price, Decimal("3000.0")) + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_cache_hit_skips_second_network_call(self, mock_urlopen): + mock_urlopen.return_value = _make_urlopen_response(3000.0) + get_token_a_price_usd() # populates cache + get_token_a_price_usd() # should hit cache + self.assertEqual(mock_urlopen.call_count, 1) + + @patch("tee_gateway.util.urllib.request.urlopen") + @patch("tee_gateway.util.time") + def test_cache_expires_after_ttl(self, mock_time, mock_urlopen): + mock_urlopen.return_value = _make_urlopen_response(3000.0) + mock_time.time.return_value = 1_000_000.0 + get_token_a_price_usd() + + # Advance time past TTL + from tee_gateway.config import OPG_PRICE_CACHE_TTL_SECONDS + + mock_time.time.return_value = 1_000_000.0 + OPG_PRICE_CACHE_TTL_SECONDS + 1 + mock_urlopen.return_value = _make_urlopen_response(3500.0) + price = get_token_a_price_usd() + + self.assertEqual(mock_urlopen.call_count, 2) + self.assertEqual(price, Decimal("3500.0")) + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_raises_on_refresh_failure(self, mock_urlopen): + """If the cache is expired and a refresh fails, the error is raised immediately.""" + mock_urlopen.return_value = _make_urlopen_response(3000.0) + get_token_a_price_usd() # populate cache + + # Force cache to appear expired then make the refresh fail + _token_price_cache["updated_at"] = 0.0 + mock_urlopen.side_effect = OSError("network down") + with self.assertRaises(OSError): + get_token_a_price_usd() + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_raises_when_never_fetched_and_network_fails(self, mock_urlopen): + """With empty cache and a failing network, the exception propagates.""" + mock_urlopen.side_effect = OSError("network down") + with self.assertRaises(OSError): + get_token_a_price_usd() + + +# --------------------------------------------------------------------------- +# dynamic_session_cost_calculator — end-to-end with mocked price +# --------------------------------------------------------------------------- + + +class TestDynamicSessionCostCalculator(unittest.TestCase): + """Full pipeline: token counts + model pricing + OPG price → on-chain units.""" + + def setUp(self): + _reset_price_cache() + + def _make_context( + self, + model: str = "gpt-4.1", + prompt_tokens: int = 100, + completion_tokens: int = 50, + asset: str = "0xFbC2051AE2265686a469421b2C5A2D5462FbF5eB", # OPG mainnet + ) -> dict: + return { + "request_json": {"model": model}, + "response_json": { + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + } + }, + "payment_requirements": {"asset": asset}, + } + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_returns_positive_integer(self, mock_urlopen): + mock_urlopen.return_value = _make_urlopen_response(3000.0) + cost = dynamic_session_cost_calculator(self._make_context()) + self.assertIsInstance(cost, int) + self.assertGreater(cost, 0) + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_higher_token_price_reduces_cost(self, mock_urlopen): + """When OPG is worth more ($/token), fewer tokens are charged.""" + mock_urlopen.return_value = _make_urlopen_response(1000.0) + _reset_price_cache() + cost_cheap = dynamic_session_cost_calculator(self._make_context()) + + mock_urlopen.return_value = _make_urlopen_response(5000.0) + _reset_price_cache() + cost_expensive = dynamic_session_cost_calculator(self._make_context()) + + self.assertGreater(cost_cheap, cost_expensive) + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_more_tokens_increases_cost(self, mock_urlopen): + mock_urlopen.return_value = _make_urlopen_response(3000.0) + cost_small = dynamic_session_cost_calculator( + self._make_context(prompt_tokens=10, completion_tokens=5) + ) + _reset_price_cache() + mock_urlopen.return_value = _make_urlopen_response(3000.0) + cost_large = dynamic_session_cost_calculator( + self._make_context(prompt_tokens=1000, completion_tokens=500) + ) + self.assertGreater(cost_large, cost_small) + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_cost_scales_correctly(self, mock_urlopen): + """Spot-check the math for gpt-4.1 at a known token price of $3 000. + + gpt-4.1 input: $0.000002/token, output: $0.000008/token + 100 input + 50 output = $0.0002 + $0.0004 = $0.0006 USD + At token price = $3 000: 0.0006 / 3000 = 0.0000002 tokens + In smallest units (10^18 decimals): 200_000_000_000 + """ + mock_urlopen.return_value = _make_urlopen_response(3000.0) + cost = dynamic_session_cost_calculator( + self._make_context(model="gpt-4.1", prompt_tokens=100, completion_tokens=50) + ) + self.assertEqual(cost, 200_000_000_000) + + @patch("tee_gateway.util.urllib.request.urlopen") + def test_zero_tokens_returns_zero(self, mock_urlopen): + mock_urlopen.return_value = _make_urlopen_response(3000.0) + cost = dynamic_session_cost_calculator( + self._make_context(prompt_tokens=0, completion_tokens=0) + ) + self.assertEqual(cost, 0) + + def test_raises_on_unknown_model(self): + ctx = self._make_context(model="not-a-real-model") + with self.assertRaises(ValueError): + dynamic_session_cost_calculator(ctx) + + def test_raises_when_usage_missing(self): + ctx = { + "request_json": {"model": "gpt-4.1"}, + "response_json": {}, + "payment_requirements": { + "asset": "0xFbC2051AE2265686a469421b2C5A2D5462FbF5eB" + }, + } + with self.assertRaises(ValueError): + dynamic_session_cost_calculator(ctx) + + def test_raises_when_request_json_missing(self): + with self.assertRaises(ValueError): + dynamic_session_cost_calculator( + { + "response_json": { + "usage": {"prompt_tokens": 1, "completion_tokens": 1} + } + } + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tee_gateway/util.py b/tee_gateway/util.py index 47559d9..effbcf7 100644 --- a/tee_gateway/util.py +++ b/tee_gateway/util.py @@ -155,46 +155,106 @@ def _deserialize_dict(data, boxed_type): return {k: _deserialize(v, boxed_type) for k, v in data.items()} +import json # noqa: E402 +import urllib.request # noqa: E402 + +from tee_gateway.config import ( # noqa: E402 + OPG_PRICE_CACHE_TTL_SECONDS, + OPG_PRICE_COINGECKO_ID, + OPG_PRICE_FETCH_RETRIES, +) from tee_gateway.definitions import ( # noqa: E402 ASSET_DECIMALS_BY_ADDRESS, ) from tee_gateway.model_registry import get_model_config # noqa: E402 -TOKEN_A_PRICE_CACHE_TTL_SECONDS = 60 - +# Cache layout: +# "last_good" – most recent successfully fetched price (Decimal | None) +# "updated_at" – epoch seconds of last successful fetch (float) _token_price_cache: dict[str, Any] = { - "value": Decimal("1"), + "last_good": None, "updated_at": 0.0, } _token_price_lock = threading.Lock() -def _fetch_token_a_price_usd_mock() -> Decimal: - """Return the USD price of the payment token used for cost calculation. +def _fetch_opg_price_usd() -> Decimal: + """Fetch the OPG/USD price from CoinGecko, retrying up to OPG_PRICE_FETCH_RETRIES times. - Currently returns a fixed 1:1 ratio, which is correct for USDC-denominated - payments (1 USDC ≈ $1 USD). For OPG-denominated payments, replace this - with a live price feed (e.g. a DEX oracle or CoinGecko API call) that - returns the current OPG/USD exchange rate so that token amounts are - calculated correctly against the model's USD pricing. + The token queried is controlled by OPG_PRICE_COINGECKO_ID in config.py. + Raises ValueError if the token is listed but has no price data yet. + Raises the last exception if all network attempts fail. """ - return Decimal("1") + url = ( + f"https://api.coingecko.com/api/v3/simple/price" + f"?ids={OPG_PRICE_COINGECKO_ID}&vs_currencies=usd" + ) + last_exc: Exception = RuntimeError("no attempts made") + for attempt in range(1, OPG_PRICE_FETCH_RETRIES + 1): + try: + req = urllib.request.Request(url, headers={"User-Agent": "tee-gateway/1.0"}) + with urllib.request.urlopen(req, timeout=5) as resp: + data: dict[str, Any] = json.loads(resp.read()) + coin_data = data.get(OPG_PRICE_COINGECKO_ID) + if not isinstance(coin_data, dict) or "usd" not in coin_data: + # Deterministic failure — the coin is listed but has no price. + # Retrying won't help, so raise immediately without consuming + # the remaining retry budget. + logger.error( + "CoinGecko returned no USD price for '%s' — token may not have a trading price yet. Response: %r", + OPG_PRICE_COINGECKO_ID, + data, + ) + raise ValueError( + f"CoinGecko returned no price for '{OPG_PRICE_COINGECKO_ID}' — " + f"token may not have a trading price yet: {data!r}" + ) + price = Decimal(str(coin_data["usd"])) + if price <= 0: + raise ValueError( + f"CoinGecko returned non-positive price for '{OPG_PRICE_COINGECKO_ID}': {price}" + ) + return price + except ValueError: + raise + except Exception as exc: + last_exc = exc + logger.warning( + "CoinGecko price fetch attempt %d/%d failed: %s", + attempt, + OPG_PRICE_FETCH_RETRIES, + exc, + ) + raise last_exc def get_token_a_price_usd() -> Decimal: + """Return the current OPG/USD price, refreshing at most once per TTL window. + + Strategy: + - Return cached price immediately if it was fetched within the TTL. + - On TTL expiry, attempt a fresh CoinGecko fetch. + - Success → update cache, return new price. + - Failure → raise immediately; no silent fallback. + This means at most one network call every TTL window regardless of request + volume, and inference is blocked (400 returned) if the price cannot be fetched. + """ now = time.time() with _token_price_lock: - cached_value = _token_price_cache.get("value") + last_good: Decimal | None = _token_price_cache.get("last_good") # type: ignore[assignment] cached_at = float(_token_price_cache.get("updated_at") or 0.0) - if ( - isinstance(cached_value, Decimal) - and (now - cached_at) < TOKEN_A_PRICE_CACHE_TTL_SECONDS - ): - return cached_value - value = _fetch_token_a_price_usd_mock() - _token_price_cache["value"] = value + if last_good is not None and (now - cached_at) < OPG_PRICE_CACHE_TTL_SECONDS: + return last_good + + value = _fetch_opg_price_usd() + _token_price_cache["last_good"] = value _token_price_cache["updated_at"] = now + logger.info( + "OPG price refreshed: $%s (via CoinGecko '%s')", + value, + OPG_PRICE_COINGECKO_ID, + ) return value @@ -304,6 +364,21 @@ def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int: return ASSET_DECIMALS_BY_ADDRESS[asset_lower] +def validate_pricing_preflight(model: str) -> None: + """Validate that this request can be priced before any LLM call is made. + + Raises ValueError if the model is not in the registry. + Raises (propagates) whatever get_token_a_price_usd raises if the price + feed is unavailable — e.g. network down or token has no trading price yet. + + Call this at the top of each request handler so that a pricing failure + returns a proper error to the client rather than silently producing free + inference after the response has already been sent. + """ + get_model_config(model) # raises ValueError for unknown models + get_token_a_price_usd() # raises if price is unavailable + + def dynamic_session_cost_calculator(context: dict[str, Any]) -> int: """Compute UPTO per-request cost in token smallest units from actual usage. diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..cd77983 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,155 @@ +""" +Integration tests — require live network access. + +These tests hit the real CoinGecko API and are intentionally excluded from +the standard unit-test run. Opt in with: + + pytest -m integration tests/test_integration.py + +In CI, these run in a separate job (see .github/workflows/test.yml). + +NOTE: OPG (opengradient) is listed on CoinGecko but currently has no trading +price. The fetch tests below verify the correct error behaviour until a price +becomes available. Tests that require a live price are skipped automatically +when OPG has no price data. +""" + +import pytest +from decimal import Decimal + + +def _opg_has_price() -> bool: + """Return True if CoinGecko currently reports a price for OPG.""" + try: + from tee_gateway.util import _fetch_opg_price_usd + + _fetch_opg_price_usd() + return True + except Exception: + return False + + +requires_opg_price = pytest.mark.skipif( + not _opg_has_price(), + reason="OPG has no trading price on CoinGecko yet", +) + + +@pytest.mark.integration +class TestCoinGeckoPriceFeed: + """Verify the live OPG price fetch end-to-end via the configured CoinGecko token.""" + + def test_coingecko_slug_is_recognised(self): + """CoinGecko must recognise the OPG slug (i.e. return a dict for the coin, + even if price data is absent). A completely unknown slug returns an empty dict.""" + import json + import urllib.request + + from tee_gateway.config import OPG_PRICE_COINGECKO_ID + + url = ( + f"https://api.coingecko.com/api/v3/simple/price" + f"?ids={OPG_PRICE_COINGECKO_ID}&vs_currencies=usd" + ) + req = urllib.request.Request(url, headers={"User-Agent": "tee-gateway/1.0"}) + with urllib.request.urlopen(req, timeout=10) as resp: + data = json.loads(resp.read()) + + assert OPG_PRICE_COINGECKO_ID in data, ( + f"CoinGecko did not recognise slug '{OPG_PRICE_COINGECKO_ID}'. " + f"Response: {data!r}" + ) + + def test_fetch_raises_clear_error_when_no_price(self): + """When OPG has no trading price, _fetch_opg_price_usd must raise ValueError + with a message indicating the price is unavailable — not a bare KeyError.""" + import urllib.error + + if _opg_has_price(): + pytest.skip("OPG now has a price — this test is no longer applicable") + + from tee_gateway.util import _fetch_opg_price_usd + + try: + _fetch_opg_price_usd() + pytest.fail("Expected ValueError but no exception was raised") + except ValueError as exc: + assert "no price" in str(exc), f"Unexpected ValueError message: {exc}" + except urllib.error.HTTPError as exc: + if exc.code == 429: + pytest.skip(f"CoinGecko rate-limited the integration test run: {exc}") + raise + + @requires_opg_price + def test_fetch_returns_positive_decimal(self): + """_fetch_opg_price_usd must return a positive Decimal from CoinGecko.""" + from tee_gateway.util import _fetch_opg_price_usd + + price = _fetch_opg_price_usd() + assert isinstance(price, Decimal) + assert price > 0, f"Expected positive price, got {price}" + + @requires_opg_price + def test_price_is_within_sanity_bounds(self): + """Fetched price must not exceed the configured sanity ceiling.""" + from tee_gateway.config import OPG_PRICE_SANITY_MAX_USD + from tee_gateway.util import _fetch_opg_price_usd + + price = _fetch_opg_price_usd() + max_price = Decimal(OPG_PRICE_SANITY_MAX_USD) + assert 0 < price < max_price, ( + f"Price ${price} is outside the expected range (0, ${OPG_PRICE_SANITY_MAX_USD})" + ) + + @requires_opg_price + def test_get_token_a_price_usd_returns_cached_value(self): + """get_token_a_price_usd must return the same value on two rapid calls + (second call must hit the cache, not make a second network request).""" + from tee_gateway.util import _token_price_cache, get_token_a_price_usd + + # Reset cache so first call is a fresh fetch + _token_price_cache["last_good"] = None + _token_price_cache["updated_at"] = 0.0 + + first = get_token_a_price_usd() + second = get_token_a_price_usd() + + assert first == second, "Cache should return the same price on the second call" + assert first > 0 + + @requires_opg_price + def test_dynamic_cost_uses_live_price(self): + """Full pipeline: token counts + live token price → positive on-chain units.""" + from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS + from tee_gateway.util import ( + _token_price_cache, + dynamic_session_cost_calculator, + get_token_a_price_usd, + ) + + # Reset cache to force a live fetch + _token_price_cache["last_good"] = None + _token_price_cache["updated_at"] = 0.0 + + ctx = { + "request_json": {"model": "gpt-4.1"}, + "response_json": { + "usage": {"prompt_tokens": 1000, "completion_tokens": 500} + }, + "payment_requirements": {"asset": BASE_MAINNET_OPG_ADDRESS}, + } + + cost = dynamic_session_cost_calculator(ctx) + live_price = get_token_a_price_usd() + + assert isinstance(cost, int) + assert cost > 0 + + # Sanity: cost should be far less than 1 full OPG (10^18 units) + # for a small request at any plausible token price + assert cost < 10**18, f"Cost {cost} seems too large for a small request" + + from tee_gateway.config import OPG_PRICE_COINGECKO_ID + + print(f"\nLive price ({OPG_PRICE_COINGECKO_ID}): ${live_price}") + print(f"Cost for gpt-4.1 (1000 input + 500 output tokens): {cost} OPG units") diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 2f3ed90..2f1ef50 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -10,14 +10,35 @@ import unittest from decimal import Decimal +from unittest.mock import patch -from tee_gateway.definitions import BASE_OPG_ADDRESS +from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS from tee_gateway.model_registry import ( _MODEL_LOOKUP, get_model_config, ) from tee_gateway.util import dynamic_session_cost_calculator +# --------------------------------------------------------------------------- +# Pin token price to $1 for all unit tests in this file. +# +# test_pricing.py exists to verify the USD-per-token math for each model. +# It should not depend on a live price feed — that is tested separately in +# test_integration.py. A $1 pin keeps the hardcoded expected wei values valid +# and makes the test suite fully deterministic and offline. +# --------------------------------------------------------------------------- +_price_patcher = patch( + "tee_gateway.util.get_token_a_price_usd", return_value=Decimal("1") +) + + +def setUpModule(): + _price_patcher.start() + + +def tearDownModule(): + _price_patcher.stop() + # --------------------------------------------------------------------------- # Helpers @@ -26,7 +47,7 @@ def _opg_requirements() -> dict: """Fake PaymentRequirements dict for OPG (18 decimals).""" - return {"asset": BASE_OPG_ADDRESS, "amount": "50000000000000000"} + return {"asset": BASE_MAINNET_OPG_ADDRESS, "amount": "50000000000000000"} def _ctx(model: str, input_tokens: int, output_tokens: int, requirements=None) -> dict: diff --git a/uv.lock b/uv.lock index af18737..bfb2691 100644 --- a/uv.lock +++ b/uv.lock @@ -1242,17 +1242,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, ] +[[package]] +name = "nest-asyncio" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/83/f8/51569ac65d696c8ecbee95938f89d4abf00f47d58d48f6fbabfe8f0baefe/nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe", size = 7418, upload-time = "2024-01-21T14:25:19.227Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" }, +] + [[package]] name = "og-x402" -version = "0.0.1.dev6" +version = "0.0.1.dev9" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "nest-asyncio" }, { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/70/aa/2b616b9be6dfa4dfee98bde3ed20dd41cb446d0569e0069c1d6c11faa032/og_x402-0.0.1.dev6.tar.gz", hash = "sha256:140c4b725f372e81f4a3c2caf392f58b6fcf242bc51a1c3a6417f58e3ef9e347", size = 900115, upload-time = "2026-03-30T07:13:25.623Z" } +sdist = { url = "https://files.pythonhosted.org/packages/97/f5/02e7b68af825c200da2aa88292f2c07823d321a4fd9e2a3d20130358fc10/og_x402-0.0.1.dev9.tar.gz", hash = "sha256:d3cfd05443636712cb1277e3d904b878d875a60b3728d64265098ea06eeb116b", size = 1312652, upload-time = "2026-04-10T10:40:13.267Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/03/5e/a64de6f29eb80bb180288297882d5aba2a894363622d4f94417b420cf0b5/og_x402-0.0.1.dev6-py3-none-any.whl", hash = "sha256:2a1f962fa2a50d02f28421199027245d5c5f013f36a143ec2f184a546325f1bd", size = 952670, upload-time = "2026-03-30T07:13:00.408Z" }, + { url = "https://files.pythonhosted.org/packages/8b/08/f5a05fc8454541e96650d44bf15b34491505d0e4f1e9e77b26c804fbbdd3/og_x402-0.0.1.dev9-py3-none-any.whl", hash = "sha256:2db171be2526aa13a1243255538d185c4f1f6106f615eff532d1720a89672034", size = 1392934, upload-time = "2026-04-10T10:40:11.224Z" }, ] [package.optional-dependencies] @@ -1881,7 +1891,7 @@ requires-dist = [ { name = "langchain-google-genai", specifier = ">=4.2.1" }, { name = "langchain-openai", specifier = ">=1.1.12" }, { name = "langchain-xai", specifier = ">=1.2.2" }, - { name = "og-x402", extras = ["evm"], specifier = "==0.0.1.dev6" }, + { name = "og-x402", extras = ["evm"], specifier = ">=0.0.1.dev9" }, { name = "openai", specifier = ">=2.15.0" }, { name = "psutil", specifier = ">=7.2.1" }, { name = "pydantic", specifier = ">=2.12.5" },