From 1d6e7d1938aaded7a020f72e4b8084fc3e9c6cc8 Mon Sep 17 00:00:00 2001 From: almeida Date: Mon, 18 May 2026 10:00:41 -0300 Subject: [PATCH] feat(paykit): add explicit subscription cancel API --- packages/paykit/src/api/methods.ts | 3 +- .../__tests__/subscription.service.test.ts | 226 ++++++++++++++++++ .../src/subscription/subscription.api.ts | 23 +- .../src/subscription/subscription.service.ts | 139 ++++++++++- .../src/subscription/subscription.types.ts | 11 + packages/paykit/src/types/instance.ts | 49 +++- 6 files changed, 436 insertions(+), 15 deletions(-) create mode 100644 packages/paykit/src/subscription/__tests__/subscription.service.test.ts diff --git a/packages/paykit/src/api/methods.ts b/packages/paykit/src/api/methods.ts index d9b881c1..83605a51 100644 --- a/packages/paykit/src/api/methods.ts +++ b/packages/paykit/src/api/methods.ts @@ -9,7 +9,7 @@ import { upsertCustomer, } from "../customer/customer.api"; import { check, report } from "../entitlement/entitlement.api"; -import { subscribe } from "../subscription/subscription.api"; +import { cancelSubscription, subscribe } from "../subscription/subscription.api"; import { advanceTestClock, getTestClock } from "../testing/testing.api"; import type { PayKitOptions } from "../types/options"; import { receiveWebhook } from "../webhook/webhook.api"; @@ -17,6 +17,7 @@ import type { PayKitMethod } from "./define-route"; export const baseMethods = { subscribe, + cancelSubscription, customerPortal, upsertCustomer, getCustomer, diff --git a/packages/paykit/src/subscription/__tests__/subscription.service.test.ts b/packages/paykit/src/subscription/__tests__/subscription.service.test.ts new file mode 100644 index 00000000..fd8994c6 --- /dev/null +++ b/packages/paykit/src/subscription/__tests__/subscription.service.test.ts @@ -0,0 +1,226 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +import type { PayKitContext } from "../../core/context"; +import type { StoredSubscription } from "../../types/models"; +import type { NormalizedSchema } from "../../types/schema"; + +const { + getDefaultProductInGroup, + getProductByHash, + getProductByInternalId, + getProductByProviderData, + getProductFeatures, + withProviderInfo, +} = vi.hoisted(() => ({ + getDefaultProductInGroup: vi.fn(), + getProductByHash: vi.fn(), + getProductByInternalId: vi.fn(), + getProductByProviderData: vi.fn(), + getProductFeatures: vi.fn(), + withProviderInfo: vi.fn(), +})); + +vi.mock("../../product/product.service", () => ({ + getDefaultProductInGroup, + getProductByHash, + getProductByInternalId, + getProductByProviderData, + getProductFeatures, + withProviderInfo, +})); + +import { cancelPlanSubscription } from "../subscription.service"; + +const emptyProducts: NormalizedSchema = { + features: [], + plans: [], + planMap: new Map(), +}; + +function createSelectChain(result: unknown, terminalMethod: "limit" | "orderBy") { + const chain: Record = { + from: vi.fn(), + innerJoin: vi.fn(), + where: vi.fn(), + orderBy: vi.fn(), + limit: vi.fn(), + }; + + chain.from = vi.fn().mockReturnValue(chain); + chain.innerJoin = vi.fn().mockReturnValue(chain); + chain.where = vi.fn().mockReturnValue(chain); + chain.orderBy = + terminalMethod === "orderBy" + ? vi.fn().mockResolvedValue(result) + : vi.fn().mockReturnValue(chain); + chain.limit = terminalMethod === "limit" ? vi.fn().mockResolvedValue(result) : vi.fn(); + + return chain; +} + +function createUpdateChain(result?: unknown) { + const where = vi.fn().mockResolvedValue(result); + const set = vi.fn().mockReturnValue({ where }); + return { set, where }; +} + +describe("subscription/service", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("cancels the active subscription and schedules the default free plan", async () => { + const activeSubscriptionRow = { + subscription: { + cancelAtPeriodEnd: false, + canceled: false, + canceledAt: null, + createdAt: new Date("2024-01-01T00:00:00.000Z"), + currentPeriodEndAt: new Date("2024-02-01T00:00:00.000Z"), + currentPeriodStartAt: new Date("2024-01-01T00:00:00.000Z"), + customerId: "customer_123", + endedAt: null, + id: "sub_active", + productInternalId: "product_pro_internal", + providerData: { + subscriptionId: "sub_provider_123", + }, + providerId: "stripe", + quantity: 1, + scheduledProductId: null, + startedAt: new Date("2024-01-01T00:00:00.000Z"), + status: "active", + trialEndsAt: null, + updatedAt: new Date("2024-01-01T00:00:00.000Z"), + } satisfies StoredSubscription, + product: { + createdAt: new Date("2024-01-01T00:00:00.000Z"), + currency: "usd", + group: "default", + hash: "pro_hash", + id: "pro", + internalId: "product_pro_internal", + isDefault: false, + name: "Pro", + priceAmount: 1900, + priceInterval: "month", + provider: null, + updatedAt: new Date("2024-01-01T00:00:00.000Z"), + }, + }; + const activeSubscriptionSelect = createSelectChain([activeSubscriptionRow], "limit"); + const scheduledSubscriptionSelect = createSelectChain([], "orderBy"); + const insertReturning = vi.fn().mockResolvedValue([ + { + id: "sub_free_scheduled", + }, + ]); + const insertValues = vi.fn().mockReturnValue({ returning: insertReturning }); + const insert = vi.fn().mockReturnValue({ values: insertValues }); + const cancelUpdate = createUpdateChain(); + const scheduleUpdate = createUpdateChain(); + + getProductByHash.mockResolvedValue({ id: "pro" }); + withProviderInfo.mockReturnValue({ + group: "default", + id: "pro", + internalId: "product_pro_internal", + priceAmount: 1900, + priceInterval: "month", + providerProduct: { priceId: "price_pro_123" }, + }); + getDefaultProductInGroup.mockResolvedValue({ + id: "free", + internalId: "product_free_internal", + priceAmount: null, + }); + + const tx = { + insert, + query: { + subscription: { + findFirst: vi.fn().mockResolvedValue(activeSubscriptionRow.subscription), + }, + }, + select: vi.fn().mockReturnValue(scheduledSubscriptionSelect), + update: vi.fn().mockReturnValueOnce({ set: cancelUpdate.set }).mockReturnValueOnce({ + set: scheduleUpdate.set, + }), + }; + + const providerCancel = vi.fn().mockResolvedValue({ + paymentUrl: null, + subscription: null, + }); + const ctx = { + database: { + select: vi.fn().mockReturnValue(activeSubscriptionSelect), + transaction: vi.fn().mockImplementation(async (callback) => callback(tx)), + }, + logger: { + info: vi.fn(), + trace: { + run: vi.fn().mockImplementation(async (_label, fn) => fn()), + }, + }, + products: { + ...emptyProducts, + planMap: new Map([ + [ + "pro", + { + hash: "pro_hash", + id: "pro", + includes: [], + }, + ], + [ + "free", + { + hash: "free_hash", + id: "free", + includes: [], + }, + ], + ]), + }, + provider: { + cancelSubscription: providerCancel, + id: "stripe", + name: "Stripe", + }, + } as unknown as PayKitContext; + + const result = await cancelPlanSubscription(ctx, { + customerId: "customer_123", + planId: "pro", + }); + + expect(result).toEqual({ + paymentUrl: null, + requiredAction: null, + }); + expect(providerCancel).toHaveBeenCalledWith({ + currentPeriodEndAt: new Date("2024-02-01T00:00:00.000Z"), + providerSubscriptionId: "sub_provider_123", + providerSubscriptionScheduleId: null, + }); + expect(insertValues).toHaveBeenCalledWith( + expect.objectContaining({ + customerId: "customer_123", + productInternalId: "product_free_internal", + status: "scheduled", + }), + ); + expect(cancelUpdate.set).toHaveBeenCalledWith( + expect.objectContaining({ + canceled: true, + }), + ); + expect(scheduleUpdate.set).toHaveBeenCalledWith( + expect.objectContaining({ + scheduledProductId: "product_free_internal", + }), + ); + }); +}); diff --git a/packages/paykit/src/subscription/subscription.api.ts b/packages/paykit/src/subscription/subscription.api.ts index 8eeb8514..bf9da044 100644 --- a/packages/paykit/src/subscription/subscription.api.ts +++ b/packages/paykit/src/subscription/subscription.api.ts @@ -1,6 +1,6 @@ import { definePayKitMethod } from "../api/define-route"; -import { subscribeToPlan } from "./subscription.service"; -import { subscribeBodySchema } from "./subscription.types"; +import { cancelPlanSubscription, subscribeToPlan } from "./subscription.service"; +import { cancelSubscriptionBodySchema, subscribeBodySchema } from "./subscription.types"; /** Applies a subscription change for the resolved customer. */ export const subscribe = definePayKitMethod( @@ -23,3 +23,22 @@ export const subscribe = definePayKitMethod( }); }, ); + +/** Cancels the current paid subscription for the plan's group. */ +export const cancelSubscription = definePayKitMethod( + { + input: cancelSubscriptionBodySchema, + requireCustomer: true, + route: { + client: true, + method: "POST", + path: "/cancel-subscription", + }, + }, + async (ctx) => { + return cancelPlanSubscription(ctx.paykit, { + customerId: ctx.customer.id, + planId: ctx.input.planId, + }); + }, +); diff --git a/packages/paykit/src/subscription/subscription.service.ts b/packages/paykit/src/subscription/subscription.service.ts index ba4c28e3..dafa36ac 100644 --- a/packages/paykit/src/subscription/subscription.service.ts +++ b/packages/paykit/src/subscription/subscription.service.ts @@ -29,6 +29,7 @@ import type { import type { StoredSubscription } from "../types/models"; import type { NormalizedPlanFeature } from "../types/schema"; import type { + CancelSubscriptionInput, SubscribeInput, SubscribeResult, SubscriptionWithCatalog, @@ -72,6 +73,103 @@ export async function subscribeToPlan( }); } +/** Cancels the current paid subscription for the requested plan group. */ +export async function cancelPlanSubscription( + ctx: PayKitContext, + input: CancelSubscriptionInput, +): Promise { + return ctx.logger.trace.run("cancel-sub", async () => { + const startTime = Date.now(); + ctx.logger.info({ planId: input.planId, customerId: input.customerId }, "cancel started"); + + const { storedPlan } = await resolveRequestedPlan(ctx, input); + const activeSubscription = storedPlan.group + ? await getActiveSubscriptionInGroup(ctx.database, { + customerId: input.customerId, + group: storedPlan.group, + }) + : null; + + if (!activeSubscription) { + throw PayKitError.from( + "NOT_FOUND", + PAYKIT_ERROR_CODES.SUBSCRIPTION_NOT_FOUND, + `No active subscription found for plan "${input.planId}"`, + ); + } + + if (activeSubscription.priceAmount === null || activeSubscription.canceled) { + return buildSubscribeResult({ paymentUrl: null }); + } + + const activeSubscriptionRef = getProviderSubscriptionRef(activeSubscription); + if (!activeSubscriptionRef.subscriptionId) { + return buildSubscribeResult({ paymentUrl: null }); + } + + const defaultFreePlan = storedPlan.group + ? await resolveDefaultFreePlanInGroup(ctx, storedPlan.group) + : null; + const providerResult = await ctx.provider.cancelSubscription({ + currentPeriodEndAt: activeSubscription.currentPeriodEndAt, + providerSubscriptionId: activeSubscriptionRef.subscriptionId, + providerSubscriptionScheduleId: activeSubscriptionRef.subscriptionScheduleId, + }); + + await ctx.database.transaction(async (tx) => { + if (storedPlan.group) { + await clearScheduledSubscriptionsInGroup(tx, { + customerId: input.customerId, + group: storedPlan.group, + }); + } + + if (defaultFreePlan) { + await insertSubscriptionRecord(tx, { + customerId: input.customerId, + planFeatures: defaultFreePlan.planFeatures, + productInternalId: defaultFreePlan.internalId, + startedAt: activeSubscription.currentPeriodEndAt ?? null, + status: "scheduled", + }); + } + + await scheduleSubscriptionCancellation(tx, { + canceledAt: new Date(), + currentPeriodEndAt: activeSubscription.currentPeriodEndAt ?? null, + subscriptionId: activeSubscription.id, + }); + await replaceSubscriptionSchedule(tx, { + scheduledProductId: defaultFreePlan?.internalId ?? null, + subscriptionId: activeSubscription.id, + }); + + if (providerResult.subscription) { + await syncSubscriptionBillingState(tx, { + currentPeriodEndAt: providerResult.subscription.currentPeriodEndAt, + currentPeriodStartAt: providerResult.subscription.currentPeriodStartAt, + providerData: { + subscriptionId: providerResult.subscription.providerSubscriptionId, + subscriptionScheduleId: + providerResult.subscription.providerSubscriptionScheduleId ?? null, + }, + status: providerResult.subscription.status, + subscriptionId: activeSubscription.id, + }); + } + }); + + const duration = Date.now() - startTime; + ctx.logger.info({ duration }, "cancel completed"); + + return buildSubscribeResult({ + invoice: providerResult.invoice, + paymentUrl: providerResult.paymentUrl, + requiredAction: providerResult.requiredAction, + }); + }); +} + async function resolveStoredPlanFeatures( database: PayKitDatabase, productInternalId: string, @@ -94,7 +192,10 @@ async function resolveStoredPlanFeatures( })); } -export async function loadSubscribeContext(ctx: PayKitContext, input: SubscribeInput) { +async function resolveRequestedPlan( + ctx: PayKitContext, + input: { planId: string; productInternalId?: string }, +) { const providerId = ctx.provider.id; const normalizedPlan = ctx.products.planMap.get(input.planId); const matchingProduct = input.productInternalId @@ -112,8 +213,7 @@ export async function loadSubscribeContext(ctx: PayKitContext, input: SubscribeI ); } - const isFreeTarget = storedPlan.priceAmount === null; - const isPaidTarget = !isFreeTarget; + const isPaidTarget = storedPlan.priceAmount !== null; if (isPaidTarget && !storedPlan.providerProduct) { throw PayKitError.from( "INTERNAL_SERVER_ERROR", @@ -122,6 +222,39 @@ export async function loadSubscribeContext(ctx: PayKitContext, input: SubscribeI ); } + return { + normalizedPlan, + providerId, + storedPlan, + }; +} + +async function resolveDefaultFreePlanInGroup( + ctx: PayKitContext, + group: string, +): Promise<{ internalId: string; planFeatures: readonly NormalizedPlanFeature[] } | null> { + const defaultPlan = await getDefaultProductInGroup(ctx.database, group); + if (!defaultPlan || defaultPlan.priceAmount !== null) { + return null; + } + + const normalizedPlan = ctx.products.planMap.get(defaultPlan.id); + const planFeatures = normalizedPlan + ? normalizedPlan.includes + : await resolveStoredPlanFeatures(ctx.database, defaultPlan.internalId); + + return { + internalId: defaultPlan.internalId, + planFeatures, + }; +} + +export async function loadSubscribeContext(ctx: PayKitContext, input: SubscribeInput) { + const { normalizedPlan, providerId, storedPlan } = await resolveRequestedPlan(ctx, input); + + const isFreeTarget = storedPlan.priceAmount === null; + const isPaidTarget = !isFreeTarget; + await warnOnDuplicateActiveSubscriptionGroups(ctx, input.customerId); const { providerCustomerId } = await upsertProviderCustomer(ctx, { diff --git a/packages/paykit/src/subscription/subscription.types.ts b/packages/paykit/src/subscription/subscription.types.ts index d4fd1116..8e059bb7 100644 --- a/packages/paykit/src/subscription/subscription.types.ts +++ b/packages/paykit/src/subscription/subscription.types.ts @@ -12,11 +12,22 @@ export const subscribeBodySchema = z.object({ export type SubscribeBody = z.infer; +export const cancelSubscriptionBodySchema = z.object({ + planId: z.string(), +}); + +export type CancelSubscriptionBody = z.infer; + export type SubscribeInput = SubscribeBody & { customerId: string; productInternalId?: string; }; +export type CancelSubscriptionInput = CancelSubscriptionBody & { + customerId: string; + productInternalId?: string; +}; + export interface SubscribeResult { invoice?: { currency: string; diff --git a/packages/paykit/src/types/instance.ts b/packages/paykit/src/types/instance.ts index 09d6a91f..0c9c9079 100644 --- a/packages/paykit/src/types/instance.ts +++ b/packages/paykit/src/types/instance.ts @@ -59,13 +59,20 @@ type RefineServerMethodInput< successUrl: string; } : TInput - : TKey extends "check" | "report" - ? TInput extends { featureId: string } - ? Omit & { - featureId: FeatureIdFromOptions; + : TKey extends "cancelSubscription" + ? TInput extends { planId: string } + ? Omit & { + customerId: string; + planId: PlanIdFromOptions; } : TInput - : TInput; + : TKey extends "check" | "report" + ? TInput extends { featureId: string } + ? Omit & { + featureId: FeatureIdFromOptions; + } + : TInput + : TInput; type RefineClientMethodInput< TOptions extends PayKitOptions, @@ -81,11 +88,17 @@ type RefineClientMethodInput< planId: PlanIdFromOptions; } : OmitCustomerId - : TKey extends "customerPortal" - ? OmitCustomerId extends { returnUrl: string } - ? Omit, "returnUrl"> & { returnUrl?: string } + : TKey extends "cancelSubscription" + ? OmitCustomerId extends { planId: string } + ? Omit, "planId"> & { + planId: PlanIdFromOptions; + } : OmitCustomerId - : OmitCustomerId; + : TKey extends "customerPortal" + ? OmitCustomerId extends { returnUrl: string } + ? Omit, "returnUrl"> & { returnUrl?: string } + : OmitCustomerId + : OmitCustomerId; export type PayKitClientSubscribeInput = RefineClientMethodInput>; @@ -95,6 +108,24 @@ export type PayKitSubscribeInput export type PayKitSubscribeResult = InferMethodResult; +export type PayKitClientCancelSubscriptionInput = + RefineClientMethodInput< + TOptions, + "cancelSubscription", + InferMethodInput + >; + +export type PayKitCancelSubscriptionInput = + RefineServerMethodInput< + TOptions, + "cancelSubscription", + InferMethodInput + >; + +export type PayKitCancelSubscriptionResult = InferMethodResult< + RegisteredMethods["cancelSubscription"] +>; + export type PayKitCustomerInput = InferMethodInput; export type PayKitClientCustomerPortalInput = RefineClientMethodInput<