Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions packages/paykit/src/api/__tests__/define-route.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { describe, expect, it } from "vitest";
import * as z from "zod";

import type { PayKitContext } from "../../core/context";
import { definePayKitMethod, returnUrl } from "../define-route";

function createTestContext(trustedOrigins?: string[]) {
return {
options: {
database: "postgres://paykit:test@localhost:5432/paykit",
provider: {
createAdapter: () => {
throw new Error("not used in test");
},
id: "stripe",
name: "Stripe",
},
trustedOrigins,
},
} as unknown as PayKitContext;
}

describe("api/define-route", () => {
it("resolves relative return URLs for trusted origins", async () => {
const method = definePayKitMethod(
{
input: z.object({
successUrl: returnUrl(),
}),
},
async (ctx) => ctx.input,
);

const result = await method(
createTestContext(["https://app.example.com"]),
{ successUrl: "/billing/success" },
new Request("https://app.example.com/paykit/subscribe"),
);

expect(result).toEqual({
successUrl: "https://app.example.com/billing/success",
});
});

it("rejects relative return URLs for untrusted origins", async () => {
const method = definePayKitMethod(
{
input: z.object({
successUrl: returnUrl(),
}),
},
async (ctx) => ctx.input,
);

await expect(
method(
createTestContext(["https://app.example.com"]),
{ successUrl: "/billing/success" },
new Request("https://evil.example.com/paykit/subscribe"),
),
).rejects.toMatchObject({
code: "TRUSTED_ORIGIN_INVALID",
});
});
});
64 changes: 50 additions & 14 deletions packages/paykit/src/api/define-route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ export function definePayKitMethod<const TConfig extends PayKitMethodConfig, TRe
stripCustomerId(input),
request,
request?.headers,
paykit,
) as InferMethodInput<TConfig>;
const customer = config.requireCustomer
? await resolveCustomer(
Expand Down Expand Up @@ -212,6 +213,7 @@ export function definePayKitMethod<const TConfig extends PayKitMethodConfig, TRe
: ctx.body,
ctx.request,
ctx.headers,
ctx.context,
);
const customer = config.requireCustomer
? await resolveCustomer(ctx.context, ctx.request)
Expand Down Expand Up @@ -280,6 +282,7 @@ function normalizeMethodInput(
input: unknown,
request?: Request,
headers?: Headers,
paykit?: Pick<PayKitContext, "options">,
): unknown {
if (!(schema instanceof z.ZodObject) || !input || typeof input !== "object") {
return input;
Expand All @@ -295,12 +298,12 @@ function normalizeMethodInput(
const value = normalized[field];

if (typeof value === "string") {
normalized[field] = normalizeReturnUrlValue(field, value, request, headers);
normalized[field] = normalizeReturnUrlValue(field, value, request, headers, paykit);
continue;
}

if (value == null && shouldDefaultReturnUrlField(field)) {
normalized[field] = resolveAbsoluteUrl("/", request, headers, field);
normalized[field] = resolveAbsoluteUrl("/", request, headers, paykit, field);
}
}

Expand Down Expand Up @@ -359,9 +362,10 @@ function normalizeReturnUrlValue(
value: string,
request?: Request,
headers?: Headers,
paykit?: Pick<PayKitContext, "options">,
): string {
if (isAbsolutePath(value)) {
return resolveAbsoluteUrl(value, request, headers, field);
return resolveAbsoluteUrl(value, request, headers, paykit, field);
}

return value;
Expand All @@ -375,9 +379,10 @@ function resolveAbsoluteUrl(
value: string,
request: Request | undefined,
headers: Headers | undefined,
paykit: Pick<PayKitContext, "options"> | undefined,
field: string,
): string {
const origin = resolveOrigin(request, headers);
const origin = resolveOrigin(request, headers, paykit);
if (!origin) {
throw PayKitError.from(
"BAD_REQUEST",
Expand All @@ -389,23 +394,54 @@ function resolveAbsoluteUrl(
return new URL(value, origin).toString();
}

function resolveOrigin(request?: Request, headers?: Headers): string | null {
function resolveOrigin(
request?: Request,
headers?: Headers,
paykit?: Pick<PayKitContext, "options">,
): string | null {
let origin: string | null = null;

if (request?.url) {
return new URL("/", request.url).toString();
origin = new URL("/", request.url).toString();
} else {
const explicitOrigin = headers?.get("origin");
if (explicitOrigin && isAbsoluteUrl(explicitOrigin)) {
origin = explicitOrigin.endsWith("/") ? explicitOrigin : `${explicitOrigin}/`;
} else {
const host = headers?.get("x-forwarded-host") ?? headers?.get("host");
if (!host) {
return null;
}

const protocol = headers?.get("x-forwarded-proto") ?? "https";
origin = `${protocol}://${host}/`;
}
}

const explicitOrigin = headers?.get("origin");
if (explicitOrigin && isAbsoluteUrl(explicitOrigin)) {
return explicitOrigin.endsWith("/") ? explicitOrigin : `${explicitOrigin}/`;
if (origin && paykit?.options.trustedOrigins?.length) {
assertTrustedOrigin(origin, paykit.options.trustedOrigins);
}

const host = headers?.get("x-forwarded-host") ?? headers?.get("host");
if (!host) {
return null;
return origin;
}

function assertTrustedOrigin(origin: string, trustedOrigins: readonly string[]): void {
const normalizedOrigin = normalizeTrustedOrigin(origin);
const isAllowed = trustedOrigins.some((trustedOrigin) => {
return normalizeTrustedOrigin(trustedOrigin) === normalizedOrigin;
});

if (!isAllowed) {
throw PayKitError.from(
"BAD_REQUEST",
PAYKIT_ERROR_CODES.TRUSTED_ORIGIN_INVALID,
`Resolved origin "${normalizedOrigin}" is not allowed by trustedOrigins`,
);
}
}

const protocol = headers?.get("x-forwarded-proto") ?? "https";
return `${protocol}://${host}/`;
function normalizeTrustedOrigin(origin: string): string {
return new URL(origin).origin;
}

function isAbsoluteUrl(value: string): boolean {
Expand Down
1 change: 1 addition & 0 deletions packages/paykit/src/core/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export const PAYKIT_ERROR_CODES = defineErrorCodes({
CUSTOMER_ID_REQUIRED: "No customerId provided and no identify configured",
SUCCESS_URL_REQUIRED:
"A successUrl is required when subscribe is called without a request context",
TRUSTED_ORIGIN_INVALID: "Resolved origin is not in trustedOrigins",
BASEPATH_INVALID: "basePath must start with a leading slash",
TESTING_NOT_ENABLED: "Testing mode is not enabled",
TEST_CLOCK_NOT_FOUND: "Customer does not have a test clock",
Expand Down
21 changes: 21 additions & 0 deletions packages/paykit/src/core/validate-options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,25 @@ export function assertValidPayKitOptions(
if (error) {
throw new Error(error);
}

for (const origin of options.trustedOrigins ?? []) {
assertValidTrustedOrigin(origin);
}
}

function assertValidTrustedOrigin(origin: string): void {
let parsed: URL;
try {
parsed = new URL(origin);
} catch {
throw new Error(
`PayKit option \`trustedOrigins\` must contain absolute origins only. Received "${origin}".`,
);
}

if (parsed.pathname !== "/" || parsed.search || parsed.hash) {
throw new Error(
`PayKit option \`trustedOrigins\` must not include a path, query, or hash. Received "${origin}".`,
);
}
}
6 changes: 6 additions & 0 deletions packages/paykit/src/types/options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ export interface PayKitOptions {
* @default "/paykit"
*/
basePath?: string;
/**
* Allowlist of origins that PayKit may trust when resolving relative return URLs.
* Useful to prevent host header spoofing when `successUrl`, `cancelUrl`, or `returnUrl`
* are provided as absolute paths like `/billing/success`.
*/
trustedOrigins?: string[];
identify?: (request: Request) => Promise<{
customerId: string;
email?: string;
Expand Down