Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,34 @@ The router automatically detects and classifies failures:
- `access_denied` - HTTP 403 or unauthorized
- `unknown` - Other errors

### Request Header Hints

TokenFirewall can parse optional request headers that describe how a future
smart-routing layer should treat a request:

```javascript
const {
parseTokenFirewallHeaders,
hasTokenFirewallHeaderHints
} = require("tokenfirewall");

const hints = parseTokenFirewallHeaders({
"X-TokenFirewall-Task-Type": "code_generation",
"X-TokenFirewall-Smart-Routing": "true",
"X-TokenFirewall-Tags": "prod, premium"
});

if (hasTokenFirewallHeaderHints(hints)) {
console.log(hints);
// { taskType: "code_generation", smartRouting: true, tags: ["prod", "premium"] }
}
```

Supported headers:
- `X-TokenFirewall-Task-Type` - manual task type hint
- `X-TokenFirewall-Smart-Routing` - `true`/`false` style routing toggle
- `X-TokenFirewall-Tags` - comma-separated routing or analytics tags

### `disableModelRouter()`

Disables the model router.
Expand Down
10 changes: 9 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ export async function listModels(options: Omit<ListModelsOptions, 'budgetManager

// Keep the original export for backward compatibility
export { listAvailableModels };
export {
TOKEN_FIREWALL_SMART_ROUTING_HEADER,
TOKEN_FIREWALL_TAGS_HEADER,
TOKEN_FIREWALL_TASK_TYPE_HEADER,
hasTokenFirewallHeaderHints,
parseTokenFirewallHeaders
} from "./router/requestHeaders";

// Export types for TypeScript users
export type {
Expand All @@ -249,7 +256,8 @@ export type {
FailureContext,
RoutingDecision,
RouterEvent,
ApiKeyConfig
ApiKeyConfig,
TokenFirewallRequestHeaderHints
} from "./router/types";

/**
Expand Down
28 changes: 28 additions & 0 deletions src/interceptors/fetchInterceptor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import { apiKeyManager } from "../router/apiKeyManager";
import { buildProviderHeaders, appendApiKeyToUrl } from "../router/providerHeaders";
import { transformRequest } from "../router/requestTransformer";
import { transformResponse } from "../router/responseTransformer";
import {
hasTokenFirewallHeaderHints,
parseTokenFirewallHeaders
} from "../router/requestHeaders";

let isPatched = false;
let budgetManager: BudgetManager | null = null;
Expand Down Expand Up @@ -133,6 +137,7 @@ async function fetchWithRetry(

// Extract original model and provider from request
const { originalModel, provider: originalProvider } = extractModelInfo(input, init);
const headerHints = extractTokenFirewallHeaderHints(input, init);

if (originalModel) {
attemptedModels.push(originalModel);
Expand Down Expand Up @@ -232,6 +237,7 @@ async function fetchWithRetry(
return {};
}
})() : {},
headerHints,
provider: originalProvider,
retryCount,
attemptedModels
Expand Down Expand Up @@ -345,6 +351,28 @@ function buildCrossProviderRequest(
};
}

/**
* Extract TokenFirewall-specific header hints without mutating request headers.
*/
function extractTokenFirewallHeaderHints(
input: Parameters<typeof fetch>[0],
init?: Parameters<typeof fetch>[1]
) {
const initHints = parseTokenFirewallHeaders(init?.headers);
if (hasTokenFirewallHeaderHints(initHints)) {
return initHints;
}

if (input instanceof Request) {
const requestHints = parseTokenFirewallHeaders(input.headers);
if (hasTokenFirewallHeaderHints(requestHints)) {
return requestHints;
}
}

return undefined;
}

/**
* Extract current model from the (possibly updated) request
*/
Expand Down
129 changes: 129 additions & 0 deletions src/router/requestHeaders.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import { TokenFirewallRequestHeaderHints } from "./types";

export const TOKEN_FIREWALL_TASK_TYPE_HEADER = "x-tokenfirewall-task-type";
export const TOKEN_FIREWALL_SMART_ROUTING_HEADER = "x-tokenfirewall-smart-routing";
export const TOKEN_FIREWALL_TAGS_HEADER = "x-tokenfirewall-tags";

type HeaderValue = string | number | boolean | string[] | undefined | null;

interface HeaderGetter {
get(name: string): string | null;
}

interface HeaderForEach {
forEach(callback: (value: string, key: string) => void): void;
}

/**
* Parse TokenFirewall smart-routing controls from request headers.
*/
export function parseTokenFirewallHeaders(headers: unknown): TokenFirewallRequestHeaderHints {
return {
taskType: normalizeHeaderText(readHeader(headers, TOKEN_FIREWALL_TASK_TYPE_HEADER)),
smartRouting: parseSmartRoutingHeader(
readHeader(headers, TOKEN_FIREWALL_SMART_ROUTING_HEADER)
),
tags: parseTagsHeader(readHeader(headers, TOKEN_FIREWALL_TAGS_HEADER))
};
}

/**
* Check whether any TokenFirewall header hints were supplied.
*/
export function hasTokenFirewallHeaderHints(hints: TokenFirewallRequestHeaderHints): boolean {
return Boolean(
hints.taskType ||
hints.smartRouting !== undefined ||
hints.tags.length > 0
);
}

function readHeader(headers: unknown, name: string): string | undefined {
if (!headers) {
return undefined;
}

const getter = headers as Partial<HeaderGetter>;
if (typeof getter.get === "function") {
return normalizeHeaderValue(getter.get(name));
}

if (Array.isArray(headers)) {
const match = headers.find(([key]) =>
typeof key === "string" && key.toLowerCase() === name
);
return match ? normalizeHeaderValue(match[1]) : undefined;
}

const forEachHeaders = headers as Partial<HeaderForEach>;
if (typeof forEachHeaders.forEach === "function") {
let value: string | undefined;
forEachHeaders.forEach((headerValue, key) => {
if (key.toLowerCase() === name) {
value = normalizeHeaderValue(headerValue);
}
});
return value;
}

if (typeof headers === "object") {
const record = headers as Record<string, HeaderValue>;
const key = Object.keys(record).find((headerName) =>
headerName.toLowerCase() === name
);
return key ? normalizeHeaderValue(record[key]) : undefined;
}

return undefined;
}

function normalizeHeaderValue(value: HeaderValue): string | undefined {
if (Array.isArray(value)) {
return value.join(",");
}
if (value === undefined || value === null) {
return undefined;
}
return String(value);
}

function normalizeHeaderText(value: string | undefined): string | undefined {
const normalized = value?.trim();
return normalized ? normalized : undefined;
}

function parseSmartRoutingHeader(value: string | undefined): boolean | undefined {
const normalized = value?.trim().toLowerCase();
if (!normalized) {
return undefined;
}

if (["1", "true", "yes", "on", "enabled"].includes(normalized)) {
return true;
}

if (["0", "false", "no", "off", "disabled"].includes(normalized)) {
return false;
}

return undefined;
}

function parseTagsHeader(value: string | undefined): string[] {
const seen = new Set<string>();
const tags = value
? value
.split(",")
.map((tag) => tag.trim())
.filter(Boolean)
: [];

return tags.filter((tag) => {
const key = tag.toLowerCase();
if (seen.has(key)) {
return false;
}
seen.add(key);
return true;
});
}
14 changes: 14 additions & 0 deletions src/router/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ export interface FailureContext {
originalModel: string;
/** Request body sent to API */
requestBody: any;
/** Optional TokenFirewall routing hints parsed from request headers */
headerHints?: TokenFirewallRequestHeaderHints;
/** Provider name */
provider: string;
/** Current retry attempt count */
Expand Down Expand Up @@ -90,3 +92,15 @@ export interface RouterEvent {
/** Maximum retries allowed */
maxRetries: number;
}

/**
* TokenFirewall-specific request header hints for smart routing integrations.
*/
export interface TokenFirewallRequestHeaderHints {
/** Optional manual task type, from X-TokenFirewall-Task-Type */
taskType?: string;
/** Optional smart-routing toggle, from X-TokenFirewall-Smart-Routing */
smartRouting?: boolean;
/** Optional tag list, from X-TokenFirewall-Tags */
tags: string[];
}
83 changes: 83 additions & 0 deletions tests/request-headers.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/**
* TokenFirewall request header parsing tests.
*
* Run: node tests/request-headers.test.js
*/

const {
TOKEN_FIREWALL_SMART_ROUTING_HEADER,
TOKEN_FIREWALL_TAGS_HEADER,
TOKEN_FIREWALL_TASK_TYPE_HEADER,
hasTokenFirewallHeaderHints,
parseTokenFirewallHeaders,
} = require("../dist/index.js");

function assert(condition, message) {
if (!condition) {
throw new Error(message);
}
}

function assertDeepEqual(actual, expected, message) {
const actualJson = JSON.stringify(actual);
const expectedJson = JSON.stringify(expected);
if (actualJson !== expectedJson) {
throw new Error(`${message}\nExpected: ${expectedJson}\nActual: ${actualJson}`);
}
}

try {
const objectHints = parseTokenFirewallHeaders({
"X-TokenFirewall-Task-Type": " code_generation ",
"X-TokenFirewall-Smart-Routing": "enabled",
"X-TokenFirewall-Tags": "prod, premium, PROD, api",
});

assertDeepEqual(objectHints, {
taskType: "code_generation",
smartRouting: true,
tags: ["prod", "premium", "api"],
}, "object headers should parse task type, smart routing, and tags");
assert(hasTokenFirewallHeaderHints(objectHints), "object hints should be detected");

const disabledHints = parseTokenFirewallHeaders(new Headers({
[TOKEN_FIREWALL_TASK_TYPE_HEADER]: "simple_chat",
[TOKEN_FIREWALL_SMART_ROUTING_HEADER]: "off",
[TOKEN_FIREWALL_TAGS_HEADER]: "support,internal",
}));

assertDeepEqual(disabledHints, {
taskType: "simple_chat",
smartRouting: false,
tags: ["support", "internal"],
}, "Headers instances should parse case-insensitive controls");

const tupleHints = parseTokenFirewallHeaders([
[TOKEN_FIREWALL_TASK_TYPE_HEADER, "math_reasoning"],
[TOKEN_FIREWALL_SMART_ROUTING_HEADER, "1"],
[TOKEN_FIREWALL_TAGS_HEADER, "analysis,priority"],
]);

assertDeepEqual(tupleHints, {
taskType: "math_reasoning",
smartRouting: true,
tags: ["analysis", "priority"],
}, "header tuples should parse controls");

const invalidHints = parseTokenFirewallHeaders({
[TOKEN_FIREWALL_TASK_TYPE_HEADER]: " ",
[TOKEN_FIREWALL_SMART_ROUTING_HEADER]: "maybe",
[TOKEN_FIREWALL_TAGS_HEADER]: " , , ",
});

assertDeepEqual(invalidHints, {
smartRouting: undefined,
tags: [],
}, "empty and invalid values should be ignored safely");
assert(!hasTokenFirewallHeaderHints(invalidHints), "empty hints should not be detected");

console.log("Request header parsing tests passed.");
} catch (error) {
console.error("Request header parsing tests failed:", error.message);
process.exit(1);
}