diff --git a/session/session.go b/session/session.go index 42fddef01..2c6344bb2 100644 --- a/session/session.go +++ b/session/session.go @@ -138,10 +138,6 @@ type Event struct { // NodeInfo carries the per-event metadata identifying which node in // a workflow activation emitted it. -// -// TODO(wolo): adk-python's NodeInfo also has OutputFor []string -// (fan-in re-attribution) and MessageAsOutput bool (content-as-output -// shorthand). Add as the corresponding features land in adk-go. type NodeInfo struct { // Path is the composite path of the emitting node within its // workflow activation. Empty for top-level static nodes; @@ -150,6 +146,18 @@ type NodeInfo struct { // invariants to the emitter, allowing dynamic nodes to forward // children's terminal events alongside their own. Path string `json:"path,omitempty"` + + // MessageAsOutput marks that this event's content IS the node's + // output: when set and Event.Output is nil, readers derive the + // node output from the event's model text. Mirrors adk-python's + // node_info.message_as_output. + MessageAsOutput bool `json:"messageAsOutput,omitempty"` + + // OutputFor lists the node paths this event's Output counts for: the + // emitter plus any WithUseAsOutput delegating ancestors, so one event + // stands in for a whole delegation chain rather than each level + // re-emitting a duplicate. Mirrors adk-python's node_info.output_for. + OutputFor []string `json:"outputFor,omitempty"` } // RequestInput describes a single human-in-the-loop prompt emitted diff --git a/workflow/agent_node.go b/workflow/agent_node.go index 02ae2477d..840751675 100644 --- a/workflow/agent_node.go +++ b/workflow/agent_node.go @@ -129,7 +129,9 @@ func (n *AgentNode) Run(ctx agent.InvocationContext, input any) iter.Seq2[*sessi synthesizeAgentOutput(event) - // TODO: add output validation + // The output schema (if any) is applied by the scheduler via + // ValidateOutput; synthesizeAgentOutput leaves the raw model + // text for defaultValidateOutput to project onto the schema. if !yield(event, nil) { return } @@ -139,7 +141,14 @@ func (n *AgentNode) Run(ctx agent.InvocationContext, input any) iter.Seq2[*sessi // synthesizeAgentOutput sets Event.Output from concatenated model // text on final model responses so RunNode returns the agent's -// reply instead of the zero value. +// reply instead of the zero value. Empty model text yields an empty +// "" output (a value, not "no output"), matching adk-python and +// messageAsOutput; non-model events are left untouched. +// +// It also stamps NodeInfo.MessageAsOutput so readers (live and +// resume) know this event's output was derived from the model +// message, mirroring adk-python's process_llm_agent_output which +// sets event.output and node_info.message_as_output together. func synthesizeAgentOutput(event *session.Event) { if event == nil || event.Output != nil { return @@ -147,9 +156,25 @@ func synthesizeAgentOutput(event *session.Event) { if !event.IsFinalResponse() { return } + if text, ok := messageText(event); ok { + event.Output = text + if event.NodeInfo == nil { + event.NodeInfo = &session.NodeInfo{} + } + event.NodeInfo.MessageAsOutput = true + } +} + +// messageText concatenates the non-thought model text of an event. ok +// is false when the event carries no model content, distinguishing it +// from a model message with empty text. +func messageText(event *session.Event) (text string, ok bool) { + if event == nil { + return "", false + } content := event.LLMResponse.Content if content == nil || content.Role != "model" { - return + return "", false } var b []byte for _, p := range content.Parts { @@ -158,8 +183,19 @@ func synthesizeAgentOutput(event *session.Event) { } b = append(b, p.Text...) } - if len(b) == 0 { - return + return string(b), true +} + +// childEventOutput returns the output an event carries: its Output, or +// the model text when MessageAsOutput is set. +func childEventOutput(event *session.Event) (any, bool) { + if event.Output != nil { + return event.Output, true + } + if event.NodeInfo != nil && event.NodeInfo.MessageAsOutput { + if text, ok := messageText(event); ok { + return text, true + } } - event.Output = string(b) + return nil, false } diff --git a/workflow/agent_node_test.go b/workflow/agent_node_test.go index 2d2846b39..290368693 100644 --- a/workflow/agent_node_test.go +++ b/workflow/agent_node_test.go @@ -443,4 +443,70 @@ func TestAgentNode_SynthesizesOutputFromModelText(t *testing.T) { if got, want := gotFinal.Output, "Hello, world!"; got != want { t.Errorf("final event Output = %v, want %q", got, want) } + if gotFinal.NodeInfo == nil || !gotFinal.NodeInfo.MessageAsOutput { + t.Errorf("final event NodeInfo.MessageAsOutput = %v, want true", gotFinal.NodeInfo) + } + if gotPartial.NodeInfo != nil && gotPartial.NodeInfo.MessageAsOutput { + t.Errorf("partial event MessageAsOutput = true, want false/unset") + } +} + +// TestAgentNode_StructuredOutputProjectedViaValidation verifies the +// end-to-end path that makes the validation fallback reachable: an +// AgentNode with a structured output schema yields JSON model text, +// and ValidateOutput projects it onto the schema. +func TestAgentNode_StructuredOutputProjectedViaValidation(t *testing.T) { + wrapped, err := agent.New(agent.Config{ + Name: "json-talky", + Run: func(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { + return func(yield func(*session.Event, error) bool) { + final := session.NewEvent(ctx.InvocationID()) + final.LLMResponse.Content = &genai.Content{ + Role: "model", + Parts: []*genai.Part{{Text: `{"value":"hello"}`}}, + } + yield(final, nil) + } + }, + }) + if err != nil { + t.Fatalf("agent.New: %v", err) + } + outSchema, err := jsonschema.For[testSchemaInput](nil) + if err != nil { + t.Fatalf("jsonschema.For: %v", err) + } + node, err := NewAgentNodeWithSchemas(wrapped, nil, outSchema, NodeConfig{}) + if err != nil { + t.Fatalf("NewAgentNodeWithSchemas: %v", err) + } + + mockCtx := newMockCtx(t) + mockCtx.sess = &mockSession{id: "test-session-id"} + var gotFinal *session.Event + for ev, err := range node.Run(mockCtx, "ignored") { + if err != nil { + t.Fatalf("node.Run: %v", err) + } + if !ev.LLMResponse.Partial { + gotFinal = ev + } + } + if gotFinal == nil { + t.Fatal("missing final event") + } + + // AgentNode itself only synthesizes the raw text; the projection + // onto the schema happens in ValidateOutput. + got, err := node.ValidateOutput(gotFinal.Output) + if err != nil { + t.Fatalf("ValidateOutput: %v", err) + } + gotMap, ok := got.(map[string]any) + if !ok { + t.Fatalf("ValidateOutput returned %T, want map[string]any", got) + } + if gotMap["value"] != "hello" { + t.Errorf("got %v, want value=hello", gotMap) + } } diff --git a/workflow/base_node.go b/workflow/base_node.go index ec690544b..b50eb32fe 100644 --- a/workflow/base_node.go +++ b/workflow/base_node.go @@ -15,7 +15,13 @@ package workflow import ( + "encoding/json" + "strings" + "github.com/google/jsonschema-go/jsonschema" + "google.golang.org/genai" + + "google.golang.org/adk/session" ) // BaseNode provides identity and a default Config implementation for @@ -85,12 +91,83 @@ func (b BaseNode) ValidateOutput(out any) (any, error) { return defaultValidateOutput(out, b.outputSchema) } +// defaultValidateOutput is the shared output-validation helper used by +// BaseNode.ValidateOutput. +// +// Framework control values (*session.Event, *session.RequestInput) ride +// through Event.Output on some nodes but are not user output payloads, +// so they bypass schema validation. When direct validation fails on a +// model-text output (a string, or a *genai.Content of model parts — +// see synthesizeAgentOutput), the text fallback projects it onto the +// schema. On total failure the original validation error is returned, +// not a downstream parse error. Mirrors ADK Python's +// _validate_output_data. func defaultValidateOutput(out any, schema *jsonschema.Resolved) (any, error) { if schema == nil { return out, nil } - if err := schema.Validate(out); err != nil { - return nil, err + switch out.(type) { + case *session.Event, *session.RequestInput: + return out, nil + } + err := schema.Validate(out) + if err == nil { + return out, nil + } + if text, ok := modelText(out); ok { + if v, ok := projectTextOntoSchema(text, schema); ok { + return v, nil + } + } + return nil, err +} + +// modelText extracts the model text carried by an output value: the +// string itself, or the concatenated text parts of a *genai.Content. +// ok is false for any other type. +func modelText(out any) (string, bool) { + switch v := out.(type) { + case string: + return v, true + case *genai.Content: + var text strings.Builder + for _, part := range v.Parts { + if part != nil && part.Text != "" { + text.WriteString(part.Text) + } + } + return text.String(), true + default: + return "", false + } +} + +// projectTextOntoSchema projects model text onto schema: return it +// directly for a string schema, otherwise JSON-parse and re-validate. +// ok is false when no valid value can be produced, leaving error +// reporting to the caller. +func projectTextOntoSchema(s string, schema *jsonschema.Resolved) (any, bool) { + if rootSchemaIsString(schema) { + return s, true + } + if strings.TrimSpace(s) == "" { + return nil, false + } + var parsed any + if err := json.Unmarshal([]byte(s), &parsed); err != nil { + return nil, false + } + if err := schema.Validate(parsed); err != nil { + return nil, false + } + return parsed, true +} + +// rootSchemaIsString reports whether schema's root type is "string". +func rootSchemaIsString(schema *jsonschema.Resolved) bool { + root := schema.Schema() + if root == nil { + return false } - return out, nil + return root.Type == "string" } diff --git a/workflow/base_node_test.go b/workflow/base_node_test.go index a6492890b..fede4a650 100644 --- a/workflow/base_node_test.go +++ b/workflow/base_node_test.go @@ -20,6 +20,9 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/jsonschema-go/jsonschema" + "google.golang.org/genai" + + "google.golang.org/adk/session" ) // Compile-time assertions: every built-in workflow node must satisfy @@ -205,3 +208,133 @@ func TestBaseNode_WithSchemas(t *testing.T) { t.Error("expected ValidateOutput to fail on invalid output type, but succeeded") } } + +// resolveTestSchema generates a *jsonschema.Resolved from a Go type +// for use in tests. +func resolveTestSchema[T any](t *testing.T) *jsonschema.Resolved { + t.Helper() + s, err := jsonschema.For[T](nil) + if err != nil { + t.Fatalf("jsonschema.For failed: %v", err) + } + resolved, err := s.Resolve(nil) + if err != nil { + t.Fatalf("Resolve failed: %v", err) + } + return resolved +} + +// TestDefaultValidateOutput_PassthroughTypes verifies that framework +// control values (*session.Event, *session.RequestInput) are returned +// unchanged even when a strict schema is configured: they ride through +// Event.Output on some nodes but are not user output payloads. +func TestDefaultValidateOutput_PassthroughTypes(t *testing.T) { + schema := resolveTestSchema[testSchemaInput](t) + + tests := []struct { + name string + in any + }{ + {name: "*session.Event", in: &session.Event{Author: "node"}}, + {name: "*session.RequestInput", in: &session.RequestInput{InterruptID: "approval"}}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := defaultValidateOutput(tc.in, schema) + if err != nil { + t.Fatalf("expected passthrough, got error: %v", err) + } + if got != tc.in { + t.Errorf("expected identity passthrough, got different value") + } + }) + } +} + +// TestDefaultValidateOutput_ContentFallback exercises the *genai.Content +// fallback: extract text from parts, return it directly for a string +// schema, otherwise JSON-parse and re-validate. When the fallback +// cannot produce a valid value the original validation error surfaces. +func TestDefaultValidateOutput_ContentFallback(t *testing.T) { + t.Run("string_schema_returns_text", func(t *testing.T) { + schema := resolveTestSchema[string](t) + content := &genai.Content{Parts: []*genai.Part{{Text: "hello "}, {Text: "world"}}} + got, err := defaultValidateOutput(content, schema) + if err != nil { + t.Fatalf("defaultValidateOutput failed: %v", err) + } + if got != "hello world" { + t.Errorf("got %q, want %q", got, "hello world") + } + }) + + t.Run("object_schema_parses_json", func(t *testing.T) { + schema := resolveTestSchema[testSchemaInput](t) + content := &genai.Content{Parts: []*genai.Part{{Text: `{"value":"hello"}`}}} + got, err := defaultValidateOutput(content, schema) + if err != nil { + t.Fatalf("defaultValidateOutput failed: %v", err) + } + gotMap, ok := got.(map[string]any) + if !ok { + t.Fatalf("expected map[string]any, got %T", got) + } + if gotMap["value"] != "hello" { + t.Errorf("got %v, want value=hello", gotMap) + } + }) + + t.Run("invalid_json_returns_original_error", func(t *testing.T) { + schema := resolveTestSchema[testSchemaInput](t) + content := &genai.Content{Parts: []*genai.Part{{Text: "not valid json"}}} + if _, err := defaultValidateOutput(content, schema); err == nil { + t.Fatal("expected validation error, got nil") + } + }) + + t.Run("empty_text_returns_original_error", func(t *testing.T) { + schema := resolveTestSchema[testSchemaInput](t) + content := &genai.Content{Parts: []*genai.Part{{Text: " "}}} + if _, err := defaultValidateOutput(content, schema); err == nil { + t.Fatal("expected validation error, got nil") + } + }) +} + +// TestDefaultValidateOutput_StringFallback covers the string output +// path (as produced by AgentNode): a JSON string that fails direct +// validation is parsed and projected onto a structured schema. +func TestDefaultValidateOutput_StringFallback(t *testing.T) { + t.Run("object_schema_parses_json", func(t *testing.T) { + schema := resolveTestSchema[testSchemaInput](t) + got, err := defaultValidateOutput(`{"value":"hello"}`, schema) + if err != nil { + t.Fatalf("defaultValidateOutput failed: %v", err) + } + gotMap, ok := got.(map[string]any) + if !ok { + t.Fatalf("expected map[string]any, got %T", got) + } + if gotMap["value"] != "hello" { + t.Errorf("got %v, want value=hello", gotMap) + } + }) + + t.Run("non_json_returns_original_error", func(t *testing.T) { + schema := resolveTestSchema[testSchemaInput](t) + if _, err := defaultValidateOutput("not json", schema); err == nil { + t.Fatal("expected validation error, got nil") + } + }) + + t.Run("plain_string_passes_any_schema", func(t *testing.T) { + schema := resolveTestSchema[any](t) + got, err := defaultValidateOutput("plain text", schema) + if err != nil { + t.Fatalf("defaultValidateOutput failed: %v", err) + } + if got != "plain text" { + t.Errorf("got %q, want %q", got, "plain text") + } + }) +} diff --git a/workflow/dynamic_node.go b/workflow/dynamic_node.go index 512905c02..9d4ad54e7 100644 --- a/workflow/dynamic_node.go +++ b/workflow/dynamic_node.go @@ -112,7 +112,7 @@ func (n *dynamicNode[IN, OUT]) Run(ctx agent.InvocationContext, input any) iter. emit := makeEmit(yield, parentNC) sub := newDynamicSubScheduler(parentNC, n.composePath(parentNC), emit) - orchestratorCtx := newDynamicNodeContext(parentNC, sub.parentPath, "", sub) + orchestratorCtx := newDynamicNodeContext(parentNC, sub.parentPath, "", sub, sub.outputForAncestors) out, err := n.fn(orchestratorCtx, typedInput, emit) if err != nil { @@ -125,6 +125,13 @@ func (n *dynamicNode[IN, OUT]) Run(ctx agent.InvocationContext, input any) iter. return } + // A WithUseAsOutput child already emitted this output on its own + // event (stamped for this node), so emit no duplicate terminal + // event. Mirrors adk-python's _output_delegated. + if _, delegated := sub.delegatedOutput(); delegated { + return + } + // nil output: nothing to emit as a terminal event — the body // either produced no output or already carried it on a content // event. @@ -132,11 +139,7 @@ func (n *dynamicNode[IN, OUT]) Run(ctx agent.InvocationContext, input any) iter. return } ev := session.NewEvent(parentNC.InvocationID()) - if delegated, ok := sub.delegatedOutput(); ok { - ev.Output = delegated - } else { - ev.Output = out - } + ev.Output = out ev.NodeInfo = &session.NodeInfo{Path: sub.parentPath} // TODO(wolo): validate ev.Output against n.outputSchema, // mirroring function_node.go:87-92. @@ -162,11 +165,14 @@ func (n *dynamicNode[IN, OUT]) coerceInput(input any) (IN, error) { return typed, nil } -// composePath returns this dynamic node's own composite path. Top-level -// activations get the bare Name(); nested dynamic nodes append. +// composePath returns this dynamic node's own composite path. When this +// node runs as a dynamic child, the scheduler already created its +// context with the full child path ("/@"), so that +// path is used as-is. A top-level activation has no parent path and +// gets the bare Name(). func (n *dynamicNode[IN, OUT]) composePath(parent NodeContext) string { if p := parent.Path(); p != "" { - return p + "/" + n.Name() + return p } return n.Name() } diff --git a/workflow/dynamic_scheduler.go b/workflow/dynamic_scheduler.go index 15e8c865f..644e79122 100644 --- a/workflow/dynamic_scheduler.go +++ b/workflow/dynamic_scheduler.go @@ -29,6 +29,12 @@ type dynamicSubScheduler struct { parentCtx NodeContext emitUp func(*session.Event) error + // outputForAncestors are the delegating-ancestor paths this + // activation's output also counts for, set when this dynamic node is + // itself a WithUseAsOutput child. Mirrors adk-python's + // Context._output_for_ancestors. + outputForAncestors []string + // mu guards everything below. Never held across child.Run. mu sync.Mutex // runCountByChild seeds the auto-counter per child name; the @@ -44,10 +50,9 @@ type dynamicSubScheduler struct { // outputDelegation is the at-most-one WithUseAsOutput delegation for a // parent activation. claim is set eagerly on the first delegating child // and never cleared within the activation (matching adk-python's -// _output_delegated); a second delegating child is rejected. A fresh -// sub-scheduler is built per activation, so there is nothing to reset -// across turns. hasValue is the source of truth for readability because -// nil is a valid delegated value. +// _output_delegated); a second delegating child is rejected. hasValue +// (not value != nil) is the source of truth, since nil is a valid +// delegated value. // // Methods require the enclosing scheduler's mu to be held. type outputDelegation struct { @@ -86,12 +91,17 @@ func (d *outputDelegation) output() (any, bool) { } func newDynamicSubScheduler(parent NodeContext, parentPath string, emitUp func(*session.Event) error) *dynamicSubScheduler { + var ancestors []string + if p, ok := parent.(*nodeContext); ok { + ancestors = p.outputForAncestors + } s := &dynamicSubScheduler{ - parentPath: parentPath, - parentCtx: parent, - emitUp: emitUp, - runCountByChild: map[string]int{}, - resultByPath: map[string]any{}, + parentPath: parentPath, + parentCtx: parent, + emitUp: emitUp, + outputForAncestors: ancestors, + runCountByChild: map[string]int{}, + resultByPath: map[string]any{}, } s.rehydrateCache() return s @@ -156,7 +166,13 @@ func (s *dynamicSubScheduler) runNode(child Node, input any, opts runNodeOptions } childBranch := deriveChildBranch(s.parentCtx.Branch(), name, runID, opts.useSubBranch, opts.overrideBranch) - childCtx := newDynamicNodeContext(s.parentCtx.WithBranch(childBranch), childPath, runID, s) + // A delegating child extends the chain: its own delegating children + // must count their output for this parent and its ancestors too. + var childAncestors []string + if opts.useAsOutput { + childAncestors = append([]string{s.parentPath}, s.outputForAncestors...) + } + childCtx := newDynamicNodeContext(s.parentCtx.WithBranch(childBranch), childPath, runID, s, childAncestors) // EXPERIMENTAL: stash childCtx (a *nodeContext with non-nil // subScheduler) in the embedded context.Context so tools running @@ -186,23 +202,47 @@ func (s *dynamicSubScheduler) runNode(child Node, input any, opts runNodeOptions // Stamp NodeInfo.Path so the top scheduler scopes the // child's Output/Routes to the child (not the parent's // accumulator). RequestedInput is promoted to the parent — - // see scheduler.handleEvent. Skip if the child already - // stamped NodeInfo (nested dynamic node yielding its own - // terminal event, dynamic_node.go). + // see scheduler.handleEvent. A child may set NodeInfo without + // a Path (e.g. MessageAsOutput), so fill the Path when empty + // rather than only when NodeInfo is nil; a nested dynamic node + // that already set its own Path keeps it. if ev.NodeInfo == nil { ev.NodeInfo = &session.NodeInfo{Path: childPath} + } else if ev.NodeInfo.Path == "" { + ev.NodeInfo.Path = childPath } if ev.RequestedInput != nil { interrupted = true } - if ev.Output != nil { - out = ev.Output - // A delegated child's output is re-emitted by the - // parent's terminal event; drop it here to avoid a - // duplicate. Partial/state-only events (Output == - // nil) still propagate. - if opts.useAsOutput { - continue + if childOut, ok := childEventOutput(ev); ok { + // Validate against the child's output schema on the same + // terms as the static scheduler (scheduler.handleEvent). + validated, err := child.ValidateOutput(childOut) + if err != nil { + return nil, &NodeRunError{ + ChildName: name, ChildPath: childPath, RunID: runID, + Cause: fmt.Errorf("%w: output validation failed: %v", ErrNodeFailed, err), + } + } + out = validated + // Stamp the validated value back onto Event.Output only when + // it was carried there; model-text outputs stay off the event + // (mirrors scheduler.handleEvent). + if ev.Output != nil { + ev.Output = validated + } + // Stamp OutputFor so resume can attribute the output: the + // emitter's own path plus, under delegation, this parent and + // its ancestors (the parent then suppresses its own terminal + // event). Mirrors adk-python _enrich_event. A nested child + // that already stamped its chain keeps it. + if ev.NodeInfo.OutputFor == nil { + outputFor := []string{ev.NodeInfo.Path} + if opts.useAsOutput { + outputFor = append(outputFor, s.parentPath) + outputFor = append(outputFor, s.outputForAncestors...) + } + ev.NodeInfo.OutputFor = outputFor } } if err := s.emitUp(ev); err != nil { diff --git a/workflow/dynamic_scheduler_test.go b/workflow/dynamic_scheduler_test.go index edddf0844..2517b2e45 100644 --- a/workflow/dynamic_scheduler_test.go +++ b/workflow/dynamic_scheduler_test.go @@ -18,9 +18,13 @@ import ( "errors" "iter" "strconv" + "strings" "sync" "testing" + "github.com/google/jsonschema-go/jsonschema" + "google.golang.org/genai" + "google.golang.org/adk/agent" "google.golang.org/adk/session" ) @@ -218,6 +222,73 @@ func (n *stubNode) Run(ctx agent.InvocationContext, _ any) iter.Seq2[*session.Ev } } +// newSchemaStubNode returns a stubNode carrying an output schema so the +// dynamic sub-scheduler invokes ValidateOutput on its yielded output. +func newSchemaStubNode(name string, schema *jsonschema.Resolved, out any) *stubNode { + return &stubNode{ + BaseNode: NewBaseNodeWithSchemas(name, "", NodeConfig{}, nil, schema), + out: out, + } +} + +func TestSubScheduler_RunNode_ValidatesOutput(t *testing.T) { + schema := resolveTestSchema[testSchemaInput](t) + + t.Run("valid_passes", func(t *testing.T) { + child := newSchemaStubNode("ok", schema, map[string]any{"value": "hi"}) + sub := newDynamicSubScheduler(newTopLevelCtx(t), "wf", noopEmit) + + out, err := sub.runNode(child, nil, runNodeOptions{}) + if err != nil { + t.Fatalf("runNode: %v", err) + } + gotMap, ok := out.(map[string]any) + if !ok || gotMap["value"] != "hi" { + t.Errorf("output = %v, want map value=hi", out) + } + }) + + t.Run("invalid_fails", func(t *testing.T) { + child := newSchemaStubNode("bad", schema, map[string]any{"value": 123}) + sub := newDynamicSubScheduler(newTopLevelCtx(t), "wf", noopEmit) + + _, err := sub.runNode(child, nil, runNodeOptions{}) + if !errors.Is(err, ErrNodeFailed) { + t.Fatalf("err = %v, want ErrNodeFailed", err) + } + if !strings.Contains(err.Error(), "output validation failed") { + t.Errorf("err = %q, want substring %q", err.Error(), "output validation failed") + } + }) +} + +// messageAsOutputNode emits a final model-text event whose content IS +// its output (NodeInfo.MessageAsOutput set, Event.Output nil), like an +// LlmAgent node in single_turn mode. +type messageAsOutputNode struct { + BaseNode + text string +} + +func newMessageAsOutputNode(name, text string) *messageAsOutputNode { + return &messageAsOutputNode{ + BaseNode: NewBaseNode(name, "", NodeConfig{}), + text: text, + } +} + +func (n *messageAsOutputNode) Run(agent.InvocationContext, any) iter.Seq2[*session.Event, error] { + text := n.text + return func(yield func(*session.Event, error) bool) { + ev := &session.Event{NodeInfo: &session.NodeInfo{MessageAsOutput: true}} + ev.LLMResponse.Content = &genai.Content{ + Role: "model", + Parts: []*genai.Part{{Text: text}}, + } + yield(ev, nil) + } +} + // requestInputNode emits one RequestedInput event and exits cleanly. type requestInputNode struct { BaseNode diff --git a/workflow/node_context.go b/workflow/node_context.go index 294389534..ec1138166 100644 --- a/workflow/node_context.go +++ b/workflow/node_context.go @@ -65,6 +65,11 @@ type nodeContext struct { // subScheduler is non-nil only when this context belongs to a // dynamic-node activation; RunNode uses it to schedule children. subScheduler *dynamicSubScheduler + + // outputForAncestors are the delegating-ancestor paths carried + // into this activation when it runs as a WithUseAsOutput child; + // its dynamic sub-scheduler reads them to stamp OutputFor. + outputForAncestors []string } // Compile-time: *nodeContext implements NodeContext. @@ -85,17 +90,18 @@ func newNodeContext(parent agent.InvocationContext, resumeInputs map[string]any) // dynamic node's own activation passes runID="" — it is not itself a // sub-scheduler child. Child inherits resumeInputs so HITL responses // reach dynamic children. -func newDynamicNodeContext(parent NodeContext, path, runID string, sub *dynamicSubScheduler) *nodeContext { +func newDynamicNodeContext(parent NodeContext, path, runID string, sub *dynamicSubScheduler, outputForAncestors []string) *nodeContext { var inherited map[string]any if p, ok := parent.(*nodeContext); ok { inherited = p.resumeInputs } return &nodeContext{ - InvocationContext: parent, - resumeInputs: inherited, - path: path, - runID: runID, - subScheduler: sub, + InvocationContext: parent, + resumeInputs: inherited, + path: path, + runID: runID, + subScheduler: sub, + outputForAncestors: outputForAncestors, } } @@ -131,10 +137,11 @@ func (c *nodeContext) WithBranch(branch string) NodeContext { // activations and any other workflow-specific accessors. func (c *nodeContext) WithContext(ctx context.Context) agent.InvocationContext { return &nodeContext{ - c.InvocationContext.WithContext(ctx), - c.resumeInputs, - c.path, - c.runID, - c.subScheduler, + InvocationContext: c.InvocationContext.WithContext(ctx), + resumeInputs: c.resumeInputs, + path: c.path, + runID: c.runID, + subScheduler: c.subScheduler, + outputForAncestors: c.outputForAncestors, } } diff --git a/workflow/node_context_test.go b/workflow/node_context_test.go index c90b088a2..7796be680 100644 --- a/workflow/node_context_test.go +++ b/workflow/node_context_test.go @@ -60,7 +60,7 @@ func TestNodeContext_PathAndRunID(t *testing.T) { t.Run("child populated from constructor", func(t *testing.T) { parent := newNodeContext(newMockCtx(t), nil) - child := newDynamicNodeContext(parent, "wf/fixer@2", "2", nil) + child := newDynamicNodeContext(parent, "wf/fixer@2", "2", nil, nil) if got, want := child.Path(), "wf/fixer@2"; got != want { t.Errorf("Path() = %q, want %q", got, want) } @@ -71,7 +71,7 @@ func TestNodeContext_PathAndRunID(t *testing.T) { t.Run("activation populated with empty runID", func(t *testing.T) { parent := newNodeContext(newMockCtx(t), nil) - act := newDynamicNodeContext(parent, "city_workflow", "", nil) + act := newDynamicNodeContext(parent, "city_workflow", "", nil, nil) if got, want := act.Path(), "city_workflow"; got != want { t.Errorf("Path() = %q, want %q", got, want) } @@ -86,7 +86,7 @@ func TestNodeContext_DynamicInheritsResumeInputs(t *testing.T) { sub := &dynamicSubScheduler{} t.Run("child", func(t *testing.T) { - child := newDynamicNodeContext(parent, "wf/asker@1", "1", sub) + child := newDynamicNodeContext(parent, "wf/asker@1", "1", sub, nil) if v, ok := child.ResumedInput("approval"); !ok || v != "yes" { t.Errorf("child.ResumedInput(\"approval\") = (%v, %v), want (\"yes\", true)", v, ok) } @@ -96,7 +96,7 @@ func TestNodeContext_DynamicInheritsResumeInputs(t *testing.T) { }) t.Run("activation", func(t *testing.T) { - act := newDynamicNodeContext(parent, "city_workflow", "", sub) + act := newDynamicNodeContext(parent, "city_workflow", "", sub, nil) if v, ok := act.ResumedInput("approval"); !ok || v != "yes" { t.Errorf("act.ResumedInput(\"approval\") = (%v, %v), want (\"yes\", true)", v, ok) } diff --git a/workflow/persistence.go b/workflow/persistence.go index cfc9dfb2c..e08dff1a1 100644 --- a/workflow/persistence.go +++ b/workflow/persistence.go @@ -196,8 +196,29 @@ func collectNodeOutputs(events session.Events, nodesByName map[string]Node) (out continue } completed[name] = true - if ev.Output != nil { - outputs[name] = ev.Output + // Prefer an explicit Output; otherwise derive it from the + // model message when the event is flagged MessageAsOutput, + // so a message-as-output node recovers its output on resume + // (mirrors adk-python _reconstruct_node_states' + // use_message_as_output branch). + out, ok := childEventOutput(ev) + if !ok { + continue + } + outputs[name] = out + // A delegated output also counts for the static owners of the + // OutputFor paths, so a delegating ancestor recovers it on resume + // without re-emitting. Mirrors adk-python's output_for. + if ev.NodeInfo != nil { + for _, p := range ev.NodeInfo.OutputFor { + owner := staticNodeName(p) + if owner == name { + continue + } + if _, known := nodesByName[owner]; known { + outputs[owner] = out + } + } } } return outputs, completed @@ -399,15 +420,20 @@ func firstUserInput(events session.Events) any { // LlmAgent node path, where Author == node name and no path is set. func eventNodeName(ev *session.Event) string { if ev.NodeInfo != nil && ev.NodeInfo.Path != "" { - path := ev.NodeInfo.Path - if i := strings.IndexByte(path, '/'); i >= 0 { - return path[:i] - } - return path + return staticNodeName(ev.NodeInfo.Path) } return ev.Author } +// staticNodeName returns the static graph node owning a node path: the +// first segment of a composite "parent/child@run" path. +func staticNodeName(path string) string { + if i := strings.IndexByte(path, '/'); i >= 0 { + return path[:i] + } + return path +} + // frPart returns the FunctionResponse on a part if present and keyed. func frPart(p *genai.Part) *genai.FunctionResponse { if p == nil || p.FunctionResponse == nil || p.FunctionResponse.ID == "" { diff --git a/workflow/persistence_test.go b/workflow/persistence_test.go new file mode 100644 index 000000000..0d2934e16 --- /dev/null +++ b/workflow/persistence_test.go @@ -0,0 +1,120 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package workflow + +import ( + "iter" + "testing" + + "google.golang.org/genai" + + "google.golang.org/adk/session" +) + +// sliceEvents adapts a []*session.Event to session.Events for tests. +type sliceEvents []*session.Event + +func (e sliceEvents) Len() int { return len(e) } +func (e sliceEvents) At(i int) *session.Event { return e[i] } +func (e sliceEvents) All() iter.Seq[*session.Event] { + return func(yield func(*session.Event) bool) { + for _, ev := range e { + if !yield(ev) { + return + } + } + } +} + +func modelEvent(path, text string, messageAsOutput bool) *session.Event { + ev := &session.Event{ + NodeInfo: &session.NodeInfo{Path: path, MessageAsOutput: messageAsOutput}, + } + ev.LLMResponse.Content = &genai.Content{ + Role: "model", + Parts: []*genai.Part{{Text: text}}, + } + return ev +} + +// Resume derives output from the model message when an event is +// flagged MessageAsOutput with no explicit Output (adk-python parity). +func TestCollectNodeOutputs_MessageAsOutput(t *testing.T) { + nodes := map[string]Node{"talky": &dummyNode{name: "talky"}} + + events := sliceEvents{modelEvent("talky", "Hello, world!", true)} + + outputs, completed := collectNodeOutputs(events, nodes) + + if got, want := outputs["talky"], "Hello, world!"; got != want { + t.Errorf("outputs[talky] = %#v, want %q", got, want) + } + if !completed["talky"] { + t.Errorf("completed[talky] = false, want true") + } +} + +func TestCollectNodeOutputs_MessageNotFlagged(t *testing.T) { + nodes := map[string]Node{"talky": &dummyNode{name: "talky"}} + + events := sliceEvents{modelEvent("talky", "Hello, world!", false)} + + outputs, _ := collectNodeOutputs(events, nodes) + + if _, ok := outputs["talky"]; ok { + t.Errorf("outputs[talky] = %#v, want absent", outputs["talky"]) + } +} + +func TestCollectNodeOutputs_ExplicitOutputWins(t *testing.T) { + nodes := map[string]Node{"talky": &dummyNode{name: "talky"}} + + ev := modelEvent("talky", "from message", true) + ev.Output = "explicit" + events := sliceEvents{ev} + + outputs, _ := collectNodeOutputs(events, nodes) + + if got, want := outputs["talky"], "explicit"; got != want { + t.Errorf("outputs[talky] = %#v, want %q", got, want) + } +} + +// A delegated child's output is attributed on resume to the static +// owners of every path in OutputFor, so a delegating ancestor recovers +// it without re-emitting (adk-python output_for parity). +func TestCollectNodeOutputs_OutputForAttributesAncestors(t *testing.T) { + nodes := map[string]Node{ + "child": &dummyNode{name: "child"}, + "outer": &dummyNode{name: "outer"}, + } + + ev := &session.Event{ + Output: "delegated", + NodeInfo: &session.NodeInfo{ + Path: "child/gc@1", + OutputFor: []string{"child/gc@1", "outer/child@1"}, + }, + } + + outputs, _ := collectNodeOutputs(sliceEvents{ev}, nodes) + + if got, want := outputs["child"], "delegated"; got != want { + t.Errorf("outputs[child] = %#v, want %q", got, want) + } + if got, want := outputs["outer"], "delegated"; got != want { + t.Errorf("outputs[outer] = %#v, want %q (ancestor not attributed)", got, want) + } +} diff --git a/workflow/run_node_test.go b/workflow/run_node_test.go index b39178833..b853d71f6 100644 --- a/workflow/run_node_test.go +++ b/workflow/run_node_test.go @@ -191,12 +191,91 @@ func TestRunNode_WithUseAsOutput_ChildOutputBecomesParentOutput(t *testing.T) { NodeConfig{}, ) events := drainDynamic(t, n, "") - if got := parentTerminalOutput(t, events, "orch"); got != "child_value" { - t.Errorf("parent terminal Output = %v, want %q", got, "child_value") + // Full suppression: the delegated output is carried on the child's + // own event; the parent emits no terminal event. + if got := outputBearingPaths(events); !reflect.DeepEqual(got, []string{"orch/c@1"}) { + t.Errorf("paths of events with Output = %v, want exactly [\"orch/c@1\"]", got) } - // Delegated child must not emit a duplicate output event. - if got := outputBearingPaths(events); !reflect.DeepEqual(got, []string{"orch"}) { - t.Errorf("paths of events with Output = %v, want exactly [\"orch\"]", got) + if got := parentTerminalOutput(t, events, "orch/c@1"); got != "child_value" { + t.Errorf("delegated child Output = %v, want %q", got, "child_value") + } + // The child event is stamped OutputFor with its own path plus the + // delegating parent, so resume attributes the output to both. + if got := outputForAtPath(events, "orch/c@1"); !reflect.DeepEqual(got, []string{"orch/c@1", "orch"}) { + t.Errorf("OutputFor = %v, want [orch/c@1 orch]", got) + } +} + +func TestRunNode_WithUseAsOutput_MultiLevelStampsAllAncestors(t *testing.T) { + // grandchild delegates up through child to the top orchestrator; + // the single output event is stamped for the whole chain. + grandchild := newStubNode("gc", "deep_value") + child := NewDynamicNode[string, string]( + "mid", + func(ctx NodeContext, _ string, _ func(*session.Event) error) (string, error) { + return RunNode[string](ctx, grandchild, nil, WithUseAsOutput()) + }, + NodeConfig{}, + ) + top := NewDynamicNode[string, string]( + "top", + func(ctx NodeContext, _ string, _ func(*session.Event) error) (string, error) { + return RunNode[string](ctx, child, nil, WithUseAsOutput()) + }, + NodeConfig{}, + ) + events := drainDynamic(t, top, "") + // One output event, carried on the grandchild, suppressing both + // delegating ancestors. + if got := outputBearingPaths(events); !reflect.DeepEqual(got, []string{"top/mid@1/gc@1"}) { + t.Errorf("output-bearing paths = %v, want [top/mid@1/gc@1]", got) + } + if got := outputForAtPath(events, "top/mid@1/gc@1"); !reflect.DeepEqual(got, []string{"top/mid@1/gc@1", "top/mid@1", "top"}) { + t.Errorf("OutputFor = %v, want [top/mid@1/gc@1 top/mid@1 top]", got) + } +} + +func TestRunNode_WithUseAsOutput_MessageAsOutputChildBecomesParentOutput(t *testing.T) { + // A delegated child whose message IS its output (NodeInfo. + // MessageAsOutput, no explicit Output — like an LlmAgent node) + // promotes its model text to the parent's terminal Output. + child := newMessageAsOutputNode("c", "child_text") + n := NewDynamicNode[string, string]( + "orch", + func(ctx NodeContext, _ string, _ func(*session.Event) error) (string, error) { + if _, err := RunNode[string](ctx, child, nil, WithUseAsOutput()); err != nil { + return "", err + } + return "parent_value", nil + }, + NodeConfig{}, + ) + events := drainDynamic(t, n, "") + // Full suppression: the child's own event carries the output (via + // MessageAsOutput); the parent emits nothing. + if got, ok := derivedOutputAtPath(events, "orch/c@1"); !ok || got != "child_text" { + t.Errorf("delegated child derived output = %v (ok=%v), want %q", got, ok, "child_text") + } +} + +func TestRunNode_WithUseAsOutput_MessageAsOutputEmptyTextIsValidOutput(t *testing.T) { + // Empty model text under MessageAsOutput is a valid output ("", + // not "no output"), matching adk-python. The parent's terminal + // Output must be the empty string, not nil. + child := newMessageAsOutputNode("c", "") + n := NewDynamicNode[string, string]( + "orch", + func(ctx NodeContext, _ string, _ func(*session.Event) error) (string, error) { + if _, err := RunNode[string](ctx, child, nil, WithUseAsOutput()); err != nil { + return "", err + } + return "parent_value", nil + }, + NodeConfig{}, + ) + events := drainDynamic(t, n, "") + if got, ok := derivedOutputAtPath(events, "orch/c@1"); !ok || got != "" { + t.Errorf("delegated child derived output = %#v (ok=%v), want empty string", got, ok) } } @@ -256,8 +335,10 @@ func TestRunNode_WithRunID_AndUseAsOutput_IdempotentReplay(t *testing.T) { if got := child.runCount(); got != 1 { t.Errorf("child.Run invocations = %d, want 1", got) } - if got := parentTerminalOutput(t, events, "orch"); got != "delegated_value" { - t.Errorf("parent terminal Output = %v, want %q", got, "delegated_value") + // Full suppression: the child's event carries the delegated output; + // the cached replay re-emits nothing and the parent stays silent. + if got, ok := derivedOutputAtPath(events, "orch/c@stable-id"); !ok || got != "delegated_value" { + t.Errorf("delegated child output = %v (ok=%v), want %q", got, ok, "delegated_value") } } @@ -346,6 +427,29 @@ func outputBearingPaths(events []*session.Event) []string { // parentTerminalOutput returns the Output of the last event // stamped with parentPath. +// outputForAtPath returns NodeInfo.OutputFor of the event at nodePath. +func outputForAtPath(events []*session.Event, nodePath string) []string { + for i := len(events) - 1; i >= 0; i-- { + ev := events[i] + if ev.NodeInfo != nil && ev.NodeInfo.Path == nodePath { + return ev.NodeInfo.OutputFor + } + } + return nil +} + +// derivedOutputAtPath returns the output the event at nodePath carries, +// via childEventOutput (explicit Output or MessageAsOutput-derived). +func derivedOutputAtPath(events []*session.Event, nodePath string) (any, bool) { + for i := len(events) - 1; i >= 0; i-- { + ev := events[i] + if ev.NodeInfo != nil && ev.NodeInfo.Path == nodePath { + return childEventOutput(ev) + } + } + return nil, false +} + func parentTerminalOutput(t *testing.T, events []*session.Event, parentPath string) any { t.Helper() for i := len(events) - 1; i >= 0; i-- { diff --git a/workflow/scheduler.go b/workflow/scheduler.go index ed27def6f..c61d809d8 100644 --- a/workflow/scheduler.go +++ b/workflow/scheduler.go @@ -653,9 +653,43 @@ func (s *scheduler) handleEvent(it eventItem) { if it.ev.Routes != nil { nr.setRoutingEvent(it.ev, it.nodeName) } - if it.ev.Output != nil { - nr.setOutput(it.ev.Output, it.nodeName) + if out, ok := childEventOutput(it.ev); ok { + // Validate (and optionally coerce) the output against the node's + // output schema before it is committed to the accumulator and + // forwarded to the consumer. Events without output bypass + // validation entirely. + validated, err := s.validateNodeOutput(it.nodeName, out) + if err != nil { + nr.recordErr(err) + return + } + // Write the validated value back onto the event when it is + // carried via Event.Output. Outputs derived from model text + // (MessageAsOutput) are not stamped back onto Event.Output. + if it.ev.Output != nil { + it.ev.Output = validated + } + nr.setOutput(validated, it.nodeName) + } +} + +// validateNodeOutput invokes ValidateOutput on the node identified by +// nodeName for the given output value. On validation failure the +// returned error is wrapped with the node name to aid debugging; the +// caller records it on the node-run accumulator so handleCompletion +// surfaces it as a NodeFailed transition. +func (s *scheduler) validateNodeOutput(nodeName string, out any) (any, error) { + n := s.nodesByName[nodeName] + if n == nil { + // handleEvent only runs for registered nodes; a miss means the + // registry is out of sync. Fail rather than forward unvalidated. + return nil, fmt.Errorf("output validation: node %q not found in graph", nodeName) + } + validated, err := n.ValidateOutput(out) + if err != nil { + return nil, fmt.Errorf("output validation failed for node %q: %w", nodeName, err) } + return validated, nil } // handleCompletion finalises a node's run: transitions its lifecycle diff --git a/workflow/scheduler_test.go b/workflow/scheduler_test.go index 8b1b2ff2e..6eaf1c352 100644 --- a/workflow/scheduler_test.go +++ b/workflow/scheduler_test.go @@ -20,6 +20,7 @@ import ( "fmt" "iter" "sort" + "strings" "sync/atomic" "testing" "time" @@ -60,6 +61,28 @@ func TestScheduler_LinearChain(t *testing.T) { } } +// TestScheduler_MessageAsOutput_FeedsSuccessor verifies that a node +// whose message IS its output (NodeInfo.MessageAsOutput, no explicit +// Event.Output) has its model text derived as the node output and fed +// to the successor as input. +func TestScheduler_MessageAsOutput_FeedsSuccessor(t *testing.T) { + mockCtx := newSeededMockCtx(t) + + a := newMessageAsOutputNode("A", "hello") + b := newRecordingNode("B") + b.release() + + w := mustNew(t, Chain(Start, a, b)) + + gotEvents := drain(t, w.Run(mockCtx)) + + // B echoes ":B"; the input must be A's derived output. + got := outputsOf(gotEvents) + if len(got) != 1 || got[0] != "hello:B" { + t.Errorf("outputs = %v, want [\"hello:B\"] (A's message text fed to B)", got) + } +} + // TestScheduler_FanOutConcurrency verifies that three nodes // downstream of START are mid-Run simultaneously, not serialised by // the legacy BFS. Each node blocks on its release channel until the @@ -735,3 +758,117 @@ func (n *validationTestNode) Run(ctx agent.InvocationContext, input any) iter.Se yield(ev, nil) } } + +// TestScheduler_ValidateOutput_ValidPasses verifies that a node whose +// yielded output conforms to its output_schema is forwarded unchanged. +func TestScheduler_ValidateOutput_ValidPasses(t *testing.T) { + mockCtx := newSeededMockCtx(t) + schema := resolveTestSchema[testSchemaInput](t) + n := newSchemaValidatedNode("n", schema, map[string]any{"value": "hello"}) + + w := mustNew(t, []Edge{{From: Start, To: n}}) + + events := drain(t, w.Run(mockCtx)) + if got, want := len(events), 1; got != want { + t.Fatalf("event count = %d, want %d", got, want) + } + gotMap, ok := events[0].Output.(map[string]any) + if !ok { + t.Fatalf("Output type = %T, want map[string]any", events[0].Output) + } + if gotMap["value"] != "hello" { + t.Errorf("Output[value] = %v, want %q", gotMap["value"], "hello") + } +} + +// TestScheduler_ValidateOutput_InvalidEndsActivation verifies that a +// node yielding output that fails its output_schema surfaces a +// validation error and does not transition to NodeCompleted. +func TestScheduler_ValidateOutput_InvalidEndsActivation(t *testing.T) { + mockCtx := newSeededMockCtx(t) + schema := resolveTestSchema[testSchemaInput](t) + n := newSchemaValidatedNode("n", schema, map[string]any{"value": 123}) + + w := mustNew(t, []Edge{{From: Start, To: n}}) + + gotErr := drainErr(t, w.Run(mockCtx)) + if gotErr == nil { + t.Fatal("expected validation error, got nil") + } + if wantSubstr := `output validation failed for node "n"`; !strings.Contains(gotErr.Error(), wantSubstr) { + t.Errorf("error = %q, want substring %q", gotErr.Error(), wantSubstr) + } +} + +// TestScheduler_ValidateOutput_NoOutputSkipsValidation verifies that +// events without Output (progress events) are forwarded without +// invoking ValidateOutput, even under a schema that would reject nil. +func TestScheduler_ValidateOutput_NoOutputSkipsValidation(t *testing.T) { + mockCtx := newSeededMockCtx(t) + schema := resolveTestSchema[testSchemaInput](t) + n := &progressThenSchemaOutputNode{ + BaseNode: NewBaseNodeWithSchemas("n", "", NodeConfig{}, nil, schema), + progress: 3, + output: map[string]any{"value": "hello"}, + } + + w := mustNew(t, []Edge{{From: Start, To: n}}) + + events := drain(t, w.Run(mockCtx)) + if got, want := len(events), 4; got != want { + t.Fatalf("event count = %d, want %d", got, want) + } + for i := 0; i < 3; i++ { + if events[i].Output != nil { + t.Errorf("event %d Output = %v, want nil (progress)", i, events[i].Output) + } + } + if events[3].Output == nil { + t.Errorf("last event Output = nil, want validated map") + } +} + +// schemaValidatedNode yields one event whose Output is the supplied +// value; its BaseNode carries an output schema so the scheduler runs +// ValidateOutput on the yielded value. +type schemaValidatedNode struct { + BaseNode + output any +} + +func newSchemaValidatedNode(name string, schema *jsonschema.Resolved, output any) *schemaValidatedNode { + return &schemaValidatedNode{ + BaseNode: NewBaseNodeWithSchemas(name, "", NodeConfig{}, nil, schema), + output: output, + } +} + +func (n *schemaValidatedNode) Run(ctx agent.InvocationContext, _ any) iter.Seq2[*session.Event, error] { + return func(yield func(*session.Event, error) bool) { + ev := session.NewEvent(ctx.InvocationID()) + ev.Output = n.output + yield(ev, nil) + } +} + +// progressThenSchemaOutputNode yields `progress` output-less events +// followed by one carrying `output`, to verify the scheduler skips +// ValidateOutput on output-less events. +type progressThenSchemaOutputNode struct { + BaseNode + progress int + output any +} + +func (n *progressThenSchemaOutputNode) Run(ctx agent.InvocationContext, _ any) iter.Seq2[*session.Event, error] { + return func(yield func(*session.Event, error) bool) { + for i := 0; i < n.progress; i++ { + if !yield(session.NewEvent(ctx.InvocationID()), nil) { + return + } + } + ev := session.NewEvent(ctx.InvocationID()) + ev.Output = n.output + yield(ev, nil) + } +}