diff --git a/packages/billing/ai-billing.ts b/packages/billing/ai-billing.ts index 50709fe8e0..97197b9041 100644 --- a/packages/billing/ai-billing.ts +++ b/packages/billing/ai-billing.ts @@ -73,6 +73,42 @@ export async function validateAICredits( }; } +export async function spendUsageCost( + dbAdapter: DBAdapter, + matrixUserId: string, + costInUsd: number, +) { + try { + if ( + typeof costInUsd !== 'number' || + !Number.isFinite(costInUsd) || + costInUsd < 0 + ) { + log.warn( + `Invalid costInUsd value: ${costInUsd} for user ${matrixUserId}, skipping`, + ); + return; + } + + let creditsConsumed = Math.round(costInUsd * CREDITS_PER_USD); + let user = await getUserByMatrixUserId(dbAdapter, matrixUserId); + + if (!user) { + throw new Error( + `should not happen: user with matrix id ${matrixUserId} not found in the users table`, + ); + } + + await spendCredits(dbAdapter, user.id, creditsConsumed); + } catch (err) { + log.error( + `Failed to spend usage cost (matrixUserId: ${matrixUserId}, costInUsd: ${costInUsd}):`, + err, + ); + Sentry.captureException(err); + } +} + export async function saveUsageCost( dbAdapter: DBAdapter, matrixUserId: string, diff --git a/packages/realm-server/handlers/handle-request-forward.ts b/packages/realm-server/handlers/handle-request-forward.ts index 76b9de6871..84971433bd 100644 --- a/packages/realm-server/handlers/handle-request-forward.ts +++ b/packages/realm-server/handlers/handle-request-forward.ts @@ -13,6 +13,10 @@ import * as Sentry from '@sentry/node'; const log = logger('request-forward'); +// Track pending cost-saving promises per user so we can ensure the previous +// request's cost has been recorded before allowing a new one +const pendingCostPromises = new Map>(); + async function handleStreamingRequest( ctxt: Koa.Context, url: string, @@ -61,13 +65,24 @@ async function handleStreamingRequest( // Handle end of stream if (data === '[DONE]') { if (generationId) { - // Create a mock response object with the generation ID for the credit strategy - const mockResponse = { id: generationId }; - await endpointConfig.creditStrategy.saveUsageCost( - dbAdapter, - matrixUserId, - mockResponse, - ); + // Save cost in the background so we don't block the stream on OpenRouter's generation cost API. + // Chain per-user promises so costs are recorded sequentially. + const previousPromise = + pendingCostPromises.get(matrixUserId) ?? Promise.resolve(); + const costPromise = previousPromise + .then(() => + endpointConfig.creditStrategy.saveUsageCost( + dbAdapter, + matrixUserId, + { id: generationId }, + ), + ) + .finally(() => { + if (pendingCostPromises.get(matrixUserId) === costPromise) { + pendingCostPromises.delete(matrixUserId); + } + }); + pendingCostPromises.set(matrixUserId, costPromise); } ctxt.res.write(`data: [DONE]\n\n`); return 'stop'; @@ -328,7 +343,22 @@ export default function handleRequestForward({ return; } - // 4. Check user has sufficient credits using credit strategy + // 4. Wait for any pending cost from a previous request to be recorded + const pendingCost = pendingCostPromises.get(matrixUserId); + if (pendingCost) { + try { + await pendingCost; + } catch (e) { + log.error('Error waiting for pending cost:', e); + await sendResponseForSystemError( + ctxt, + 'There was an error saving your Boxel credits usage. Try again or contact support if the problem persists.', + ); + return; + } + } + + // 5. Check user has sufficient credits using credit strategy const creditValidation = await destinationConfig.creditStrategy.validateCredits( dbAdapter, @@ -469,12 +499,47 @@ export default function handleRequestForward({ const responseData = await externalResponse.json(); - // 6. Calculate and deduct credits using credit strategy - await destinationConfig.creditStrategy.saveUsageCost( - dbAdapter, - matrixUserId, - responseData, - ); + // 6. Deduct credits in the background using the cost from the response, + // or fall back to saveUsageCost when the cost is not provided. + const costInUsd = responseData?.usage?.cost; + const previousPromise = + pendingCostPromises.get(matrixUserId) ?? Promise.resolve(); + let costPromise: Promise; + + if ( + typeof costInUsd === 'number' && + Number.isFinite(costInUsd) && + costInUsd > 0 + ) { + costPromise = previousPromise + .then(() => + destinationConfig.creditStrategy.spendUsageCost( + dbAdapter, + matrixUserId, + costInUsd, + ), + ) + .finally(() => { + if (pendingCostPromises.get(matrixUserId) === costPromise) { + pendingCostPromises.delete(matrixUserId); + } + }); + } else { + costPromise = previousPromise + .then(() => + destinationConfig.creditStrategy.saveUsageCost( + dbAdapter, + matrixUserId, + responseData, + ), + ) + .finally(() => { + if (pendingCostPromises.get(matrixUserId) === costPromise) { + pendingCostPromises.delete(matrixUserId); + } + }); + } + pendingCostPromises.set(matrixUserId, costPromise); // 7. Return response const response = new Response(JSON.stringify(responseData), { diff --git a/packages/realm-server/lib/credit-strategies.ts b/packages/realm-server/lib/credit-strategies.ts index 304b52fa35..e97c90bdc0 100644 --- a/packages/realm-server/lib/credit-strategies.ts +++ b/packages/realm-server/lib/credit-strategies.ts @@ -6,6 +6,7 @@ import { validateAICredits, extractGenerationIdFromResponse, saveUsageCost as saveUsageCostFromBilling, + spendUsageCost as spendUsageCostFromBilling, } from '@cardstack/billing/ai-billing'; export interface CreditStrategy { @@ -23,6 +24,11 @@ export interface CreditStrategy { matrixUserId: string, response: any, ): Promise; + spendUsageCost( + dbAdapter: DBAdapter, + matrixUserId: string, + costInUsd: number, + ): Promise; } // Default AI Bot Credit Strategy (reused from AI bot) @@ -62,6 +68,14 @@ export class OpenRouterCreditStrategy implements CreditStrategy { ); } } + + async spendUsageCost( + dbAdapter: DBAdapter, + matrixUserId: string, + costInUsd: number, + ): Promise { + await spendUsageCostFromBilling(dbAdapter, matrixUserId, costInUsd); + } } // No Credit Strategy (for free endpoints) @@ -82,6 +96,14 @@ export class NoCreditStrategy implements CreditStrategy { ): Promise { // No-op for no-credit strategy } + + async spendUsageCost( + _dbAdapter: DBAdapter, + _matrixUserId: string, + _costInUsd: number, + ): Promise { + // No-op for no-credit strategy + } } // Credit Strategy Factory diff --git a/packages/realm-server/tests/request-forward-test.ts b/packages/realm-server/tests/request-forward-test.ts index 59ecd121b3..19785176d6 100644 --- a/packages/realm-server/tests/request-forward-test.ts +++ b/packages/realm-server/tests/request-forward-test.ts @@ -15,6 +15,7 @@ import { insertPlan, realmSecretSeed, createVirtualNetwork, + waitUntil, } from './helpers'; import { createJWT as createRealmServerJWT } from '../utils/jwt'; import { @@ -134,34 +135,20 @@ module(basename(__filename), function () { const originalFetch = global.fetch; const mockFetch = sinon.stub(global, 'fetch'); - // Mock OpenRouter response + // Mock OpenRouter response (includes usage.cost so credits can be + // deducted directly without polling the generation cost API) const mockOpenRouterResponse = { id: 'gen-test-123', choices: [{ text: 'Test response from OpenRouter' }], - usage: { total_tokens: 150 }, + usage: { total_tokens: 150, cost: 0.003 }, }; - // Mock generation cost API response - const mockCostResponse = { - data: { - id: 'gen-test-123', - total_cost: 0.003, - total_tokens: 150, - model: 'openai/gpt-3.5-turbo', - }, - }; - - // Set up fetch to return different responses based on URL + // Set up fetch to return OpenRouter response mockFetch.callsFake( async (input: string | URL | Request, _init?: RequestInit) => { const url = typeof input === 'string' ? input : input.toString(); - if (url.includes('/generation?id=')) { - return new Response(JSON.stringify(mockCostResponse), { - status: 200, - headers: { 'content-type': 'application/json' }, - }); - } else if (url.includes('/chat/completions')) { + if (url.includes('/chat/completions')) { return new Response(JSON.stringify(mockOpenRouterResponse), { status: 200, headers: { 'content-type': 'application/json' }, @@ -207,36 +194,39 @@ module(basename(__filename), function () { // Verify fetch was called correctly (allowing unrelated fetches) const calls = mockFetch.getCalls(); - const chatCallIndex = calls.findIndex((call) => { + const chatCall = calls.find((call) => { const url = call.args[0]; const href = typeof url === 'string' ? url : url?.toString(); return Boolean(href && href.includes('/chat/completions')); }); - const generationCallIndex = calls.findIndex((call) => { - const url = call.args[0]; - const href = typeof url === 'string' ? url : url?.toString(); - return Boolean(href && href.includes('/generation?id=')); - }); - assert.true(chatCallIndex >= 0, 'Fetch should call chat completions'); - assert.true( - generationCallIndex >= 0, - 'Fetch should call generation cost API', - ); - assert.true( - chatCallIndex < generationCallIndex, - 'Generation cost should be fetched after chat completions', - ); + assert.ok(chatCall, 'Fetch should call chat completions'); // Verify authorization header was set correctly - const firstCallHeaders = calls[chatCallIndex].args[1] - ?.headers as Record; - // Note: The actual authorization header will include the JWT token, not the API key - // The API key is added by the proxy handler, not the test + const chatCallHeaders = chatCall!.args[1]?.headers as Record< + string, + string + >; assert.true( - firstCallHeaders?.Authorization?.startsWith('Bearer '), + chatCallHeaders?.Authorization?.startsWith('Bearer '), 'Should set authorization header', ); + + // Verify credits were deducted (0.003 USD * 1000 = 3 credits) + const user = await getUserByMatrixUserId( + dbAdapter, + '@testuser:localhost', + ); + await waitUntil( + async () => { + const credits = await sumUpCreditsLedger(dbAdapter, { + creditType: ['extra_credit', 'extra_credit_used'], + userId: user!.id, + }); + return credits === 47; + }, + { timeoutMessage: 'Credits should be deducted (50 - 3 = 47)' }, + ); } finally { mockFetch.restore(); global.fetch = originalFetch;