diff --git a/lib/mars/gate.rb b/lib/mars/gate.rb index 5418d73..71fa591 100644 --- a/lib/mars/gate.rb +++ b/lib/mars/gate.rb @@ -27,10 +27,9 @@ def initialize(name = "Gate", check: nil, fallbacks: nil, **kwargs) def run(context) context = ensure_context(context) - input = context.current_input - result = check.call(input) + result = check.call(context) - return input if result.nil? || result == :default + return context if result.nil? || result == :default branch = fallbacks[result] raise ArgumentError, "No fallback registered for #{result.inspect}" unless branch diff --git a/spec/mars/gate_spec.rb b/spec/mars/gate_spec.rb index 4d2f472..2414aa3 100644 --- a/spec/mars/gate_spec.rb +++ b/spec/mars/gate_spec.rb @@ -22,17 +22,17 @@ def run(context) it "passes through when check returns falsy" do gate = described_class.new( "PassGate", - check: ->(_input) {}, + check: ->(_context) {}, fallbacks: { fail: fallback_step } ) - expect(gate.run("hello")).to eq("hello") + expect(gate.run("hello").current_input).to eq("hello") end it "returns the fallback branch result when check returns a registered key" do gate = described_class.new( "FailGate", - check: ->(_input) { :fail }, + check: ->(_context) { :fail }, fallbacks: { fail: fallback_step } ) @@ -43,7 +43,7 @@ def run(context) it "raises when check returns an unregistered key" do gate = described_class.new( "BadGate", - check: ->(_input) { :unknown }, + check: ->(_context) { :unknown }, fallbacks: { fail: fallback_step } ) @@ -53,7 +53,7 @@ def run(context) it "selects among multiple fallbacks" do gate = described_class.new( "MultiFallback", - check: ->(input) { input[:error_type] }, + check: ->(context) { context.current_input[:error_type] }, fallbacks: { timeout: fallback_step, auth: error_step } ) @@ -75,12 +75,12 @@ def run(context) it "uses check and fallback DSL" do cls = fallback_cls gate_class = Class.new(described_class) do - check { |input| :invalid if input.length > 5 } + check { |context| :invalid if context.current_input.length > 5 } fallback :invalid, cls end gate = gate_class.new("DSLGate") - expect(gate.run("hi")).to eq("hi") + expect(gate.run("hi").current_input).to eq("hi") expect(gate.run("longstring")).to eq("handled: longstring") end end