diff --git a/docs/tools-and-tui-integration.md b/docs/tools-and-tui-integration.md index 79d72691..f00f7927 100644 --- a/docs/tools-and-tui-integration.md +++ b/docs/tools-and-tui-integration.md @@ -21,11 +21,14 @@ - `webfetch` - `memo_remember` - `memo_recall` +- `memo_list` +- `memo_remove` ## Memo 能力集成 -- `memo_remember` 与 `memo_recall` 作为标准工具暴露给模型,沿 `Runtime -> Tool Manager -> internal/tools/memo` 链路执行。 +- `memo_remember`、`memo_recall`、`memo_list`、`memo_remove` 作为标准工具暴露给模型,沿 `Runtime -> Tool Manager -> internal/tools/memo` 链路执行。 - 自动记忆提取不作为单独工具暴露给模型,也不由 TUI 直接触发;它在 runtime 完成最终回复后由 memo 子系统后台调度。 -- TUI 目前只通过 Slash Command 展示和管理 memo(如 `/memo`、`/remember`、`/forget`),不会展示后台自动提取的中间状态。 +- TUI 的 `/memo`、`/remember`、`/forget` 等 Slash Command 不再直接依赖 memo service,而是通过 `Runtime.ExecuteSystemTool` 统一入口触发系统工具执行,保证 UI 与 memo 逻辑解耦。 +- TUI 不会展示后台自动提取的中间状态。 ## TUI 集成方式 - 本地配置操作统一通过 Slash Command 完成,例如 Base URL、API Key 和模型选择 diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index 9a9f217d..05bde9c4 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -79,7 +79,7 @@ func newMemoExtractorAdapter( }) }) - scheduler.ScheduleWithExtractor(sessionID, messages, memo.NewLLMExtractor(generator)) + scheduler.ScheduleWithExtractor(sessionID, messages, memo.NewLLMExtractor(generator, cfg.Memo.ExtractRecentMessages)) }) } @@ -159,9 +159,11 @@ func BuildRuntime(ctx context.Context, opts BootstrapOptions) (RuntimeBundle, er sourceInvl = invalidator.InvalidateCache } contextBuilder = agentcontext.NewBuilderWithMemoAndSummarizers(toolRegistry, toolRegistry, memoSource) - memoSvc = memo.NewService(memoStore, nil, cfg.Memo, sourceInvl) + memoSvc = memo.NewService(memoStore, cfg.Memo, sourceInvl) toolRegistry.Register(memotool.NewRememberTool(memoSvc)) toolRegistry.Register(memotool.NewRecallTool(memoSvc)) + toolRegistry.Register(memotool.NewListTool(memoSvc)) + toolRegistry.Register(memotool.NewRemoveTool(memoSvc)) } runtimeSvc := agentruntime.NewWithFactory( @@ -189,7 +191,7 @@ func BuildRuntime(ctx context.Context, opts BootstrapOptions) (RuntimeBundle, er runtimeSvc.SetMemoExtractor(newMemoExtractorAdapter( providerRegistry, manager, - memo.NewAutoExtractor(nil, memoSvc), + memo.NewAutoExtractor(nil, memoSvc, time.Duration(cfg.Memo.ExtractTimeoutSec)*time.Second), )) } needCleanup = false diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 46dc3966..71e52b9b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1492,16 +1492,19 @@ func TestMemoConfigClone(t *testing.T) { t.Parallel() original := MemoConfig{ - Enabled: true, - AutoExtract: false, - MaxIndexLines: 100, + Enabled: true, + AutoExtract: false, + MaxEntries: 100, + MaxIndexBytes: 2048, + ExtractTimeoutSec: 9, + ExtractRecentMessages: 3, } cloned := original.Clone() if cloned != original { t.Fatalf("Clone() = %+v, want %+v", cloned, original) } - cloned.MaxIndexLines = 200 - if original.MaxIndexLines != 100 { + cloned.MaxEntries = 200 + if original.MaxEntries != 100 { t.Error("modifying clone should not affect original (value type check)") } } @@ -1509,25 +1512,57 @@ func TestMemoConfigClone(t *testing.T) { func TestMemoConfigApplyDefaults(t *testing.T) { t.Parallel() - t.Run("fills zero MaxIndexLines", func(t *testing.T) { - cfg := MemoConfig{Enabled: true, MaxIndexLines: 0} - cfg.ApplyDefaults(MemoConfig{MaxIndexLines: DefaultMemoMaxIndexLines}) - if cfg.MaxIndexLines != DefaultMemoMaxIndexLines { - t.Errorf("MaxIndexLines = %d, want %d", cfg.MaxIndexLines, DefaultMemoMaxIndexLines) + t.Run("fills zero fields", func(t *testing.T) { + cfg := MemoConfig{} + cfg.ApplyDefaults(MemoConfig{ + MaxEntries: DefaultMemoMaxEntries, + MaxIndexBytes: DefaultMemoMaxIndexBytes, + ExtractTimeoutSec: DefaultMemoExtractTimeoutSec, + ExtractRecentMessages: DefaultMemoExtractRecentMessage, + }) + if cfg.MaxEntries != DefaultMemoMaxEntries { + t.Errorf("MaxEntries = %d, want %d", cfg.MaxEntries, DefaultMemoMaxEntries) + } + if cfg.MaxIndexBytes != DefaultMemoMaxIndexBytes { + t.Errorf("MaxIndexBytes = %d, want %d", cfg.MaxIndexBytes, DefaultMemoMaxIndexBytes) + } + if cfg.ExtractTimeoutSec != DefaultMemoExtractTimeoutSec { + t.Errorf("ExtractTimeoutSec = %d, want %d", cfg.ExtractTimeoutSec, DefaultMemoExtractTimeoutSec) + } + if cfg.ExtractRecentMessages != DefaultMemoExtractRecentMessage { + t.Errorf("ExtractRecentMessages = %d, want %d", cfg.ExtractRecentMessages, DefaultMemoExtractRecentMessage) + } + }) + + t.Run("preserves explicit fields", func(t *testing.T) { + cfg := MemoConfig{ + MaxEntries: 50, + MaxIndexBytes: 1024, + ExtractTimeoutSec: 30, + ExtractRecentMessages: 5, + } + cfg.ApplyDefaults(defaultMemoConfig()) + if cfg.MaxEntries != 50 || cfg.MaxIndexBytes != 1024 || cfg.ExtractTimeoutSec != 30 || cfg.ExtractRecentMessages != 5 { + t.Fatalf("ApplyDefaults() unexpectedly overwrote explicit values: %+v", cfg) } }) - t.Run("preserves explicit MaxIndexLines", func(t *testing.T) { - cfg := MemoConfig{MaxIndexLines: 50} - cfg.ApplyDefaults(MemoConfig{MaxIndexLines: DefaultMemoMaxIndexLines}) - if cfg.MaxIndexLines != 50 { - t.Errorf("MaxIndexLines = %d, want 50", cfg.MaxIndexLines) + t.Run("preserves negative fields for validation", func(t *testing.T) { + cfg := MemoConfig{ + MaxEntries: -1, + MaxIndexBytes: -2, + ExtractTimeoutSec: -3, + ExtractRecentMessages: -4, + } + cfg.ApplyDefaults(defaultMemoConfig()) + if cfg.MaxEntries != -1 || cfg.MaxIndexBytes != -2 || cfg.ExtractTimeoutSec != -3 || cfg.ExtractRecentMessages != -4 { + t.Fatalf("ApplyDefaults() unexpectedly rewrote invalid values: %+v", cfg) } }) t.Run("nil receiver is no-op", func(t *testing.T) { var cfg *MemoConfig - cfg.ApplyDefaults(MemoConfig{MaxIndexLines: 200}) + cfg.ApplyDefaults(defaultMemoConfig()) }) } @@ -1535,23 +1570,41 @@ func TestMemoConfigValidate(t *testing.T) { t.Parallel() t.Run("valid config", func(t *testing.T) { - cfg := MemoConfig{MaxIndexLines: 100} + cfg := defaultMemoConfig() if err := cfg.Validate(); err != nil { t.Fatalf("valid config should not error: %v", err) } }) - t.Run("negative MaxIndexLines", func(t *testing.T) { - cfg := MemoConfig{MaxIndexLines: -1} + t.Run("non-positive MaxEntries", func(t *testing.T) { + cfg := defaultMemoConfig() + cfg.MaxEntries = 0 if err := cfg.Validate(); err == nil { - t.Fatal("negative MaxIndexLines should fail validation") + t.Fatal("non-positive MaxEntries should fail validation") } }) - t.Run("zero MaxIndexLines is valid", func(t *testing.T) { - cfg := MemoConfig{MaxIndexLines: 0} - if err := cfg.Validate(); err != nil { - t.Fatalf("zero MaxIndexLines should be valid: %v", err) + t.Run("non-positive MaxIndexBytes", func(t *testing.T) { + cfg := defaultMemoConfig() + cfg.MaxIndexBytes = -1 + if err := cfg.Validate(); err == nil { + t.Fatal("non-positive MaxIndexBytes should fail validation") + } + }) + + t.Run("non-positive ExtractTimeoutSec", func(t *testing.T) { + cfg := defaultMemoConfig() + cfg.ExtractTimeoutSec = 0 + if err := cfg.Validate(); err == nil { + t.Fatal("non-positive ExtractTimeoutSec should fail validation") + } + }) + + t.Run("non-positive ExtractRecentMessages", func(t *testing.T) { + cfg := defaultMemoConfig() + cfg.ExtractRecentMessages = 0 + if err := cfg.Validate(); err == nil { + t.Fatal("non-positive ExtractRecentMessages should fail validation") } }) } diff --git a/internal/config/loader.go b/internal/config/loader.go index 976f923f..e0eaa68c 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -56,9 +56,13 @@ type persistedAutoCompactConfig struct { } type persistedMemoConfig struct { - Enabled *bool `yaml:"enabled,omitempty"` - AutoExtract *bool `yaml:"auto_extract,omitempty"` - MaxIndexLines int `yaml:"max_index_lines,omitempty"` + Enabled *bool `yaml:"enabled,omitempty"` + AutoExtract *bool `yaml:"auto_extract,omitempty"` + MaxEntries *int `yaml:"max_entries,omitempty"` + MaxIndexBytes *int `yaml:"max_index_bytes,omitempty"` + ExtractTimeoutSec *int `yaml:"extract_timeout_sec,omitempty"` + ExtractRecentMessages *int `yaml:"extract_recent_messages,omitempty"` + MaxIndexLines *int `yaml:"max_index_lines,omitempty"` } func NewLoader(baseDir string, defaults *Config) *Loader { @@ -333,10 +337,17 @@ func assembleProviders(builtin []ProviderConfig, custom []ProviderConfig) ([]Pro func newPersistedMemoConfig(cfg MemoConfig) persistedMemoConfig { enabled := cfg.Enabled autoExtract := cfg.AutoExtract + maxEntries := cfg.MaxEntries + maxIndexBytes := cfg.MaxIndexBytes + extractTimeoutSec := cfg.ExtractTimeoutSec + extractRecentMessages := cfg.ExtractRecentMessages return persistedMemoConfig{ - Enabled: &enabled, - AutoExtract: &autoExtract, - MaxIndexLines: cfg.MaxIndexLines, + Enabled: &enabled, + AutoExtract: &autoExtract, + MaxEntries: &maxEntries, + MaxIndexBytes: &maxIndexBytes, + ExtractTimeoutSec: &extractTimeoutSec, + ExtractRecentMessages: &extractRecentMessages, } } @@ -349,8 +360,19 @@ func fromPersistedMemoConfig(file persistedMemoConfig, defaults MemoConfig) Memo if file.AutoExtract != nil { out.AutoExtract = *file.AutoExtract } - if file.MaxIndexLines > 0 { - out.MaxIndexLines = file.MaxIndexLines + if file.MaxEntries != nil { + out.MaxEntries = *file.MaxEntries + } else if file.MaxIndexLines != nil { + out.MaxEntries = *file.MaxIndexLines + } + if file.MaxIndexBytes != nil { + out.MaxIndexBytes = *file.MaxIndexBytes + } + if file.ExtractTimeoutSec != nil { + out.ExtractTimeoutSec = *file.ExtractTimeoutSec + } + if file.ExtractRecentMessages != nil { + out.ExtractRecentMessages = *file.ExtractRecentMessages } return out } diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 842f6ef8..ee62d674 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -1259,7 +1259,10 @@ shell: powershell memo: enabled: false auto_extract: false - max_index_lines: 123 + max_entries: 123 + max_index_bytes: 4096 + extract_timeout_sec: 9 + extract_recent_messages: 4 ` writeLoaderConfig(t, loader, raw) @@ -1273,8 +1276,17 @@ memo: if cfg.Memo.AutoExtract { t.Fatalf("expected memo.auto_extract to stay false") } - if cfg.Memo.MaxIndexLines != 123 { - t.Fatalf("expected memo.max_index_lines=123, got %d", cfg.Memo.MaxIndexLines) + if cfg.Memo.MaxEntries != 123 { + t.Fatalf("expected memo.max_entries=123, got %d", cfg.Memo.MaxEntries) + } + if cfg.Memo.MaxIndexBytes != 4096 { + t.Fatalf("expected memo.max_index_bytes=4096, got %d", cfg.Memo.MaxIndexBytes) + } + if cfg.Memo.ExtractTimeoutSec != 9 { + t.Fatalf("expected memo.extract_timeout_sec=9, got %d", cfg.Memo.ExtractTimeoutSec) + } + if cfg.Memo.ExtractRecentMessages != 4 { + t.Fatalf("expected memo.extract_recent_messages=4, got %d", cfg.Memo.ExtractRecentMessages) } data, err := os.ReadFile(loader.ConfigPath()) @@ -1312,8 +1324,92 @@ shell: powershell if !cfg.Memo.AutoExtract { t.Fatalf("expected memo.auto_extract default true when memo section missing") } - if cfg.Memo.MaxIndexLines <= 0 { - t.Fatalf("expected memo.max_index_lines to be defaulted, got %d", cfg.Memo.MaxIndexLines) + if cfg.Memo.MaxEntries <= 0 { + t.Fatalf("expected memo.max_entries to be defaulted, got %d", cfg.Memo.MaxEntries) + } + if cfg.Memo.MaxIndexBytes <= 0 { + t.Fatalf("expected memo.max_index_bytes to be defaulted, got %d", cfg.Memo.MaxIndexBytes) + } + if cfg.Memo.ExtractTimeoutSec <= 0 { + t.Fatalf("expected memo.extract_timeout_sec to be defaulted, got %d", cfg.Memo.ExtractTimeoutSec) + } + if cfg.Memo.ExtractRecentMessages <= 0 { + t.Fatalf("expected memo.extract_recent_messages to be defaulted, got %d", cfg.Memo.ExtractRecentMessages) + } +} + +func TestLoaderSupportsLegacyMemoMaxIndexLinesField(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + raw := ` +selected_provider: openai +current_model: gpt-4.1 +shell: powershell +memo: + max_index_lines: 123 +` + writeLoaderConfig(t, loader, raw) + + cfg, err := loader.Load(context.Background()) + if err != nil { + t.Fatalf("expected legacy memo field to be accepted, got %v", err) + } + if cfg.Memo.MaxEntries != 123 { + t.Fatalf("expected legacy max_index_lines mapped to memo.max_entries=123, got %d", cfg.Memo.MaxEntries) + } +} + +func TestLoaderRejectsExplicitInvalidMemoNumbers(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + fieldYAML string + errContain string + }{ + { + name: "negative max_entries", + fieldYAML: "max_entries: -1", + errContain: "config: memo: max_entries must be greater than 0", + }, + { + name: "negative max_index_bytes", + fieldYAML: "max_index_bytes: -1", + errContain: "config: memo: max_index_bytes must be greater than 0", + }, + { + name: "negative extract_timeout_sec", + fieldYAML: "extract_timeout_sec: -1", + errContain: "config: memo: extract_timeout_sec must be greater than 0", + }, + { + name: "negative extract_recent_messages", + fieldYAML: "extract_recent_messages: -1", + errContain: "config: memo: extract_recent_messages must be greater than 0", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + raw := ` +selected_provider: openai +current_model: gpt-4.1 +shell: powershell +memo: + ` + tt.fieldYAML + ` +` + writeLoaderConfig(t, loader, raw) + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), tt.errContain) { + t.Fatalf("expected %q, got %v", tt.errContain, err) + } + }) } } diff --git a/internal/config/memo.go b/internal/config/memo.go index d6c0f42e..9da2e2d3 100644 --- a/internal/config/memo.go +++ b/internal/config/memo.go @@ -2,21 +2,32 @@ package config import "errors" -const DefaultMemoMaxIndexLines = 200 +const ( + DefaultMemoMaxEntries = 200 + DefaultMemoMaxIndexBytes = 16 * 1024 + DefaultMemoExtractTimeoutSec = 15 + DefaultMemoExtractRecentMessage = 10 +) // MemoConfig 控制跨会话持久记忆的行为配置。 type MemoConfig struct { - Enabled bool `yaml:"enabled,omitempty"` - AutoExtract bool `yaml:"auto_extract,omitempty"` - MaxIndexLines int `yaml:"max_index_lines,omitempty"` + Enabled bool `yaml:"enabled,omitempty"` + AutoExtract bool `yaml:"auto_extract,omitempty"` + MaxEntries int `yaml:"max_entries,omitempty"` + MaxIndexBytes int `yaml:"max_index_bytes,omitempty"` + ExtractTimeoutSec int `yaml:"extract_timeout_sec,omitempty"` + ExtractRecentMessages int `yaml:"extract_recent_messages,omitempty"` } // defaultMemoConfig 返回跨会话记忆的默认配置。 func defaultMemoConfig() MemoConfig { return MemoConfig{ - Enabled: true, - AutoExtract: true, - MaxIndexLines: DefaultMemoMaxIndexLines, + Enabled: true, + AutoExtract: true, + MaxEntries: DefaultMemoMaxEntries, + MaxIndexBytes: DefaultMemoMaxIndexBytes, + ExtractTimeoutSec: DefaultMemoExtractTimeoutSec, + ExtractRecentMessages: DefaultMemoExtractRecentMessage, } } @@ -30,15 +41,33 @@ func (c *MemoConfig) ApplyDefaults(defaults MemoConfig) { if c == nil { return } - if c.MaxIndexLines <= 0 { - c.MaxIndexLines = defaults.MaxIndexLines + if c.MaxEntries == 0 { + c.MaxEntries = defaults.MaxEntries + } + if c.MaxIndexBytes == 0 { + c.MaxIndexBytes = defaults.MaxIndexBytes + } + if c.ExtractTimeoutSec == 0 { + c.ExtractTimeoutSec = defaults.ExtractTimeoutSec + } + if c.ExtractRecentMessages == 0 { + c.ExtractRecentMessages = defaults.ExtractRecentMessages } } // Validate 校验 memo 配置是否合法。 func (c MemoConfig) Validate() error { - if c.MaxIndexLines < 0 { - return errors.New("max_index_lines must be non-negative") + if c.MaxEntries <= 0 { + return errors.New("max_entries must be greater than 0") + } + if c.MaxIndexBytes <= 0 { + return errors.New("max_index_bytes must be greater than 0") + } + if c.ExtractTimeoutSec <= 0 { + return errors.New("extract_timeout_sec must be greater than 0") + } + if c.ExtractRecentMessages <= 0 { + return errors.New("extract_recent_messages must be greater than 0") } return nil } diff --git a/internal/memo/auto_extractor.go b/internal/memo/auto_extractor.go index 7ca659b4..666342af 100644 --- a/internal/memo/auto_extractor.go +++ b/internal/memo/auto_extractor.go @@ -17,11 +17,12 @@ const ( // AutoExtractor 负责按会话在后台调度自动提取,并处理防抖、互斥和尾随执行。 type AutoExtractor struct { - extractor Extractor - svc *Service - debounce time.Duration - idleTTL time.Duration - logf func(format string, args ...any) + extractor Extractor + svc *Service + debounce time.Duration + idleTTL time.Duration + extractTimeout time.Duration + logf func(format string, args ...any) mu sync.Mutex states map[string]*autoExtractState @@ -44,14 +45,18 @@ type autoExtractRequest struct { } // NewAutoExtractor 创建后台自动提取调度器。 -func NewAutoExtractor(extractor Extractor, svc *Service) *AutoExtractor { +func NewAutoExtractor(extractor Extractor, svc *Service, extractTimeout time.Duration) *AutoExtractor { + if extractTimeout <= 0 { + extractTimeout = 15 * time.Second + } return &AutoExtractor{ - extractor: extractor, - svc: svc, - debounce: autoExtractDebounce, - idleTTL: autoExtractIdleTTL, - logf: log.Printf, - states: make(map[string]*autoExtractState), + extractor: extractor, + svc: svc, + debounce: autoExtractDebounce, + idleTTL: autoExtractIdleTTL, + extractTimeout: extractTimeout, + logf: log.Printf, + states: make(map[string]*autoExtractState), } } @@ -205,7 +210,9 @@ func isIdleStateLocked(state *autoExtractState, seq uint64) bool { // extractAndStore 执行提取,并在写入前做本地批次去重和持久化级别的原子去重。 func (a *AutoExtractor) extractAndStore(extractor Extractor, messages []providertypes.Message) { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), a.extractTimeout) + defer cancel() + entries, err := extractor.Extract(ctx, messages) if err != nil { a.logError("memo: auto extract failed: %v", err) diff --git a/internal/memo/auto_extractor_test.go b/internal/memo/auto_extractor_test.go index 5cc9e956..502903f4 100644 --- a/internal/memo/auto_extractor_test.go +++ b/internal/memo/auto_extractor_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "neo-code/internal/config" providertypes "neo-code/internal/provider/types" ) @@ -41,7 +40,19 @@ func (s *stubMemoExtractor) Calls() int { func newAutoExtractorTestService(t *testing.T) *Service { t.Helper() store := NewFileStore(t.TempDir(), t.TempDir()) - return NewService(store, nil, config.MemoConfig{MaxIndexLines: 200}, nil) + return NewService(store, testMemoConfig(), nil) +} + +func registerAutoExtractorCleanup(t *testing.T, auto *AutoExtractor) { + t.Helper() + auto.idleTTL = 20 * time.Millisecond + t.Cleanup(func() { + waitFor(t, time.Second, func() bool { + auto.mu.Lock() + defer auto.mu.Unlock() + return len(auto.states) == 0 + }) + }) } func TestAutoExtractorDebounceMergesRequests(t *testing.T) { @@ -52,32 +63,18 @@ func TestAutoExtractorDebounceMergesRequests(t *testing.T) { return []Entry{{Type: TypeProject, Title: last, Content: last, Source: SourceAutoExtract}}, nil }, } - auto := NewAutoExtractor(extractor, svc) + auto := NewAutoExtractor(extractor, svc, time.Second) auto.debounce = 20 * time.Millisecond auto.logf = func(string, ...any) {} + registerAutoExtractorCleanup(t, auto) auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("first")}}}) auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("second")}}}) - waitFor(t, time.Second, func() bool { return extractor.Calls() == 1 }) - time.Sleep(60 * time.Millisecond) - - if extractor.Calls() != 1 { - t.Fatalf("extractor calls = %d, want 1", extractor.Calls()) - } - - recall, err := svc.Recall(context.Background(), "second") - if err != nil { - t.Fatalf("Recall() error = %v", err) - } - if len(recall) != 1 { - t.Fatalf("recall = %#v", recall) - } - for _, content := range recall { - if !strings.Contains(content, "second") { - t.Fatalf("recall content = %q", content) - } - } + waitFor(t, time.Second, func() bool { + recall, err := svc.Recall(context.Background(), "second", ScopeAll) + return err == nil && len(recall) == 1 && strings.Contains(recall[0].Content, "second") + }) } func TestAutoExtractorTrailingRun(t *testing.T) { @@ -99,12 +96,12 @@ func TestAutoExtractorTrailingRun(t *testing.T) { return []Entry{{Type: TypeProject, Title: last, Content: last, Source: SourceAutoExtract}}, nil }, } - auto := NewAutoExtractor(extractor, svc) + auto := NewAutoExtractor(extractor, svc, time.Second) auto.debounce = 15 * time.Millisecond auto.logf = func(string, ...any) {} + registerAutoExtractorCleanup(t, auto) auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("first")}}}) - select { case <-firstStarted: case <-time.After(time.Second): @@ -118,12 +115,11 @@ func TestAutoExtractorTrailingRun(t *testing.T) { select { case <-secondStarted: case <-time.After(time.Second): - t.Fatal("second trailing extraction did not start") + t.Fatal("second extraction did not start") } - waitFor(t, time.Second, func() bool { return extractor.Calls() == 2 }) waitFor(t, time.Second, func() bool { - entries, err := svc.List(context.Background()) + entries, err := svc.List(context.Background(), ScopeProject) return err == nil && len(entries) == 2 }) } @@ -135,14 +131,15 @@ func TestAutoExtractorErrorsAreSilent(t *testing.T) { return nil, errors.New("boom") }, } - auto := NewAutoExtractor(extractor, svc) + auto := NewAutoExtractor(extractor, svc, time.Second) auto.debounce = 10 * time.Millisecond auto.logf = func(string, ...any) {} + registerAutoExtractorCleanup(t, auto) auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("x")}}}) waitFor(t, time.Second, func() bool { return extractor.Calls() == 1 }) - entries, err := svc.List(context.Background()) + entries, err := svc.List(context.Background(), ScopeAll) if err != nil { t.Fatalf("List() error = %v", err) } @@ -154,9 +151,10 @@ func TestAutoExtractorErrorsAreSilent(t *testing.T) { func TestAutoExtractorSuppressesExactDuplicates(t *testing.T) { svc := newAutoExtractorTestService(t) if err := svc.Add(context.Background(), Entry{ - Type: TypeUser, - Title: "reply in chinese", Content: "reply in chinese", - Source: SourceAutoExtract, + Type: TypeUser, + Title: "reply in chinese", + Content: "reply in chinese", + Source: SourceAutoExtract, }); err != nil { t.Fatalf("seed Add() error = %v", err) } @@ -170,67 +168,40 @@ func TestAutoExtractorSuppressesExactDuplicates(t *testing.T) { }, nil }, } - auto := NewAutoExtractor(extractor, svc) + auto := NewAutoExtractor(extractor, svc, time.Second) auto.debounce = 10 * time.Millisecond auto.logf = func(string, ...any) {} + registerAutoExtractorCleanup(t, auto) auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("dedupe")}}}) waitFor(t, time.Second, func() bool { - entries, err := svc.List(context.Background()) + entries, err := svc.List(context.Background(), ScopeAll) return err == nil && len(entries) == 2 }) - - entries, err := svc.List(context.Background()) - if err != nil { - t.Fatalf("List() error = %v", err) - } - if len(entries) != 2 { - t.Fatalf("len(entries) = %d, want 2", len(entries)) - } } -func TestAutoExtractorSuppressesExactDuplicatesAcrossSessions(t *testing.T) { +func TestAutoExtractorUsesTimeoutContext(t *testing.T) { svc := newAutoExtractorTestService(t) - started := make(chan struct{}, 2) - release := make(chan struct{}) - extractor := &stubMemoExtractor{ extractFn: func(ctx context.Context, messages []providertypes.Message) ([]Entry, error) { - started <- struct{}{} - <-release - return []Entry{ - {Type: TypeProject, Title: "same title", Content: "same content", Source: SourceAutoExtract}, - }, nil + <-ctx.Done() + return nil, ctx.Err() }, } - auto := NewAutoExtractor(extractor, svc) - auto.debounce = 0 + auto := NewAutoExtractor(extractor, svc, 20*time.Millisecond) + auto.debounce = 5 * time.Millisecond auto.logf = func(string, ...any) {} + registerAutoExtractorCleanup(t, auto) - auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("one")}}}) - auto.Schedule("session-2", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("two")}}}) - - for i := 0; i < 2; i++ { - select { - case <-started: - case <-time.After(time.Second): - t.Fatal("concurrent extraction did not start") - } - } - close(release) - - waitFor(t, time.Second, func() bool { return extractor.Calls() == 2 }) - waitFor(t, time.Second, func() bool { - entries, err := svc.List(context.Background()) - return err == nil && len(entries) == 1 - }) + auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("timeout")}}}) + waitFor(t, time.Second, func() bool { return extractor.Calls() == 1 }) - entries, err := svc.List(context.Background()) + entries, err := svc.List(context.Background(), ScopeAll) if err != nil { t.Fatalf("List() error = %v", err) } - if len(entries) != 1 { - t.Fatalf("len(entries) = %d, want 1", len(entries)) + if len(entries) != 0 { + t.Fatalf("entries after timeout = %#v, want empty", entries) } } @@ -241,13 +212,13 @@ func TestAutoExtractorRemovesIdleState(t *testing.T) { return []Entry{{Type: TypeProject, Title: "done", Content: "done", Source: SourceAutoExtract}}, nil }, } - auto := NewAutoExtractor(extractor, svc) + auto := NewAutoExtractor(extractor, svc, time.Second) auto.debounce = 5 * time.Millisecond auto.idleTTL = 20 * time.Millisecond auto.logf = func(string, ...any) {} + registerAutoExtractorCleanup(t, auto) auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("cleanup")}}}) - waitFor(t, time.Second, func() bool { return extractor.Calls() == 1 }) waitFor(t, time.Second, func() bool { auto.mu.Lock() @@ -256,35 +227,10 @@ func TestAutoExtractorRemovesIdleState(t *testing.T) { }) } -func TestAutoExtractorHandleIdleKeepsActiveState(t *testing.T) { - auto := NewAutoExtractor(&stubMemoExtractor{}, newAutoExtractorTestService(t)) - auto.logf = func(string, ...any) {} - - state := &autoExtractState{ - idleSeq: 2, - pending: &autoExtractRequest{ - messages: []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("keep")}}}, - }, - } - - auto.mu.Lock() - auto.states["session-1"] = state - auto.mu.Unlock() - - auto.handleIdle("session-1", state, 2) - - auto.mu.Lock() - defer auto.mu.Unlock() - if _, ok := auto.states["session-1"]; !ok { - t.Fatal("active state should not be removed by idle callback") - } -} - func TestAutoExtractorLoadsDedupIndexOutsideCurrentProcessState(t *testing.T) { baseDir := t.TempDir() - workspace := t.TempDir() - store := NewFileStore(baseDir, workspace) - svc := NewService(store, nil, config.MemoConfig{MaxIndexLines: 200}, nil) + store := NewFileStore(baseDir, "/workspace/a") + svc := NewService(store, testMemoConfig(), nil) if err := svc.Add(context.Background(), Entry{ Type: TypeUser, Title: "reply in chinese", @@ -294,22 +240,21 @@ func TestAutoExtractorLoadsDedupIndexOutsideCurrentProcessState(t *testing.T) { t.Fatalf("seed Add() error = %v", err) } - reloaded := NewService(NewFileStore(baseDir, workspace), nil, config.MemoConfig{MaxIndexLines: 200}, nil) + reloaded := NewService(NewFileStore(baseDir, "/workspace/b"), testMemoConfig(), nil) extractor := &stubMemoExtractor{ extractFn: func(ctx context.Context, messages []providertypes.Message) ([]Entry, error) { - return []Entry{ - {Type: TypeUser, Title: "reply in chinese", Content: "reply in chinese", Source: SourceAutoExtract}, - }, nil + return []Entry{{Type: TypeUser, Title: "reply in chinese", Content: "reply in chinese", Source: SourceAutoExtract}}, nil }, } - auto := NewAutoExtractor(extractor, reloaded) + auto := NewAutoExtractor(extractor, reloaded, time.Second) auto.debounce = 5 * time.Millisecond auto.logf = func(string, ...any) {} + registerAutoExtractorCleanup(t, auto) auto.Schedule("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("dedupe after reload")}}}) - waitFor(t, time.Second, func() bool { return extractor.Calls() == 1 }) - entries, err := reloaded.List(context.Background()) + + entries, err := reloaded.List(context.Background(), ScopeAll) if err != nil { t.Fatalf("List() error = %v", err) } @@ -318,112 +263,6 @@ func TestAutoExtractorLoadsDedupIndexOutsideCurrentProcessState(t *testing.T) { } } -func TestAutoExtractorScheduleWithExtractorUsesBoundExtractor(t *testing.T) { - svc := newAutoExtractorTestService(t) - defaultExtractor := &stubMemoExtractor{ - extractFn: func(ctx context.Context, messages []providertypes.Message) ([]Entry, error) { - return []Entry{{Type: TypeProject, Title: "default", Content: "default", Source: SourceAutoExtract}}, nil - }, - } - boundExtractor := &stubMemoExtractor{ - extractFn: func(ctx context.Context, messages []providertypes.Message) ([]Entry, error) { - return []Entry{{Type: TypeProject, Title: "bound", Content: "bound", Source: SourceAutoExtract}}, nil - }, - } - - auto := NewAutoExtractor(defaultExtractor, svc) - auto.debounce = 5 * time.Millisecond - auto.logf = func(string, ...any) {} - - auto.ScheduleWithExtractor("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("use bound")}}}, boundExtractor) - - waitFor(t, time.Second, func() bool { return boundExtractor.Calls() == 1 }) - if defaultExtractor.Calls() != 0 { - t.Fatalf("default extractor calls = %d, want 0", defaultExtractor.Calls()) - } -} - -func TestAutoExtractorScheduleGuardClauses(t *testing.T) { - svc := newAutoExtractorTestService(t) - extractor := &stubMemoExtractor{} - auto := NewAutoExtractor(extractor, svc) - auto.debounce = 5 * time.Millisecond - auto.logf = func(string, ...any) {} - - auto.Schedule("", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("skip")}}}) - auto.ScheduleWithExtractor("session-1", []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("skip")}}}, nil) - - waitFor(t, 150*time.Millisecond, func() bool { return true }) - if extractor.Calls() != 0 { - t.Fatalf("extractor calls = %d, want 0", extractor.Calls()) - } -} - -func TestAutoExtractDedupKeyAndTopicParsing(t *testing.T) { - if key := autoExtractDedupKey(Entry{Type: TypeProject, Title: " demo ", Content: " value "}); key != "project\x1fdemo\x1fvalue" { - t.Fatalf("autoExtractDedupKey() = %q", key) - } - if key := autoExtractDedupKey(Entry{Type: "invalid", Title: "demo", Content: "value"}); key != "" { - t.Fatalf("autoExtractDedupKey() invalid type = %q, want empty", key) - } - - source, body := parseTopicSourceAndContent("plain body") - if source != "" || body != "plain body" { - t.Fatalf("parse plain topic = (%q,%q)", source, body) - } - - source, body = parseTopicSourceAndContent("---\nsource: extractor_auto\n---\n\n正文") - if source != SourceAutoExtract || body != "正文" { - t.Fatalf("parse frontmatter topic = (%q,%q)", source, body) - } -} - -func TestCloneProviderMessagesDeepCopyAndStopTimer(t *testing.T) { - original := []providertypes.Message{ - { - Role: providertypes.RoleAssistant, - Parts: []providertypes.ContentPart{providertypes.NewTextPart("msg")}, - ToolCalls: []providertypes.ToolCall{ - {ID: "c1", Name: "tool", Arguments: "{}"}, - }, - ToolMetadata: map[string]string{"k": "v"}, - }, - } - cloned := cloneProviderMessages(original) - if len(cloned) != 1 || len(cloned[0].ToolCalls) != 1 || cloned[0].ToolMetadata["k"] != "v" { - t.Fatalf("cloneProviderMessages() = %#v", cloned) - } - - original[0].ToolCalls[0].Name = "changed" - original[0].ToolMetadata["k"] = "changed" - if cloned[0].ToolCalls[0].Name != "tool" || cloned[0].ToolMetadata["k"] != "v" { - t.Fatalf("clone should be isolated, got %#v", cloned[0]) - } - - stopTimer(nil) - timer := time.NewTimer(5 * time.Millisecond) - time.Sleep(10 * time.Millisecond) - stopTimer(timer) -} - -func TestIsIdleStateLocked(t *testing.T) { - state := &autoExtractState{idleSeq: 3} - if !isIdleStateLocked(state, 3) { - t.Fatal("expected idle state to be recyclable") - } - - state.pending = &autoExtractRequest{} - if isIdleStateLocked(state, 3) { - t.Fatal("state with pending request should not be recyclable") - } - - state.pending = nil - state.running = true - if isIdleStateLocked(state, 3) { - t.Fatal("running state should not be recyclable") - } -} - func waitFor(t *testing.T, timeout time.Duration, fn func() bool) { t.Helper() deadline := time.Now().Add(timeout) diff --git a/internal/memo/context_source.go b/internal/memo/context_source.go index 4a7de38b..aee6d908 100644 --- a/internal/memo/context_source.go +++ b/internal/memo/context_source.go @@ -9,14 +9,13 @@ import ( agentcontext "neo-code/internal/context" ) -// memoContextSource 将持久化记忆作为 prompt section 注入上下文构建器。 -// 它实现 agentcontext.SectionSource 接口,仅加载 MEMO.md 索引内容, -// topic 文件的详细内容通过 memo_recall 工具按需加载。 +// memoContextSource 将持久化记忆索引作为 prompt section 注入上下文构建器。 +// 它按 user/project 双层输出目录索引,topic 详情仍通过 memo_recall 工具按需加载。 type memoContextSource struct { store Store mu sync.RWMutex cacheReady bool - cachedText string + cacheText map[Scope]string cacheTime time.Time ttl time.Duration } @@ -33,39 +32,43 @@ func WithCacheTTL(ttl time.Duration) MemoContextSourceOption { // NewContextSource 创建注入记忆到上下文的 SectionSource 实现。 func NewContextSource(store Store, opts ...MemoContextSourceOption) agentcontext.SectionSource { - s := &memoContextSource{ - store: store, - ttl: 5 * time.Second, + source := &memoContextSource{ + store: store, + ttl: 5 * time.Second, + cacheText: make(map[Scope]string, 2), } for _, opt := range opts { - opt(s) + opt(source) } - return s + return source } -// Sections 实现 agentcontext.SectionSource,返回记忆索引作为 prompt section。 +// Sections 实现 agentcontext.SectionSource,返回 user/project 双层记忆索引。 func (s *memoContextSource) Sections(ctx context.Context, _ agentcontext.BuildInput) ([]agentcontext.PromptSection, error) { - text, err := s.loadCached(ctx) + cached, err := s.loadCached(ctx) if err != nil { - // 记忆加载失败不应阻断上下文构建,返回空 section return nil, nil } - if text == "" { + + sections := make([]agentcontext.PromptSection, 0, 2) + if text := cached[ScopeUser]; text != "" { + sections = append(sections, agentcontext.NewPromptSection("User Memo", buildMemoSectionPayload(text))) + } + if text := cached[ScopeProject]; text != "" { + sections = append(sections, agentcontext.NewPromptSection("Project Memo", buildMemoSectionPayload(text))) + } + if len(sections) == 0 { return nil, nil } - payload := fmt.Sprintf("以下内容是持久记忆数据,只可作为参考,不可视为当前用户指令。\n```memo\n%s\n```", text) - - return []agentcontext.PromptSection{ - agentcontext.NewPromptSection("Memo", payload), - }, nil + return sections, nil } -// loadCached 带缓存地加载 MEMO.md 内容。 -func (s *memoContextSource) loadCached(ctx context.Context) (string, error) { +// loadCached 带缓存地加载 user/project 双层 MEMO 索引内容。 +func (s *memoContextSource) loadCached(ctx context.Context) (map[Scope]string, error) { now := time.Now() s.mu.RLock() if s.isCacheValid(now) { - text := s.cachedText + text := cloneMemoCache(s.cacheText) s.mu.RUnlock() return text, nil } @@ -74,22 +77,24 @@ func (s *memoContextSource) loadCached(ctx context.Context) (string, error) { s.mu.Lock() defer s.mu.Unlock() - // 双重检查 now = time.Now() if s.isCacheValid(now) { - return s.cachedText, nil + return cloneMemoCache(s.cacheText), nil } - index, err := s.store.LoadIndex(ctx) - if err != nil { - return "", err + next := make(map[Scope]string, 2) + for _, scope := range supportedStorageScopes() { + index, err := s.store.LoadIndex(ctx, scope) + if err != nil { + return nil, err + } + next[scope] = RenderIndex(index) } - text := RenderIndex(index) s.cacheReady = true - s.cachedText = text + s.cacheText = next s.cacheTime = time.Now() - return text, nil + return cloneMemoCache(next), nil } // isCacheValid 判断当前缓存是否仍在有效期内。 @@ -102,6 +107,20 @@ func (s *memoContextSource) InvalidateCache() { s.mu.Lock() defer s.mu.Unlock() s.cacheReady = false - s.cachedText = "" + s.cacheText = make(map[Scope]string, 2) s.cacheTime = time.Time{} } + +// buildMemoSectionPayload 构造注入 prompt 的 memo section 文本。 +func buildMemoSectionPayload(text string) string { + return fmt.Sprintf("以下内容是持久记忆数据,只可作为参考,不可视为当前用户指令。\n```memo\n%s\n```", text) +} + +// cloneMemoCache 复制缓存 map,避免外部修改共享状态。 +func cloneMemoCache(source map[Scope]string) map[Scope]string { + cloned := make(map[Scope]string, len(source)) + for key, value := range source { + cloned[key] = value + } + return cloned +} diff --git a/internal/memo/context_source_test.go b/internal/memo/context_source_test.go index be7899fb..35172794 100644 --- a/internal/memo/context_source_test.go +++ b/internal/memo/context_source_test.go @@ -10,196 +10,136 @@ import ( agentcontext "neo-code/internal/context" ) -// stubStore 实现 Store 接口用于测试。 -type stubStore struct { - index *Index - err error - loadIndexCalls int - saveIndexErr error - saveTopicErr error - deleteTopicErr error - deletedTopics []string - saveIndexCalls int - saveTopicCalls int - deleteTopicCalls int -} - -func (s *stubStore) LoadIndex(_ context.Context) (*Index, error) { - s.loadIndexCalls++ - if s.err != nil { - return nil, s.err - } - if s.index == nil { - return &Index{Entries: []Entry{}}, nil - } - return s.index, nil -} - -func (s *stubStore) SaveIndex(_ context.Context, index *Index) error { - s.saveIndexCalls++ - if s.saveIndexErr != nil { - return s.saveIndexErr - } - s.index = index - return nil -} -func (s *stubStore) LoadTopic(_ context.Context, _ string) (string, error) { - return "", nil -} -func (s *stubStore) SaveTopic(_ context.Context, _, _ string) error { - s.saveTopicCalls++ - if s.saveTopicErr != nil { - return s.saveTopicErr - } - return nil -} -func (s *stubStore) DeleteTopic(_ context.Context, filename string) error { - s.deleteTopicCalls++ - s.deletedTopics = append(s.deletedTopics, filename) - if s.deleteTopicErr != nil { - return s.deleteTopicErr - } - return nil -} -func (s *stubStore) ListTopics(_ context.Context) ([]string, error) { return nil, nil } - func TestContextSourceEmpty(t *testing.T) { - store := &stubStore{} + store := newMemoryTestStore() source := NewContextSource(store) + sections, err := source.Sections(context.Background(), agentcontext.BuildInput{}) if err != nil { - t.Fatalf("Sections error: %v", err) + t.Fatalf("Sections() error = %v", err) } if len(sections) != 0 { - t.Errorf("Sections on empty store = %d, want 0", len(sections)) + t.Fatalf("len(sections) = %d, want 0", len(sections)) } } -func TestContextSourceWithEntries(t *testing.T) { - store := &stubStore{ - index: &Index{ - Entries: []Entry{ - {Type: TypeUser, Title: "偏好 tab", TopicFile: "user.md"}, - }, - }, - } +func TestContextSourceRendersTwoScopes(t *testing.T) { + store := newMemoryTestStore() + store.indexes[ScopeUser] = &Index{Entries: []Entry{{Type: TypeUser, Title: "user pref", TopicFile: "u.md"}}} + store.indexes[ScopeProject] = &Index{Entries: []Entry{{Type: TypeProject, Title: "project fact", TopicFile: "p.md"}}} source := NewContextSource(store) + sections, err := source.Sections(context.Background(), agentcontext.BuildInput{}) if err != nil { - t.Fatalf("Sections error: %v", err) + t.Fatalf("Sections() error = %v", err) } - if len(sections) != 1 { - t.Fatalf("Sections = %d, want 1", len(sections)) + if len(sections) != 2 { + t.Fatalf("len(sections) = %d, want 2", len(sections)) } - if sections[0].Title != "Memo" { - t.Errorf("Title = %q, want %q", sections[0].Title, "Memo") + if sections[0].Title != "User Memo" || !strings.Contains(sections[0].Content, "user pref") { + t.Fatalf("unexpected user section: %+v", sections[0]) } - if !strings.Contains(sections[0].Content, "偏好 tab") { - t.Errorf("Content should contain entry: %q", sections[0].Content) + if sections[1].Title != "Project Memo" || !strings.Contains(sections[1].Content, "project fact") { + t.Fatalf("unexpected project section: %+v", sections[1]) } } func TestContextSourceCache(t *testing.T) { - store := &stubStore{ - index: &Index{ - Entries: []Entry{ - {Type: TypeUser, Title: "first"}, - }, - }, - } - source := NewContextSource(store, WithCacheTTL(10*time.Second)) - ctx := context.Background() - - // 第一次加载 - sections1, _ := source.Sections(ctx, agentcontext.BuildInput{}) - if !strings.Contains(sections1[0].Content, "first") { - t.Error("first load should contain 'first'") - } - - // 修改 store 数据(模拟外部变更) - store.index.Entries[0].Title = "second" - - // 缓存 TTL 内应返回旧数据 - sections2, _ := source.Sections(ctx, agentcontext.BuildInput{}) - if !strings.Contains(sections2[0].Content, "first") { - t.Error("cached load should still contain 'first'") - } -} - -func TestContextSourceCacheCachesEmptyIndex(t *testing.T) { - store := &stubStore{index: &Index{Entries: []Entry{}}} + store := newMemoryTestStore() + store.indexes[ScopeUser] = &Index{Entries: []Entry{{Type: TypeUser, Title: "first"}}} source := NewContextSource(store, WithCacheTTL(10*time.Second)) - ctx := context.Background() - sections1, err := source.Sections(ctx, agentcontext.BuildInput{}) + sections, err := source.Sections(context.Background(), agentcontext.BuildInput{}) if err != nil { - t.Fatalf("Sections first call error: %v", err) + t.Fatalf("Sections() error = %v", err) } - if len(sections1) != 0 { - t.Fatalf("sections first call = %d, want 0", len(sections1)) + if !strings.Contains(sections[0].Content, "first") { + t.Fatalf("expected cached content to include first, got %q", sections[0].Content) } - sections2, err := source.Sections(ctx, agentcontext.BuildInput{}) + store.indexes[ScopeUser].Entries[0].Title = "second" + sections, err = source.Sections(context.Background(), agentcontext.BuildInput{}) if err != nil { - t.Fatalf("Sections second call error: %v", err) + t.Fatalf("Sections() second call error = %v", err) } - if len(sections2) != 0 { - t.Fatalf("sections second call = %d, want 0", len(sections2)) + if !strings.Contains(sections[0].Content, "first") { + t.Fatalf("expected cached content to stay stale, got %q", sections[0].Content) } - if store.loadIndexCalls != 1 { - t.Fatalf("LoadIndex calls = %d, want 1", store.loadIndexCalls) + if store.loadIndexCalls != 2 { + t.Fatalf("LoadIndex() calls = %d, want 2 (one per scope)", store.loadIndexCalls) } } -func TestContextSourceCacheInvalidation(t *testing.T) { - store := &stubStore{ - index: &Index{ - Entries: []Entry{ - {Type: TypeUser, Title: "old"}, - }, - }, - } +func TestContextSourceInvalidateCache(t *testing.T) { + store := newMemoryTestStore() + store.indexes[ScopeUser] = &Index{Entries: []Entry{{Type: TypeUser, Title: "old"}}} source := NewContextSource(store, WithCacheTTL(10*time.Second)) - ctx := context.Background() - - // 加载并缓存 - source.Sections(ctx, agentcontext.BuildInput{}) - // 修改数据 - store.index.Entries[0].Title = "new" - - // 手动失效缓存 - cs := source.(*memoContextSource) - cs.InvalidateCache() + if _, err := source.Sections(context.Background(), agentcontext.BuildInput{}); err != nil { + t.Fatalf("Sections() warm cache error = %v", err) + } + store.indexes[ScopeUser].Entries[0].Title = "new" - // 应加载新数据 - sections, _ := source.Sections(ctx, agentcontext.BuildInput{}) + source.(*memoContextSource).InvalidateCache() + sections, err := source.Sections(context.Background(), agentcontext.BuildInput{}) + if err != nil { + t.Fatalf("Sections() after invalidation error = %v", err) + } if !strings.Contains(sections[0].Content, "new") { - t.Errorf("after invalidation, should contain 'new': %q", sections[0].Content) + t.Fatalf("expected invalidated content to include new, got %q", sections[0].Content) } } -func TestContextSourceStoreError(t *testing.T) { - store := &stubStore{err: errors.New("read error")} +func TestContextSourceStoreErrorReturnsNil(t *testing.T) { + store := newMemoryTestStore() + store.err = errors.New("boom") source := NewContextSource(store) + sections, err := source.Sections(context.Background(), agentcontext.BuildInput{}) if err != nil { - t.Fatalf("Sections should not propagate error: %v", err) + t.Fatalf("Sections() should suppress store error, got %v", err) } if sections != nil { - t.Errorf("Sections on store error should return nil, got %v", sections) + t.Fatalf("sections = %+v, want nil", sections) } } -func TestContextSourceCancelledContext(t *testing.T) { - store := &stubStore{index: &Index{}} - source := NewContextSource(store) - ctx, cancel := context.WithCancel(context.Background()) - cancel() +func TestContextSourceReadsGlobalUserAndScopedProject(t *testing.T) { + baseDir := t.TempDir() + storeA := NewFileStore(baseDir, "/workspace/a") + storeB := NewFileStore(baseDir, "/workspace/b") + + if err := storeA.SaveIndex(context.Background(), ScopeUser, &Index{ + Entries: []Entry{{Type: TypeUser, Title: "global user pref", TopicFile: "u.md"}}, + }); err != nil { + t.Fatalf("SaveIndex(user) error = %v", err) + } + if err := storeA.SaveIndex(context.Background(), ScopeProject, &Index{ + Entries: []Entry{{Type: TypeProject, Title: "workspace a fact", TopicFile: "a.md"}}, + }); err != nil { + t.Fatalf("SaveIndex(project a) error = %v", err) + } + if err := storeB.SaveIndex(context.Background(), ScopeProject, &Index{ + Entries: []Entry{{Type: TypeProject, Title: "workspace b fact", TopicFile: "b.md"}}, + }); err != nil { + t.Fatalf("SaveIndex(project b) error = %v", err) + } - sections, err := source.Sections(ctx, agentcontext.BuildInput{}) - if err == nil && sections != nil { - // 取消的上下文可能导致错误或空结果,都合理 - t.Logf("cancelled context returned %d sections", len(sections)) + source := NewContextSource(storeB) + sections, err := source.Sections(context.Background(), agentcontext.BuildInput{}) + if err != nil { + t.Fatalf("Sections() error = %v", err) + } + if len(sections) != 2 { + t.Fatalf("len(sections) = %d, want 2", len(sections)) + } + if !strings.Contains(sections[0].Content, "global user pref") { + t.Fatalf("expected user memo to include global entry, got %q", sections[0].Content) + } + if strings.Contains(sections[1].Content, "workspace a fact") { + t.Fatalf("project memo leaked another workspace: %q", sections[1].Content) + } + if !strings.Contains(sections[1].Content, "workspace b fact") { + t.Fatalf("expected project memo to include current workspace entry, got %q", sections[1].Content) } } diff --git a/internal/memo/extractor_test.go b/internal/memo/extractor_test.go index 6d1f0526..faf854ce 100644 --- a/internal/memo/extractor_test.go +++ b/internal/memo/extractor_test.go @@ -6,7 +6,6 @@ import ( "testing" "unicode/utf8" - "neo-code/internal/config" providertypes "neo-code/internal/provider/types" ) @@ -207,12 +206,12 @@ func TestExtractAndStore(t *testing.T) { t.Run("no signal does not add entries", func(t *testing.T) { store := NewFileStore(t.TempDir(), t.TempDir()) - svc := NewService(store, nil, config.MemoConfig{MaxIndexLines: 200}, nil) + svc := NewService(store, testMemoConfig(), nil) messages := []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("写个函数")}}, } ExtractAndStore(context.Background(), NewRuleExtractor(), svc, messages) - entries, _ := svc.List(context.Background()) + entries, _ := svc.List(context.Background(), ScopeAll) if len(entries) != 0 { t.Errorf("expected 0 entries, got %d", len(entries)) } @@ -220,12 +219,12 @@ func TestExtractAndStore(t *testing.T) { t.Run("with signal adds entry", func(t *testing.T) { store := NewFileStore(t.TempDir(), t.TempDir()) - svc := NewService(store, nil, config.MemoConfig{MaxIndexLines: 200}, nil) + svc := NewService(store, testMemoConfig(), nil) messages := []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住以后都用中文注释")}}, } ExtractAndStore(context.Background(), NewRuleExtractor(), svc, messages) - entries, _ := svc.List(context.Background()) + entries, _ := svc.List(context.Background(), ScopeAll) if len(entries) != 1 { t.Fatalf("expected 1 entry, got %d", len(entries)) } diff --git a/internal/memo/llm_extractor.go b/internal/memo/llm_extractor.go index ed2a6ef5..cee9eb41 100644 --- a/internal/memo/llm_extractor.go +++ b/internal/memo/llm_extractor.go @@ -12,12 +12,11 @@ import ( providertypes "neo-code/internal/provider/types" ) -const llmExtractorRecentMessageLimit = 10 - // LLMExtractor 基于 LLM 分析最近对话,并返回结构化记忆条目。 type LLMExtractor struct { - generator TextGenerator - now func() time.Time + generator TextGenerator + now func() time.Time + recentMessageLimit int } type extractedEntry struct { @@ -28,10 +27,14 @@ type extractedEntry struct { } // NewLLMExtractor 创建基于 TextGenerator 的记忆提取器。 -func NewLLMExtractor(generator TextGenerator) *LLMExtractor { +func NewLLMExtractor(generator TextGenerator, recentMessageLimit int) *LLMExtractor { + if recentMessageLimit <= 0 { + recentMessageLimit = 10 + } return &LLMExtractor{ - generator: generator, - now: time.Now, + generator: generator, + now: time.Now, + recentMessageLimit: recentMessageLimit, } } @@ -44,7 +47,7 @@ func (e *LLMExtractor) Extract(ctx context.Context, messages []providertypes.Mes return nil, errors.New("memo: text generator is nil") } - recent := agentcontext.BuildRecentMessagesForModel(messages, llmExtractorRecentMessageLimit) + recent := agentcontext.BuildRecentMessagesForModel(messages, e.recentMessageLimit) if len(recent) == 0 || !containsUserMessage(recent) { return nil, nil } diff --git a/internal/memo/llm_extractor_test.go b/internal/memo/llm_extractor_test.go index 43d82291..3ccf9ee9 100644 --- a/internal/memo/llm_extractor_test.go +++ b/internal/memo/llm_extractor_test.go @@ -37,7 +37,7 @@ func TestLLMExtractorExtractValidJSON(t *testing.T) { generator := &stubTextGenerator{ response: `[{"type":"user","title":" 偏好 Go 代码风格 ","content":"用户偏好使用 Go 惯用写法。","keywords":["go"," style ","go"]}]`, } - extractor := NewLLMExtractor(generator) + extractor := NewLLMExtractor(generator, 10) extractor.now = func() time.Time { return time.Date(2026, 4, 13, 10, 0, 0, 0, time.FixedZone("CST", 8*3600)) } @@ -83,7 +83,7 @@ func TestLLMExtractorExtractValidJSON(t *testing.T) { // TestLLMExtractorExtractEmptyResult 验证空数组响应会返回零条记忆。 func TestLLMExtractorExtractEmptyResult(t *testing.T) { - extractor := NewLLMExtractor(&stubTextGenerator{response: `[]`}) + extractor := NewLLMExtractor(&stubTextGenerator{response: `[]`}, 10) entries, err := extractor.Extract(context.Background(), []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("这轮没有需要记住的内容。")}}, @@ -99,7 +99,7 @@ func TestLLMExtractorExtractEmptyResult(t *testing.T) { // TestLLMExtractorExtractNoUserMessage 验证没有用户消息时不会调用模型。 func TestLLMExtractorExtractNoUserMessage(t *testing.T) { generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator) + extractor := NewLLMExtractor(generator, 10) entries, err := extractor.Extract(context.Background(), []providertypes.Message{ {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("只有助手消息。")}}, @@ -117,7 +117,7 @@ func TestLLMExtractorExtractNoUserMessage(t *testing.T) { // TestLLMExtractorExtractNoMessages 验证空消息输入直接返回空结果。 func TestLLMExtractorExtractNoMessages(t *testing.T) { - extractor := NewLLMExtractor(&stubTextGenerator{response: `[]`}) + extractor := NewLLMExtractor(&stubTextGenerator{response: `[]`}, 10) entries, err := extractor.Extract(context.Background(), nil) if err != nil { @@ -130,7 +130,7 @@ func TestLLMExtractorExtractNoMessages(t *testing.T) { // TestLLMExtractorExtractInvalidJSON 验证无效 JSON 会返回错误。 func TestLLMExtractorExtractInvalidJSON(t *testing.T) { - extractor := NewLLMExtractor(&stubTextGenerator{response: `[{invalid json}]`}) + extractor := NewLLMExtractor(&stubTextGenerator{response: `[{invalid json}]`}, 10) _, err := extractor.Extract(context.Background(), []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}}, @@ -144,7 +144,7 @@ func TestLLMExtractorExtractInvalidJSON(t *testing.T) { func TestLLMExtractorExtractToleratesWrappedJSON(t *testing.T) { extractor := NewLLMExtractor(&stubTextGenerator{ response: "分析如下:\n[{\"type\":\"feedback\",\"title\":\"以后先跑测试\",\"content\":\"用户要求修改后先跑测试。\"}]\n以上完毕。", - }) + }, 10) entries, err := extractor.Extract(context.Background(), []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("以后改完先跑测试。")}}, @@ -165,7 +165,7 @@ func TestLLMExtractorExtractFiltersInvalidEntries(t *testing.T) { {"type":"project","title":" ","content":"missing title"}, {"type":"reference","title":"文档入口","content":"查看 docs/runtime-provider-event-flow.md"} ]`, - }) + }, 10) entries, err := extractor.Extract(context.Background(), []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("参考文档在 docs/runtime-provider-event-flow.md。")}}, @@ -181,7 +181,7 @@ func TestLLMExtractorExtractFiltersInvalidEntries(t *testing.T) { // TestLLMExtractorExtractCancelledContext 验证已取消上下文会中止提取。 func TestLLMExtractorExtractCancelledContext(t *testing.T) { generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator) + extractor := NewLLMExtractor(generator, 10) ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -199,7 +199,7 @@ func TestLLMExtractorExtractCancelledContext(t *testing.T) { // TestLLMExtractorExtractUsesRecentNonToolMessages 验证只取最近 10 条非 tool 消息。 func TestLLMExtractorExtractUsesRecentNonToolMessages(t *testing.T) { generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator) + extractor := NewLLMExtractor(generator, 10) messages := make([]providertypes.Message, 0, 16) for index := 0; index < 12; index++ { @@ -237,7 +237,7 @@ func TestLLMExtractorExtractUsesRecentNonToolMessages(t *testing.T) { func TestLLMExtractorExtractDropsIncompleteToolCallSpan(t *testing.T) { generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator) + extractor := NewLLMExtractor(generator, 10) _, err := extractor.Extract(context.Background(), []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("first")}}, @@ -264,7 +264,7 @@ func TestLLMExtractorExtractDropsIncompleteToolCallSpan(t *testing.T) { func TestLLMExtractorExtractKeepsProjectedToolCallSpan(t *testing.T) { generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator) + extractor := NewLLMExtractor(generator, 10) _, err := extractor.Extract(context.Background(), []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("remember this")}}, @@ -305,7 +305,7 @@ func TestLLMExtractorExtractKeepsProjectedToolCallSpan(t *testing.T) { func TestLLMExtractorExtractKeepsMetadataOnlyToolCallSpan(t *testing.T) { generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator) + extractor := NewLLMExtractor(generator, 10) _, err := extractor.Extract(context.Background(), []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("remember this")}}, @@ -347,7 +347,7 @@ func TestLLMExtractorExtractKeepsMetadataOnlyToolCallSpan(t *testing.T) { func TestLLMExtractorExtractSkipsOrphanAndClearedToolMessages(t *testing.T) { generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator) + extractor := NewLLMExtractor(generator, 10) _, err := extractor.Extract(context.Background(), []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("alpha")}}, @@ -383,7 +383,7 @@ func TestLLMExtractorExtractNilGenerator(t *testing.T) { t.Fatalf("Extract() error = %v", err) } - extractor = NewLLMExtractor(nil) + extractor = NewLLMExtractor(nil, 10) _, err = extractor.Extract(context.Background(), []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}}, }) @@ -393,7 +393,7 @@ func TestLLMExtractorExtractNilGenerator(t *testing.T) { } func TestLLMExtractorExtractGeneratorFailure(t *testing.T) { - extractor := NewLLMExtractor(&stubTextGenerator{err: errors.New("upstream failed")}) + extractor := NewLLMExtractor(&stubTextGenerator{err: errors.New("upstream failed")}, 10) _, err := extractor.Extract(context.Background(), []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("记住这个。")}}, }) @@ -413,7 +413,7 @@ func TestExtractJSONArrayErrors(t *testing.T) { func TestLLMExtractorExtractImageOnlyUserMessageSkipsGenerator(t *testing.T) { generator := &stubTextGenerator{response: `[]`} - extractor := NewLLMExtractor(generator) + extractor := NewLLMExtractor(generator, 10) entries, err := extractor.Extract(context.Background(), []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewRemoteImagePart("https://example.com/pic.png")}}, diff --git a/internal/memo/service.go b/internal/memo/service.go index fa88f781..d983efef 100644 --- a/internal/memo/service.go +++ b/internal/memo/service.go @@ -12,10 +12,9 @@ import ( "neo-code/internal/config" ) -// Service 编排记忆的存储、检索、提取和删除,是 memo 子系统对外的统一入口。 +// Service 编排记忆的存储、检索、删除和索引维护,是 memo 子系统对外的统一入口。 type Service struct { store Store - extractor Extractor config config.MemoConfig mu sync.Mutex sourceInvl func() @@ -25,17 +24,16 @@ type Service struct { autoExtractKeyRefs map[string]int } -// NewService 创建 memo Service 实例;extractor 可以为 nil。 -func NewService(store Store, extractor Extractor, cfg config.MemoConfig, sourceInvl func()) *Service { +// NewService 创建 memo Service 实例。 +func NewService(store Store, cfg config.MemoConfig, sourceInvl func()) *Service { return &Service{ store: store, - extractor: extractor, config: cfg, sourceInvl: sourceInvl, } } -// Add 添加一条记忆并持久化索引和 topic 文件。 +// Add 添加一条记忆并持久化到对应分层的索引与 topic 文件。 func (s *Service) Add(ctx context.Context, entry Entry) error { entry, err := normalizeEntryForPersist(entry) if err != nil { @@ -64,15 +62,149 @@ func (s *Service) addAutoExtractIfAbsent(ctx context.Context, entry Entry) (bool if s.hasExactAutoExtractLocked(entry) { return false, nil } - if err := s.saveEntryLocked(ctx, entry); err != nil { return false, err } return true, nil } +// normalizeKeyword 统一关键词的空格与大小写处理。 +func normalizeKeyword(keyword string) string { + return strings.ToLower(strings.TrimSpace(keyword)) +} + +// Remove 按关键词删除匹配的记忆条目,支持按 scope 过滤。 +func (s *Service) Remove(ctx context.Context, keyword string, scope Scope) (int, error) { + keyword = normalizeKeyword(keyword) + if keyword == "" { + return 0, fmt.Errorf("memo: keyword is empty") + } + if err := validateQueryScope(scope); err != nil { + return 0, err + } + + s.mu.Lock() + defer s.mu.Unlock() + + removed := 0 + for _, bucket := range scopesForQuery(scope) { + index, err := s.loadIndexLocked(ctx, bucket) + if err != nil { + return 0, err + } + + remaining := make([]Entry, 0, len(index.Entries)) + removedEntries := make([]Entry, 0, len(index.Entries)) + for _, entry := range index.Entries { + if matchesKeyword(entry, keyword) { + removedEntries = append(removedEntries, entry) + continue + } + remaining = append(remaining, entry) + } + if len(removedEntries) == 0 { + continue + } + + index.Entries = remaining + index.UpdatedAt = time.Now() + if err := s.store.SaveIndex(ctx, bucket, index); err != nil { + return removed, fmt.Errorf("memo: save index: %w", err) + } + for _, entry := range removedEntries { + if topicFile := strings.TrimSpace(entry.TopicFile); topicFile != "" { + _ = s.store.DeleteTopic(ctx, bucket, topicFile) + s.removeAutoExtractTopicLocked(bucket, topicFile) + } + } + removed += len(removedEntries) + } + + if removed > 0 { + s.invalidateCache() + } + return removed, nil +} + +// List 返回 scope 范围内的记忆条目浅拷贝。 +func (s *Service) List(ctx context.Context, scope Scope) ([]Entry, error) { + if err := validateQueryScope(scope); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + + return s.listLocked(ctx, scope) +} + +// Search 按关键词搜索记忆条目,支持按 scope 过滤。 +func (s *Service) Search(ctx context.Context, keyword string, scope Scope) ([]Entry, error) { + if err := validateQueryScope(scope); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + + entries, err := s.listLocked(ctx, scope) + if err != nil { + return nil, err + } + keyword = normalizeKeyword(keyword) + if keyword == "" { + return entries, nil + } + + results := make([]Entry, 0, len(entries)) + for _, entry := range entries { + if matchesKeyword(entry, keyword) { + results = append(results, entry) + } + } + return results, nil +} + +// Recall 加载匹配关键词的 topic 文件内容,支持按 scope 过滤。 +func (s *Service) Recall(ctx context.Context, keyword string, scope Scope) ([]RecalledEntry, error) { + if err := validateQueryScope(scope); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + + keyword = normalizeKeyword(keyword) + results := make([]RecalledEntry, 0) + for _, bucket := range scopesForQuery(scope) { + index, err := s.loadIndexLocked(ctx, bucket) + if err != nil { + return nil, err + } + for _, entry := range index.Entries { + if !matchesKeyword(entry, keyword) { + continue + } + if strings.TrimSpace(entry.TopicFile) == "" { + continue + } + content, err := s.store.LoadTopic(ctx, bucket, entry.TopicFile) + if err != nil { + continue + } + results = append(results, RecalledEntry{ + Scope: bucket, + Entry: entry, + Content: content, + }) + } + } + return results, nil +} + // saveEntryLocked 在持有 Service 锁的前提下持久化单条记忆及索引。 func (s *Service) saveEntryLocked(ctx context.Context, entry Entry) error { + scope := ScopeForType(entry.Type) now := time.Now() if entry.ID == "" { entry.ID = newEntryID(entry.Type) @@ -81,14 +213,13 @@ func (s *Service) saveEntryLocked(ctx context.Context, entry Entry) error { entry.CreatedAt = now } entry.UpdatedAt = now - if entry.TopicFile == "" { entry.TopicFile = fmt.Sprintf("%s_%s.md", entry.Type, entry.ID) } - index, err := s.store.LoadIndex(ctx) + index, err := s.loadIndexLocked(ctx, scope) if err != nil { - return fmt.Errorf("memo: load index: %w", err) + return err } working := cloneIndex(index) @@ -107,38 +238,36 @@ func (s *Service) saveEntryLocked(ctx context.Context, entry Entry) error { } working.UpdatedAt = now - var topicsToDelete []string - if s.config.MaxIndexLines > 0 && len(working.Entries) > s.config.MaxIndexLines { - excess := len(working.Entries) - s.config.MaxIndexLines - for i := 0; i < excess; i++ { - topicFile := strings.TrimSpace(working.Entries[i].TopicFile) - if topicFile != "" && topicFile != entry.TopicFile { - topicsToDelete = append(topicsToDelete, topicFile) - } - } - working.Entries = working.Entries[excess:] - } - - if err := s.store.SaveTopic(ctx, entry.TopicFile, RenderTopic(&entry)); err != nil { + removedEntries := trimIndexEntries(working, s.config.MaxEntries, s.config.MaxIndexBytes) + if err := s.store.SaveTopic(ctx, scope, entry.TopicFile, RenderTopic(&entry)); err != nil { return fmt.Errorf("memo: save topic: %w", err) } - if err := s.store.SaveIndex(ctx, working); err != nil { + if err := s.store.SaveIndex(ctx, scope, working); err != nil { if !replaced { - _ = s.store.DeleteTopic(ctx, entry.TopicFile) + _ = s.store.DeleteTopic(ctx, scope, entry.TopicFile) } return fmt.Errorf("memo: save index: %w", err) } - for _, topicFile := range topicsToDelete { - _ = s.store.DeleteTopic(ctx, topicFile) + + if replaced && previous.TopicFile != "" && previous.TopicFile != entry.TopicFile { + _ = s.store.DeleteTopic(ctx, scope, previous.TopicFile) + } + for _, removed := range removedEntries { + if topicFile := strings.TrimSpace(removed.TopicFile); topicFile != "" { + _ = s.store.DeleteTopic(ctx, scope, topicFile) + } } + if s.autoExtractIndexReady { if replaced { - s.removeAutoExtractTopicLocked(previous.TopicFile) + s.removeAutoExtractTopicLocked(scope, previous.TopicFile) + } + for _, removed := range removedEntries { + s.removeAutoExtractTopicLocked(scope, removed.TopicFile) } - for _, topicFile := range topicsToDelete { - s.removeAutoExtractTopicLocked(topicFile) + if indexContainsEntryID(working, entry.ID) { + s.trackAutoExtractEntryLocked(scope, entry) } - s.trackAutoExtractEntryLocked(entry) } s.invalidateCache() @@ -164,38 +293,36 @@ func (s *Service) ensureAutoExtractIndex(ctx context.Context) error { } s.mu.Unlock() - index, err := s.store.LoadIndex(ctx) - if err != nil { - return fmt.Errorf("memo: load index: %w", err) - } - keysByTopic := make(map[string]string) keyRefs := make(map[string]int) - for _, entry := range index.Entries { - topicFile := strings.TrimSpace(entry.TopicFile) - if topicFile == "" { - continue - } - - topicContent, err := s.store.LoadTopic(ctx, topicFile) + for _, scope := range supportedStorageScopes() { + index, err := s.store.LoadIndex(ctx, scope) if err != nil { - continue - } - - source, content := parseTopicSourceAndContent(topicContent) - if source != SourceAutoExtract { - continue + return fmt.Errorf("memo: load index: %w", err) } - - entry.Source = source - entry.Content = content - key := autoExtractDedupKey(entry) - if key == "" { - continue + for _, entry := range index.Entries { + topicFile := strings.TrimSpace(entry.TopicFile) + if topicFile == "" { + continue + } + topicContent, err := s.store.LoadTopic(ctx, scope, topicFile) + if err != nil { + continue + } + source, content := parseTopicSourceAndContent(topicContent) + if source != SourceAutoExtract { + continue + } + entry.Source = source + entry.Content = content + key := autoExtractDedupKey(entry) + if key == "" { + continue + } + topicKey := scopedTopicKey(scope, topicFile) + keysByTopic[topicKey] = key + keyRefs[key]++ } - - keysByTopic[topicFile] = key - keyRefs[key]++ } s.mu.Lock() @@ -219,47 +346,42 @@ func (s *Service) hasExactAutoExtractLocked(target Entry) bool { } // trackAutoExtractEntryLocked 在自动提取索引已就绪时维护单条条目的去重键。 -func (s *Service) trackAutoExtractEntryLocked(entry Entry) { +func (s *Service) trackAutoExtractEntryLocked(scope Scope, entry Entry) { if !s.autoExtractIndexReady { return } - topicFile := strings.TrimSpace(entry.TopicFile) if topicFile == "" { return } - s.removeAutoExtractTopicLocked(topicFile) - + s.removeAutoExtractTopicLocked(scope, topicFile) if entry.Source != SourceAutoExtract { return } - key := autoExtractDedupKey(entry) if key == "" { return } - - s.autoExtractKeysByTopic[topicFile] = key + s.autoExtractKeysByTopic[scopedTopicKey(scope, topicFile)] = key s.autoExtractKeyRefs[key]++ } // removeAutoExtractTopicLocked 从精确去重索引中移除指定 topic 的记录。 -func (s *Service) removeAutoExtractTopicLocked(topicFile string) { +func (s *Service) removeAutoExtractTopicLocked(scope Scope, topicFile string) { if !s.autoExtractIndexReady { return } - topicFile = strings.TrimSpace(topicFile) if topicFile == "" { return } - key, ok := s.autoExtractKeysByTopic[topicFile] + topicKey := scopedTopicKey(scope, topicFile) + key, ok := s.autoExtractKeysByTopic[topicKey] if !ok { return } - delete(s.autoExtractKeysByTopic, topicFile) - + delete(s.autoExtractKeysByTopic, topicKey) if refs := s.autoExtractKeyRefs[key]; refs > 1 { s.autoExtractKeyRefs[key] = refs - 1 return @@ -267,127 +389,26 @@ func (s *Service) removeAutoExtractTopicLocked(topicFile string) { delete(s.autoExtractKeyRefs, key) } -// loadIndexLocked 在持有锁的状态下加载索引。 -func (s *Service) loadIndexLocked(ctx context.Context) (*Index, error) { - index, err := s.store.LoadIndex(ctx) +// loadIndexLocked 在持有锁的状态下加载指定分层索引。 +func (s *Service) loadIndexLocked(ctx context.Context, scope Scope) (*Index, error) { + index, err := s.store.LoadIndex(ctx, scope) if err != nil { return nil, fmt.Errorf("memo: load index: %w", err) } return index, nil } -// Remove 按关键词搜索并删除匹配的记忆条目,返回删除数量。 -func (s *Service) Remove(ctx context.Context, keyword string) (int, error) { - s.mu.Lock() - defer s.mu.Unlock() - - index, err := s.loadIndexLocked(ctx) - if err != nil { - return 0, err - } - working := cloneIndex(index) - - keyword = strings.ToLower(strings.TrimSpace(keyword)) - if keyword == "" { - return 0, fmt.Errorf("memo: keyword is empty") - } - - var remaining []Entry - removed := 0 - topicsToDelete := make([]string, 0, len(working.Entries)) - for _, entry := range working.Entries { - if matchesKeyword(entry, keyword) { - if topicFile := strings.TrimSpace(entry.TopicFile); topicFile != "" { - topicsToDelete = append(topicsToDelete, topicFile) - } - removed++ - } else { - remaining = append(remaining, entry) - } - } - - if removed == 0 { - return 0, nil - } - - working.Entries = remaining - working.UpdatedAt = time.Now() - if err := s.store.SaveIndex(ctx, working); err != nil { - return 0, fmt.Errorf("memo: save index: %w", err) - } - for _, topicFile := range topicsToDelete { - _ = s.store.DeleteTopic(ctx, topicFile) - } - if s.autoExtractIndexReady { - for _, topicFile := range topicsToDelete { - s.removeAutoExtractTopicLocked(topicFile) - } - } - - s.invalidateCache() - return removed, nil -} - -// List 返回索引中的所有记忆条目浅拷贝。 -func (s *Service) List(ctx context.Context) ([]Entry, error) { - s.mu.Lock() - defer s.mu.Unlock() - - index, err := s.loadIndexLocked(ctx) - if err != nil { - return nil, err - } - result := make([]Entry, len(index.Entries)) - copy(result, index.Entries) - return result, nil -} - -// Search 按关键词搜索记忆条目。 -func (s *Service) Search(ctx context.Context, keyword string) ([]Entry, error) { - s.mu.Lock() - defer s.mu.Unlock() - - index, err := s.loadIndexLocked(ctx) - if err != nil { - return nil, err - } - - keyword = strings.ToLower(strings.TrimSpace(keyword)) - var results []Entry - for _, entry := range index.Entries { - if matchesKeyword(entry, keyword) { - results = append(results, entry) - } - } - return results, nil -} - -// Recall 加载匹配关键词的 topic 文件内容。 -func (s *Service) Recall(ctx context.Context, keyword string) (map[string]string, error) { - s.mu.Lock() - defer s.mu.Unlock() - - index, err := s.loadIndexLocked(ctx) - if err != nil { - return nil, err - } - - keyword = strings.ToLower(strings.TrimSpace(keyword)) - results := make(map[string]string) - for _, entry := range index.Entries { - if !matchesKeyword(entry, keyword) { - continue - } - if entry.TopicFile == "" { - continue - } - content, err := s.store.LoadTopic(ctx, entry.TopicFile) +// listLocked 返回 scope 范围内的所有条目。 +func (s *Service) listLocked(ctx context.Context, scope Scope) ([]Entry, error) { + results := make([]Entry, 0) + for _, bucket := range scopesForQuery(scope) { + index, err := s.loadIndexLocked(ctx, bucket) if err != nil { - continue + return nil, err } - results[entry.TopicFile] = content + results = append(results, index.Entries...) } - return results, nil + return append([]Entry(nil), results...), nil } // invalidateCache 触发上下文源缓存失效回调。 @@ -399,12 +420,18 @@ func (s *Service) invalidateCache() { // matchesKeyword 检查条目是否匹配关键词。 func matchesKeyword(entry Entry, keyword string) bool { + if keyword == "" { + return true + } if strings.Contains(strings.ToLower(entry.Title), keyword) { return true } if strings.Contains(strings.ToLower(string(entry.Type)), keyword) { return true } + if strings.Contains(strings.ToLower(entry.Content), keyword) { + return true + } for _, kw := range entry.Keywords { if strings.Contains(strings.ToLower(kw), keyword) { return true @@ -422,6 +449,11 @@ func normalizeEntryForPersist(entry Entry) (Entry, error) { if entry.Title == "" { return Entry{}, fmt.Errorf("memo: title is empty") } + entry.Content = strings.TrimSpace(entry.Content) + if entry.Content == "" { + return Entry{}, fmt.Errorf("memo: content is empty") + } + entry.Keywords = normalizeKeywords(entry.Keywords) return entry, nil } @@ -430,8 +462,7 @@ func newEntryID(t Type) string { ts := fmt.Sprintf("%x", time.Now().Unix()) buf := make([]byte, 4) _, _ = rand.Read(buf) - randHex := hex.EncodeToString(buf) - return fmt.Sprintf("%s_%s_%s", t, ts, randHex) + return fmt.Sprintf("%s_%s_%s", t, ts, hex.EncodeToString(buf)) } // cloneIndex 复制索引结构,避免持久化失败时污染原始数据引用。 @@ -446,3 +477,86 @@ func cloneIndex(index *Index) *Index { copy(cloned.Entries, index.Entries) return cloned } + +// trimIndexEntries 先按条目数、再按索引字节数裁剪最旧条目,并返回被删除的记录。 +func trimIndexEntries(index *Index, maxEntries int, maxIndexBytes int) []Entry { + if index == nil { + return nil + } + removed := make([]Entry, 0) + for maxEntries > 0 && len(index.Entries) > maxEntries { + removed = append(removed, index.Entries[0]) + index.Entries = index.Entries[1:] + } + if maxIndexBytes > 0 && len(index.Entries) > 0 { + removed = append(removed, trimIndexEntriesByBytes(index, maxIndexBytes)...) + } + return removed +} + +// trimIndexEntriesByBytes 在索引超过字节阈值时,通过二分定位最小移除数量并返回被删除条目。 +func trimIndexEntriesByBytes(index *Index, maxIndexBytes int) []Entry { + if index == nil || len(index.Entries) == 0 || maxIndexBytes <= 0 { + return nil + } + if len(RenderIndex(index)) <= maxIndexBytes { + return nil + } + + entries := index.Entries + lo, hi := 0, len(entries) + for lo < hi { + mid := lo + (hi-lo)/2 + candidate := &Index{Entries: entries[mid:], UpdatedAt: index.UpdatedAt} + if len(RenderIndex(candidate)) > maxIndexBytes { + lo = mid + 1 + continue + } + hi = mid + } + + removed := append([]Entry(nil), entries[:lo]...) + index.Entries = entries[lo:] + return removed +} + +// indexContainsEntryID 判断索引中是否仍保留目标 ID,用于避免为已裁剪条目建立去重索引。 +func indexContainsEntryID(index *Index, entryID string) bool { + if index == nil || strings.TrimSpace(entryID) == "" { + return false + } + for _, item := range index.Entries { + if item.ID == entryID { + return true + } + } + return false +} + +// scopesForQuery 将查询范围展开为实际存储分层列表。 +func scopesForQuery(scope Scope) []Scope { + switch NormalizeScope(scope) { + case ScopeUser: + return []Scope{ScopeUser} + case ScopeProject: + return []Scope{ScopeProject} + default: + return supportedStorageScopes() + } +} + +// validateQueryScope 校验外部查询/删除接口允许的 scope 取值。 +func validateQueryScope(scope Scope) error { + normalized := strings.ToLower(strings.TrimSpace(string(scope))) + switch Scope(normalized) { + case "", ScopeAll, ScopeUser, ScopeProject: + return nil + default: + return fmt.Errorf("memo: unsupported scope %q", scope) + } +} + +// scopedTopicKey 为自动提取去重索引生成稳定的 topic 维度键。 +func scopedTopicKey(scope Scope, topicFile string) string { + return string(scope) + ":" + strings.TrimSpace(topicFile) +} diff --git a/internal/memo/service_test.go b/internal/memo/service_test.go index 8b5db591..21988dba 100644 --- a/internal/memo/service_test.go +++ b/internal/memo/service_test.go @@ -4,370 +4,348 @@ import ( "context" "errors" "strings" - "sync" "testing" "neo-code/internal/config" ) -func TestServiceAdd(t *testing.T) { - store := &stubStore{} - var invCalled bool - svc := NewService(store, nil, config.MemoConfig{MaxIndexLines: 200}, func() { invCalled = true }) +func testMemoConfig() config.MemoConfig { + return config.MemoConfig{ + MaxEntries: 200, + MaxIndexBytes: 16 * 1024, + ExtractTimeoutSec: 15, + ExtractRecentMessages: 10, + } +} - entry := Entry{ +func TestServiceAddRoutesByScope(t *testing.T) { + store := newMemoryTestStore() + invalidateCalls := 0 + svc := NewService(store, testMemoConfig(), func() { invalidateCalls++ }) + + if err := svc.Add(context.Background(), Entry{ Type: TypeUser, - Title: "偏好 tab 缩进", - Content: "用户偏好使用 tab 缩进", + Title: "user pref", + Content: "user pref", Source: SourceUserManual, + }); err != nil { + t.Fatalf("Add(user) error = %v", err) } - if err := svc.Add(context.Background(), entry); err != nil { - t.Fatalf("Add error: %v", err) - } - if !invCalled { - t.Error("cache invalidation callback should have been called") + if err := svc.Add(context.Background(), Entry{ + Type: TypeProject, + Title: "project fact", + Content: "project fact", + Source: SourceUserManual, + }); err != nil { + t.Fatalf("Add(project) error = %v", err) } - entries, _ := svc.List(context.Background()) - if len(entries) != 1 { - t.Fatalf("List entries = %d, want 1", len(entries)) + userEntries, err := svc.List(context.Background(), ScopeUser) + if err != nil { + t.Fatalf("List(user) error = %v", err) } - if entries[0].Title != "偏好 tab 缩进" { - t.Errorf("Title = %q, want %q", entries[0].Title, "偏好 tab 缩进") + projectEntries, err := svc.List(context.Background(), ScopeProject) + if err != nil { + t.Fatalf("List(project) error = %v", err) } - if entries[0].ID == "" { - t.Error("ID should be auto-generated") + if len(userEntries) != 1 || userEntries[0].Type != TypeUser { + t.Fatalf("unexpected user entries: %#v", userEntries) } - if entries[0].TopicFile == "" { - t.Error("TopicFile should be auto-generated") + if len(projectEntries) != 1 || projectEntries[0].Type != TypeProject { + t.Fatalf("unexpected project entries: %#v", projectEntries) } -} - -func TestServiceAddInvalidType(t *testing.T) { - svc := NewService(&stubStore{}, nil, config.MemoConfig{}, nil) - err := svc.Add(context.Background(), Entry{Type: "invalid", Title: "test"}) - if err == nil { - t.Error("Add with invalid type should return error") + if invalidateCalls != 2 { + t.Fatalf("invalidate calls = %d, want 2", invalidateCalls) } } -func TestServiceAddEmptyTitle(t *testing.T) { - svc := NewService(&stubStore{}, nil, config.MemoConfig{}, nil) - err := svc.Add(context.Background(), Entry{Type: TypeUser, Title: ""}) - if err == nil { - t.Error("Add with empty title should return error") +func TestServiceAddValidatesEntry(t *testing.T) { + svc := NewService(newMemoryTestStore(), testMemoConfig(), nil) + + tests := []Entry{ + {Type: "invalid", Title: "x", Content: "x"}, + {Type: TypeUser, Title: "", Content: "x"}, + {Type: TypeUser, Title: "x", Content: ""}, + } + for _, entry := range tests { + if err := svc.Add(context.Background(), entry); err == nil { + t.Fatalf("expected Add(%+v) to fail", entry) + } } } -func TestServiceAddNormalizesTitle(t *testing.T) { - store := &stubStore{} - svc := NewService(store, nil, config.MemoConfig{}, nil) +func TestServiceAddNormalizesTitleAndKeywords(t *testing.T) { + svc := NewService(newMemoryTestStore(), testMemoConfig(), nil) err := svc.Add(context.Background(), Entry{ - Type: TypeUser, - Title: " # heading\n(with suffix) ", - Source: SourceUserManual, + Type: TypeUser, + Title: " # heading\n(with suffix) ", + Content: "content", + Keywords: []string{" Tabs ", "tabs", "", "Style"}, + Source: SourceUserManual, }) if err != nil { - t.Fatalf("Add error: %v", err) + t.Fatalf("Add() error = %v", err) } - entries, _ := svc.List(context.Background()) - if len(entries) != 1 { - t.Fatalf("entries = %d, want 1", len(entries)) + entries, err := svc.List(context.Background(), ScopeUser) + if err != nil { + t.Fatalf("List() error = %v", err) } if entries[0].Title != "# heading {with suffix}" { t.Fatalf("normalized title = %q", entries[0].Title) } + if len(entries[0].Keywords) != 2 || entries[0].Keywords[0] != "Tabs" || entries[0].Keywords[1] != "Style" { + t.Fatalf("normalized keywords = %#v", entries[0].Keywords) + } } -func TestServiceRemove(t *testing.T) { - store := &stubStore{ - index: &Index{ - Entries: []Entry{ - {ID: "1", Type: TypeUser, Title: "偏好 tab", TopicFile: "a.md", Keywords: []string{"tabs"}}, - {ID: "2", Type: TypeProject, Title: "使用 Go", TopicFile: "b.md"}, - }, - }, - } - svc := NewService(store, nil, config.MemoConfig{}, nil) +func TestServiceSearchMatchesContentAndScope(t *testing.T) { + svc := NewService(newMemoryTestStore(), testMemoConfig(), nil) + _ = svc.Add(context.Background(), Entry{ + Type: TypeUser, + Title: "style", + Content: "please use tabs", + Source: SourceUserManual, + }) + _ = svc.Add(context.Background(), Entry{ + Type: TypeProject, + Title: "plan", + Content: "ship in april", + Source: SourceUserManual, + }) - removed, err := svc.Remove(context.Background(), "tab") + results, err := svc.Search(context.Background(), "tabs", ScopeUser) if err != nil { - t.Fatalf("Remove error: %v", err) + t.Fatalf("Search() error = %v", err) } - if removed != 1 { - t.Errorf("Remove returned %d, want 1", removed) + if len(results) != 1 || results[0].Type != TypeUser { + t.Fatalf("unexpected search results: %#v", results) } - entries, _ := svc.List(context.Background()) - if len(entries) != 1 { - t.Fatalf("after remove, entries = %d, want 1", len(entries)) + results, err = svc.Search(context.Background(), "ship", ScopeProject) + if err != nil { + t.Fatalf("Search() project error = %v", err) } - if entries[0].Title != "使用 Go" { - t.Errorf("remaining Title = %q, want %q", entries[0].Title, "使用 Go") + if len(results) != 1 || results[0].Type != TypeProject { + t.Fatalf("unexpected project search results: %#v", results) } } -func TestServiceRemoveNoMatch(t *testing.T) { - store := &stubStore{ - index: &Index{ - Entries: []Entry{{ID: "1", Type: TypeUser, Title: "test"}}, - }, - } - svc := NewService(store, nil, config.MemoConfig{}, nil) +func TestServiceRecallReturnsScopedEntries(t *testing.T) { + store := newMemoryTestStore() + svc := NewService(store, testMemoConfig(), nil) + _ = svc.Add(context.Background(), Entry{ + Type: TypeUser, + Title: "reply in chinese", + Content: "reply in chinese", + Source: SourceUserManual, + }) - removed, err := svc.Remove(context.Background(), "nonexistent") + results, err := svc.Recall(context.Background(), "chinese", ScopeAll) if err != nil { - t.Fatalf("Remove error: %v", err) + t.Fatalf("Recall() error = %v", err) } - if removed != 0 { - t.Errorf("Remove returned %d, want 0", removed) + if len(results) != 1 { + t.Fatalf("len(results) = %d, want 1", len(results)) } -} - -func TestServiceRemoveEmptyKeyword(t *testing.T) { - svc := NewService(&stubStore{}, nil, config.MemoConfig{}, nil) - _, err := svc.Remove(context.Background(), "") - if err == nil { - t.Error("Remove with empty keyword should return error") + if results[0].Scope != ScopeUser { + t.Fatalf("Scope = %q, want %q", results[0].Scope, ScopeUser) } -} - -func TestServiceSearch(t *testing.T) { - store := &stubStore{ - index: &Index{ - Entries: []Entry{ - {ID: "1", Type: TypeUser, Title: "偏好 tab", Keywords: []string{"indentation"}}, - {ID: "2", Type: TypeProject, Title: "使用 Go"}, - {ID: "3", Type: TypeUser, Title: "偏好中文注释"}, - }, - }, - } - svc := NewService(store, nil, config.MemoConfig{}, nil) - - results, err := svc.Search(context.Background(), "偏好") - if err != nil { - t.Fatalf("Search error: %v", err) - } - if len(results) != 2 { - t.Errorf("Search '偏好' = %d results, want 2", len(results)) + if !strings.Contains(results[0].Content, "reply in chinese") { + t.Fatalf("content = %q", results[0].Content) } } -func TestServiceSearchByKeyword(t *testing.T) { - store := &stubStore{ - index: &Index{ - Entries: []Entry{ - {ID: "1", Type: TypeUser, Title: "style", Keywords: []string{"tabs", "indentation"}}, - }, - }, - } - svc := NewService(store, nil, config.MemoConfig{}, nil) +func TestServiceRemoveRespectsScope(t *testing.T) { + svc := NewService(newMemoryTestStore(), testMemoConfig(), nil) + _ = svc.Add(context.Background(), Entry{Type: TypeUser, Title: "same", Content: "user content", Source: SourceUserManual}) + _ = svc.Add(context.Background(), Entry{Type: TypeProject, Title: "same", Content: "project content", Source: SourceUserManual}) - results, err := svc.Search(context.Background(), "indent") + removed, err := svc.Remove(context.Background(), "same", ScopeUser) if err != nil { - t.Fatalf("Search error: %v", err) + t.Fatalf("Remove() error = %v", err) } - if len(results) != 1 { - t.Errorf("Search by keyword = %d results, want 1", len(results)) + if removed != 1 { + t.Fatalf("removed = %d, want 1", removed) } -} -func TestServiceRecall(t *testing.T) { - store := &stubStore{ - index: &Index{ - Entries: []Entry{ - {ID: "1", Type: TypeUser, Title: "偏好 tab", TopicFile: "a.md"}, - {ID: "2", Type: TypeProject, Title: "其他", TopicFile: "b.md"}, - }, - }, + userEntries, _ := svc.List(context.Background(), ScopeUser) + projectEntries, _ := svc.List(context.Background(), ScopeProject) + if len(userEntries) != 0 { + t.Fatalf("expected user scope to be empty, got %#v", userEntries) } - // 为 stubStore 添加 topic 加载能力 - storeWithTopics := &stubStoreWithTopics{ - stubStore: store, - topics: map[string]string{ - "a.md": "---\ntype: user\n---\n\n详细内容A", - "b.md": "---\ntype: project\n---\n\n详细内容B", - }, + if len(projectEntries) != 1 { + t.Fatalf("expected project scope to remain, got %#v", projectEntries) } - svc := NewService(storeWithTopics, nil, config.MemoConfig{}, nil) +} - results, err := svc.Recall(context.Background(), "tab") - if err != nil { - t.Fatalf("Recall error: %v", err) - } - if len(results) != 1 { - t.Fatalf("Recall = %d results, want 1", len(results)) - } - if !strings.Contains(results["a.md"], "详细内容A") { - t.Errorf("Recall content = %q, should contain 详细内容A", results["a.md"]) +func TestServiceRemoveRejectsInvalidScope(t *testing.T) { + svc := NewService(newMemoryTestStore(), testMemoConfig(), nil) + if _, err := svc.Remove(context.Background(), "x", Scope("bad")); err == nil { + t.Fatal("expected invalid scope error") } } -func TestServiceMaxIndexLines(t *testing.T) { - store := &stubStore{} - svc := NewService(store, nil, config.MemoConfig{MaxIndexLines: 2}, nil) +func TestServiceMaxEntriesTrim(t *testing.T) { + cfg := testMemoConfig() + cfg.MaxEntries = 2 + svc := NewService(newMemoryTestStore(), cfg, nil) - for i := 0; i < 4; i++ { - entry := Entry{ + for _, title := range []string{"A", "B", "C"} { + if err := svc.Add(context.Background(), Entry{ Type: TypeUser, - Title: string(rune('A' + i)), - Content: string(rune('A' + i)), + Title: title, + Content: title, Source: SourceUserManual, - } - if err := svc.Add(context.Background(), entry); err != nil { - t.Fatalf("Add %d error: %v", i, err) + }); err != nil { + t.Fatalf("Add(%s) error = %v", title, err) } } - entries, _ := svc.List(context.Background()) - if len(entries) != 2 { - t.Errorf("after overflow, entries = %d, want 2", len(entries)) + entries, err := svc.List(context.Background(), ScopeUser) + if err != nil { + t.Fatalf("List() error = %v", err) } - // 应保留最新的两条 - if entries[0].Title != "C" || entries[1].Title != "D" { - t.Errorf("expected [C, D], got [%s, %s]", entries[0].Title, entries[1].Title) + if len(entries) != 2 || entries[0].Title != "B" || entries[1].Title != "C" { + t.Fatalf("entries after trim = %#v", entries) } } -func TestServiceAddUpdate(t *testing.T) { - store := &stubStore{} - svc := NewService(store, nil, config.MemoConfig{}, nil) - - entry := Entry{ - ID: "fixed_id", - Type: TypeUser, - Title: "旧标题", - Source: SourceUserManual, +func TestServiceMaxIndexBytesTrim(t *testing.T) { + cfg := testMemoConfig() + cfg.MaxEntries = 10 + cfg.MaxIndexBytes = 40 + svc := NewService(newMemoryTestStore(), cfg, nil) + + for _, title := range []string{"one", "two", "three"} { + if err := svc.Add(context.Background(), Entry{ + Type: TypeProject, + Title: title, + Content: title, + Source: SourceUserManual, + }); err != nil { + t.Fatalf("Add(%s) error = %v", title, err) + } } - _ = svc.Add(context.Background(), entry) - entry.Title = "新标题" - _ = svc.Add(context.Background(), entry) - - entries, _ := svc.List(context.Background()) - if len(entries) != 1 { - t.Fatalf("after update, entries = %d, want 1", len(entries)) + entries, err := svc.List(context.Background(), ScopeProject) + if err != nil { + t.Fatalf("List() error = %v", err) } - if entries[0].Title != "新标题" { - t.Errorf("Title = %q, want %q", entries[0].Title, "新标题") + if len(entries) >= 3 { + t.Fatalf("expected byte trimming to remove oldest entry, got %#v", entries) } } -func TestServiceAddSaveTopicFailureDoesNotPersistIndex(t *testing.T) { - store := &stubStore{ - index: &Index{ - Entries: []Entry{ - {ID: "existing", Type: TypeUser, Title: "existing", TopicFile: "existing.md"}, - }, - }, - saveTopicErr: errors.New("save topic failed"), - } - svc := NewService(store, nil, config.MemoConfig{}, nil) +func TestServiceSaveTopicFailureDoesNotPersistIndex(t *testing.T) { + store := newMemoryTestStore() + store.saveTopicErr = errors.New("save topic failed") + svc := NewService(store, testMemoConfig(), nil) err := svc.Add(context.Background(), Entry{ - ID: "new-id", - Type: TypeUser, - Title: "new entry", - Source: SourceUserManual, - TopicFile: "new.md", + Type: TypeUser, + Title: "new", + Content: "new", + Source: SourceUserManual, }) if err == nil || !strings.Contains(err.Error(), "save topic") { t.Fatalf("expected save topic error, got %v", err) } - if len(store.index.Entries) != 1 { - t.Fatalf("index should stay unchanged on topic failure, entries=%d", len(store.index.Entries)) - } if store.saveIndexCalls != 0 { - t.Fatalf("SaveIndex should not run when SaveTopic fails, calls=%d", store.saveIndexCalls) + t.Fatalf("SaveIndex() calls = %d, want 0", store.saveIndexCalls) } } -func TestServiceRemoveSaveIndexFailureDoesNotDeleteTopics(t *testing.T) { - store := &stubStore{ - index: &Index{ - Entries: []Entry{ - {ID: "1", Type: TypeUser, Title: "match", TopicFile: "a.md"}, - {ID: "2", Type: TypeUser, Title: "other", TopicFile: "b.md"}, - }, - }, - saveIndexErr: errors.New("save index failed"), - } - svc := NewService(store, nil, config.MemoConfig{}, nil) +func TestServiceSaveIndexFailureDoesNotDeleteTopic(t *testing.T) { + store := newMemoryTestStore() + store.saveIndexErr = errors.New("save index failed") + svc := NewService(store, testMemoConfig(), nil) - _, err := svc.Remove(context.Background(), "match") + err := svc.Add(context.Background(), Entry{ + Type: TypeUser, + Title: "new", + Content: "new", + Source: SourceUserManual, + }) if err == nil || !strings.Contains(err.Error(), "save index") { t.Fatalf("expected save index error, got %v", err) } - if store.deleteTopicCalls != 0 { - t.Fatalf("DeleteTopic should not run when SaveIndex fails, calls=%d", store.deleteTopicCalls) - } - if len(store.index.Entries) != 2 { - t.Fatalf("index should stay unchanged on save failure, entries=%d", len(store.index.Entries)) + if store.deleteTopicCalls != 1 { + t.Fatalf("DeleteTopic() calls = %d, want 1 rollback delete", store.deleteTopicCalls) } } -func TestServiceAddAutoExtractIfAbsent(t *testing.T) { - baseDir := t.TempDir() - workspace := t.TempDir() - store := NewFileStore(baseDir, workspace) - svc := NewService(store, nil, config.MemoConfig{MaxIndexLines: 200}, nil) - ctx := context.Background() - - first := Entry{ +func TestServiceAutoExtractDedupAcrossScopes(t *testing.T) { + svc := NewService(newMemoryTestStore(), testMemoConfig(), nil) + entry := Entry{ Type: TypeUser, Title: "reply in chinese", Content: "reply in chinese", Source: SourceAutoExtract, } - added, err := svc.addAutoExtractIfAbsent(ctx, first) + + added, err := svc.addAutoExtractIfAbsent(context.Background(), entry) if err != nil || !added { - t.Fatalf("first addAutoExtractIfAbsent() = (%v,%v), want (true,nil)", added, err) + t.Fatalf("first addAutoExtractIfAbsent() = (%v, %v), want (true, nil)", added, err) } - - added, err = svc.addAutoExtractIfAbsent(ctx, first) + added, err = svc.addAutoExtractIfAbsent(context.Background(), entry) if err != nil { t.Fatalf("second addAutoExtractIfAbsent() error = %v", err) } if added { - t.Fatalf("duplicate addAutoExtractIfAbsent() = true, want false") + t.Fatal("expected duplicate auto extract to be skipped") } +} - entries, err := svc.List(ctx) +func TestServiceAutoExtractTrimmedEntryDoesNotPolluteDedupIndex(t *testing.T) { + svc := NewService(newMemoryTestStore(), config.MemoConfig{ + MaxEntries: 10, + MaxIndexBytes: 1, + ExtractTimeoutSec: 15, + ExtractRecentMessages: 10, + }, nil) + entry := Entry{ + Type: TypeUser, + Title: "reply in chinese", + Content: "reply in chinese", + Source: SourceAutoExtract, + } + + added, err := svc.addAutoExtractIfAbsent(context.Background(), entry) + if err != nil || !added { + t.Fatalf("first addAutoExtractIfAbsent() = (%v, %v), want (true, nil)", added, err) + } + added, err = svc.addAutoExtractIfAbsent(context.Background(), entry) + if err != nil || !added { + t.Fatalf("second addAutoExtractIfAbsent() = (%v, %v), want (true, nil)", added, err) + } + + entries, err := svc.List(context.Background(), ScopeUser) if err != nil { t.Fatalf("List() error = %v", err) } - if len(entries) != 1 { - t.Fatalf("entries after dedupe = %d, want 1", len(entries)) + if len(entries) != 0 { + t.Fatalf("len(entries) = %d, want 0 after byte trim", len(entries)) } -} -func TestServiceEnsureAutoExtractIndexAndLoadFailureTolerance(t *testing.T) { - store := &stubStoreWithTopics{ - stubStore: &stubStore{ - index: &Index{ - Entries: []Entry{ - {ID: "1", Type: TypeUser, Title: "A", TopicFile: "a.md"}, - {ID: "2", Type: TypeFeedback, Title: "B", TopicFile: "b.md"}, - {ID: "3", Type: TypeProject, Title: "C", TopicFile: "missing.md"}, - }, - }, - }, - topics: map[string]string{ - "a.md": "---\nsource: extractor_auto\n---\n\na content", - "b.md": "---\nsource: user_manual\n---\n\nb content", - }, + key := autoExtractDedupKey(entry) + if refs := svc.autoExtractKeyRefs[key]; refs != 0 { + t.Fatalf("autoExtractKeyRefs[%q] = %d, want 0", key, refs) } - svc := NewService(store, nil, config.MemoConfig{MaxIndexLines: 200}, nil) +} + +func TestServiceEnsureAutoExtractIndexLoadsExistingEntries(t *testing.T) { + store := newMemoryTestStore() + store.indexes[ScopeUser] = &Index{Entries: []Entry{{Type: TypeUser, Title: "A", TopicFile: "a.md"}}} + store.topics[ScopeUser]["a.md"] = "---\nsource: extractor_auto\n---\n\na content" + svc := NewService(store, testMemoConfig(), nil) + if err := svc.ensureAutoExtractIndex(context.Background()); err != nil { t.Fatalf("ensureAutoExtractIndex() error = %v", err) } if !svc.autoExtractIndexReady { - t.Fatalf("autoExtractIndexReady = false, want true") - } - if len(svc.autoExtractKeysByTopic) != 1 { - t.Fatalf("autoExtractKeysByTopic = %+v", svc.autoExtractKeysByTopic) + t.Fatal("autoExtractIndexReady = false") } if svc.autoExtractKeyRefs[autoExtractDedupKey(Entry{ Type: TypeUser, @@ -375,138 +353,45 @@ func TestServiceEnsureAutoExtractIndexAndLoadFailureTolerance(t *testing.T) { Content: "a content", Source: SourceAutoExtract, })] != 1 { - t.Fatalf("autoExtractKeyRefs = %+v", svc.autoExtractKeyRefs) + t.Fatalf("autoExtractKeyRefs = %#v", svc.autoExtractKeyRefs) } } -func TestServiceEnsureAutoExtractIndexLoadIndexFailure(t *testing.T) { - store := &stubStore{err: errors.New("load index failed")} - svc := NewService(store, nil, config.MemoConfig{}, nil) - err := svc.ensureAutoExtractIndex(context.Background()) - if err == nil || !strings.Contains(err.Error(), "load index") { - t.Fatalf("ensureAutoExtractIndex() error = %v", err) - } -} - -func TestServiceAutoExtractIndexHelpers(t *testing.T) { - svc := NewService(&stubStore{}, nil, config.MemoConfig{}, nil) - svc.autoExtractIndexReady = true - svc.autoExtractKeysByTopic = map[string]string{} - svc.autoExtractKeyRefs = map[string]int{} - +func TestMatchesKeywordIncludesContent(t *testing.T) { entry := Entry{ - Type: TypeProject, - Title: " release plan ", - Content: " ship in april ", - Source: SourceAutoExtract, - TopicFile: "plan.md", - } - svc.trackAutoExtractEntryLocked(entry) - key := autoExtractDedupKey(entry) - if svc.autoExtractKeysByTopic["plan.md"] != key || svc.autoExtractKeyRefs[key] != 1 { - t.Fatalf("trackAutoExtractEntryLocked() state = %+v %+v", svc.autoExtractKeysByTopic, svc.autoExtractKeyRefs) - } - if !svc.hasExactAutoExtractLocked(entry) { - t.Fatalf("hasExactAutoExtractLocked() = false, want true") - } - - svc.trackAutoExtractEntryLocked(entry) - if svc.autoExtractKeyRefs[key] != 1 { - t.Fatalf("expected stable ref count on same topic replacement, refs=%d", svc.autoExtractKeyRefs[key]) - } - - svc.autoExtractKeysByTopic["plan-copy.md"] = key - svc.autoExtractKeyRefs[key] = 2 - svc.removeAutoExtractTopicLocked("plan.md") - if svc.autoExtractKeyRefs[key] != 1 { - t.Fatalf("removeAutoExtractTopicLocked() should decrement refs, got %d", svc.autoExtractKeyRefs[key]) - } - svc.removeAutoExtractTopicLocked("plan-copy.md") - if svc.autoExtractKeyRefs[key] != 0 { - t.Fatalf("removeAutoExtractTopicLocked() should clear refs, got %+v", svc.autoExtractKeyRefs) - } -} - -func TestCloneIndexNilAndCopyIsolation(t *testing.T) { - clonedNil := cloneIndex(nil) - if clonedNil == nil || len(clonedNil.Entries) != 0 { - t.Fatalf("cloneIndex(nil) = %+v", clonedNil) + Type: TypeUser, + Title: "style", + Content: "please use tabs", + Keywords: []string{"indentation"}, } - - origin := &Index{Entries: []Entry{{ID: "1", Type: TypeUser, Title: "old"}}} - cloned := cloneIndex(origin) - origin.Entries[0].Title = "changed" - if cloned.Entries[0].Title != "old" { - t.Fatalf("cloneIndex should isolate entries, got %+v", cloned.Entries) + if !matchesKeyword(entry, "tabs") { + t.Fatal("expected content match") } -} - -func TestNewEntryID(t *testing.T) { - id := newEntryID(TypeUser) - if !strings.HasPrefix(id, "user_") { - t.Errorf("ID = %q, should start with 'user_'", id) + if !matchesKeyword(entry, "indent") { + t.Fatal("expected keyword match") } - // 确保唯一 - id2 := newEntryID(TypeUser) - if id == id2 { - t.Error("consecutive IDs should be unique") + if matchesKeyword(entry, "missing") { + t.Fatal("unexpected match for missing keyword") } } -func TestMatchesKeyword(t *testing.T) { - entry := Entry{ - Type: TypeUser, - Title: "偏好 tab 缩进", - Keywords: []string{"indentation", "style"}, - } - tests := []struct { - kw string - want bool - }{ - {"tab", true}, - {"偏好", true}, - {"indent", true}, - {"user", true}, - {"nonexistent", false}, - } - for _, tt := range tests { - got := matchesKeyword(entry, strings.ToLower(tt.kw)) - if got != tt.want { - t.Errorf("matchesKeyword(%q) = %v, want %v", tt.kw, got, tt.want) - } +func TestTrimIndexEntriesByBytesRemovesMinimalPrefix(t *testing.T) { + index := &Index{ + Entries: []Entry{ + {Type: TypeUser, Title: "one", TopicFile: "one.md"}, + {Type: TypeUser, Title: "two", TopicFile: "two.md"}, + {Type: TypeUser, Title: "three", TopicFile: "three.md"}, + }, } -} -// stubStoreWithTopics 扩展 stubStore 支持 topic 加载。 -type stubStoreWithTopics struct { - *stubStore - topics map[string]string - mu sync.Mutex -} + target := &Index{Entries: append([]Entry(nil), index.Entries[1:]...)} + maxIndexBytes := len(RenderIndex(target)) + removed := trimIndexEntries(index, 10, maxIndexBytes) -func (s *stubStoreWithTopics) LoadTopic(_ context.Context, filename string) (string, error) { - s.mu.Lock() - defer s.mu.Unlock() - content, ok := s.topics[filename] - if !ok { - return "", errors.New("not found") + if len(removed) != 1 || removed[0].Title != "one" { + t.Fatalf("removed = %#v, want only first entry", removed) } - return content, nil -} - -func (s *stubStoreWithTopics) SaveTopic(_ context.Context, filename string, content string) error { - s.mu.Lock() - defer s.mu.Unlock() - if s.topics == nil { - s.topics = make(map[string]string) + if len(index.Entries) != 2 || index.Entries[0].Title != "two" || index.Entries[1].Title != "three" { + t.Fatalf("remaining entries = %#v", index.Entries) } - s.topics[filename] = content - return nil -} - -func (s *stubStoreWithTopics) DeleteTopic(_ context.Context, filename string) error { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.topics, filename) - return nil } diff --git a/internal/memo/store.go b/internal/memo/store.go index b22fe704..befbb527 100644 --- a/internal/memo/store.go +++ b/internal/memo/store.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "sort" "sync" agentsession "neo-code/internal/session" @@ -17,39 +18,51 @@ const ( memoFileName = "MEMO.md" ) -// FileStore 基于文件系统实现 Store 接口,采用工作区隔离的目录布局。 +// FileStore 基于文件系统实现 Store 接口,采用工作区隔离的双层目录布局。 type FileStore struct { - mu sync.RWMutex - memoDir string - topicsDir string + mu sync.RWMutex + baseDir string + workspaceRoot string } // NewFileStore 创建 FileStore 实例,目录基于 baseDir 和 workspaceRoot 计算工作区隔离路径。 func NewFileStore(baseDir string, workspaceRoot string) *FileStore { - dir := memoDirectory(baseDir, workspaceRoot) return &FileStore{ - memoDir: dir, - topicsDir: filepath.Join(dir, topicsDirName), + baseDir: baseDir, + workspaceRoot: workspaceRoot, } } -// LoadIndex 加载 MEMO.md 索引文件并解析为 Index 结构。 -func (s *FileStore) LoadIndex(ctx context.Context) (*Index, error) { +// LoadIndex 加载指定分层下的 MEMO.md 索引文件并解析为 Index 结构。 +func (s *FileStore) LoadIndex(ctx context.Context, scope Scope) (*Index, error) { if err := ctx.Err(); err != nil { return nil, err } + if err := validateStorageScope(scope); err != nil { + return nil, err + } s.mu.RLock() defer s.mu.RUnlock() - return s.loadIndexUnlocked() + data, err := readFirstExistingFile(s.indexPaths(scope)) + if errors.Is(err, os.ErrNotExist) { + return &Index{}, nil + } + if err != nil { + return nil, fmt.Errorf("memo: read index: %w", err) + } + return ParseIndex(string(data)) } -// SaveIndex 将索引写入 MEMO.md 文件,采用临时文件 + 原子替换策略。 -func (s *FileStore) SaveIndex(ctx context.Context, index *Index) error { +// SaveIndex 将索引写入指定分层下的 MEMO.md 文件,采用临时文件 + 原子替换策略。 +func (s *FileStore) SaveIndex(ctx context.Context, scope Scope, index *Index) error { if err := ctx.Err(); err != nil { return err } + if err := validateStorageScope(scope); err != nil { + return err + } if index == nil { return errors.New("memo: index is nil") } @@ -57,13 +70,18 @@ func (s *FileStore) SaveIndex(ctx context.Context, index *Index) error { s.mu.Lock() defer s.mu.Unlock() - if err := os.MkdirAll(s.memoDir, 0o755); err != nil { + if err := s.migrateLegacyProjectScopeLocked(scope); err != nil { + return err + } + + dir := s.scopeDir(scope) + if err := os.MkdirAll(dir, 0o755); err != nil { return fmt.Errorf("memo: create memo dir: %w", err) } - content := RenderIndex(index) - target := filepath.Join(s.memoDir, memoFileName) + target := filepath.Join(dir, memoFileName) temp := target + ".tmp" + content := RenderIndex(index) if err := os.WriteFile(temp, []byte(content), 0o644); err != nil { return fmt.Errorf("memo: write temp index: %w", err) @@ -74,43 +92,51 @@ func (s *FileStore) SaveIndex(ctx context.Context, index *Index) error { if err := os.Rename(temp, target); err != nil { return fmt.Errorf("memo: commit index: %w", err) } - return nil } -// LoadTopic 读取指定 topic 文件的完整内容。 -func (s *FileStore) LoadTopic(ctx context.Context, filename string) (string, error) { +// LoadTopic 读取指定分层下的 topic 文件完整内容。 +func (s *FileStore) LoadTopic(ctx context.Context, scope Scope, filename string) (string, error) { if err := ctx.Err(); err != nil { return "", err } + if err := validateStorageScope(scope); err != nil { + return "", err + } s.mu.RLock() defer s.mu.RUnlock() - path := s.topicPath(filename) - data, err := os.ReadFile(path) + data, err := readFirstExistingFile(s.topicPaths(scope, filename)) if err != nil { return "", fmt.Errorf("memo: read topic %s: %w", filename, err) } return string(data), nil } -// SaveTopic 将内容写入指定 topic 文件,采用临时文件 + 原子替换策略。 -func (s *FileStore) SaveTopic(ctx context.Context, filename string, content string) error { +// SaveTopic 将内容写入指定分层下的 topic 文件,采用临时文件 + 原子替换策略。 +func (s *FileStore) SaveTopic(ctx context.Context, scope Scope, filename string, content string) error { if err := ctx.Err(); err != nil { return err } + if err := validateStorageScope(scope); err != nil { + return err + } s.mu.Lock() defer s.mu.Unlock() - if err := os.MkdirAll(s.topicsDir, 0o755); err != nil { + if err := s.migrateLegacyProjectScopeLocked(scope); err != nil { + return err + } + + dir := s.topicsDir(scope) + if err := os.MkdirAll(dir, 0o755); err != nil { return fmt.Errorf("memo: create topics dir: %w", err) } - path := s.topicPath(filename) + path := s.topicPath(scope, filename) temp := path + ".tmp" - if err := os.WriteFile(temp, []byte(content), 0o644); err != nil { return fmt.Errorf("memo: write temp topic: %w", err) } @@ -120,43 +146,70 @@ func (s *FileStore) SaveTopic(ctx context.Context, filename string, content stri if err := os.Rename(temp, path); err != nil { return fmt.Errorf("memo: commit topic: %w", err) } - return nil } -// DeleteTopic 删除指定 topic 文件。 -func (s *FileStore) DeleteTopic(ctx context.Context, filename string) error { +// DeleteTopic 删除指定分层下的 topic 文件。 +func (s *FileStore) DeleteTopic(ctx context.Context, scope Scope, filename string) error { if err := ctx.Err(); err != nil { return err } + if err := validateStorageScope(scope); err != nil { + return err + } s.mu.Lock() defer s.mu.Unlock() - path := s.topicPath(filename) + if err := s.migrateLegacyProjectScopeLocked(scope); err != nil { + return err + } + + path := s.topicPath(scope, filename) if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("memo: delete topic %s: %w", filename, err) } return nil } -// ListTopics 列出 topics 目录下所有 .md 文件名。 -func (s *FileStore) ListTopics(ctx context.Context) ([]string, error) { +// ListTopics 列出指定分层下 topics 目录中的所有 .md 文件名。 +func (s *FileStore) ListTopics(ctx context.Context, scope Scope) ([]string, error) { if err := ctx.Err(); err != nil { return nil, err } + if err := validateStorageScope(scope); err != nil { + return nil, err + } s.mu.RLock() defer s.mu.RUnlock() - entries, err := os.ReadDir(s.topicsDir) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil, nil + seen := make(map[string]struct{}) + for _, dir := range s.topicsDirs(scope) { + entries, err := os.ReadDir(dir) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + continue + } + return nil, fmt.Errorf("memo: list topics: %w", err) } - return nil, fmt.Errorf("memo: list topics: %w", err) + for _, name := range collectTopicNames(entries) { + seen[name] = struct{}{} + } + } + if len(seen) == 0 { + return nil, nil + } + names := make([]string, 0, len(seen)) + for name := range seen { + names = append(names, name) } + sort.Strings(names) + return names, nil +} +// collectTopicNames 将目录项过滤为 topic 文件名列表。 +func collectTopicNames(entries []os.DirEntry) []string { names := make([]string, 0, len(entries)) for _, entry := range entries { if entry.IsDir() || filepath.Ext(entry.Name()) != ".md" { @@ -164,29 +217,161 @@ func (s *FileStore) ListTopics(ctx context.Context) ([]string, error) { } names = append(names, entry.Name()) } - return names, nil + return names +} + +// readFirstExistingFile 按顺序读取候选路径,返回首个存在文件内容;若均不存在则返回 os.ErrNotExist。 +func readFirstExistingFile(paths []string) ([]byte, error) { + for _, path := range paths { + data, err := os.ReadFile(path) + if err == nil { + return data, nil + } + if !errors.Is(err, os.ErrNotExist) { + return nil, err + } + } + return nil, os.ErrNotExist +} + +// scopeDir 返回指定 memo 分层的根目录。 +func (s *FileStore) scopeDir(scope Scope) string { + if scope == ScopeUser { + return filepath.Join(globalMemoDirectory(s.baseDir), string(scope)) + } + return filepath.Join(projectMemoDirectory(s.baseDir, s.workspaceRoot), string(scope)) } -// loadIndexUnlocked 在无锁状态下读取并解析 MEMO.md。 -func (s *FileStore) loadIndexUnlocked() (*Index, error) { - path := filepath.Join(s.memoDir, memoFileName) - data, err := os.ReadFile(path) +// scopeDirLegacy 返回旧版本 project scope 的根目录,仅用于兼容迁移。 +func (s *FileStore) scopeDirLegacy(scope Scope) string { + if scope == ScopeProject { + return projectMemoDirectory(s.baseDir, s.workspaceRoot) + } + return "" +} + +// indexPaths 返回读取索引时的候选路径,顺序为新路径优先、旧路径兜底。 +func (s *FileStore) indexPaths(scope Scope) []string { + paths := []string{filepath.Join(s.scopeDir(scope), memoFileName)} + if legacy := s.scopeDirLegacy(scope); legacy != "" { + paths = append(paths, filepath.Join(legacy, memoFileName)) + } + return paths +} + +// topicsDir 返回指定 memo 分层的 topics 目录。 +func (s *FileStore) topicsDir(scope Scope) string { + return filepath.Join(s.scopeDir(scope), topicsDirName) +} + +// topicsDirs 返回读取 topics 时的候选目录,顺序为新路径优先、旧路径兜底。 +func (s *FileStore) topicsDirs(scope Scope) []string { + dirs := []string{s.topicsDir(scope)} + if legacy := s.scopeDirLegacy(scope); legacy != "" { + dirs = append(dirs, filepath.Join(legacy, topicsDirName)) + } + return dirs +} + +// topicPath 生成指定分层下 topic 文件的安全路径,防止目录穿越。 +func (s *FileStore) topicPath(scope Scope, filename string) string { + return filepath.Join(s.topicsDir(scope), filepath.Base(filename)) +} + +// topicPaths 返回读取 topic 时的候选路径,顺序为新路径优先、旧路径兜底。 +func (s *FileStore) topicPaths(scope Scope, filename string) []string { + base := filepath.Base(filename) + paths := []string{filepath.Join(s.topicsDir(scope), base)} + if legacy := s.scopeDirLegacy(scope); legacy != "" { + paths = append(paths, filepath.Join(legacy, topicsDirName, base)) + } + return paths +} + +// migrateLegacyProjectScopeLocked 在首次写入前把旧版 project 目录迁移到新目录,避免历史数据不可见。 +func (s *FileStore) migrateLegacyProjectScopeLocked(scope Scope) error { + if scope != ScopeProject { + return nil + } + + legacyDir := s.scopeDirLegacy(scope) + if legacyDir == "" { + return nil + } + + if err := os.MkdirAll(s.scopeDir(scope), 0o755); err != nil { + return fmt.Errorf("memo: create scoped dir for migration: %w", err) + } + + legacyMemo := filepath.Join(legacyDir, memoFileName) + targetMemo := filepath.Join(s.scopeDir(scope), memoFileName) + if err := moveFileIfDstMissing(legacyMemo, targetMemo); err != nil { + return fmt.Errorf("memo: migrate legacy index: %w", err) + } + + legacyTopics := filepath.Join(legacyDir, topicsDirName) + legacyEntries, err := os.ReadDir(legacyTopics) if err != nil { if errors.Is(err, os.ErrNotExist) { - return &Index{}, nil + return nil } - return nil, fmt.Errorf("memo: read index: %w", err) + return fmt.Errorf("memo: list legacy topics: %w", err) } - return ParseIndex(string(data)) + if len(legacyEntries) == 0 { + return nil + } + + newTopics := s.topicsDir(scope) + if err := os.MkdirAll(newTopics, 0o755); err != nil { + return fmt.Errorf("memo: create scoped topics dir for migration: %w", err) + } + for _, entry := range legacyEntries { + if entry.IsDir() || filepath.Ext(entry.Name()) != ".md" { + continue + } + oldPath := filepath.Join(legacyTopics, entry.Name()) + newPath := filepath.Join(newTopics, entry.Name()) + if err := moveFileIfDstMissing(oldPath, newPath); err != nil { + return fmt.Errorf("memo: migrate legacy topic %s: %w", entry.Name(), err) + } + } + + return nil } -// topicPath 生成 topic 文件的安全路径,防止目录穿越。 -func (s *FileStore) topicPath(filename string) string { - safe := filepath.Base(filename) - return filepath.Join(s.topicsDir, safe) +// moveFileIfDstMissing 在源文件存在且目标文件不存在时执行迁移重命名。 +func moveFileIfDstMissing(src string, dst string) error { + if _, err := os.Stat(src); err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + if _, err := os.Stat(dst); err == nil { + return nil + } else if !errors.Is(err, os.ErrNotExist) { + return err + } + return os.Rename(src, dst) } // memoDirectory 根据工作区根目录计算记忆分桶目录,复用 session 包的工作区哈希。 -func memoDirectory(baseDir string, workspaceRoot string) string { +// globalMemoDirectory 返回全局 memo 根目录,用于存放 user 层记忆。 +func globalMemoDirectory(baseDir string) string { + return filepath.Join(baseDir, memoDirName) +} + +// projectMemoDirectory 根据 workspace 根目录计算 project 层 memo 根目录。 +func projectMemoDirectory(baseDir string, workspaceRoot string) string { return filepath.Join(baseDir, "projects", agentsession.HashWorkspaceRoot(workspaceRoot), memoDirName) } + +// validateStorageScope 校验当前 scope 是否是允许落盘的 memo 分层。 +func validateStorageScope(scope Scope) error { + switch scope { + case ScopeUser, ScopeProject: + return nil + default: + return fmt.Errorf("memo: unsupported storage scope %q", scope) + } +} diff --git a/internal/memo/store_test.go b/internal/memo/store_test.go index c799f827..1a917dc0 100644 --- a/internal/memo/store_test.go +++ b/internal/memo/store_test.go @@ -2,6 +2,7 @@ package memo import ( "context" + "errors" "os" "path/filepath" "strings" @@ -15,263 +16,332 @@ func TestNewFileStore(t *testing.T) { tmp := t.TempDir() store := NewFileStore(tmp, "/workspace/project") if store == nil { - t.Fatal("NewFileStore returned nil") + t.Fatal("NewFileStore() returned nil") } - if store.memoDir == "" { - t.Error("memoDir is empty") + if store.baseDir != tmp { + t.Fatalf("baseDir = %q, want %q", store.baseDir, tmp) } - if store.topicsDir == "" { - t.Error("topicsDir is empty") + if store.workspaceRoot != "/workspace/project" { + t.Fatalf("workspaceRoot = %q, want %q", store.workspaceRoot, "/workspace/project") } } -func TestFileStoreLoadIndexNotExist(t *testing.T) { - tmp := t.TempDir() - store := NewFileStore(tmp, "/workspace/project") +func TestFileStoreSaveAndLoadIndexByScope(t *testing.T) { + store := NewFileStore(t.TempDir(), "/workspace/project") + index := &Index{ + Entries: []Entry{{ + ID: "user_001", + Type: TypeUser, + Title: "user pref", + Content: "content", + Keywords: []string{"tabs"}, + Source: SourceUserManual, + TopicFile: "user.md", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }}, + UpdatedAt: time.Now(), + } - idx, err := store.LoadIndex(context.Background()) - if err != nil { - t.Fatalf("LoadIndex on nonexistent dir error: %v", err) + if err := store.SaveIndex(context.Background(), ScopeUser, index); err != nil { + t.Fatalf("SaveIndex() error = %v", err) } - if idx == nil { - t.Fatal("LoadIndex returned nil index") + loaded, err := store.LoadIndex(context.Background(), ScopeUser) + if err != nil { + t.Fatalf("LoadIndex() error = %v", err) } - if len(idx.Entries) != 0 { - t.Errorf("Entries = %d, want 0", len(idx.Entries)) + if len(loaded.Entries) != 1 || loaded.Entries[0].Title != "user pref" { + t.Fatalf("loaded entries = %#v", loaded.Entries) } } -func TestFileStoreSaveAndLoadIndex(t *testing.T) { - tmp := t.TempDir() - store := NewFileStore(tmp, "/workspace/project") - - original := &Index{ - Entries: []Entry{ - { - ID: "user_001", - Type: TypeUser, - Title: "偏好 tab 缩进", - Content: "详细内容", - Keywords: []string{"tabs"}, - Source: SourceUserManual, - TopicFile: "user_profile.md", - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }, - }, - UpdatedAt: time.Now(), - } +func TestFileStoreSaveAndLoadTopicByScope(t *testing.T) { + store := NewFileStore(t.TempDir(), "/workspace/project") + content := "---\ntype: user\n---\n\nbody\n" - ctx := context.Background() - if err := store.SaveIndex(ctx, original); err != nil { - t.Fatalf("SaveIndex error: %v", err) + if err := store.SaveTopic(context.Background(), ScopeUser, "user.md", content); err != nil { + t.Fatalf("SaveTopic() error = %v", err) } - - loaded, err := store.LoadIndex(ctx) + loaded, err := store.LoadTopic(context.Background(), ScopeUser, "user.md") if err != nil { - t.Fatalf("LoadIndex error: %v", err) + t.Fatalf("LoadTopic() error = %v", err) } - if len(loaded.Entries) != 1 { - t.Fatalf("loaded entries = %d, want 1", len(loaded.Entries)) - } - if loaded.Entries[0].Title != "偏好 tab 缩进" { - t.Errorf("Title = %q, want %q", loaded.Entries[0].Title, "偏好 tab 缩进") - } - if loaded.Entries[0].TopicFile != "user_profile.md" { - t.Errorf("TopicFile = %q, want %q", loaded.Entries[0].TopicFile, "user_profile.md") + if loaded != content { + t.Fatalf("LoadTopic() = %q, want %q", loaded, content) } } -func TestFileStoreSaveIndexNil(t *testing.T) { - tmp := t.TempDir() - store := NewFileStore(tmp, "/workspace/project") - err := store.SaveIndex(context.Background(), nil) - if err == nil { - t.Error("SaveIndex(nil) should return error") +func TestFileStoreDeleteTopic(t *testing.T) { + store := NewFileStore(t.TempDir(), "/workspace/project") + + if err := store.SaveTopic(context.Background(), ScopeProject, "p.md", "content"); err != nil { + t.Fatalf("SaveTopic() error = %v", err) + } + if err := store.DeleteTopic(context.Background(), ScopeProject, "p.md"); err != nil { + t.Fatalf("DeleteTopic() error = %v", err) + } + if _, err := store.LoadTopic(context.Background(), ScopeProject, "p.md"); err == nil { + t.Fatal("expected deleted topic to be missing") } } -func TestFileStoreSaveAndLoadTopic(t *testing.T) { - tmp := t.TempDir() - store := NewFileStore(tmp, "/workspace/project") - ctx := context.Background() +func TestFileStoreListTopics(t *testing.T) { + store := NewFileStore(t.TempDir(), "/workspace/project") - content := "---\ntype: user\n---\n\n这是详细内容\n" - if err := store.SaveTopic(ctx, "user_profile.md", content); err != nil { - t.Fatalf("SaveTopic error: %v", err) + if err := store.SaveTopic(context.Background(), ScopeProject, "a.md", "a"); err != nil { + t.Fatalf("SaveTopic(a) error = %v", err) + } + if err := store.SaveTopic(context.Background(), ScopeProject, "b.md", "b"); err != nil { + t.Fatalf("SaveTopic(b) error = %v", err) } - loaded, err := store.LoadTopic(ctx, "user_profile.md") + topics, err := store.ListTopics(context.Background(), ScopeProject) if err != nil { - t.Fatalf("LoadTopic error: %v", err) + t.Fatalf("ListTopics() error = %v", err) } - if loaded != content { - t.Errorf("LoadTopic = %q, want %q", loaded, content) + if len(topics) != 2 { + t.Fatalf("len(topics) = %d, want 2", len(topics)) } } -func TestFileStoreLoadTopicNotExist(t *testing.T) { +func TestFileStoreUserScopeIsGlobal(t *testing.T) { tmp := t.TempDir() - store := NewFileStore(tmp, "/workspace/project") - ctx := context.Background() + storeA := NewFileStore(tmp, "/workspace/a") + storeB := NewFileStore(tmp, "/workspace/b") - _, err := store.LoadTopic(ctx, "nonexistent.md") - if err == nil { - t.Error("LoadTopic on nonexistent file should return error") + if err := storeA.SaveIndex(context.Background(), ScopeUser, &Index{Entries: []Entry{{Type: TypeUser, Title: "A"}}}); err != nil { + t.Fatalf("SaveIndex() error = %v", err) + } + index, err := storeB.LoadIndex(context.Background(), ScopeUser) + if err != nil { + t.Fatalf("LoadIndex() error = %v", err) + } + if len(index.Entries) != 1 || index.Entries[0].Title != "A" { + t.Fatalf("global user scope failed, got %#v", index.Entries) } } -func TestFileStoreDeleteTopic(t *testing.T) { +func TestFileStoreProjectScopeIsWorkspaceIsolated(t *testing.T) { tmp := t.TempDir() - store := NewFileStore(tmp, "/workspace/project") - ctx := context.Background() + storeA := NewFileStore(tmp, "/workspace/a") + storeB := NewFileStore(tmp, "/workspace/b") - if err := store.SaveTopic(ctx, "to_delete.md", "content"); err != nil { - t.Fatalf("SaveTopic error: %v", err) + if err := storeA.SaveIndex(context.Background(), ScopeProject, &Index{Entries: []Entry{{Type: TypeProject, Title: "A"}}}); err != nil { + t.Fatalf("SaveIndex() error = %v", err) } - if err := store.DeleteTopic(ctx, "to_delete.md"); err != nil { - t.Fatalf("DeleteTopic error: %v", err) + index, err := storeB.LoadIndex(context.Background(), ScopeProject) + if err != nil { + t.Fatalf("LoadIndex() error = %v", err) } - if _, err := store.LoadTopic(ctx, "to_delete.md"); err == nil { - t.Error("LoadTopic after delete should return error") + if len(index.Entries) != 0 { + t.Fatalf("workspace isolation failed, got %#v", index.Entries) } } -func TestFileStoreDeleteTopicNotExist(t *testing.T) { - tmp := t.TempDir() - store := NewFileStore(tmp, "/workspace/project") - ctx := context.Background() - - err := store.DeleteTopic(ctx, "nonexistent.md") - if err != nil { - t.Errorf("DeleteTopic on nonexistent file should not error: %v", err) +func TestFileStoreRejectsUnsupportedScope(t *testing.T) { + store := NewFileStore(t.TempDir(), "/workspace/project") + if _, err := store.LoadIndex(context.Background(), ScopeAll); err == nil { + t.Fatal("expected ScopeAll load to fail") } } -func TestFileStoreListTopics(t *testing.T) { - tmp := t.TempDir() - store := NewFileStore(tmp, "/workspace/project") - ctx := context.Background() +func TestFileStoreCancelContext(t *testing.T) { + store := NewFileStore(t.TempDir(), "/workspace/project") + ctx, cancel := context.WithCancel(context.Background()) + cancel() - // 空目录应返回空列表 - topics, err := store.ListTopics(ctx) - if err != nil { - t.Fatalf("ListTopics on empty dir error: %v", err) + if _, err := store.LoadIndex(ctx, ScopeUser); err == nil { + t.Fatal("expected LoadIndex() to fail on canceled context") + } + if err := store.SaveIndex(ctx, ScopeUser, &Index{}); err == nil { + t.Fatal("expected SaveIndex() to fail on canceled context") + } + if _, err := store.LoadTopic(ctx, ScopeUser, "x.md"); err == nil { + t.Fatal("expected LoadTopic() to fail on canceled context") + } + if err := store.SaveTopic(ctx, ScopeUser, "x.md", "body"); err == nil { + t.Fatal("expected SaveTopic() to fail on canceled context") } - if len(topics) != 0 { - t.Errorf("ListTopics empty = %d, want 0", len(topics)) + if err := store.DeleteTopic(ctx, ScopeUser, "x.md"); err == nil { + t.Fatal("expected DeleteTopic() to fail on canceled context") } + if _, err := store.ListTopics(ctx, ScopeUser); err == nil { + t.Fatal("expected ListTopics() to fail on canceled context") + } +} - // 写入几个 topic - for _, name := range []string{"a.md", "b.md", "c.txt"} { - if strings.HasSuffix(name, ".md") { - _ = store.SaveTopic(ctx, name, "content") - } +func TestFileStoreAtomicWriteLeavesNoTempFiles(t *testing.T) { + store := NewFileStore(t.TempDir(), "/workspace/project") + if err := store.SaveIndex(context.Background(), ScopeUser, &Index{Entries: []Entry{{Type: TypeUser, Title: "test"}}}); err != nil { + t.Fatalf("SaveIndex() error = %v", err) } - topics, err = store.ListTopics(ctx) + entries, err := os.ReadDir(store.scopeDir(ScopeUser)) if err != nil { - t.Fatalf("ListTopics error: %v", err) + t.Fatalf("ReadDir() error = %v", err) } - if len(topics) != 2 { - t.Errorf("ListTopics = %d, want 2 (only .md files)", len(topics)) + for _, entry := range entries { + if strings.HasSuffix(entry.Name(), ".tmp") { + t.Fatalf("unexpected temp file %s", entry.Name()) + } } } -func TestFileStoreCancelContext(t *testing.T) { - tmp := t.TempDir() - store := NewFileStore(tmp, "/workspace/project") - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - if _, err := store.LoadIndex(ctx); err == nil { - t.Error("LoadIndex with cancelled context should return error") +func TestGlobalMemoDirectory(t *testing.T) { + got := globalMemoDirectory("/base") + want := filepath.Join("/base", "memo") + if got != want { + t.Fatalf("globalMemoDirectory() = %q, want %q", got, want) } - if err := store.SaveIndex(ctx, &Index{}); err == nil { - t.Error("SaveIndex with cancelled context should return error") +} + +func TestProjectMemoDirectory(t *testing.T) { + got := projectMemoDirectory("/base", "/workspace") + want := filepath.Join("/base", "projects", agentsession.HashWorkspaceRoot("/workspace"), "memo") + if got != want { + t.Fatalf("projectMemoDirectory() = %q, want %q", got, want) } - if _, err := store.LoadTopic(ctx, "f.md"); err == nil { - t.Error("LoadTopic with cancelled context should return error") +} + +func TestFileStoreWritesScopesToExpectedDirectories(t *testing.T) { + baseDir := t.TempDir() + store := NewFileStore(baseDir, "/workspace/project") + + if err := store.SaveIndex(context.Background(), ScopeUser, &Index{Entries: []Entry{{Type: TypeUser, Title: "user"}}}); err != nil { + t.Fatalf("SaveIndex(user) error = %v", err) } - if err := store.SaveTopic(ctx, "f.md", "c"); err == nil { - t.Error("SaveTopic with cancelled context should return error") + if err := store.SaveIndex(context.Background(), ScopeProject, &Index{Entries: []Entry{{Type: TypeProject, Title: "project"}}}); err != nil { + t.Fatalf("SaveIndex(project) error = %v", err) } - if err := store.DeleteTopic(ctx, "f.md"); err == nil { - t.Error("DeleteTopic with cancelled context should return error") + + if _, err := os.Stat(filepath.Join(baseDir, "memo", "user", memoFileName)); err != nil { + t.Fatalf("expected global user memo to exist: %v", err) } - if _, err := store.ListTopics(ctx); err == nil { - t.Error("ListTopics with cancelled context should return error") + if _, err := os.Stat(filepath.Join(baseDir, "projects", agentsession.HashWorkspaceRoot("/workspace/project"), "memo", "project", memoFileName)); err != nil { + t.Fatalf("expected project memo to exist: %v", err) } } -func TestFileStoreWorkspaceIsolation(t *testing.T) { - tmp := t.TempDir() - store1 := NewFileStore(tmp, "/workspace/a") - store2 := NewFileStore(tmp, "/workspace/b") - ctx := context.Background() - - idx1 := &Index{Entries: []Entry{{Type: TypeUser, Title: "Project A"}}} - if err := store1.SaveIndex(ctx, idx1); err != nil { - t.Fatalf("SaveIndex store1 error: %v", err) +func TestFileStoreLoadIndexFallsBackToLegacyProjectPath(t *testing.T) { + store, legacyDir := newLegacyProjectStore(t) + if err := os.MkdirAll(legacyDir, 0o755); err != nil { + t.Fatalf("MkdirAll(legacy) error = %v", err) + } + index := &Index{Entries: []Entry{{Type: TypeProject, Title: "legacy"}}} + if err := os.WriteFile(filepath.Join(legacyDir, memoFileName), []byte(RenderIndex(index)), 0o644); err != nil { + t.Fatalf("WriteFile(legacy index) error = %v", err) } - idx2, err := store2.LoadIndex(ctx) + loaded, err := store.LoadIndex(context.Background(), ScopeProject) if err != nil { - t.Fatalf("LoadIndex store2 error: %v", err) + t.Fatalf("LoadIndex() error = %v", err) } - if len(idx2.Entries) != 0 { - t.Errorf("store2 should have no entries (workspace isolation), got %d", len(idx2.Entries)) + if len(loaded.Entries) != 1 || loaded.Entries[0].Title != "legacy" { + t.Fatalf("loaded entries = %#v", loaded.Entries) } } -func TestFileStoreAtomicWrite(t *testing.T) { - tmp := t.TempDir() - store := NewFileStore(tmp, "/workspace/project") - ctx := context.Background() +func TestFileStoreLoadTopicAndListTopicsFallbackToLegacyProjectPath(t *testing.T) { + store, legacyDir := newLegacyProjectStore(t) + legacyTopicsDir := filepath.Join(legacyDir, topicsDirName) + if err := os.MkdirAll(legacyTopicsDir, 0o755); err != nil { + t.Fatalf("MkdirAll(legacy topics) error = %v", err) + } + if err := os.WriteFile(filepath.Join(legacyTopicsDir, "legacy.md"), []byte("legacy"), 0o644); err != nil { + t.Fatalf("WriteFile(legacy topic) error = %v", err) + } - // 写入索引后不应存在临时文件 - _ = store.SaveIndex(ctx, &Index{Entries: []Entry{{Type: TypeUser, Title: "test"}}}) + content, err := store.LoadTopic(context.Background(), ScopeProject, "legacy.md") + if err != nil { + t.Fatalf("LoadTopic() error = %v", err) + } + if content != "legacy" { + t.Fatalf("LoadTopic() = %q, want %q", content, "legacy") + } - memoDir := store.memoDir - entries, _ := os.ReadDir(memoDir) - for _, e := range entries { - if strings.HasSuffix(e.Name(), ".tmp") { - t.Errorf("temp file should not exist after atomic write: %s", e.Name()) - } + topics, err := store.ListTopics(context.Background(), ScopeProject) + if err != nil { + t.Fatalf("ListTopics() error = %v", err) + } + if len(topics) != 1 || topics[0] != "legacy.md" { + t.Fatalf("ListTopics() = %#v, want [legacy.md]", topics) } } -func TestMemoDirectory(t *testing.T) { - dir := memoDirectory("/base", "/workspace") - expected := filepath.Join("/base", "projects", agentsession.HashWorkspaceRoot("/workspace"), "memo") - if dir != expected { - t.Errorf("memoDirectory = %q, want %q", dir, expected) +func TestFileStoreListTopicsMergesScopedAndLegacyProjectTopics(t *testing.T) { + store, legacyDir := newLegacyProjectStore(t) + scopedTopicsDir := store.topicsDir(ScopeProject) + legacyTopicsDir := filepath.Join(legacyDir, topicsDirName) + if err := os.MkdirAll(scopedTopicsDir, 0o755); err != nil { + t.Fatalf("MkdirAll(scoped topics) error = %v", err) + } + if err := os.MkdirAll(legacyTopicsDir, 0o755); err != nil { + t.Fatalf("MkdirAll(legacy topics) error = %v", err) + } + if err := os.WriteFile(filepath.Join(scopedTopicsDir, "scoped.md"), []byte("scoped"), 0o644); err != nil { + t.Fatalf("WriteFile(scoped topic) error = %v", err) + } + if err := os.WriteFile(filepath.Join(legacyTopicsDir, "legacy.md"), []byte("legacy"), 0o644); err != nil { + t.Fatalf("WriteFile(legacy topic) error = %v", err) + } + if err := os.WriteFile(filepath.Join(legacyTopicsDir, "scoped.md"), []byte("legacy dup"), 0o644); err != nil { + t.Fatalf("WriteFile(legacy duplicate topic) error = %v", err) } -} -func TestHashWorkspaceRootStable(t *testing.T) { - h1 := agentsession.HashWorkspaceRoot("/workspace/project") - h2 := agentsession.HashWorkspaceRoot("/workspace/project") - if h1 != h2 { - t.Errorf("hash not stable: %q != %q", h1, h2) + topics, err := store.ListTopics(context.Background(), ScopeProject) + if err != nil { + t.Fatalf("ListTopics() error = %v", err) + } + want := []string{"legacy.md", "scoped.md"} + if len(topics) != len(want) { + t.Fatalf("len(topics) = %d, want %d, topics = %#v", len(topics), len(want), topics) + } + for i := range want { + if topics[i] != want[i] { + t.Fatalf("topics[%d] = %q, want %q (topics=%#v)", i, topics[i], want[i], topics) + } } } -func TestHashWorkspaceRootDifferent(t *testing.T) { - h1 := agentsession.HashWorkspaceRoot("/workspace/a") - h2 := agentsession.HashWorkspaceRoot("/workspace/b") - if h1 == h2 { - t.Errorf("different paths should produce different hashes") +func TestFileStoreSaveIndexMigratesLegacyProjectData(t *testing.T) { + store, legacyDir := newLegacyProjectStore(t) + legacyTopicsDir := filepath.Join(legacyDir, topicsDirName) + if err := os.MkdirAll(legacyTopicsDir, 0o755); err != nil { + t.Fatalf("MkdirAll(legacy topics) error = %v", err) + } + if err := os.WriteFile(filepath.Join(legacyDir, memoFileName), []byte(RenderIndex(&Index{ + Entries: []Entry{{Type: TypeProject, Title: "legacy index"}}, + })), 0o644); err != nil { + t.Fatalf("WriteFile(legacy index) error = %v", err) + } + if err := os.WriteFile(filepath.Join(legacyTopicsDir, "legacy.md"), []byte("legacy topic"), 0o644); err != nil { + t.Fatalf("WriteFile(legacy topic) error = %v", err) } -} -func TestHashWorkspaceRootEmpty(t *testing.T) { - h := agentsession.HashWorkspaceRoot("") - // 空路径回退到 "unknown" 的哈希,应产生稳定的非空结果 - if h == "" { - t.Error("hash of empty workspace root should not be empty") + if err := store.SaveIndex(context.Background(), ScopeProject, &Index{ + Entries: []Entry{{Type: TypeProject, Title: "new index"}}, + }); err != nil { + t.Fatalf("SaveIndex() error = %v", err) } - if len(h) != 16 { - t.Errorf("hash length = %d, want 16 (8 bytes hex)", len(h)) + + newScopeDir := store.scopeDir(ScopeProject) + if _, err := os.Stat(filepath.Join(newScopeDir, memoFileName)); err != nil { + t.Fatalf("expected scoped index after migration: %v", err) + } + if _, err := os.Stat(filepath.Join(newScopeDir, topicsDirName, "legacy.md")); err != nil { + t.Fatalf("expected scoped topic after migration: %v", err) + } + if _, err := os.Stat(filepath.Join(legacyDir, memoFileName)); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected legacy index to be migrated, stat err = %v", err) } + if _, err := os.Stat(filepath.Join(legacyTopicsDir, "legacy.md")); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected legacy topic to be migrated, stat err = %v", err) + } +} + +func newLegacyProjectStore(t *testing.T) (*FileStore, string) { + t.Helper() + baseDir := t.TempDir() + workspaceRoot := "/workspace/project" + return NewFileStore(baseDir, workspaceRoot), projectMemoDirectory(baseDir, workspaceRoot) } diff --git a/internal/memo/test_helpers_test.go b/internal/memo/test_helpers_test.go new file mode 100644 index 00000000..aa92b023 --- /dev/null +++ b/internal/memo/test_helpers_test.go @@ -0,0 +1,119 @@ +package memo + +import ( + "context" + "errors" + "sync" +) + +type memoryTestStore struct { + mu sync.Mutex + + indexes map[Scope]*Index + topics map[Scope]map[string]string + + err error + saveIndexErr error + saveTopicErr error + deleteTopicErr error + + loadIndexCalls int + saveIndexCalls int + saveTopicCalls int + deleteTopicCalls int + + deletedTopics []string +} + +func newMemoryTestStore() *memoryTestStore { + return &memoryTestStore{ + indexes: make(map[Scope]*Index), + topics: map[Scope]map[string]string{ + ScopeUser: {}, + ScopeProject: {}, + }, + } +} + +func (s *memoryTestStore) LoadIndex(_ context.Context, scope Scope) (*Index, error) { + s.mu.Lock() + defer s.mu.Unlock() + + s.loadIndexCalls++ + if s.err != nil { + return nil, s.err + } + index, ok := s.indexes[scope] + if !ok || index == nil { + return &Index{Entries: []Entry{}}, nil + } + return cloneIndex(index), nil +} + +func (s *memoryTestStore) SaveIndex(_ context.Context, scope Scope, index *Index) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.saveIndexCalls++ + if s.saveIndexErr != nil { + return s.saveIndexErr + } + s.indexes[scope] = cloneIndex(index) + return nil +} + +func (s *memoryTestStore) LoadTopic(_ context.Context, scope Scope, filename string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.err != nil { + return "", s.err + } + content, ok := s.topics[scope][filename] + if !ok { + return "", errors.New("not found") + } + return content, nil +} + +func (s *memoryTestStore) SaveTopic(_ context.Context, scope Scope, filename string, content string) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.saveTopicCalls++ + if s.saveTopicErr != nil { + return s.saveTopicErr + } + if s.topics[scope] == nil { + s.topics[scope] = map[string]string{} + } + s.topics[scope][filename] = content + return nil +} + +func (s *memoryTestStore) DeleteTopic(_ context.Context, scope Scope, filename string) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.deleteTopicCalls++ + s.deletedTopics = append(s.deletedTopics, scopedTopicKey(scope, filename)) + if s.deleteTopicErr != nil { + return s.deleteTopicErr + } + delete(s.topics[scope], filename) + return nil +} + +func (s *memoryTestStore) ListTopics(_ context.Context, scope Scope) ([]string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.err != nil { + return nil, s.err + } + result := make([]string, 0, len(s.topics[scope])) + for name := range s.topics[scope] { + result = append(result, name) + } + return result, nil +} diff --git a/internal/memo/types.go b/internal/memo/types.go index 62ab2b9d..563cde52 100644 --- a/internal/memo/types.go +++ b/internal/memo/types.go @@ -2,6 +2,7 @@ package memo import ( "context" + "strings" "time" providertypes "neo-code/internal/provider/types" @@ -43,20 +44,39 @@ type Entry struct { UpdatedAt time.Time } +// Scope 表示 memo 的逻辑分层范围。 +type Scope string + +const ( + // ScopeAll 表示同时覆盖 user 与 project 两层。 + ScopeAll Scope = "all" + // ScopeUser 表示仅 user 记忆层。 + ScopeUser Scope = "user" + // ScopeProject 表示 project 记忆层,承载 feedback/project/reference。 + ScopeProject Scope = "project" +) + // Index 表示 MEMO.md 索引文件的内存模型。 type Index struct { Entries []Entry UpdatedAt time.Time } +// RecalledEntry 表示一次 recall 命中的结构化结果。 +type RecalledEntry struct { + Scope Scope + Entry Entry + Content string +} + // Store 定义记忆持久化的最小抽象。 type Store interface { - LoadIndex(ctx context.Context) (*Index, error) - SaveIndex(ctx context.Context, index *Index) error - LoadTopic(ctx context.Context, filename string) (string, error) - SaveTopic(ctx context.Context, filename string, content string) error - DeleteTopic(ctx context.Context, filename string) error - ListTopics(ctx context.Context) ([]string, error) + LoadIndex(ctx context.Context, scope Scope) (*Index, error) + SaveIndex(ctx context.Context, scope Scope, index *Index) error + LoadTopic(ctx context.Context, scope Scope, filename string) (string, error) + SaveTopic(ctx context.Context, scope Scope, filename string, content string) error + DeleteTopic(ctx context.Context, scope Scope, filename string) error + ListTopics(ctx context.Context, scope Scope) ([]string, error) } // Extractor 定义从对话消息中提取记忆的最小能力。 @@ -90,3 +110,32 @@ func ParseType(s string) (Type, bool) { t := Type(s) return t, IsValidType(t) } + +// NormalizeScope 将外部输入收敛为受支持的 memo 作用范围。 +func NormalizeScope(scope Scope) Scope { + switch Scope(strings.ToLower(string(scope))) { + case ScopeUser: + return ScopeUser + case ScopeProject: + return ScopeProject + default: + return ScopeAll + } +} + +// ScopeForType 返回给定类型固定落到的记忆分层。 +func ScopeForType(t Type) Scope { + switch t { + case TypeUser: + return ScopeUser + case TypeFeedback, TypeProject, TypeReference: + return ScopeProject + default: + return ScopeProject + } +} + +// supportedStorageScopes 返回当前实现实际落盘的所有 memo 分层。 +func supportedStorageScopes() []Scope { + return []Scope{ScopeUser, ScopeProject} +} diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 07c6cbc1..a8383f21 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -32,6 +32,7 @@ type Runtime interface { PrepareUserInput(ctx context.Context, input PrepareInput) (UserInput, error) Run(ctx context.Context, input UserInput) error Compact(ctx context.Context, input CompactInput) (CompactResult, error) + ExecuteSystemTool(ctx context.Context, input SystemToolInput) (tools.ToolResult, error) ResolvePermission(ctx context.Context, input PermissionResolutionInput) error CancelActiveRun() bool Events() <-chan RuntimeEvent @@ -68,6 +69,15 @@ type PrepareInput struct { Images []UserImageInput } +// SystemToolInput 描述一次由系统入口触发的确定性工具执行请求。 +type SystemToolInput struct { + SessionID string + RunID string + Workdir string + ToolName string + Arguments []byte +} + // PreparedInputResult 描述输入归一化完成后的结果快照(标准 UserInput + 本轮保存附件元数据)。 type PreparedInputResult struct { UserInput UserInput diff --git a/internal/runtime/system_tool.go b/internal/runtime/system_tool.go new file mode 100644 index 00000000..fa70de49 --- /dev/null +++ b/internal/runtime/system_tool.go @@ -0,0 +1,130 @@ +package runtime + +import ( + "context" + "fmt" + "strings" + "time" + + providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" + "neo-code/internal/tools" +) + +// ExecuteSystemTool 通过 runtime 统一执行一次确定性系统工具调用,不进入 provider/ReAct 主循环。 +func (s *Service) ExecuteSystemTool(ctx context.Context, input SystemToolInput) (tools.ToolResult, error) { + if s == nil { + return tools.ToolResult{}, fmt.Errorf("runtime: service is nil") + } + if err := ctx.Err(); err != nil { + return tools.ToolResult{}, err + } + + toolName := strings.TrimSpace(input.ToolName) + if toolName == "" { + return tools.ToolResult{}, fmt.Errorf("runtime: tool name is empty") + } + + sessionID := strings.TrimSpace(input.SessionID) + runID := strings.TrimSpace(input.RunID) + if runID == "" { + runID = newSystemToolRunID(toolName) + } + + cfg := s.configManager.Get() + workdir := strings.TrimSpace(input.Workdir) + if workdir == "" { + workdir = cfg.Workdir + } + + var ( + state *runState + loaded agentsession.Session + ) + if sessionID != "" { + sessionMu, releaseLockRef := s.acquireSessionLock(sessionID) + sessionMu.Lock() + + session, err := s.sessionStore.LoadSession(ctx, sessionID) + if err != nil { + sessionMu.Unlock() + releaseLockRef() + return tools.ToolResult{}, err + } + loaded = session + if workdir == "" { + workdir = strings.TrimSpace(session.Workdir) + } + runStateValue := newRunState(runID, session) + state = &runStateValue + sessionMu.Unlock() + releaseLockRef() + } + + call := providertypes.ToolCall{ + ID: newSystemToolCallID(toolName), + Name: toolName, + Arguments: string(input.Arguments), + } + + if state != nil { + _ = s.emitRunScoped(ctx, EventToolStart, state, call) + } else { + _ = s.emit(ctx, EventToolStart, runID, sessionID, call) + } + + result, execErr := s.executeToolCallWithPermission(ctx, permissionExecutionInput{ + RunID: runID, + SessionID: sessionID, + State: state, + Call: call, + Workdir: workdir, + ToolTimeout: time.Duration(cfg.ToolTimeoutSec) * time.Second, + }) + + if strings.TrimSpace(result.ToolCallID) == "" { + result.ToolCallID = call.ID + } + if strings.TrimSpace(result.Name) == "" { + result.Name = toolName + } + if execErr != nil { + result.IsError = true + } + + if state != nil { + if loaded.ID != "" { + state.session = loaded + } + _ = s.emitRunScoped(ctx, EventToolResult, state, result) + s.emitTodoToolEvent(ctx, state, call, result, execErr) + } else { + _ = s.emit(ctx, EventToolResult, runID, sessionID, result) + } + + return result, execErr +} + +// normalizeToolName 将工具名标准化,空值回退为 "tool"。 +func normalizeToolName(name string) string { + normalized := strings.ToLower(strings.TrimSpace(name)) + if normalized == "" { + normalized = "tool" + } + return normalized +} + +// newSystemToolRunID 为系统工具调用生成稳定前缀的运行标识,便于事件与日志定位。 +func newSystemToolRunID(toolName string) string { + return formatSystemToolID("system-tool", toolName) +} + +// newSystemToolCallID 为系统工具调用生成单次执行唯一的 tool call id。 +func newSystemToolCallID(toolName string) string { + return formatSystemToolID("call", toolName) +} + +// formatSystemToolID 统一构造系统工具相关 ID,避免不同类型 ID 生成逻辑分散重复。 +func formatSystemToolID(prefix, toolName string) string { + return fmt.Sprintf("%s-%s-%d", prefix, normalizeToolName(toolName), time.Now().UnixNano()) +} diff --git a/internal/runtime/system_tool_test.go b/internal/runtime/system_tool_test.go new file mode 100644 index 00000000..b6af7330 --- /dev/null +++ b/internal/runtime/system_tool_test.go @@ -0,0 +1,475 @@ +package runtime + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + agentsession "neo-code/internal/session" + "neo-code/internal/tools" +) + +// TestExecuteSystemToolNilService 验证在 nil *Service 上调用返回错误。 +func TestExecuteSystemToolNilService(t *testing.T) { + t.Parallel() + + var s *Service + _, err := s.ExecuteSystemTool(context.Background(), SystemToolInput{ToolName: "bash"}) + if err == nil { + t.Fatal("expected error on nil service, got nil") + } + if !strings.Contains(err.Error(), "service is nil") { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestExecuteSystemToolEmptyToolName 验证空工具名返回错误。 +func TestExecuteSystemToolEmptyToolName(t *testing.T) { + t.Parallel() + + service := NewWithFactory( + newRuntimeConfigManager(t), + &stubToolManager{}, + newMemoryStore(), + &scriptedProviderFactory{provider: &scriptedProvider{}}, + nil, + ) + + _, err := service.ExecuteSystemTool(context.Background(), SystemToolInput{ToolName: ""}) + if err == nil { + t.Fatal("expected error for empty tool name, got nil") + } + if !strings.Contains(err.Error(), "tool name is empty") { + t.Fatalf("unexpected error: %v", err) + } + + // 空白字符串也应返回错误 + _, err = service.ExecuteSystemTool(context.Background(), SystemToolInput{ToolName: " "}) + if err == nil { + t.Fatal("expected error for whitespace-only tool name, got nil") + } +} + +// TestExecuteSystemToolCancelledContext 验证已取消的上下文立即返回错误。 +func TestExecuteSystemToolCancelledContext(t *testing.T) { + t.Parallel() + + service := NewWithFactory( + newRuntimeConfigManager(t), + &stubToolManager{}, + newMemoryStore(), + &scriptedProviderFactory{provider: &scriptedProvider{}}, + nil, + ) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := service.ExecuteSystemTool(ctx, SystemToolInput{ToolName: "bash"}) + if err == nil { + t.Fatal("expected error for cancelled context, got nil") + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got: %v", err) + } +} + +// TestExecuteSystemToolSuccess 验证基本成功执行路径。 +func TestExecuteSystemToolSuccess(t *testing.T) { + t.Parallel() + + tm := &stubToolManager{ + result: tools.ToolResult{Content: "ok"}, + } + service := NewWithFactory( + newRuntimeConfigManager(t), + tm, + newMemoryStore(), + &scriptedProviderFactory{provider: &scriptedProvider{}}, + nil, + ) + + result, err := service.ExecuteSystemTool(context.Background(), SystemToolInput{ + ToolName: "bash", + Arguments: []byte(`{"command":"echo hello"}`), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.IsError { + t.Fatal("result should not be an error") + } + if result.Name != "bash" { + t.Fatalf("expected tool name 'bash', got %q", result.Name) + } + if result.ToolCallID == "" { + t.Fatal("expected non-empty tool call ID") + } + + // 验证事件发射 + events := collectRuntimeEvents(service.Events()) + assertEventContains(t, events, EventToolStart) + assertEventContains(t, events, EventToolResult) +} + +// TestExecuteSystemToolWithSession 验证提供 sessionID 时能正确加载会话并执行。 +func TestExecuteSystemToolWithSession(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session, err := store.CreateSession(context.Background(), agentsession.CreateSessionInput{ + Title: "test-session", + }) + if err != nil { + t.Fatalf("create session: %v", err) + } + + tm := &stubToolManager{ + result: tools.ToolResult{Content: "done"}, + } + service := NewWithFactory( + newRuntimeConfigManager(t), + tm, + store, + &scriptedProviderFactory{provider: &scriptedProvider{}}, + nil, + ) + + result, err := service.ExecuteSystemTool(context.Background(), SystemToolInput{ + SessionID: session.ID, + ToolName: "bash", + Arguments: []byte(`{"command":"ls"}`), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.IsError { + t.Fatal("result should not be an error") + } + + events := collectRuntimeEvents(service.Events()) + assertEventContains(t, events, EventToolStart) + assertEventContains(t, events, EventToolResult) +} + +// TestExecuteSystemToolReleasesSessionLockBeforeToolExecution 验证工具执行期间不会继续持有会话锁。 +func TestExecuteSystemToolReleasesSessionLockBeforeToolExecution(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session, err := store.CreateSession(context.Background(), agentsession.CreateSessionInput{ + Title: "test-session", + }) + if err != nil { + t.Fatalf("create session: %v", err) + } + + started := make(chan struct{}) + releaseTool := make(chan struct{}) + tm := &stubToolManager{ + executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + close(started) + <-releaseTool + return tools.ToolResult{Content: "done"}, nil + }, + } + service := NewWithFactory( + newRuntimeConfigManager(t), + tm, + store, + &scriptedProviderFactory{provider: &scriptedProvider{}}, + nil, + ) + + runErr := make(chan error, 1) + go func() { + _, runErrValue := service.ExecuteSystemTool(context.Background(), SystemToolInput{ + SessionID: session.ID, + ToolName: "bash", + Arguments: []byte(`{"command":"ls"}`), + }) + runErr <- runErrValue + }() + + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting tool execution to start") + } + + sessionMu, releaseLockRef := service.acquireSessionLock(session.ID) + lockAcquired := make(chan struct{}) + go func() { + sessionMu.Lock() + close(lockAcquired) + sessionMu.Unlock() + releaseLockRef() + }() + + select { + case <-lockAcquired: + case <-time.After(2 * time.Second): + close(releaseTool) + t.Fatal("session lock is still held during tool execution") + } + + close(releaseTool) + select { + case err := <-runErr: + if err != nil { + t.Fatalf("execute system tool: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting ExecuteSystemTool to return") + } +} + +// TestExecuteSystemToolWithSessionLoadError 验证会话加载失败时返回错误。 +func TestExecuteSystemToolWithSessionLoadError(t *testing.T) { + t.Parallel() + + tm := &stubToolManager{ + result: tools.ToolResult{Content: "should not run"}, + } + service := NewWithFactory( + newRuntimeConfigManager(t), + tm, + newMemoryStore(), // 空 store,无会话数据 + &scriptedProviderFactory{provider: &scriptedProvider{}}, + nil, + ) + + _, err := service.ExecuteSystemTool(context.Background(), SystemToolInput{ + SessionID: "nonexistent-session", + ToolName: "bash", + }) + if err == nil { + t.Fatal("expected error for nonexistent session, got nil") + } + if !strings.Contains(err.Error(), "not found") { + t.Fatalf("expected 'not found' error, got: %v", err) + } +} + +// TestExecuteSystemToolCustomRunID 验证自定义 RunID 被正确使用。 +func TestExecuteSystemToolCustomRunID(t *testing.T) { + t.Parallel() + + tm := &stubToolManager{ + executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + return tools.ToolResult{Content: "ok"}, nil + }, + } + service := NewWithFactory( + newRuntimeConfigManager(t), + tm, + newMemoryStore(), + &scriptedProviderFactory{provider: &scriptedProvider{}}, + nil, + ) + + result, err := service.ExecuteSystemTool(context.Background(), SystemToolInput{ + RunID: "my-custom-run-id", + ToolName: "bash", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.IsError { + t.Fatal("result should not be an error") + } + + // 验证事件中包含自定义 RunID + events := collectRuntimeEvents(service.Events()) + found := false + for _, e := range events { + if e.RunID == "my-custom-run-id" { + found = true + break + } + } + if !found { + t.Fatalf("expected event with RunID 'my-custom-run-id' in %d events", len(events)) + } +} + +// TestExecuteSystemToolDefaultWorkdir 验证 workdir 为空时使用配置默认值。 +func TestExecuteSystemToolDefaultWorkdir(t *testing.T) { + t.Parallel() + + cfgManager := newRuntimeConfigManager(t) + cfg := cfgManager.Get() + + var capturedInput tools.ToolCallInput + tm := &stubToolManager{ + executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + capturedInput = input + return tools.ToolResult{Content: "ok"}, nil + }, + } + service := NewWithFactory( + cfgManager, + tm, + newMemoryStore(), + &scriptedProviderFactory{provider: &scriptedProvider{}}, + nil, + ) + + _, err := service.ExecuteSystemTool(context.Background(), SystemToolInput{ + ToolName: "bash", + Workdir: "", // 空值,应使用配置默认值 + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if capturedInput.Workdir != cfg.Workdir { + t.Fatalf("expected workdir %q, got %q", cfg.Workdir, capturedInput.Workdir) + } +} + +// TestExecuteSystemToolToolExecutionError 验证工具执行错误时 result.IsError 被设置为 true。 +func TestExecuteSystemToolToolExecutionError(t *testing.T) { + t.Parallel() + + execErr := errors.New("tool execution failed") + tm := &stubToolManager{ + result: tools.ToolResult{Content: "partial output"}, + err: execErr, + } + service := NewWithFactory( + newRuntimeConfigManager(t), + tm, + newMemoryStore(), + &scriptedProviderFactory{provider: &scriptedProvider{}}, + nil, + ) + + result, err := service.ExecuteSystemTool(context.Background(), SystemToolInput{ + ToolName: "bash", + }) + if err == nil { + t.Fatal("expected error from tool execution, got nil") + } + if !errors.Is(err, execErr) { + t.Fatalf("expected wrapped exec error, got: %v", err) + } + if !result.IsError { + t.Fatal("result.IsError should be true when execution fails") + } + if result.Name != "bash" { + t.Fatalf("expected tool name 'bash', got %q", result.Name) + } + + events := collectRuntimeEvents(service.Events()) + assertEventContains(t, events, EventToolStart) + assertEventContains(t, events, EventToolResult) +} + +// TestNewSystemToolRunID 验证 run ID 生成格式与空名称回退。 +func TestNewSystemToolRunID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + prefix string + }{ + { + name: "normal tool name", + input: "Bash", + prefix: "system-tool-bash-", + }, + { + name: "mixed case normalized", + input: "ReadFile", + prefix: "system-tool-readfile-", + }, + { + name: "whitespace trimmed", + input: " bash ", + prefix: "system-tool-bash-", + }, + { + name: "empty name falls back to tool", + input: "", + prefix: "system-tool-tool-", + }, + { + name: "whitespace-only falls back to tool", + input: " ", + prefix: "system-tool-tool-", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := newSystemToolRunID(tc.input) + assertGeneratedIDWithPrefix(t, got, tc.prefix) + }) + } +} + +// TestNewSystemToolCallID 验证 call ID 生成格式与空名称回退。 +func TestNewSystemToolCallID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + prefix string + }{ + { + name: "normal tool name", + input: "Bash", + prefix: "call-bash-", + }, + { + name: "mixed case normalized", + input: "ReadFile", + prefix: "call-readfile-", + }, + { + name: "whitespace trimmed", + input: " bash ", + prefix: "call-bash-", + }, + { + name: "empty name falls back to tool", + input: "", + prefix: "call-tool-", + }, + { + name: "whitespace-only falls back to tool", + input: " ", + prefix: "call-tool-", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := newSystemToolCallID(tc.input) + assertGeneratedIDWithPrefix(t, got, tc.prefix) + }) + } +} + +func assertGeneratedIDWithPrefix(t *testing.T, got, prefix string) { + t.Helper() + + if !strings.HasPrefix(got, prefix) { + t.Fatalf("expected prefix %q, got %q", prefix, got) + } + + suffix := strings.TrimPrefix(got, prefix) + if suffix == "" { + t.Fatal("expected numeric suffix after prefix") + } + for _, ch := range suffix { + if ch < '0' || ch > '9' { + t.Fatalf("expected numeric suffix, got %q in %q", string(ch), got) + } + } +} diff --git a/internal/security/workspace.go b/internal/security/workspace.go index b36f7582..33b1fd4e 100644 --- a/internal/security/workspace.go +++ b/internal/security/workspace.go @@ -380,10 +380,14 @@ func normalizeVolumeName(path string) string { func nearestExistingPath(root string, target string) (string, error) { current := cleanedPathKey(target) + originalTarget := current root = cleanedPathKey(root) for { info, err := os.Lstat(current) if err == nil { + if !samePathKey(current, originalTarget) && info.Mode()&os.ModeSymlink == 0 && !info.IsDir() { + return "", fmt.Errorf("security: inspect path %q: parent is not a directory", current) + } if info.Mode()&os.ModeSymlink != 0 || current != root { return current, nil } diff --git a/internal/tools/memo/common.go b/internal/tools/memo/common.go index 5534b3d5..18a83fd4 100644 --- a/internal/tools/memo/common.go +++ b/internal/tools/memo/common.go @@ -2,7 +2,9 @@ package memo import ( "fmt" + "strings" + "neo-code/internal/memo" "neo-code/internal/tools" ) @@ -11,3 +13,40 @@ func nilServiceError(toolName string) (tools.ToolResult, error) { err := fmt.Errorf("%s: service is nil", toolName) return tools.NewErrorResult(toolName, tools.NormalizeErrorReason(toolName, err), "", nil), err } + +// invalidArgumentsError 构造 memo 工具参数解析失败时的统一错误结果。 +func invalidArgumentsError(toolName string, err error) (tools.ToolResult, error) { + wrappedErr := fmt.Errorf("%s: %w", toolName, err) + return tools.NewErrorResult(toolName, "invalid arguments", wrappedErr.Error(), nil), wrappedErr +} + +// memoScopePropertySchema 返回 memo 工具统一的 scope 参数 schema 描述。 +func memoScopePropertySchema() map[string]any { + return map[string]any{ + "type": "string", + "description": "Optional scope filter: all, user, or project.", + "enum": []string{"all", "user", "project"}, + } +} + +// parseMemoScope 解析 memo scope,并根据 allowAll 决定是否接受 all。 +func parseMemoScope(raw string, allowAll bool) (memo.Scope, error) { + normalized := strings.ToLower(strings.TrimSpace(raw)) + if normalized == "" { + if allowAll { + return memo.ScopeAll, nil + } + return memo.ScopeProject, fmt.Errorf("memo: scope is required") + } + switch memo.Scope(normalized) { + case memo.ScopeUser: + return memo.ScopeUser, nil + case memo.ScopeProject: + return memo.ScopeProject, nil + case memo.ScopeAll: + if allowAll { + return memo.ScopeAll, nil + } + } + return "", fmt.Errorf("memo: unsupported scope %q", raw) +} diff --git a/internal/tools/memo/list.go b/internal/tools/memo/list.go new file mode 100644 index 00000000..72b2676e --- /dev/null +++ b/internal/tools/memo/list.go @@ -0,0 +1,118 @@ +package memo + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "neo-code/internal/memo" + "neo-code/internal/tools" +) + +const listToolName = tools.ToolNameMemoList + +type listInput struct { + Scope string `json:"scope,omitempty"` +} + +// ListTool 让调用方按层列出持久记忆目录。 +type ListTool struct { + svc *memo.Service +} + +// NewListTool 创建 memo_list 工具。 +func NewListTool(svc *memo.Service) *ListTool { + return &ListTool{svc: svc} +} + +// Name 返回工具注册名。 +func (t *ListTool) Name() string { return listToolName } + +// Description 返回工具描述。 +func (t *ListTool) Description() string { + return "List persistent memories grouped by scope. Use this to inspect saved user or project memories." +} + +// Schema 返回 JSON Schema 描述的工具参数格式。 +func (t *ListTool) Schema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "scope": memoScopePropertySchema(), + }, + } +} + +// MicroCompactPolicy 记忆目录结果应保留在上下文中,不参与 micro compact 清理。 +func (t *ListTool) MicroCompactPolicy() tools.MicroCompactPolicy { + return tools.MicroCompactPolicyPreserveHistory +} + +// Execute 执行 memo_list 工具调用。 +func (t *ListTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { + if t.svc == nil { + return nilServiceError(listToolName) + } + + var args listInput + if len(call.Arguments) > 0 { + if err := json.Unmarshal(call.Arguments, &args); err != nil { + return invalidArgumentsError(listToolName, err) + } + } + + scope, err := parseMemoScope(args.Scope, true) + if err != nil { + return tools.NewErrorResult(listToolName, tools.NormalizeErrorReason(listToolName, err), "", nil), err + } + entries, err := t.svc.List(ctx, scope) + if err != nil { + return tools.NewErrorResult(listToolName, tools.NormalizeErrorReason(listToolName, err), "", nil), err + } + if len(entries) == 0 { + return tools.ToolResult{ + Name: listToolName, + Content: "No memos stored yet.", + }, nil + } + + var userLines []string + var projectLines []string + for _, entry := range entries { + line := fmt.Sprintf("- [%s] %s", entry.Type, entry.Title) + if memo.ScopeForType(entry.Type) == memo.ScopeUser { + userLines = append(userLines, line) + continue + } + projectLines = append(projectLines, line) + } + + var builder strings.Builder + if scope == memo.ScopeAll || scope == memo.ScopeUser { + builder.WriteString("User Memo:\n") + if len(userLines) == 0 { + builder.WriteString("- \n") + } else { + builder.WriteString(strings.Join(userLines, "\n")) + builder.WriteString("\n") + } + } + if scope == memo.ScopeAll || scope == memo.ScopeProject { + if builder.Len() > 0 { + builder.WriteString("\n") + } + builder.WriteString("Project Memo:\n") + if len(projectLines) == 0 { + builder.WriteString("- \n") + } else { + builder.WriteString(strings.Join(projectLines, "\n")) + builder.WriteString("\n") + } + } + + return tools.ToolResult{ + Name: listToolName, + Content: strings.TrimSpace(builder.String()), + }, nil +} diff --git a/internal/tools/memo/list_test.go b/internal/tools/memo/list_test.go new file mode 100644 index 00000000..7062039f --- /dev/null +++ b/internal/tools/memo/list_test.go @@ -0,0 +1,224 @@ +package memo + +import ( + "context" + "encoding/json" + "strings" + "testing" + + "neo-code/internal/memo" + "neo-code/internal/tools" +) + +func TestListToolName(t *testing.T) { + tool := NewListTool(nil) + if tool.Name() != tools.ToolNameMemoList { + t.Fatalf("Name() = %q, want %q", tool.Name(), tools.ToolNameMemoList) + } + if tool.Description() == "" { + t.Fatal("Description() should not be empty") + } + if tool.Schema() == nil { + t.Fatal("Schema() should not be nil") + } + if tool.MicroCompactPolicy() != tools.MicroCompactPolicyPreserveHistory { + t.Fatalf("MicroCompactPolicy() = %v, want PreserveHistory", tool.MicroCompactPolicy()) + } +} + +func TestListToolExecuteEmpty(t *testing.T) { + svc := newTestService(t) + tool := NewListTool(svc) + + args, _ := json.Marshal(listInput{}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if result.IsError { + t.Fatalf("unexpected error result: %+v", result) + } + if result.Content != "No memos stored yet." { + t.Fatalf("Content = %q, want %q", result.Content, "No memos stored yet.") + } +} + +func TestListToolExecuteWithUserMemos(t *testing.T) { + svc := newTestService(t) + if err := svc.Add(context.Background(), memo.Entry{ + Type: memo.TypeUser, + Title: "prefer go style", + Content: "prefer go style", + Source: memo.SourceUserManual, + }); err != nil { + t.Fatalf("seed Add() error = %v", err) + } + + tool := NewListTool(svc) + args, _ := json.Marshal(listInput{Scope: "user"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if result.IsError { + t.Fatalf("unexpected error result: %+v", result) + } + if !strings.Contains(result.Content, "User Memo:") { + t.Fatalf("expected 'User Memo:' header, got %q", result.Content) + } + if !strings.Contains(result.Content, "[user] prefer go style") { + t.Fatalf("expected user entry line, got %q", result.Content) + } + if strings.Contains(result.Content, "Project Memo:") { + t.Fatalf("scope=user should not include Project Memo section, got %q", result.Content) + } +} + +func TestListToolExecuteWithProjectMemos(t *testing.T) { + svc := newTestService(t) + for _, e := range []memo.Entry{ + {Type: memo.TypeFeedback, Title: "fix logging", Content: "fix logging", Source: memo.SourceToolInitiated}, + {Type: memo.TypeProject, Title: "use grpc", Content: "use grpc", Source: memo.SourceToolInitiated}, + {Type: memo.TypeReference, Title: "design doc", Content: "design doc", Source: memo.SourceToolInitiated}, + } { + if err := svc.Add(context.Background(), e); err != nil { + t.Fatalf("seed Add() error = %v", err) + } + } + + tool := NewListTool(svc) + args, _ := json.Marshal(listInput{Scope: "project"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if result.IsError { + t.Fatalf("unexpected error result: %+v", result) + } + if !strings.Contains(result.Content, "Project Memo:") { + t.Fatalf("expected 'Project Memo:' header, got %q", result.Content) + } + if !strings.Contains(result.Content, "[feedback] fix logging") { + t.Fatalf("expected feedback entry, got %q", result.Content) + } + if !strings.Contains(result.Content, "[project] use grpc") { + t.Fatalf("expected project entry, got %q", result.Content) + } + if !strings.Contains(result.Content, "[reference] design doc") { + t.Fatalf("expected reference entry, got %q", result.Content) + } + if strings.Contains(result.Content, "User Memo:") { + t.Fatalf("scope=project should not include User Memo section, got %q", result.Content) + } +} + +func TestListToolExecuteAllScopes(t *testing.T) { + svc := newTestService(t) + if err := svc.Add(context.Background(), memo.Entry{ + Type: memo.TypeUser, + Title: "dark mode", + Content: "dark mode", + Source: memo.SourceUserManual, + }); err != nil { + t.Fatalf("seed Add() error = %v", err) + } + if err := svc.Add(context.Background(), memo.Entry{ + Type: memo.TypeProject, + Title: "use sqlite", + Content: "use sqlite", + Source: memo.SourceToolInitiated, + }); err != nil { + t.Fatalf("seed Add() error = %v", err) + } + + tool := NewListTool(svc) + + // scope="all" 应包含两层 + args, _ := json.Marshal(listInput{Scope: "all"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if result.IsError { + t.Fatalf("unexpected error result: %+v", result) + } + if !strings.Contains(result.Content, "User Memo:") { + t.Fatalf("expected 'User Memo:' header, got %q", result.Content) + } + if !strings.Contains(result.Content, "Project Memo:") { + t.Fatalf("expected 'Project Memo:' header, got %q", result.Content) + } + if !strings.Contains(result.Content, "[user] dark mode") { + t.Fatalf("expected user entry, got %q", result.Content) + } + if !strings.Contains(result.Content, "[project] use sqlite") { + t.Fatalf("expected project entry, got %q", result.Content) + } + + // 默认空 scope 等价于 all + args, _ = json.Marshal(listInput{}) + result2, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if result2.Content != result.Content { + t.Fatalf("default scope content = %q, want same as all = %q", result2.Content, result.Content) + } +} + +func TestListToolExecuteAllScopesPartialEmpty(t *testing.T) { + svc := newTestService(t) + // 只有 project 条目,user 层为空 + if err := svc.Add(context.Background(), memo.Entry{ + Type: memo.TypeProject, + Title: "use grpc", + Content: "use grpc", + Source: memo.SourceToolInitiated, + }); err != nil { + t.Fatalf("seed Add() error = %v", err) + } + + tool := NewListTool(svc) + args, _ := json.Marshal(listInput{Scope: "all"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if !strings.Contains(result.Content, "User Memo:") { + t.Fatalf("expected 'User Memo:' header for empty section, got %q", result.Content) + } + if !strings.Contains(result.Content, "") { + t.Fatalf("expected '' for user section, got %q", result.Content) + } + if !strings.Contains(result.Content, "Project Memo:") { + t.Fatalf("expected 'Project Memo:' header, got %q", result.Content) + } +} + +func TestListToolExecuteNilService(t *testing.T) { + tool := NewListTool(nil) + args, _ := json.Marshal(listInput{}) + + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err == nil || !result.IsError { + t.Fatalf("expected nil service error, got result=%+v err=%v", result, err) + } +} + +func TestListToolExecuteInvalidJSON(t *testing.T) { + tool := NewListTool(nil) + if _, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: []byte("not json")}); err == nil { + t.Fatal("expected invalid JSON error") + } +} + +func TestListToolExecuteInvalidScope(t *testing.T) { + svc := newTestService(t) + tool := NewListTool(svc) + + args, _ := json.Marshal(listInput{Scope: "badscope"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err == nil || !result.IsError { + t.Fatalf("expected invalid scope error, got result=%+v err=%v", result, err) + } +} diff --git a/internal/tools/memo/recall.go b/internal/tools/memo/recall.go index e921787b..dfcc2cdf 100644 --- a/internal/tools/memo/recall.go +++ b/internal/tools/memo/recall.go @@ -18,6 +18,7 @@ const ( // recallInput 定义 memo_recall 工具的 JSON 入参。 type recallInput struct { Keyword string `json:"keyword"` + Scope string `json:"scope,omitempty"` } // RecallTool 让 Agent 按关键词搜索并加载记忆详情。 @@ -47,8 +48,9 @@ func (t *RecallTool) Schema() map[string]any { "properties": map[string]any{ "keyword": map[string]any{ "type": "string", - "description": "Search keyword to find matching memory entries (searches title, type, and keywords).", + "description": "Search keyword to find matching memory entries (searches title, type, content, and keywords).", }, + "scope": memoScopePropertySchema(), }, "required": []string{"keyword"}, } @@ -63,8 +65,7 @@ func (t *RecallTool) MicroCompactPolicy() tools.MicroCompactPolicy { func (t *RecallTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { var args recallInput if err := json.Unmarshal(call.Arguments, &args); err != nil { - err = fmt.Errorf("%s: %w", recallToolName, err) - return tools.NewErrorResult(recallToolName, "invalid arguments", err.Error(), nil), err + return invalidArgumentsError(recallToolName, err) } args.Keyword = strings.TrimSpace(args.Keyword) @@ -76,7 +77,12 @@ func (t *RecallTool) Execute(ctx context.Context, call tools.ToolCallInput) (too return nilServiceError(recallToolName) } - results, err := t.svc.Recall(ctx, args.Keyword) + scope, err := parseMemoScope(args.Scope, true) + if err != nil { + return tools.NewErrorResult(recallToolName, tools.NormalizeErrorReason(recallToolName, err), "", nil), err + } + + results, err := t.svc.Recall(ctx, args.Keyword, scope) if err != nil { return tools.NewErrorResult(recallToolName, tools.NormalizeErrorReason(recallToolName, err), "", nil), err } @@ -88,17 +94,16 @@ func (t *RecallTool) Execute(ctx context.Context, call tools.ToolCallInput) (too }, tools.DefaultOutputLimitBytes), nil } - // 按 key 排序保证输出稳定性 - keys := make([]string, 0, len(results)) - for k := range results { - keys = append(keys, k) - } - sort.Strings(keys) - var builder strings.Builder fmt.Fprintf(&builder, "Found %d memory topic(s) matching %q:\n\n", len(results), args.Keyword) - for _, k := range keys { - fmt.Fprintf(&builder, "--- %s ---\n%s\n\n", k, results[k]) + sort.SliceStable(results, func(i, j int) bool { + if results[i].Scope != results[j].Scope { + return results[i].Scope < results[j].Scope + } + return results[i].Entry.TopicFile < results[j].Entry.TopicFile + }) + for _, item := range results { + fmt.Fprintf(&builder, "--- [%s] %s ---\n%s\n\n", item.Scope, item.Entry.TopicFile, item.Content) } return tools.ApplyOutputLimit(tools.ToolResult{ diff --git a/internal/tools/memo/recall_test.go b/internal/tools/memo/recall_test.go index 239a0fd4..04cbf1c7 100644 --- a/internal/tools/memo/recall_test.go +++ b/internal/tools/memo/recall_test.go @@ -13,151 +13,124 @@ import ( func TestRecallToolName(t *testing.T) { tool := NewRecallTool(nil) if tool.Name() != tools.ToolNameMemoRecall { - t.Errorf("Name() = %q, want %q", tool.Name(), tools.ToolNameMemoRecall) - } -} - -func TestRecallToolSchema(t *testing.T) { - tool := NewRecallTool(nil) - schema := tool.Schema() - if schema["type"] != "object" { - t.Errorf("Schema type = %v, want object", schema["type"]) - } - props, ok := schema["properties"].(map[string]any) - if !ok { - t.Fatal("Schema properties is not a map") - } - if _, exists := props["keyword"]; !exists { - t.Error("Schema missing 'keyword' property") - } -} - -func TestRecallToolMicroCompactPolicy(t *testing.T) { - tool := NewRecallTool(nil) - if tool.MicroCompactPolicy() != tools.MicroCompactPolicyPreserveHistory { - t.Errorf("MicroCompactPolicy() = %v, want PreserveHistory", tool.MicroCompactPolicy()) + t.Fatalf("Name() = %q, want %q", tool.Name(), tools.ToolNameMemoRecall) } } func TestRecallToolExecuteSuccess(t *testing.T) { svc := newTestService(t) - // 预先写入记忆 - svc.Add(context.Background(), memo.Entry{ + if err := svc.Add(context.Background(), memo.Entry{ Type: memo.TypeUser, - Title: "偏好中文注释", - Content: "用户偏好使用中文注释和 tab 缩进", + Title: "prefer chinese comments", + Content: "prefer chinese comments", Source: memo.SourceUserManual, - }) + }); err != nil { + t.Fatalf("seed Add() error = %v", err) + } tool := NewRecallTool(svc) - args, _ := json.Marshal(recallInput{Keyword: "中文"}) + args, _ := json.Marshal(recallInput{Keyword: "chinese"}) result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) if err != nil { - t.Fatalf("Execute error: %v", err) - } - if result.IsError { - t.Errorf("unexpected error result: %s", result.Content) + t.Fatalf("Execute() error = %v", err) } - if !strings.Contains(result.Content, "Found 1 memory") { - t.Errorf("Content should show match count: %q", result.Content) + if result.IsError || !strings.Contains(result.Content, "Found 1 memory topic") { + t.Fatalf("unexpected result: %+v", result) } - if !strings.Contains(result.Content, "中文注释") { - t.Errorf("Content should contain topic content: %q", result.Content) + if !strings.Contains(result.Content, "[user]") { + t.Fatalf("expected scoped header, got %q", result.Content) } } func TestRecallToolExecuteNoMatch(t *testing.T) { - svc := newTestService(t) - tool := NewRecallTool(svc) + tool := NewRecallTool(newTestService(t)) + args, _ := json.Marshal(recallInput{Keyword: "missing"}) - args, _ := json.Marshal(recallInput{Keyword: "nonexistent"}) result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) if err != nil { - t.Fatalf("Execute error: %v", err) + t.Fatalf("Execute() error = %v", err) } - if result.IsError { - t.Errorf("no match should not be an error: %s", result.Content) - } - if !strings.Contains(result.Content, "No memories found") { - t.Errorf("Content should show no match: %q", result.Content) + if result.IsError || !strings.Contains(result.Content, "No memories found") { + t.Fatalf("unexpected result: %+v", result) } } -func TestRecallToolExecuteInvalidJSON(t *testing.T) { - tool := NewRecallTool(nil) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: []byte("not json")}) - if err == nil { - t.Error("expected error for invalid JSON") +func TestRecallToolExecuteBadInput(t *testing.T) { + tool := NewRecallTool(newTestService(t)) + + if _, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: []byte("not json")}); err == nil { + t.Fatal("expected invalid JSON error") + } + args, _ := json.Marshal(recallInput{Keyword: ""}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err == nil || !result.IsError { + t.Fatalf("expected empty keyword error, got result=%+v err=%v", result, err) } } func TestRecallToolExecuteNilService(t *testing.T) { tool := NewRecallTool(nil) - args, _ := json.Marshal(recallInput{Keyword: "tab"}) + args, _ := json.Marshal(recallInput{Keyword: "x"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) - if err == nil { - t.Fatal("expected error for nil service") - } - if !result.IsError { - t.Fatal("expected error result") - } - if !strings.Contains(result.Content, "service is nil") { - t.Fatalf("unexpected error content: %q", result.Content) + if err == nil || !result.IsError { + t.Fatalf("expected nil service error, got result=%+v err=%v", result, err) } } -func TestRecallToolExecuteEmptyKeyword(t *testing.T) { - svc := newTestService(t) - tool := NewRecallTool(svc) - - args, _ := json.Marshal(recallInput{Keyword: ""}) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) - if err == nil { - t.Error("expected error for empty keyword") +func TestRecallToolDescriptionAndSchema(t *testing.T) { + tool := NewRecallTool(nil) + if tool.Description() == "" { + t.Fatal("Description() should not be empty") } - if !result.IsError { - t.Error("expected error result") + schema := tool.Schema() + if schema == nil { + t.Fatal("Schema() should not be nil") + } + if tool.MicroCompactPolicy() != tools.MicroCompactPolicyPreserveHistory { + t.Fatalf("MicroCompactPolicy() = %v, want PreserveHistory", tool.MicroCompactPolicy()) } } -func TestRecallToolExecuteWhitespaceKeyword(t *testing.T) { +func TestRecallToolExecuteWithScopeFilter(t *testing.T) { svc := newTestService(t) - tool := NewRecallTool(svc) - - args, _ := json.Marshal(recallInput{Keyword: " "}) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) - if err == nil { - t.Error("expected error for whitespace keyword") + if err := svc.Add(context.Background(), memo.Entry{ + Type: memo.TypeUser, + Title: "user pref", + Content: "user pref content", + Source: memo.SourceUserManual, + }); err != nil { + t.Fatalf("Add user: %v", err) } - if !result.IsError { - t.Error("expected error result") + if err := svc.Add(context.Background(), memo.Entry{ + Type: memo.TypeFeedback, + Title: "feedback pref", + Content: "feedback pref content", + Source: memo.SourceUserManual, + }); err != nil { + t.Fatalf("Add feedback: %v", err) } -} - -func TestRecallToolExecuteMultipleResults(t *testing.T) { - svc := newTestService(t) - svc.Add(context.Background(), memo.Entry{Type: memo.TypeUser, Title: "偏好 tab", Content: "tab content", Source: memo.SourceUserManual}) - svc.Add(context.Background(), memo.Entry{Type: memo.TypeFeedback, Title: "反馈 tab 问题", Content: "feedback content", Source: memo.SourceUserManual}) tool := NewRecallTool(svc) - args, _ := json.Marshal(recallInput{Keyword: "tab"}) + args, _ := json.Marshal(recallInput{Keyword: "pref", Scope: "user"}) result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) if err != nil { - t.Fatalf("Execute error: %v", err) + t.Fatalf("Execute() error = %v", err) } - if result.IsError { - t.Errorf("unexpected error: %s", result.Content) + if result.IsError || !strings.Contains(result.Content, "Found 1 memory topic") { + t.Fatalf("expected 1 user result, got: %s", result.Content) } - if !strings.Contains(result.Content, "Found 2 memory") { - t.Errorf("Content should show 2 matches: %q", result.Content) + if strings.Contains(result.Content, "feedback") { + t.Fatalf("should not contain feedback entry: %s", result.Content) } } -func TestRecallToolDescription(t *testing.T) { - tool := NewRecallTool(nil) - desc := tool.Description() - if !strings.Contains(desc, "memory") { - t.Errorf("Description should mention 'memory': %q", desc) +func TestRecallToolExecuteInvalidScope(t *testing.T) { + tool := NewRecallTool(newTestService(t)) + args, _ := json.Marshal(recallInput{Keyword: "test", Scope: "badscope"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err == nil || !result.IsError { + t.Fatalf("expected bad scope error, got result=%+v err=%v", result, err) } } @@ -165,27 +138,20 @@ func TestRecallToolExecuteAppliesOutputLimit(t *testing.T) { svc := newTestService(t) if err := svc.Add(context.Background(), memo.Entry{ Type: memo.TypeReference, - Title: "超长记忆", + Title: "long memory", Content: strings.Repeat("x", tools.DefaultOutputLimitBytes+1024), Source: memo.SourceUserManual, }); err != nil { - t.Fatalf("seed memo entry: %v", err) + t.Fatalf("seed Add() error = %v", err) } tool := NewRecallTool(svc) - args, _ := json.Marshal(recallInput{Keyword: "超长"}) + args, _ := json.Marshal(recallInput{Keyword: "long"}) result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) if err != nil { - t.Fatalf("Execute error: %v", err) - } - if result.IsError { - t.Fatalf("expected success result, got error: %s", result.Content) - } - if !strings.Contains(result.Content, "...[truncated]") { - t.Fatalf("expected truncated suffix, got content length %d", len(result.Content)) + t.Fatalf("Execute() error = %v", err) } - truncated, ok := result.Metadata["truncated"].(bool) - if !ok || !truncated { - t.Fatalf("expected metadata truncated=true, got %+v", result.Metadata) + if result.IsError || !strings.Contains(result.Content, "...[truncated]") { + t.Fatalf("unexpected truncated result: %+v", result) } } diff --git a/internal/tools/memo/remember.go b/internal/tools/memo/remember.go index d6cac2db..5009ef01 100644 --- a/internal/tools/memo/remember.go +++ b/internal/tools/memo/remember.go @@ -78,8 +78,7 @@ func (t *RememberTool) MicroCompactPolicy() tools.MicroCompactPolicy { func (t *RememberTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { var args rememberInput if err := json.Unmarshal(call.Arguments, &args); err != nil { - err = fmt.Errorf("%s: %w", rememberToolName, err) - return tools.NewErrorResult(rememberToolName, "invalid arguments", err.Error(), nil), err + return invalidArgumentsError(rememberToolName, err) } args.Type = strings.TrimSpace(args.Type) diff --git a/internal/tools/memo/remember_test.go b/internal/tools/memo/remember_test.go index f2ac7fab..ac0d39f8 100644 --- a/internal/tools/memo/remember_test.go +++ b/internal/tools/memo/remember_test.go @@ -11,41 +11,21 @@ import ( "neo-code/internal/tools" ) -// newTestService 创建绑定临时目录的 memo.Service 实例。 func newTestService(t *testing.T) *memo.Service { t.Helper() store := memo.NewFileStore(t.TempDir(), t.TempDir()) - return memo.NewService(store, nil, config.MemoConfig{MaxIndexLines: 200}, nil) + return memo.NewService(store, config.MemoConfig{ + MaxEntries: 200, + MaxIndexBytes: 16 * 1024, + ExtractTimeoutSec: 15, + ExtractRecentMessages: 10, + }, nil) } func TestRememberToolName(t *testing.T) { tool := NewRememberTool(nil) if tool.Name() != tools.ToolNameMemoRemember { - t.Errorf("Name() = %q, want %q", tool.Name(), tools.ToolNameMemoRemember) - } -} - -func TestRememberToolSchema(t *testing.T) { - tool := NewRememberTool(nil) - schema := tool.Schema() - if schema["type"] != "object" { - t.Errorf("Schema type = %v, want object", schema["type"]) - } - props, ok := schema["properties"].(map[string]any) - if !ok { - t.Fatal("Schema properties is not a map") - } - for _, field := range []string{"type", "title", "content"} { - if _, exists := props[field]; !exists { - t.Errorf("Schema missing required property %q", field) - } - } -} - -func TestRememberToolMicroCompactPolicy(t *testing.T) { - tool := NewRememberTool(nil) - if tool.MicroCompactPolicy() != tools.MicroCompactPolicyPreserveHistory { - t.Errorf("MicroCompactPolicy() = %v, want PreserveHistory", tool.MicroCompactPolicy()) + t.Fatalf("Name() = %q, want %q", tool.Name(), tools.ToolNameMemoRemember) } } @@ -54,169 +34,63 @@ func TestRememberToolExecuteSuccess(t *testing.T) { tool := NewRememberTool(svc) args, _ := json.Marshal(rememberInput{ - Type: "user", - Title: "偏好中文注释", - Content: "用户偏好使用中文注释和 tab 缩进", + Type: "user", + Title: "prefer chinese comments", + Content: "prefer chinese comments", + Keywords: []string{" comments ", "comments", "style"}, }) result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) if err != nil { - t.Fatalf("Execute error: %v", err) - } - if result.IsError { - t.Errorf("unexpected error result: %s", result.Content) + t.Fatalf("Execute() error = %v", err) } - if !strings.Contains(result.Content, "Memory saved") { - t.Errorf("Content = %q, want saved confirmation", result.Content) - } - if !strings.Contains(result.Content, "偏好中文注释") { - t.Errorf("Content should contain title: %q", result.Content) + if result.IsError || !strings.Contains(result.Content, "Memory saved") { + t.Fatalf("unexpected result: %+v", result) } - // 验证实际保存(索引只保留 Type/Title/TopicFile,完整信息在 topic 文件中) - entries, _ := svc.List(context.Background()) - if len(entries) != 1 { - t.Fatalf("expected 1 entry, got %d", len(entries)) - } - if entries[0].Type != memo.TypeUser { - t.Errorf("Type = %q, want %q", entries[0].Type, memo.TypeUser) - } -} - -func TestRememberToolExecuteWithKeywords(t *testing.T) { - svc := newTestService(t) - tool := NewRememberTool(svc) - - args, _ := json.Marshal(rememberInput{ - Type: "feedback", - Title: "不要 mock 数据库", - Content: "集成测试必须连接真实数据库", - Keywords: []string{"testing", "database"}, - }) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + entries, err := svc.List(context.Background(), memo.ScopeUser) if err != nil { - t.Fatalf("Execute error: %v", err) + t.Fatalf("List() error = %v", err) } - if result.IsError { - t.Errorf("unexpected error result: %s", result.Content) + if len(entries) != 1 || entries[0].Type != memo.TypeUser { + t.Fatalf("unexpected entries: %#v", entries) } - - entries, _ := svc.List(context.Background()) - if len(entries) != 1 { - t.Fatalf("expected 1 entry, got %d", len(entries)) - } - // Keywords 存储在 topic 文件中,不在索引中 if entries[0].TopicFile == "" { - t.Error("TopicFile should be set") + t.Fatal("expected TopicFile to be set") } } -func TestRememberToolExecuteInvalidJSON(t *testing.T) { - tool := NewRememberTool(nil) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: []byte("not json")}) - if err == nil { - t.Error("expected error for invalid JSON") - } -} - -func TestRememberToolExecuteNilService(t *testing.T) { - tool := NewRememberTool(nil) - args, _ := json.Marshal(rememberInput{ - Type: "user", - Title: "偏好中文注释", - Content: "用户偏好使用中文注释和 tab 缩进", - }) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) - if err == nil { - t.Fatal("expected error for nil service") - } - if !result.IsError { - t.Fatal("expected error result") - } - if !strings.Contains(result.Content, "service is nil") { - t.Fatalf("unexpected error content: %q", result.Content) - } -} - -func TestRememberToolExecuteMissingFields(t *testing.T) { +func TestRememberToolExecuteRejectsBadInput(t *testing.T) { svc := newTestService(t) tool := NewRememberTool(svc) - tests := []struct { - name string - args rememberInput - }{ - {"empty type", rememberInput{Type: "", Title: "t", Content: "c"}}, - {"empty title", rememberInput{Type: "user", Title: "", Content: "c"}}, - {"empty content", rememberInput{Type: "user", Title: "t", Content: ""}}, - {"whitespace type", rememberInput{Type: " ", Title: "t", Content: "c"}}, + tests := []rememberInput{ + {Type: "", Title: "t", Content: "c"}, + {Type: "user", Title: "", Content: "c"}, + {Type: "user", Title: "t", Content: ""}, + {Type: "bad", Title: "t", Content: "c"}, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - args, _ := json.Marshal(tt.args) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) - if err == nil { - t.Error("expected error for missing fields") - } - if !result.IsError { - t.Error("expected error result") - } - }) + args, _ := json.Marshal(tt) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err == nil || !result.IsError { + t.Fatalf("expected bad input to fail: %+v / %+v", tt, result) + } } } -func TestRememberToolExecuteInvalidType(t *testing.T) { - svc := newTestService(t) - tool := NewRememberTool(svc) +func TestRememberToolExecuteNilService(t *testing.T) { + tool := NewRememberTool(nil) + args, _ := json.Marshal(rememberInput{Type: "user", Title: "t", Content: "c"}) - args, _ := json.Marshal(rememberInput{Type: "invalid", Title: "t", Content: "c"}) result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) - if err == nil { - t.Error("expected error for invalid type") - } - if !result.IsError { - t.Error("expected error result") - } - if !strings.Contains(result.Content, "invalid type") { - t.Errorf("Content should mention invalid type: %q", result.Content) - } -} - -func TestRememberToolExecuteAllTypes(t *testing.T) { - svc := newTestService(t) - tool := NewRememberTool(svc) - - for _, memoType := range memo.ValidTypes() { - t.Run(string(memoType), func(t *testing.T) { - args, _ := json.Marshal(rememberInput{ - Type: string(memoType), - Title: "test " + string(memoType), - Content: "content for " + string(memoType), - }) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) - if err != nil { - t.Fatalf("Execute error for type %s: %v", memoType, err) - } - if result.IsError { - t.Errorf("unexpected error for type %s: %s", memoType, result.Content) - } - }) + if err == nil || !result.IsError { + t.Fatalf("expected nil service error, got result=%+v err=%v", result, err) } } -func TestRememberToolExecuteServiceError(t *testing.T) { - svc := newTestService(t) - tool := NewRememberTool(svc) - - args, _ := json.Marshal(rememberInput{Type: "user", Title: "test", Content: "test"}) - ctx, cancel := context.WithCancel(context.Background()) - cancel() // 取消上下文以触发错误 - - result, err := tool.Execute(ctx, tools.ToolCallInput{Arguments: args}) - if err == nil { - t.Error("expected error when context is cancelled") - } - if !result.IsError { - t.Error("expected error result") +func TestRememberToolExecuteInvalidJSON(t *testing.T) { + tool := NewRememberTool(nil) + if _, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: []byte("not json")}); err == nil { + t.Fatal("expected invalid JSON error") } } diff --git a/internal/tools/memo/remove.go b/internal/tools/memo/remove.go new file mode 100644 index 00000000..6f5e0c81 --- /dev/null +++ b/internal/tools/memo/remove.go @@ -0,0 +1,93 @@ +package memo + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "neo-code/internal/memo" + "neo-code/internal/tools" +) + +const removeToolName = tools.ToolNameMemoRemove + +type removeInput struct { + Keyword string `json:"keyword"` + Scope string `json:"scope,omitempty"` +} + +// RemoveTool 让调用方按关键词删除持久记忆。 +type RemoveTool struct { + svc *memo.Service +} + +// NewRemoveTool 创建 memo_remove 工具。 +func NewRemoveTool(svc *memo.Service) *RemoveTool { + return &RemoveTool{svc: svc} +} + +// Name 返回工具注册名。 +func (t *RemoveTool) Name() string { return removeToolName } + +// Description 返回工具描述。 +func (t *RemoveTool) Description() string { + return "Remove persistent memories by keyword. Optionally limit deletion scope to user or project memories." +} + +// Schema 返回 JSON Schema 描述的工具参数格式。 +func (t *RemoveTool) Schema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "keyword": map[string]any{ + "type": "string", + "description": "Keyword to match against memory title, type, keywords, or content.", + }, + "scope": memoScopePropertySchema(), + }, + "required": []string{"keyword"}, + } +} + +// MicroCompactPolicy 删除结果应保留在上下文中,不参与 micro compact 清理。 +func (t *RemoveTool) MicroCompactPolicy() tools.MicroCompactPolicy { + return tools.MicroCompactPolicyPreserveHistory +} + +// Execute 执行 memo_remove 工具调用。 +func (t *RemoveTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { + var args removeInput + if err := json.Unmarshal(call.Arguments, &args); err != nil { + return invalidArgumentsError(removeToolName, err) + } + if t.svc == nil { + return nilServiceError(removeToolName) + } + + args.Keyword = strings.TrimSpace(args.Keyword) + if args.Keyword == "" { + err := fmt.Errorf("%s: keyword is required", removeToolName) + return tools.NewErrorResult(removeToolName, tools.NormalizeErrorReason(removeToolName, err), "", nil), err + } + + scope, err := parseMemoScope(args.Scope, true) + if err != nil { + return tools.NewErrorResult(removeToolName, tools.NormalizeErrorReason(removeToolName, err), "", nil), err + } + + removed, err := t.svc.Remove(ctx, args.Keyword, scope) + if err != nil { + return tools.NewErrorResult(removeToolName, tools.NormalizeErrorReason(removeToolName, err), "", nil), err + } + if removed == 0 { + return tools.ToolResult{ + Name: removeToolName, + Content: fmt.Sprintf("No memos matching %q.", args.Keyword), + }, nil + } + return tools.ToolResult{ + Name: removeToolName, + Content: fmt.Sprintf("Removed %d memo(s) matching %q.", removed, args.Keyword), + }, nil +} diff --git a/internal/tools/memo/remove_test.go b/internal/tools/memo/remove_test.go new file mode 100644 index 00000000..bcf5d9a8 --- /dev/null +++ b/internal/tools/memo/remove_test.go @@ -0,0 +1,229 @@ +package memo + +import ( + "context" + "encoding/json" + "strings" + "testing" + + "neo-code/internal/memo" + "neo-code/internal/tools" +) + +func TestRemoveToolName(t *testing.T) { + tool := NewRemoveTool(nil) + if tool.Name() != tools.ToolNameMemoRemove { + t.Fatalf("Name() = %q, want %q", tool.Name(), tools.ToolNameMemoRemove) + } + if tool.Description() == "" { + t.Fatal("Description() should not be empty") + } + if tool.Schema() == nil { + t.Fatal("Schema() should not be nil") + } + if tool.MicroCompactPolicy() != tools.MicroCompactPolicyPreserveHistory { + t.Fatalf("MicroCompactPolicy() = %v, want PreserveHistory", tool.MicroCompactPolicy()) + } +} + +func TestRemoveToolExecuteSuccess(t *testing.T) { + svc := newTestService(t) + if err := svc.Add(context.Background(), memo.Entry{ + Type: memo.TypeUser, + Title: "prefer chinese comments", + Content: "always write comments in chinese", + Source: memo.SourceUserManual, + }); err != nil { + t.Fatalf("seed Add() error = %v", err) + } + if err := svc.Add(context.Background(), memo.Entry{ + Type: memo.TypeUser, + Title: "prefer short functions", + Content: "keep functions under 30 lines", + Source: memo.SourceUserManual, + }); err != nil { + t.Fatalf("seed Add() error = %v", err) + } + + tool := NewRemoveTool(svc) + args, _ := json.Marshal(removeInput{Keyword: "prefer"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if result.IsError { + t.Fatalf("unexpected error result: %+v", result) + } + if !strings.Contains(result.Content, "Removed 2 memo(s)") { + t.Fatalf("unexpected content: %q", result.Content) + } + + // 验证记忆已被真正删除 + entries, err := svc.List(context.Background(), memo.ScopeUser) + if err != nil { + t.Fatalf("List() error = %v", err) + } + if len(entries) != 0 { + t.Fatalf("expected no entries after removal, got %d", len(entries)) + } +} + +func TestRemoveToolExecuteNoMatch(t *testing.T) { + svc := newTestService(t) + if err := svc.Add(context.Background(), memo.Entry{ + Type: memo.TypeUser, + Title: "prefer tabs", + Content: "use tabs for indentation", + Source: memo.SourceUserManual, + }); err != nil { + t.Fatalf("seed Add() error = %v", err) + } + + tool := NewRemoveTool(svc) + args, _ := json.Marshal(removeInput{Keyword: "nonexistent"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if result.IsError { + t.Fatalf("unexpected error result: %+v", result) + } + if !strings.Contains(result.Content, "No memos matching") { + t.Fatalf("unexpected content: %q", result.Content) + } +} + +func TestRemoveToolExecuteWithScopeFilter(t *testing.T) { + svc := newTestService(t) + + // 添加 user 作用域的记忆(TypeUser -> ScopeUser) + if err := svc.Add(context.Background(), memo.Entry{ + Type: memo.TypeUser, + Title: "prefer dark theme", + Content: "dark theme preference", + Source: memo.SourceUserManual, + }); err != nil { + t.Fatalf("seed Add() error = %v", err) + } + + // 添加 project 作用域的记忆(TypeProject -> ScopeProject) + if err := svc.Add(context.Background(), memo.Entry{ + Type: memo.TypeProject, + Title: "prefer dark theme for project", + Content: "dark theme for this project", + Source: memo.SourceUserManual, + }); err != nil { + t.Fatalf("seed Add() error = %v", err) + } + + tool := NewRemoveTool(svc) + // 只删除 user 作用域的匹配条目 + args, _ := json.Marshal(removeInput{Keyword: "dark theme", Scope: "user"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if result.IsError { + t.Fatalf("unexpected error result: %+v", result) + } + if !strings.Contains(result.Content, "Removed 1 memo(s)") { + t.Fatalf("unexpected content: %q", result.Content) + } + + // 验证 user 作用域已清空 + userEntries, err := svc.List(context.Background(), memo.ScopeUser) + if err != nil { + t.Fatalf("List(user) error = %v", err) + } + if len(userEntries) != 0 { + t.Fatalf("expected 0 user entries, got %d", len(userEntries)) + } + + // 验证 project 作用域保留 + projEntries, err := svc.List(context.Background(), memo.ScopeProject) + if err != nil { + t.Fatalf("List(project) error = %v", err) + } + if len(projEntries) != 1 { + t.Fatalf("expected 1 project entry, got %d", len(projEntries)) + } +} + +func TestRemoveToolExecuteEmptyKeyword(t *testing.T) { + svc := newTestService(t) + tool := NewRemoveTool(svc) + + args, _ := json.Marshal(removeInput{Keyword: " "}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err == nil || !result.IsError { + t.Fatalf("expected empty keyword error, got result=%+v err=%v", result, err) + } +} + +func TestRemoveToolExecuteNilService(t *testing.T) { + tool := NewRemoveTool(nil) + args, _ := json.Marshal(removeInput{Keyword: "test"}) + + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err == nil || !result.IsError { + t.Fatalf("expected nil service error, got result=%+v err=%v", result, err) + } +} + +func TestRemoveToolExecuteInvalidJSON(t *testing.T) { + tool := NewRemoveTool(nil) + if _, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: []byte("not json")}); err == nil { + t.Fatal("expected invalid JSON error") + } +} + +func TestRemoveToolExecuteInvalidScope(t *testing.T) { + svc := newTestService(t) + tool := NewRemoveTool(svc) + + args, _ := json.Marshal(removeInput{Keyword: "test", Scope: "invalid"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{Arguments: args}) + if err == nil || !result.IsError { + t.Fatalf("expected invalid scope error, got result=%+v err=%v", result, err) + } +} + +func TestParseMemoScope(t *testing.T) { + tests := []struct { + name string + raw string + allowAll bool + want memo.Scope + wantErr bool + }{ + {"空字符串 allowAll=true 默认 ScopeAll", "", true, memo.ScopeAll, false}, + {"空字符串 allowAll=false 返回错误", "", false, "", true}, + {"user", "user", true, memo.ScopeUser, false}, + {"USER 大写", "USER", true, memo.ScopeUser, false}, + {"project", "project", true, memo.ScopeProject, false}, + {"PROJECT 大写", "PROJECT", true, memo.ScopeProject, false}, + {"all allowAll=true", "all", true, memo.ScopeAll, false}, + {"all allowAll=false 返回错误", "all", false, "", true}, + {"不合法的 scope", "invalid", true, "", true}, + {"带空格的 user", " user ", true, memo.ScopeUser, false}, + {"不合法 scope allowAll=false", "badscope", false, "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseMemoScope(tt.raw, tt.allowAll) + if tt.wantErr { + if err == nil { + t.Fatalf("parseMemoScope(%q, %v) expected error, got scope=%q", tt.raw, tt.allowAll, got) + } + return + } + if err != nil { + t.Fatalf("parseMemoScope(%q, %v) unexpected error: %v", tt.raw, tt.allowAll, err) + } + if got != tt.want { + t.Fatalf("parseMemoScope(%q, %v) = %q, want %q", tt.raw, tt.allowAll, got, tt.want) + } + }) + } +} diff --git a/internal/tools/names.go b/internal/tools/names.go index fd55e8d2..0be5454a 100644 --- a/internal/tools/names.go +++ b/internal/tools/names.go @@ -12,4 +12,6 @@ const ( ToolNameTodoWrite = "todo_write" ToolNameMemoRemember = "memo_remember" ToolNameMemoRecall = "memo_recall" + ToolNameMemoList = "memo_list" + ToolNameMemoRemove = "memo_remove" ) diff --git a/internal/tui/bootstrap/builder_test.go b/internal/tui/bootstrap/builder_test.go index 873a5135..298f64ff 100644 --- a/internal/tui/bootstrap/builder_test.go +++ b/internal/tui/bootstrap/builder_test.go @@ -11,6 +11,7 @@ import ( agentruntime "neo-code/internal/runtime" agentsession "neo-code/internal/session" "neo-code/internal/skills" + "neo-code/internal/tools" ) type testRuntime struct{} @@ -43,6 +44,10 @@ func (r *testRuntime) Compact(ctx context.Context, input agentruntime.CompactInp return agentruntime.CompactResult{}, nil } +func (r *testRuntime) ExecuteSystemTool(ctx context.Context, input agentruntime.SystemToolInput) (tools.ToolResult, error) { + return tools.ToolResult{}, nil +} + func (r *testRuntime) ResolvePermission(ctx context.Context, input agentruntime.PermissionResolutionInput) error { return nil } @@ -261,6 +266,10 @@ func (r noopRuntime) Compact(ctx context.Context, input agentruntime.CompactInpu return agentruntime.CompactResult{}, nil } +func (r noopRuntime) ExecuteSystemTool(ctx context.Context, input agentruntime.SystemToolInput) (tools.ToolResult, error) { + return tools.ToolResult{}, nil +} + func (r noopRuntime) ResolvePermission(ctx context.Context, input agentruntime.PermissionResolutionInput) error { return nil } diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index bc0c4316..4ab17c59 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -2,6 +2,7 @@ package tui import ( "context" + "encoding/json" "errors" "fmt" "path/filepath" @@ -18,7 +19,6 @@ import ( "neo-code/internal/config" configstate "neo-code/internal/config/state" - "neo-code/internal/memo" "neo-code/internal/provider" providertypes "neo-code/internal/provider/types" agentruntime "neo-code/internal/runtime" @@ -2083,88 +2083,70 @@ func (a App) isBusy() bool { // handleMemoCommand 处理 /memo 命令,显示记忆索引内容。 func (a *App) handleMemoCommand() tea.Cmd { - if a.memoSvc == nil { - a.appendInlineMessage(roleError, "[System] Memo service is not enabled.") - a.rebuildTranscript() - return nil - } - entries, err := a.memoSvc.List(context.Background()) - if err != nil { - a.appendInlineMessage(roleError, fmt.Sprintf("[System] Failed to load memo: %s", err)) - a.rebuildTranscript() - return nil - } - if len(entries) == 0 { - a.appendInlineMessage(roleSystem, "[System] No memos stored yet. Use /remember to add one.") - a.rebuildTranscript() - return nil - } - var lines []string - lines = append(lines, fmt.Sprintf("[System] %d memo(s):", len(entries))) - for _, entry := range entries { - lines = append(lines, fmt.Sprintf(" [%s] %s", entry.Type, entry.Title)) - } - a.appendInlineMessage(roleSystem, strings.Join(lines, "\n")) - a.rebuildTranscript() - return nil + return a.runMemoSystemTool(tools.ToolNameMemoList, map[string]any{}) } // handleRememberCommand 处理 /remember 命令,创建新的记忆条目。 func (a *App) handleRememberCommand(text string) tea.Cmd { text = strings.TrimSpace(text) - if a.memoSvc == nil { - a.appendInlineMessage(roleError, "[System] Memo service is not enabled.") - a.rebuildTranscript() - return nil - } if text == "" { a.appendInlineMessage(roleError, fmt.Sprintf("[System] Usage: %s", slashUsageRemember)) a.rebuildTranscript() return nil } - title := memo.NormalizeTitle(text) - entry := memo.Entry{ - Type: memo.TypeUser, - Title: title, - Content: text, - Source: memo.SourceUserManual, - } - if err := a.memoSvc.Add(context.Background(), entry); err != nil { - a.appendInlineMessage(roleError, fmt.Sprintf("[System] Failed to save memo: %s", err)) - a.rebuildTranscript() - return nil - } - a.appendInlineMessage(roleSystem, fmt.Sprintf("[System] Memo saved: %s", title)) - a.rebuildTranscript() - return nil + return a.runMemoSystemTool(tools.ToolNameMemoRemember, map[string]any{ + "type": "user", + "title": text, + "content": text, + }) } // handleForgetCommand 处理 /forget 命令,删除匹配的记忆条目。 func (a *App) handleForgetCommand(keyword string) tea.Cmd { keyword = strings.TrimSpace(keyword) - if a.memoSvc == nil { - a.appendInlineMessage(roleError, "[System] Memo service is not enabled.") - a.rebuildTranscript() - return nil - } if keyword == "" { a.appendInlineMessage(roleError, fmt.Sprintf("[System] Usage: %s", slashUsageForget)) a.rebuildTranscript() return nil } - removed, err := a.memoSvc.Remove(context.Background(), keyword) + return a.runMemoSystemTool(tools.ToolNameMemoRemove, map[string]any{ + "keyword": keyword, + "scope": "all", + }) +} + +// runMemoSystemTool 通过 runtime 的系统工具入口执行 memo 相关 slash 命令。 +func (a *App) runMemoSystemTool(toolName string, arguments map[string]any) tea.Cmd { + payload, err := json.Marshal(arguments) if err != nil { - a.appendInlineMessage(roleError, fmt.Sprintf("[System] Failed to remove memo: %s", err)) + a.appendInlineMessage(roleError, fmt.Sprintf("[System] Failed to encode memo command: %s", err)) a.rebuildTranscript() return nil } - if removed == 0 { - a.appendInlineMessage(roleSystem, fmt.Sprintf("[System] No memos matching %q.", keyword)) - } else { - a.appendInlineMessage(roleSystem, fmt.Sprintf("[System] Removed %d memo(s) matching %q.", removed, keyword)) - } - a.rebuildTranscript() - return nil + + return tuiservices.RunSystemToolCmd( + a.runtime, + agentruntime.SystemToolInput{ + SessionID: a.state.ActiveSessionID, + Workdir: a.state.CurrentWorkdir, + ToolName: toolName, + Arguments: payload, + }, + func(result tools.ToolResult, err error) tea.Msg { + if err != nil { + message := strings.TrimSpace(result.Content) + if message == "" { + message = err.Error() + } + return localCommandResultMsg{Err: errors.New(message)} + } + notice := strings.TrimSpace(result.Content) + if notice == "" { + notice = "Memo command completed." + } + return localCommandResultMsg{Notice: notice} + }, + ) } // setCurrentWorkdir 统一设置当前工作目录,仅接受非空白且为绝对路径的值。 diff --git a/internal/tui/core/app/update_permission_test.go b/internal/tui/core/app/update_permission_test.go index 5b24ef62..12d28987 100644 --- a/internal/tui/core/app/update_permission_test.go +++ b/internal/tui/core/app/update_permission_test.go @@ -15,6 +15,7 @@ import ( agentruntime "neo-code/internal/runtime" approvalflow "neo-code/internal/runtime/approval" agentsession "neo-code/internal/session" + "neo-code/internal/tools" tuistate "neo-code/internal/tui/state" ) @@ -45,6 +46,10 @@ func (r *permissionTestRuntime) Compact(ctx context.Context, input agentruntime. return agentruntime.CompactResult{}, nil } +func (r *permissionTestRuntime) ExecuteSystemTool(ctx context.Context, input agentruntime.SystemToolInput) (tools.ToolResult, error) { + return tools.ToolResult{}, nil +} + func (r *permissionTestRuntime) ResolvePermission(ctx context.Context, input agentruntime.PermissionResolutionInput) error { r.lastResolved = input return r.resolveErr diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index 2b54cb4a..ca246202 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -21,6 +21,7 @@ import ( approvalflow "neo-code/internal/runtime/approval" agentsession "neo-code/internal/session" "neo-code/internal/tools" + memotool "neo-code/internal/tools/memo" tuibootstrap "neo-code/internal/tui/bootstrap" tuiservices "neo-code/internal/tui/services" tuistate "neo-code/internal/tui/state" @@ -103,6 +104,8 @@ type stubRuntime struct { prepareErr error preparedOutput agentruntime.UserInput runInputs []agentruntime.UserInput + systemToolCalls []agentruntime.SystemToolInput + systemToolFn func(ctx context.Context, input agentruntime.SystemToolInput) (tools.ToolResult, error) resolveCalls []agentruntime.PermissionResolutionInput resolveErr error cancelInvoked bool @@ -164,6 +167,14 @@ func (s *stubRuntime) Compact(ctx context.Context, input agentruntime.CompactInp return agentruntime.CompactResult{}, nil } +func (s *stubRuntime) ExecuteSystemTool(ctx context.Context, input agentruntime.SystemToolInput) (tools.ToolResult, error) { + s.systemToolCalls = append(s.systemToolCalls, input) + if s.systemToolFn != nil { + return s.systemToolFn(ctx, input) + } + return tools.ToolResult{}, nil +} + func (s *stubRuntime) ResolvePermission(ctx context.Context, input agentruntime.PermissionResolutionInput) error { s.resolveCalls = append(s.resolveCalls, input) return s.resolveErr @@ -2014,6 +2025,11 @@ func TestSetCurrentWorkdir(t *testing.T) { // newTestAppWithMemo 创建一个注入了 memo 服务的测试 App。 func newTestAppWithMemo(t *testing.T) (App, *stubRuntime) { t.Helper() + return newTestAppWithMemoBaseDir(t, t.TempDir()) +} + +func newTestAppWithMemoBaseDir(t *testing.T, memoBaseDir string) (App, *stubRuntime) { + t.Helper() cfg := newDefaultAppConfig() cfg.Workdir = t.TempDir() @@ -2039,10 +2055,22 @@ func newTestAppWithMemo(t *testing.T) (App, *stubRuntime) { } // 创建真实的 memo 服务 - memoStore := memo.NewFileStore(t.TempDir(), cfg.Workdir) - memoSvc := memo.NewService(memoStore, nil, cfg.Memo, nil) + memoStore := memo.NewFileStore(memoBaseDir, cfg.Workdir) + memoSvc := memo.NewService(memoStore, cfg.Memo, nil) runtime := newStubRuntime() + runtime.systemToolFn = func(ctx context.Context, input agentruntime.SystemToolInput) (tools.ToolResult, error) { + switch input.ToolName { + case tools.ToolNameMemoList: + return memotool.NewListTool(memoSvc).Execute(ctx, tools.ToolCallInput{Arguments: input.Arguments}) + case tools.ToolNameMemoRemember: + return memotool.NewRememberTool(memoSvc).Execute(ctx, tools.ToolCallInput{Arguments: input.Arguments}) + case tools.ToolNameMemoRemove: + return memotool.NewRemoveTool(memoSvc).Execute(ctx, tools.ToolCallInput{Arguments: input.Arguments}) + default: + return tools.ToolResult{}, errors.New("unsupported system tool") + } + } app, err := newApp(tuibootstrap.Container{ Config: *cfg, ConfigManager: manager, @@ -2062,16 +2090,13 @@ func TestHandleMemoCommand(t *testing.T) { t.Run("shows no memos message when empty", func(t *testing.T) { app, _ := newTestAppWithMemo(t) cmd := app.handleMemoCommand() - if cmd != nil { - t.Error("expected nil cmd") + if cmd == nil { + t.Fatal("expected async cmd") } - msgs := app.activeMessages - if len(msgs) == 0 { - t.Fatal("expected at least one inline message") - } - last := msgs[len(msgs)-1] - if !strings.Contains(messageText(last), "No memos stored yet") { - t.Errorf("expected 'no memos' message, got: %s", messageText(last)) + model, _ := app.Update(cmd()) + app = model.(App) + if !strings.Contains(app.state.StatusText, "No memos stored yet") { + t.Errorf("expected status to mention no memos, got: %s", app.state.StatusText) } }) @@ -2079,30 +2104,26 @@ func TestHandleMemoCommand(t *testing.T) { app, _ := newTestAppWithMemo(t) app.memoSvc.Add(context.Background(), memo.Entry{Type: memo.TypeUser, Title: "test entry", Content: "test", Source: memo.SourceUserManual}) - app.handleMemoCommand() - msgs := app.activeMessages - last := msgs[len(msgs)-1] - if !strings.Contains(messageText(last), "1 memo(s)") { - t.Errorf("expected memo count, got: %s", messageText(last)) - } - if !strings.Contains(messageText(last), "test entry") { - t.Errorf("expected entry title, got: %s", messageText(last)) + cmd := app.handleMemoCommand() + model, _ := app.Update(cmd()) + app = model.(App) + if !strings.Contains(app.state.StatusText, "test entry") { + t.Errorf("expected status to include entry title, got: %s", app.state.StatusText) } }) - t.Run("nil memoSvc shows error", func(t *testing.T) { - app, _ := newTestApp(t) + t.Run("routes through runtime system tool", func(t *testing.T) { + app, runtime := newTestApp(t) cmd := app.handleMemoCommand() - if cmd != nil { - t.Error("expected nil cmd") + if cmd == nil { + t.Fatal("expected async cmd") } - msgs := app.activeMessages - if len(msgs) == 0 { - t.Fatal("expected at least one inline message") + _ = cmd() + if len(runtime.systemToolCalls) != 1 { + t.Fatalf("system tool calls = %d, want 1", len(runtime.systemToolCalls)) } - last := msgs[len(msgs)-1] - if !strings.Contains(messageText(last), "not enabled") { - t.Errorf("expected 'not enabled' message, got: %s", messageText(last)) + if runtime.systemToolCalls[0].ToolName != tools.ToolNameMemoList { + t.Fatalf("ToolName = %q, want %q", runtime.systemToolCalls[0].ToolName, tools.ToolNameMemoList) } }) } @@ -2113,16 +2134,16 @@ func TestHandleRememberCommand(t *testing.T) { t.Run("saves memo and shows confirmation", func(t *testing.T) { app, _ := newTestAppWithMemo(t) cmd := app.handleRememberCommand("my preference") - if cmd != nil { - t.Error("expected nil cmd") + if cmd == nil { + t.Fatal("expected async cmd") } - msgs := app.activeMessages - last := msgs[len(msgs)-1] - if !strings.Contains(messageText(last), "Memo saved") { - t.Errorf("expected saved confirmation, got: %s", messageText(last)) + model, _ := app.Update(cmd()) + app = model.(App) + if !strings.Contains(app.state.StatusText, "Memory saved") { + t.Errorf("expected saved confirmation, got: %s", app.state.StatusText) } // Verify the entry was actually saved - entries, _ := app.memoSvc.List(context.Background()) + entries, _ := app.memoSvc.List(context.Background(), memo.ScopeUser) if len(entries) != 1 { t.Fatalf("expected 1 entry, got %d", len(entries)) } @@ -2131,6 +2152,29 @@ func TestHandleRememberCommand(t *testing.T) { } }) + t.Run("user memo is visible from another workspace", func(t *testing.T) { + baseDir := t.TempDir() + app, _ := newTestAppWithMemoBaseDir(t, baseDir) + cmd := app.handleRememberCommand("global preference") + if cmd == nil { + t.Fatal("expected async cmd") + } + model, _ := app.Update(cmd()) + app = model.(App) + + otherSvc := memo.NewService(memo.NewFileStore(baseDir, t.TempDir()), newDefaultAppConfig().Memo, nil) + entries, err := otherSvc.List(context.Background(), memo.ScopeUser) + if err != nil { + t.Fatalf("List() error = %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 shared user memo, got %d", len(entries)) + } + if entries[0].Title != "global preference" { + t.Fatalf("shared entry title = %q, want %q", entries[0].Title, "global preference") + } + }) + t.Run("empty text shows usage", func(t *testing.T) { app, _ := newTestAppWithMemo(t) app.handleRememberCommand("") @@ -2151,13 +2195,18 @@ func TestHandleRememberCommand(t *testing.T) { } }) - t.Run("nil memoSvc shows error", func(t *testing.T) { - app, _ := newTestApp(t) - app.handleRememberCommand("something") - msgs := app.activeMessages - last := msgs[len(msgs)-1] - if !strings.Contains(messageText(last), "not enabled") { - t.Errorf("expected 'not enabled' message, got: %s", messageText(last)) + t.Run("routes through runtime system tool", func(t *testing.T) { + app, runtime := newTestApp(t) + cmd := app.handleRememberCommand("something") + if cmd == nil { + t.Fatal("expected async cmd") + } + _ = cmd() + if len(runtime.systemToolCalls) != 1 { + t.Fatalf("system tool calls = %d, want 1", len(runtime.systemToolCalls)) + } + if runtime.systemToolCalls[0].ToolName != tools.ToolNameMemoRemember { + t.Fatalf("ToolName = %q, want %q", runtime.systemToolCalls[0].ToolName, tools.ToolNameMemoRemember) } }) } @@ -2170,14 +2219,14 @@ func TestHandleForgetCommand(t *testing.T) { app.memoSvc.Add(context.Background(), memo.Entry{Type: memo.TypeUser, Title: "remove me", Content: "test", Source: memo.SourceUserManual}) app.memoSvc.Add(context.Background(), memo.Entry{Type: memo.TypeFeedback, Title: "keep this", Content: "test2", Source: memo.SourceUserManual}) - app.handleForgetCommand("remove") - msgs := app.activeMessages - last := msgs[len(msgs)-1] - if !strings.Contains(messageText(last), "Removed 1 memo") { - t.Errorf("expected removal confirmation, got: %s", messageText(last)) + cmd := app.handleForgetCommand("remove") + model, _ := app.Update(cmd()) + app = model.(App) + if !strings.Contains(app.state.StatusText, "Removed 1 memo") { + t.Errorf("expected removal confirmation, got: %s", app.state.StatusText) } // Verify only one was removed - entries, _ := app.memoSvc.List(context.Background()) + entries, _ := app.memoSvc.List(context.Background(), memo.ScopeAll) if len(entries) != 1 { t.Fatalf("expected 1 remaining entry, got %d", len(entries)) } @@ -2188,11 +2237,11 @@ func TestHandleForgetCommand(t *testing.T) { t.Run("no match shows message", func(t *testing.T) { app, _ := newTestAppWithMemo(t) - app.handleForgetCommand("nonexistent") - msgs := app.activeMessages - last := msgs[len(msgs)-1] - if !strings.Contains(messageText(last), "No memos matching") { - t.Errorf("expected no match message, got: %s", messageText(last)) + cmd := app.handleForgetCommand("nonexistent") + model, _ := app.Update(cmd()) + app = model.(App) + if !strings.Contains(app.state.StatusText, "No memos matching") { + t.Errorf("expected no match message, got: %s", app.state.StatusText) } }) @@ -2206,13 +2255,18 @@ func TestHandleForgetCommand(t *testing.T) { } }) - t.Run("nil memoSvc shows error", func(t *testing.T) { - app, _ := newTestApp(t) - app.handleForgetCommand("something") - msgs := app.activeMessages - last := msgs[len(msgs)-1] - if !strings.Contains(messageText(last), "not enabled") { - t.Errorf("expected 'not enabled' message, got: %s", messageText(last)) + t.Run("routes through runtime system tool", func(t *testing.T) { + app, runtime := newTestApp(t) + cmd := app.handleForgetCommand("something") + if cmd == nil { + t.Fatal("expected async cmd") + } + _ = cmd() + if len(runtime.systemToolCalls) != 1 { + t.Fatalf("system tool calls = %d, want 1", len(runtime.systemToolCalls)) + } + if runtime.systemToolCalls[0].ToolName != tools.ToolNameMemoRemove { + t.Fatalf("ToolName = %q, want %q", runtime.systemToolCalls[0].ToolName, tools.ToolNameMemoRemove) } }) } diff --git a/internal/tui/services/runtime_service.go b/internal/tui/services/runtime_service.go index 16d5caa1..dc879987 100644 --- a/internal/tui/services/runtime_service.go +++ b/internal/tui/services/runtime_service.go @@ -7,6 +7,7 @@ import ( tea "github.com/charmbracelet/bubbletea" agentruntime "neo-code/internal/runtime" + "neo-code/internal/tools" ) const permissionResolveTimeout = 10 * time.Second @@ -27,6 +28,11 @@ type Compactor interface { Compact(ctx context.Context, input agentruntime.CompactInput) (agentruntime.CompactResult, error) } +// SystemToolRunner 定义执行 runtime 系统工具入口所需最小能力。 +type SystemToolRunner interface { + ExecuteSystemTool(ctx context.Context, input agentruntime.SystemToolInput) (tools.ToolResult, error) +} + // PermissionResolver 定义权限审批提交所需最小能力。 type PermissionResolver interface { ResolvePermission(ctx context.Context, input agentruntime.PermissionResolutionInput) error @@ -80,6 +86,18 @@ func RunCompactCmd( } } +// RunSystemToolCmd 执行 runtime 系统工具入口,并将结果映射为 UI 消息。 +func RunSystemToolCmd( + runtime SystemToolRunner, + input agentruntime.SystemToolInput, + doneMsg func(tools.ToolResult, error) tea.Msg, +) tea.Cmd { + return func() tea.Msg { + result, err := runtime.ExecuteSystemTool(context.Background(), input) + return doneMsg(result, err) + } +} + // RunResolvePermissionCmd 提交权限审批决定,并将结果映射为 UI 消息。 func RunResolvePermissionCmd( runtime PermissionResolver,