diff --git a/.changeset/parallel-tool-race-fix.md b/.changeset/parallel-tool-race-fix.md new file mode 100644 index 000000000000..b2241f3512d9 --- /dev/null +++ b/.changeset/parallel-tool-race-fix.md @@ -0,0 +1,5 @@ +--- +"ai": patch +--- + +fix(ai): resolve race condition in parallel tool execution causing stream errors diff --git a/packages/ai/src/generate-text/run-tools-transformation.test.ts b/packages/ai/src/generate-text/run-tools-transformation.test.ts index 6adade0f83d5..18470fac5ff8 100644 --- a/packages/ai/src/generate-text/run-tools-transformation.test.ts +++ b/packages/ai/src/generate-text/run-tools-transformation.test.ts @@ -1140,4 +1140,320 @@ describe('runToolsTransformation', () => { }); }); }); + + describe('parallel tool execution', () => { + it('should use toolCallId for tracking (not generateId) to handle parallel tools correctly', async () => { + // This test exposes the bug where generateId() returns the same value for all tools + // in a batch, causing the outstandingToolResults Set to only track one tool. + // The fix uses toolCall.toolCallId instead which is unique per tool call. + const constantIdGenerator = () => 'same-id-for-all'; + + const inputStream: ReadableStream = + convertArrayToReadableStream([ + { + type: 'tool-call', + toolCallId: 'unique-call-1', + toolName: 'toolA', + input: `{ "value": "a" }`, + }, + { + type: 'tool-call', + toolCallId: 'unique-call-2', + toolName: 'toolB', + input: `{ "value": "b" }`, + }, + { + type: 'tool-call', + toolCallId: 'unique-call-3', + toolName: 'toolC', + input: `{ "value": "c" }`, + }, + { + type: 'finish', + finishReason: { unified: 'tool_calls', raw: 'tool_calls' }, + usage: testUsage, + }, + ]); + + const transformedStream = runToolsTransformation({ + generateId: constantIdGenerator, + tools: { + toolA: { + title: 'Tool A', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(30); + return `${value}-result`; + }, + }, + toolB: { + title: 'Tool B', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(10); + return `${value}-result`; + }, + }, + toolC: { + title: 'Tool C', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(20); + return `${value}-result`; + }, + }, + }, + generatorStream: inputStream, + tracer: new MockTracer(), + telemetry: undefined, + messages: [], + system: undefined, + abortSignal: undefined, + repairToolCall: undefined, + experimental_context: undefined, + }); + + const result = await convertReadableStreamToArray(transformedStream); + + // All three tool results should be captured + // (Bug: without the fix, only 1 result would be captured because + // outstandingToolResults Set would use the same ID for all tools) + const toolResults = result.filter(r => r.type === 'tool-result'); + expect(toolResults).toHaveLength(3); + expect(toolResults.map(r => r.toolCallId).sort()).toEqual([ + 'unique-call-1', + 'unique-call-2', + 'unique-call-3', + ]); + + // Finish should be last + expect(result[result.length - 1]).toMatchObject({ + type: 'finish', + }); + }); + + it('should capture all results when multiple tools execute in parallel with different delays', async () => { + const inputStream: ReadableStream = + convertArrayToReadableStream([ + { + type: 'tool-call', + toolCallId: 'call-1', + toolName: 'slowTool', + input: `{ "value": "slow" }`, + }, + { + type: 'tool-call', + toolCallId: 'call-2', + toolName: 'fastTool', + input: `{ "value": "fast" }`, + }, + { + type: 'tool-call', + toolCallId: 'call-3', + toolName: 'mediumTool', + input: `{ "value": "medium" }`, + }, + { + type: 'finish', + finishReason: { unified: 'tool_calls', raw: 'tool_calls' }, + usage: testUsage, + }, + ]); + + const transformedStream = runToolsTransformation({ + generateId: mockId({ prefix: 'id' }), + tools: { + slowTool: { + title: 'Slow Tool', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(50); // Slowest + return `${value}-result`; + }, + }, + fastTool: { + title: 'Fast Tool', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(10); // Fastest + return `${value}-result`; + }, + }, + mediumTool: { + title: 'Medium Tool', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(30); // Medium + return `${value}-result`; + }, + }, + }, + generatorStream: inputStream, + tracer: new MockTracer(), + telemetry: undefined, + messages: [], + system: undefined, + abortSignal: undefined, + repairToolCall: undefined, + experimental_context: undefined, + }); + + const result = await convertReadableStreamToArray(transformedStream); + + // All three tool calls should be present + const toolCalls = result.filter(r => r.type === 'tool-call'); + expect(toolCalls).toHaveLength(3); + + // All three tool results should be present + const toolResults = result.filter(r => r.type === 'tool-result'); + expect(toolResults).toHaveLength(3); + expect(toolResults.map(r => r.toolCallId).sort()).toEqual([ + 'call-1', + 'call-2', + 'call-3', + ]); + + // Finish should be last + expect(result[result.length - 1]).toMatchObject({ + type: 'finish', + }); + }); + + it('should not close stream prematurely when fast tool completes before slow tool', async () => { + const executionOrder: string[] = []; + + const inputStream: ReadableStream = + convertArrayToReadableStream([ + { + type: 'tool-call', + toolCallId: 'slow-call', + toolName: 'slowTool', + input: `{ "value": "slow" }`, + }, + { + type: 'tool-call', + toolCallId: 'fast-call', + toolName: 'fastTool', + input: `{ "value": "fast" }`, + }, + { + type: 'finish', + finishReason: { unified: 'tool_calls', raw: 'tool_calls' }, + usage: testUsage, + }, + ]); + + const transformedStream = runToolsTransformation({ + generateId: mockId({ prefix: 'id' }), + tools: { + slowTool: { + title: 'Slow Tool', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(50); + executionOrder.push('slow-completed'); + return `${value}-slow-result`; + }, + }, + fastTool: { + title: 'Fast Tool', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(5); + executionOrder.push('fast-completed'); + return `${value}-fast-result`; + }, + }, + }, + generatorStream: inputStream, + tracer: new MockTracer(), + telemetry: undefined, + messages: [], + system: undefined, + abortSignal: undefined, + repairToolCall: undefined, + experimental_context: undefined, + }); + + const result = await convertReadableStreamToArray(transformedStream); + + // Fast tool should complete first + expect(executionOrder).toEqual(['fast-completed', 'slow-completed']); + + // Both results should be captured + const toolResults = result.filter(r => r.type === 'tool-result'); + expect(toolResults).toHaveLength(2); + expect(toolResults.map(r => r.output).sort()).toEqual([ + 'fast-fast-result', + 'slow-slow-result', + ]); + + // Stream should close properly after all tools complete + expect(result[result.length - 1]).toMatchObject({ + type: 'finish', + }); + }); + + it('should handle many parallel tool calls without losing results', async () => { + const toolCount = 10; + const toolCalls = Array.from({ length: toolCount }, (_, i) => ({ + type: 'tool-call' as const, + toolCallId: `call-${i}`, + toolName: 'parallelTool', + input: `{ "index": ${i} }`, + })); + + const inputStream: ReadableStream = + convertArrayToReadableStream([ + ...toolCalls, + { + type: 'finish', + finishReason: { unified: 'tool_calls', raw: 'tool_calls' }, + usage: testUsage, + }, + ]); + + const transformedStream = runToolsTransformation({ + generateId: mockId({ prefix: 'id' }), + tools: { + parallelTool: { + title: 'Parallel Tool', + inputSchema: z.object({ index: z.number() }), + execute: async ({ index }) => { + // Random delay to simulate real-world variance + await delay(Math.random() * 20); + return `result-${index}`; + }, + }, + }, + generatorStream: inputStream, + tracer: new MockTracer(), + telemetry: undefined, + messages: [], + system: undefined, + abortSignal: undefined, + repairToolCall: undefined, + experimental_context: undefined, + }); + + const result = await convertReadableStreamToArray(transformedStream); + + // All tool results should be captured + const toolResults = result.filter(r => r.type === 'tool-result'); + expect(toolResults).toHaveLength(toolCount); + + // Verify all results are present (order may vary) + const resultOutputs = toolResults.map(r => r.output).sort(); + const expectedOutputs = Array.from( + { length: toolCount }, + (_, i) => `result-${i}`, + ).sort(); + expect(resultOutputs).toEqual(expectedOutputs); + + // Finish should be last + expect(result[result.length - 1]).toMatchObject({ + type: 'finish', + }); + }); + }); }); diff --git a/packages/ai/src/generate-text/run-tools-transformation.ts b/packages/ai/src/generate-text/run-tools-transformation.ts index 29865865aad7..e0fe75a9d433 100644 --- a/packages/ai/src/generate-text/run-tools-transformation.ts +++ b/packages/ai/src/generate-text/run-tools-transformation.ts @@ -150,13 +150,23 @@ export function runToolsTransformation({ const toolCallsByToolCallId = new Map>(); let canClose = false; + let closed = false; // Prevent race condition when multiple tools complete simultaneously let finishChunk: | (SingleRequestTextStreamPart & { type: 'finish' }) | undefined = undefined; function attemptClose() { + // Prevent re-entry: if already closed, nothing to do + if (closed) { + return; + } + // close the tool results controller if no more outstanding tool calls if (canClose && outstandingToolResults.size === 0) { + // Mark as closed BEFORE doing any work to prevent race conditions + // where multiple finally() blocks call attemptClose() simultaneously + closed = true; + // we delay sending the finish chunk until all tool results (incl. delayed ones) // are received to ensure that the frontend receives tool results before a message // finish event arrives. @@ -309,7 +319,9 @@ export function runToolsTransformation({ // Only execute tools that are not provider-executed: if (tool.execute != null && toolCall.providerExecuted !== true) { - const toolExecutionId = generateId(); // use our own id to guarantee uniqueness + // Use toolCallId which is unique per tool call from the LLM + // (generateId() was returning the same value for multiple tools in a batch) + const toolExecutionId = toolCall.toolCallId; outstandingToolResults.add(toolExecutionId); // Note: we don't await the tool execution here (by leaving out 'await' on recordSpan), @@ -324,17 +336,26 @@ export function runToolsTransformation({ abortSignal, experimental_context, onPreliminaryToolResult: result => { - toolResultsStreamController!.enqueue(result); + // Guard against enqueueing after stream is closed (parallel tool race) + if (!closed) { + toolResultsStreamController!.enqueue(result); + } }, }) .then(result => { - toolResultsStreamController!.enqueue(result); + // Guard against enqueueing after stream is closed (parallel tool race) + if (!closed) { + toolResultsStreamController!.enqueue(result); + } }) .catch(error => { - toolResultsStreamController!.enqueue({ - type: 'error', - error, - }); + // Guard against enqueueing after stream is closed (parallel tool race) + if (!closed) { + toolResultsStreamController!.enqueue({ + type: 'error', + error, + }); + } }) .finally(() => { outstandingToolResults.delete(toolExecutionId);