Skip to content

Commit a3e6ce4

Browse files
committed
fix double billing risk
1 parent 8c6f3e3 commit a3e6ce4

2 files changed

Lines changed: 212 additions & 73 deletions

File tree

apps/sim/lib/billing/threshold-billing.test.ts

Lines changed: 116 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,23 +130,39 @@ const userSubscription = {
130130
status: 'active',
131131
}
132132

133-
function buildCustomerSelectChain(customerId = 'cus_1') {
133+
function buildSelectChain<T>(rows: T[]) {
134+
const chain = {
135+
from: vi.fn(() => chain),
136+
leftJoin: vi.fn(() => chain),
137+
innerJoin: vi.fn(() => chain),
138+
where: vi.fn(() => result),
139+
}
140+
const result = {
141+
limit: vi.fn(async () => rows),
142+
then: (resolve: (value: T[]) => unknown, reject?: (reason: unknown) => unknown) =>
143+
Promise.resolve(rows).then(resolve, reject),
144+
}
145+
134146
return {
135-
from: vi.fn(() => ({
136-
where: vi.fn(() => ({
137-
limit: vi.fn(async () => [{ stripeCustomerId: customerId }]),
138-
})),
139-
})),
147+
from: chain.from,
140148
}
141149
}
142150

151+
function buildCustomerSelectChain(customerId = 'cus_1') {
152+
return buildSelectChain([{ stripeCustomerId: customerId }])
153+
}
154+
143155
function buildStatsSelectChain() {
156+
const result = {
157+
limit: mockTxStatsLimit,
158+
then: (resolve: (value: unknown[]) => unknown, reject?: (reason: unknown) => unknown) =>
159+
Promise.resolve(mockTxStatsLimit()).then(resolve, reject),
160+
}
161+
144162
return {
145163
from: vi.fn(() => ({
146164
where: vi.fn(() => ({
147-
for: vi.fn(() => ({
148-
limit: mockTxStatsLimit,
149-
})),
165+
for: vi.fn(() => result),
150166
})),
151167
})),
152168
}
@@ -223,4 +239,95 @@ describe('checkAndBillOverageThreshold', () => {
223239
expect(mockTxUpdate).not.toHaveBeenCalled()
224240
expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled()
225241
})
242+
243+
it('computes organization overage before opening the locked transaction', async () => {
244+
mockIsOrgScopedSubscription.mockReturnValue(true)
245+
mockIsOrganizationBillingBlocked.mockResolvedValue(false)
246+
mockGetOrganizationSubscriptionUsable.mockResolvedValue({
247+
plan: 'team',
248+
seats: 2,
249+
periodStart: new Date('2026-05-01T00:00:00.000Z'),
250+
periodEnd: new Date('2026-06-01T00:00:00.000Z'),
251+
stripeSubscriptionId: 'sub_team_1',
252+
stripeCustomerId: 'cus_team_1',
253+
})
254+
mockDbSelect.mockImplementationOnce(() =>
255+
buildSelectChain([
256+
{
257+
userId: 'owner-1',
258+
role: 'owner',
259+
currentPeriodCost: '350',
260+
departedMemberUsage: '25',
261+
},
262+
])
263+
)
264+
mockComputeOrgOverageAmount.mockResolvedValue({
265+
totalOverage: 250,
266+
baseSubscriptionAmount: 100,
267+
effectiveUsage: 350,
268+
})
269+
mockTxStatsLimit
270+
.mockResolvedValueOnce([
271+
{ userId: 'owner-1', currentPeriodCost: '350', billedOverageThisPeriod: '0' },
272+
])
273+
.mockResolvedValueOnce([{ creditBalance: '0', departedMemberUsage: '25' }])
274+
275+
await checkAndBillOverageThreshold('user-1')
276+
277+
expect(mockComputeOrgOverageAmount).toHaveBeenCalledWith({
278+
plan: 'team',
279+
seats: 2,
280+
periodStart: new Date('2026-05-01T00:00:00.000Z'),
281+
periodEnd: new Date('2026-06-01T00:00:00.000Z'),
282+
organizationId: userSubscription.referenceId,
283+
pooledCurrentPeriodCost: 350,
284+
departedMemberUsage: 25,
285+
memberIds: ['owner-1'],
286+
})
287+
expect(mockDbTransaction).toHaveBeenCalled()
288+
expect(mockComputeOrgOverageAmount.mock.invocationCallOrder[0]).toBeLessThan(
289+
mockDbTransaction.mock.invocationCallOrder[0]
290+
)
291+
expect(mockTxExecute).toHaveBeenCalledTimes(1)
292+
expect(mockEnqueueOutboxEvent).toHaveBeenCalledTimes(1)
293+
})
294+
295+
it('skips stale organization overage when locked usage inputs changed', async () => {
296+
mockIsOrgScopedSubscription.mockReturnValue(true)
297+
mockIsOrganizationBillingBlocked.mockResolvedValue(false)
298+
mockGetOrganizationSubscriptionUsable.mockResolvedValue({
299+
plan: 'team',
300+
seats: 2,
301+
periodStart: new Date('2026-05-01T00:00:00.000Z'),
302+
periodEnd: new Date('2026-06-01T00:00:00.000Z'),
303+
stripeSubscriptionId: 'sub_team_1',
304+
stripeCustomerId: 'cus_team_1',
305+
})
306+
mockDbSelect.mockImplementationOnce(() =>
307+
buildSelectChain([
308+
{
309+
userId: 'owner-1',
310+
role: 'owner',
311+
currentPeriodCost: '350',
312+
departedMemberUsage: '25',
313+
},
314+
])
315+
)
316+
mockComputeOrgOverageAmount.mockResolvedValue({
317+
totalOverage: 250,
318+
baseSubscriptionAmount: 100,
319+
effectiveUsage: 350,
320+
})
321+
mockTxStatsLimit
322+
.mockResolvedValueOnce([
323+
{ userId: 'owner-1', currentPeriodCost: '350', billedOverageThisPeriod: '0' },
324+
])
325+
.mockResolvedValueOnce([{ creditBalance: '0', departedMemberUsage: '75' }])
326+
327+
await checkAndBillOverageThreshold('user-1')
328+
329+
expect(mockDbTransaction).toHaveBeenCalled()
330+
expect(mockEnqueueOutboxEvent).not.toHaveBeenCalled()
331+
expect(mockTxUpdate).not.toHaveBeenCalled()
332+
})
226333
})

