Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 71 additions & 27 deletions agent/workflowagent/hitl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"google.golang.org/genai"

"google.golang.org/adk/agent"
"google.golang.org/adk/model"
"google.golang.org/adk/session"
"google.golang.org/adk/workflow"
)
Expand Down Expand Up @@ -336,18 +337,20 @@ func TestWorkflowAgent_RunThenResume_DynamicNodeOrchestrator(t *testing.T) {
// Test fixtures and helpers
// =============================================================================

// fakeSession is a minimal session.Session implementation that
// faithfully models the AppendEvent-gated state contract used by
// session.InMemoryService and persistent backends: state mutations
// are applied only when applyStateDelta is called for an event
// (the unit-test analogue of session.Service.AppendEvent), never
// via direct State.Set from inside the agent.
// fakeSession is a minimal session.Session that records appended
// events, as the real session services do. HITL resume reconstructs
// paused state from this event history (Workflow.ReconstructRunState),
// so the test must append every yielded event — and the inbound user
// FunctionResponse on a resume turn — into the session.
//
// drainAgent (this file) calls applyStateDelta for every event the
// agent yields, simulating what the runner does in production.
// drainAgent (this file) appends every event the agent yields, and
// appendUserMessage records the inbound resume message, together
// simulating what the runner does in production.
type fakeSession struct {
session.Session
state *fakeSessionState
state *fakeSessionState
mu sync.Mutex
events []*session.Event
}

func newFakeSession() *fakeSession {
Expand All @@ -357,18 +360,55 @@ func newFakeSession() *fakeSession {
func (s *fakeSession) ID() string { return "test-session-id" }
func (s *fakeSession) State() session.State { return s.state }

// applyStateDelta merges any Actions.StateDelta on the supplied
// event into the underlying state map. Mirrors what
// inMemoryService.AppendEvent does for session-scoped (no
// app:/user:/temp: prefix) keys; HITL persistence uses such keys.
func (s *fakeSession) applyStateDelta(ev *session.Event) {
if ev == nil || len(ev.Actions.StateDelta) == 0 {
func (s *fakeSession) Events() session.Events {
s.mu.Lock()
defer s.mu.Unlock()
return fakeEvents(append([]*session.Event(nil), s.events...))
}

// appendEvent records an event in history (the test analogue of
// session.Service.AppendEvent) and applies any StateDelta.
func (s *fakeSession) appendEvent(ev *session.Event) {
if ev == nil {
return
}
s.mu.Lock()
s.events = append(s.events, ev)
s.mu.Unlock()
if len(ev.Actions.StateDelta) > 0 {
s.state.mu.Lock()
for k, v := range ev.Actions.StateDelta {
s.state.m[k] = v
}
s.state.mu.Unlock()
}
}

// appendUserMessage records an inbound user message as a "user"
// event so a resume turn's FunctionResponse is visible to
// ReconstructRunState, mirroring the runner appending the user turn.
func (s *fakeSession) appendUserMessage(msg *genai.Content) {
if msg == nil {
return
}
s.state.mu.Lock()
defer s.state.mu.Unlock()
for k, v := range ev.Actions.StateDelta {
s.state.m[k] = v
ev := session.NewEvent("test-invocation-id")
ev.Author = "user"
ev.LLMResponse = model.LLMResponse{Content: msg}
s.appendEvent(ev)
}

// fakeEvents is a session.Events over a fixed slice.
type fakeEvents []*session.Event

func (e fakeEvents) Len() int { return len(e) }
func (e fakeEvents) At(i int) *session.Event { return e[i] }
func (e fakeEvents) All() iter.Seq[*session.Event] {
return func(yield func(*session.Event) bool) {
for _, ev := range e {
if !yield(ev) {
return
}
}
}
}

Expand Down Expand Up @@ -461,6 +501,12 @@ func makeAgent(t *testing.T, edges []workflow.Edge) agent.Agent {
// pause/resume round-trips through fakeSessionState as they would
// in production.
func newMockCtx(sess session.Session, agt agent.Agent, msg *genai.Content) *MockInvocationContext {
// Append the inbound user turn to history first, as the runner
// does in production, so a resume turn's FunctionResponse is
// visible to ReconstructRunState.
if fs, ok := sess.(*fakeSession); ok {
fs.appendUserMessage(msg)
}
return &MockInvocationContext{
Context: context.TODO(),
sess: sess,
Expand All @@ -469,13 +515,11 @@ func newMockCtx(sess session.Session, agt agent.Agent, msg *genai.Content) *Mock
}
}

// drainAgent consumes the agent's iter.Seq2, collecting events,
// and applies each event's StateDelta to sess. The applyStateDelta
// step replaces the AppendEvent-side state propagation that the
// real runner performs; without it state writes from the agent
// would never become visible to subsequent calls. Fails the test
// if the iterator yields a non-nil error that the test did not
// opt into via wantErr.
// drainAgent consumes the agent's iter.Seq2, collecting events and
// appending each to sess. The append step is the test analogue of
// the runner's AppendEvent: it builds the session history that the
// next turn's ReconstructRunState reads. Fails the test if the
// iterator yields a non-nil error the test did not opt into.
func drainAgent(t *testing.T, sess *fakeSession, seq iter.Seq2[*session.Event, error], wantErr error) []*session.Event {
t.Helper()
var got []*session.Event
Expand All @@ -488,7 +532,7 @@ func drainAgent(t *testing.T, sess *fakeSession, seq iter.Seq2[*session.Event, e
continue
}
got = append(got, ev)
sess.applyStateDelta(ev)
sess.appendEvent(ev)
}
switch {
case wantErr == nil && sawErr != nil:
Expand Down
28 changes: 18 additions & 10 deletions agent/workflowagent/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ type workflowAgent struct {
// Workflow.Run (every other turn).
func (a *workflowAgent) run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] {
return func(yield func(*session.Event, error) bool) {
if responses, state, ok := a.detectResume(ctx); ok {
responses, state, ok, err := a.detectResume(ctx)
if err != nil {
yield(nil, err)
return
}
if ok {
for ev, err := range a.workflow.Resume(ctx, state, responses) {
if !yield(ev, err) {
return
Expand All @@ -116,10 +121,10 @@ func (a *workflowAgent) run(ctx agent.InvocationContext) iter.Seq2[*session.Even
// responses map keyed by InterruptID (suitable for
// Workflow.Resume), the RunState loaded from session, and true if
// this turn is a resume; (nil, nil, false) for a fresh turn.
func (a *workflowAgent) detectResume(ctx agent.InvocationContext) (map[string]any, *workflow.RunState, bool) {
func (a *workflowAgent) detectResume(ctx agent.InvocationContext) (map[string]any, *workflow.RunState, bool, error) {
frs := utils.FunctionResponses(ctx.UserContent())
if len(frs) == 0 {
return nil, nil, false
return nil, nil, false, nil
}

responses := map[string]any{}
Expand All @@ -130,17 +135,20 @@ func (a *workflowAgent) detectResume(ctx agent.InvocationContext) (map[string]an
responses[fr.ID] = decodeWorkflowInputResponse(fr)
}
if len(responses) == 0 {
return nil, nil, false
return nil, nil, false, nil
}

state, err := workflow.LoadRunState(ctx.Session(), a.workflow.Name())
if err != nil || state == nil {
// No persisted state means there is nothing to resume;
// fall through to a fresh Workflow.Run.
return nil, nil, false
state, err := a.workflow.ReconstructRunState(ctx.Session())
if err != nil {
// A bad resume (e.g. failed schema validation) must fail,
// not silently fall through to a fresh Run.
return nil, nil, false, err
}
if state == nil {
return nil, nil, false, nil
}

return responses, state, true
return responses, state, true, nil
}

// decodeWorkflowInputResponse extracts the user-supplied payload
Expand Down
43 changes: 43 additions & 0 deletions session/database/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,49 @@ func Test_databaseService(t *testing.T) {
})
}

// TestDatabaseService_AppendEvent_WorkflowFieldsRoundTrip guards that
// the storage layer serializes/deserializes the workflow event fields
// (NodeInfo, RequestedInput, Routes); dropping them breaks HITL resume.
func TestDatabaseService_AppendEvent_WorkflowFieldsRoundTrip(t *testing.T) {
ctx := t.Context()
s := emptyService(t)

created, err := s.Create(ctx, &session.CreateRequest{AppName: "app", UserID: "user"})
if err != nil {
t.Fatalf("Create: %v", err)
}

event := &session.Event{
ID: "wf_event",
Author: "agent",
NodeInfo: &session.NodeInfo{Path: "ask_name"},
RequestedInput: &session.RequestInput{InterruptID: "ask_name", Message: "What's your name?"},
Routes: []string{"route_a"},
}
if err := s.AppendEvent(ctx, created.Session, event); err != nil {
t.Fatalf("AppendEvent: %v", err)
}

got, err := s.Get(ctx, &session.GetRequest{AppName: "app", UserID: "user", SessionID: created.Session.ID()})
if err != nil {
t.Fatalf("Get: %v", err)
}
evs := got.Session.Events()
if evs.Len() != 1 {
t.Fatalf("got %d events, want 1", evs.Len())
}
ev := evs.At(0)
if ev.NodeInfo == nil || ev.NodeInfo.Path != "ask_name" {
t.Errorf("NodeInfo not persisted: %#v", ev.NodeInfo)
}
if ev.RequestedInput == nil || ev.RequestedInput.InterruptID != "ask_name" {
t.Errorf("RequestedInput not persisted: %#v", ev.RequestedInput)
}
if len(ev.Routes) != 1 || ev.Routes[0] != "route_a" {
t.Errorf("Routes not persisted: %#v", ev.Routes)
}
}

func emptyService(t *testing.T) *databaseService {
t.Helper()
gormConfig := &gorm.Config{
Expand Down
34 changes: 34 additions & 0 deletions session/database/storage_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ type storageEvent struct {
LongRunningToolIDsJSON dynamicJSON
RoutesJSON dynamicJSON
OutputJSON dynamicJSON
NodeInfoJSON dynamicJSON
RequestedInputJSON dynamicJSON
Branch *string
Timestamp time.Time `gorm:"precision:6"`

Expand Down Expand Up @@ -153,6 +155,22 @@ func createStorageEvent(session session.Session, event *session.Event) (*storage
storageEv.OutputJSON = outputJSON
}

if event.NodeInfo != nil {
nodeInfoJSON, err := json.Marshal(event.NodeInfo)
if err != nil {
return nil, fmt.Errorf("failed to marshal node info: %w", err)
}
storageEv.NodeInfoJSON = nodeInfoJSON
}

if event.RequestedInput != nil {
reqInputJSON, err := json.Marshal(event.RequestedInput)
if err != nil {
return nil, fmt.Errorf("failed to marshal requested input: %w", err)
}
storageEv.RequestedInputJSON = reqInputJSON
}

// Handle optional fields by taking the address of the value.
// An empty string from the event becomes a nil pointer in storage.
if event.Branch != "" {
Expand Down Expand Up @@ -282,6 +300,20 @@ func createEventFromStorageEvent(se *storageEvent) (*session.Event, error) {
}
}

var nodeInfo *session.NodeInfo
if se.NodeInfoJSON != nil {
if err := json.Unmarshal([]byte(se.NodeInfoJSON), &nodeInfo); err != nil {
return nil, fmt.Errorf("failed to unmarshal node info: %w", err)
}
}

var requestedInput *session.RequestInput
if se.RequestedInputJSON != nil {
if err := json.Unmarshal([]byte(se.RequestedInputJSON), &requestedInput); err != nil {
return nil, fmt.Errorf("failed to unmarshal requested input: %w", err)
}
}

// --- Handle simple pointer fields (dereference or use zero value) ---
// Use the helper to safely get the value or its zero-value default
branch := derefOrZero(se.Branch)
Expand All @@ -302,6 +334,8 @@ func createEventFromStorageEvent(se *storageEvent) (*session.Event, error) {
Routes: routes,
Branch: branch,
Output: output,
NodeInfo: nodeInfo,
RequestedInput: requestedInput,
LLMResponse: model.LLMResponse{
Content: content,
GroundingMetadata: groundingMetadata,
Expand Down
3 changes: 3 additions & 0 deletions session/inmemory.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,11 @@ func (s *inMemoryService) AppendEvent(ctx context.Context, curSession Session, e
SkipSummarization: event.Actions.SkipSummarization,
},
LongRunningToolIDs: slices.Clone(event.LongRunningToolIDs),
Routes: slices.Clone(event.Routes),
RequestedInput: event.RequestedInput,
LLMResponse: event.LLMResponse,
Output: event.Output,
NodeInfo: event.NodeInfo,
}

// update the in-memory session service
Expand Down
45 changes: 45 additions & 0 deletions session/inmemory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,51 @@ func Test_inMemoryService_CreateConcurrentAccess(t *testing.T) {
}
}

// TestInMemorySession_AppendEvent_WorkflowFieldsRoundTrip guards that
// AppendEvent persists the workflow event fields (NodeInfo,
// RequestedInput, Routes) — workflow resume rehydrates node state from
// them, and a manual event copy that drops them breaks HITL resume.
func TestInMemorySession_AppendEvent_WorkflowFieldsRoundTrip(t *testing.T) {
ctx := t.Context()
service := session.InMemoryService()

createResp, err := service.Create(ctx, &session.CreateRequest{AppName: "app", UserID: "user"})
if err != nil {
t.Fatalf("Create: %v", err)
}
sess := createResp.Session

event := &session.Event{
ID: "wf_event",
Author: "agent",
NodeInfo: &session.NodeInfo{Path: "ask_name"},
RequestedInput: &session.RequestInput{InterruptID: "ask_name", Message: "What's your name?"},
Routes: []string{"route_a"},
}
if err := service.AppendEvent(ctx, sess, event); err != nil {
t.Fatalf("AppendEvent: %v", err)
}

got, err := service.Get(ctx, &session.GetRequest{AppName: "app", UserID: "user", SessionID: sess.ID()})
if err != nil {
t.Fatalf("Get: %v", err)
}
evs := got.Session.Events()
if evs.Len() != 1 {
t.Fatalf("got %d events, want 1", evs.Len())
}
ev := evs.At(0)
if ev.NodeInfo == nil || ev.NodeInfo.Path != "ask_name" {
t.Errorf("NodeInfo not persisted: %#v", ev.NodeInfo)
}
if ev.RequestedInput == nil || ev.RequestedInput.InterruptID != "ask_name" {
t.Errorf("RequestedInput not persisted: %#v", ev.RequestedInput)
}
if len(ev.Routes) != 1 || ev.Routes[0] != "route_a" {
t.Errorf("Routes not persisted: %#v", ev.Routes)
}
}

func TestInMemorySession_AppendEvent_Deadlock(t *testing.T) {
ctx := t.Context()
service := session.InMemoryService()
Expand Down
Loading
Loading