Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 36 additions & 29 deletions src/services/chat/builders/OpenAIContextBuilder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ type ReasoningToolCallLike = {
thought_signature?: string;
};

export class OpenAIContextBuilder implements IContextBuilder {
readonly provider = 'openai';
export class OpenAIContextBuilder implements IContextBuilder {
readonly provider = 'openai';

private getToolName(toolCall: ToolCall): string {
return toolCall.function?.name || toolCall.name || '';
}

/**
* Validate if a message should be included in LLM context
Expand Down Expand Up @@ -96,17 +100,18 @@ export class OpenAIContextBuilder implements IContextBuilder {
});

// Add tool result messages with proper tool_call_id
msg.toolCalls.forEach((toolCall: ToolCall) => {
const resultContent = toolCall.success !== false
? JSON.stringify(toolCall.result || {})
: JSON.stringify({ error: toolCall.error || 'Tool execution failed' });

messages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: resultContent
});
});
msg.toolCalls.forEach((toolCall: ToolCall) => {
const resultContent = toolCall.success !== false
? JSON.stringify(toolCall.result || {})
: JSON.stringify({ error: toolCall.error || 'Tool execution failed' });

messages.push({
role: 'tool',
tool_call_id: toolCall.id,
name: this.getToolName(toolCall),
content: resultContent
});
});
} else {
if (msg.content && msg.content.trim()) {
messages.push({ role: 'assistant', content: msg.content });
Expand Down Expand Up @@ -164,16 +169,17 @@ export class OpenAIContextBuilder implements IContextBuilder {
// Add tool result messages
toolResults.forEach((result, index) => {
const toolCall = toolCalls[index];
const resultContent = result.success
? JSON.stringify(result.result || {})
: JSON.stringify({ error: result.error || 'Tool execution failed' });

messages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: resultContent
});
});
const resultContent = result.success
? JSON.stringify(result.result || {})
: JSON.stringify({ error: result.error || 'Tool execution failed' });

messages.push({
role: 'tool',
tool_call_id: toolCall.id,
name: toolCall.function?.name || '',
content: resultContent
});
});

return messages;
}
Expand Down Expand Up @@ -202,12 +208,13 @@ export class OpenAIContextBuilder implements IContextBuilder {
// Add tool result messages
toolResults.forEach((result, index) => {
const toolCall = toolCalls[index];
messages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: result.success
? JSON.stringify(result.result || {})
: JSON.stringify({ error: result.error || 'Tool execution failed' })
messages.push({
role: 'tool',
tool_call_id: toolCall.id,
name: toolCall.function?.name || '',
content: result.success
? JSON.stringify(result.result || {})
: JSON.stringify({ error: result.error || 'Tool execution failed' })
});
});

Expand Down
243 changes: 237 additions & 6 deletions src/services/llm/adapters/mistral/MistralAdapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ interface MistralMessageContentPart {
}

interface MistralMessage {
role?: string;
content?: string | MistralMessageContentPart[];
name?: string;
tool_call_id?: string;
toolCalls?: Array<Record<string, unknown>>;
tool_calls?: Array<Record<string, unknown>>;
}
Expand Down Expand Up @@ -84,6 +87,15 @@ type MistralStreamChunk = {
};
};

interface MistralNormalizedToolCall {
id: string;
type: 'function';
function: {
name: string;
arguments: string;
};
}

export class MistralAdapter extends BaseAdapter {
readonly name = 'mistral';
readonly baseUrl = 'https://api.mistral.ai';
Expand Down Expand Up @@ -113,6 +125,12 @@ export class MistralAdapter extends BaseAdapter {
*/
async* generateStreamAsync(prompt: string, options?: GenerateOptions): AsyncGenerator<StreamChunk, void, unknown> {
try {
const messages = this.prepareMessages(
options?.conversationHistory && options.conversationHistory.length > 0
? options.conversationHistory
: this.buildMessages(prompt, options?.systemPrompt)
);

const nodeStream = await this.requestStream({
url: `${this.baseUrl}/v1/chat/completions`,
operation: 'streaming generation',
Expand All @@ -123,9 +141,7 @@ export class MistralAdapter extends BaseAdapter {
},
body: JSON.stringify({
model: options?.model || this.currentModel,
messages: options?.conversationHistory && options.conversationHistory.length > 0
? options.conversationHistory
: this.buildMessages(prompt, options?.systemPrompt),
messages,
temperature: options?.temperature,
max_tokens: options?.maxTokens,
top_p: options?.topP,
Expand Down Expand Up @@ -208,13 +224,16 @@ export class MistralAdapter extends BaseAdapter {
*/
private async generateWithChatCompletions(prompt: string, options?: GenerateOptions): Promise<LLMResponse> {
const model = options?.model || this.currentModel;
const messages = this.prepareMessages(
options?.conversationHistory && options.conversationHistory.length > 0
? options.conversationHistory
: this.buildMessages(prompt, options?.systemPrompt)
);

// Build request body with snake_case keys matching the Mistral REST API
const requestBody: Record<string, unknown> = {
model,
messages: options?.conversationHistory && options.conversationHistory.length > 0
? options.conversationHistory
: this.buildMessages(prompt, options?.systemPrompt),
messages,
temperature: options?.temperature,
max_tokens: options?.maxTokens,
top_p: options?.topP,
Expand Down Expand Up @@ -289,6 +308,218 @@ export class MistralAdapter extends BaseAdapter {
});
}

private prepareMessages(messages: Array<Record<string, unknown>>): Array<Record<string, unknown>> {
return this.normalizeMessagesForMistral(messages);
}

private normalizeMessagesForMistral(messages: Array<Record<string, unknown>>): Array<Record<string, unknown>> {
const normalizedMessages: Array<Record<string, unknown>> = [];
const normalizedToolCallIds = new Map<string, string>();
const toolNamesById = new Map<string, string>();

for (const message of messages) {
const role = typeof message.role === 'string' ? message.role : undefined;
if (!role) {
continue;
}

if (role === 'assistant') {
const toolCalls = this.normalizeAssistantToolCalls(message, normalizedToolCallIds, toolNamesById);
normalizedMessages.push({
role,
content: this.stringifyMessageContent(message.content),
...(toolCalls.length > 0 ? { tool_calls: toolCalls } : {})
});
continue;
}

if (role === 'tool') {
const rawToolCallId = typeof message.tool_call_id === 'string' ? message.tool_call_id : '';
const normalizedToolCallId = this.normalizeToolCallId(
rawToolCallId,
normalizedToolCallIds,
`tool-${normalizedMessages.length}`
);
const inferredToolName = this.toOptionalString(message.name) || toolNamesById.get(normalizedToolCallId) || '';

normalizedMessages.push({
role,
tool_call_id: normalizedToolCallId,
content: this.stringifyMessageContent(message.content),
...(inferredToolName ? { name: inferredToolName } : {})
});
continue;
}

normalizedMessages.push({
role,
content: this.stringifyMessageContent(message.content)
});
}

return normalizedMessages;
}

private normalizeAssistantToolCalls(
message: Record<string, unknown>,
normalizedToolCallIds: Map<string, string>,
toolNamesById: Map<string, string>
): MistralNormalizedToolCall[] {
const rawToolCalls = this.getRawToolCalls(message);
const normalizedToolCalls: MistralNormalizedToolCall[] = [];

for (const [index, rawToolCall] of rawToolCalls.entries()) {
if (!rawToolCall || typeof rawToolCall !== 'object') {
continue;
}

const functionPayload = this.getFunctionPayload(rawToolCall);
const toolName = this.toOptionalString(functionPayload?.name);
if (!toolName) {
continue;
}

const normalizedId = this.normalizeToolCallId(
this.toOptionalString((rawToolCall as { id?: unknown }).id),
normalizedToolCallIds,
`${toolName}-${index}`
);
toolNamesById.set(normalizedId, toolName);

normalizedToolCalls.push({
id: normalizedId,
type: 'function',
function: {
name: toolName,
arguments: this.normalizeArguments(functionPayload?.arguments)
}
});
}

return normalizedToolCalls;
}

private getRawToolCalls(message: Record<string, unknown>): Array<Record<string, unknown>> {
const toolCalls = message.tool_calls;
if (Array.isArray(toolCalls)) {
return toolCalls.filter((toolCall): toolCall is Record<string, unknown> => !!toolCall && typeof toolCall === 'object');
}

const camelToolCalls = message.toolCalls;
if (Array.isArray(camelToolCalls)) {
return camelToolCalls.filter((toolCall): toolCall is Record<string, unknown> => !!toolCall && typeof toolCall === 'object');
}

return [];
}

private getFunctionPayload(toolCall: Record<string, unknown>): { name?: string; arguments?: unknown } | undefined {
const rawFunction = toolCall.function;
if (rawFunction && typeof rawFunction === 'object' && !Array.isArray(rawFunction)) {
return rawFunction as { name?: string; arguments?: unknown };
}

const toolName = this.toOptionalString(toolCall.name);
if (!toolName) {
return undefined;
}

return {
name: toolName,
arguments: toolCall.arguments
};
}

private normalizeArguments(argumentsValue: unknown): string {
if (typeof argumentsValue === 'string') {
return argumentsValue;
}
if (argumentsValue === undefined) {
return '{}';
}

try {
return JSON.stringify(argumentsValue);
} catch {
return '{}';
}
}

private normalizeToolCallId(
rawId: string | undefined,
normalizedToolCallIds: Map<string, string>,
fallbackSeed: string
): string {
const originalId = rawId || '';
if (originalId && normalizedToolCallIds.has(originalId)) {
return normalizedToolCallIds.get(originalId) || '';
}

const candidate = originalId.replace(/[^A-Za-z0-9]/g, '');
if (candidate.length === 9) {
normalizedToolCallIds.set(originalId, candidate);
return candidate;
}

const normalizedId = this.generateMistralToolCallId(originalId || fallbackSeed);
if (originalId) {
normalizedToolCallIds.set(originalId, normalizedId);
}
return normalizedId;
}

private generateMistralToolCallId(seed: string): string {
const source = seed || 'mistraltoolcall';
let hash = 0;
for (let index = 0; index < source.length; index++) {
hash = ((hash << 5) - hash) + source.charCodeAt(index);
hash |= 0;
}

const alphabet = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789';
let value = Math.abs(hash);
let output = '';

for (let index = 0; index < 9; index++) {
const charIndex = value % alphabet.length;
output += alphabet.charAt(charIndex);
value = Math.floor(value / alphabet.length);
}

return output;
}

private stringifyMessageContent(content: unknown): string {
if (typeof content === 'string') {
return content;
}

if (Array.isArray(content)) {
return content
.filter((chunk): chunk is MistralMessageContentPart => !!chunk && typeof chunk === 'object')
.filter(chunk => chunk.type === 'text')
.map(chunk => chunk.text || '')
.join('');
}

if (content === undefined || content === null) {
return '';
}

try {
return JSON.stringify(content);
} catch {
if (typeof content === 'number' || typeof content === 'boolean' || typeof content === 'bigint') {
return `${content}`;
}
return '';
}
}

private toOptionalString(value: unknown): string | undefined {
return typeof value === 'string' && value.length > 0 ? value : undefined;
}

private extractToolCalls(message: MistralMessage | undefined): Array<Record<string, unknown>> {
return message?.toolCalls || [];
}
Expand Down
Loading