apps/sim/lib/billing/threshold-billing.ts

Lines changed: 96 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ const logger = createLogger('ThresholdBilling')
2323

2424
const OVERAGE_THRESHOLD = envNumber(env.OVERAGE_THRESHOLD_DOLLARS, DEFAULT_OVERAGE_THRESHOLD)
2525
const USER_STATS_LOCK_TIMEOUT_MS = 5_000
26+
const USAGE_TOTAL_EPSILON = 0.000001
2627

2728
export async function checkAndBillOverageThreshold(userId: string): Promise<void> {
2829
try {
@@ -247,23 +248,30 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string):
247248
return
248249
}
249250

250-
const members = await db
251-
.select({ userId: member.userId, role: member.role })
251+
const memberUsageRows = await db
252+
.select({
253+
userId: member.userId,
254+
role: member.role,
255+
currentPeriodCost: userStats.currentPeriodCost,
256+
departedMemberUsage: organization.departedMemberUsage,
257+
})
252258
.from(member)
259+
.leftJoin(userStats, eq(member.userId, userStats.userId))
260+
.innerJoin(organization, eq(organization.id, member.organizationId))
253261
.where(eq(member.organizationId, organizationId))
254262

255263
logger.debug('Found organization members', {
256264
organizationId,
257-
memberCount: members.length,
258-
members: members.map((m) => ({ userId: m.userId, role: m.role })),
265+
memberCount: memberUsageRows.length,
266+
members: memberUsageRows.map((m) => ({ userId: m.userId, role: m.role })),
259267
})
260268

