Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/app/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func newMemoExtractorAdapter(
})
})

scheduler.ScheduleWithExtractor(sessionID, messages, memo.NewLLMExtractor(generator, cfg.Memo.ExtractRecentMessages))
scheduler.ScheduleWithExtractor(sessionID, messages, memo.NewLLMExtractor(generator))
})
}

Expand Down
41 changes: 41 additions & 0 deletions internal/app/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,47 @@ func TestNewMemoExtractorAdapterBuildsProviderSafeMemoWindow(t *testing.T) {
}
}

func TestNewMemoExtractorAdapterUsesFullRunMemoWindow(t *testing.T) {
t.Setenv(config.OpenAIDefaultAPIKeyEnv, "token")
cfg := config.StaticDefaults().Clone()
cfg.SelectedProvider = config.OpenAIName
manager := config.NewManager(config.NewLoader("", &cfg))

providerStub := &stubMemoProvider{
generate: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error {
if len(req.Messages) != 12 {
t.Fatalf("unexpected memo window length %d, want full run: %+v", len(req.Messages), req.Messages)
}
events <- providertypes.NewTextDeltaStreamEvent(`[]`)
events <- providertypes.NewMessageDoneStreamEvent("stop", nil)
return nil
},
}
factory := &stubMemoProviderFactory{provider: providerStub}
scheduler := &stubMemoExtractorScheduler{}
extractor := newMemoExtractorAdapter(factory, manager, scheduler)

inputMessages := make([]providertypes.Message, 0, 12)
for index := 0; index < 12; index++ {
inputMessages = append(inputMessages, providertypes.Message{
Role: providertypes.RoleUser,
Parts: []providertypes.ContentPart{providertypes.NewTextPart(fmt.Sprintf("message-%02d", index))},
})
}
extractor.Schedule("session-1", inputMessages)
if !scheduler.called || scheduler.extractor == nil {
t.Fatalf("expected scheduler to receive extractor")
}

_, err := scheduler.extractor.Extract(context.Background(), inputMessages)
if err != nil {
t.Fatalf("extractor.Extract() error = %v", err)
}
if !factory.called {
t.Fatalf("expected provider factory Build to be called")
}
}

