From 0f9726f612bc874309ecc8e60ea8b80bdceb2306 Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Wed, 1 Apr 2026 18:53:25 +0530 Subject: [PATCH 01/13] init token update --- tee_gateway/__main__.py | 20 ++++++++++++++++++-- tee_gateway/definitions.py | 11 +++++++++-- tests/test_pricing.py | 4 ++-- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index a27f426..8afcdfb 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -39,7 +39,9 @@ BASE_TESTNET_NETWORK, EVM_PAYMENT_ADDRESS, USDC_ADDRESS, - BASE_OPG_ADDRESS, + BASE_TESTNET_OPG_ADDRESS, + BASE_MAINNET_NETWORK, + BASE_MAINNET_OPG_ADDRESS, CHAT_COMPLETIONS_USDC_AMOUNT, CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, COMPLETIONS_USDC_AMOUNT, @@ -138,7 +140,7 @@ def _shutdown_heartbeat(): pay_to=EVM_PAYMENT_ADDRESS, price=AssetAmount( amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, - asset=BASE_OPG_ADDRESS, + asset=BASE_TESTNET_OPG_ADDRESS, extra={ "name": "OPG", "version": "2", @@ -147,6 +149,20 @@ def _shutdown_heartbeat(): ), network=BASE_TESTNET_NETWORK, ), + PaymentOption( + scheme="upto", + pay_to=EVM_PAYMENT_ADDRESS, + price=AssetAmount( + amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, + asset=BASE_MAINNET_OPG_ADDRESS, + extra={ + "name": "OPG", + "version": "2", + "assetTransferMethod": "permit2", + }, + ), + network=BASE_MAINNET_NETWORK, + ), ], mime_type="application/json", description="Chat completion", diff --git a/tee_gateway/definitions.py b/tee_gateway/definitions.py index 6ff58de..2c14a4e 100644 --- a/tee_gateway/definitions.py +++ b/tee_gateway/definitions.py @@ -25,6 +25,9 @@ # Base Testnet — where OPG payments are accepted BASE_TESTNET_NETWORK: str = "eip155:84532" +# Base Mainnet — where OPG payments are accepted +BASE_MAINNET_NETWORK: str = "eip155:8453" + # --------------------------------------------------------------------------- # Payment recipient # --------------------------------------------------------------------------- @@ -45,7 +48,10 @@ USDC_ADDRESS: str = "0x094E464A23B90A71a0894D5D1e5D470FfDD074e1" # OpenGradient token (OPG) on Base Testnet -BASE_OPG_ADDRESS: str = "0x240b09731D96979f50B2C649C9CE10FcF9C7987F" +BASE_TESTNET_OPG_ADDRESS: str = "0x240b09731D96979f50B2C649C9CE10FcF9C7987F" + +# OpenGradient token (OPG) on Base Mainnet +BASE_MAINNET_OPG_ADDRESS: str = "0x5feCcD17C393CaF1001D18164236A37E731FCb9d" # --------------------------------------------------------------------------- # Token decimal places @@ -54,7 +60,8 @@ # Maps lowercase contract address → number of decimals for unit conversion. ASSET_DECIMALS_BY_ADDRESS: dict[str, int] = { USDC_ADDRESS.lower(): 6, # USDC / OUSDC standard: 6 decimals - BASE_OPG_ADDRESS.lower(): 18, # OPG: 18 decimals (ERC-20 standard) + BASE_TESTNET_OPG_ADDRESS.lower(): 18, # OPG: 18 decimals (ERC-20 standard) + BASE_MAINNET_OPG_ADDRESS.lower(): 18, # OPG: 18 decimals (ERC-20 standard) } # Fallback for any asset not explicitly listed above diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 2b397c3..1a2f528 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -11,7 +11,7 @@ import unittest from decimal import Decimal -from tee_gateway.definitions import BASE_OPG_ADDRESS, USDC_ADDRESS +from tee_gateway.definitions import BASE_TESTNET_OPG_ADDRESS, USDC_ADDRESS from tee_gateway.model_registry import ( _MODEL_LOOKUP, get_model_config, @@ -26,7 +26,7 @@ def _opg_requirements() -> dict: """Fake PaymentRequirements dict for OPG (18 decimals).""" - return {"asset": BASE_OPG_ADDRESS, "amount": "50000000000000000"} + return {"asset": BASE_TESTNET_OPG_ADDRESS, "amount": "50000000000000000"} def _usdc_requirements() -> dict: From dac78ea91b4281f6668270e58b0ec387b8264028 Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Thu, 9 Apr 2026 19:09:22 +0530 Subject: [PATCH 02/13] base mainnet changes --- pyproject.toml | 2 +- tee_gateway/__main__.py | 70 ++++++++++++++++++++++++----------------- uv.lock | 18 ++++++++--- 3 files changed, 56 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5f27fb3..285ca61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "setuptools>=21.0.0", "Flask>=3.0.0", "gunicorn>=23.0.0", - "og-x402[evm]==0.0.1.dev6", + "og-x402[evm]>=0.0.1.dev7", "fastapi>=0.128.0", "uvicorn[standard]>=0.40.0", "pydantic>=2.12.5", diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 8afcdfb..c92c4e7 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -27,24 +27,22 @@ from x402.http.types import RouteConfig from x402.mechanisms.evm.exact import ExactEvmServerScheme from x402.mechanisms.evm.upto import UptoEvmServerScheme +from x402.extensions.eip2612_gas_sponsoring import declare_eip2612_gas_sponsoring_extension +from x402.extensions.erc20_approval_gas_sponsoring import declare_erc20_approval_gas_sponsoring_extension from x402.schemas import AssetAmount from x402.server import x402ResourceServerSync from x402.session import SessionStore -import types as _types import x402.http.middleware.flask as x402_flask +import types as _types from .util import dynamic_session_cost_calculator from .definitions import ( - EVM_NETWORK, BASE_TESTNET_NETWORK, EVM_PAYMENT_ADDRESS, - USDC_ADDRESS, BASE_TESTNET_OPG_ADDRESS, BASE_MAINNET_NETWORK, BASE_MAINNET_OPG_ADDRESS, - CHAT_COMPLETIONS_USDC_AMOUNT, CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, - COMPLETIONS_USDC_AMOUNT, FACILITATOR_URL, ) @@ -113,28 +111,16 @@ def _shutdown_heartbeat(): server = x402ResourceServerSync(facilitator) store = SessionStore() -server.register(EVM_NETWORK, ExactEvmServerScheme()) server.register(BASE_TESTNET_NETWORK, ExactEvmServerScheme()) -server.register(EVM_NETWORK, UptoEvmServerScheme()) +server.register(BASE_MAINNET_NETWORK, ExactEvmServerScheme()) + +# Upto scheme registrations (permit2-based, variable settlement) server.register(BASE_TESTNET_NETWORK, UptoEvmServerScheme()) +server.register(BASE_MAINNET_NETWORK, UptoEvmServerScheme()) routes = { "POST /v1/chat/completions": RouteConfig( accepts=[ - PaymentOption( - scheme="upto", - pay_to=EVM_PAYMENT_ADDRESS, - price=AssetAmount( - amount=CHAT_COMPLETIONS_USDC_AMOUNT, - asset=USDC_ADDRESS, - extra={ - "name": "OUSDC", - "version": "2", - "assetTransferMethod": "permit2", - }, - ), - network=EVM_NETWORK, - ), PaymentOption( scheme="upto", pay_to=EVM_PAYMENT_ADDRESS, @@ -142,8 +128,8 @@ def _shutdown_heartbeat(): amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, asset=BASE_TESTNET_OPG_ADDRESS, extra={ - "name": "OPG", - "version": "2", + "name": "OpenGradient", + "version": "1", "assetTransferMethod": "permit2", }, ), @@ -156,14 +142,18 @@ def _shutdown_heartbeat(): amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, asset=BASE_MAINNET_OPG_ADDRESS, extra={ - "name": "OPG", - "version": "2", + "name": "OpenGradient", + "version": "1", "assetTransferMethod": "permit2", }, ), network=BASE_MAINNET_NETWORK, ), ], + extensions={ + **declare_eip2612_gas_sponsoring_extension(), + **declare_erc20_approval_gas_sponsoring_extension(), + }, mime_type="application/json", description="Chat completion", ), @@ -173,13 +163,35 @@ def _shutdown_heartbeat(): scheme="upto", pay_to=EVM_PAYMENT_ADDRESS, price=AssetAmount( - amount=COMPLETIONS_USDC_AMOUNT, - asset=USDC_ADDRESS, - extra={"name": "USDC", "version": "2"}, + amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, + asset=BASE_TESTNET_OPG_ADDRESS, + extra={ + "name": "OpenGradient", + "version": "1", + "assetTransferMethod": "permit2", + }, ), - network=EVM_NETWORK, + network=BASE_TESTNET_NETWORK, ), + PaymentOption( + scheme="upto", + pay_to=EVM_PAYMENT_ADDRESS, + price=AssetAmount( + amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, + asset=BASE_MAINNET_OPG_ADDRESS, + extra={ + "name": "OpenGradient", + "version": "1", + "assetTransferMethod": "permit2", + }, + ), + network=BASE_MAINNET_NETWORK, + ), ], + extensions={ + **declare_eip2612_gas_sponsoring_extension(), + **declare_erc20_approval_gas_sponsoring_extension(), + }, mime_type="application/json", description="Completion", ), diff --git a/uv.lock b/uv.lock index 4026c9d..9f67ce2 100644 --- a/uv.lock +++ b/uv.lock @@ -1242,17 +1242,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, ] +[[package]] +name = "nest-asyncio" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/83/f8/51569ac65d696c8ecbee95938f89d4abf00f47d58d48f6fbabfe8f0baefe/nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe", size = 7418, upload-time = "2024-01-21T14:25:19.227Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" }, +] + [[package]] name = "og-x402" -version = "0.0.1.dev6" +version = "0.0.1.dev7" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "nest-asyncio" }, { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/70/aa/2b616b9be6dfa4dfee98bde3ed20dd41cb446d0569e0069c1d6c11faa032/og_x402-0.0.1.dev6.tar.gz", hash = "sha256:140c4b725f372e81f4a3c2caf392f58b6fcf242bc51a1c3a6417f58e3ef9e347", size = 900115, upload-time = "2026-03-30T07:13:25.623Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e2/45/5fa555d9c8ba319780b6f12d8f7222e31019e90b032b85ce9a7f9a08a9f4/og_x402-0.0.1.dev7.tar.gz", hash = "sha256:a4bce840c07b783d14debad0c11941c8660ea8d80c423164e0e18da012bf92d6", size = 1306033, upload-time = "2026-04-09T11:42:59.898Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/03/5e/a64de6f29eb80bb180288297882d5aba2a894363622d4f94417b420cf0b5/og_x402-0.0.1.dev6-py3-none-any.whl", hash = "sha256:2a1f962fa2a50d02f28421199027245d5c5f013f36a143ec2f184a546325f1bd", size = 952670, upload-time = "2026-03-30T07:13:00.408Z" }, + { url = "https://files.pythonhosted.org/packages/36/a0/30f882406ba5e52913127f1e85c6fd2ec977291273be0fdc89980dcb7c4c/og_x402-0.0.1.dev7-py3-none-any.whl", hash = "sha256:947f91a13134350997fcac2a78c435587111d6ea7d69d8a4a96fbf63d21306c7", size = 1386876, upload-time = "2026-04-09T11:42:57.667Z" }, ] [package.optional-dependencies] @@ -1881,7 +1891,7 @@ requires-dist = [ { name = "langchain-google-genai", specifier = ">=4.2.0" }, { name = "langchain-openai", specifier = ">=0.3.35" }, { name = "langchain-xai", specifier = ">=0.2.5" }, - { name = "og-x402", extras = ["evm"], specifier = "==0.0.1.dev6" }, + { name = "og-x402", extras = ["evm"], specifier = ">=0.0.1.dev7" }, { name = "openai", specifier = ">=2.15.0" }, { name = "psutil", specifier = ">=7.2.1" }, { name = "pydantic", specifier = ">=2.12.5" }, From ce3a491c8e876d11ff8ef42b3cd7a458854863b5 Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Thu, 9 Apr 2026 20:09:49 +0530 Subject: [PATCH 03/13] addr updates --- tee_gateway/definitions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tee_gateway/definitions.py b/tee_gateway/definitions.py index 2c14a4e..4337e87 100644 --- a/tee_gateway/definitions.py +++ b/tee_gateway/definitions.py @@ -37,7 +37,7 @@ # your own instance. EVM_PAYMENT_ADDRESS: str = os.getenv( "EVM_PAYMENT_ADDRESS", - "0x40eFb45552EDfB2502D90A657a8ab41F03ec460d", + "0x9deEBB5D1b22e4a6e027977CeAd13893A7E4cC1a", ) # --------------------------------------------------------------------------- From 015edb96aedce1712c4655b0267c3ce7869c4f17 Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Fri, 10 Apr 2026 01:15:45 +0530 Subject: [PATCH 04/13] dep updates --- pyproject.toml | 2 +- uv.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 285ca61..42d915e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "setuptools>=21.0.0", "Flask>=3.0.0", "gunicorn>=23.0.0", - "og-x402[evm]>=0.0.1.dev7", + "og-x402[evm]>=0.0.1.dev8", "fastapi>=0.128.0", "uvicorn[standard]>=0.40.0", "pydantic>=2.12.5", diff --git a/uv.lock b/uv.lock index 9f67ce2..5029d5d 100644 --- a/uv.lock +++ b/uv.lock @@ -1253,16 +1253,16 @@ wheels = [ [[package]] name = "og-x402" -version = "0.0.1.dev7" +version = "0.0.1.dev8" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nest-asyncio" }, { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e2/45/5fa555d9c8ba319780b6f12d8f7222e31019e90b032b85ce9a7f9a08a9f4/og_x402-0.0.1.dev7.tar.gz", hash = "sha256:a4bce840c07b783d14debad0c11941c8660ea8d80c423164e0e18da012bf92d6", size = 1306033, upload-time = "2026-04-09T11:42:59.898Z" } +sdist = { url = "https://files.pythonhosted.org/packages/94/9e/1d718f3e0f7a6f6fd53c8a183c1794bc4aa15d986b0faa76139d5b04096b/og_x402-0.0.1.dev8.tar.gz", hash = "sha256:9d02c2c81112b7a612cd1aea03c09af75fc75d70766d042b5ddcc82ee7d8f98a", size = 1306960, upload-time = "2026-04-09T19:44:24.966Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/36/a0/30f882406ba5e52913127f1e85c6fd2ec977291273be0fdc89980dcb7c4c/og_x402-0.0.1.dev7-py3-none-any.whl", hash = "sha256:947f91a13134350997fcac2a78c435587111d6ea7d69d8a4a96fbf63d21306c7", size = 1386876, upload-time = "2026-04-09T11:42:57.667Z" }, + { url = "https://files.pythonhosted.org/packages/c8/0e/48facce5d73330d1cb79bbd67eda9c94b9786ea86f433338ee4423a6b1d0/og_x402-0.0.1.dev8-py3-none-any.whl", hash = "sha256:2b5b9601a6d312f9b1cf68967eaf98229eb203c54ca403e46994d6eed2488ccc", size = 1387989, upload-time = "2026-04-09T19:44:23.174Z" }, ] [package.optional-dependencies] @@ -1891,7 +1891,7 @@ requires-dist = [ { name = "langchain-google-genai", specifier = ">=4.2.0" }, { name = "langchain-openai", specifier = ">=0.3.35" }, { name = "langchain-xai", specifier = ">=0.2.5" }, - { name = "og-x402", extras = ["evm"], specifier = ">=0.0.1.dev7" }, + { name = "og-x402", extras = ["evm"], specifier = ">=0.0.1.dev8" }, { name = "openai", specifier = ">=2.15.0" }, { name = "psutil", specifier = ">=7.2.1" }, { name = "pydantic", specifier = ">=2.12.5" }, From 6bd40759658c33a146893714163c7308d36e657a Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Fri, 10 Apr 2026 16:11:12 +0530 Subject: [PATCH 05/13] updates --- pyproject.toml | 2 +- uv.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 42d915e..027e43e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "setuptools>=21.0.0", "Flask>=3.0.0", "gunicorn>=23.0.0", - "og-x402[evm]>=0.0.1.dev8", + "og-x402[evm]>=0.0.1.dev9", "fastapi>=0.128.0", "uvicorn[standard]>=0.40.0", "pydantic>=2.12.5", diff --git a/uv.lock b/uv.lock index 5029d5d..4cab594 100644 --- a/uv.lock +++ b/uv.lock @@ -1253,16 +1253,16 @@ wheels = [ [[package]] name = "og-x402" -version = "0.0.1.dev8" +version = "0.0.1.dev9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nest-asyncio" }, { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/94/9e/1d718f3e0f7a6f6fd53c8a183c1794bc4aa15d986b0faa76139d5b04096b/og_x402-0.0.1.dev8.tar.gz", hash = "sha256:9d02c2c81112b7a612cd1aea03c09af75fc75d70766d042b5ddcc82ee7d8f98a", size = 1306960, upload-time = "2026-04-09T19:44:24.966Z" } +sdist = { url = "https://files.pythonhosted.org/packages/97/f5/02e7b68af825c200da2aa88292f2c07823d321a4fd9e2a3d20130358fc10/og_x402-0.0.1.dev9.tar.gz", hash = "sha256:d3cfd05443636712cb1277e3d904b878d875a60b3728d64265098ea06eeb116b", size = 1312652, upload-time = "2026-04-10T10:40:13.267Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/0e/48facce5d73330d1cb79bbd67eda9c94b9786ea86f433338ee4423a6b1d0/og_x402-0.0.1.dev8-py3-none-any.whl", hash = "sha256:2b5b9601a6d312f9b1cf68967eaf98229eb203c54ca403e46994d6eed2488ccc", size = 1387989, upload-time = "2026-04-09T19:44:23.174Z" }, + { url = "https://files.pythonhosted.org/packages/8b/08/f5a05fc8454541e96650d44bf15b34491505d0e4f1e9e77b26c804fbbdd3/og_x402-0.0.1.dev9-py3-none-any.whl", hash = "sha256:2db171be2526aa13a1243255538d185c4f1f6106f615eff532d1720a89672034", size = 1392934, upload-time = "2026-04-10T10:40:11.224Z" }, ] [package.optional-dependencies] @@ -1891,7 +1891,7 @@ requires-dist = [ { name = "langchain-google-genai", specifier = ">=4.2.0" }, { name = "langchain-openai", specifier = ">=0.3.35" }, { name = "langchain-xai", specifier = ">=0.2.5" }, - { name = "og-x402", extras = ["evm"], specifier = ">=0.0.1.dev8" }, + { name = "og-x402", extras = ["evm"], specifier = ">=0.0.1.dev9" }, { name = "openai", specifier = ">=2.15.0" }, { name = "psutil", specifier = ">=7.2.1" }, { name = "pydantic", specifier = ">=2.12.5" }, From 7d3ec36da1bd6e9580677368d3316be5ac5fc1ac Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Sun, 12 Apr 2026 11:21:21 +0530 Subject: [PATCH 06/13] token addr update --- tee_gateway/definitions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tee_gateway/definitions.py b/tee_gateway/definitions.py index 4337e87..7ff3111 100644 --- a/tee_gateway/definitions.py +++ b/tee_gateway/definitions.py @@ -51,7 +51,7 @@ BASE_TESTNET_OPG_ADDRESS: str = "0x240b09731D96979f50B2C649C9CE10FcF9C7987F" # OpenGradient token (OPG) on Base Mainnet -BASE_MAINNET_OPG_ADDRESS: str = "0x5feCcD17C393CaF1001D18164236A37E731FCb9d" +BASE_MAINNET_OPG_ADDRESS: str = "0xFbC2051AE2265686a469421b2C5A2D5462FbF5eB" # --------------------------------------------------------------------------- # Token decimal places From 793cc931d6b6b4ec918fe07fe640cd9f658bf1a2 Mon Sep 17 00:00:00 2001 From: kylexqian Date: Sun, 12 Apr 2026 11:53:54 -0700 Subject: [PATCH 07/13] Remove unused USDC var and lint --- tee_gateway/__main__.py | 10 +++++++--- tee_gateway/definitions.py | 1 - tests/test_pricing.py | 2 +- uv.lock | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index c92c4e7..ea43a7e 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -27,8 +27,12 @@ from x402.http.types import RouteConfig from x402.mechanisms.evm.exact import ExactEvmServerScheme from x402.mechanisms.evm.upto import UptoEvmServerScheme -from x402.extensions.eip2612_gas_sponsoring import declare_eip2612_gas_sponsoring_extension -from x402.extensions.erc20_approval_gas_sponsoring import declare_erc20_approval_gas_sponsoring_extension +from x402.extensions.eip2612_gas_sponsoring import ( + declare_eip2612_gas_sponsoring_extension, +) +from x402.extensions.erc20_approval_gas_sponsoring import ( + declare_erc20_approval_gas_sponsoring_extension, +) from x402.schemas import AssetAmount from x402.server import x402ResourceServerSync from x402.session import SessionStore @@ -186,7 +190,7 @@ def _shutdown_heartbeat(): }, ), network=BASE_MAINNET_NETWORK, - ), + ), ], extensions={ **declare_eip2612_gas_sponsoring_extension(), diff --git a/tee_gateway/definitions.py b/tee_gateway/definitions.py index aac907c..6517185 100644 --- a/tee_gateway/definitions.py +++ b/tee_gateway/definitions.py @@ -53,7 +53,6 @@ # Maps lowercase contract address → number of decimals for unit conversion. ASSET_DECIMALS_BY_ADDRESS: dict[str, int] = { - USDC_ADDRESS.lower(): 6, # USDC / OUSDC standard: 6 decimals BASE_TESTNET_OPG_ADDRESS.lower(): 18, # OPG: 18 decimals (ERC-20 standard) BASE_MAINNET_OPG_ADDRESS.lower(): 18, # OPG: 18 decimals (ERC-20 standard) } diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 3629ebd..088a849 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -11,7 +11,7 @@ import unittest from decimal import Decimal -from tee_gateway.definitions import BASE_TESTNET_OPG_ADDRESS, USDC_ADDRESS +from tee_gateway.definitions import BASE_TESTNET_OPG_ADDRESS from tee_gateway.model_registry import ( _MODEL_LOOKUP, get_model_config, diff --git a/uv.lock b/uv.lock index 8dad936..bfb2691 100644 --- a/uv.lock +++ b/uv.lock @@ -1891,7 +1891,7 @@ requires-dist = [ { name = "langchain-google-genai", specifier = ">=4.2.1" }, { name = "langchain-openai", specifier = ">=1.1.12" }, { name = "langchain-xai", specifier = ">=1.2.2" }, - { name = "og-x402", extras = ["evm"], specifier = "==0.0.1.dev9" }, + { name = "og-x402", extras = ["evm"], specifier = ">=0.0.1.dev9" }, { name = "openai", specifier = ">=2.15.0" }, { name = "psutil", specifier = ">=7.2.1" }, { name = "pydantic", specifier = ">=2.12.5" }, From fa2c5457acfa749eeecf7a84712245b965e932aa Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Tue, 14 Apr 2026 15:58:53 +0530 Subject: [PATCH 08/13] extension change --- tee_gateway/__main__.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index ea43a7e..935f077 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -27,9 +27,6 @@ from x402.http.types import RouteConfig from x402.mechanisms.evm.exact import ExactEvmServerScheme from x402.mechanisms.evm.upto import UptoEvmServerScheme -from x402.extensions.eip2612_gas_sponsoring import ( - declare_eip2612_gas_sponsoring_extension, -) from x402.extensions.erc20_approval_gas_sponsoring import ( declare_erc20_approval_gas_sponsoring_extension, ) @@ -155,7 +152,6 @@ def _shutdown_heartbeat(): ), ], extensions={ - **declare_eip2612_gas_sponsoring_extension(), **declare_erc20_approval_gas_sponsoring_extension(), }, mime_type="application/json", @@ -193,7 +189,6 @@ def _shutdown_heartbeat(): ), ], extensions={ - **declare_eip2612_gas_sponsoring_extension(), **declare_erc20_approval_gas_sponsoring_extension(), }, mime_type="application/json", From baf01e90028146b567b4316713872c35ef8a607c Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Fri, 17 Apr 2026 14:36:33 +0530 Subject: [PATCH 09/13] mainnet changes --- tee_gateway/__main__.py | 32 -------------------------------- tee_gateway/definitions.py | 6 ------ tests/test_pricing.py | 4 ++-- 3 files changed, 2 insertions(+), 40 deletions(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 935f077..1ba44c2 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -38,9 +38,7 @@ from .util import dynamic_session_cost_calculator from .definitions import ( - BASE_TESTNET_NETWORK, EVM_PAYMENT_ADDRESS, - BASE_TESTNET_OPG_ADDRESS, BASE_MAINNET_NETWORK, BASE_MAINNET_OPG_ADDRESS, CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, @@ -112,30 +110,14 @@ def _shutdown_heartbeat(): server = x402ResourceServerSync(facilitator) store = SessionStore() -server.register(BASE_TESTNET_NETWORK, ExactEvmServerScheme()) server.register(BASE_MAINNET_NETWORK, ExactEvmServerScheme()) # Upto scheme registrations (permit2-based, variable settlement) -server.register(BASE_TESTNET_NETWORK, UptoEvmServerScheme()) server.register(BASE_MAINNET_NETWORK, UptoEvmServerScheme()) routes = { "POST /v1/chat/completions": RouteConfig( accepts=[ - PaymentOption( - scheme="upto", - pay_to=EVM_PAYMENT_ADDRESS, - price=AssetAmount( - amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, - asset=BASE_TESTNET_OPG_ADDRESS, - extra={ - "name": "OpenGradient", - "version": "1", - "assetTransferMethod": "permit2", - }, - ), - network=BASE_TESTNET_NETWORK, - ), PaymentOption( scheme="upto", pay_to=EVM_PAYMENT_ADDRESS, @@ -159,20 +141,6 @@ def _shutdown_heartbeat(): ), "POST /v1/completions": RouteConfig( accepts=[ - PaymentOption( - scheme="upto", - pay_to=EVM_PAYMENT_ADDRESS, - price=AssetAmount( - amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, - asset=BASE_TESTNET_OPG_ADDRESS, - extra={ - "name": "OpenGradient", - "version": "1", - "assetTransferMethod": "permit2", - }, - ), - network=BASE_TESTNET_NETWORK, - ), PaymentOption( scheme="upto", pay_to=EVM_PAYMENT_ADDRESS, diff --git a/tee_gateway/definitions.py b/tee_gateway/definitions.py index 6517185..ca316bd 100644 --- a/tee_gateway/definitions.py +++ b/tee_gateway/definitions.py @@ -19,8 +19,6 @@ # Network IDs (EIP-155 chain identifiers) # --------------------------------------------------------------------------- -# Base Testnet — where OPG payments are accepted -BASE_TESTNET_NETWORK: str = "eip155:84532" # Base Mainnet — where OPG payments are accepted BASE_MAINNET_NETWORK: str = "eip155:8453" @@ -41,9 +39,6 @@ # ERC-20 token contract addresses # --------------------------------------------------------------------------- -# OpenGradient token (OPG) on Base Testnet -BASE_TESTNET_OPG_ADDRESS: str = "0x240b09731D96979f50B2C649C9CE10FcF9C7987F" - # OpenGradient token (OPG) on Base Mainnet BASE_MAINNET_OPG_ADDRESS: str = "0xFbC2051AE2265686a469421b2C5A2D5462FbF5eB" @@ -53,7 +48,6 @@ # Maps lowercase contract address → number of decimals for unit conversion. ASSET_DECIMALS_BY_ADDRESS: dict[str, int] = { - BASE_TESTNET_OPG_ADDRESS.lower(): 18, # OPG: 18 decimals (ERC-20 standard) BASE_MAINNET_OPG_ADDRESS.lower(): 18, # OPG: 18 decimals (ERC-20 standard) } diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 088a849..d1b5f25 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -11,7 +11,7 @@ import unittest from decimal import Decimal -from tee_gateway.definitions import BASE_TESTNET_OPG_ADDRESS +from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS from tee_gateway.model_registry import ( _MODEL_LOOKUP, get_model_config, @@ -26,7 +26,7 @@ def _opg_requirements() -> dict: """Fake PaymentRequirements dict for OPG (18 decimals).""" - return {"asset": BASE_TESTNET_OPG_ADDRESS, "amount": "50000000000000000"} + return {"asset": BASE_MAINNET_OPG_ADDRESS, "amount": "50000000000000000"} def _ctx(model: str, input_tokens: int, output_tokens: int, requirements=None) -> dict: From 151191a2ebf8212b06baf769aedec4beb8e51356 Mon Sep 17 00:00:00 2001 From: kylexqian Date: Mon, 20 Apr 2026 17:24:27 -0700 Subject: [PATCH 10/13] Fix lint --- tee_gateway/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 1ba44c2..c9e8f28 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -473,7 +473,7 @@ def _strict_resolve_session_request_cost( return self._coerce_non_negative_int(dynamic_cost) -_payment_mw._resolve_session_request_cost = _types.MethodType( # type: ignore[method-assign] +_payment_mw._resolve_session_request_cost = _types.MethodType( # type: ignore[method-assign, attr-defined] _strict_resolve_session_request_cost, _payment_mw ) From e8f76e2111d9e1da167248362e6747a7c7b3424e Mon Sep 17 00:00:00 2001 From: kylexqian Date: Mon, 20 Apr 2026 18:34:50 -0700 Subject: [PATCH 11/13] fix: use COMPLETIONS_OPG_SESSION_MAX_SPEND for /v1/completions route Co-Authored-By: Claude Sonnet 4.6 --- tee_gateway/__main__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index c9e8f28..fd36b24 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -42,6 +42,7 @@ BASE_MAINNET_NETWORK, BASE_MAINNET_OPG_ADDRESS, CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, + COMPLETIONS_OPG_SESSION_MAX_SPEND, FACILITATOR_URL, ) @@ -145,7 +146,7 @@ def _shutdown_heartbeat(): scheme="upto", pay_to=EVM_PAYMENT_ADDRESS, price=AssetAmount( - amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND, + amount=COMPLETIONS_OPG_SESSION_MAX_SPEND, asset=BASE_MAINNET_OPG_ADDRESS, extra={ "name": "OpenGradient", From e100e2d5eb426fbf1ad1ef3e52af5c413e88522a Mon Sep 17 00:00:00 2001 From: Kyle Qian Date: Mon, 20 Apr 2026 19:53:05 -0700 Subject: [PATCH 12/13] feat: coingecko opg price feed (#56) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Add CoinGecko OPG/USD price feed background service Replaces the hardcoded mock price (Decimal("1")) with a real CoinGecko price feed that runs as a background daemon thread. Key behaviour: - Fetches OPG/USD price from CoinGecko /simple/token_price/base at startup and every 5 minutes thereafter (well within free-tier rate limits) - Retries up to 3 times per refresh cycle with 10s delay between attempts; exits retry loop immediately on 429 (rate-limited) to avoid hammering - Retains the last known good price on exhausted retries so live traffic is not disrupted by a transient CoinGecko outage - Logs a WARNING when the cached price is older than 2× the refresh interval (background loop may be stuck) - Tracks last_success, last_error, consecutive_failures, total_fetches, total_errors via get_status() / get_price_feed_status() - get_price() raises ValueError when no price has ever been fetched, which propagates through dynamic_session_cost_calculator and the existing strict _resolve_session_request_cost monkey-patch to return HTTP 500 rather than silently charging an incorrect amount Also adds: - 29 unit tests (all mocked, no network required) covering the fetch helper, retry logic, rate-limit handling, stale warning, stats, and module-level singleton functions - 4 integration tests that hit the live CoinGecko API; OPG-specific tests skip gracefully until the token is fully indexed Co-Authored-By: Claude Sonnet 4.6 * refactor: Reorganize price feed into tee_gateway/price_feed/ package Moves opg_price_feed.py into a dedicated package matching the layout of the existing tee_gateway/heartbeat/ package: tee_gateway/price_feed/ __init__.py — re-exports public API (OPGPriceFeed, PriceFeedConfig, start_price_feed, get_opg_price_usd, get_price_feed_status) config.py — all constants (COINGECKO_BASE_URL, COINGECKO_PLATFORM, FETCH_TIMEOUT, refresh/retry defaults, stale threshold) plus the PriceFeedConfig frozen dataclass feed.py — OPGPriceFeed class, fetch_opg_price(), singleton helpers Also renames the test files to match: test_opg_price_feed.py -> test_price_feed.py test_opg_price_feed_integration.py -> test_price_feed_integration.py No behaviour changes — all 29 unit tests still pass. Co-Authored-By: Claude Sonnet 4.6 * refactor: dependency injection for OPG price feed via make_cost_calculator Replace module-level singleton in price_feed with a factory/closure pattern. make_cost_calculator(price_feed) in util.py binds an OPGPriceFeed instance explicitly, eliminating hidden global state. Tests updated: TestModuleLevelFunctions removed, TestMakeCostCalculator added (10 cases). Also fix missing Any import in feed.py and unused PriceFeedConfig import. Co-Authored-By: Claude Sonnet 4.6 * refactor: replace make_cost_calculator factory with calculate_session_cost + named function Remove the factory/closure pattern and _PriceSource Protocol. util.py now exports calculate_session_cost(context, get_price) — a plain function that accepts a Callable[[], Decimal]. __main__.py wires it via a named _session_cost_calculator function that passes _price_feed.get_price directly. Tests updated to match the new signature. Co-Authored-By: Claude Sonnet 4.6 * feat: include price feed status in /health response Co-Authored-By: Claude Sonnet 4.6 * test: verify calculate_session_cost fetches live price on every call Co-Authored-By: Claude Sonnet 4.6 * fix: update test_pricing.py for calculate_session_cost signature change Replace dynamic_session_cost_calculator import with calculate_session_cost and pass _get_price (OPG=$1.00) to all call sites. Co-Authored-By: Claude Sonnet 4.6 * fix: add attr-defined to type: ignore on monkey-patch line Co-Authored-By: Claude Sonnet 4.6 * fix: address Copilot review comments on price feed - start() is now non-blocking: initial fetch runs inside the background thread instead of blocking the caller - start() is idempotent: duplicate calls are a no-op if thread is alive - Clear last_error on successful refresh so /health reflects recovery - Gate integration tests behind RUN_INTEGRATION_TESTS env var to prevent real CoinGecko calls in CI by default - Fix stale docstring references to make_cost_calculator → calculate_session_cost Co-Authored-By: Claude Sonnet 4.6 * Update .github/workflows/test.yml Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: update stale dynamic_session_cost_calculator references Rename test classes and update comment/error message in __main__.py to match the current calculate_session_cost / _session_cost_calculator names. Co-Authored-By: Claude Sonnet 4.6 * feat: add pre-inference pricing gate and remove ineffective monkey-patch Replace the _strict_resolve_session_request_cost monkey-patch (which was patching a method not called in the upto session flow) with a proper before_request hook that rejects inference requests early if pricing would fail: - 503 if the OPG price feed has no valid price yet - 400 if the requested model is not in the registry _session_cost_calculator now logs CRITICAL with full traceback on any post-inference cost failure (e.g. missing usage field) so uncharged requests are never silently missed. Co-Authored-By: Claude Sonnet 4.6 * fix: address Copilot review comments (docstrings, defensive coding, thread safety) - Update start() docstring: replace monkey-patch reference with pre-inference gate - Make fetch_opg_price() validate response.json() is a dict before calling .get(), raising ValueError instead of AttributeError on unexpected CoinGecko response shapes - Make start() idempotency check thread-safe under _lock - Clarify CRITICAL log: provider error, x402 swallows exception so client is not charged - Update calculate_session_cost docstring: replace monkey-patch reference with pre-inference gate and CRITICAL log behavior Co-Authored-By: Claude Sonnet 4.6 * feat: TGE fallback price and CoinGecko sanity checks Before the TGE cutover (2026-04-21 12:30 UTC) get_price() returns a fixed $0.10 fallback so inference requests can be priced before OPG is listed on CoinGecko. After the cutover the live cached price is used. Also adds a guard in fetch_opg_price() rejecting non-positive or non-finite prices from the API response. Co-Authored-By: Claude Sonnet 4.6 * feat: expire stale OPG price after 4 hours get_price() now raises ValueError if the cached price is older than 4 hours, preventing billing on a price that is too outdated. A warning is still logged at the existing 2 × refresh_interval threshold (~10 min) as an early signal. The pre-inference gate in __main__.py will surface this as a 503. Co-Authored-By: Claude Sonnet 4.6 * Update tee_gateway/price_feed/feed.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tee_gateway/util.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Claude Sonnet 4.6 Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .github/workflows/test.yml | 5 +- tee_gateway/__main__.py | 146 ++--- tee_gateway/price_feed/__init__.py | 7 + tee_gateway/price_feed/config.py | 52 ++ tee_gateway/price_feed/feed.py | 269 +++++++++ tee_gateway/test/test_price_feed.py | 516 ++++++++++++++++++ .../test/test_price_feed_integration.py | 133 +++++ tee_gateway/util.py | 73 +-- tests/test_pricing.py | 54 +- 9 files changed, 1085 insertions(+), 170 deletions(-) create mode 100644 tee_gateway/price_feed/__init__.py create mode 100644 tee_gateway/price_feed/config.py create mode 100644 tee_gateway/price_feed/feed.py create mode 100644 tee_gateway/test/test_price_feed.py create mode 100644 tee_gateway/test/test_price_feed_integration.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 835c886..273d4c9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,4 +15,7 @@ jobs: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 - name: Run unit tests - run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tests/test_pricing.py -v --import-mode=importlib + run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tee_gateway/test/test_price_feed.py tests/test_pricing.py -v --import-mode=importlib + # To also run integration tests (real CoinGecko network calls), add: + # env: + # RUN_INTEGRATION_TESTS: "1" diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index fd36b24..d5662fc 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -34,9 +34,10 @@ from x402.server import x402ResourceServerSync from x402.session import SessionStore import x402.http.middleware.flask as x402_flask -import types as _types -from .util import dynamic_session_cost_calculator +from .util import calculate_session_cost +from .model_registry import get_model_config +from .price_feed import OPGPriceFeed from .definitions import ( EVM_PAYMENT_ADDRESS, BASE_MAINNET_NETWORK, @@ -107,6 +108,13 @@ def _shutdown_heartbeat(): atexit.register(_shutdown_heartbeat) +# --------------------------------------------------------------------------- +# OPG price feed — start before x402 middleware so the first request can be +# priced correctly. Runs as a daemon thread; no cleanup needed on exit. +# --------------------------------------------------------------------------- +_price_feed = OPGPriceFeed() +_price_feed.start() + facilitator = HTTPFacilitatorClientSync(FacilitatorConfig(url=FACILITATOR_URL)) server = x402ResourceServerSync(facilitator) store = SessionStore() @@ -303,6 +311,7 @@ def health(): "status": "OK", "version": "1.0.0", "tee_enabled": True, + "price_feed": _price_feed.get_status(), }, 200 @@ -374,6 +383,26 @@ def _patched_read_body_bytes(environ): x402_flask._read_body_bytes = _patched_read_body_bytes + +def _session_cost_calculator(ctx: dict) -> int: + # Post-inference cost calculation — response already sent to client. + # Predictable failures (unknown price, unknown model) are blocked by the + # pre-inference gate; any exception here indicates a provider-side error + # (e.g. missing usage field in the LLM response). The x402 middleware + # swallows the exception in close(), so the client is not charged. + # Log CRITICAL so provider errors are never silently missed. + try: + return calculate_session_cost(ctx, _price_feed.get_price) + except Exception as exc: + logger.critical( + "Post-inference cost calculation failed (provider error) — " + "client was NOT charged: %s", + exc, + exc_info=True, + ) + raise + + _payment_mw = payment_middleware( application, routes=routes, @@ -381,102 +410,39 @@ def _patched_read_body_bytes(environ): session_store=store, cost_per_request=100000000000000, # static precheck/fallback estimate session_idle_timeout=100, - session_cost_calculator=dynamic_session_cost_calculator, + session_cost_calculator=_session_cost_calculator, ) # --------------------------------------------------------------------------- -# Strict cost-resolution patch -# -# Why this exists -# --------------- -# The upstream x402 PaymentMiddleware._resolve_session_request_cost wraps the -# call to the session_cost_calculator in a broad try/except. If the calculator -# raises (e.g. ValueError for an unrecognised model name, KeyError for missing -# usage data), the exception is swallowed and the middleware silently falls back -# to the static session maximum (CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND / -# CHAT_COMPLETIONS_USDC_AMOUNT). That silent fallback means: -# • The client is charged the full pre-check cap instead of actual usage. -# • The server has no visible indication that pricing failed. +# Pre-inference pricing gate # -# The fix -# ------- -# We replace _resolve_session_request_cost with our own implementation that is -# identical to upstream, except the cost-calculator call is NOT wrapped in a -# try/except. Any exception from dynamic_session_cost_calculator() therefore -# propagates up through the middleware and Flask, producing a proper HTTP 500 -# response to the client instead of an incorrect silent charge. +# In the upto session scheme the response is streamed to the client before +# cost is settled, so a post-inference pricing failure cannot be surfaced as +# an HTTP error. Instead we validate everything that can be checked up-front +# and reject the request early if pricing would fail: +# 1. Price feed has a valid OPG/USD price (CoinGecko fetch succeeded). +# 2. The requested model is in the registry (has a known per-token price). # --------------------------------------------------------------------------- -def _strict_resolve_session_request_cost( - self, - *, - method: str, - path: str, - request_body_bytes: bytes, - response_body_bytes: bytes, - payment_payload: object, - payment_requirements: object, - status_code: int | None, - output_object: object = None, - is_streaming: bool = False, -) -> int: - """Replacement for PaymentMiddleware._resolve_session_request_cost. - - Identical to the upstream implementation except that exceptions raised by - the dynamic cost calculator are NOT caught. This means a request whose - cost cannot be determined (unknown model, missing usage data, etc.) will - result in a 500 error rather than silently falling back to the static cap - amount and charging the user an incorrect amount. - """ - from x402.http.middleware.flask import _parse_json_bytes as _x402_parse_json # noqa: PLC0415 - - default_cost = self._get_session_cost(payment_requirements) - if not self._should_charge_response(status_code): - return default_cost - if not callable(self._session_cost_calculator): - return default_cost - - request_object = _x402_parse_json(request_body_bytes) - response_object = ( - output_object - if output_object is not None - else _x402_parse_json(response_body_bytes) - ) - - callback_context = { - "method": method, - "path": path, - "status_code": status_code, - "is_streaming": is_streaming, - "request_body_bytes": request_body_bytes, - "response_body_bytes": response_body_bytes, - "request_json": request_object - if isinstance(request_object, (dict, list)) - else None, - "response_json": response_object - if isinstance(response_object, (dict, list)) - else None, - "response_object": response_object, - "payment_payload": payment_payload, - "payment_requirements": payment_requirements, - "default_cost": default_cost, - } - - # Do NOT catch exceptions here — let them propagate so the request fails - # with a 500 rather than silently charging the static fallback amount. - dynamic_cost = self._session_cost_calculator(callback_context) - if dynamic_cost is None: - raise ValueError( - f"dynamic_session_cost_calculator returned None for {method} {path}; " - "cannot determine request cost" - ) - return self._coerce_non_negative_int(dynamic_cost) - +@application.before_request +def _check_pricing_ready(): + if request.path not in ("/v1/chat/completions", "/v1/completions"): + return + try: + _price_feed.get_price() + except ValueError as exc: + logger.warning("Rejecting inference request — price feed unavailable: %s", exc) + return jsonify({"error": f"Pricing unavailable: {exc}"}), 503 + + body = request.get_json(silent=True, cache=True) or {} + model = body.get("model") + if model: + try: + get_model_config(model) + except ValueError: + return jsonify({"error": f"Model '{model}' is not supported"}), 400 -_payment_mw._resolve_session_request_cost = _types.MethodType( # type: ignore[method-assign, attr-defined] - _strict_resolve_session_request_cost, _payment_mw -) logger.info("x402 payment middleware initialized") diff --git a/tee_gateway/price_feed/__init__.py b/tee_gateway/price_feed/__init__.py new file mode 100644 index 0000000..1349825 --- /dev/null +++ b/tee_gateway/price_feed/__init__.py @@ -0,0 +1,7 @@ +from .config import PriceFeedConfig +from .feed import OPGPriceFeed + +__all__ = [ + "OPGPriceFeed", + "PriceFeedConfig", +] diff --git a/tee_gateway/price_feed/config.py b/tee_gateway/price_feed/config.py new file mode 100644 index 0000000..a8742dc --- /dev/null +++ b/tee_gateway/price_feed/config.py @@ -0,0 +1,52 @@ +""" +Configuration constants and dataclass for the OPG price feed. +""" + +from dataclasses import dataclass +from datetime import datetime, timezone +from decimal import Decimal + + +# --------------------------------------------------------------------------- +# CoinGecko API +# --------------------------------------------------------------------------- +COINGECKO_BASE_URL = "https://api.coingecko.com/api/v3" +COINGECKO_PLATFORM = "base" # Base mainnet platform identifier on CoinGecko +FETCH_TIMEOUT = 10 # seconds per HTTP request + +# --------------------------------------------------------------------------- +# Refresh / retry defaults +# --------------------------------------------------------------------------- +DEFAULT_REFRESH_INTERVAL = 300 # 5 minutes between background refresh cycles +DEFAULT_MAX_RETRIES = 3 # attempts per refresh cycle before giving up +DEFAULT_RETRY_DELAY = 10 # seconds between retry attempts within a cycle + +# --------------------------------------------------------------------------- +# TGE (Token Generation Event) fallback +# --------------------------------------------------------------------------- +# Before the TGE cutover, OPG is not yet listed on CoinGecko. Return a fixed +# fallback price so inference requests can be priced immediately at launch. +# After the cutover, the live CoinGecko price is used. +TGE_CUTOVER_UTC = datetime(2026, 4, 21, 12, 30, 0, tzinfo=timezone.utc) +TGE_FALLBACK_PRICE_USD = Decimal("0.10") + +# --------------------------------------------------------------------------- +# Stale-price thresholds +# --------------------------------------------------------------------------- +# get_price() logs WARNING when last successful fetch is older than +# STALE_WARNING_MULTIPLIER × refresh_interval seconds. +STALE_WARNING_MULTIPLIER = 2 + +# get_price() raises ValueError when last successful fetch is older than +# STALE_PRICE_MAX_AGE seconds — at this point the cached price is considered +# too outdated to use for billing. +STALE_PRICE_MAX_AGE = 4 * 60 * 60 # 4 hours + + +@dataclass(frozen=True) +class PriceFeedConfig: + """Runtime configuration for the OPG price feed background service.""" + + refresh_interval: int = DEFAULT_REFRESH_INTERVAL + max_retries: int = DEFAULT_MAX_RETRIES + retry_delay: float = DEFAULT_RETRY_DELAY diff --git a/tee_gateway/price_feed/feed.py b/tee_gateway/price_feed/feed.py new file mode 100644 index 0000000..6875f68 --- /dev/null +++ b/tee_gateway/price_feed/feed.py @@ -0,0 +1,269 @@ +""" +Background OPG/USD price feed using the CoinGecko public API. + +Runs as a daemon thread that proactively refreshes the OPG token price at a +configurable interval, with retry on per-cycle fetch failure and early exit on +rate limiting. + +Usage +----- +Create an ``OPGPriceFeed`` instance in the application entry point, call +``start()``, then pass it explicitly to wherever the price is needed (e.g. +``calculate_session_cost(...)`` in ``util.py``). +""" + +import logging +import threading +import time +from datetime import datetime, timezone +from decimal import Decimal +from typing import Any, Optional + +import requests + +from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS +from tee_gateway.price_feed.config import ( + COINGECKO_BASE_URL, + COINGECKO_PLATFORM, + DEFAULT_MAX_RETRIES, + DEFAULT_REFRESH_INTERVAL, + DEFAULT_RETRY_DELAY, + FETCH_TIMEOUT, + STALE_PRICE_MAX_AGE, + STALE_WARNING_MULTIPLIER, + TGE_CUTOVER_UTC, + TGE_FALLBACK_PRICE_USD, +) + +logger = logging.getLogger("llm_server.price_feed") + + +class OPGPriceFeed: + """Fetches and caches the OPG/USD price from CoinGecko in a background thread.""" + + def __init__( + self, + refresh_interval: int = DEFAULT_REFRESH_INTERVAL, + max_retries: int = DEFAULT_MAX_RETRIES, + retry_delay: float = DEFAULT_RETRY_DELAY, + ) -> None: + self._refresh_interval = refresh_interval + self._max_retries = max_retries + self._retry_delay = retry_delay + + self._price: Optional[Decimal] = None + self._lock = threading.Lock() + self._thread: Optional[threading.Thread] = None + + # Status tracking — updated under _lock on every refresh cycle outcome. + self.last_success: Optional[float] = None # epoch seconds of last good fetch + self.last_error: Optional[str] = None # description of last failure (if any) + self.consecutive_failures: int = 0 # reset to 0 on any successful fetch + self.total_fetches: int = 0 # cumulative successful fetches + self.total_errors: int = 0 # cumulative failed refresh cycles + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + def start(self) -> None: + """Launch the background refresh loop, including the initial price fetch. + + The initial fetch runs inside the background thread so startup is + non-blocking. ``get_price()`` will raise ``ValueError`` until the + first successful fetch completes; until then, inference requests are + rejected by the pre-inference pricing gate in ``__main__.py``. + + Idempotent — calling ``start()`` on an already-running feed is a no-op. + Thread-safe: the check-and-start is performed under ``_lock``. + """ + with self._lock: + if self._thread is not None and self._thread.is_alive(): + logger.info( + "OPG price feed already running, ignoring duplicate start()" + ) + return + self._thread = threading.Thread( + target=self._run_with_initial_fetch, + name="opg-price-feed", + daemon=True, + ) + self._thread.start() + logger.info( + "OPG price feed started (refresh_interval=%ds, max_retries=%d)", + self._refresh_interval, + self._max_retries, + ) + + def get_price(self) -> Decimal: + """Return the latest cached OPG/USD price. + + Before the TGE cutover (``TGE_CUTOVER_UTC``), returns the fixed + ``TGE_FALLBACK_PRICE_USD`` so requests can be priced before OPG is + listed on CoinGecko. After the cutover the live cached price is used. + + Raises ``ValueError`` if no price has been successfully fetched yet + (post-TGE only). Logs a warning (but still returns the price) if the + cached value is older than ``STALE_WARNING_MULTIPLIER * refresh_interval`` + seconds — this indicates the background loop has missed at least one + refresh cycle and may be experiencing persistent errors. + """ + if datetime.now(timezone.utc) < TGE_CUTOVER_UTC: + return TGE_FALLBACK_PRICE_USD + + now = time.time() + with self._lock: + if self._price is None: + raise ValueError( + "OPG price not yet available — " + "price feed has not completed a successful fetch" + ) + if self.last_success is None: + raise ValueError( + "OPG price not yet available — " + "price feed has not completed a successful fetch" + ) + age = now - self.last_success + if age > STALE_PRICE_MAX_AGE: + raise ValueError( + f"OPG price data expired: last successful fetch was {age:.0f}s ago " + f"(max: {STALE_PRICE_MAX_AGE}s); consecutive failures: " + f"{self.consecutive_failures}" + ) + stale_threshold = self._refresh_interval * STALE_WARNING_MULTIPLIER + if age > stale_threshold: + logger.warning( + "OPG price data is stale: last successful fetch was %.0fs ago " + "(threshold: %.0fs); consecutive failures: %d", + age, + stale_threshold, + self.consecutive_failures, + ) + return self._price + + def get_status(self) -> dict[str, Any]: + """Return a health snapshot suitable for logging or a /health endpoint.""" + with self._lock: + return { + "price_usd": float(self._price) if self._price is not None else None, + "last_success": self.last_success, + "last_error": self.last_error, + "consecutive_failures": self.consecutive_failures, + "total_fetches": self.total_fetches, + "total_errors": self.total_errors, + "refresh_interval": self._refresh_interval, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _run_with_initial_fetch(self) -> None: + self._refresh_price() + while True: + time.sleep(self._refresh_interval) + self._refresh_price() + + def _refresh_price(self) -> None: + """Attempt to fetch a fresh price, retrying on transient failure. + + - On success: updates the cached price and resets ``consecutive_failures``. + - On HTTP 429: logs a rate-limit warning and exits the retry loop early + (no point hammering a rate-limited API). + - On exhausted retries: increments ``consecutive_failures`` and retains + the last known good price so live traffic is not disrupted by a + transient CoinGecko outage. + """ + last_exc: Optional[Exception] = None + + for attempt in range(1, self._max_retries + 1): + try: + price = fetch_opg_price() + with self._lock: + self._price = price + self.last_success = time.time() + self.last_error = None + self.consecutive_failures = 0 + self.total_fetches += 1 + logger.info( + "OPG price updated: $%.6f USD (attempt %d/%d)", + float(price), + attempt, + self._max_retries, + ) + return + except requests.exceptions.HTTPError as exc: + last_exc = exc + status_code = ( + exc.response.status_code if exc.response is not None else None + ) + if status_code == 429: + logger.warning( + "CoinGecko rate limit hit (429) on attempt %d/%d; " + "skipping remaining retries for this cycle", + attempt, + self._max_retries, + ) + break + logger.warning( + "OPG price fetch attempt %d/%d failed (HTTP %s): %s", + attempt, + self._max_retries, + status_code, + exc, + ) + except Exception as exc: + last_exc = exc + logger.warning( + "OPG price fetch attempt %d/%d failed: %s", + attempt, + self._max_retries, + exc, + ) + + if attempt < self._max_retries: + time.sleep(self._retry_delay) + + # All attempts exhausted (or rate-limited out) — record the failure. + with self._lock: + self.total_errors += 1 + self.consecutive_failures += 1 + self.last_error = str(last_exc) if last_exc is not None else "unknown error" + + logger.error( + "OPG price refresh failed (consecutive failures: %d); " + "retaining last known price (%s)", + self.consecutive_failures, + self._price, + ) + + +def fetch_opg_price() -> Decimal: + """Fetch the current OPG/USD price from CoinGecko. Raises on any error.""" + url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" + params = { + "contract_addresses": BASE_MAINNET_OPG_ADDRESS, + "vs_currencies": "usd", + } + response = requests.get(url, params=params, timeout=FETCH_TIMEOUT) + response.raise_for_status() + + data: Any = response.json() + if not isinstance(data, dict): + raise ValueError( + f"Unexpected CoinGecko response for {BASE_MAINNET_OPG_ADDRESS}: {data!r}" + ) + # CoinGecko keys the result by the lowercased contract address. + price_entry = data.get(BASE_MAINNET_OPG_ADDRESS.lower()) + if not isinstance(price_entry, dict) or "usd" not in price_entry: + raise ValueError( + f"Unexpected CoinGecko response for {BASE_MAINNET_OPG_ADDRESS}: {data!r}" + ) + + price = Decimal(str(price_entry["usd"])) + if not price.is_finite() or price <= 0: + raise ValueError( + f"Invalid price from CoinGecko for {BASE_MAINNET_OPG_ADDRESS}: " + f"{price_entry['usd']!r}" + ) + return price diff --git a/tee_gateway/test/test_price_feed.py b/tee_gateway/test/test_price_feed.py new file mode 100644 index 0000000..fe9bb8a --- /dev/null +++ b/tee_gateway/test/test_price_feed.py @@ -0,0 +1,516 @@ +""" +Unit tests for tee_gateway.price_feed and tee_gateway.util.calculate_session_cost. + +All external HTTP calls are mocked — no network access required. + +Test classes +------------ +TestFetchOPGPrice — the raw fetch_opg_price() helper in feed.py +TestOPGPriceFeedRefresh — OPGPriceFeed._refresh_price() (retry, rate-limit, stats) +TestOPGPriceFeedGetPrice — OPGPriceFeed.get_price() (stale warning, ValueError before fetch) +TestOPGPriceFeedStatus — OPGPriceFeed.get_status() snapshots +TestCalculateSessionCost — calculate_session_cost(context, get_price) in util.py +""" + +import time +import unittest +from datetime import datetime, timezone +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import requests + +from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS +from tee_gateway.price_feed import OPGPriceFeed +from tee_gateway.price_feed.feed import fetch_opg_price +from tee_gateway.util import calculate_session_cost + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +OPG_ADDRESS_LOWER = BASE_MAINNET_OPG_ADDRESS.lower() +SAMPLE_PRICE = Decimal("0.042") +SAMPLE_PRICE_FLOAT = 0.042 + +# Patch target prefix — all mocks go through the feed module. +_FEED = "tee_gateway.price_feed.feed" + +# A datetime well after the TGE cutover so get_price() uses the cached price. +_POST_TGE = datetime(2026, 4, 22, 0, 0, 0, tzinfo=timezone.utc) + + +def _mock_response(status_code: int = 200, json_body: dict | None = None) -> MagicMock: + """Build a minimal mock requests.Response.""" + mock = MagicMock() + mock.status_code = status_code + mock.json.return_value = json_body or {} + if status_code >= 400: + http_err = requests.exceptions.HTTPError(response=mock) + mock.raise_for_status.side_effect = http_err + else: + mock.raise_for_status.return_value = None + return mock + + +def _coingecko_success_body() -> dict: + return {OPG_ADDRESS_LOWER: {"usd": SAMPLE_PRICE_FLOAT}} + + +# --------------------------------------------------------------------------- +# TestFetchOPGPrice +# --------------------------------------------------------------------------- + + +class TestFetchOPGPrice(unittest.TestCase): + """Tests for the fetch_opg_price() free function in feed.py.""" + + @patch(f"{_FEED}.requests.get") + def test_happy_path_returns_decimal(self, mock_get): + mock_get.return_value = _mock_response(200, _coingecko_success_body()) + price = fetch_opg_price() + self.assertIsInstance(price, Decimal) + self.assertEqual(price, Decimal(str(SAMPLE_PRICE_FLOAT))) + + @patch(f"{_FEED}.requests.get") + def test_passes_correct_params(self, mock_get): + mock_get.return_value = _mock_response(200, _coingecko_success_body()) + fetch_opg_price() + _, kwargs = mock_get.call_args + self.assertIn("contract_addresses", kwargs["params"]) + self.assertEqual(kwargs["params"]["vs_currencies"], "usd") + self.assertIn( + "base", kwargs["url"] if "url" in kwargs else mock_get.call_args[0][0] + ) + + @patch(f"{_FEED}.requests.get") + def test_raises_on_http_500(self, mock_get): + mock_get.return_value = _mock_response(500) + with self.assertRaises(requests.exceptions.HTTPError): + fetch_opg_price() + + @patch(f"{_FEED}.requests.get") + def test_raises_on_http_429(self, mock_get): + mock_get.return_value = _mock_response(429) + with self.assertRaises(requests.exceptions.HTTPError) as ctx: + fetch_opg_price() + self.assertEqual(ctx.exception.response.status_code, 429) + + @patch(f"{_FEED}.requests.get") + def test_raises_on_empty_response_body(self, mock_get): + mock_get.return_value = _mock_response(200, {}) + with self.assertRaises(ValueError, msg="should raise when address key absent"): + fetch_opg_price() + + @patch(f"{_FEED}.requests.get") + def test_raises_when_usd_key_missing(self, mock_get): + mock_get.return_value = _mock_response(200, {OPG_ADDRESS_LOWER: {"eur": 0.04}}) + with self.assertRaises(ValueError): + fetch_opg_price() + + @patch(f"{_FEED}.requests.get") + def test_raises_when_address_entry_is_empty_dict(self, mock_get): + """CoinGecko returns {address: {}} for known-but-unpriced tokens (current OPG behaviour).""" + mock_get.return_value = _mock_response(200, {OPG_ADDRESS_LOWER: {}}) + with self.assertRaises(ValueError, msg="empty price entry should raise"): + fetch_opg_price() + + @patch(f"{_FEED}.requests.get") + def test_raises_on_network_error(self, mock_get): + mock_get.side_effect = requests.exceptions.ConnectionError("timeout") + with self.assertRaises(requests.exceptions.ConnectionError): + fetch_opg_price() + + +# --------------------------------------------------------------------------- +# TestOPGPriceFeedRefresh +# --------------------------------------------------------------------------- + + +class TestOPGPriceFeedRefresh(unittest.TestCase): + """Tests for OPGPriceFeed._refresh_price() — retry logic, rate-limit, stats.""" + + def _feed(self, **kwargs) -> OPGPriceFeed: + defaults = {"refresh_interval": 300, "max_retries": 3, "retry_delay": 0} + defaults.update(kwargs) + return OPGPriceFeed(**defaults) + + @patch(f"{_FEED}.fetch_opg_price") + def test_successful_refresh_sets_price(self, mock_fetch): + mock_fetch.return_value = SAMPLE_PRICE + feed = self._feed() + feed._refresh_price() + self.assertEqual(feed._price, SAMPLE_PRICE) + + @patch(f"{_FEED}.fetch_opg_price") + def test_successful_refresh_updates_stats(self, mock_fetch): + mock_fetch.return_value = SAMPLE_PRICE + feed = self._feed() + feed._refresh_price() + self.assertEqual(feed.total_fetches, 1) + self.assertEqual(feed.total_errors, 0) + self.assertEqual(feed.consecutive_failures, 0) + self.assertIsNotNone(feed.last_success) + + @patch(f"{_FEED}.fetch_opg_price") + def test_retry_on_transient_failure_then_success(self, mock_fetch): + mock_fetch.side_effect = [ + ValueError("transient"), + ValueError("transient"), + SAMPLE_PRICE, + ] + feed = self._feed(max_retries=3, retry_delay=0) + feed._refresh_price() + self.assertEqual(feed._price, SAMPLE_PRICE) + self.assertEqual(mock_fetch.call_count, 3) + self.assertEqual(feed.total_fetches, 1) + self.assertEqual(feed.total_errors, 0) + + @patch(f"{_FEED}.fetch_opg_price") + def test_exhausted_retries_records_error_stats(self, mock_fetch): + mock_fetch.side_effect = ValueError("always fails") + feed = self._feed(max_retries=3, retry_delay=0) + feed._refresh_price() + self.assertEqual(feed.total_errors, 1) + self.assertEqual(feed.consecutive_failures, 1) + self.assertIsNotNone(feed.last_error) + self.assertEqual(feed.total_fetches, 0) + + @patch(f"{_FEED}.fetch_opg_price") + def test_exhausted_retries_keeps_last_known_price(self, mock_fetch): + feed = self._feed(max_retries=2, retry_delay=0) + feed._price = SAMPLE_PRICE + feed.last_success = time.time() + mock_fetch.side_effect = ValueError("fail") + feed._refresh_price() + self.assertEqual(feed._price, SAMPLE_PRICE) + + @patch(f"{_FEED}.fetch_opg_price") + def test_success_after_failures_resets_consecutive_failures(self, mock_fetch): + feed = self._feed(max_retries=1, retry_delay=0) + mock_fetch.side_effect = ValueError("fail") + feed._refresh_price() + self.assertEqual(feed.consecutive_failures, 1) + mock_fetch.side_effect = None + mock_fetch.return_value = SAMPLE_PRICE + feed._refresh_price() + self.assertEqual(feed.consecutive_failures, 0) + + @patch(f"{_FEED}.fetch_opg_price") + def test_rate_limit_breaks_retry_loop_immediately(self, mock_fetch): + resp = MagicMock() + resp.status_code = 429 + mock_fetch.side_effect = requests.exceptions.HTTPError(response=resp) + feed = self._feed(max_retries=3, retry_delay=0) + feed._refresh_price() + self.assertEqual(mock_fetch.call_count, 1) + self.assertEqual(feed.total_errors, 1) + + @patch(f"{_FEED}.time.sleep") + @patch(f"{_FEED}.fetch_opg_price") + def test_retry_delay_called_between_attempts(self, mock_fetch, mock_sleep): + mock_fetch.side_effect = [ValueError("fail"), ValueError("fail"), SAMPLE_PRICE] + feed = self._feed(max_retries=3, retry_delay=5) + feed._refresh_price() + self.assertEqual(mock_sleep.call_count, 2) + mock_sleep.assert_called_with(5) + + @patch(f"{_FEED}.time.sleep") + @patch(f"{_FEED}.fetch_opg_price") + def test_no_sleep_after_last_failed_attempt(self, mock_fetch, mock_sleep): + mock_fetch.side_effect = ValueError("always fails") + feed = self._feed(max_retries=3, retry_delay=5) + feed._refresh_price() + self.assertEqual(mock_sleep.call_count, 2) + + +# --------------------------------------------------------------------------- +# TestOPGPriceFeedGetPrice +# --------------------------------------------------------------------------- + + +class TestOPGPriceFeedGetPrice(unittest.TestCase): + """Tests for OPGPriceFeed.get_price() behaviour.""" + + @patch(f"{_FEED}.datetime") + def test_raises_before_any_successful_fetch(self, mock_dt): + mock_dt.now.return_value = _POST_TGE + feed = OPGPriceFeed() + with self.assertRaises(ValueError) as ctx: + feed.get_price() + self.assertIn("not yet available", str(ctx.exception)) + + @patch(f"{_FEED}.datetime") + @patch(f"{_FEED}.fetch_opg_price") + def test_returns_price_after_successful_refresh(self, mock_fetch, mock_dt): + mock_dt.now.return_value = _POST_TGE + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(retry_delay=0) + feed._refresh_price() + self.assertEqual(feed.get_price(), SAMPLE_PRICE) + + @patch(f"{_FEED}.datetime") + @patch(f"{_FEED}.time.time") + @patch(f"{_FEED}.fetch_opg_price") + def test_warns_when_price_is_stale(self, mock_fetch, mock_time, mock_dt): + mock_dt.now.return_value = _POST_TGE + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(refresh_interval=300, retry_delay=0) + + mock_time.return_value = 0.0 + feed._refresh_price() + + # Advance past stale threshold (300 * 2 = 600s) + mock_time.return_value = 601.0 + + with self.assertLogs("llm_server.price_feed", level="WARNING") as log_ctx: + price = feed.get_price() + + self.assertEqual(price, SAMPLE_PRICE) + self.assertTrue(any("stale" in line.lower() for line in log_ctx.output)) + + @patch(f"{_FEED}.datetime") + @patch(f"{_FEED}.time.time") + @patch(f"{_FEED}.fetch_opg_price") + def test_raises_when_price_exceeds_max_age(self, mock_fetch, mock_time, mock_dt): + mock_dt.now.return_value = _POST_TGE + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(retry_delay=0) + + mock_time.return_value = 0.0 + feed._refresh_price() + + # Advance past the 4-hour max age + mock_time.return_value = 4 * 60 * 60 + 1.0 + + with self.assertRaises(ValueError) as ctx: + feed.get_price() + self.assertIn("expired", str(ctx.exception)) + + @patch(f"{_FEED}.time.time") + @patch(f"{_FEED}.fetch_opg_price") + def test_no_stale_warning_when_price_is_fresh(self, mock_fetch, mock_time): + import logging + + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(refresh_interval=300, retry_delay=0) + + mock_time.return_value = 0.0 + feed._refresh_price() + mock_time.return_value = 100.0 # well within threshold + + with self.assertLogs("llm_server.price_feed", level="DEBUG") as log_ctx: + logging.getLogger("llm_server.price_feed").debug("sentinel") + feed.get_price() + + warning_lines = [ + line + for line in log_ctx.output + if "WARNING" in line and "stale" in line.lower() + ] + self.assertEqual(warning_lines, []) + + +# --------------------------------------------------------------------------- +# TestOPGPriceFeedStatus +# --------------------------------------------------------------------------- + + +class TestOPGPriceFeedStatus(unittest.TestCase): + """Tests for OPGPriceFeed.get_status() snapshot.""" + + def test_initial_status_has_no_price(self): + feed = OPGPriceFeed() + status = feed.get_status() + self.assertIsNone(status["price_usd"]) + self.assertIsNone(status["last_success"]) + self.assertEqual(status["consecutive_failures"], 0) + self.assertEqual(status["total_fetches"], 0) + self.assertEqual(status["total_errors"], 0) + + @patch(f"{_FEED}.fetch_opg_price") + def test_status_reflects_successful_fetch(self, mock_fetch): + mock_fetch.return_value = SAMPLE_PRICE + feed = OPGPriceFeed(retry_delay=0) + feed._refresh_price() + status = feed.get_status() + self.assertAlmostEqual(status["price_usd"], float(SAMPLE_PRICE), places=6) + self.assertIsNotNone(status["last_success"]) + self.assertEqual(status["total_fetches"], 1) + self.assertEqual(status["consecutive_failures"], 0) + + @patch(f"{_FEED}.fetch_opg_price") + def test_status_reflects_failed_cycle(self, mock_fetch): + mock_fetch.side_effect = ValueError("fail") + feed = OPGPriceFeed(max_retries=1, retry_delay=0) + feed._refresh_price() + status = feed.get_status() + self.assertIsNone(status["price_usd"]) + self.assertEqual(status["total_errors"], 1) + self.assertEqual(status["consecutive_failures"], 1) + self.assertIsNotNone(status["last_error"]) + + def test_status_includes_refresh_interval(self): + feed = OPGPriceFeed(refresh_interval=600) + self.assertEqual(feed.get_status()["refresh_interval"], 600) + + @patch(f"{_FEED}.fetch_opg_price") + def test_status_accumulates_multiple_error_cycles(self, mock_fetch): + mock_fetch.side_effect = ValueError("fail") + feed = OPGPriceFeed(max_retries=1, retry_delay=0) + feed._refresh_price() + feed._refresh_price() + feed._refresh_price() + status = feed.get_status() + self.assertEqual(status["total_errors"], 3) + self.assertEqual(status["consecutive_failures"], 3) + + +# --------------------------------------------------------------------------- +# TestMakeCostCalculator +# --------------------------------------------------------------------------- + +_ASSET_ADDR = "0xdeadbeef" +_ASSET_ADDR_LOWER = _ASSET_ADDR.lower() +_ASSET_DECIMALS = 18 + + +def _make_payment_requirements(asset: str = _ASSET_ADDR) -> dict: + return {"asset": asset, "price": {"amount": "1000000000000000000", "asset": asset}} + + +def _make_context( + model: str = "gpt-4.1-mini", + input_tokens: int = 100, + output_tokens: int = 50, + price_usd: Decimal = Decimal("0.10"), + asset: str = _ASSET_ADDR, +) -> dict: + return { + "request_json": {"model": model}, + "response_json": { + "model": model, + "usage": { + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + }, + }, + "payment_requirements": _make_payment_requirements(asset), + "method": "POST", + "path": "/v1/chat/completions", + "status_code": 200, + "is_streaming": False, + "request_body_bytes": b"", + "response_body_bytes": b"", + "default_cost": 10**18, + } + + +def _make_get_price(price_usd: Decimal = Decimal("0.10")) -> MagicMock: + mock = MagicMock(return_value=price_usd) + return mock + + +class TestCalculateSessionCost(unittest.TestCase): + """Tests for calculate_session_cost(context, get_price).""" + + def _patch_definitions(self): + return patch( + "tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", + {_ASSET_ADDR_LOWER: _ASSET_DECIMALS}, + ) + + def _patch_model( + self, input_price: str = "0.000001", output_price: str = "0.000002" + ): + cfg = MagicMock() + cfg.input_price_usd = Decimal(input_price) + cfg.output_price_usd = Decimal(output_price) + return patch("tee_gateway.util.get_model_config", return_value=cfg) + + def test_calls_get_price(self): + get_price = _make_get_price() + with self._patch_definitions(), self._patch_model(): + calculate_session_cost(_make_context(), get_price) + get_price.assert_called_once() + + def test_returns_positive_int(self): + with self._patch_definitions(), self._patch_model(): + result = calculate_session_cost(_make_context(), _make_get_price()) + self.assertIsInstance(result, int) + self.assertGreaterEqual(result, 0) + + def test_zero_tokens_returns_zero(self): + with self._patch_definitions(), self._patch_model(): + result = calculate_session_cost( + _make_context(input_tokens=0, output_tokens=0), _make_get_price() + ) + self.assertEqual(result, 0) + + def test_raises_when_get_price_raises(self): + get_price = MagicMock(side_effect=ValueError("price not available")) + with self._patch_definitions(), self._patch_model(): + with self.assertRaises(ValueError): + calculate_session_cost(_make_context(), get_price) + + def test_raises_when_non_positive_price(self): + with self._patch_definitions(), self._patch_model(): + with self.assertRaises(ValueError): + calculate_session_cost(_make_context(), _make_get_price(Decimal("0"))) + + def test_raises_when_request_json_missing(self): + ctx = _make_context() + ctx["request_json"] = None + with self._patch_definitions(), self._patch_model(): + with self.assertRaises(ValueError): + calculate_session_cost(ctx, _make_get_price()) + + def test_raises_when_usage_missing(self): + ctx = _make_context() + ctx["response_json"] = {"model": "gpt-4.1-mini"} + with self._patch_definitions(), self._patch_model(): + with self.assertRaises(ValueError): + calculate_session_cost(ctx, _make_get_price()) + + def test_raises_when_asset_unknown(self): + ctx = _make_context(asset="0xunknown") + with ( + patch("tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", {}), + self._patch_model(), + ): + with self.assertRaises(ValueError): + calculate_session_cost(ctx, _make_get_price()) + + def test_cost_scales_with_token_count(self): + with self._patch_definitions(), self._patch_model(): + cost_small = calculate_session_cost( + _make_context(input_tokens=10, output_tokens=5), _make_get_price() + ) + cost_large = calculate_session_cost( + _make_context(input_tokens=1000, output_tokens=500), _make_get_price() + ) + self.assertGreater(cost_large, cost_small) + + def test_higher_token_price_yields_lower_cost(self): + with self._patch_definitions(), self._patch_model(): + cost_cheap = calculate_session_cost( + _make_context(), _make_get_price(Decimal("0.10")) + ) + cost_expensive = calculate_session_cost( + _make_context(), _make_get_price(Decimal("0.20")) + ) + self.assertGreater(cost_cheap, cost_expensive) + + def test_uses_current_price_on_each_call(self): + """get_price is called fresh every invocation — price changes are picked up.""" + get_price = MagicMock(side_effect=[Decimal("0.10"), Decimal("0.20")]) + with self._patch_definitions(), self._patch_model(): + cost_first = calculate_session_cost(_make_context(), get_price) + cost_second = calculate_session_cost(_make_context(), get_price) + self.assertEqual(get_price.call_count, 2) + # Price doubled → cost should halve (same USD spend, twice the token price). + self.assertGreater(cost_first, cost_second) + + +if __name__ == "__main__": + unittest.main() diff --git a/tee_gateway/test/test_price_feed_integration.py b/tee_gateway/test/test_price_feed_integration.py new file mode 100644 index 0000000..2da9db1 --- /dev/null +++ b/tee_gateway/test/test_price_feed_integration.py @@ -0,0 +1,133 @@ +""" +Integration tests for tee_gateway.price_feed. + +These tests make REAL network calls to the CoinGecko public API. + +Expected behaviour +------------------ +* ``TestCoinGeckoConnectivity`` — passes when the CoinGecko API is reachable. + Skips on network errors or rate-limiting (429). +* ``TestOPGPriceFetchLive`` — skips when OPG is not yet priced on CoinGecko's + Base platform (CoinGecko currently returns an empty price entry for the + token). Will pass automatically once the token is fully listed. + +Run with:: + + uv run pytest tee_gateway/test/test_price_feed_integration.py -v +""" + +import os +import unittest +from decimal import Decimal + +import requests + +if not os.getenv("RUN_INTEGRATION_TESTS"): + raise unittest.SkipTest("Set RUN_INTEGRATION_TESTS=1 to run integration tests") + +from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS +from tee_gateway.price_feed.config import ( + COINGECKO_BASE_URL, + COINGECKO_PLATFORM, + FETCH_TIMEOUT, +) +from tee_gateway.price_feed.feed import fetch_opg_price + + +def _get(url: str, **kwargs) -> requests.Response: + """Wrapper that skips the test on network errors or rate-limiting.""" + try: + resp = requests.get(url, timeout=FETCH_TIMEOUT, **kwargs) + except requests.exceptions.RequestException as exc: + raise unittest.SkipTest(f"Network unavailable: {exc}") from exc + if resp.status_code == 429: + raise unittest.SkipTest( + "CoinGecko rate limit hit (429) — re-run after a short wait" + ) + return resp + + +class TestCoinGeckoConnectivity(unittest.TestCase): + """Verify that the CoinGecko API endpoint is reachable and well-formed.""" + + def test_ping_endpoint_reachable(self): + """CoinGecko /ping should return {gecko_says: ...}.""" + resp = _get(f"{COINGECKO_BASE_URL}/ping") + self.assertEqual(resp.status_code, 200) + self.assertIn("gecko_says", resp.json()) + + def test_base_platform_endpoint_returns_200(self): + """The token_price/base endpoint should respond with HTTP 200 for a known token.""" + url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" + # USDC on Base mainnet — reliably indexed on CoinGecko. + usdc_base = "0x833589fcd6edb6e08f4c7c32d4f71b54bda02913" + resp = _get( + url, params={"contract_addresses": usdc_base, "vs_currencies": "usd"} + ) + self.assertEqual( + resp.status_code, + 200, + f"Expected 200 from CoinGecko, got {resp.status_code}: {resp.text[:200]}", + ) + data = resp.json() + self.assertIsInstance(data, dict) + self.assertIn(usdc_base, data, "USDC should be indexed on Base platform") + self.assertIn("usd", data[usdc_base], "USDC price entry should have 'usd' key") + + +class TestOPGPriceFetchLive(unittest.TestCase): + """Live fetch of the OPG token price. + + Both tests skip gracefully when OPG is not yet fully priced on CoinGecko + (currently returns ``{address: {}}`` with no 'usd' key). They will pass + automatically once the token is listed with a live price. + """ + + def test_opg_response_structure(self): + """Inspect the raw CoinGecko response for the OPG contract address.""" + url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" + resp = _get( + url, + params={ + "contract_addresses": BASE_MAINNET_OPG_ADDRESS, + "vs_currencies": "usd", + }, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + print(f"\nCoinGecko response for OPG ({BASE_MAINNET_OPG_ADDRESS}): {data}") # noqa: T201 + + opg_lower = BASE_MAINNET_OPG_ADDRESS.lower() + price_entry = data.get(opg_lower) + # CoinGecko returns the address key with {} when the token is known but + # not yet priced — skip in that case rather than fail. + if not price_entry or "usd" not in price_entry: + self.skipTest( + f"OPG not yet priced on CoinGecko Base platform " + f"(response: {data!r}). Will pass once the token is fully listed." + ) + self.assertIsInstance(price_entry["usd"], (int, float)) + + def test_opg_price_fetch_live(self): + """End-to-end: fetch_opg_price() returns a positive Decimal price.""" + try: + price = fetch_opg_price() + except requests.exceptions.HTTPError as exc: + if exc.response is not None and exc.response.status_code == 429: + self.skipTest("CoinGecko rate limit — re-run after a short wait") + raise + except ValueError as exc: + if "Unexpected CoinGecko response" in str(exc): + self.skipTest( + f"OPG ({BASE_MAINNET_OPG_ADDRESS}) not yet priced on " + f"CoinGecko Base platform. Details: {exc}" + ) + raise + + self.assertIsInstance(price, Decimal) + self.assertGreater(price, Decimal("0"), "Price must be positive") + print(f"\nLive OPG price: ${price} USD") # noqa: T201 + + +if __name__ == "__main__": + unittest.main() diff --git a/tee_gateway/util.py b/tee_gateway/util.py index 47559d9..ac79cd6 100644 --- a/tee_gateway/util.py +++ b/tee_gateway/util.py @@ -2,10 +2,8 @@ from tee_gateway import typing_utils import logging -import threading -import time from decimal import Decimal, InvalidOperation, ROUND_CEILING -from typing import Any +from typing import Any, Callable logger = logging.getLogger("llm_server.dynamic_pricing") @@ -160,43 +158,6 @@ def _deserialize_dict(data, boxed_type): ) from tee_gateway.model_registry import get_model_config # noqa: E402 -TOKEN_A_PRICE_CACHE_TTL_SECONDS = 60 - -_token_price_cache: dict[str, Any] = { - "value": Decimal("1"), - "updated_at": 0.0, -} -_token_price_lock = threading.Lock() - - -def _fetch_token_a_price_usd_mock() -> Decimal: - """Return the USD price of the payment token used for cost calculation. - - Currently returns a fixed 1:1 ratio, which is correct for USDC-denominated - payments (1 USDC ≈ $1 USD). For OPG-denominated payments, replace this - with a live price feed (e.g. a DEX oracle or CoinGecko API call) that - returns the current OPG/USD exchange rate so that token amounts are - calculated correctly against the model's USD pricing. - """ - return Decimal("1") - - -def get_token_a_price_usd() -> Decimal: - now = time.time() - with _token_price_lock: - cached_value = _token_price_cache.get("value") - cached_at = float(_token_price_cache.get("updated_at") or 0.0) - if ( - isinstance(cached_value, Decimal) - and (now - cached_at) < TOKEN_A_PRICE_CACHE_TTL_SECONDS - ): - return cached_value - - value = _fetch_token_a_price_usd_mock() - _token_price_cache["value"] = value - _token_price_cache["updated_at"] = now - return value - def _as_dict(value: Any) -> dict[str, Any] | None: if value is None: @@ -304,35 +265,36 @@ def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int: return ASSET_DECIMALS_BY_ADDRESS[asset_lower] -def dynamic_session_cost_calculator(context: dict[str, Any]) -> int: - """Compute UPTO per-request cost in token smallest units from actual usage. +def calculate_session_cost( + context: dict[str, Any], get_price: Callable[[], Decimal] +) -> int: + """Calculate the x402 session cost in token smallest units for a completed request. - Raises ValueError on any missing or unrecognised input — no silent fallback. + ``get_price`` is called on every invocation to fetch the current OPG/USD + price — pass ``price_feed.get_price`` so the latest cached value is used. + Raises ``ValueError`` on any missing/invalid data. Predictable failures + (unavailable price, unknown model) are blocked before inference by the + pre-inference gate in ``__main__.py``; post-inference failures are logged + as CRITICAL by the caller and the client is not charged. """ request_json = context.get("request_json") response_json = context.get("response_json") if not isinstance(request_json, dict) or not isinstance(response_json, dict): raise ValueError( - "dynamic_session_cost_calculator requires both request_json and response_json" + "calculate_session_cost requires both request_json and response_json" ) model = _extract_model_from_context(request_json, response_json) - - # get_model_config raises ValueError for unknown models — no fallback cfg = get_model_config(model) - input_tokens, output_tokens = _extract_usage_tokens(response_json) - input_rate = cfg.input_price_usd - output_rate = cfg.output_price_usd - - total_usd = (Decimal(input_tokens) * input_rate) + ( - Decimal(output_tokens) * output_rate + total_usd = (Decimal(input_tokens) * cfg.input_price_usd) + ( + Decimal(output_tokens) * cfg.output_price_usd ) - token_price_usd = get_token_a_price_usd() + token_price_usd = get_price() if token_price_usd <= 0: - raise ValueError(f"Token A price is non-positive: {token_price_usd}") + raise ValueError(f"Token price is non-positive: {token_price_usd}") token_amount = total_usd / token_price_usd decimals = _extract_asset_decimals_from_requirements( @@ -344,7 +306,8 @@ def dynamic_session_cost_calculator(context: dict[str, Any]) -> int: ) logger.info( - "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d total_usd=%s token_price_usd=%s decimals=%d cost=%d", + "CALCULATE_SESSION_COST model=%s input_tokens=%d output_tokens=%d " + "total_usd=%s token_price_usd=%s decimals=%d cost=%d", model, input_tokens, output_tokens, diff --git a/tests/test_pricing.py b/tests/test_pricing.py index d1b5f25..5419782 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -3,7 +3,7 @@ Tests verify that: - Every user-facing model name resolves to the correct ModelConfig - - dynamic_session_cost_calculator produces the right amount in OPG token + - calculate_session_cost produces the right amount in OPG token smallest-units for supported models - Edge cases (no usage, unknown model, bad context) are handled correctly """ @@ -16,7 +16,11 @@ _MODEL_LOOKUP, get_model_config, ) -from tee_gateway.util import dynamic_session_cost_calculator +from tee_gateway.util import calculate_session_cost + +# All pricing tests assume OPG = $1.00 so USD cost == OPG token amount. +_OPG_PRICE_USD = Decimal("1") +_get_price = lambda: _OPG_PRICE_USD # noqa: E731 # --------------------------------------------------------------------------- @@ -205,12 +209,12 @@ def test_unknown_sonnet_variant_raises(self): # --------------------------------------------------------------------------- -class TestDynamicSessionCostCalculatorOPG(unittest.TestCase): - """dynamic_session_cost_calculator with OPG (18 decimals).""" +class TestCalculateSessionCostOPG(unittest.TestCase): + """calculate_session_cost with OPG (18 decimals).""" def _calc(self, model, input_tokens, output_tokens): - return dynamic_session_cost_calculator( - _ctx(model, input_tokens, output_tokens, _opg_requirements()) + return calculate_session_cost( + _ctx(model, input_tokens, output_tokens, _opg_requirements()), _get_price ) # ── OpenAI ────────────────────────────────────────────────────────────── @@ -351,11 +355,11 @@ def test_grok_4_fast_cheaper_than_grok_4(self): self.assertLess(fast, full) -class TestDynamicSessionCostCalculatorEdgeCases(unittest.TestCase): - """Edge cases for dynamic_session_cost_calculator.""" +class TestCalculateSessionCostEdgeCases(unittest.TestCase): + """Edge cases for calculate_session_cost.""" def test_zero_tokens_returns_zero(self): - cost = dynamic_session_cost_calculator(_ctx("claude-sonnet-4-5", 0, 0)) + cost = calculate_session_cost(_ctx("claude-sonnet-4-5", 0, 0), _get_price) self.assertEqual(cost, 0) def test_missing_usage_raises(self): @@ -365,24 +369,24 @@ def test_missing_usage_raises(self): "payment_requirements": _opg_requirements(), } with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_unknown_asset_raises(self): ctx = _ctx("claude-sonnet-4-5", 100, 100) ctx["payment_requirements"] = {"asset": "0xdeadbeef", "amount": "1000"} with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_missing_asset_raises(self): ctx = _ctx("claude-sonnet-4-5", 100, 100) ctx["payment_requirements"] = {"amount": "1000"} # no asset with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_unknown_model_raises_value_error(self): ctx = _ctx("gpt-4o", 100, 100) with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_missing_request_json_raises_value_error(self): ctx = { @@ -394,7 +398,7 @@ def test_missing_request_json_raises_value_error(self): "payment_requirements": _opg_requirements(), } with self.assertRaises(ValueError): - dynamic_session_cost_calculator(ctx) + calculate_session_cost(ctx, _get_price) def test_model_from_request_takes_priority(self): """request_json model name is used even if response_json has a different model.""" @@ -406,7 +410,7 @@ def test_model_from_request_takes_priority(self): }, "payment_requirements": _opg_requirements(), } - cost = dynamic_session_cost_calculator(ctx) + cost = calculate_session_cost(ctx, _get_price) # Should be priced as Haiku (from request), not Sonnet haiku_cost = _expected_cost_opg("claude-haiku-4-5", 1000, 500) self.assertEqual(cost, haiku_cost) @@ -414,29 +418,31 @@ def test_model_from_request_takes_priority(self): def test_rounding_ceiling(self): """Fractional token costs are always rounded UP.""" # 1 output token of Haiku: 0.000005 USD = 5e12 wei — exact, no rounding needed - cost = dynamic_session_cost_calculator(_ctx("claude-haiku-4-5", 0, 1)) + cost = calculate_session_cost(_ctx("claude-haiku-4-5", 0, 1), _get_price) self.assertEqual(cost, 5_000_000_000_000) # 1 input token of Gemini Flash Lite: 0.0000001 USD = 1e11 wei — exact - cost = dynamic_session_cost_calculator(_ctx("gemini-2.5-flash-lite", 1, 0)) + cost = calculate_session_cost(_ctx("gemini-2.5-flash-lite", 1, 0), _get_price) self.assertEqual(cost, 100_000_000_000) def test_model_name_case_insensitive(self): """Model names are normalized to lowercase before lookup.""" - cost_lower = dynamic_session_cost_calculator( - _ctx("claude-sonnet-4-5", 100, 100) + cost_lower = calculate_session_cost( + _ctx("claude-sonnet-4-5", 100, 100), _get_price ) - cost_upper = dynamic_session_cost_calculator( - _ctx("CLAUDE-SONNET-4-5", 100, 100) + cost_upper = calculate_session_cost( + _ctx("CLAUDE-SONNET-4-5", 100, 100), _get_price ) self.assertEqual(cost_lower, cost_upper) def test_sonnet_4_0_hyphen_vs_dot_same_cost(self): """claude-sonnet-4-0 and claude-4.0-sonnet are the same model.""" - cost_hyphen = dynamic_session_cost_calculator( - _ctx("claude-sonnet-4-0", 1000, 500) + cost_hyphen = calculate_session_cost( + _ctx("claude-sonnet-4-0", 1000, 500), _get_price + ) + cost_dot = calculate_session_cost( + _ctx("claude-4.0-sonnet", 1000, 500), _get_price ) - cost_dot = dynamic_session_cost_calculator(_ctx("claude-4.0-sonnet", 1000, 500)) self.assertEqual(cost_hyphen, cost_dot) From 460b2ed0f83ded02f7f480f66277aa462b5629fe Mon Sep 17 00:00:00 2001 From: Kyle Qian Date: Mon, 20 Apr 2026 22:07:12 -0700 Subject: [PATCH 13/13] Revert "feat: coingecko opg price feed (#56)" This reverts commit e100e2d5eb426fbf1ad1ef3e52af5c413e88522a. --- .github/workflows/test.yml | 5 +- tee_gateway/__main__.py | 146 +++-- tee_gateway/price_feed/__init__.py | 7 - tee_gateway/price_feed/config.py | 52 -- tee_gateway/price_feed/feed.py | 269 --------- tee_gateway/test/test_price_feed.py | 516 ------------------ .../test/test_price_feed_integration.py | 133 ----- tee_gateway/util.py | 73 ++- tests/test_pricing.py | 54 +- 9 files changed, 170 insertions(+), 1085 deletions(-) delete mode 100644 tee_gateway/price_feed/__init__.py delete mode 100644 tee_gateway/price_feed/config.py delete mode 100644 tee_gateway/price_feed/feed.py delete mode 100644 tee_gateway/test/test_price_feed.py delete mode 100644 tee_gateway/test/test_price_feed_integration.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 273d4c9..835c886 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,4 @@ jobs: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 - name: Run unit tests - run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tee_gateway/test/test_price_feed.py tests/test_pricing.py -v --import-mode=importlib - # To also run integration tests (real CoinGecko network calls), add: - # env: - # RUN_INTEGRATION_TESTS: "1" + run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tests/test_pricing.py -v --import-mode=importlib diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index d5662fc..fd36b24 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -34,10 +34,9 @@ from x402.server import x402ResourceServerSync from x402.session import SessionStore import x402.http.middleware.flask as x402_flask +import types as _types -from .util import calculate_session_cost -from .model_registry import get_model_config -from .price_feed import OPGPriceFeed +from .util import dynamic_session_cost_calculator from .definitions import ( EVM_PAYMENT_ADDRESS, BASE_MAINNET_NETWORK, @@ -108,13 +107,6 @@ def _shutdown_heartbeat(): atexit.register(_shutdown_heartbeat) -# --------------------------------------------------------------------------- -# OPG price feed — start before x402 middleware so the first request can be -# priced correctly. Runs as a daemon thread; no cleanup needed on exit. -# --------------------------------------------------------------------------- -_price_feed = OPGPriceFeed() -_price_feed.start() - facilitator = HTTPFacilitatorClientSync(FacilitatorConfig(url=FACILITATOR_URL)) server = x402ResourceServerSync(facilitator) store = SessionStore() @@ -311,7 +303,6 @@ def health(): "status": "OK", "version": "1.0.0", "tee_enabled": True, - "price_feed": _price_feed.get_status(), }, 200 @@ -383,26 +374,6 @@ def _patched_read_body_bytes(environ): x402_flask._read_body_bytes = _patched_read_body_bytes - -def _session_cost_calculator(ctx: dict) -> int: - # Post-inference cost calculation — response already sent to client. - # Predictable failures (unknown price, unknown model) are blocked by the - # pre-inference gate; any exception here indicates a provider-side error - # (e.g. missing usage field in the LLM response). The x402 middleware - # swallows the exception in close(), so the client is not charged. - # Log CRITICAL so provider errors are never silently missed. - try: - return calculate_session_cost(ctx, _price_feed.get_price) - except Exception as exc: - logger.critical( - "Post-inference cost calculation failed (provider error) — " - "client was NOT charged: %s", - exc, - exc_info=True, - ) - raise - - _payment_mw = payment_middleware( application, routes=routes, @@ -410,39 +381,102 @@ def _session_cost_calculator(ctx: dict) -> int: session_store=store, cost_per_request=100000000000000, # static precheck/fallback estimate session_idle_timeout=100, - session_cost_calculator=_session_cost_calculator, + session_cost_calculator=dynamic_session_cost_calculator, ) # --------------------------------------------------------------------------- -# Pre-inference pricing gate +# Strict cost-resolution patch +# +# Why this exists +# --------------- +# The upstream x402 PaymentMiddleware._resolve_session_request_cost wraps the +# call to the session_cost_calculator in a broad try/except. If the calculator +# raises (e.g. ValueError for an unrecognised model name, KeyError for missing +# usage data), the exception is swallowed and the middleware silently falls back +# to the static session maximum (CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND / +# CHAT_COMPLETIONS_USDC_AMOUNT). That silent fallback means: +# • The client is charged the full pre-check cap instead of actual usage. +# • The server has no visible indication that pricing failed. # -# In the upto session scheme the response is streamed to the client before -# cost is settled, so a post-inference pricing failure cannot be surfaced as -# an HTTP error. Instead we validate everything that can be checked up-front -# and reject the request early if pricing would fail: -# 1. Price feed has a valid OPG/USD price (CoinGecko fetch succeeded). -# 2. The requested model is in the registry (has a known per-token price). +# The fix +# ------- +# We replace _resolve_session_request_cost with our own implementation that is +# identical to upstream, except the cost-calculator call is NOT wrapped in a +# try/except. Any exception from dynamic_session_cost_calculator() therefore +# propagates up through the middleware and Flask, producing a proper HTTP 500 +# response to the client instead of an incorrect silent charge. # --------------------------------------------------------------------------- -@application.before_request -def _check_pricing_ready(): - if request.path not in ("/v1/chat/completions", "/v1/completions"): - return - try: - _price_feed.get_price() - except ValueError as exc: - logger.warning("Rejecting inference request — price feed unavailable: %s", exc) - return jsonify({"error": f"Pricing unavailable: {exc}"}), 503 - - body = request.get_json(silent=True, cache=True) or {} - model = body.get("model") - if model: - try: - get_model_config(model) - except ValueError: - return jsonify({"error": f"Model '{model}' is not supported"}), 400 +def _strict_resolve_session_request_cost( + self, + *, + method: str, + path: str, + request_body_bytes: bytes, + response_body_bytes: bytes, + payment_payload: object, + payment_requirements: object, + status_code: int | None, + output_object: object = None, + is_streaming: bool = False, +) -> int: + """Replacement for PaymentMiddleware._resolve_session_request_cost. + + Identical to the upstream implementation except that exceptions raised by + the dynamic cost calculator are NOT caught. This means a request whose + cost cannot be determined (unknown model, missing usage data, etc.) will + result in a 500 error rather than silently falling back to the static cap + amount and charging the user an incorrect amount. + """ + from x402.http.middleware.flask import _parse_json_bytes as _x402_parse_json # noqa: PLC0415 + + default_cost = self._get_session_cost(payment_requirements) + if not self._should_charge_response(status_code): + return default_cost + if not callable(self._session_cost_calculator): + return default_cost + + request_object = _x402_parse_json(request_body_bytes) + response_object = ( + output_object + if output_object is not None + else _x402_parse_json(response_body_bytes) + ) + callback_context = { + "method": method, + "path": path, + "status_code": status_code, + "is_streaming": is_streaming, + "request_body_bytes": request_body_bytes, + "response_body_bytes": response_body_bytes, + "request_json": request_object + if isinstance(request_object, (dict, list)) + else None, + "response_json": response_object + if isinstance(response_object, (dict, list)) + else None, + "response_object": response_object, + "payment_payload": payment_payload, + "payment_requirements": payment_requirements, + "default_cost": default_cost, + } + + # Do NOT catch exceptions here — let them propagate so the request fails + # with a 500 rather than silently charging the static fallback amount. + dynamic_cost = self._session_cost_calculator(callback_context) + if dynamic_cost is None: + raise ValueError( + f"dynamic_session_cost_calculator returned None for {method} {path}; " + "cannot determine request cost" + ) + return self._coerce_non_negative_int(dynamic_cost) + + +_payment_mw._resolve_session_request_cost = _types.MethodType( # type: ignore[method-assign, attr-defined] + _strict_resolve_session_request_cost, _payment_mw +) logger.info("x402 payment middleware initialized") diff --git a/tee_gateway/price_feed/__init__.py b/tee_gateway/price_feed/__init__.py deleted file mode 100644 index 1349825..0000000 --- a/tee_gateway/price_feed/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .config import PriceFeedConfig -from .feed import OPGPriceFeed - -__all__ = [ - "OPGPriceFeed", - "PriceFeedConfig", -] diff --git a/tee_gateway/price_feed/config.py b/tee_gateway/price_feed/config.py deleted file mode 100644 index a8742dc..0000000 --- a/tee_gateway/price_feed/config.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -Configuration constants and dataclass for the OPG price feed. -""" - -from dataclasses import dataclass -from datetime import datetime, timezone -from decimal import Decimal - - -# --------------------------------------------------------------------------- -# CoinGecko API -# --------------------------------------------------------------------------- -COINGECKO_BASE_URL = "https://api.coingecko.com/api/v3" -COINGECKO_PLATFORM = "base" # Base mainnet platform identifier on CoinGecko -FETCH_TIMEOUT = 10 # seconds per HTTP request - -# --------------------------------------------------------------------------- -# Refresh / retry defaults -# --------------------------------------------------------------------------- -DEFAULT_REFRESH_INTERVAL = 300 # 5 minutes between background refresh cycles -DEFAULT_MAX_RETRIES = 3 # attempts per refresh cycle before giving up -DEFAULT_RETRY_DELAY = 10 # seconds between retry attempts within a cycle - -# --------------------------------------------------------------------------- -# TGE (Token Generation Event) fallback -# --------------------------------------------------------------------------- -# Before the TGE cutover, OPG is not yet listed on CoinGecko. Return a fixed -# fallback price so inference requests can be priced immediately at launch. -# After the cutover, the live CoinGecko price is used. -TGE_CUTOVER_UTC = datetime(2026, 4, 21, 12, 30, 0, tzinfo=timezone.utc) -TGE_FALLBACK_PRICE_USD = Decimal("0.10") - -# --------------------------------------------------------------------------- -# Stale-price thresholds -# --------------------------------------------------------------------------- -# get_price() logs WARNING when last successful fetch is older than -# STALE_WARNING_MULTIPLIER × refresh_interval seconds. -STALE_WARNING_MULTIPLIER = 2 - -# get_price() raises ValueError when last successful fetch is older than -# STALE_PRICE_MAX_AGE seconds — at this point the cached price is considered -# too outdated to use for billing. -STALE_PRICE_MAX_AGE = 4 * 60 * 60 # 4 hours - - -@dataclass(frozen=True) -class PriceFeedConfig: - """Runtime configuration for the OPG price feed background service.""" - - refresh_interval: int = DEFAULT_REFRESH_INTERVAL - max_retries: int = DEFAULT_MAX_RETRIES - retry_delay: float = DEFAULT_RETRY_DELAY diff --git a/tee_gateway/price_feed/feed.py b/tee_gateway/price_feed/feed.py deleted file mode 100644 index 6875f68..0000000 --- a/tee_gateway/price_feed/feed.py +++ /dev/null @@ -1,269 +0,0 @@ -""" -Background OPG/USD price feed using the CoinGecko public API. - -Runs as a daemon thread that proactively refreshes the OPG token price at a -configurable interval, with retry on per-cycle fetch failure and early exit on -rate limiting. - -Usage ------ -Create an ``OPGPriceFeed`` instance in the application entry point, call -``start()``, then pass it explicitly to wherever the price is needed (e.g. -``calculate_session_cost(...)`` in ``util.py``). -""" - -import logging -import threading -import time -from datetime import datetime, timezone -from decimal import Decimal -from typing import Any, Optional - -import requests - -from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS -from tee_gateway.price_feed.config import ( - COINGECKO_BASE_URL, - COINGECKO_PLATFORM, - DEFAULT_MAX_RETRIES, - DEFAULT_REFRESH_INTERVAL, - DEFAULT_RETRY_DELAY, - FETCH_TIMEOUT, - STALE_PRICE_MAX_AGE, - STALE_WARNING_MULTIPLIER, - TGE_CUTOVER_UTC, - TGE_FALLBACK_PRICE_USD, -) - -logger = logging.getLogger("llm_server.price_feed") - - -class OPGPriceFeed: - """Fetches and caches the OPG/USD price from CoinGecko in a background thread.""" - - def __init__( - self, - refresh_interval: int = DEFAULT_REFRESH_INTERVAL, - max_retries: int = DEFAULT_MAX_RETRIES, - retry_delay: float = DEFAULT_RETRY_DELAY, - ) -> None: - self._refresh_interval = refresh_interval - self._max_retries = max_retries - self._retry_delay = retry_delay - - self._price: Optional[Decimal] = None - self._lock = threading.Lock() - self._thread: Optional[threading.Thread] = None - - # Status tracking — updated under _lock on every refresh cycle outcome. - self.last_success: Optional[float] = None # epoch seconds of last good fetch - self.last_error: Optional[str] = None # description of last failure (if any) - self.consecutive_failures: int = 0 # reset to 0 on any successful fetch - self.total_fetches: int = 0 # cumulative successful fetches - self.total_errors: int = 0 # cumulative failed refresh cycles - - # ------------------------------------------------------------------ - # Public interface - # ------------------------------------------------------------------ - - def start(self) -> None: - """Launch the background refresh loop, including the initial price fetch. - - The initial fetch runs inside the background thread so startup is - non-blocking. ``get_price()`` will raise ``ValueError`` until the - first successful fetch completes; until then, inference requests are - rejected by the pre-inference pricing gate in ``__main__.py``. - - Idempotent — calling ``start()`` on an already-running feed is a no-op. - Thread-safe: the check-and-start is performed under ``_lock``. - """ - with self._lock: - if self._thread is not None and self._thread.is_alive(): - logger.info( - "OPG price feed already running, ignoring duplicate start()" - ) - return - self._thread = threading.Thread( - target=self._run_with_initial_fetch, - name="opg-price-feed", - daemon=True, - ) - self._thread.start() - logger.info( - "OPG price feed started (refresh_interval=%ds, max_retries=%d)", - self._refresh_interval, - self._max_retries, - ) - - def get_price(self) -> Decimal: - """Return the latest cached OPG/USD price. - - Before the TGE cutover (``TGE_CUTOVER_UTC``), returns the fixed - ``TGE_FALLBACK_PRICE_USD`` so requests can be priced before OPG is - listed on CoinGecko. After the cutover the live cached price is used. - - Raises ``ValueError`` if no price has been successfully fetched yet - (post-TGE only). Logs a warning (but still returns the price) if the - cached value is older than ``STALE_WARNING_MULTIPLIER * refresh_interval`` - seconds — this indicates the background loop has missed at least one - refresh cycle and may be experiencing persistent errors. - """ - if datetime.now(timezone.utc) < TGE_CUTOVER_UTC: - return TGE_FALLBACK_PRICE_USD - - now = time.time() - with self._lock: - if self._price is None: - raise ValueError( - "OPG price not yet available — " - "price feed has not completed a successful fetch" - ) - if self.last_success is None: - raise ValueError( - "OPG price not yet available — " - "price feed has not completed a successful fetch" - ) - age = now - self.last_success - if age > STALE_PRICE_MAX_AGE: - raise ValueError( - f"OPG price data expired: last successful fetch was {age:.0f}s ago " - f"(max: {STALE_PRICE_MAX_AGE}s); consecutive failures: " - f"{self.consecutive_failures}" - ) - stale_threshold = self._refresh_interval * STALE_WARNING_MULTIPLIER - if age > stale_threshold: - logger.warning( - "OPG price data is stale: last successful fetch was %.0fs ago " - "(threshold: %.0fs); consecutive failures: %d", - age, - stale_threshold, - self.consecutive_failures, - ) - return self._price - - def get_status(self) -> dict[str, Any]: - """Return a health snapshot suitable for logging or a /health endpoint.""" - with self._lock: - return { - "price_usd": float(self._price) if self._price is not None else None, - "last_success": self.last_success, - "last_error": self.last_error, - "consecutive_failures": self.consecutive_failures, - "total_fetches": self.total_fetches, - "total_errors": self.total_errors, - "refresh_interval": self._refresh_interval, - } - - # ------------------------------------------------------------------ - # Internal helpers - # ------------------------------------------------------------------ - - def _run_with_initial_fetch(self) -> None: - self._refresh_price() - while True: - time.sleep(self._refresh_interval) - self._refresh_price() - - def _refresh_price(self) -> None: - """Attempt to fetch a fresh price, retrying on transient failure. - - - On success: updates the cached price and resets ``consecutive_failures``. - - On HTTP 429: logs a rate-limit warning and exits the retry loop early - (no point hammering a rate-limited API). - - On exhausted retries: increments ``consecutive_failures`` and retains - the last known good price so live traffic is not disrupted by a - transient CoinGecko outage. - """ - last_exc: Optional[Exception] = None - - for attempt in range(1, self._max_retries + 1): - try: - price = fetch_opg_price() - with self._lock: - self._price = price - self.last_success = time.time() - self.last_error = None - self.consecutive_failures = 0 - self.total_fetches += 1 - logger.info( - "OPG price updated: $%.6f USD (attempt %d/%d)", - float(price), - attempt, - self._max_retries, - ) - return - except requests.exceptions.HTTPError as exc: - last_exc = exc - status_code = ( - exc.response.status_code if exc.response is not None else None - ) - if status_code == 429: - logger.warning( - "CoinGecko rate limit hit (429) on attempt %d/%d; " - "skipping remaining retries for this cycle", - attempt, - self._max_retries, - ) - break - logger.warning( - "OPG price fetch attempt %d/%d failed (HTTP %s): %s", - attempt, - self._max_retries, - status_code, - exc, - ) - except Exception as exc: - last_exc = exc - logger.warning( - "OPG price fetch attempt %d/%d failed: %s", - attempt, - self._max_retries, - exc, - ) - - if attempt < self._max_retries: - time.sleep(self._retry_delay) - - # All attempts exhausted (or rate-limited out) — record the failure. - with self._lock: - self.total_errors += 1 - self.consecutive_failures += 1 - self.last_error = str(last_exc) if last_exc is not None else "unknown error" - - logger.error( - "OPG price refresh failed (consecutive failures: %d); " - "retaining last known price (%s)", - self.consecutive_failures, - self._price, - ) - - -def fetch_opg_price() -> Decimal: - """Fetch the current OPG/USD price from CoinGecko. Raises on any error.""" - url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" - params = { - "contract_addresses": BASE_MAINNET_OPG_ADDRESS, - "vs_currencies": "usd", - } - response = requests.get(url, params=params, timeout=FETCH_TIMEOUT) - response.raise_for_status() - - data: Any = response.json() - if not isinstance(data, dict): - raise ValueError( - f"Unexpected CoinGecko response for {BASE_MAINNET_OPG_ADDRESS}: {data!r}" - ) - # CoinGecko keys the result by the lowercased contract address. - price_entry = data.get(BASE_MAINNET_OPG_ADDRESS.lower()) - if not isinstance(price_entry, dict) or "usd" not in price_entry: - raise ValueError( - f"Unexpected CoinGecko response for {BASE_MAINNET_OPG_ADDRESS}: {data!r}" - ) - - price = Decimal(str(price_entry["usd"])) - if not price.is_finite() or price <= 0: - raise ValueError( - f"Invalid price from CoinGecko for {BASE_MAINNET_OPG_ADDRESS}: " - f"{price_entry['usd']!r}" - ) - return price diff --git a/tee_gateway/test/test_price_feed.py b/tee_gateway/test/test_price_feed.py deleted file mode 100644 index fe9bb8a..0000000 --- a/tee_gateway/test/test_price_feed.py +++ /dev/null @@ -1,516 +0,0 @@ -""" -Unit tests for tee_gateway.price_feed and tee_gateway.util.calculate_session_cost. - -All external HTTP calls are mocked — no network access required. - -Test classes ------------- -TestFetchOPGPrice — the raw fetch_opg_price() helper in feed.py -TestOPGPriceFeedRefresh — OPGPriceFeed._refresh_price() (retry, rate-limit, stats) -TestOPGPriceFeedGetPrice — OPGPriceFeed.get_price() (stale warning, ValueError before fetch) -TestOPGPriceFeedStatus — OPGPriceFeed.get_status() snapshots -TestCalculateSessionCost — calculate_session_cost(context, get_price) in util.py -""" - -import time -import unittest -from datetime import datetime, timezone -from decimal import Decimal -from unittest.mock import MagicMock, patch - -import requests - -from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS -from tee_gateway.price_feed import OPGPriceFeed -from tee_gateway.price_feed.feed import fetch_opg_price -from tee_gateway.util import calculate_session_cost - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -OPG_ADDRESS_LOWER = BASE_MAINNET_OPG_ADDRESS.lower() -SAMPLE_PRICE = Decimal("0.042") -SAMPLE_PRICE_FLOAT = 0.042 - -# Patch target prefix — all mocks go through the feed module. -_FEED = "tee_gateway.price_feed.feed" - -# A datetime well after the TGE cutover so get_price() uses the cached price. -_POST_TGE = datetime(2026, 4, 22, 0, 0, 0, tzinfo=timezone.utc) - - -def _mock_response(status_code: int = 200, json_body: dict | None = None) -> MagicMock: - """Build a minimal mock requests.Response.""" - mock = MagicMock() - mock.status_code = status_code - mock.json.return_value = json_body or {} - if status_code >= 400: - http_err = requests.exceptions.HTTPError(response=mock) - mock.raise_for_status.side_effect = http_err - else: - mock.raise_for_status.return_value = None - return mock - - -def _coingecko_success_body() -> dict: - return {OPG_ADDRESS_LOWER: {"usd": SAMPLE_PRICE_FLOAT}} - - -# --------------------------------------------------------------------------- -# TestFetchOPGPrice -# --------------------------------------------------------------------------- - - -class TestFetchOPGPrice(unittest.TestCase): - """Tests for the fetch_opg_price() free function in feed.py.""" - - @patch(f"{_FEED}.requests.get") - def test_happy_path_returns_decimal(self, mock_get): - mock_get.return_value = _mock_response(200, _coingecko_success_body()) - price = fetch_opg_price() - self.assertIsInstance(price, Decimal) - self.assertEqual(price, Decimal(str(SAMPLE_PRICE_FLOAT))) - - @patch(f"{_FEED}.requests.get") - def test_passes_correct_params(self, mock_get): - mock_get.return_value = _mock_response(200, _coingecko_success_body()) - fetch_opg_price() - _, kwargs = mock_get.call_args - self.assertIn("contract_addresses", kwargs["params"]) - self.assertEqual(kwargs["params"]["vs_currencies"], "usd") - self.assertIn( - "base", kwargs["url"] if "url" in kwargs else mock_get.call_args[0][0] - ) - - @patch(f"{_FEED}.requests.get") - def test_raises_on_http_500(self, mock_get): - mock_get.return_value = _mock_response(500) - with self.assertRaises(requests.exceptions.HTTPError): - fetch_opg_price() - - @patch(f"{_FEED}.requests.get") - def test_raises_on_http_429(self, mock_get): - mock_get.return_value = _mock_response(429) - with self.assertRaises(requests.exceptions.HTTPError) as ctx: - fetch_opg_price() - self.assertEqual(ctx.exception.response.status_code, 429) - - @patch(f"{_FEED}.requests.get") - def test_raises_on_empty_response_body(self, mock_get): - mock_get.return_value = _mock_response(200, {}) - with self.assertRaises(ValueError, msg="should raise when address key absent"): - fetch_opg_price() - - @patch(f"{_FEED}.requests.get") - def test_raises_when_usd_key_missing(self, mock_get): - mock_get.return_value = _mock_response(200, {OPG_ADDRESS_LOWER: {"eur": 0.04}}) - with self.assertRaises(ValueError): - fetch_opg_price() - - @patch(f"{_FEED}.requests.get") - def test_raises_when_address_entry_is_empty_dict(self, mock_get): - """CoinGecko returns {address: {}} for known-but-unpriced tokens (current OPG behaviour).""" - mock_get.return_value = _mock_response(200, {OPG_ADDRESS_LOWER: {}}) - with self.assertRaises(ValueError, msg="empty price entry should raise"): - fetch_opg_price() - - @patch(f"{_FEED}.requests.get") - def test_raises_on_network_error(self, mock_get): - mock_get.side_effect = requests.exceptions.ConnectionError("timeout") - with self.assertRaises(requests.exceptions.ConnectionError): - fetch_opg_price() - - -# --------------------------------------------------------------------------- -# TestOPGPriceFeedRefresh -# --------------------------------------------------------------------------- - - -class TestOPGPriceFeedRefresh(unittest.TestCase): - """Tests for OPGPriceFeed._refresh_price() — retry logic, rate-limit, stats.""" - - def _feed(self, **kwargs) -> OPGPriceFeed: - defaults = {"refresh_interval": 300, "max_retries": 3, "retry_delay": 0} - defaults.update(kwargs) - return OPGPriceFeed(**defaults) - - @patch(f"{_FEED}.fetch_opg_price") - def test_successful_refresh_sets_price(self, mock_fetch): - mock_fetch.return_value = SAMPLE_PRICE - feed = self._feed() - feed._refresh_price() - self.assertEqual(feed._price, SAMPLE_PRICE) - - @patch(f"{_FEED}.fetch_opg_price") - def test_successful_refresh_updates_stats(self, mock_fetch): - mock_fetch.return_value = SAMPLE_PRICE - feed = self._feed() - feed._refresh_price() - self.assertEqual(feed.total_fetches, 1) - self.assertEqual(feed.total_errors, 0) - self.assertEqual(feed.consecutive_failures, 0) - self.assertIsNotNone(feed.last_success) - - @patch(f"{_FEED}.fetch_opg_price") - def test_retry_on_transient_failure_then_success(self, mock_fetch): - mock_fetch.side_effect = [ - ValueError("transient"), - ValueError("transient"), - SAMPLE_PRICE, - ] - feed = self._feed(max_retries=3, retry_delay=0) - feed._refresh_price() - self.assertEqual(feed._price, SAMPLE_PRICE) - self.assertEqual(mock_fetch.call_count, 3) - self.assertEqual(feed.total_fetches, 1) - self.assertEqual(feed.total_errors, 0) - - @patch(f"{_FEED}.fetch_opg_price") - def test_exhausted_retries_records_error_stats(self, mock_fetch): - mock_fetch.side_effect = ValueError("always fails") - feed = self._feed(max_retries=3, retry_delay=0) - feed._refresh_price() - self.assertEqual(feed.total_errors, 1) - self.assertEqual(feed.consecutive_failures, 1) - self.assertIsNotNone(feed.last_error) - self.assertEqual(feed.total_fetches, 0) - - @patch(f"{_FEED}.fetch_opg_price") - def test_exhausted_retries_keeps_last_known_price(self, mock_fetch): - feed = self._feed(max_retries=2, retry_delay=0) - feed._price = SAMPLE_PRICE - feed.last_success = time.time() - mock_fetch.side_effect = ValueError("fail") - feed._refresh_price() - self.assertEqual(feed._price, SAMPLE_PRICE) - - @patch(f"{_FEED}.fetch_opg_price") - def test_success_after_failures_resets_consecutive_failures(self, mock_fetch): - feed = self._feed(max_retries=1, retry_delay=0) - mock_fetch.side_effect = ValueError("fail") - feed._refresh_price() - self.assertEqual(feed.consecutive_failures, 1) - mock_fetch.side_effect = None - mock_fetch.return_value = SAMPLE_PRICE - feed._refresh_price() - self.assertEqual(feed.consecutive_failures, 0) - - @patch(f"{_FEED}.fetch_opg_price") - def test_rate_limit_breaks_retry_loop_immediately(self, mock_fetch): - resp = MagicMock() - resp.status_code = 429 - mock_fetch.side_effect = requests.exceptions.HTTPError(response=resp) - feed = self._feed(max_retries=3, retry_delay=0) - feed._refresh_price() - self.assertEqual(mock_fetch.call_count, 1) - self.assertEqual(feed.total_errors, 1) - - @patch(f"{_FEED}.time.sleep") - @patch(f"{_FEED}.fetch_opg_price") - def test_retry_delay_called_between_attempts(self, mock_fetch, mock_sleep): - mock_fetch.side_effect = [ValueError("fail"), ValueError("fail"), SAMPLE_PRICE] - feed = self._feed(max_retries=3, retry_delay=5) - feed._refresh_price() - self.assertEqual(mock_sleep.call_count, 2) - mock_sleep.assert_called_with(5) - - @patch(f"{_FEED}.time.sleep") - @patch(f"{_FEED}.fetch_opg_price") - def test_no_sleep_after_last_failed_attempt(self, mock_fetch, mock_sleep): - mock_fetch.side_effect = ValueError("always fails") - feed = self._feed(max_retries=3, retry_delay=5) - feed._refresh_price() - self.assertEqual(mock_sleep.call_count, 2) - - -# --------------------------------------------------------------------------- -# TestOPGPriceFeedGetPrice -# --------------------------------------------------------------------------- - - -class TestOPGPriceFeedGetPrice(unittest.TestCase): - """Tests for OPGPriceFeed.get_price() behaviour.""" - - @patch(f"{_FEED}.datetime") - def test_raises_before_any_successful_fetch(self, mock_dt): - mock_dt.now.return_value = _POST_TGE - feed = OPGPriceFeed() - with self.assertRaises(ValueError) as ctx: - feed.get_price() - self.assertIn("not yet available", str(ctx.exception)) - - @patch(f"{_FEED}.datetime") - @patch(f"{_FEED}.fetch_opg_price") - def test_returns_price_after_successful_refresh(self, mock_fetch, mock_dt): - mock_dt.now.return_value = _POST_TGE - mock_fetch.return_value = SAMPLE_PRICE - feed = OPGPriceFeed(retry_delay=0) - feed._refresh_price() - self.assertEqual(feed.get_price(), SAMPLE_PRICE) - - @patch(f"{_FEED}.datetime") - @patch(f"{_FEED}.time.time") - @patch(f"{_FEED}.fetch_opg_price") - def test_warns_when_price_is_stale(self, mock_fetch, mock_time, mock_dt): - mock_dt.now.return_value = _POST_TGE - mock_fetch.return_value = SAMPLE_PRICE - feed = OPGPriceFeed(refresh_interval=300, retry_delay=0) - - mock_time.return_value = 0.0 - feed._refresh_price() - - # Advance past stale threshold (300 * 2 = 600s) - mock_time.return_value = 601.0 - - with self.assertLogs("llm_server.price_feed", level="WARNING") as log_ctx: - price = feed.get_price() - - self.assertEqual(price, SAMPLE_PRICE) - self.assertTrue(any("stale" in line.lower() for line in log_ctx.output)) - - @patch(f"{_FEED}.datetime") - @patch(f"{_FEED}.time.time") - @patch(f"{_FEED}.fetch_opg_price") - def test_raises_when_price_exceeds_max_age(self, mock_fetch, mock_time, mock_dt): - mock_dt.now.return_value = _POST_TGE - mock_fetch.return_value = SAMPLE_PRICE - feed = OPGPriceFeed(retry_delay=0) - - mock_time.return_value = 0.0 - feed._refresh_price() - - # Advance past the 4-hour max age - mock_time.return_value = 4 * 60 * 60 + 1.0 - - with self.assertRaises(ValueError) as ctx: - feed.get_price() - self.assertIn("expired", str(ctx.exception)) - - @patch(f"{_FEED}.time.time") - @patch(f"{_FEED}.fetch_opg_price") - def test_no_stale_warning_when_price_is_fresh(self, mock_fetch, mock_time): - import logging - - mock_fetch.return_value = SAMPLE_PRICE - feed = OPGPriceFeed(refresh_interval=300, retry_delay=0) - - mock_time.return_value = 0.0 - feed._refresh_price() - mock_time.return_value = 100.0 # well within threshold - - with self.assertLogs("llm_server.price_feed", level="DEBUG") as log_ctx: - logging.getLogger("llm_server.price_feed").debug("sentinel") - feed.get_price() - - warning_lines = [ - line - for line in log_ctx.output - if "WARNING" in line and "stale" in line.lower() - ] - self.assertEqual(warning_lines, []) - - -# --------------------------------------------------------------------------- -# TestOPGPriceFeedStatus -# --------------------------------------------------------------------------- - - -class TestOPGPriceFeedStatus(unittest.TestCase): - """Tests for OPGPriceFeed.get_status() snapshot.""" - - def test_initial_status_has_no_price(self): - feed = OPGPriceFeed() - status = feed.get_status() - self.assertIsNone(status["price_usd"]) - self.assertIsNone(status["last_success"]) - self.assertEqual(status["consecutive_failures"], 0) - self.assertEqual(status["total_fetches"], 0) - self.assertEqual(status["total_errors"], 0) - - @patch(f"{_FEED}.fetch_opg_price") - def test_status_reflects_successful_fetch(self, mock_fetch): - mock_fetch.return_value = SAMPLE_PRICE - feed = OPGPriceFeed(retry_delay=0) - feed._refresh_price() - status = feed.get_status() - self.assertAlmostEqual(status["price_usd"], float(SAMPLE_PRICE), places=6) - self.assertIsNotNone(status["last_success"]) - self.assertEqual(status["total_fetches"], 1) - self.assertEqual(status["consecutive_failures"], 0) - - @patch(f"{_FEED}.fetch_opg_price") - def test_status_reflects_failed_cycle(self, mock_fetch): - mock_fetch.side_effect = ValueError("fail") - feed = OPGPriceFeed(max_retries=1, retry_delay=0) - feed._refresh_price() - status = feed.get_status() - self.assertIsNone(status["price_usd"]) - self.assertEqual(status["total_errors"], 1) - self.assertEqual(status["consecutive_failures"], 1) - self.assertIsNotNone(status["last_error"]) - - def test_status_includes_refresh_interval(self): - feed = OPGPriceFeed(refresh_interval=600) - self.assertEqual(feed.get_status()["refresh_interval"], 600) - - @patch(f"{_FEED}.fetch_opg_price") - def test_status_accumulates_multiple_error_cycles(self, mock_fetch): - mock_fetch.side_effect = ValueError("fail") - feed = OPGPriceFeed(max_retries=1, retry_delay=0) - feed._refresh_price() - feed._refresh_price() - feed._refresh_price() - status = feed.get_status() - self.assertEqual(status["total_errors"], 3) - self.assertEqual(status["consecutive_failures"], 3) - - -# --------------------------------------------------------------------------- -# TestMakeCostCalculator -# --------------------------------------------------------------------------- - -_ASSET_ADDR = "0xdeadbeef" -_ASSET_ADDR_LOWER = _ASSET_ADDR.lower() -_ASSET_DECIMALS = 18 - - -def _make_payment_requirements(asset: str = _ASSET_ADDR) -> dict: - return {"asset": asset, "price": {"amount": "1000000000000000000", "asset": asset}} - - -def _make_context( - model: str = "gpt-4.1-mini", - input_tokens: int = 100, - output_tokens: int = 50, - price_usd: Decimal = Decimal("0.10"), - asset: str = _ASSET_ADDR, -) -> dict: - return { - "request_json": {"model": model}, - "response_json": { - "model": model, - "usage": { - "prompt_tokens": input_tokens, - "completion_tokens": output_tokens, - }, - }, - "payment_requirements": _make_payment_requirements(asset), - "method": "POST", - "path": "/v1/chat/completions", - "status_code": 200, - "is_streaming": False, - "request_body_bytes": b"", - "response_body_bytes": b"", - "default_cost": 10**18, - } - - -def _make_get_price(price_usd: Decimal = Decimal("0.10")) -> MagicMock: - mock = MagicMock(return_value=price_usd) - return mock - - -class TestCalculateSessionCost(unittest.TestCase): - """Tests for calculate_session_cost(context, get_price).""" - - def _patch_definitions(self): - return patch( - "tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", - {_ASSET_ADDR_LOWER: _ASSET_DECIMALS}, - ) - - def _patch_model( - self, input_price: str = "0.000001", output_price: str = "0.000002" - ): - cfg = MagicMock() - cfg.input_price_usd = Decimal(input_price) - cfg.output_price_usd = Decimal(output_price) - return patch("tee_gateway.util.get_model_config", return_value=cfg) - - def test_calls_get_price(self): - get_price = _make_get_price() - with self._patch_definitions(), self._patch_model(): - calculate_session_cost(_make_context(), get_price) - get_price.assert_called_once() - - def test_returns_positive_int(self): - with self._patch_definitions(), self._patch_model(): - result = calculate_session_cost(_make_context(), _make_get_price()) - self.assertIsInstance(result, int) - self.assertGreaterEqual(result, 0) - - def test_zero_tokens_returns_zero(self): - with self._patch_definitions(), self._patch_model(): - result = calculate_session_cost( - _make_context(input_tokens=0, output_tokens=0), _make_get_price() - ) - self.assertEqual(result, 0) - - def test_raises_when_get_price_raises(self): - get_price = MagicMock(side_effect=ValueError("price not available")) - with self._patch_definitions(), self._patch_model(): - with self.assertRaises(ValueError): - calculate_session_cost(_make_context(), get_price) - - def test_raises_when_non_positive_price(self): - with self._patch_definitions(), self._patch_model(): - with self.assertRaises(ValueError): - calculate_session_cost(_make_context(), _make_get_price(Decimal("0"))) - - def test_raises_when_request_json_missing(self): - ctx = _make_context() - ctx["request_json"] = None - with self._patch_definitions(), self._patch_model(): - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _make_get_price()) - - def test_raises_when_usage_missing(self): - ctx = _make_context() - ctx["response_json"] = {"model": "gpt-4.1-mini"} - with self._patch_definitions(), self._patch_model(): - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _make_get_price()) - - def test_raises_when_asset_unknown(self): - ctx = _make_context(asset="0xunknown") - with ( - patch("tee_gateway.util.ASSET_DECIMALS_BY_ADDRESS", {}), - self._patch_model(), - ): - with self.assertRaises(ValueError): - calculate_session_cost(ctx, _make_get_price()) - - def test_cost_scales_with_token_count(self): - with self._patch_definitions(), self._patch_model(): - cost_small = calculate_session_cost( - _make_context(input_tokens=10, output_tokens=5), _make_get_price() - ) - cost_large = calculate_session_cost( - _make_context(input_tokens=1000, output_tokens=500), _make_get_price() - ) - self.assertGreater(cost_large, cost_small) - - def test_higher_token_price_yields_lower_cost(self): - with self._patch_definitions(), self._patch_model(): - cost_cheap = calculate_session_cost( - _make_context(), _make_get_price(Decimal("0.10")) - ) - cost_expensive = calculate_session_cost( - _make_context(), _make_get_price(Decimal("0.20")) - ) - self.assertGreater(cost_cheap, cost_expensive) - - def test_uses_current_price_on_each_call(self): - """get_price is called fresh every invocation — price changes are picked up.""" - get_price = MagicMock(side_effect=[Decimal("0.10"), Decimal("0.20")]) - with self._patch_definitions(), self._patch_model(): - cost_first = calculate_session_cost(_make_context(), get_price) - cost_second = calculate_session_cost(_make_context(), get_price) - self.assertEqual(get_price.call_count, 2) - # Price doubled → cost should halve (same USD spend, twice the token price). - self.assertGreater(cost_first, cost_second) - - -if __name__ == "__main__": - unittest.main() diff --git a/tee_gateway/test/test_price_feed_integration.py b/tee_gateway/test/test_price_feed_integration.py deleted file mode 100644 index 2da9db1..0000000 --- a/tee_gateway/test/test_price_feed_integration.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -Integration tests for tee_gateway.price_feed. - -These tests make REAL network calls to the CoinGecko public API. - -Expected behaviour ------------------- -* ``TestCoinGeckoConnectivity`` — passes when the CoinGecko API is reachable. - Skips on network errors or rate-limiting (429). -* ``TestOPGPriceFetchLive`` — skips when OPG is not yet priced on CoinGecko's - Base platform (CoinGecko currently returns an empty price entry for the - token). Will pass automatically once the token is fully listed. - -Run with:: - - uv run pytest tee_gateway/test/test_price_feed_integration.py -v -""" - -import os -import unittest -from decimal import Decimal - -import requests - -if not os.getenv("RUN_INTEGRATION_TESTS"): - raise unittest.SkipTest("Set RUN_INTEGRATION_TESTS=1 to run integration tests") - -from tee_gateway.definitions import BASE_MAINNET_OPG_ADDRESS -from tee_gateway.price_feed.config import ( - COINGECKO_BASE_URL, - COINGECKO_PLATFORM, - FETCH_TIMEOUT, -) -from tee_gateway.price_feed.feed import fetch_opg_price - - -def _get(url: str, **kwargs) -> requests.Response: - """Wrapper that skips the test on network errors or rate-limiting.""" - try: - resp = requests.get(url, timeout=FETCH_TIMEOUT, **kwargs) - except requests.exceptions.RequestException as exc: - raise unittest.SkipTest(f"Network unavailable: {exc}") from exc - if resp.status_code == 429: - raise unittest.SkipTest( - "CoinGecko rate limit hit (429) — re-run after a short wait" - ) - return resp - - -class TestCoinGeckoConnectivity(unittest.TestCase): - """Verify that the CoinGecko API endpoint is reachable and well-formed.""" - - def test_ping_endpoint_reachable(self): - """CoinGecko /ping should return {gecko_says: ...}.""" - resp = _get(f"{COINGECKO_BASE_URL}/ping") - self.assertEqual(resp.status_code, 200) - self.assertIn("gecko_says", resp.json()) - - def test_base_platform_endpoint_returns_200(self): - """The token_price/base endpoint should respond with HTTP 200 for a known token.""" - url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" - # USDC on Base mainnet — reliably indexed on CoinGecko. - usdc_base = "0x833589fcd6edb6e08f4c7c32d4f71b54bda02913" - resp = _get( - url, params={"contract_addresses": usdc_base, "vs_currencies": "usd"} - ) - self.assertEqual( - resp.status_code, - 200, - f"Expected 200 from CoinGecko, got {resp.status_code}: {resp.text[:200]}", - ) - data = resp.json() - self.assertIsInstance(data, dict) - self.assertIn(usdc_base, data, "USDC should be indexed on Base platform") - self.assertIn("usd", data[usdc_base], "USDC price entry should have 'usd' key") - - -class TestOPGPriceFetchLive(unittest.TestCase): - """Live fetch of the OPG token price. - - Both tests skip gracefully when OPG is not yet fully priced on CoinGecko - (currently returns ``{address: {}}`` with no 'usd' key). They will pass - automatically once the token is listed with a live price. - """ - - def test_opg_response_structure(self): - """Inspect the raw CoinGecko response for the OPG contract address.""" - url = f"{COINGECKO_BASE_URL}/simple/token_price/{COINGECKO_PLATFORM}" - resp = _get( - url, - params={ - "contract_addresses": BASE_MAINNET_OPG_ADDRESS, - "vs_currencies": "usd", - }, - ) - self.assertEqual(resp.status_code, 200) - data = resp.json() - print(f"\nCoinGecko response for OPG ({BASE_MAINNET_OPG_ADDRESS}): {data}") # noqa: T201 - - opg_lower = BASE_MAINNET_OPG_ADDRESS.lower() - price_entry = data.get(opg_lower) - # CoinGecko returns the address key with {} when the token is known but - # not yet priced — skip in that case rather than fail. - if not price_entry or "usd" not in price_entry: - self.skipTest( - f"OPG not yet priced on CoinGecko Base platform " - f"(response: {data!r}). Will pass once the token is fully listed." - ) - self.assertIsInstance(price_entry["usd"], (int, float)) - - def test_opg_price_fetch_live(self): - """End-to-end: fetch_opg_price() returns a positive Decimal price.""" - try: - price = fetch_opg_price() - except requests.exceptions.HTTPError as exc: - if exc.response is not None and exc.response.status_code == 429: - self.skipTest("CoinGecko rate limit — re-run after a short wait") - raise - except ValueError as exc: - if "Unexpected CoinGecko response" in str(exc): - self.skipTest( - f"OPG ({BASE_MAINNET_OPG_ADDRESS}) not yet priced on " - f"CoinGecko Base platform. Details: {exc}" - ) - raise - - self.assertIsInstance(price, Decimal) - self.assertGreater(price, Decimal("0"), "Price must be positive") - print(f"\nLive OPG price: ${price} USD") # noqa: T201 - - -if __name__ == "__main__": - unittest.main() diff --git a/tee_gateway/util.py b/tee_gateway/util.py index ac79cd6..47559d9 100644 --- a/tee_gateway/util.py +++ b/tee_gateway/util.py @@ -2,8 +2,10 @@ from tee_gateway import typing_utils import logging +import threading +import time from decimal import Decimal, InvalidOperation, ROUND_CEILING -from typing import Any, Callable +from typing import Any logger = logging.getLogger("llm_server.dynamic_pricing") @@ -158,6 +160,43 @@ def _deserialize_dict(data, boxed_type): ) from tee_gateway.model_registry import get_model_config # noqa: E402 +TOKEN_A_PRICE_CACHE_TTL_SECONDS = 60 + +_token_price_cache: dict[str, Any] = { + "value": Decimal("1"), + "updated_at": 0.0, +} +_token_price_lock = threading.Lock() + + +def _fetch_token_a_price_usd_mock() -> Decimal: + """Return the USD price of the payment token used for cost calculation. + + Currently returns a fixed 1:1 ratio, which is correct for USDC-denominated + payments (1 USDC ≈ $1 USD). For OPG-denominated payments, replace this + with a live price feed (e.g. a DEX oracle or CoinGecko API call) that + returns the current OPG/USD exchange rate so that token amounts are + calculated correctly against the model's USD pricing. + """ + return Decimal("1") + + +def get_token_a_price_usd() -> Decimal: + now = time.time() + with _token_price_lock: + cached_value = _token_price_cache.get("value") + cached_at = float(_token_price_cache.get("updated_at") or 0.0) + if ( + isinstance(cached_value, Decimal) + and (now - cached_at) < TOKEN_A_PRICE_CACHE_TTL_SECONDS + ): + return cached_value + + value = _fetch_token_a_price_usd_mock() + _token_price_cache["value"] = value + _token_price_cache["updated_at"] = now + return value + def _as_dict(value: Any) -> dict[str, Any] | None: if value is None: @@ -265,36 +304,35 @@ def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int: return ASSET_DECIMALS_BY_ADDRESS[asset_lower] -def calculate_session_cost( - context: dict[str, Any], get_price: Callable[[], Decimal] -) -> int: - """Calculate the x402 session cost in token smallest units for a completed request. +def dynamic_session_cost_calculator(context: dict[str, Any]) -> int: + """Compute UPTO per-request cost in token smallest units from actual usage. - ``get_price`` is called on every invocation to fetch the current OPG/USD - price — pass ``price_feed.get_price`` so the latest cached value is used. - Raises ``ValueError`` on any missing/invalid data. Predictable failures - (unavailable price, unknown model) are blocked before inference by the - pre-inference gate in ``__main__.py``; post-inference failures are logged - as CRITICAL by the caller and the client is not charged. + Raises ValueError on any missing or unrecognised input — no silent fallback. """ request_json = context.get("request_json") response_json = context.get("response_json") if not isinstance(request_json, dict) or not isinstance(response_json, dict): raise ValueError( - "calculate_session_cost requires both request_json and response_json" + "dynamic_session_cost_calculator requires both request_json and response_json" ) model = _extract_model_from_context(request_json, response_json) + + # get_model_config raises ValueError for unknown models — no fallback cfg = get_model_config(model) + input_tokens, output_tokens = _extract_usage_tokens(response_json) - total_usd = (Decimal(input_tokens) * cfg.input_price_usd) + ( - Decimal(output_tokens) * cfg.output_price_usd + input_rate = cfg.input_price_usd + output_rate = cfg.output_price_usd + + total_usd = (Decimal(input_tokens) * input_rate) + ( + Decimal(output_tokens) * output_rate ) - token_price_usd = get_price() + token_price_usd = get_token_a_price_usd() if token_price_usd <= 0: - raise ValueError(f"Token price is non-positive: {token_price_usd}") + raise ValueError(f"Token A price is non-positive: {token_price_usd}") token_amount = total_usd / token_price_usd decimals = _extract_asset_decimals_from_requirements( @@ -306,8 +344,7 @@ def calculate_session_cost( ) logger.info( - "CALCULATE_SESSION_COST model=%s input_tokens=%d output_tokens=%d " - "total_usd=%s token_price_usd=%s decimals=%d cost=%d", + "DYNAMIC_SESSION_COST model=%s input_tokens=%d output_tokens=%d total_usd=%s token_price_usd=%s decimals=%d cost=%d", model, input_tokens, output_tokens, diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 5419782..d1b5f25 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -3,7 +3,7 @@ Tests verify that: - Every user-facing model name resolves to the correct ModelConfig - - calculate_session_cost produces the right amount in OPG token + - dynamic_session_cost_calculator produces the right amount in OPG token smallest-units for supported models - Edge cases (no usage, unknown model, bad context) are handled correctly """ @@ -16,11 +16,7 @@ _MODEL_LOOKUP, get_model_config, ) -from tee_gateway.util import calculate_session_cost - -# All pricing tests assume OPG = $1.00 so USD cost == OPG token amount. -_OPG_PRICE_USD = Decimal("1") -_get_price = lambda: _OPG_PRICE_USD # noqa: E731 +from tee_gateway.util import dynamic_session_cost_calculator # --------------------------------------------------------------------------- @@ -209,12 +205,12 @@ def test_unknown_sonnet_variant_raises(self): # --------------------------------------------------------------------------- -class TestCalculateSessionCostOPG(unittest.TestCase): - """calculate_session_cost with OPG (18 decimals).""" +class TestDynamicSessionCostCalculatorOPG(unittest.TestCase): + """dynamic_session_cost_calculator with OPG (18 decimals).""" def _calc(self, model, input_tokens, output_tokens): - return calculate_session_cost( - _ctx(model, input_tokens, output_tokens, _opg_requirements()), _get_price + return dynamic_session_cost_calculator( + _ctx(model, input_tokens, output_tokens, _opg_requirements()) ) # ── OpenAI ────────────────────────────────────────────────────────────── @@ -355,11 +351,11 @@ def test_grok_4_fast_cheaper_than_grok_4(self): self.assertLess(fast, full) -class TestCalculateSessionCostEdgeCases(unittest.TestCase): - """Edge cases for calculate_session_cost.""" +class TestDynamicSessionCostCalculatorEdgeCases(unittest.TestCase): + """Edge cases for dynamic_session_cost_calculator.""" def test_zero_tokens_returns_zero(self): - cost = calculate_session_cost(_ctx("claude-sonnet-4-5", 0, 0), _get_price) + cost = dynamic_session_cost_calculator(_ctx("claude-sonnet-4-5", 0, 0)) self.assertEqual(cost, 0) def test_missing_usage_raises(self): @@ -369,24 +365,24 @@ def test_missing_usage_raises(self): "payment_requirements": _opg_requirements(), } with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) + dynamic_session_cost_calculator(ctx) def test_unknown_asset_raises(self): ctx = _ctx("claude-sonnet-4-5", 100, 100) ctx["payment_requirements"] = {"asset": "0xdeadbeef", "amount": "1000"} with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) + dynamic_session_cost_calculator(ctx) def test_missing_asset_raises(self): ctx = _ctx("claude-sonnet-4-5", 100, 100) ctx["payment_requirements"] = {"amount": "1000"} # no asset with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) + dynamic_session_cost_calculator(ctx) def test_unknown_model_raises_value_error(self): ctx = _ctx("gpt-4o", 100, 100) with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) + dynamic_session_cost_calculator(ctx) def test_missing_request_json_raises_value_error(self): ctx = { @@ -398,7 +394,7 @@ def test_missing_request_json_raises_value_error(self): "payment_requirements": _opg_requirements(), } with self.assertRaises(ValueError): - calculate_session_cost(ctx, _get_price) + dynamic_session_cost_calculator(ctx) def test_model_from_request_takes_priority(self): """request_json model name is used even if response_json has a different model.""" @@ -410,7 +406,7 @@ def test_model_from_request_takes_priority(self): }, "payment_requirements": _opg_requirements(), } - cost = calculate_session_cost(ctx, _get_price) + cost = dynamic_session_cost_calculator(ctx) # Should be priced as Haiku (from request), not Sonnet haiku_cost = _expected_cost_opg("claude-haiku-4-5", 1000, 500) self.assertEqual(cost, haiku_cost) @@ -418,31 +414,29 @@ def test_model_from_request_takes_priority(self): def test_rounding_ceiling(self): """Fractional token costs are always rounded UP.""" # 1 output token of Haiku: 0.000005 USD = 5e12 wei — exact, no rounding needed - cost = calculate_session_cost(_ctx("claude-haiku-4-5", 0, 1), _get_price) + cost = dynamic_session_cost_calculator(_ctx("claude-haiku-4-5", 0, 1)) self.assertEqual(cost, 5_000_000_000_000) # 1 input token of Gemini Flash Lite: 0.0000001 USD = 1e11 wei — exact - cost = calculate_session_cost(_ctx("gemini-2.5-flash-lite", 1, 0), _get_price) + cost = dynamic_session_cost_calculator(_ctx("gemini-2.5-flash-lite", 1, 0)) self.assertEqual(cost, 100_000_000_000) def test_model_name_case_insensitive(self): """Model names are normalized to lowercase before lookup.""" - cost_lower = calculate_session_cost( - _ctx("claude-sonnet-4-5", 100, 100), _get_price + cost_lower = dynamic_session_cost_calculator( + _ctx("claude-sonnet-4-5", 100, 100) ) - cost_upper = calculate_session_cost( - _ctx("CLAUDE-SONNET-4-5", 100, 100), _get_price + cost_upper = dynamic_session_cost_calculator( + _ctx("CLAUDE-SONNET-4-5", 100, 100) ) self.assertEqual(cost_lower, cost_upper) def test_sonnet_4_0_hyphen_vs_dot_same_cost(self): """claude-sonnet-4-0 and claude-4.0-sonnet are the same model.""" - cost_hyphen = calculate_session_cost( - _ctx("claude-sonnet-4-0", 1000, 500), _get_price - ) - cost_dot = calculate_session_cost( - _ctx("claude-4.0-sonnet", 1000, 500), _get_price + cost_hyphen = dynamic_session_cost_calculator( + _ctx("claude-sonnet-4-0", 1000, 500) ) + cost_dot = dynamic_session_cost_calculator(_ctx("claude-4.0-sonnet", 1000, 500)) self.assertEqual(cost_hyphen, cost_dot)