diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 835c886..273d4c9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,4 +15,7 @@ jobs: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 - name: Run unit tests - run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tests/test_pricing.py -v --import-mode=importlib + run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tee_gateway/test/test_price_feed.py tests/test_pricing.py -v --import-mode=importlib + # To also run integration tests (real CoinGecko network calls), add: + # env: + # RUN_INTEGRATION_TESTS: "1" diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index fd36b24..d5662fc 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -34,9 +34,10 @@ from x402.server import x402ResourceServerSync from x402.session import SessionStore import x402.http.middleware.flask as x402_flask -import types as _types -from .util import dynamic_session_cost_calculator +from .util import calculate_session_cost +from .model_registry import get_model_config +from .price_feed import OPGPriceFeed from .definitions import ( EVM_PAYMENT_ADDRESS, BASE_MAINNET_NETWORK, @@ -107,6 +108,13 @@ def _shutdown_heartbeat(): atexit.register(_shutdown_heartbeat) +# --------------------------------------------------------------------------- +# OPG price feed — start before x402 middleware so the first request can be +# priced correctly. Runs as a daemon thread; no cleanup needed on exit. +# --------------------------------------------------------------------------- +_price_feed = OPGPriceFeed() +_price_feed.start() + facilitator = HTTPFacilitatorClientSync(FacilitatorConfig(url=FACILITATOR_URL)) server = x402ResourceServerSync(facilitator) store = SessionStore() @@ -303,6 +311,7 @@ def health(): "status": "OK", "version": "1.0.0", "tee_enabled": True, + "price_feed": _price_feed.get_status(), }, 200 @@ -374,6 +383,26 @@ def _patched_read_body_bytes(environ): x402_flask._read_body_bytes = _patched_read_body_bytes + +def _session_cost_calculator(ctx: dict) -> int: + # Post-inference cost calculation — response already sent to client. + # Predictable failures (unknown price, unknown model) are blocked by the + # pre-inference gate; any exception here indicates a provider-side error + # (e.g. missing usage field in the LLM response). The x402 middleware + # swallows the exception in close(), so the client is not charged. + # Log CRITICAL so provider errors are never silently missed. + try: + return calculate_session_cost(ctx, _price_feed.get_price) + except Exception as exc: + logger.critical( + "Post-inference cost calculation failed (provider error) — " + "client was NOT charged: %s", + exc, + exc_info=True, + ) + raise + + _payment_mw = payment_middleware( application, routes=routes, @@ -381,102 +410,39 @@ def _patched_read_body_bytes(environ): session_store=store, cost_per_request=100000000000000, # static precheck/fallback estimate session_idle_timeout=100, - session_cost_calculator=dynamic_session_cost_calculator, + session_cost_calculator=_session_cost_calculator, ) # --------------------------------------------------------------------------- -# Strict cost-resolution patch -# -# Why this exists -# --------------- -# The upstream x402 PaymentMiddleware._resolve_session_request_cost wraps the -# call to the session_cost_calculator in a broad try/except. If the calculator -# raises (e.g. ValueError for an unrecognised model name, KeyError for missing -# usage data), the exception is swallowed and the middleware silently falls back -# to the static session maximum (CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND / -# CHAT_COMPLETIONS_USDC_AMOUNT). That silent fallback means: -# • The client is charged the full pre-check cap instead of actual usage. -# • The server has no visible indication that pricing failed. +# Pre-inference pricing gate # -# The fix -# ------- -# We replace _resolve_session_request_cost with our own implementation that is -# identical to upstream, except the cost-calculator call is NOT wrapped in a -# try/except. Any exception from dynamic_session_cost_calculator() therefore -# propagates up through the middleware and Flask, producing a proper HTTP 500 -# response to the client instead of an incorrect silent charge. +# In the upto session scheme the response is streamed to the client before +# cost is settled, so a post-inference pricing failure cannot be surfaced as +# an HTTP error. Instead we validate everything that can be checked up-front +# and reject the request early if pricing would fail: +# 1. Price feed has a valid OPG/USD price (CoinGecko fetch succeeded). +# 2. The requested model is in the registry (has a known per-token price). # --------------------------------------------------------------------------- -def _strict_resolve_session_request_cost( - self, - *, - method: str, - path: str, - request_body_bytes: bytes, - response_body_bytes: bytes, - payment_payload: object, - payment_requirements: object, - status_code: int | None, - output_object: object = None, - is_streaming: bool = False, -) -> int: - """Replacement for PaymentMiddleware._resolve_session_request_cost. - - Identical to the upstream implementation except that exceptions raised by - the dynamic cost calculator are NOT caught. This means a request whose - cost cannot be determined (unknown model, missing usage data, etc.) will - result in a 500 error rather than silently falling back to the static cap - amount and charging the user an incorrect amount. - """ - from x402.http.middleware.flask import _parse_json_bytes as _x402_parse_json # noqa: PLC0415 - - default_cost = self._get_session_cost(payment_requirements) - if not self._should_charge_response(status_code): - return default_cost - if not callable(self._session_cost_calculator): - return default_cost - - request_object = _x402_parse_json(request_body_bytes) - response_object = ( - output_object - if output_object is not None - else _x402_parse_json(response_body_bytes) - ) - - callback_context = { - "method": method, - "path": path, - "status_code": status_code, - "is_streaming": is_streaming, - "request_body_bytes": request_body_bytes, - "response_body_bytes": response_body_bytes, - "request_json": request_object - if isinstance(request_object, (dict, list)) - else None, - "response_json": response_object - if isinstance(response_object, (dict, list)) - else None, - "response_object": response_object, - "payment_payload": payment_payload, - "payment_requirements": payment_requirements, - "default_cost": default_cost, - } - - # Do NOT catch exceptions here — let them propagate so the request fails - # with a 500 rather than silently charging the static fallback amount. - dynamic_cost = self._session_cost_calculator(callback_context) - if dynamic_cost is None: - raise ValueError( - f"dynamic_session_cost_calculator returned None for {method} {path}; " - "cannot determine request cost" - ) - return self._coerce_non_negative_int(dynamic_cost) - +@application.before_request +def _check_pricing_ready(): + if request.path not in ("/v1/chat/completions", "/v1/completions"): + return + try: + _price_feed.get_price() + except ValueError as exc: + logger.warning("Rejecting inference request — price feed unavailable: %s", exc) + return jsonify({"error": f"Pricing unavailable: {exc}"}), 503 + + body = request.get_json(silent=True, cache=True) or {} + model = body.get("model") + if model: + try: + get_model_config(model) + except ValueError: + return jsonify({"error": f"Model '{model}' is not supported"}), 400 -_payment_mw._resolve_session_request_cost = _types.MethodType( # type: ignore[method-assign, attr-defined] - _strict_resolve_session_request_cost, _payment_mw -) logger.info("x402 payment middleware initialized") diff --git a/tee_gateway/price_feed/__init__.py b/tee_gateway/price_feed/__init__.py new file mode 100644 index 0000000..1349825 --- /dev/null +++ b/tee_gateway/price_feed/__init__.py @@ -0,0 +1,7 @@ +from .config import PriceFeedConfig +from .feed import OPGPriceFeed + +__all__ = [ + "OPGPriceFeed", + "PriceFeedConfig", +] diff --git a/tee_gateway/price_feed/config.py b/tee_gateway/price_feed/config.py new file mode 100644 index 0000000..a8742dc --- /dev/null +++ b/tee_gateway/price_feed/config.py @@ -0,0 +1,52 @@ +""" +Configuration constants and dataclass for the OPG price feed. +""" + +from dataclasses import dataclass +from datetime import datetime, timezone +from decimal import Decimal + + +# --------------------------------------------------------------------------- +# CoinGecko API +# --------------------------------------------------------------------------- +COINGECKO_BASE_URL = "https://api.coingecko.com/api/v3" +COINGECKO_PLATFORM = "base" # Base mainnet platform identifier on CoinGecko +FETCH_TIMEOUT = 10 # seconds per HTTP request + +# --------------------------------------------------------------------------- +# Refresh / retry defaults +# --------------------------------------------------------------------------- +DEFAULT_REFRESH_INTERVAL = 300 # 5 minutes between background refresh cycles +DEFAULT_MAX_RETRIES = 3 # attempts per refresh cycle before giving up +DEFAULT_RETRY_DELAY = 10 # seconds between retry attempts within a cycle + +# --------------------------------------------------------------------------- +# TGE (Token Generation Event) fallback +# --------------------------------------------------------------------------- +# Before the TGE cutover, OPG is not yet listed on CoinGecko. Return a fixed +# fallback price so inference requests can be priced immediately at launch. +# After the cutover, the live CoinGecko price is used. +TGE_CUTOVER_UTC = datetime(2026, 4, 21, 12, 30, 0, tzinfo=timezone.utc) +TGE_FALLBACK_PRICE_USD = Decimal("0.10") + +# --------------------------------------------------------------------------- +# Stale-price thresholds +# --------------------------------------------------------------------------- +# get_price() logs WARNING when last successful fetch is older than +# STALE_WARNING_MULTIPLIER × refresh_interval seconds. +STALE_WARNING_MULTIPLIER = 2 + +# get_price() raises ValueError when last successful fetch is older than +# STALE_PRICE_MAX_AGE seconds — at this point the cached price is considered +# too outdated to use for billing. +STALE_PRICE_MAX_AGE = 4 * 60 * 60 # 4 hours + + +@dataclass(frozen=True) +class PriceFeedConfig: + """Runtime configuration for the OPG price feed background service.""" + + refresh_interval: int = DEFAULT_REFRESH_INTERVAL + max_retries: int = DEFAULT_MAX_RETRIES + retry_delay: float = DEFAULT_RETRY_DELAY diff --git a/tee_gateway/price_feed/feed.py b/tee_gateway/price_feed/feed.py new file mode 100644 index 0000000..6875f68 --- /dev/null +++ b/tee_gateway/price_feed/feed.py @@ -0,0 +1,269 @@ +""" +Background OPG/USD price feed using the CoinGecko public API. + +Runs as a daemon thread that proactively refreshes the OPG token price at a +configurable interval, with retry on per-cycle fetch failure and early exit on +rate limiting. + +Usage +----- +Create an ``OPGPriceFeed`` instance in the application entry point, call +``start()``, then pass it explicitly to wherever the price is needed (e.g. +``calculate_session_cost(...)`` in ``util.py``). +""" + +import logging +import threading +import time +from datetime import datetime, timezone +from decimal import Decimal +from typing import Any, Optional + +import requests + +from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS +from tee_gateway.price_feed.config import ( + COINGECKO_BASE_URL, + COINGECKO_PLATFORM, + DEFAULT_MAX_RETRIES, + DEFAULT_REFRESH_INTERVAL, + DEFAULT_RETRY_DELAY, + FETCH_TIMEOUT, + STALE_PRICE_MAX_AGE, + STALE_WARNING_MULTIPLIER, + TGE_CUTOVER_UTC, + TGE_FALLBACK_PRICE_USD, +) + +logger = logging.getLogger("llm_server.price_feed") + + +class OPGPriceFeed: + """Fetches and caches the OPG/USD price from CoinGecko in a background thread.""" + + def __init__( + self, + refresh_interval: int = DEFAULT_REFRESH_INTERVAL, + max_retries: int = DEFAULT_MAX_RETRIES, + retry_delay: float = DEFAULT_RETRY_DELAY, + ) -> None: + self._refresh_interval = refresh_interval + self._max_retries = max_retries + self._retry_delay = retry_delay + + self._price: Optional[Decimal] = None + self._lock = threading.Lock() + self._thread: Optional[threading.Thread] = None + + # Status tracking — updated under _lock on every refresh cycle outcome. + self.last_success: Optional[float] = None # epoch seconds of last good fetch + self.last_error: Optional[str] = None # description of last failure (if any) + self.consecutive_failures: int = 0 # reset to 0 on any successful fetch + self.total_fetches: int = 0 # cumulative successful fetches + self.total_errors: int = 0 # cumulative failed refresh cycles + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + def start(self) -> None: + """Launch the background refresh loop, including the initial price fetch. + + The initial fetch runs inside the background thread so startup is + non-blocking. ``get_price()`` will raise ``ValueError`` until the + first successful fetch completes; until then, inference requests are + rejected by the pre-inference pricing gate in ``__main__.py``. + + Idempotent — calling ``start()`` on an already-running feed is a no-op. + Thread-safe: the check-and-start is performed under ``_lock``. + """ + with self._lock: + if self._thread is not None and self._thread.is_alive(): + logger.info( + "OPG price feed already running, ignoring duplicate start()" + ) + return + self._thread = threading.Thread( + target=self._run_with_initial_fetch, + name="opg-price-feed", + daemon=True, + ) + self._thread.start() + logger.info( + "OPG price feed started (refresh_interval=%ds, max_retries=%d)", + self._refresh_interval, + self._max_retries, + ) + + def get_price(self) -> Decimal: + """Return the latest cached OPG/USD price. + + Before the TGE cutover (``TGE_CUTOVER_UTC``), returns the fixed + ``TGE_FALLBACK_PRICE_USD`` so requests can be priced before OPG is + listed on CoinGecko. After the cutover the live cached price is used. + + Raises ``ValueError`` if no price has been successfully fetched yet + (post-TGE only). Logs a warning (but still returns the price) if the + cached value is older than ``STALE_WARNING_MULTIPLIER * refresh_interval`` + seconds — this indicates the background loop has missed at least one + refresh cycle and may be experiencing persistent errors. + """ + if datetime.now(timezone.utc) < TGE_CUTOVER_UTC: + return TGE_FALLBACK_PRICE_USD + + now = time.time() + with self._lock: + if self._price is None: + raise ValueError( + "OPG price not yet available — " + "price feed has not completed a successful fetch" + ) + if self.last_success is None: + raise ValueError( + "OPG price not yet available — " + "price feed has not completed a successful fetch" + ) + age = now - self.last_success + if age > STALE_PRICE_MAX_AGE: + raise ValueError( + f"OPG price data expired: last successful fetch was {age:.0f}s ago " + f"(max: {STALE_PRICE_MAX_AGE}s); consecutive failures: " + f"{self.consecutive_failures}" + ) + stale_threshold = self._refresh_interval * STALE_WARNING_MULTIPLIER + if age > stale_threshold: + logger.warning( + "OPG price data is stale: last successful fetch was %.0fs ago " + "(threshold: %.0fs); consecutive failures: %d", + age, + stale_threshold, + self.consecutive_failures, + ) + return self._price + + def get_status(self) -> dict[str, Any]: + """Return a health snapshot suitable for logging or a /health endpoint.""" + with self._lock: + return { + "price_usd": float(self._price) if self._price is not None else None, + "last_success": self.last_success, + "last_error": self.last_error, + "consecutive_failures": self.consecutive_failures, + "total_fetches": self.total_fetches, + "total_errors": self.total_errors, + "refresh_interval": self._refresh_interval, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _run_with_initial_fetch(self) -> None: + self._refresh_price() + while True: + time.sleep(self._refresh_interval) + self._refresh_price() + + def _refresh_price(self) -> None: + """Attempt to fetch a fresh price, retrying on transient failure. + + - On success: updates the cached price and resets ``consecutive_failures``. + - On HTTP 429: logs a rate-limit warning and exits the retry loop early + (no point hammering a rate-limited API). + - On exhausted retries: increments ``consecutive_failures`` and retains + the last known good price so live traffic is not disrupted by a + transient CoinGecko outage. + """ + last_exc: Optional[Exception] = None + + for attempt in range(1, self._max_retries + 1): + try: + price = fetch_opg_price() + with self._lock: + self._price = price + self.last_success = time.time() + self.last_error = None + self.consecutive_failures = 0 + self.total_fetches += 1 + logger.info( + "OPG price updated: $%.6f USD (attempt %d/%d)", + float(price), + attempt, + self._max_retries, + ) + return + except requests.exceptions.HTTPError as exc: + last_exc = exc + status_code = ( + exc.response.status_code if exc.response is not None else None + ) + if status_code == 429: + logger.warning( + "CoinGecko rate limit hit (429) on attempt %d/%d; " + "skipping remaining retries for this cycle", + attempt, + self._max_retries, + ) + break + logger.warning( + "OPG price fetch attempt %d/%d failed (HTTP %s): %s", + attempt, + self._max_retries, + status_code, + exc, + ) + except Exception as exc: + last_exc = exc + logger.warning( + "OPG price fetch attempt %d/%d failed: %s", + attempt, + self._max_retries, + exc, + ) + + if attempt < self._max_retries: + time.sleep(self._retry_delay) + + # All attempts exhausted (or rate-limited out) — record the failure. + with self._lock: + self.total_errors += 1 + self.consecutive_failures += 1 + self.last_error = str(last_exc) if last_exc is not None else "unknown error" + + logger.error( + "OPG price refresh failed (consecutive failures: %d); " + "retaining last known price (%s)", + self.consecutive_failures, + self._price, + ) + + +def fetch_opg_price() -> Decimal: + """Fetch the current OPG/USD price from CoinGecko. Raises on any error.""" + url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" + params = { + "contract_addresses": BASE_MAINNET_OPG_ADDRESS, + "vs_currencies": "usd", + } + response = requests.get(url, params=params, timeout=FETCH_TIMEOUT) + response.raise_for_status() + + data: Any = response.json() + if not isinstance(data, dict): + raise ValueError( + f"Unexpected CoinGecko response for {BASE_MAINNET_OPG_ADDRESS}: {data!r}" + ) + # CoinGecko keys the result by the lowercased contract address. + price_entry = data.get(BASE_MAINNET_OPG_ADDRESS.lower()) + if not isinstance(price_entry, dict) or "usd" not in price_entry: + raise ValueError( + f"Unexpected CoinGecko response for {BASE_MAINNET_OPG_ADDRESS}: {data!r}" + ) + + price = Decimal(str(price_entry["usd"])) + if not price.is_finite() or price <= 0: + raise ValueError( + f"Invalid price from CoinGecko for {BASE_MAINNET_OPG_ADDRESS}: " + f"{price_entry['usd']!r}" + ) + return price diff --git a/tee_gateway/test/test_price_feed.py b/tee_gateway/test/test_price_feed.py new file mode 100644 index 0000000..fe9bb8a --- /dev/null +++ b/tee_gateway/test/test_price_feed.py @@ -0,0 +1,516 @@ +""" +Unit tests for tee_gateway.price_feed and tee_gateway.util.calculate_session_cost. + +All external HTTP calls are mocked — no network access required. + +Test classes +------------ +TestFetchOPGPrice — the raw fetch_opg_price() helper in feed.py +TestOPGPriceFeedRefresh — OPGPriceFeed._refresh_price() (retry, rate-limit, stats) +TestOPGPriceFeedGetPrice — OPGPriceFeed.get_price() (stale warning, ValueError before fetch) +TestOPGPriceFeedStatus — OPGPriceFeed.get_status() snapshots +TestCalculateSessionCost — calculate_session_cost(context, get_price) in util.py +""" + +import time +import unittest +from datetime import datetime, timezone +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import requests + +from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS +from tee_gateway.price_feed import OPGPriceFeed +from tee_gateway.price_feed.feed import fetch_opg_price +from tee_gateway.util import calculate_session_cost + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +OPG_ADDRESS_LOWER = BASE_MAINNET_OPG_ADDRESS.lower() +SAMPLE_PRICE = Decimal("0.042") +SAMPLE_PRICE_FLOAT = 0.042 + +# Patch target prefix — all mocks go through the feed module. +_FEED = "tee_gateway.price_feed.feed" + +# A datetime well after the TGE cutover so get_price() uses the cached price. +_POST_TGE = datetime(2026, 4, 22, 0, 0, 0, tzinfo=timezone.utc) + + +def _mock_response(status_code: int = 200, json_body: dict | None = None) -> MagicMock: + """Build a minimal mock requests.Response.""" + mock = MagicMock() + mock.status_code = status_code + mock.json.return_value = json_body or {} + if status_code >= 400: + http_err = requests.exceptions.HTTPError(response=mock) + mock.raise_for_status.side_effect = http_err + else: + mock.raise_for_status.return_value = None + return mock + + +def _coingecko_success_body() -> dict: + return {OPG_ADDRESS_LOWER: {"usd": SAMPLE_PRICE_FLOAT}} + + +# --------------------------------------------------------------------------- +# TestFetchOPGPrice +# --------------------------------------------------------------------------- + + +class TestFetchOPGPrice(unittest.TestCase): + """Tests for the fetch_opg_price() free function in feed.py.""" + + @patch(f"{_FEED}.requests.get") + def test_happy_path_returns_decimal(self, mock_get): + mock_get.return_value = _mock_response(200, _coingecko_success_body()) + price = fetch_opg_price() + self.assertIsInstance(price, Decimal) + self.assertEqual(price, Decimal(str(SAMPLE_PRICE_FLOAT))) + + @patch(f"{_FEED}.requests.get") + def test_passes_correct_params(self, mock_get): + mock_get.return_value = _mock_response(200, _coingecko_success_body()) + fetch_opg_price() + _, kwargs = mock_get.call_args + self.assertIn("contract_addresses", kwargs["params"]) + self.assertEqual(kwargs["params"]["vs_currencies"], "usd") + self.assertIn( + "base", kwargs["url"] if "url" in kwargs else mock_get.call_args[0][0] + ) + + @patch(f"{_FEED}.requests.get") + def test_raises_on_http_500(self, mock_get): + mock_get.return_value = _mock_response(500) + with self.assertRaises(requests.exceptions.HTTPError): + fetch_opg_price() + + @patch(f"{_FEED}.requests.get") + def test_raises_on_http_429(self, mock_get): + mock_get.return_value = _mock_response(429) + with self.assertRaises(requests.exceptions.HTTPError) as ctx: + fetch_opg_price() + self.assertEqual(ctx.exception.response.status_code, 429) + + @patch(f"{_FEED}.requests.get") + def test_raises_on_empty_response_body(self, mock_get): + mock_get.return_value = _mock_response(200, {}) + with self.assertRaises(ValueError, msg="should raise when address key absent"): + fetch_opg_price() + + @patch(f"{_FEED}.requests.get") + def test_raises_when_usd_key_missing(self, mock_get): + mock_get.return_value = _mock_response(200, {OPG_ADDRESS_LOWER: {"eur": 0.04}}) + with self.assertRaises(ValueError): + fetch_opg_price() + + @patch(f"{_FEED}.requests.get") + def test_raises_when_address_entry_is_empty_dict(self, mock_get): + """CoinGecko returns {address: {}} for known-but-unpriced tokens (current OPG behaviour).""" + mock_get.return_value = _mock_response(200, {OPG_ADDRESS_LOWER: {}}) + with self.assertRaises(ValueError, msg="empty price entry should raise"): + fetch_opg_price() + + @patch(f"{_FEED}.requests.get") + def test_raises_on_network_error(self, mock_get): + mock_get.side_effect = requests.exceptions.ConnectionError("timeout") + with self.assertRaises(requests.exceptions.ConnectionError): + fetch_opg_price() + + +# --------------------------------------------------------------------------- +# TestOPGPriceFeedRefresh +# --------------------------------------------------------------------------- + + +class TestOPGPriceFeedRefresh(unittest.TestCase): + """Tests for OPGPriceFeed._refresh_price() — retry logic, rate-limit, stats.""" + + def _feed(self, **kwargs) -> OPGPriceFeed: + defaults = {"refresh_interval": 300, "max_retries": 3, "retry_delay": 0} + defaults.update(kwargs) + return OPGPriceFeed(**defaults) + + @patch(f"{_FEED}.fetch_opg_price") + def test_successful_refresh_sets_price(self, mock_fetch): + mock_fetch.return_value = SAMPLE_PRICE + feed = self._feed() + feed._refresh_price() + self.assertEqual(feed._price, SAMPLE_PRICE) + + @patch(f"{_FEED}.fetch_opg_price") + def test_successful_refresh_updates_stats(self, mock_fetch): + mock_fetch.return_value = SAMPLE_PRICE + feed = self._feed() + feed._refresh_price() + self.assertEqual(feed.total_fetches, 1) + self.assertEqual(feed.total_errors, 0) + self.assertEqual(feed.consecutive_failures, 0) + self.assertIsNotNone(feed.last_success) + + @patch(f"{_FEED}.fetch_opg_price") + def test_retry_on_transient_failure_then_success(self, mock_fetch): + mock_fetch.side_effect = [ + ValueError("transient"), + ValueError("transient"), + SAMPLE_PRICE, + ] + feed = self._feed(max_retries=3, retry_delay=0) + feed._refresh_price() + self.assertEqual(feed._price, SAMPLE_PRICE) + self.assertEqual(mock_fetch.call_count, 3) + self.assertEqual(feed.total_fetches, 1) + self.assertEqual(feed.total_errors, 0) + + @patch(f"{_FEED}.fetch_opg_price") + def test_exhausted_retries_records_error_stats(self, mock_fetch): + mock_fetch.side_effect = ValueError("always fails") + feed = self._feed(max_retries=3, retry_delay=0) + feed._refresh_price() + self.assertEqual(feed.total_errors, 1) + self.assertEqual(feed.consecutive_failures, 1) + self.assertIsNotNone(feed.last_error) + self.assertEqual(feed.total_fetches, 0) + + @patch(f"{_FEED}.fetch_opg_price") + def test_exhausted_retries_keeps_last_known_price(self, mock_fetch): + feed = self._feed(max_retries=2, retry_delay=0) + feed._price = SAMPLE_PRICE + feed.last_success = time.time() + mock_fetch.side_effect = ValueError("fail") + feed._refresh_price() + self.assertEqual(feed._price, SAMPLE_PRICE) + + @patch(f"{_FEED}.fetch_opg_price") + def test_success_after_failures_resets_consecutive_failures(self, mock_fetch): + feed = self._feed(max_retries=1, retry_delay=0) + mock_fetch.side_effect = ValueError("fail") + feed._refresh_price() + self.assertEqual(feed.consecutive_failures, 1) + mock_fetch.side_effect = None + mock_fetch.return_value = SAMPLE_PRICE + feed._refresh_price() + self.assertEqual(feed.consecutive_failures, 0) + + @patch(f"{_FEED}.fetch_opg_price") + def test_rate_limit_breaks_retry_loop_immediately(self, mock_fetch): + resp = MagicMock() + resp.status_code = 429 + mock_fetch.side_effect = requests.exceptions.HTTPError(response=resp) + feed = self._feed(max_retries=3, retry_delay=0) + feed._refresh_price() + self.assertEqual(mock_fetch.call_count, 1) + self.assertEqual(feed.total_errors, 1) + + @patch(f"{_FEED}.time.sleep") + @patch(f"{_FEED}.fetch_opg_price") + def test_retry_delay_called_between_attempts(self, mock_fetch, mock_sleep): + mock_fetch.side_effect = [ValueError("fail"), ValueError("fail"), SAMPLE_PRICE] + feed = self._feed(max_retries=3, retry_delay=5) + feed._refresh_price() + self.assertEqual(mock_sleep.call_count, 2) + mock_sleep.assert_called_with(5) + + @patch(f"{_FEED}.time.sleep") + @patch(f"{_FEED}.fetch_opg_price") + def test_no_sleep_after_last_failed_attempt(self, mock_fetch, mock_sleep): + mock_fetch.side_effect = ValueError("always fails") + feed = self._feed(max_retries=3, retry_delay=5) + feed._refresh_price() + self.assertEqual(mock_sleep.call_count, 2) + + +# --------------------------------------------------------------------------- +# TestOPGPriceFeedGetPrice +# --------------------------------------------------------------------------- + + +class TestOPGPriceFeedGetPrice(unittest.TestCase): + """Tests for OPGPriceFeed.get_price() behaviour.""" + + @patch(f"{_FEED}.datetime") + def test_raises_before_any_successful_fetch(self, mock_dt): + mock_dt.now.return_value = _POST_TGE + feed = OPGPriceFeed() + with self.assertRaises(ValueError) as ctx: + feed.get_price() + self.assertIn("not yet available", str(ctx.exception)) + + @patch(f"{_FEED}.datetime") + @patch(f"{_FEED}.fetch_opg_price") + def test_returns_price_after_successful_refresh(self, mock_fetch, mock_dt): + mock_dt.now.return_value = _POST_TGE + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(retry_delay=0) + feed._refresh_price() + self.assertEqual(feed.get_price(), SAMPLE_PRICE) + + @patch(f"{_FEED}.datetime") + @patch(f"{_FEED}.time.time") + @patch(f"{_FEED}.fetch_opg_price") + def test_warns_when_price_is_stale(self, mock_fetch, mock_time, mock_dt): + mock_dt.now.return_value = _POST_TGE + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(refresh_interval=300, retry_delay=0) + + mock_time.return_value = 0.0 + feed._refresh_price() + + # Advance past stale threshold (300 * 2 = 600s) + mock_time.return_value = 601.0 + + with self.assertLogs("llm_server.price_feed", level="WARNING") as log_ctx: + price = feed.get_price() + + self.assertEqual(price, SAMPLE_PRICE) + self.assertTrue(any("stale" in line.lower() for line in log_ctx.output)) + + @patch(f"{_FEED}.datetime") + @patch(f"{_FEED}.time.time") + @patch(f"{_FEED}.fetch_opg_price") + def test_raises_when_price_exceeds_max_age(self, mock_fetch, mock_time, mock_dt): + mock_dt.now.return_value = _POST_TGE + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(retry_delay=0) + + mock_time.return_value = 0.0 + feed._refresh_price() + + # Advance past the 4-hour max age + mock_time.return_value = 4 * 60 * 60 + 1.0 + + with self.assertRaises(ValueError) as ctx: + feed.get_price() + self.assertIn("expired", str(ctx.exception)) + + @patch(f"{_FEED}.time.time") + @patch(f"{_FEED}.fetch_opg_price") + def test_no_stale_warning_when_price_is_fresh(self, mock_fetch, mock_time): + import logging + + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(refresh_interval=300, retry_delay=0) + + mock_time.return_value = 0.0 + feed._refresh_price() + mock_time.return_value = 100.0 # well within threshold + + with self.assertLogs("llm_server.price_feed", level="DEBUG") as log_ctx: + logging.getLogger("llm_server.price_feed").debug("sentinel") + feed.get_price() + + warning_lines = [ + line + for line in log_ctx.output + if "WARNING" in line and "stale" in line.lower() + ] + self.assertEqual(warning_lines, []) + + +# --------------------------------------------------------------------------- +# TestOPGPriceFeedStatus +# --------------------------------------------------------------------------- + + +class TestOPGPriceFeedStatus(unittest.TestCase): + """Tests for OPGPriceFeed.get_status() snapshot.""" + + def test_initial_status_has_no_price(self): + feed = OPGPriceFeed() + status = feed.get_status() + self.assertIsNone(status["price_usd"]) + self.assertIsNone(status["last_success"]) + self.assertEqual(status["consecutive_failures"], 0) + self.assertEqual(status["total_fetches"], 0) + self.assertEqual(status["total_errors"], 0) + + @patch(f"{_FEED}.fetch_opg_price") + def test_status_reflects_successful_fetch(self, mock_fetch): + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(retry_delay=0) + feed._refresh_price() + status = feed.get_status() + self.assertAlmostEqual(status["price_usd"], float(SAMPLE_PRICE), places=6) + self.assertIsNotNone(status["last_success"]) + self.assertEqual(status["total_fetches"], 1) + self.assertEqual(status["consecutive_failures"], 0) + + @patch(f"{_FEED}.fetch_opg_price") + def test_status_reflects_failed_cycle(self, mock_fetch): + mock_fetch.side_effect = ValueError("fail") + feed = OPGPriceFeed(max_retries=1, retry_delay=0) + feed._refresh_price() + status = feed.get_status() + self.assertIsNone(status["price_usd"]) + self.assertEqual(status["total_errors"], 1) + self.assertEqual(status["consecutive_failures"], 1) + self.assertIsNotNone(status["last_error"]) + + def test_status_includes_refresh_interval(self): + feed = OPGPriceFeed(refresh_interval=600) + self.assertEqual(feed.get_status()["refresh_interval"], 600) + + @patch(f"{_FEED}.fetch_opg_price") + def test_status_accumulates_multiple_error_cycles(self, mock_fetch): + mock_fetch.side_effect = ValueError("fail") + feed = OPGPriceFeed(max_retries=1, retry_delay=0) + feed._refresh_price() + feed._refresh_price() + feed._refresh_price() + status = feed.get_status() + self.assertEqual(status["total_errors"], 3) + self.assertEqual(status["consecutive_failures"], 3) + + +# --------------------------------------------------------------------------- +# TestMakeCostCalculator +# --------------------------------------------------------------------------- + +_ASSET_ADDR = "0xdeadbeef" +_ASSET_ADDR_LOWER = _ASSET_ADDR.lower() +_ASSET_DECIMALS = 18 + + +def _make_payment_requirements(asset: str = _ASSET_ADDR) -> dict: + return {"asset": asset, "price": {"amount": "1000000000000000000", "asset": asset}} + + +def _make_context( + model: str = "gpt-4.1-mini", + input_tokens: int = 100, + output_tokens: int = 50, + price_usd: Decimal = Decimal("0.10"), + asset: str = _ASSET_ADDR, +) -> dict: + return { + "request_json": {"model": model}, + "response_json": { + "model": model, + "usage": { + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + }, + }, + "payment_requirements": _make_payment_requirements(asset), + "method": "POST", + "path": "/v1/chat/completions", + "status_code": 200, + "is_streaming": False, + "request_body_bytes": b"", + "response_body_bytes": b"", + "default_cost": 10**18, + } + + +def _make_get_price(price_usd: Decimal = Decimal("0.10")) -> MagicMock: + mock = MagicMock(return_value=price_usd) + return mock + + +class TestCalculateSessionCost(unittest.TestCase): + """Tests for calculate_session_cost(context, get_price).""" + + def _patch_definitions(self): + return patch( + "tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", + {_ASSET_ADDR_LOWER: _ASSET_DECIMALS}, + ) + + def _patch_model( + self, input_price: str = "0.000001", output_price: str = "0.000002" + ): + cfg = MagicMock() + cfg.input_price_usd = Decimal(input_price) + cfg.output_price_usd = Decimal(output_price) + return patch("tee_gateway.util.get_model_config", return_value=cfg) + + def test_calls_get_price(self): + get_price = _make_get_price() + with self._patch_definitions(), self._patch_model(): + calculate_session_cost(_make_context(), get_price) + get_price.assert_called_once() + + def test_returns_positive_int(self): + with self._patch_definitions(), self._patch_model(): + result = calculate_session_cost(_make_context(), _make_get_price()) + self.assertIsInstance(result, int) + self.assertGreaterEqual(result, 0) + + def test_zero_tokens_returns_zero(self): + with self._patch_definitions(), self._patch_model(): + result = calculate_session_cost( + _make_context(input_tokens=0, output_tokens=0), _make_get_price() + ) + self.assertEqual(result, 0) + + def test_raises_when_get_price_raises(self): + get_price = MagicMock(side_effect=ValueError("price not available")) + with self._patch_definitions(), self._patch_model(): + with self.assertRaises(ValueError): + calculate_session_cost(_make_context(), get_price) + + def test_raises_when_non_positive_price(self): + with self._patch_definitions(), self._patch_model(): + with self.assertRaises(ValueError): + calculate_session_cost(_make_context(), _make_get_price(Decimal("0"))) + + def test_raises_when_request_json_missing(self): + ctx = _make_context() + ctx["request_json"] = None + with self._patch_definitions(), self._patch_model(): + with self.assertRaises(ValueError): + calculate_session_cost(ctx, _make_get_price()) + + def test_raises_when_usage_missing(self): + ctx = _make_context() + ctx["response_json"] = {"model": "gpt-4.1-mini"} + with self._patch_definitions(), self._patch_model(): + with self.assertRaises(ValueError): + calculate_session_cost(ctx, _make_get_price()) + + def test_raises_when_asset_unknown(self): + ctx = _make_context(asset="0xunknown") + with ( + patch("tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", {}), + self._patch_model(), + ): + with self.assertRaises(ValueError): + calculate_session_cost(ctx, _make_get_price()) + + def test_cost_scales_with_token_count(self): + with self._patch_definitions(), self._patch_model(): + cost_small = calculate_session_cost( + _make_context(input_tokens=10, output_tokens=5), _make_get_price() + ) + cost_large = calculate_session_cost( + _make_context(input_tokens=1000, output_tokens=500), _make_get_price() + ) + self.assertGreater(cost_large, cost_small) + + def test_higher_token_price_yields_lower_cost(self): + with self._patch_definitions(), self._patch_model(): + cost_cheap = calculate_session_cost( + _make_context(), _make_get_price(Decimal("0.10")) + ) + cost_expensive = calculate_session_cost( + _make_context(), _make_get_price(Decimal("0.20")) + ) + self.assertGreater(cost_cheap, cost_expensive) + + def test_uses_current_price_on_each_call(self): + """get_price is called fresh every invocation — price changes are picked up.""" + get_price = MagicMock(side_effect=[Decimal("0.10"), Decimal("0.20")]) + with self._patch_definitions(), self._patch_model(): + cost_first = calculate_session_cost(_make_context(), get_price) + cost_second = calculate_session_cost(_make_context(), get_price) + self.assertEqual(get_price.call_count, 2) + # Price doubled → cost should halve (same USD spend, twice the token price). + self.assertGreater(cost_first, cost_second) + + +if __name__ == "__main__": + unittest.main() diff --git a/tee_gateway/test/test_price_feed_integration.py b/tee_gateway/test/test_price_feed_integration.py new file mode 100644 index 0000000..2da9db1 --- /dev/null +++ b/tee_gateway/test/test_price_feed_integration.py @@ -0,0 +1,133 @@ +""" +Integration tests for tee_gateway.price_feed. + +These tests make REAL network calls to the CoinGecko public API. + +Expected behaviour +------------------ +* ``TestCoinGeckoConnectivity`` — passes when the CoinGecko API is reachable. + Skips on network errors or rate-limiting (429). +* ``TestOPGPriceFetchLive`` — skips when OPG is not yet priced on CoinGecko's + Base platform (CoinGecko currently returns an empty price entry for the + token). Will pass automatically once the token is fully listed. + +Run with:: + + uv run pytest tee_gateway/test/test_price_feed_integration.py -v +""" + +import os +import unittest +from decimal import Decimal + +import requests + +if not os.getenv("RUN_INTEGRATION_TESTS"): + raise unittest.SkipTest("Set RUN_INTEGRATION_TESTS=1 to run integration tests") + +from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS +from tee_gateway.price_feed.config import ( + COINGECKO_BASE_URL, + COINGECKO_PLATFORM, + FETCH_TIMEOUT, +) +from tee_gateway.price_feed.feed import fetch_opg_price + + +def _get(url: str, **kwargs) -> requests.Response: + """Wrapper that skips the test on network errors or rate-limiting.""" + try: + resp = requests.get(url, timeout=FETCH_TIMEOUT, **kwargs) + except requests.exceptions.RequestException as exc: + raise unittest.SkipTest(f"Network unavailable: {exc}") from exc + if resp.status_code == 429: + raise unittest.SkipTest( + "CoinGecko rate limit hit (429) — re-run after a short wait" + ) + return resp + + +class TestCoinGeckoConnectivity(unittest.TestCase): + """Verify that the CoinGecko API endpoint is reachable and well-formed.""" + + def test_ping_endpoint_reachable(self): + """CoinGecko /ping should return {gecko_says: ...}.""" + resp = _get(f"{COINGECKO_BASE_URL}/ping") + self.assertEqual(resp.status_code, 200) + self.assertIn("gecko_says", resp.json()) + + def test_base_platform_endpoint_returns_200(self): + """The token_price/base endpoint should respond with HTTP 200 for a known token.""" + url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" + # USDC on Base mainnet — reliably indexed on CoinGecko. + usdc_base = "0x833589fcd6edb6e08f4c7c32d4f71b54bda02913" + resp = _get( + url, params={"contract_addresses": usdc_base, "vs_currencies": "usd"} + ) + self.assertEqual( + resp.status_code, + 200, + f"Expected 200 from CoinGecko, got {resp.status_code}: {resp.text[:200]}", + ) + data = resp.json() + self.assertIsInstance(data, dict) + self.assertIn(usdc_base, data, "USDC should be indexed on Base platform") + self.assertIn("usd", data[usdc_base], "USDC price entry should have 'usd' key") + + +class TestOPGPriceFetchLive(unittest.TestCase): + """Live fetch of the OPG token price. + + Both tests skip gracefully when OPG is not yet fully priced on CoinGecko + (currently returns ``{address: {}}`` with no 'usd' key). They will pass + automatically once the token is listed with a live price. + """ + + def test_opg_response_structure(self): + """Inspect the raw CoinGecko response for the OPG contract address.""" + url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" + resp = _get( + url, + params={ + "contract_addresses": BASE_MAINNET_OPG_ADDRESS, + "vs_currencies": "usd", + }, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + print(f"\nCoinGecko response for OPG ({BASE_MAINNET_OPG_ADDRESS}): {data}") # noqa: T201 + + opg_lower = BASE_MAINNET_OPG_ADDRESS.lower() + price_entry = data.get(opg_lower) + # CoinGecko returns the address key with {} when the token is known but + # not yet priced — skip in that case rather than fail. + if not price_entry or "usd" not in price_entry: + self.skipTest( + f"OPG not yet priced on CoinGecko Base platform " + f"(response: {data!r}). Will pass once the token is fully listed." + ) + self.assertIsInstance(price_entry["usd"], (int, float)) + + def test_opg_price_fetch_live(self): + """End-to-end: fetch_opg_price() returns a positive Decimal price.""" + try: + price = fetch_opg_price() + except requests.exceptions.HTTPError as exc: + if exc.response is not None and exc.response.status_code == 429: + self.skipTest("CoinGecko rate limit — re-run after a short wait") + raise + except ValueError as exc: + if "Unexpected CoinGecko response" in str(exc): + self.skipTest( + f"OPG ({BASE_MAINNET_OPG_ADDRESS}) not yet priced on " + f"CoinGecko Base platform. Details: {exc}" + ) + raise + + self.assertIsInstance(price, Decimal) + self.assertGreater(price, Decimal("0"), "Price must be positive") + print(f"\nLive OPG price: ${price} USD") # noqa: T201 + + +if __name__ == "__main__": + unittest.main() diff --git a/tee_gateway/util.py b/tee_gateway/util.py index 47559d9..ac79cd6 100644 --- a/tee_gateway/util.py +++ b/tee_gateway/util.py @@ -2,10 +2,8 @@ from tee_gateway import typing_utils import logging -import threading -import time from decimal import Decimal, InvalidOperation, ROUND_CEILING -from typing import Any +from typing import Any, Callable logger = logging.getLogger("llm_server.dynamic_pricing") @@ -160,43 +158,6 @@ def _deserialize_dict(data, boxed_type): ) from tee_gateway.model_registry import get_model_config # noqa: E402 -TOKEN_A_PRICE_CACHE_TTL_SECONDS = 60 - -_token_price_cache: dict[str, Any] = { - "value": Decimal("1"), - "updated_at": 0.0, -} -_token_price_lock = threading.Lock() - - -def _fetch_token_a_price_usd_mock() -> Decimal: - """Return the USD price of the payment token used for cost calculation. - - Currently returns a fixed 1:1 ratio, which is correct for USDC-denominated - payments (1 USDC ≈ $1 USD). For OPG-denominated payments, replace this - with a live price feed (e.g. a DEX oracle or CoinGecko API call) that - returns the current OPG/USD exchange rate so that token amounts are - calculated correctly against the model's USD pricing. - """ - return Decimal("1") - - -def get_token_a_price_usd() -> Decimal: - now = time.time() - with _token_price_lock: - cached_value = _token_price_cache.get("value") - cached_at = float(_token_price_cache.get("updated_at") or 0.0) - if ( - isinstance(cached_value, Decimal) - and (now - cached_at) < TOKEN_A_PRICE_CACHE_TTL_SECONDS - ): - return cached_value - - value = _fetch_token_a_price_usd_mock() - _token_price_cache["value"] = value - _token_price_cache["updated_at"] = now - return value - def _as_dict(value: Any) -> dict[str, Any] | None: if value is None: @@ -304,35 +265,36 @@ def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int: return ASSET_DECIMALS_BY_ADDRESS[asset_lower] -def dynamic_session_cost_calculator(context: dict[str, Any]) -> int: - """Compute UPTO per-request cost in token smallest units from actual usage. +def calculate_session_cost( + context: dict[str, Any], get_price: Callable[[], Decimal] +) -> int: + """Calculate the x402 session cost in token smallest units for a completed request. - Raises ValueError on any missing or unrecognised input — no silent fallback. + ``get_price`` is called on every invocation to fetch the current OPG/USD + price — pass ``price_feed.get_price`` so the latest cached value is used. + Raises ``ValueError`` on any missing/invalid data. Predictable failures + (unavailable price, unknown model) are blocked before inference by the + pre-inference gate in ``__main__.py``; post-inference failures are logged + as CRITICAL by the caller and the client is not charged. """ request_json = context.get("request_json") response_json = context.get("response_json") if not isinstance(request_json, dict) or not isinstance(response_json, dict): raise ValueError( - "dynamic_session_cost_calculator requires both request_json and response_json" + "calculate_session_cost requires both request_json and response_json" ) model = _extract_model_from_context(request_json, response_json) - - # get_model_config raises ValueError for unknown models — no fallback cfg = get_model_config(model) - input_tokens, output_tokens = _extract_usage_tokens(response_json) - input_rate = cfg.input_price_usd - output_rate = cfg.output_price_usd - - total_usd = (Decimal(input_tokens) * input_rate) + ( - Decimal(output_tokens) * output_rate + total_usd = (Decimal(input_tokens) * cfg.input_price_usd) + ( + Decimal(output_tokens) * cfg.output_price_usd ) - token_price_usd = get_token_a_price_usd() + token_price_usd = get_price() if token_price_usd <= 0: - raise ValueError(f"Token A price is non-positive: {token_price_usd}") + raise ValueError(f"Token price is non-positive: {token_price_usd}") token_amount = total_usd / token_price_usd decimals = _extract_asset_decimals_from_requirements( @@ -344,7 +306,8 @@ def dynamic_session_cost_calculator(context: dict[str, Any]) -> int: ) logger.info( - "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d total_usd=%s token_price_usd=%s decimals=%d cost=%d", + "CALCULATE_SESSION_COST model=%s input_tokens=%d output_tokens=%d " + "total_usd=%s token_price_usd=%s decimals=%d cost=%d", model, input_tokens, output_tokens, diff --git a/tests/test_pricing.py b/tests/test_pricing.py index d1b5f25..5419782 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -3,7 +3,7 @@ Tests verify that: - Every user-facing model name resolves to the correct ModelConfig - - dynamic_session_cost_calculator produces the right amount in OPG token + - calculate_session_cost produces the right amount in OPG token smallest-units for supported models - Edge cases (no usage, unknown model, bad context) are handled correctly """ @@ -16,7 +16,11 @@ _MODEL_LOOKUP, get_model_config, ) -from tee_gateway.util import dynamic_session_cost_calculator +from tee_gateway.util import calculate_session_cost + +# All pricing tests assume OPG = $1.00 so USD cost == OPG token amount. +_OPG_PRICE_USD = Decimal("1") +_get_price = lambda: _OPG_PRICE_USD # noqa: E731 # --------------------------------------------------------------------------- @@ -205,12 +209,12 @@ def test_unknown_sonnet_variant_raises(self): # --------------------------------------------------------------------------- -class TestDynamicSessionCostCalculatorOPG(unittest.TestCase): - """dynamic_session_cost_calculator with OPG (18 decimals).""" +class TestCalculateSessionCostOPG(unittest.TestCase): + """calculate_session_cost with OPG (18 decimals).""" def _calc(self, model, input_tokens, output_tokens): - return dynamic_session_cost_calculator( - _ctx(model, input_tokens, output_tokens, _opg_requirements()) + return calculate_session_cost( + _ctx(model, input_tokens, output_tokens, _opg_requirements()), _get_price ) # ── OpenAI ────────────────────────────────────────────────────────────── @@ -351,11 +355,11 @@ def test_grok_4_fast_cheaper_than_grok_4(self): self.assertLess(fast, full) -class TestDynamicSessionCostCalculatorEdgeCases(unittest.TestCase): - """Edge cases for dynamic_session_cost_calculator.""" +class TestCalculateSessionCostEdgeCases(unittest.TestCase): + """Edge cases for calculate_session_cost.""" def test_zero_tokens_returns_zero(self): - cost = dynamic_session_cost_calculator(_ctx("claude-sonnet-4-5", 0, 0)) + cost = calculate_session_cost(_ctx("claude-sonnet-4-5", 0, 0), _get_price) self.assertEqual(cost, 0) def test_missing_usage_raises(self): @@ -365,24 +369,24 @@ def test_missing_usage_raises(self): "payment_requirements": _opg_requirements(), } with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_unknown_asset_raises(self): ctx = _ctx("claude-sonnet-4-5", 100, 100) ctx["payment_requirements"] = {"asset": "0xdeadbeef", "amount": "1000"} with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_missing_asset_raises(self): ctx = _ctx("claude-sonnet-4-5", 100, 100) ctx["payment_requirements"] = {"amount": "1000"} # no asset with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_unknown_model_raises_value_error(self): ctx = _ctx("gpt-4o", 100, 100) with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_missing_request_json_raises_value_error(self): ctx = { @@ -394,7 +398,7 @@ def test_missing_request_json_raises_value_error(self): "payment_requirements": _opg_requirements(), } with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_model_from_request_takes_priority(self): """request_json model name is used even if response_json has a different model.""" @@ -406,7 +410,7 @@ def test_model_from_request_takes_priority(self): }, "payment_requirements": _opg_requirements(), } - cost = dynamic_session_cost_calculator(ctx) + cost = calculate_session_cost(ctx, _get_price) # Should be priced as Haiku (from request), not Sonnet haiku_cost = _expected_cost_opg("claude-haiku-4-5", 1000, 500) self.assertEqual(cost, haiku_cost) @@ -414,29 +418,31 @@ def test_model_from_request_takes_priority(self): def test_rounding_ceiling(self): """Fractional token costs are always rounded UP.""" # 1 output token of Haiku: 0.000005 USD = 5e12 wei — exact, no rounding needed - cost = dynamic_session_cost_calculator(_ctx("claude-haiku-4-5", 0, 1)) + cost = calculate_session_cost(_ctx("claude-haiku-4-5", 0, 1), _get_price) self.assertEqual(cost, 5_000_000_000_000) # 1 input token of Gemini Flash Lite: 0.0000001 USD = 1e11 wei — exact - cost = dynamic_session_cost_calculator(_ctx("gemini-2.5-flash-lite", 1, 0)) + cost = calculate_session_cost(_ctx("gemini-2.5-flash-lite", 1, 0), _get_price) self.assertEqual(cost, 100_000_000_000) def test_model_name_case_insensitive(self): """Model names are normalized to lowercase before lookup.""" - cost_lower = dynamic_session_cost_calculator( - _ctx("claude-sonnet-4-5", 100, 100) + cost_lower = calculate_session_cost( + _ctx("claude-sonnet-4-5", 100, 100), _get_price ) - cost_upper = dynamic_session_cost_calculator( - _ctx("CLAUDE-SONNET-4-5", 100, 100) + cost_upper = calculate_session_cost( + _ctx("CLAUDE-SONNET-4-5", 100, 100), _get_price ) self.assertEqual(cost_lower, cost_upper) def test_sonnet_4_0_hyphen_vs_dot_same_cost(self): """claude-sonnet-4-0 and claude-4.0-sonnet are the same model.""" - cost_hyphen = dynamic_session_cost_calculator( - _ctx("claude-sonnet-4-0", 1000, 500) + cost_hyphen = calculate_session_cost( + _ctx("claude-sonnet-4-0", 1000, 500), _get_price + ) + cost_dot = calculate_session_cost( + _ctx("claude-4.0-sonnet", 1000, 500), _get_price ) - cost_dot = dynamic_session_cost_calculator(_ctx("claude-4.0-sonnet", 1000, 500)) self.assertEqual(cost_hyphen, cost_dot)