diff --git a/agent/agent.go b/agent/agent.go new file mode 100644 index 0000000000..13c5900030 --- /dev/null +++ b/agent/agent.go @@ -0,0 +1,497 @@ +// Package agent provides AI agents that manage the lifecycle of services. +// Agents use tools to observe and control services, driven by a directive +// that describes their purpose. They operate externally to the services +// they manage, interacting through the registry and RPC client. +package agent + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "go-micro.dev/v5/codec/bytes" + log "go-micro.dev/v5/logger" + "go-micro.dev/v5/model" + "go-micro.dev/v5/registry" +) + +// ActivityType classifies the kind of action an agent has performed. +type ActivityType string + +const ( + // ActivityEvaluate marks a periodic evaluation cycle. + ActivityEvaluate ActivityType = "evaluate" + // ActivityPrompt marks an on-demand prompt submitted via Prompt. + ActivityPrompt ActivityType = "prompt" + // ActivityTool marks a tool invocation made by the model. + ActivityTool ActivityType = "tool" + // ActivityResponse marks a completed model response. + ActivityResponse ActivityType = "response" + // ActivityError marks an error that occurred during processing. + ActivityError ActivityType = "error" +) + +// Activity records a single action performed by the agent. +type Activity struct { + // Time is when the activity occurred. + Time time.Time + // Type classifies the activity. + Type ActivityType + // Prompt is the text of the prompt that triggered the activity (if any). + Prompt string + // Tool is the name of the tool invoked (for ActivityTool). + Tool string + // Result holds the output of a tool call or model response. + Result string + // Err holds any error that occurred (for ActivityError). + Err error +} + +// maxActivities is the maximum number of Activity entries kept in memory. +const maxActivities = 256 + +// Agent manages the lifecycle of services using AI-driven tools. +// Its interface mirrors the Service interface so agents can live alongside +// services in the same runtime environment. +type Agent interface { + // Init initializes the agent with options + Init(...Option) error + // Options returns the current options + Options() Options + // Run starts the agent loop + Run() error + // Stop gracefully stops the agent + Stop() error + // String returns the agent name + String() string + // Prompt queues a user-provided prompt for the agent to process immediately. + // The call is non-blocking and returns a channel that will receive the model + // response once the prompt has been evaluated (and any requested tools have + // been executed). The channel is buffered and closed after the response is + // sent, so callers can range over it or select on it. + Prompt(text string) <-chan *model.Response + // Activity returns a chronological snapshot of recent agent activities + // (evaluations, prompts, tool calls, responses, and errors). + Activity() []Activity +} + +// agent is the default Agent implementation. +type agent struct { + opts Options + stop chan struct{} + once sync.Once + activities []Activity + actMu sync.RWMutex +} + +// New creates a new Agent with the given options. +func New(opts ...Option) Agent { + return &agent{ + opts: newOptions(opts...), + stop: make(chan struct{}), + activities: make([]Activity, 0, maxActivities), + } +} + +// Init initializes the agent with additional options. +func (a *agent) Init(opts ...Option) error { + for _, o := range opts { + o(&a.opts) + } + return nil +} + +// Options returns the current agent options. +func (a *agent) Options() Options { + return a.opts +} + +// String returns the agent name. +func (a *agent) String() string { + return a.opts.Name +} + +// Stop signals the agent to stop its run loop. +func (a *agent) Stop() error { + a.once.Do(func() { + close(a.stop) + }) + return nil +} + +// record appends act to the agent's activity log. +// Oldest entries are dropped once the log reaches maxActivities. +func (a *agent) record(act Activity) { + if act.Time.IsZero() { + act.Time = time.Now() + } + a.actMu.Lock() + a.activities = append(a.activities, act) + if len(a.activities) > maxActivities { + a.activities = a.activities[len(a.activities)-maxActivities:] + } + a.actMu.Unlock() +} + +// Activity returns a chronological snapshot of recent agent activities. +func (a *agent) Activity() []Activity { + a.actMu.RLock() + defer a.actMu.RUnlock() + result := make([]Activity, len(a.activities)) + copy(result, a.activities) + return result +} + +// Prompt processes a user-provided prompt immediately. +// It is non-blocking: it spawns a goroutine and returns a buffered channel +// that will receive the model response (then be closed). If no model is +// configured, the channel is closed immediately with no value. +func (a *agent) Prompt(text string) <-chan *model.Response { + ch := make(chan *model.Response, 1) + a.record(Activity{Type: ActivityPrompt, Prompt: text}) + go func() { + defer close(ch) + if a.opts.Model == nil { + return + } + tools := a.buildTools() + resp, err := a.opts.Model.Generate(a.opts.Context, &model.Request{ + SystemPrompt: a.opts.Directive, + Prompt: text, + Tools: tools, + }) + if err != nil { + a.record(Activity{Type: ActivityError, Prompt: text, Err: err}) + return + } + for _, tc := range resp.ToolCalls { + _, content := a.executeTool(tc.Name, tc.Input) + if isErrorContent(content) { + a.record(Activity{Type: ActivityError, Tool: tc.Name, Result: content}) + } else { + a.record(Activity{Type: ActivityTool, Tool: tc.Name, Result: content}) + } + } + reply := resp.Reply + if reply == "" { + reply = resp.Answer + } + a.record(Activity{Type: ActivityResponse, Prompt: text, Result: reply}) + ch <- resp + }() + return ch +} + +// Run starts the agent loop. The agent watches the services it manages, +// periodically evaluates their state using the AI model, and acts on +// the results via its built-in service management tools. +func (a *agent) Run() error { + logger := a.opts.Logger + logger.Logf(log.InfoLevel, "Starting [agent] %s", a.opts.Name) + + // Build the set of tools available to this agent. + tools := a.buildTools() + + ticker := time.NewTicker(a.opts.Interval) + defer ticker.Stop() + + for { + select { + case <-a.stop: + logger.Logf(log.InfoLevel, "Stopping [agent] %s", a.opts.Name) + return nil + case <-a.opts.Context.Done(): + return nil + case <-ticker.C: + if err := a.evaluate(tools); err != nil { + logger.Logf(log.ErrorLevel, "[agent] %s evaluate error: %v", a.opts.Name, err) + } + } + } +} + +// evaluate asks the model to assess the current state of the managed +// services and execute any necessary management actions. +func (a *agent) evaluate(tools []model.Tool) error { + if a.opts.Model == nil { + return nil + } + + a.record(Activity{Type: ActivityEvaluate}) + + status, err := a.serviceStatus() + if err != nil { + return err + } + + prompt := fmt.Sprintf( + "Current status of managed services: %s\n\nDirective: %s\n\nAssess the services and take any necessary management actions.", + status, a.opts.Directive, + ) + + req := &model.Request{ + SystemPrompt: a.opts.Directive, + Prompt: prompt, + Tools: tools, + } + + resp, err := a.opts.Model.Generate(a.opts.Context, req) + if err != nil { + a.record(Activity{Type: ActivityError, Err: err}) + return fmt.Errorf("model generate: %w", err) + } + + // Execute any tool calls requested by the model. + for _, tc := range resp.ToolCalls { + result, content := a.executeTool(tc.Name, tc.Input) + if isErrorContent(content) { + a.record(Activity{Type: ActivityError, Tool: tc.Name, Result: content}) + } else { + a.record(Activity{Type: ActivityTool, Tool: tc.Name, Result: content}) + } + a.opts.Logger.Logf(log.DebugLevel, "[agent] %s tool %s result: %v", a.opts.Name, tc.Name, result) + } + + reply := resp.Reply + if reply == "" { + reply = resp.Answer + } + if reply != "" { + a.record(Activity{Type: ActivityResponse, Result: reply}) + } + + return nil +} + +// serviceStatus returns a JSON summary of the current state of all +// managed services by querying the registry. +func (a *agent) serviceStatus() (string, error) { + if a.opts.Registry == nil { + return "{}", nil + } + + type svcStatus struct { + Name string `json:"name"` + Running bool `json:"running"` + Version string `json:"version,omitempty"` + Nodes int `json:"nodes"` + } + + var statuses []svcStatus + + for _, name := range a.opts.Services { + svcs, err := a.opts.Registry.GetService(name) + if err != nil || len(svcs) == 0 { + statuses = append(statuses, svcStatus{Name: name, Running: false}) + continue + } + statuses = append(statuses, svcStatus{ + Name: name, + Running: true, + Version: svcs[0].Version, + Nodes: len(svcs[0].Nodes), + }) + } + + b, err := json.Marshal(statuses) + if err != nil { + return "{}", err + } + return string(b), nil +} + +// buildTools returns the set of model.Tool definitions the agent uses +// to manage its services. +func (a *agent) buildTools() []model.Tool { + return []model.Tool{ + { + Name: "list_services", + OriginalName: "list_services", + Description: "List all services managed by this agent along with their current status.", + Properties: map[string]any{ + "type": "object", + "properties": map[string]any{}, + }, + }, + { + Name: "get_service_status", + OriginalName: "get_service_status", + Description: "Get the detailed status of a specific service.", + Properties: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + "description": "The name of the service", + }, + }, + "required": []string{"name"}, + }, + }, + { + Name: "call_service", + OriginalName: "call_service", + Description: "Make an RPC call to a service endpoint.", + Properties: map[string]any{ + "type": "object", + "properties": map[string]any{ + "service": map[string]any{ + "type": "string", + "description": "The name of the service to call", + }, + "endpoint": map[string]any{ + "type": "string", + "description": "The endpoint/method to call", + }, + "request": map[string]any{ + "type": "object", + "description": "The request payload", + }, + }, + "required": []string{"service", "endpoint"}, + }, + }, + } +} + +// executeTool dispatches a tool call by name and returns the result. +func (a *agent) executeTool(name string, input map[string]any) (any, string) { + switch name { + case "list_services": + status, err := a.serviceStatus() + if err != nil { + return nil, fmt.Sprintf(`{"error": %q}`, err.Error()) + } + return status, status + + case "get_service_status": + svcName, _ := input["name"].(string) + if svcName == "" { + return nil, `{"error": "name is required"}` + } + svcs, err := a.opts.Registry.GetService(svcName) + if err != nil || len(svcs) == 0 { + return nil, fmt.Sprintf(`{"name": %q, "running": false}`, svcName) + } + b, _ := json.Marshal(map[string]any{ + "name": svcName, + "running": true, + "version": svcs[0].Version, + "nodes": len(svcs[0].Nodes), + }) + return string(b), string(b) + + case "call_service": + if a.opts.Client == nil { + return nil, `{"error": "no client configured"}` + } + svcName, _ := input["service"].(string) + endpoint, _ := input["endpoint"].(string) + if svcName == "" || endpoint == "" { + return nil, `{"error": "service and endpoint are required"}` + } + + reqBody, _ := json.Marshal(input["request"]) + req := a.opts.Client.NewRequest(svcName, endpoint, &bytes.Frame{Data: reqBody}) + var rsp bytes.Frame + if err := a.opts.Client.Call(context.Background(), req, &rsp); err != nil { + return nil, fmt.Sprintf(`{"error": %q}`, err.Error()) + } + return string(rsp.Data), string(rsp.Data) + + default: + // Delegate to a custom tool handler if provided. + if a.opts.ToolHandler != nil { + result, content := a.opts.ToolHandler(name, input) + return result, content + } + return nil, fmt.Sprintf(`{"error": "unknown tool %q"}`, name) + } +} + +// isErrorContent reports whether the JSON content string returned by +// executeTool represents a tool error (i.e. contains an "error" key). +func isErrorContent(content string) bool { + var obj map[string]any + if err := json.Unmarshal([]byte(content), &obj); err != nil { + return false + } + _, hasErr := obj["error"] + return hasErr +} + +// DefaultAgent is the package-level default Agent instance. +var DefaultAgent Agent + +// Run starts the default agent. +func Run() error { + if DefaultAgent == nil { + return fmt.Errorf("no default agent configured") + } + return DefaultAgent.Run() +} + +// NewFunc is a constructor function for creating Agent instances. +type NewFunc func(...Option) Agent + +// Directive returns the agent's system prompt / purpose description. +// It is a convenience accessor to Options.Directive. +func Directive(a Agent) string { + return a.Options().Directive +} + +// Services returns the list of service names managed by this agent. +func Services(a Agent) []string { + return a.Options().Services +} + +// WatchServices watches the registry for changes to managed services +// and calls fn whenever a service changes. It blocks until ctx is done. +func WatchServices(ctx context.Context, reg registry.Registry, names []string, fn func(string, *registry.Result)) error { + if reg == nil { + return fmt.Errorf("registry is required") + } + + nameSet := make(map[string]struct{}, len(names)) + for _, n := range names { + nameSet[n] = struct{}{} + } + + watcher, err := reg.Watch() + if err != nil { + return err + } + + // Stop the watcher when the context is cancelled so that the + // blocking Next() call below returns promptly. + go func() { + <-ctx.Done() + watcher.Stop() + }() + + for { + res, err := watcher.Next() + if err != nil { + // A non-nil error means the watcher was stopped or failed. + // Return nil when the context was cancelled (expected shutdown). + select { + case <-ctx.Done(): + return nil + default: + return err + } + } + if res == nil || res.Service == nil { + continue + } + if len(names) == 0 { + fn(res.Service.Name, res) + continue + } + if _, ok := nameSet[res.Service.Name]; ok { + fn(res.Service.Name, res) + } + } +} diff --git a/agent/agent_test.go b/agent/agent_test.go new file mode 100644 index 0000000000..f3c9792aa8 --- /dev/null +++ b/agent/agent_test.go @@ -0,0 +1,500 @@ +package agent + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go-micro.dev/v5/model" + "go-micro.dev/v5/registry" +) + +// TestNew verifies that New returns an Agent with default options applied. +func TestNew(t *testing.T) { + a := New() + require.NotNil(t, a) + assert.Equal(t, "agent", a.String()) + + opts := a.Options() + assert.NotEmpty(t, opts.Directive) + assert.Equal(t, 30*time.Second, opts.Interval) + assert.NotNil(t, opts.Context) +} + +// TestNewWithOptions verifies functional options are applied correctly. +func TestNewWithOptions(t *testing.T) { + reg := registry.NewMemoryRegistry() + a := New( + WithName("test-agent"), + WithDirective("manage my service"), + WithServices("svc-a", "svc-b"), + WithRegistry(reg), + WithInterval(5*time.Second), + ) + + require.NotNil(t, a) + assert.Equal(t, "test-agent", a.String()) + + opts := a.Options() + assert.Equal(t, "manage my service", opts.Directive) + assert.Equal(t, []string{"svc-a", "svc-b"}, opts.Services) + assert.Equal(t, 5*time.Second, opts.Interval) +} + +// TestInit verifies Init applies additional options after creation. +func TestInit(t *testing.T) { + a := New(WithName("orig")) + err := a.Init(WithName("updated")) + require.NoError(t, err) + assert.Equal(t, "updated", a.String()) +} + +// TestRunStop verifies Run starts and Stop terminates the agent cleanly. +func TestRunStop(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + a := New( + WithName("lifecycle-agent"), + WithContext(ctx), + WithInterval(10*time.Second), // long interval – evaluation won't run + ) + + errCh := make(chan error, 1) + go func() { + errCh <- a.Run() + }() + + // Give the goroutine a moment to start. + time.Sleep(50 * time.Millisecond) + + require.NoError(t, a.Stop()) + select { + case err := <-errCh: + assert.NoError(t, err) + case <-time.After(2 * time.Second): + t.Fatal("agent did not stop in time") + } +} + +// TestServiceStatus verifies serviceStatus with an in-memory registry. +func TestServiceStatus(t *testing.T) { + reg := registry.NewMemoryRegistry() + + // Register a fake service. + err := reg.Register(®istry.Service{ + Name: "greeter", + Version: "1.0.0", + Nodes: []*registry.Node{ + {Id: "greeter-1", Address: "127.0.0.1:8080"}, + }, + }) + require.NoError(t, err) + + a := &agent{ + opts: newOptions( + WithRegistry(reg), + WithServices("greeter", "missing-svc"), + ), + stop: make(chan struct{}), + } + + status, err := a.serviceStatus() + require.NoError(t, err) + assert.Contains(t, status, `"greeter"`) + assert.Contains(t, status, `"running":true`) + assert.Contains(t, status, `"missing-svc"`) + assert.Contains(t, status, `"running":false`) +} + +// TestBuildTools verifies the built-in tool definitions are well-formed. +func TestBuildTools(t *testing.T) { + a := &agent{ + opts: newOptions(), + stop: make(chan struct{}), + } + tools := a.buildTools() + assert.Len(t, tools, 3) + + names := make(map[string]bool) + for _, tool := range tools { + names[tool.Name] = true + assert.NotEmpty(t, tool.Description) + assert.NotNil(t, tool.Properties) + } + assert.True(t, names["list_services"]) + assert.True(t, names["get_service_status"]) + assert.True(t, names["call_service"]) +} + +// TestExecuteToolListServices verifies list_services returns service state. +func TestExecuteToolListServices(t *testing.T) { + reg := registry.NewMemoryRegistry() + err := reg.Register(®istry.Service{ + Name: "hello", + Version: "v1", + Nodes: []*registry.Node{{Id: "hello-1", Address: "127.0.0.1:9090"}}, + }) + require.NoError(t, err) + + a := &agent{ + opts: newOptions( + WithRegistry(reg), + WithServices("hello"), + ), + stop: make(chan struct{}), + } + + _, content := a.executeTool("list_services", nil) + assert.Contains(t, content, `"hello"`) +} + +// TestExecuteToolGetServiceStatus verifies get_service_status returns details. +func TestExecuteToolGetServiceStatus(t *testing.T) { + reg := registry.NewMemoryRegistry() + err := reg.Register(®istry.Service{ + Name: "store", + Version: "v2", + Nodes: []*registry.Node{{Id: "store-1", Address: "127.0.0.1:7070"}}, + }) + require.NoError(t, err) + + a := &agent{ + opts: newOptions(WithRegistry(reg)), + stop: make(chan struct{}), + } + + _, content := a.executeTool("get_service_status", map[string]any{"name": "store"}) + assert.Contains(t, content, `"running":true`) + + _, missing := a.executeTool("get_service_status", map[string]any{"name": "unknown"}) + assert.Contains(t, missing, `"running":`) + + _, noName := a.executeTool("get_service_status", map[string]any{}) + assert.Contains(t, noName, "error") +} + +// TestExecuteToolUnknownWithHandler verifies custom tool handlers are called. +func TestExecuteToolUnknownWithHandler(t *testing.T) { + called := false + a := &agent{ + opts: newOptions(WithToolHandler(func(name string, input map[string]any) (any, string) { + called = true + return nil, `{"custom": true}` + })), + stop: make(chan struct{}), + } + + _, content := a.executeTool("custom_tool", map[string]any{}) + assert.True(t, called) + assert.Contains(t, content, "custom") +} + +// TestExecuteToolUnknownNoHandler verifies unknown tools return an error when no handler is set. +func TestExecuteToolUnknownNoHandler(t *testing.T) { + a := &agent{opts: newOptions(), stop: make(chan struct{})} + _, content := a.executeTool("nope", nil) + assert.Contains(t, content, "error") +} + +// TestEvaluateNoModel verifies evaluate is a no-op when no model is configured. +func TestEvaluateNoModel(t *testing.T) { + a := &agent{opts: newOptions(), stop: make(chan struct{})} + err := a.evaluate(nil) + assert.NoError(t, err) +} + +// TestEvaluateWithMockModel verifies evaluate calls the model and handles tool calls. +func TestEvaluateWithMockModel(t *testing.T) { + mockModel := &mockModel{ + resp: &model.Response{ + ToolCalls: []model.ToolCall{ + {Name: "list_services", Input: map[string]any{}}, + }, + }, + } + + reg := registry.NewMemoryRegistry() + a := &agent{ + opts: newOptions( + WithModel(mockModel), + WithRegistry(reg), + ), + stop: make(chan struct{}), + } + + tools := a.buildTools() + err := a.evaluate(tools) + assert.NoError(t, err) + assert.True(t, mockModel.called) +} + +// TestDirectiveHelper verifies the Directive helper function. +func TestDirectiveHelper(t *testing.T) { + a := New(WithDirective("my directive")) + assert.Equal(t, "my directive", Directive(a)) +} + +// TestServicesHelper verifies the Services helper function. +func TestServicesHelper(t *testing.T) { + a := New(WithServices("svc1", "svc2")) + assert.Equal(t, []string{"svc1", "svc2"}, Services(a)) +} + +// TestWatchServicesContextCancel verifies WatchServices respects context cancellation. +func TestWatchServicesContextCancel(t *testing.T) { + reg := registry.NewMemoryRegistry() + ctx, cancel := context.WithCancel(context.Background()) + + errCh := make(chan error, 1) + go func() { + errCh <- WatchServices(ctx, reg, nil, func(name string, _ *registry.Result) {}) + }() + + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case err := <-errCh: + assert.NoError(t, err) + case <-time.After(2 * time.Second): + t.Fatal("WatchServices did not return after context cancel") + } +} + +// TestWatchServicesNilRegistry verifies WatchServices returns error for nil registry. +func TestWatchServicesNilRegistry(t *testing.T) { + err := WatchServices(context.Background(), nil, nil, func(string, *registry.Result) {}) + assert.Error(t, err) +} + +// TestPromptNoModel verifies Prompt closes the channel immediately when no model is set. +func TestPromptNoModel(t *testing.T) { + a := New(WithName("no-model-agent")) + + ch := a.Prompt("hello") + select { + case resp, ok := <-ch: + assert.False(t, ok, "channel should be closed with no value") + assert.Nil(t, resp) + case <-time.After(2 * time.Second): + t.Fatal("Prompt channel was not closed in time") + } +} + +// TestPromptNonBlocking verifies Prompt returns immediately. +func TestPromptNonBlocking(t *testing.T) { + // slow model — blocks for up to 5 s + slow := &slowMockModel{delay: 5 * time.Second} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + a := New(WithName("slow-agent"), WithModel(slow), WithContext(ctx)) + + start := time.Now() + ch := a.Prompt("are you there?") + elapsed := time.Since(start) + + // Prompt must return without waiting for the model. + assert.Less(t, elapsed, 500*time.Millisecond, "Prompt should be non-blocking") + + // Clean up: cancel context so the goroutine exits. + cancel() + // Drain channel. + select { + case <-ch: + case <-time.After(6 * time.Second): + } +} + +// TestPromptWithModel verifies Prompt delivers the model response on the channel. +func TestPromptWithModel(t *testing.T) { + expected := &model.Response{Reply: "all services healthy"} + mock := &mockModel{resp: expected} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + a := New(WithName("prompt-agent"), WithModel(mock), WithContext(ctx)) + + ch := a.Prompt("how are the services?") + select { + case resp := <-ch: + require.NotNil(t, resp) + assert.Equal(t, "all services healthy", resp.Reply) + case <-time.After(2 * time.Second): + t.Fatal("did not receive prompt response in time") + } + + // Channel should be closed after the single response. + _, ok := <-ch + assert.False(t, ok, "channel should be closed after response") +} + +// TestPromptRecordsActivity verifies that Prompt records ActivityPrompt and ActivityResponse. +func TestPromptRecordsActivity(t *testing.T) { + mock := &mockModel{resp: &model.Response{Reply: "ok"}} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + a := New(WithName("activity-agent"), WithModel(mock), WithContext(ctx)) + + ch := a.Prompt("status check") + <-ch // wait for completion + + acts := a.Activity() + require.NotEmpty(t, acts) + + types := make(map[ActivityType]int) + for _, act := range acts { + types[act.Type]++ + assert.False(t, act.Time.IsZero(), "activity should have a timestamp") + } + + assert.GreaterOrEqual(t, types[ActivityPrompt], 1, "should have at least one ActivityPrompt") + assert.GreaterOrEqual(t, types[ActivityResponse], 1, "should have at least one ActivityResponse") +} + +// TestPromptRecordsToolActivity verifies tool calls made during Prompt are recorded. +func TestPromptRecordsToolActivity(t *testing.T) { + reg := registry.NewMemoryRegistry() + mock := &mockModel{ + resp: &model.Response{ + Reply: "checked", + ToolCalls: []model.ToolCall{ + {Name: "list_services", Input: map[string]any{}}, + }, + }, + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + a := New( + WithName("tool-activity-agent"), + WithModel(mock), + WithRegistry(reg), + WithContext(ctx), + ) + + ch := a.Prompt("list services please") + <-ch + + acts := a.Activity() + types := make(map[ActivityType]int) + for _, act := range acts { + types[act.Type]++ + } + assert.GreaterOrEqual(t, types[ActivityTool], 1, "should record at least one tool activity") +} + +// TestActivityIsSnapshot verifies Activity returns an independent copy. +func TestActivityIsSnapshot(t *testing.T) { + a := New(WithName("snapshot-agent")) + + snap1 := a.Activity() + assert.Empty(t, snap1) + + // Directly record something. + impl, ok := a.(*agent) + require.True(t, ok, "New() must return *agent") + impl.record(Activity{Type: ActivityEvaluate}) + + snap2 := a.Activity() + assert.Len(t, snap2, 1) + + // The first snapshot is unchanged. + assert.Empty(t, snap1) +} + +// TestEvaluateRecordsActivity verifies evaluate records evaluate/tool/response activities. +func TestEvaluateRecordsActivity(t *testing.T) { + mock := &mockModel{ + resp: &model.Response{ + Reply: "evaluated", + ToolCalls: []model.ToolCall{ + {Name: "list_services", Input: map[string]any{}}, + }, + }, + } + reg := registry.NewMemoryRegistry() + a := &agent{ + opts: newOptions(WithModel(mock), WithRegistry(reg)), + stop: make(chan struct{}), + activities: make([]Activity, 0, maxActivities), + } + + tools := a.buildTools() + err := a.evaluate(tools) + require.NoError(t, err) + + acts := a.Activity() + types := make(map[ActivityType]int) + for _, act := range acts { + types[act.Type]++ + } + assert.GreaterOrEqual(t, types[ActivityEvaluate], 1) + assert.GreaterOrEqual(t, types[ActivityTool], 1) + assert.GreaterOrEqual(t, types[ActivityResponse], 1) +} + +// TestPromptErrorRecorded verifies that a model error is recorded as ActivityError. +func TestPromptErrorRecorded(t *testing.T) { + errModel := &mockModel{err: fmt.Errorf("model offline")} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + a := New(WithName("error-agent"), WithModel(errModel), WithContext(ctx)) + + ch := a.Prompt("hello") + <-ch // closed without a value on error + + acts := a.Activity() + types := make(map[ActivityType]int) + for _, act := range acts { + types[act.Type]++ + if act.Type == ActivityError { + assert.NotNil(t, act.Err) + } + } + assert.GreaterOrEqual(t, types[ActivityError], 1) +} + +// slowMockModel is a model.Model that blocks until its context is cancelled. +type slowMockModel struct { + delay time.Duration +} + +func (m *slowMockModel) Init(...model.Option) error { return nil } +func (m *slowMockModel) Options() model.Options { return model.Options{} } +func (m *slowMockModel) String() string { return "slow" } +func (m *slowMockModel) Stream(_ context.Context, _ *model.Request, _ ...model.GenerateOption) (model.Stream, error) { + return nil, nil +} +func (m *slowMockModel) Generate(ctx context.Context, _ *model.Request, _ ...model.GenerateOption) (*model.Response, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(m.delay): + return &model.Response{Reply: "done"}, nil + } +} + +// mockModel is a test double for model.Model. +type mockModel struct { + called bool + resp *model.Response + err error +} + +func (m *mockModel) Init(...model.Option) error { return nil } +func (m *mockModel) Options() model.Options { return model.Options{} } +func (m *mockModel) String() string { return "mock" } +func (m *mockModel) Stream(_ context.Context, _ *model.Request, _ ...model.GenerateOption) (model.Stream, error) { + return nil, nil +} +func (m *mockModel) Generate(_ context.Context, _ *model.Request, _ ...model.GenerateOption) (*model.Response, error) { + m.called = true + return m.resp, m.err +} diff --git a/agent/options.go b/agent/options.go new file mode 100644 index 0000000000..17a8fe33d8 --- /dev/null +++ b/agent/options.go @@ -0,0 +1,139 @@ +package agent + +import ( + "context" + "time" + + log "go-micro.dev/v5/logger" + "go-micro.dev/v5/model" + "go-micro.dev/v5/registry" + + "go-micro.dev/v5/client" +) + +// Options configures an Agent. +type Options struct { + // Name of the agent. + Name string + + // Directive is the agent's system prompt describing its purpose + // and how it should manage the services it is responsible for. + Directive string + + // Services is the list of service names this agent manages. + // An empty list means the agent watches all registered services. + Services []string + + // Model is the AI model the agent uses to reason about service state. + // When nil the agent still runs and executes its built-in tools but + // does not perform AI-driven evaluation. + Model model.Model + + // Registry used to discover and watch services. + Registry registry.Registry + + // Client used to make RPC calls to services. + Client client.Client + + // Logger for agent output. + Logger log.Logger + + // Context for cancellation and deadline propagation. + Context context.Context + + // Interval between evaluation cycles. Defaults to 30 seconds. + Interval time.Duration + + // ToolHandler is an optional callback for custom tool execution. + // It is called when a tool call does not match a built-in tool. + ToolHandler model.ToolHandler +} + +// Option is a function that modifies Options. +type Option func(*Options) + +func newOptions(opts ...Option) Options { + o := Options{ + Name: "agent", + Directive: "You are an agent that manages the lifecycle of microservices. Monitor their health and take corrective action when needed.", + Context: context.Background(), + Interval: 30 * time.Second, + Registry: registry.DefaultRegistry, + Client: client.DefaultClient, + Logger: log.DefaultLogger, + } + for _, opt := range opts { + opt(&o) + } + return o +} + +// WithName sets the agent name. +func WithName(name string) Option { + return func(o *Options) { + o.Name = name + } +} + +// WithDirective sets the agent directive (system prompt). +func WithDirective(directive string) Option { + return func(o *Options) { + o.Directive = directive + } +} + +// WithServices sets the list of service names the agent manages. +func WithServices(services ...string) Option { + return func(o *Options) { + o.Services = services + } +} + +// WithModel sets the AI model used for evaluation. +func WithModel(m model.Model) Option { + return func(o *Options) { + o.Model = m + } +} + +// WithRegistry sets the registry for service discovery. +func WithRegistry(r registry.Registry) Option { + return func(o *Options) { + o.Registry = r + } +} + +// WithClient sets the RPC client used for service calls. +func WithClient(c client.Client) Option { + return func(o *Options) { + o.Client = c + } +} + +// WithLogger sets the logger. +func WithLogger(l log.Logger) Option { + return func(o *Options) { + o.Logger = l + } +} + +// WithContext sets the context. +func WithContext(ctx context.Context) Option { + return func(o *Options) { + o.Context = ctx + } +} + +// WithInterval sets the evaluation interval. +func WithInterval(d time.Duration) Option { + return func(o *Options) { + o.Interval = d + } +} + +// WithToolHandler sets a custom tool handler for unrecognized tool calls. +func WithToolHandler(h model.ToolHandler) Option { + return func(o *Options) { + o.ToolHandler = h + } +} diff --git a/micro.go b/micro.go index 6f03853d94..8d4bcdc779 100644 --- a/micro.go +++ b/micro.go @@ -4,6 +4,7 @@ package micro import ( "context" + "go-micro.dev/v5/agent" "go-micro.dev/v5/client" "go-micro.dev/v5/server" "go-micro.dev/v5/service" @@ -87,3 +88,15 @@ func RegisterHandler(s server.Server, h interface{}, opts ...server.HandlerOptio func RegisterSubscriber(topic string, s server.Server, h interface{}, opts ...server.SubscriberOption) error { return s.Subscribe(s.NewSubscriber(topic, h, opts...)) } + +// Agent is an AI-driven entity that manages the lifecycle of one or more services. +// It uses tools to observe and control services based on a directive. +type Agent = agent.Agent + +// AgentOption is a functional option for configuring an Agent. +type AgentOption = agent.Option + +// NewAgent creates and returns a new Agent. +func NewAgent(opts ...AgentOption) Agent { + return agent.New(opts...) +}