diff --git a/lib/mars/agent_step.rb b/lib/mars/agent_step.rb index a519997..36244e6 100644 --- a/lib/mars/agent_step.rb +++ b/lib/mars/agent_step.rb @@ -8,8 +8,8 @@ def agent(klass = nil) end end - def run(input) - self.class.agent.new.ask(input).content + def run(context) + self.class.agent.new.ask(context.current_input).content end end end diff --git a/lib/mars/aggregator.rb b/lib/mars/aggregator.rb index d21b3bd..843b28e 100644 --- a/lib/mars/aggregator.rb +++ b/lib/mars/aggregator.rb @@ -10,8 +10,15 @@ def initialize(name = "Aggregator", operation: nil, **kwargs) @operation = operation || ->(inputs) { inputs } end - def run(inputs) - operation.call(inputs) + def run(context) + context = ensure_context(context) + operation.call(context.current_input) + end + + private + + def ensure_context(input) + input.is_a?(ExecutionContext) ? input : ExecutionContext.new(input: input) end end end diff --git a/lib/mars/execution_context.rb b/lib/mars/execution_context.rb index 2e4b7be..8467acb 100644 --- a/lib/mars/execution_context.rb +++ b/lib/mars/execution_context.rb @@ -2,7 +2,8 @@ module MARS class ExecutionContext - attr_reader :current_input, :outputs, :global_state + attr_reader :outputs, :global_state + attr_accessor :current_input def initialize(input: nil, global_state: {}) @current_input = input @@ -19,8 +20,8 @@ def record(step_name, output) @current_input = output end - def fork(input: current_input) - self.class.new(input: input, global_state: global_state) + def fork(input: current_input, state: {}) + self.class.new(input: input, global_state: global_state.merge(state)) end def merge(child_contexts) diff --git a/lib/mars/gate.rb b/lib/mars/gate.rb index 5ce7502..195716b 100644 --- a/lib/mars/gate.rb +++ b/lib/mars/gate.rb @@ -16,21 +16,18 @@ def fallback(key, runnable) def fallbacks_map @fallbacks_map ||= {} end - - def halt_scope(scope = nil) - scope ? @halt_scope = scope : @halt_scope - end end - def initialize(name = "Gate", check: nil, fallbacks: nil, halt_scope: nil, **kwargs) + def initialize(name = "Gate", check: nil, fallbacks: nil, **kwargs) super(name: name, **kwargs) @check = check || self.class.check_block @fallbacks = fallbacks || self.class.fallbacks_map - @halt_scope = halt_scope || self.class.halt_scope || :local end - def run(input) + def run(context) + context = ensure_context(context) + input = context.current_input result = check.call(input) return input unless result @@ -38,7 +35,7 @@ def run(input) branch = fallbacks[result] raise ArgumentError, "No fallback registered for #{result.inspect}" unless branch - Halt.new(resolve_branch(branch).run(input), scope: @halt_scope) + resolve_branch(branch).run(context) end private @@ -48,5 +45,9 @@ def run(input) def resolve_branch(branch) branch.is_a?(Class) ? branch.new : branch end + + def ensure_context(input) + input.is_a?(ExecutionContext) ? input : ExecutionContext.new(input: input) + end end end diff --git a/lib/mars/halt.rb b/lib/mars/halt.rb deleted file mode 100644 index 043e80e..0000000 --- a/lib/mars/halt.rb +++ /dev/null @@ -1,15 +0,0 @@ -# frozen_string_literal: true - -module MARS - class Halt - attr_reader :result, :scope - - def initialize(result, scope: :local) - @result = result - @scope = scope - end - - def local? = scope == :local - def global? = scope == :global - end -end diff --git a/lib/mars/runnable.rb b/lib/mars/runnable.rb index ea071e5..8996e69 100644 --- a/lib/mars/runnable.rb +++ b/lib/mars/runnable.rb @@ -26,7 +26,7 @@ def initialize(name: self.class.step_name, state: {}, formatter: nil) @formatter = formatter || self.class.formatter&.new || Formatter.new end - def run(input) + def run(context) raise NotImplementedError end end diff --git a/lib/mars/workflows/parallel.rb b/lib/mars/workflows/parallel.rb index 4e147d2..1038b64 100644 --- a/lib/mars/workflows/parallel.rb +++ b/lib/mars/workflows/parallel.rb @@ -10,8 +10,8 @@ def initialize(name, steps:, aggregator: nil, **kwargs) @aggregator = aggregator || Aggregator.new("#{name} Aggregator") end - def run(input) - context = ensure_context(input) + def run(context) + context = ensure_context(context) errors = [] child_contexts = [] results = execute_steps(context, errors, child_contexts) @@ -19,24 +19,18 @@ def run(input) raise AggregateError, errors if errors.any? context.merge(child_contexts) - aggregate_results(results) + context.current_input = results + aggregator.run(context) end private attr_reader :steps, :aggregator - def aggregate_results(results) - has_global_halt = results.any? { |r| r.is_a?(Halt) && r.global? } - unwrapped = results.map { |r| r.is_a?(Halt) ? r.result : r } - result = aggregator.run(unwrapped) - has_global_halt ? Halt.new(result, scope: :global) : result - end - def execute_steps(context, errors, child_contexts) Async do |workflow| tasks = steps.map do |step| - child_ctx = context.fork + child_ctx = context.fork(state: step.state) child_contexts << child_ctx workflow.async do @@ -54,17 +48,14 @@ def workflow_step(step, child_ctx) step.run_before_hooks(child_ctx) step_input = step.formatter.format_input(child_ctx) - result = step.run(step_input) + child_ctx.current_input = step_input + + result = step.run(child_ctx) - if result.is_a?(Halt) - step.run_after_hooks(child_ctx, result) - result - else - formatted = step.formatter.format_output(result) - child_ctx.record(step.name, formatted) - step.run_after_hooks(child_ctx, formatted) - formatted - end + formatted = step.formatter.format_output(result) + child_ctx.record(step.name, formatted) + step.run_after_hooks(child_ctx, formatted) + formatted end def ensure_context(input) diff --git a/lib/mars/workflows/sequential.rb b/lib/mars/workflows/sequential.rb index 008dd18..f8539b9 100644 --- a/lib/mars/workflows/sequential.rb +++ b/lib/mars/workflows/sequential.rb @@ -9,26 +9,16 @@ def initialize(name, steps:, **kwargs) @steps = steps end - def run(input) - context = ensure_context(input) + def run(context) + context = ensure_context(context) @steps.each do |step| step.run_before_hooks(context) step_input = step.formatter.format_input(context) - result = step.run(step_input) - - if result.is_a?(Halt) - if result.global? - step.run_after_hooks(context, result) - return result - end - - formatted = step.formatter.format_output(result.result) - context.record(step.name, formatted) - step.run_after_hooks(context, formatted) - break - end + context.current_input = step_input + + result = step.run(context) formatted = step.formatter.format_output(result) context.record(step.name, formatted) diff --git a/spec/mars/agent_step_spec.rb b/spec/mars/agent_step_spec.rb index 18244a5..af197ce 100644 --- a/spec/mars/agent_step_spec.rb +++ b/spec/mars/agent_step_spec.rb @@ -39,7 +39,7 @@ it "creates a new agent instance and calls ask" do step = step_class.new - result = step.run("hello") + result = step.run(MARS::ExecutionContext.new(input: "hello")) expect(result).to eq("agent response") expect(mock_agent_class).to have_received(:new) diff --git a/spec/mars/execution_context_spec.rb b/spec/mars/execution_context_spec.rb index d30d47d..d2ba095 100644 --- a/spec/mars/execution_context_spec.rb +++ b/spec/mars/execution_context_spec.rb @@ -79,13 +79,13 @@ expect(child.current_input).to eq("custom") end - it "shares global_state with the parent" do + it "does not share global_state with the parent" do context = described_class.new(input: "query", global_state: { shared: true }) child = context.fork child.global_state[:added_by_child] = true - expect(context.global_state[:added_by_child]).to be(true) + expect(context.global_state[:added_by_child]).to be_nil end it "has independent outputs from the parent" do diff --git a/spec/mars/gate_spec.rb b/spec/mars/gate_spec.rb index cea8849..4d2f472 100644 --- a/spec/mars/gate_spec.rb +++ b/spec/mars/gate_spec.rb @@ -3,16 +3,16 @@ RSpec.describe MARS::Gate do let(:fallback_step) do Class.new(MARS::Runnable) do - def run(input) - "fallback: #{input}" + def run(context) + "fallback: #{context.current_input}" end end.new end let(:error_step) do Class.new(MARS::Runnable) do - def run(input) - "error: #{input}" + def run(context) + "error: #{context.current_input}" end end.new end @@ -29,7 +29,7 @@ def run(input) expect(gate.run("hello")).to eq("hello") end - it "halts with fallback result when check returns a key" do + it "returns the fallback branch result when check returns a registered key" do gate = described_class.new( "FailGate", check: ->(_input) { :fail }, @@ -37,8 +37,7 @@ def run(input) ) result = gate.run("hello") - expect(result).to be_a(MARS::Halt) - expect(result.result).to eq("fallback: hello") + expect(result).to eq("fallback: hello") end it "raises when check returns an unregistered key" do @@ -60,16 +59,15 @@ def run(input) input = { error_type: :auth } result = gate.run(input) - expect(result).to be_a(MARS::Halt) - expect(result.result).to eq("error: #{input}") + expect(result).to eq("error: #{input}") end end context "with class-level DSL" do let(:fallback_cls) do Class.new(MARS::Runnable) do - def run(input) - "handled: #{input}" + def run(context) + "handled: #{context.current_input}" end end end @@ -83,45 +81,7 @@ def run(input) gate = gate_class.new("DSLGate") expect(gate.run("hi")).to eq("hi") - expect(gate.run("longstring").result).to eq("handled: longstring") - end - - it "supports halt_scope DSL" do - cls = fallback_cls - gate_class = Class.new(described_class) do - check { |_input| :fail } - fallback :fail, cls - halt_scope :global - end - - result = gate_class.new("GlobalGate").run("test") - expect(result).to be_a(MARS::Halt) - expect(result).to be_global - end - end - - context "with halt scope" do - it "defaults to local scope" do - gate = described_class.new( - "LocalGate", - check: ->(_input) { :fail }, - fallbacks: { fail: fallback_step } - ) - - result = gate.run("hello") - expect(result).to be_local - end - - it "respects constructor halt_scope" do - gate = described_class.new( - "GlobalGate", - check: ->(_input) { :fail }, - fallbacks: { fail: fallback_step }, - halt_scope: :global - ) - - result = gate.run("hello") - expect(result).to be_global + expect(gate.run("longstring")).to eq("handled: longstring") end end end diff --git a/spec/mars/halt_spec.rb b/spec/mars/halt_spec.rb deleted file mode 100644 index da3b9c5..0000000 --- a/spec/mars/halt_spec.rb +++ /dev/null @@ -1,26 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe MARS::Halt do - describe "#scope" do - it "defaults to :local" do - halt = described_class.new("result") - expect(halt.scope).to eq(:local) - expect(halt).to be_local - expect(halt).not_to be_global - end - - it "can be set to :global" do - halt = described_class.new("result", scope: :global) - expect(halt.scope).to eq(:global) - expect(halt).to be_global - expect(halt).not_to be_local - end - end - - describe "#result" do - it "stores the result" do - halt = described_class.new("hello") - expect(halt.result).to eq("hello") - end - end -end diff --git a/spec/mars/workflows/parallel_spec.rb b/spec/mars/workflows/parallel_spec.rb index 97407c8..7b4afb8 100644 --- a/spec/mars/workflows/parallel_spec.rb +++ b/spec/mars/workflows/parallel_spec.rb @@ -10,9 +10,9 @@ def initialize(value, **kwargs) @value = value end - def run(input) + def run(context) sleep 0.1 - input + @value + context.current_input + @value end end end @@ -24,8 +24,8 @@ def initialize(multiplier, **kwargs) @multiplier = multiplier end - def run(input) - input * @multiplier + def run(context) + context.current_input * @multiplier end end end @@ -37,7 +37,7 @@ def initialize(message, **kwargs) @message = message end - def run(_input) + def run(_context) raise StandardError, @message end end @@ -80,11 +80,11 @@ def run(_input) it "records outputs in context per step" do step1 = Class.new(MARS::Runnable) do - def run(input) = "from_step1:#{input}" + def run(input) = "from_step1:#{input.current_input}" end.new(name: "step1") step2 = Class.new(MARS::Runnable) do - def run(input) = "from_step2:#{input}" + def run(input) = "from_step2:#{input.current_input}" end.new(name: "step2") context = MARS::ExecutionContext.new(input: "hello") @@ -97,11 +97,11 @@ def run(input) = "from_step2:#{input}" it "forks context so parallel steps get independent current_input" do step1 = Class.new(MARS::Runnable) do - def run(input) = "#{input}_modified" + def run(context) = "#{context.current_input}_modified" end.new(name: "step1") step2 = Class.new(MARS::Runnable) do - def run(input) = "#{input}_also_modified" + def run(context) = "#{context.current_input}_also_modified" end.new(name: "step2") context = MARS::ExecutionContext.new(input: "original") @@ -115,7 +115,7 @@ def run(input) = "#{input}_also_modified" it "shares global_state across forked contexts" do step1 = Class.new(MARS::Runnable) do - def run(_input) + def run(_context) "done" end end.new(name: "step1") @@ -135,7 +135,7 @@ def format_output(output) end step = Class.new(MARS::Runnable) do - def run(input) = "result:#{input}" + def run(context) = "result:#{context.current_input}" end.new(name: "step", formatter: uppercase_formatter.new) workflow = described_class.new("fmt_workflow", steps: [step]) @@ -150,7 +150,7 @@ def run(input) = "result:#{input}" before_run { |_ctx, step| hook_log << "before:#{step.name}" } after_run { |_ctx, _result, step| hook_log << "after:#{step.name}" } - def run(input) = input + def run(context) = context.current_input end step = step_class.new(name: "hooked") @@ -160,50 +160,6 @@ def run(input) = input expect(hook_log).to eq(["before:hooked", "after:hooked"]) end - it "unwraps local halts and returns plain result" do - gate = MARS::Gate.new( - "local_branch", - check: ->(_input) { :branch }, - fallbacks: { - branch: Class.new(MARS::Runnable) do - def run(input) - "branched:#{input}" - end - end.new(name: "branch_step") - } - ) - add_five = add_step_class.new(5, name: "add_five") - - workflow = described_class.new("halt_workflow", steps: [gate, add_five]) - - result = workflow.run(10) - expect(result).not_to be_a(MARS::Halt) - expect(result).to eq(["branched:10", 15]) - end - - it "propagates global halt to parent workflow" do - gate = MARS::Gate.new( - "global_branch", - check: ->(_input) { :branch }, - fallbacks: { - branch: Class.new(MARS::Runnable) do - def run(input) - "branched:#{input}" - end - end.new(name: "branch_step") - }, - halt_scope: :global - ) - add_five = add_step_class.new(5, name: "add_five") - - workflow = described_class.new("halt_workflow", steps: [gate, add_five]) - - result = workflow.run(10) - expect(result).to be_a(MARS::Halt) - expect(result).to be_global - expect(result.result).to eq(["branched:10", 15]) - end - it "propagates errors from steps" do add_step = add_step_class.new(5, name: "add") error_step = error_step_class.new("Step failed", name: "error_step_one") diff --git a/spec/mars/workflows/sequential_spec.rb b/spec/mars/workflows/sequential_spec.rb index b2ea2b5..b61d7f5 100644 --- a/spec/mars/workflows/sequential_spec.rb +++ b/spec/mars/workflows/sequential_spec.rb @@ -9,7 +9,7 @@ def initialize(value, **kwargs) end def run(input) - input + @value + input.current_input + @value end end end @@ -22,7 +22,7 @@ def initialize(multiplier, **kwargs) end def run(input) - input * @multiplier + input.current_input * @multiplier end end end @@ -67,11 +67,11 @@ def run(_input) it "records outputs in context accessible by step name" do step1 = Class.new(MARS::Runnable) do - def run(input) = "from_step1:#{input}" + def run(input) = "from_step1:#{input.current_input}" end.new(name: "step1") step2 = Class.new(MARS::Runnable) do - def run(input) = "from_step2:#{input}" + def run(input) = "from_step2:#{input.current_input}" end.new(name: "step2") context = MARS::ExecutionContext.new(input: "hello") @@ -84,7 +84,7 @@ def run(input) = "from_step2:#{input}" it "wraps raw input in ExecutionContext automatically" do step = Class.new(MARS::Runnable) do - def run(input) = "processed:#{input}" + def run(input) = "processed:#{input.current_input}" end.new(name: "step") workflow = described_class.new("auto_wrap", steps: [step]) @@ -100,7 +100,7 @@ def format_output(output) end step = Class.new(MARS::Runnable) do - def run(input) = "result:#{input}" + def run(input) = "result:#{input.current_input}" end.new(name: "step", formatter: uppercase_formatter.new) workflow = described_class.new("fmt_workflow", steps: [step]) @@ -115,7 +115,7 @@ def run(input) = "result:#{input}" before_run { |_ctx, step| hook_log << "before:#{step.name}" } after_run { |_ctx, _result, step| hook_log << "after:#{step.name}" } - def run(input) = input + def run(input) = input.current_input end step = step_class.new(name: "hooked") @@ -125,104 +125,6 @@ def run(input) = input expect(hook_log).to eq(["before:hooked", "after:hooked"]) end - it "halts locally when a gate triggers with local scope" do - add_five = add_step_class.new(5, name: "add_five") - gate = MARS::Gate.new( - "local_gate", - check: ->(_input) { :branch }, - fallbacks: { - branch: Class.new(MARS::Runnable) do - def run(input) - "branched:#{input}" - end - end.new(name: "branch_step") - } - ) - multiply_three = multiply_step_class.new(3, name: "multiply_three") - - workflow = described_class.new("halt_workflow", steps: [add_five, gate, multiply_three]) - - # 10 + 5 = 15, gate branches -> "branched:15", multiply_three is never reached - result = workflow.run(10) - expect(result).to eq("branched:15") - expect(result).not_to be_a(MARS::Halt) - end - - it "propagates global halt without unwrapping" do - add_five = add_step_class.new(5, name: "add_five") - gate = MARS::Gate.new( - "global_gate", - check: ->(_input) { :branch }, - fallbacks: { - branch: Class.new(MARS::Runnable) do - def run(input) - "branched:#{input}" - end - end.new(name: "branch_step") - }, - halt_scope: :global - ) - multiply_three = multiply_step_class.new(3, name: "multiply_three") - - workflow = described_class.new("halt_workflow", steps: [add_five, gate, multiply_three]) - - result = workflow.run(10) - expect(result).to be_a(MARS::Halt) - expect(result).to be_global - expect(result.result).to eq("branched:15") - end - - it "propagates global halt through nested sequential workflows" do - inner_gate = MARS::Gate.new( - "inner_gate", - check: ->(_input) { :stop }, - fallbacks: { - stop: Class.new(MARS::Runnable) do - def run(input) - "stopped:#{input}" - end - end.new(name: "stop_step") - }, - halt_scope: :global - ) - - inner = described_class.new("inner", steps: [inner_gate]) - after_inner = add_step_class.new(100, name: "after_inner") - outer = described_class.new("outer", steps: [inner, after_inner]) - - result = outer.run(1) - expect(result).to be_a(MARS::Halt) - expect(result.result).to eq("stopped:1") - end - - it "consumes local halt — outer workflow continues" do - inner_gate = MARS::Gate.new( - "inner_gate", - check: ->(_input) { :stop }, - fallbacks: { - stop: Class.new(MARS::Runnable) do - def run(input) - "stopped:#{input}" - end - end.new(name: "stop_step") - } - ) - - inner = described_class.new("inner", steps: [inner_gate]) - - string_step = Class.new(MARS::Runnable) do - def run(input) - "after:#{input}" - end - end.new(name: "after_step") - - outer = described_class.new("outer", steps: [inner, string_step]) - - result = outer.run(1) - expect(result).to eq("after:stopped:1") - expect(result).not_to be_a(MARS::Halt) - end - it "propagates errors from steps" do add_step = add_step_class.new(5, name: "add") error_step = error_step_class.new("Step failed", name: "error")