func TestNewMemoExtractorAdapterKeepsScheduledConfigSnapshot(t *testing.T) {
t.Setenv(config.OpenAIDefaultAPIKeyEnv, "openai-token")
t.Setenv(config.QiniuDefaultAPIKeyEnv, "qiniu-token")
Expand Down
46 changes: 16 additions & 30 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1495,12 +1495,11 @@ func TestMemoConfigClone(t *testing.T) {
t.Parallel()

original := MemoConfig{
Enabled: true,
AutoExtract: false,
MaxEntries: 100,
MaxIndexBytes: 2048,
ExtractTimeoutSec: 9,
ExtractRecentMessages: 3,
Enabled: true,
AutoExtract: false,
MaxEntries: 100,
MaxIndexBytes: 2048,
ExtractTimeoutSec: 9,
}
cloned := original.Clone()
if cloned != original {
Expand All @@ -1518,10 +1517,9 @@ func TestMemoConfigApplyDefaults(t *testing.T) {
t.Run("fills zero fields", func(t *testing.T) {
cfg := MemoConfig{}
cfg.ApplyDefaults(MemoConfig{
MaxEntries: DefaultMemoMaxEntries,
MaxIndexBytes: DefaultMemoMaxIndexBytes,
ExtractTimeoutSec: DefaultMemoExtractTimeoutSec,
ExtractRecentMessages: DefaultMemoExtractRecentMessage,
MaxEntries: DefaultMemoMaxEntries,
MaxIndexBytes: DefaultMemoMaxIndexBytes,
ExtractTimeoutSec: DefaultMemoExtractTimeoutSec,
})
if cfg.MaxEntries != DefaultMemoMaxEntries {
t.Errorf("MaxEntries = %d, want %d", cfg.MaxEntries, DefaultMemoMaxEntries)
Expand All @@ -1532,33 +1530,28 @@ func TestMemoConfigApplyDefaults(t *testing.T) {
if cfg.ExtractTimeoutSec != DefaultMemoExtractTimeoutSec {
t.Errorf("ExtractTimeoutSec = %d, want %d", cfg.ExtractTimeoutSec, DefaultMemoExtractTimeoutSec)
}
if cfg.ExtractRecentMessages != DefaultMemoExtractRecentMessage {
t.Errorf("ExtractRecentMessages = %d, want %d", cfg.ExtractRecentMessages, DefaultMemoExtractRecentMessage)
}
})

t.Run("preserves explicit fields", func(t *testing.T) {
cfg := MemoConfig{
MaxEntries: 50,
MaxIndexBytes: 1024,
ExtractTimeoutSec: 30,
ExtractRecentMessages: 5,
MaxEntries: 50,
MaxIndexBytes: 1024,
ExtractTimeoutSec: 30,
}
cfg.ApplyDefaults(defaultMemoConfig())
if cfg.MaxEntries != 50 || cfg.MaxIndexBytes != 1024 || cfg.ExtractTimeoutSec != 30 || cfg.ExtractRecentMessages != 5 {
if cfg.MaxEntries != 50 || cfg.MaxIndexBytes != 1024 || cfg.ExtractTimeoutSec != 30 {
t.Fatalf("ApplyDefaults() unexpectedly overwrote explicit values: %+v", cfg)
}
})

t.Run("preserves negative fields for validation", func(t *testing.T) {
cfg := MemoConfig{
MaxEntries: -1,
MaxIndexBytes: -2,
ExtractTimeoutSec: -3,
ExtractRecentMessages: -4,
MaxEntries: -1,
MaxIndexBytes: -2,
ExtractTimeoutSec: -3,
}
cfg.ApplyDefaults(defaultMemoConfig())
if cfg.MaxEntries != -1 || cfg.MaxIndexBytes != -2 || cfg.ExtractTimeoutSec != -3 || cfg.ExtractRecentMessages != -4 {
if cfg.MaxEntries != -1 || cfg.MaxIndexBytes != -2 || cfg.ExtractTimeoutSec != -3 {
t.Fatalf("ApplyDefaults() unexpectedly rewrote invalid values: %+v", cfg)
}
})
Expand Down Expand Up @@ -1603,13 +1596,6 @@ func TestMemoConfigValidate(t *testing.T) {
}
})

t.Run("non-positive ExtractRecentMessages", func(t *testing.T) {
cfg := defaultMemoConfig()
cfg.ExtractRecentMessages = 0
if err := cfg.Validate(); err == nil {
t.Fatal("non-positive ExtractRecentMessages should fail validation")
}
})
}

func TestNormalizeWorkdirEdgeCases(t *testing.T) {
Expand Down
63 changes: 47 additions & 16 deletions internal/config/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,11 @@ type persistedAskConfig struct {
}

type persistedMemoConfig struct {
Enabled *bool `yaml:"enabled,omitempty"`
AutoExtract *bool `yaml:"auto_extract,omitempty"`
MaxEntries *int `yaml:"max_entries,omitempty"`
MaxIndexBytes *int `yaml:"max_index_bytes,omitempty"`
ExtractTimeoutSec *int `yaml:"extract_timeout_sec,omitempty"`
ExtractRecentMessages *int `yaml:"extract_recent_messages,omitempty"`
Enabled *bool `yaml:"enabled,omitempty"`
AutoExtract *bool `yaml:"auto_extract,omitempty"`
MaxEntries *int `yaml:"max_entries,omitempty"`
MaxIndexBytes *int `yaml:"max_index_bytes,omitempty"`
ExtractTimeoutSec *int `yaml:"extract_timeout_sec,omitempty"`
}

func NewLoader(baseDir string, defaults *Config) *Loader {
Expand Down Expand Up @@ -225,6 +224,9 @@ func parseConfigWithContextDefaults(
}

func parseCurrentConfig(data []byte, contextDefaults ContextConfig, memoDefaults MemoConfig) (*Config, error) {
if err := rejectRemovedMemoFields(data); err != nil {
return nil, err
}
var file persistedConfig
decoder := yaml.NewDecoder(bytes.NewReader(data))
decoder.KnownFields(true)
Expand Down Expand Up @@ -384,14 +386,12 @@ func newPersistedMemoConfig(cfg MemoConfig) persistedMemoConfig {
maxEntries := cfg.MaxEntries
maxIndexBytes := cfg.MaxIndexBytes
extractTimeoutSec := cfg.ExtractTimeoutSec
extractRecentMessages := cfg.ExtractRecentMessages
return persistedMemoConfig{
Enabled: &enabled,
AutoExtract: &autoExtract,
MaxEntries: &maxEntries,
MaxIndexBytes: &maxIndexBytes,
ExtractTimeoutSec: &extractTimeoutSec,
ExtractRecentMessages: &extractRecentMessages,
Enabled: &enabled,
AutoExtract: &autoExtract,
MaxEntries: &maxEntries,
MaxIndexBytes: &maxIndexBytes,
ExtractTimeoutSec: &extractTimeoutSec,
}
}

Expand All @@ -413,12 +413,43 @@ func fromPersistedMemoConfig(file persistedMemoConfig, defaults MemoConfig) Memo
if file.ExtractTimeoutSec != nil {
out.ExtractTimeoutSec = *file.ExtractTimeoutSec
}
if file.ExtractRecentMessages != nil {
out.ExtractRecentMessages = *file.ExtractRecentMessages
}
return out
}

// rejectRemovedMemoFields 在 strict decode 前拦截已删除的 memo 字段,输出明确迁移提示。
func rejectRemovedMemoFields(data []byte) error {
var root yaml.Node
if err := yaml.Unmarshal(data, &root); err != nil {
return err
}
if len(root.Content) == 0 {
return nil
}
doc := root.Content[0]
if doc.Kind != yaml.MappingNode {
return nil
}

for i := 0; i < len(doc.Content); i += 2 {
if strings.TrimSpace(doc.Content[i].Value) != "memo" {
continue
}
memoNode := doc.Content[i+1]
if memoNode.Kind != yaml.MappingNode {
return nil
}
for j := 0; j < len(memoNode.Content); j += 2 {
if strings.TrimSpace(memoNode.Content[j].Value) == "extract_recent_messages" {
return fmt.Errorf(
"config: memo.extract_recent_messages has been removed; memory extraction now always uses the full run boundary",
)
}
}
return nil
}
return nil
}

// normalizeVerificationSchemaContent 在内存中预处理 verification schema,避免旧字段先于 strict decode 触发硬失败。
func normalizeVerificationSchemaContent(raw []byte) ([]byte, bool, error) {
if len(bytes.TrimSpace(raw)) == 0 {
Expand Down
34 changes: 22 additions & 12 deletions internal/config/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1930,7 +1930,6 @@ memo:
max_entries: 123
max_index_bytes: 4096
extract_timeout_sec: 9
extract_recent_messages: 4
`
writeLoaderConfig(t, loader, raw)

Expand All @@ -1953,9 +1952,6 @@ memo:
if cfg.Memo.ExtractTimeoutSec != 9 {
t.Fatalf("expected memo.extract_timeout_sec=9, got %d", cfg.Memo.ExtractTimeoutSec)
}
if cfg.Memo.ExtractRecentMessages != 4 {
t.Fatalf("expected memo.extract_recent_messages=4, got %d", cfg.Memo.ExtractRecentMessages)
}

data, err := os.ReadFile(loader.ConfigPath())
if err != nil {
Expand Down Expand Up @@ -2001,9 +1997,6 @@ shell: powershell
if cfg.Memo.ExtractTimeoutSec <= 0 {
t.Fatalf("expected memo.extract_timeout_sec to be defaulted, got %d", cfg.Memo.ExtractTimeoutSec)
}
if cfg.Memo.ExtractRecentMessages <= 0 {
t.Fatalf("expected memo.extract_recent_messages to be defaulted, got %d", cfg.Memo.ExtractRecentMessages)
}
}

func TestLoaderRejectsLegacyMemoMaxIndexLinesField(t *testing.T) {
Expand All @@ -2028,6 +2021,28 @@ memo:
}
}

func TestLoaderRejectsRemovedMemoExtractRecentMessagesField(t *testing.T) {
t.Parallel()

loader := NewLoader(t.TempDir(), testDefaultConfig())
raw := `
selected_provider: openai
current_model: gpt-4.1
shell: powershell
memo:
extract_recent_messages: 4
`
writeLoaderConfig(t, loader, raw)

cfg, err := loader.Load(context.Background())
if err == nil {
t.Fatalf("expected removed memo field to be rejected, cfg=%+v", cfg)
}
if !strings.Contains(err.Error(), "memo.extract_recent_messages has been removed") {
t.Fatalf("expected migration hint for extract_recent_messages, got %v", err)
}
}

func TestLoaderRejectsExplicitInvalidMemoNumbers(t *testing.T) {
t.Parallel()

Expand All @@ -2051,11 +2066,6 @@ func TestLoaderRejectsExplicitInvalidMemoNumbers(t *testing.T) {
fieldYAML: "extract_timeout_sec: -1",
errContain: "config: memo: extract_timeout_sec must be greater than 0",
},
{
name: "negative extract_recent_messages",
fieldYAML: "extract_recent_messages: -1",
errContain: "config: memo: extract_recent_messages must be greater than 0",
},
}

for _, tt := range tests {
Expand Down
35 changes: 13 additions & 22 deletions internal/config/memo.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,28 @@ package config
import "errors"

const (
DefaultMemoMaxEntries = 200
DefaultMemoMaxIndexBytes = 16 * 1024
DefaultMemoExtractTimeoutSec = 15
DefaultMemoExtractRecentMessage = 10
DefaultMemoMaxEntries = 200
DefaultMemoMaxIndexBytes = 16 * 1024
DefaultMemoExtractTimeoutSec = 15
)

// MemoConfig 控制跨会话持久记忆的行为配置。
type MemoConfig struct {
Enabled bool `yaml:"enabled,omitempty"`
AutoExtract bool `yaml:"auto_extract,omitempty"`
MaxEntries int `yaml:"max_entries,omitempty"`
MaxIndexBytes int `yaml:"max_index_bytes,omitempty"`
ExtractTimeoutSec int `yaml:"extract_timeout_sec,omitempty"`
ExtractRecentMessages int `yaml:"extract_recent_messages,omitempty"`
Enabled bool `yaml:"enabled,omitempty"`
AutoExtract bool `yaml:"auto_extract,omitempty"`
MaxEntries int `yaml:"max_entries,omitempty"`
MaxIndexBytes int `yaml:"max_index_bytes,omitempty"`
ExtractTimeoutSec int `yaml:"extract_timeout_sec,omitempty"`
}

// defaultMemoConfig 返回跨会话记忆的默认配置。
func defaultMemoConfig() MemoConfig {
return MemoConfig{
Enabled: true,
AutoExtract: true,
MaxEntries: DefaultMemoMaxEntries,
MaxIndexBytes: DefaultMemoMaxIndexBytes,
ExtractTimeoutSec: DefaultMemoExtractTimeoutSec,
ExtractRecentMessages: DefaultMemoExtractRecentMessage,
Enabled: true,
AutoExtract: true,
MaxEntries: DefaultMemoMaxEntries,
MaxIndexBytes: DefaultMemoMaxIndexBytes,
ExtractTimeoutSec: DefaultMemoExtractTimeoutSec,
}
}

Expand All @@ -50,9 +47,6 @@ func (c *MemoConfig) ApplyDefaults(defaults MemoConfig) {
if c.ExtractTimeoutSec == 0 {
c.ExtractTimeoutSec = defaults.ExtractTimeoutSec
}
if c.ExtractRecentMessages == 0 {
c.ExtractRecentMessages = defaults.ExtractRecentMessages
}
}

// Validate 校验 memo 配置是否合法。
Expand All @@ -66,8 +60,5 @@ func (c MemoConfig) Validate() error {
if c.ExtractTimeoutSec <= 0 {
return errors.New("extract_timeout_sec must be greater than 0")
}
if c.ExtractRecentMessages <= 0 {
return errors.New("extract_recent_messages must be greater than 0")
}
return nil
}
Loading
Loading