diff --git a/internal/plugininternal/plugin_manager.go b/internal/plugininternal/plugin_manager.go index 51347acbc..08a21c278 100644 --- a/internal/plugininternal/plugin_manager.go +++ b/internal/plugininternal/plugin_manager.go @@ -269,6 +269,17 @@ func (pm *PluginManager) RunOnModelErrorCallback(cctx agent.CallbackContext, llm return nil, nil } +// RunOnPipelineErrorCallback runs the OnPipelineErrorCallback for all plugins. +func (pm *PluginManager) RunOnPipelineErrorCallback(cctx agent.InvocationContext, err error) error { + for _, plugin := range pm.plugins { + callback := plugin.OnPipelineErrorCallback() + if callback != nil { + err = callback(cctx, err) + } + } + return err +} + // Close calls the CloseFunc on all registered plugins. func (pm *PluginManager) Close() error { var errors []error @@ -283,6 +294,11 @@ func (pm *PluginManager) Close() error { return nil } +// ClearPlugins clears all registered plugins from the manager. +func (pm *PluginManager) ClearPlugins() { + pm.plugins = nil +} + func ToContext(ctx context.Context, cfg *PluginManager) context.Context { return context.WithValue(ctx, plugincontext.PluginManagerCtxKey, cfg) } diff --git a/plugin/plugin.go b/plugin/plugin.go index 162e73e68..7d98c657c 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -44,6 +44,8 @@ type Config struct { AfterToolCallback llmagent.AfterToolCallback OnToolErrorCallback llmagent.OnToolErrorCallback + OnPipelineErrorCallback OnPipelineErrorCallback + CloseFunc func() error } @@ -62,6 +64,7 @@ func New(cfg Config) (*Plugin, error) { beforeToolCallback: cfg.BeforeToolCallback, afterToolCallback: cfg.AfterToolCallback, onToolErrorCallback: cfg.OnToolErrorCallback, + onPipelineErrorCallback: cfg.OnPipelineErrorCallback, closeFunc: cfg.CloseFunc, } @@ -95,6 +98,8 @@ type Plugin struct { afterToolCallback llmagent.AfterToolCallback onToolErrorCallback llmagent.OnToolErrorCallback + onPipelineErrorCallback OnPipelineErrorCallback + closeFunc func() error } @@ -158,6 +163,10 @@ func (p *Plugin) OnToolErrorCallback() llmagent.OnToolErrorCallback { return p.onToolErrorCallback } +func (p *Plugin) OnPipelineErrorCallback() OnPipelineErrorCallback { + return p.onPipelineErrorCallback +} + type OnUserMessageCallback func(agent.InvocationContext, *genai.Content) (*genai.Content, error) type BeforeRunCallback func(agent.InvocationContext) (*genai.Content, error) @@ -165,3 +174,5 @@ type BeforeRunCallback func(agent.InvocationContext) (*genai.Content, error) type AfterRunCallback func(agent.InvocationContext) type OnEventCallback func(agent.InvocationContext, *session.Event) (*session.Event, error) + +type OnPipelineErrorCallback func(agent.InvocationContext, error) error diff --git a/plugin/plugin_manager_test.go b/plugin/plugin_manager_test.go index 81c256125..0cef7a6f9 100644 --- a/plugin/plugin_manager_test.go +++ b/plugin/plugin_manager_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/adk/model" "google.golang.org/adk/plugin" "google.golang.org/adk/runner" + "google.golang.org/adk/session" "google.golang.org/adk/tool" "google.golang.org/adk/tool/functiontool" ) @@ -876,3 +877,36 @@ func TestModelCallbacks(t *testing.T) { }) } } + +func TestClearPlugins(t *testing.T) { + p, err := plugin.New(plugin.Config{ + Name: "test-plugin", + }) + if err != nil { + t.Fatalf("failed to create plugin: %v", err) + } + + model := &testutil.MockModel{} + a, err := llmagent.New(llmagent.Config{ + Name: "test_agent", + Model: model, + }) + if err != nil { + t.Fatalf("failed to create agent: %v", err) + } + + r, err := runner.New(runner.Config{ + AppName: "test-app", + Agent: a, + SessionService: session.InMemoryService(), + PluginConfig: runner.PluginConfig{ + Plugins: []*plugin.Plugin{p}, + }, + }) + if err != nil { + t.Fatalf("failed to create runner: %v", err) + } + + r.ClearPlugins() +} + diff --git a/runner/runner.go b/runner/runner.go index fd910ee4f..93fe6fb7a 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -205,6 +205,9 @@ func (r *Runner) Run(ctx context.Context, userID, sessionID string, msg *genai.C }) ctx, err = r.appendMessageToSession(ctx, storedSession, msg, cfg.SaveInputBlobsAsArtifacts, r.pluginManager, options.stateDelta) if err != nil { + if r.pluginManager != nil { + err = r.pluginManager.RunOnPipelineErrorCallback(ctx, err) + } yield(nil, err) return } @@ -216,6 +219,9 @@ func (r *Runner) Run(ctx context.Context, userID, sessionID string, msg *genai.C defer pluginManager.RunAfterRunCallback(ctx) earlyExitResult, err := pluginManager.RunBeforeRunCallback(ctx) + if err != nil { + err = pluginManager.RunOnPipelineErrorCallback(ctx, err) + } if earlyExitResult != nil || err != nil { earlyExitEvent := session.NewEvent(ctx.InvocationID()) earlyExitEvent.Author = "user" @@ -233,8 +239,13 @@ func (r *Runner) Run(ctx context.Context, userID, sessionID string, msg *genai.C for event, err := range agentToRun.Run(ctx) { if err != nil { - if !yield(event, err) { - return + if pluginManager != nil { + err = pluginManager.RunOnPipelineErrorCallback(ctx, err) + } + if err != nil { + if !yield(event, err) { + return + } } continue } @@ -242,8 +253,11 @@ func (r *Runner) Run(ctx context.Context, userID, sessionID string, msg *genai.C if pluginManager != nil { modifiedEvent, err := pluginManager.RunOnEventCallback(ctx, event) if err != nil { - if !yield(nil, err) { - return + err = pluginManager.RunOnPipelineErrorCallback(ctx, err) + if err != nil { + if !yield(nil, err) { + return + } } continue } @@ -402,6 +416,9 @@ func (r *Runner) RunLive(ctx context.Context, userID, sessionID string, cfg agen if r.pluginManager != nil { earlyExitResult, err := r.pluginManager.RunBeforeRunCallback(iCtx) + if err != nil { + err = r.pluginManager.RunOnPipelineErrorCallback(iCtx, err) + } if err != nil { return nil, nil, err } @@ -424,6 +441,9 @@ func (r *Runner) RunLive(ctx context.Context, userID, sessionID string, cfg agen agentSess, innerIter, err := lAgent.RunLive(iCtx) if err != nil { + if r.pluginManager != nil { + err = r.pluginManager.RunOnPipelineErrorCallback(iCtx, err) + } return nil, nil, err } @@ -437,8 +457,13 @@ func (r *Runner) RunLive(ctx context.Context, userID, sessionID string, cfg agen for event, err := range innerIter { if err != nil { - if !yield(nil, err) { - return + if r.pluginManager != nil { + err = r.pluginManager.RunOnPipelineErrorCallback(iCtx, err) + } + if err != nil { + if !yield(nil, err) { + return + } } continue } @@ -446,8 +471,11 @@ func (r *Runner) RunLive(ctx context.Context, userID, sessionID string, cfg agen if r.pluginManager != nil { modifiedEvent, pluginErr := r.pluginManager.RunOnEventCallback(iCtx, event) if pluginErr != nil { - if !yield(nil, pluginErr) { - return + pluginErr = r.pluginManager.RunOnPipelineErrorCallback(iCtx, pluginErr) + if pluginErr != nil { + if !yield(nil, pluginErr) { + return + } } continue } @@ -676,3 +704,11 @@ func hasInlineData(event *session.Event) bool { } return false } + +// ClearPlugins clears all registered plugins from the runner's plugin manager. +func (r *Runner) ClearPlugins() { + if r.pluginManager != nil { + r.pluginManager.ClearPlugins() + } +} + diff --git a/runner/runner_test.go b/runner/runner_test.go index 03eb70ec4..c0a497649 100644 --- a/runner/runner_test.go +++ b/runner/runner_test.go @@ -28,6 +28,7 @@ import ( "google.golang.org/adk/agent/llmagent" "google.golang.org/adk/artifact" "google.golang.org/adk/model" + "google.golang.org/adk/plugin" "google.golang.org/adk/session" ) @@ -447,3 +448,206 @@ func TestRunner_AutoCreateSession(t *testing.T) { }) } } + +func TestRunner_OnPipelineErrorCallback(t *testing.T) { + appName := "testApp" + userID := "testUser" + sessionID := "testSession" + + t.Run("BeforeRunCallback error triggers OnPipelineErrorCallback and is propagated", func(t *testing.T) { + ctx := t.Context() + sessionService := session.InMemoryService() + _, err := sessionService.Create(ctx, &session.CreateRequest{ + AppName: appName, + UserID: userID, + SessionID: sessionID, + }) + if err != nil { + t.Fatal(err) + } + + originalErr := fmt.Errorf("before run error") + interceptedErr := fmt.Errorf("intercepted before run error") + + var callbackCalled bool + var receivedErr error + + p, err := plugin.New(plugin.Config{ + Name: "test_plugin", + BeforeRunCallback: func(ctx agent.InvocationContext) (*genai.Content, error) { + return nil, originalErr + }, + OnPipelineErrorCallback: func(ctx agent.InvocationContext, err error) error { + callbackCalled = true + receivedErr = err + return interceptedErr + }, + }) + if err != nil { + t.Fatal(err) + } + + testAgent := must(agent.New(agent.Config{Name: "test_agent"})) + r, err := New(Config{ + AppName: appName, + Agent: testAgent, + SessionService: sessionService, + PluginConfig: PluginConfig{ + Plugins: []*plugin.Plugin{p}, + }, + }) + if err != nil { + t.Fatal(err) + } + + var runErr error + for _, err := range r.Run(ctx, userID, sessionID, &genai.Content{Parts: []*genai.Part{{Text: "hello"}}}, agent.RunConfig{}) { + if err != nil { + runErr = err + } + } + + if !callbackCalled { + t.Error("OnPipelineErrorCallback was not called") + } + if receivedErr != originalErr { + t.Errorf("expected received error %v, got %v", originalErr, receivedErr) + } + if runErr != interceptedErr { + t.Errorf("expected run error %v, got %v", interceptedErr, runErr) + } + }) + + t.Run("OnUserMessageCallback error triggers OnPipelineErrorCallback and is propagated", func(t *testing.T) { + ctx := t.Context() + sessionService := session.InMemoryService() + _, err := sessionService.Create(ctx, &session.CreateRequest{ + AppName: appName, + UserID: userID, + SessionID: sessionID, + }) + if err != nil { + t.Fatal(err) + } + + originalErr := fmt.Errorf("user message error") + interceptedErr := fmt.Errorf("intercepted user message error") + + var callbackCalled bool + var receivedErr error + + p, err := plugin.New(plugin.Config{ + Name: "test_plugin", + OnUserMessageCallback: func(ctx agent.InvocationContext, msg *genai.Content) (*genai.Content, error) { + return nil, originalErr + }, + OnPipelineErrorCallback: func(ctx agent.InvocationContext, err error) error { + callbackCalled = true + receivedErr = err + return interceptedErr + }, + }) + if err != nil { + t.Fatal(err) + } + + testAgent := must(agent.New(agent.Config{Name: "test_agent"})) + r, err := New(Config{ + AppName: appName, + Agent: testAgent, + SessionService: sessionService, + PluginConfig: PluginConfig{ + Plugins: []*plugin.Plugin{p}, + }, + }) + if err != nil { + t.Fatal(err) + } + + var runErr error + for _, err := range r.Run(ctx, userID, sessionID, &genai.Content{Parts: []*genai.Part{{Text: "hello"}}}, agent.RunConfig{}) { + if err != nil { + runErr = err + } + } + + if !callbackCalled { + t.Error("OnPipelineErrorCallback was not called") + } + if receivedErr == nil || !strings.Contains(receivedErr.Error(), originalErr.Error()) { + t.Errorf("expected received error to contain %v, got %v", originalErr, receivedErr) + } + if runErr != interceptedErr { + t.Errorf("expected run error %v, got %v", interceptedErr, runErr) + } + }) + + t.Run("Agent execution loop error triggers OnPipelineErrorCallback and is propagated", func(t *testing.T) { + ctx := t.Context() + sessionService := session.InMemoryService() + _, err := sessionService.Create(ctx, &session.CreateRequest{ + AppName: appName, + UserID: userID, + SessionID: sessionID, + }) + if err != nil { + t.Fatal(err) + } + + originalErr := fmt.Errorf("agent run error") + interceptedErr := fmt.Errorf("intercepted agent run error") + + var callbackCalled bool + var receivedErr error + + p, err := plugin.New(plugin.Config{ + Name: "test_plugin", + OnPipelineErrorCallback: func(ctx agent.InvocationContext, err error) error { + callbackCalled = true + receivedErr = err + return interceptedErr + }, + }) + if err != nil { + t.Fatal(err) + } + + testAgent := must(agent.New(agent.Config{ + Name: "test_agent", + Run: func(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { + return func(yield func(*session.Event, error) bool) { + yield(nil, originalErr) + } + }, + })) + + r, err := New(Config{ + AppName: appName, + Agent: testAgent, + SessionService: sessionService, + PluginConfig: PluginConfig{ + Plugins: []*plugin.Plugin{p}, + }, + }) + if err != nil { + t.Fatal(err) + } + + var runErr error + for _, err := range r.Run(ctx, userID, sessionID, &genai.Content{Parts: []*genai.Part{{Text: "hello"}}}, agent.RunConfig{}) { + if err != nil { + runErr = err + } + } + + if !callbackCalled { + t.Error("OnPipelineErrorCallback was not called") + } + if receivedErr != originalErr { + t.Errorf("expected received error %v, got %v", originalErr, receivedErr) + } + if runErr != interceptedErr { + t.Errorf("expected run error %v, got %v", interceptedErr, runErr) + } + }) +} diff --git a/server/adkrest/internal/services/debugtelemetry_test.go b/server/adkrest/internal/services/debugtelemetry_test.go index 3ebb6ed8a..a46abf10f 100644 --- a/server/adkrest/internal/services/debugtelemetry_test.go +++ b/server/adkrest/internal/services/debugtelemetry_test.go @@ -229,6 +229,7 @@ func TestDebugTelemetryGetSpansBySessionID(t *testing.T) { cmpopts.IgnoreFields(DebugSpan{}, "StartTime", "EndTime", "TraceID", "SpanID", "ParentSpanID"), cmpopts.IgnoreFields(DebugLog{}, "ObservedTimestamp", "TraceID", "SpanID"), cmpopts.EquateEmpty(), + cmpopts.SortSlices(func(x, y DebugSpan) bool { return x.Name < y.Name }), } // Validate session spans @@ -366,6 +367,7 @@ func TestDebugTelemetryGetSpansByEventID(t *testing.T) { cmpopts.IgnoreFields(DebugSpan{}, "StartTime", "EndTime", "ParentSpanID", "TraceID", "SpanID"), cmpopts.IgnoreFields(DebugLog{}, "ObservedTimestamp", "TraceID", "SpanID"), cmpopts.EquateEmpty(), + cmpopts.SortSlices(func(x, y DebugSpan) bool { return x.Name < y.Name }), } // Validate event spans