Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ XAI_API_KEY=
# If HEARTBEAT_CONTRACT_ADDRESS and HEARTBEAT_FACILITATOR_URL are set, the enclave
# signs heartbeat payloads and the facilitator relays on-chain txs.
HEARTBEAT_CONTRACT_ADDRESS=
HEARTBEAT_FACILITATOR_URL=
FACILITATOR_URL=
TEE_HEARTBEAT_INTERVAL=900
22 changes: 14 additions & 8 deletions scripts/run-enclave.sh
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ if [ -f "$ENV_FILE" ]; then
ANTHROPIC_API_KEY="$(grep -E '^ANTHROPIC_API_KEY=' "$ENV_FILE" | cut -d'=' -f2-)"
XAI_API_KEY="$(grep -E '^XAI_API_KEY=' "$ENV_FILE" | cut -d'=' -f2-)"

# Heartbeat configuration (optional — wallet key is generated inside the TEE)
# FACILITATOR_URL is used for both x402 payment verification and the heartbeat relay.
# HEARTBEAT_CONTRACT_ADDRESS and TEE_HEARTBEAT_INTERVAL are optional heartbeat parameters.
# The TEE wallet key is generated inside the enclave and never injected.
HEARTBEAT_CONTRACT_ADDRESS="$(grep -E '^HEARTBEAT_CONTRACT_ADDRESS=' "$ENV_FILE" | cut -d'=' -f2-)"
HEARTBEAT_FACILITATOR_URL="$(grep -E '^HEARTBEAT_FACILITATOR_URL=' "$ENV_FILE" | cut -d'=' -f2-)"
FACILITATOR_URL="$(grep -E '^FACILITATOR_URL=' "$ENV_FILE" | cut -d'=' -f2-)"
Comment thread
kylexqian marked this conversation as resolved.
TEE_HEARTBEAT_INTERVAL="$(grep -E '^TEE_HEARTBEAT_INTERVAL=' "$ENV_FILE" | cut -d'=' -f2-)"

# Build the JSON payload using jq for safe escaping
Expand All @@ -103,7 +105,7 @@ if [ -f "$ENV_FILE" ]; then
--arg anthropic "$ANTHROPIC_API_KEY" \
--arg xai "$XAI_API_KEY" \
--arg hb_contract "$HEARTBEAT_CONTRACT_ADDRESS" \
--arg hb_facilitator "$HEARTBEAT_FACILITATOR_URL" \
--arg facilitator "$FACILITATOR_URL" \
--arg hb_interval "$TEE_HEARTBEAT_INTERVAL" \
'{
openai_api_key: $openai,
Expand All @@ -112,7 +114,7 @@ if [ -f "$ENV_FILE" ]; then
xai_api_key: $xai
}
+ if $hb_contract != "" then {heartbeat_contract_address: $hb_contract} else {} end
+ if $hb_facilitator != "" then {heartbeat_facilitator_url: $hb_facilitator} else {} end
+ if $facilitator != "" then {facilitator_url: $facilitator} else {} end
+ if $hb_interval != "" then {tee_heartbeat_interval: $hb_interval} else {} end
')

Expand All @@ -125,18 +127,22 @@ if [ -f "$ENV_FILE" ]; then

if [ "$http_status" = "200" ]; then
echo "[ec2] API keys injected successfully."
if [ -n "$HEARTBEAT_CONTRACT_ADDRESS" ] && [ -n "$HEARTBEAT_FACILITATOR_URL" ]; then
echo "[ec2] Heartbeat service configured via facilitator relay (contract: ${HEARTBEAT_CONTRACT_ADDRESS})"
if [ -n "$HEARTBEAT_CONTRACT_ADDRESS" ]; then
if [ -n "$FACILITATOR_URL" ]; then
echo "[ec2] Heartbeat service configured via facilitator relay (contract: ${HEARTBEAT_CONTRACT_ADDRESS})"
else
echo "[ec2] Heartbeat service configured using enclave default facilitator URL (contract: ${HEARTBEAT_CONTRACT_ADDRESS})"
fi
else
echo "[ec2] Heartbeat service not configured (missing env vars)."
echo "[ec2] Heartbeat service not configured (missing HEARTBEAT_CONTRACT_ADDRESS)."
fi
else
echo "[ec2] Warning: Key injection returned HTTP $http_status. Check enclave logs."
fi

# Clear key variables from this shell immediately after use
unset OPENAI_API_KEY GOOGLE_API_KEY ANTHROPIC_API_KEY XAI_API_KEY
unset HEARTBEAT_CONTRACT_ADDRESS HEARTBEAT_FACILITATOR_URL TEE_HEARTBEAT_INTERVAL
unset HEARTBEAT_CONTRACT_ADDRESS FACILITATOR_URL TEE_HEARTBEAT_INTERVAL
fi
else
echo "[ec2] No .env file found at $ENV_FILE"
Expand Down
256 changes: 135 additions & 121 deletions tee_gateway/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,83 +108,151 @@ def _shutdown_heartbeat():

