diff --git a/src/index.ts b/src/index.ts index c769e99..9805a96 100644 --- a/src/index.ts +++ b/src/index.ts @@ -9,6 +9,7 @@ import { contextRegistry } from "./introspection/contextRegistry"; import { ModelRouter } from "./router/modelRouter"; import { ModelRouterOptions, ApiKeyConfig } from "./router/types"; import { apiKeyManager } from "./router/apiKeyManager"; +import { detectByPatterns, defaultPatternDefinitions } from "./router/patternDetector"; let globalBudgetManager: BudgetManager | null = null; let globalModelRouter: ModelRouter | null = null; @@ -227,6 +228,7 @@ export async function listModels(options: Omit b.priority - a.priority + ); + + let timedOut = false; + let best: PatternDetectionResult | null = null; + + for (const definition of definitions) { + if (Date.now() - startedAt > timeoutMs) { + timedOut = true; + break; + } + + const matchedPatterns: string[] = []; + for (const pattern of definition.patterns) { + if (Date.now() - startedAt > timeoutMs) { + timedOut = true; + break; + } + pattern.lastIndex = 0; + if (pattern.test(source)) { + matchedPatterns.push(pattern.toString()); + } + } + + const matchedKeywords = definition.keywords.filter((keyword) => + lowerSource.includes(keyword.toLowerCase()) + ); + + const confidence = calculateConfidence(matchedPatterns.length, matchedKeywords.length); + if (confidence <= 0) { + continue; + } + + const candidate: PatternDetectionResult = { + taskType: definition.taskType, + model: definition.model, + reason: definition.reason, + confidence, + matchedPatterns, + matchedKeywords, + timedOut, + elapsedMs: Date.now() - startedAt, + }; + + if ( + best === null || + candidate.confidence > best.confidence || + (candidate.confidence === best.confidence && + definition.priority > getDefinitionPriority(best.taskType, definitions)) + ) { + best = candidate; + } + + if (timedOut) { + break; + } + } + + if (best && best.confidence >= MIN_RECOMMENDATION_CONFIDENCE) { + return { + ...best, + timedOut, + elapsedMs: Date.now() - startedAt, + }; + } + + return createUnknownResult( + timedOut ? "Pattern detection timed out before a confident match" : "No confident pattern match", + timedOut, + startedAt + ); +} + +function calculateConfidence(patternMatches: number, keywordMatches: number): number { + const score = patternMatches * 0.55 + keywordMatches * 0.12; + return Math.min(0.99, Number(score.toFixed(2))); +} + +function getDefinitionPriority( + taskType: PatternTaskType | "unknown", + definitions: PatternDefinition[] +): number { + return definitions.find((definition) => definition.taskType === taskType)?.priority ?? -1; +} + +function createUnknownResult( + reason: string, + timedOut: boolean, + startedAt: number +): PatternDetectionResult { + return { + taskType: "unknown", + reason, + confidence: 0, + matchedPatterns: [], + matchedKeywords: [], + timedOut, + elapsedMs: Date.now() - startedAt, + }; +} diff --git a/tests/pattern-detector.test.js b/tests/pattern-detector.test.js new file mode 100644 index 0000000..ab5dd85 --- /dev/null +++ b/tests/pattern-detector.test.js @@ -0,0 +1,110 @@ +/** + * Pattern detector tests. + * + * Run: node tests/pattern-detector.test.js + */ + +const { detectByPatterns } = require("../dist/index.js"); + +const colors = { + reset: "\x1b[0m", + green: "\x1b[32m", + red: "\x1b[31m", + cyan: "\x1b[36m", +}; + +const results = { total: 0, passed: 0, failed: 0, errors: [] }; + +function log(message, color = "reset") { + console.log(`${colors[color]}${message}${colors.reset}`); +} + +function record(name, passed, detail = "") { + results.total += 1; + if (passed) { + results.passed += 1; + log(` ✓ ${name}`, "green"); + } else { + results.failed += 1; + results.errors.push({ name, detail }); + log(` ✗ ${name}${detail ? `: ${detail}` : ""}`, "red"); + } +} + +function expectTask(name, prompt, expectedTask) { + const detection = detectByPatterns(prompt); + record( + name, + detection.taskType === expectedTask && detection.confidence > 0, + `expected ${expectedTask}, got ${detection.taskType}` + ); +} + +console.log("\n" + "=".repeat(70)); +log("PATTERN DETECTOR TESTS", "cyan"); +console.log("=".repeat(70)); + +expectTask( + "Detects code generation prompts", + "Write a TypeScript function that validates API keys", + "code_generation" +); + +expectTask( + "Detects code review prompts", + "Please review this code and find potential bugs", + "code_review" +); + +expectTask( + "Detects math reasoning prompts", + "Solve this equation and calculate the compound interest", + "math_reasoning" +); + +expectTask( + "Detects document analysis prompts", + "Analyze this PDF document and extract the key findings", + "document_analysis" +); + +expectTask( + "Detects simple chat prompts", + "Hello, thanks for your help", + "simple_chat" +); + +expectTask( + "Detects data extraction prompts", + "Extract all email addresses and dates from this text", + "data_extraction" +); + +expectTask( + "Detects Chinese language prompts", + "请帮我写一个排序算法", + "chinese_language" +); + +const unknown = detectByPatterns(" "); +record( + "Empty prompt returns unknown", + unknown.taskType === "unknown" && unknown.confidence === 0, + `got ${unknown.taskType} with confidence ${unknown.confidence}` +); + +const longInput = `${"x".repeat(100000)} Write code for a function`; +const bounded = detectByPatterns(longInput, { timeoutMs: 5, maxPromptLength: 1000 }); +record( + "Long input scan stays bounded", + bounded.elapsedMs < 100 && Array.isArray(bounded.matchedPatterns), + `elapsed ${bounded.elapsedMs}ms` +); + +console.log("\n" + "-".repeat(70)); +log(`Passed: ${results.passed}/${results.total}`, results.failed === 0 ? "green" : "red"); + +if (results.failed > 0) { + console.error(results.errors); + process.exit(1); +}