From 4047a251c869d1ac20048600ad0731eea757a6c1 Mon Sep 17 00:00:00 2001 From: kylexqian Date: Sun, 19 Apr 2026 21:38:40 -0700 Subject: [PATCH 01/17] feat: Add CoinGecko OPG/USD price feed background service MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the hardcoded mock price (Decimal("1")) with a real CoinGecko price feed that runs as a background daemon thread. Key behaviour: - Fetches OPG/USD price from CoinGecko /simple/token_price/base at startup and every 5 minutes thereafter (well within free-tier rate limits) - Retries up to 3 times per refresh cycle with 10s delay between attempts; exits retry loop immediately on 429 (rate-limited) to avoid hammering - Retains the last known good price on exhausted retries so live traffic is not disrupted by a transient CoinGecko outage - Logs a WARNING when the cached price is older than 2× the refresh interval (background loop may be stuck) - Tracks last_success, last_error, consecutive_failures, total_fetches, total_errors via get_status() / get_price_feed_status() - get_price() raises ValueError when no price has ever been fetched, which propagates through dynamic_session_cost_calculator and the existing strict _resolve_session_request_cost monkey-patch to return HTTP 500 rather than silently charging an incorrect amount Also adds: - 29 unit tests (all mocked, no network required) covering the fetch helper, retry logic, rate-limit handling, stale warning, stats, and module-level singleton functions - 4 integration tests that hit the live CoinGecko API; OPG-specific tests skip gracefully until the token is fully indexed Co-Authored-By: Claude Sonnet 4.6 --- tee_gateway/__main__.py | 7 + tee_gateway/opg_price_feed.py | 285 +++++++++++++ tee_gateway/test/test_opg_price_feed.py | 387 ++++++++++++++++++ .../test/test_opg_price_feed_integration.py | 127 ++++++ tee_gateway/util.py | 37 +- 5 files changed, 808 insertions(+), 35 deletions(-) create mode 100644 tee_gateway/opg_price_feed.py create mode 100644 tee_gateway/test/test_opg_price_feed.py create mode 100644 tee_gateway/test/test_opg_price_feed_integration.py diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 1ba44c2..a070128 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -37,6 +37,7 @@ import types as _types from .util import dynamic_session_cost_calculator +from .opg_price_feed import start_price_feed from .definitions import ( EVM_PAYMENT_ADDRESS, BASE_MAINNET_NETWORK, @@ -106,6 +107,12 @@ def _shutdown_heartbeat(): atexit.register(_shutdown_heartbeat) +# --------------------------------------------------------------------------- +# OPG price feed — start before x402 middleware so the first request can be +# priced correctly. Runs as a daemon thread; no cleanup needed on exit. +# --------------------------------------------------------------------------- +start_price_feed() + facilitator = HTTPFacilitatorClientSync(FacilitatorConfig(url=FACILITATOR_URL)) server = x402ResourceServerSync(facilitator) store = SessionStore() diff --git a/tee_gateway/opg_price_feed.py b/tee_gateway/opg_price_feed.py new file mode 100644 index 0000000..4c18f2a --- /dev/null +++ b/tee_gateway/opg_price_feed.py @@ -0,0 +1,285 @@ +""" +Background OPG/USD price feed using the CoinGecko public API. + +Runs as a daemon thread that proactively refreshes the OPG token price at a +configurable interval, with retry on per-cycle fetch failure and early exit on +rate limiting. + +Usage +----- +Call ``start_price_feed()`` once during application startup (e.g. in +``__main__.py``). The dynamic cost calculator in ``util.py`` then calls +``get_opg_price_usd()`` to obtain the latest cached price. If no price has +been fetched yet, ``get_opg_price_usd()`` raises ``ValueError``, which +propagates through ``dynamic_session_cost_calculator`` and the strict +``_resolve_session_request_cost`` monkey-patch in ``__main__.py`` to produce +an HTTP 500 rather than silently charging an incorrect amount. +""" + +import logging +import threading +import time +from decimal import Decimal +from typing import Any, Optional + +import requests + +from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS + +logger = logging.getLogger("llm_server.opg_price_feed") + +COINGECKO_BASE_URL = "https://api.coingecko.com/api/v3" +COINGECKO_PLATFORM = "base" + +# How long to wait between background refreshes (seconds). +DEFAULT_REFRESH_INTERVAL = 300 # 5 minutes — well within CoinGecko free-tier limits + +# Per-refresh retry settings. +DEFAULT_MAX_RETRIES = 3 +DEFAULT_RETRY_DELAY = 10 # seconds between retry attempts within a single refresh cycle + +# HTTP request timeout for each CoinGecko call. +FETCH_TIMEOUT = 10 + +# Warn in get_price() when the last successful fetch is this many intervals old. +STALE_WARNING_MULTIPLIER = 2 + + +class OPGPriceFeed: + """Fetches and caches the OPG/USD price from CoinGecko in a background thread.""" + + def __init__( + self, + refresh_interval: int = DEFAULT_REFRESH_INTERVAL, + max_retries: int = DEFAULT_MAX_RETRIES, + retry_delay: float = DEFAULT_RETRY_DELAY, + ) -> None: + self._refresh_interval = refresh_interval + self._max_retries = max_retries + self._retry_delay = retry_delay + + self._price: Optional[Decimal] = None + self._lock = threading.Lock() + self._thread: Optional[threading.Thread] = None + + # Status tracking — updated under _lock on every refresh cycle outcome. + self.last_success: Optional[float] = None # epoch seconds of last good fetch + self.last_error: Optional[str] = None # description of last failure (if any) + self.consecutive_failures: int = 0 # reset to 0 on any successful fetch + self.total_fetches: int = 0 # cumulative successful fetches + self.total_errors: int = 0 # cumulative failed refresh cycles + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + def start(self) -> None: + """Perform an initial price fetch then launch the background refresh loop. + + If the initial fetch fails after all retries the feed still starts — + ``get_price()`` will raise ``ValueError`` until the background loop + eventually succeeds. + """ + self._refresh_price() + self._thread = threading.Thread( + target=self._run, name="opg-price-feed", daemon=True + ) + self._thread.start() + logger.info( + "OPG price feed started (refresh_interval=%ds, max_retries=%d)", + self._refresh_interval, + self._max_retries, + ) + + def get_price(self) -> Decimal: + """Return the latest cached OPG/USD price. + + Raises ``ValueError`` if no price has been successfully fetched yet. + Logs a warning (but still returns the price) if the cached value is + older than ``STALE_WARNING_MULTIPLIER * refresh_interval`` seconds — + this indicates the background loop has missed at least one refresh + cycle and may be experiencing persistent errors. + """ + now = time.time() + with self._lock: + if self._price is None: + raise ValueError( + "OPG price not yet available — " + "price feed has not completed a successful fetch" + ) + if self.last_success is not None: + age = now - self.last_success + stale_threshold = self._refresh_interval * STALE_WARNING_MULTIPLIER + if age > stale_threshold: + logger.warning( + "OPG price data is stale: last successful fetch was %.0fs ago " + "(threshold: %.0fs); consecutive failures: %d", + age, + stale_threshold, + self.consecutive_failures, + ) + return self._price + + def get_status(self) -> dict[str, Any]: + """Return a health snapshot suitable for logging or a /health endpoint.""" + with self._lock: + return { + "price_usd": float(self._price) if self._price is not None else None, + "last_success": self.last_success, + "last_error": self.last_error, + "consecutive_failures": self.consecutive_failures, + "total_fetches": self.total_fetches, + "total_errors": self.total_errors, + "refresh_interval": self._refresh_interval, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _run(self) -> None: + while True: + time.sleep(self._refresh_interval) + self._refresh_price() + + def _refresh_price(self) -> None: + """Attempt to fetch a fresh price, retrying on transient failure. + + - On success: updates the cached price and resets ``consecutive_failures``. + - On HTTP 429: logs a rate-limit warning and exits the retry loop early + (no point hammering a rate-limited API). + - On exhausted retries: increments ``consecutive_failures`` and retains + the last known good price so live traffic is not disrupted by a + transient CoinGecko outage. + """ + last_exc: Optional[Exception] = None + + for attempt in range(1, self._max_retries + 1): + try: + price = _fetch_from_coingecko() + with self._lock: + self._price = price + self.last_success = time.time() + self.consecutive_failures = 0 + self.total_fetches += 1 + logger.info( + "OPG price updated: $%.6f USD (attempt %d/%d)", + float(price), + attempt, + self._max_retries, + ) + return + except requests.exceptions.HTTPError as exc: + last_exc = exc + status_code = ( + exc.response.status_code if exc.response is not None else None + ) + if status_code == 429: + logger.warning( + "CoinGecko rate limit hit (429) on attempt %d/%d; " + "skipping remaining retries for this cycle", + attempt, + self._max_retries, + ) + break + logger.warning( + "OPG price fetch attempt %d/%d failed (HTTP %s): %s", + attempt, + self._max_retries, + status_code, + exc, + ) + except Exception as exc: + last_exc = exc + logger.warning( + "OPG price fetch attempt %d/%d failed: %s", + attempt, + self._max_retries, + exc, + ) + + if attempt < self._max_retries: + time.sleep(self._retry_delay) + + # All attempts exhausted (or rate-limited out) — record the failure. + with self._lock: + self.total_errors += 1 + self.consecutive_failures += 1 + self.last_error = str(last_exc) if last_exc is not None else "unknown error" + + logger.error( + "OPG price refresh failed (consecutive failures: %d); " + "retaining last known price (%s)", + self.consecutive_failures, + self._price, + ) + + +def _fetch_from_coingecko() -> Decimal: + """Single CoinGecko HTTP call. Raises on any error.""" + url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" + params = { + "contract_addresses": BASE_MAINNET_OPG_ADDRESS, + "vs_currencies": "usd", + } + response = requests.get(url, params=params, timeout=FETCH_TIMEOUT) + response.raise_for_status() + + data: dict = response.json() + # CoinGecko keys the result by the lowercased contract address. + price_entry = data.get(BASE_MAINNET_OPG_ADDRESS.lower()) + if not isinstance(price_entry, dict) or "usd" not in price_entry: + raise ValueError( + f"Unexpected CoinGecko response for {BASE_MAINNET_OPG_ADDRESS}: {data!r}" + ) + + return Decimal(str(price_entry["usd"])) + + +# --------------------------------------------------------------------------- +# Module-level singleton — initialised by start_price_feed() +# --------------------------------------------------------------------------- + +_feed: Optional[OPGPriceFeed] = None + + +def start_price_feed( + refresh_interval: int = DEFAULT_REFRESH_INTERVAL, + max_retries: int = DEFAULT_MAX_RETRIES, + retry_delay: float = DEFAULT_RETRY_DELAY, +) -> None: + """Create and start the global OPG price feed. Call once at app startup.""" + global _feed + if _feed is not None: + logger.info("OPG price feed already running, skipping") + return + _feed = OPGPriceFeed( + refresh_interval=refresh_interval, + max_retries=max_retries, + retry_delay=retry_delay, + ) + _feed.start() + + +def get_opg_price_usd() -> Decimal: + """Return the current OPG/USD price from the running price feed. + + Raises ``ValueError`` if the feed has not been started or has not yet + completed a successful fetch. + """ + if _feed is None: + raise ValueError( + "OPG price feed has not been started — " + "call start_price_feed() at app startup" + ) + return _feed.get_price() + + +def get_price_feed_status() -> dict[str, Any]: + """Return the current health snapshot of the price feed. + + Returns ``{"status": "not_started"}`` if the feed has never been started. + """ + if _feed is None: + return {"status": "not_started"} + return _feed.get_status() diff --git a/tee_gateway/test/test_opg_price_feed.py b/tee_gateway/test/test_opg_price_feed.py new file mode 100644 index 0000000..64377af --- /dev/null +++ b/tee_gateway/test/test_opg_price_feed.py @@ -0,0 +1,387 @@ +""" +Unit tests for tee_gateway.opg_price_feed. + +All external HTTP calls are mocked — no network access required. + +Test classes +------------ +TestFetchFromCoinGecko — the raw _fetch_from_coingecko() helper +TestOPGPriceFeedRefresh — OPGPriceFeed._refresh_price() logic (retry, rate-limit, stats) +TestOPGPriceFeedGetPrice — OPGPriceFeed.get_price() (stale warning, ValueError before fetch) +TestOPGPriceFeedStatus — OPGPriceFeed.get_status() snapshots +TestModuleLevelFunctions — start_price_feed() / get_opg_price_usd() / get_price_feed_status() +""" + +import time +import unittest +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import requests + +from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS +from tee_gateway.opg_price_feed import ( + OPGPriceFeed, + _fetch_from_coingecko, + get_opg_price_usd, + get_price_feed_status, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +OPG_ADDRESS_LOWER = BASE_MAINNET_OPG_ADDRESS.lower() +SAMPLE_PRICE = Decimal("0.042") +SAMPLE_PRICE_FLOAT = 0.042 + + +def _mock_response(status_code: int = 200, json_body: dict | None = None) -> MagicMock: + """Build a minimal mock requests.Response.""" + mock = MagicMock() + mock.status_code = status_code + mock.json.return_value = json_body or {} + if status_code >= 400: + http_err = requests.exceptions.HTTPError(response=mock) + mock.raise_for_status.side_effect = http_err + else: + mock.raise_for_status.return_value = None + return mock + + +def _coingecko_success_body() -> dict: + return {OPG_ADDRESS_LOWER: {"usd": SAMPLE_PRICE_FLOAT}} + + +# --------------------------------------------------------------------------- +# TestFetchFromCoinGecko +# --------------------------------------------------------------------------- + + +class TestFetchFromCoinGecko(unittest.TestCase): + """Tests for the _fetch_from_coingecko() free function.""" + + @patch("tee_gateway.opg_price_feed.requests.get") + def test_happy_path_returns_decimal(self, mock_get): + mock_get.return_value = _mock_response(200, _coingecko_success_body()) + price = _fetch_from_coingecko() + self.assertIsInstance(price, Decimal) + self.assertEqual(price, Decimal(str(SAMPLE_PRICE_FLOAT))) + + @patch("tee_gateway.opg_price_feed.requests.get") + def test_passes_correct_params(self, mock_get): + mock_get.return_value = _mock_response(200, _coingecko_success_body()) + _fetch_from_coingecko() + _, kwargs = mock_get.call_args + self.assertIn("contract_addresses", kwargs["params"]) + self.assertEqual(kwargs["params"]["vs_currencies"], "usd") + self.assertIn( + "base", kwargs["url"] if "url" in kwargs else mock_get.call_args[0][0] + ) + + @patch("tee_gateway.opg_price_feed.requests.get") + def test_raises_on_http_500(self, mock_get): + mock_get.return_value = _mock_response(500) + with self.assertRaises(requests.exceptions.HTTPError): + _fetch_from_coingecko() + + @patch("tee_gateway.opg_price_feed.requests.get") + def test_raises_on_http_429(self, mock_get): + mock_get.return_value = _mock_response(429) + with self.assertRaises(requests.exceptions.HTTPError) as ctx: + _fetch_from_coingecko() + self.assertEqual(ctx.exception.response.status_code, 429) + + @patch("tee_gateway.opg_price_feed.requests.get") + def test_raises_on_empty_response_body(self, mock_get): + mock_get.return_value = _mock_response(200, {}) + with self.assertRaises(ValueError, msg="should raise when address key absent"): + _fetch_from_coingecko() + + @patch("tee_gateway.opg_price_feed.requests.get") + def test_raises_when_usd_key_missing(self, mock_get): + # Address present but no 'usd' field + mock_get.return_value = _mock_response(200, {OPG_ADDRESS_LOWER: {"eur": 0.04}}) + with self.assertRaises(ValueError): + _fetch_from_coingecko() + + @patch("tee_gateway.opg_price_feed.requests.get") + def test_raises_on_network_error(self, mock_get): + mock_get.side_effect = requests.exceptions.ConnectionError("timeout") + with self.assertRaises(requests.exceptions.ConnectionError): + _fetch_from_coingecko() + + +# --------------------------------------------------------------------------- +# TestOPGPriceFeedRefresh +# --------------------------------------------------------------------------- + + +class TestOPGPriceFeedRefresh(unittest.TestCase): + """Tests for OPGPriceFeed._refresh_price() — retry logic, rate-limit, stats.""" + + def _feed(self, **kwargs) -> OPGPriceFeed: + defaults = {"refresh_interval": 300, "max_retries": 3, "retry_delay": 0} + defaults.update(kwargs) + return OPGPriceFeed(**defaults) + + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_successful_refresh_sets_price(self, mock_fetch): + mock_fetch.return_value = SAMPLE_PRICE + feed = self._feed() + feed._refresh_price() + self.assertEqual(feed._price, SAMPLE_PRICE) + + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_successful_refresh_updates_stats(self, mock_fetch): + mock_fetch.return_value = SAMPLE_PRICE + feed = self._feed() + feed._refresh_price() + self.assertEqual(feed.total_fetches, 1) + self.assertEqual(feed.total_errors, 0) + self.assertEqual(feed.consecutive_failures, 0) + self.assertIsNotNone(feed.last_success) + + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_retry_on_transient_failure_then_success(self, mock_fetch): + # Fail twice, succeed on third attempt + mock_fetch.side_effect = [ + ValueError("transient"), + ValueError("transient"), + SAMPLE_PRICE, + ] + feed = self._feed(max_retries=3, retry_delay=0) + feed._refresh_price() + self.assertEqual(feed._price, SAMPLE_PRICE) + self.assertEqual(mock_fetch.call_count, 3) + self.assertEqual(feed.total_fetches, 1) + self.assertEqual(feed.total_errors, 0) + + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_exhausted_retries_records_error_stats(self, mock_fetch): + mock_fetch.side_effect = ValueError("always fails") + feed = self._feed(max_retries=3, retry_delay=0) + feed._refresh_price() + self.assertEqual(feed.total_errors, 1) + self.assertEqual(feed.consecutive_failures, 1) + self.assertIsNotNone(feed.last_error) + self.assertEqual(feed.total_fetches, 0) + + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_exhausted_retries_keeps_last_known_price(self, mock_fetch): + # Seed a previous price, then let all retries fail + feed = self._feed(max_retries=2, retry_delay=0) + feed._price = SAMPLE_PRICE # simulate a previously fetched price + feed.last_success = time.time() + mock_fetch.side_effect = ValueError("fail") + feed._refresh_price() + # Price must not be cleared + self.assertEqual(feed._price, SAMPLE_PRICE) + + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_success_after_failures_resets_consecutive_failures(self, mock_fetch): + feed = self._feed(max_retries=1, retry_delay=0) + # First cycle fails + mock_fetch.side_effect = ValueError("fail") + feed._refresh_price() + self.assertEqual(feed.consecutive_failures, 1) + # Second cycle succeeds + mock_fetch.side_effect = None + mock_fetch.return_value = SAMPLE_PRICE + feed._refresh_price() + self.assertEqual(feed.consecutive_failures, 0) + + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_rate_limit_breaks_retry_loop_immediately(self, mock_fetch): + # Build a 429 HTTPError + resp = MagicMock() + resp.status_code = 429 + http_err = requests.exceptions.HTTPError(response=resp) + mock_fetch.side_effect = http_err + + feed = self._feed(max_retries=3, retry_delay=0) + feed._refresh_price() + + # Should only attempt once — 429 means no further retries + self.assertEqual(mock_fetch.call_count, 1) + self.assertEqual(feed.total_errors, 1) + + @patch("tee_gateway.opg_price_feed.time.sleep") + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_retry_delay_called_between_attempts(self, mock_fetch, mock_sleep): + mock_fetch.side_effect = [ValueError("fail"), ValueError("fail"), SAMPLE_PRICE] + feed = self._feed(max_retries=3, retry_delay=5) + feed._refresh_price() + # sleep should be called between attempts 1→2 and 2→3 (not after success) + self.assertEqual(mock_sleep.call_count, 2) + mock_sleep.assert_called_with(5) + + @patch("tee_gateway.opg_price_feed.time.sleep") + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_no_sleep_after_last_failed_attempt(self, mock_fetch, mock_sleep): + mock_fetch.side_effect = ValueError("always fails") + feed = self._feed(max_retries=3, retry_delay=5) + feed._refresh_price() + # 3 attempts: sleep after attempt 1, sleep after attempt 2, NO sleep after attempt 3 + self.assertEqual(mock_sleep.call_count, 2) + + +# --------------------------------------------------------------------------- +# TestOPGPriceFeedGetPrice +# --------------------------------------------------------------------------- + + +class TestOPGPriceFeedGetPrice(unittest.TestCase): + """Tests for OPGPriceFeed.get_price() behaviour.""" + + def test_raises_before_any_successful_fetch(self): + feed = OPGPriceFeed() + with self.assertRaises(ValueError) as ctx: + feed.get_price() + self.assertIn("not yet available", str(ctx.exception)) + + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_returns_price_after_successful_refresh(self, mock_fetch): + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(retry_delay=0) + feed._refresh_price() + self.assertEqual(feed.get_price(), SAMPLE_PRICE) + + @patch("tee_gateway.opg_price_feed.time.time") + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_warns_when_price_is_stale(self, mock_fetch, mock_time): + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(refresh_interval=300, retry_delay=0) + + # Simulate fetch at t=0 + mock_time.return_value = 0.0 + feed._refresh_price() + + # Advance time past stale threshold (300 * 2 = 600s) + mock_time.return_value = 601.0 + + with self.assertLogs("llm_server.opg_price_feed", level="WARNING") as log_ctx: + price = feed.get_price() + + self.assertEqual(price, SAMPLE_PRICE) + self.assertTrue(any("stale" in line.lower() for line in log_ctx.output)) + + @patch("tee_gateway.opg_price_feed.time.time") + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_no_stale_warning_when_price_is_fresh(self, mock_fetch, mock_time): + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(refresh_interval=300, retry_delay=0) + + mock_time.return_value = 0.0 + feed._refresh_price() + mock_time.return_value = 100.0 # well within threshold + + # Should not emit any WARNING + import logging + + with self.assertLogs("llm_server.opg_price_feed", level="DEBUG") as log_ctx: + # Emit a debug line ourselves so assertLogs doesn't raise on empty + logging.getLogger("llm_server.opg_price_feed").debug("sentinel") + feed.get_price() + + warning_lines = [ + line for line in log_ctx.output if "WARNING" in line and "stale" in line.lower() + ] + self.assertEqual(warning_lines, []) + + +# --------------------------------------------------------------------------- +# TestOPGPriceFeedStatus +# --------------------------------------------------------------------------- + + +class TestOPGPriceFeedStatus(unittest.TestCase): + """Tests for OPGPriceFeed.get_status() snapshot.""" + + def test_initial_status_has_no_price(self): + feed = OPGPriceFeed() + status = feed.get_status() + self.assertIsNone(status["price_usd"]) + self.assertIsNone(status["last_success"]) + self.assertEqual(status["consecutive_failures"], 0) + self.assertEqual(status["total_fetches"], 0) + self.assertEqual(status["total_errors"], 0) + + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_status_reflects_successful_fetch(self, mock_fetch): + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(retry_delay=0) + feed._refresh_price() + status = feed.get_status() + self.assertAlmostEqual(status["price_usd"], float(SAMPLE_PRICE), places=6) + self.assertIsNotNone(status["last_success"]) + self.assertEqual(status["total_fetches"], 1) + self.assertEqual(status["consecutive_failures"], 0) + + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_status_reflects_failed_cycle(self, mock_fetch): + mock_fetch.side_effect = ValueError("fail") + feed = OPGPriceFeed(max_retries=1, retry_delay=0) + feed._refresh_price() + status = feed.get_status() + self.assertIsNone(status["price_usd"]) + self.assertEqual(status["total_errors"], 1) + self.assertEqual(status["consecutive_failures"], 1) + self.assertIsNotNone(status["last_error"]) + + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_status_includes_refresh_interval(self, mock_fetch): + feed = OPGPriceFeed(refresh_interval=600) + self.assertEqual(feed.get_status()["refresh_interval"], 600) + + @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + def test_status_accumulates_multiple_error_cycles(self, mock_fetch): + mock_fetch.side_effect = ValueError("fail") + feed = OPGPriceFeed(max_retries=1, retry_delay=0) + feed._refresh_price() + feed._refresh_price() + feed._refresh_price() + status = feed.get_status() + self.assertEqual(status["total_errors"], 3) + self.assertEqual(status["consecutive_failures"], 3) + + +# --------------------------------------------------------------------------- +# TestModuleLevelFunctions +# --------------------------------------------------------------------------- + + +class TestModuleLevelFunctions(unittest.TestCase): + """Tests for the module-level singleton helpers.""" + + def test_get_opg_price_usd_raises_when_feed_is_none(self): + with patch("tee_gateway.opg_price_feed._feed", None): + with self.assertRaises(ValueError) as ctx: + get_opg_price_usd() + self.assertIn("not been started", str(ctx.exception)) + + def test_get_opg_price_usd_delegates_to_feed(self): + mock_feed = MagicMock() + mock_feed.get_price.return_value = SAMPLE_PRICE + with patch("tee_gateway.opg_price_feed._feed", mock_feed): + price = get_opg_price_usd() + self.assertEqual(price, SAMPLE_PRICE) + mock_feed.get_price.assert_called_once() + + def test_get_price_feed_status_when_feed_is_none(self): + with patch("tee_gateway.opg_price_feed._feed", None): + status = get_price_feed_status() + self.assertEqual(status, {"status": "not_started"}) + + def test_get_price_feed_status_delegates_to_feed(self): + expected = {"price_usd": 0.042, "total_fetches": 5} + mock_feed = MagicMock() + mock_feed.get_status.return_value = expected + with patch("tee_gateway.opg_price_feed._feed", mock_feed): + status = get_price_feed_status() + self.assertEqual(status, expected) + mock_feed.get_status.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/tee_gateway/test/test_opg_price_feed_integration.py b/tee_gateway/test/test_opg_price_feed_integration.py new file mode 100644 index 0000000..2e35bd6 --- /dev/null +++ b/tee_gateway/test/test_opg_price_feed_integration.py @@ -0,0 +1,127 @@ +""" +Integration tests for tee_gateway.opg_price_feed. + +These tests make REAL network calls to the CoinGecko public API. + +Expected behaviour +------------------ +* ``TestCoinGeckoConnectivity`` — passes when the CoinGecko API is reachable. + Skips on network errors or rate-limiting (429). +* ``TestOPGPriceFetchLive`` — skips when OPG is not yet priced on CoinGecko's + Base platform (CoinGecko currently returns an empty price entry for the + token). Will pass automatically once the token is fully listed. + +Run with:: + + uv run pytest tee_gateway/test/test_opg_price_feed_integration.py -v +""" + +import unittest +from decimal import Decimal + +import requests + +from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS +from tee_gateway.opg_price_feed import ( + COINGECKO_BASE_URL, + COINGECKO_PLATFORM, + FETCH_TIMEOUT, + _fetch_from_coingecko, +) + + +def _get(url: str, **kwargs) -> requests.Response: + """Wrapper that skips the test on network errors or rate-limiting.""" + try: + resp = requests.get(url, timeout=FETCH_TIMEOUT, **kwargs) + except requests.exceptions.RequestException as exc: + raise unittest.SkipTest(f"Network unavailable: {exc}") from exc + if resp.status_code == 429: + raise unittest.SkipTest( + "CoinGecko rate limit hit (429) — re-run after a short wait" + ) + return resp + + +class TestCoinGeckoConnectivity(unittest.TestCase): + """Verify that the CoinGecko API endpoint is reachable and well-formed.""" + + def test_ping_endpoint_reachable(self): + """CoinGecko /ping should return {gecko_says: ...}.""" + resp = _get(f"{COINGECKO_BASE_URL}/ping") + self.assertEqual(resp.status_code, 200) + self.assertIn("gecko_says", resp.json()) + + def test_base_platform_endpoint_returns_200(self): + """The token_price/base endpoint should respond with HTTP 200 for a known token.""" + url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" + # USDC on Base mainnet — reliably indexed on CoinGecko. + usdc_base = "0x833589fcd6edb6e08f4c7c32d4f71b54bda02913" + resp = _get(url, params={"contract_addresses": usdc_base, "vs_currencies": "usd"}) + self.assertEqual( + resp.status_code, + 200, + f"Expected 200 from CoinGecko, got {resp.status_code}: {resp.text[:200]}", + ) + data = resp.json() + self.assertIsInstance(data, dict) + self.assertIn(usdc_base, data, "USDC should be indexed on Base platform") + self.assertIn("usd", data[usdc_base], "USDC price entry should have 'usd' key") + + +class TestOPGPriceFetchLive(unittest.TestCase): + """Live fetch of the OPG token price. + + Both tests skip gracefully when OPG is not yet fully priced on CoinGecko + (currently returns ``{address: {}}`` with no 'usd' key). They will pass + automatically once the token is listed with a live price. + """ + + def test_opg_response_structure(self): + """Inspect the raw CoinGecko response for the OPG contract address.""" + url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" + resp = _get( + url, + params={ + "contract_addresses": BASE_MAINNET_OPG_ADDRESS, + "vs_currencies": "usd", + }, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + print(f"\nCoinGecko response for OPG ({BASE_MAINNET_OPG_ADDRESS}): {data}") # noqa: T201 + + opg_lower = BASE_MAINNET_OPG_ADDRESS.lower() + price_entry = data.get(opg_lower) + # CoinGecko returns the address key with {} when the token is known but + # not yet priced — skip in that case rather than fail. + if not price_entry or "usd" not in price_entry: + self.skipTest( + f"OPG not yet priced on CoinGecko Base platform " + f"(response: {data!r}). Will pass once the token is fully listed." + ) + self.assertIsInstance(price_entry["usd"], (int, float)) + + def test_opg_price_fetch_live(self): + """End-to-end: _fetch_from_coingecko() returns a positive Decimal price.""" + try: + price = _fetch_from_coingecko() + except requests.exceptions.HTTPError as exc: + if exc.response is not None and exc.response.status_code == 429: + self.skipTest("CoinGecko rate limit — re-run after a short wait") + raise + except ValueError as exc: + if "Unexpected CoinGecko response" in str(exc): + self.skipTest( + f"OPG ({BASE_MAINNET_OPG_ADDRESS}) not yet priced on " + f"CoinGecko Base platform. Details: {exc}" + ) + raise + + self.assertIsInstance(price, Decimal) + self.assertGreater(price, Decimal("0"), "Price must be positive") + print(f"\nLive OPG price: ${price} USD") # noqa: T201 + + +if __name__ == "__main__": + unittest.main() diff --git a/tee_gateway/util.py b/tee_gateway/util.py index 47559d9..a242a5c 100644 --- a/tee_gateway/util.py +++ b/tee_gateway/util.py @@ -2,8 +2,6 @@ from tee_gateway import typing_utils import logging -import threading -import time from decimal import Decimal, InvalidOperation, ROUND_CEILING from typing import Any @@ -160,42 +158,11 @@ def _deserialize_dict(data, boxed_type): ) from tee_gateway.model_registry import get_model_config # noqa: E402 -TOKEN_A_PRICE_CACHE_TTL_SECONDS = 60 - -_token_price_cache: dict[str, Any] = { - "value": Decimal("1"), - "updated_at": 0.0, -} -_token_price_lock = threading.Lock() - - -def _fetch_token_a_price_usd_mock() -> Decimal: - """Return the USD price of the payment token used for cost calculation. - - Currently returns a fixed 1:1 ratio, which is correct for USDC-denominated - payments (1 USDC ≈ $1 USD). For OPG-denominated payments, replace this - with a live price feed (e.g. a DEX oracle or CoinGecko API call) that - returns the current OPG/USD exchange rate so that token amounts are - calculated correctly against the model's USD pricing. - """ - return Decimal("1") - def get_token_a_price_usd() -> Decimal: - now = time.time() - with _token_price_lock: - cached_value = _token_price_cache.get("value") - cached_at = float(_token_price_cache.get("updated_at") or 0.0) - if ( - isinstance(cached_value, Decimal) - and (now - cached_at) < TOKEN_A_PRICE_CACHE_TTL_SECONDS - ): - return cached_value + from tee_gateway.opg_price_feed import get_opg_price_usd # noqa: PLC0415 - value = _fetch_token_a_price_usd_mock() - _token_price_cache["value"] = value - _token_price_cache["updated_at"] = now - return value + return get_opg_price_usd() def _as_dict(value: Any) -> dict[str, Any] | None: From 5ff9fc9f2db726ac50972785c9553797c1ff02b1 Mon Sep 17 00:00:00 2001 From: kylexqian Date: Sun, 19 Apr 2026 21:55:59 -0700 Subject: [PATCH 02/17] refactor: Reorganize price feed into tee_gateway/price_feed/ package MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Moves opg_price_feed.py into a dedicated package matching the layout of the existing tee_gateway/heartbeat/ package: tee_gateway/price_feed/ __init__.py — re-exports public API (OPGPriceFeed, PriceFeedConfig, start_price_feed, get_opg_price_usd, get_price_feed_status) config.py — all constants (COINGECKO_BASE_URL, COINGECKO_PLATFORM, FETCH_TIMEOUT, refresh/retry defaults, stale threshold) plus the PriceFeedConfig frozen dataclass feed.py — OPGPriceFeed class, fetch_opg_price(), singleton helpers Also renames the test files to match: test_opg_price_feed.py -> test_price_feed.py test_opg_price_feed_integration.py -> test_price_feed_integration.py No behaviour changes — all 29 unit tests still pass. Co-Authored-By: Claude Sonnet 4.6 --- tee_gateway/__main__.py | 2 +- tee_gateway/price_feed/__init__.py | 15 ++ tee_gateway/price_feed/config.py | 36 +++++ .../{opg_price_feed.py => price_feed/feed.py} | 47 +++--- ...t_opg_price_feed.py => test_price_feed.py} | 134 ++++++++---------- ...tion.py => test_price_feed_integration.py} | 16 ++- tee_gateway/util.py | 2 +- 7 files changed, 142 insertions(+), 110 deletions(-) create mode 100644 tee_gateway/price_feed/__init__.py create mode 100644 tee_gateway/price_feed/config.py rename tee_gateway/{opg_price_feed.py => price_feed/feed.py} (90%) rename tee_gateway/test/{test_opg_price_feed.py => test_price_feed.py} (75%) rename tee_gateway/test/{test_opg_price_feed_integration.py => test_price_feed_integration.py} (91%) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index a070128..ab84846 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -37,7 +37,7 @@ import types as _types from .util import dynamic_session_cost_calculator -from .opg_price_feed import start_price_feed +from .price_feed import start_price_feed from .definitions import ( EVM_PAYMENT_ADDRESS, BASE_MAINNET_NETWORK, diff --git a/tee_gateway/price_feed/__init__.py b/tee_gateway/price_feed/__init__.py new file mode 100644 index 0000000..13aa282 --- /dev/null +++ b/tee_gateway/price_feed/__init__.py @@ -0,0 +1,15 @@ +from .config import PriceFeedConfig +from .feed import ( + OPGPriceFeed, + get_opg_price_usd, + get_price_feed_status, + start_price_feed, +) + +__all__ = [ + "OPGPriceFeed", + "PriceFeedConfig", + "get_opg_price_usd", + "get_price_feed_status", + "start_price_feed", +] diff --git a/tee_gateway/price_feed/config.py b/tee_gateway/price_feed/config.py new file mode 100644 index 0000000..46b2268 --- /dev/null +++ b/tee_gateway/price_feed/config.py @@ -0,0 +1,36 @@ +""" +Configuration constants and dataclass for the OPG price feed. +""" + +from dataclasses import dataclass + + +# --------------------------------------------------------------------------- +# CoinGecko API +# --------------------------------------------------------------------------- +COINGECKO_BASE_URL = "https://api.coingecko.com/api/v3" +COINGECKO_PLATFORM = "base" # Base mainnet platform identifier on CoinGecko +FETCH_TIMEOUT = 10 # seconds per HTTP request + +# --------------------------------------------------------------------------- +# Refresh / retry defaults +# --------------------------------------------------------------------------- +DEFAULT_REFRESH_INTERVAL = 300 # 5 minutes between background refresh cycles +DEFAULT_MAX_RETRIES = 3 # attempts per refresh cycle before giving up +DEFAULT_RETRY_DELAY = 10 # seconds between retry attempts within a cycle + +# --------------------------------------------------------------------------- +# Stale-price warning threshold +# --------------------------------------------------------------------------- +# get_price() logs WARNING when last successful fetch is older than +# STALE_WARNING_MULTIPLIER × refresh_interval seconds. +STALE_WARNING_MULTIPLIER = 2 + + +@dataclass(frozen=True) +class PriceFeedConfig: + """Runtime configuration for the OPG price feed background service.""" + + refresh_interval: int = DEFAULT_REFRESH_INTERVAL + max_retries: int = DEFAULT_MAX_RETRIES + retry_delay: float = DEFAULT_RETRY_DELAY diff --git a/tee_gateway/opg_price_feed.py b/tee_gateway/price_feed/feed.py similarity index 90% rename from tee_gateway/opg_price_feed.py rename to tee_gateway/price_feed/feed.py index 4c18f2a..078df96 100644 --- a/tee_gateway/opg_price_feed.py +++ b/tee_gateway/price_feed/feed.py @@ -25,24 +25,18 @@ import requests from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS +from tee_gateway.price_feed.config import ( + COINGECKO_BASE_URL, + COINGECKO_PLATFORM, + DEFAULT_MAX_RETRIES, + DEFAULT_REFRESH_INTERVAL, + DEFAULT_RETRY_DELAY, + FETCH_TIMEOUT, + STALE_WARNING_MULTIPLIER, + PriceFeedConfig, +) -logger = logging.getLogger("llm_server.opg_price_feed") - -COINGECKO_BASE_URL = "https://api.coingecko.com/api/v3" -COINGECKO_PLATFORM = "base" - -# How long to wait between background refreshes (seconds). -DEFAULT_REFRESH_INTERVAL = 300 # 5 minutes — well within CoinGecko free-tier limits - -# Per-refresh retry settings. -DEFAULT_MAX_RETRIES = 3 -DEFAULT_RETRY_DELAY = 10 # seconds between retry attempts within a single refresh cycle - -# HTTP request timeout for each CoinGecko call. -FETCH_TIMEOUT = 10 - -# Warn in get_price() when the last successful fetch is this many intervals old. -STALE_WARNING_MULTIPLIER = 2 +logger = logging.getLogger("llm_server.price_feed") class OPGPriceFeed: @@ -156,7 +150,7 @@ def _refresh_price(self) -> None: for attempt in range(1, self._max_retries + 1): try: - price = _fetch_from_coingecko() + price = fetch_opg_price() with self._lock: self._price = price self.last_success = time.time() @@ -215,8 +209,8 @@ def _refresh_price(self) -> None: ) -def _fetch_from_coingecko() -> Decimal: - """Single CoinGecko HTTP call. Raises on any error.""" +def fetch_opg_price() -> Decimal: + """Fetch the current OPG/USD price from CoinGecko. Raises on any error.""" url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" params = { "contract_addresses": BASE_MAINNET_OPG_ADDRESS, @@ -243,20 +237,17 @@ def _fetch_from_coingecko() -> Decimal: _feed: Optional[OPGPriceFeed] = None -def start_price_feed( - refresh_interval: int = DEFAULT_REFRESH_INTERVAL, - max_retries: int = DEFAULT_MAX_RETRIES, - retry_delay: float = DEFAULT_RETRY_DELAY, -) -> None: +def start_price_feed(config: Optional[PriceFeedConfig] = None) -> None: """Create and start the global OPG price feed. Call once at app startup.""" global _feed if _feed is not None: logger.info("OPG price feed already running, skipping") return + cfg = config or PriceFeedConfig() _feed = OPGPriceFeed( - refresh_interval=refresh_interval, - max_retries=max_retries, - retry_delay=retry_delay, + refresh_interval=cfg.refresh_interval, + max_retries=cfg.max_retries, + retry_delay=cfg.retry_delay, ) _feed.start() diff --git a/tee_gateway/test/test_opg_price_feed.py b/tee_gateway/test/test_price_feed.py similarity index 75% rename from tee_gateway/test/test_opg_price_feed.py rename to tee_gateway/test/test_price_feed.py index 64377af..92c435e 100644 --- a/tee_gateway/test/test_opg_price_feed.py +++ b/tee_gateway/test/test_price_feed.py @@ -1,14 +1,14 @@ """ -Unit tests for tee_gateway.opg_price_feed. +Unit tests for tee_gateway.price_feed. All external HTTP calls are mocked — no network access required. Test classes ------------ -TestFetchFromCoinGecko — the raw _fetch_from_coingecko() helper -TestOPGPriceFeedRefresh — OPGPriceFeed._refresh_price() logic (retry, rate-limit, stats) +TestFetchOPGPrice — the raw fetch_opg_price() helper in feed.py +TestOPGPriceFeedRefresh — OPGPriceFeed._refresh_price() (retry, rate-limit, stats) TestOPGPriceFeedGetPrice — OPGPriceFeed.get_price() (stale warning, ValueError before fetch) -TestOPGPriceFeedStatus — OPGPriceFeed.get_status() snapshots +TestOPGPriceFeedStatus — OPGPriceFeed.get_status() snapshots TestModuleLevelFunctions — start_price_feed() / get_opg_price_usd() / get_price_feed_status() """ @@ -20,12 +20,12 @@ import requests from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS -from tee_gateway.opg_price_feed import ( +from tee_gateway.price_feed import ( OPGPriceFeed, - _fetch_from_coingecko, get_opg_price_usd, get_price_feed_status, ) +from tee_gateway.price_feed.feed import fetch_opg_price # --------------------------------------------------------------------------- # Helpers @@ -35,6 +35,9 @@ SAMPLE_PRICE = Decimal("0.042") SAMPLE_PRICE_FLOAT = 0.042 +# Patch target prefix — all mocks go through the feed module. +_FEED = "tee_gateway.price_feed.feed" + def _mock_response(status_code: int = 200, json_body: dict | None = None) -> MagicMock: """Build a minimal mock requests.Response.""" @@ -54,24 +57,24 @@ def _coingecko_success_body() -> dict: # --------------------------------------------------------------------------- -# TestFetchFromCoinGecko +# TestFetchOPGPrice # --------------------------------------------------------------------------- -class TestFetchFromCoinGecko(unittest.TestCase): - """Tests for the _fetch_from_coingecko() free function.""" +class TestFetchOPGPrice(unittest.TestCase): + """Tests for the fetch_opg_price() free function in feed.py.""" - @patch("tee_gateway.opg_price_feed.requests.get") + @patch(f"{_FEED}.requests.get") def test_happy_path_returns_decimal(self, mock_get): mock_get.return_value = _mock_response(200, _coingecko_success_body()) - price = _fetch_from_coingecko() + price = fetch_opg_price() self.assertIsInstance(price, Decimal) self.assertEqual(price, Decimal(str(SAMPLE_PRICE_FLOAT))) - @patch("tee_gateway.opg_price_feed.requests.get") + @patch(f"{_FEED}.requests.get") def test_passes_correct_params(self, mock_get): mock_get.return_value = _mock_response(200, _coingecko_success_body()) - _fetch_from_coingecko() + fetch_opg_price() _, kwargs = mock_get.call_args self.assertIn("contract_addresses", kwargs["params"]) self.assertEqual(kwargs["params"]["vs_currencies"], "usd") @@ -79,37 +82,36 @@ def test_passes_correct_params(self, mock_get): "base", kwargs["url"] if "url" in kwargs else mock_get.call_args[0][0] ) - @patch("tee_gateway.opg_price_feed.requests.get") + @patch(f"{_FEED}.requests.get") def test_raises_on_http_500(self, mock_get): mock_get.return_value = _mock_response(500) with self.assertRaises(requests.exceptions.HTTPError): - _fetch_from_coingecko() + fetch_opg_price() - @patch("tee_gateway.opg_price_feed.requests.get") + @patch(f"{_FEED}.requests.get") def test_raises_on_http_429(self, mock_get): mock_get.return_value = _mock_response(429) with self.assertRaises(requests.exceptions.HTTPError) as ctx: - _fetch_from_coingecko() + fetch_opg_price() self.assertEqual(ctx.exception.response.status_code, 429) - @patch("tee_gateway.opg_price_feed.requests.get") + @patch(f"{_FEED}.requests.get") def test_raises_on_empty_response_body(self, mock_get): mock_get.return_value = _mock_response(200, {}) with self.assertRaises(ValueError, msg="should raise when address key absent"): - _fetch_from_coingecko() + fetch_opg_price() - @patch("tee_gateway.opg_price_feed.requests.get") + @patch(f"{_FEED}.requests.get") def test_raises_when_usd_key_missing(self, mock_get): - # Address present but no 'usd' field mock_get.return_value = _mock_response(200, {OPG_ADDRESS_LOWER: {"eur": 0.04}}) with self.assertRaises(ValueError): - _fetch_from_coingecko() + fetch_opg_price() - @patch("tee_gateway.opg_price_feed.requests.get") + @patch(f"{_FEED}.requests.get") def test_raises_on_network_error(self, mock_get): mock_get.side_effect = requests.exceptions.ConnectionError("timeout") with self.assertRaises(requests.exceptions.ConnectionError): - _fetch_from_coingecko() + fetch_opg_price() # --------------------------------------------------------------------------- @@ -125,14 +127,14 @@ def _feed(self, **kwargs) -> OPGPriceFeed: defaults.update(kwargs) return OPGPriceFeed(**defaults) - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.fetch_opg_price") def test_successful_refresh_sets_price(self, mock_fetch): mock_fetch.return_value = SAMPLE_PRICE feed = self._feed() feed._refresh_price() self.assertEqual(feed._price, SAMPLE_PRICE) - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.fetch_opg_price") def test_successful_refresh_updates_stats(self, mock_fetch): mock_fetch.return_value = SAMPLE_PRICE feed = self._feed() @@ -142,9 +144,8 @@ def test_successful_refresh_updates_stats(self, mock_fetch): self.assertEqual(feed.consecutive_failures, 0) self.assertIsNotNone(feed.last_success) - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.fetch_opg_price") def test_retry_on_transient_failure_then_success(self, mock_fetch): - # Fail twice, succeed on third attempt mock_fetch.side_effect = [ ValueError("transient"), ValueError("transient"), @@ -157,7 +158,7 @@ def test_retry_on_transient_failure_then_success(self, mock_fetch): self.assertEqual(feed.total_fetches, 1) self.assertEqual(feed.total_errors, 0) - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.fetch_opg_price") def test_exhausted_retries_records_error_stats(self, mock_fetch): mock_fetch.side_effect = ValueError("always fails") feed = self._feed(max_retries=3, retry_delay=0) @@ -167,62 +168,51 @@ def test_exhausted_retries_records_error_stats(self, mock_fetch): self.assertIsNotNone(feed.last_error) self.assertEqual(feed.total_fetches, 0) - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.fetch_opg_price") def test_exhausted_retries_keeps_last_known_price(self, mock_fetch): - # Seed a previous price, then let all retries fail feed = self._feed(max_retries=2, retry_delay=0) - feed._price = SAMPLE_PRICE # simulate a previously fetched price + feed._price = SAMPLE_PRICE feed.last_success = time.time() mock_fetch.side_effect = ValueError("fail") feed._refresh_price() - # Price must not be cleared self.assertEqual(feed._price, SAMPLE_PRICE) - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.fetch_opg_price") def test_success_after_failures_resets_consecutive_failures(self, mock_fetch): feed = self._feed(max_retries=1, retry_delay=0) - # First cycle fails mock_fetch.side_effect = ValueError("fail") feed._refresh_price() self.assertEqual(feed.consecutive_failures, 1) - # Second cycle succeeds mock_fetch.side_effect = None mock_fetch.return_value = SAMPLE_PRICE feed._refresh_price() self.assertEqual(feed.consecutive_failures, 0) - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.fetch_opg_price") def test_rate_limit_breaks_retry_loop_immediately(self, mock_fetch): - # Build a 429 HTTPError resp = MagicMock() resp.status_code = 429 - http_err = requests.exceptions.HTTPError(response=resp) - mock_fetch.side_effect = http_err - + mock_fetch.side_effect = requests.exceptions.HTTPError(response=resp) feed = self._feed(max_retries=3, retry_delay=0) feed._refresh_price() - - # Should only attempt once — 429 means no further retries self.assertEqual(mock_fetch.call_count, 1) self.assertEqual(feed.total_errors, 1) - @patch("tee_gateway.opg_price_feed.time.sleep") - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.time.sleep") + @patch(f"{_FEED}.fetch_opg_price") def test_retry_delay_called_between_attempts(self, mock_fetch, mock_sleep): mock_fetch.side_effect = [ValueError("fail"), ValueError("fail"), SAMPLE_PRICE] feed = self._feed(max_retries=3, retry_delay=5) feed._refresh_price() - # sleep should be called between attempts 1→2 and 2→3 (not after success) self.assertEqual(mock_sleep.call_count, 2) mock_sleep.assert_called_with(5) - @patch("tee_gateway.opg_price_feed.time.sleep") - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.time.sleep") + @patch(f"{_FEED}.fetch_opg_price") def test_no_sleep_after_last_failed_attempt(self, mock_fetch, mock_sleep): mock_fetch.side_effect = ValueError("always fails") feed = self._feed(max_retries=3, retry_delay=5) feed._refresh_price() - # 3 attempts: sleep after attempt 1, sleep after attempt 2, NO sleep after attempt 3 self.assertEqual(mock_sleep.call_count, 2) @@ -240,35 +230,36 @@ def test_raises_before_any_successful_fetch(self): feed.get_price() self.assertIn("not yet available", str(ctx.exception)) - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.fetch_opg_price") def test_returns_price_after_successful_refresh(self, mock_fetch): mock_fetch.return_value = SAMPLE_PRICE feed = OPGPriceFeed(retry_delay=0) feed._refresh_price() self.assertEqual(feed.get_price(), SAMPLE_PRICE) - @patch("tee_gateway.opg_price_feed.time.time") - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.time.time") + @patch(f"{_FEED}.fetch_opg_price") def test_warns_when_price_is_stale(self, mock_fetch, mock_time): mock_fetch.return_value = SAMPLE_PRICE feed = OPGPriceFeed(refresh_interval=300, retry_delay=0) - # Simulate fetch at t=0 mock_time.return_value = 0.0 feed._refresh_price() - # Advance time past stale threshold (300 * 2 = 600s) + # Advance past stale threshold (300 * 2 = 600s) mock_time.return_value = 601.0 - with self.assertLogs("llm_server.opg_price_feed", level="WARNING") as log_ctx: + with self.assertLogs("llm_server.price_feed", level="WARNING") as log_ctx: price = feed.get_price() self.assertEqual(price, SAMPLE_PRICE) self.assertTrue(any("stale" in line.lower() for line in log_ctx.output)) - @patch("tee_gateway.opg_price_feed.time.time") - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.time.time") + @patch(f"{_FEED}.fetch_opg_price") def test_no_stale_warning_when_price_is_fresh(self, mock_fetch, mock_time): + import logging + mock_fetch.return_value = SAMPLE_PRICE feed = OPGPriceFeed(refresh_interval=300, retry_delay=0) @@ -276,16 +267,14 @@ def test_no_stale_warning_when_price_is_fresh(self, mock_fetch, mock_time): feed._refresh_price() mock_time.return_value = 100.0 # well within threshold - # Should not emit any WARNING - import logging - - with self.assertLogs("llm_server.opg_price_feed", level="DEBUG") as log_ctx: - # Emit a debug line ourselves so assertLogs doesn't raise on empty - logging.getLogger("llm_server.opg_price_feed").debug("sentinel") + with self.assertLogs("llm_server.price_feed", level="DEBUG") as log_ctx: + logging.getLogger("llm_server.price_feed").debug("sentinel") feed.get_price() warning_lines = [ - line for line in log_ctx.output if "WARNING" in line and "stale" in line.lower() + line + for line in log_ctx.output + if "WARNING" in line and "stale" in line.lower() ] self.assertEqual(warning_lines, []) @@ -307,7 +296,7 @@ def test_initial_status_has_no_price(self): self.assertEqual(status["total_fetches"], 0) self.assertEqual(status["total_errors"], 0) - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.fetch_opg_price") def test_status_reflects_successful_fetch(self, mock_fetch): mock_fetch.return_value = SAMPLE_PRICE feed = OPGPriceFeed(retry_delay=0) @@ -318,7 +307,7 @@ def test_status_reflects_successful_fetch(self, mock_fetch): self.assertEqual(status["total_fetches"], 1) self.assertEqual(status["consecutive_failures"], 0) - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.fetch_opg_price") def test_status_reflects_failed_cycle(self, mock_fetch): mock_fetch.side_effect = ValueError("fail") feed = OPGPriceFeed(max_retries=1, retry_delay=0) @@ -329,12 +318,11 @@ def test_status_reflects_failed_cycle(self, mock_fetch): self.assertEqual(status["consecutive_failures"], 1) self.assertIsNotNone(status["last_error"]) - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") - def test_status_includes_refresh_interval(self, mock_fetch): + def test_status_includes_refresh_interval(self): feed = OPGPriceFeed(refresh_interval=600) self.assertEqual(feed.get_status()["refresh_interval"], 600) - @patch("tee_gateway.opg_price_feed._fetch_from_coingecko") + @patch(f"{_FEED}.fetch_opg_price") def test_status_accumulates_multiple_error_cycles(self, mock_fetch): mock_fetch.side_effect = ValueError("fail") feed = OPGPriceFeed(max_retries=1, retry_delay=0) @@ -355,7 +343,7 @@ class TestModuleLevelFunctions(unittest.TestCase): """Tests for the module-level singleton helpers.""" def test_get_opg_price_usd_raises_when_feed_is_none(self): - with patch("tee_gateway.opg_price_feed._feed", None): + with patch(f"{_FEED}._feed", None): with self.assertRaises(ValueError) as ctx: get_opg_price_usd() self.assertIn("not been started", str(ctx.exception)) @@ -363,13 +351,13 @@ def test_get_opg_price_usd_raises_when_feed_is_none(self): def test_get_opg_price_usd_delegates_to_feed(self): mock_feed = MagicMock() mock_feed.get_price.return_value = SAMPLE_PRICE - with patch("tee_gateway.opg_price_feed._feed", mock_feed): + with patch(f"{_FEED}._feed", mock_feed): price = get_opg_price_usd() self.assertEqual(price, SAMPLE_PRICE) mock_feed.get_price.assert_called_once() def test_get_price_feed_status_when_feed_is_none(self): - with patch("tee_gateway.opg_price_feed._feed", None): + with patch(f"{_FEED}._feed", None): status = get_price_feed_status() self.assertEqual(status, {"status": "not_started"}) @@ -377,7 +365,7 @@ def test_get_price_feed_status_delegates_to_feed(self): expected = {"price_usd": 0.042, "total_fetches": 5} mock_feed = MagicMock() mock_feed.get_status.return_value = expected - with patch("tee_gateway.opg_price_feed._feed", mock_feed): + with patch(f"{_FEED}._feed", mock_feed): status = get_price_feed_status() self.assertEqual(status, expected) mock_feed.get_status.assert_called_once() diff --git a/tee_gateway/test/test_opg_price_feed_integration.py b/tee_gateway/test/test_price_feed_integration.py similarity index 91% rename from tee_gateway/test/test_opg_price_feed_integration.py rename to tee_gateway/test/test_price_feed_integration.py index 2e35bd6..363a311 100644 --- a/tee_gateway/test/test_opg_price_feed_integration.py +++ b/tee_gateway/test/test_price_feed_integration.py @@ -1,5 +1,5 @@ """ -Integration tests for tee_gateway.opg_price_feed. +Integration tests for tee_gateway.price_feed. These tests make REAL network calls to the CoinGecko public API. @@ -13,7 +13,7 @@ Run with:: - uv run pytest tee_gateway/test/test_opg_price_feed_integration.py -v + uv run pytest tee_gateway/test/test_price_feed_integration.py -v """ import unittest @@ -22,12 +22,12 @@ import requests from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS -from tee_gateway.opg_price_feed import ( +from tee_gateway.price_feed.config import ( COINGECKO_BASE_URL, COINGECKO_PLATFORM, FETCH_TIMEOUT, - _fetch_from_coingecko, ) +from tee_gateway.price_feed.feed import fetch_opg_price def _get(url: str, **kwargs) -> requests.Response: @@ -57,7 +57,9 @@ def test_base_platform_endpoint_returns_200(self): url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" # USDC on Base mainnet — reliably indexed on CoinGecko. usdc_base = "0x833589fcd6edb6e08f4c7c32d4f71b54bda02913" - resp = _get(url, params={"contract_addresses": usdc_base, "vs_currencies": "usd"}) + resp = _get( + url, params={"contract_addresses": usdc_base, "vs_currencies": "usd"} + ) self.assertEqual( resp.status_code, 200, @@ -103,9 +105,9 @@ def test_opg_response_structure(self): self.assertIsInstance(price_entry["usd"], (int, float)) def test_opg_price_fetch_live(self): - """End-to-end: _fetch_from_coingecko() returns a positive Decimal price.""" + """End-to-end: fetch_opg_price() returns a positive Decimal price.""" try: - price = _fetch_from_coingecko() + price = fetch_opg_price() except requests.exceptions.HTTPError as exc: if exc.response is not None and exc.response.status_code == 429: self.skipTest("CoinGecko rate limit — re-run after a short wait") diff --git a/tee_gateway/util.py b/tee_gateway/util.py index a242a5c..443c540 100644 --- a/tee_gateway/util.py +++ b/tee_gateway/util.py @@ -160,7 +160,7 @@ def _deserialize_dict(data, boxed_type): def get_token_a_price_usd() -> Decimal: - from tee_gateway.opg_price_feed import get_opg_price_usd # noqa: PLC0415 + from tee_gateway.price_feed import get_opg_price_usd # noqa: PLC0415 return get_opg_price_usd() From 69f715026f740bdbfa5ef42f3b8335410a1c859f Mon Sep 17 00:00:00 2001 From: kylexqian Date: Sun, 19 Apr 2026 22:08:13 -0700 Subject: [PATCH 03/17] refactor: dependency injection for OPG price feed via make_cost_calculator Replace module-level singleton in price_feed with a factory/closure pattern. make_cost_calculator(price_feed) in util.py binds an OPGPriceFeed instance explicitly, eliminating hidden global state. Tests updated: TestModuleLevelFunctions removed, TestMakeCostCalculator added (10 cases). Also fix missing Any import in feed.py and unused PriceFeedConfig import. Co-Authored-By: Claude Sonnet 4.6 --- tee_gateway/__main__.py | 9 +- tee_gateway/price_feed/__init__.py | 10 +- tee_gateway/price_feed/feed.py | 57 +-------- tee_gateway/test/test_price_feed.py | 179 ++++++++++++++++++++++------ tee_gateway/util.py | 108 +++++++++-------- 5 files changed, 210 insertions(+), 153 deletions(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index ab84846..4d593a8 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -36,8 +36,8 @@ import x402.http.middleware.flask as x402_flask import types as _types -from .util import dynamic_session_cost_calculator -from .price_feed import start_price_feed +from .util import make_cost_calculator +from .price_feed import OPGPriceFeed from .definitions import ( EVM_PAYMENT_ADDRESS, BASE_MAINNET_NETWORK, @@ -111,7 +111,8 @@ def _shutdown_heartbeat(): # OPG price feed — start before x402 middleware so the first request can be # priced correctly. Runs as a daemon thread; no cleanup needed on exit. # --------------------------------------------------------------------------- -start_price_feed() +_price_feed = OPGPriceFeed() +_price_feed.start() facilitator = HTTPFacilitatorClientSync(FacilitatorConfig(url=FACILITATOR_URL)) server = x402ResourceServerSync(facilitator) @@ -387,7 +388,7 @@ def _patched_read_body_bytes(environ): session_store=store, cost_per_request=100000000000000, # static precheck/fallback estimate session_idle_timeout=100, - session_cost_calculator=dynamic_session_cost_calculator, + session_cost_calculator=make_cost_calculator(_price_feed), ) # --------------------------------------------------------------------------- diff --git a/tee_gateway/price_feed/__init__.py b/tee_gateway/price_feed/__init__.py index 13aa282..1349825 100644 --- a/tee_gateway/price_feed/__init__.py +++ b/tee_gateway/price_feed/__init__.py @@ -1,15 +1,7 @@ from .config import PriceFeedConfig -from .feed import ( - OPGPriceFeed, - get_opg_price_usd, - get_price_feed_status, - start_price_feed, -) +from .feed import OPGPriceFeed __all__ = [ "OPGPriceFeed", "PriceFeedConfig", - "get_opg_price_usd", - "get_price_feed_status", - "start_price_feed", ] diff --git a/tee_gateway/price_feed/feed.py b/tee_gateway/price_feed/feed.py index 078df96..1906860 100644 --- a/tee_gateway/price_feed/feed.py +++ b/tee_gateway/price_feed/feed.py @@ -7,13 +7,9 @@ Usage ----- -Call ``start_price_feed()`` once during application startup (e.g. in -``__main__.py``). The dynamic cost calculator in ``util.py`` then calls -``get_opg_price_usd()`` to obtain the latest cached price. If no price has -been fetched yet, ``get_opg_price_usd()`` raises ``ValueError``, which -propagates through ``dynamic_session_cost_calculator`` and the strict -``_resolve_session_request_cost`` monkey-patch in ``__main__.py`` to produce -an HTTP 500 rather than silently charging an incorrect amount. +Create an ``OPGPriceFeed`` instance in the application entry point, call +``start()``, then pass it explicitly to wherever the price is needed (e.g. +``make_cost_calculator`` in ``util.py``). """ import logging @@ -33,7 +29,6 @@ DEFAULT_RETRY_DELAY, FETCH_TIMEOUT, STALE_WARNING_MULTIPLIER, - PriceFeedConfig, ) logger = logging.getLogger("llm_server.price_feed") @@ -228,49 +223,3 @@ def fetch_opg_price() -> Decimal: ) return Decimal(str(price_entry["usd"])) - - -# --------------------------------------------------------------------------- -# Module-level singleton — initialised by start_price_feed() -# --------------------------------------------------------------------------- - -_feed: Optional[OPGPriceFeed] = None - - -def start_price_feed(config: Optional[PriceFeedConfig] = None) -> None: - """Create and start the global OPG price feed. Call once at app startup.""" - global _feed - if _feed is not None: - logger.info("OPG price feed already running, skipping") - return - cfg = config or PriceFeedConfig() - _feed = OPGPriceFeed( - refresh_interval=cfg.refresh_interval, - max_retries=cfg.max_retries, - retry_delay=cfg.retry_delay, - ) - _feed.start() - - -def get_opg_price_usd() -> Decimal: - """Return the current OPG/USD price from the running price feed. - - Raises ``ValueError`` if the feed has not been started or has not yet - completed a successful fetch. - """ - if _feed is None: - raise ValueError( - "OPG price feed has not been started — " - "call start_price_feed() at app startup" - ) - return _feed.get_price() - - -def get_price_feed_status() -> dict[str, Any]: - """Return the current health snapshot of the price feed. - - Returns ``{"status": "not_started"}`` if the feed has never been started. - """ - if _feed is None: - return {"status": "not_started"} - return _feed.get_status() diff --git a/tee_gateway/test/test_price_feed.py b/tee_gateway/test/test_price_feed.py index 92c435e..f43555c 100644 --- a/tee_gateway/test/test_price_feed.py +++ b/tee_gateway/test/test_price_feed.py @@ -1,5 +1,5 @@ """ -Unit tests for tee_gateway.price_feed. +Unit tests for tee_gateway.price_feed and tee_gateway.util.make_cost_calculator. All external HTTP calls are mocked — no network access required. @@ -9,7 +9,7 @@ TestOPGPriceFeedRefresh — OPGPriceFeed._refresh_price() (retry, rate-limit, stats) TestOPGPriceFeedGetPrice — OPGPriceFeed.get_price() (stale warning, ValueError before fetch) TestOPGPriceFeedStatus — OPGPriceFeed.get_status() snapshots -TestModuleLevelFunctions — start_price_feed() / get_opg_price_usd() / get_price_feed_status() +TestMakeCostCalculator — make_cost_calculator() factory and the returned closure """ import time @@ -20,12 +20,9 @@ import requests from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS -from tee_gateway.price_feed import ( - OPGPriceFeed, - get_opg_price_usd, - get_price_feed_status, -) +from tee_gateway.price_feed import OPGPriceFeed from tee_gateway.price_feed.feed import fetch_opg_price +from tee_gateway.util import make_cost_calculator # --------------------------------------------------------------------------- # Helpers @@ -335,40 +332,154 @@ def test_status_accumulates_multiple_error_cycles(self, mock_fetch): # --------------------------------------------------------------------------- -# TestModuleLevelFunctions +# TestMakeCostCalculator # --------------------------------------------------------------------------- - -class TestModuleLevelFunctions(unittest.TestCase): - """Tests for the module-level singleton helpers.""" - - def test_get_opg_price_usd_raises_when_feed_is_none(self): - with patch(f"{_FEED}._feed", None): - with self.assertRaises(ValueError) as ctx: - get_opg_price_usd() - self.assertIn("not been started", str(ctx.exception)) - - def test_get_opg_price_usd_delegates_to_feed(self): +_ASSET_ADDR = "0xdeadbeef" +_ASSET_ADDR_LOWER = _ASSET_ADDR.lower() +_ASSET_DECIMALS = 18 + + +def _make_payment_requirements(asset: str = _ASSET_ADDR) -> dict: + return {"asset": asset, "price": {"amount": "1000000000000000000", "asset": asset}} + + +def _make_context( + model: str = "gpt-4.1-mini", + input_tokens: int = 100, + output_tokens: int = 50, + price_usd: Decimal = Decimal("0.10"), + asset: str = _ASSET_ADDR, +) -> dict: + return { + "request_json": {"model": model}, + "response_json": { + "model": model, + "usage": { + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + }, + }, + "payment_requirements": _make_payment_requirements(asset), + "method": "POST", + "path": "/v1/chat/completions", + "status_code": 200, + "is_streaming": False, + "request_body_bytes": b"", + "response_body_bytes": b"", + "default_cost": 10**18, + } + + +class TestMakeCostCalculator(unittest.TestCase): + """Tests for make_cost_calculator() and the returned closure.""" + + def _calculator(self, price_usd: Decimal = Decimal("0.10")): mock_feed = MagicMock() - mock_feed.get_price.return_value = SAMPLE_PRICE - with patch(f"{_FEED}._feed", mock_feed): - price = get_opg_price_usd() - self.assertEqual(price, SAMPLE_PRICE) + mock_feed.get_price.return_value = price_usd + return make_cost_calculator(mock_feed), mock_feed + + def _patch_definitions(self): + """Patch ASSET_DECIMALS_BY_ADDRESS so the test asset is recognised.""" + return patch( + "tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", + {_ASSET_ADDR_LOWER: _ASSET_DECIMALS}, + ) + + def _patch_model( + self, input_price: str = "0.000001", output_price: str = "0.000002" + ): + """Patch get_model_config to return a predictable pricing config.""" + cfg = MagicMock() + cfg.input_price_usd = Decimal(input_price) + cfg.output_price_usd = Decimal(output_price) + return patch("tee_gateway.util.get_model_config", return_value=cfg) + + def test_calls_price_feed_get_price(self): + calc, mock_feed = self._calculator() + with self._patch_definitions(), self._patch_model(): + calc(_make_context()) mock_feed.get_price.assert_called_once() - def test_get_price_feed_status_when_feed_is_none(self): - with patch(f"{_FEED}._feed", None): - status = get_price_feed_status() - self.assertEqual(status, {"status": "not_started"}) + def test_returns_positive_int(self): + calc, _ = self._calculator() + with self._patch_definitions(), self._patch_model(): + result = calc(_make_context()) + self.assertIsInstance(result, int) + self.assertGreaterEqual(result, 0) + + def test_zero_tokens_returns_zero(self): + calc, _ = self._calculator() + with self._patch_definitions(), self._patch_model(): + result = calc(_make_context(input_tokens=0, output_tokens=0)) + self.assertEqual(result, 0) - def test_get_price_feed_status_delegates_to_feed(self): - expected = {"price_usd": 0.042, "total_fetches": 5} + def test_raises_when_price_feed_raises(self): mock_feed = MagicMock() - mock_feed.get_status.return_value = expected - with patch(f"{_FEED}._feed", mock_feed): - status = get_price_feed_status() - self.assertEqual(status, expected) - mock_feed.get_status.assert_called_once() + mock_feed.get_price.side_effect = ValueError("price not available") + calc = make_cost_calculator(mock_feed) + with self._patch_definitions(), self._patch_model(): + with self.assertRaises( + ValueError, msg="error from get_price must propagate" + ): + calc(_make_context()) + + def test_raises_when_non_positive_price(self): + calc, _ = self._calculator(price_usd=Decimal("0")) + with self._patch_definitions(), self._patch_model(): + with self.assertRaises(ValueError): + calc(_make_context()) + + def test_raises_when_request_json_missing(self): + calc, _ = self._calculator() + ctx = _make_context() + ctx["request_json"] = None + with self._patch_definitions(), self._patch_model(): + with self.assertRaises(ValueError): + calc(ctx) + + def test_raises_when_usage_missing(self): + calc, _ = self._calculator() + ctx = _make_context() + ctx["response_json"] = {"model": "gpt-4.1-mini"} # no usage key + with self._patch_definitions(), self._patch_model(): + with self.assertRaises(ValueError): + calc(ctx) + + def test_raises_when_asset_unknown(self): + calc, _ = self._calculator() + ctx = _make_context(asset="0xunknown") + with ( + patch("tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", {}), + self._patch_model(), + ): + with self.assertRaises(ValueError): + calc(ctx) + + def test_cost_scales_with_token_count(self): + calc, _ = self._calculator() + with self._patch_definitions(), self._patch_model(): + cost_small = calc(_make_context(input_tokens=10, output_tokens=5)) + cost_large = calc(_make_context(input_tokens=1000, output_tokens=500)) + self.assertGreater(cost_large, cost_small) + + def test_each_call_uses_independent_closure(self): + feed_a = MagicMock() + feed_a.get_price.return_value = Decimal("0.10") + feed_b = MagicMock() + feed_b.get_price.return_value = Decimal("0.20") + + calc_a = make_cost_calculator(feed_a) + calc_b = make_cost_calculator(feed_b) + + with self._patch_definitions(), self._patch_model(): + cost_a = calc_a(_make_context()) + cost_b = calc_b(_make_context()) + + # Higher token price → lower cost in smallest units for same USD amount. + self.assertGreater(cost_a, cost_b) + feed_a.get_price.assert_called_once() + feed_b.get_price.assert_called_once() if __name__ == "__main__": diff --git a/tee_gateway/util.py b/tee_gateway/util.py index 443c540..a72ba03 100644 --- a/tee_gateway/util.py +++ b/tee_gateway/util.py @@ -3,7 +3,7 @@ from tee_gateway import typing_utils import logging from decimal import Decimal, InvalidOperation, ROUND_CEILING -from typing import Any +from typing import Any, Callable, Protocol logger = logging.getLogger("llm_server.dynamic_pricing") @@ -159,12 +159,6 @@ def _deserialize_dict(data, boxed_type): from tee_gateway.model_registry import get_model_config # noqa: E402 -def get_token_a_price_usd() -> Decimal: - from tee_gateway.price_feed import get_opg_price_usd # noqa: PLC0415 - - return get_opg_price_usd() - - def _as_dict(value: Any) -> dict[str, Any] | None: if value is None: return None @@ -271,53 +265,63 @@ def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int: return ASSET_DECIMALS_BY_ADDRESS[asset_lower] -def dynamic_session_cost_calculator(context: dict[str, Any]) -> int: - """Compute UPTO per-request cost in token smallest units from actual usage. +class _PriceSource(Protocol): + """Structural type for anything that can provide a current token price.""" + + def get_price(self) -> Decimal: ... - Raises ValueError on any missing or unrecognised input — no silent fallback. + +def make_cost_calculator( + price_feed: _PriceSource, +) -> Callable[[dict[str, Any]], int]: + """Return a session cost calculator bound to the given price feed. + + The returned callable is passed directly to the x402 payment middleware as + ``session_cost_calculator``. Raising ``ValueError`` from the inner + function propagates through the strict monkey-patch in ``__main__.py`` and + produces an HTTP 500 rather than silently charging an incorrect amount. """ - request_json = context.get("request_json") - response_json = context.get("response_json") - if not isinstance(request_json, dict) or not isinstance(response_json, dict): - raise ValueError( - "dynamic_session_cost_calculator requires both request_json and response_json" + def dynamic_session_cost_calculator(context: dict[str, Any]) -> int: + request_json = context.get("request_json") + response_json = context.get("response_json") + + if not isinstance(request_json, dict) or not isinstance(response_json, dict): + raise ValueError( + "dynamic_session_cost_calculator requires both request_json and response_json" + ) + + model = _extract_model_from_context(request_json, response_json) + cfg = get_model_config(model) + input_tokens, output_tokens = _extract_usage_tokens(response_json) + + total_usd = (Decimal(input_tokens) * cfg.input_price_usd) + ( + Decimal(output_tokens) * cfg.output_price_usd + ) + token_price_usd = price_feed.get_price() + if token_price_usd <= 0: + raise ValueError(f"Token price is non-positive: {token_price_usd}") + + token_amount = total_usd / token_price_usd + decimals = _extract_asset_decimals_from_requirements( + context.get("payment_requirements") + ) + scale = Decimal(10) ** decimals + cost_smallest_units = int( + (token_amount * scale).to_integral_value(rounding=ROUND_CEILING) + ) + + logger.info( + "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d " + "total_usd=%s token_price_usd=%s decimals=%d cost=%d", + model, + input_tokens, + output_tokens, + str(total_usd), + str(token_price_usd), + decimals, + cost_smallest_units, ) + return max(0, cost_smallest_units) - model = _extract_model_from_context(request_json, response_json) - - # get_model_config raises ValueError for unknown models — no fallback - cfg = get_model_config(model) - - input_tokens, output_tokens = _extract_usage_tokens(response_json) - - input_rate = cfg.input_price_usd - output_rate = cfg.output_price_usd - - total_usd = (Decimal(input_tokens) * input_rate) + ( - Decimal(output_tokens) * output_rate - ) - token_price_usd = get_token_a_price_usd() - if token_price_usd <= 0: - raise ValueError(f"Token A price is non-positive: {token_price_usd}") - - token_amount = total_usd / token_price_usd - decimals = _extract_asset_decimals_from_requirements( - context.get("payment_requirements") - ) - scale = Decimal(10) ** decimals - cost_smallest_units = int( - (token_amount * scale).to_integral_value(rounding=ROUND_CEILING) - ) - - logger.info( - "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d total_usd=%s token_price_usd=%s decimals=%d cost=%d", - model, - input_tokens, - output_tokens, - str(total_usd), - str(token_price_usd), - decimals, - cost_smallest_units, - ) - return max(0, cost_smallest_units) + return dynamic_session_cost_calculator From 7d520805090fc8bdec917fdd3303abca375bac06 Mon Sep 17 00:00:00 2001 From: kylexqian Date: Sun, 19 Apr 2026 22:41:16 -0700 Subject: [PATCH 04/17] refactor: replace make_cost_calculator factory with calculate_session_cost + named function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove the factory/closure pattern and _PriceSource Protocol. util.py now exports calculate_session_cost(context, get_price) — a plain function that accepts a Callable[[], Decimal]. __main__.py wires it via a named _session_cost_calculator function that passes _price_feed.get_price directly. Tests updated to match the new signature. Co-Authored-By: Claude Sonnet 4.6 --- tee_gateway/__main__.py | 9 ++- tee_gateway/test/test_price_feed.py | 93 +++++++++++-------------- tee_gateway/util.py | 103 +++++++++++++--------------- 3 files changed, 93 insertions(+), 112 deletions(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 4d593a8..771efb6 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -36,7 +36,7 @@ import x402.http.middleware.flask as x402_flask import types as _types -from .util import make_cost_calculator +from .util import calculate_session_cost from .price_feed import OPGPriceFeed from .definitions import ( EVM_PAYMENT_ADDRESS, @@ -381,6 +381,11 @@ def _patched_read_body_bytes(environ): x402_flask._read_body_bytes = _patched_read_body_bytes + +def _session_cost_calculator(ctx: dict) -> int: + return calculate_session_cost(ctx, _price_feed.get_price) + + _payment_mw = payment_middleware( application, routes=routes, @@ -388,7 +393,7 @@ def _patched_read_body_bytes(environ): session_store=store, cost_per_request=100000000000000, # static precheck/fallback estimate session_idle_timeout=100, - session_cost_calculator=make_cost_calculator(_price_feed), + session_cost_calculator=_session_cost_calculator, ) # --------------------------------------------------------------------------- diff --git a/tee_gateway/test/test_price_feed.py b/tee_gateway/test/test_price_feed.py index f43555c..962abd6 100644 --- a/tee_gateway/test/test_price_feed.py +++ b/tee_gateway/test/test_price_feed.py @@ -9,7 +9,7 @@ TestOPGPriceFeedRefresh — OPGPriceFeed._refresh_price() (retry, rate-limit, stats) TestOPGPriceFeedGetPrice — OPGPriceFeed.get_price() (stale warning, ValueError before fetch) TestOPGPriceFeedStatus — OPGPriceFeed.get_status() snapshots -TestMakeCostCalculator — make_cost_calculator() factory and the returned closure +TestCalculateSessionCost — calculate_session_cost(context, get_price) in util.py """ import time @@ -22,7 +22,7 @@ from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS from tee_gateway.price_feed import OPGPriceFeed from tee_gateway.price_feed.feed import fetch_opg_price -from tee_gateway.util import make_cost_calculator +from tee_gateway.util import calculate_session_cost # --------------------------------------------------------------------------- # Helpers @@ -371,16 +371,15 @@ def _make_context( } -class TestMakeCostCalculator(unittest.TestCase): - """Tests for make_cost_calculator() and the returned closure.""" +def _make_get_price(price_usd: Decimal = Decimal("0.10")) -> MagicMock: + mock = MagicMock(return_value=price_usd) + return mock + - def _calculator(self, price_usd: Decimal = Decimal("0.10")): - mock_feed = MagicMock() - mock_feed.get_price.return_value = price_usd - return make_cost_calculator(mock_feed), mock_feed +class TestCalculateSessionCost(unittest.TestCase): + """Tests for calculate_session_cost(context, get_price).""" def _patch_definitions(self): - """Patch ASSET_DECIMALS_BY_ADDRESS so the test asset is recognised.""" return patch( "tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", {_ASSET_ADDR_LOWER: _ASSET_DECIMALS}, @@ -389,97 +388,83 @@ def _patch_definitions(self): def _patch_model( self, input_price: str = "0.000001", output_price: str = "0.000002" ): - """Patch get_model_config to return a predictable pricing config.""" cfg = MagicMock() cfg.input_price_usd = Decimal(input_price) cfg.output_price_usd = Decimal(output_price) return patch("tee_gateway.util.get_model_config", return_value=cfg) - def test_calls_price_feed_get_price(self): - calc, mock_feed = self._calculator() + def test_calls_get_price(self): + get_price = _make_get_price() with self._patch_definitions(), self._patch_model(): - calc(_make_context()) - mock_feed.get_price.assert_called_once() + calculate_session_cost(_make_context(), get_price) + get_price.assert_called_once() def test_returns_positive_int(self): - calc, _ = self._calculator() with self._patch_definitions(), self._patch_model(): - result = calc(_make_context()) + result = calculate_session_cost(_make_context(), _make_get_price()) self.assertIsInstance(result, int) self.assertGreaterEqual(result, 0) def test_zero_tokens_returns_zero(self): - calc, _ = self._calculator() with self._patch_definitions(), self._patch_model(): - result = calc(_make_context(input_tokens=0, output_tokens=0)) + result = calculate_session_cost( + _make_context(input_tokens=0, output_tokens=0), _make_get_price() + ) self.assertEqual(result, 0) - def test_raises_when_price_feed_raises(self): - mock_feed = MagicMock() - mock_feed.get_price.side_effect = ValueError("price not available") - calc = make_cost_calculator(mock_feed) + def test_raises_when_get_price_raises(self): + get_price = MagicMock(side_effect=ValueError("price not available")) with self._patch_definitions(), self._patch_model(): - with self.assertRaises( - ValueError, msg="error from get_price must propagate" - ): - calc(_make_context()) + with self.assertRaises(ValueError): + calculate_session_cost(_make_context(), get_price) def test_raises_when_non_positive_price(self): - calc, _ = self._calculator(price_usd=Decimal("0")) with self._patch_definitions(), self._patch_model(): with self.assertRaises(ValueError): - calc(_make_context()) + calculate_session_cost(_make_context(), _make_get_price(Decimal("0"))) def test_raises_when_request_json_missing(self): - calc, _ = self._calculator() ctx = _make_context() ctx["request_json"] = None with self._patch_definitions(), self._patch_model(): with self.assertRaises(ValueError): - calc(ctx) + calculate_session_cost(ctx, _make_get_price()) def test_raises_when_usage_missing(self): - calc, _ = self._calculator() ctx = _make_context() - ctx["response_json"] = {"model": "gpt-4.1-mini"} # no usage key + ctx["response_json"] = {"model": "gpt-4.1-mini"} with self._patch_definitions(), self._patch_model(): with self.assertRaises(ValueError): - calc(ctx) + calculate_session_cost(ctx, _make_get_price()) def test_raises_when_asset_unknown(self): - calc, _ = self._calculator() ctx = _make_context(asset="0xunknown") with ( patch("tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", {}), self._patch_model(), ): with self.assertRaises(ValueError): - calc(ctx) + calculate_session_cost(ctx, _make_get_price()) def test_cost_scales_with_token_count(self): - calc, _ = self._calculator() with self._patch_definitions(), self._patch_model(): - cost_small = calc(_make_context(input_tokens=10, output_tokens=5)) - cost_large = calc(_make_context(input_tokens=1000, output_tokens=500)) + cost_small = calculate_session_cost( + _make_context(input_tokens=10, output_tokens=5), _make_get_price() + ) + cost_large = calculate_session_cost( + _make_context(input_tokens=1000, output_tokens=500), _make_get_price() + ) self.assertGreater(cost_large, cost_small) - def test_each_call_uses_independent_closure(self): - feed_a = MagicMock() - feed_a.get_price.return_value = Decimal("0.10") - feed_b = MagicMock() - feed_b.get_price.return_value = Decimal("0.20") - - calc_a = make_cost_calculator(feed_a) - calc_b = make_cost_calculator(feed_b) - + def test_higher_token_price_yields_lower_cost(self): with self._patch_definitions(), self._patch_model(): - cost_a = calc_a(_make_context()) - cost_b = calc_b(_make_context()) - - # Higher token price → lower cost in smallest units for same USD amount. - self.assertGreater(cost_a, cost_b) - feed_a.get_price.assert_called_once() - feed_b.get_price.assert_called_once() + cost_cheap = calculate_session_cost( + _make_context(), _make_get_price(Decimal("0.10")) + ) + cost_expensive = calculate_session_cost( + _make_context(), _make_get_price(Decimal("0.20")) + ) + self.assertGreater(cost_cheap, cost_expensive) if __name__ == "__main__": diff --git a/tee_gateway/util.py b/tee_gateway/util.py index a72ba03..5bcb245 100644 --- a/tee_gateway/util.py +++ b/tee_gateway/util.py @@ -3,7 +3,7 @@ from tee_gateway import typing_utils import logging from decimal import Decimal, InvalidOperation, ROUND_CEILING -from typing import Any, Callable, Protocol +from typing import Any, Callable logger = logging.getLogger("llm_server.dynamic_pricing") @@ -265,63 +265,54 @@ def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int: return ASSET_DECIMALS_BY_ADDRESS[asset_lower] -class _PriceSource(Protocol): - """Structural type for anything that can provide a current token price.""" +def calculate_session_cost( + context: dict[str, Any], get_price: Callable[[], Decimal] +) -> int: + """Calculate the x402 session cost in token smallest units for a completed request. - def get_price(self) -> Decimal: ... - - -def make_cost_calculator( - price_feed: _PriceSource, -) -> Callable[[dict[str, Any]], int]: - """Return a session cost calculator bound to the given price feed. - - The returned callable is passed directly to the x402 payment middleware as - ``session_cost_calculator``. Raising ``ValueError`` from the inner - function propagates through the strict monkey-patch in ``__main__.py`` and - produces an HTTP 500 rather than silently charging an incorrect amount. + ``get_price`` is called on every invocation to fetch the current OPG/USD + price — pass ``price_feed.get_price`` so the latest cached value is used. + Raises ``ValueError`` on any missing/invalid data; the strict monkey-patch + in ``__main__.py`` propagates this as HTTP 500 rather than silently + charging the static fallback amount. """ + request_json = context.get("request_json") + response_json = context.get("response_json") - def dynamic_session_cost_calculator(context: dict[str, Any]) -> int: - request_json = context.get("request_json") - response_json = context.get("response_json") - - if not isinstance(request_json, dict) or not isinstance(response_json, dict): - raise ValueError( - "dynamic_session_cost_calculator requires both request_json and response_json" - ) - - model = _extract_model_from_context(request_json, response_json) - cfg = get_model_config(model) - input_tokens, output_tokens = _extract_usage_tokens(response_json) - - total_usd = (Decimal(input_tokens) * cfg.input_price_usd) + ( - Decimal(output_tokens) * cfg.output_price_usd - ) - token_price_usd = price_feed.get_price() - if token_price_usd <= 0: - raise ValueError(f"Token price is non-positive: {token_price_usd}") - - token_amount = total_usd / token_price_usd - decimals = _extract_asset_decimals_from_requirements( - context.get("payment_requirements") - ) - scale = Decimal(10) ** decimals - cost_smallest_units = int( - (token_amount * scale).to_integral_value(rounding=ROUND_CEILING) - ) - - logger.info( - "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d " - "total_usd=%s token_price_usd=%s decimals=%d cost=%d", - model, - input_tokens, - output_tokens, - str(total_usd), - str(token_price_usd), - decimals, - cost_smallest_units, + if not isinstance(request_json, dict) or not isinstance(response_json, dict): + raise ValueError( + "calculate_session_cost requires both request_json and response_json" ) - return max(0, cost_smallest_units) - return dynamic_session_cost_calculator + model = _extract_model_from_context(request_json, response_json) + cfg = get_model_config(model) + input_tokens, output_tokens = _extract_usage_tokens(response_json) + + total_usd = (Decimal(input_tokens) * cfg.input_price_usd) + ( + Decimal(output_tokens) * cfg.output_price_usd + ) + token_price_usd = get_price() + if token_price_usd <= 0: + raise ValueError(f"Token price is non-positive: {token_price_usd}") + + token_amount = total_usd / token_price_usd + decimals = _extract_asset_decimals_from_requirements( + context.get("payment_requirements") + ) + scale = Decimal(10) ** decimals + cost_smallest_units = int( + (token_amount * scale).to_integral_value(rounding=ROUND_CEILING) + ) + + logger.info( + "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d " + "total_usd=%s token_price_usd=%s decimals=%d cost=%d", + model, + input_tokens, + output_tokens, + str(total_usd), + str(token_price_usd), + decimals, + cost_smallest_units, + ) + return max(0, cost_smallest_units) From c0d8373895ca4b5080e48f321aa5d431bc928211 Mon Sep 17 00:00:00 2001 From: kylexqian Date: Sun, 19 Apr 2026 22:43:05 -0700 Subject: [PATCH 05/17] feat: include price feed status in /health response Co-Authored-By: Claude Sonnet 4.6 --- tee_gateway/__main__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 771efb6..7336a21 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -310,6 +310,7 @@ def health(): "status": "OK", "version": "1.0.0", "tee_enabled": True, + "price_feed": _price_feed.get_status(), }, 200 From f1b6526a9a08a566aab0d2431229ff419b161c3d Mon Sep 17 00:00:00 2001 From: kylexqian Date: Sun, 19 Apr 2026 22:50:02 -0700 Subject: [PATCH 06/17] test: verify calculate_session_cost fetches live price on every call Co-Authored-By: Claude Sonnet 4.6 --- tee_gateway/test/test_price_feed.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tee_gateway/test/test_price_feed.py b/tee_gateway/test/test_price_feed.py index 962abd6..0b08a78 100644 --- a/tee_gateway/test/test_price_feed.py +++ b/tee_gateway/test/test_price_feed.py @@ -466,6 +466,16 @@ def test_higher_token_price_yields_lower_cost(self): ) self.assertGreater(cost_cheap, cost_expensive) + def test_uses_current_price_on_each_call(self): + """get_price is called fresh every invocation — price changes are picked up.""" + get_price = MagicMock(side_effect=[Decimal("0.10"), Decimal("0.20")]) + with self._patch_definitions(), self._patch_model(): + cost_first = calculate_session_cost(_make_context(), get_price) + cost_second = calculate_session_cost(_make_context(), get_price) + self.assertEqual(get_price.call_count, 2) + # Price doubled → cost should halve (same USD spend, twice the token price). + self.assertGreater(cost_first, cost_second) + if __name__ == "__main__": unittest.main() From 092f14d8ee7b4e1e68df961bc0cc2921039fa2b3 Mon Sep 17 00:00:00 2001 From: kylexqian Date: Sun, 19 Apr 2026 23:04:03 -0700 Subject: [PATCH 07/17] fix: update test_pricing.py for calculate_session_cost signature change Replace dynamic_session_cost_calculator import with calculate_session_cost and pass _get_price (OPG=$1.00) to all call sites. Co-Authored-By: Claude Sonnet 4.6 --- tests/test_pricing.py | 46 ++++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/tests/test_pricing.py b/tests/test_pricing.py index d1b5f25..37507eb 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -3,7 +3,7 @@ Tests verify that: - Every user-facing model name resolves to the correct ModelConfig - - dynamic_session_cost_calculator produces the right amount in OPG token + - calculate_session_cost produces the right amount in OPG token smallest-units for supported models - Edge cases (no usage, unknown model, bad context) are handled correctly """ @@ -16,7 +16,11 @@ _MODEL_LOOKUP, get_model_config, ) -from tee_gateway.util import dynamic_session_cost_calculator +from tee_gateway.util import calculate_session_cost + +# All pricing tests assume OPG = $1.00 so USD cost == OPG token amount. +_OPG_PRICE_USD = Decimal("1") +_get_price = lambda: _OPG_PRICE_USD # noqa: E731 # --------------------------------------------------------------------------- @@ -209,8 +213,8 @@ class TestDynamicSessionCostCalculatorOPG(unittest.TestCase): """dynamic_session_cost_calculator with OPG (18 decimals).""" def _calc(self, model, input_tokens, output_tokens): - return dynamic_session_cost_calculator( - _ctx(model, input_tokens, output_tokens, _opg_requirements()) + return calculate_session_cost( + _ctx(model, input_tokens, output_tokens, _opg_requirements()), _get_price ) # ── OpenAI ────────────────────────────────────────────────────────────── @@ -355,7 +359,7 @@ class TestDynamicSessionCostCalculatorEdgeCases(unittest.TestCase): """Edge cases for dynamic_session_cost_calculator.""" def test_zero_tokens_returns_zero(self): - cost = dynamic_session_cost_calculator(_ctx("claude-sonnet-4-5", 0, 0)) + cost = calculate_session_cost(_ctx("claude-sonnet-4-5", 0, 0), _get_price) self.assertEqual(cost, 0) def test_missing_usage_raises(self): @@ -365,24 +369,24 @@ def test_missing_usage_raises(self): "payment_requirements": _opg_requirements(), } with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_unknown_asset_raises(self): ctx = _ctx("claude-sonnet-4-5", 100, 100) ctx["payment_requirements"] = {"asset": "0xdeadbeef", "amount": "1000"} with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_missing_asset_raises(self): ctx = _ctx("claude-sonnet-4-5", 100, 100) ctx["payment_requirements"] = {"amount": "1000"} # no asset with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_unknown_model_raises_value_error(self): ctx = _ctx("gpt-4o", 100, 100) with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_missing_request_json_raises_value_error(self): ctx = { @@ -394,7 +398,7 @@ def test_missing_request_json_raises_value_error(self): "payment_requirements": _opg_requirements(), } with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_model_from_request_takes_priority(self): """request_json model name is used even if response_json has a different model.""" @@ -406,7 +410,7 @@ def test_model_from_request_takes_priority(self): }, "payment_requirements": _opg_requirements(), } - cost = dynamic_session_cost_calculator(ctx) + cost = calculate_session_cost(ctx, _get_price) # Should be priced as Haiku (from request), not Sonnet haiku_cost = _expected_cost_opg("claude-haiku-4-5", 1000, 500) self.assertEqual(cost, haiku_cost) @@ -414,29 +418,31 @@ def test_model_from_request_takes_priority(self): def test_rounding_ceiling(self): """Fractional token costs are always rounded UP.""" # 1 output token of Haiku: 0.000005 USD = 5e12 wei — exact, no rounding needed - cost = dynamic_session_cost_calculator(_ctx("claude-haiku-4-5", 0, 1)) + cost = calculate_session_cost(_ctx("claude-haiku-4-5", 0, 1), _get_price) self.assertEqual(cost, 5_000_000_000_000) # 1 input token of Gemini Flash Lite: 0.0000001 USD = 1e11 wei — exact - cost = dynamic_session_cost_calculator(_ctx("gemini-2.5-flash-lite", 1, 0)) + cost = calculate_session_cost(_ctx("gemini-2.5-flash-lite", 1, 0), _get_price) self.assertEqual(cost, 100_000_000_000) def test_model_name_case_insensitive(self): """Model names are normalized to lowercase before lookup.""" - cost_lower = dynamic_session_cost_calculator( - _ctx("claude-sonnet-4-5", 100, 100) + cost_lower = calculate_session_cost( + _ctx("claude-sonnet-4-5", 100, 100), _get_price ) - cost_upper = dynamic_session_cost_calculator( - _ctx("CLAUDE-SONNET-4-5", 100, 100) + cost_upper = calculate_session_cost( + _ctx("CLAUDE-SONNET-4-5", 100, 100), _get_price ) self.assertEqual(cost_lower, cost_upper) def test_sonnet_4_0_hyphen_vs_dot_same_cost(self): """claude-sonnet-4-0 and claude-4.0-sonnet are the same model.""" - cost_hyphen = dynamic_session_cost_calculator( - _ctx("claude-sonnet-4-0", 1000, 500) + cost_hyphen = calculate_session_cost( + _ctx("claude-sonnet-4-0", 1000, 500), _get_price + ) + cost_dot = calculate_session_cost( + _ctx("claude-4.0-sonnet", 1000, 500), _get_price ) - cost_dot = dynamic_session_cost_calculator(_ctx("claude-4.0-sonnet", 1000, 500)) self.assertEqual(cost_hyphen, cost_dot) From 996a950d6d831cdbb3043417d0384bcec915f44a Mon Sep 17 00:00:00 2001 From: kylexqian Date: Sun, 19 Apr 2026 23:11:11 -0700 Subject: [PATCH 08/17] fix: add attr-defined to type: ignore on monkey-patch line Co-Authored-By: Claude Sonnet 4.6 --- tee_gateway/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 7336a21..916de83 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -487,7 +487,7 @@ def _strict_resolve_session_request_cost( return self._coerce_non_negative_int(dynamic_cost) -_payment_mw._resolve_session_request_cost = _types.MethodType( # type: ignore[method-assign] +_payment_mw._resolve_session_request_cost = _types.MethodType( # type: ignore[method-assign,attr-defined] _strict_resolve_session_request_cost, _payment_mw ) From 3e1b81cd06db4dd9b9979c43512ca812a9f6796e Mon Sep 17 00:00:00 2001 From: kylexqian Date: Sun, 19 Apr 2026 23:25:38 -0700 Subject: [PATCH 09/17] fix: address Copilot review comments on price feed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - start() is now non-blocking: initial fetch runs inside the background thread instead of blocking the caller - start() is idempotent: duplicate calls are a no-op if thread is alive - Clear last_error on successful refresh so /health reflects recovery - Gate integration tests behind RUN_INTEGRATION_TESTS env var to prevent real CoinGecko calls in CI by default - Fix stale docstring references to make_cost_calculator → calculate_session_cost Co-Authored-By: Claude Sonnet 4.6 --- .github/workflows/test.yml | 3 +++ tee_gateway/price_feed/feed.py | 23 ++++++++++++------- tee_gateway/test/test_price_feed.py | 2 +- .../test/test_price_feed_integration.py | 4 ++++ 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 835c886..9c257b8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,3 +16,6 @@ jobs: - uses: astral-sh/setup-uv@v5 - name: Run unit tests run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tests/test_pricing.py -v --import-mode=importlib + # To also run integration tests (real CoinGecko network calls), add: + # env: + # RUN_INTEGRATION_TESTS: "1" diff --git a/tee_gateway/price_feed/feed.py b/tee_gateway/price_feed/feed.py index 1906860..904fe4e 100644 --- a/tee_gateway/price_feed/feed.py +++ b/tee_gateway/price_feed/feed.py @@ -9,7 +9,7 @@ ----- Create an ``OPGPriceFeed`` instance in the application entry point, call ``start()``, then pass it explicitly to wherever the price is needed (e.g. -``make_cost_calculator`` in ``util.py``). +``calculate_session_cost(...)`` in ``util.py``). """ import logging @@ -63,15 +63,20 @@ def __init__( # ------------------------------------------------------------------ def start(self) -> None: - """Perform an initial price fetch then launch the background refresh loop. + """Launch the background refresh loop, including the initial price fetch. - If the initial fetch fails after all retries the feed still starts — - ``get_price()`` will raise ``ValueError`` until the background loop - eventually succeeds. + The initial fetch runs inside the background thread so startup is + non-blocking. ``get_price()`` will raise ``ValueError`` until the + first fetch completes; any error propagates as HTTP 500 via the + strict cost-resolution patch in ``__main__.py``. + + Idempotent — calling ``start()`` on an already-running feed is a no-op. """ - self._refresh_price() + if self._thread is not None and self._thread.is_alive(): + logger.info("OPG price feed already running, ignoring duplicate start()") + return self._thread = threading.Thread( - target=self._run, name="opg-price-feed", daemon=True + target=self._run_with_initial_fetch, name="opg-price-feed", daemon=True ) self._thread.start() logger.info( @@ -126,7 +131,8 @@ def get_status(self) -> dict[str, Any]: # Internal helpers # ------------------------------------------------------------------ - def _run(self) -> None: + def _run_with_initial_fetch(self) -> None: + self._refresh_price() while True: time.sleep(self._refresh_interval) self._refresh_price() @@ -149,6 +155,7 @@ def _refresh_price(self) -> None: with self._lock: self._price = price self.last_success = time.time() + self.last_error = None self.consecutive_failures = 0 self.total_fetches += 1 logger.info( diff --git a/tee_gateway/test/test_price_feed.py b/tee_gateway/test/test_price_feed.py index 0b08a78..04c8c02 100644 --- a/tee_gateway/test/test_price_feed.py +++ b/tee_gateway/test/test_price_feed.py @@ -1,5 +1,5 @@ """ -Unit tests for tee_gateway.price_feed and tee_gateway.util.make_cost_calculator. +Unit tests for tee_gateway.price_feed and tee_gateway.util.calculate_session_cost. All external HTTP calls are mocked — no network access required. diff --git a/tee_gateway/test/test_price_feed_integration.py b/tee_gateway/test/test_price_feed_integration.py index 363a311..2da9db1 100644 --- a/tee_gateway/test/test_price_feed_integration.py +++ b/tee_gateway/test/test_price_feed_integration.py @@ -16,11 +16,15 @@ uv run pytest tee_gateway/test/test_price_feed_integration.py -v """ +import os import unittest from decimal import Decimal import requests +if not os.getenv("RUN_INTEGRATION_TESTS"): + raise unittest.SkipTest("Set RUN_INTEGRATION_TESTS=1 to run integration tests") + from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS from tee_gateway.price_feed.config import ( COINGECKO_BASE_URL, From 340476c836fe6d1a48ea9e327b7b8cdeee473b8e Mon Sep 17 00:00:00 2001 From: Kyle Qian Date: Sun, 19 Apr 2026 23:32:05 -0700 Subject: [PATCH 10/17] Update .github/workflows/test.yml Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9c257b8..273d4c9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 - name: Run unit tests - run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tests/test_pricing.py -v --import-mode=importlib + run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tee_gateway/test/test_price_feed.py tests/test_pricing.py -v --import-mode=importlib # To also run integration tests (real CoinGecko network calls), add: # env: # RUN_INTEGRATION_TESTS: "1" From be53f7649fbb358f2a85857a80e07fa83034a19b Mon Sep 17 00:00:00 2001 From: kylexqian Date: Sun, 19 Apr 2026 23:34:04 -0700 Subject: [PATCH 11/17] fix: update stale dynamic_session_cost_calculator references Rename test classes and update comment/error message in __main__.py to match the current calculate_session_cost / _session_cost_calculator names. Co-Authored-By: Claude Sonnet 4.6 --- tee_gateway/__main__.py | 4 ++-- tests/test_pricing.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 916de83..b48682d 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -415,7 +415,7 @@ def _session_cost_calculator(ctx: dict) -> int: # ------- # We replace _resolve_session_request_cost with our own implementation that is # identical to upstream, except the cost-calculator call is NOT wrapped in a -# try/except. Any exception from dynamic_session_cost_calculator() therefore +# try/except. Any exception from _session_cost_calculator() therefore # propagates up through the middleware and Flask, producing a proper HTTP 500 # response to the client instead of an incorrect silent charge. # --------------------------------------------------------------------------- @@ -481,7 +481,7 @@ def _strict_resolve_session_request_cost( dynamic_cost = self._session_cost_calculator(callback_context) if dynamic_cost is None: raise ValueError( - f"dynamic_session_cost_calculator returned None for {method} {path}; " + f"_session_cost_calculator returned None for {method} {path}; " "cannot determine request cost" ) return self._coerce_non_negative_int(dynamic_cost) diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 37507eb..5419782 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -209,8 +209,8 @@ def test_unknown_sonnet_variant_raises(self): # --------------------------------------------------------------------------- -class TestDynamicSessionCostCalculatorOPG(unittest.TestCase): - """dynamic_session_cost_calculator with OPG (18 decimals).""" +class TestCalculateSessionCostOPG(unittest.TestCase): + """calculate_session_cost with OPG (18 decimals).""" def _calc(self, model, input_tokens, output_tokens): return calculate_session_cost( @@ -355,8 +355,8 @@ def test_grok_4_fast_cheaper_than_grok_4(self): self.assertLess(fast, full) -class TestDynamicSessionCostCalculatorEdgeCases(unittest.TestCase): - """Edge cases for dynamic_session_cost_calculator.""" +class TestCalculateSessionCostEdgeCases(unittest.TestCase): + """Edge cases for calculate_session_cost.""" def test_zero_tokens_returns_zero(self): cost = calculate_session_cost(_ctx("claude-sonnet-4-5", 0, 0), _get_price) From 5891728e69a9378288902630756373a89ab09e43 Mon Sep 17 00:00:00 2001 From: kylexqian Date: Sun, 19 Apr 2026 23:58:17 -0700 Subject: [PATCH 12/17] feat: add pre-inference pricing gate and remove ineffective monkey-patch Replace the _strict_resolve_session_request_cost monkey-patch (which was patching a method not called in the upto session flow) with a proper before_request hook that rejects inference requests early if pricing would fail: - 503 if the OPG price feed has no valid price yet - 400 if the requested model is not in the registry _session_cost_calculator now logs CRITICAL with full traceback on any post-inference cost failure (e.g. missing usage field) so uncharged requests are never silently missed. Co-Authored-By: Claude Sonnet 4.6 --- tee_gateway/__main__.py | 126 ++++++++++++---------------------------- 1 file changed, 37 insertions(+), 89 deletions(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index b48682d..ab42e60 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -34,9 +34,9 @@ from x402.server import x402ResourceServerSync from x402.session import SessionStore import x402.http.middleware.flask as x402_flask -import types as _types from .util import calculate_session_cost +from .model_registry import get_model_config from .price_feed import OPGPriceFeed from .definitions import ( EVM_PAYMENT_ADDRESS, @@ -384,7 +384,18 @@ def _patched_read_body_bytes(environ): def _session_cost_calculator(ctx: dict) -> int: - return calculate_session_cost(ctx, _price_feed.get_price) + # Post-inference cost calculation — response already sent to client. + # Failures here (e.g. missing usage field in LLM response) cannot be + # returned as errors; log CRITICAL so they are never silently missed. + try: + return calculate_session_cost(ctx, _price_feed.get_price) + except Exception as exc: + logger.critical( + "Post-inference cost calculation failed — client was NOT charged: %s", + exc, + exc_info=True, + ) + raise _payment_mw = payment_middleware( @@ -398,98 +409,35 @@ def _session_cost_calculator(ctx: dict) -> int: ) # --------------------------------------------------------------------------- -# Strict cost-resolution patch -# -# Why this exists -# --------------- -# The upstream x402 PaymentMiddleware._resolve_session_request_cost wraps the -# call to the session_cost_calculator in a broad try/except. If the calculator -# raises (e.g. ValueError for an unrecognised model name, KeyError for missing -# usage data), the exception is swallowed and the middleware silently falls back -# to the static session maximum (CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND / -# CHAT_COMPLETIONS_USDC_AMOUNT). That silent fallback means: -# • The client is charged the full pre-check cap instead of actual usage. -# • The server has no visible indication that pricing failed. +# Pre-inference pricing gate # -# The fix -# ------- -# We replace _resolve_session_request_cost with our own implementation that is -# identical to upstream, except the cost-calculator call is NOT wrapped in a -# try/except. Any exception from _session_cost_calculator() therefore -# propagates up through the middleware and Flask, producing a proper HTTP 500 -# response to the client instead of an incorrect silent charge. +# In the upto session scheme the response is streamed to the client before +# cost is settled, so a post-inference pricing failure cannot be surfaced as +# an HTTP error. Instead we validate everything that can be checked up-front +# and reject the request early if pricing would fail: +# 1. Price feed has a valid OPG/USD price (CoinGecko fetch succeeded). +# 2. The requested model is in the registry (has a known per-token price). # --------------------------------------------------------------------------- -def _strict_resolve_session_request_cost( - self, - *, - method: str, - path: str, - request_body_bytes: bytes, - response_body_bytes: bytes, - payment_payload: object, - payment_requirements: object, - status_code: int | None, - output_object: object = None, - is_streaming: bool = False, -) -> int: - """Replacement for PaymentMiddleware._resolve_session_request_cost. - - Identical to the upstream implementation except that exceptions raised by - the dynamic cost calculator are NOT caught. This means a request whose - cost cannot be determined (unknown model, missing usage data, etc.) will - result in a 500 error rather than silently falling back to the static cap - amount and charging the user an incorrect amount. - """ - from x402.http.middleware.flask import _parse_json_bytes as _x402_parse_json # noqa: PLC0415 - - default_cost = self._get_session_cost(payment_requirements) - if not self._should_charge_response(status_code): - return default_cost - if not callable(self._session_cost_calculator): - return default_cost - - request_object = _x402_parse_json(request_body_bytes) - response_object = ( - output_object - if output_object is not None - else _x402_parse_json(response_body_bytes) - ) - - callback_context = { - "method": method, - "path": path, - "status_code": status_code, - "is_streaming": is_streaming, - "request_body_bytes": request_body_bytes, - "response_body_bytes": response_body_bytes, - "request_json": request_object - if isinstance(request_object, (dict, list)) - else None, - "response_json": response_object - if isinstance(response_object, (dict, list)) - else None, - "response_object": response_object, - "payment_payload": payment_payload, - "payment_requirements": payment_requirements, - "default_cost": default_cost, - } - - # Do NOT catch exceptions here — let them propagate so the request fails - # with a 500 rather than silently charging the static fallback amount. - dynamic_cost = self._session_cost_calculator(callback_context) - if dynamic_cost is None: - raise ValueError( - f"_session_cost_calculator returned None for {method} {path}; " - "cannot determine request cost" - ) - return self._coerce_non_negative_int(dynamic_cost) - +@application.before_request +def _check_pricing_ready(): + if request.path not in ("/v1/chat/completions", "/v1/completions"): + return + try: + _price_feed.get_price() + except ValueError as exc: + logger.warning("Rejecting inference request — price feed unavailable: %s", exc) + return jsonify({"error": f"Pricing unavailable: {exc}"}), 503 + + body = request.get_json(silent=True, cache=True) or {} + model = body.get("model") + if model: + try: + get_model_config(model) + except ValueError: + return jsonify({"error": f"Model '{model}' is not supported"}), 400 -_payment_mw._resolve_session_request_cost = _types.MethodType( # type: ignore[method-assign,attr-defined] - _strict_resolve_session_request_cost, _payment_mw -) logger.info("x402 payment middleware initialized") From ba2258fa9a4c2c62c9d1e60aba50444ebe192a4c Mon Sep 17 00:00:00 2001 From: kylexqian Date: Mon, 20 Apr 2026 00:14:29 -0700 Subject: [PATCH 13/17] fix: address Copilot review comments (docstrings, defensive coding, thread safety) - Update start() docstring: replace monkey-patch reference with pre-inference gate - Make fetch_opg_price() validate response.json() is a dict before calling .get(), raising ValueError instead of AttributeError on unexpected CoinGecko response shapes - Make start() idempotency check thread-safe under _lock - Clarify CRITICAL log: provider error, x402 swallows exception so client is not charged - Update calculate_session_cost docstring: replace monkey-patch reference with pre-inference gate and CRITICAL log behavior Co-Authored-By: Claude Sonnet 4.6 --- tee_gateway/__main__.py | 10 +++++++--- tee_gateway/price_feed/feed.py | 20 ++++++++++++++------ tee_gateway/util.py | 7 ++++--- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index ab42e60..7ad75d4 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -385,13 +385,17 @@ def _patched_read_body_bytes(environ): def _session_cost_calculator(ctx: dict) -> int: # Post-inference cost calculation — response already sent to client. - # Failures here (e.g. missing usage field in LLM response) cannot be - # returned as errors; log CRITICAL so they are never silently missed. + # Predictable failures (unknown price, unknown model) are blocked by the + # pre-inference gate; any exception here indicates a provider-side error + # (e.g. missing usage field in the LLM response). The x402 middleware + # swallows the exception in close(), so the client is not charged. + # Log CRITICAL so provider errors are never silently missed. try: return calculate_session_cost(ctx, _price_feed.get_price) except Exception as exc: logger.critical( - "Post-inference cost calculation failed — client was NOT charged: %s", + "Post-inference cost calculation failed (provider error) — " + "client was NOT charged: %s", exc, exc_info=True, ) diff --git a/tee_gateway/price_feed/feed.py b/tee_gateway/price_feed/feed.py index 904fe4e..85fc0f6 100644 --- a/tee_gateway/price_feed/feed.py +++ b/tee_gateway/price_feed/feed.py @@ -67,14 +67,18 @@ def start(self) -> None: The initial fetch runs inside the background thread so startup is non-blocking. ``get_price()`` will raise ``ValueError`` until the - first fetch completes; any error propagates as HTTP 500 via the - strict cost-resolution patch in ``__main__.py``. + first successful fetch completes; until then, inference requests are + rejected by the pre-inference pricing gate in ``__main__.py``. Idempotent — calling ``start()`` on an already-running feed is a no-op. + Thread-safe: the check-and-start is performed under ``_lock``. """ - if self._thread is not None and self._thread.is_alive(): - logger.info("OPG price feed already running, ignoring duplicate start()") - return + with self._lock: + if self._thread is not None and self._thread.is_alive(): + logger.info( + "OPG price feed already running, ignoring duplicate start()" + ) + return self._thread = threading.Thread( target=self._run_with_initial_fetch, name="opg-price-feed", daemon=True ) @@ -221,7 +225,11 @@ def fetch_opg_price() -> Decimal: response = requests.get(url, params=params, timeout=FETCH_TIMEOUT) response.raise_for_status() - data: dict = response.json() + data: Any = response.json() + if not isinstance(data, dict): + raise ValueError( + f"Unexpected CoinGecko response for {BASE_MAINNET_OPG_ADDRESS}: {data!r}" + ) # CoinGecko keys the result by the lowercased contract address. price_entry = data.get(BASE_MAINNET_OPG_ADDRESS.lower()) if not isinstance(price_entry, dict) or "usd" not in price_entry: diff --git a/tee_gateway/util.py b/tee_gateway/util.py index 5bcb245..1c9047e 100644 --- a/tee_gateway/util.py +++ b/tee_gateway/util.py @@ -272,9 +272,10 @@ def calculate_session_cost( ``get_price`` is called on every invocation to fetch the current OPG/USD price — pass ``price_feed.get_price`` so the latest cached value is used. - Raises ``ValueError`` on any missing/invalid data; the strict monkey-patch - in ``__main__.py`` propagates this as HTTP 500 rather than silently - charging the static fallback amount. + Raises ``ValueError`` on any missing/invalid data. Predictable failures + (unavailable price, unknown model) are blocked before inference by the + pre-inference gate in ``__main__.py``; post-inference failures are logged + as CRITICAL by the caller and the client is not charged. """ request_json = context.get("request_json") response_json = context.get("response_json") From 3b1da9ac1bf4e196e1e9dc2943d1db166a57ff20 Mon Sep 17 00:00:00 2001 From: kylexqian Date: Mon, 20 Apr 2026 14:27:47 -0700 Subject: [PATCH 14/17] feat: TGE fallback price and CoinGecko sanity checks Before the TGE cutover (2026-04-21 12:30 UTC) get_price() returns a fixed $0.10 fallback so inference requests can be priced before OPG is listed on CoinGecko. After the cutover the live cached price is used. Also adds a guard in fetch_opg_price() rejecting non-positive or non-finite prices from the API response. Co-Authored-By: Claude Sonnet 4.6 --- tee_gateway/price_feed/config.py | 11 +++++++++++ tee_gateway/price_feed/feed.py | 28 ++++++++++++++++++++++------ tee_gateway/test/test_price_feed.py | 23 ++++++++++++++++++++--- 3 files changed, 53 insertions(+), 9 deletions(-) diff --git a/tee_gateway/price_feed/config.py b/tee_gateway/price_feed/config.py index 46b2268..fe4c980 100644 --- a/tee_gateway/price_feed/config.py +++ b/tee_gateway/price_feed/config.py @@ -3,6 +3,8 @@ """ from dataclasses import dataclass +from datetime import datetime, timezone +from decimal import Decimal # --------------------------------------------------------------------------- @@ -19,6 +21,15 @@ DEFAULT_MAX_RETRIES = 3 # attempts per refresh cycle before giving up DEFAULT_RETRY_DELAY = 10 # seconds between retry attempts within a cycle +# --------------------------------------------------------------------------- +# TGE (Token Generation Event) fallback +# --------------------------------------------------------------------------- +# Before the TGE cutover, OPG is not yet listed on CoinGecko. Return a fixed +# fallback price so inference requests can be priced immediately at launch. +# After the cutover, the live CoinGecko price is used. +TGE_CUTOVER_UTC = datetime(2026, 4, 21, 12, 30, 0, tzinfo=timezone.utc) +TGE_FALLBACK_PRICE_USD = Decimal("0.10") + # --------------------------------------------------------------------------- # Stale-price warning threshold # --------------------------------------------------------------------------- diff --git a/tee_gateway/price_feed/feed.py b/tee_gateway/price_feed/feed.py index 85fc0f6..b3d7369 100644 --- a/tee_gateway/price_feed/feed.py +++ b/tee_gateway/price_feed/feed.py @@ -15,6 +15,7 @@ import logging import threading import time +from datetime import datetime, timezone from decimal import Decimal from typing import Any, Optional @@ -29,6 +30,8 @@ DEFAULT_RETRY_DELAY, FETCH_TIMEOUT, STALE_WARNING_MULTIPLIER, + TGE_CUTOVER_UTC, + TGE_FALLBACK_PRICE_USD, ) logger = logging.getLogger("llm_server.price_feed") @@ -92,12 +95,19 @@ def start(self) -> None: def get_price(self) -> Decimal: """Return the latest cached OPG/USD price. - Raises ``ValueError`` if no price has been successfully fetched yet. - Logs a warning (but still returns the price) if the cached value is - older than ``STALE_WARNING_MULTIPLIER * refresh_interval`` seconds — - this indicates the background loop has missed at least one refresh - cycle and may be experiencing persistent errors. + Before the TGE cutover (``TGE_CUTOVER_UTC``), returns the fixed + ``TGE_FALLBACK_PRICE_USD`` so requests can be priced before OPG is + listed on CoinGecko. After the cutover the live cached price is used. + + Raises ``ValueError`` if no price has been successfully fetched yet + (post-TGE only). Logs a warning (but still returns the price) if the + cached value is older than ``STALE_WARNING_MULTIPLIER * refresh_interval`` + seconds — this indicates the background loop has missed at least one + refresh cycle and may be experiencing persistent errors. """ + if datetime.now(timezone.utc) < TGE_CUTOVER_UTC: + return TGE_FALLBACK_PRICE_USD + now = time.time() with self._lock: if self._price is None: @@ -237,4 +247,10 @@ def fetch_opg_price() -> Decimal: f"Unexpected CoinGecko response for {BASE_MAINNET_OPG_ADDRESS}: {data!r}" ) - return Decimal(str(price_entry["usd"])) + price = Decimal(str(price_entry["usd"])) + if not price.is_finite() or price <= 0: + raise ValueError( + f"Invalid price from CoinGecko for {BASE_MAINNET_OPG_ADDRESS}: " + f"{price_entry['usd']!r}" + ) + return price diff --git a/tee_gateway/test/test_price_feed.py b/tee_gateway/test/test_price_feed.py index 04c8c02..b8eea81 100644 --- a/tee_gateway/test/test_price_feed.py +++ b/tee_gateway/test/test_price_feed.py @@ -14,6 +14,7 @@ import time import unittest +from datetime import datetime, timezone from decimal import Decimal from unittest.mock import MagicMock, patch @@ -35,6 +36,9 @@ # Patch target prefix — all mocks go through the feed module. _FEED = "tee_gateway.price_feed.feed" +# A datetime well after the TGE cutover so get_price() uses the cached price. +_POST_TGE = datetime(2026, 4, 22, 0, 0, 0, tzinfo=timezone.utc) + def _mock_response(status_code: int = 200, json_body: dict | None = None) -> MagicMock: """Build a minimal mock requests.Response.""" @@ -104,6 +108,13 @@ def test_raises_when_usd_key_missing(self, mock_get): with self.assertRaises(ValueError): fetch_opg_price() + @patch(f"{_FEED}.requests.get") + def test_raises_when_address_entry_is_empty_dict(self, mock_get): + """CoinGecko returns {address: {}} for known-but-unpriced tokens (current OPG behaviour).""" + mock_get.return_value = _mock_response(200, {OPG_ADDRESS_LOWER: {}}) + with self.assertRaises(ValueError, msg="empty price entry should raise"): + fetch_opg_price() + @patch(f"{_FEED}.requests.get") def test_raises_on_network_error(self, mock_get): mock_get.side_effect = requests.exceptions.ConnectionError("timeout") @@ -221,22 +232,28 @@ def test_no_sleep_after_last_failed_attempt(self, mock_fetch, mock_sleep): class TestOPGPriceFeedGetPrice(unittest.TestCase): """Tests for OPGPriceFeed.get_price() behaviour.""" - def test_raises_before_any_successful_fetch(self): + @patch(f"{_FEED}.datetime") + def test_raises_before_any_successful_fetch(self, mock_dt): + mock_dt.now.return_value = _POST_TGE feed = OPGPriceFeed() with self.assertRaises(ValueError) as ctx: feed.get_price() self.assertIn("not yet available", str(ctx.exception)) + @patch(f"{_FEED}.datetime") @patch(f"{_FEED}.fetch_opg_price") - def test_returns_price_after_successful_refresh(self, mock_fetch): + def test_returns_price_after_successful_refresh(self, mock_fetch, mock_dt): + mock_dt.now.return_value = _POST_TGE mock_fetch.return_value = SAMPLE_PRICE feed = OPGPriceFeed(retry_delay=0) feed._refresh_price() self.assertEqual(feed.get_price(), SAMPLE_PRICE) + @patch(f"{_FEED}.datetime") @patch(f"{_FEED}.time.time") @patch(f"{_FEED}.fetch_opg_price") - def test_warns_when_price_is_stale(self, mock_fetch, mock_time): + def test_warns_when_price_is_stale(self, mock_fetch, mock_time, mock_dt): + mock_dt.now.return_value = _POST_TGE mock_fetch.return_value = SAMPLE_PRICE feed = OPGPriceFeed(refresh_interval=300, retry_delay=0) From b2e970236cee8a1ba78f13b846233c4d156725de Mon Sep 17 00:00:00 2001 From: kylexqian Date: Mon, 20 Apr 2026 14:37:00 -0700 Subject: [PATCH 15/17] feat: expire stale OPG price after 4 hours MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit get_price() now raises ValueError if the cached price is older than 4 hours, preventing billing on a price that is too outdated. A warning is still logged at the existing 2 × refresh_interval threshold (~10 min) as an early signal. The pre-inference gate in __main__.py will surface this as a 503. Co-Authored-By: Claude Sonnet 4.6 --- tee_gateway/price_feed/config.py | 7 +++++- tee_gateway/price_feed/feed.py | 33 +++++++++++++++++++---------- tee_gateway/test/test_price_feed.py | 18 ++++++++++++++++ 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/tee_gateway/price_feed/config.py b/tee_gateway/price_feed/config.py index fe4c980..a8742dc 100644 --- a/tee_gateway/price_feed/config.py +++ b/tee_gateway/price_feed/config.py @@ -31,12 +31,17 @@ TGE_FALLBACK_PRICE_USD = Decimal("0.10") # --------------------------------------------------------------------------- -# Stale-price warning threshold +# Stale-price thresholds # --------------------------------------------------------------------------- # get_price() logs WARNING when last successful fetch is older than # STALE_WARNING_MULTIPLIER × refresh_interval seconds. STALE_WARNING_MULTIPLIER = 2 +# get_price() raises ValueError when last successful fetch is older than +# STALE_PRICE_MAX_AGE seconds — at this point the cached price is considered +# too outdated to use for billing. +STALE_PRICE_MAX_AGE = 4 * 60 * 60 # 4 hours + @dataclass(frozen=True) class PriceFeedConfig: diff --git a/tee_gateway/price_feed/feed.py b/tee_gateway/price_feed/feed.py index b3d7369..bbfb86b 100644 --- a/tee_gateway/price_feed/feed.py +++ b/tee_gateway/price_feed/feed.py @@ -29,6 +29,7 @@ DEFAULT_REFRESH_INTERVAL, DEFAULT_RETRY_DELAY, FETCH_TIMEOUT, + STALE_PRICE_MAX_AGE, STALE_WARNING_MULTIPLIER, TGE_CUTOVER_UTC, TGE_FALLBACK_PRICE_USD, @@ -115,17 +116,27 @@ def get_price(self) -> Decimal: "OPG price not yet available — " "price feed has not completed a successful fetch" ) - if self.last_success is not None: - age = now - self.last_success - stale_threshold = self._refresh_interval * STALE_WARNING_MULTIPLIER - if age > stale_threshold: - logger.warning( - "OPG price data is stale: last successful fetch was %.0fs ago " - "(threshold: %.0fs); consecutive failures: %d", - age, - stale_threshold, - self.consecutive_failures, - ) + if self.last_success is None: + raise ValueError( + "OPG price not yet available — " + "price feed has not completed a successful fetch" + ) + age = now - self.last_success + if age > STALE_PRICE_MAX_AGE: + raise ValueError( + f"OPG price data expired: last successful fetch was {age:.0f}s ago " + f"(max: {STALE_PRICE_MAX_AGE}s); consecutive failures: " + f"{self.consecutive_failures}" + ) + stale_threshold = self._refresh_interval * STALE_WARNING_MULTIPLIER + if age > stale_threshold: + logger.warning( + "OPG price data is stale: last successful fetch was %.0fs ago " + "(threshold: %.0fs); consecutive failures: %d", + age, + stale_threshold, + self.consecutive_failures, + ) return self._price def get_status(self) -> dict[str, Any]: diff --git a/tee_gateway/test/test_price_feed.py b/tee_gateway/test/test_price_feed.py index b8eea81..fe9bb8a 100644 --- a/tee_gateway/test/test_price_feed.py +++ b/tee_gateway/test/test_price_feed.py @@ -269,6 +269,24 @@ def test_warns_when_price_is_stale(self, mock_fetch, mock_time, mock_dt): self.assertEqual(price, SAMPLE_PRICE) self.assertTrue(any("stale" in line.lower() for line in log_ctx.output)) + @patch(f"{_FEED}.datetime") + @patch(f"{_FEED}.time.time") + @patch(f"{_FEED}.fetch_opg_price") + def test_raises_when_price_exceeds_max_age(self, mock_fetch, mock_time, mock_dt): + mock_dt.now.return_value = _POST_TGE + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(retry_delay=0) + + mock_time.return_value = 0.0 + feed._refresh_price() + + # Advance past the 4-hour max age + mock_time.return_value = 4 * 60 * 60 + 1.0 + + with self.assertRaises(ValueError) as ctx: + feed.get_price() + self.assertIn("expired", str(ctx.exception)) + @patch(f"{_FEED}.time.time") @patch(f"{_FEED}.fetch_opg_price") def test_no_stale_warning_when_price_is_fresh(self, mock_fetch, mock_time): From 6289250bf543c0d0f9cc73320a1e17c3fbbbbe21 Mon Sep 17 00:00:00 2001 From: Kyle Qian Date: Mon, 20 Apr 2026 16:40:22 -0700 Subject: [PATCH 16/17] Update tee_gateway/price_feed/feed.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tee_gateway/price_feed/feed.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tee_gateway/price_feed/feed.py b/tee_gateway/price_feed/feed.py index bbfb86b..6875f68 100644 --- a/tee_gateway/price_feed/feed.py +++ b/tee_gateway/price_feed/feed.py @@ -83,10 +83,12 @@ def start(self) -> None: "OPG price feed already running, ignoring duplicate start()" ) return - self._thread = threading.Thread( - target=self._run_with_initial_fetch, name="opg-price-feed", daemon=True - ) - self._thread.start() + self._thread = threading.Thread( + target=self._run_with_initial_fetch, + name="opg-price-feed", + daemon=True, + ) + self._thread.start() logger.info( "OPG price feed started (refresh_interval=%ds, max_retries=%d)", self._refresh_interval, From caf076cf499f4705dcd0dc758bd71fb6e5c9c428 Mon Sep 17 00:00:00 2001 From: Kyle Qian Date: Mon, 20 Apr 2026 16:41:05 -0700 Subject: [PATCH 17/17] Update tee_gateway/util.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tee_gateway/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tee_gateway/util.py b/tee_gateway/util.py index 1c9047e..ac79cd6 100644 --- a/tee_gateway/util.py +++ b/tee_gateway/util.py @@ -306,7 +306,7 @@ def calculate_session_cost( ) logger.info( - "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d " + "CALCULATE_SESSION_COST model=%s input_tokens=%d output_tokens=%d " "total_usd=%s token_price_usd=%s decimals=%d cost=%d", model, input_tokens,