diff --git a/src/client/index.ts b/src/client/index.ts index 874b5648..885c0a6e 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -99,9 +99,10 @@ import type { Output, } from "./types.js"; import { streamText } from "./streamText.js"; -import { errorToString, willContinue } from "./utils.js"; +import { errorToString, hasSuccessfulToolCall, willContinue } from "./utils.js"; export { stepCountIs } from "ai"; +export { hasSuccessfulToolCall }; export { docsToModelMessages, toModelMessage, diff --git a/src/client/utils.ts b/src/client/utils.ts index abda1e16..d5c3b454 100644 --- a/src/client/utils.ts +++ b/src/client/utils.ts @@ -1,5 +1,19 @@ import type { StepResult, StopCondition } from "ai"; +/** + * A stop condition that only matches tool calls which completed + * successfully (i.e. produced a `tool-result`, not a `tool-error`). + * + * Use this instead of the AI SDK's `hasToolCall` when you want the + * agent to retry on argument validation failures rather than stopping. + */ +export function hasSuccessfulToolCall(toolName: string): StopCondition { + return ({ steps }) => + steps[steps.length - 1]?.toolResults?.some( + (result) => result.toolName === toolName, + ) ?? false; +} + export async function willContinue( steps: StepResult[], @@ -9,8 +23,15 @@ export async function willContinue( // we aren't doing another round after a tool result // TODO: whether to handle continuing after too much context used.. if (step.finishReason !== "tool-calls") return false; + // Count both successful results and errors as completed outputs. + // In AI SDK v6, failed tool calls produce tool-error content parts + // instead of tool-result, so only checking toolResults misses them. + const completedOutputs = + step.content?.filter( + (p) => p.type === "tool-result" || p.type === "tool-error", + ).length ?? step.toolResults.length; // we don't have a tool result, so we'll wait for more - if (step.toolCalls.length > step.toolResults.length) return false; + if (step.toolCalls.length > completedOutputs) return false; if (Array.isArray(stopWhen)) { return (await Promise.all(stopWhen.map(async (s) => s({ steps })))).every( (stop) => !stop,