From a28d57daf49924caa8abbba40ac0ba72db28d04b Mon Sep 17 00:00:00 2001 From: ProfSynapse Date: Mon, 6 Apr 2026 18:04:24 -0400 Subject: [PATCH] =?UTF-8?q?fix:=20Mistral=20adapter=20improvements=20?= =?UTF-8?q?=E2=80=94=20tool=20calling,=20model=20catalog,=20streaming?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add prepareMessages() to normalize conversation history for Mistral API (role mapping, tool_call_id, content coercion) - Update MistralModels catalog to April 2026 (Mistral Large 256K context, add Devstral, Codestral, Pixtral, Magistral models) - Fix ProviderHttpClient: add node:http/https streaming fallback with proper timeout and chunked transfer handling - Fix ToolContinuationService: pass empty string instead of userPrompt in recursive tool continuation to avoid duplicate context - Fix OpenAIContextBuilder: add getToolName() helper, improve tool result message formatting - Add unit tests for MistralAdapter, ProviderHttpClient, OpenAIContextBuilder - Add Mistral live integration test Co-Authored-By: Claude Opus 4.6 (1M context) --- .../chat/builders/OpenAIContextBuilder.ts | 65 ++--- .../llm/adapters/mistral/MistralAdapter.ts | 243 +++++++++++++++++- .../llm/adapters/mistral/MistralModels.ts | 140 +++++++++- .../llm/adapters/shared/ProviderHttpClient.ts | 98 ++++++- .../llm/core/ToolContinuationService.ts | 12 +- tests/integration/MistralChatLive.test.ts | 117 +++++++++ tests/unit/MistralAdapter.test.ts | 88 +++++++ tests/unit/OpenAIContextBuilder.test.ts | 59 +++++ tests/unit/ProviderHttpClient.test.ts | 65 +++++ 9 files changed, 819 insertions(+), 68 deletions(-) create mode 100644 tests/integration/MistralChatLive.test.ts create mode 100644 tests/unit/MistralAdapter.test.ts create mode 100644 tests/unit/OpenAIContextBuilder.test.ts diff --git a/src/services/chat/builders/OpenAIContextBuilder.ts b/src/services/chat/builders/OpenAIContextBuilder.ts index e8847058d..72d7a70f2 100644 --- a/src/services/chat/builders/OpenAIContextBuilder.ts +++ b/src/services/chat/builders/OpenAIContextBuilder.ts @@ -28,8 +28,12 @@ type ReasoningToolCallLike = { thought_signature?: string; }; -export class OpenAIContextBuilder implements IContextBuilder { - readonly provider = 'openai'; +export class OpenAIContextBuilder implements IContextBuilder { + readonly provider = 'openai'; + + private getToolName(toolCall: ToolCall): string { + return toolCall.function?.name || toolCall.name || ''; + } /** * Validate if a message should be included in LLM context @@ -96,17 +100,18 @@ export class OpenAIContextBuilder implements IContextBuilder { }); // Add tool result messages with proper tool_call_id - msg.toolCalls.forEach((toolCall: ToolCall) => { - const resultContent = toolCall.success !== false - ? JSON.stringify(toolCall.result || {}) - : JSON.stringify({ error: toolCall.error || 'Tool execution failed' }); - - messages.push({ - role: 'tool', - tool_call_id: toolCall.id, - content: resultContent - }); - }); + msg.toolCalls.forEach((toolCall: ToolCall) => { + const resultContent = toolCall.success !== false + ? JSON.stringify(toolCall.result || {}) + : JSON.stringify({ error: toolCall.error || 'Tool execution failed' }); + + messages.push({ + role: 'tool', + tool_call_id: toolCall.id, + name: this.getToolName(toolCall), + content: resultContent + }); + }); } else { if (msg.content && msg.content.trim()) { messages.push({ role: 'assistant', content: msg.content }); @@ -164,16 +169,17 @@ export class OpenAIContextBuilder implements IContextBuilder { // Add tool result messages toolResults.forEach((result, index) => { const toolCall = toolCalls[index]; - const resultContent = result.success - ? JSON.stringify(result.result || {}) - : JSON.stringify({ error: result.error || 'Tool execution failed' }); - - messages.push({ - role: 'tool', - tool_call_id: toolCall.id, - content: resultContent - }); - }); + const resultContent = result.success + ? JSON.stringify(result.result || {}) + : JSON.stringify({ error: result.error || 'Tool execution failed' }); + + messages.push({ + role: 'tool', + tool_call_id: toolCall.id, + name: toolCall.function?.name || '', + content: resultContent + }); + }); return messages; } @@ -202,12 +208,13 @@ export class OpenAIContextBuilder implements IContextBuilder { // Add tool result messages toolResults.forEach((result, index) => { const toolCall = toolCalls[index]; - messages.push({ - role: 'tool', - tool_call_id: toolCall.id, - content: result.success - ? JSON.stringify(result.result || {}) - : JSON.stringify({ error: result.error || 'Tool execution failed' }) + messages.push({ + role: 'tool', + tool_call_id: toolCall.id, + name: toolCall.function?.name || '', + content: result.success + ? JSON.stringify(result.result || {}) + : JSON.stringify({ error: result.error || 'Tool execution failed' }) }); }); diff --git a/src/services/llm/adapters/mistral/MistralAdapter.ts b/src/services/llm/adapters/mistral/MistralAdapter.ts index 60c5cf5bd..bc3f84a97 100644 --- a/src/services/llm/adapters/mistral/MistralAdapter.ts +++ b/src/services/llm/adapters/mistral/MistralAdapter.ts @@ -49,7 +49,10 @@ interface MistralMessageContentPart { } interface MistralMessage { + role?: string; content?: string | MistralMessageContentPart[]; + name?: string; + tool_call_id?: string; toolCalls?: Array>; tool_calls?: Array>; } @@ -84,6 +87,15 @@ type MistralStreamChunk = { }; }; +interface MistralNormalizedToolCall { + id: string; + type: 'function'; + function: { + name: string; + arguments: string; + }; +} + export class MistralAdapter extends BaseAdapter { readonly name = 'mistral'; readonly baseUrl = 'https://api.mistral.ai'; @@ -113,6 +125,12 @@ export class MistralAdapter extends BaseAdapter { */ async* generateStreamAsync(prompt: string, options?: GenerateOptions): AsyncGenerator { try { + const messages = this.prepareMessages( + options?.conversationHistory && options.conversationHistory.length > 0 + ? options.conversationHistory + : this.buildMessages(prompt, options?.systemPrompt) + ); + const nodeStream = await this.requestStream({ url: `${this.baseUrl}/v1/chat/completions`, operation: 'streaming generation', @@ -123,9 +141,7 @@ export class MistralAdapter extends BaseAdapter { }, body: JSON.stringify({ model: options?.model || this.currentModel, - messages: options?.conversationHistory && options.conversationHistory.length > 0 - ? options.conversationHistory - : this.buildMessages(prompt, options?.systemPrompt), + messages, temperature: options?.temperature, max_tokens: options?.maxTokens, top_p: options?.topP, @@ -208,13 +224,16 @@ export class MistralAdapter extends BaseAdapter { */ private async generateWithChatCompletions(prompt: string, options?: GenerateOptions): Promise { const model = options?.model || this.currentModel; + const messages = this.prepareMessages( + options?.conversationHistory && options.conversationHistory.length > 0 + ? options.conversationHistory + : this.buildMessages(prompt, options?.systemPrompt) + ); // Build request body with snake_case keys matching the Mistral REST API const requestBody: Record = { model, - messages: options?.conversationHistory && options.conversationHistory.length > 0 - ? options.conversationHistory - : this.buildMessages(prompt, options?.systemPrompt), + messages, temperature: options?.temperature, max_tokens: options?.maxTokens, top_p: options?.topP, @@ -289,6 +308,218 @@ export class MistralAdapter extends BaseAdapter { }); } + private prepareMessages(messages: Array>): Array> { + return this.normalizeMessagesForMistral(messages); + } + + private normalizeMessagesForMistral(messages: Array>): Array> { + const normalizedMessages: Array> = []; + const normalizedToolCallIds = new Map(); + const toolNamesById = new Map(); + + for (const message of messages) { + const role = typeof message.role === 'string' ? message.role : undefined; + if (!role) { + continue; + } + + if (role === 'assistant') { + const toolCalls = this.normalizeAssistantToolCalls(message, normalizedToolCallIds, toolNamesById); + normalizedMessages.push({ + role, + content: this.stringifyMessageContent(message.content), + ...(toolCalls.length > 0 ? { tool_calls: toolCalls } : {}) + }); + continue; + } + + if (role === 'tool') { + const rawToolCallId = typeof message.tool_call_id === 'string' ? message.tool_call_id : ''; + const normalizedToolCallId = this.normalizeToolCallId( + rawToolCallId, + normalizedToolCallIds, + `tool-${normalizedMessages.length}` + ); + const inferredToolName = this.toOptionalString(message.name) || toolNamesById.get(normalizedToolCallId) || ''; + + normalizedMessages.push({ + role, + tool_call_id: normalizedToolCallId, + content: this.stringifyMessageContent(message.content), + ...(inferredToolName ? { name: inferredToolName } : {}) + }); + continue; + } + + normalizedMessages.push({ + role, + content: this.stringifyMessageContent(message.content) + }); + } + + return normalizedMessages; + } + + private normalizeAssistantToolCalls( + message: Record, + normalizedToolCallIds: Map, + toolNamesById: Map + ): MistralNormalizedToolCall[] { + const rawToolCalls = this.getRawToolCalls(message); + const normalizedToolCalls: MistralNormalizedToolCall[] = []; + + for (const [index, rawToolCall] of rawToolCalls.entries()) { + if (!rawToolCall || typeof rawToolCall !== 'object') { + continue; + } + + const functionPayload = this.getFunctionPayload(rawToolCall); + const toolName = this.toOptionalString(functionPayload?.name); + if (!toolName) { + continue; + } + + const normalizedId = this.normalizeToolCallId( + this.toOptionalString((rawToolCall as { id?: unknown }).id), + normalizedToolCallIds, + `${toolName}-${index}` + ); + toolNamesById.set(normalizedId, toolName); + + normalizedToolCalls.push({ + id: normalizedId, + type: 'function', + function: { + name: toolName, + arguments: this.normalizeArguments(functionPayload?.arguments) + } + }); + } + + return normalizedToolCalls; + } + + private getRawToolCalls(message: Record): Array> { + const toolCalls = message.tool_calls; + if (Array.isArray(toolCalls)) { + return toolCalls.filter((toolCall): toolCall is Record => !!toolCall && typeof toolCall === 'object'); + } + + const camelToolCalls = message.toolCalls; + if (Array.isArray(camelToolCalls)) { + return camelToolCalls.filter((toolCall): toolCall is Record => !!toolCall && typeof toolCall === 'object'); + } + + return []; + } + + private getFunctionPayload(toolCall: Record): { name?: string; arguments?: unknown } | undefined { + const rawFunction = toolCall.function; + if (rawFunction && typeof rawFunction === 'object' && !Array.isArray(rawFunction)) { + return rawFunction as { name?: string; arguments?: unknown }; + } + + const toolName = this.toOptionalString(toolCall.name); + if (!toolName) { + return undefined; + } + + return { + name: toolName, + arguments: toolCall.arguments + }; + } + + private normalizeArguments(argumentsValue: unknown): string { + if (typeof argumentsValue === 'string') { + return argumentsValue; + } + if (argumentsValue === undefined) { + return '{}'; + } + + try { + return JSON.stringify(argumentsValue); + } catch { + return '{}'; + } + } + + private normalizeToolCallId( + rawId: string | undefined, + normalizedToolCallIds: Map, + fallbackSeed: string + ): string { + const originalId = rawId || ''; + if (originalId && normalizedToolCallIds.has(originalId)) { + return normalizedToolCallIds.get(originalId) || ''; + } + + const candidate = originalId.replace(/[^A-Za-z0-9]/g, ''); + if (candidate.length === 9) { + normalizedToolCallIds.set(originalId, candidate); + return candidate; + } + + const normalizedId = this.generateMistralToolCallId(originalId || fallbackSeed); + if (originalId) { + normalizedToolCallIds.set(originalId, normalizedId); + } + return normalizedId; + } + + private generateMistralToolCallId(seed: string): string { + const source = seed || 'mistraltoolcall'; + let hash = 0; + for (let index = 0; index < source.length; index++) { + hash = ((hash << 5) - hash) + source.charCodeAt(index); + hash |= 0; + } + + const alphabet = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'; + let value = Math.abs(hash); + let output = ''; + + for (let index = 0; index < 9; index++) { + const charIndex = value % alphabet.length; + output += alphabet.charAt(charIndex); + value = Math.floor(value / alphabet.length); + } + + return output; + } + + private stringifyMessageContent(content: unknown): string { + if (typeof content === 'string') { + return content; + } + + if (Array.isArray(content)) { + return content + .filter((chunk): chunk is MistralMessageContentPart => !!chunk && typeof chunk === 'object') + .filter(chunk => chunk.type === 'text') + .map(chunk => chunk.text || '') + .join(''); + } + + if (content === undefined || content === null) { + return ''; + } + + try { + return JSON.stringify(content); + } catch { + if (typeof content === 'number' || typeof content === 'boolean' || typeof content === 'bigint') { + return `${content}`; + } + return ''; + } + } + + private toOptionalString(value: unknown): string | undefined { + return typeof value === 'string' && value.length > 0 ? value : undefined; + } + private extractToolCalls(message: MistralMessage | undefined): Array> { return message?.toolCalls || []; } diff --git a/src/services/llm/adapters/mistral/MistralModels.ts b/src/services/llm/adapters/mistral/MistralModels.ts index 97b359c9a..846e60514 100644 --- a/src/services/llm/adapters/mistral/MistralModels.ts +++ b/src/services/llm/adapters/mistral/MistralModels.ts @@ -1,6 +1,6 @@ /** * Mistral Model Specifications - * Updated June 17, 2025 with latest Mistral releases + * Updated April 6, 2026 against current Mistral Docs model catalog. */ import { ModelSpec } from '../modelTypes'; @@ -10,10 +10,10 @@ export const MISTRAL_MODELS: ModelSpec[] = [ provider: 'mistral', name: 'Mistral Large Latest', apiName: 'mistral-large-latest', - contextWindow: 128000, + contextWindow: 256000, maxTokens: 8192, - inputCostPerMillion: 2.00, - outputCostPerMillion: 6.00, + inputCostPerMillion: 0.50, + outputCostPerMillion: 1.50, capabilities: { supportsJSON: true, supportsImages: true, @@ -40,14 +40,78 @@ export const MISTRAL_MODELS: ModelSpec[] = [ }, { provider: 'mistral', - name: 'Mistral Saba', - apiName: 'mistral-saba-latest', + name: 'Mistral Small Latest', + apiName: 'mistral-small-latest', contextWindow: 128000, - maxTokens: 4096, + maxTokens: 8192, + inputCostPerMillion: 0.10, + outputCostPerMillion: 0.30, + capabilities: { + supportsJSON: true, + supportsImages: true, + supportsFunctions: true, + supportsStreaming: true, + supportsThinking: false + } + }, + { + provider: 'mistral', + name: 'Devstral Latest', + apiName: 'devstral-latest', + contextWindow: 256000, + maxTokens: 8192, + inputCostPerMillion: 0.40, + outputCostPerMillion: 2.00, + capabilities: { + supportsJSON: true, + supportsImages: false, + supportsFunctions: true, + supportsStreaming: true, + supportsThinking: false + } + }, + { + provider: 'mistral', + name: 'Ministral 3 14B Latest', + apiName: 'ministral-14b-latest', + contextWindow: 256000, + maxTokens: 8192, inputCostPerMillion: 0.20, - outputCostPerMillion: 0.60, + outputCostPerMillion: 0.20, + capabilities: { + supportsJSON: true, + supportsImages: true, + supportsFunctions: true, + supportsStreaming: true, + supportsThinking: false + } + }, + { + provider: 'mistral', + name: 'Ministral 3 8B Latest', + apiName: 'ministral-8b-latest', + contextWindow: 256000, + maxTokens: 8192, + inputCostPerMillion: 0.15, + outputCostPerMillion: 0.15, + capabilities: { + supportsJSON: true, + supportsImages: true, + supportsFunctions: true, + supportsStreaming: true, + supportsThinking: false + } + }, + { + provider: 'mistral', + name: 'Ministral 3 3B Latest', + apiName: 'ministral-3b-latest', + contextWindow: 256000, + maxTokens: 8192, + inputCostPerMillion: 0.10, + outputCostPerMillion: 0.10, capabilities: { - supportsJSON: false, + supportsJSON: true, supportsImages: true, supportsFunctions: true, supportsStreaming: true, @@ -58,18 +122,66 @@ export const MISTRAL_MODELS: ModelSpec[] = [ provider: 'mistral', name: 'Magistral Medium', apiName: 'magistral-medium-latest', - contextWindow: 40000, - maxTokens: 40000, + contextWindow: 128000, + maxTokens: 32768, inputCostPerMillion: 2.00, outputCostPerMillion: 5.00, capabilities: { supportsJSON: true, - supportsImages: false, - supportsFunctions: false, + supportsImages: true, + supportsFunctions: true, + supportsStreaming: true, + supportsThinking: true + } + }, + { + provider: 'mistral', + name: 'Magistral Small', + apiName: 'magistral-small-latest', + contextWindow: 128000, + maxTokens: 32768, + inputCostPerMillion: 0.50, + outputCostPerMillion: 1.50, + capabilities: { + supportsJSON: true, + supportsImages: true, + supportsFunctions: true, supportsStreaming: true, supportsThinking: true } + }, + { + provider: 'mistral', + name: 'Codestral Latest', + apiName: 'codestral-latest', + contextWindow: 128000, + maxTokens: 8192, + inputCostPerMillion: 0.30, + outputCostPerMillion: 0.90, + capabilities: { + supportsJSON: true, + supportsImages: false, + supportsFunctions: true, + supportsStreaming: true, + supportsThinking: false + } + }, + { + provider: 'mistral', + name: 'Voxtral Small Latest', + apiName: 'voxtral-small-latest', + contextWindow: 32000, + maxTokens: 8192, + inputCostPerMillion: 0.10, + outputCostPerMillion: 0.30, + capabilities: { + supportsJSON: true, + supportsImages: false, + supportsFunctions: true, + supportsStreaming: true, + supportsThinking: false + } } ]; -export const MISTRAL_DEFAULT_MODEL = 'mistral-large-latest'; \ No newline at end of file +export const MISTRAL_DEFAULT_MODEL = 'mistral-large-latest'; diff --git a/src/services/llm/adapters/shared/ProviderHttpClient.ts b/src/services/llm/adapters/shared/ProviderHttpClient.ts index 84989a089..96ec318a4 100644 --- a/src/services/llm/adapters/shared/ProviderHttpClient.ts +++ b/src/services/llm/adapters/shared/ProviderHttpClient.ts @@ -53,6 +53,29 @@ interface ErrorLikeResponse { json: unknown; } +type NodeRequireLike = (moduleName: string) => unknown; +type NodeHttpModule = { + request: ( + options: { + hostname: string; + port: string | number; + path: string; + method: string; + headers: Record; + }, + callback: (response: NodeJS.ReadableStream & { + statusCode?: number; + statusMessage?: string; + }) => void + ) => { + write: (chunk: string) => void; + end: () => void; + destroy: (error?: Error) => void; + setTimeout: (timeoutMs: number, callback: () => void) => void; + on: (event: string, handler: (error: unknown) => void) => void; + }; +}; + export class ProviderHttpError extends Error { response: ErrorLikeResponse; @@ -85,6 +108,40 @@ function enforceHttps(url: string): void { } export class ProviderHttpClient { + private static getNodeRequire(): NodeRequireLike | null { + const globalRequire = typeof globalThis === 'object' + ? (globalThis as { require?: unknown }).require + : undefined; + const candidate = typeof require === 'function' ? require : globalRequire; + return typeof candidate === 'function' ? (candidate as NodeRequireLike) : null; + } + + private static loadNodeBuiltin(moduleName: string): T | null { + if (!hasNodeRuntime()) { + return null; + } + + const nodeRequire = this.getNodeRequire(); + if (!nodeRequire) { + return null; + } + + try { + return nodeRequire(moduleName) as T; + } catch { + const fallbackName = moduleName.replace(/^node:/, ''); + if (fallbackName === moduleName) { + return null; + } + + try { + return nodeRequire(fallbackName) as T; + } catch { + return null; + } + } + } + static async request( config: ProviderHttpRequest ): Promise> { @@ -187,10 +244,14 @@ export class ProviderHttpClient { const parsed = new URL(config.url); const isHttps = parsed.protocol === 'https:'; - // Dynamically import Node.js modules (available in Electron) - const nodeModule = isHttps - ? await import('node:https') - : await import('node:http'); + // Use CommonJS require here. Dynamic ESM imports of node: builtins are blocked + // in Obsidian's app:// renderer even when Node runtime is available. + const nodeModule = this.loadNodeBuiltin( + isHttps ? 'node:https' : 'node:http' + ); + if (!nodeModule) { + return this.requestStreamBufferedFallback(config); + } const timeoutMs = config.timeoutMs ?? 120_000; @@ -225,7 +286,9 @@ export class ProviderHttpClient { } )); }); - res.on('error', (err) => reject(err)); + res.on('error', (err) => { + reject(err instanceof Error ? err : new Error(String(err))); + }); return; } @@ -293,16 +356,25 @@ export class ProviderHttpClient { ); } - // Wrap the buffered text as a minimal readable stream - const { Readable } = await import('node:stream'); - const readable = new Readable({ - read() { - this.push(Buffer.from(response.text, 'utf-8')); - this.push(null); + // Return a minimal async-iterable stream shape so adapters can keep using + // the same SSE parsing path without depending on node:stream in the renderer. + const bufferedStream: AsyncIterable = { + [Symbol.asyncIterator]() { + let emitted = false; + return { + next: () => { + if (emitted) { + return Promise.resolve({ done: true, value: undefined }); + } + + emitted = true; + return Promise.resolve({ done: false, value: response.text }); + } + }; } - }); + }; - return readable; + return bufferedStream as NodeJS.ReadableStream; } private static async requestOnce( diff --git a/src/services/llm/core/ToolContinuationService.ts b/src/services/llm/core/ToolContinuationService.ts index e75be6320..5827505b7 100644 --- a/src/services/llm/core/ToolContinuationService.ts +++ b/src/services/llm/core/ToolContinuationService.ts @@ -423,12 +423,12 @@ export class ToolContinuationService { } // Build continuation for recursive pingpong - const recursiveContinuationOptions = this.messageBuilder.buildContinuationOptions( - provider, - userPrompt, - recursiveToolCalls, - recursiveToolResults, - previousMessages, + const recursiveContinuationOptions = this.messageBuilder.buildContinuationOptions( + provider, + '', + recursiveToolCalls, + recursiveToolResults, + previousMessages, generateOptions, options ); diff --git a/tests/integration/MistralChatLive.test.ts b/tests/integration/MistralChatLive.test.ts new file mode 100644 index 000000000..1fde62ac3 --- /dev/null +++ b/tests/integration/MistralChatLive.test.ts @@ -0,0 +1,117 @@ +/** + * Live integration test for Mistral chat + tool continuation. + * + * Requires: + * MISTRAL_API_KEY=... + * + * Run: + * source .env && npx jest tests/integration/MistralChatLive.test.ts --runInBand --no-coverage --verbose + */ + +import { MistralAdapter } from '../../src/services/llm/adapters/mistral/MistralAdapter'; +import type { GenerateOptions, ToolCall } from '../../src/services/llm/adapters/types'; + +type StreamCapture = { + content: string; + toolCalls: ToolCall[]; +}; + +async function collectStream( + adapter: MistralAdapter, + prompt: string, + options?: GenerateOptions +): Promise { + let content = ''; + let toolCalls: ToolCall[] = []; + + for await (const chunk of adapter.generateStreamAsync(prompt, options)) { + if (chunk.content) { + content += chunk.content; + } + + if (chunk.toolCalls && chunk.toolCalls.length > 0) { + toolCalls = chunk.toolCalls; + } + } + + return { content, toolCalls }; +} + +const mistralKey = process.env.MISTRAL_API_KEY; + +describe('Live Chat: Mistral tool continuations', () => { + const runTest = mistralKey ? it : it.skip; + + runTest('executes a real tool call and accepts the continuation payload', async () => { + const adapter = new MistralAdapter(mistralKey!); + const tools = [ + { + type: 'function' as const, + function: { + name: 'get_weather', + description: 'Look up the current weather for a city.', + parameters: { + type: 'object', + properties: { + city: { + type: 'string', + description: 'City name' + } + }, + required: ['city'] + } + } + } + ]; + + const firstPass = await collectStream( + adapter, + 'You must call the get_weather tool exactly once for New York City. Do not answer directly before the tool call.', + { + model: 'mistral-large-latest', + tools, + temperature: 0 + } + ); + + expect(firstPass.toolCalls.length).toBeGreaterThan(0); + expect(firstPass.toolCalls[0]?.function.name).toBe('get_weather'); + + const firstToolCall = firstPass.toolCalls[0]; + const continuationHistory = [ + { + role: 'assistant', + content: '', + tool_calls: [ + { + id: firstToolCall.id, + type: 'function', + function: { + name: firstToolCall.function.name, + arguments: firstToolCall.function.arguments + } + } + ] + }, + { + role: 'tool', + tool_call_id: firstToolCall.id, + name: firstToolCall.function.name, + content: JSON.stringify({ + city: 'New York City', + temperature_f: 72, + condition: 'Sunny' + }) + } + ]; + + const secondPass = await collectStream(adapter, '', { + model: 'mistral-large-latest', + tools, + conversationHistory: continuationHistory, + temperature: 0 + }); + + expect(secondPass.content.trim().length).toBeGreaterThan(0); + }, 120_000); +}); diff --git a/tests/unit/MistralAdapter.test.ts b/tests/unit/MistralAdapter.test.ts new file mode 100644 index 000000000..930897a1a --- /dev/null +++ b/tests/unit/MistralAdapter.test.ts @@ -0,0 +1,88 @@ +import { MistralAdapter } from '../../src/services/llm/adapters/mistral/MistralAdapter'; +import { GenerateOptions, StreamChunk } from '../../src/services/llm/adapters/types'; + +class TestMistralAdapter extends MistralAdapter { + lastStreamBody: Record | undefined; + + protected override requestStream(config: { + body: string; + }): Promise { + this.lastStreamBody = JSON.parse(config.body) as Record; + return Promise.resolve({} as NodeJS.ReadableStream); + } + + protected override async* processNodeStream(): AsyncGenerator { + yield { + content: '', + complete: true + }; + } +} + +describe('MistralAdapter', () => { + it('normalizes continuation messages to Mistral-safe tool payloads', async () => { + const adapter = new TestMistralAdapter('test-key'); + const conversationHistory = [ + { + role: 'assistant', + content: '', + tool_calls: [ + { + id: 'tool-call_12345', + type: 'function', + function: { + name: 'storageManager_list', + arguments: '{"path":"/"}' + }, + index: 0, + reasoning_details: [{ ignored: true }] + } + ] + }, + { + role: 'tool', + tool_call_id: 'tool-call_12345', + content: '{"files":[]}' + } + ]; + + const options: GenerateOptions = { + model: 'mistral-large-latest', + conversationHistory + }; + + for await (const _chunk of adapter.generateStreamAsync('', options)) { + // Exhaust the generator to capture the serialized request body. + } + + const body = adapter.lastStreamBody; + expect(body).toBeDefined(); + + const messages = body?.messages as Array>; + expect(Array.isArray(messages)).toBe(true); + expect(messages).toHaveLength(2); + + const assistantMessage = messages[0]; + const toolMessage = messages[1]; + const assistantToolCalls = assistantMessage?.tool_calls as Array>; + const assistantToolCall = assistantToolCalls?.[0]; + const normalizedId = assistantToolCall?.id as string; + + expect(normalizedId).toMatch(/^[A-Za-z0-9]{9}$/); + expect(assistantToolCall).toEqual({ + id: normalizedId, + type: 'function', + function: { + name: 'storageManager_list', + arguments: '{"path":"/"}' + } + }); + + expect(toolMessage).toEqual({ + role: 'tool', + tool_call_id: normalizedId, + name: 'storageManager_list', + content: '{"files":[]}' + }); + }); +}); diff --git a/tests/unit/OpenAIContextBuilder.test.ts b/tests/unit/OpenAIContextBuilder.test.ts new file mode 100644 index 000000000..96d9bcca1 --- /dev/null +++ b/tests/unit/OpenAIContextBuilder.test.ts @@ -0,0 +1,59 @@ +import { OpenAIContextBuilder } from '../../src/services/chat/builders/OpenAIContextBuilder'; +import { ConversationData } from '../../src/types/chat/ChatTypes'; + +describe('OpenAIContextBuilder', () => { + it('includes tool names on tool result messages', () => { + const builder = new OpenAIContextBuilder(); + const conversation: ConversationData = { + id: 'conv-1', + title: 'Test', + created: 1, + updated: 1, + messages: [ + { + id: 'assistant-1', + role: 'assistant', + content: '', + timestamp: 1, + conversationId: 'conv-1', + toolCalls: [ + { + id: 'call-1', + type: 'function', + function: { + name: 'storageManager_list', + arguments: '{"path":"/"}' + }, + result: { files: [] }, + success: true + } + ] + } + ] + }; + + const messages = builder.buildContext(conversation); + expect(messages).toEqual([ + { + role: 'assistant', + content: '', + tool_calls: [ + { + id: 'call-1', + type: 'function', + function: { + name: 'storageManager_list', + arguments: '{"path":"/"}' + } + } + ] + }, + { + role: 'tool', + tool_call_id: 'call-1', + name: 'storageManager_list', + content: '{"files":[]}' + } + ]); + }); +}); diff --git a/tests/unit/ProviderHttpClient.test.ts b/tests/unit/ProviderHttpClient.test.ts index beda3b62b..b01b2dad0 100644 --- a/tests/unit/ProviderHttpClient.test.ts +++ b/tests/unit/ProviderHttpClient.test.ts @@ -12,6 +12,11 @@ describe('ProviderHttpClient', () => { message: string; }; + type ProviderHttpClientInternals = { + getNodeRequire: () => ((moduleName: string) => unknown) | null; + loadNodeBuiltin: (moduleName: string) => T | null; + }; + beforeEach(() => { __setRequestUrlMock(async () => ({ status: 200, @@ -280,4 +285,64 @@ describe('ProviderHttpClient', () => { ) ).toThrow('Custom validation error'); }); + + it('falls back from node: builtin names to bare module names', () => { + const requireMock = jest.fn((moduleName: string) => { + if (moduleName === 'node:http') { + throw new Error('node: prefix unsupported'); + } + + if (moduleName === 'http') { + return { request: jest.fn() }; + } + + throw new Error(`unexpected module ${moduleName}`); + }); + + const internals = ProviderHttpClient as unknown as ProviderHttpClientInternals; + const requireSpy = jest.spyOn(internals, 'getNodeRequire').mockReturnValue(requireMock); + + try { + const module = internals.loadNodeBuiltin<{ request: unknown }>('node:http'); + expect(module).toEqual({ request: expect.any(Function) }); + expect(requireMock).toHaveBeenNthCalledWith(1, 'node:http'); + expect(requireMock).toHaveBeenNthCalledWith(2, 'http'); + } finally { + requireSpy.mockRestore(); + } + }); + + it('uses buffered fallback when node transport is unavailable', async () => { + __setRequestUrlMock(async () => ({ + status: 200, + headers: { 'content-type': 'text/event-stream' }, + text: 'data: {"choices":[{"delta":{"content":"hello"}}]}\n\ndata: [DONE]\n\n', + json: null, + arrayBuffer: new ArrayBuffer(0) + })); + + const internals = ProviderHttpClient as unknown as ProviderHttpClientInternals; + const loadBuiltinSpy = jest.spyOn(internals, 'loadNodeBuiltin').mockReturnValue(null); + + try { + const stream = await ProviderHttpClient.requestStream({ + url: 'http://127.0.0.1:1234/v1/chat/completions', + provider: 'lmstudio', + operation: 'streaming generation', + method: 'POST', + body: '{}' + }); + + const chunks: string[] = []; + for await (const chunk of stream as AsyncIterable) { + chunks.push(chunk); + } + + expect(chunks).toEqual([ + 'data: {"choices":[{"delta":{"content":"hello"}}]}\n\ndata: [DONE]\n\n' + ]); + } finally { + loadBuiltinSpy.mockRestore(); + } + }); });