Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ export type {
ApiKeyConfig
} from "./router/types";

export type {
TaskClassificationConfig,
TaskDetection
} from "./router/smartRouter";

export {
defaultTaskClassification,
classifyTask,
extractPrompt
} from "./router/smartRouter";

/**
* Model configuration for bulk registration
*/
Expand Down
61 changes: 50 additions & 11 deletions src/interceptors/fetchInterceptor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { apiKeyManager } from "../router/apiKeyManager";
import { buildProviderHeaders, appendApiKeyToUrl } from "../router/providerHeaders";
import { transformRequest } from "../router/requestTransformer";
import { transformResponse } from "../router/responseTransformer";
import { extractPrompt, classifyTask, defaultTaskClassification } from "../router/smartRouter";

let isPatched = false;
let budgetManager: BudgetManager | null = null;
Expand Down Expand Up @@ -75,7 +76,7 @@ async function standardFetch(
// Process response and track budget BEFORE returning
try {
const responseData = await clonedResponse.json();

// Try to process with adapter registry
const normalizedUsage = adapterRegistry.process(responseData);

Expand Down Expand Up @@ -132,13 +133,9 @@ async function fetchWithRetry(
let currentInput = input;

// Extract original model and provider from request
const { originalModel, provider: originalProvider } = extractModelInfo(input, init);

if (originalModel) {
attemptedModels.push(originalModel);
}
let { originalModel, provider: originalProvider } = extractModelInfo(input, init);

// Parse the original request body for potential cross-provider transformations
// Parse the original request body for potential cross-provider transformations or smart routing
let originalRequestBody: any = null;
if (init?.body) {
try {
Expand All @@ -148,6 +145,48 @@ async function fetchWithRetry(
}
}

// --- SMART ROUTING ---
if (modelRouter && modelRouter.getStrategy() === 'smart') {
const prompt = extractPrompt(originalRequestBody, originalProvider);
if (prompt) {
const detection = classifyTask(
prompt,
modelRouter.getTaskClassification() || defaultTaskClassification,
modelRouter.getConfidenceThreshold() || 0.7
);

const newModel = detection?.selectedModel || modelRouter.getDefaultModel();

if (newModel && newModel !== originalModel) {
const nextProvider = detectProvider(newModel);
if (nextProvider && originalProvider && nextProvider !== originalProvider && modelRouter.isCrossProviderEnabled()) {
const updated = buildCrossProviderRequest(originalRequestBody, originalProvider, nextProvider, newModel);
if (updated) {
currentInput = updated.url;
currentInit = updated.init;
originalModel = newModel;
originalProvider = nextProvider;
if (currentInit.body) {
try { originalRequestBody = JSON.parse(currentInit.body as string); } catch { }
}
}
} else if (nextProvider === originalProvider) {
const updated = updateRequestModel(currentInput, currentInit, newModel, originalProvider);
currentInput = updated.input;
currentInit = updated.init;
originalModel = newModel;
if (currentInit?.body) {
try { originalRequestBody = JSON.parse(currentInit.body as string); } catch { }
}
}
}
}
}

if (originalModel) {
attemptedModels.push(originalModel);
}

while (retryCount <= (modelRouter?.getMaxRetries() || 0)) {
try {
// Make the request
Expand All @@ -163,7 +202,7 @@ async function fetchWithRetry(
} catch {
// Response is not JSON or already consumed
}

// Throw proper Error instance with structured data
const errorObj = {
status: response.status,
Expand Down Expand Up @@ -418,7 +457,7 @@ function extractModelInfo(
// Extract model from request body (for non-Gemini providers)
if (!model) {
let body: string | null = null;

// Get body from init or Request object
if (init?.body) {
body = typeof init.body === 'string' ? init.body : null;
Expand All @@ -427,7 +466,7 @@ function extractModelInfo(
// So we skip body parsing for Request objects without explicit init.body
body = null;
}

if (body) {
try {
const bodyObj = JSON.parse(body);
Expand Down Expand Up @@ -462,7 +501,7 @@ function updateRequestModel(
// Update URL for Gemini (model is in URL path)
if (provider === 'gemini') {
const newUrl = url.replace(/\/models\/[^:]+:/, `/models/${newModel}:`);

// If input was a Request object, we need to create a new Request with updated URL
if (input instanceof Request) {
// Clone the request with new URL
Expand Down
21 changes: 21 additions & 0 deletions src/router/modelRouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@ export class ModelRouter {
private fallbackMap: Record<string, string[]>;
private maxRetries: number;
private crossProviderEnabled: boolean;
private taskClassification?: any;
private confidenceThreshold?: number;
private defaultModel?: string;

constructor(options: ModelRouterOptions) {
this.strategy = options.strategy;
this.fallbackMap = options.fallbackMap || {};
this.maxRetries = options.maxRetries ?? 1;
this.crossProviderEnabled = options.enableCrossProvider ?? false;
this.taskClassification = options.taskClassification;
this.confidenceThreshold = options.confidenceThreshold;
this.defaultModel = options.defaultModel;

// Register API keys if provided
if (options.apiKeys) {
Expand Down Expand Up @@ -137,6 +143,9 @@ export class ModelRouter {
case "cost":
return costStrategy(context, failureType as any);

case "smart":
return { retry: false, reason: "Smart routing handles requests pre-flight, no post-failure fallback defined" };

default:
throw new Error(
`TokenFirewall Router: Unknown strategy "${this.strategy}"`
Expand Down Expand Up @@ -164,4 +173,16 @@ export class ModelRouter {
public isCrossProviderEnabled(): boolean {
return this.crossProviderEnabled;
}

public getTaskClassification() {
return this.taskClassification;
}

public getConfidenceThreshold() {
return this.confidenceThreshold;
}

public getDefaultModel() {
return this.defaultModel;
}
}
123 changes: 123 additions & 0 deletions src/router/smartRouter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import { detectProvider } from "./providerDetector";

export interface TaskClassificationConfig {
[taskType: string]: {
model: string;
reason: string;
keywords?: string[];
patterns?: RegExp[];
priority?: number;
};
}

export interface TaskDetection {
taskType: string;
confidence: number;
selectedModel: string;
reason: string;
}

export const defaultTaskClassification: TaskClassificationConfig = {
code_generation: {
model: "claude-3-5-sonnet-20241022",
reason: "Claude excels at code generation",
keywords: ["write code", "create function", "implement", "build", "develop", "program"],
patterns: [/write.*code/i, /create.*function/i, /implement.*class/i, /write.*function/i],
priority: 10
},
math_reasoning: {
model: "o1-mini",
reason: "o1 designed for reasoning",
keywords: ["calculate", "solve", "compute", "equation", "formula", "math"],
patterns: [/solve.*equation/i, /calculate/i, /mathematical/i],
priority: 9
},
document_analysis: {
model: "gemini-2.5-pro",
reason: "Gemini has 2M token context window",
keywords: ["summarize document", "analyze document", "extract from", "review document"],
patterns: [/summarize.*document/i, /analyze.*pdf/i, /extract.*information/i],
priority: 8
},
simple_chat: {
model: "gpt-4o-mini",
reason: "Cost-effective for simple conversation",
keywords: ["hello", "hi", "how are you", "thanks", "thank you", "help"],
patterns: [/^(hi|hello|hey)/i, /how.*are.*you/i, /thank/i],
priority: 1
}
};

export function extractPrompt(requestBody: any, provider: string | null): string {
if (!requestBody) return "";
let promptText = "";

try {
if ((provider === "openai" || provider === "anthropic") && Array.isArray(requestBody.messages)) {
const msgs = requestBody.messages;
if (msgs.length > 0) {
promptText = msgs[msgs.length - 1].content || "";
}
} else if (provider === "gemini" && Array.isArray(requestBody.contents)) {
const contents = requestBody.contents;
if (contents.length > 0) {
const parts = contents[contents.length - 1].parts;
if (Array.isArray(parts) && parts.length > 0) {
promptText = parts[parts.length - 1].text || "";
}
}
}
} catch (e) {
// Parsing failed, return empty string
}

return promptText;
}

export function classifyTask(
prompt: string,
config: TaskClassificationConfig = defaultTaskClassification,
confidenceThreshold: number = 0.7
): TaskDetection | null {
if (!prompt) return null;

let bestMatch: TaskDetection | null = null;
let highestConfidence = 0;

for (const [taskType, rules] of Object.entries(config)) {
let confidence = 0;

// Keyword matching
if (rules.keywords) {
const matchedKeywords = rules.keywords.filter(kw => prompt.toLowerCase().includes(kw.toLowerCase()));
if (matchedKeywords.length > 0) {
confidence += 0.5 + (0.1 * matchedKeywords.length);
}
}

// Pattern matching
if (rules.patterns) {
for (const pattern of rules.patterns) {
if (pattern.test(prompt)) {
confidence += 0.8;
break;
}
}
}

// Cap confidence
confidence = Math.min(confidence, 1.0);

if (confidence >= confidenceThreshold && confidence > highestConfidence) {
highestConfidence = confidence;
bestMatch = {
taskType,
confidence,
selectedModel: rules.model,
reason: rules.reason
};
}
}

return bestMatch;
}
8 changes: 7 additions & 1 deletion src/router/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
/**
* Routing strategy types
*/
export type RoutingStrategy = "fallback" | "context" | "cost";
export type RoutingStrategy = "fallback" | "context" | "cost" | "smart";

/**
* Failure types detected by error detector
Expand Down Expand Up @@ -43,6 +43,12 @@ export interface ModelRouterOptions {
apiKeys?: ApiKeyConfig;
/** Enable cross-provider fallback (default: false) */
enableCrossProvider?: boolean;
/** Custom task classification config for smart strategy */
taskClassification?: any;
/** Confidence threshold for smart strategy */
confidenceThreshold?: number;
/** Default model for smart strategy if no match found */
defaultModel?: string;
}

/**
Expand Down
Loading