From d763a120a99255b6c9e472b7439bce118411b6ae Mon Sep 17 00:00:00 2001 From: venkat1701 Date: Sat, 24 Jan 2026 14:19:34 +0530 Subject: [PATCH 1/7] Implement pluggable memory system with SQLite driver and associated configurations --- cagent-schema.json | 88 +++++++++++++ examples/memory_demo.yaml | 80 ++++++++++++ pkg/config/latest/types.go | 44 +++++++ pkg/config/latest/validate.go | 102 ++++++++++++++- pkg/memory/adapter.go | 49 ++++++++ pkg/memory/driver.go | 88 +++++++++++++ pkg/memory/memory_test.go | 210 +++++++++++++++++++++++++++++++ pkg/memory/registry.go | 31 +++++ pkg/memory/sqlite/driver.go | 123 ++++++++++++++++++ pkg/memory/sqlite/driver_test.go | 156 +++++++++++++++++++++++ pkg/memory/sqlite/init.go | 7 ++ pkg/teamloader/registry.go | 17 ++- pkg/teamloader/teamloader.go | 46 +++++++ 13 files changed, 1035 insertions(+), 6 deletions(-) create mode 100644 examples/memory_demo.yaml create mode 100644 pkg/memory/adapter.go create mode 100644 pkg/memory/driver.go create mode 100644 pkg/memory/memory_test.go create mode 100644 pkg/memory/registry.go create mode 100644 pkg/memory/sqlite/driver.go create mode 100644 pkg/memory/sqlite/driver_test.go create mode 100644 pkg/memory/sqlite/init.go diff --git a/cagent-schema.json b/cagent-schema.json index 677be8384..0322603c6 100644 --- a/cagent-schema.json +++ b/cagent-schema.json @@ -51,6 +51,13 @@ "$ref": "#/definitions/RAGConfig" } }, + "memory": { + "type": "object", + "description": "Map of memory scope configurations for pluggable memory backends", + "additionalProperties": { + "$ref": "#/definitions/MemoryConfig" + } + }, "metadata": { "$ref": "#/definitions/Metadata", "description": "Configuration metadata" @@ -277,6 +284,13 @@ "type": "string" } }, + "memory": { + "type": "array", + "description": "List of memory scopes to use for this agent", + "items": { + "type": "string" + } + }, "hooks": { "$ref": "#/definitions/HooksConfig", "description": "Lifecycle hooks for executing shell commands at various points in the agent's execution" @@ -1297,6 +1311,80 @@ "strategies" ], "additionalProperties": false + }, + "MemoryConfig": { + "type": "object", + "description": "Memory scope configuration for pluggable memory backends supporting long-term (RAG-style) and short-term (whiteboard) strategies", + "required": ["kind"], + "properties": { + "kind": { + "type": "string", + "description": "Memory backend type", + "enum": ["sqlite", "neo4j", "qdrant", "redis", "whiteboard"] + }, + "strategy": { + "type": "string", + "description": "Memory strategy: long_term (persistent RAG-style) or short_term (ephemeral whiteboard)", + "enum": ["long_term", "short_term"] + }, + "description": { + "type": "string", + "description": "Human-readable description of this memory scope" + }, + "path": { + "type": "string", + "description": "File path for file-based backends (sqlite)" + }, + "ttl": { + "type": "integer", + "description": "Time-to-live in seconds for ephemeral memory (whiteboard, redis)", + "minimum": 0 + }, + "mode": { + "type": "string", + "description": "Access mode for the memory", + "enum": ["read_write", "read_only", "append_only"] + }, + "connection": { + "type": "object", + "description": "Connection details for remote backends", + "properties": { + "url": { + "type": "string", + "description": "Connection URL for the memory backend" + }, + "database": { + "type": "string", + "description": "Database name (for backends supporting multiple databases)" + }, + "collection": { + "type": "string", + "description": "Collection/table name (for vector stores)" + }, + "auth": { + "type": "object", + "description": "Authentication credentials", + "properties": { + "username": { + "type": "string", + "description": "Username for basic authentication" + }, + "password": { + "type": "string", + "description": "Password for basic authentication" + }, + "token": { + "type": "string", + "description": "Token for token-based authentication" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false } } } diff --git a/examples/memory_demo.yaml b/examples/memory_demo.yaml new file mode 100644 index 000000000..6ce4058d2 --- /dev/null +++ b/examples/memory_demo.yaml @@ -0,0 +1,80 @@ +#!/usr/bin/env cagent run +version: "3" + +metadata: + author: Memory System Demo + readme: | + Demonstrates the pluggable memory system with both long-term and short-term memory. + + Memory strategies: + - **Long-term (SQLite)**: Persistent memory for user facts + - **Short-term (Whiteboard)**: Ephemeral shared context for multi-agent collaboration + + Future backends (config-ready, implementation pending): + - Neo4j GraphRAG for knowledge graphs + - Qdrant for vector-based semantic search + - Redis for distributed whiteboard + +memory: + # Long-term persistent memory (current implementation) + user_facts: + kind: sqlite + strategy: long_term + path: ./memory/user_facts.db + description: "Persistent memory about the user" + + # Short-term shared whiteboard (implementation pending) + # team_whiteboard: + # kind: whiteboard + # strategy: short_term + # ttl: 3600 # 1 hour expiry + # description: "Shared context for agent collaboration" + + # GraphRAG with Neo4j (implementation pending) + # knowledge_graph: + # kind: neo4j + # strategy: long_term + # connection: + # url: bolt://localhost:7687 + # database: cagent + # auth: + # username: neo4j + # password: ${NEO4J_PASSWORD} + # description: "Knowledge graph for semantic relationships" + +agents: + root: + model: anthropic/claude-sonnet-4-5 + description: "Assistant with long-term memory" + instruction: | + You are a helpful assistant with memory capabilities. + + Use the memory tool to remember things about the user. + Before responding, always check memories to personalize your responses. + memory: + - user_facts + toolsets: + - type: think + + # Multi-agent example with shared whiteboard (pending whiteboard implementation) + # coordinator: + # model: anthropic/claude-sonnet-4-5 + # description: "Coordinates team using shared whiteboard" + # memory: + # - team_whiteboard # Shared with sub-agents + # - user_facts # Personal long-term memory + # sub_agents: + # - researcher + # - writer + # + # researcher: + # model: openai/gpt-4o + # description: "Research specialist" + # memory: + # - team_whiteboard # Shared whiteboard + # + # writer: + # model: anthropic/claude-sonnet-4-5 + # description: "Writing specialist" + # memory: + # - team_whiteboard # Shared whiteboard diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index aab024847..e0921ec57 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -18,6 +18,7 @@ type Config struct { Providers map[string]ProviderConfig `json:"providers,omitempty"` Models map[string]ModelConfig `json:"models,omitempty"` RAG map[string]RAGConfig `json:"rag,omitempty"` + Memory map[string]MemoryConfig `json:"memory,omitempty"` Metadata Metadata `json:"metadata,omitempty"` Permissions *PermissionsConfig `json:"permissions,omitempty"` } @@ -119,6 +120,7 @@ type AgentConfig struct { SubAgents []string `json:"sub_agents,omitempty"` Handoffs []string `json:"handoffs,omitempty"` RAG []string `json:"rag,omitempty"` + Memory []string `json:"memory,omitempty"` AddDate bool `json:"add_date,omitempty"` AddEnvironmentInfo bool `json:"add_environment_info,omitempty"` CodeModeTools bool `json:"code_mode_tools,omitempty"` @@ -1003,3 +1005,45 @@ func (h *HookDefinition) validate(prefix string, index int) error { return nil } + +// MemoryConfig represents a named memory scope configuration. +// Memory scopes define how agents store and retrieve information. +type MemoryConfig struct { + // Kind specifies the memory backend type: sqlite, whiteboard, neo4j, qdrant, redis. + Kind string `json:"kind"` + // Strategy specifies the memory strategy: "long_term" (persistent RAG-style) or "short_term" (ephemeral whiteboard). + // Default is "long_term" for sqlite/neo4j/qdrant, "short_term" for whiteboard/redis. + Strategy string `json:"strategy,omitempty"` + // Description is an optional human-readable description of this memory scope. + Description string `json:"description,omitempty"` + // Connection holds connection details for remote backends. + Connection *MemoryConnectionConfig `json:"connection,omitempty"` + // Path is the file path for file-based backends like sqlite. + Path string `json:"path,omitempty"` + // TTL is the time-to-live in seconds for ephemeral memory (e.g., whiteboard). 0 means no expiry. + TTL int `json:"ttl,omitempty"` + // Mode specifies the access mode: "read_write" (default), "read_only", or "append_only" (event-log style). + Mode string `json:"mode,omitempty"` +} + +// MemoryConnectionConfig holds connection details for remote memory backends. +type MemoryConnectionConfig struct { + // URL is the connection URL for the memory backend. + URL string `json:"url"` + // Database is the database name (for backends that support multiple databases). + Database string `json:"database,omitempty"` + // Collection is the collection/table name (for vector stores). + Collection string `json:"collection,omitempty"` + // Auth holds authentication credentials. + Auth *MemoryAuthConfig `json:"auth,omitempty"` +} + +// MemoryAuthConfig holds authentication credentials for memory backends. +type MemoryAuthConfig struct { + // Username for basic authentication. + Username string `json:"username,omitempty"` + // Password for basic authentication. + Password string `json:"password,omitempty"` + // Token for token-based authentication. + Token string `json:"token,omitempty"` +} diff --git a/pkg/config/latest/validate.go b/pkg/config/latest/validate.go index 752c29987..372420650 100644 --- a/pkg/config/latest/validate.go +++ b/pkg/config/latest/validate.go @@ -2,6 +2,7 @@ package latest import ( "errors" + "fmt" "strings" ) @@ -16,8 +17,7 @@ func (t *Config) UnmarshalYAML(unmarshal func(any) error) error { } func (t *Config) validate() error { - for i := range t.Agents { - agent := t.Agents[i] + for agentName, agent := range t.Agents { for j := range agent.Toolsets { if err := agent.Toolsets[j].validate(); err != nil { return err @@ -28,6 +28,18 @@ func (t *Config) validate() error { return err } } + // Validate agent memory references exist in top-level memory map + for _, memRef := range agent.Memory { + if _, exists := t.Memory[memRef]; !exists { + return fmt.Errorf("agent %q: references undefined memory %q", agentName, memRef) + } + } + } + + for name, mem := range t.Memory { + if err := mem.validate(name); err != nil { + return err + } } return nil @@ -123,3 +135,89 @@ func (t *Toolset) validate() error { return nil } + +func (m *MemoryConfig) validate(name string) error { + if m.Kind == "" { + return fmt.Errorf("memory %q: kind is required", name) + } + + validKinds := map[string]bool{ + "sqlite": true, + "neo4j": true, + "qdrant": true, + "redis": true, + "whiteboard": true, + } + if !validKinds[m.Kind] { + return fmt.Errorf("memory %q: invalid kind %q, must be one of: sqlite, neo4j, qdrant, redis, whiteboard", name, m.Kind) + } + + // Validate strategy if provided + if m.Strategy != "" { + validStrategies := map[string]bool{ + "long_term": true, // Persistent RAG-style memory (sqlite, neo4j, qdrant) + "short_term": true, // Ephemeral whiteboard-style memory (whiteboard, redis) + } + if !validStrategies[m.Strategy] { + return fmt.Errorf("memory %q: invalid strategy %q, must be one of: long_term, short_term", name, m.Strategy) + } + + // Validate strategy matches kind semantics + longTermKinds := map[string]bool{"sqlite": true, "neo4j": true, "qdrant": true} + shortTermKinds := map[string]bool{"whiteboard": true, "redis": true} + + if m.Strategy == "long_term" && shortTermKinds[m.Kind] { + return fmt.Errorf("memory %q: kind %q is not suitable for long_term strategy (use sqlite, neo4j, or qdrant)", name, m.Kind) + } + if m.Strategy == "short_term" && longTermKinds[m.Kind] && m.Kind != "redis" { + // Note: redis can be used for both strategies; sqlite/neo4j/qdrant are long-term only + if m.Kind != "redis" { + return fmt.Errorf("memory %q: kind %q is not suitable for short_term strategy (use whiteboard or redis)", name, m.Kind) + } + } + } + + // Validate mode if provided + if m.Mode != "" { + validModes := map[string]bool{ + "read_write": true, // Default: full read/write access + "read_only": true, // Read-only access (useful for shared knowledge bases) + "append_only": true, // Append-only (event-log style, no updates/deletes) + } + if !validModes[m.Mode] { + return fmt.Errorf("memory %q: invalid mode %q, must be one of: read_write, read_only, append_only", name, m.Mode) + } + } + + // Validate auth completeness if present (check Connection is not nil first) + if m.Connection != nil && m.Connection.Auth != nil { + auth := m.Connection.Auth + hasUserPass := auth.Username != "" || auth.Password != "" + hasToken := auth.Token != "" + + if hasUserPass && hasToken { + return fmt.Errorf("memory %q: auth must use either username/password or token, not both", name) + } + if hasUserPass && (auth.Username == "" || auth.Password == "") { + return fmt.Errorf("memory %q: auth requires both username and password when using user/password auth", name) + } + } + + // For sqlite, path is required + if m.Kind == "sqlite" && m.Path == "" { + return fmt.Errorf("memory %q: sqlite requires a path", name) + } + + // For remote backends, connection URL is typically required + remoteKinds := map[string]bool{"neo4j": true, "qdrant": true, "redis": true} + if remoteKinds[m.Kind] && (m.Connection == nil || m.Connection.URL == "") { + return fmt.Errorf("memory %q: %s requires connection.url", name, m.Kind) + } + + // TTL validation: only meaningful for short-term/ephemeral memory + if m.TTL > 0 && m.Kind != "whiteboard" && m.Kind != "redis" { + return fmt.Errorf("memory %q: ttl is only supported for whiteboard and redis kinds", name) + } + + return nil +} diff --git a/pkg/memory/adapter.go b/pkg/memory/adapter.go new file mode 100644 index 000000000..b0851b952 --- /dev/null +++ b/pkg/memory/adapter.go @@ -0,0 +1,49 @@ +package memory + +import ( + "context" + + "github.com/docker/cagent/pkg/memory/database" + "github.com/google/uuid" +) + +// DatabaseAdapter adapts the new Driver interface to the legacy database.Database interface +type DatabaseAdapter struct { + driver Driver +} + +var _ database.Database = (*DatabaseAdapter)(nil) + +// NewDatabaseAdapter creates an adapter that wraps a Driver +func NewDatabaseAdapter(driver Driver) *DatabaseAdapter { + return &DatabaseAdapter{driver: driver} +} + +func (a *DatabaseAdapter) AddMemory(ctx context.Context, memory database.UserMemory) error { + key := memory.ID + if key == "" { + key = uuid.New().String() + } + return a.driver.Store(ctx, key, memory.Memory) +} + +func (a *DatabaseAdapter) GetMemories(ctx context.Context) ([]database.UserMemory, error) { + entries, err := a.driver.Retrieve(ctx, Query{}) + if err != nil { + return nil, err + } + + memories := make([]database.UserMemory, len(entries)) + for i, e := range entries { + memories[i] = database.UserMemory{ + ID: e.ID, + CreatedAt: e.CreatedAt, + Memory: e.Content, + } + } + return memories, nil +} + +func (a *DatabaseAdapter) DeleteMemory(ctx context.Context, memory database.UserMemory) error { + return a.driver.Delete(ctx, memory.ID) +} diff --git a/pkg/memory/driver.go b/pkg/memory/driver.go new file mode 100644 index 000000000..98e4d76d9 --- /dev/null +++ b/pkg/memory/driver.go @@ -0,0 +1,88 @@ +package memory + +import ( + "context" + "io" + + "github.com/docker/cagent/pkg/config/latest" +) + +// Driver defines the interface for memory backends. +// Implementations support different strategies (long-term RAG, short-term whiteboard). +type Driver interface { + // Store saves a memory entry with the given key and value + Store(ctx context.Context, key string, value string) error + + // Retrieve fetches memory entries matching the query + Retrieve(ctx context.Context, query Query) ([]Entry, error) + + // Delete removes a memory entry by key + Delete(ctx context.Context, key string) error + + // Close releases resources held by the driver + io.Closer +} + +// Query represents different types of memory queries +type Query struct { + // ID for exact match retrieval + ID string + + // Semantic for natural language queries (GraphRAG, vector search) + Semantic string + + // Limit on number of results + Limit int + + // Filters for metadata-based filtering + Filters map[string]any +} + +// Entry represents a memory item returned from a query +type Entry struct { + ID string + CreatedAt string + Content string + Metadata map[string]any + Score float64 // Relevance score for semantic queries +} + +// Factory creates memory drivers from configuration +type Factory interface { + CreateDriver(ctx context.Context, cfg latest.MemoryConfig) (Driver, error) +} + +// Registry holds registered driver factories +type Registry struct { + factories map[string]Factory +} + +// NewRegistry creates a new driver registry +func NewRegistry() *Registry { + return &Registry{ + factories: make(map[string]Factory), + } +} + +// Register adds a factory for a specific backend kind +func (r *Registry) Register(kind string, factory Factory) { + r.factories[kind] = factory +} + +// CreateDriver instantiates a driver from config +func (r *Registry) CreateDriver(ctx context.Context, cfg latest.MemoryConfig) (Driver, error) { + factory, ok := r.factories[cfg.Kind] + if !ok { + return nil, &UnsupportedKindError{Kind: cfg.Kind} + } + return factory.CreateDriver(ctx, cfg) +} + +// UnsupportedKindError indicates an unknown backend kind +type UnsupportedKindError struct { + Kind string +} + +func (e *UnsupportedKindError) Error() string { + return "unsupported memory kind: " + e.Kind +} diff --git a/pkg/memory/memory_test.go b/pkg/memory/memory_test.go new file mode 100644 index 000000000..964e0bb22 --- /dev/null +++ b/pkg/memory/memory_test.go @@ -0,0 +1,210 @@ +package memory_test + +import ( + "context" + "path/filepath" + "testing" + + "github.com/docker/cagent/pkg/config/latest" + "github.com/docker/cagent/pkg/memory" + "github.com/docker/cagent/pkg/memory/database" + "github.com/docker/cagent/pkg/memory/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// MockDriver implements memory.Driver for testing +type MockDriver struct { + stored map[string]string + entries []memory.Entry + closeErr error +} + +func NewMockDriver() *MockDriver { + return &MockDriver{ + stored: make(map[string]string), + entries: []memory.Entry{}, + } +} + +func (m *MockDriver) Store(ctx context.Context, key string, value string) error { + m.stored[key] = value + m.entries = append(m.entries, memory.Entry{ + ID: key, + Content: value, + CreatedAt: "2026-01-17T00:00:00Z", + }) + return nil +} + +func (m *MockDriver) Retrieve(ctx context.Context, query memory.Query) ([]memory.Entry, error) { + if query.ID != "" { + for _, e := range m.entries { + if e.ID == query.ID { + return []memory.Entry{e}, nil + } + } + return []memory.Entry{}, nil + } + return m.entries, nil +} + +func (m *MockDriver) Delete(ctx context.Context, key string) error { + delete(m.stored, key) + var newEntries []memory.Entry + for _, e := range m.entries { + if e.ID != key { + newEntries = append(newEntries, e) + } + } + m.entries = newEntries + return nil +} + +func (m *MockDriver) Close() error { + return m.closeErr +} + +func TestRegistry(t *testing.T) { + t.Parallel() + + t.Run("register and create driver", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + registry := memory.NewRegistry() + + // Register a mock factory + mockFactory := &mockFactory{} + registry.Register("mock", mockFactory) + + // Create driver + cfg := latest.MemoryConfig{Kind: "mock"} + driver, err := registry.CreateDriver(ctx, cfg) + require.NoError(t, err) + require.NotNil(t, driver) + }) + + t.Run("error on unknown kind", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + registry := memory.NewRegistry() + + cfg := latest.MemoryConfig{Kind: "unknown"} + driver, err := registry.CreateDriver(ctx, cfg) + require.Error(t, err) + assert.Nil(t, driver) + + var unsupportedErr *memory.UnsupportedKindError + assert.ErrorAs(t, err, &unsupportedErr) + assert.Equal(t, "unknown", unsupportedErr.Kind) + }) +} + +type mockFactory struct{} + +func (f *mockFactory) CreateDriver(ctx context.Context, cfg latest.MemoryConfig) (memory.Driver, error) { + return NewMockDriver(), nil +} + +func TestDatabaseAdapter(t *testing.T) { + t.Parallel() + ctx := t.Context() + + mockDriver := NewMockDriver() + adapter := memory.NewDatabaseAdapter(mockDriver) + + t.Run("add memory", func(t *testing.T) { + mem := database.UserMemory{ + ID: "test-1", + Memory: "test content", + } + err := adapter.AddMemory(ctx, mem) + require.NoError(t, err) + assert.Equal(t, "test content", mockDriver.stored["test-1"]) + }) + + t.Run("add memory with auto ID", func(t *testing.T) { + mem := database.UserMemory{ + Memory: "auto id content", + } + err := adapter.AddMemory(ctx, mem) + require.NoError(t, err) + // Should have stored with a UUID + assert.Len(t, mockDriver.stored, 2) + }) + + t.Run("get memories", func(t *testing.T) { + memories, err := adapter.GetMemories(ctx) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(memories), 2) + }) + + t.Run("delete memory", func(t *testing.T) { + mem := database.UserMemory{ID: "test-1"} + err := adapter.DeleteMemory(ctx, mem) + require.NoError(t, err) + _, exists := mockDriver.stored["test-1"] + assert.False(t, exists) + }) +} + +func TestDefaultRegistry(t *testing.T) { + t.Run("default registry is singleton", func(t *testing.T) { + reg1 := memory.DefaultRegistry() + reg2 := memory.DefaultRegistry() + assert.Same(t, reg1, reg2) + }) +} + +// Integration test with real SQLite driver +func TestSQLiteDriverIntegration(t *testing.T) { + t.Parallel() + ctx := t.Context() + + dbPath := filepath.Join(t.TempDir(), "integration_test.db") + cfg := latest.MemoryConfig{ + Kind: "sqlite", + Path: dbPath, + } + + // Use the actual sqlite factory + sqliteFactory := &sqlite.Factory{} + registry := memory.NewRegistry() + registry.Register("sqlite", sqliteFactory) + + driver, err := registry.CreateDriver(ctx, cfg) + require.NoError(t, err) + defer driver.Close() + + // Test full workflow through adapter + adapter := memory.NewDatabaseAdapter(driver) + + // Add + err = adapter.AddMemory(ctx, database.UserMemory{ + ID: "integration-1", + Memory: "Integration test memory", + }) + require.NoError(t, err) + + // Get + memories, err := adapter.GetMemories(ctx) + require.NoError(t, err) + require.Len(t, memories, 1) + assert.Equal(t, "Integration test memory", memories[0].Memory) + + // Delete + err = adapter.DeleteMemory(ctx, database.UserMemory{ID: "integration-1"}) + require.NoError(t, err) + + memories, err = adapter.GetMemories(ctx) + require.NoError(t, err) + assert.Empty(t, memories) +} + +// sqliteTestFactory wraps sqlite.Factory for testing (kept for reference) +type sqliteTestFactory struct{} + +func (f *sqliteTestFactory) CreateDriver(ctx context.Context, cfg latest.MemoryConfig) (memory.Driver, error) { + factory := &sqlite.Factory{} + return factory.CreateDriver(ctx, cfg) +} diff --git a/pkg/memory/registry.go b/pkg/memory/registry.go new file mode 100644 index 000000000..85134942f --- /dev/null +++ b/pkg/memory/registry.go @@ -0,0 +1,31 @@ +package memory + +import ( + "context" + "sync" + + "github.com/docker/cagent/pkg/config/latest" +) + +var ( + globalRegistry *Registry + globalRegistryOnce sync.Once +) + +// DefaultRegistry returns the global driver registry +func DefaultRegistry() *Registry { + globalRegistryOnce.Do(func() { + globalRegistry = NewRegistry() + }) + return globalRegistry +} + +// RegisterFactory registers a driver factory for a backend kind +func RegisterFactory(kind string, factory Factory) { + DefaultRegistry().Register(kind, factory) +} + +// CreateDriver creates a driver from config using the default registry +func CreateDriver(ctx context.Context, cfg latest.MemoryConfig) (Driver, error) { + return DefaultRegistry().CreateDriver(ctx, cfg) +} diff --git a/pkg/memory/sqlite/driver.go b/pkg/memory/sqlite/driver.go new file mode 100644 index 000000000..91172289f --- /dev/null +++ b/pkg/memory/sqlite/driver.go @@ -0,0 +1,123 @@ +package sqlite + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/docker/cagent/pkg/config/latest" + "github.com/docker/cagent/pkg/memory" + "github.com/docker/cagent/pkg/sqliteutil" + "github.com/google/uuid" +) + +// Driver implements the memory.Driver interface using SQLite +type Driver struct { + db *sql.DB +} + +// Factory creates SQLite drivers +type Factory struct{} + +var _ memory.Factory = (*Factory)(nil) + +func (f *Factory) CreateDriver(ctx context.Context, cfg latest.MemoryConfig) (memory.Driver, error) { + if cfg.Path == "" { + return nil, fmt.Errorf("sqlite driver requires a path") + } + + db, err := sqliteutil.OpenDB(cfg.Path) + if err != nil { + return nil, fmt.Errorf("failed to open sqlite database: %w", err) + } + + _, err = db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + created_at TEXT, + content TEXT, + metadata TEXT + )`) + if err != nil { + db.Close() + return nil, fmt.Errorf("failed to create memories table: %w", err) + } + + return &Driver{db: db}, nil +} + +func (d *Driver) Store(ctx context.Context, key string, value string) error { + if key == "" { + key = uuid.New().String() + } + + createdAt := time.Now().UTC().Format(time.RFC3339) + _, err := d.db.ExecContext(ctx, + "INSERT OR REPLACE INTO memories (id, created_at, content, metadata) VALUES (?, ?, ?, ?)", + key, createdAt, value, "{}") + if err != nil { + return fmt.Errorf("failed to store memory: %w", err) + } + return nil +} + +func (d *Driver) Retrieve(ctx context.Context, query memory.Query) ([]memory.Entry, error) { + var rows *sql.Rows + var err error + + if query.ID != "" { + rows, err = d.db.QueryContext(ctx, + "SELECT id, created_at, content FROM memories WHERE id = ?", + query.ID) + } else if query.Semantic != "" { + // Semantic search not yet implemented for SQLite + // For now, fall back to retrieving all memories + // Future: Use FTS5 or vector extension for semantic search + sqlQuery := "SELECT id, created_at, content FROM memories ORDER BY created_at DESC" + if query.Limit > 0 { + sqlQuery = fmt.Sprintf("%s LIMIT %d", sqlQuery, query.Limit) + } + rows, err = d.db.QueryContext(ctx, sqlQuery) + } else { + sqlQuery := "SELECT id, created_at, content FROM memories ORDER BY created_at DESC" + if query.Limit > 0 { + sqlQuery = fmt.Sprintf("%s LIMIT %d", sqlQuery, query.Limit) + } + rows, err = d.db.QueryContext(ctx, sqlQuery) + } + + if err != nil { + return nil, fmt.Errorf("failed to retrieve memories: %w", err) + } + defer rows.Close() + + var entries []memory.Entry + for rows.Next() { + var e memory.Entry + if err := rows.Scan(&e.ID, &e.CreatedAt, &e.Content); err != nil { + return nil, fmt.Errorf("failed to scan memory row: %w", err) + } + entries = append(entries, e) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating memory rows: %w", err) + } + + return entries, nil +} + +func (d *Driver) Delete(ctx context.Context, key string) error { + _, err := d.db.ExecContext(ctx, "DELETE FROM memories WHERE id = ?", key) + if err != nil { + return fmt.Errorf("failed to delete memory: %w", err) + } + return nil +} + +func (d *Driver) Close() error { + if d.db != nil { + return d.db.Close() + } + return nil +} diff --git a/pkg/memory/sqlite/driver_test.go b/pkg/memory/sqlite/driver_test.go new file mode 100644 index 000000000..7839cc71c --- /dev/null +++ b/pkg/memory/sqlite/driver_test.go @@ -0,0 +1,156 @@ +package sqlite + +import ( + "path/filepath" + "testing" + + "github.com/docker/cagent/pkg/config/latest" + "github.com/docker/cagent/pkg/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFactory_CreateDriver(t *testing.T) { + t.Parallel() + + ctx := t.Context() + factory := &Factory{} + + t.Run("creates driver with valid path", func(t *testing.T) { + t.Parallel() + dbPath := filepath.Join(t.TempDir(), "test.db") + cfg := latest.MemoryConfig{ + Kind: "sqlite", + Path: dbPath, + } + + driver, err := factory.CreateDriver(ctx, cfg) + require.NoError(t, err) + require.NotNil(t, driver) + defer driver.Close() + }) + + t.Run("fails without path", func(t *testing.T) { + t.Parallel() + cfg := latest.MemoryConfig{ + Kind: "sqlite", + Path: "", + } + + driver, err := factory.CreateDriver(ctx, cfg) + require.Error(t, err) + assert.Nil(t, driver) + assert.Contains(t, err.Error(), "requires a path") + }) +} + +func TestDriver_StoreAndRetrieve(t *testing.T) { + t.Parallel() + + ctx := t.Context() + driver := createTestDriver(t) + defer driver.Close() + + t.Run("store with explicit key", func(t *testing.T) { + err := driver.Store(ctx, "test-key-1", "test value 1") + require.NoError(t, err) + + entries, err := driver.Retrieve(ctx, memory.Query{ID: "test-key-1"}) + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Equal(t, "test-key-1", entries[0].ID) + assert.Equal(t, "test value 1", entries[0].Content) + }) + + t.Run("store with auto-generated key", func(t *testing.T) { + err := driver.Store(ctx, "", "auto key value") + require.NoError(t, err) + + entries, err := driver.Retrieve(ctx, memory.Query{}) + require.NoError(t, err) + require.GreaterOrEqual(t, len(entries), 1) + }) + + t.Run("retrieve all with limit", func(t *testing.T) { + // Add more entries + for i := 0; i < 5; i++ { + err := driver.Store(ctx, "", "bulk value") + require.NoError(t, err) + } + + entries, err := driver.Retrieve(ctx, memory.Query{Limit: 3}) + require.NoError(t, err) + assert.Len(t, entries, 3) + }) + + t.Run("retrieve with semantic query falls back to all", func(t *testing.T) { + entries, err := driver.Retrieve(ctx, memory.Query{ + Semantic: "some semantic query", + Limit: 2, + }) + require.NoError(t, err) + assert.LessOrEqual(t, len(entries), 2) + }) +} + +func TestDriver_Delete(t *testing.T) { + t.Parallel() + + ctx := t.Context() + driver := createTestDriver(t) + defer driver.Close() + + // Store a memory + err := driver.Store(ctx, "delete-test", "to be deleted") + require.NoError(t, err) + + // Verify it exists + entries, err := driver.Retrieve(ctx, memory.Query{ID: "delete-test"}) + require.NoError(t, err) + require.Len(t, entries, 1) + + // Delete it + err = driver.Delete(ctx, "delete-test") + require.NoError(t, err) + + // Verify it's gone + entries, err = driver.Retrieve(ctx, memory.Query{ID: "delete-test"}) + require.NoError(t, err) + assert.Empty(t, entries) +} + +func TestDriver_UpdateExisting(t *testing.T) { + t.Parallel() + + ctx := t.Context() + driver := createTestDriver(t) + defer driver.Close() + + // Store initial value + err := driver.Store(ctx, "update-key", "initial value") + require.NoError(t, err) + + // Update with same key + err = driver.Store(ctx, "update-key", "updated value") + require.NoError(t, err) + + // Retrieve and verify updated + entries, err := driver.Retrieve(ctx, memory.Query{ID: "update-key"}) + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Equal(t, "updated value", entries[0].Content) +} + +func createTestDriver(t *testing.T) *Driver { + t.Helper() + ctx := t.Context() + factory := &Factory{} + dbPath := filepath.Join(t.TempDir(), "test.db") + cfg := latest.MemoryConfig{ + Kind: "sqlite", + Path: dbPath, + } + driver, err := factory.CreateDriver(ctx, cfg) + require.NoError(t, err) + return driver.(*Driver) +} diff --git a/pkg/memory/sqlite/init.go b/pkg/memory/sqlite/init.go new file mode 100644 index 000000000..dc9af38bc --- /dev/null +++ b/pkg/memory/sqlite/init.go @@ -0,0 +1,7 @@ +package sqlite + +import "github.com/docker/cagent/pkg/memory" + +func init() { + memory.RegisterFactory("sqlite", &Factory{}) +} diff --git a/pkg/teamloader/registry.go b/pkg/teamloader/registry.go index 7374605cc..dd59e9ecf 100644 --- a/pkg/teamloader/registry.go +++ b/pkg/teamloader/registry.go @@ -12,7 +12,8 @@ import ( "github.com/docker/cagent/pkg/environment" "github.com/docker/cagent/pkg/gateway" "github.com/docker/cagent/pkg/js" - "github.com/docker/cagent/pkg/memory/database/sqlite" + "github.com/docker/cagent/pkg/memory" + _ "github.com/docker/cagent/pkg/memory/sqlite" // Register sqlite driver "github.com/docker/cagent/pkg/path" "github.com/docker/cagent/pkg/tools" "github.com/docker/cagent/pkg/tools/a2a" @@ -80,7 +81,7 @@ func createTodoTool(_ context.Context, toolset latest.Toolset, _ string, _ *conf return builtin.NewTodoTool(), nil } -func createMemoryTool(_ context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { +func createMemoryTool(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { var memoryPath string if filepath.IsAbs(toolset.Path) { memoryPath = "" @@ -98,11 +99,19 @@ func createMemoryTool(_ context.Context, toolset latest.Toolset, parentDir strin return nil, fmt.Errorf("failed to create memory database directory: %w", err) } - db, err := sqlite.NewMemoryDatabase(validatedMemoryPath) + // Use new driver-based approach + cfg := latest.MemoryConfig{ + Kind: "sqlite", + Path: validatedMemoryPath, + } + + driver, err := memory.CreateDriver(ctx, cfg) if err != nil { - return nil, fmt.Errorf("failed to create memory database: %w", err) + return nil, fmt.Errorf("failed to create memory driver: %w", err) } + // Adapt new driver to legacy database interface + db := memory.NewDatabaseAdapter(driver) return builtin.NewMemoryTool(db), nil } diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 527d17378..45143d770 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -12,6 +12,8 @@ import ( "github.com/docker/cagent/pkg/config" "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/js" + "github.com/docker/cagent/pkg/memory" + _ "github.com/docker/cagent/pkg/memory/sqlite" // Register sqlite driver "github.com/docker/cagent/pkg/model/provider" "github.com/docker/cagent/pkg/model/provider/options" "github.com/docker/cagent/pkg/modelsdev" @@ -113,6 +115,16 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c return nil, err } + // Create memory drivers from top-level memory configs + memoryDrivers := make(map[string]memory.Driver) + for name, memCfg := range cfg.Memory { + driver, err := memory.CreateDriver(ctx, memCfg) + if err != nil { + return nil, fmt.Errorf("failed to create memory driver %q: %w", name, err) + } + memoryDrivers[name] = driver + } + // Create RAG managers parentDir := cmp.Or(agentSource.ParentDir(), runConfig.WorkingDir) ragManagers, err := rag.NewManagers(ctx, cfg, rag.ManagersBuildConfig{ @@ -175,6 +187,12 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c agentTools = append(agentTools, ragTools...) } + // Add memory tools if agent has memory scopes + if len(agentConfig.Memory) > 0 { + memoryTools := createMemoryToolsForAgent(&agentConfig, memoryDrivers) + agentTools = append(agentTools, memoryTools...) + } + opts = append(opts, agent.WithToolSets(agentTools...)) ag := agent.New(agentConfig.Name, agentConfig.Instruction, opts...) @@ -392,3 +410,31 @@ func createRAGToolsForAgent(agentConfig *latest.AgentConfig, allManagers map[str return ragTools } + +// createMemoryToolsForAgent creates memory tools for an agent, one for each referenced memory scope +func createMemoryToolsForAgent(agentConfig *latest.AgentConfig, allDrivers map[string]memory.Driver) []tools.ToolSet { + if len(agentConfig.Memory) == 0 { + return nil + } + + var memoryTools []tools.ToolSet + + for _, memName := range agentConfig.Memory { + driver, exists := allDrivers[memName] + if !exists { + slog.Error("Memory scope not found", "memory_scope", memName) + continue + } + + // Adapt driver to legacy database interface + db := memory.NewDatabaseAdapter(driver) + memTool := builtin.NewMemoryTool(db) + memoryTools = append(memoryTools, memTool) + + slog.Debug("Created memory tool for agent", + "memory_scope", memName, + ) + } + + return memoryTools +} From 5f680dc9bbff4d25518b325b2c3072c5ffa29be9 Mon Sep 17 00:00:00 2001 From: venkat1701 Date: Sat, 24 Jan 2026 14:53:48 +0530 Subject: [PATCH 2/7] Update memory_demo.yaml version to 4 and refactor memory driver interface for improved readability --- examples/memory_demo.yaml | 2 +- pkg/memory/adapter.go | 3 ++- pkg/memory/driver.go | 2 +- pkg/memory/memory_test.go | 17 +++++----------- pkg/memory/sqlite/driver.go | 34 +++++++++++++++----------------- pkg/memory/sqlite/driver_test.go | 2 +- pkg/memory/sqlite/init.go | 5 +++-- 7 files changed, 29 insertions(+), 36 deletions(-) diff --git a/examples/memory_demo.yaml b/examples/memory_demo.yaml index 6ce4058d2..e80b25931 100644 --- a/examples/memory_demo.yaml +++ b/examples/memory_demo.yaml @@ -1,5 +1,5 @@ #!/usr/bin/env cagent run -version: "3" +version: "4" metadata: author: Memory System Demo diff --git a/pkg/memory/adapter.go b/pkg/memory/adapter.go index b0851b952..c8140a34f 100644 --- a/pkg/memory/adapter.go +++ b/pkg/memory/adapter.go @@ -3,8 +3,9 @@ package memory import ( "context" - "github.com/docker/cagent/pkg/memory/database" "github.com/google/uuid" + + "github.com/docker/cagent/pkg/memory/database" ) // DatabaseAdapter adapts the new Driver interface to the legacy database.Database interface diff --git a/pkg/memory/driver.go b/pkg/memory/driver.go index 98e4d76d9..8b012e03b 100644 --- a/pkg/memory/driver.go +++ b/pkg/memory/driver.go @@ -11,7 +11,7 @@ import ( // Implementations support different strategies (long-term RAG, short-term whiteboard). type Driver interface { // Store saves a memory entry with the given key and value - Store(ctx context.Context, key string, value string) error + Store(ctx context.Context, key, value string) error // Retrieve fetches memory entries matching the query Retrieve(ctx context.Context, query Query) ([]Entry, error) diff --git a/pkg/memory/memory_test.go b/pkg/memory/memory_test.go index 964e0bb22..a9eed971c 100644 --- a/pkg/memory/memory_test.go +++ b/pkg/memory/memory_test.go @@ -5,12 +5,13 @@ import ( "path/filepath" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/memory" "github.com/docker/cagent/pkg/memory/database" "github.com/docker/cagent/pkg/memory/sqlite" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // MockDriver implements memory.Driver for testing @@ -27,7 +28,7 @@ func NewMockDriver() *MockDriver { } } -func (m *MockDriver) Store(ctx context.Context, key string, value string) error { +func (m *MockDriver) Store(ctx context.Context, key, value string) error { m.stored[key] = value m.entries = append(m.entries, memory.Entry{ ID: key, @@ -95,7 +96,7 @@ func TestRegistry(t *testing.T) { assert.Nil(t, driver) var unsupportedErr *memory.UnsupportedKindError - assert.ErrorAs(t, err, &unsupportedErr) + require.ErrorAs(t, err, &unsupportedErr) assert.Equal(t, "unknown", unsupportedErr.Kind) }) } @@ -200,11 +201,3 @@ func TestSQLiteDriverIntegration(t *testing.T) { require.NoError(t, err) assert.Empty(t, memories) } - -// sqliteTestFactory wraps sqlite.Factory for testing (kept for reference) -type sqliteTestFactory struct{} - -func (f *sqliteTestFactory) CreateDriver(ctx context.Context, cfg latest.MemoryConfig) (memory.Driver, error) { - factory := &sqlite.Factory{} - return factory.CreateDriver(ctx, cfg) -} diff --git a/pkg/memory/sqlite/driver.go b/pkg/memory/sqlite/driver.go index 91172289f..ee680a956 100644 --- a/pkg/memory/sqlite/driver.go +++ b/pkg/memory/sqlite/driver.go @@ -6,10 +6,11 @@ import ( "fmt" "time" + "github.com/google/uuid" + "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/memory" "github.com/docker/cagent/pkg/sqliteutil" - "github.com/google/uuid" ) // Driver implements the memory.Driver interface using SQLite @@ -46,7 +47,7 @@ func (f *Factory) CreateDriver(ctx context.Context, cfg latest.MemoryConfig) (me return &Driver{db: db}, nil } -func (d *Driver) Store(ctx context.Context, key string, value string) error { +func (d *Driver) Store(ctx context.Context, key, value string) error { if key == "" { key = uuid.New().String() } @@ -62,23 +63,20 @@ func (d *Driver) Store(ctx context.Context, key string, value string) error { } func (d *Driver) Retrieve(ctx context.Context, query memory.Query) ([]memory.Entry, error) { - var rows *sql.Rows - var err error - - if query.ID != "" { - rows, err = d.db.QueryContext(ctx, + var ( + rows *sql.Rows + err error + ) + + switch { + case query.ID != "": + rows, err = d.db.QueryContext( + ctx, "SELECT id, created_at, content FROM memories WHERE id = ?", - query.ID) - } else if query.Semantic != "" { - // Semantic search not yet implemented for SQLite - // For now, fall back to retrieving all memories - // Future: Use FTS5 or vector extension for semantic search - sqlQuery := "SELECT id, created_at, content FROM memories ORDER BY created_at DESC" - if query.Limit > 0 { - sqlQuery = fmt.Sprintf("%s LIMIT %d", sqlQuery, query.Limit) - } - rows, err = d.db.QueryContext(ctx, sqlQuery) - } else { + query.ID, + ) + default: + // Semantic search is not yet implemented for SQLite. sqlQuery := "SELECT id, created_at, content FROM memories ORDER BY created_at DESC" if query.Limit > 0 { sqlQuery = fmt.Sprintf("%s LIMIT %d", sqlQuery, query.Limit) diff --git a/pkg/memory/sqlite/driver_test.go b/pkg/memory/sqlite/driver_test.go index 7839cc71c..c141cf16a 100644 --- a/pkg/memory/sqlite/driver_test.go +++ b/pkg/memory/sqlite/driver_test.go @@ -73,7 +73,7 @@ func TestDriver_StoreAndRetrieve(t *testing.T) { t.Run("retrieve all with limit", func(t *testing.T) { // Add more entries - for i := 0; i < 5; i++ { + for range 5 { err := driver.Store(ctx, "", "bulk value") require.NoError(t, err) } diff --git a/pkg/memory/sqlite/init.go b/pkg/memory/sqlite/init.go index dc9af38bc..53086185b 100644 --- a/pkg/memory/sqlite/init.go +++ b/pkg/memory/sqlite/init.go @@ -2,6 +2,7 @@ package sqlite import "github.com/docker/cagent/pkg/memory" -func init() { +var _ = func() struct{} { memory.RegisterFactory("sqlite", &Factory{}) -} + return struct{}{} +}() From 9766a967984f6799e0af01b53a1b88106b598953 Mon Sep 17 00:00:00 2001 From: venkat1701 Date: Sat, 24 Jan 2026 15:05:17 +0530 Subject: [PATCH 3/7] Remove version declaration from memory_demo.yaml --- examples/memory_demo.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/memory_demo.yaml b/examples/memory_demo.yaml index e80b25931..6c1f54785 100644 --- a/examples/memory_demo.yaml +++ b/examples/memory_demo.yaml @@ -1,5 +1,4 @@ #!/usr/bin/env cagent run -version: "4" metadata: author: Memory System Demo From ee3e78685329a2ead773c8a2dcb291ac0737664e Mon Sep 17 00:00:00 2001 From: venkat1701 Date: Sat, 24 Jan 2026 15:12:50 +0530 Subject: [PATCH 4/7] Refactor SQLite driver registration for clarity and maintainability --- pkg/memory/sqlite/driver_test.go | 5 +++-- pkg/memory/sqlite/init.go | 9 +++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pkg/memory/sqlite/driver_test.go b/pkg/memory/sqlite/driver_test.go index c141cf16a..63f15835c 100644 --- a/pkg/memory/sqlite/driver_test.go +++ b/pkg/memory/sqlite/driver_test.go @@ -4,10 +4,11 @@ import ( "path/filepath" "testing" - "github.com/docker/cagent/pkg/config/latest" - "github.com/docker/cagent/pkg/memory" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/config/latest" + "github.com/docker/cagent/pkg/memory" ) func TestFactory_CreateDriver(t *testing.T) { diff --git a/pkg/memory/sqlite/init.go b/pkg/memory/sqlite/init.go index 53086185b..14adbf4e0 100644 --- a/pkg/memory/sqlite/init.go +++ b/pkg/memory/sqlite/init.go @@ -2,7 +2,12 @@ package sqlite import "github.com/docker/cagent/pkg/memory" -var _ = func() struct{} { +// registerSQLite registers the sqlite driver factory via package side-effects. +// +//nolint:unparam // Return value exists only to allow calling from a var initializer. +func registerSQLite() struct{} { memory.RegisterFactory("sqlite", &Factory{}) return struct{}{} -}() +} + +var _ = registerSQLite() From f18b8bbeabf4de27665585677897dda1f197af24 Mon Sep 17 00:00:00 2001 From: venkat1701 Date: Wed, 28 Jan 2026 05:27:04 +0530 Subject: [PATCH 5/7] implement memory driver support in Team and enhance toolset creation with fix in memory driver closing Signed-off by Krish Jaiswal - thejeastdev@gmail.com --- pkg/team/team.go | 19 +++++++++++-- pkg/team/team_test.go | 44 +++++++++++++++++++++++++++++ pkg/teamloader/registry.go | 20 +++++++++++++- pkg/teamloader/registry_test.go | 46 +++++++++++++++++++++++++++++++ pkg/teamloader/teamloader.go | 33 ++++++++++++++++++++-- pkg/teamloader/teamloader_test.go | 36 ++++++++++++++++++++++++ pkg/tools/builtin/memory.go | 31 +++++++++++++++++---- pkg/tools/builtin/memory_test.go | 19 +++++++++++++ 8 files changed, 237 insertions(+), 11 deletions(-) create mode 100644 pkg/team/team_test.go diff --git a/pkg/team/team.go b/pkg/team/team.go index 42fee64e3..895907f0b 100644 --- a/pkg/team/team.go +++ b/pkg/team/team.go @@ -9,14 +9,16 @@ import ( "github.com/docker/cagent/pkg/agent" "github.com/docker/cagent/pkg/config/types" + "github.com/docker/cagent/pkg/memory" "github.com/docker/cagent/pkg/permissions" "github.com/docker/cagent/pkg/rag" ) type Team struct { - agents []*agent.Agent - ragManagers map[string]*rag.Manager - permissions *permissions.Checker + agents []*agent.Agent + ragManagers map[string]*rag.Manager + memoryDrivers map[string]memory.Driver + permissions *permissions.Checker } type Opt func(*Team) @@ -33,6 +35,12 @@ func WithRAGManagers(managers map[string]*rag.Manager) Opt { } } +func WithMemoryDrivers(drivers map[string]memory.Driver) Opt { + return func(t *Team) { + t.memoryDrivers = drivers + } +} + func WithPermissions(checker *permissions.Checker) Opt { return func(t *Team) { t.permissions = checker @@ -129,6 +137,11 @@ func (t *Team) StopToolSets(ctx context.Context) error { return fmt.Errorf("failed to stop tool sets: %w", err) } } + for name, driver := range t.memoryDrivers { + if err := driver.Close(); err != nil { + slog.Error("Failed to close memory driver", "name", name, "error", err) + } + } for name, mgr := range t.ragManagers { if err := mgr.Close(); err != nil { slog.Error("Failed to close RAG manager", "name", name, "error", err) diff --git a/pkg/team/team_test.go b/pkg/team/team_test.go new file mode 100644 index 000000000..c4f3445a0 --- /dev/null +++ b/pkg/team/team_test.go @@ -0,0 +1,44 @@ +package team_test + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/docker/cagent/pkg/memory" + "github.com/docker/cagent/pkg/team" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type closeTrackingDriver struct { + closed atomic.Bool +} + +func (d *closeTrackingDriver) Store(context.Context, string, string) error { return nil } + +func (d *closeTrackingDriver) Retrieve(context.Context, memory.Query) ([]memory.Entry, error) { + return nil, nil +} + +func (d *closeTrackingDriver) Delete(context.Context, string) error { return nil } + +func (d *closeTrackingDriver) Close() error { + d.closed.Store(true) + return nil +} + +func TestTeamStopToolSets_ClosesMemoryDrivers(t *testing.T) { + t.Parallel() + + driver := &closeTrackingDriver{} + + tm := team.New(team.WithMemoryDrivers(map[string]memory.Driver{ + "test": driver, + })) + + require.NoError(t, tm.StopToolSets(t.Context())) + assert.True(t, driver.closed.Load()) +} + + diff --git a/pkg/teamloader/registry.go b/pkg/teamloader/registry.go index dd59e9ecf..bb821bf99 100644 --- a/pkg/teamloader/registry.go +++ b/pkg/teamloader/registry.go @@ -3,6 +3,7 @@ package teamloader import ( "context" "fmt" + "io" "os" "path/filepath" "time" @@ -21,6 +22,20 @@ import ( "github.com/docker/cagent/pkg/tools/mcp" ) +type toolSetWithCloser struct { + tools.ToolSet + closer io.Closer +} + +func (t toolSetWithCloser) Stop(ctx context.Context) error { + stopErr := t.ToolSet.Stop(ctx) + closeErr := t.closer.Close() + if stopErr != nil { + return stopErr + } + return closeErr +} + // ToolsetCreator is a function that creates a toolset based on the provided configuration type ToolsetCreator func(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) @@ -112,7 +127,10 @@ func createMemoryTool(ctx context.Context, toolset latest.Toolset, parentDir str // Adapt new driver to legacy database interface db := memory.NewDatabaseAdapter(driver) - return builtin.NewMemoryTool(db), nil + return toolSetWithCloser{ + ToolSet: builtin.NewMemoryTool(db), + closer: driver, + }, nil } func createThinkTool(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig) (tools.ToolSet, error) { diff --git a/pkg/teamloader/registry_test.go b/pkg/teamloader/registry_test.go index c705b2c6c..d0b81314c 100644 --- a/pkg/teamloader/registry_test.go +++ b/pkg/teamloader/registry_test.go @@ -2,6 +2,8 @@ package teamloader import ( "context" + "errors" + "io" "testing" "github.com/stretchr/testify/assert" @@ -10,6 +12,7 @@ import ( "github.com/docker/cagent/pkg/config" "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/environment" + "github.com/docker/cagent/pkg/tools" ) type mockEnvProvider struct { @@ -140,3 +143,46 @@ func TestCreateShellToolWithSandboxExpansion(t *testing.T) { require.NoError(t, err) require.NotNil(t, tool) } + +type stopTrackingToolSet struct { + tools.BaseToolSet + stopped bool + stopErr error +} + +func (s *stopTrackingToolSet) Tools(context.Context) ([]tools.Tool, error) { return nil, nil } +func (s *stopTrackingToolSet) Stop(context.Context) error { + s.stopped = true + return s.stopErr +} + +type closeTrackingCloser struct { + closed bool + err error +} + +func (c *closeTrackingCloser) Close() error { + c.closed = true + return c.err +} + +func TestToolSetWithCloser_Stop_AlwaysCloses(t *testing.T) { + t.Parallel() + + base := &stopTrackingToolSet{stopErr: errors.New("stop failed")} + closer := &closeTrackingCloser{} + + ts := toolSetWithCloser{ToolSet: base, closer: closer} + + err := ts.Stop(t.Context()) + require.Error(t, err) + + assert.True(t, base.stopped) + assert.True(t, closer.closed) + + // Ensure we return the stop error (and not swallow it) + assert.ErrorContains(t, err, "stop failed") + + // Closer should be an io.Closer, as required + var _ io.Closer = closer +} diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 45143d770..0116cbdac 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "log/slog" + "regexp" "strings" "sync" @@ -25,6 +26,8 @@ import ( "github.com/docker/cagent/pkg/tools/codemode" ) +var nonToolNameChars = regexp.MustCompile(`[^a-zA-Z0-9_-]+`) + var defaultMaxTokens int64 = 32000 // isThinkingBudgetDisabled returns true if the thinking budget is explicitly set to disable thinking @@ -84,7 +87,7 @@ func Load(ctx context.Context, agentSource config.Source, runConfig *config.Runt // LoadWithConfig loads an agent team and returns both the team and config info // needed for runtime model switching. -func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *config.RuntimeConfig, opts ...Opt) (*LoadResult, error) { +func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *config.RuntimeConfig, opts ...Opt) (res *LoadResult, err error) { var loadOpts loadOptions loadOpts.toolsetRegistry = NewDefaultToolsetRegistry() @@ -117,6 +120,16 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c // Create memory drivers from top-level memory configs memoryDrivers := make(map[string]memory.Driver) + defer func() { + if err == nil { + return + } + for name, driver := range memoryDrivers { + if closeErr := driver.Close(); closeErr != nil { + slog.Error("Failed to close memory driver after load failure", "name", name, "error", closeErr) + } + } + }() for name, memCfg := range cfg.Memory { driver, err := memory.CreateDriver(ctx, memCfg) if err != nil { @@ -242,6 +255,7 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c Team: team.New( team.WithAgents(agents...), team.WithRAGManagers(ragManagers), + team.WithMemoryDrivers(memoryDrivers), team.WithPermissions(permChecker), ), Models: cfg.Models, @@ -418,6 +432,8 @@ func createMemoryToolsForAgent(agentConfig *latest.AgentConfig, allDrivers map[s } var memoryTools []tools.ToolSet + multiScope := len(agentConfig.Memory) > 1 + usedPrefixes := make(map[string]int) for _, memName := range agentConfig.Memory { driver, exists := allDrivers[memName] @@ -428,7 +444,20 @@ func createMemoryToolsForAgent(agentConfig *latest.AgentConfig, allDrivers map[s // Adapt driver to legacy database interface db := memory.NewDatabaseAdapter(driver) - memTool := builtin.NewMemoryTool(db) + var memTool tools.ToolSet + if !multiScope { + memTool = builtin.NewMemoryTool(db) + } else { + prefix := nonToolNameChars.ReplaceAllString(memName, "_") + if prefix == "" { + prefix = "memory" + } + usedPrefixes[prefix]++ + if usedPrefixes[prefix] > 1 { + prefix = fmt.Sprintf("%s_%d", prefix, usedPrefixes[prefix]) + } + memTool = builtin.NewMemoryToolWithPrefix(db, prefix) + } memoryTools = append(memoryTools, memTool) slog.Debug("Created memory tool for agent", diff --git a/pkg/teamloader/teamloader_test.go b/pkg/teamloader/teamloader_test.go index ce37c5902..4de9cd44f 100644 --- a/pkg/teamloader/teamloader_test.go +++ b/pkg/teamloader/teamloader_test.go @@ -12,6 +12,7 @@ import ( "github.com/docker/cagent/pkg/config" "github.com/docker/cagent/pkg/config/latest" + "github.com/docker/cagent/pkg/memory" ) // skipExamples contains example files that require cloud-specific configurations @@ -226,3 +227,38 @@ func TestIsThinkingBudgetDisabled(t *testing.T) { }) } } + +type noopMemoryDriver struct{} + +func (noopMemoryDriver) Store(context.Context, string, string) error { return nil } +func (noopMemoryDriver) Retrieve(context.Context, memory.Query) ([]memory.Entry, error) { + return nil, nil +} +func (noopMemoryDriver) Delete(context.Context, string) error { return nil } +func (noopMemoryDriver) Close() error { return nil } + +func TestCreateMemoryToolsForAgent_MultipleScopes_UsesUniqueToolNames(t *testing.T) { + t.Parallel() + + agentCfg := &latest.AgentConfig{ + Memory: []string{"long-term", "long term"}, + } + + drivers := map[string]memory.Driver{ + "long-term": noopMemoryDriver{}, + "long term": noopMemoryDriver{}, + } + + toolsets := createMemoryToolsForAgent(agentCfg, drivers) + require.Len(t, toolsets, 2) + + seen := make(map[string]bool) + for _, ts := range toolsets { + toolsList, err := ts.Tools(t.Context()) + require.NoError(t, err) + for _, tl := range toolsList { + require.False(t, seen[tl.Name], "tool name collision: %s", tl.Name) + seen[tl.Name] = true + } + } +} diff --git a/pkg/tools/builtin/memory.go b/pkg/tools/builtin/memory.go index d7a5ae8c2..42761393e 100644 --- a/pkg/tools/builtin/memory.go +++ b/pkg/tools/builtin/memory.go @@ -25,14 +25,24 @@ type DB interface { type MemoryTool struct { tools.BaseToolSet db DB + // namePrefix, when set, namespaces the tool names to avoid collisions + // (e.g., "_get_memories"). + namePrefix string } // Make sure Memory Tool implements the ToolSet Interface var _ tools.ToolSet = (*MemoryTool)(nil) func NewMemoryTool(manager DB) *MemoryTool { + return NewMemoryToolWithPrefix(manager, "") +} + +// NewMemoryToolWithPrefix creates a MemoryTool that uses prefixed tool names. +// When prefix is empty, tool names are the legacy unprefixed names. +func NewMemoryToolWithPrefix(manager DB, prefix string) *MemoryTool { return &MemoryTool{ - db: manager, + db: manager, + namePrefix: prefix, } } @@ -45,19 +55,30 @@ type DeleteMemoryArgs struct { } func (t *MemoryTool) Instructions() string { + getMemoriesTool := ToolNameGetMemories + if t.namePrefix != "" { + getMemoriesTool = t.namePrefix + "_" + ToolNameGetMemories + } return `## Using the memory tool -Before taking any action or responding to the user use the "get_memories" tool to remember things about the user. +Before taking any action or responding to the user use the "` + getMemoriesTool + `" tool to remember things about the user. Do not talk about using the tool, just use it. ## Rules - Use the memory tool generously to remember things about the user.` } +func (t *MemoryTool) toolName(base string) string { + if t.namePrefix == "" { + return base + } + return t.namePrefix + "_" + base +} + func (t *MemoryTool) Tools(context.Context) ([]tools.Tool, error) { return []tools.Tool{ { - Name: ToolNameAddMemory, + Name: t.toolName(ToolNameAddMemory), Category: "memory", Description: "Add a new memory to the database", Parameters: tools.MustSchemaFor[AddMemoryArgs](), @@ -68,7 +89,7 @@ func (t *MemoryTool) Tools(context.Context) ([]tools.Tool, error) { }, }, { - Name: ToolNameGetMemories, + Name: t.toolName(ToolNameGetMemories), Category: "memory", Description: "Retrieve all stored memories", OutputSchema: tools.MustSchemaFor[[]database.UserMemory](), @@ -79,7 +100,7 @@ func (t *MemoryTool) Tools(context.Context) ([]tools.Tool, error) { }, }, { - Name: ToolNameDeleteMemory, + Name: t.toolName(ToolNameDeleteMemory), Category: "memory", Description: "Delete a specific memory by ID", Parameters: tools.MustSchemaFor[DeleteMemoryArgs](), diff --git a/pkg/tools/builtin/memory_test.go b/pkg/tools/builtin/memory_test.go index 42d8cdd13..340c0a9c3 100644 --- a/pkg/tools/builtin/memory_test.go +++ b/pkg/tools/builtin/memory_test.go @@ -144,3 +144,22 @@ func TestMemoryTool_ParametersAreObjects(t *testing.T) { assert.Equal(t, "object", m["type"]) } } + +func TestMemoryTool_WithPrefix_NamespacesToolNames(t *testing.T) { + t.Parallel() + + manager := new(MockDB) + toolset := NewMemoryToolWithPrefix(manager, "longterm") + + all, err := toolset.Tools(t.Context()) + require.NoError(t, err) + + names := make([]string, 0, len(all)) + for _, tl := range all { + names = append(names, tl.Name) + } + + assert.Contains(t, names, "longterm_"+ToolNameAddMemory) + assert.Contains(t, names, "longterm_"+ToolNameGetMemories) + assert.Contains(t, names, "longterm_"+ToolNameDeleteMemory) +} From 1a80301382d5b3382e702ee5126a688200bd45d8 Mon Sep 17 00:00:00 2001 From: venkat1701 Date: Thu, 29 Jan 2026 23:05:55 +0530 Subject: [PATCH 6/7] reorder import statements in team_test.go for lint Signed off by Krish Jaiswal - thejeastdev@gmail.com --- pkg/team/team_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/team/team_test.go b/pkg/team/team_test.go index c4f3445a0..9133ef950 100644 --- a/pkg/team/team_test.go +++ b/pkg/team/team_test.go @@ -5,10 +5,11 @@ import ( "sync/atomic" "testing" - "github.com/docker/cagent/pkg/memory" - "github.com/docker/cagent/pkg/team" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/memory" + "github.com/docker/cagent/pkg/team" ) type closeTrackingDriver struct { From 3d70b26f5ecec9a3220dc698d15a58c85ad20475 Mon Sep 17 00:00:00 2001 From: venkat1701 Date: Thu, 29 Jan 2026 23:16:41 +0530 Subject: [PATCH 7/7] lints fixes --- pkg/team/team_test.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pkg/team/team_test.go b/pkg/team/team_test.go index 9133ef950..19b104924 100644 --- a/pkg/team/team_test.go +++ b/pkg/team/team_test.go @@ -41,5 +41,3 @@ func TestTeamStopToolSets_ClosesMemoryDrivers(t *testing.T) { require.NoError(t, tm.StopToolSets(t.Context())) assert.True(t, driver.closed.Load()) } - -