atexit.register(_shutdown_heartbeat)


# ---------------------------------------------------------------------------
# OPG price feed — start before x402 middleware so the first request can be
# priced correctly. Runs as a daemon thread; no cleanup needed on exit.
# ---------------------------------------------------------------------------
_price_feed = OPGPriceFeed()
_price_feed.start()

facilitator = HTTPFacilitatorClientSync(FacilitatorConfig(url=FACILITATOR_URL))
server = x402ResourceServerSync(facilitator)
store = SessionStore()

server.register(BASE_MAINNET_NETWORK, ExactEvmServerScheme())

# Upto scheme registrations (permit2-based, variable settlement)
server.register(BASE_MAINNET_NETWORK, UptoEvmServerScheme())

routes = {
"POST /v1/chat/completions": RouteConfig(
accepts=[
PaymentOption(
scheme="upto",
pay_to=EVM_PAYMENT_ADDRESS,
price=AssetAmount(
amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND,
asset=BASE_MAINNET_OPG_ADDRESS,
extra={
"name": "OpenGradient",
"version": "1",
"assetTransferMethod": "permit2",
},
),
network=BASE_MAINNET_NETWORK,
),
],
extensions={
**declare_erc20_approval_gas_sponsoring_extension(),
},
mime_type="application/json",
description="Chat completion",
),
"POST /v1/completions": RouteConfig(
accepts=[
PaymentOption(
scheme="upto",
pay_to=EVM_PAYMENT_ADDRESS,
price=AssetAmount(
amount=COMPLETIONS_OPG_SESSION_MAX_SPEND,
asset=BASE_MAINNET_OPG_ADDRESS,
extra={
"name": "OpenGradient",
"version": "1",
"assetTransferMethod": "permit2",
},
),
network=BASE_MAINNET_NETWORK,
),
],
extensions={
**declare_erc20_approval_gas_sponsoring_extension(),
},
mime_type="application/json",
description="Completion",
),
}

# ---------------------------------------------------------------------------
# x402 read-body patch
#
# Ensures that non-payment 0-length requests can bypass the middleware without
# errors. Applied at module load so it is in place before the middleware
# instance is created at injection time.
# ---------------------------------------------------------------------------
_original_read_body_bytes = x402_flask._read_body_bytes


def _patched_read_body_bytes(environ):
try:
content_length = int(environ.get("CONTENT_LENGTH") or 0)
except (ValueError, TypeError):
content_length = 0

if content_length <= 0:
return b""

return _original_read_body_bytes(environ)


x402_flask._read_body_bytes = _patched_read_body_bytes


def _session_cost_calculator(ctx: dict) -> int:
# Post-inference cost calculation — response already sent to client.
# Predictable failures (unknown price, unknown model) are blocked by the
# pre-inference gate; any exception here indicates a provider-side error
# (e.g. missing usage field in the LLM response). The x402 middleware
# swallows the exception in close(), so the client is not charged.
# Log CRITICAL so provider errors are never silently missed.
try:
return calculate_session_cost(ctx, _price_feed.get_price)
except Exception as exc:
logger.critical(
"Post-inference cost calculation failed (provider error) — "
"client was NOT charged: %s",
exc,
exc_info=True,
)
raise


# ---------------------------------------------------------------------------
# One-time provider key injection
# One-time runtime configuration injection
# ---------------------------------------------------------------------------
_keys_initialized: bool = False
_keys_lock = threading.Lock()


def _init_payment_middleware(facilitator_url: str) -> None:
"""Build and attach the x402 payment middleware to the running Flask app.

Called once from set_provider_keys() after the facilitator URL is known.
Swaps application.wsgi_app so all subsequent requests flow through it.
"""
facilitator = HTTPFacilitatorClientSync(FacilitatorConfig(url=facilitator_url))
server = x402ResourceServerSync(facilitator)
store = SessionStore()

server.register(BASE_MAINNET_NETWORK, ExactEvmServerScheme())
server.register(BASE_MAINNET_NETWORK, UptoEvmServerScheme())

routes = {
"POST /v1/chat/completions": RouteConfig(
accepts=[
PaymentOption(
scheme="upto",
pay_to=EVM_PAYMENT_ADDRESS,
price=AssetAmount(
amount=CHAT_COMPLETIONS_OPG_SESSION_MAX_SPEND,
asset=BASE_MAINNET_OPG_ADDRESS,
extra={
"name": "OpenGradient",
"version": "1",
"assetTransferMethod": "permit2",
},
),
network=BASE_MAINNET_NETWORK,
),
],
extensions={
**declare_erc20_approval_gas_sponsoring_extension(),
},
mime_type="application/json",
description="Chat completion",
),
"POST /v1/completions": RouteConfig(
accepts=[
PaymentOption(
scheme="upto",
pay_to=EVM_PAYMENT_ADDRESS,
price=AssetAmount(
amount=COMPLETIONS_OPG_SESSION_MAX_SPEND,
asset=BASE_MAINNET_OPG_ADDRESS,
extra={
"name": "OpenGradient",
"version": "1",
"assetTransferMethod": "permit2",
},
),
network=BASE_MAINNET_NETWORK,
),
],
extensions={
**declare_erc20_approval_gas_sponsoring_extension(),
},
mime_type="application/json",
description="Completion",
),
}

