diff --git a/README.md b/README.md index 57253f1..02b02e8 100644 --- a/README.md +++ b/README.md @@ -286,6 +286,9 @@ interface ModelRouterOptions { strategy: "fallback" | "context" | "cost"; // Routing strategy fallbackMap?: Record; // Fallback model map maxRetries?: number; // Max retry attempts (default: 1) + cacheRoutingDecisions?: boolean; // Cache repeated decisions (default: true) + routingCacheTtlMs?: number; // Cache TTL in ms (default: 5 minutes) + maxRoutingCacheSize?: number; // Max cached decisions (default: 1000) } ``` @@ -301,7 +304,8 @@ createModelRouter({ "gpt-4o": ["gpt-4o-mini", "gpt-3.5-turbo"], "claude-3-5-sonnet-20241022": ["claude-3-5-haiku-20241022"] }, - maxRetries: 2 + maxRetries: 2, + cacheRoutingDecisions: true }); patchGlobalFetch(); @@ -324,6 +328,13 @@ patchGlobalFetch(); - Selects cheaper model from same provider - Best for: Cost optimization, rate limit handling +### Performance + +- Repeated routing decisions are cached by default for 5 minutes. +- Cached decisions are bounded by `maxRoutingCacheSize` to avoid unbounded memory growth. +- The cache key includes the strategy, failure type, provider, model, retry count, attempted models, and request prompt fingerprint. +- Set `cacheRoutingDecisions: false` to disable caching for workloads that require fully uncached routing decisions. + ### Error Detection The router automatically detects and classifies failures: diff --git a/src/router/errorDetector.ts b/src/router/errorDetector.ts index 54efdf5..563ef36 100644 --- a/src/router/errorDetector.ts +++ b/src/router/errorDetector.ts @@ -211,11 +211,12 @@ export class ErrorDetector { if (typeof data === "object" && data !== null) { const errorMessage = data.error?.message || data.message || ""; const errorCode = data.error?.code || data.code || ""; + const lowerMessage = errorMessage.toLowerCase(); return ( errorCode === "model_not_found" || - errorMessage.toLowerCase().includes("model") && - errorMessage.toLowerCase().includes("not found") + lowerMessage.includes("model") && + lowerMessage.includes("not found") ); } diff --git a/src/router/modelRouter.ts b/src/router/modelRouter.ts index 4b6f8ea..ed703f3 100644 --- a/src/router/modelRouter.ts +++ b/src/router/modelRouter.ts @@ -8,6 +8,14 @@ import { errorDetector } from "./errorDetector"; import { fallbackStrategy, contextStrategy, costStrategy } from "./routingStrategies"; import { apiKeyManager } from "./apiKeyManager"; +const DEFAULT_ROUTING_CACHE_TTL_MS = 5 * 60 * 1000; +const DEFAULT_ROUTING_CACHE_SIZE = 1000; + +interface CachedRoutingDecision { + decision: RoutingDecision; + expiresAt: number; +} + /** * Intelligent Model Router * Handles automatic retries and model switching on failures @@ -17,12 +25,19 @@ export class ModelRouter { private fallbackMap: Record; private maxRetries: number; private crossProviderEnabled: boolean; + private cacheRoutingDecisions: boolean; + private routingCacheTtlMs: number; + private maxRoutingCacheSize: number; + private decisionCache = new Map(); constructor(options: ModelRouterOptions) { this.strategy = options.strategy; this.fallbackMap = options.fallbackMap || {}; this.maxRetries = options.maxRetries ?? 1; this.crossProviderEnabled = options.enableCrossProvider ?? false; + this.cacheRoutingDecisions = options.cacheRoutingDecisions ?? true; + this.routingCacheTtlMs = options.routingCacheTtlMs ?? DEFAULT_ROUTING_CACHE_TTL_MS; + this.maxRoutingCacheSize = options.maxRoutingCacheSize ?? DEFAULT_ROUTING_CACHE_SIZE; // Register API keys if provided if (options.apiKeys) { @@ -40,6 +55,14 @@ export class ModelRouter { throw new Error("TokenFirewall Router: maxRetries must be non-negative"); } + if (this.routingCacheTtlMs < 0) { + throw new Error("TokenFirewall Router: routingCacheTtlMs must be non-negative"); + } + + if (this.maxRoutingCacheSize < 0) { + throw new Error("TokenFirewall Router: maxRoutingCacheSize must be non-negative"); + } + if (this.maxRetries > 5) { console.warn( "TokenFirewall Router: maxRetries > 5 may cause excessive API calls" @@ -90,6 +113,11 @@ export class ModelRouter { // Detect failure type const failureType = errorDetector.detectFailureType(context.error); + const cacheKey = this.buildDecisionCacheKey(context, failureType); + const cachedDecision = this.getCachedDecision(cacheKey); + if (cachedDecision) { + return cachedDecision; + } // Select routing strategy const decision = this.selectStrategy(context, failureType); @@ -103,21 +131,26 @@ export class ModelRouter { // Prevent retrying same model if (decision.retry && context.attemptedModels.includes(decision.nextModel!)) { - return { + const finalDecision = { retry: false, reason: `Model ${decision.nextModel} has already been attempted` }; + this.setCachedDecision(cacheKey, finalDecision); + return finalDecision; } // Prevent switching back to original model (circular retry) if (decision.retry && decision.nextModel === context.originalModel && context.retryCount > 0) { - return { + const finalDecision = { retry: false, reason: `Cannot switch back to original model ${context.originalModel}` }; + this.setCachedDecision(cacheKey, finalDecision); + return finalDecision; } - return decision; + this.setCachedDecision(cacheKey, decision); + return { ...decision }; } /** @@ -164,4 +197,118 @@ export class ModelRouter { public isCrossProviderEnabled(): boolean { return this.crossProviderEnabled; } + + /** + * Get current routing decision cache size. + */ + public getRoutingCacheSize(): number { + return this.decisionCache.size; + } + + /** + * Clear cached routing decisions. + */ + public clearRoutingCache(): void { + this.decisionCache.clear(); + } + + private buildDecisionCacheKey( + context: FailureContext, + failureType: string + ): string | null { + if (!this.cacheRoutingDecisions || this.routingCacheTtlMs === 0 || this.maxRoutingCacheSize === 0) { + return null; + } + + const attemptedModels = context.attemptedModels.join("|"); + const requestFingerprint = this.getRequestFingerprint(context.requestBody); + return [ + this.strategy, + failureType, + context.provider, + context.originalModel, + context.retryCount, + attemptedModels, + requestFingerprint + ].join("::"); + } + + private getRequestFingerprint(requestBody: any): string { + if (requestBody === null || requestBody === undefined) { + return ""; + } + + if (typeof requestBody === "string") { + return requestBody.slice(0, 2048); + } + + if (typeof requestBody !== "object") { + return String(requestBody); + } + + const body = requestBody as { + prompt?: unknown; + input?: unknown; + messages?: Array<{ content?: unknown }>; + model?: unknown; + }; + + if (typeof body.prompt === "string") { + return body.prompt.slice(0, 2048); + } + + if (typeof body.input === "string") { + return body.input.slice(0, 2048); + } + + if (Array.isArray(body.messages)) { + return body.messages + .map(message => String(message.content ?? "")) + .join("\n") + .slice(0, 2048); + } + + if (typeof body.model === "string") { + return body.model; + } + + return ""; + } + + private getCachedDecision(cacheKey: string | null): RoutingDecision | null { + if (!cacheKey) { + return null; + } + + const cached = this.decisionCache.get(cacheKey); + if (!cached) { + return null; + } + + if (cached.expiresAt <= Date.now()) { + this.decisionCache.delete(cacheKey); + return null; + } + + return { ...cached.decision }; + } + + private setCachedDecision(cacheKey: string | null, decision: RoutingDecision): void { + if (!cacheKey) { + return; + } + + while (this.decisionCache.size >= this.maxRoutingCacheSize) { + const oldestKey = this.decisionCache.keys().next().value; + if (!oldestKey) { + break; + } + this.decisionCache.delete(oldestKey); + } + + this.decisionCache.set(cacheKey, { + decision: { ...decision }, + expiresAt: Date.now() + this.routingCacheTtlMs + }); + } } diff --git a/src/router/routingStrategies.ts b/src/router/routingStrategies.ts index ca5533a..2876db0 100644 --- a/src/router/routingStrategies.ts +++ b/src/router/routingStrategies.ts @@ -12,6 +12,7 @@ export function fallbackStrategy( fallbackMap: Record ): RoutingDecision { const { originalModel, attemptedModels } = context; + const attempted = new Set(attemptedModels); // Get fallback list for this model const fallbacks = fallbackMap[originalModel]; @@ -24,7 +25,7 @@ export function fallbackStrategy( } // Find first fallback that hasn't been attempted - const nextModel = fallbacks.find(model => !attemptedModels.includes(model)); + const nextModel = fallbacks.find(model => !attempted.has(model)); if (!nextModel) { return { @@ -49,6 +50,7 @@ export function contextStrategy( failureType: FailureType ): RoutingDecision { const { originalModel, provider, attemptedModels } = context; + const attempted = new Set(attemptedModels); // Only applicable for context overflow if (failureType !== "context_overflow") { @@ -80,20 +82,19 @@ export function contextStrategy( // Filter models with larger context that haven't been attempted const largerContextModels = availableModels - .filter((model: string) => { + .map((model: string) => { const limit = contextRegistry.getContextLimit(provider, model); + return { model, limit }; + }) + .filter(({ model, limit }) => { return ( limit !== undefined && limit > currentLimit && - !attemptedModels.includes(model) && + !attempted.has(model) && model !== originalModel // Don't suggest the same model ); }) - .sort((a: string, b: string) => { - const limitA = contextRegistry.getContextLimit(provider, a) || 0; - const limitB = contextRegistry.getContextLimit(provider, b) || 0; - return limitA - limitB; // Sort ascending (smallest upgrade first) - }); + .sort((a, b) => (a.limit || 0) - (b.limit || 0)); // Sort ascending (smallest upgrade first) if (largerContextModels.length === 0) { return { @@ -102,8 +103,7 @@ export function contextStrategy( }; } - const nextModel = largerContextModels[0]; - const nextLimit = contextRegistry.getContextLimit(provider, nextModel); + const { model: nextModel, limit: nextLimit } = largerContextModels[0]; return { retry: true, @@ -121,6 +121,7 @@ export function costStrategy( failureType: FailureType ): RoutingDecision { const { originalModel, provider, attemptedModels } = context; + const attempted = new Set(attemptedModels); // Get current model's pricing let currentPricing; @@ -149,25 +150,25 @@ export function costStrategy( // Find cheaper models that haven't been attempted const cheaperModels = providerModels .filter((model: string) => { - if (attemptedModels.includes(model) || model === originalModel) { + if (attempted.has(model) || model === originalModel) { return false; } + return true; + }) + .map((model: string) => { try { const pricing = pricingRegistry.getPricing(provider, model); const avgCost = (pricing.input + pricing.output) / 2; - return avgCost < currentAvgCost; + return { model, avgCost }; } catch { - return false; + return null; } }) - .sort((a: string, b: string) => { - const pricingA = pricingRegistry.getPricing(provider, a); - const pricingB = pricingRegistry.getPricing(provider, b); - const avgCostA = (pricingA.input + pricingA.output) / 2; - const avgCostB = (pricingB.input + pricingB.output) / 2; - return avgCostA - avgCostB; // Sort ascending (cheapest first) - }); + .filter((entry): entry is { model: string; avgCost: number } => ( + entry !== null && entry.avgCost < currentAvgCost + )) + .sort((a, b) => a.avgCost - b.avgCost); // Sort ascending (cheapest first) if (cheaperModels.length === 0) { return { @@ -176,7 +177,7 @@ export function costStrategy( }; } - const nextModel = cheaperModels[0]; + const nextModel = cheaperModels[0].model; return { retry: true, diff --git a/src/router/types.ts b/src/router/types.ts index 8e333f4..13cce14 100644 --- a/src/router/types.ts +++ b/src/router/types.ts @@ -39,6 +39,12 @@ export interface ModelRouterOptions { fallbackMap?: Record; /** Maximum number of retry attempts (default: 1) */ maxRetries?: number; + /** Cache repeated routing decisions for identical failure contexts (default: true) */ + cacheRoutingDecisions?: boolean; + /** Routing decision cache TTL in milliseconds (default: 5 minutes) */ + routingCacheTtlMs?: number; + /** Maximum cached routing decisions retained in memory (default: 1000) */ + maxRoutingCacheSize?: number; /** API keys for cross-provider fallback */ apiKeys?: ApiKeyConfig; /** Enable cross-provider fallback (default: false) */ diff --git a/tests/performance-optimization.test.js b/tests/performance-optimization.test.js new file mode 100644 index 0000000..45a13dc --- /dev/null +++ b/tests/performance-optimization.test.js @@ -0,0 +1,98 @@ +/** + * Smart routing performance optimization tests. + * + * Run: node tests/performance-optimization.test.js + */ + +const { createModelRouter, disableModelRouter } = require("../dist/index.js"); + +const colors = { + reset: "\x1b[0m", + green: "\x1b[32m", + red: "\x1b[31m", + cyan: "\x1b[36m", +}; + +function log(message, color = "reset") { + console.log(`${colors[color]}${message}${colors.reset}`); +} + +function assert(condition, message) { + if (!condition) { + throw new Error(message); + } +} + +function createContext(prompt = "Review this function for bugs") { + return { + error: { status: 429 }, + originalModel: "gpt-4o", + requestBody: { + model: "gpt-4o", + messages: [{ role: "user", content: prompt }], + }, + provider: "openai", + retryCount: 0, + attemptedModels: ["gpt-4o"], + }; +} + +async function run() { + console.clear(); + log("\nSmart Routing Performance Optimization Tests\n", "cyan"); + + const router = createModelRouter({ + strategy: "fallback", + fallbackMap: { + "gpt-4o": ["gpt-4o-mini", "claude-3-5-sonnet-20241022"], + }, + maxRetries: 2, + routingCacheTtlMs: 5 * 60 * 1000, + maxRoutingCacheSize: 64, + }); + + const firstDecision = router.handleFailure(createContext()); + assert(firstDecision.retry === true, "first routing decision should retry"); + assert(firstDecision.nextModel === "gpt-4o-mini", "first fallback should be gpt-4o-mini"); + assert(router.getRoutingCacheSize() === 1, "first decision should populate the cache"); + + firstDecision.nextModel = "mutated-model"; + const cachedDecision = router.handleFailure(createContext()); + assert( + cachedDecision.nextModel === "gpt-4o-mini", + "cached decisions should be returned as defensive copies" + ); + + router.handleFailure(createContext("Summarize this invoice")); + assert(router.getRoutingCacheSize() === 2, "different prompts should get distinct cache entries"); + + const iterations = 25000; + const context = createContext(); + const start = process.hrtime.bigint(); + for (let i = 0; i < iterations; i++) { + router.handleFailure(context); + } + const elapsedMs = Number(process.hrtime.bigint() - start) / 1e6; + const averageMs = elapsedMs / iterations; + + assert( + averageMs < 10, + `routing overhead averaged ${averageMs.toFixed(4)}ms, expected under 10ms` + ); + + router.clearRoutingCache(); + assert(router.getRoutingCacheSize() === 0, "clearRoutingCache should empty cached decisions"); + + disableModelRouter(); + + log(" ✓ decision cache stores repeated routing results", "green"); + log(" ✓ cached decisions are defensive copies", "green"); + log(" ✓ repeated routing overhead stays below 10ms", "green"); + log(`\nAverage routing overhead: ${averageMs.toFixed(4)}ms over ${iterations} iterations\n`, "cyan"); +} + +run().catch(error => { + disableModelRouter(); + log(`\nPerformance optimization test failed: ${error.message}`, "red"); + process.exit(1); +});