Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/mars/agent_step.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def agent(klass = nil)
end
end

def run(input)
self.class.agent.new.ask(input).content
def run(context)
self.class.agent.new.ask(context.current_input).content
end
end
end
11 changes: 9 additions & 2 deletions lib/mars/aggregator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@ def initialize(name = "Aggregator", operation: nil, **kwargs)
@operation = operation || ->(inputs) { inputs }
end

def run(inputs)
operation.call(inputs)
def run(context)
context = ensure_context(context)
operation.call(context.current_input)
end

private

def ensure_context(input)
input.is_a?(ExecutionContext) ? input : ExecutionContext.new(input: input)
end
end
end
7 changes: 4 additions & 3 deletions lib/mars/execution_context.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

module MARS
class ExecutionContext
attr_reader :current_input, :outputs, :global_state
attr_reader :outputs, :global_state
attr_accessor :current_input

def initialize(input: nil, global_state: {})
@current_input = input
Expand All @@ -19,8 +20,8 @@ def record(step_name, output)
@current_input = output
end

def fork(input: current_input)
self.class.new(input: input, global_state: global_state)
def fork(input: current_input, state: {})
self.class.new(input: input, global_state: global_state.merge(state))
end

def merge(child_contexts)
Expand Down
17 changes: 9 additions & 8 deletions lib/mars/gate.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,26 @@ def fallback(key, runnable)
def fallbacks_map
@fallbacks_map ||= {}
end

def halt_scope(scope = nil)
scope ? @halt_scope = scope : @halt_scope
end
end

def initialize(name = "Gate", check: nil, fallbacks: nil, halt_scope: nil, **kwargs)
def initialize(name = "Gate", check: nil, fallbacks: nil, **kwargs)
super(name: name, **kwargs)

@check = check || self.class.check_block
@fallbacks = fallbacks || self.class.fallbacks_map
@halt_scope = halt_scope || self.class.halt_scope || :local
end

def run(input)
def run(context)
context = ensure_context(context)
input = context.current_input
result = check.call(input)

return input unless result

branch = fallbacks[result]
raise ArgumentError, "No fallback registered for #{result.inspect}" unless branch

Halt.new(resolve_branch(branch).run(input), scope: @halt_scope)
resolve_branch(branch).run(context)
end

private
Expand All @@ -48,5 +45,9 @@ def run(input)
def resolve_branch(branch)
branch.is_a?(Class) ? branch.new : branch
end

def ensure_context(input)
input.is_a?(ExecutionContext) ? input : ExecutionContext.new(input: input)
end
end
end
15 changes: 0 additions & 15 deletions lib/mars/halt.rb

This file was deleted.

2 changes: 1 addition & 1 deletion lib/mars/runnable.rb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def initialize(name: self.class.step_name, state: {}, formatter: nil)
@formatter = formatter || self.class.formatter&.new || Formatter.new
end

def run(input)
def run(context)
raise NotImplementedError
end
end
Expand Down
33 changes: 12 additions & 21 deletions lib/mars/workflows/parallel.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,27 @@ def initialize(name, steps:, aggregator: nil, **kwargs)
@aggregator = aggregator || Aggregator.new("#{name} Aggregator")
end

def run(input)
context = ensure_context(input)
def run(context)
context = ensure_context(context)
errors = []
child_contexts = []
results = execute_steps(context, errors, child_contexts)

raise AggregateError, errors if errors.any?

context.merge(child_contexts)
aggregate_results(results)
context.current_input = results
aggregator.run(context)
end

private

attr_reader :steps, :aggregator

def aggregate_results(results)
has_global_halt = results.any? { |r| r.is_a?(Halt) && r.global? }
unwrapped = results.map { |r| r.is_a?(Halt) ? r.result : r }
result = aggregator.run(unwrapped)
has_global_halt ? Halt.new(result, scope: :global) : result
end

def execute_steps(context, errors, child_contexts)
Async do |workflow|
tasks = steps.map do |step|
child_ctx = context.fork
child_ctx = context.fork(state: step.state)
child_contexts << child_ctx

workflow.async do
Expand All @@ -54,17 +48,14 @@ def workflow_step(step, child_ctx)
step.run_before_hooks(child_ctx)

step_input = step.formatter.format_input(child_ctx)
result = step.run(step_input)
child_ctx.current_input = step_input

result = step.run(child_ctx)

if result.is_a?(Halt)
step.run_after_hooks(child_ctx, result)
result
else
formatted = step.formatter.format_output(result)
child_ctx.record(step.name, formatted)
step.run_after_hooks(child_ctx, formatted)
formatted
end
formatted = step.formatter.format_output(result)
child_ctx.record(step.name, formatted)
step.run_after_hooks(child_ctx, formatted)
formatted
end

