diff --git a/agent/workflowagent/hitl_test.go b/agent/workflowagent/hitl_test.go index aa529cb6f..e2ff07dd6 100644 --- a/agent/workflowagent/hitl_test.go +++ b/agent/workflowagent/hitl_test.go @@ -26,6 +26,7 @@ import ( "google.golang.org/genai" "google.golang.org/adk/agent" + "google.golang.org/adk/model" "google.golang.org/adk/session" "google.golang.org/adk/workflow" ) @@ -336,18 +337,20 @@ func TestWorkflowAgent_RunThenResume_DynamicNodeOrchestrator(t *testing.T) { // Test fixtures and helpers // ============================================================================= -// fakeSession is a minimal session.Session implementation that -// faithfully models the AppendEvent-gated state contract used by -// session.InMemoryService and persistent backends: state mutations -// are applied only when applyStateDelta is called for an event -// (the unit-test analogue of session.Service.AppendEvent), never -// via direct State.Set from inside the agent. +// fakeSession is a minimal session.Session that records appended +// events, as the real session services do. HITL resume reconstructs +// paused state from this event history (Workflow.ReconstructRunState), +// so the test must append every yielded event — and the inbound user +// FunctionResponse on a resume turn — into the session. // -// drainAgent (this file) calls applyStateDelta for every event the -// agent yields, simulating what the runner does in production. +// drainAgent (this file) appends every event the agent yields, and +// appendUserMessage records the inbound resume message, together +// simulating what the runner does in production. type fakeSession struct { session.Session - state *fakeSessionState + state *fakeSessionState + mu sync.Mutex + events []*session.Event } func newFakeSession() *fakeSession { @@ -357,18 +360,55 @@ func newFakeSession() *fakeSession { func (s *fakeSession) ID() string { return "test-session-id" } func (s *fakeSession) State() session.State { return s.state } -// applyStateDelta merges any Actions.StateDelta on the supplied -// event into the underlying state map. Mirrors what -// inMemoryService.AppendEvent does for session-scoped (no -// app:/user:/temp: prefix) keys; HITL persistence uses such keys. -func (s *fakeSession) applyStateDelta(ev *session.Event) { - if ev == nil || len(ev.Actions.StateDelta) == 0 { +func (s *fakeSession) Events() session.Events { + s.mu.Lock() + defer s.mu.Unlock() + return fakeEvents(append([]*session.Event(nil), s.events...)) +} + +// appendEvent records an event in history (the test analogue of +// session.Service.AppendEvent) and applies any StateDelta. +func (s *fakeSession) appendEvent(ev *session.Event) { + if ev == nil { + return + } + s.mu.Lock() + s.events = append(s.events, ev) + s.mu.Unlock() + if len(ev.Actions.StateDelta) > 0 { + s.state.mu.Lock() + for k, v := range ev.Actions.StateDelta { + s.state.m[k] = v + } + s.state.mu.Unlock() + } +} + +// appendUserMessage records an inbound user message as a "user" +// event so a resume turn's FunctionResponse is visible to +// ReconstructRunState, mirroring the runner appending the user turn. +func (s *fakeSession) appendUserMessage(msg *genai.Content) { + if msg == nil { return } - s.state.mu.Lock() - defer s.state.mu.Unlock() - for k, v := range ev.Actions.StateDelta { - s.state.m[k] = v + ev := session.NewEvent("test-invocation-id") + ev.Author = "user" + ev.LLMResponse = model.LLMResponse{Content: msg} + s.appendEvent(ev) +} + +// fakeEvents is a session.Events over a fixed slice. +type fakeEvents []*session.Event + +func (e fakeEvents) Len() int { return len(e) } +func (e fakeEvents) At(i int) *session.Event { return e[i] } +func (e fakeEvents) All() iter.Seq[*session.Event] { + return func(yield func(*session.Event) bool) { + for _, ev := range e { + if !yield(ev) { + return + } + } } } @@ -461,6 +501,12 @@ func makeAgent(t *testing.T, edges []workflow.Edge) agent.Agent { // pause/resume round-trips through fakeSessionState as they would // in production. func newMockCtx(sess session.Session, agt agent.Agent, msg *genai.Content) *MockInvocationContext { + // Append the inbound user turn to history first, as the runner + // does in production, so a resume turn's FunctionResponse is + // visible to ReconstructRunState. + if fs, ok := sess.(*fakeSession); ok { + fs.appendUserMessage(msg) + } return &MockInvocationContext{ Context: context.TODO(), sess: sess, @@ -469,13 +515,11 @@ func newMockCtx(sess session.Session, agt agent.Agent, msg *genai.Content) *Mock } } -// drainAgent consumes the agent's iter.Seq2, collecting events, -// and applies each event's StateDelta to sess. The applyStateDelta -// step replaces the AppendEvent-side state propagation that the -// real runner performs; without it state writes from the agent -// would never become visible to subsequent calls. Fails the test -// if the iterator yields a non-nil error that the test did not -// opt into via wantErr. +// drainAgent consumes the agent's iter.Seq2, collecting events and +// appending each to sess. The append step is the test analogue of +// the runner's AppendEvent: it builds the session history that the +// next turn's ReconstructRunState reads. Fails the test if the +// iterator yields a non-nil error the test did not opt into. func drainAgent(t *testing.T, sess *fakeSession, seq iter.Seq2[*session.Event, error], wantErr error) []*session.Event { t.Helper() var got []*session.Event @@ -488,7 +532,7 @@ func drainAgent(t *testing.T, sess *fakeSession, seq iter.Seq2[*session.Event, e continue } got = append(got, ev) - sess.applyStateDelta(ev) + sess.appendEvent(ev) } switch { case wantErr == nil && sawErr != nil: diff --git a/agent/workflowagent/workflow.go b/agent/workflowagent/workflow.go index 5fbcc7610..ebb32da6e 100644 --- a/agent/workflowagent/workflow.go +++ b/agent/workflowagent/workflow.go @@ -95,7 +95,12 @@ type workflowAgent struct { // Workflow.Run (every other turn). func (a *workflowAgent) run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { return func(yield func(*session.Event, error) bool) { - if responses, state, ok := a.detectResume(ctx); ok { + responses, state, ok, err := a.detectResume(ctx) + if err != nil { + yield(nil, err) + return + } + if ok { for ev, err := range a.workflow.Resume(ctx, state, responses) { if !yield(ev, err) { return @@ -116,10 +121,10 @@ func (a *workflowAgent) run(ctx agent.InvocationContext) iter.Seq2[*session.Even // responses map keyed by InterruptID (suitable for // Workflow.Resume), the RunState loaded from session, and true if // this turn is a resume; (nil, nil, false) for a fresh turn. -func (a *workflowAgent) detectResume(ctx agent.InvocationContext) (map[string]any, *workflow.RunState, bool) { +func (a *workflowAgent) detectResume(ctx agent.InvocationContext) (map[string]any, *workflow.RunState, bool, error) { frs := utils.FunctionResponses(ctx.UserContent()) if len(frs) == 0 { - return nil, nil, false + return nil, nil, false, nil } responses := map[string]any{} @@ -130,17 +135,20 @@ func (a *workflowAgent) detectResume(ctx agent.InvocationContext) (map[string]an responses[fr.ID] = decodeWorkflowInputResponse(fr) } if len(responses) == 0 { - return nil, nil, false + return nil, nil, false, nil } - state, err := workflow.LoadRunState(ctx.Session(), a.workflow.Name()) - if err != nil || state == nil { - // No persisted state means there is nothing to resume; - // fall through to a fresh Workflow.Run. - return nil, nil, false + state, err := a.workflow.ReconstructRunState(ctx.Session()) + if err != nil { + // A bad resume (e.g. failed schema validation) must fail, + // not silently fall through to a fresh Run. + return nil, nil, false, err + } + if state == nil { + return nil, nil, false, nil } - return responses, state, true + return responses, state, true, nil } // decodeWorkflowInputResponse extracts the user-supplied payload diff --git a/session/database/service_test.go b/session/database/service_test.go index 90de7adcb..3898e9cb2 100644 --- a/session/database/service_test.go +++ b/session/database/service_test.go @@ -31,6 +31,49 @@ func Test_databaseService(t *testing.T) { }) } +// TestDatabaseService_AppendEvent_WorkflowFieldsRoundTrip guards that +// the storage layer serializes/deserializes the workflow event fields +// (NodeInfo, RequestedInput, Routes); dropping them breaks HITL resume. +func TestDatabaseService_AppendEvent_WorkflowFieldsRoundTrip(t *testing.T) { + ctx := t.Context() + s := emptyService(t) + + created, err := s.Create(ctx, &session.CreateRequest{AppName: "app", UserID: "user"}) + if err != nil { + t.Fatalf("Create: %v", err) + } + + event := &session.Event{ + ID: "wf_event", + Author: "agent", + NodeInfo: &session.NodeInfo{Path: "ask_name"}, + RequestedInput: &session.RequestInput{InterruptID: "ask_name", Message: "What's your name?"}, + Routes: []string{"route_a"}, + } + if err := s.AppendEvent(ctx, created.Session, event); err != nil { + t.Fatalf("AppendEvent: %v", err) + } + + got, err := s.Get(ctx, &session.GetRequest{AppName: "app", UserID: "user", SessionID: created.Session.ID()}) + if err != nil { + t.Fatalf("Get: %v", err) + } + evs := got.Session.Events() + if evs.Len() != 1 { + t.Fatalf("got %d events, want 1", evs.Len()) + } + ev := evs.At(0) + if ev.NodeInfo == nil || ev.NodeInfo.Path != "ask_name" { + t.Errorf("NodeInfo not persisted: %#v", ev.NodeInfo) + } + if ev.RequestedInput == nil || ev.RequestedInput.InterruptID != "ask_name" { + t.Errorf("RequestedInput not persisted: %#v", ev.RequestedInput) + } + if len(ev.Routes) != 1 || ev.Routes[0] != "route_a" { + t.Errorf("Routes not persisted: %#v", ev.Routes) + } +} + func emptyService(t *testing.T) *databaseService { t.Helper() gormConfig := &gorm.Config{ diff --git a/session/database/storage_session.go b/session/database/storage_session.go index 548e393c2..f8cd51e8b 100644 --- a/session/database/storage_session.go +++ b/session/database/storage_session.go @@ -81,6 +81,8 @@ type storageEvent struct { LongRunningToolIDsJSON dynamicJSON RoutesJSON dynamicJSON OutputJSON dynamicJSON + NodeInfoJSON dynamicJSON + RequestedInputJSON dynamicJSON Branch *string Timestamp time.Time `gorm:"precision:6"` @@ -153,6 +155,22 @@ func createStorageEvent(session session.Session, event *session.Event) (*storage storageEv.OutputJSON = outputJSON } + if event.NodeInfo != nil { + nodeInfoJSON, err := json.Marshal(event.NodeInfo) + if err != nil { + return nil, fmt.Errorf("failed to marshal node info: %w", err) + } + storageEv.NodeInfoJSON = nodeInfoJSON + } + + if event.RequestedInput != nil { + reqInputJSON, err := json.Marshal(event.RequestedInput) + if err != nil { + return nil, fmt.Errorf("failed to marshal requested input: %w", err) + } + storageEv.RequestedInputJSON = reqInputJSON + } + // Handle optional fields by taking the address of the value. // An empty string from the event becomes a nil pointer in storage. if event.Branch != "" { @@ -282,6 +300,20 @@ func createEventFromStorageEvent(se *storageEvent) (*session.Event, error) { } } + var nodeInfo *session.NodeInfo + if se.NodeInfoJSON != nil { + if err := json.Unmarshal([]byte(se.NodeInfoJSON), &nodeInfo); err != nil { + return nil, fmt.Errorf("failed to unmarshal node info: %w", err) + } + } + + var requestedInput *session.RequestInput + if se.RequestedInputJSON != nil { + if err := json.Unmarshal([]byte(se.RequestedInputJSON), &requestedInput); err != nil { + return nil, fmt.Errorf("failed to unmarshal requested input: %w", err) + } + } + // --- Handle simple pointer fields (dereference or use zero value) --- // Use the helper to safely get the value or its zero-value default branch := derefOrZero(se.Branch) @@ -302,6 +334,8 @@ func createEventFromStorageEvent(se *storageEvent) (*session.Event, error) { Routes: routes, Branch: branch, Output: output, + NodeInfo: nodeInfo, + RequestedInput: requestedInput, LLMResponse: model.LLMResponse{ Content: content, GroundingMetadata: groundingMetadata, diff --git a/session/inmemory.go b/session/inmemory.go index 46c509c56..be8bf11a4 100644 --- a/session/inmemory.go +++ b/session/inmemory.go @@ -238,8 +238,11 @@ func (s *inMemoryService) AppendEvent(ctx context.Context, curSession Session, e SkipSummarization: event.Actions.SkipSummarization, }, LongRunningToolIDs: slices.Clone(event.LongRunningToolIDs), + Routes: slices.Clone(event.Routes), + RequestedInput: event.RequestedInput, LLMResponse: event.LLMResponse, Output: event.Output, + NodeInfo: event.NodeInfo, } // update the in-memory session service diff --git a/session/inmemory_test.go b/session/inmemory_test.go index 1d178be89..82fbed105 100644 --- a/session/inmemory_test.go +++ b/session/inmemory_test.go @@ -78,6 +78,51 @@ func Test_inMemoryService_CreateConcurrentAccess(t *testing.T) { } } +// TestInMemorySession_AppendEvent_WorkflowFieldsRoundTrip guards that +// AppendEvent persists the workflow event fields (NodeInfo, +// RequestedInput, Routes) — workflow resume rehydrates node state from +// them, and a manual event copy that drops them breaks HITL resume. +func TestInMemorySession_AppendEvent_WorkflowFieldsRoundTrip(t *testing.T) { + ctx := t.Context() + service := session.InMemoryService() + + createResp, err := service.Create(ctx, &session.CreateRequest{AppName: "app", UserID: "user"}) + if err != nil { + t.Fatalf("Create: %v", err) + } + sess := createResp.Session + + event := &session.Event{ + ID: "wf_event", + Author: "agent", + NodeInfo: &session.NodeInfo{Path: "ask_name"}, + RequestedInput: &session.RequestInput{InterruptID: "ask_name", Message: "What's your name?"}, + Routes: []string{"route_a"}, + } + if err := service.AppendEvent(ctx, sess, event); err != nil { + t.Fatalf("AppendEvent: %v", err) + } + + got, err := service.Get(ctx, &session.GetRequest{AppName: "app", UserID: "user", SessionID: sess.ID()}) + if err != nil { + t.Fatalf("Get: %v", err) + } + evs := got.Session.Events() + if evs.Len() != 1 { + t.Fatalf("got %d events, want 1", evs.Len()) + } + ev := evs.At(0) + if ev.NodeInfo == nil || ev.NodeInfo.Path != "ask_name" { + t.Errorf("NodeInfo not persisted: %#v", ev.NodeInfo) + } + if ev.RequestedInput == nil || ev.RequestedInput.InterruptID != "ask_name" { + t.Errorf("RequestedInput not persisted: %#v", ev.RequestedInput) + } + if len(ev.Routes) != 1 || ev.Routes[0] != "route_a" { + t.Errorf("Routes not persisted: %#v", ev.Routes) + } +} + func TestInMemorySession_AppendEvent_Deadlock(t *testing.T) { ctx := t.Context() service := session.InMemoryService() diff --git a/workflow/dynamic_scheduler.go b/workflow/dynamic_scheduler.go index 469663daf..f7c383fb1 100644 --- a/workflow/dynamic_scheduler.go +++ b/workflow/dynamic_scheduler.go @@ -41,13 +41,41 @@ type dynamicSubScheduler struct { } func newDynamicSubScheduler(parent NodeContext, parentPath string, emitUp func(*session.Event) error) *dynamicSubScheduler { - return &dynamicSubScheduler{ + s := &dynamicSubScheduler{ parentPath: parentPath, parentCtx: parent, emitUp: emitUp, runCountByChild: map[string]int{}, resultByPath: map[string]any{}, } + s.rehydrateCache() + return s +} + +// rehydrateCache repopulates resultByPath from session history so a +// resumed orchestrator (which re-runs from the top) serves already +// completed children from cache instead of re-executing them. Each +// child's terminal event carries its childPath in NodeInfo.Path and a +// non-nil Output; keyed by childPath, so only stable WithRunID calls +// hit (auto-counter ids regenerate per activation and miss). +func (s *dynamicSubScheduler) rehydrateCache() { + sess := s.parentCtx.Session() + if sess == nil { + return + } + prefix := s.parentPath + "/" + s.mu.Lock() + defer s.mu.Unlock() + for ev := range sess.Events().All() { + if ev == nil || ev.Output == nil || ev.NodeInfo == nil { + continue + } + if !strings.HasPrefix(ev.NodeInfo.Path, prefix) { + continue + } + // Last write wins, matching live execution order. + s.resultByPath[ev.NodeInfo.Path] = ev.Output + } } // runNode executes child once and classifies the outcome: HITL → diff --git a/workflow/hitl_test.go b/workflow/hitl_test.go index 3a3fff7e9..0ef5b19a7 100644 --- a/workflow/hitl_test.go +++ b/workflow/hitl_test.go @@ -157,12 +157,11 @@ func TestScheduler_HitlNode_PreservesExplicitInterruptID(t *testing.T) { } } -// TestScheduler_HitlNode_MultipleRequestsFails verifies the -// single-request-per-activation invariant: a node yielding two -// RequestedInput events surfaces ErrMultipleInputRequests at -// completion and is treated as failed, so it does not silently -// park in NodeWaiting. -func TestScheduler_HitlNode_MultipleRequestsFails(t *testing.T) { +// TestScheduler_HitlNode_MultipleRequestsPark verifies a node may +// raise more than one interrupt in a single activation: both are +// recorded on NodeState.Interrupts and the node parks NodeWaiting +// (matching adk-python, which accumulates a set of interrupt IDs). +func TestScheduler_HitlNode_MultipleRequestsPark(t *testing.T) { mockCtx := newSeededMockCtx(t) asker := newHitlNode("asker", func(ctx agent.InvocationContext, _ any, yield func(*session.Event, error) bool) { @@ -172,9 +171,19 @@ func TestScheduler_HitlNode_MultipleRequestsFails(t *testing.T) { w := mustNew(t, []Edge{{From: Start, To: asker}}) - gotErr := drainErr(t, w.Run(mockCtx)) - if !errors.Is(gotErr, ErrMultipleInputRequests) { - t.Errorf("Run error = %v, want ErrMultipleInputRequests (node must fail, not park)", gotErr) + // drain fails the test on any error; multiple interrupts must + // park cleanly rather than surface an error. Both pause events + // carry their interrupt on LongRunningToolIDs — the signal the + // scheduler accumulates and history rehydration reads back. + events := drain(t, w.Run(mockCtx)) + got := map[string]bool{} + for _, ev := range events { + for _, id := range ev.LongRunningToolIDs { + got[id] = true + } + } + if !got["first"] || !got["second"] { + t.Errorf("long-running interrupts = %v, want both first and second", got) } } diff --git a/workflow/persistence.go b/workflow/persistence.go index 268fe2fa8..cfc9dfb2c 100644 --- a/workflow/persistence.go +++ b/workflow/persistence.go @@ -16,86 +16,457 @@ package workflow import ( "encoding/json" - "errors" "fmt" + "strings" + + "github.com/google/jsonschema-go/jsonschema" + "google.golang.org/genai" "google.golang.org/adk/session" ) -// runStateSessionKeyPrefix is the prefix used by RunStateSessionKey -// for namespacing workflow RunStates inside session.State. -const runStateSessionKeyPrefix = "adk.workflow.runstate." - -// RunStateSessionKey returns the session.State key under which a -// workflow's RunState is persisted between invocations. Namespaced -// by workflow name so multiple workflows in the same session do -// not collide. -func RunStateSessionKey(workflowName string) string { - return runStateSessionKeyPrefix + workflowName -} - -// LoadRunState reads and decodes the workflow's persisted RunState -// from the given session. Returns (nil, nil) when no state has -// been stored yet, so callers can distinguish "nothing to resume" -// from "load failed". An empty workflowName disables persistence -// and always returns (nil, nil). -func LoadRunState(sess session.Session, workflowName string) (*RunState, error) { - if sess == nil || workflowName == "" { +// nodeScanState accumulates, per node, what the session history says +// about a paused run. Mirrors adk-python's _ChildScanState. +type nodeScanState struct { + // interrupts are the long-running tool IDs the node raised + // (insertion-ordered for stable reconstruction). + interrupts []string + seen map[string]struct{} + // resolved maps an interrupt ID to the (last) user response. + resolved map[string]any + // resolvedCount maps an interrupt ID to how many user + // FunctionResponse events in history resolved it. 1 means the + // response arrived this turn for the first time; >1 means a + // duplicate resume replayed an already-consumed response. Lets + // Resume tell a genuine first resume from an idempotent no-op. + resolvedCount map[string]int + // schemas maps an interrupt ID to its declared response schema, + // re-extracted from the pause FunctionCall args. + schemas map[string]*jsonschema.Schema + branch string +} + +func (s *nodeScanState) addInterrupt(id string) { + if s.seen == nil { + s.seen = map[string]struct{}{} + } + if _, ok := s.seen[id]; ok { + return + } + s.seen[id] = struct{}{} + s.interrupts = append(s.interrupts, id) +} + +// ReconstructRunState rebuilds the paused RunState by scanning session +// history instead of loading a persisted blob, mirroring adk-python's +// rehydration (workflow/utils/_rehydration_utils.py: +// _reconstruct_node_states + _workflow.py:_infer_node_state). +// +// For each node it collects the long-running interrupts it raised +// (Event.LongRunningToolIDs, attributed by event node path), the user +// FunctionResponses that resolved them, and each interrupt's declared +// response schema. inferNodeState then maps that scan to a NodeState +// (WAITING / PENDING+ResumedInputs / COMPLETED+Output). Returns +// (nil, nil) when no node has interrupt history. +func (w *Workflow) ReconstructRunState(sess session.Session) (*RunState, error) { + if sess == nil { return nil, nil } - state := sess.State() + nodesByName := buildNodesByName(w.graph) + events := sess.Events() + + // Stage 1: scan history into a per-node view of the pause + // (interrupts raised, responses that resolved them, schemas). + scans := scanHistory(events, nodesByName) + + // Stage 2: gather the inputs inferNodeState needs to rebuild a + // re-entry node's input: every node's cached output, the set of + // nodes that ran, and the workflow's seed input. + nodeOutputs, completed := collectNodeOutputs(events, nodesByName) + workflowInput := firstUserInput(events) + + // Stage 3: turn each interrupted node's scan into a NodeState. + state, err := w.buildRunState(scans, nodesByName, nodeOutputs, workflowInput) + if err != nil { + return nil, err + } if state == nil { return nil, nil } - raw, err := state.Get(RunStateSessionKey(workflowName)) - if err != nil { - if errors.Is(err, session.ErrStateKeyNotExist) { - return nil, nil + + // WAITING nodes have not finished, so Resume must not treat them + // as already-run; the rest stay in completed to skip their + // successors. + for name, ns := range state.Nodes { + if ns.Status == NodeWaiting { + delete(completed, name) } - return nil, err } + state.completed = completed + return state, nil +} + +// scanHistory walks session events once and returns, per static graph +// node, what history says about a paused run: the long-running +// interrupts it raised, the user responses that resolved them, and +// each interrupt's declared response schema. Only nodes with +// interrupt history are returned. +func scanHistory(events session.Events, nodesByName map[string]Node) map[string]*nodeScanState { + scans := map[string]*nodeScanState{} + interruptOwner := map[string]string{} // interrupt ID -> node name + scanFor := func(name string) *nodeScanState { + s := scans[name] + if s == nil { + s = &nodeScanState{resolved: map[string]any{}, resolvedCount: map[string]int{}, schemas: map[string]*jsonschema.Schema{}} + scans[name] = s + } + return s + } + + for i := 0; i < events.Len(); i++ { + ev := events.At(i) + if ev == nil { + continue + } + + // A user FunctionResponse resolves an interrupt — not the + // tool's own initial "pending" response (authored by the + // node). Mirrors adk-python's event.author == 'user' gate. + // Last response per interrupt wins, so a retry after a + // rejected payload supersedes the earlier one. + if ev.Author == "user" && ev.Content != nil { + for _, p := range ev.Content.Parts { + fr := frPart(p) + if fr == nil { + continue + } + owner, ok := interruptOwner[fr.ID] + if !ok { + continue + } + sf := scanFor(owner) + sf.resolved[fr.ID] = unwrapResponse(fr.Response) + sf.resolvedCount[fr.ID]++ + } + continue + } + + // Interrupts the node raised, attributed to the static graph + // node that emitted the event (NodeInfo.Path; dynamic children + // fold into their static ancestor — see eventNodeName). + owner := eventNodeName(ev) + if _, ok := nodesByName[owner]; !ok { + continue + } + s := scanFor(owner) + if ev.Output != nil { + s.branch = ev.Branch + } + for _, id := range ev.LongRunningToolIDs { + if id == "" { + continue + } + s.addInterrupt(id) + if s.branch == "" { + s.branch = ev.Branch + } + interruptOwner[id] = owner + if sc := schemaFromEvent(ev, id); sc != nil { + s.schemas[id] = sc + } + } + } + return scans +} + +// collectNodeOutputs walks history once and returns each graph node's +// last cached output plus the set of nodes that emitted any event. +// The outputs feed predecessor-input reconstruction for re-entry +// nodes; completed lets Resume skip already-run successors. +func collectNodeOutputs(events session.Events, nodesByName map[string]Node) (outputs map[string]any, completed map[string]bool) { + outputs = map[string]any{} + completed = map[string]bool{} + for i := 0; i < events.Len(); i++ { + ev := events.At(i) + if ev == nil { + continue + } + name := eventNodeName(ev) + if _, ok := nodesByName[name]; !ok { + continue + } + completed[name] = true + if ev.Output != nil { + outputs[name] = ev.Output + } + } + return outputs, completed +} + +// buildRunState maps each interrupted node's scan to a NodeState via +// inferNodeState. Returns (nil, nil) when no node has interrupt +// history, matching the "nothing to resume" case. +func (w *Workflow) buildRunState(scans map[string]*nodeScanState, nodesByName map[string]Node, nodeOutputs map[string]any, workflowInput any) (*RunState, error) { + var state *RunState + for nodeName, scan := range scans { + if len(scan.interrupts) == 0 { + continue + } + ns, err := w.inferNodeState(nodesByName[nodeName], scan, nodeOutputs, workflowInput) + if err != nil { + return nil, err + } + if ns == nil { + continue + } + if state == nil { + state = NewRunState() + } + state.Nodes[nodeName] = ns + } + return state, nil +} + +// unresolvedInterrupts returns the interrupts the node raised that no +// user response has resolved yet, preserving insertion order. +func unresolvedInterrupts(scan *nodeScanState) []string { + unresolved := make([]string, 0, len(scan.interrupts)) + for _, id := range scan.interrupts { + if _, done := scan.resolved[id]; !done { + unresolved = append(unresolved, id) + } + } + return unresolved +} + +// rerunsOnResume reports whether the node opted into re-entry mode +// (NodeConfig.RerunOnResume), in which Resume re-runs the node with +// the user responses rather than handing off to its successors. +func rerunsOnResume(node Node) bool { + if node == nil { + return false + } + r := node.Config().RerunOnResume + return r != nil && *r +} - // raw is JSON-encoded []byte (or its base64-string form when - // the session backend round-trips StateDelta through JSON). - decode := func(b []byte) (*RunState, error) { - var state RunState - if err := json.Unmarshal(b, &state); err != nil { - return nil, fmt.Errorf("workflow: decode run state: %w", err) +// validateResolved validates each surviving (last-wins) response +// against its declared schema and returns the responses keyed by +// interrupt ID. A superseded invalid payload never reaches here. +func validateResolved(scan *nodeScanState) (map[string]any, error) { + resumed := map[string]any{} + for id, resp := range scan.resolved { + if sc := scan.schemas[id]; sc != nil { + validated, err := validateResumeResponse(resp, sc) + if err != nil { + return nil, fmt.Errorf("%w: interrupt %q: %w", ErrInvalidResumeResponse, id, err) + } + resp = validated } - return &state, nil + resumed[id] = resp + } + return resumed, nil +} + +// inferNodeState maps a node's scan to a NodeState, mirroring +// adk-python _infer_node_state. +// +// Status priority: +// - unresolved interrupts, re-run + some resolved -> NodePending +// (partial resume: re-run with the resolved responses) +// - unresolved interrupts otherwise -> NodeWaiting +// - all resolved, re-run -> NodePending (re-entry) +// - all resolved, handoff -> NodeCompleted +// with Output = the response (forwarded to successors by Resume) +func (w *Workflow) inferNodeState(node Node, scan *nodeScanState, nodeOutputs map[string]any, workflowInput any) (*NodeState, error) { + unresolved := unresolvedInterrupts(scan) + reenter := rerunsOnResume(node) + + resumed, err := validateResolved(scan) + if err != nil { + return nil, err } - switch v := raw.(type) { - case []byte: - return decode(v) - case string: - return decode([]byte(v)) + + ns := &NodeState{Branch: scan.branch, interruptSchemas: scan.schemas} + + switch { + case len(unresolved) > 0 && reenter && len(resumed) > 0: + // Partial resume: re-run with resolved responses so the node + // can proceed or re-interrupt. + ns.Status = NodePending + ns.ResumedInputs = resumed + ns.Interrupts = unresolved + ns.Input, ns.TriggeredBy = w.predecessorInput(node, nodeOutputs, workflowInput) + case len(unresolved) > 0: + // Still waiting for the remaining interrupts. + ns.Status = NodeWaiting + ns.Interrupts = unresolved + if len(resumed) > 0 { + ns.ResumedInputs = resumed + } + case reenter: + // All resolved, re-entry: re-run with the responses. + ns.Status = NodePending + ns.ResumedInputs = resumed + ns.Input, ns.TriggeredBy = w.predecessorInput(node, nodeOutputs, workflowInput) default: - return nil, fmt.Errorf("workflow: run state has unexpected type %T (want []byte or string)", raw) + // All resolved, handoff: the node is done; its output is the + // response, which Resume forwards to successors. Keep the + // resolved responses so Resume can gate the idempotent + // successor trigger on this turn's responses. + ns.Status = NodeCompleted + ns.Output = resumeOutput(resumed) + ns.ResumedInputs = resumed + // A response seen for the first time this turn (count == 1) + // marks a genuine first resume; a duplicate turn replays an + // already-counted response (>= 2) and must stay a no-op. + for id := range resumed { + if scan.resolvedCount[id] == 1 { + ns.answeredThisTurn = true + break + } + } + } + return ns, nil +} + +// predecessorInput walks incoming edges backward to find a resuming +// node's input: a predecessor's cached output, else the workflow seed +// input for a START successor. Mirrors adk-python +// _find_predecessor_input. +func (w *Workflow) predecessorInput(node Node, nodeOutputs map[string]any, workflowInput any) (any, string) { + if node == nil { + return nil, "" + } + incoming := w.graph.predecessorsOf(node) + if len(incoming) == 0 { + return nil, "" + } + for _, e := range incoming { + from := e.From.Name() + if from != Start.Name() { + if out, ok := nodeOutputs[from]; ok { + return out, from + } + } + } + for _, e := range incoming { + if e.From.Name() == Start.Name() { + return workflowInput, Start.Name() + } + } + return nodeOutputs[incoming[0].From.Name()], incoming[0].From.Name() +} + +// firstUserInput returns the seed workflow input: the text of the +// first user event in history (the original prompt), used as the +// START successor's input on re-entry. Resume turns (user +// FunctionResponses) are skipped. +func firstUserInput(events session.Events) any { + for i := 0; i < events.Len(); i++ { + ev := events.At(i) + if ev == nil || ev.Author != "user" || ev.Content == nil { + continue + } + var text string + hasFR := false + for _, p := range ev.Content.Parts { + if p == nil { + continue + } + if p.FunctionResponse != nil { + hasFR = true + } + text += p.Text + } + if hasFR { + continue + } + if text != "" { + return text + } } + return nil } -// NewRunStateEvent builds a session.Event whose Actions.StateDelta -// carries the workflow's serialised RunState. Workflow.Run and -// Workflow.Resume yield this event before returning so the -// surrounding event-append pipeline can persist the state -// alongside every other delta-based update. +// eventNodeName returns the name of the static graph node that owns +// ev, for attribution during rehydration. // -// Persistence backends apply state mutations only when they -// observe them on Event.Actions.StateDelta during the append -// path; a direct session.State().Set updates the per-invocation -// copy but is not propagated to storage. Returning nil for an -// empty workflowName lets callers use NewRunStateEvent -// unconditionally and skip the yield when persistence is not -// desired. -func NewRunStateEvent(invocationID, workflowName string, state *RunState) (*session.Event, error) { - if workflowName == "" || state == nil { - return nil, nil +// Static node events are stamped with NodeInfo.Path == node name. A +// dynamic child invoked via RunNode carries a hierarchical path like +// "parent/child@1"; its interrupt is owned by the nearest static +// ancestor (the first path segment). Falls back to Author for the +// LlmAgent node path, where Author == node name and no path is set. +func eventNodeName(ev *session.Event) string { + if ev.NodeInfo != nil && ev.NodeInfo.Path != "" { + path := ev.NodeInfo.Path + if i := strings.IndexByte(path, '/'); i >= 0 { + return path[:i] + } + return path } - bytes, err := json.Marshal(state) - if err != nil { - return nil, fmt.Errorf("workflow: encode run state: %w", err) + return ev.Author +} + +// frPart returns the FunctionResponse on a part if present and keyed. +func frPart(p *genai.Part) *genai.FunctionResponse { + if p == nil || p.FunctionResponse == nil || p.FunctionResponse.ID == "" { + return nil + } + return p.FunctionResponse +} + +// schemaFromEvent re-extracts the response schema for interrupt id +// from the pause event (RequestedInput or the adk_request_input +// FunctionCall args), mirroring adk-python _extract_schema_from_event. +// The schema lives only in the events; it is not persisted. +func schemaFromEvent(ev *session.Event, id string) *jsonschema.Schema { + if ev.RequestedInput != nil && ev.RequestedInput.InterruptID == id { + return ev.RequestedInput.ResponseSchema + } + if ev.Content == nil { + return nil + } + for _, p := range ev.Content.Parts { + if p == nil { + continue + } + fc := p.FunctionCall + if fc == nil || fc.Name != WorkflowInputFunctionCallName || fc.ID != id { + continue + } + if raw, ok := fc.Args["responseSchema"]; ok { + if sc, ok := raw.(*jsonschema.Schema); ok { + return sc + } + } + } + return nil +} + +// unwrapResponse extracts the original value from a FunctionResponse +// payload. A sole single-key wrapper — {"result": v} (adk-python), +// {"response": v} or {"payload": v} (adk-go) — is unwrapped, with +// string values JSON-parsed when possible; anything else passes +// through. Mirrors adk-python _unwrap_response, extended with the +// adk-go keys for cross-runtime sessions. +func unwrapResponse(data map[string]any) any { + if len(data) != 1 { + return data + } + for _, key := range []string{"result", "response", "payload"} { + v, ok := data[key] + if !ok { + continue + } + if s, isStr := v.(string); isStr { + var parsed any + if err := json.Unmarshal([]byte(s), &parsed); err == nil { + return parsed + } + return s + } + return v } - ev := session.NewEvent(invocationID) - ev.Actions.StateDelta[RunStateSessionKey(workflowName)] = bytes - return ev, nil + return data } diff --git a/workflow/resume.go b/workflow/resume.go index c9827f89d..5f92bc241 100644 --- a/workflow/resume.go +++ b/workflow/resume.go @@ -93,76 +93,102 @@ func (w *Workflow) Resume( resp any } var deferredHandoffs []deferredHandoff - - // Pass 1: dispatch every waiting asker matched by - // responses (handoff → defer; re-entry → reschedule now). scheduled := 0 + + // Act on each node the rehydration reconstructed, but only + // for interrupts answered in THIS turn (present in responses). + // Gating on the current turn's responses keeps Resume + // idempotent: a duplicate turn whose responses target an + // already-consumed interrupt reschedules nothing. Mirrors + // adk-python gating _extract_resume_output on ctx.resume_inputs. for name, ns := range state.Nodes { - if ns.Status != NodeWaiting || ns.PendingRequest == nil { - continue - } - resp, ok := responses[ns.PendingRequest.InterruptID] - if !ok { + node := s.nodesByName[name] + if node == nil { continue } - // Schema validation: surface a typed error and leave - // the node parked so the caller can retry. - if ns.PendingRequest.ResponseSchema != nil { - validated, err := validateResumeResponse(resp, ns.PendingRequest.ResponseSchema) - if err != nil { - if !yield(nil, fmt.Errorf("%w: node %q: %w", ErrInvalidResumeResponse, name, err)) { - return + // Which of this node's interrupts were answered this turn? + answeredNow := false + for id := range ns.ResumedInputs { + if _, ok := responses[id]; ok { + answeredNow = true + break + } + } + // WAITING nodes whose response arrived this turn but is not + // yet in history (the runner node path passes responses + // directly): fold it into ResumedInputs after validation. + freshMatched := map[string]any{} + if ns.Status == NodeWaiting { + schemaErr := false + for _, id := range ns.Interrupts { + resp, ok := responses[id] + if !ok { + continue + } + if sc := ns.interruptSchemas[id]; sc != nil { + validated, err := validateResumeResponse(resp, sc) + if err != nil { + if !yield(nil, fmt.Errorf("%w: node %q: %w", ErrInvalidResumeResponse, name, err)) { + return + } + schemaErr = true + break + } + resp = validated } + freshMatched[id] = resp + } + if schemaErr { continue } - resp = validated } - - node := s.nodesByName[name] - if node == nil { + if !answeredNow && len(freshMatched) == 0 { continue } - // Snapshot InterruptID before consuming PendingRequest; - // re-entry mode passes it through resumeInputs. - interruptID := ns.PendingRequest.InterruptID - - // Consume PendingRequest before scheduling. A duplicate - // Resume with the same InterruptID will skip this node - // because PendingRequest is now nil. - ns.PendingRequest = nil - ns.Status = NodePending - + reenter := false if r := node.Config().RerunOnResume; r != nil && *r { - // Re-entry mode: re-activate the asker with its - // original input; the response is delivered via - // ctx.ResumedInput(InterruptID), not via the - // input parameter. Successors fire only when the - // re-entry activation produces an output. - // - // Accumulate into ns.ResumedInputs so a node that - // yields multiple RequestInputs across resume - // cycles sees every prior response, not just the - // most recent one. The map is cleared when the - // node transitions to NodeCompleted. + reenter = true + } + + if reenter || ns.Status == NodePending { + // Re-entry: re-activate with the resolved responses + // delivered via ctx.ResumedInput. if ns.ResumedInputs == nil { ns.ResumedInputs = map[string]any{} } - ns.ResumedInputs[interruptID] = resp + for id, resp := range freshMatched { + ns.ResumedInputs[id] = resp + } + ns.Status = NodePending s.scheduleResumedNode(node, ns.Input, ns.TriggeredBy, ns.Branch, ns.ResumedInputs) + scheduled++ } else { - // Handoff mode: promote the asker as if it had - // emitted resp as its output. Recording Output - // lets the join barrier read it back without a - // special case for "completed via resume". + // Handoff: the response is the asker's output for its + // successors; the asker does not re-run. + out := ns.Output + if len(freshMatched) > 0 { + out = resumeOutput(freshMatched) + } ns.Status = NodeCompleted - ns.Output = resp + ns.Output = out + ns.Interrupts = nil deferredHandoffs = append(deferredHandoffs, deferredHandoff{ - node: node, resp: resp, + node: node, resp: out, }) + // A matched asker is itself an effective resume even + // when terminal (no successors to count in Pass 2): + // without this a single-asker workflow would wrongly + // report ErrNothingToResume. answeredThisTurn gates on + // the response being new this turn (rehydration sets it + // from resolvedCount), so a duplicate resume stays a + // no-op. freshMatched covers the runner-direct path + // where the response is not yet in history. + if ns.answeredThisTurn || len(freshMatched) > 0 { + scheduled++ + } } - scheduled++ } // Pass 2: walk successors of the deferred handoffs. @@ -183,7 +209,14 @@ func (w *Workflow) Resume( parentBranch = ns.Branch } for _, succ := range findSuccessors(s.graph, s.state, h.node, h.resp, nil, parentBranch) { + // Skip a successor that already produced output on a + // prior turn: re-triggering it would re-run completed + // work (a duplicate resume). Keeps Resume idempotent. + if state.completed[succ.node.Name()] { + continue + } s.scheduleNode(succ.node, succ.input, succ.triggeredBy, succ.branch) + scheduled++ } } @@ -194,13 +227,20 @@ func (w *Workflow) Resume( s.run(yield) s.wg.Wait() + } +} - // Persist the post-resume state via a session.Event with - // StateDelta. If new nodes paused during this Resume the - // next turn will see them; if the run completed the state - // reflects that too. - yieldRunStateEvent(ctx, w.name, s.state, yield) +// resumeOutput collapses a node's matched interrupt responses into a +// single handoff output: one response forwards its value directly, +// several forward the whole map. Mirrors adk-python +// _extract_resume_output. +func resumeOutput(matched map[string]any) any { + if len(matched) == 1 { + for _, v := range matched { + return v + } } + return matched } // validateResumeResponse coerces resp into the type described by diff --git a/workflow/scheduler.go b/workflow/scheduler.go index cc8910246..0ef4076b8 100644 --- a/workflow/scheduler.go +++ b/workflow/scheduler.go @@ -45,11 +45,6 @@ var ( // than one event whose Routes field is set. A node activation // may emit at most one routing decision. ErrMultipleRoutingEvents = errors.New("workflow: node produced multiple events with route tags; only one event per execution can specify routes") - - // ErrMultipleInputRequests is returned when a node yields more - // than one event whose RequestedInput field is set. A node - // activation may issue at most one human-input request. - ErrMultipleInputRequests = errors.New("workflow: node produced multiple events with RequestedInput; only one human-input request per execution is allowed") ) // scheduler drives a single Workflow.Run invocation. It owns the @@ -129,12 +124,16 @@ type pendingActivation struct { // without overwriting the first value, and the consumer surfaces // the error at completion. type nodeRun struct { - routingEvent *session.Event // at most one; multiple is an error - output any // single Event.Output; nil if hasOutput is false - hasOutput bool // distinguishes "no output yet" from "output was nil" - inputRequest *session.RequestInput // at most one human-input request; multiple is an error - err error // set on duplicate output, duplicate routing event, or duplicate input request - branch string // composite branch assigned at scheduling; used to stamp Event.Branch when the node leaves it empty + routingEvent *session.Event // at most one; multiple is an error + output any // single Event.Output; nil if hasOutput is false + hasOutput bool // distinguishes "no output yet" from "output was nil" + err error // set on duplicate output or duplicate routing event + branch string // composite branch assigned at scheduling; used to stamp Event.Branch when the node leaves it empty + + // interruptIDs are unresolved long-running tool call IDs raised by + // the node's events. A non-empty set at completion parks the node + // in NodeWaiting. Mirrors adk-python's Context._interrupt_ids. + interruptIDs map[string]struct{} } // recordErr stores err as the accumulator's first error. Subsequent @@ -156,18 +155,27 @@ func (nr *nodeRun) setRoutingEvent(ev *session.Event, nodeName string) { nr.routingEvent = ev } -// setInputRequest stores req as the node's single in-flight -// human-input request. A second call records -// ErrMultipleInputRequests instead of overwriting; the consumer -// surfaces the error at completion and the node ends up -// NodeFailed (the waiting branch is gated on nr.err == nil so a -// node that requested twice does not silently park). -func (nr *nodeRun) setInputRequest(req *session.RequestInput, nodeName string) { - if nr.inputRequest != nil { - nr.recordErr(fmt.Errorf("%w: node %q", ErrMultipleInputRequests, nodeName)) +// trackInterrupts accumulates the node's long-running tool call IDs. +// +// It does NOT resolve an interrupt from a FunctionResponse seen in the +// same run: that is the tool's own initial "pending" response, not the +// reply. The real reply arrives on a later turn — a fresh run that does +// not re-raise the call, so its interrupt set is empty and the node +// completes. Mirrors adk-python: the long-running call ends the turn +// and the pause persists until a new invocation answers it. +func (nr *nodeRun) trackInterrupts(ev *session.Event) { + if ev == nil { return } - nr.inputRequest = req + for _, id := range ev.LongRunningToolIDs { + if id == "" { + continue + } + if nr.interruptIDs == nil { + nr.interruptIDs = map[string]struct{}{} + } + nr.interruptIDs[id] = struct{}{} + } } // setOutput stores out as the node's single output value. A second @@ -189,9 +197,17 @@ type queueItem interface{ isQueueItem() } // consumer. nodeName is required so the consumer can correlate the // event with the right nodeRun without relying on channel-FIFO- // per-task semantics (which Go channels do not provide). +// +// processed, when non-nil, is a back-pressure handshake: the producing +// goroutine blocks until the consumer closes it (after the event is +// yielded and persisted), so a non-partial function-response is in the +// session before the node's flow rebuilds the next model request. +// Mirrors adk-python's enqueue_event/processed_signal handshake. Nil +// for partial events, which are fire-and-forget. type eventItem struct { - nodeName string - ev *session.Event + nodeName string + ev *session.Event + processed chan struct{} } func (eventItem) isQueueItem() {} @@ -445,12 +461,27 @@ func runNode( completion.err = err return } + // Block on non-partial events until the consumer has persisted + // them (see eventItem.processed). Partial events are + // fire-and-forget. + var processed chan struct{} + if ev != nil && !ev.LLMResponse.Partial { + processed = make(chan struct{}) + } select { - case out <- eventItem{nodeName: name, ev: ev}: + case out <- eventItem{nodeName: name, ev: ev, processed: processed}: case <-ctx.Done(): completion.err = ctx.Err() return } + if processed != nil { + select { + case <-processed: + case <-ctx.Done(): + completion.err = ctx.Err() + return + } + } } // If the node's iter returned cleanly but the context was // cancelled or its deadline elapsed, surface that as the @@ -523,6 +554,12 @@ func (s *scheduler) run(yield func(*session.Event, error) bool) { s.cancelAll() } } + // Release the producer's handshake now that the event is + // yielded and persisted. Always signal — even when + // draining — so a blocked producer does not leak. + if it.processed != nil { + close(it.processed) + } case completionItem: err := s.handleCompletion(it, !draining) if err != nil && pendingErr == nil { @@ -586,12 +623,24 @@ func (s *scheduler) handleEvent(it eventItem) { path = it.ev.NodeInfo.Path } isDescendant := path != "" && path != it.nodeName - if it.ev.RequestedInput != nil { - nr.setInputRequest(it.ev.RequestedInput, it.nodeName) - } + // Track long-running interrupts before the descendant + // short-circuit so a dynamic child's pause is promoted to the + // parent node (a RequestInput pause rides on LongRunningToolIDs). + nr.trackInterrupts(it.ev) if isDescendant { return } + // Stamp the node name onto the event so history rehydration can + // attribute it back to this node. Static nodes leave Path empty + // (the node name is not otherwise on the event — Author is the + // workflow agent, not the node). Matches adk-python, which sets + // node_info.path on every event and attributes by it. + if path == "" { + if it.ev.NodeInfo == nil { + it.ev.NodeInfo = &session.NodeInfo{} + } + it.ev.NodeInfo.Path = it.nodeName + } if it.ev.Routes != nil { nr.setRoutingEvent(it.ev, it.nodeName) } @@ -664,14 +713,22 @@ func (s *scheduler) handleCompletion(it completionItem, scheduleSuccessors bool) return nr.err } - // Happy path: decide between NodeWaiting (a recorded human- - // input request) or NodeCompleted. The waiting branch fires - // regardless of the scheduleSuccessors flag — a request that - // survived the run must be observable in RunState even when - // the consumer is draining, so the caller can persist it. - if nr != nil && nr.inputRequest != nil { + // Happy path: decide between NodeWaiting (an open interrupt) or + // NodeCompleted. The waiting branch fires regardless of the + // scheduleSuccessors flag — an interrupt that survived the run + // must be observable in RunState even when the consumer is + // draining. + // + // Long-running-tool pause (incl. RequestInput, which rides on + // LongRunningToolIDs): park WAITING with the open interrupt IDs + // so resume can match a human's FunctionResponse back to this + // node. Mirrors adk-python _handle_completion. + if nr != nil && len(nr.interruptIDs) > 0 { ns.Status = NodeWaiting - ns.PendingRequest = nr.inputRequest + ns.Interrupts = ns.Interrupts[:0] + for id := range nr.interruptIDs { + ns.Interrupts = append(ns.Interrupts, id) + } return nil } diff --git a/workflow/state.go b/workflow/state.go index d2a5d9880..fdfb30d3a 100644 --- a/workflow/state.go +++ b/workflow/state.go @@ -14,7 +14,7 @@ package workflow -import "google.golang.org/adk/session" +import "github.com/google/jsonschema-go/jsonschema" // NodeStatus is the lifecycle status of a node in the workflow graph. // @@ -109,11 +109,29 @@ type NodeState struct { // derivation to remain stable across pause/resume turns. Branch string `json:"branch,omitempty"` - // PendingRequest, when non-nil, carries the human-input request - // the node emitted before pausing. Non-nil iff Status == - // NodeWaiting and the wait was caused by a human-input request - // (as opposed to a fan-in barrier). - PendingRequest *session.RequestInput `json:"pendingRequest,omitempty"` + // Interrupts holds the long-running tool call IDs the node is + // waiting on. Non-empty iff Status == NodeWaiting due to a + // long-running tool pause; lets resume match a human's + // FunctionResponse to the node. Mirrors adk-python + // NodeState.interrupts. + Interrupts []string `json:"interrupts,omitempty"` + + // interruptSchemas maps an interrupt ID to its declared response + // schema, re-extracted from the pause event during rehydration. + // Not persisted: the schema lives only in the events and is + // rebuilt each turn (matching adk-python, which keeps no schema + // on NodeState). Consumed by Resume to validate the payload. + interruptSchemas map[string]*jsonschema.Schema + + // answeredThisTurn is true when this node's interrupt was + // resolved by a user response that appeared in history for the + // first time on the current resume turn (resolvedCount == 1), as + // opposed to a duplicate resume that replays an already-consumed + // response. Not persisted; rebuilt each turn from event history. + // Lets Resume count a terminal handoff asker (no successors) as + // an effective resume on its first turn while staying a no-op on + // duplicates (idempotency). + answeredThisTurn bool // Attempt is the number of times this node has been failed. Attempt int `json:"attempt,omitempty"` @@ -139,6 +157,12 @@ type RunState struct { // Nodes is the per-node lifecycle map. Absent entries are // inactive. Nodes map[string]*NodeState `json:"nodes,omitempty"` + + // completed is the set of node names that already produced an + // output in session history. Reconstructed by ReconstructRunState + // and used by Resume to avoid re-triggering a handoff successor + // that already ran on a prior turn (idempotency). Not persisted. + completed map[string]bool } // NewRunState returns an empty state with the Nodes map diff --git a/workflow/workflow.go b/workflow/workflow.go index d34c3511a..1aed0106d 100644 --- a/workflow/workflow.go +++ b/workflow/workflow.go @@ -260,38 +260,7 @@ func (w *Workflow) RunNode(ctx agent.InvocationContext, input any) iter.Seq2[*se // All goroutines have returned; ensure no leak. s.wg.Wait() - - // Persist the run state so a follow-up turn can call - // Workflow.Resume with the recovered NodeWaiting set. - // Emitted as a session.Event with StateDelta so the - // surrounding event-append pipeline can propagate it to - // storage; see NewRunStateEvent for why a direct - // State.Set is not sufficient. - yieldRunStateEvent(ctx, w.name, s.state, yield) - } -} - -// yieldRunStateEvent emits a session.Event carrying the workflow's -// serialised RunState in Actions.StateDelta. No-op when the -// workflow is anonymous (no name → no persistence) or when the -// caller has stopped consuming the iterator. See NewRunStateEvent -// for why the state must be delivered as an event rather than via -// a direct State.Set call. -func yieldRunStateEvent( - ctx agent.InvocationContext, - workflowName string, - state *RunState, - yield func(*session.Event, error) bool, -) { - ev, err := NewRunStateEvent(ctx.InvocationID(), workflowName, state) - if err != nil { - yield(nil, err) - return - } - if ev == nil { - return } - yield(ev, nil) } // userInput extracts the workflow's seed input from the