Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ interface ModelRouterOptions {
strategy: "fallback" | "context" | "cost"; // Routing strategy
fallbackMap?: Record<string, string[]>; // Fallback model map
maxRetries?: number; // Max retry attempts (default: 1)
cacheRoutingDecisions?: boolean; // Cache repeated decisions (default: true)
routingCacheTtlMs?: number; // Cache TTL in ms (default: 5 minutes)
maxRoutingCacheSize?: number; // Max cached decisions (default: 1000)
}
```

Expand All @@ -301,7 +304,8 @@ createModelRouter({
"gpt-4o": ["gpt-4o-mini", "gpt-3.5-turbo"],
"claude-3-5-sonnet-20241022": ["claude-3-5-haiku-20241022"]
},
maxRetries: 2
maxRetries: 2,
cacheRoutingDecisions: true
});

patchGlobalFetch();
Expand All @@ -324,6 +328,13 @@ patchGlobalFetch();
- Selects cheaper model from same provider
- Best for: Cost optimization, rate limit handling

### Performance

- Repeated routing decisions are cached by default for 5 minutes.
- Cached decisions are bounded by `maxRoutingCacheSize` to avoid unbounded memory growth.
- The cache key includes the strategy, failure type, provider, model, retry count, attempted models, and request prompt fingerprint.
- Set `cacheRoutingDecisions: false` to disable caching for workloads that require fully uncached routing decisions.

### Error Detection

The router automatically detects and classifies failures:
Expand Down
5 changes: 3 additions & 2 deletions src/router/errorDetector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,12 @@ export class ErrorDetector {
if (typeof data === "object" && data !== null) {
const errorMessage = data.error?.message || data.message || "";
const errorCode = data.error?.code || data.code || "";
const lowerMessage = errorMessage.toLowerCase();

return (
errorCode === "model_not_found" ||
errorMessage.toLowerCase().includes("model") &&
errorMessage.toLowerCase().includes("not found")
lowerMessage.includes("model") &&
lowerMessage.includes("not found")
);
}

Expand Down
153 changes: 150 additions & 3 deletions src/router/modelRouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ import { errorDetector } from "./errorDetector";
import { fallbackStrategy, contextStrategy, costStrategy } from "./routingStrategies";
import { apiKeyManager } from "./apiKeyManager";

const DEFAULT_ROUTING_CACHE_TTL_MS = 5 * 60 * 1000;
const DEFAULT_ROUTING_CACHE_SIZE = 1000;

interface CachedRoutingDecision {
decision: RoutingDecision;
expiresAt: number;
}

/**
* Intelligent Model Router
* Handles automatic retries and model switching on failures
Expand All @@ -17,12 +25,19 @@ export class ModelRouter {
private fallbackMap: Record<string, string[]>;
private maxRetries: number;
private crossProviderEnabled: boolean;
private cacheRoutingDecisions: boolean;
private routingCacheTtlMs: number;
private maxRoutingCacheSize: number;
private decisionCache = new Map<string, CachedRoutingDecision>();

constructor(options: ModelRouterOptions) {
this.strategy = options.strategy;
this.fallbackMap = options.fallbackMap || {};
this.maxRetries = options.maxRetries ?? 1;
this.crossProviderEnabled = options.enableCrossProvider ?? false;
this.cacheRoutingDecisions = options.cacheRoutingDecisions ?? true;
this.routingCacheTtlMs = options.routingCacheTtlMs ?? DEFAULT_ROUTING_CACHE_TTL_MS;
this.maxRoutingCacheSize = options.maxRoutingCacheSize ?? DEFAULT_ROUTING_CACHE_SIZE;

// Register API keys if provided
if (options.apiKeys) {
Expand All @@ -40,6 +55,14 @@ export class ModelRouter {
throw new Error("TokenFirewall Router: maxRetries must be non-negative");
}

if (this.routingCacheTtlMs < 0) {
throw new Error("TokenFirewall Router: routingCacheTtlMs must be non-negative");
}

if (this.maxRoutingCacheSize < 0) {
throw new Error("TokenFirewall Router: maxRoutingCacheSize must be non-negative");
}

if (this.maxRetries > 5) {
console.warn(
"TokenFirewall Router: maxRetries > 5 may cause excessive API calls"
Expand Down Expand Up @@ -90,6 +113,11 @@ export class ModelRouter {

// Detect failure type
const failureType = errorDetector.detectFailureType(context.error);
const cacheKey = this.buildDecisionCacheKey(context, failureType);
const cachedDecision = this.getCachedDecision(cacheKey);
if (cachedDecision) {
return cachedDecision;
}

// Select routing strategy
const decision = this.selectStrategy(context, failureType);
Expand All @@ -103,21 +131,26 @@ export class ModelRouter {

// Prevent retrying same model
if (decision.retry && context.attemptedModels.includes(decision.nextModel!)) {
return {
const finalDecision = {
retry: false,
reason: `Model ${decision.nextModel} has already been attempted`
};
this.setCachedDecision(cacheKey, finalDecision);
return finalDecision;
}

// Prevent switching back to original model (circular retry)
if (decision.retry && decision.nextModel === context.originalModel && context.retryCount > 0) {
return {
const finalDecision = {
retry: false,
reason: `Cannot switch back to original model ${context.originalModel}`
};
this.setCachedDecision(cacheKey, finalDecision);
return finalDecision;
}

return decision;
this.setCachedDecision(cacheKey, decision);
return { ...decision };
}

/**
Expand Down Expand Up @@ -164,4 +197,118 @@ export class ModelRouter {
public isCrossProviderEnabled(): boolean {
return this.crossProviderEnabled;
}

/**
* Get current routing decision cache size.
*/
public getRoutingCacheSize(): number {
return this.decisionCache.size;
}

/**
* Clear cached routing decisions.
*/
public clearRoutingCache(): void {
this.decisionCache.clear();
}

private buildDecisionCacheKey(
context: FailureContext,
failureType: string
): string | null {
if (!this.cacheRoutingDecisions || this.routingCacheTtlMs === 0 || this.maxRoutingCacheSize === 0) {
return null;
}

const attemptedModels = context.attemptedModels.join("|");
const requestFingerprint = this.getRequestFingerprint(context.requestBody);
return [
this.strategy,
failureType,
context.provider,
context.originalModel,
context.retryCount,
attemptedModels,
requestFingerprint
].join("::");
}

private getRequestFingerprint(requestBody: any): string {
if (requestBody === null || requestBody === undefined) {
return "";
}

if (typeof requestBody === "string") {
return requestBody.slice(0, 2048);
}

if (typeof requestBody !== "object") {
return String(requestBody);
}

const body = requestBody as {
prompt?: unknown;
input?: unknown;
messages?: Array<{ content?: unknown }>;
model?: unknown;
};

if (typeof body.prompt === "string") {
return body.prompt.slice(0, 2048);
}

if (typeof body.input === "string") {
return body.input.slice(0, 2048);
}

if (Array.isArray(body.messages)) {
return body.messages
.map(message => String(message.content ?? ""))
.join("\n")
.slice(0, 2048);
}

if (typeof body.model === "string") {
return body.model;
}

return "";
}

private getCachedDecision(cacheKey: string | null): RoutingDecision | null {
if (!cacheKey) {
return null;
}

const cached = this.decisionCache.get(cacheKey);
if (!cached) {
return null;
}

if (cached.expiresAt <= Date.now()) {
this.decisionCache.delete(cacheKey);
return null;
}

return { ...cached.decision };
}

private setCachedDecision(cacheKey: string | null, decision: RoutingDecision): void {
if (!cacheKey) {
return;
}

while (this.decisionCache.size >= this.maxRoutingCacheSize) {
const oldestKey = this.decisionCache.keys().next().value;
if (!oldestKey) {
break;
}
this.decisionCache.delete(oldestKey);
}

this.decisionCache.set(cacheKey, {
decision: { ...decision },
expiresAt: Date.now() + this.routingCacheTtlMs
});
}
}
43 changes: 22 additions & 21 deletions src/router/routingStrategies.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export function fallbackStrategy(
fallbackMap: Record<string, string[]>
): RoutingDecision {
const { originalModel, attemptedModels } = context;
const attempted = new Set(attemptedModels);

// Get fallback list for this model
const fallbacks = fallbackMap[originalModel];
Expand All @@ -24,7 +25,7 @@ export function fallbackStrategy(
}

// Find first fallback that hasn't been attempted
const nextModel = fallbacks.find(model => !attemptedModels.includes(model));
const nextModel = fallbacks.find(model => !attempted.has(model));

if (!nextModel) {
return {
Expand All @@ -49,6 +50,7 @@ export function contextStrategy(
failureType: FailureType
): RoutingDecision {
const { originalModel, provider, attemptedModels } = context;
const attempted = new Set(attemptedModels);

// Only applicable for context overflow
if (failureType !== "context_overflow") {
Expand Down Expand Up @@ -80,20 +82,19 @@ export function contextStrategy(

// Filter models with larger context that haven't been attempted
const largerContextModels = availableModels
.filter((model: string) => {
.map((model: string) => {
const limit = contextRegistry.getContextLimit(provider, model);
return { model, limit };
})
.filter(({ model, limit }) => {
return (
limit !== undefined &&
limit > currentLimit &&
!attemptedModels.includes(model) &&
!attempted.has(model) &&
model !== originalModel // Don't suggest the same model
);
})
.sort((a: string, b: string) => {
const limitA = contextRegistry.getContextLimit(provider, a) || 0;
const limitB = contextRegistry.getContextLimit(provider, b) || 0;
return limitA - limitB; // Sort ascending (smallest upgrade first)
});
.sort((a, b) => (a.limit || 0) - (b.limit || 0)); // Sort ascending (smallest upgrade first)

if (largerContextModels.length === 0) {
return {
Expand All @@ -102,8 +103,7 @@ export function contextStrategy(
};
}

const nextModel = largerContextModels[0];
const nextLimit = contextRegistry.getContextLimit(provider, nextModel);
const { model: nextModel, limit: nextLimit } = largerContextModels[0];

return {
retry: true,
Expand All @@ -121,6 +121,7 @@ export function costStrategy(
failureType: FailureType
): RoutingDecision {
const { originalModel, provider, attemptedModels } = context;
const attempted = new Set(attemptedModels);

// Get current model's pricing
let currentPricing;
Expand Down Expand Up @@ -149,25 +150,25 @@ export function costStrategy(
// Find cheaper models that haven't been attempted
const cheaperModels = providerModels
.filter((model: string) => {
if (attemptedModels.includes(model) || model === originalModel) {
if (attempted.has(model) || model === originalModel) {
return false;
}

return true;
})
.map((model: string) => {
try {
const pricing = pricingRegistry.getPricing(provider, model);
const avgCost = (pricing.input + pricing.output) / 2;
return avgCost < currentAvgCost;
return { model, avgCost };
} catch {
return false;
return null;
}
})
.sort((a: string, b: string) => {
const pricingA = pricingRegistry.getPricing(provider, a);
const pricingB = pricingRegistry.getPricing(provider, b);
const avgCostA = (pricingA.input + pricingA.output) / 2;
const avgCostB = (pricingB.input + pricingB.output) / 2;
return avgCostA - avgCostB; // Sort ascending (cheapest first)
});
.filter((entry): entry is { model: string; avgCost: number } => (
entry !== null && entry.avgCost < currentAvgCost
))
.sort((a, b) => a.avgCost - b.avgCost); // Sort ascending (cheapest first)

if (cheaperModels.length === 0) {
return {
Expand All @@ -176,7 +177,7 @@ export function costStrategy(
};
}

const nextModel = cheaperModels[0];
const nextModel = cheaperModels[0].model;

return {
retry: true,
Expand Down
6 changes: 6 additions & 0 deletions src/router/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ export interface ModelRouterOptions {
fallbackMap?: Record<string, string[]>;
/** Maximum number of retry attempts (default: 1) */
maxRetries?: number;
/** Cache repeated routing decisions for identical failure contexts (default: true) */
cacheRoutingDecisions?: boolean;
/** Routing decision cache TTL in milliseconds (default: 5 minutes) */
routingCacheTtlMs?: number;
/** Maximum cached routing decisions retained in memory (default: 1000) */
maxRoutingCacheSize?: number;
/** API keys for cross-provider fallback */
apiKeys?: ApiKeyConfig;
/** Enable cross-provider fallback (default: false) */
Expand Down
Loading