Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions internal/plugininternal/plugin_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,17 @@ func (pm *PluginManager) RunOnModelErrorCallback(cctx agent.CallbackContext, llm
return nil, nil
}

// RunOnPipelineErrorCallback runs the OnPipelineErrorCallback for all plugins.
func (pm *PluginManager) RunOnPipelineErrorCallback(cctx agent.InvocationContext, err error) error {
for _, plugin := range pm.plugins {
callback := plugin.OnPipelineErrorCallback()
if callback != nil {
err = callback(cctx, err)
}
}
return err
}

// Close calls the CloseFunc on all registered plugins.
func (pm *PluginManager) Close() error {
var errors []error
Expand All @@ -283,6 +294,11 @@ func (pm *PluginManager) Close() error {
return nil
}

// ClearPlugins clears all registered plugins from the manager.
func (pm *PluginManager) ClearPlugins() {
pm.plugins = nil
}

func ToContext(ctx context.Context, cfg *PluginManager) context.Context {
return context.WithValue(ctx, plugincontext.PluginManagerCtxKey, cfg)
}
11 changes: 11 additions & 0 deletions plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ type Config struct {
AfterToolCallback llmagent.AfterToolCallback
OnToolErrorCallback llmagent.OnToolErrorCallback

OnPipelineErrorCallback OnPipelineErrorCallback

CloseFunc func() error
}

Expand All @@ -62,6 +64,7 @@ func New(cfg Config) (*Plugin, error) {
beforeToolCallback: cfg.BeforeToolCallback,
afterToolCallback: cfg.AfterToolCallback,
onToolErrorCallback: cfg.OnToolErrorCallback,
onPipelineErrorCallback: cfg.OnPipelineErrorCallback,
closeFunc: cfg.CloseFunc,
}

Expand Down Expand Up @@ -95,6 +98,8 @@ type Plugin struct {
afterToolCallback llmagent.AfterToolCallback
onToolErrorCallback llmagent.OnToolErrorCallback

onPipelineErrorCallback OnPipelineErrorCallback

closeFunc func() error
}

Expand Down Expand Up @@ -158,10 +163,16 @@ func (p *Plugin) OnToolErrorCallback() llmagent.OnToolErrorCallback {
return p.onToolErrorCallback
}

func (p *Plugin) OnPipelineErrorCallback() OnPipelineErrorCallback {
return p.onPipelineErrorCallback
}

type OnUserMessageCallback func(agent.InvocationContext, *genai.Content) (*genai.Content, error)

type BeforeRunCallback func(agent.InvocationContext) (*genai.Content, error)

type AfterRunCallback func(agent.InvocationContext)

type OnEventCallback func(agent.InvocationContext, *session.Event) (*session.Event, error)

type OnPipelineErrorCallback func(agent.InvocationContext, error) error
34 changes: 34 additions & 0 deletions plugin/plugin_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"google.golang.org/adk/model"
"google.golang.org/adk/plugin"
"google.golang.org/adk/runner"
"google.golang.org/adk/session"
"google.golang.org/adk/tool"
"google.golang.org/adk/tool/functiontool"
)
Expand Down Expand Up @@ -876,3 +877,36 @@ func TestModelCallbacks(t *testing.T) {
})
}
}

func TestClearPlugins(t *testing.T) {
p, err := plugin.New(plugin.Config{
Name: "test-plugin",
})
if err != nil {
t.Fatalf("failed to create plugin: %v", err)
}

model := &testutil.MockModel{}
a, err := llmagent.New(llmagent.Config{
Name: "test_agent",
Model: model,
})
if err != nil {
t.Fatalf("failed to create agent: %v", err)
}

r, err := runner.New(runner.Config{
AppName: "test-app",
Agent: a,
SessionService: session.InMemoryService(),
PluginConfig: runner.PluginConfig{
Plugins: []*plugin.Plugin{p},
},
})
if err != nil {
t.Fatalf("failed to create runner: %v", err)
}

r.ClearPlugins()
}

52 changes: 44 additions & 8 deletions runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ func (r *Runner) Run(ctx context.Context, userID, sessionID string, msg *genai.C
})
ctx, err = r.appendMessageToSession(ctx, storedSession, msg, cfg.SaveInputBlobsAsArtifacts, r.pluginManager, options.stateDelta)
if err != nil {
if r.pluginManager != nil {
err = r.pluginManager.RunOnPipelineErrorCallback(ctx, err)
}
yield(nil, err)
return
}
Expand All @@ -216,6 +219,9 @@ func (r *Runner) Run(ctx context.Context, userID, sessionID string, msg *genai.C
defer pluginManager.RunAfterRunCallback(ctx)

earlyExitResult, err := pluginManager.RunBeforeRunCallback(ctx)
if err != nil {
err = pluginManager.RunOnPipelineErrorCallback(ctx, err)
}
if earlyExitResult != nil || err != nil {
earlyExitEvent := session.NewEvent(ctx.InvocationID())
earlyExitEvent.Author = "user"
Expand All @@ -233,17 +239,25 @@ func (r *Runner) Run(ctx context.Context, userID, sessionID string, msg *genai.C

for event, err := range agentToRun.Run(ctx) {
if err != nil {
if !yield(event, err) {
return
if pluginManager != nil {
err = pluginManager.RunOnPipelineErrorCallback(ctx, err)
}
if err != nil {
if !yield(event, err) {
return
}
}
continue
}

if pluginManager != nil {
modifiedEvent, err := pluginManager.RunOnEventCallback(ctx, event)
if err != nil {
if !yield(nil, err) {
return
err = pluginManager.RunOnPipelineErrorCallback(ctx, err)
if err != nil {
if !yield(nil, err) {
return
}
}
continue
}
Expand Down Expand Up @@ -402,6 +416,9 @@ func (r *Runner) RunLive(ctx context.Context, userID, sessionID string, cfg agen

if r.pluginManager != nil {
earlyExitResult, err := r.pluginManager.RunBeforeRunCallback(iCtx)
if err != nil {
err = r.pluginManager.RunOnPipelineErrorCallback(iCtx, err)
}
if err != nil {
return nil, nil, err
}
Expand All @@ -424,6 +441,9 @@ func (r *Runner) RunLive(ctx context.Context, userID, sessionID string, cfg agen

agentSess, innerIter, err := lAgent.RunLive(iCtx)
if err != nil {
if r.pluginManager != nil {
err = r.pluginManager.RunOnPipelineErrorCallback(iCtx, err)
}
return nil, nil, err
}

Expand All @@ -437,17 +457,25 @@ func (r *Runner) RunLive(ctx context.Context, userID, sessionID string, cfg agen

for event, err := range innerIter {
if err != nil {
if !yield(nil, err) {
return
if r.pluginManager != nil {
err = r.pluginManager.RunOnPipelineErrorCallback(iCtx, err)
}
if err != nil {
if !yield(nil, err) {
return
}
}
continue
}

if r.pluginManager != nil {
modifiedEvent, pluginErr := r.pluginManager.RunOnEventCallback(iCtx, event)
if pluginErr != nil {
if !yield(nil, pluginErr) {
return
pluginErr = r.pluginManager.RunOnPipelineErrorCallback(iCtx, pluginErr)
if pluginErr != nil {
if !yield(nil, pluginErr) {
return
}
}
continue
}
Expand Down Expand Up @@ -676,3 +704,11 @@ func hasInlineData(event *session.Event) bool {
}
return false
}

// ClearPlugins clears all registered plugins from the runner's plugin manager.
func (r *Runner) ClearPlugins() {
if r.pluginManager != nil {
r.pluginManager.ClearPlugins()
}
}

Loading