diff --git a/package-lock.json b/package-lock.json index 81a0b60e..cab349f1 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10,6 +10,7 @@ "license": "Apache-2.0", "devDependencies": { "@ai-sdk/anthropic": "^3.0.13", + "@ai-sdk/google": "^3.0.30", "@ai-sdk/groq": "^3.0.8", "@ai-sdk/openai": "^3.0.10", "@ai-sdk/provider": "^3.0.3", @@ -172,6 +173,51 @@ "node": ">= 20" } }, + "node_modules/@ai-sdk/google": { + "version": "3.0.30", + "resolved": "https://registry.npmjs.org/@ai-sdk/google/-/google-3.0.30.tgz", + "integrity": "sha512-ZzG6dU0XUSSXbxQJJTQUFpWeKkfzdpR7IykEZwaiaW5d+3u3RZ/zkRiGwAOcUpLp6k0eMd+IJF4looJv21ecxw==", + "dev": true, + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.15" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/google/node_modules/@ai-sdk/provider": { + "version": "3.0.8", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.8.tgz", + "integrity": "sha512-oGMAgGoQdBXbZqNG0Ze56CHjDZ1IDYOwGYxYjO5KLSlz5HiNQ9udIXsPZ61VWaHGZ5XW/jyjmr6t2xz2jGVwbQ==", + "dev": true, + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/google/node_modules/@ai-sdk/provider-utils": { + "version": "4.0.15", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.15.tgz", + "integrity": "sha512-8XiKWbemmCbvNN0CLR9u3PQiet4gtEVIrX4zzLxnCj06AwsEDJwJVBbKrEI4t6qE8XRSIvU2irka0dcpziKW6w==", + "dev": true, + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.6" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, "node_modules/@ai-sdk/groq": { "version": "3.0.8", "resolved": "https://registry.npmjs.org/@ai-sdk/groq/-/groq-3.0.8.tgz", diff --git a/package.json b/package.json index 2c8c5d1b..1caa40fc 100644 --- a/package.json +++ b/package.json @@ -85,6 +85,7 @@ }, "devDependencies": { "@ai-sdk/anthropic": "^3.0.13", + "@ai-sdk/google": "^3.0.30", "@ai-sdk/groq": "^3.0.8", "@ai-sdk/openai": "^3.0.10", "@ai-sdk/provider": "^3.0.3", diff --git a/src/client/approval.test.ts b/src/client/approval.test.ts new file mode 100644 index 00000000..371b93d1 --- /dev/null +++ b/src/client/approval.test.ts @@ -0,0 +1,254 @@ +import { describe, expect, test } from "vitest"; +import { Agent, createTool } from "./index.js"; +import type { + DataModelFromSchemaDefinition, + ApiFromModules, + ActionBuilder, +} from "convex/server"; +import { anyApi, actionGeneric } from "convex/server"; +import { defineSchema } from "convex/server"; +import { stepCountIs, type LanguageModelUsage } from "ai"; +import { components, initConvexTest } from "./setup.test.js"; +import { z } from "zod/v4"; +import { mockModel } from "./mockModel.js"; +import type { UsageHandler } from "./types.js"; + +const schema = defineSchema({}); +type DataModel = DataModelFromSchemaDefinition; +const action = actionGeneric as ActionBuilder; + +// Tool that always requires approval +const deleteFileTool = createTool({ + description: "Delete a file", + inputSchema: z.object({ filename: z.string() }), + needsApproval: () => true, + execute: async (_ctx, input) => `Deleted: ${input.filename}`, +}); + +// Track usage handler calls to verify the full flow is exercised +const usageCalls: LanguageModelUsage[] = []; +const testUsageHandler: UsageHandler = async (_ctx, args) => { + usageCalls.push(args.usage); +}; + +function getApprovalIdFromSavedMessages( + savedMessages: + | Array<{ + message?: { content: unknown }; + }> + | undefined, +): string { + const approvalRequest = savedMessages + ?.flatMap((savedMessage) => + Array.isArray(savedMessage.message?.content) + ? savedMessage.message.content + : [], + ) + .find((part) => { + const maybeApproval = part as { type?: unknown }; + return maybeApproval.type === "tool-approval-request"; + }) as { approvalId?: unknown } | undefined; + if (typeof approvalRequest?.approvalId !== "string") { + throw new Error("No approval request found in saved messages"); + } + return approvalRequest.approvalId; +} + +// --- Agents (separate mock model instances to avoid shared callIndex) --- + +const approvalAgent = new Agent(components.agent, { + name: "approval-test", + instructions: "You delete files when asked.", + tools: { deleteFile: deleteFileTool }, + languageModel: mockModel({ + contentSteps: [ + // Step 1: model makes a tool call (LanguageModelV3 uses `input` as JSON string) + [ + { + type: "tool-call", + toolCallId: "tc-approve", + toolName: "deleteFile", + input: JSON.stringify({ filename: "test.txt" }), + }, + ], + // Step 2: after tool execution, model responds with text + [{ type: "text", text: "Done! I deleted test.txt." }], + ], + }), + stopWhen: stepCountIs(5), + usageHandler: testUsageHandler, +}); + +const denialAgent = new Agent(components.agent, { + name: "denial-test", + instructions: "You delete files when asked.", + tools: { deleteFile: deleteFileTool }, + languageModel: mockModel({ + contentSteps: [ + [ + { + type: "tool-call", + toolCallId: "tc-deny", + toolName: "deleteFile", + input: JSON.stringify({ filename: "secret.txt" }), + }, + ], + [{ type: "text", text: "OK, I won't delete that file." }], + ], + }), + stopWhen: stepCountIs(5), + usageHandler: testUsageHandler, +}); + +// --- Test helpers --- + +export const testApproveFlow = action({ + args: {}, + handler: async (ctx) => { + const { thread } = await approvalAgent.createThread(ctx, { userId: "u1" }); + + // Step 1: Generate text — model returns tool call, SDK sees needsApproval → stops + const result1 = await thread.generateText({ + prompt: "Delete test.txt", + }); + + const approvalId = getApprovalIdFromSavedMessages(result1.savedMessages); + + // Step 2: Approve the tool call + const { messageId } = await approvalAgent.approveToolCall(ctx, { + threadId: thread.threadId, + approvalId, + }); + + // Step 3: Continue generation — SDK executes tool, model responds + const result2 = await thread.generateText({ + promptMessageId: messageId, + }); + + // Verify thread has all messages persisted + const allMessages = await approvalAgent.listMessages(ctx, { + threadId: thread.threadId, + paginationOpts: { cursor: null, numItems: 20 }, + }); + + return { + approvalId, + firstText: result1.text, + secondText: result2.text, + firstSavedCount: result1.savedMessages?.length ?? 0, + secondSavedCount: result2.savedMessages?.length ?? 0, + totalThreadMessages: allMessages.page.length, + threadMessageRoles: allMessages.page.map((m) => m.message?.role), + usageCallCount: usageCalls.length, + // Verify usage data includes detail fields (AI SDK v6) + lastUsage: usageCalls.at(-1), + }; + }, +}); + +export const testDenyFlow = action({ + args: {}, + handler: async (ctx) => { + const { thread } = await denialAgent.createThread(ctx, { userId: "u2" }); + + // Step 1: Generate — model returns tool call, approval requested + const result1 = await thread.generateText({ + prompt: "Delete secret.txt", + }); + + const approvalId = getApprovalIdFromSavedMessages(result1.savedMessages); + + // Step 2: Deny the tool call + const { messageId } = await denialAgent.denyToolCall(ctx, { + threadId: thread.threadId, + approvalId, + reason: "This file is important", + }); + + // Step 3: Continue generation — SDK creates execution-denied, model responds + const result2 = await thread.generateText({ + promptMessageId: messageId, + }); + + // Verify thread state + const allMessages = await denialAgent.listMessages(ctx, { + threadId: thread.threadId, + paginationOpts: { cursor: null, numItems: 20 }, + }); + + return { + approvalId, + firstText: result1.text, + secondText: result2.text, + totalThreadMessages: allMessages.page.length, + threadMessageRoles: allMessages.page.map((m) => m.message?.role), + usageCallCount: usageCalls.length, + lastUsage: usageCalls.at(-1), + }; + }, +}); + +const testApi: ApiFromModules<{ + fns: { + testApproveFlow: typeof testApproveFlow; + testDenyFlow: typeof testDenyFlow; + }; +}>["fns"] = anyApi["approval.test"] as any; + +describe("Tool Approval Workflow", () => { + test("approve: generate → approval request → approve → tool executes → final text", async () => { + usageCalls.length = 0; + const t = initConvexTest(schema); + const result = await t.action(testApi.testApproveFlow, {}); + + expect(result.approvalId).toBeDefined(); + // First call produces no text (just a tool call) + expect(result.firstText).toBe(""); + // Second call produces the final text + expect(result.secondText).toBe("Done! I deleted test.txt."); + // First call: user message + assistant (tool-call + approval-request) + expect(result.firstSavedCount).toBeGreaterThanOrEqual(2); + // Second call: tool-result + assistant text + expect(result.secondSavedCount).toBeGreaterThanOrEqual(1); + // Thread should have (ascending): user, assistant(tool-call+approval), + // tool(approval-response), tool(tool-result), assistant(text) + // listMessages returns descending order: + expect(result.threadMessageRoles).toEqual([ + "assistant", // final text + "tool", // tool-result + "tool", // approval-response + "assistant", // tool-call + approval-request + "user", // prompt + ]); + // Usage handler should be called for each generateText call + expect(result.usageCallCount).toBeGreaterThanOrEqual(2); + // Usage data should include AI SDK v6 detail fields + expect(result.lastUsage).toBeDefined(); + expect(result.lastUsage!.inputTokenDetails).toBeDefined(); + expect(result.lastUsage!.outputTokenDetails).toBeDefined(); + }); + + test("deny: generate → approval request → deny → model acknowledges denial", async () => { + usageCalls.length = 0; + const t = initConvexTest(schema); + const result = await t.action(testApi.testDenyFlow, {}); + + expect(result.approvalId).toBeDefined(); + expect(result.firstText).toBe(""); + expect(result.secondText).toBe("OK, I won't delete that file."); + // Same message ordering as approve flow: + // user, assistant(tool-call+approval), tool(denial-response), + // tool(execution-denied result), assistant(text) + expect(result.threadMessageRoles).toEqual([ + "assistant", + "tool", + "tool", + "assistant", + "user", + ]); + // Usage handler exercised + expect(result.usageCallCount).toBeGreaterThanOrEqual(2); + expect(result.lastUsage!.inputTokenDetails).toBeDefined(); + expect(result.lastUsage!.outputTokenDetails).toBeDefined(); + }); +}); diff --git a/src/client/index.ts b/src/client/index.ts index 75222764..3678e1a5 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1016,6 +1016,74 @@ export class Agent< ); } + /** + * Approve a tool call that requires human approval. + * Saves a `tool-approval-response` message to the thread. + * After calling this, call `agent.streamText` or `agent.generateText` + * with `promptMessageId` set to the returned `messageId` to continue + * generation — the AI SDK will automatically execute the approved tool. + * + * @param ctx A ctx object from a mutation or action. + * @param args.threadId The thread containing the tool call. + * @param args.approvalId The approval ID from the tool-approval-request part. + * @param args.reason Optional reason for approval. + * @returns The messageId of the saved approval response message. + */ + async approveToolCall( + ctx: MutationCtx | ActionCtx, + args: { threadId: string; approvalId: string; reason?: string }, + ): Promise<{ messageId: string }> { + return this.respondToToolCallApproval(ctx, { ...args, approved: true }); + } + + /** + * Deny a tool call that requires human approval. + * Saves a `tool-approval-response` message to the thread. + * After calling this, call `agent.streamText` or `agent.generateText` + * with `promptMessageId` set to the returned `messageId` to continue + * generation — the AI SDK will automatically create an `execution-denied` + * result and let the model respond accordingly. + * + * @param ctx A ctx object from a mutation or action. + * @param args.threadId The thread containing the tool call. + * @param args.approvalId The approval ID from the tool-approval-request part. + * @param args.reason Optional reason for denial. + * @returns The messageId of the saved denial response message. + */ + async denyToolCall( + ctx: MutationCtx | ActionCtx, + args: { threadId: string; approvalId: string; reason?: string }, + ): Promise<{ messageId: string }> { + return this.respondToToolCallApproval(ctx, { ...args, approved: false }); + } + + private async respondToToolCallApproval( + ctx: MutationCtx | ActionCtx, + args: { + threadId: string; + approvalId: string; + approved: boolean; + reason?: string; + }, + ): Promise<{ messageId: string }> { + const { messageId } = await this.saveMessage(ctx, { + threadId: args.threadId, + skipEmbeddings: true, + message: { + role: "tool", + content: [ + { + type: "tool-approval-response", + approvalId: args.approvalId, + approved: args.approved, + reason: args.reason, + }, + ], + }, + }); + return { messageId }; + } + /** * Explicitly save a "step" created by the AI SDK. * @param ctx The ctx argument to a mutation or action. diff --git a/src/client/mockModel.ts b/src/client/mockModel.ts index 728b8c86..f9dc51a2 100644 --- a/src/client/mockModel.ts +++ b/src/client/mockModel.ts @@ -16,8 +16,15 @@ const DEFAULT_USAGE = { outputTokens: 10, inputTokens: 3, totalTokens: 13, - inputTokenDetails: undefined, - outputTokenDetails: undefined, + inputTokenDetails: { + noCacheTokens: 3, + cacheReadTokens: 0, + cacheWriteTokens: 0, + }, + outputTokenDetails: { + textTokens: 10, + reasoningTokens: 0, + }, }; export type MockModelArgs = { diff --git a/src/client/start.ts b/src/client/start.ts index 55c1bc15..e08dae11 100644 --- a/src/client/start.ts +++ b/src/client/start.ts @@ -211,6 +211,13 @@ export async function startGeneration< }; } } + // Track how many response messages we've already saved across steps. + // step.response.messages is cumulative — each step appends to it. + // We need to know which messages are new in each step to serialize + // only the new ones (important for tool approval flows where the SDK + // may add extra messages like approval tool-results). + let previousResponseMessageCount = 0; + return { args: aiArgs, order: order ?? 0, @@ -236,20 +243,28 @@ export async function startGeneration< finishStreamId?: string, ) => { if (threadId && saveMessages !== "none") { - const serialized = - "object" in toSave - ? await serializeObjectResult( - ctx, - component, - toSave.object, - activeModel, - ) - : await serializeNewMessagesInStep( - ctx, - component, - toSave.step, - activeModel, - ); + let serialized; + if ("object" in toSave) { + serialized = await serializeObjectResult( + ctx, + component, + toSave.object, + activeModel, + ); + } else { + const allResponseMessages = toSave.step.response.messages; + const newResponseMessages = allResponseMessages.slice( + previousResponseMessageCount, + ); + previousResponseMessageCount = allResponseMessages.length; + serialized = await serializeNewMessagesInStep( + ctx, + component, + toSave.step, + activeModel, + newResponseMessages, + ); + } const embeddings = await embedMessages( ctx, { threadId, ...opts, userId }, diff --git a/src/client/streamText.ts b/src/client/streamText.ts index a2d04a31..b1276832 100644 --- a/src/client/streamText.ts +++ b/src/client/streamText.ts @@ -160,8 +160,16 @@ export async function streamText< !options.saveStreamDeltas.returnImmediately) || options?.saveStreamDeltas === true ) { - await stream; - await result.consumeStream(); + try { + await stream; + await result.consumeStream(); + } catch (e) { + // If the stream errored (e.g. onStepFinish threw), the DeltaStreamer's + // finish() was never called, leaving the streaming message stuck in + // "streaming" state. Clean it up by marking it as aborted. + await streamer?.fail(e instanceof Error ? e.message : String(e)); + throw e; + } } // If we deferred the final step save, do it now with atomic stream finish. diff --git a/src/component/messages.ts b/src/component/messages.ts index 91f11ebd..1b9a9a7d 100644 --- a/src/component/messages.ts +++ b/src/component/messages.ts @@ -278,7 +278,7 @@ async function addMessagesHandler( order: pendingMessage.order, stepOrder: pendingMessage.stepOrder, }); - toReturn.push(pendingMessage); + toReturn.push((await ctx.db.get(pendingMessage._id))!); continue; } if (message.message.role === "user") { diff --git a/src/mapping.ts b/src/mapping.ts index 025a1441..a8473e4f 100644 --- a/src/mapping.ts +++ b/src/mapping.ts @@ -211,6 +211,16 @@ export async function serializeNewMessagesInStep( component: AgentComponent, step: StepResult, model: ModelOrMetadata | undefined, + /** + * If provided, these are the new response messages for this step + * (pre-sliced by the caller). When not provided, falls back to the + * existing heuristic of slicing the last 1-2 messages. + * + * This is needed for tool approval flows where the SDK adds extra + * messages (e.g. approval tool-results) at the beginning of + * responseMessages that the old slice(-1/-2) logic would miss. + */ + newResponseMessages?: ModelMessage[], ): Promise<{ messages: MessageWithMetadata[] }> { // If there are tool results, there's another message with the tool results // ref: https://github.com/vercel/ai/blob/main/packages/ai/src/generate-text/to-response-messages.ts#L120 @@ -228,13 +238,21 @@ export async function serializeNewMessagesInStep( sources: hasToolMessage ? undefined : step.sources, } satisfies Omit; const toolFields = { sources: step.sources }; + + // Determine which messages to serialize for this step + let messagesToSerialize: ModelMessage[]; + if (newResponseMessages) { + messagesToSerialize = newResponseMessages; + } else if (hasToolMessage) { + messagesToSerialize = step.response.messages.slice(-2); + } else if (step.content.length) { + messagesToSerialize = step.response.messages.slice(-1); + } else { + messagesToSerialize = [{ role: "assistant" as const, content: [] }]; + } + const messages: MessageWithMetadata[] = await Promise.all( - (hasToolMessage - ? step.response.messages.slice(-2) - : step.content.length - ? step.response.messages.slice(-1) - : [{ role: "assistant" as const, content: [] }] - ).map(async (msg): Promise => { + messagesToSerialize.map(async (msg): Promise => { const { message, fileIds } = await serializeMessage(ctx, component, msg); return parse(vMessageWithMetadata, { message, diff --git a/src/react/useDeltaStreams.ts b/src/react/useDeltaStreams.ts index 799feca1..c01094a5 100644 --- a/src/react/useDeltaStreams.ts +++ b/src/react/useDeltaStreams.ts @@ -83,6 +83,12 @@ export function useDeltaStreams< ), ); + // When no active streams remain, clear the stale state so we stop + // returning old streaming UIMessages. + if (streamMessages !== undefined && streamMessages.length === 0) { + state.deltaStreams = undefined; + } + // Get the deltas for all the active streams, if any. const cursorQuery = useQuery( query,