261-
if (members.length === 0) {
269+
if (memberUsageRows.length === 0) {
262270
logger.warn('No members found for organization', { organizationId })
263271
return
264272
}
265273

266-
const owner = members.find((m) => m.role === 'owner')
274+
const owner = memberUsageRows.find((m) => m.role === 'owner')
267275
if (!owner) {
268276
logger.error(
269277
'Organization has no owner when running threshold billing — data integrity issue, skipping',
@@ -277,16 +285,72 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string):
277285
ownerId: owner.userId,
278286
})
279287

288+
const memberIds = memberUsageRows.map((m) => m.userId)
289+
let pooledCurrentPeriodCost = 0
290+
for (const stats of memberUsageRows) {
291+
pooledCurrentPeriodCost += toNumber(toDecimal(stats.currentPeriodCost))
292+
}
293+
294+
const departedMemberUsage = toNumber(toDecimal(memberUsageRows[0].departedMemberUsage))
295+
296+
const {
297+
totalOverage: currentOverage,
298+
baseSubscriptionAmount: basePrice,
299+
effectiveUsage: effectiveTeamUsage,
300+
} = await computeOrgOverageAmount({
301+
plan: orgSubscription.plan,
302+
seats: orgSubscription.seats ?? null,
303+
periodStart: orgSubscription.periodStart ?? null,
304+
periodEnd: orgSubscription.periodEnd ?? null,
305+
organizationId,
306+
pooledCurrentPeriodCost,
307+
departedMemberUsage,
308+
memberIds,
309+
})
310+
311+
if (currentOverage < threshold) {
312+
logger.debug('Organization threshold billing check below threshold before locking', {
313+
organizationId,
314+
totalTeamUsage: pooledCurrentPeriodCost + departedMemberUsage,
315+
effectiveTeamUsage,
316+
basePrice,
317+
currentOverage,
318+
threshold,
319+
})
320+
return
321+
}
322+
323+
// Validate Stripe identifiers BEFORE mutating credits/trackers.
324+
const stripeSubscriptionId = orgSubscription.stripeSubscriptionId
325+
if (!stripeSubscriptionId) {
326+
logger.error('No Stripe subscription ID for organization', { organizationId })
327+
return
328+
}
329+
330+
const customerId = orgSubscription.stripeCustomerId
331+
if (!customerId) {
332+
logger.error('No Stripe customer ID for organization', { organizationId })
333+
return
334+
}
335+
336+
const periodEnd = orgSubscription.periodEnd
337+
? Math.floor(orgSubscription.periodEnd.getTime() / 1000)
338+
: Math.floor(Date.now() / 1000)
339+
const billingPeriod = new Date(periodEnd * 1000).toISOString().slice(0, 7)
340+
const totalOverageCents = Math.round(currentOverage * 100)
341+
280342
await db.transaction(async (tx) => {
281343
await tx.execute(sql.raw(`SET LOCAL lock_timeout = '${USER_STATS_LOCK_TIMEOUT_MS}ms'`))
282344

283-
// Lock both owner stats and organization rows
284-
const ownerStatsLock = await tx
285-
.select()
345+
const lockedMemberStats = await tx
346+
.select({
347+
userId: userStats.userId,
348+
currentPeriodCost: userStats.currentPeriodCost,
349+
billedOverageThisPeriod: userStats.billedOverageThisPeriod,
350+
})
286351
.from(userStats)
287-
.where(eq(userStats.userId, owner.userId))
352+
.where(inArray(userStats.userId, memberIds))
288353
.for('update')
289-
.limit(1)
290354

291355
const orgLock = await tx
292356
.select()
@@ -295,7 +359,8 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string):
295359
.for('update')
296360
.limit(1)
297361

298-
if (ownerStatsLock.length === 0) {
362+
const ownerStatsLock = lockedMemberStats.find((stats) => stats.userId === owner.userId)
363+
if (!ownerStatsLock) {
299364
logger.error('Owner stats not found', { organizationId, ownerId: owner.userId })
300365
return
301366
}
@@ -305,42 +370,27 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string):
305370
return
306371
}
307372

