Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions packages/paykit/src/providers/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ export interface ProviderRequiredAction {
type: string;
}

export interface ProviderCheckoutCustomer {
email?: string;
name?: string;
}

export interface ProviderSubscription {
cancelAtPeriodEnd: boolean;
canceledAt?: Date | null;
Expand Down Expand Up @@ -111,6 +116,7 @@ export interface PaymentProvider {
}): Promise<{ url: string }>;

createSubscriptionCheckout(data: {
customer?: ProviderCheckoutCustomer;
providerCustomerId: string;
providerProduct: Record<string, string>;
successUrl: string;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import { beforeEach, describe, expect, it, vi } from "vitest";

import type { PayKitContext } from "../../core/context";
import type { Customer } from "../../types/models";
import type { NormalizedSchema } from "../../types/schema";

const {
getCustomerByIdOrThrow,
upsertProviderCustomer,
getDefaultPaymentMethod,
getProductByHash,
getProductByInternalId,
getProductFeatures,
getDefaultProductInGroup,
getProductByProviderData,
withProviderInfo,
} = vi.hoisted(() => ({
getCustomerByIdOrThrow: vi.fn(),
upsertProviderCustomer: vi.fn(),
getDefaultPaymentMethod: vi.fn(),
getProductByHash: vi.fn(),
getProductByInternalId: vi.fn(),
getProductFeatures: vi.fn(),
getDefaultProductInGroup: vi.fn(),
getProductByProviderData: vi.fn(),
withProviderInfo: vi.fn(),
}));

vi.mock("../../customer/customer.service", () => ({
findCustomerByProviderCustomerId: vi.fn(),
getCustomerByIdOrThrow,
upsertProviderCustomer,
}));

vi.mock("../../payment-method/payment-method.service", () => ({
getDefaultPaymentMethod,
}));

vi.mock("../../product/product.service", () => ({
getDefaultProductInGroup,
getProductByHash,
getProductByInternalId,
getProductByProviderData,
getProductFeatures,
withProviderInfo,
}));

import { subscribeToPlan } from "../subscription.service";

const emptyProducts: NormalizedSchema = {
features: [],
plans: [],
planMap: new Map(),
};

function createCustomerRow(overrides: Partial<Customer> = {}): Customer {
const now = new Date("2024-01-01T00:00:00.000Z");

return {
createdAt: now,
deletedAt: null,
email: null,
id: "customer_123",
metadata: null,
name: null,
provider: {},
updatedAt: now,
...overrides,
};
}

function createSelectChain(result: unknown, terminalMethod: "where" | "orderBy" | "limit") {
const chain: Record<string, unknown> = {
from: vi.fn(),
innerJoin: vi.fn(),
where: vi.fn(),
orderBy: vi.fn(),
limit: vi.fn(),
};

chain.from = vi.fn().mockReturnValue(chain);
chain.innerJoin = vi.fn().mockReturnValue(chain);
chain.where =
terminalMethod === "where" ? vi.fn().mockResolvedValue(result) : vi.fn().mockReturnValue(chain);
chain.orderBy =
terminalMethod === "orderBy"
? vi.fn().mockResolvedValue(result)
: vi.fn().mockReturnValue(chain);
chain.limit = terminalMethod === "limit" ? vi.fn().mockResolvedValue(result) : vi.fn();

return chain;
}

describe("subscription/service", () => {
beforeEach(() => {
vi.clearAllMocks();
});

it("passes customer details into subscription checkout", async () => {
const customer = createCustomerRow({
email: "billing@example.com",
name: "Billing User",
});
const storedPlan = {
group: "default",
id: "pro",
internalId: "product_internal_123",
priceAmount: 1900,
priceInterval: "month",
providerProduct: { priceId: "price_123" },
};
const createSubscriptionCheckout = vi.fn().mockResolvedValue({
paymentUrl: "https://checkout.example.com/session",
providerCheckoutSessionId: "cs_123",
});
const warningSelect = createSelectChain([], "where");
const activeSubscriptionSelect = createSelectChain([], "limit");
const scheduledSubscriptionSelect = createSelectChain([], "orderBy");

getCustomerByIdOrThrow.mockResolvedValue(customer);
upsertProviderCustomer.mockResolvedValue({
customerId: "customer_123",
providerCustomer: { id: "cus_123" },
providerCustomerId: "cus_123",
});
getDefaultPaymentMethod.mockResolvedValue(null);
getProductByHash.mockResolvedValue({ id: "pro" });
withProviderInfo.mockReturnValue(storedPlan);

const ctx = {
database: {
select: vi
.fn()
.mockReturnValueOnce(warningSelect)
.mockReturnValueOnce(activeSubscriptionSelect)
.mockReturnValueOnce(scheduledSubscriptionSelect),
},
logger: {
info: vi.fn(),
trace: {
run: vi.fn().mockImplementation(async (_label, fn) => fn()),
},
warn: vi.fn(),
},
options: {
provider: {
createAdapter: vi.fn(),
id: "stripe",
name: "Stripe",
},
},
products: {
...emptyProducts,
planMap: new Map([
[
"pro",
{
hash: "plan_hash_123",
id: "pro",
includes: [],
},
],
]),
},
provider: {
createSubscription: vi.fn(),
createSubscriptionCheckout,
id: "stripe",
name: "Stripe",
},
} as unknown as PayKitContext;

const result = await subscribeToPlan(ctx, {
customerId: "customer_123",
forceCheckout: true,
planId: "pro",
successUrl: "https://example.com/success",
});

expect(result).toEqual({
paymentUrl: "https://checkout.example.com/session",
requiredAction: null,
});
expect(createSubscriptionCheckout).toHaveBeenCalledWith({
cancelUrl: undefined,
customer: {
email: "billing@example.com",
name: "Billing User",
},
metadata: {
paykit_customer_id: "customer_123",
paykit_intent: "subscribe",
paykit_plan_id: "pro",
paykit_product_internal_id: "product_internal_123",
},
providerCustomerId: "cus_123",
providerProduct: { priceId: "price_123" },
successUrl: "https://example.com/success",
});
});
});
8 changes: 8 additions & 0 deletions packages/paykit/src/subscription/subscription.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import type { PayKitContext } from "../core/context";
import { PayKitError, PAYKIT_ERROR_CODES } from "../core/errors";
import { generateId } from "../core/utils";
import {
getCustomerByIdOrThrow,
findCustomerByProviderCustomerId,
upsertProviderCustomer,
} from "../customer/customer.service";
Expand Down Expand Up @@ -124,8 +125,10 @@ export async function loadSubscribeContext(ctx: PayKitContext, input: SubscribeI

await warnOnDuplicateActiveSubscriptionGroups(ctx, input.customerId);

const customer = await getCustomerByIdOrThrow(ctx.database, input.customerId);
const { providerCustomerId } = await upsertProviderCustomer(ctx, {
customerId: input.customerId,
customerRow: customer,
});
const hasDefaultPaymentMethod =
(await getDefaultPaymentMethod(ctx.database, {
Expand Down Expand Up @@ -160,6 +163,10 @@ export async function loadSubscribeContext(ctx: PayKitContext, input: SubscribeI
return {
activeSubscription,
cancelUrl: input.cancelUrl,
customer: {
email: customer.email ?? undefined,
name: customer.name ?? undefined,
},
customerId: input.customerId,
isFreeTarget,
isPaidTarget,
Expand Down Expand Up @@ -1068,6 +1075,7 @@ async function createCheckoutSubscribe(
): Promise<SubscribeResult> {
const checkoutResult = await ctx.provider.createSubscriptionCheckout({
cancelUrl: subCtx.cancelUrl,
customer: subCtx.customer,
metadata: {
paykit_customer_id: subCtx.customerId,
paykit_intent: "subscribe",
Expand Down
2 changes: 2 additions & 0 deletions packages/polar/src/polar-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,10 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme

async createSubscriptionCheckout(data) {
const checkout = await client.checkouts.create({
customerEmail: data.customer?.email,
products: [data.providerProduct.productId!],
customerId: data.providerCustomerId,
customerName: data.customer?.name,
metadata: data.metadata,
successUrl: data.successUrl,
});
Expand Down
23 changes: 23 additions & 0 deletions packages/stripe/src/__tests__/stripe.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,29 @@ describe("providers/stripe", () => {
expect(params.managed_payments).toBeUndefined();
});

it("passes customer email to subscription checkout when available", async () => {
const createSession = vi
.fn()
.mockResolvedValue({ id: "cs_123", url: "https://checkout.stripe.com/x" });
const runtime = createCheckoutRuntime(createSession, false);

await runtime.createSubscriptionCheckout({
cancelUrl: "https://example.com/cancel",
customer: {
email: "billing@example.com",
name: "Billing User",
},
metadata: {},
providerCustomerId: "cus_123",
providerProduct: { priceId: "price_123" },
successUrl: "https://example.com/success",
});

expect(createSession).toHaveBeenCalledWith(
expect.objectContaining({ customer_email: "billing@example.com" }),
);
});

it("throws when managedPayments is enabled without the preview apiVersion", () => {
expect(() =>
stripe({
Expand Down
1 change: 1 addition & 0 deletions packages/stripe/src/stripe-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions):
cancel_url: data.cancelUrl ?? data.successUrl,
client_reference_id: data.providerCustomerId,
customer: data.providerCustomerId,
customer_email: data.customer?.email,
line_items: [{ price: data.providerProduct.priceId, quantity: 1 }],
metadata: data.metadata,
mode: "subscription",
Expand Down