diff --git a/README.md b/README.md index 57253f1..b8e89b6 100644 --- a/README.md +++ b/README.md @@ -333,6 +333,34 @@ The router automatically detects and classifies failures: - `access_denied` - HTTP 403 or unauthorized - `unknown` - Other errors +### Request Header Hints + +TokenFirewall can parse optional request headers that describe how a future +smart-routing layer should treat a request: + +```javascript +const { + parseTokenFirewallHeaders, + hasTokenFirewallHeaderHints +} = require("tokenfirewall"); + +const hints = parseTokenFirewallHeaders({ + "X-TokenFirewall-Task-Type": "code_generation", + "X-TokenFirewall-Smart-Routing": "true", + "X-TokenFirewall-Tags": "prod, premium" +}); + +if (hasTokenFirewallHeaderHints(hints)) { + console.log(hints); + // { taskType: "code_generation", smartRouting: true, tags: ["prod", "premium"] } +} +``` + +Supported headers: +- `X-TokenFirewall-Task-Type` - manual task type hint +- `X-TokenFirewall-Smart-Routing` - `true`/`false` style routing toggle +- `X-TokenFirewall-Tags` - comma-separated routing or analytics tags + ### `disableModelRouter()` Disables the model router. diff --git a/src/index.ts b/src/index.ts index c769e99..f302f5f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -227,6 +227,13 @@ export async function listModels(options: Omit[0], + init?: Parameters[1] +) { + const initHints = parseTokenFirewallHeaders(init?.headers); + if (hasTokenFirewallHeaderHints(initHints)) { + return initHints; + } + + if (input instanceof Request) { + const requestHints = parseTokenFirewallHeaders(input.headers); + if (hasTokenFirewallHeaderHints(requestHints)) { + return requestHints; + } + } + + return undefined; +} + /** * Extract current model from the (possibly updated) request */ diff --git a/src/router/requestHeaders.ts b/src/router/requestHeaders.ts new file mode 100644 index 0000000..6c0a924 --- /dev/null +++ b/src/router/requestHeaders.ts @@ -0,0 +1,129 @@ +import { TokenFirewallRequestHeaderHints } from "./types"; + +export const TOKEN_FIREWALL_TASK_TYPE_HEADER = "x-tokenfirewall-task-type"; +export const TOKEN_FIREWALL_SMART_ROUTING_HEADER = "x-tokenfirewall-smart-routing"; +export const TOKEN_FIREWALL_TAGS_HEADER = "x-tokenfirewall-tags"; + +type HeaderValue = string | number | boolean | string[] | undefined | null; + +interface HeaderGetter { + get(name: string): string | null; +} + +interface HeaderForEach { + forEach(callback: (value: string, key: string) => void): void; +} + +/** + * Parse TokenFirewall smart-routing controls from request headers. + */ +export function parseTokenFirewallHeaders(headers: unknown): TokenFirewallRequestHeaderHints { + return { + taskType: normalizeHeaderText(readHeader(headers, TOKEN_FIREWALL_TASK_TYPE_HEADER)), + smartRouting: parseSmartRoutingHeader( + readHeader(headers, TOKEN_FIREWALL_SMART_ROUTING_HEADER) + ), + tags: parseTagsHeader(readHeader(headers, TOKEN_FIREWALL_TAGS_HEADER)) + }; +} + +/** + * Check whether any TokenFirewall header hints were supplied. + */ +export function hasTokenFirewallHeaderHints(hints: TokenFirewallRequestHeaderHints): boolean { + return Boolean( + hints.taskType || + hints.smartRouting !== undefined || + hints.tags.length > 0 + ); +} + +function readHeader(headers: unknown, name: string): string | undefined { + if (!headers) { + return undefined; + } + + const getter = headers as Partial; + if (typeof getter.get === "function") { + return normalizeHeaderValue(getter.get(name)); + } + + if (Array.isArray(headers)) { + const match = headers.find(([key]) => + typeof key === "string" && key.toLowerCase() === name + ); + return match ? normalizeHeaderValue(match[1]) : undefined; + } + + const forEachHeaders = headers as Partial; + if (typeof forEachHeaders.forEach === "function") { + let value: string | undefined; + forEachHeaders.forEach((headerValue, key) => { + if (key.toLowerCase() === name) { + value = normalizeHeaderValue(headerValue); + } + }); + return value; + } + + if (typeof headers === "object") { + const record = headers as Record; + const key = Object.keys(record).find((headerName) => + headerName.toLowerCase() === name + ); + return key ? normalizeHeaderValue(record[key]) : undefined; + } + + return undefined; +} + +function normalizeHeaderValue(value: HeaderValue): string | undefined { + if (Array.isArray(value)) { + return value.join(","); + } + if (value === undefined || value === null) { + return undefined; + } + return String(value); +} + +function normalizeHeaderText(value: string | undefined): string | undefined { + const normalized = value?.trim(); + return normalized ? normalized : undefined; +} + +function parseSmartRoutingHeader(value: string | undefined): boolean | undefined { + const normalized = value?.trim().toLowerCase(); + if (!normalized) { + return undefined; + } + + if (["1", "true", "yes", "on", "enabled"].includes(normalized)) { + return true; + } + + if (["0", "false", "no", "off", "disabled"].includes(normalized)) { + return false; + } + + return undefined; +} + +function parseTagsHeader(value: string | undefined): string[] { + const seen = new Set(); + const tags = value + ? value + .split(",") + .map((tag) => tag.trim()) + .filter(Boolean) + : []; + + return tags.filter((tag) => { + const key = tag.toLowerCase(); + if (seen.has(key)) { + return false; + } + seen.add(key); + return true; + }); +} diff --git a/src/router/types.ts b/src/router/types.ts index 8e333f4..6e150e8 100644 --- a/src/router/types.ts +++ b/src/router/types.ts @@ -55,6 +55,8 @@ export interface FailureContext { originalModel: string; /** Request body sent to API */ requestBody: any; + /** Optional TokenFirewall routing hints parsed from request headers */ + headerHints?: TokenFirewallRequestHeaderHints; /** Provider name */ provider: string; /** Current retry attempt count */ @@ -90,3 +92,15 @@ export interface RouterEvent { /** Maximum retries allowed */ maxRetries: number; } + +/** + * TokenFirewall-specific request header hints for smart routing integrations. + */ +export interface TokenFirewallRequestHeaderHints { + /** Optional manual task type, from X-TokenFirewall-Task-Type */ + taskType?: string; + /** Optional smart-routing toggle, from X-TokenFirewall-Smart-Routing */ + smartRouting?: boolean; + /** Optional tag list, from X-TokenFirewall-Tags */ + tags: string[]; +} diff --git a/tests/request-headers.test.js b/tests/request-headers.test.js new file mode 100644 index 0000000..c6fb8ab --- /dev/null +++ b/tests/request-headers.test.js @@ -0,0 +1,83 @@ +/** + * TokenFirewall request header parsing tests. + * + * Run: node tests/request-headers.test.js + */ + +const { + TOKEN_FIREWALL_SMART_ROUTING_HEADER, + TOKEN_FIREWALL_TAGS_HEADER, + TOKEN_FIREWALL_TASK_TYPE_HEADER, + hasTokenFirewallHeaderHints, + parseTokenFirewallHeaders, +} = require("../dist/index.js"); + +function assert(condition, message) { + if (!condition) { + throw new Error(message); + } +} + +function assertDeepEqual(actual, expected, message) { + const actualJson = JSON.stringify(actual); + const expectedJson = JSON.stringify(expected); + if (actualJson !== expectedJson) { + throw new Error(`${message}\nExpected: ${expectedJson}\nActual: ${actualJson}`); + } +} + +try { + const objectHints = parseTokenFirewallHeaders({ + "X-TokenFirewall-Task-Type": " code_generation ", + "X-TokenFirewall-Smart-Routing": "enabled", + "X-TokenFirewall-Tags": "prod, premium, PROD, api", + }); + + assertDeepEqual(objectHints, { + taskType: "code_generation", + smartRouting: true, + tags: ["prod", "premium", "api"], + }, "object headers should parse task type, smart routing, and tags"); + assert(hasTokenFirewallHeaderHints(objectHints), "object hints should be detected"); + + const disabledHints = parseTokenFirewallHeaders(new Headers({ + [TOKEN_FIREWALL_TASK_TYPE_HEADER]: "simple_chat", + [TOKEN_FIREWALL_SMART_ROUTING_HEADER]: "off", + [TOKEN_FIREWALL_TAGS_HEADER]: "support,internal", + })); + + assertDeepEqual(disabledHints, { + taskType: "simple_chat", + smartRouting: false, + tags: ["support", "internal"], + }, "Headers instances should parse case-insensitive controls"); + + const tupleHints = parseTokenFirewallHeaders([ + [TOKEN_FIREWALL_TASK_TYPE_HEADER, "math_reasoning"], + [TOKEN_FIREWALL_SMART_ROUTING_HEADER, "1"], + [TOKEN_FIREWALL_TAGS_HEADER, "analysis,priority"], + ]); + + assertDeepEqual(tupleHints, { + taskType: "math_reasoning", + smartRouting: true, + tags: ["analysis", "priority"], + }, "header tuples should parse controls"); + + const invalidHints = parseTokenFirewallHeaders({ + [TOKEN_FIREWALL_TASK_TYPE_HEADER]: " ", + [TOKEN_FIREWALL_SMART_ROUTING_HEADER]: "maybe", + [TOKEN_FIREWALL_TAGS_HEADER]: " , , ", + }); + + assertDeepEqual(invalidHints, { + smartRouting: undefined, + tags: [], + }, "empty and invalid values should be ignored safely"); + assert(!hasTokenFirewallHeaderHints(invalidHints), "empty hints should not be detected"); + + console.log("Request header parsing tests passed."); +} catch (error) { + console.error("Request header parsing tests failed:", error.message); + process.exit(1); +}