308-
let pooledCurrentPeriodCost = toNumber(toDecimal(ownerStatsLock[0].currentPeriodCost))
309-
const totalBilledOverage = toNumber(toDecimal(ownerStatsLock[0].billedOverageThisPeriod))
310-
const orgCreditBalance = toNumber(toDecimal(orgLock[0].creditBalance))
311-
312-
const nonOwnerIds = members.filter((m) => m.userId !== owner.userId).map((m) => m.userId)
313-
314-
if (nonOwnerIds.length > 0) {
315-
const memberStatsRows = await tx
316-
.select({
317-
userId: userStats.userId,
318-
currentPeriodCost: userStats.currentPeriodCost,
319-
})
320-
.from(userStats)
321-
.where(inArray(userStats.userId, nonOwnerIds))
322-
323-
for (const stats of memberStatsRows) {
324-
pooledCurrentPeriodCost += toNumber(toDecimal(stats.currentPeriodCost))
325-
}
373+
let lockedPooledCurrentPeriodCost = 0
374+
for (const stats of lockedMemberStats) {
375+
lockedPooledCurrentPeriodCost += toNumber(toDecimal(stats.currentPeriodCost))
376+
}
377+
const lockedDepartedMemberUsage = toNumber(toDecimal(orgLock[0].departedMemberUsage))
378+
if (
379+
Math.abs(lockedPooledCurrentPeriodCost - pooledCurrentPeriodCost) > USAGE_TOTAL_EPSILON ||
380+
Math.abs(lockedDepartedMemberUsage - departedMemberUsage) > USAGE_TOTAL_EPSILON
381+
) {
382+
logger.debug('Organization usage changed during threshold billing check; retry later', {
383+
organizationId,
384+
pooledCurrentPeriodCost,
385+
lockedPooledCurrentPeriodCost,
386+
departedMemberUsage,
387+
lockedDepartedMemberUsage,
388+
})
389+
return
326390
}
327391

328-
const departedMemberUsage = toNumber(toDecimal(orgLock[0].departedMemberUsage))
329-
330-
const {
331-
totalOverage: currentOverage,
332-
baseSubscriptionAmount: basePrice,
333-
effectiveUsage: effectiveTeamUsage,
334-
} = await computeOrgOverageAmount({
335-
plan: orgSubscription.plan,
336-
seats: orgSubscription.seats ?? null,
337-
periodStart: orgSubscription.periodStart ?? null,
338-
periodEnd: orgSubscription.periodEnd ?? null,
339-
organizationId,
340-
pooledCurrentPeriodCost,
341-
departedMemberUsage,
342-
memberIds: members.map((m) => m.userId),
343-
})
392+
const totalBilledOverage = toNumber(toDecimal(ownerStatsLock.billedOverageThisPeriod))
393+
const orgCreditBalance = toNumber(toDecimal(orgLock[0].creditBalance))
344394

345395
const unbilledOverage = Math.max(0, currentOverage - totalBilledOverage)
346396

@@ -359,19 +409,6 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string):
359409
return
360410
}
361411

362-
// Validate Stripe identifiers BEFORE mutating credits/trackers.
363-
const stripeSubscriptionId = orgSubscription.stripeSubscriptionId
364-
if (!stripeSubscriptionId) {
365-
logger.error('No Stripe subscription ID for organization', { organizationId })
366-
return
367-
}
368-
369-
const customerId = orgSubscription.stripeCustomerId
370-
if (!customerId) {
371-
logger.error('No Stripe customer ID for organization', { organizationId })
372-
return
373-
}
374-
375412
let amountToBill = unbilledOverage
376413
let creditsApplied = 0
377414

@@ -410,12 +447,7 @@ async function checkAndBillOrganizationOverageThreshold(organizationId: string):
410447
return
411448
}
412449

413-
const periodEnd = orgSubscription.periodEnd
414-
? Math.floor(orgSubscription.periodEnd.getTime() / 1000)
415-
: Math.floor(Date.now() / 1000)
416-
const billingPeriod = new Date(periodEnd * 1000).toISOString().slice(0, 7)
417450
const amountCents = Math.round(amountToBill * 100)
418-
const totalOverageCents = Math.round(currentOverage * 100)
419451

420452
// Bump billed tracker and enqueue Stripe invoice atomically.
421453
// See user-path above for the full retry-invariant reasoning.

0 commit comments

Comments
 (0)