def ensure_context(input)
Expand Down
20 changes: 5 additions & 15 deletions lib/mars/workflows/sequential.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,16 @@ def initialize(name, steps:, **kwargs)
@steps = steps
end

def run(input)
context = ensure_context(input)
def run(context)
context = ensure_context(context)

@steps.each do |step|
step.run_before_hooks(context)

step_input = step.formatter.format_input(context)
result = step.run(step_input)

if result.is_a?(Halt)
if result.global?
step.run_after_hooks(context, result)
return result
end

formatted = step.formatter.format_output(result.result)
context.record(step.name, formatted)
step.run_after_hooks(context, formatted)
break
end
context.current_input = step_input

result = step.run(context)

formatted = step.formatter.format_output(result)
context.record(step.name, formatted)
Expand Down
2 changes: 1 addition & 1 deletion spec/mars/agent_step_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

it "creates a new agent instance and calls ask" do
step = step_class.new
result = step.run("hello")
result = step.run(MARS::ExecutionContext.new(input: "hello"))

expect(result).to eq("agent response")
expect(mock_agent_class).to have_received(:new)
Expand Down
4 changes: 2 additions & 2 deletions spec/mars/execution_context_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@
expect(child.current_input).to eq("custom")
end

it "shares global_state with the parent" do
it "does not share global_state with the parent" do
context = described_class.new(input: "query", global_state: { shared: true })
child = context.fork

child.global_state[:added_by_child] = true

expect(context.global_state[:added_by_child]).to be(true)
expect(context.global_state[:added_by_child]).to be_nil
end

it "has independent outputs from the parent" do
Expand Down
60 changes: 10 additions & 50 deletions spec/mars/gate_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
RSpec.describe MARS::Gate do
let(:fallback_step) do
Class.new(MARS::Runnable) do
def run(input)
"fallback: #{input}"
def run(context)
"fallback: #{context.current_input}"
end
end.new
end

let(:error_step) do
Class.new(MARS::Runnable) do
def run(input)
"error: #{input}"
def run(context)
"error: #{context.current_input}"
end
end.new
end
Expand All @@ -29,16 +29,15 @@ def run(input)
expect(gate.run("hello")).to eq("hello")
end

it "halts with fallback result when check returns a key" do
it "returns the fallback branch result when check returns a registered key" do
gate = described_class.new(
"FailGate",
check: ->(_input) { :fail },
fallbacks: { fail: fallback_step }
)

result = gate.run("hello")
expect(result).to be_a(MARS::Halt)
expect(result.result).to eq("fallback: hello")
expect(result).to eq("fallback: hello")
end

it "raises when check returns an unregistered key" do
Expand All @@ -60,16 +59,15 @@ def run(input)

input = { error_type: :auth }
result = gate.run(input)
expect(result).to be_a(MARS::Halt)
expect(result.result).to eq("error: #{input}")
expect(result).to eq("error: #{input}")
end
end

context "with class-level DSL" do
let(:fallback_cls) do
Class.new(MARS::Runnable) do
def run(input)
"handled: #{input}"
def run(context)
"handled: #{context.current_input}"
end
end
end
Expand All @@ -83,45 +81,7 @@ def run(input)

gate = gate_class.new("DSLGate")
expect(gate.run("hi")).to eq("hi")
expect(gate.run("longstring").result).to eq("handled: longstring")
end

it "supports halt_scope DSL" do
cls = fallback_cls
gate_class = Class.new(described_class) do
check { |_input| :fail }
fallback :fail, cls
halt_scope :global
end

result = gate_class.new("GlobalGate").run("test")
expect(result).to be_a(MARS::Halt)
expect(result).to be_global
end
end

context "with halt scope" do
it "defaults to local scope" do
gate = described_class.new(
"LocalGate",
check: ->(_input) { :fail },
fallbacks: { fail: fallback_step }
)

result = gate.run("hello")
expect(result).to be_local
end

it "respects constructor halt_scope" do
gate = described_class.new(
"GlobalGate",
check: ->(_input) { :fail },
fallbacks: { fail: fallback_step },
halt_scope: :global
)

result = gate.run("hello")
expect(result).to be_global
expect(gate.run("longstring")).to eq("handled: longstring")
end
end
end
Expand Down
26 changes: 0 additions & 26 deletions spec/mars/halt_spec.rb

This file was deleted.

Loading
Loading