# Return value intentionally discarded — PaymentMiddleware.__init__ self-wires
# by setting application.wsgi_app = self._wsgi_middleware internally.
payment_middleware(
application,
routes=routes,
server=server,
session_store=store,
Comment thread
kylexqian marked this conversation as resolved.
cost_per_request=100000000000000, # static precheck/fallback estimate
session_idle_timeout=100,
session_cost_calculator=_session_cost_calculator,
)
logger.info(
"x402 payment middleware initialized with facilitator: %s", facilitator_url
)


def set_provider_keys():
"""
POST /v1/keys — inject LLM provider API keys into the enclave.
Can only be called once; subsequent calls return HTTP 409.
POST /v1/keys — inject runtime configuration into the enclave.

Accepts LLM provider API keys, a shared facilitator_url (used for both
x402 payment verification and the heartbeat relay), and optional heartbeat
parameters. Can only be called once; subsequent calls return HTTP 409.
"""
global _keys_initialized

Expand All @@ -207,15 +275,12 @@ def set_provider_keys():
)
set_provider_config(provider_config)

facilitator_url = body.get("facilitator_url") or FACILITATOR_URL

# Build heartbeat config from request body (optional)
contract_address = body.get("heartbeat_contract_address")
facilitator_url = (
body.get("heartbeat_facilitator_url")
or os.getenv("FACILITATOR_URL")
or FACILITATOR_URL
)
heartbeat_config: HeartbeatConfig | None = None
if contract_address and facilitator_url:
if contract_address:
interval_raw = body.get(
"tee_heartbeat_interval", DEFAULT_HEARTBEAT_INTERVAL
)
Expand Down Expand Up @@ -262,14 +327,11 @@ def _set(val: str | None) -> str:
logger.info(
" xai_api_key : %s", _set(provider_config.xai_api_key)
)
logger.info(" facilitator_url : %s", facilitator_url)
logger.info(
" heartbeat_contract_address : %s",
_set(heartbeat_config.contract_address if heartbeat_config else None),
)
logger.info(
" heartbeat_facilitator_url : %s",
_set(heartbeat_config.facilitator_url if heartbeat_config else None),
)
logger.info(
" tee_heartbeat_interval : %s",
heartbeat_config.interval if heartbeat_config else "900 (default)",
Expand All @@ -284,6 +346,8 @@ def _set(val: str | None) -> str:
except Exception as e:
logger.warning(f"Heartbeat initialization failed: {e}")

_init_payment_middleware(facilitator_url)

_keys_initialized = True

providers_set = [
Expand Down Expand Up @@ -359,59 +423,11 @@ def create_app():


# ---------------------------------------------------------------------------
# WSGI application + x402 payment middleware
# WSGI application
# ---------------------------------------------------------------------------

# Create the WSGI application
application = create_app()

# This patch ensures that non-payment 0-length requests can still bypass the middleware
_original_read_body_bytes = x402_flask._read_body_bytes


def _patched_read_body_bytes(environ):
try:
content_length = int(environ.get("CONTENT_LENGTH") or 0)
except (ValueError, TypeError):
content_length = 0

if content_length <= 0:
return b""

return _original_read_body_bytes(environ)


x402_flask._read_body_bytes = _patched_read_body_bytes


def _session_cost_calculator(ctx: dict) -> int:
# Post-inference cost calculation — response already sent to client.
# Predictable failures (unknown price, unknown model) are blocked by the
# pre-inference gate; any exception here indicates a provider-side error
# (e.g. missing usage field in the LLM response). The x402 middleware
# swallows the exception in close(), so the client is not charged.
# Log CRITICAL so provider errors are never silently missed.
try:
return calculate_session_cost(ctx, _price_feed.get_price)
except Exception as exc:
logger.critical(
"Post-inference cost calculation failed (provider error) — "
"client was NOT charged: %s",
exc,
exc_info=True,
)
raise


_payment_mw = payment_middleware(
application,
routes=routes,
server=server,
session_store=store,
cost_per_request=100000000000000, # static precheck/fallback estimate
session_idle_timeout=100,
session_cost_calculator=_session_cost_calculator,
)

# ---------------------------------------------------------------------------
# Pre-inference pricing gate
Expand Down Expand Up @@ -444,8 +460,6 @@ def _check_pricing_ready():
return jsonify({"error": f"Model '{model}' is not supported"}), 400


logger.info("x402 payment middleware initialized")

if __name__ == "__main__":
port = int(os.getenv("API_SERVER_PORT", "8000"))
host = os.getenv("API_SERVER_HOST", "0.0.0.0")
Expand Down
Loading
Loading