diff --git a/packages/paykit/src/api/__tests__/define-route.test.ts b/packages/paykit/src/api/__tests__/define-route.test.ts new file mode 100644 index 0000000..0a370ca --- /dev/null +++ b/packages/paykit/src/api/__tests__/define-route.test.ts @@ -0,0 +1,65 @@ +import { describe, expect, it } from "vitest"; +import * as z from "zod"; + +import type { PayKitContext } from "../../core/context"; +import { definePayKitMethod, returnUrl } from "../define-route"; + +function createTestContext(trustedOrigins?: string[]) { + return { + options: { + database: "postgres://paykit:test@localhost:5432/paykit", + provider: { + createAdapter: () => { + throw new Error("not used in test"); + }, + id: "stripe", + name: "Stripe", + }, + trustedOrigins, + }, + } as unknown as PayKitContext; +} + +describe("api/define-route", () => { + it("resolves relative return URLs for trusted origins", async () => { + const method = definePayKitMethod( + { + input: z.object({ + successUrl: returnUrl(), + }), + }, + async (ctx) => ctx.input, + ); + + const result = await method( + createTestContext(["https://app.example.com"]), + { successUrl: "/billing/success" }, + new Request("https://app.example.com/paykit/subscribe"), + ); + + expect(result).toEqual({ + successUrl: "https://app.example.com/billing/success", + }); + }); + + it("rejects relative return URLs for untrusted origins", async () => { + const method = definePayKitMethod( + { + input: z.object({ + successUrl: returnUrl(), + }), + }, + async (ctx) => ctx.input, + ); + + await expect( + method( + createTestContext(["https://app.example.com"]), + { successUrl: "/billing/success" }, + new Request("https://evil.example.com/paykit/subscribe"), + ), + ).rejects.toMatchObject({ + code: "TRUSTED_ORIGIN_INVALID", + }); + }); +}); diff --git a/packages/paykit/src/api/define-route.ts b/packages/paykit/src/api/define-route.ts index 68c0b07..84d5016 100644 --- a/packages/paykit/src/api/define-route.ts +++ b/packages/paykit/src/api/define-route.ts @@ -169,6 +169,7 @@ export function definePayKitMethod; const customer = config.requireCustomer ? await resolveCustomer( @@ -212,6 +213,7 @@ export function definePayKitMethod, ): unknown { if (!(schema instanceof z.ZodObject) || !input || typeof input !== "object") { return input; @@ -295,12 +298,12 @@ function normalizeMethodInput( const value = normalized[field]; if (typeof value === "string") { - normalized[field] = normalizeReturnUrlValue(field, value, request, headers); + normalized[field] = normalizeReturnUrlValue(field, value, request, headers, paykit); continue; } if (value == null && shouldDefaultReturnUrlField(field)) { - normalized[field] = resolveAbsoluteUrl("/", request, headers, field); + normalized[field] = resolveAbsoluteUrl("/", request, headers, paykit, field); } } @@ -359,9 +362,10 @@ function normalizeReturnUrlValue( value: string, request?: Request, headers?: Headers, + paykit?: Pick, ): string { if (isAbsolutePath(value)) { - return resolveAbsoluteUrl(value, request, headers, field); + return resolveAbsoluteUrl(value, request, headers, paykit, field); } return value; @@ -375,9 +379,10 @@ function resolveAbsoluteUrl( value: string, request: Request | undefined, headers: Headers | undefined, + paykit: Pick | undefined, field: string, ): string { - const origin = resolveOrigin(request, headers); + const origin = resolveOrigin(request, headers, paykit); if (!origin) { throw PayKitError.from( "BAD_REQUEST", @@ -389,23 +394,54 @@ function resolveAbsoluteUrl( return new URL(value, origin).toString(); } -function resolveOrigin(request?: Request, headers?: Headers): string | null { +function resolveOrigin( + request?: Request, + headers?: Headers, + paykit?: Pick, +): string | null { + let origin: string | null = null; + if (request?.url) { - return new URL("/", request.url).toString(); + origin = new URL("/", request.url).toString(); + } else { + const explicitOrigin = headers?.get("origin"); + if (explicitOrigin && isAbsoluteUrl(explicitOrigin)) { + origin = explicitOrigin.endsWith("/") ? explicitOrigin : `${explicitOrigin}/`; + } else { + const host = headers?.get("x-forwarded-host") ?? headers?.get("host"); + if (!host) { + return null; + } + + const protocol = headers?.get("x-forwarded-proto") ?? "https"; + origin = `${protocol}://${host}/`; + } } - const explicitOrigin = headers?.get("origin"); - if (explicitOrigin && isAbsoluteUrl(explicitOrigin)) { - return explicitOrigin.endsWith("/") ? explicitOrigin : `${explicitOrigin}/`; + if (origin && paykit?.options.trustedOrigins?.length) { + assertTrustedOrigin(origin, paykit.options.trustedOrigins); } - const host = headers?.get("x-forwarded-host") ?? headers?.get("host"); - if (!host) { - return null; + return origin; +} + +function assertTrustedOrigin(origin: string, trustedOrigins: readonly string[]): void { + const normalizedOrigin = normalizeTrustedOrigin(origin); + const isAllowed = trustedOrigins.some((trustedOrigin) => { + return normalizeTrustedOrigin(trustedOrigin) === normalizedOrigin; + }); + + if (!isAllowed) { + throw PayKitError.from( + "BAD_REQUEST", + PAYKIT_ERROR_CODES.TRUSTED_ORIGIN_INVALID, + `Resolved origin "${normalizedOrigin}" is not allowed by trustedOrigins`, + ); } +} - const protocol = headers?.get("x-forwarded-proto") ?? "https"; - return `${protocol}://${host}/`; +function normalizeTrustedOrigin(origin: string): string { + return new URL(origin).origin; } function isAbsoluteUrl(value: string): boolean { diff --git a/packages/paykit/src/core/errors.ts b/packages/paykit/src/core/errors.ts index 5510a26..9637afd 100644 --- a/packages/paykit/src/core/errors.ts +++ b/packages/paykit/src/core/errors.ts @@ -34,6 +34,7 @@ export const PAYKIT_ERROR_CODES = defineErrorCodes({ CUSTOMER_ID_REQUIRED: "No customerId provided and no identify configured", SUCCESS_URL_REQUIRED: "A successUrl is required when subscribe is called without a request context", + TRUSTED_ORIGIN_INVALID: "Resolved origin is not in trustedOrigins", BASEPATH_INVALID: "basePath must start with a leading slash", TESTING_NOT_ENABLED: "Testing mode is not enabled", TEST_CLOCK_NOT_FOUND: "Customer does not have a test clock", diff --git a/packages/paykit/src/core/validate-options.ts b/packages/paykit/src/core/validate-options.ts index 783ddef..27a987f 100644 --- a/packages/paykit/src/core/validate-options.ts +++ b/packages/paykit/src/core/validate-options.ts @@ -24,4 +24,25 @@ export function assertValidPayKitOptions( if (error) { throw new Error(error); } + + for (const origin of options.trustedOrigins ?? []) { + assertValidTrustedOrigin(origin); + } +} + +function assertValidTrustedOrigin(origin: string): void { + let parsed: URL; + try { + parsed = new URL(origin); + } catch { + throw new Error( + `PayKit option \`trustedOrigins\` must contain absolute origins only. Received "${origin}".`, + ); + } + + if (parsed.pathname !== "/" || parsed.search || parsed.hash) { + throw new Error( + `PayKit option \`trustedOrigins\` must not include a path, query, or hash. Received "${origin}".`, + ); + } } diff --git a/packages/paykit/src/types/options.ts b/packages/paykit/src/types/options.ts index 50286eb..c8c25bd 100644 --- a/packages/paykit/src/types/options.ts +++ b/packages/paykit/src/types/options.ts @@ -25,6 +25,12 @@ export interface PayKitOptions { * @default "/paykit" */ basePath?: string; + /** + * Allowlist of origins that PayKit may trust when resolving relative return URLs. + * Useful to prevent host header spoofing when `successUrl`, `cancelUrl`, or `returnUrl` + * are provided as absolute paths like `/billing/success`. + */ + trustedOrigins?: string[]; identify?: (request: Request) => Promise<{ customerId: string; email?: string;