From 0355c22e52562ab8ad290b8dfac5e18107b886df Mon Sep 17 00:00:00 2001 From: Devansh Soni Date: Thu, 28 May 2026 22:34:39 +0530 Subject: [PATCH 1/6] Add files via upload --- src/router/smartRouter.ts | 123 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 src/router/smartRouter.ts diff --git a/src/router/smartRouter.ts b/src/router/smartRouter.ts new file mode 100644 index 0000000..12f8487 --- /dev/null +++ b/src/router/smartRouter.ts @@ -0,0 +1,123 @@ +import { detectProvider } from "./providerDetector"; + +export interface TaskClassificationConfig { + [taskType: string]: { + model: string; + reason: string; + keywords?: string[]; + patterns?: RegExp[]; + priority?: number; + }; +} + +export interface TaskDetection { + taskType: string; + confidence: number; + selectedModel: string; + reason: string; +} + +export const defaultTaskClassification: TaskClassificationConfig = { + code_generation: { + model: "claude-3-5-sonnet-20241022", + reason: "Claude excels at code generation", + keywords: ["write code", "create function", "implement", "build", "develop", "program"], + patterns: [/write.*code/i, /create.*function/i, /implement.*class/i, /write.*function/i], + priority: 10 + }, + math_reasoning: { + model: "o1-mini", + reason: "o1 designed for reasoning", + keywords: ["calculate", "solve", "compute", "equation", "formula", "math"], + patterns: [/solve.*equation/i, /calculate/i, /mathematical/i], + priority: 9 + }, + document_analysis: { + model: "gemini-2.5-pro", + reason: "Gemini has 2M token context window", + keywords: ["summarize document", "analyze document", "extract from", "review document"], + patterns: [/summarize.*document/i, /analyze.*pdf/i, /extract.*information/i], + priority: 8 + }, + simple_chat: { + model: "gpt-4o-mini", + reason: "Cost-effective for simple conversation", + keywords: ["hello", "hi", "how are you", "thanks", "thank you", "help"], + patterns: [/^(hi|hello|hey)/i, /how.*are.*you/i, /thank/i], + priority: 1 + } +}; + +export function extractPrompt(requestBody: any, provider: string | null): string { + if (!requestBody) return ""; + let promptText = ""; + + try { + if ((provider === "openai" || provider === "anthropic") && Array.isArray(requestBody.messages)) { + const msgs = requestBody.messages; + if (msgs.length > 0) { + promptText = msgs[msgs.length - 1].content || ""; + } + } else if (provider === "gemini" && Array.isArray(requestBody.contents)) { + const contents = requestBody.contents; + if (contents.length > 0) { + const parts = contents[contents.length - 1].parts; + if (Array.isArray(parts) && parts.length > 0) { + promptText = parts[parts.length - 1].text || ""; + } + } + } + } catch (e) { + // Parsing failed, return empty string + } + + return promptText; +} + +export function classifyTask( + prompt: string, + config: TaskClassificationConfig = defaultTaskClassification, + confidenceThreshold: number = 0.7 +): TaskDetection | null { + if (!prompt) return null; + + let bestMatch: TaskDetection | null = null; + let highestConfidence = 0; + + for (const [taskType, rules] of Object.entries(config)) { + let confidence = 0; + + // Keyword matching + if (rules.keywords) { + const matchedKeywords = rules.keywords.filter(kw => prompt.toLowerCase().includes(kw.toLowerCase())); + if (matchedKeywords.length > 0) { + confidence += 0.5 + (0.1 * matchedKeywords.length); + } + } + + // Pattern matching + if (rules.patterns) { + for (const pattern of rules.patterns) { + if (pattern.test(prompt)) { + confidence += 0.8; + break; + } + } + } + + // Cap confidence + confidence = Math.min(confidence, 1.0); + + if (confidence >= confidenceThreshold && confidence > highestConfidence) { + highestConfidence = confidence; + bestMatch = { + taskType, + confidence, + selectedModel: rules.model, + reason: rules.reason + }; + } + } + + return bestMatch; +} From 35c6d1399be50d3f4ebab2dba3789803c190a75b Mon Sep 17 00:00:00 2001 From: Devansh Soni Date: Thu, 28 May 2026 22:35:24 +0530 Subject: [PATCH 2/6] Add files via upload --- tests/smartRouting.test.js | 183 +++++++++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 tests/smartRouting.test.js diff --git a/tests/smartRouting.test.js b/tests/smartRouting.test.js new file mode 100644 index 0000000..e373ea1 --- /dev/null +++ b/tests/smartRouting.test.js @@ -0,0 +1,183 @@ +// Mock global fetch BEFORE requiring the module +let interceptedUrl = ""; +let interceptedHeaders = null; +let interceptedBody = null; + +globalThis.fetch = async (input, init) => { + interceptedUrl = typeof input === 'string' ? input : input.url; + interceptedHeaders = init?.headers; + if (init?.body) { + interceptedBody = JSON.parse(init.body); + } + return { + ok: true, + json: async () => ({}), + clone: function() { return this; } + }; +}; + +const { + createModelRouter, + registerApiKeys, + patchGlobalFetch, + disableModelRouter, + unpatchGlobalFetch, +} = require("../dist/index.js"); + +const { extractPrompt, classifyTask } = require("../dist/router/smartRouter.js"); + +const colors = { + reset: '\x1b[0m', + green: '\x1b[32m', + red: '\x1b[31m', + cyan: '\x1b[36m', + yellow: '\x1b[33m', +}; + +function log(msg, color = 'reset') { + console.log(`${colors[color]}${msg}${colors.reset}`); +} + +const testResults = { total: 0, passed: 0, failed: 0, errors: [] }; + +function recordTest(name, passed, error = null) { + testResults.total++; + if (passed) { + testResults.passed++; + log(` ✓ ${name}`, 'green'); + } else { + testResults.failed++; + log(` ✗ ${name}`, 'red'); + if (error) { + testResults.errors.push({ test: name, error: error.message || error }); + } + } +} + +function section(title) { + console.log('\n' + '='.repeat(70)); + log(title, 'cyan'); + console.log('='.repeat(70)); +} + +// Test 1: Extract Prompt +async function testExtractPrompt() { + section('TEST SUITE 1: Prompt Extraction'); + + // Test OpenAI/Anthropic structure + try { + const openaiBody = { + messages: [{ role: "user", content: "Write a function" }] + }; + const prompt = extractPrompt(openaiBody, "openai"); + recordTest('Extract from OpenAI', prompt === "Write a function"); + } catch (err) { + recordTest('Extract from OpenAI', false, err); + } + + // Test Gemini structure + try { + const geminiBody = { + contents: [{ parts: [{ text: "Hello Gemini" }] }] + }; + const prompt = extractPrompt(geminiBody, "gemini"); + recordTest('Extract from Gemini', prompt === "Hello Gemini"); + } catch (err) { + recordTest('Extract from Gemini', false, err); + } +} + +// Test 2: Task Classification +async function testClassification() { + section('TEST SUITE 2: Task Classification'); + try { + const detection1 = classifyTask("Write a Python function to sort"); + recordTest('Classify code generation', detection1 && detection1.taskType === 'code_generation'); + + const detection2 = classifyTask("Calculate the math equation 2+2"); + recordTest('Classify math reasoning', detection2 && detection2.taskType === 'math_reasoning'); + + const detection3 = classifyTask("summarize this document please"); + recordTest('Classify document analysis', detection3 && detection3.taskType === 'document_analysis'); + + const detection4 = classifyTask("hi how are you"); + recordTest('Classify simple chat', detection4 && detection4.taskType === 'simple_chat'); + } catch (err) { + recordTest('Task Classification tests', false, err); + } +} +// Test 3: Fetch Interceptor Smart Routing +async function testFetchIntegration() { + section('TEST SUITE 3: Fetch Interception integration'); + + interceptedUrl = ""; + interceptedHeaders = null; + interceptedBody = null; + + registerApiKeys({ + anthropic: 'test-anthropic', + openai: 'test-openai' + }); + + createModelRouter({ + strategy: "smart", + enableCrossProvider: true, + }); + + patchGlobalFetch(); + try { + // We send an OpenAI request with a code prompt + // The smart router should intercept it, classify it as 'code_generation', + // and since default code_generation is 'claude-3-5-sonnet-20241022' (Anthropic), + // it should cross-provider fallback it to anthropic! + await globalThis.fetch("https://api.openai.com/v1/chat/completions", { + method: "POST", + headers: { + "Authorization": "Bearer test-key" + }, + body: JSON.stringify({ + model: "gpt-4o", + messages: [{ role: "user", content: "write code for sorting" }] + }) + }); + + const isAnthropicUrl = interceptedUrl.includes("api.anthropic.com"); + const isClaudeModel = interceptedBody.model === "claude-3-5-sonnet-20241022"; + + recordTest('Fetch Interceptor correctly overrides model', isClaudeModel); + recordTest('Fetch Interceptor correctly handles cross-provider URL override', isAnthropicUrl); + } catch (err) { + recordTest('Fetch Interception integration', false, err); + } finally { + disableModelRouter(); + } +} + +async function runAllTests() { + console.clear(); + log('\n╔════════════════════════════════════════════════════════════════════╗', 'cyan'); + log('║ Smart Routing - Test Suite ║', 'cyan'); + log('╚════════════════════════════════════════════════════════════════════╝\n', 'cyan'); + + await testExtractPrompt(); + await testClassification(); + await testFetchIntegration(); + + section('TEST SUMMARY'); + log(`Total Tests: ${testResults.total}`, 'cyan'); + log(`Passed: ${testResults.passed}`, 'green'); + log(`Failed: ${testResults.failed}`, testResults.failed > 0 ? 'red' : 'green'); + + if (testResults.errors.length > 0) { + console.log('\n' + '─'.repeat(70)); + log('FAILED TESTS:', 'red'); + testResults.errors.forEach((err, idx) => { + log(`${idx + 1}. ${err.test}`, 'red'); + log(` ${err.error}`, 'yellow'); + }); + } + + process.exit(testResults.failed > 0 ? 1 : 0); +} + +runAllTests(); From 4357f5a27c13d0e264fdca4e40d4ca78b548f214 Mon Sep 17 00:00:00 2001 From: Devansh Soni Date: Thu, 28 May 2026 22:35:54 +0530 Subject: [PATCH 3/6] Add files via upload --- src/index.ts | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/index.ts b/src/index.ts index c769e99..d88e136 100644 --- a/src/index.ts +++ b/src/index.ts @@ -252,6 +252,17 @@ export type { ApiKeyConfig } from "./router/types"; +export type { + TaskClassificationConfig, + TaskDetection +} from "./router/smartRouter"; + +export { + defaultTaskClassification, + classifyTask, + extractPrompt +} from "./router/smartRouter"; + /** * Model configuration for bulk registration */ From 8148b4fa5894fc4695e586da72a3d75db0bc2434 Mon Sep 17 00:00:00 2001 From: Devansh Soni Date: Thu, 28 May 2026 22:38:50 +0530 Subject: [PATCH 4/6] Add files via upload --- src/router/types.ts | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/router/types.ts b/src/router/types.ts index 8e333f4..ee233c6 100644 --- a/src/router/types.ts +++ b/src/router/types.ts @@ -5,7 +5,7 @@ /** * Routing strategy types */ -export type RoutingStrategy = "fallback" | "context" | "cost"; +export type RoutingStrategy = "fallback" | "context" | "cost" | "smart"; /** * Failure types detected by error detector @@ -43,6 +43,12 @@ export interface ModelRouterOptions { apiKeys?: ApiKeyConfig; /** Enable cross-provider fallback (default: false) */ enableCrossProvider?: boolean; + /** Custom task classification config for smart strategy */ + taskClassification?: any; + /** Confidence threshold for smart strategy */ + confidenceThreshold?: number; + /** Default model for smart strategy if no match found */ + defaultModel?: string; } /** From 395b3197a6c51924765b3e89b99460f0b03e8847 Mon Sep 17 00:00:00 2001 From: Devansh Soni Date: Thu, 28 May 2026 22:53:30 +0530 Subject: [PATCH 5/6] Add files via upload --- src/router/modelRouter.ts | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/router/modelRouter.ts b/src/router/modelRouter.ts index 4b6f8ea..4f6c896 100644 --- a/src/router/modelRouter.ts +++ b/src/router/modelRouter.ts @@ -17,12 +17,18 @@ export class ModelRouter { private fallbackMap: Record; private maxRetries: number; private crossProviderEnabled: boolean; + private taskClassification?: any; + private confidenceThreshold?: number; + private defaultModel?: string; constructor(options: ModelRouterOptions) { this.strategy = options.strategy; this.fallbackMap = options.fallbackMap || {}; this.maxRetries = options.maxRetries ?? 1; this.crossProviderEnabled = options.enableCrossProvider ?? false; + this.taskClassification = options.taskClassification; + this.confidenceThreshold = options.confidenceThreshold; + this.defaultModel = options.defaultModel; // Register API keys if provided if (options.apiKeys) { @@ -137,6 +143,9 @@ export class ModelRouter { case "cost": return costStrategy(context, failureType as any); + case "smart": + return { retry: false, reason: "Smart routing handles requests pre-flight, no post-failure fallback defined" }; + default: throw new Error( `TokenFirewall Router: Unknown strategy "${this.strategy}"` @@ -164,4 +173,16 @@ export class ModelRouter { public isCrossProviderEnabled(): boolean { return this.crossProviderEnabled; } + + public getTaskClassification() { + return this.taskClassification; + } + + public getConfidenceThreshold() { + return this.confidenceThreshold; + } + + public getDefaultModel() { + return this.defaultModel; + } } From 98e8f092439ccdf2da2b945b33bbbb26ccd93c09 Mon Sep 17 00:00:00 2001 From: Devansh Soni Date: Thu, 28 May 2026 22:56:01 +0530 Subject: [PATCH 6/6] updated logic for fetch interceptor --- src/interceptors/fetchInterceptor.ts | 61 +++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/src/interceptors/fetchInterceptor.ts b/src/interceptors/fetchInterceptor.ts index 3048ec4..8d3ab9e 100644 --- a/src/interceptors/fetchInterceptor.ts +++ b/src/interceptors/fetchInterceptor.ts @@ -8,6 +8,7 @@ import { apiKeyManager } from "../router/apiKeyManager"; import { buildProviderHeaders, appendApiKeyToUrl } from "../router/providerHeaders"; import { transformRequest } from "../router/requestTransformer"; import { transformResponse } from "../router/responseTransformer"; +import { extractPrompt, classifyTask, defaultTaskClassification } from "../router/smartRouter"; let isPatched = false; let budgetManager: BudgetManager | null = null; @@ -75,7 +76,7 @@ async function standardFetch( // Process response and track budget BEFORE returning try { const responseData = await clonedResponse.json(); - + // Try to process with adapter registry const normalizedUsage = adapterRegistry.process(responseData); @@ -132,13 +133,9 @@ async function fetchWithRetry( let currentInput = input; // Extract original model and provider from request - const { originalModel, provider: originalProvider } = extractModelInfo(input, init); - - if (originalModel) { - attemptedModels.push(originalModel); - } + let { originalModel, provider: originalProvider } = extractModelInfo(input, init); - // Parse the original request body for potential cross-provider transformations + // Parse the original request body for potential cross-provider transformations or smart routing let originalRequestBody: any = null; if (init?.body) { try { @@ -148,6 +145,48 @@ async function fetchWithRetry( } } + // --- SMART ROUTING --- + if (modelRouter && modelRouter.getStrategy() === 'smart') { + const prompt = extractPrompt(originalRequestBody, originalProvider); + if (prompt) { + const detection = classifyTask( + prompt, + modelRouter.getTaskClassification() || defaultTaskClassification, + modelRouter.getConfidenceThreshold() || 0.7 + ); + + const newModel = detection?.selectedModel || modelRouter.getDefaultModel(); + + if (newModel && newModel !== originalModel) { + const nextProvider = detectProvider(newModel); + if (nextProvider && originalProvider && nextProvider !== originalProvider && modelRouter.isCrossProviderEnabled()) { + const updated = buildCrossProviderRequest(originalRequestBody, originalProvider, nextProvider, newModel); + if (updated) { + currentInput = updated.url; + currentInit = updated.init; + originalModel = newModel; + originalProvider = nextProvider; + if (currentInit.body) { + try { originalRequestBody = JSON.parse(currentInit.body as string); } catch { } + } + } + } else if (nextProvider === originalProvider) { + const updated = updateRequestModel(currentInput, currentInit, newModel, originalProvider); + currentInput = updated.input; + currentInit = updated.init; + originalModel = newModel; + if (currentInit?.body) { + try { originalRequestBody = JSON.parse(currentInit.body as string); } catch { } + } + } + } + } + } + + if (originalModel) { + attemptedModels.push(originalModel); + } + while (retryCount <= (modelRouter?.getMaxRetries() || 0)) { try { // Make the request @@ -163,7 +202,7 @@ async function fetchWithRetry( } catch { // Response is not JSON or already consumed } - + // Throw proper Error instance with structured data const errorObj = { status: response.status, @@ -418,7 +457,7 @@ function extractModelInfo( // Extract model from request body (for non-Gemini providers) if (!model) { let body: string | null = null; - + // Get body from init or Request object if (init?.body) { body = typeof init.body === 'string' ? init.body : null; @@ -427,7 +466,7 @@ function extractModelInfo( // So we skip body parsing for Request objects without explicit init.body body = null; } - + if (body) { try { const bodyObj = JSON.parse(body); @@ -462,7 +501,7 @@ function updateRequestModel( // Update URL for Gemini (model is in URL path) if (provider === 'gemini') { const newUrl = url.replace(/\/models\/[^:]+:/, `/models/${newModel}:`); - + // If input was a Request object, we need to create a new Request with updated URL if (input instanceof Request) { // Clone the request with new URL