diff --git a/temporalio/lib/temporalio/contrib/tool_registry.rb b/temporalio/lib/temporalio/contrib/tool_registry.rb new file mode 100644 index 00000000..793c4ba7 --- /dev/null +++ b/temporalio/lib/temporalio/contrib/tool_registry.rb @@ -0,0 +1,61 @@ +# frozen_string_literal: true + +require 'temporalio/contrib/tool_registry/provider' +require 'temporalio/contrib/tool_registry/registry' +require 'temporalio/contrib/tool_registry/session' + +module Temporalio + module Contrib + # LLM tool-calling primitives for Temporal activities. + # + # This module provides building blocks for running agentic LLM tool-use + # loops inside Temporal activities with automatic heartbeat checkpointing + # and retry/resume semantics. + # + # == Quick-start + # + # registry = Temporalio::Contrib::ToolRegistry::Registry.new + # registry.register(name: 'get_weather', description: 'Get the weather', + # input_schema: { type: 'object', properties: { city: { type: 'string' } } }) do |input| + # WeatherService.get(input['city']) + # end + # + # provider = Temporalio::Contrib::ToolRegistry::Providers::AnthropicProvider.new( + # registry, 'You are a helpful assistant.', api_key: ENV['ANTHROPIC_API_KEY'] + # ) + # + # # Inside a Temporal activity: + # Temporalio::Contrib::ToolRegistry::AgenticSession.run_with_session do |session| + # session.run_tool_loop(provider, registry, 'What is the weather in NYC?') + # end + # + # == Module-level helper + # + # For simple cases that do not require checkpointing (no activity context): + # + # messages = Temporalio::Contrib::ToolRegistry.run_tool_loop(provider, registry, 'user prompt') + # + module ToolRegistry + # Run a single (non-checkpointed) agentic tool-use loop. + # + # This is a convenience wrapper that does NOT require an active Temporal + # activity context. For production use inside activities, prefer + # {AgenticSession.run_with_session} to get heartbeat checkpointing and + # automatic retry-resume. + # + # @param provider [Provider] LLM provider adapter. + # @param registry [Registry] Tool registry. + # @param prompt [String] Initial user prompt. + # @return [Array] Full conversation message history. + def self.run_tool_loop(provider, registry, prompt) + messages = [{ 'role' => 'user', 'content' => prompt }] + loop do + new_msgs, done = provider.run_turn(messages, registry.defs) + messages.concat(new_msgs) + break if done + end + messages + end + end + end +end diff --git a/temporalio/lib/temporalio/contrib/tool_registry/README.md b/temporalio/lib/temporalio/contrib/tool_registry/README.md new file mode 100644 index 00000000..6e467e55 --- /dev/null +++ b/temporalio/lib/temporalio/contrib/tool_registry/README.md @@ -0,0 +1,234 @@ +# temporalio/contrib/tool_registry + +LLM tool-calling primitives for Temporal activities — define tools once, use with +Anthropic or OpenAI. + +## Before you start + +A Temporal Activity is a function that Temporal monitors and retries automatically on failure. Temporal streams progress between retries via heartbeats — that's the mechanism `run_with_session` uses to resume a crashed LLM conversation mid-turn. + +`run_tool_loop` works standalone in any function — no Temporal server needed. Add `AgenticSession` only when you need crash-safe resume inside a Temporal activity. + +`AgenticSession` requires a running Temporal worker — it reads and writes heartbeat state from the active activity context. Use `run_tool_loop` standalone for scripts, one-off jobs, or any code that runs outside a Temporal worker. + +New to Temporal? → https://docs.temporal.io/develop + +**Python or TypeScript user?** Those SDKs also ship framework-level integrations (`openai_agents`, `google_adk_agents`, `langgraph`, `@temporalio/ai-sdk`) for teams already using a specific agent framework. ToolRegistry is the equivalent story for direct Anthropic/OpenAI calls, and shares the same API surface across all six Temporal SDKs. + +## Install + +Add to your `Gemfile`: + +```ruby +gem 'temporalio' +``` + +Install the LLM client gem separately: + +```ruby +gem 'anthropic' # Anthropic +gem 'ruby-openai' # OpenAI +``` + +## Quickstart + +Tool definitions use [JSON Schema](https://json-schema.org/understanding-json-schema/) for `input_schema`. The quickstart uses a single string field; for richer schemas refer to the JSON Schema docs. + +```ruby +require 'temporalio/contrib/tool_registry' +require 'temporalio/contrib/tool_registry/providers/anthropic' + +include Temporalio::Contrib # brings ToolRegistry::* into scope + +activity :analyze do |prompt| + results = [] + registry = ToolRegistry::Registry.new + registry.register( + name: 'flag_issue', + description: 'Flag a problem found in the analysis', + input_schema: { + 'type' => 'object', + 'properties' => { 'description' => { 'type' => 'string' } }, + 'required' => ['description'] + } + ) do |input| + results << input['description'] + 'recorded' # this string is sent back to the LLM as the tool result + end + + provider = ToolRegistry::Providers::AnthropicProvider.new( + registry, + 'You are a code reviewer. Call flag_issue for each problem you find.', + api_key: ENV['ANTHROPIC_API_KEY'] + ) + + ToolRegistry.run_tool_loop(provider, registry, prompt) + results +end +``` + +### Selecting a model + +The default model is `"claude-sonnet-4-6"` (Anthropic) or `"gpt-4o"` (OpenAI). Override with the `model:` keyword: + +```ruby +provider = ToolRegistry::Providers::AnthropicProvider.new( + registry, + 'You are a code reviewer.', + api_key: ENV['ANTHROPIC_API_KEY'], + model: 'claude-3-5-sonnet-20241022' +) +``` + +Model IDs are defined by the provider — see Anthropic or OpenAI docs for current names. + +### OpenAI + +```ruby +require 'temporalio/contrib/tool_registry/providers/openai' + +provider = ToolRegistry::Providers::OpenAIProvider.new( + registry, 'your system prompt', api_key: ENV['OPENAI_API_KEY']) +ToolRegistry.run_tool_loop(provider, registry, prompt) +``` + +## Crash-safe agentic sessions + +For multi-turn LLM conversations that must survive activity retries, use +`AgenticSession.run_with_session`. It saves conversation history via +`Temporalio::Activity::Context.current.heartbeat` on every turn and restores +it on retry. + +```ruby +require 'temporalio/contrib/tool_registry/session' + +results = ToolRegistry::AgenticSession.run_with_session do |session| + registry = ToolRegistry::Registry.new + registry.register(name: 'flag', description: '...', + input_schema: { 'type' => 'object' }) do |input| + session.add_result(input) # use add_result, not session.results << + 'ok' # this string is sent back to the LLM as the tool result + end + + provider = ToolRegistry::Providers::AnthropicProvider.new( + registry, 'your system prompt', api_key: ENV['ANTHROPIC_API_KEY']) + session.run_tool_loop(provider, registry, prompt) + session.results # return value of block = return value of run_with_session +end +``` + +## Testing without an API key + +```ruby +require 'temporalio/contrib/tool_registry' +require 'temporalio/contrib/tool_registry/testing' + +include Temporalio::Contrib::ToolRegistry # brings ToolRegistry::* into scope + +registry = Registry.new +registry.register(name: 'flag', description: 'd', input_schema: { 'type' => 'object' }) do |_| + 'ok' # this string is sent back to the LLM as the tool result +end + +provider = Testing::MockProvider.new( + Testing::MockResponse.tool_call('flag', { 'description' => 'stale API' }), + Testing::MockResponse.done('analysis complete') +).with_registry(registry) + +msgs = ToolRegistry.run_tool_loop(provider, registry, 'analyze') +assert msgs.length > 2 +``` + +## Integration testing with real providers + +To run the integration tests against live Anthropic and OpenAI APIs: + +```bash +RUN_INTEGRATION_TESTS=1 \ + ANTHROPIC_API_KEY=sk-ant-... \ + OPENAI_API_KEY=sk-proj-... \ + ruby -I lib -I test test/contrib/tool_registry_test.rb +``` + +Tests skip automatically when `RUN_INTEGRATION_TESTS` is unset. Real API calls +incur billing — expect a few cents per full test run. + +## Storing application results + +`session.results` accumulates application-level +results during the tool loop. Elements are serialized to JSON inside each heartbeat +checkpoint — they must be plain maps/dicts with JSON-serializable values. A non-serializable +value raises a non-retryable `ApplicationError` at heartbeat time rather than silently +losing data on the next retry. + +### Storing typed results + +Convert your domain type to a plain dict at the tool-call site and back after the session: + +```ruby +Result = Struct.new(:type, :file, keyword_init: true) + +# Inside tool handler: +session.add_result({ 'type' => 'smell', 'file' => 'foo.rb' }) + +# After session: +results = session.results.map { |h| Result.new(**h.transform_keys(&:to_sym)) } +``` + +## Per-turn LLM timeout + +Individual LLM calls inside the tool loop are unbounded by default. A hung HTTP +connection holds the activity open until Temporal's `ScheduleToCloseTimeout` +fires — potentially many minutes. Set a per-turn timeout on the provider client: + +```ruby +provider = ToolRegistry::Providers::AnthropicProvider.new( + registry, + 'system prompt', + api_key: ENV['ANTHROPIC_API_KEY'], + timeout: 30 # seconds +) +``` + +Recommended timeouts: + +| Model type | Recommended | +|---|---| +| Standard (Claude 3.x, GPT-4o) | 30 s | +| Reasoning (o1, o3, extended thinking) | 300 s | + +### Activity-level timeout + +Set `schedule_to_close_timeout` on the activity options to bound the entire conversation: + +```ruby +workflow.execute_activity( + MyActivities.long_analysis, + prompt, + schedule_to_close_timeout: 600 # seconds +) +``` + +The per-turn client timeout and `schedule_to_close_timeout` are complementary: +- Per-turn timeout fires if one LLM call hangs (protects against a single stuck turn) +- `schedule_to_close_timeout` bounds the entire conversation including all retries (protects against runaway multi-turn loops) + +## MCP integration + +`Registry.from_mcp_tools` converts a list of MCP tool descriptors into a populated +registry. Handlers default to no-ops that return an empty string; override them with +`register` after construction. + +```ruby +# mcp_tools is an array of objects responding to :name, :description, :input_schema. +registry = ToolRegistry::Registry.from_mcp_tools(mcp_tools) + +# Override specific handlers before running the loop. +registry.register(name: 'read_file', description: '...', input_schema: { 'type' => 'object' }) do |input| + read_file(input['path']) +end +``` + +Each descriptor must respond to `name`, `description`, and `input_schema` (or +`inputSchema` for camelCase MCP objects). `input_schema` should be a Hash containing +a JSON Schema object. diff --git a/temporalio/lib/temporalio/contrib/tool_registry/provider.rb b/temporalio/lib/temporalio/contrib/tool_registry/provider.rb new file mode 100644 index 00000000..f213c5a9 --- /dev/null +++ b/temporalio/lib/temporalio/contrib/tool_registry/provider.rb @@ -0,0 +1,25 @@ +# frozen_string_literal: true + +module Temporalio + module Contrib + module ToolRegistry + # Abstract base for LLM provider adapters. + # + # Subclasses implement {run_turn} to drive one round-trip with the LLM: + # send the current message history and available tool definitions, then + # return any new messages and whether the conversation is complete. + class Provider + # Execute one conversation turn. + # + # @param messages [Array] Current message history (String-keyed). + # @param tools [Array] Available tool definitions. + # @return [Array(Array, Boolean)] Tuple of [new_messages, done]. + # - new_messages: messages to append to the conversation history. + # - done: true if the LLM produced a final response with no pending tool calls. + def run_turn(messages, tools) + raise NotImplementedError, "#{self.class}#run_turn not implemented" + end + end + end + end +end diff --git a/temporalio/lib/temporalio/contrib/tool_registry/providers/anthropic.rb b/temporalio/lib/temporalio/contrib/tool_registry/providers/anthropic.rb new file mode 100644 index 00000000..7a1eac49 --- /dev/null +++ b/temporalio/lib/temporalio/contrib/tool_registry/providers/anthropic.rb @@ -0,0 +1,88 @@ +# frozen_string_literal: true + +require 'json' +require 'temporalio/contrib/tool_registry/provider' + +module Temporalio + module Contrib + module ToolRegistry + module Providers + # LLM provider adapter for the Anthropic Messages API. + # + # Requires the +anthropic+ gem. Install it separately: + # gem 'anthropic' + # + # The registry passed to the constructor is used to dispatch tool calls + # that the model requests. Tool definitions are passed via {run_turn}'s + # +tools+ parameter (typically +registry.defs+). + class AnthropicProvider < Provider + DEFAULT_MODEL = 'claude-sonnet-4-6' + + # @param registry [Registry] Registry used to dispatch tool calls. + # @param system [String] System prompt sent on every turn. + # @param model [String] Anthropic model ID. + # @param api_key [String, nil] API key (falls back to +ANTHROPIC_API_KEY+ env var). + # @param base_url [String, nil] Optional custom base URL. + # @param client [Object, nil] Pre-built Anthropic client (skips key/URL). + def initialize(registry, system, model: DEFAULT_MODEL, api_key: nil, base_url: nil, client: nil) + super() + + @registry = registry + @system = system + @model = model + @client = client || build_client(api_key, base_url) + end + + # @see Provider#run_turn + def run_turn(messages, tools) + anthropic_tools = tools.map do |t| + { name: t.name, description: t.description, input_schema: t.input_schema } + end + + resp = @client.messages.create( + model: @model, + max_tokens: 4096, + system: @system, + messages: messages, + tools: anthropic_tools + ) + + # JSON round-trip: convert typed SDK objects → plain Hashes for + # heartbeat safety and consistent message format across all turns. + content = JSON.parse(JSON.generate(resp.content)) + new_msgs = [{ 'role' => 'assistant', 'content' => content }] + + tool_calls = content.select { |b| b['type'] == 'tool_use' } + stop_reason = resp.stop_reason.to_s + return [new_msgs, true] if tool_calls.empty? || stop_reason == 'end_turn' + + tool_results = tool_calls.map do |call| + is_error = false + result = begin + @registry.dispatch(call['name'], call['input']) + rescue => e # rubocop:disable Style/RescueStandardError + is_error = true + "error: #{e.message}" + end + entry = { 'type' => 'tool_result', 'tool_use_id' => call['id'], 'content' => result.to_s } + entry['is_error'] = true if is_error + entry + end + new_msgs << { 'role' => 'user', 'content' => tool_results } + [new_msgs, false] + end + + private + + def build_client(api_key, base_url) + require 'anthropic' + key = api_key || ENV.fetch('ANTHROPIC_API_KEY') + opts = { api_key: key } + opts[:base_url] = base_url if base_url + Anthropic::Client.new(**opts) + end + end + end + end + end +end diff --git a/temporalio/lib/temporalio/contrib/tool_registry/providers/openai.rb b/temporalio/lib/temporalio/contrib/tool_registry/providers/openai.rb new file mode 100644 index 00000000..6aab7b98 --- /dev/null +++ b/temporalio/lib/temporalio/contrib/tool_registry/providers/openai.rb @@ -0,0 +1,110 @@ +# frozen_string_literal: true + +require 'json' +require 'temporalio/contrib/tool_registry/provider' + +module Temporalio + module Contrib + module ToolRegistry + module Providers + # LLM provider adapter for the OpenAI Chat Completions API. + # + # Requires the +ruby-openai+ gem. Install it separately: + # gem 'ruby-openai' + # + # The registry passed to the constructor is used to dispatch tool calls + # that the model requests. Tool definitions are passed via {run_turn}'s + # +tools+ parameter (typically +registry.defs+). + class OpenAIProvider < Provider + DEFAULT_MODEL = 'gpt-4o' + + # @param registry [Registry] Registry used to dispatch tool calls. + # @param system [String] System prompt prepended as a system message. + # @param model [String] OpenAI model ID. + # @param api_key [String, nil] API key (falls back to +OPENAI_API_KEY+ env var). + # @param base_url [String, nil] Optional custom base URL. + # @param client [Object, nil] Pre-built OpenAI client (skips key/URL). + def initialize(registry, system, model: DEFAULT_MODEL, api_key: nil, base_url: nil, client: nil) + super() + require 'openai' + + @registry = registry + @system = system + @model = model + @client = client || build_client(api_key, base_url) + end + + # @see Provider#run_turn + def run_turn(messages, tools) + full_messages = [{ 'role' => 'system', 'content' => @system }] + messages + openai_tools = tools.map do |t| + { + 'type' => 'function', + 'function' => { + 'name' => t.name, + 'description' => t.description, + 'parameters' => t.input_schema + } + } + end + + resp = @client.chat(parameters: { + model: @model, + messages: full_messages, + tools: openai_tools + }) + + choice = resp.dig('choices', 0) + msg = choice&.dig('message') || {} + + msg_hash = { 'role' => 'assistant', 'content' => msg['content'] } + tool_calls = msg['tool_calls'] || [] + + unless tool_calls.empty? + msg_hash['tool_calls'] = tool_calls.map do |tc| + { + 'id' => tc['id'], + 'type' => 'function', + 'function' => { + 'name' => tc.dig('function', 'name'), + 'arguments' => tc.dig('function', 'arguments') + } + } + end + end + + new_msgs = [msg_hash] + finish_reason = choice&.dig('finish_reason') || '' + done = tool_calls.empty? || %w[stop length].include?(finish_reason) + return [new_msgs, true] if done + + tool_calls.each do |tc| + input = JSON.parse(tc.dig('function', 'arguments') || '{}') + name = tc.dig('function', 'name') + result = begin + @registry.dispatch(name, input) + rescue => e # rubocop:disable Style/RescueStandardError + "error: #{e.message}" + end + new_msgs << { + 'role' => 'tool', + 'tool_call_id' => tc['id'], + 'content' => result.to_s + } + end + [new_msgs, false] + end + + private + + def build_client(api_key, base_url) + key = api_key || ENV.fetch('OPENAI_API_KEY') + opts = { access_token: key } + opts[:uri_base] = base_url if base_url + OpenAI::Client.new(**opts) + end + end + end + end + end +end diff --git a/temporalio/lib/temporalio/contrib/tool_registry/registry.rb b/temporalio/lib/temporalio/contrib/tool_registry/registry.rb new file mode 100644 index 00000000..83109a78 --- /dev/null +++ b/temporalio/lib/temporalio/contrib/tool_registry/registry.rb @@ -0,0 +1,97 @@ +# frozen_string_literal: true + +require 'json' + +module Temporalio + module Contrib + module ToolRegistry + # Immutable definition of a single LLM-callable tool. + ToolDef = Data.define(:name, :description, :input_schema) + + # Registry that maps tool names to handlers. + class Registry + def initialize + @defs = [] + @handlers = {} + end + + # Build a Registry from a list of MCP tool descriptors. + # + # Each descriptor must respond to +name+, +description+, and +input_schema+ + # (or +inputSchema+ for camelCase MCP objects). No-op handlers (returning an + # empty string) are registered for each tool — override them with {#register} + # after construction. + # + # @param tools [Array] MCP tool descriptor objects. + # @return [Registry] + def self.from_mcp_tools(tools) + registry = new + tools.each do |tool| + schema = (tool.respond_to?(:input_schema) ? tool.input_schema : tool.inputSchema) || + { 'type' => 'object', 'properties' => {} } + desc = (tool.respond_to?(:description) ? tool.description : nil) || '' + registry.register( + name: tool.name, + description: desc, + input_schema: schema + ) { |_input| '' } + end + registry + end + + # Register a tool with the given name, description, and JSON Schema for its input. + # The block receives a Hash of parsed arguments and must return a String result. + # + # @param name [String] Tool name. + # @param description [String] Human-readable description. + # @param input_schema [Hash] JSON Schema for the tool's input object. + # @yield [Hash] Called with the parsed input when the tool is invoked. + # @return [self] + def register(name:, description:, input_schema:, &handler) + raise ArgumentError, 'Block required' unless block_given? + + defn = ToolDef.new(name: name.to_s, description: description.to_s, input_schema:) + @defs << defn + @handlers[defn.name] = handler + self + end + + # Dispatch a tool call by name. Raises KeyError if the tool is not registered. + # + # @param name [String] Tool name. + # @param input [Hash] Parsed input arguments. + # @return [String] Tool result. + def dispatch(name, input) + handler = @handlers.fetch(name.to_s) { raise KeyError, "Unknown tool: #{name}" } + handler.call(input) + end + + # @return [Array] Frozen copy of all registered tool definitions. + def defs + @defs.dup.freeze + end + + # @return [Array] Tool definitions in Anthropic API format. + def to_anthropic + @defs.map do |t| + { 'name' => t.name, 'description' => t.description, 'input_schema' => t.input_schema } + end + end + + # @return [Array] Tool definitions in OpenAI function-calling format. + def to_openai + @defs.map do |t| + { + 'type' => 'function', + 'function' => { + 'name' => t.name, + 'description' => t.description, + 'parameters' => t.input_schema + } + } + end + end + end + end + end +end diff --git a/temporalio/lib/temporalio/contrib/tool_registry/session.rb b/temporalio/lib/temporalio/contrib/tool_registry/session.rb new file mode 100644 index 00000000..022a48d8 --- /dev/null +++ b/temporalio/lib/temporalio/contrib/tool_registry/session.rb @@ -0,0 +1,128 @@ +# frozen_string_literal: true + +require 'json' +require 'temporalio/activity/context' +require 'temporalio/error/failure' + +module Temporalio + module Contrib + module ToolRegistry + # Holds conversation state across a multi-turn LLM tool-use loop. + # + # On activity retry, {run_with_session} restores the session from the last + # heartbeat checkpoint so the conversation resumes mid-turn rather than + # restarting from the beginning. + class AgenticSession + # @return [Array] Full conversation history (String-keyed, JSON-safe). + attr_reader :messages + + # @return [Array] Application-level results from tool calls. + attr_reader :results + + # Run +block+ with a durable, checkpointed LLM session. + # + # On entry it reads the last heartbeat checkpoint from the activity + # context. If found, the session is restored so the conversation resumes + # mid-turn rather than restarting from turn 0. + # + # Must be called inside a Temporal activity (requires an active + # {Activity::Context}). + # + # @yield [AgenticSession] Freshly created (or restored) session. + def self.run_with_session + session = new + ctx = Activity::Context.current + details = ctx.info.heartbeat_details + cp = details&.first + if cp.is_a?(Hash) + session.send(:restore, cp) + elsif !cp.nil? + ctx.logger.warn("AgenticSession: corrupt checkpoint (#{cp.class}), starting fresh") + end + yield session + end + + def initialize + @messages = [] + @results = [] + end + + private + + def restore(checkpoint) + return unless checkpoint.is_a?(Hash) + + v = checkpoint['version'] + if v.nil? + Activity::Context.current.logger.warn( + 'AgenticSession: checkpoint has no version field — may be from an older release' + ) + elsif v != 1 + Activity::Context.current.logger.warn( + "AgenticSession: checkpoint version #{v}, expected 1 — starting fresh" + ) + return + end + + @messages = Array(checkpoint['messages']) + @results = Array(checkpoint['results']) + end + + public + + # Append an application-level result record. + # + # @param result_hash [Hash] JSON-serializable result. + def add_result(result_hash) + @results << result_hash + end + + # Run the agentic tool-use loop to completion. + # + # If {messages} is empty (fresh start), +prompt+ is added as the first + # user message. Otherwise the existing conversation state is resumed + # (retry case). + # + # On every turn it checkpoints via {Activity::Context#heartbeat} before + # calling the provider. Ruby's +CanceledError+ is raised asynchronously + # through the next blocking call when the activity is cancelled; no + # explicit check after heartbeat is needed. + # + # @param provider [Provider] LLM provider adapter. + # @param registry [Registry] Tool registry whose definitions are passed to the LLM. + # @param prompt [String] Initial user prompt (ignored on retry). + def run_tool_loop(provider, registry, prompt) + @messages << { 'role' => 'user', 'content' => prompt } if @messages.empty? + + loop do + checkpoint + new_msgs, done = provider.run_turn(@messages, registry.defs) + @messages.concat(new_msgs) + break if done + end + end + + # Heartbeat the current session state to Temporal. + # + # Call this inside an activity context. On cancellation, +CanceledError+ + # arrives asynchronously through the next blocking call (the LLM HTTP + # request) — no explicit check is needed after calling +checkpoint+. + # + # @raise [Temporalio::Error::ApplicationError] (non-retryable) if any result is not + # JSON-serializable. + def checkpoint + @results.each_with_index do |result, i| + JSON.generate(result) + rescue TypeError, JSON::GeneratorError => e + raise Temporalio::Error::ApplicationError.new( + "AgenticSession: results[#{i}] is not JSON-serializable: #{e}. " \ + 'Store only Hash values with JSON-serializable content.', + non_retryable: true + ) + end + Activity::Context.current.heartbeat('version' => 1, 'messages' => @messages, 'results' => @results) + end + end + end + end +end diff --git a/temporalio/lib/temporalio/contrib/tool_registry/testing.rb b/temporalio/lib/temporalio/contrib/tool_registry/testing.rb new file mode 100644 index 00000000..104f8d9c --- /dev/null +++ b/temporalio/lib/temporalio/contrib/tool_registry/testing.rb @@ -0,0 +1,179 @@ +# frozen_string_literal: true + +require 'securerandom' +require 'temporalio/contrib/tool_registry/provider' +require 'temporalio/contrib/tool_registry/registry' +require 'temporalio/contrib/tool_registry/session' + +module Temporalio + module Contrib + module ToolRegistry + module Testing + # A recorded tool dispatch (name, input, result). + DispatchCall = Data.define(:name, :input, :result) + + # A canned response to be returned by {MockProvider}. + # + # Use the factory methods {done} and {tool_call} to build instances. + class MockResponse + # @return [Symbol] :done or :tool_call + attr_reader :type + # @return [String, nil] Text content for :done responses. + attr_reader :text + # @return [String, nil] Tool name for :tool_call responses. + attr_reader :tool_name + # @return [Hash, nil] Tool input for :tool_call responses. + attr_reader :input + # @return [String, nil] Explicit call ID (auto-generated if nil). + attr_reader :call_id + + # Build a "done" response with the given text. + # + # @param text [String] + # @return [MockResponse] + def self.done(text) + new(:done, text:) + end + + # Build a "tool_call" response. + # + # @param tool_name [String] Tool to call. + # @param input [Hash] Input for the tool. + # @param call_id [String, nil] Optional explicit call ID. + # @return [MockResponse] + def self.tool_call(tool_name, input, call_id = nil) + new(:tool_call, tool_name:, input:, call_id:) + end + + private + + def initialize(type, text: nil, tool_name: nil, input: nil, call_id: nil) + @type = type + @text = text + @tool_name = tool_name + @input = input + @call_id = call_id + end + end + + # A {Provider} backed by pre-scripted {MockResponse} values. + # + # When a tool-call response is scripted and a registry is wired via + # {with_registry}, the provider dispatches the tool call and adds the + # result as a user message. + class MockProvider < Provider + def initialize(*responses) + super() + @responses = responses.dup + @registry = nil + end + + # Wire a registry so tool-call responses can dispatch tools. + # + # @param registry [Registry] + # @return [self] + def with_registry(registry) + @registry = registry + self + end + + # @see Provider#run_turn + def run_turn(_messages, _tools) + raise 'MockProvider: no more responses' if @responses.empty? + + resp = @responses.shift + + case resp.type + when :done + text_block = { 'type' => 'text', 'text' => resp.text } + [[{ 'role' => 'assistant', 'content' => [text_block] }], true] + + when :tool_call + reg = @registry or raise 'MockProvider: tool_call response requires a registry (use #with_registry)' + call_id = resp.call_id || "mock-call-#{SecureRandom.hex(4)}" + tool_block = { + 'type' => 'tool_use', + 'id' => call_id, + 'name' => resp.tool_name, + 'input' => resp.input + } + new_msgs = [{ 'role' => 'assistant', 'content' => [tool_block] }] + result = begin + reg.dispatch(resp.tool_name, resp.input) + rescue => e # rubocop:disable Style/RescueStandardError + "error: #{e.message}" + end + new_msgs << { + 'role' => 'user', + 'content' => [{ 'type' => 'tool_result', 'tool_use_id' => call_id, 'content' => result.to_s }] + } + [new_msgs, false] + + else + raise "MockProvider: unknown response type #{resp.type.inspect}" + end + end + end + + # A {Registry} subclass that records every {dispatch} call. + class FakeToolRegistry < Registry + # @return [Array] All recorded dispatch calls. + def calls + @calls ||= [] + end + + # @see Registry#dispatch + def dispatch(name, input) + result = super + calls << DispatchCall.new(name:, input:, result:) + result + end + end + + # A {AgenticSession} whose {run_tool_loop} is a no-op for isolation tests. + # + # Useful when you want to test code that *holds* a session reference + # without actually driving the LLM loop. + class MockAgenticSession < AgenticSession + # @return [String, nil] The prompt passed to the last {run_tool_loop} call. + attr_reader :captured_prompt + + # Override: records the prompt but does not call the LLM. + def run_tool_loop(_provider, _registry, _system, prompt) + @captured_prompt = prompt + end + + # Expose issues array for pre-seeding in tests. + # + # @return [Array] + def mutable_issues + @issues + end + end + + # A {Provider} decorator that raises after a given number of turns. + # + # Useful for testing retry / checkpoint-restore behaviour. + class CrashAfterTurns < Provider + # @param turns [Integer] Number of successful turns before crashing. + # @param delegate [Provider, nil] Underlying provider. If nil a + # {MockProvider} with a single :done response is used internally. + def initialize(turns, delegate = nil) + super() + @turns = turns + @delegate = delegate || MockProvider.new(MockResponse.done('ok')) + @count = 0 + end + + # @see Provider#run_turn + def run_turn(messages, tools) + @count += 1 + raise "CrashAfterTurns: crashed after #{@turns} turns" if @count > @turns + + @delegate.run_turn(messages, tools) + end + end + end + end + end +end diff --git a/temporalio/test/contrib/tool_registry_session_test.rb b/temporalio/test/contrib/tool_registry_session_test.rb new file mode 100644 index 00000000..8eafff03 --- /dev/null +++ b/temporalio/test/contrib/tool_registry_session_test.rb @@ -0,0 +1,273 @@ +# frozen_string_literal: true + +require 'minitest/autorun' +require 'temporalio/activity/context' +require 'temporalio/contrib/tool_registry' +require 'temporalio/contrib/tool_registry/testing' + +module Contrib + # Tests for AgenticSession: run_tool_loop, checkpoint, and run_with_session. + # + # These tests run without a Temporal server or native bridge by manually + # wiring a lightweight fake activity context into Thread.current. + class ToolRegistrySessionTest < Minitest::Test + AgenticSession = Temporalio::Contrib::ToolRegistry::AgenticSession + Registry = Temporalio::Contrib::ToolRegistry::Registry + Testing = Temporalio::Contrib::ToolRegistry::Testing + + # ── Test harness ────────────────────────────────────────────────────────── + + # Minimal activity context used in session tests. + class FakeContext < Temporalio::Activity::Context + attr_reader :heartbeats, :warnings + + def initialize(heartbeat_details: []) + @heartbeats = [] + @warnings = [] + @info = FakeInfo.new(heartbeat_details) + @logger = FakeLogger.new(@warnings) + end + + def heartbeat(*details) + @heartbeats << details + end + + def info + @info + end + + def logger + @logger + end + end + + # Stub logger that captures warnings for assertions. + class FakeLogger + def initialize(warnings) + @warnings = warnings + end + + def warn(msg) + @warnings << msg + end + end + + # Minimal Info that returns pre-seeded heartbeat details directly. + class FakeInfo + def initialize(details) + @details = details + end + + def heartbeat_details(hints: nil) + @details + end + end + + # Minimal executor whose #activity_context returns the given FakeContext. + class FakeExecutor + def initialize(context) + @context = context + end + + def activity_context + @context + end + end + + # Run a block inside a fake activity context. + def in_activity(ctx = FakeContext.new) + original = Thread.current[:temporal_activity_executor] + Thread.current[:temporal_activity_executor] = FakeExecutor.new(ctx) + yield ctx + ensure + Thread.current[:temporal_activity_executor] = original + end + + # ── AgenticSession ──────────────────────────────────────────────────────── + + def test_fresh_start_seeds_user_message + registry = Registry.new + provider = Testing::MockProvider.new(Testing::MockResponse.done('done')) + captured = nil + + in_activity do + session = AgenticSession.new + session.run_tool_loop(provider, registry, 'hello') + captured = session.messages + end + + assert_equal 'user', captured[0]['role'] + assert_equal 'hello', captured[0]['content'] + assert_equal 'assistant', captured[1]['role'] + end + + def test_existing_messages_skip_prompt + checkpoint = { + 'messages' => [{ 'role' => 'user', 'content' => 'original' }], + 'issues' => [] + } + ctx = FakeContext.new(heartbeat_details: [checkpoint]) + registry = Registry.new + provider = Testing::MockProvider.new(Testing::MockResponse.done('ok')) + captured = nil + + in_activity(ctx) do + AgenticSession.run_with_session do |session| + session.run_tool_loop(provider, registry, 'ignored') + captured = session.messages + end + end + + assert_equal 'original', captured[0]['content'] + assert_equal 'assistant', captured[1]['role'] + end + + def test_tool_call_dispatched + collected = [] + fake_reg = Testing::FakeToolRegistry.new + fake_reg.register(name: 'collect', description: 'd', input_schema: { 'type' => 'object' }) do |input| + collected << input['v'] + 'ok' + end + provider = Testing::MockProvider.new( + Testing::MockResponse.tool_call('collect', { 'v' => 'first' }), + Testing::MockResponse.tool_call('collect', { 'v' => 'second' }), + Testing::MockResponse.done('done') + ).with_registry(fake_reg) + + captured = nil + in_activity do + session = AgenticSession.new + session.run_tool_loop(provider, fake_reg, 'go') + captured = session.messages + end + + assert_equal %w[first second], collected + # user + (assistant + user)*2 + final_assistant = at least 5 + assert captured.size > 4 + end + + def test_checkpoint_called_each_turn + registry = Registry.new + provider = Testing::MockProvider.new(Testing::MockResponse.done('x')) + ctx = FakeContext.new + + in_activity(ctx) do + session = AgenticSession.new + session.run_tool_loop(provider, registry, 'prompt') + end + + # checkpoint is called once before the first (and only) turn + assert_equal 1, ctx.heartbeats.size + detail = ctx.heartbeats.first.first + assert detail.is_a?(Hash) + assert detail.key?('messages') + assert detail.key?('issues') + end + + def test_add_issue + session = AgenticSession.new + session.add_issue({ 'type' => 'error', 'msg' => 'oops' }) + assert_equal 1, session.issues.size + assert_equal 'error', session.issues.first['type'] + end + + # ── run_with_session ────────────────────────────────────────────────────── + + def test_run_with_session_fresh_start + registry = Registry.new + provider = Testing::MockProvider.new(Testing::MockResponse.done('done')) + captured = nil + + in_activity do + AgenticSession.run_with_session do |session| + session.run_tool_loop(provider, registry, 'hello') + captured = session.messages + end + end + + refute_nil captured + assert_equal 'hello', captured[0]['content'] + end + + def test_run_with_session_restores_checkpoint + checkpoint = { + 'messages' => [{ 'role' => 'user', 'content' => 'restored' }], + 'issues' => [{ 'code' => 42 }] + } + ctx = FakeContext.new(heartbeat_details: [checkpoint]) + registry = Registry.new + provider = Testing::MockProvider.new(Testing::MockResponse.done('done')) + captured_messages = nil + captured_issues = nil + + in_activity(ctx) do + AgenticSession.run_with_session do |session| + session.run_tool_loop(provider, registry, 'ignored') + captured_messages = session.messages + captured_issues = session.issues + end + end + + assert_equal 'restored', captured_messages[0]['content'] + assert_equal [{ 'code' => 42 }], captured_issues + end + + # ── Checkpoint round-trip test (T6) ────────────────────────────────────── + + def test_checkpoint_round_trip_preserves_tool_calls + # Simulate a real Temporal round-trip: the heartbeat payload is JSON-serialized + # and deserialized by the Temporal server between activity attempts. + tool_calls = [ + { + 'id' => 'call_abc', + 'type' => 'function', + 'function' => { 'name' => 'my_tool', 'arguments' => '{"x":1}' } + } + ] + assistant_msg = { 'role' => 'assistant', 'tool_calls' => tool_calls } + issue = { 'type' => 'smell', 'file' => 'foo.rb' } + + ctx = FakeContext.new + in_activity(ctx) do + session = AgenticSession.new + session.instance_variable_set(:@messages, [assistant_msg]) + session.instance_variable_set(:@issues, [issue]) + session.checkpoint + end + + assert_equal 1, ctx.heartbeats.size + raw_payload = ctx.heartbeats.first.first + + # Simulate JSON round-trip as Temporal would apply between activity attempts. + json = JSON.generate(raw_payload) + restored_payload = JSON.parse(json) + + assert_equal 'assistant', restored_payload['messages'][0]['role'] + tool_calls_restored = restored_payload['messages'][0]['tool_calls'] + assert_instance_of Array, tool_calls_restored + assert_equal 1, tool_calls_restored.size + assert_equal 'call_abc', tool_calls_restored[0]['id'] + assert_equal 'my_tool', tool_calls_restored[0]['function']['name'] + assert_equal 'smell', restored_payload['issues'][0]['type'] + assert_equal 'foo.rb', restored_payload['issues'][0]['file'] + end + + def test_run_with_session_ignores_non_hash_checkpoint + # nil / non-Hash heartbeat details → fresh start + ctx = FakeContext.new(heartbeat_details: [42]) + registry = Registry.new + provider = Testing::MockProvider.new(Testing::MockResponse.done('done')) + captured = nil + + in_activity(ctx) do + AgenticSession.run_with_session do |session| + session.run_tool_loop(provider, registry, 'fresh') + captured = session.messages + end + end + + assert_equal 'fresh', captured[0]['content'] + end + end +end diff --git a/temporalio/test/contrib/tool_registry_test.rb b/temporalio/test/contrib/tool_registry_test.rb new file mode 100644 index 00000000..dc036da1 --- /dev/null +++ b/temporalio/test/contrib/tool_registry_test.rb @@ -0,0 +1,267 @@ +# frozen_string_literal: true + +require 'minitest/autorun' +require 'temporalio/contrib/tool_registry' +require 'temporalio/contrib/tool_registry/testing' + +module Contrib + # Tests for Registry, module-level run_tool_loop, and Provider base class. + class ToolRegistryTest < Minitest::Test + Registry = Temporalio::Contrib::ToolRegistry::Registry + ToolDef = Temporalio::Contrib::ToolRegistry::ToolDef + Provider = Temporalio::Contrib::ToolRegistry::Provider + Testing = Temporalio::Contrib::ToolRegistry::Testing + + # ── Registry ────────────────────────────────────────────────────────────── + + def test_register_and_dispatch + r = Registry.new + r.register(name: 'echo', description: 'Echo input', input_schema: { 'type' => 'object' }) do |input| + input['value'] + end + assert_equal 'hello', r.dispatch('echo', { 'value' => 'hello' }) + end + + def test_dispatch_unknown_tool_raises + r = Registry.new + assert_raises(KeyError) { r.dispatch('missing', {}) } + end + + def test_defs_returns_frozen_copy + r = Registry.new + r.register(name: 'a', description: 'd', input_schema: {}) { 'ok' } + defs = r.defs + assert_equal 1, defs.size + assert_instance_of ToolDef, defs.first + assert_equal 'a', defs.first.name + assert defs.frozen? + end + + def test_defs_copy_is_independent + r = Registry.new + r.register(name: 'a', description: 'd', input_schema: {}) { 'ok' } + defs1 = r.defs + r.register(name: 'b', description: 'd', input_schema: {}) { 'ok' } + assert_equal 1, defs1.size + assert_equal 2, r.defs.size + end + + def test_register_requires_block + r = Registry.new + assert_raises(ArgumentError) { r.register(name: 'x', description: 'd', input_schema: {}) } + end + + def test_to_anthropic + r = Registry.new + r.register(name: 'ping', description: 'Ping', input_schema: { 'type' => 'object' }) { 'pong' } + result = r.to_anthropic + assert_equal 1, result.size + assert_equal 'ping', result.first['name'] + assert_equal 'Ping', result.first['description'] + assert_equal({ 'type' => 'object' }, result.first['input_schema']) + end + + def test_to_openai + r = Registry.new + r.register(name: 'ping', description: 'Ping', input_schema: { 'type' => 'object' }) { 'pong' } + result = r.to_openai + assert_equal 1, result.size + item = result.first + assert_equal 'function', item['type'] + assert_equal 'ping', item.dig('function', 'name') + assert_equal 'Ping', item.dig('function', 'description') + assert_equal({ 'type' => 'object' }, item.dig('function', 'parameters')) + end + + def test_to_openai_empty + r = Registry.new + assert_equal [], r.to_openai + end + + # ── ToolDef ─────────────────────────────────────────────────────────────── + + def test_tool_def_is_immutable + defn = ToolDef.new(name: 'x', description: 'd', input_schema: {}) + assert_raises(NoMethodError) { defn.name = 'y' } + end + + # ── Provider abstract base ──────────────────────────────────────────────── + + def test_provider_run_turn_raises_not_implemented + p = Provider.new + assert_raises(NotImplementedError) { p.run_turn([], []) } + end + + # ── Module-level run_tool_loop ───────────────────────────────────────────── + + def test_module_run_tool_loop_fresh + registry = Registry.new + provider = Testing::MockProvider.new(Testing::MockResponse.done('result')) + + messages = Temporalio::Contrib::ToolRegistry.run_tool_loop(provider, registry, 'user prompt') + + assert_equal 2, messages.size + assert_equal 'user', messages[0]['role'] + assert_equal 'user prompt', messages[0]['content'] + assert_equal 'assistant', messages[1]['role'] + end + + def test_module_run_tool_loop_with_tool_call + collected = [] + registry = Testing::FakeToolRegistry.new + registry.register(name: 'collect', description: 'd', input_schema: { 'type' => 'object' }) do |input| + collected << input['v'] + 'collected' + end + provider = Testing::MockProvider.new( + Testing::MockResponse.tool_call('collect', { 'v' => 'item' }), + Testing::MockResponse.done('done') + ).with_registry(registry) + + messages = Temporalio::Contrib::ToolRegistry.run_tool_loop(provider, registry, 'go') + + assert_equal ['item'], collected + assert messages.size > 2 + end + + # ── AnthropicProvider is_error / handler error tests ───────────────────────── + + def test_anthropic_handler_error_sets_is_error + require 'temporalio/contrib/tool_registry/providers/anthropic' + + registry = Registry.new + registry.register(name: 'boom', description: 'd', input_schema: {}) do |_| + raise 'intentional failure' + end + + fake_resp = Struct.new(:content, :stop_reason).new( + [{ 'type' => 'tool_use', 'id' => 'c1', 'name' => 'boom', 'input' => {} }], + 'tool_use' + ) + mock_messages = Object.new + mock_messages.define_singleton_method(:create) { |**_| fake_resp } + mock_client = Object.new + mock_client.define_singleton_method(:messages) { mock_messages } + + provider = Temporalio::Contrib::ToolRegistry::Providers::AnthropicProvider.new( + registry, 'sys', client: mock_client + ) + + new_msgs, done = provider.run_turn([], registry.defs) + + refute done + assert_equal 2, new_msgs.size + + tool_result_msg = new_msgs[1] + assert_equal 'user', tool_result_msg['role'] + tool_results = tool_result_msg['content'] + assert_equal 1, tool_results.size + assert_equal 'tool_result', tool_results[0]['type'] + assert_equal true, tool_results[0]['is_error'] + assert_includes tool_results[0]['content'], 'intentional failure' + end + + def test_anthropic_handler_success_no_is_error + require 'temporalio/contrib/tool_registry/providers/anthropic' + + registry = Registry.new + registry.register(name: 'ok_tool', description: 'd', input_schema: {}) { |_| 'result' } + + fake_resp = Struct.new(:content, :stop_reason).new( + [{ 'type' => 'tool_use', 'id' => 'c1', 'name' => 'ok_tool', 'input' => {} }], + 'tool_use' + ) + mock_messages = Object.new + mock_messages.define_singleton_method(:create) { |**_| fake_resp } + mock_client = Object.new + mock_client.define_singleton_method(:messages) { mock_messages } + + provider = Temporalio::Contrib::ToolRegistry::Providers::AnthropicProvider.new( + registry, 'sys', client: mock_client + ) + + new_msgs, = provider.run_turn([], registry.defs) + + tool_results = new_msgs[1]['content'] + assert_equal 1, tool_results.size + refute tool_results[0].key?('is_error'), 'is_error should not be present on success' + end + + # ── Integration tests (skipped unless RUN_INTEGRATION_TESTS is set) ───────── + + def make_record_registry + collected = [] + registry = Registry.new + registry.register( + name: 'record', + description: 'Record a value', + input_schema: { + 'type' => 'object', + 'properties' => { 'value' => { 'type' => 'string' } }, + 'required' => ['value'] + } + ) do |input| + collected << input['value'] + 'recorded' + end + [registry, collected] + end + + def test_integration_anthropic + skip 'RUN_INTEGRATION_TESTS not set' unless ENV['RUN_INTEGRATION_TESTS'] + api_key = ENV['ANTHROPIC_API_KEY'] + skip 'ANTHROPIC_API_KEY not set' unless api_key + + require 'temporalio/contrib/tool_registry/providers/anthropic' + registry, collected = make_record_registry + provider = Temporalio::Contrib::ToolRegistry::Providers::AnthropicProvider.new( + registry, + "You must call record() exactly once with value='hello'.", + api_key: api_key + ) + Temporalio::Contrib::ToolRegistry.run_tool_loop( + provider, registry, + "Please call the record tool with value='hello'." + ) + assert_includes collected, 'hello' + end + + # ── from_mcp_tools ──────────────────────────────────────────────────────── + + def test_from_mcp_tools + t1 = Struct.new(:name, :description, :input_schema).new( + 'read_file', 'Read a file', + { 'type' => 'object', 'properties' => { 'path' => { 'type' => 'string' } } } + ) + t2 = Struct.new(:name, :description, :input_schema).new('list_dir', nil, nil) + + reg = Registry.from_mcp_tools([t1, t2]) + + assert_equal 2, reg.defs.size + assert_equal 'read_file', reg.defs[0].name + assert_equal 'Read a file', reg.defs[0].description + assert_equal 'list_dir', reg.defs[1].name + assert_equal 'object', reg.defs[1].input_schema['type'] # nil schema → empty object schema + assert_equal '', reg.dispatch('read_file', { 'path' => '/etc/hosts' }) + end + + def test_integration_openai + skip 'RUN_INTEGRATION_TESTS not set' unless ENV['RUN_INTEGRATION_TESTS'] + api_key = ENV['OPENAI_API_KEY'] + skip 'OPENAI_API_KEY not set' unless api_key + + require 'temporalio/contrib/tool_registry/providers/openai' + registry, collected = make_record_registry + provider = Temporalio::Contrib::ToolRegistry::Providers::OpenAIProvider.new( + registry, + "You must call record() exactly once with value='hello'.", + api_key: api_key + ) + Temporalio::Contrib::ToolRegistry.run_tool_loop( + provider, registry, + "Please call the record tool with value='hello'." + ) + assert_includes collected, 'hello' + end + end +end diff --git a/temporalio/test/contrib/tool_registry_testing_test.rb b/temporalio/test/contrib/tool_registry_testing_test.rb new file mode 100644 index 00000000..4962f479 --- /dev/null +++ b/temporalio/test/contrib/tool_registry_testing_test.rb @@ -0,0 +1,177 @@ +# frozen_string_literal: true + +require 'minitest/autorun' +require 'temporalio/contrib/tool_registry' +require 'temporalio/contrib/tool_registry/testing' + +module Contrib + # Tests for the testing utilities: MockResponse, MockProvider, FakeToolRegistry, + # MockAgenticSession, and CrashAfterTurns. + class ToolRegistryTestingTest < Minitest::Test + Registry = Temporalio::Contrib::ToolRegistry::Registry + Testing = Temporalio::Contrib::ToolRegistry::Testing + + # ── MockResponse ────────────────────────────────────────────────────────── + + def test_mock_response_done + r = Testing::MockResponse.done('all done') + assert_equal :done, r.type + assert_equal 'all done', r.text + assert_nil r.tool_name + end + + def test_mock_response_tool_call + r = Testing::MockResponse.tool_call('search', { 'q' => 'ruby' }) + assert_equal :tool_call, r.type + assert_equal 'search', r.tool_name + assert_equal({ 'q' => 'ruby' }, r.input) + assert_nil r.call_id + end + + def test_mock_response_tool_call_with_id + r = Testing::MockResponse.tool_call('search', { 'q' => 'ruby' }, 'call-123') + assert_equal 'call-123', r.call_id + end + + # ── MockProvider ────────────────────────────────────────────────────────── + + def test_mock_provider_done_response + provider = Testing::MockProvider.new(Testing::MockResponse.done('finished')) + msgs, done = provider.run_turn([], []) + assert done + assert_equal 1, msgs.size + assert_equal 'assistant', msgs.first['role'] + end + + def test_mock_provider_tool_call_without_registry_raises + provider = Testing::MockProvider.new(Testing::MockResponse.tool_call('x', {})) + assert_raises(RuntimeError) { provider.run_turn([], []) } + end + + def test_mock_provider_tool_call_dispatches + registry = Testing::FakeToolRegistry.new + registry.register(name: 'greet', description: 'd', input_schema: {}) { |i| "Hello #{i['name']}" } + + provider = Testing::MockProvider.new( + Testing::MockResponse.tool_call('greet', { 'name' => 'World' }), + Testing::MockResponse.done('bye') + ).with_registry(registry) + + msgs1, done1 = provider.run_turn([], []) + refute done1 + assert_equal 2, msgs1.size # assistant + tool_result user message + assert_equal 'assistant', msgs1[0]['role'] + assert_equal 'user', msgs1[1]['role'] + + msgs2, done2 = provider.run_turn([], []) + assert done2 + end + + def test_mock_provider_exhausted_raises + provider = Testing::MockProvider.new(Testing::MockResponse.done('only one')) + provider.run_turn([], []) + assert_raises(RuntimeError) { provider.run_turn([], []) } + end + + def test_mock_provider_uses_explicit_call_id + registry = Testing::FakeToolRegistry.new + registry.register(name: 't', description: 'd', input_schema: {}) { 'result' } + + provider = Testing::MockProvider.new( + Testing::MockResponse.tool_call('t', {}, 'fixed-id') + ).with_registry(registry) + + msgs, = provider.run_turn([], []) + # First message is assistant with tool_use content block containing our id + tool_block = msgs[0]['content'].first + assert_equal 'fixed-id', tool_block['id'] + # Second message is tool_result referencing the same id + result_block = msgs[1]['content'].first + assert_equal 'fixed-id', result_block['tool_use_id'] + end + + # ── FakeToolRegistry ────────────────────────────────────────────────────── + + def test_fake_tool_registry_records_calls + r = Testing::FakeToolRegistry.new + r.register(name: 'add', description: 'add', input_schema: {}) { |i| i['a'] + i['b'] } + + r.dispatch('add', { 'a' => 1, 'b' => 2 }) + r.dispatch('add', { 'a' => 10, 'b' => 20 }) + + assert_equal 2, r.calls.size + assert_equal 'add', r.calls.first.name + assert_equal 3, r.calls.first.result + assert_equal 30, r.calls.last.result + end + + def test_fake_tool_registry_inherits_dispatch + r = Testing::FakeToolRegistry.new + r.register(name: 'upper', description: 'u', input_schema: {}) { |i| i['s'].upcase } + assert_equal 'HELLO', r.dispatch('upper', { 's' => 'hello' }) + end + + def test_fake_tool_registry_unknown_tool_raises + r = Testing::FakeToolRegistry.new + assert_raises(KeyError) { r.dispatch('no_such', {}) } + end + + # ── MockAgenticSession ──────────────────────────────────────────────────── + + def test_mock_agentic_session_captures_prompt + registry = Registry.new + provider = Testing::MockProvider.new(Testing::MockResponse.done('x')) + session = Testing::MockAgenticSession.new + session.run_tool_loop(provider, registry, 'sys', 'my prompt') + assert_equal 'my prompt', session.captured_prompt + end + + def test_mock_agentic_session_run_tool_loop_is_noop + registry = Registry.new + provider = Testing::MockProvider.new # no responses — would crash if called + session = Testing::MockAgenticSession.new + session.run_tool_loop(provider, registry, 'sys', 'whatever') + assert_empty session.messages # not modified + end + + def test_mock_agentic_session_mutable_issues + session = Testing::MockAgenticSession.new + session.mutable_issues << { 'type' => 'seed' } + assert_equal 1, session.issues.size + end + + # ── CrashAfterTurns ─────────────────────────────────────────────────────── + + def test_crash_after_turns_crashes_on_nth_plus_one + provider = Testing::CrashAfterTurns.new( + 2, + Testing::MockProvider.new( + Testing::MockResponse.done('t1'), + Testing::MockResponse.done('t2'), + Testing::MockResponse.done('t3') + ) + ) + + provider.run_turn([], []) # turn 1 — ok + provider.run_turn([], []) # turn 2 — ok + assert_raises(RuntimeError) { provider.run_turn([], []) } # turn 3 — crash + end + + def test_crash_after_turns_delegates + inner = Testing::MockProvider.new(Testing::MockResponse.done('hello')) + provider = Testing::CrashAfterTurns.new(5, inner) + msgs, done = provider.run_turn([], []) + assert done + assert_equal 'assistant', msgs.first['role'] + end + + # ── DispatchCall ────────────────────────────────────────────────────────── + + def test_dispatch_call_fields + call = Testing::DispatchCall.new(name: 'fn', input: { 'x' => 1 }, result: 'out') + assert_equal 'fn', call.name + assert_equal({ 'x' => 1 }, call.input) + assert_equal 'out', call.result + end + end +end