From 5655a763689874afa62df662b332986c7b505775 Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Wed, 29 Apr 2026 00:31:17 +0800 Subject: [PATCH 1/2] =?UTF-8?q?refactor(config):=E6=94=B6=E6=95=9B?= =?UTF-8?q?=E5=86=85=E7=BD=AE=E6=A8=A1=E5=9E=8B=E6=9D=A5=E6=BA=90=E5=B9=B6?= =?UTF-8?q?=E7=AE=80=E5=8C=96=E8=87=AA=E5=AE=9A=E4=B9=89=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/config-management-detail-design.md | 11 +- internal/config/config_test.go | 1 + internal/config/loader_test.go | 61 ++---- internal/config/provider.go | 201 ++++++++++++++++-- internal/config/provider_custom_normalize.go | 51 ++--- internal/config/provider_loader.go | 39 +--- internal/config/provider_test.go | 151 ++++++++++--- internal/config/state/model.go | 11 +- .../config/state/model_additional_test.go | 11 + internal/config/state/model_test.go | 9 +- internal/config/state/service.go | 49 +++++ .../config/state/service_provider_create.go | 24 +-- .../state/service_provider_create_test.go | 3 +- internal/config/state/service_test.go | 27 +++ internal/provider/catalog/additional_test.go | 45 +--- internal/provider/catalog/service.go | 120 +++-------- internal/provider/catalog/service_test.go | 197 +++++------------ internal/provider/catalog/store.go | 7 +- internal/provider/catalog/store_test.go | 20 +- .../runtime/subagent_more_branches_test.go | 6 + internal/tui/core/app/update.go | 124 +++++------ internal/tui/core/app/update_test.go | 6 +- 22 files changed, 617 insertions(+), 557 deletions(-) diff --git a/docs/config-management-detail-design.md b/docs/config-management-detail-design.md index 27b98b50..4500aa1e 100644 --- a/docs/config-management-detail-design.md +++ b/docs/config-management-detail-design.md @@ -138,11 +138,12 @@ custom provider 来自: ## custom provider `models` 校验约束 -`~/.neocode/providers//provider.yaml` 中允许通过 `models` 补齐模型元数据,用于 catalog/discovery 无法提供完整 `ContextWindow` 或 `MaxOutputTokens` 的场景。 +`~/.neocode/providers//provider.yaml` 中的 `models` 现在只表达“用户声明需要出现的模型 ID / 展示名”,不再承担元数据补齐职责。 -该能力的约束是: +当前约束是: - `models[].id` 必须非空。 -- `models[].context_window` 和 `models[].max_output_tokens` 如果显式提供,必须大于 `0`。 -- 重复的模型 `id` 会在加载 custom provider 时直接失败,不保留 silently drop 的宽松行为。 -- 这些元数据不会写回 `config.yaml`,只在 custom provider 文件中声明,并通过现有 catalog 合并链路参与运行时解析。 +- `models[].name` 必须非空。 +- `models` 中不允许出现 `context_window`、`max_output_tokens`、`description`、`capability_hints`。 +- discovery 缓存只保存规范化后的白名单字段:`id/name/description/context_window/max_output_tokens/capability_hints`。 +- builtin provider 不再走 `/models` discovery,模型清单改为仓库内静态维护。 diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 572dac7c..05c0579b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -26,6 +26,7 @@ func testDefaultProviderConfig() ProviderConfig { BaseURL: testBaseURL, Model: testModel, APIKeyEnv: testAPIKeyEnv, + Models: cloneBuiltinModels(openAIStaticModels), Source: ProviderSourceBuiltin, } } diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 2a79f5f6..77cec06f 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -493,8 +493,6 @@ api_key_env: COMPANY_GATEWAY_API_KEY models: - id: deepseek-coder name: DeepSeek Coder - context_window: 131072 - max_output_tokens: 8192 ` customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { @@ -532,10 +530,10 @@ models: t.Fatalf("expected custom provider default model to be empty, got %q", customProvider.Model) } if len(customProvider.Models) != 1 { - t.Fatalf("expected custom provider model metadata from provider.yaml, got %+v", customProvider.Models) + t.Fatalf("expected custom provider models from provider.yaml, got %+v", customProvider.Models) } - if customProvider.Models[0].ID != "deepseek-coder" || customProvider.Models[0].ContextWindow != 131072 { - t.Fatalf("expected parsed model metadata, got %+v", customProvider.Models[0]) + if customProvider.Models[0].ID != "deepseek-coder" || customProvider.Models[0].ContextWindow != 0 { + t.Fatalf("expected parsed id/name only model, got %+v", customProvider.Models[0]) } } @@ -815,7 +813,7 @@ models: } } -func TestLoaderRejectsCustomProviderModelWithInvalidContextWindow(t *testing.T) { +func TestLoaderRejectsCustomProviderModelWithUnsupportedContextWindow(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) @@ -832,19 +830,19 @@ api_key_env: COMPANY_GATEWAY_API_KEY models: - id: deepseek-coder name: DeepSeek Coder - context_window: 0 + context_window: 131072 ` if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { t.Fatalf("write provider.yaml: %v", err) } _, err := loader.Load(context.Background()) - if err == nil || !strings.Contains(err.Error(), "context_window") { - t.Fatalf("expected invalid context_window rejection, got %v", err) + if err == nil || !strings.Contains(err.Error(), "field context_window not found") { + t.Fatalf("expected unknown context_window rejection, got %v", err) } } -func TestLoaderRejectsCustomProviderModelWithInvalidMaxOutputTokens(t *testing.T) { +func TestLoaderRejectsCustomProviderModelWithUnsupportedMaxOutputTokens(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) @@ -861,15 +859,15 @@ api_key_env: COMPANY_GATEWAY_API_KEY models: - id: deepseek-coder name: DeepSeek Coder - max_output_tokens: 0 + max_output_tokens: 8192 ` if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { t.Fatalf("write provider.yaml: %v", err) } _, err := loader.Load(context.Background()) - if err == nil || !strings.Contains(err.Error(), "max_output_tokens") { - t.Fatalf("expected invalid max_output_tokens rejection, got %v", err) + if err == nil || !strings.Contains(err.Error(), "field max_output_tokens not found") { + t.Fatalf("expected unknown max_output_tokens rejection, got %v", err) } } @@ -1327,10 +1325,8 @@ func TestSaveCustomProviderAndLoadCustomProviderStayConsistent(t *testing.T) { DiscoveryEndpointPath: "/should-be-cleared", Models: []providertypes.ModelDescriptor{ { - ID: "manual-model-1", - Name: "Manual Model 1", - ContextWindow: 131072, - MaxOutputTokens: 8192, + ID: "manual-model-1", + Name: "Manual Model 1", }, }, }, @@ -1387,7 +1383,7 @@ func TestSaveCustomProviderAndLoadCustomProviderStayConsistent(t *testing.T) { } } -func TestSaveCustomProviderManualModelsPersistOptionalFields(t *testing.T) { +func TestSaveCustomProviderManualModelsPersistIDAndNameOnly(t *testing.T) { t.Parallel() baseDir := t.TempDir() @@ -1404,10 +1400,8 @@ func TestSaveCustomProviderManualModelsPersistOptionalFields(t *testing.T) { Name: "Manual Model 1", }, { - ID: "manual-model-2", - Name: "Manual Model 2", - ContextWindow: 131072, - MaxOutputTokens: 8192, + ID: "manual-model-2", + Name: "Manual Model 2", }, }, }) @@ -1428,11 +1422,8 @@ func TestSaveCustomProviderManualModelsPersistOptionalFields(t *testing.T) { if len(cfg.Models) != 2 { t.Fatalf("expected model list with 2 entries, got %+v", cfg.Models) } - if cfg.Models[0].ContextWindow != 0 || cfg.Models[0].MaxOutputTokens != 0 { - t.Fatalf("expected optional fields omitted for model-1, got %+v", cfg.Models[0]) - } - if cfg.Models[1].ContextWindow != 131072 || cfg.Models[1].MaxOutputTokens != 8192 { - t.Fatalf("expected optional fields persisted for model-2, got %+v", cfg.Models[1]) + if cfg.Models[0].ContextWindow != 0 || cfg.Models[0].MaxOutputTokens != 0 || cfg.Models[1].ContextWindow != 0 || cfg.Models[1].MaxOutputTokens != 0 { + t.Fatalf("expected persisted manual models to omit metadata, got %+v", cfg.Models) } } @@ -1641,10 +1632,8 @@ func TestToCustomProviderModelFiles(t *testing.T) { Name: "Model A", }, { - ID: "model-b", - Name: "Model B", - ContextWindow: 32768, - MaxOutputTokens: 2048, + ID: "model-b", + Name: "Model B", }, { ID: "Model-A", @@ -1657,14 +1646,8 @@ func TestToCustomProviderModelFiles(t *testing.T) { if converted[0].ID != "model-a" || converted[0].Name != "Model A" { t.Fatalf("expected normalized merge result for model-a, got %+v", converted[0]) } - if converted[0].ContextWindow != nil || converted[0].MaxOutputTokens != nil { - t.Fatalf("expected model-a optional pointers nil, got %+v", converted[0]) - } - if converted[1].ContextWindow == nil || *converted[1].ContextWindow != 32768 { - t.Fatalf("expected model-b context window pointer, got %+v", converted[1]) - } - if converted[1].MaxOutputTokens == nil || *converted[1].MaxOutputTokens != 2048 { - t.Fatalf("expected model-b max output tokens pointer, got %+v", converted[1]) + if converted[1].ID != "model-b" || converted[1].Name != "Model B" { + t.Fatalf("expected id/name only persistence for model-b, got %+v", converted[1]) } } diff --git a/internal/config/provider.go b/internal/config/provider.go index f05bcd55..8a758d5d 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -85,6 +85,9 @@ func (p ProviderConfig) Validate() error { if p.Source != ProviderSourceCustom && strings.TrimSpace(p.Model) == "" { return fmt.Errorf("provider %q model is empty", p.Name) } + if p.Source == ProviderSourceBuiltin && len(p.Models) > 0 && !containsModelDescriptorID(p.Models, p.Model) { + return fmt.Errorf("provider %q model %q must exist in builtin models", p.Name, p.Model) + } if strings.TrimSpace(p.APIKeyEnv) == "" { return fmt.Errorf("provider %q api_key_env is empty", p.Name) } @@ -210,6 +213,20 @@ func containsProviderName(providers []ProviderConfig, name string) bool { return false } +// containsModelDescriptorID 判断模型列表中是否包含指定 ID。 +func containsModelDescriptorID(models []providertypes.ModelDescriptor, modelID string) bool { + target := provider.NormalizeKey(modelID) + if target == "" { + return false + } + for _, model := range models { + if provider.NormalizeKey(model.ID) == target { + return true + } + } + return false +} + // normalizeConfigKey 统一规范 config 层比较使用的字符串键,避免大小写和空白造成分支漂移。 func normalizeConfigKey(value string) string { return provider.NormalizeKey(value) @@ -415,53 +432,191 @@ const ( ModelScopeName = "modelscope" ModelScopeDefaultBaseURL = "https://api-inference.modelscope.cn/v1" - ModelScopeDefaultModel = "Qwen/Qwen2.5-7B-Instruct" + ModelScopeDefaultModel = "deepseek-ai/DeepSeek-V3.2" ModelScopeDefaultAPIKeyEnv = "MODELSCOPE_API_KEY" ) +var openAIStaticModels = []providertypes.ModelDescriptor{ + builtinModel( + "gpt-5.4", + "GPT-5.4", + "Flagship GPT-5 model for reasoning, coding, and multimodal agent workflows.", + 400000, + 128000, + builtinCapabilities(providertypes.ModelCapabilityStateSupported, providertypes.ModelCapabilityStateSupported), + ), + builtinModel( + "gpt-5.4-mini", + "GPT-5.4 Mini", + "Lower-latency GPT-5 variant for everyday coding, chat, and multimodal tasks.", + 400000, + 128000, + builtinCapabilities(providertypes.ModelCapabilityStateSupported, providertypes.ModelCapabilityStateSupported), + ), + builtinModel( + "gpt-5.3-codex", + "GPT-5.3 Codex", + "GPT-5 Codex family model tuned for code generation, editing, and agentic development loops.", + 400000, + 128000, + builtinCapabilities(providertypes.ModelCapabilityStateSupported, providertypes.ModelCapabilityStateSupported), + ), + builtinModel( + "gpt-4.1", + "GPT-4.1", + "High-capability GPT-4.1 model for complex coding and long-context multimodal work.", + 1047576, + 32768, + builtinCapabilities(providertypes.ModelCapabilityStateSupported, providertypes.ModelCapabilityStateSupported), + ), + builtinModel( + "gpt-4o", + "GPT-4o", + "General-purpose GPT-4o omni model for realtime, text, and image workflows.", + 128000, + 16384, + builtinCapabilities(providertypes.ModelCapabilityStateSupported, providertypes.ModelCapabilityStateSupported), + ), + builtinModel( + "gpt-4o-mini", + "GPT-4o Mini", + "Cost-efficient GPT-4o variant for fast multimodal and tool-using tasks.", + 128000, + 16384, + builtinCapabilities(providertypes.ModelCapabilityStateSupported, providertypes.ModelCapabilityStateSupported), + ), +} + +var geminiStaticModels = []providertypes.ModelDescriptor{ + builtinModel( + "gemini-2.5-flash", + "Gemini 2.5 Flash", + "Fast Gemini 2.5 model with long-context multimodal input and tool support.", + 1048576, + 65536, + builtinCapabilities(providertypes.ModelCapabilityStateSupported, providertypes.ModelCapabilityStateSupported), + ), + builtinModel( + "gemini-2.5-pro", + "Gemini 2.5 Pro", + "High-reasoning Gemini 2.5 model with long-context multimodal input and tool support.", + 1048576, + 65536, + builtinCapabilities(providertypes.ModelCapabilityStateSupported, providertypes.ModelCapabilityStateSupported), + ), +} + +var qiniuStaticModels = []providertypes.ModelDescriptor{ + builtinModel( + "z-ai/glm-5.1", + "GLM 5.1", + "GLM-5.1 model exposed by the Qiniu gateway for long-context reasoning and tool-using tasks.", + 200000, + 128000, + builtinCapabilities(providertypes.ModelCapabilityStateSupported, providertypes.ModelCapabilityStateUnsupported), + ), +} + +var modelScopeStaticModels = []providertypes.ModelDescriptor{ + builtinModel( + "deepseek-ai/DeepSeek-V3.2", + "DeepSeek V3.2", + "Reasoning-first DeepSeek model available from ModelScope API inference with tool use support.", + 128000, + 8192, + builtinCapabilities(providertypes.ModelCapabilityStateSupported, providertypes.ModelCapabilityStateUnsupported), + ), + builtinModel( + "MiniMax/MiniMax-M2.5", + "MiniMax M2.5", + "General-purpose MiniMax model available from ModelScope API inference for coding and agent workflows.", + 204800, + 0, + builtinCapabilities(providertypes.ModelCapabilityStateSupported, providertypes.ModelCapabilityStateUnsupported), + ), +} + +// builtinCapabilities 构造内建静态模型的能力提示,显式表达支持、未知或不支持状态。 +func builtinCapabilities( + toolCalling providertypes.ModelCapabilityState, + imageInput providertypes.ModelCapabilityState, +) providertypes.ModelCapabilityHints { + return providertypes.ModelCapabilityHints{ + ToolCalling: toolCalling, + ImageInput: imageInput, + } +} + +// builtinModel 构造内建 provider 使用的静态模型条目。 +func builtinModel( + id string, + name string, + description string, + contextWindow int, + maxOutputTokens int, + capabilityHints providertypes.ModelCapabilityHints, +) providertypes.ModelDescriptor { + return providertypes.ModelDescriptor{ + ID: strings.TrimSpace(id), + Name: strings.TrimSpace(name), + Description: strings.TrimSpace(description), + ContextWindow: contextWindow, + MaxOutputTokens: maxOutputTokens, + CapabilityHints: capabilityHints, + } +} + +// cloneBuiltinModels 返回静态模型清单的独立副本,避免不同配置快照共享底层切片。 +func cloneBuiltinModels(models []providertypes.ModelDescriptor) []providertypes.ModelDescriptor { + return providertypes.CloneModelDescriptors(models) +} + func newBuiltinOpenAICompatProvider(name, baseURL, model, apiKeyEnv string) ProviderConfig { return ProviderConfig{ - Name: name, - Driver: provider.DriverOpenAICompat, - BaseURL: baseURL, - Model: model, - APIKeyEnv: apiKeyEnv, - ModelSource: ModelSourceDiscover, - ChatAPIMode: provider.ChatAPIModeChatCompletions, - ChatEndpointPath: "/chat/completions", - DiscoveryEndpointPath: provider.DiscoveryEndpointPathModels, - Source: ProviderSourceBuiltin, + Name: name, + Driver: provider.DriverOpenAICompat, + BaseURL: baseURL, + Model: model, + APIKeyEnv: apiKeyEnv, + ChatAPIMode: provider.ChatAPIModeChatCompletions, + ChatEndpointPath: "/chat/completions", + Source: ProviderSourceBuiltin, } } // OpenAIProvider returns the builtin OpenAI provider definition. func OpenAIProvider() ProviderConfig { - return newBuiltinOpenAICompatProvider(OpenAIName, OpenAIDefaultBaseURL, OpenAIDefaultModel, OpenAIDefaultAPIKeyEnv) + cfg := newBuiltinOpenAICompatProvider(OpenAIName, OpenAIDefaultBaseURL, OpenAIDefaultModel, OpenAIDefaultAPIKeyEnv) + cfg.Models = cloneBuiltinModels(openAIStaticModels) + return cfg } // GeminiProvider returns the builtin Gemini provider definition. func GeminiProvider() ProviderConfig { return ProviderConfig{ - Name: GeminiName, - Driver: provider.DriverGemini, - BaseURL: GeminiDefaultBaseURL, - Model: GeminiDefaultModel, - APIKeyEnv: GeminiDefaultAPIKeyEnv, - ModelSource: ModelSourceDiscover, - ChatEndpointPath: "", - DiscoveryEndpointPath: provider.DiscoveryEndpointPathModels, - Source: ProviderSourceBuiltin, + Name: GeminiName, + Driver: provider.DriverGemini, + BaseURL: GeminiDefaultBaseURL, + Model: GeminiDefaultModel, + APIKeyEnv: GeminiDefaultAPIKeyEnv, + ChatEndpointPath: "", + Models: cloneBuiltinModels(geminiStaticModels), + Source: ProviderSourceBuiltin, } } // QiniuProvider returns the builtin Qiniu provider definition. func QiniuProvider() ProviderConfig { - return newBuiltinOpenAICompatProvider(QiniuName, QiniuDefaultBaseURL, QiniuDefaultModel, QiniuDefaultAPIKeyEnv) + cfg := newBuiltinOpenAICompatProvider(QiniuName, QiniuDefaultBaseURL, QiniuDefaultModel, QiniuDefaultAPIKeyEnv) + cfg.Models = cloneBuiltinModels(qiniuStaticModels) + return cfg } // ModelScopeProvider 返回内置的 ModelScope provider 配置。 func ModelScopeProvider() ProviderConfig { - return newBuiltinOpenAICompatProvider(ModelScopeName, ModelScopeDefaultBaseURL, ModelScopeDefaultModel, ModelScopeDefaultAPIKeyEnv) + cfg := newBuiltinOpenAICompatProvider(ModelScopeName, ModelScopeDefaultBaseURL, ModelScopeDefaultModel, ModelScopeDefaultAPIKeyEnv) + cfg.Models = cloneBuiltinModels(modelScopeStaticModels) + return cfg } // DefaultProviders returns all builtin provider definitions. diff --git a/internal/config/provider_custom_normalize.go b/internal/config/provider_custom_normalize.go index bffb7d03..467fcba4 100644 --- a/internal/config/provider_custom_normalize.go +++ b/internal/config/provider_custom_normalize.go @@ -8,10 +8,7 @@ import ( providertypes "neo-code/internal/provider/types" ) -// ManualModelOptionalIntUnset 用于区分“未填写可选数值字段”和“显式输入 0”。 -const ManualModelOptionalIntUnset = -1 - -// NormalizeCustomProviderInput 统一归一化 custom provider 的输入字段,并执行协议/模型来源的组合校验。 +// NormalizeCustomProviderInput 统一归一化 custom provider 输入,并执行模型来源与 discovery 配置的组合校验。 func NormalizeCustomProviderInput(input SaveCustomProviderInput) (SaveCustomProviderInput, error) { normalized := SaveCustomProviderInput{ Name: strings.TrimSpace(input.Name), @@ -46,7 +43,7 @@ func NormalizeCustomProviderInput(input SaveCustomProviderInput) (SaveCustomProv normalized.ModelSource = ModelSourceDiscover } - models, err := normalizeCustomProviderModels(input.Models, true) + models, err := normalizeCustomProviderModels(input.Models) if err != nil { return SaveCustomProviderInput{}, err } @@ -74,7 +71,6 @@ func NormalizeCustomProviderInput(input SaveCustomProviderInput) (SaveCustomProv } discoveryEndpointPath := "" if normalized.ModelSource != ModelSourceManual && strings.TrimSpace(normalizedDiscoveryEndpointPath) != "" { - var err error discoveryEndpointPath, err = provider.NormalizeProviderDiscoverySettings( normalized.Driver, normalizedDiscoveryEndpointPath, @@ -106,12 +102,12 @@ func NormalizeCustomProviderInput(input SaveCustomProviderInput) (SaveCustomProv return normalized, validateNormalizedCustomProviderInput(normalized) } -// normalizeOptionalGenerateInt 归一化可选的生成控制字段,仅保留调用方原始输入,避免在保存前静默吞掉非法值。 +// normalizeOptionalGenerateInt 归一化可选生成控制字段,仅保留调用方原始输入,避免在保存前静默吞掉非法值。 func normalizeOptionalGenerateInt(value int) int { return value } -// validateNormalizedCustomProviderInput 复用统一的 provider 配置校验,避免 custom provider 保存路径和加载路径出现两套规则。 +// validateNormalizedCustomProviderInput 复用统一的 provider 配置校验,避免 custom provider 保存与加载路径出现两套规则。 func validateNormalizedCustomProviderInput(input SaveCustomProviderInput) error { cfg := ProviderConfig{ Name: input.Name, @@ -131,16 +127,13 @@ func validateNormalizedCustomProviderInput(input SaveCustomProviderInput) error return cfg.Validate() } -// NormalizeCustomProviderModels 统一归一化 custom provider 的模型描述并校验必填字段和边界条件。 +// NormalizeCustomProviderModels 统一归一化 custom provider 用户可写模型,并校验必填字段与不允许的 metadata。 func NormalizeCustomProviderModels(models []providertypes.ModelDescriptor) ([]providertypes.ModelDescriptor, error) { - return normalizeCustomProviderModels(models, false) + return normalizeCustomProviderModels(models) } -// normalizeCustomProviderModels 统一归一化 custom provider 模型列表,并在需要时兼容历史的零值省略语义。 -func normalizeCustomProviderModels( - models []providertypes.ModelDescriptor, - allowZeroAsUnset bool, -) ([]providertypes.ModelDescriptor, error) { +// normalizeCustomProviderModels 统一清洗 custom provider 用户可写模型,并拒绝任何 metadata 字段输入。 +func normalizeCustomProviderModels(models []providertypes.ModelDescriptor) ([]providertypes.ModelDescriptor, error) { if len(models) == 0 { return nil, nil } @@ -156,17 +149,17 @@ func normalizeCustomProviderModels( if name == "" { return nil, fmt.Errorf("config: models[%d].name is empty", index) } - contextWindow := model.ContextWindow - if contextWindow == ManualModelOptionalIntUnset || (allowZeroAsUnset && contextWindow == 0) { - contextWindow = 0 - } else if contextWindow <= 0 { - return nil, fmt.Errorf("config: models[%d].context_window must be greater than 0", index) + if strings.TrimSpace(model.Description) != "" { + return nil, fmt.Errorf("config: models[%d].description is not supported", index) + } + if model.ContextWindow != 0 { + return nil, fmt.Errorf("config: models[%d].context_window is not supported", index) + } + if model.MaxOutputTokens != 0 { + return nil, fmt.Errorf("config: models[%d].max_output_tokens is not supported", index) } - maxOutputTokens := model.MaxOutputTokens - if maxOutputTokens == ManualModelOptionalIntUnset || (allowZeroAsUnset && maxOutputTokens == 0) { - maxOutputTokens = 0 - } else if maxOutputTokens <= 0 { - return nil, fmt.Errorf("config: models[%d].max_output_tokens must be greater than 0", index) + if model.CapabilityHints != (providertypes.ModelCapabilityHints{}) { + return nil, fmt.Errorf("config: models[%d].capability_hints is not supported", index) } key := provider.NormalizeKey(id) @@ -176,12 +169,8 @@ func normalizeCustomProviderModels( seen[key] = struct{}{} normalized = append(normalized, providertypes.ModelDescriptor{ - ID: id, - Name: name, - Description: strings.TrimSpace(model.Description), - ContextWindow: contextWindow, - MaxOutputTokens: maxOutputTokens, - CapabilityHints: model.CapabilityHints, + ID: id, + Name: name, }) } return normalized, nil diff --git a/internal/config/provider_loader.go b/internal/config/provider_loader.go index f460a76a..628ff574 100644 --- a/internal/config/provider_loader.go +++ b/internal/config/provider_loader.go @@ -35,10 +35,8 @@ type customProviderFile struct { } type customProviderModelFile struct { - ID string `yaml:"id"` - Name string `yaml:"name"` - ContextWindow *int `yaml:"context_window,omitempty"` - MaxOutputTokens *int `yaml:"max_output_tokens,omitempty"` + ID string `yaml:"id"` + Name string `yaml:"name"` } // loadCustomProviders 扫描 baseDir/providers 下的一层子目录,并将其中的 provider.yaml 解析为运行时配置。 @@ -170,28 +168,14 @@ func customProviderModels(models []customProviderModelFile) ([]providertypes.Mod } descriptor := providertypes.ModelDescriptor{ - ID: id, - Name: name, - ContextWindow: ManualModelOptionalIntUnset, - MaxOutputTokens: ManualModelOptionalIntUnset, + ID: id, + Name: name, } key := provider.NormalizeKey(id) if _, exists := seen[key]; exists { return nil, fmt.Errorf("models[%d].id %q is duplicated", index, id) } seen[key] = struct{}{} - if model.ContextWindow != nil { - if *model.ContextWindow <= 0 { - return nil, fmt.Errorf("models[%d].context_window must be greater than 0", index) - } - descriptor.ContextWindow = *model.ContextWindow - } - if model.MaxOutputTokens != nil { - if *model.MaxOutputTokens <= 0 { - return nil, fmt.Errorf("models[%d].max_output_tokens must be greater than 0", index) - } - descriptor.MaxOutputTokens = *model.MaxOutputTokens - } descriptors = append(descriptors, descriptor) } return descriptors, nil @@ -238,7 +222,7 @@ func SaveCustomProviderWithModels(baseDir string, input SaveCustomProviderInput) cfg.BaseURL = normalizedInput.BaseURL cfg.ChatEndpointPath = normalizedInput.ChatEndpointPath cfg.DiscoveryEndpointPath = normalizedInput.DiscoveryEndpointPath - if normalizedInput.ModelSource == ModelSourceManual { + if len(normalizedInput.Models) > 0 { cfg.Models = toCustomProviderModelFiles(normalizedInput.Models) } @@ -262,19 +246,10 @@ func toCustomProviderModelFiles(models []providertypes.ModelDescriptor) []custom } items := make([]customProviderModelFile, 0, len(models)) for _, model := range providertypes.MergeModelDescriptors(models) { - item := customProviderModelFile{ + items = append(items, customProviderModelFile{ ID: strings.TrimSpace(model.ID), Name: strings.TrimSpace(model.Name), - } - if model.ContextWindow > 0 { - value := model.ContextWindow - item.ContextWindow = &value - } - if model.MaxOutputTokens > 0 { - value := model.MaxOutputTokens - item.MaxOutputTokens = &value - } - items = append(items, item) + }) } return items } diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 0668e21f..c229764b 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -280,8 +280,11 @@ func TestQiniuProviderConfig(t *testing.T) { if provider.Source != ProviderSourceBuiltin { t.Fatalf("expected builtin source, got %q", provider.Source) } - if provider.DiscoveryEndpointPath != providerpkg.DiscoveryEndpointPathModels { - t.Fatalf("expected discovery endpoint %q, got %q", providerpkg.DiscoveryEndpointPathModels, provider.DiscoveryEndpointPath) + if provider.DiscoveryEndpointPath != "" { + t.Fatalf("expected builtin discovery endpoint to stay empty, got %q", provider.DiscoveryEndpointPath) + } + if !containsModelDescriptorID(provider.Models, provider.Model) { + t.Fatalf("expected builtin static models to include default model, got %+v", provider.Models) } } @@ -455,17 +458,13 @@ func TestCloneProviderConfigModelDescriptorsIndependence(t *testing.T) { } } -func TestCustomProviderModelsParsesSupportedMetadata(t *testing.T) { +func TestCustomProviderModelsParsesIDAndNameOnly(t *testing.T) { t.Parallel() - contextWindow := 131072 - maxOutputTokens := 8192 models, err := customProviderModels([]customProviderModelFile{ { - ID: "deepseek-coder", - Name: "DeepSeek Coder", - ContextWindow: &contextWindow, - MaxOutputTokens: &maxOutputTokens, + ID: "deepseek-coder", + Name: "DeepSeek Coder", }, }) if err != nil { @@ -475,7 +474,7 @@ func TestCustomProviderModelsParsesSupportedMetadata(t *testing.T) { if len(models) != 1 { t.Fatalf("expected one parsed model, got %+v", models) } - if models[0].ID != "deepseek-coder" || models[0].ContextWindow != 131072 || models[0].MaxOutputTokens != 8192 { + if models[0].ID != "deepseek-coder" || models[0].ContextWindow != 0 || models[0].MaxOutputTokens != 0 { t.Fatalf("unexpected parsed model descriptor: %+v", models[0]) } } @@ -498,58 +497,57 @@ func TestCustomProviderModelsRejectsEmptyName(t *testing.T) { } } -func TestCustomProviderModelsRejectsNonPositiveContextWindow(t *testing.T) { +func TestNormalizeCustomProviderModelsRejectsContextWindow(t *testing.T) { t.Parallel() - contextWindow := 0 - _, err := customProviderModels([]customProviderModelFile{{ + _, err := NormalizeCustomProviderModels([]providertypes.ModelDescriptor{{ ID: "deepseek-coder", Name: "DeepSeek Coder", - ContextWindow: &contextWindow, + ContextWindow: 131072, }}) if err == nil || !strings.Contains(err.Error(), "context_window") { t.Fatalf("expected context_window validation error, got %v", err) } } -func TestCustomProviderModelsRejectsNonPositiveMaxOutputTokens(t *testing.T) { +func TestNormalizeCustomProviderModelsRejectsMaxOutputTokens(t *testing.T) { t.Parallel() - maxOutputTokens := 0 - _, err := customProviderModels([]customProviderModelFile{{ + _, err := NormalizeCustomProviderModels([]providertypes.ModelDescriptor{{ ID: "deepseek-coder", Name: "DeepSeek Coder", - MaxOutputTokens: &maxOutputTokens, + MaxOutputTokens: 8192, }}) if err == nil || !strings.Contains(err.Error(), "max_output_tokens") { t.Fatalf("expected max_output_tokens validation error, got %v", err) } } -func TestNormalizeCustomProviderModelsRejectsZeroLimits(t *testing.T) { +func TestNormalizeCustomProviderModelsRejectsMetadataFields(t *testing.T) { t.Parallel() _, err := NormalizeCustomProviderModels([]providertypes.ModelDescriptor{ { - ID: "deepseek-coder", - Name: "DeepSeek Coder", - ContextWindow: 0, + ID: "deepseek-coder", + Name: "DeepSeek Coder", + Description: "desc", }, }) - if err == nil || !strings.Contains(err.Error(), "context_window") { - t.Fatalf("expected context_window validation error, got %v", err) + if err == nil || !strings.Contains(err.Error(), "description") { + t.Fatalf("expected description validation error, got %v", err) } _, err = NormalizeCustomProviderModels([]providertypes.ModelDescriptor{ { - ID: "deepseek-coder", - Name: "DeepSeek Coder", - ContextWindow: ManualModelOptionalIntUnset, - MaxOutputTokens: 0, + ID: "deepseek-coder", + Name: "DeepSeek Coder", + CapabilityHints: providertypes.ModelCapabilityHints{ + ToolCalling: providertypes.ModelCapabilityStateSupported, + }, }, }) - if err == nil || !strings.Contains(err.Error(), "max_output_tokens") { - t.Fatalf("expected max_output_tokens validation error, got %v", err) + if err == nil || !strings.Contains(err.Error(), "capability_hints") { + t.Fatalf("expected capability_hints validation error, got %v", err) } } @@ -656,6 +654,9 @@ func TestOpenAIProviderConfig(t *testing.T) { if provider.APIKeyEnv != OpenAIDefaultAPIKeyEnv { t.Fatalf("expected API key env %q, got %q", OpenAIDefaultAPIKeyEnv, provider.APIKeyEnv) } + if !containsModelDescriptorID(provider.Models, provider.Model) { + t.Fatalf("expected builtin static models to include default model, got %+v", provider.Models) + } } func TestGeminiProviderConfig(t *testing.T) { @@ -678,6 +679,9 @@ func TestGeminiProviderConfig(t *testing.T) { if provider.APIKeyEnv != GeminiDefaultAPIKeyEnv { t.Fatalf("expected API key env %q, got %q", GeminiDefaultAPIKeyEnv, provider.APIKeyEnv) } + if !containsModelDescriptorID(provider.Models, provider.Model) { + t.Fatalf("expected builtin static models to include default model, got %+v", provider.Models) + } } func TestModelScopeProviderConfig(t *testing.T) { @@ -703,8 +707,91 @@ func TestModelScopeProviderConfig(t *testing.T) { if provider.Source != ProviderSourceBuiltin { t.Fatalf("expected builtin source, got %q", provider.Source) } - if provider.DiscoveryEndpointPath != providerpkg.DiscoveryEndpointPathModels { - t.Fatalf("expected discovery endpoint %q, got %q", providerpkg.DiscoveryEndpointPathModels, provider.DiscoveryEndpointPath) + if provider.DiscoveryEndpointPath != "" { + t.Fatalf("expected builtin discovery endpoint to stay empty, got %q", provider.DiscoveryEndpointPath) + } + if !containsModelDescriptorID(provider.Models, provider.Model) { + t.Fatalf("expected builtin static models to include default model, got %+v", provider.Models) + } +} + +func TestBuiltinProvidersExposeStaticModelMetadata(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider ProviderConfig + wantTool providertypes.ModelCapabilityState + wantImage providertypes.ModelCapabilityState + }{ + { + name: "openai", + provider: OpenAIProvider(), + wantTool: providertypes.ModelCapabilityStateSupported, + wantImage: providertypes.ModelCapabilityStateSupported, + }, + { + name: "gemini", + provider: GeminiProvider(), + wantTool: providertypes.ModelCapabilityStateSupported, + wantImage: providertypes.ModelCapabilityStateSupported, + }, + { + name: "qiniu", + provider: QiniuProvider(), + wantTool: providertypes.ModelCapabilityStateSupported, + wantImage: providertypes.ModelCapabilityStateUnsupported, + }, + { + name: "modelscope", + provider: ModelScopeProvider(), + wantTool: providertypes.ModelCapabilityStateSupported, + wantImage: providertypes.ModelCapabilityStateUnsupported, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var target providertypes.ModelDescriptor + var found bool + for _, model := range tt.provider.Models { + if model.ID == tt.provider.Model { + target = model + found = true + break + } + } + if !found { + t.Fatalf("expected default model %q in provider models: %+v", tt.provider.Model, tt.provider.Models) + } + if target.Description == "" { + t.Fatalf("expected default model %q to include description", target.ID) + } + if target.ContextWindow <= 0 { + t.Fatalf("expected default model %q to include context window, got %d", target.ID, target.ContextWindow) + } + if target.MaxOutputTokens <= 0 { + t.Fatalf("expected default model %q to include max output tokens, got %d", target.ID, target.MaxOutputTokens) + } + if target.CapabilityHints.ToolCalling != tt.wantTool { + t.Fatalf("expected default model %q tool calling=%q, got %+v", target.ID, tt.wantTool, target.CapabilityHints) + } + if target.CapabilityHints.ImageInput != tt.wantImage { + t.Fatalf("expected default model %q image input=%q, got %+v", target.ID, tt.wantImage, target.CapabilityHints) + } + }) + } +} + +func TestModelScopeProviderIncludesMiniMaxFallbackModel(t *testing.T) { + t.Parallel() + + provider := ModelScopeProvider() + if !containsModelDescriptorID(provider.Models, "MiniMax/MiniMax-M2.5") { + t.Fatalf("expected modelscope static models to include MiniMax/MiniMax-M2.5, got %+v", provider.Models) } } diff --git a/internal/config/state/model.go b/internal/config/state/model.go index c74cd715..392d081d 100644 --- a/internal/config/state/model.go +++ b/internal/config/state/model.go @@ -43,6 +43,7 @@ type ModelCatalog interface { ListProviderModels(ctx context.Context, input provider.CatalogInput) ([]providertypes.ModelDescriptor, error) ListProviderModelsSnapshot(ctx context.Context, input provider.CatalogInput) ([]providertypes.ModelDescriptor, error) ListProviderModelsCached(ctx context.Context, input provider.CatalogInput) ([]providertypes.ModelDescriptor, error) + RefreshProviderModels(ctx context.Context, input provider.CatalogInput) ([]providertypes.ModelDescriptor, error) } // selectionFromConfig 将配置快照映射为当前选择结果。 @@ -97,8 +98,9 @@ func catalogInputFromProvider(cfg config.ProviderConfig) (provider.CatalogInput, input := provider.CatalogInput{ Identity: identity, ConfiguredModels: providertypes.CloneModelDescriptors(cloned.Models), - DisableDiscovery: cloned.Source == config.ProviderSourceCustom && - config.NormalizeModelSource(cloned.ModelSource) == config.ModelSourceManual, + DisableDiscovery: cloned.Source == config.ProviderSourceBuiltin || + (cloned.Source == config.ProviderSourceCustom && + config.NormalizeModelSource(cloned.ModelSource) == config.ModelSourceManual), ResolveDiscoveryConfig: func() (provider.RuntimeConfig, error) { resolved, err := cloned.Resolve() if err != nil { @@ -108,7 +110,10 @@ func catalogInputFromProvider(cfg config.ProviderConfig) (provider.CatalogInput, }, } if cloned.Source != config.ProviderSourceCustom { - input.DefaultModels = providertypes.DescriptorsFromIDs([]string{cloned.Model}) + input.DefaultModels = providertypes.CloneModelDescriptors(cloned.Models) + if len(input.DefaultModels) == 0 { + input.DefaultModels = providertypes.DescriptorsFromIDs([]string{cloned.Model}) + } } return input, nil } diff --git a/internal/config/state/model_additional_test.go b/internal/config/state/model_additional_test.go index 95d4ab1e..6c116822 100644 --- a/internal/config/state/model_additional_test.go +++ b/internal/config/state/model_additional_test.go @@ -255,6 +255,10 @@ func (additionalCatalogStub) ListProviderModelsCached(_ context.Context, _ provi return nil, nil } +func (additionalCatalogStub) RefreshProviderModels(_ context.Context, _ provider.CatalogInput) ([]providertypes.ModelDescriptor, error) { + return nil, nil +} + type denyAllDriverSupporter struct{} func (*denyAllDriverSupporter) Supports(_ string) bool { return false } @@ -516,6 +520,13 @@ func (s catalogMethodsStubForAdditional) ListProviderModelsCached(_ context.Cont return s.cachedModels, nil } +func (s catalogMethodsStubForAdditional) RefreshProviderModels(_ context.Context, _ provider.CatalogInput) ([]providertypes.ModelDescriptor, error) { + if s.listErr != nil { + return nil, s.listErr + } + return s.listModels, nil +} + type catalogMethodCallsForAdditional struct { listCalls int snapshotCalls int diff --git a/internal/config/state/model_test.go b/internal/config/state/model_test.go index de9e7e22..28e8c96c 100644 --- a/internal/config/state/model_test.go +++ b/internal/config/state/model_test.go @@ -9,7 +9,7 @@ import ( providertypes "neo-code/internal/provider/types" ) -func TestCatalogInputFromProviderBuiltinIncludesDefaultsAndLazyDiscovery(t *testing.T) { +func TestCatalogInputFromProviderBuiltinUsesStaticModelsAndDisablesDiscovery(t *testing.T) { t.Setenv("CATALOG_PROVIDER_API_KEY", "secret-key") cfg := configpkg.ProviderConfig{ @@ -41,12 +41,15 @@ func TestCatalogInputFromProviderBuiltinIncludesDefaultsAndLazyDiscovery(t *test if input.Identity.DiscoveryEndpointPath != providerpkg.DiscoveryEndpointPathModels { t.Fatalf("expected default discovery endpoint, got %+v", input.Identity) } - if len(input.DefaultModels) != 1 || input.DefaultModels[0].ID != "server-default" { - t.Fatalf("expected builtin default model, got %+v", input.DefaultModels) + if len(input.DefaultModels) != 1 || input.DefaultModels[0].ID != "model-a" { + t.Fatalf("expected builtin static models as defaults, got %+v", input.DefaultModels) } if len(input.ConfiguredModels) != 1 || input.ConfiguredModels[0].ID != "model-a" { t.Fatalf("expected configured models to be normalized, got %+v", input.ConfiguredModels) } + if !input.DisableDiscovery { + t.Fatal("expected builtin provider to disable discovery") + } cfg.Models[0].ID = "mutated" if input.ConfiguredModels[0].ID != "model-a" { diff --git a/internal/config/state/service.go b/internal/config/state/service.go index 4167f387..6a61c805 100644 --- a/internal/config/state/service.go +++ b/internal/config/state/service.go @@ -139,6 +139,55 @@ func (s *Service) ListModels(ctx context.Context) ([]providertypes.ModelDescript return s.catalogs.ListProviderModels(ctx, input) } +// RefreshModels 强制刷新当前选中 provider 的模型目录,并在必要时修正 current_model。 +func (s *Service) RefreshModels(ctx context.Context) ([]providertypes.ModelDescriptor, error) { + if err := s.validate(); err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + + selected, err := selectedProviderConfig(s.manager.Get()) + if err != nil { + return nil, err + } + if err := ensureSupportedProvider(s.supporters, selected); err != nil { + return nil, err + } + input, err := catalogInputFromProvider(selected) + if err != nil { + return nil, err + } + models, err := s.catalogs.RefreshProviderModels(ctx, input) + if err != nil { + return nil, err + } + if len(models) == 0 { + return nil, ErrNoModelsAvailable + } + + if err := s.manager.Update(ctx, func(cfg *config.Config) error { + currentSelected, err := selectedProviderConfig(*cfg) + if err != nil { + return err + } + sameIdentity, err := sameProviderIdentity(currentSelected, selected) + if err != nil { + return err + } + if !sameIdentity { + return errSelectionDrifted + } + cfg.CurrentModel, _ = resolveCurrentModel(cfg.CurrentModel, models, currentSelected.Model) + return nil + }); err != nil { + return nil, err + } + + return models, nil +} + // ListModelsSnapshot 获取当前选中 provider 的快照模型列表,不阻塞等待同步发现。 func (s *Service) ListModelsSnapshot(ctx context.Context) ([]providertypes.ModelDescriptor, error) { if err := s.validate(); err != nil { diff --git a/internal/config/state/service_provider_create.go b/internal/config/state/service_provider_create.go index 4a9f764c..194d7455 100644 --- a/internal/config/state/service_provider_create.go +++ b/internal/config/state/service_provider_create.go @@ -246,10 +246,8 @@ func normalizeCreateCustomProviderInput(input CreateCustomProviderInput) (create } type manualModelJSON struct { - ID string `json:"id"` - Name string `json:"name"` - ContextWindow *int `json:"context_window,omitempty"` - MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + ID string `json:"id"` + Name string `json:"name"` } // parseManualModelsJSON 解析并校验手工模型 JSON,确保至少包含一个合法模型且 id/name 必填。 @@ -291,22 +289,8 @@ func parseManualModelsJSON(raw string) ([]providertypes.ModelDescriptor, error) seen[key] = struct{}{} descriptor := providertypes.ModelDescriptor{ - ID: id, - Name: name, - ContextWindow: config.ManualModelOptionalIntUnset, - MaxOutputTokens: config.ManualModelOptionalIntUnset, - } - if model.ContextWindow != nil { - if *model.ContextWindow <= 0 { - return nil, fmt.Errorf("selection: models[%d].context_window must be greater than 0", index) - } - descriptor.ContextWindow = *model.ContextWindow - } - if model.MaxOutputTokens != nil { - if *model.MaxOutputTokens <= 0 { - return nil, fmt.Errorf("selection: models[%d].max_output_tokens must be greater than 0", index) - } - descriptor.MaxOutputTokens = *model.MaxOutputTokens + ID: id, + Name: name, } descriptors = append(descriptors, descriptor) } diff --git a/internal/config/state/service_provider_create_test.go b/internal/config/state/service_provider_create_test.go index 076bc774..6e23aa99 100644 --- a/internal/config/state/service_provider_create_test.go +++ b/internal/config/state/service_provider_create_test.go @@ -161,8 +161,7 @@ func TestCreateCustomProviderManualSourcePersistsModels(t *testing.T) { }, { "id": "manual-model-2", - "name": "Manual Model 2", - "context_window": 131072 + "name": "Manual Model 2" } ]`, } diff --git a/internal/config/state/service_test.go b/internal/config/state/service_test.go index 3d2ebcb9..e0a7b9bb 100644 --- a/internal/config/state/service_test.go +++ b/internal/config/state/service_test.go @@ -1156,6 +1156,10 @@ func (catalogStub) ListProviderModelsCached(_ context.Context, input provider.Ca return defaultModelsForInput(input), nil } +func (catalogStub) RefreshProviderModels(_ context.Context, input provider.CatalogInput) ([]providertypes.ModelDescriptor, error) { + return defaultModelsForInput(input), nil +} + type catalogMethodsStub struct { listModels []providertypes.ModelDescriptor snapshotModels []providertypes.ModelDescriptor @@ -1196,6 +1200,13 @@ func (s catalogMethodsStub) ListProviderModelsCached(_ context.Context, _ provid return providertypes.MergeModelDescriptors(s.cachedModels), nil } +func (s catalogMethodsStub) RefreshProviderModels(_ context.Context, _ provider.CatalogInput) ([]providertypes.ModelDescriptor, error) { + if s.listErr != nil { + return nil, s.listErr + } + return providertypes.MergeModelDescriptors(s.listModels), nil +} + type catalogMethodCalls struct { listCalls int snapshotCalls int @@ -1223,6 +1234,10 @@ func (c *cancelAfterFirstCachedCatalog) ListProviderModelsCached(_ context.Conte return defaultModelsForInput(input), nil } +func (c *cancelAfterFirstCachedCatalog) RefreshProviderModels(_ context.Context, input provider.CatalogInput) ([]providertypes.ModelDescriptor, error) { + return defaultModelsForInput(input), nil +} + type mutatingCatalog struct { listModels []providertypes.ModelDescriptor snapshotModels []providertypes.ModelDescriptor @@ -1254,6 +1269,10 @@ func (m *mutatingCatalog) ListProviderModelsCached(_ context.Context, _ provider return nil, nil } +func (m *mutatingCatalog) RefreshProviderModels(_ context.Context, _ provider.CatalogInput) ([]providertypes.ModelDescriptor, error) { + return providertypes.MergeModelDescriptors(m.listModels), nil +} + type errorCatalogStub struct { err error } @@ -1270,6 +1289,10 @@ func (s errorCatalogStub) ListProviderModelsCached(_ context.Context, _ provider return nil, s.err } +func (s errorCatalogStub) RefreshProviderModels(_ context.Context, _ provider.CatalogInput) ([]providertypes.ModelDescriptor, error) { + return nil, s.err +} + type driftingSnapshotCatalog struct { t *testing.T manager *configpkg.Manager @@ -1298,6 +1321,10 @@ func (c *driftingSnapshotCatalog) ListProviderModelsCached(_ context.Context, in return c.modelsFor(input), nil } +func (c *driftingSnapshotCatalog) RefreshProviderModels(_ context.Context, input provider.CatalogInput) ([]providertypes.ModelDescriptor, error) { + return c.modelsFor(input), nil +} + func (c *driftingSnapshotCatalog) modelsFor(input provider.CatalogInput) []providertypes.ModelDescriptor { switch input.Identity.Key() { case mustCatalogIdentity(c.t, configpkg.OpenAIProvider()).Key(): diff --git a/internal/provider/catalog/additional_test.go b/internal/provider/catalog/additional_test.go index ae01d4ec..271adea5 100644 --- a/internal/provider/catalog/additional_test.go +++ b/internal/provider/catalog/additional_test.go @@ -6,7 +6,6 @@ import ( "os" "path/filepath" "strings" - "sync/atomic" "testing" "time" @@ -59,32 +58,6 @@ func TestDiscoverAndPersistRejectsNilResolver(t *testing.T) { } } -func TestQueueRefreshSkipsIncompleteIdentity(t *testing.T) { - t.Parallel() - - var discoverCalls int32 - registry := newRegistry(t, openaicompat.DriverName, func(context.Context, provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { - atomic.AddInt32(&discoverCalls, 1) - return nil, nil - }) - service := NewService("", registry, newMemoryStore()) - service.backgroundTimeout = 50 * time.Millisecond - - service.queueRefresh(provider.CatalogInput{ - Identity: provider.ProviderIdentity{ - Driver: openaicompat.DriverName, - }, - }) - - time.Sleep(100 * time.Millisecond) - if atomic.LoadInt32(&discoverCalls) != 0 { - t.Fatalf("expected no background discovery, got %d calls", discoverCalls) - } - if len(service.inFlightByID) != 0 { - t.Fatalf("expected no in-flight markers, got %+v", service.inFlightByID) - } -} - func TestJSONStoreAdditionalFilesystemErrors(t *testing.T) { t.Parallel() @@ -135,7 +108,7 @@ func TestCatalogSnapshotOnMissingCatalog(t *testing.T) { service := NewService("", provider.NewRegistry(), newMemoryStore()) input := mustCatalogInput(t, config.OpenAIProvider()) snapshot := service.catalogSnapshot(context.Background(), input) - if snapshot.ok || snapshot.expired || len(snapshot.models) != 0 { + if snapshot.ok || len(snapshot.models) != 0 { t.Fatalf("expected empty snapshot on cache miss, got %+v", snapshot) } } @@ -157,7 +130,6 @@ func TestListProviderModelsAllowsUnsupportedOpenAICompatibleAPIStyleWhenCatalogI SchemaVersion: schemaVersion, Identity: input.Identity, FetchedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Hour), Models: []providertypes.ModelDescriptor{{ID: "cached-model", Name: "Cached Model"}}, }); err != nil { t.Fatalf("Save() cached catalog error = %v", err) @@ -189,7 +161,6 @@ func TestListBuiltinProviderModelsAllowsUnsupportedOpenAICompatibleAPIStyleWhenC SchemaVersion: schemaVersion, Identity: input.Identity, FetchedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Hour), Models: []providertypes.ModelDescriptor{{ID: "cached-model", Name: "Cached Model"}}, }); err != nil { t.Fatalf("Save() cached catalog error = %v", err) @@ -199,8 +170,8 @@ func TestListBuiltinProviderModelsAllowsUnsupportedOpenAICompatibleAPIStyleWhenC if err != nil { t.Fatalf("expected warm catalog to bypass chat-only api_style validation, got %v", err) } - if !containsModelDescriptorID(models, "cached-model") { - t.Fatalf("expected cached model in result, got %+v", models) + if !containsModelDescriptorID(models, config.OpenAIDefaultModel) { + t.Fatalf("expected builtin static models in result, got %+v", models) } } @@ -220,7 +191,6 @@ func TestListBuiltinProviderModelsSnapshotAllowsUnsupportedOpenAICompatibleAPISt SchemaVersion: schemaVersion, Identity: input.Identity, FetchedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Hour), Models: []providertypes.ModelDescriptor{{ID: "cached-model", Name: "Cached Model"}}, }); err != nil { t.Fatalf("Save() cached catalog error = %v", err) @@ -230,8 +200,8 @@ func TestListBuiltinProviderModelsSnapshotAllowsUnsupportedOpenAICompatibleAPISt if err != nil { t.Fatalf("expected snapshot path to allow chat-only api_style drift, got %v", err) } - if !containsModelDescriptorID(models, "cached-model") { - t.Fatalf("expected cached model in snapshot result, got %+v", models) + if !containsModelDescriptorID(models, config.OpenAIDefaultModel) { + t.Fatalf("expected builtin static models in snapshot result, got %+v", models) } } @@ -252,7 +222,6 @@ func TestListBuiltinProviderModelsSnapshotAndCachedAllowUnsupportedAPIStyleWithC SchemaVersion: schemaVersion, Identity: input.Identity, FetchedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Hour), Models: []providertypes.ModelDescriptor{{ID: "cached-model", Name: "Cached Model"}}, }); err != nil { t.Fatalf("Save() cached catalog error = %v", err) @@ -272,8 +241,8 @@ func TestListBuiltinProviderModelsSnapshotAndCachedAllowUnsupportedAPIStyleWithC if err != nil { t.Fatalf("expected cached %s path to allow chat-only api_style drift, got %v", tt.name, err) } - if !containsModelDescriptorID(models, "cached-model") { - t.Fatalf("expected cached %s result to include cached model, got %+v", tt.name, models) + if !containsModelDescriptorID(models, config.OpenAIDefaultModel) { + t.Fatalf("expected builtin static %s result, got %+v", tt.name, models) } }) } diff --git a/internal/provider/catalog/service.go b/internal/provider/catalog/service.go index ca72ea07..57934ebc 100644 --- a/internal/provider/catalog/service.go +++ b/internal/provider/catalog/service.go @@ -5,29 +5,18 @@ import ( "errors" "fmt" "strings" - "sync" "time" "neo-code/internal/provider" providertypes "neo-code/internal/provider/types" ) -const ( - defaultTTL = 24 * time.Hour - defaultBackgroundTimeout = 30 * time.Second -) - var errCatalogPersist = errors.New("provider catalog: persist discovered models") type Service struct { - registry *provider.Registry - store Store - catalogTTL time.Duration - backgroundTimeout time.Duration - now func() time.Time - - refreshMu sync.Mutex - inFlightByID map[string]struct{} + registry *provider.Registry + store Store + now func() time.Time } func NewService(baseDir string, registry *provider.Registry, store Store) *Service { @@ -36,32 +25,49 @@ func NewService(baseDir string, registry *provider.Registry, store Store) *Servi } return &Service{ - registry: registry, - store: store, - catalogTTL: defaultTTL, - backgroundTimeout: defaultBackgroundTimeout, - now: time.Now, - inFlightByID: map[string]struct{}{}, + registry: registry, + store: store, + now: time.Now, } } func (s *Service) ListProviderModels(ctx context.Context, input provider.CatalogInput) ([]providertypes.ModelDescriptor, error) { return s.listProviderModels(ctx, input, queryOptions{ allowSyncRefresh: true, - queueRefresh: true, }) } func (s *Service) ListProviderModelsSnapshot(ctx context.Context, input provider.CatalogInput) ([]providertypes.ModelDescriptor, error) { - return s.listProviderModels(ctx, input, queryOptions{ - queueRefresh: true, - }) + return s.listProviderModels(ctx, input, queryOptions{}) } func (s *Service) ListProviderModelsCached(ctx context.Context, input provider.CatalogInput) ([]providertypes.ModelDescriptor, error) { return s.listProviderModels(ctx, input, queryOptions{}) } +// RefreshProviderModels 强制重新执行一次远端 discovery,并在成功后覆盖本地缓存。 +func (s *Service) RefreshProviderModels(ctx context.Context, input provider.CatalogInput) ([]providertypes.ModelDescriptor, error) { + if err := s.validate(); err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + if input.DisableDiscovery { + configuredModels := providertypes.MergeModelDescriptors(input.ConfiguredModels) + defaultModels := providertypes.MergeModelDescriptors(input.DefaultModels) + return providertypes.MergeModelDescriptors(configuredModels, defaultModels), nil + } + + discovered, err := s.discoverAndPersist(ctx, input) + if err != nil { + return nil, err + } + configuredModels := providertypes.MergeModelDescriptors(input.ConfiguredModels) + defaultModels := providertypes.MergeModelDescriptors(input.DefaultModels) + return mergeResolvedModels(true, configuredModels, discovered, defaultModels), nil +} + func (s *Service) listProviderModels( ctx context.Context, input provider.CatalogInput, @@ -89,13 +95,11 @@ func (s *Service) validate() error { type queryOptions struct { allowSyncRefresh bool - queueRefresh bool } type catalogSnapshot struct { - models []providertypes.ModelDescriptor - ok bool - expired bool + models []providertypes.ModelDescriptor + ok bool } func (s *Service) modelsForProvider(ctx context.Context, input provider.CatalogInput, options queryOptions) ([]providertypes.ModelDescriptor, error) { @@ -117,7 +121,6 @@ func (s *Service) modelsForProvider(ctx context.Context, input provider.CatalogI // 空 catalog 等价于未命中,避免历史空缓存长期阻断后续 discovery。 catalogOK = false } - performedSyncRefresh := false if !catalogOK && options.allowSyncRefresh { discovered, err := s.discoverAndPersist(ctx, input) if err != nil { @@ -127,14 +130,9 @@ func (s *Service) modelsForProvider(ctx context.Context, input provider.CatalogI } else { models = discovered catalogOK = true - performedSyncRefresh = true } } - if shouldQueueRefresh(options, snapshot, performedSyncRefresh) { - s.queueRefresh(input) - } - return mergeResolvedModels(catalogOK, configuredModels, models, defaultModels), nil } @@ -144,9 +142,8 @@ func (s *Service) catalogSnapshot(ctx context.Context, input provider.CatalogInp return catalogSnapshot{} } return catalogSnapshot{ - models: modelCatalog.Models, - ok: true, - expired: modelCatalog.Expired(s.now()), + models: modelCatalog.Models, + ok: true, } } @@ -184,12 +181,10 @@ func (s *Service) discoverAndPersist(ctx context.Context, input provider.Catalog return discovered, nil } - now := s.now() if err := s.store.Save(ctx, ModelCatalog{ SchemaVersion: schemaVersion, Identity: input.Identity, - FetchedAt: now, - ExpiresAt: now.Add(s.catalogTTL), + FetchedAt: s.now(), Models: discovered, }); err != nil { return nil, fmt.Errorf("%w: %v", errCatalogPersist, err) @@ -197,51 +192,6 @@ func (s *Service) discoverAndPersist(ctx context.Context, input provider.Catalog return discovered, nil } -func (s *Service) queueRefresh(input provider.CatalogInput) { - if s.store == nil { - return - } - - if !s.registry.Supports(input.Identity.Driver) { - return - } - identity := input.Identity - if identity.Driver == "" || identity.BaseURL == "" { - return - } - - key := identity.Key() - s.refreshMu.Lock() - if _, exists := s.inFlightByID[key]; exists { - s.refreshMu.Unlock() - return - } - s.inFlightByID[key] = struct{}{} - s.refreshMu.Unlock() - - go func() { - defer func() { - s.refreshMu.Lock() - delete(s.inFlightByID, key) - s.refreshMu.Unlock() - }() - - ctx, cancel := context.WithTimeout(context.Background(), s.backgroundTimeout) - defer cancel() - _, _ = s.discoverAndPersist(ctx, input) - }() -} - -func shouldQueueRefresh(options queryOptions, snapshot catalogSnapshot, performedSyncRefresh bool) bool { - if !options.queueRefresh { - return false - } - if snapshot.expired { - return true - } - return !snapshot.ok && !performedSyncRefresh -} - func mergeResolvedModels( catalogOK bool, configuredModels []providertypes.ModelDescriptor, diff --git a/internal/provider/catalog/service_test.go b/internal/provider/catalog/service_test.go index d8a89351..fec5c6d3 100644 --- a/internal/provider/catalog/service_test.go +++ b/internal/provider/catalog/service_test.go @@ -30,16 +30,19 @@ func TestNewService(t *testing.T) { } } -func TestListProviderModelsFallsBackToDefaultModelWithoutDiscovery(t *testing.T) { +func TestListProviderModelsReturnsConfiguredModelsWhenDiscoveryDisabled(t *testing.T) { t.Parallel() service := NewService("", newRegistry(t, openaicompat.DriverName, nil), newMemoryStore()) - models, err := service.ListProviderModels(context.Background(), openAIProviderSource()) + providerCfg := customGatewayProvider() + providerCfg.ModelSource = config.ModelSourceManual + providerCfg.Models = []providertypes.ModelDescriptor{{ID: "manual-model", Name: "Manual Model"}} + models, err := service.ListProviderModels(context.Background(), mustCatalogInput(t, providerCfg)) if err != nil { t.Fatalf("ListProviderModels() error = %v", err) } - if len(models) != 1 || models[0].ID != config.OpenAIDefaultModel { - t.Fatalf("expected default model fallback, got %+v", models) + if len(models) != 1 || models[0].ID != "manual-model" { + t.Fatalf("expected configured models without discovery, got %+v", models) } } @@ -69,7 +72,7 @@ func TestListProviderModelsMergesConfiguredMetadataAfterDiscovery(t *testing.T) }) service := NewService("", registry, newMemoryStore()) - providerCfg := config.OpenAIProvider() + providerCfg := customGatewayProvider() providerCfg.Models = []providertypes.ModelDescriptor{{ ID: "deepseek-coder", Name: "DeepSeek Coder", @@ -131,7 +134,7 @@ func TestListProviderModelsUsesConfiguredContextWindowWhenDiscoveryMissesIt(t *t } } -func TestListProviderModelsSnapshotReturnsDefaultAndRefreshesInBackgroundOnMiss(t *testing.T) { +func TestListProviderModelsSnapshotReturnsBuiltinStaticModelsWithoutRefresh(t *testing.T) { t.Setenv(testAPIKeyEnv, "test-key") refreshed := make(chan struct{}, 1) @@ -145,41 +148,29 @@ func TestListProviderModelsSnapshotReturnsDefaultAndRefreshesInBackgroundOnMiss( store := newMemoryStore() service := NewService("", registry, store) - service.backgroundTimeout = time.Second models, err := service.ListProviderModelsSnapshot(context.Background(), openAIProviderSource()) if err != nil { t.Fatalf("ListProviderModelsSnapshot() error = %v", err) } - if len(models) != 1 || models[0].ID != config.OpenAIDefaultModel { - t.Fatalf("expected immediate default model fallback, got %+v", models) + if !containsModelDescriptorID(models, config.OpenAIDefaultModel) { + t.Fatalf("expected builtin static models on snapshot path, got %+v", models) } select { case <-refreshed: - case <-time.After(2 * time.Second): - t.Fatal("expected background refresh to run") + t.Fatal("expected snapshot path to avoid refresh") + case <-time.After(150 * time.Millisecond): } - identity, err := config.OpenAIProvider().Identity() + identity, err := customGatewayProvider().Identity() if err != nil { t.Fatalf("Identity() error = %v", err) } - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - modelCatalog, err := store.Load(context.Background(), identity) - if err == nil && containsModelDescriptorID(modelCatalog.Models, "gpt-4o") { - return - } - time.Sleep(20 * time.Millisecond) - } - - modelCatalog, err := store.Load(context.Background(), identity) - if err != nil { - t.Fatalf("Load() refreshed catalog error = %v", err) + if _, err := store.Load(context.Background(), identity); !errors.Is(err, ErrCatalogNotFound) { + t.Fatalf("expected snapshot path to avoid cache writes, got %v", err) } - t.Fatalf("expected refreshed catalog to contain gpt-4o, got %+v", modelCatalog.Models) } func TestListProviderModelsReturnsDiscoveryErrorOnCacheMiss(t *testing.T) { @@ -213,7 +204,7 @@ func TestListProviderModelsDiscoversAndCachesOnMiss(t *testing.T) { store := newMemoryStore() service := NewService("", registry, store) - models, err := service.ListProviderModels(context.Background(), openAIProviderSource()) + models, err := service.ListProviderModels(context.Background(), customGatewayProviderSource()) if err != nil { t.Fatalf("ListProviderModels() error = %v", err) } @@ -221,7 +212,7 @@ func TestListProviderModelsDiscoversAndCachesOnMiss(t *testing.T) { t.Fatalf("expected discovered model in result, got %+v", models) } - identity, err := config.OpenAIProvider().Identity() + identity, err := customGatewayProvider().Identity() if err != nil { t.Fatalf("Identity() error = %v", err) } @@ -234,10 +225,10 @@ func TestListProviderModelsDiscoversAndCachesOnMiss(t *testing.T) { } } -func TestListProviderModelsReturnsStaleCacheAndRefreshesInBackground(t *testing.T) { +func TestListProviderModelsReturnsCachedCatalogWithoutAutomaticRefresh(t *testing.T) { t.Setenv(testAPIKeyEnv, "test-key") - identity, err := config.OpenAIProvider().Identity() + identity, err := customGatewayProvider().Identity() if err != nil { t.Fatalf("Identity() error = %v", err) } @@ -248,7 +239,6 @@ func TestListProviderModelsReturnsStaleCacheAndRefreshesInBackground(t *testing. SchemaVersion: schemaVersion, Identity: identity, FetchedAt: now.Add(-48 * time.Hour), - ExpiresAt: now.Add(-24 * time.Hour), Models: []providertypes.ModelDescriptor{ {ID: "stale-model", Name: "Stale Model"}, }, @@ -267,9 +257,8 @@ func TestListProviderModelsReturnsStaleCacheAndRefreshesInBackground(t *testing. service := NewService("", registry, store) service.now = func() time.Time { return now } - service.backgroundTimeout = time.Second - models, err := service.ListProviderModels(context.Background(), openAIProviderSource()) + models, err := service.ListProviderModels(context.Background(), customGatewayProviderSource()) if err != nil { t.Fatalf("ListProviderModels() error = %v", err) } @@ -279,24 +268,9 @@ func TestListProviderModelsReturnsStaleCacheAndRefreshesInBackground(t *testing. select { case <-refreshed: - case <-time.After(2 * time.Second): - t.Fatal("expected background refresh to run") - } - - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - modelCatalog, err := store.Load(context.Background(), identity) - if err == nil && containsModelDescriptorID(modelCatalog.Models, "fresh-model") { - return - } - time.Sleep(20 * time.Millisecond) - } - - modelCatalog, err := store.Load(context.Background(), identity) - if err != nil { - t.Fatalf("Load() refreshed catalog error = %v", err) + t.Fatal("expected cached catalog path to avoid automatic refresh") + case <-time.After(150 * time.Millisecond): } - t.Fatalf("expected refreshed catalog to contain fresh-model, got %+v", modelCatalog.Models) } func TestDescriptorsFromIDsHelper(t *testing.T) { @@ -351,7 +325,7 @@ func TestListProviderModelsHonorsContextError(t *testing.T) { func TestListProviderModelsCachedUsesFreshCatalogWithoutDiscovery(t *testing.T) { t.Setenv(testAPIKeyEnv, "test-key") - identity, err := config.OpenAIProvider().Identity() + identity, err := customGatewayProvider().Identity() if err != nil { t.Fatalf("Identity() error = %v", err) } @@ -362,7 +336,6 @@ func TestListProviderModelsCachedUsesFreshCatalogWithoutDiscovery(t *testing.T) SchemaVersion: schemaVersion, Identity: identity, FetchedAt: now.Add(-time.Hour), - ExpiresAt: now.Add(time.Hour), Models: []providertypes.ModelDescriptor{ {ID: "cached-model", Name: "Cached Model"}, }, @@ -379,7 +352,7 @@ func TestListProviderModelsCachedUsesFreshCatalogWithoutDiscovery(t *testing.T) service := NewService("", registry, store) service.now = func() time.Time { return now } - models, err := service.ListProviderModelsCached(context.Background(), openAIProviderSource()) + models, err := service.ListProviderModelsCached(context.Background(), customGatewayProviderSource()) if err != nil { t.Fatalf("ListProviderModelsCached() error = %v", err) } @@ -400,7 +373,6 @@ func TestListProviderModelsRefreshesWhenCatalogSnapshotIsEmpty(t *testing.T) { SchemaVersion: schemaVersion, Identity: input.Identity, FetchedAt: time.Now().Add(-time.Minute), - ExpiresAt: time.Now().Add(time.Hour), Models: nil, }); err != nil { t.Fatalf("seed empty catalog: %v", err) @@ -428,7 +400,7 @@ func TestListProviderModelsRefreshesWhenCatalogSnapshotIsEmpty(t *testing.T) { func TestDiscoverAndPersistFailurePaths(t *testing.T) { t.Run("unsupported driver", func(t *testing.T) { service := NewService("", provider.NewRegistry(), newMemoryStore()) - discovered, err := service.discoverAndPersist(context.Background(), openAIProviderSource()) + discovered, err := service.discoverAndPersist(context.Background(), customGatewayProviderSource()) if err != nil || discovered != nil { t.Fatalf("expected unsupported driver to skip discovery, got err=%v models=%+v", err, discovered) } @@ -463,7 +435,7 @@ func TestDiscoverAndPersistFailurePaths(t *testing.T) { return nil, errors.New("discover failed") }), newMemoryStore()) - discovered, err := service.discoverAndPersist(context.Background(), openAIProviderSource()) + discovered, err := service.discoverAndPersist(context.Background(), customGatewayProviderSource()) if err == nil || discovered != nil { t.Fatalf("expected discovery error to skip persistence, got err=%v models=%+v", err, discovered) } @@ -475,7 +447,7 @@ func TestDiscoverAndPersistFailurePaths(t *testing.T) { return []providertypes.ModelDescriptor{{ID: "gpt-4.1", Name: "GPT-4.1"}}, nil }), nil) - discovered, err := service.discoverAndPersist(context.Background(), openAIProviderSource()) + discovered, err := service.discoverAndPersist(context.Background(), customGatewayProviderSource()) if err != nil { t.Fatalf("expected discovery without store to succeed, got %v", err) } @@ -519,7 +491,7 @@ func TestDiscoverAndPersistFailurePaths(t *testing.T) { } service := NewService("", registry, store) - models, err := service.ListProviderModels(context.Background(), openAIProviderSource()) + models, err := service.ListProviderModels(context.Background(), customGatewayProviderSource()) if err == nil { t.Fatal("expected persist failure to be returned") } @@ -532,93 +504,6 @@ func TestDiscoverAndPersistFailurePaths(t *testing.T) { }) } -func TestQueueRefreshDeduplicatesInFlightRequests(t *testing.T) { - t.Setenv(testAPIKeyEnv, "test-key") - - identity, err := config.OpenAIProvider().Identity() - if err != nil { - t.Fatalf("Identity() error = %v", err) - } - - started := make(chan struct{}, 1) - release := make(chan struct{}) - var discoverCalls int32 - registry := newRegistry(t, openaicompat.DriverName, func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { - atomic.AddInt32(&discoverCalls, 1) - select { - case started <- struct{}{}: - default: - } - select { - case <-release: - case <-ctx.Done(): - } - return []providertypes.ModelDescriptor{{ID: "gpt-4o", Name: "GPT-4o"}}, nil - }) - - service := NewService("", registry, newMemoryStore()) - service.backgroundTimeout = time.Second - - service.queueRefresh(openAIProviderSource()) - <-started - service.queueRefresh(openAIProviderSource()) - - time.Sleep(50 * time.Millisecond) - if calls := atomic.LoadInt32(&discoverCalls); calls != 1 { - t.Fatalf("expected exactly one in-flight refresh, got %d", calls) - } - - close(release) - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - service.refreshMu.Lock() - _, exists := service.inFlightByID[identity.Key()] - service.refreshMu.Unlock() - if !exists { - return - } - time.Sleep(20 * time.Millisecond) - } - - t.Fatal("expected in-flight refresh marker to be cleared") -} - -func TestQueueRefreshSkipsWhenDriverUnsupportedOrIdentityIncomplete(t *testing.T) { - t.Parallel() - - t.Run("unsupported driver", func(t *testing.T) { - t.Parallel() - - service := NewService("", provider.NewRegistry(), newMemoryStore()) - input := openAIProviderSource() - input.Identity.Driver = "missing" - - service.queueRefresh(input) - - service.refreshMu.Lock() - defer service.refreshMu.Unlock() - if len(service.inFlightByID) != 0 { - t.Fatalf("expected no refresh to be queued, got %+v", service.inFlightByID) - } - }) - - t.Run("incomplete identity", func(t *testing.T) { - t.Parallel() - - service := NewService("", newRegistry(t, openaicompat.DriverName, nil), newMemoryStore()) - input := openAIProviderSource() - input.Identity.BaseURL = "" - - service.queueRefresh(input) - - service.refreshMu.Lock() - defer service.refreshMu.Unlock() - if len(service.inFlightByID) != 0 { - t.Fatalf("expected no refresh to be queued, got %+v", service.inFlightByID) - } - }) -} - func newRegistry(t *testing.T, name string, discover provider.DiscoveryFunc) *provider.Registry { t.Helper() @@ -644,7 +529,22 @@ func newRegistry(t *testing.T, name string, discover provider.DiscoveryFunc) *pr func openAIProviderSource() provider.CatalogInput { providerCfg := config.OpenAIProvider() providerCfg.APIKeyEnv = testAPIKeyEnv - return mustCatalogInput(nil, providerCfg) + input := mustCatalogInput(nil, providerCfg) + if len(input.ConfiguredModels) == 0 { + input.ConfiguredModels = providertypes.DescriptorsFromIDs([]string{ + config.OpenAIDefaultModel, + "gpt-5.4-mini", + "gpt-5.3-codex", + "gpt-4.1", + "gpt-4o", + "gpt-4o-mini", + }) + } + if len(input.DefaultModels) == 0 { + input.DefaultModels = providertypes.CloneModelDescriptors(input.ConfiguredModels) + } + input.DisableDiscovery = true + return input } func customGatewayProviderSource() provider.CatalogInput { @@ -667,6 +567,8 @@ func mustCatalogInput(t *testing.T, cfg config.ProviderConfig) provider.CatalogI input := provider.CatalogInput{ Identity: identity, ConfiguredModels: providertypes.CloneModelDescriptors(cloned.Models), + DisableDiscovery: cloned.Source == config.ProviderSourceBuiltin || + (cloned.Source == config.ProviderSourceCustom && config.NormalizeModelSource(cloned.ModelSource) == config.ModelSourceManual), ResolveDiscoveryConfig: func() (provider.RuntimeConfig, error) { resolved, err := cloned.Resolve() if err != nil { @@ -676,7 +578,10 @@ func mustCatalogInput(t *testing.T, cfg config.ProviderConfig) provider.CatalogI }, } if cloned.Source != config.ProviderSourceCustom { - input.DefaultModels = providertypes.DescriptorsFromIDs([]string{cloned.Model}) + input.DefaultModels = providertypes.CloneModelDescriptors(cloned.Models) + if len(input.DefaultModels) == 0 { + input.DefaultModels = providertypes.DescriptorsFromIDs([]string{cloned.Model}) + } } return input } diff --git a/internal/provider/catalog/store.go b/internal/provider/catalog/store.go index 33a63133..0a6ad84e 100644 --- a/internal/provider/catalog/store.go +++ b/internal/provider/catalog/store.go @@ -17,7 +17,7 @@ import ( providertypes "neo-code/internal/provider/types" ) -const schemaVersion = 3 +const schemaVersion = 4 var ErrCatalogNotFound = errors.New("provider: model catalog not found") @@ -26,14 +26,9 @@ type ModelCatalog struct { SchemaVersion int `json:"schema_version"` Identity provider.ProviderIdentity `json:"identity"` FetchedAt time.Time `json:"fetched_at"` - ExpiresAt time.Time `json:"expires_at"` Models []providertypes.ModelDescriptor `json:"models"` } -func (c ModelCatalog) Expired(now time.Time) bool { - return !c.ExpiresAt.IsZero() && !now.Before(c.ExpiresAt) -} - // Store persists model catalogs keyed by normalized provider identity. type Store interface { Load(ctx context.Context, identity provider.ProviderIdentity) (ModelCatalog, error) diff --git a/internal/provider/catalog/store_test.go b/internal/provider/catalog/store_test.go index 77e783d1..8b08001a 100644 --- a/internal/provider/catalog/store_test.go +++ b/internal/provider/catalog/store_test.go @@ -30,7 +30,6 @@ func TestJSONStoreRoundTrip(t *testing.T) { SchemaVersion: schemaVersion, Identity: normalizedIdentity, FetchedAt: time.Date(2026, 4, 2, 10, 0, 0, 0, time.UTC), - ExpiresAt: time.Date(2026, 4, 3, 10, 0, 0, 0, time.UTC), Models: []providertypes.ModelDescriptor{ { ID: "gpt-4.1", @@ -59,7 +58,7 @@ func TestJSONStoreRoundTrip(t *testing.T) { if got.Identity != expected.Identity { t.Fatalf("expected identity %+v, got %+v", expected.Identity, got.Identity) } - if !got.FetchedAt.Equal(expected.FetchedAt) || !got.ExpiresAt.Equal(expected.ExpiresAt) { + if !got.FetchedAt.Equal(expected.FetchedAt) { t.Fatalf("expected timestamps %+v, got %+v", expected, got) } if len(got.Models) != 1 { @@ -154,7 +153,6 @@ func TestJSONStoreSaveReplacesExistingCatalogWithoutTempLeak(t *testing.T) { SchemaVersion: schemaVersion, Identity: identity, FetchedAt: time.Date(2026, 4, 2, 10, 0, 0, 0, time.UTC), - ExpiresAt: time.Date(2026, 4, 3, 10, 0, 0, 0, time.UTC), Models: []providertypes.ModelDescriptor{ {ID: "gpt-old", Name: "GPT Old"}, }, @@ -163,7 +161,6 @@ func TestJSONStoreSaveReplacesExistingCatalogWithoutTempLeak(t *testing.T) { SchemaVersion: schemaVersion, Identity: identity, FetchedAt: time.Date(2026, 4, 4, 10, 0, 0, 0, time.UTC), - ExpiresAt: time.Date(2026, 4, 5, 10, 0, 0, 0, time.UTC), Models: []providertypes.ModelDescriptor{ {ID: "gpt-new", Name: "GPT New"}, }, @@ -201,21 +198,6 @@ func TestJSONStoreSaveReplacesExistingCatalogWithoutTempLeak(t *testing.T) { } } -func TestModelCatalogExpired(t *testing.T) { - t.Parallel() - - now := time.Date(2026, 4, 8, 12, 0, 0, 0, time.UTC) - if (ModelCatalog{}).Expired(now) { - t.Fatal("expected zero-value catalog to be treated as not expired") - } - if !(ModelCatalog{ExpiresAt: now}).Expired(now) { - t.Fatal("expected catalog expiring at now to be expired") - } - if (ModelCatalog{ExpiresAt: now.Add(time.Minute)}).Expired(now) { - t.Fatal("expected future expiry to be treated as fresh") - } -} - func TestJSONStoreLoadHonorsContextError(t *testing.T) { t.Parallel() diff --git a/internal/runtime/subagent_more_branches_test.go b/internal/runtime/subagent_more_branches_test.go index f51187d4..759880a9 100644 --- a/internal/runtime/subagent_more_branches_test.go +++ b/internal/runtime/subagent_more_branches_test.go @@ -144,6 +144,12 @@ func TestRuntimeSubAgentResolveSettingsModelFallbackAndEmptyModel(t *testing.T) cfg.CurrentModel = "" for i := range cfg.Providers { cfg.Providers[i].Model = "model-from-provider" + if cfg.Providers[i].Source == config.ProviderSourceBuiltin { + cfg.Providers[i].Models = append(cfg.Providers[i].Models, providertypes.ModelDescriptor{ + ID: "model-from-provider", + Name: "model-from-provider", + }) + } } return nil }); err != nil { diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index 16e0741d..22335d7d 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -767,7 +767,7 @@ func (a *App) applyFullAccessPromptSelection(enable bool) tea.Cmd { return nil } -// openFullAccessPrompt 打开 Full Access 风险确认弹窗,并将输入焦点收敛回输入区。 +// openFullAccessPrompt 打开 Full Access 风险确认弹窗,并将输入焦点收回输入区。 func (a *App) openFullAccessPrompt() { a.pendingFullAccessPrompt = &fullAccessPromptState{Selected: 0} a.focus = panelInput @@ -1401,7 +1401,7 @@ func (a *App) resolveModelScopeGuidePath() string { return candidate } -// handleModelScopeGuideInput 澶勭悊 ModelScope 鍗婂紩瀵兼祦绋嬩腑鐨勯敭鐩樹氦浜掋€? +// handleModelScopeGuideInput 处理 ModelScope 半引导流程中的键盘交互。 func (a *App) handleModelScopeGuideInput(msg tea.KeyMsg) (tea.Model, tea.Cmd) { if a.modelScopeGuide == nil { return a, nil @@ -1440,7 +1440,7 @@ func (a *App) handleModelScopeGuideInput(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return a, nil } -// modelScopeGuideOpenTarget 杩斿洖褰撳墠寮曞姝ラ瀵瑰簲鐨勫閮ㄨ祫婧愮洰鏍囥€? +// modelScopeGuideOpenTarget 返回当前引导步骤对应的外部资源目标。 func modelScopeGuideOpenTarget(step modelScopeGuideStep, guidePath string) (string, bool) { switch step { case modelScopeGuideStepGuide: @@ -1458,7 +1458,7 @@ func modelScopeGuideOpenTarget(step modelScopeGuideStep, guidePath string) (stri } } -// submitModelScopeGuideToken 鏍¢獙骞舵彁浜ょ敤鎴风矘璐寸殑 token銆? +// submitModelScopeGuideToken 校验并提交用户粘贴的 token。 func (a *App) submitModelScopeGuideToken(guide *modelScopeGuideState) tea.Cmd { token := strings.TrimSpace(guide.Token) if token == "" { @@ -1471,7 +1471,7 @@ func (a *App) submitModelScopeGuideToken(guide *modelScopeGuideState) tea.Cmd { return a.runModelScopeGuideSubmit(guide.ProviderID, guide.APIKeyEnv, token) } -// runModelScopeGuideOpen 寮傛鎵撳紑 ModelScope 寮曞璧勬簮锛堟湰鍦?HTML 鎴栫綉椤?URL锛夈€? +// runModelScopeGuideOpen 异步打开 ModelScope 引导资源,本地 HTML 和网页 URL 共用这一入口。 func (a *App) runModelScopeGuideOpen(target string) tea.Cmd { openTarget := strings.TrimSpace(target) if openTarget == "" { @@ -1488,7 +1488,7 @@ func (a *App) runModelScopeGuideOpen(target string) tea.Cmd { } } -// handleModelScopeGuideOpenResultMsg 澶勭悊椤甸潰鎵撳紑缁撴灉锛屽苟鎺ㄨ繘寮曞鐘舵€佹満銆? +// handleModelScopeGuideOpenResultMsg 处理引导资源打开结果,并推进引导步骤。 func (a *App) handleModelScopeGuideOpenResultMsg(msg modelScopeGuideOpenResultMsg) { if a.modelScopeGuide == nil { return @@ -1513,7 +1513,7 @@ func (a *App) handleModelScopeGuideOpenResultMsg(msg modelScopeGuideOpenResultMs } } -// advanceModelScopeGuideStep 鏍规嵁鎵撳紑缁撴灉鎺ㄨ繘寮曞姝ラ銆? +// advanceModelScopeGuideStep 根据资源打开结果推进 ModelScope 引导步骤。 func advanceModelScopeGuideStep(current modelScopeGuideStep, target string) (modelScopeGuideStep, bool) { switch current { case modelScopeGuideStepGuide: @@ -1528,7 +1528,7 @@ func advanceModelScopeGuideStep(current modelScopeGuideStep, target string) (mod return current, false } -// clearModelScopeGuideFeedback 娓呯┖寮曞闈㈡澘涓婄殑閿欒鍜屾彁绀轰俊鎭€? +// clearModelScopeGuideFeedback 清空引导面板上的错误与提示信息。 func clearModelScopeGuideFeedback(guide *modelScopeGuideState) { if guide == nil { return @@ -1537,7 +1537,7 @@ func clearModelScopeGuideFeedback(guide *modelScopeGuideState) { guide.Notice = "" } -// runModelScopeGuideSubmit 鍦ㄨ缃?token 鍚庡畬鎴?provider 閫夋嫨涓庢渶灏忓彲鐢ㄦ牎楠屻€? +// runModelScopeGuideSubmit 设置 token 后完成 provider 选择与最小可用校验。 func (a *App) runModelScopeGuideSubmit(providerID string, apiKeyEnv string, token string) tea.Cmd { providerSvc := a.providerSvc baseDir := a.configManager.BaseDir() @@ -1604,7 +1604,7 @@ func (a *App) runModelScopeGuideSubmit(providerID string, apiKeyEnv string, toke } } -// rollbackModelScopeGuideSelection 鍦ㄥ紩瀵兼祦绋嬪け璐ユ椂鍥炴粴 provider/model 閫夋嫨锛岄伩鍏嶉厤缃姸鎬佹紓绉汇€? +// rollbackModelScopeGuideSelection 在引导流程失败时回滚 provider 和 model 选择,避免状态漂移。 func rollbackModelScopeGuideSelection(providerSvc ProviderController, providerID string, modelID string) error { normalizedProviderID := strings.TrimSpace(providerID) normalizedModelID := strings.TrimSpace(modelID) @@ -1627,7 +1627,7 @@ func rollbackModelScopeGuideSelection(providerSvc ProviderController, providerID return nil } -// handleModelScopeGuideSubmitResultMsg 澶勭悊 token 鏍¢獙缁撴灉锛屾垚鍔熷悗鍏抽棴鍚戝锛屽け璐ユ椂鍥為€€骞舵彁绀轰笅涓€姝ャ€? +// handleModelScopeGuideSubmitResultMsg 处理 token 校验结果;成功后关闭引导,失败时回退并提示下一步。 func (a *App) handleModelScopeGuideSubmitResultMsg(msg modelScopeGuideSubmitResultMsg) tea.Cmd { if a.modelScopeGuide == nil { return nil @@ -1678,7 +1678,7 @@ func (a *App) handleModelScopeGuideSubmitResultMsg(msg modelScopeGuideSubmitResu return a.requestModelCatalogRefresh(a.state.CurrentProvider) } -// isModelScopeAuthOrPermissionError 鍒ゆ柇閿欒鏄惁鎸囧悜璁よ瘉鎴栨潈闄愭湭瀹屾垚鍦烘櫙銆? +// isModelScopeAuthOrPermissionError 判断错误是否指向认证或权限未完成场景。 func isModelScopeAuthOrPermissionError(raw string) bool { lowered := strings.ToLower(strings.TrimSpace(raw)) if lowered == "" { @@ -1753,7 +1753,7 @@ func (a *App) refreshMessages() error { return nil } -// HydrateSession 在 TUI 启动阶段加载并接管既有会话状态,用于 URL 唤醒后的无感接续。 +// HydrateSession 在 TUI 启动阶段加载并接管既有会话状态,用于 URL 唤醒后的无感续接。 func (a *App) HydrateSession(ctx context.Context, sessionID string) error { sessionID = strings.TrimSpace(sessionID) if sessionID == "" { @@ -1795,7 +1795,7 @@ func (a *App) ConfigureStartupWakeInput(text string, workdir string) error { return nil } -// applySessionSnapshot 将会话快照同步到前端状态,统一复用于普通刷新与启动水化路径。 +// applySessionSnapshot 将会话快照同步到前端状态,统一复用于普通刷新与启动接管路径。 func (a *App) applySessionSnapshot(session agentsession.Session, warnOnMissingWorkdir bool) { a.activeMessages = session.Messages a.clearActivities() @@ -1969,7 +1969,7 @@ func (a *App) refreshRuntimeSourceSnapshot() { } } -// runtimeSessionContextSource 瀹氫箟璇诲彇浼氳瘽涓婁笅鏂囧揩鐓х殑鏈€灏忔帴鍙o紝渚夸簬鍦?UI 渚ф寜闇€鍒锋柊杩愯鎬佷俊鎭€? +// runtimeSessionContextSource 定义读取会话上下文快照的最小接口,便于 UI 侧按需刷新运行态信息。 type runtimeSessionContextSource interface { GetSessionContext(ctx context.Context, sessionID string) (any, error) } @@ -2121,7 +2121,7 @@ func runtimeEventPhaseChangedHandler(a *App, event tuiservices.RuntimeEvent) boo return false } -// runtimeEventVerificationStartedHandler 澶勭悊楠岃瘉娴佺▼寮€濮嬩簨浠跺苟鏇存柊杩愯杩涘害鎻愮ず銆? +// runtimeEventVerificationStartedHandler 处理验证流程开始事件,并更新运行进度提示。 func runtimeEventVerificationStartedHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := event.Payload.(tuiservices.VerificationStartedPayload) if !ok { @@ -2136,7 +2136,7 @@ func runtimeEventVerificationStartedHandler(a *App, event tuiservices.RuntimeEve return false } -// runtimeEventVerificationStageFinishedHandler 澶勭悊鍗曚釜 verifier 闃舵瀹屾垚浜嬩欢銆? +// runtimeEventVerificationStageFinishedHandler 处理单个 verifier 阶段完成事件。 func runtimeEventVerificationStageFinishedHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := event.Payload.(tuiservices.VerificationStageFinishedPayload) if !ok { @@ -2159,7 +2159,7 @@ func runtimeEventVerificationStageFinishedHandler(a *App, event tuiservices.Runt return false } -// runtimeEventVerificationFinishedHandler 澶勭悊楠岃瘉鎬绘祦绋嬬粨鏉熶簨浠躲€? +// runtimeEventVerificationFinishedHandler 处理验证总流程结束事件。 func runtimeEventVerificationFinishedHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := event.Payload.(tuiservices.VerificationFinishedPayload) if !ok { @@ -2171,7 +2171,7 @@ func runtimeEventVerificationFinishedHandler(a *App, event tuiservices.RuntimeEv return false } -// runtimeEventVerificationCompletedHandler 澶勭悊楠岃瘉閫氳繃浜嬩欢銆? +// runtimeEventVerificationCompletedHandler 处理验证通过事件。 func runtimeEventVerificationCompletedHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := event.Payload.(tuiservices.VerificationCompletedPayload) if !ok { @@ -2185,7 +2185,7 @@ func runtimeEventVerificationCompletedHandler(a *App, event tuiservices.RuntimeE return false } -// runtimeEventVerificationFailedHandler 澶勭悊楠岃瘉澶辫触浜嬩欢銆? +// runtimeEventVerificationFailedHandler 处理验证失败事件。 func runtimeEventVerificationFailedHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := event.Payload.(tuiservices.VerificationFailedPayload) if !ok { @@ -2202,7 +2202,7 @@ func runtimeEventVerificationFailedHandler(a *App, event tuiservices.RuntimeEven return false } -// runtimeEventAcceptanceDecidedHandler 澶勭悊 acceptance 鍐崇瓥浜嬩欢骞惰褰曞彲瑙傛祴鏃ュ織銆? +// runtimeEventAcceptanceDecidedHandler 处理 acceptance 决策事件,并记录可观测活动日志。 func runtimeEventAcceptanceDecidedHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := event.Payload.(tuiservices.AcceptanceDecidedPayload) if !ok { @@ -2227,7 +2227,7 @@ func runtimeEventAcceptanceDecidedHandler(a *App, event tuiservices.RuntimeEvent return false } -// runtimeEventStopReasonDecidedHandler 澶勭悊杩愯缁撴潫鍘熷洜骞剁粺涓€鏇存柊鐘舵€佷笌娲诲姩鏃ュ織銆? +// runtimeEventStopReasonDecidedHandler 处理运行终止原因事件,统一收尾状态与活动日志。 func runtimeEventStopReasonDecidedHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := event.Payload.(tuiservices.StopReasonDecidedPayload) if !ok { @@ -2364,7 +2364,7 @@ func runtimeEventTodoConflictHandler(a *App, event tuiservices.RuntimeEvent) boo return false } -// runtimeEventSkillActivatedHandler 鍦?runtime 婵€娲?skill 鍚庡悓姝ユ椿鍔ㄦ棩蹇椼€? +// runtimeEventSkillActivatedHandler 在 runtime 激活 skill 后同步活动日志。 func runtimeEventSkillActivatedHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := parseSessionSkillEventPayload(event.Payload) if !ok { @@ -2375,7 +2375,7 @@ func runtimeEventSkillActivatedHandler(a *App, event tuiservices.RuntimeEvent) b return false } -// runtimeEventSkillDeactivatedHandler 鍦?runtime 鍋滅敤 skill 鍚庡悓姝ユ椿鍔ㄦ棩蹇椼€? +// runtimeEventSkillDeactivatedHandler 在 runtime 停用 skill 后同步活动日志。 func runtimeEventSkillDeactivatedHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := parseSessionSkillEventPayload(event.Payload) if !ok { @@ -2386,7 +2386,7 @@ func runtimeEventSkillDeactivatedHandler(a *App, event tuiservices.RuntimeEvent) return false } -// runtimeEventSkillMissingHandler 鍦ㄤ細璇?skill 涓㈠け鏃惰緭鍑烘樉寮忛敊璇弽棣堬紝渚夸簬鎺掓煡鎭㈠闂銆? +// runtimeEventSkillMissingHandler 在会话 skill 缺失时输出显式错误反馈,便于排查恢复问题。 func runtimeEventSkillMissingHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := parseSessionSkillEventPayload(event.Payload) if !ok { @@ -2397,7 +2397,7 @@ func runtimeEventSkillMissingHandler(a *App, event tuiservices.RuntimeEvent) boo return false } -// parseSessionSkillEventPayload 瑙f瀽 runtime skill 浜嬩欢璐熻浇骞跺吋瀹?map 缁撴瀯銆? +// parseSessionSkillEventPayload 解析 runtime skill 事件负载,并兼容 map 结构。 func parseSessionSkillEventPayload(payload any) (tuiservices.SessionSkillEventPayload, bool) { switch typed := payload.(type) { case tuiservices.SessionSkillEventPayload: @@ -2625,7 +2625,7 @@ func runtimeEventTokenUsageHandler(a *App, event tuiservices.RuntimeEvent) bool return false } -// runtimeEventToolCallThinkingHandler 鍦ㄥ伐鍏疯皟鐢ㄨ繘鍏ユ€濊€冮樁娈垫椂鍚屾褰撳墠宸ュ叿涓庤繘搴︽彁绀恒€? +// runtimeEventToolCallThinkingHandler 在工具调用进入规划阶段时同步当前工具和进度提示。 func runtimeEventToolCallThinkingHandler(a *App, event tuiservices.RuntimeEvent) bool { if payload, ok := event.Payload.(string); ok && strings.TrimSpace(payload) != "" { a.state.CurrentTool = payload @@ -2635,7 +2635,7 @@ func runtimeEventToolCallThinkingHandler(a *App, event tuiservices.RuntimeEvent) return false } -// runtimeEventToolStartHandler 鍦ㄥ伐鍏峰疄闄呮墽琛屾椂鏇存柊鐘舵€佹潯鍜屾椿鍔ㄨ褰曘€? +// runtimeEventToolStartHandler 在工具开始执行时更新状态栏和活动记录。 func runtimeEventToolStartHandler(a *App, event tuiservices.RuntimeEvent) bool { a.state.StatusText = statusRunningTool a.state.StreamingReply = false @@ -2671,7 +2671,7 @@ func runtimeEventToolResultHandler(a *App, event tuiservices.RuntimeEvent) bool return true } -// runtimeEventAgentChunkHandler 灏嗘祦寮忓洖澶嶅垎鐗囨寔缁拷鍔犲埌杞綍鍖猴紝骞舵帹杩涜繍琛岃繘搴︺€? +// runtimeEventAgentChunkHandler 将流式回复分片持续追加到转录区,并推进运行进度。 func runtimeEventAgentChunkHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := event.Payload.(string) if !ok { @@ -2692,7 +2692,7 @@ func runtimeEventToolChunkHandler(a *App, event tuiservices.RuntimeEvent) bool { return false } -// runtimeEventAgentDoneHandler 鍦ㄤ唬鐞嗗洖澶嶇粨鏉熸椂鏀跺熬鐘舵€佸苟琛ラ綈鏈€缁?assistant 娑堟伅銆? +// runtimeEventAgentDoneHandler 在代理回复结束时收尾状态,并补齐最终 assistant 消息。 func runtimeEventAgentDoneHandler(a *App, event tuiservices.RuntimeEvent) bool { a.state.IsAgentRunning = false a.state.StreamingReply = false @@ -2728,7 +2728,7 @@ func runtimeEventRunCanceledHandler(a *App, event tuiservices.RuntimeEvent) bool return false } -// runtimeEventErrorHandler 鍦ㄨ繍琛屾姤閿欐椂缁熶竴娓呯悊鐜板満骞跺睍绀洪敊璇俊鎭€? +// runtimeEventErrorHandler 在运行报错时统一清理现场,并展示错误信息。 func runtimeEventErrorHandler(a *App, event tuiservices.RuntimeEvent) bool { a.state.StatusText = statusError a.state.IsAgentRunning = false @@ -2809,7 +2809,7 @@ func (a *App) beginAutoPermissionApproval(payload tuiservices.PermissionRequestP return true } -// permissionRequestActivityDetail 统一格式化权限请求相关活动明细,避免各分支重复拼接。 +// permissionRequestActivityDetail 统一格式化权限请求活动详情,避免各分支重复拼接。 func permissionRequestActivityDetail(payload tuiservices.PermissionRequestPayload) string { return fmt.Sprintf("%s -> %s", fallbackText(payload.ToolName, "tool"), fallbackText(payload.Target, "(empty target)")) } @@ -2837,7 +2837,7 @@ func runtimeEventPermissionResolvedHandler(a *App, event tuiservices.RuntimeEven return false } -// refreshPermissionPromptLayout 鍦ㄦ潈闄愭彁绀哄嚭鐜版垨娑堝け鍚庡埛鏂板竷灞€锛岄伩鍏嶉伄鎸¤緭鍏ュ尯銆? +// refreshPermissionPromptLayout 在权限提示出现或消失后刷新布局,避免遮挡输入区。 func (a *App) refreshPermissionPromptLayout() { if a.width <= 0 || a.height <= 0 { return @@ -2902,7 +2902,7 @@ func (a *App) appendInlineMessage(role string, message string) { a.activeMessages = append(a.activeMessages, providertypes.Message{Role: role, Parts: []providertypes.ContentPart{providertypes.NewTextPart(content)}}) } -// applyInlineCommandError 缁熶竴鍐欏叆鍛戒护閿欒骞跺埛鏂拌浆褰曞尯锛岀‘淇濋敊璇彁绀虹珛鍗冲彲瑙併€? +// applyInlineCommandError 统一写入内联命令错误,并立即刷新转录区显示。 func (a *App) applyInlineCommandError(message string) { message = strings.TrimSpace(message) if message == "" { @@ -2914,7 +2914,7 @@ func (a *App) applyInlineCommandError(message string) { a.rebuildTranscript() } -// recordStaleSkillCommandResult 璁板綍鏉ヨ嚜鏃т細璇濈殑鎶€鑳藉懡浠ょ粨鏋滐紝閬垮厤鍦ㄤ細璇濆垏鎹㈠悗閿欒琚潤榛樹涪寮冦€? +// recordStaleSkillCommandResult 记录来自旧会话的 skill 命令结果,避免切会话后被静默丢弃。 func (a *App) recordStaleSkillCommandResult(requestSessionID, activeSessionID string, runErr error) { detail := fmt.Sprintf("result from session %q ignored after switching to %q", requestSessionID, activeSessionID) if runErr != nil { @@ -4063,7 +4063,7 @@ func (a *App) persistLogEntriesForActiveSession() { a.logPersistDirty = false } -// sessionLogRuntime 杩斿洖鏀寔浼氳瘽鏃ュ織璇诲啓鐨?runtime 閫傞厤鑳藉姏銆? +// sessionLogRuntime 返回支持会话日志读写的 runtime 适配能力。 func (a *App) sessionLogRuntime() sessionLogPersistenceRuntime { runtimeWithPersistence, ok := a.runtime.(sessionLogPersistenceRuntime) if !ok { @@ -4072,7 +4072,7 @@ func (a *App) sessionLogRuntime() sessionLogPersistenceRuntime { return runtimeWithPersistence } -// reportLogPersistenceError 缁熶竴澶勭悊鏃ュ織鎸佷箙鍖栧け璐ユ彁绀猴紝閬垮厤閿欒闈欓粯鍚炴帀銆? +// reportLogPersistenceError 统一处理日志持久化失败提示,避免错误被静默吞掉。 func (a *App) reportLogPersistenceError(action string, err error) { if err == nil { return @@ -4082,7 +4082,7 @@ func (a *App) reportLogPersistenceError(action string, err error) { a.showFooterError(message) } -// restoreStatusAfterLogViewer 鍦ㄥ叧闂棩蹇楄鍥炬椂鎭㈠鍙鐘舵€侊紝閬垮厤瑕嗙洊鐪熷疄杩愯鎬併€? +// restoreStatusAfterLogViewer 关闭日志视图后恢复可读状态,避免覆盖真实运行态。 func (a *App) restoreStatusAfterLogViewer() { defer func() { a.logViewerPrevStatus = "" }() if executionError := strings.TrimSpace(a.state.ExecutionError); executionError != "" { @@ -4108,7 +4108,7 @@ func (a *App) restoreStatusAfterLogViewer() { a.state.StatusText = statusReady } -// toRuntimeSessionLogEntries 杞崲鏃ュ織鏉$洰鍒?runtime 鎸佷箙鍖栨ā鍨嬨€? +// toRuntimeSessionLogEntries 将日志条目转换为 runtime 持久化模型。 func toRuntimeSessionLogEntries(entries []logEntry) []tuiservices.SessionLogEntry { converted := make([]tuiservices.SessionLogEntry, 0, len(entries)) for _, entry := range entries { @@ -4122,7 +4122,7 @@ func toRuntimeSessionLogEntries(entries []logEntry) []tuiservices.SessionLogEntr return converted } -// fromRuntimeSessionLogEntries 灏?runtime 鎸佷箙鍖栨ā鍨嬫仮澶嶄负 TUI 灞曠ず妯″瀷銆? +// fromRuntimeSessionLogEntries 将 runtime 持久化模型还原为 TUI 展示模型。 func fromRuntimeSessionLogEntries(entries []tuiservices.SessionLogEntry) []logEntry { converted := make([]logEntry, 0, len(entries)) for _, entry := range entries { @@ -4360,7 +4360,7 @@ func currentProviderAddField(form *providerAddFormState) providerAddFieldID { return fields[form.Step] } -// isProviderAddEnumField 鍒ゆ柇褰撳墠鏂板 Provider 琛ㄥ崟鐒︾偣鏄惁鍦ㄦ灇涓惧瓧娈碉紙Driver/Model Source锛夈€? +// isProviderAddEnumField 判断当前新增 Provider 表单焦点是否位于枚举字段(Driver/Model Source)。 func isProviderAddEnumField(form *providerAddFormState) bool { switch currentProviderAddField(form) { case providerAddFieldDriver, providerAddFieldModelSource, providerAddFieldChatAPIMode: @@ -4647,7 +4647,7 @@ type modelScopeGuideSubmitResultMsg struct { Warning string } -// providerAddDefaultChatEndpointPath 杩斿洖 provider add 琛ㄥ崟鐨勯┍鍔ㄩ粯璁よ亰澶╃鐐硅矾寰勩€? +// providerAddDefaultChatEndpointPath 返回 provider add 表单按 driver 推导的默认 chat endpoint。 func providerAddDefaultChatEndpointPath(driver string) string { switch provider.NormalizeProviderDriver(driver) { case provider.DriverGemini: @@ -4659,7 +4659,7 @@ func providerAddDefaultChatEndpointPath(driver string) string { } } -// providerAddDefaultOpenAICompatChatEndpointPath 鏍规嵁 chat_api_mode 杩斿洖 openaicompat 鐨勯粯璁よ亰澶╃鐐硅矾寰勩€? +// providerAddDefaultOpenAICompatChatEndpointPath 根据 chat_api_mode 返回 openaicompat 的默认 chat endpoint。 func providerAddDefaultOpenAICompatChatEndpointPath(chatAPIMode string) string { mode, err := provider.NormalizeProviderChatAPIMode(chatAPIMode) if err != nil || mode == "" { @@ -4671,7 +4671,7 @@ func providerAddDefaultOpenAICompatChatEndpointPath(chatAPIMode string) string { return "/chat/completions" } -// syncProviderAddOpenAICompatModeDefaults 鍦ㄥ垏鎹?chat_api_mode 鏃跺悓姝ラ粯璁?chat endpoint锛岄伩鍏嶉粯璁ゅ€奸敊閰嶃€? +// syncProviderAddOpenAICompatModeDefaults 在切换 chat_api_mode 时同步默认 chat endpoint,避免默认值错配。 func syncProviderAddOpenAICompatModeDefaults(form *providerAddFormState, previousMode string) { if form == nil || provider.NormalizeProviderDriver(form.Driver) != provider.DriverOpenAICompat { return @@ -4685,7 +4685,7 @@ func syncProviderAddOpenAICompatModeDefaults(form *providerAddFormState, previou form.ChatEndpointPath = providerAddDefaultOpenAICompatChatEndpointPath(form.ChatAPIMode) } -// providerAddDefaultBaseURL 杩斿洖 provider add 琛ㄥ崟鐨勯┍鍔ㄩ粯璁?base URL銆? +// providerAddDefaultBaseURL 返回 provider add 表单按 driver 推导的默认 base URL。 func providerAddDefaultBaseURL(driver string) string { switch provider.NormalizeProviderDriver(driver) { case provider.DriverOpenAICompat: @@ -4699,7 +4699,7 @@ func providerAddDefaultBaseURL(driver string) string { } } -// syncProviderAddDriverDefaults 鍦ㄥ垏鎹?driver 鏃舵寜闇€鏇存柊榛樿 base URL 涓?chat endpoint銆? +// syncProviderAddDriverDefaults 在切换 driver 时按需同步默认 base URL 与 chat endpoint。 func syncProviderAddDriverDefaults(form *providerAddFormState, previousDriver string) { if form == nil { return @@ -4844,13 +4844,11 @@ func buildProviderAddRequest(form providerAddFormState) (providerAddRequest, str } type providerAddManualModelJSON struct { - ID string `json:"id"` - Name string `json:"name"` - ContextWindow *int `json:"context_window,omitempty"` - MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + ID string `json:"id"` + Name string `json:"name"` } -// parseProviderAddManualModelsJSON 瑙f瀽 provider add 琛ㄥ崟涓殑鎵嬪伐妯″瀷 JSON锛屽苟澶嶇敤 config 褰掍竴鍖栨牎楠岃鍒欍€? +// parseProviderAddManualModelsJSON 解析 provider add 表单中的手工模型 JSON,并复用 config 归一化校验规则。 func parseProviderAddManualModelsJSON(raw string) ([]providertypes.ModelDescriptor, error) { trimmed := strings.TrimSpace(raw) if trimmed == "" { @@ -4876,34 +4874,20 @@ func parseProviderAddManualModelsJSON(raw string) ([]providertypes.ModelDescript seen := make(map[string]struct{}, len(models)) for _, model := range models { descriptor := providertypes.ModelDescriptor{ - ID: strings.TrimSpace(model.ID), - Name: strings.TrimSpace(model.Name), - ContextWindow: config.ManualModelOptionalIntUnset, - MaxOutputTokens: config.ManualModelOptionalIntUnset, + ID: strings.TrimSpace(model.ID), + Name: strings.TrimSpace(model.Name), } key := provider.NormalizeKey(descriptor.ID) if _, exists := seen[key]; exists { return nil, fmt.Errorf("parse manual model json: models.id %q is duplicated", descriptor.ID) } seen[key] = struct{}{} - if model.ContextWindow != nil { - if *model.ContextWindow <= 0 { - return nil, fmt.Errorf("parse manual model json: models.context_window must be greater than 0") - } - descriptor.ContextWindow = *model.ContextWindow - } - if model.MaxOutputTokens != nil { - if *model.MaxOutputTokens <= 0 { - return nil, fmt.Errorf("parse manual model json: models.max_output_tokens must be greater than 0") - } - descriptor.MaxOutputTokens = *model.MaxOutputTokens - } descriptors = append(descriptors, descriptor) } return descriptors, nil } -// sanitizeProviderAddInputRunes 杩囨护 provider 琛ㄥ崟杈撳叆涓殑鎺у埗瀛楃锛岄伩鍏嶄笉鍙瀛楃姹℃煋閰嶇疆瀛楁銆? +// sanitizeProviderAddInputRunes 过滤 provider 表单输入中的控制字符,避免不可见字符污染字段。 func sanitizeProviderAddInputRunes(runes []rune) string { if len(runes) == 0 { return "" @@ -4920,7 +4904,7 @@ func sanitizeProviderAddInputRunes(runes []rune) string { return builder.String() } -// sanitizeProviderAddJSONInputRunes 杩囨护涓嶅彲瑙佹牸寮忔帶鍒跺瓧绗︼紝鍚屾椂淇濈暀 JSON 缂栬緫鎵€闇€鐨勬崲琛屼笌鍒惰〃绗︺€? +// sanitizeProviderAddJSONInputRunes 过滤 JSON 输入中的不可见格式控制字符,同时保留换行与制表符。 func sanitizeProviderAddJSONInputRunes(runes []rune) string { if len(runes) == 0 { return "" @@ -4943,7 +4927,7 @@ func sanitizeProviderAddJSONInputRunes(runes []rune) string { return builder.String() } -// normalizeProviderAddFieldValue 瀵?provider 琛ㄥ崟瀛楁鍋氱粺涓€娓呯悊锛屽幓闄ゆ帶鍒跺瓧绗﹀苟瑁佸壀棣栧熬绌虹櫧銆? +// normalizeProviderAddFieldValue 统一清洗 provider 表单字段,去掉控制字符并裁剪首尾空白。 func normalizeProviderAddFieldValue(value string) string { return strings.TrimSpace(sanitizeProviderAddInputRunes([]rune(value))) } diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index a1e444bb..6a4eb43e 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -3608,7 +3608,7 @@ func TestUpdatePickerProviderEnterSkipsModelScopeGuideWhenTokenExists(t *testing func TestHandleModelScopeGuideInputSubmitSuccess(t *testing.T) { app, _ := newTestAppWithProviderService(t, stubProviderService{ - models: []providertypes.ModelDescriptor{{ID: "Qwen/Qwen2.5-7B-Instruct", Name: "Qwen"}}, + models: []providertypes.ModelDescriptor{{ID: config.ModelScopeDefaultModel, Name: "Qwen"}}, }) app.modelScopeGuide = &modelScopeGuideState{ ProviderID: config.ModelScopeName, @@ -3710,7 +3710,7 @@ func TestRunModelScopeGuideSubmitRollsBackSelectionWhenVerifyFails(t *testing.T) var selectedModels []string app, _ := newTestApp(t) app.providerSvc = stubProviderService{ - models: []providertypes.ModelDescriptor{{ID: "Qwen/Qwen2.5-7B-Instruct", Name: "Qwen"}}, + models: []providertypes.ModelDescriptor{{ID: config.ModelScopeDefaultModel, Name: "Qwen"}}, listModelsErr: errors.New("modelscope verify failed"), selectHook: func(providerID string) { selectedProviders = append(selectedProviders, providerID) @@ -3752,7 +3752,7 @@ func TestRunModelScopeGuideSubmitRollsBackSelectionWhenPersistFails(t *testing.T var selectedProviders []string var selectedModels []string providerSvc := stubProviderService{ - models: []providertypes.ModelDescriptor{{ID: "Qwen/Qwen2.5-7B-Instruct", Name: "Qwen"}}, + models: []providertypes.ModelDescriptor{{ID: config.ModelScopeDefaultModel, Name: "Qwen"}}, selectHook: func(providerID string) { selectedProviders = append(selectedProviders, providerID) }, From 50015dd33cc521b59a109b11e9dbb7de95510cf9 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Tue, 28 Apr 2026 16:42:26 +0000 Subject: [PATCH 2/2] test(config): improve refresh model coverage Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/config/state/service_test.go | 46 +++++++++++++++++++++++ internal/provider/catalog/service_test.go | 44 ++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/internal/config/state/service_test.go b/internal/config/state/service_test.go index e0a7b9bb..407fe5a6 100644 --- a/internal/config/state/service_test.go +++ b/internal/config/state/service_test.go @@ -184,6 +184,52 @@ func TestSelectionServiceListModelsSnapshotRejectsUnsupportedDriver(t *testing.T } } +func TestSelectionServiceRefreshModelsUsesRefreshPathAndRepairsCurrentModel(t *testing.T) { + t.Parallel() + + manager := newSelectionTestManager(t, testDefaultConfig()) + if err := manager.Update(context.Background(), func(cfg *configpkg.Config) error { + cfg.CurrentModel = "removed-model" + return nil + }); err != nil { + t.Fatalf("seed current model: %v", err) + } + + service := NewService(manager, newDriverSupporterStub(), catalogMethodsStub{ + listModels: []providertypes.ModelDescriptor{ + {ID: OpenAIDefaultModel, Name: "GPT Default"}, + {ID: "gpt-5.4-mini", Name: "GPT Mini"}, + }, + }) + + models, err := service.RefreshModels(context.Background()) + if err != nil { + t.Fatalf("RefreshModels() error = %v", err) + } + if len(models) != 2 { + t.Fatalf("expected refreshed models, got %+v", models) + } + + cfg := manager.Get() + if cfg.CurrentModel != OpenAIDefaultModel { + t.Fatalf("expected current model repaired to provider default, got %q", cfg.CurrentModel) + } +} + +func TestSelectionServiceRefreshModelsReturnsNoModelsAvailable(t *testing.T) { + t.Parallel() + + manager := newSelectionTestManager(t, testDefaultConfig()) + service := NewService(manager, newDriverSupporterStub(), catalogMethodsStub{ + listModels: nil, + }) + + _, err := service.RefreshModels(context.Background()) + if !errors.Is(err, ErrNoModelsAvailable) { + t.Fatalf("expected ErrNoModelsAvailable, got %v", err) + } +} + func TestSelectionServiceSelectProviderAndSetCurrentModel(t *testing.T) { manager := newSelectionTestManager(t, testDefaultConfig()) diff --git a/internal/provider/catalog/service_test.go b/internal/provider/catalog/service_test.go index fec5c6d3..54e4fb44 100644 --- a/internal/provider/catalog/service_test.go +++ b/internal/provider/catalog/service_test.go @@ -397,6 +397,50 @@ func TestListProviderModelsRefreshesWhenCatalogSnapshotIsEmpty(t *testing.T) { } } +func TestRefreshProviderModelsReturnsConfiguredModelsWhenDiscoveryDisabled(t *testing.T) { + t.Parallel() + + service := NewService("", newRegistry(t, openaicompat.DriverName, nil), newMemoryStore()) + providerCfg := customGatewayProvider() + providerCfg.ModelSource = config.ModelSourceManual + providerCfg.Models = []providertypes.ModelDescriptor{{ID: "manual-model", Name: "Manual Model"}} + + models, err := service.RefreshProviderModels(context.Background(), mustCatalogInput(t, providerCfg)) + if err != nil { + t.Fatalf("RefreshProviderModels() error = %v", err) + } + if len(models) != 1 || models[0].ID != "manual-model" { + t.Fatalf("expected configured models without discovery, got %+v", models) + } +} + +func TestRefreshProviderModelsDiscoversAndPersists(t *testing.T) { + t.Setenv(testAPIKeyEnv, "test-key") + + registry := newRegistry(t, openaicompat.DriverName, func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { + return []providertypes.ModelDescriptor{{ID: "fresh-model", Name: "Fresh Model"}}, nil + }) + store := newMemoryStore() + service := NewService("", registry, store) + + input := customGatewayProviderSource() + models, err := service.RefreshProviderModels(context.Background(), input) + if err != nil { + t.Fatalf("RefreshProviderModels() error = %v", err) + } + if len(models) != 1 || models[0].ID != "fresh-model" { + t.Fatalf("expected discovered models, got %+v", models) + } + + cached, err := store.Load(context.Background(), input.Identity) + if err != nil { + t.Fatalf("load cached catalog: %v", err) + } + if len(cached.Models) != 1 || cached.Models[0].ID != "fresh-model" { + t.Fatalf("expected refreshed models to be persisted, got %+v", cached.Models) + } +} + func TestDiscoverAndPersistFailurePaths(t *testing.T) { t.Run("unsupported driver", func(t *testing.T) { service := NewService("", provider.NewRegistry(), newMemoryStore())