diff --git a/internal/cli/gateway_runtime_bridge.go b/internal/cli/gateway_runtime_bridge.go index 97bdcdf6..0f748dd0 100644 --- a/internal/cli/gateway_runtime_bridge.go +++ b/internal/cli/gateway_runtime_bridge.go @@ -614,9 +614,10 @@ func (b *gatewayRuntimePortBridge) ListModels(ctx context.Context, input gateway name = id } models = append(models, gateway.ModelEntry{ - ID: id, - Name: name, - Provider: strings.TrimSpace(p.ID), + ID: id, + Name: name, + Provider: strings.TrimSpace(p.ID), + CapabilityHints: model.CapabilityHints, }) } } diff --git a/internal/config/defaults_test.go b/internal/config/defaults_test.go index f5fc7f7e..29fcc179 100644 --- a/internal/config/defaults_test.go +++ b/internal/config/defaults_test.go @@ -6,8 +6,8 @@ func TestDefaultProvidersIncludesBuiltinProviders(t *testing.T) { t.Parallel() providers := DefaultProviders() - if len(providers) != 4 { - t.Fatalf("expected 4 builtin providers, got %d", len(providers)) + if len(providers) != 10 { + t.Fatalf("expected 10 builtin providers, got %d", len(providers)) } if providers[0].Name != OpenAIName { t.Fatalf("expected first provider %q, got %q", OpenAIName, providers[0].Name) @@ -15,12 +15,6 @@ func TestDefaultProvidersIncludesBuiltinProviders(t *testing.T) { if providers[1].Name != GeminiName { t.Fatalf("expected second provider %q, got %q", GeminiName, providers[1].Name) } - if providers[2].Name != QiniuName { - t.Fatalf("expected third provider %q, got %q", QiniuName, providers[2].Name) - } - if providers[3].Name != ModelScopeName { - t.Fatalf("expected fourth provider %q, got %q", ModelScopeName, providers[3].Name) - } } func TestLoaderDefaultsValidate(t *testing.T) { diff --git a/internal/config/provider.go b/internal/config/provider.go index 061b8265..bfb9ca67 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -73,8 +73,8 @@ func (p ProviderConfig) Validate() error { if normalizedDriver == "" { return fmt.Errorf("provider %q driver is empty", p.Name) } - if normalizedDriver != provider.DriverOpenAICompat && strings.TrimSpace(p.ChatAPIMode) != "" { - return fmt.Errorf("provider %q chat_api_mode is only supported for openaicompat driver", p.Name) + if !supportsChatAPIMode(normalizedDriver) && strings.TrimSpace(p.ChatAPIMode) != "" { + return fmt.Errorf("provider %q chat_api_mode is only supported for openaicompat-compatible drivers", p.Name) } if strings.TrimSpace(p.BaseURL) == "" && !allowsEmptyBaseURL(normalizedDriver) { return fmt.Errorf("provider %q base_url is empty", p.Name) @@ -367,7 +367,7 @@ func normalizeProviderRuntimePathsFromConfig(cfg ProviderConfig) (string, string // requiresDiscoveryEndpointPath 标记哪些 driver 的 discover 仍依赖 HTTP endpoint 配置。 func requiresDiscoveryEndpointPath(driver string) bool { - return normalizeProviderDriver(driver) == provider.DriverOpenAICompat + return isOpenAICompatLike(driver) } // sanitizeRuntimeBaseURL 对运行时 base_url 做最小安全规整,确保不会透传 userinfo 等敏感片段。 @@ -397,6 +397,22 @@ func allowsEmptyBaseURL(driver string) bool { } } +// supportsChatAPIMode 判断指定 driver 是否允许配置 chat_api_mode 字段。 +func supportsChatAPIMode(driver string) bool { + return isOpenAICompatLike(driver) +} + +// isOpenAICompatLike 判断指定 driver 是否使用 OpenAI-compatible HTTP discovery。 +func isOpenAICompatLike(driver string) bool { + switch normalizeProviderDriver(driver) { + case provider.DriverOpenAICompat, provider.DriverDeepSeek, provider.DriverQwen, + provider.DriverGLM, provider.DriverMiMo, provider.DriverMiniMax: + return true + default: + return false + } +} + // identityBaseURL 返回用于身份归一化的 base_url,确保空值场景也有稳定键。 func identityBaseURL(cfg ProviderConfig) string { if strings.TrimSpace(cfg.BaseURL) != "" { @@ -407,6 +423,16 @@ func identityBaseURL(cfg ProviderConfig) string { return GeminiDefaultBaseURL case provider.DriverAnthropic: return AnthropicDefaultBaseURL + case provider.DriverDeepSeek: + return DeepSeekDefaultBaseURL + case provider.DriverQwen: + return QwenDefaultBaseURL + case provider.DriverGLM: + return GLMDefaultBaseURL + case provider.DriverMiMo: + return MiMoDefaultBaseURL + case provider.DriverMiniMax: + return MiniMaxDefaultBaseURL default: return cfg.BaseURL } @@ -415,12 +441,12 @@ func identityBaseURL(cfg ProviderConfig) string { const ( OpenAIName = "openai" OpenAIDefaultBaseURL = "https://api.openai.com/v1" - OpenAIDefaultModel = "gpt-5.4" + OpenAIDefaultModel = "gpt-5.5" OpenAIDefaultAPIKeyEnv = "OPENAI_API_KEY" GeminiName = "gemini" GeminiDefaultBaseURL = "https://generativelanguage.googleapis.com/v1beta" - GeminiDefaultModel = "gemini-2.5-flash" + GeminiDefaultModel = "gemini-3.1-pro-preview" GeminiDefaultAPIKeyEnv = "GEMINI_API_KEY" AnthropicDefaultBaseURL = "https://api.anthropic.com/v1" @@ -437,6 +463,21 @@ const ( ) var openAIStaticModels = []providertypes.ModelDescriptor{ + builtinModel( + "gpt-5.5", + "GPT-5.5", + "OpenAI flagship GPT-5.5 model with 1M context, enhanced reasoning and agentic coding.", + 1000000, + 128000, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + []string{"low", "medium", "high", "xhigh"}, + "high", + false, + ), + ), builtinModel( "gpt-5.4", "GPT-5.4", @@ -488,6 +529,36 @@ var openAIStaticModels = []providertypes.ModelDescriptor{ } var geminiStaticModels = []providertypes.ModelDescriptor{ + builtinModel( + "gemini-3.1-pro-preview", + "Gemini 3.1 Pro", + "Latest Gemini 3.1 flagship with 2M context, advanced reasoning and agentic coding.", + 2097152, + 65536, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + []string{"low", "high"}, + "high", + true, // ThinkingForceEnabled: only downgrade to LOW, cannot disable + ), + ), + builtinModel( + "gemini-3.1-flash", + "Gemini 3.1 Flash", + "Fast Gemini 3.1 model for cost-efficient multimodal tasks.", + 1048576, + 65536, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + []string{"low", "high"}, + "high", + true, // ThinkingForceEnabled + ), + ), builtinModel( "gemini-2.5-flash", "Gemini 2.5 Flash", @@ -547,6 +618,23 @@ func builtinCapabilities( } } +// builtinCapabilitiesV2 构造包含 thinking 能力的静态模型能力提示。 +func builtinCapabilitiesV2( + toolCalling, imageInput, thinking providertypes.ModelCapabilityState, + thinkingEfforts []string, + thinkingDefaultEffort string, + thinkingForceEnabled bool, +) providertypes.ModelCapabilityHints { + return providertypes.ModelCapabilityHints{ + ToolCalling: toolCalling, + ImageInput: imageInput, + Thinking: thinking, + ThinkingEfforts: thinkingEfforts, + ThinkingDefaultEffort: thinkingDefaultEffort, + ThinkingForceEnabled: thinkingForceEnabled, + } +} + // builtinModel 构造内建 provider 使用的静态模型条目。 func builtinModel( id string, @@ -621,11 +709,341 @@ func ModelScopeProvider() ProviderConfig { return cfg } +const ( + // New provider constants + DeepSeekName = "deepseek" + DeepSeekDefaultBaseURL = "https://api.deepseek.com" + DeepSeekDefaultModel = "deepseek-v4-pro" + DeepSeekDefaultAPIKeyEnv = "DEEPSEEK_API_KEY" + + KimiName = "kimi" + KimiDefaultBaseURL = "https://api.moonshot.cn/v1" + KimiDefaultModel = "kimi-k2.6" + KimiDefaultAPIKeyEnv = "KIMI_API_KEY" + + QwenName = "qwen" + QwenDefaultBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1" + QwenDefaultModel = "qwen3.6-plus" + QwenDefaultAPIKeyEnv = "QWEN_API_KEY" + + GLMName = "glm" + GLMDefaultBaseURL = "https://api.z.ai/api/paas/v4" + GLMDefaultModel = "glm-5.1" + GLMDefaultAPIKeyEnv = "GLM_API_KEY" + + MiMoName = "mimo" + MiMoDefaultBaseURL = "https://api.xiaomimimo.com/v1" + MiMoDefaultModel = "mimo-v2.5-pro" + MiMoDefaultAPIKeyEnv = "MIMO_API_KEY" + + MiniMaxName = "minimax" + MiniMaxDefaultBaseURL = "https://api.minimax.chat/v1/text/chatcompletion_v2" + MiniMaxDefaultModel = "MiniMax-M2.7" + MiniMaxDefaultAPIKeyEnv = "MINIMAX_API_KEY" +) + +// New static model lists. +var deepSeekStaticModels = []providertypes.ModelDescriptor{ + builtinModel( + "deepseek-v4-pro", + "DeepSeek V4 Pro", + "Flagship DeepSeek MoE model with 1M context and thinking/non-thinking dual modes.", + 1048576, + 384000, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateUnsupported, + providertypes.ModelCapabilityStateSupported, + []string{"high", "max"}, + "high", + false, + ), + ), + builtinModel( + "deepseek-v4-flash", + "DeepSeek V4 Flash", + "Fast, economical DeepSeek MoE model with 1M context and dual modes.", + 1048576, + 384000, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateUnsupported, + providertypes.ModelCapabilityStateSupported, + []string{"high", "max"}, + "high", + false, + ), + ), +} + +var kimiStaticModels = []providertypes.ModelDescriptor{ + builtinModel( + "kimi-k2.6", + "Kimi K2.6", + "Moonshot flagship 1T MoE model with 256K context and agent swarm support.", + 262144, + 4096, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateUnsupported, + providertypes.ModelCapabilityStateSupported, + nil, + "", + false, + ), + ), + builtinModel( + "kimi-k2.5", + "Kimi K2.5", + "Previous generation Moonshot model, flexible thinking default-on.", + 131072, + 4096, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateUnsupported, + providertypes.ModelCapabilityStateSupported, + nil, + "", + false, + ), + ), +} + +var qwenStaticModels = []providertypes.ModelDescriptor{ + builtinModel( + "qwen3.6-plus", + "Qwen 3.6 Plus", + "Alibaba flagship agentic coding model with 1M context and hybrid thinking.", + 1048576, + 128000, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + nil, + "", + false, + ), + ), + builtinModel( + "qwen3.6-flash", + "Qwen 3.6 Flash", + "Fast Qwen MoE model (35B total / 3B active) for cost-efficient agent tasks.", + 1048576, + 32768, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + nil, + "", + false, + ), + ), + builtinModel( + "qwen3.6-max-preview", + "Qwen 3.6 Max Preview", + "Latest Qwen flagship preview with preserve_thinking for sustained chain-of-thought.", + 1048576, + 128000, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + nil, + "", + false, + ), + ), +} + +var glmStaticModels = []providertypes.ModelDescriptor{ + builtinModel( + "glm-5.1", + "GLM-5.1", + "Zhipu flagship model, first Chinese model matching Claude Opus 4.6 across the board.", + 200000, + 128000, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateUnsupported, + providertypes.ModelCapabilityStateSupported, + nil, + "", + false, + ), + ), + builtinModel( + "glm-5", + "GLM-5", + "Zhipu 744B MoE model with DeepSeek Sparse Attention, MIT licensed.", + 200000, + 128000, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateUnsupported, + providertypes.ModelCapabilityStateSupported, + nil, + "", + false, + ), + ), +} + +var miMoStaticModels = []providertypes.ModelDescriptor{ + builtinModel( + "mimo-v2.5-pro", + "MiMo V2.5 Pro", + "Xiaomi flagship 1.02T MoE model with 1M context, #1 open-source on Artificial Analysis.", + 1048576, + 128000, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + nil, + "", + false, + ), + ), + builtinModel( + "mimo-v2-flash", + "MiMo V2 Flash", + "Xiaomi fast MoE model (309B/15B active), #1 open-source coding at 150 tok/s.", + 262144, + 32768, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateSupported, + nil, + "", + false, + ), + ), +} + +var miniMaxStaticModels = []providertypes.ModelDescriptor{ + builtinModel( + "MiniMax-M2.7", + "MiniMax M2.7", + "Latest MiniMax interleaved-thinking model with cost-aggressive pricing.", + 205000, + 8192, + builtinCapabilitiesV2( + providertypes.ModelCapabilityStateSupported, + providertypes.ModelCapabilityStateUnsupported, + providertypes.ModelCapabilityStateSupported, + nil, + "", + true, // ThinkingForceEnabled: enable_thinking=false is unreliable + ), + ), +} + +// New provider factory functions. +func DeepSeekProvider() ProviderConfig { + return ProviderConfig{ + Name: DeepSeekName, + Driver: provider.DriverDeepSeek, + BaseURL: DeepSeekDefaultBaseURL, + Model: DeepSeekDefaultModel, + APIKeyEnv: DeepSeekDefaultAPIKeyEnv, + ChatAPIMode: provider.ChatAPIModeChatCompletions, + ChatEndpointPath: "/chat/completions", + ModelSource: ModelSourceDiscover, + Models: cloneBuiltinModels(deepSeekStaticModels), + Source: ProviderSourceBuiltin, + } +} + +func KimiProvider() ProviderConfig { + return ProviderConfig{ + Name: KimiName, + Driver: provider.DriverOpenAICompat, + BaseURL: KimiDefaultBaseURL, + Model: KimiDefaultModel, + APIKeyEnv: KimiDefaultAPIKeyEnv, + ChatAPIMode: provider.ChatAPIModeChatCompletions, + ChatEndpointPath: "/chat/completions", + DiscoveryEndpointPath: provider.DiscoveryEndpointPathModels, + ModelSource: ModelSourceDiscover, + Models: cloneBuiltinModels(kimiStaticModels), + Source: ProviderSourceBuiltin, + } +} + +func QwenProvider() ProviderConfig { + return ProviderConfig{ + Name: QwenName, + Driver: provider.DriverQwen, + BaseURL: QwenDefaultBaseURL, + Model: QwenDefaultModel, + APIKeyEnv: QwenDefaultAPIKeyEnv, + ChatAPIMode: provider.ChatAPIModeChatCompletions, + ChatEndpointPath: "/chat/completions", + ModelSource: ModelSourceDiscover, + Models: cloneBuiltinModels(qwenStaticModels), + Source: ProviderSourceBuiltin, + } +} + +func GLMProvider() ProviderConfig { + return ProviderConfig{ + Name: GLMName, + Driver: provider.DriverGLM, + BaseURL: GLMDefaultBaseURL, + Model: GLMDefaultModel, + APIKeyEnv: GLMDefaultAPIKeyEnv, + ChatAPIMode: provider.ChatAPIModeChatCompletions, + ChatEndpointPath: "/chat/completions", + ModelSource: ModelSourceDiscover, + Models: cloneBuiltinModels(glmStaticModels), + Source: ProviderSourceBuiltin, + } +} + +func MiMoProvider() ProviderConfig { + return ProviderConfig{ + Name: MiMoName, + Driver: provider.DriverMiMo, + BaseURL: MiMoDefaultBaseURL, + Model: MiMoDefaultModel, + APIKeyEnv: MiMoDefaultAPIKeyEnv, + ChatAPIMode: provider.ChatAPIModeChatCompletions, + ChatEndpointPath: "/chat/completions", + ModelSource: ModelSourceDiscover, + Models: cloneBuiltinModels(miMoStaticModels), + Source: ProviderSourceBuiltin, + } +} + +func MiniMaxProvider() ProviderConfig { + return ProviderConfig{ + Name: MiniMaxName, + Driver: provider.DriverMiniMax, + BaseURL: MiniMaxDefaultBaseURL, + Model: MiniMaxDefaultModel, + APIKeyEnv: MiniMaxDefaultAPIKeyEnv, + ChatAPIMode: provider.ChatAPIModeChatCompletions, + ChatEndpointPath: "/", + ModelSource: ModelSourceDiscover, + Models: cloneBuiltinModels(miniMaxStaticModels), + Source: ProviderSourceBuiltin, + } +} + // DefaultProviders returns all builtin provider definitions. func DefaultProviders() []ProviderConfig { return []ProviderConfig{ OpenAIProvider(), GeminiProvider(), + DeepSeekProvider(), + KimiProvider(), + QwenProvider(), + GLMProvider(), + MiMoProvider(), + MiniMaxProvider(), QiniuProvider(), ModelScopeProvider(), } diff --git a/internal/config/provider_custom_normalize.go b/internal/config/provider_custom_normalize.go index 467fcba4..f18967b7 100644 --- a/internal/config/provider_custom_normalize.go +++ b/internal/config/provider_custom_normalize.go @@ -158,7 +158,7 @@ func normalizeCustomProviderModels(models []providertypes.ModelDescriptor) ([]pr if model.MaxOutputTokens != 0 { return nil, fmt.Errorf("config: models[%d].max_output_tokens is not supported", index) } - if model.CapabilityHints != (providertypes.ModelCapabilityHints{}) { + if !hintsAreZero(model.CapabilityHints) { return nil, fmt.Errorf("config: models[%d].capability_hints is not supported", index) } @@ -175,3 +175,13 @@ func normalizeCustomProviderModels(models []providertypes.ModelDescriptor) ([]pr } return normalized, nil } + +// hintsAreZero reports whether capability hints contain only zero values. +func hintsAreZero(h providertypes.ModelCapabilityHints) bool { + return h.ToolCalling == "" && + h.ImageInput == "" && + h.Thinking == "" && + len(h.ThinkingEfforts) == 0 && + h.ThinkingDefaultEffort == "" && + !h.ThinkingForceEnabled +} diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index c229764b..ac345f6b 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -622,11 +622,11 @@ func TestDefaultProvidersReturnsAllBuiltins(t *testing.T) { t.Parallel() providers := DefaultProviders() - if len(providers) != 4 { - t.Fatalf("expected 4 builtin providers, got %d", len(providers)) + if len(providers) != 10 { + t.Fatalf("expected 10 builtin providers, got %d", len(providers)) } - expectedNames := []string{OpenAIName, GeminiName, QiniuName, ModelScopeName} + expectedNames := []string{OpenAIName, GeminiName, DeepSeekName, KimiName, QwenName, GLMName, MiMoName, MiniMaxName, QiniuName, ModelScopeName} for i, provider := range providers { if provider.Name != expectedNames[i] { t.Fatalf("expected provider[%d] name %q, got %q", i, expectedNames[i], provider.Name) diff --git a/internal/config/state/service_test.go b/internal/config/state/service_test.go index 6ea10ec5..47fbf4a9 100644 --- a/internal/config/state/service_test.go +++ b/internal/config/state/service_test.go @@ -53,10 +53,16 @@ func TestSelectionServiceListProviderOptionsUsesCatalogModels(t *testing.T) { t.Fatalf("ListProviderOptions() error = %v", err) } expected := map[string]int{ - OpenAIName: 2, - GeminiName: 2, - QiniuName: 2, - ModelScopeName: 2, + configpkg.OpenAIName: 2, + configpkg.GeminiName: 2, + configpkg.DeepSeekName: 2, + configpkg.KimiName: 2, + configpkg.QwenName: 2, + configpkg.GLMName: 2, + configpkg.MiMoName: 2, + configpkg.MiniMaxName: 2, + configpkg.QiniuName: 2, + configpkg.ModelScopeName: 2, } if len(items) != len(expected) { t.Fatalf("expected only builtin providers, got %d", len(items)) diff --git a/internal/context/builder.go b/internal/context/builder.go index b8dafa03..569f84b8 100644 --- a/internal/context/builder.go +++ b/internal/context/builder.go @@ -19,6 +19,7 @@ type DefaultBuilder struct { func newPromptSources(extra ...SectionSource) []promptSectionSource { sources := []promptSectionSource{ corePromptSource{}, + capabilitiesSource{}, newRulesPromptSource(nil), taskStateSource{}, planModeContextSource{}, diff --git a/internal/context/prompt_test.go b/internal/context/prompt_test.go index b442d67e..7356cbc7 100644 --- a/internal/context/prompt_test.go +++ b/internal/context/prompt_test.go @@ -125,6 +125,12 @@ func TestDefaultToolUsagePromptIncludesPermissionAndAntiLoopGuidance(t *testing. if !strings.Contains(toolUsage, "`todo_write`") { t.Fatalf("expected Tool Usage to mention todo_write for task state, got %q", toolUsage) } + if !strings.Contains(toolUsage, "If the user clearly switches to a different task") { + t.Fatalf("expected Tool Usage to describe task-switch todo handling, got %q", toolUsage) + } + if !strings.Contains(toolUsage, "mark it `canceled` before planning or executing the new task") { + t.Fatalf("expected Tool Usage to require canceling stale todos on task switches, got %q", toolUsage) + } if !strings.Contains(toolUsage, "Execute todos sequentially in the main loop") { t.Fatalf("expected Tool Usage to enforce sequential todo execution, got %q", toolUsage) } diff --git a/internal/context/source_capabilities.go b/internal/context/source_capabilities.go new file mode 100644 index 00000000..ab991cd0 --- /dev/null +++ b/internal/context/source_capabilities.go @@ -0,0 +1,29 @@ +package context + +import ( + "context" + "strings" + + "neo-code/internal/promptasset" +) + +// capabilitiesSource 根据当前 PlanStage 动态注入能力声明。 +type capabilitiesSource struct{} + +// Sections 返回与当前模式匹配的能力与限制声明。 +func (capabilitiesSource) Sections(ctx context.Context, input BuildInput) ([]promptSection, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + stage := strings.TrimSpace(input.PlanStage) + content := promptasset.CapabilitiesPrompt(stage) + if content == "" { + return nil, nil + } + + return []promptSection{{ + Title: "Capabilities & Limitations", + Content: content, + }}, nil +} diff --git a/internal/context/source_todos.go b/internal/context/source_todos.go index 7651c619..cb58cbf2 100644 --- a/internal/context/source_todos.go +++ b/internal/context/source_todos.go @@ -55,7 +55,7 @@ func (todosSource) Sections(ctx context.Context, input BuildInput) ([]promptSect active = active[:maxPromptTodos] } - lines := make([]string, 0, len(active)+1) + lines := make([]string, 0, len(active)+2) for _, item := range active { id := sanitizePromptValue(item.ID, maxPromptTodoIDLength) content := sanitizePromptValue(item.Content, maxPromptTodoTextLen) @@ -80,6 +80,12 @@ func (todosSource) Sections(ctx context.Context, input BuildInput) ([]promptSect } } + lines = append(lines, "", + "stale_todo_reminder: If any todo above is no longer relevant to the current task,", + "or the user clearly switches to a different task, use todo_write to mark it completed", + "only if the work is actually done; otherwise set_status=canceled before moving on.", + ) + return []promptSection{ { Title: "Todo State", diff --git a/internal/context/source_todos_test.go b/internal/context/source_todos_test.go index e467ea95..5276b9fe 100644 --- a/internal/context/source_todos_test.go +++ b/internal/context/source_todos_test.go @@ -54,13 +54,19 @@ func TestTodosSourceSections(t *testing.T) { if sections[0].Title != "Todo State" { t.Fatalf("title = %q, want %q", sections[0].Title, "Todo State") } - if strings.Contains(sections[0].Content, "done") { + if strings.Contains(sections[0].Content, `id="done"`) { t.Fatalf("expected terminal todo filtered, got %q", sections[0].Content) } lines := strings.Split(sections[0].Content, "\n") if len(lines) < 2 || !strings.Contains(lines[0], "in-progress") { t.Fatalf("expected in_progress todo first, got %q", sections[0].Content) } + if !strings.Contains(sections[0].Content, "user clearly switches to a different task") { + t.Fatalf("expected stale todo reminder to mention task switching, got %q", sections[0].Content) + } + if !strings.Contains(sections[0].Content, "only if the work is actually done") { + t.Fatalf("expected stale todo reminder to distinguish completed from canceled, got %q", sections[0].Content) + } } func TestTodosSourceSectionsBoundaries(t *testing.T) { diff --git a/internal/gateway/contracts.go b/internal/gateway/contracts.go index 7de807f6..ff5b2996 100644 --- a/internal/gateway/contracts.go +++ b/internal/gateway/contracts.go @@ -225,6 +225,8 @@ type ModelEntry struct { Name string `json:"name"` // Provider 是模型供应商。 Provider string `json:"provider"` + // CapabilityHints 描述模型能力提示。 + CapabilityHints providertypes.ModelCapabilityHints `json:"capability_hints,omitempty"` } // ListModelsInput 表示 gateway.listModels 动作的下游输入。 diff --git a/internal/promptasset/assets.go b/internal/promptasset/assets.go index f8e8b0ed..3a640f7b 100644 --- a/internal/promptasset/assets.go +++ b/internal/promptasset/assets.go @@ -38,6 +38,10 @@ var coderRolePrompt = mustReadTemplate("templates/subagent/coder.md") var reviewerRolePrompt = mustReadTemplate("templates/subagent/reviewer.md") +var defaultCapabilities = mustReadTemplate("templates/core/capabilities.md") + +var planCapabilities = mustReadTemplate("templates/core/capabilities_plan.md") + // CoreSections 返回主会话固定核心 prompt sections 的有序副本。 func CoreSections() []Section { return append([]Section(nil), coreSections...) @@ -89,6 +93,14 @@ func ReviewerRolePrompt() string { return reviewerRolePrompt } +// CapabilitiesPrompt 根据当前 stage 返回对应的能力声明模板。 +func CapabilitiesPrompt(stage string) string { + if stage == "plan" { + return planCapabilities + } + return defaultCapabilities +} + // loadCoreSections 按固定顺序加载主会话核心 section 模板。 func loadCoreSections() []Section { return []Section{ diff --git a/internal/promptasset/templates/core/agent_identity.md b/internal/promptasset/templates/core/agent_identity.md index fb68e481..c5c43e22 100644 --- a/internal/promptasset/templates/core/agent_identity.md +++ b/internal/promptasset/templates/core/agent_identity.md @@ -26,18 +26,6 @@ Core workflow: 5. Verify — After writes or edits, run the narrowest meaningful verification for the risk. 6. Respond — Report what changed, what was verified, and what remains if incomplete. Do not over-explain. -Capabilities: -- Read, search, write, and edit files within the current workspace. -- Run non-interactive shell commands when filesystem tools are insufficient. -- Maintain explicit task state and todos via `todo_write`. -- Ask clarifying questions when requirements are ambiguous or conflicting. - -Limitations: -- Cannot access files or directories outside the provided workdir. -- Cannot browse the internet unless the `webfetch` tool is explicitly exposed. -- Cannot execute interactive commands that require human input. -- No persistent memory across sessions without explicit session-level context. - When to ask the user: - Destructive or risky operations (e.g., `rm`, `git push --force`). - Ambiguous requirements or conflicting constraints. diff --git a/internal/promptasset/templates/core/capabilities.md b/internal/promptasset/templates/core/capabilities.md new file mode 100644 index 00000000..1671acfd --- /dev/null +++ b/internal/promptasset/templates/core/capabilities.md @@ -0,0 +1,13 @@ +## Capabilities +You are currently in build execution mode. All tools are available. + +- Read, search, write, and edit files within the current workspace. +- Run non-interactive shell commands when filesystem tools are insufficient. +- Maintain explicit task state and todos via `todo_write`. +- Ask clarifying questions when requirements are ambiguous or conflicting. + +## Limitations +- Cannot access files or directories outside the provided workdir. +- Cannot browse the internet unless the `webfetch` tool is explicitly exposed. +- Cannot execute interactive commands that require human input. +- No persistent memory across sessions without explicit session-level context. diff --git a/internal/promptasset/templates/core/capabilities_plan.md b/internal/promptasset/templates/core/capabilities_plan.md new file mode 100644 index 00000000..646b642d --- /dev/null +++ b/internal/promptasset/templates/core/capabilities_plan.md @@ -0,0 +1,14 @@ +## Capabilities +You are currently in plan mode. Write and edit tools are disabled. Only read and search tools are available. + +- Read and search files within the current workspace. +- Run non-interactive shell commands for read-only inspection only. +- Maintain explicit task state and todos via `todo_write`. +- Ask clarifying questions when requirements are ambiguous or conflicting. +- **Do not perform any write, edit, delete, or file mutation operations.** Use this stage only for research, analysis, and planning. + +## Limitations +- Cannot write, edit, create, or delete files in plan mode. +- Cannot access files or directories outside the provided workdir. +- Cannot browse the internet unless the `webfetch` tool is explicitly exposed. +- No persistent memory across sessions without explicit session-level context. diff --git a/internal/promptasset/templates/core/tool_usage.md b/internal/promptasset/templates/core/tool_usage.md index e80cf99f..410c4ea2 100644 --- a/internal/promptasset/templates/core/tool_usage.md +++ b/internal/promptasset/templates/core/tool_usage.md @@ -24,6 +24,7 @@ - For multi-step implementation, debugging, refactoring, or long-running work, keep task state explicit via `todo_write` (plan/add/update/set_status/claim/complete/fail) instead of relying on implicit memory. - Create todos that map to real acceptance work, not vague activity. - Required todos are acceptance-relevant and must converge before finalization. +- If the user clearly switches to a different task, do not carry unfinished todos forward blindly: mark each old todo `completed` only when the work is actually done, otherwise mark it `canceled` before planning or executing the new task. - `todo_write` parameters must match schema strictly: `id` must be a string (for example, `"3"` instead of `3`). - `todo_write` does not auto-dispatch subagents. Setting todo metadata does not trigger execution by itself. - `todo_write` `set_status` requires: `{"action":"set_status","id":"","status":"pending|in_progress|blocked|completed|failed|canceled"}`. diff --git a/internal/provider/builtin/builtin.go b/internal/provider/builtin/builtin.go index 89ab753a..ab1bdfb0 100644 --- a/internal/provider/builtin/builtin.go +++ b/internal/provider/builtin/builtin.go @@ -5,8 +5,13 @@ import ( "neo-code/internal/provider" "neo-code/internal/provider/anthropic" + "neo-code/internal/provider/deepseek" "neo-code/internal/provider/gemini" + "neo-code/internal/provider/mimo" + "neo-code/internal/provider/minimax" "neo-code/internal/provider/openaicompat" + "neo-code/internal/provider/openaicompat/glm" + "neo-code/internal/provider/openaicompat/qwen" ) func NewRegistry() (*provider.Registry, error) { @@ -21,11 +26,20 @@ func register(registry *provider.Registry) error { if registry == nil { return errors.New("builtin provider registry is nil") } - if err := registry.Register(openaicompat.Driver()); err != nil { - return err + drivers := []provider.DriverDefinition{ + openaicompat.Driver(), + gemini.Driver(), + anthropic.Driver(), + deepseek.Driver(), + qwen.Driver(), + glm.Driver(), + mimo.Driver(), + minimax.Driver(), } - if err := registry.Register(gemini.Driver()); err != nil { - return err + for _, d := range drivers { + if err := registry.Register(d); err != nil { + return err + } } - return registry.Register(anthropic.Driver()) + return nil } diff --git a/internal/provider/catalog/service_test.go b/internal/provider/catalog/service_test.go index d53b9e6a..ee030a2a 100644 --- a/internal/provider/catalog/service_test.go +++ b/internal/provider/catalog/service_test.go @@ -75,8 +75,8 @@ func TestListProviderModelsBuiltinDiscoversAndMergesWithStaticModels(t *testing. if err != nil { t.Fatalf("ListProviderModels() error = %v", err) } - if len(models) != 7 { - t.Fatalf("expected 7 models (6 static + 1 new discovered), got %d: %+v", len(models), models) + if len(models) != 8 { + t.Fatalf("expected 8 models (7 static + 1 new discovered), got %d: %+v", len(models), models) } // Discovered model not in static list should appear @@ -102,8 +102,8 @@ func TestListProviderModelsBuiltinFallbackWhenDiscoveryFails(t *testing.T) { if err != nil { t.Fatalf("ListProviderModels() error = %v", err) } - if len(models) != 6 { - t.Fatalf("expected 6 fallback static models when discovery fails, got %d: %+v", len(models), models) + if len(models) != 7 { + t.Fatalf("expected 7 fallback static models when discovery fails, got %d: %+v", len(models), models) } if !containsModelDescriptorID(models, config.OpenAIDefaultModel) { t.Fatalf("expected fallback default model to be present, got %+v", models) diff --git a/internal/provider/conformance/conformance_test.go b/internal/provider/conformance/conformance_test.go index bf4235d9..976eb9c8 100644 --- a/internal/provider/conformance/conformance_test.go +++ b/internal/provider/conformance/conformance_test.go @@ -12,8 +12,13 @@ import ( "neo-code/internal/provider" "neo-code/internal/provider/anthropic" + "neo-code/internal/provider/deepseek" "neo-code/internal/provider/gemini" + "neo-code/internal/provider/mimo" + "neo-code/internal/provider/minimax" "neo-code/internal/provider/openaicompat" + "neo-code/internal/provider/openaicompat/glm" + "neo-code/internal/provider/openaicompat/qwen" providertypes "neo-code/internal/provider/types" ) @@ -27,6 +32,7 @@ func TestGenerateContractAcrossDrivers(t *testing.T) { streamBody string expectReason string expectTokens int + expectedOrder []providertypes.StreamEventType }{ { name: "openaicompat_chat_completions", @@ -50,6 +56,12 @@ func TestGenerateContractAcrossDrivers(t *testing.T) { "data: [DONE]\n\n", expectReason: "stop", expectTokens: 7, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, }, { name: "gemini_native", @@ -71,6 +83,12 @@ func TestGenerateContractAcrossDrivers(t *testing.T) { "data: {\"candidates\":[{\"index\":0,\"finishReason\":\"STOP\",\"content\":{\"parts\":[{\"functionCall\":{\"name\":\"filesystem_read_file\",\"args\":{\"path\":\"README.md\"}}}]}}]}\n\n", expectReason: "stop", expectTokens: 7, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, }, { name: "anthropic_messages", @@ -100,6 +118,128 @@ func TestGenerateContractAcrossDrivers(t *testing.T) { "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":6}}\n\n", expectReason: "tool_use", expectTokens: 10, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, + }, + { + name: "deepseek_chat_completions_with_reasoning", + driver: deepseek.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: deepseek.DriverName, + Driver: provider.DriverDeepSeek, + BaseURL: baseURL, + DefaultModel: "deepseek-v4", + APIKeyEnv: "DEEPSEEK_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + expectedPath: "/v1/chat/completions", + expectedHeader: "Authorization", + streamBody: "data: {\"choices\":[{\"delta\":{\"reasoning_content\":\"plan\",\"content\":\"Hello \"}}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n" + + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"filesystem_read_file\"}}]}}]}\n" + + "data: {\"choices\":[{\"finish_reason\":\"stop\",\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"arguments\":\"{\\\"path\\\":\\\"README.md\\\"}\"}}]}}]}\n" + + "data: [DONE]\n\n", + expectReason: "stop", + expectTokens: 7, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventThinkingDelta, + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, + }, + { + name: "mimo_chat_completions_with_reasoning", + driver: mimo.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: mimo.DriverName, + Driver: provider.DriverMiMo, + BaseURL: baseURL, + DefaultModel: "mimo-v2.5", + APIKeyEnv: "MIMO_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + expectedPath: "/chat/completions", + expectedHeader: "Authorization", + streamBody: "data: {\"choices\":[{\"delta\":{\"reasoning_content\":\"plan\",\"content\":\"Hello \"}}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n" + + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"filesystem_read_file\"}}]}}]}\n" + + "data: {\"choices\":[{\"finish_reason\":\"stop\",\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"arguments\":\"{\\\"path\\\":\\\"README.md\\\"}\"}}]}}]}\n" + + "data: [DONE]\n\n", + expectReason: "stop", + expectTokens: 7, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventThinkingDelta, + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, + }, + { + name: "qwen_chat_completions_with_reasoning", + driver: qwen.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: qwen.DriverName, + Driver: provider.DriverQwen, + BaseURL: baseURL, + DefaultModel: "qwen3", + APIKeyEnv: "QWEN_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + expectedPath: "/chat/completions", + expectedHeader: "Authorization", + streamBody: "data: {\"choices\":[{\"delta\":{\"reasoning_content\":\"plan\",\"content\":\"Hello \"}}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n" + + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"filesystem_read_file\"}}]}}]}\n" + + "data: {\"choices\":[{\"finish_reason\":\"stop\",\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"arguments\":\"{\\\"path\\\":\\\"README.md\\\"}\"}}]}}]}\n" + + "data: [DONE]\n\n", + expectReason: "stop", + expectTokens: 7, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventThinkingDelta, + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, + }, + { + name: "glm_chat_completions_with_reasoning", + driver: glm.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: glm.DriverName, + Driver: provider.DriverGLM, + BaseURL: baseURL, + DefaultModel: "glm-5.1", + APIKeyEnv: "GLM_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + expectedPath: "/chat/completions", + expectedHeader: "Authorization", + streamBody: "data: {\"choices\":[{\"delta\":{\"reasoning_content\":\"plan\",\"content\":\"Hello \"}}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n" + + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"filesystem_read_file\"}}]}}]}\n" + + "data: {\"choices\":[{\"finish_reason\":\"stop\",\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"arguments\":\"{\\\"path\\\":\\\"README.md\\\"}\"}}]}}]}\n" + + "data: [DONE]\n\n", + expectReason: "stop", + expectTokens: 7, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventThinkingDelta, + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, }, } @@ -131,22 +271,16 @@ func TestGenerateContractAcrossDrivers(t *testing.T) { } drained := drainEvents(events) - if len(drained) != 4 { - t.Fatalf("expected 4 events, got %d (%+v)", len(drained), drained) - } - expectedOrder := []providertypes.StreamEventType{ - providertypes.StreamEventTextDelta, - providertypes.StreamEventToolCallStart, - providertypes.StreamEventToolCallDelta, - providertypes.StreamEventMessageDone, + if len(drained) != len(tt.expectedOrder) { + t.Fatalf("expected %d events, got %d (%+v)", len(tt.expectedOrder), len(drained), drained) } - for i := range expectedOrder { - if drained[i].Type != expectedOrder[i] { - t.Fatalf("unexpected event order at index %d, expected %q got %q", i, expectedOrder[i], drained[i].Type) + for i := range tt.expectedOrder { + if drained[i].Type != tt.expectedOrder[i] { + t.Fatalf("unexpected event order at index %d, expected %q got %q", i, tt.expectedOrder[i], drained[i].Type) } } - done, doneErr := drained[3].MessageDoneValue() + done, doneErr := drained[len(drained)-1].MessageDoneValue() if doneErr != nil { t.Fatalf("MessageDoneValue() error = %v", doneErr) } @@ -186,6 +320,23 @@ func TestDiscoverContractAcrossDrivers(t *testing.T) { expectedHeader: "Authorization", responseBody: `{"data":[{"id":"gpt-4.1","name":"GPT 4.1"}]}`, }, + { + name: "deepseek_discover", + driver: deepseek.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: deepseek.DriverName, + Driver: provider.DriverDeepSeek, + BaseURL: baseURL, + APIKeyEnv: "DEEPSEEK_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + DiscoveryEndpointPath: "/models", + } + }, + expectedPath: "/models", + expectedHeader: "Authorization", + responseBody: `{"data":[{"id":"deepseek-v4","name":"DeepSeek V4"}]}`, + }, { name: "gemini_discover", driver: gemini.Driver(), @@ -220,6 +371,74 @@ func TestDiscoverContractAcrossDrivers(t *testing.T) { expectedHeader: "x-api-key", responseBody: `{"data":[{"id":"claude-3-7-sonnet","display_name":"Claude 3.7 Sonnet"}],"has_more":false}`, }, + { + name: "mimo_discover", + driver: mimo.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: mimo.DriverName, + Driver: provider.DriverMiMo, + BaseURL: baseURL, + APIKeyEnv: "MIMO_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + DiscoveryEndpointPath: "/models", + } + }, + expectedPath: "/models", + expectedHeader: "Authorization", + responseBody: `{"data":[{"id":"mimo-v2.5","name":"MiMo V2.5"}]}`, + }, + { + name: "minimax_discover", + driver: minimax.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: minimax.DriverName, + Driver: provider.DriverMiniMax, + BaseURL: baseURL, + APIKeyEnv: "MINIMAX_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + DiscoveryEndpointPath: "/models", + } + }, + expectedPath: "/models", + expectedHeader: "Authorization", + responseBody: `{"data":[{"id":"minimax-m2.7","name":"MiniMax M2.7"}]}`, + }, + { + name: "qwen_discover", + driver: qwen.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: qwen.DriverName, + Driver: provider.DriverQwen, + BaseURL: baseURL, + APIKeyEnv: "QWEN_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + DiscoveryEndpointPath: "/models", + } + }, + expectedPath: "/models", + expectedHeader: "Authorization", + responseBody: `{"data":[{"id":"qwen3","name":"Qwen 3"}]}`, + }, + { + name: "glm_discover", + driver: glm.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: glm.DriverName, + Driver: provider.DriverGLM, + BaseURL: baseURL, + APIKeyEnv: "GLM_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + DiscoveryEndpointPath: "/models", + } + }, + expectedPath: "/models", + expectedHeader: "Authorization", + responseBody: `{"data":[{"id":"glm-5.1","name":"GLM 5.1"}]}`, + }, } for _, tt := range testCases { @@ -272,6 +491,21 @@ func TestGenerateErrorClassificationAcrossDrivers(t *testing.T) { path: "/chat/completions", body: `{"error":{"message":"invalid api key"}}`, }, + { + name: "deepseek_auth_error", + driver: deepseek.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Driver: provider.DriverDeepSeek, + BaseURL: baseURL, + DefaultModel: "deepseek-v4", + APIKeyEnv: "DEEPSEEK_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + path: "/v1/chat/completions", + body: `{"error":{"message":"invalid api key"}}`, + }, { name: "gemini_auth_error", driver: gemini.Driver(), @@ -304,6 +538,66 @@ func TestGenerateErrorClassificationAcrossDrivers(t *testing.T) { path: "/v1/messages", body: `{"error":{"message":"invalid x-api-key"}}`, }, + { + name: "mimo_auth_error", + driver: mimo.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Driver: provider.DriverMiMo, + BaseURL: baseURL, + DefaultModel: "mimo-v2.5", + APIKeyEnv: "MIMO_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + path: "/chat/completions", + body: `{"error":{"message":"invalid api key"}}`, + }, + { + name: "minimax_auth_error", + driver: minimax.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Driver: provider.DriverMiniMax, + BaseURL: baseURL + "/chat/completions", + DefaultModel: "minimax-m2.7", + APIKeyEnv: "MINIMAX_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + path: "/chat/completions", + body: `{"error":{"message":"invalid api key"}}`, + }, + { + name: "qwen_auth_error", + driver: qwen.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Driver: provider.DriverQwen, + BaseURL: baseURL, + DefaultModel: "qwen3", + APIKeyEnv: "QWEN_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + path: "/chat/completions", + body: `{"error":{"message":"invalid api key"}}`, + }, + { + name: "glm_auth_error", + driver: glm.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Driver: provider.DriverGLM, + BaseURL: baseURL, + DefaultModel: "glm-5.1", + APIKeyEnv: "GLM_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + path: "/chat/completions", + body: `{"error":{"message":"invalid api key"}}`, + }, } for _, tt := range testCases { @@ -346,6 +640,106 @@ func TestGenerateErrorClassificationAcrossDrivers(t *testing.T) { } } +func TestEstimateInputTokensContractAcrossDrivers(t *testing.T) { + testCases := []struct { + name string + driver provider.DriverDefinition + buildConfig func() provider.RuntimeConfig + }{ + { + name: "deepseek_estimate", + driver: deepseek.Driver(), + buildConfig: func() provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: deepseek.DriverName, + Driver: provider.DriverDeepSeek, + BaseURL: "https://api.deepseek.example", + DefaultModel: "deepseek-v4", + APIKeyEnv: "DEEPSEEK_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + }, + { + name: "mimo_estimate", + driver: mimo.Driver(), + buildConfig: func() provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: mimo.DriverName, + Driver: provider.DriverMiMo, + BaseURL: "https://api.mimo.example", + DefaultModel: "mimo-v2.5", + APIKeyEnv: "MIMO_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + }, + { + name: "minimax_estimate", + driver: minimax.Driver(), + buildConfig: func() provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: minimax.DriverName, + Driver: provider.DriverMiniMax, + BaseURL: "https://api.minimax.example/chat/completions", + DefaultModel: "minimax-m2.7", + APIKeyEnv: "MINIMAX_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + }, + { + name: "qwen_estimate", + driver: qwen.Driver(), + buildConfig: func() provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: qwen.DriverName, + Driver: provider.DriverQwen, + BaseURL: "https://api.qwen.example", + DefaultModel: "qwen3", + APIKeyEnv: "QWEN_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + }, + { + name: "glm_estimate", + driver: glm.Driver(), + buildConfig: func() provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: glm.DriverName, + Driver: provider.DriverGLM, + BaseURL: "https://api.glm.example", + DefaultModel: "glm-5.1", + APIKeyEnv: "GLM_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + }, + } + + for _, tt := range testCases { + tt := tt + t.Run(tt.name, func(t *testing.T) { + p, err := tt.driver.Build(context.Background(), tt.buildConfig()) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + estimate, err := p.EstimateInputTokens(context.Background(), generateRequestWithAssets()) + if err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + if estimate.EstimatedInputTokens <= 0 { + t.Fatalf("expected positive token estimate, got %+v", estimate) + } + if estimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("expected local estimate source, got %+v", estimate) + } + }) + } +} + func generateRequestWithAssets() providertypes.GenerateRequest { return providertypes.GenerateRequest{ Messages: []providertypes.Message{ diff --git a/internal/provider/constants.go b/internal/provider/constants.go index ea28a24d..e31c13f9 100644 --- a/internal/provider/constants.go +++ b/internal/provider/constants.go @@ -7,6 +7,11 @@ const ( DriverOpenAICompat = "openaicompat" DriverGemini = "gemini" DriverAnthropic = "anthropic" + DriverDeepSeek = "deepseek" + DriverQwen = "qwen" + DriverGLM = "glm" + DriverMiMo = "mimo" + DriverMiniMax = "minimax" DiscoveryEndpointPathModels = "/models" ) diff --git a/internal/provider/deepseek/driver.go b/internal/provider/deepseek/driver.go new file mode 100644 index 00000000..46e35295 --- /dev/null +++ b/internal/provider/deepseek/driver.go @@ -0,0 +1,29 @@ +package deepseek + +import ( + "context" + + "neo-code/internal/provider" + providertypes "neo-code/internal/provider/types" +) + +const DriverName = provider.DriverDeepSeek + +func Driver() provider.DriverDefinition { + return provider.DriverDefinition{ + Name: DriverName, + Build: func(ctx context.Context, cfg provider.RuntimeConfig) (provider.Provider, error) { + return New(cfg) + }, + Discover: func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { + p, err := New(cfg) + if err != nil { + return nil, err + } + return p.DiscoverModels(ctx) + }, + ValidateCatalogIdentity: func(identity provider.ProviderIdentity) error { + return nil + }, + } +} diff --git a/internal/provider/deepseek/driver_test.go b/internal/provider/deepseek/driver_test.go new file mode 100644 index 00000000..ca0b83d9 --- /dev/null +++ b/internal/provider/deepseek/driver_test.go @@ -0,0 +1,50 @@ +package deepseek + +import ( + "testing" + + "neo-code/internal/provider" +) + +func TestDriverName(t *testing.T) { + d := Driver() + if d.Name != provider.DriverDeepSeek { + t.Fatalf("expected driver name %q, got %q", provider.DriverDeepSeek, d.Name) + } + if d.Build == nil { + t.Fatal("build func is nil") + } + if d.Discover == nil { + t.Fatal("discover func is nil") + } +} + +func TestNewValidatesBaseURL(t *testing.T) { + t.Parallel() + _, err := New(provider.RuntimeConfig{BaseURL: "", APIKeyEnv: "KEY"}) + if err == nil { + t.Fatal("expected error for empty baseURL") + } +} + +func TestNewValidatesAPIKeyEnv(t *testing.T) { + t.Parallel() + _, err := New(provider.RuntimeConfig{BaseURL: "https://api.example.com", APIKeyEnv: ""}) + if err == nil { + t.Fatal("expected error for empty api_key_env") + } +} + +func TestNewSucceedsWithValidConfig(t *testing.T) { + t.Parallel() + p, err := New(provider.RuntimeConfig{ + BaseURL: "https://api.example.com", + APIKeyEnv: "TEST_KEY", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if p == nil { + t.Fatal("expected non-nil provider") + } +} diff --git a/internal/provider/deepseek/provider.go b/internal/provider/deepseek/provider.go new file mode 100644 index 00000000..cf110709 --- /dev/null +++ b/internal/provider/deepseek/provider.go @@ -0,0 +1,125 @@ +package deepseek + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "neo-code/internal/provider" + "neo-code/internal/provider/openaicompat" + "neo-code/internal/provider/openaicompat/chatcompletions" + providertypes "neo-code/internal/provider/types" +) + +const errorPrefix = "deepseek provider: " + +type Provider struct { + cfg provider.RuntimeConfig + generateClient *http.Client + discoveryClient *http.Client +} + +func New(cfg provider.RuntimeConfig) (*Provider, error) { + if strings.TrimSpace(cfg.BaseURL) == "" { + return nil, errors.New(errorPrefix + "base url is empty") + } + if strings.TrimSpace(cfg.APIKeyEnv) == "" { + return nil, errors.New(errorPrefix + "api_key_env is empty") + } + return &Provider{ + cfg: cfg, + generateClient: &http.Client{ + Transport: http.DefaultTransport, + }, + discoveryClient: &http.Client{ + Timeout: provider.DefaultSDKRequestTimeout, + Transport: http.DefaultTransport, + }, + }, nil +} + +func (p *Provider) EstimateInputTokens( + ctx context.Context, + req providertypes.GenerateRequest, +) (providertypes.BudgetEstimate, error) { + payload, err := BuildRequest(ctx, p.cfg, req) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + tokens, err := provider.EstimateSerializedPayloadTokens(payload) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil +} + +func (p *Provider) DiscoverModels(ctx context.Context) ([]providertypes.ModelDescriptor, error) { + requestCfg, err := openaicompat.RequestConfigFromRuntime(p.cfg) + if err != nil { + return nil, err + } + return openaicompat.DiscoverModelDescriptors(ctx, p.discoveryClient, requestCfg) +} + +func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + payload, err := BuildRequest(ctx, p.cfg, req) + if err != nil { + return err + } + + tc := req.ThinkingConfig + + return provider.RunGenerateWithRetry(ctx, p.cfg, events, func( + attemptCtx context.Context, + attemptEvents chan<- providertypes.StreamEvent, + ) error { + return p.generateOnce(attemptCtx, payload, tc, attemptEvents) + }) +} + +func (p *Provider) generateOnce(ctx context.Context, payload chatcompletions.Request, tc *providertypes.ThinkingConfig, events chan<- providertypes.StreamEvent) error { + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("%smarshal request: %w", errorPrefix, err) + } + + if tc != nil { + body, err = InjectThinkingParams(body, *tc) + if err != nil { + return fmt.Errorf("%sinject thinking params: %w", errorPrefix, err) + } + } + + apiKey, err := p.cfg.ResolveAPIKeyValue() + if err != nil { + return err + } + + endpoint := strings.TrimRight(strings.TrimSpace(p.cfg.BaseURL), "/") + "/v1/chat/completions" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("%screate request: %w", errorPrefix, err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err := p.generateClient.Do(req) + if err != nil { + return fmt.Errorf("%ssend request: %w", errorPrefix, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return provider.WrapIfThinkingNotSupported(chatcompletions.ParseError(resp)) + } + + return chatcompletions.ConsumeStream(ctx, resp.Body, events) +} diff --git a/internal/provider/deepseek/provider_more_test.go b/internal/provider/deepseek/provider_more_test.go new file mode 100644 index 00000000..b4a866c0 --- /dev/null +++ b/internal/provider/deepseek/provider_more_test.go @@ -0,0 +1,223 @@ +package deepseek + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "neo-code/internal/provider" + "neo-code/internal/provider/openaicompat/chatcompletions" + providertypes "neo-code/internal/provider/types" +) + +func TestDriverBuildAndDiscover(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/models" { + t.Fatalf("unexpected path %q", r.URL.Path) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{{"id": "deepseek-chat"}}, + }) + })) + defer server.Close() + + cfg := provider.RuntimeConfig{ + BaseURL: server.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + Driver: DriverName, + } + driver := Driver() + if _, err := driver.Build(context.Background(), cfg); err != nil { + t.Fatalf("Build() error = %v", err) + } + models, err := driver.Discover(context.Background(), cfg) + if err != nil { + t.Fatalf("Discover() error = %v", err) + } + if len(models) != 1 || models[0].ID != "deepseek-chat" { + t.Fatalf("unexpected models: %+v", models) + } + if err := driver.ValidateCatalogIdentity(provider.ProviderIdentity{}); err != nil { + t.Fatalf("ValidateCatalogIdentity() error = %v", err) + } + if _, err := driver.Discover(context.Background(), provider.RuntimeConfig{}); err == nil { + t.Fatal("expected invalid config discover error") + } +} + +func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { + t.Parallel() + + var authHeader string + var requestBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader = r.Header.Get("Authorization") + switch r.URL.Path { + case "/models": + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{{"id": "deepseek-chat"}}, + }) + case "/v1/chat/completions": + var err error + requestBody, err = ioReadAll(r) + if err != nil { + t.Fatalf("read body: %v", err) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte(strings.Join([]string{ + `data: {"choices":[{"delta":{"reasoning_content":"plan","content":"answer"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`, + `data: [DONE]`, + "", + }, "\n"))) + default: + t.Fatalf("unexpected path %q", r.URL.Path) + } + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + BaseURL: server.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + DefaultModel: "deepseek-chat", + Driver: DriverName, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + req := providertypes.GenerateRequest{ + Model: "deepseek-chat", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + }}, + ThinkingConfig: &providertypes.ThinkingConfig{Enabled: true, Effort: "high"}, + } + estimate, err := p.EstimateInputTokens(context.Background(), req) + if err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + if estimate.EstimatedInputTokens <= 0 { + t.Fatalf("expected positive token estimate, got %+v", estimate) + } + + models, err := p.DiscoverModels(context.Background()) + if err != nil { + t.Fatalf("DiscoverModels() error = %v", err) + } + if len(models) != 1 || models[0].ID != "deepseek-chat" { + t.Fatalf("unexpected models: %+v", models) + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), req, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + drained := drainDeepSeekEvents(events) + if len(drained) != 3 { + t.Fatalf("expected 3 events, got %d", len(drained)) + } + thinking, err := drained[0].ThinkingDeltaValue() + if err != nil || thinking.Text != "plan" { + t.Fatalf("unexpected thinking event: err=%v event=%+v", err, drained[0]) + } + if authHeader != "Bearer secret" { + t.Fatalf("authorization header = %q, want bearer token", authHeader) + } + if !strings.Contains(string(requestBody), `"reasoning_effort":"high"`) { + t.Fatalf("request body missing reasoning_effort: %s", string(requestBody)) + } + + errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{"message": "thinking is not supported"}, + }) + })) + defer errorServer.Close() + + p, err = New(provider.RuntimeConfig{ + BaseURL: errorServer.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + DefaultModel: "deepseek-chat", + Driver: DriverName, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + err = p.Generate(context.Background(), req, make(chan providertypes.StreamEvent, 1)) + if !provider.IsThinkingNotSupportedError(err) { + t.Fatalf("expected thinking-not-supported error, got %v", err) + } + + p, err = New(provider.RuntimeConfig{ + BaseURL: server.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + DefaultModel: "deepseek-chat", + Driver: DriverName, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + p.generateClient = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("network down") + })} + err = p.Generate(context.Background(), req, make(chan providertypes.StreamEvent, 1)) + if err == nil || !strings.Contains(err.Error(), "send request") { + t.Fatalf("expected send request error, got %v", err) + } + invalidReq := providertypes.GenerateRequest{ + Model: "deepseek-chat", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{{Kind: "invalid"}}, + }}, + } + if _, err := p.EstimateInputTokens(context.Background(), invalidReq); err == nil { + t.Fatal("expected invalid estimate request error") + } + if err := p.Generate(context.Background(), invalidReq, make(chan providertypes.StreamEvent, 1)); err == nil { + t.Fatal("expected invalid generate request error") + } + p.cfg.APIKeyResolver = provider.StaticAPIKeyResolver("") + if _, err := p.DiscoverModels(context.Background()); err == nil { + t.Fatal("expected discovery api key error") + } + if err := p.generateOnce(context.Background(), chatcompletions.Request{}, nil, make(chan providertypes.StreamEvent, 1)); err == nil { + t.Fatal("expected api key resolve error") + } +} + +func drainDeepSeekEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { + var drained []providertypes.StreamEvent + for { + select { + case event := <-events: + drained = append(drained, event) + default: + return drained + } + } +} + +func ioReadAll(r *http.Request) ([]byte, error) { + defer r.Body.Close() + return io.ReadAll(r.Body) +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/internal/provider/deepseek/request.go b/internal/provider/deepseek/request.go new file mode 100644 index 00000000..afa322f0 --- /dev/null +++ b/internal/provider/deepseek/request.go @@ -0,0 +1,58 @@ +package deepseek + +import ( + "context" + "encoding/json" + + "neo-code/internal/provider" + "neo-code/internal/provider/openaicompat/chatcompletions" + providertypes "neo-code/internal/provider/types" +) + +// BuildRequest 将 GenerateRequest 转换为 chatcompletions.Request。 +func BuildRequest(ctx context.Context, cfg provider.RuntimeConfig, req providertypes.GenerateRequest) (chatcompletions.Request, error) { + return chatcompletions.BuildRequest(ctx, cfg, req) +} + +// InjectThinkingParams 将基础请求 JSON 注入 DeepSeek 特定的 thinking 控制参数。 +func InjectThinkingParams(body []byte, tc providertypes.ThinkingConfig) ([]byte, error) { + if tc.Enabled { + return injectEnabledThinking(body, tc) + } + var raw map[string]any + if err := json.Unmarshal(body, &raw); err != nil { + return nil, err + } + raw["thinking"] = map[string]any{"type": "disabled"} + result, err := json.Marshal(raw) + if err != nil { + return nil, err + } + return result, nil +} + +func injectEnabledThinking(body []byte, tc providertypes.ThinkingConfig) ([]byte, error) { + var raw map[string]any + if err := json.Unmarshal(body, &raw); err != nil { + return nil, err + } + raw["thinking"] = map[string]any{"type": "enabled"} + if tc.Effort != "" { + raw["reasoning_effort"] = tc.Effort + } + return json.Marshal(raw) +} + +// ExtractContinuity 将 reasoning_content 存入消息的 ThinkingMetadata 中以备续轮使用。 +func ExtractContinuity(msg *providertypes.Message, reasoningContent string) { + if msg == nil || reasoningContent == "" { + return + } + meta, err := json.Marshal(map[string]string{ + "reasoning_content": reasoningContent, + }) + if err != nil { + return + } + msg.ThinkingMetadata = json.RawMessage(meta) +} diff --git a/internal/provider/deepseek/request_test.go b/internal/provider/deepseek/request_test.go new file mode 100644 index 00000000..ad274ef0 --- /dev/null +++ b/internal/provider/deepseek/request_test.go @@ -0,0 +1,110 @@ +package deepseek + +import ( + "encoding/json" + "testing" + + providertypes "neo-code/internal/provider/types" +) + +func TestInjectThinkingParams_Enabled(t *testing.T) { + t.Parallel() + + body := []byte(`{"model":"test","messages":[],"stream":true}`) + result, err := InjectThinkingParams(body, providertypes.ThinkingConfig{Enabled: true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var raw map[string]any + if err := json.Unmarshal(result, &raw); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + thinking, ok := raw["thinking"].(map[string]any) + if !ok { + t.Fatalf("thinking not found or wrong type") + } + if thinking["type"] != "enabled" { + t.Fatalf("expected thinking.type=enabled, got %v", thinking["type"]) + } +} + +func TestInjectThinkingParams_Disabled(t *testing.T) { + t.Parallel() + + body := []byte(`{"model":"test","messages":[],"stream":true}`) + result, err := InjectThinkingParams(body, providertypes.ThinkingConfig{Enabled: false}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var raw map[string]any + if err := json.Unmarshal(result, &raw); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + thinking, ok := raw["thinking"].(map[string]any) + if !ok { + t.Fatalf("thinking not found") + } + if thinking["type"] != "disabled" { + t.Fatalf("expected thinking.type=disabled, got %v", thinking["type"]) + } +} + +func TestInjectThinkingParams_WithEffort(t *testing.T) { + t.Parallel() + + body := []byte(`{"model":"test","messages":[],"stream":true}`) + result, err := InjectThinkingParams(body, providertypes.ThinkingConfig{ + Enabled: true, + Effort: "max", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var raw map[string]any + if err := json.Unmarshal(result, &raw); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if raw["reasoning_effort"] != "max" { + t.Fatalf("expected reasoning_effort=max, got %v", raw["reasoning_effort"]) + } +} + +func TestExtractContinuity(t *testing.T) { + t.Parallel() + + msg := &providertypes.Message{ + Role: providertypes.RoleAssistant, + } + ExtractContinuity(msg, "thinking content here") + + if len(msg.ThinkingMetadata) == 0 { + t.Fatalf("expected ThinkingMetadata to be set") + } + + var meta map[string]string + if err := json.Unmarshal(msg.ThinkingMetadata, &meta); err != nil { + t.Fatalf("unmarshal ThinkingMetadata: %v", err) + } + if meta["reasoning_content"] != "thinking content here" { + t.Fatalf("expected reasoning_content, got %v", meta) + } +} + +func TestExtractContinuity_EmptySkips(t *testing.T) { + t.Parallel() + + msg := &providertypes.Message{} + ExtractContinuity(msg, "") + if len(msg.ThinkingMetadata) != 0 { + t.Fatalf("expected empty ThinkingMetadata") + } + + ExtractContinuity(nil, "content") + // nil should not panic +} diff --git a/internal/provider/errors.go b/internal/provider/errors.go index b7e25b82..1bca17fc 100644 --- a/internal/provider/errors.go +++ b/internal/provider/errors.go @@ -14,9 +14,10 @@ var ( ErrDiscoveryConfig = errors.New("provider: discovery config invalid") // 流级哨兵错误,用于区分可恢复/不可恢复的流中断原因。 - ErrStreamInterrupted = errors.New("provider: stream interrupted") - ErrLineTooLong = errors.New("provider: SSE line exceeds max length") - ErrStreamTooLarge = errors.New("provider: stream total size exceeds limit") + ErrStreamInterrupted = errors.New("provider: stream interrupted") + ErrLineTooLong = errors.New("provider: SSE line exceeds max length") + ErrStreamTooLarge = errors.New("provider: stream total size exceeds limit") + ErrThinkingNotSupported = errors.New("provider: thinking not supported for this model") ) type ProviderErrorCode string @@ -154,6 +155,34 @@ func NewTimeoutProviderError(message string) *ProviderError { } } +// IsThinkingNotSupportedError 判断错误是否为 thinking 不支持错误。 +func IsThinkingNotSupportedError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, ErrThinkingNotSupported) +} + +var thinkingNotSupportedFragments = []string{ + "thinking", + "reasoning", +} + +// WrapIfThinkingNotSupported 检查错误消息是否包含 thinking/reasoning 相关关键字, +// 若是则用 ErrThinkingNotSupported 包装,供 runtime 降级重试。 +func WrapIfThinkingNotSupported(err error) error { + if err == nil { + return nil + } + msg := strings.ToLower(err.Error()) + for _, f := range thinkingNotSupportedFragments { + if strings.Contains(msg, f) { + return fmt.Errorf("%w: %w", ErrThinkingNotSupported, err) + } + } + return err +} + // IsContextTooLong 判断 provider 错误是否表示请求上下文超出模型窗口。 // 优先识别 typed error,必要时再回退到消息文本匹配,兼容不同厂商或额外包装层。 // 已被归类为 rate_limited (429) 的错误不会因文本片段而被误判为 context_too_long。 diff --git a/internal/provider/generate_attempt.go b/internal/provider/generate_attempt.go index f8bec712..eb7f2f72 100644 --- a/internal/provider/generate_attempt.go +++ b/internal/provider/generate_attempt.go @@ -280,7 +280,8 @@ func updateGenerateAttemptPhase( // IsEffectiveGeneratePayloadEvent 判断事件是否属于“流已开始”的有效 payload。 func IsEffectiveGeneratePayloadEvent(event providertypes.StreamEvent) bool { switch event.Type { - case providertypes.StreamEventTextDelta, providertypes.StreamEventToolCallStart, providertypes.StreamEventToolCallDelta: + case providertypes.StreamEventTextDelta, providertypes.StreamEventToolCallStart, providertypes.StreamEventToolCallDelta, + providertypes.StreamEventThinkingDelta: return true default: return false diff --git a/internal/provider/mimo/driver.go b/internal/provider/mimo/driver.go new file mode 100644 index 00000000..5a9330af --- /dev/null +++ b/internal/provider/mimo/driver.go @@ -0,0 +1,29 @@ +package mimo + +import ( + "context" + + "neo-code/internal/provider" + providertypes "neo-code/internal/provider/types" +) + +const DriverName = provider.DriverMiMo + +func Driver() provider.DriverDefinition { + return provider.DriverDefinition{ + Name: DriverName, + Build: func(ctx context.Context, cfg provider.RuntimeConfig) (provider.Provider, error) { + return New(cfg) + }, + Discover: func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { + p, err := New(cfg) + if err != nil { + return nil, err + } + return p.DiscoverModels(ctx) + }, + ValidateCatalogIdentity: func(identity provider.ProviderIdentity) error { + return nil + }, + } +} diff --git a/internal/provider/mimo/provider.go b/internal/provider/mimo/provider.go new file mode 100644 index 00000000..0d52deba --- /dev/null +++ b/internal/provider/mimo/provider.go @@ -0,0 +1,134 @@ +package mimo + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "neo-code/internal/provider" + "neo-code/internal/provider/openaicompat" + "neo-code/internal/provider/openaicompat/chatcompletions" + providertypes "neo-code/internal/provider/types" +) + +const errorPrefix = "mimo provider: " + +type Provider struct { + cfg provider.RuntimeConfig + generateClient *http.Client + discoveryClient *http.Client +} + +func New(cfg provider.RuntimeConfig) (*Provider, error) { + if strings.TrimSpace(cfg.BaseURL) == "" { + return nil, errors.New(errorPrefix + "base url is empty") + } + if strings.TrimSpace(cfg.APIKeyEnv) == "" { + return nil, errors.New(errorPrefix + "api_key_env is empty") + } + return &Provider{ + cfg: cfg, + generateClient: &http.Client{ + Transport: http.DefaultTransport, + }, + discoveryClient: &http.Client{ + Timeout: provider.DefaultSDKRequestTimeout, + Transport: http.DefaultTransport, + }, + }, nil +} + +func (p *Provider) EstimateInputTokens(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + tokens, err := provider.EstimateSerializedPayloadTokens(payload) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil +} + +func (p *Provider) DiscoverModels(ctx context.Context) ([]providertypes.ModelDescriptor, error) { + requestCfg, err := openaicompat.RequestConfigFromRuntime(p.cfg) + if err != nil { + return nil, err + } + return openaicompat.DiscoverModelDescriptors(ctx, p.discoveryClient, requestCfg) +} + +func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) + if err != nil { + return err + } + tc := req.ThinkingConfig + + return provider.RunGenerateWithRetry(ctx, p.cfg, events, func( + attemptCtx context.Context, + attemptEvents chan<- providertypes.StreamEvent, + ) error { + return p.generateOnce(attemptCtx, payload, tc, attemptEvents) + }) +} + +func (p *Provider) generateOnce(ctx context.Context, payload chatcompletions.Request, tc *providertypes.ThinkingConfig, events chan<- providertypes.StreamEvent) error { + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("%smarshal request: %w", errorPrefix, err) + } + if tc != nil { + body, err = InjectThinkingParams(body, *tc) + if err != nil { + return fmt.Errorf("%sinject thinking params: %w", errorPrefix, err) + } + } + + apiKey, err := p.cfg.ResolveAPIKeyValue() + if err != nil { + return err + } + + endpoint := strings.TrimRight(strings.TrimSpace(p.cfg.BaseURL), "/") + "/chat/completions" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("%screate request: %w", errorPrefix, err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err := p.generateClient.Do(req) + if err != nil { + return fmt.Errorf("%ssend request: %w", errorPrefix, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return provider.WrapIfThinkingNotSupported(chatcompletions.ParseError(resp)) + } + + return chatcompletions.ConsumeStream(ctx, resp.Body, events) +} + +// InjectThinkingParams 注入 MiMo 特有的 thinking.type enabled/disabled 参数。 +func InjectThinkingParams(body []byte, tc providertypes.ThinkingConfig) ([]byte, error) { + var raw map[string]any + if err := json.Unmarshal(body, &raw); err != nil { + return nil, err + } + if tc.Enabled { + raw["thinking"] = map[string]any{"type": "enabled"} + } else { + raw["thinking"] = map[string]any{"type": "disabled"} + } + return json.Marshal(raw) +} diff --git a/internal/provider/mimo/provider_more_test.go b/internal/provider/mimo/provider_more_test.go new file mode 100644 index 00000000..dff969b1 --- /dev/null +++ b/internal/provider/mimo/provider_more_test.go @@ -0,0 +1,189 @@ +package mimo + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "neo-code/internal/provider" + "neo-code/internal/provider/openaicompat/chatcompletions" + providertypes "neo-code/internal/provider/types" +) + +func TestDriverBuildAndDiscover(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{{"id": "mimo-vl"}}, + }) + })) + defer server.Close() + + cfg := provider.RuntimeConfig{ + BaseURL: server.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + Driver: DriverName, + } + driver := Driver() + if _, err := driver.Build(context.Background(), cfg); err != nil { + t.Fatalf("Build() error = %v", err) + } + models, err := driver.Discover(context.Background(), cfg) + if err != nil { + t.Fatalf("Discover() error = %v", err) + } + if len(models) != 1 || models[0].ID != "mimo-vl" { + t.Fatalf("unexpected models: %+v", models) + } + if err := driver.ValidateCatalogIdentity(provider.ProviderIdentity{}); err != nil { + t.Fatalf("ValidateCatalogIdentity() error = %v", err) + } + if _, err := driver.Discover(context.Background(), provider.RuntimeConfig{}); err == nil { + t.Fatal("expected invalid config discover error") + } +} + +func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { + t.Parallel() + + var requestBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/models": + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{{"id": "mimo-vl"}}, + }) + case "/chat/completions": + var err error + requestBody, err = io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _, _ = w.Write([]byte(strings.Join([]string{ + `data: {"choices":[{"delta":{"reasoning_content":"plan","content":"answer"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`, + `data: [DONE]`, + "", + }, "\n"))) + default: + t.Fatalf("unexpected path %q", r.URL.Path) + } + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + BaseURL: server.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + DefaultModel: "mimo-vl", + Driver: DriverName, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + req := providertypes.GenerateRequest{ + Model: "mimo-vl", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + }}, + ThinkingConfig: &providertypes.ThinkingConfig{Enabled: false}, + } + if _, err := p.EstimateInputTokens(context.Background(), req); err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + if _, err := p.DiscoverModels(context.Background()); err != nil { + t.Fatalf("DiscoverModels() error = %v", err) + } + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), req, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + drained := drainMimoEvents(events) + if len(drained) != 3 { + t.Fatalf("expected 3 events, got %d", len(drained)) + } + if !strings.Contains(string(requestBody), `"thinking":{"type":"disabled"}`) { + t.Fatalf("request body missing disabled thinking: %s", string(requestBody)) + } + + errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{"message": "reasoning unsupported"}, + }) + })) + defer errorServer.Close() + + p, err = New(provider.RuntimeConfig{ + BaseURL: errorServer.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + DefaultModel: "mimo-vl", + Driver: DriverName, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + err = p.Generate(context.Background(), req, make(chan providertypes.StreamEvent, 1)) + if !provider.IsThinkingNotSupportedError(err) { + t.Fatalf("expected thinking-not-supported error, got %v", err) + } + + if _, err := New(provider.RuntimeConfig{APIKeyEnv: "KEY"}); err == nil { + t.Fatal("expected base url validation error") + } + if _, err := New(provider.RuntimeConfig{BaseURL: "https://example.com"}); err == nil { + t.Fatal("expected api key env validation error") + } + p.generateClient = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("network down") + })} + if err := p.Generate(context.Background(), req, make(chan providertypes.StreamEvent, 1)); err == nil || !strings.Contains(err.Error(), "send request") { + t.Fatalf("expected send request error, got %v", err) + } + invalidReq := providertypes.GenerateRequest{ + Model: "mimo-vl", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{{Kind: "invalid"}}, + }}, + } + if _, err := p.EstimateInputTokens(context.Background(), invalidReq); err == nil { + t.Fatal("expected invalid estimate request error") + } + if err := p.Generate(context.Background(), invalidReq, make(chan providertypes.StreamEvent, 1)); err == nil { + t.Fatal("expected invalid generate request error") + } + p.cfg.APIKeyResolver = provider.StaticAPIKeyResolver("") + if _, err := p.DiscoverModels(context.Background()); err == nil { + t.Fatal("expected discovery api key error") + } + if err := p.generateOnce(context.Background(), chatcompletions.Request{}, nil, make(chan providertypes.StreamEvent, 1)); err == nil { + t.Fatal("expected api key resolve error") + } +} + +func drainMimoEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { + var drained []providertypes.StreamEvent + for { + select { + case event := <-events: + drained = append(drained, event) + default: + return drained + } + } +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/internal/provider/mimo/provider_test.go b/internal/provider/mimo/provider_test.go new file mode 100644 index 00000000..e882a413 --- /dev/null +++ b/internal/provider/mimo/provider_test.go @@ -0,0 +1,54 @@ +package mimo + +import ( + "encoding/json" + "testing" + + providertypes "neo-code/internal/provider/types" +) + +func TestInjectThinkingParams_Enabled(t *testing.T) { + t.Parallel() + + body := []byte(`{"model":"test","messages":[],"stream":true}`) + result, err := InjectThinkingParams(body, providertypes.ThinkingConfig{Enabled: true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var raw map[string]any + if err := json.Unmarshal(result, &raw); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + thinking, ok := raw["thinking"].(map[string]any) + if !ok { + t.Fatalf("thinking not found or wrong type") + } + if thinking["type"] != "enabled" { + t.Fatalf("expected thinking.type=enabled, got %v", thinking["type"]) + } +} + +func TestInjectThinkingParams_Disabled(t *testing.T) { + t.Parallel() + + body := []byte(`{"model":"test","messages":[],"stream":true}`) + result, err := InjectThinkingParams(body, providertypes.ThinkingConfig{Enabled: false}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var raw map[string]any + if err := json.Unmarshal(result, &raw); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + thinking, ok := raw["thinking"].(map[string]any) + if !ok { + t.Fatalf("thinking not found") + } + if thinking["type"] != "disabled" { + t.Fatalf("expected thinking.type=disabled, got %v", thinking["type"]) + } +} diff --git a/internal/provider/minimax/adapter.go b/internal/provider/minimax/adapter.go new file mode 100644 index 00000000..1e5ad275 --- /dev/null +++ b/internal/provider/minimax/adapter.go @@ -0,0 +1,116 @@ +package minimax + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "strings" + + "neo-code/internal/provider" + providertypes "neo-code/internal/provider/types" +) + +const ( + maxSSELineSize = 256 * 1024 + maxSSEStreamTotalSize = 10 << 20 +) + +// minimaxChunk 匹配 MiniMax chat completion 响应格式。 +type minimaxChunk struct { + Choices []struct { + Delta struct { + Content string `json:"content,omitempty"` + ReasoningDetails string `json:"reasoning_details,omitempty"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage *struct { + TotalTokens int `json:"total_tokens"` + } `json:"usage,omitempty"` +} + +// ConsumeMiniMaxStream 解析 MiniMax SSE 流,优先从 reasoning_details 提取 thinking, +// 兜底从 content 中剥离 标签。 +func ConsumeMiniMaxStream(ctx context.Context, body io.Reader, events chan<- providertypes.StreamEvent) error { + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, maxSSELineSize), maxSSELineSize) + + var ( + finishReason string + usage providertypes.Usage + ) + + for scanner.Scan() { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + line := scanner.Text() + if !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimPrefix(line, "data:") + data = strings.TrimSpace(data) + + if data == "[DONE]" { + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) + } + + var chunk minimaxChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + + if chunk.Usage != nil && chunk.Usage.TotalTokens > 0 { + usage = providertypes.Usage{ + InputTokens: 0, + OutputTokens: chunk.Usage.TotalTokens, + TotalTokens: chunk.Usage.TotalTokens, + OutputObserved: true, + } + } + + for _, choice := range chunk.Choices { + if strings.TrimSpace(choice.FinishReason) != "" { + finishReason = strings.TrimSpace(choice.FinishReason) + } + + // 优先使用 reasoning_details 作为 thinking 内容 + useContent := choice.Delta.Content + if reasoning := strings.TrimSpace(choice.Delta.ReasoningDetails); reasoning != "" { + if err := provider.EmitThinkingDelta(ctx, events, reasoning); err != nil { + return err + } + } else if thinkText := ExtractThinkContent(choice.Delta.Content); thinkText != "" { + // 兜底:从 content 中剥离 标签,避免泄漏到正文 + if err := provider.EmitThinkingDelta(ctx, events, thinkText); err != nil { + return err + } + useContent = thinkTagRe.ReplaceAllString(choice.Delta.Content, "") + } + + if err := provider.EmitTextDelta(ctx, events, useContent); err != nil { + return err + } + } + } + + if err := scanner.Err(); err != nil && !errors.Is(err, io.EOF) { + return fmt.Errorf("%ssse scanner: %w", errorPrefix, err) + } + + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) +} + +func doneUsagePtr(usage providertypes.Usage) *providertypes.Usage { + if !usage.OutputObserved { + return nil + } + cp := usage + return &cp +} diff --git a/internal/provider/minimax/driver.go b/internal/provider/minimax/driver.go new file mode 100644 index 00000000..4f018bbb --- /dev/null +++ b/internal/provider/minimax/driver.go @@ -0,0 +1,29 @@ +package minimax + +import ( + "context" + + "neo-code/internal/provider" + providertypes "neo-code/internal/provider/types" +) + +const DriverName = provider.DriverMiniMax + +func Driver() provider.DriverDefinition { + return provider.DriverDefinition{ + Name: DriverName, + Build: func(ctx context.Context, cfg provider.RuntimeConfig) (provider.Provider, error) { + return New(cfg) + }, + Discover: func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { + p, err := New(cfg) + if err != nil { + return nil, err + } + return p.DiscoverModels(ctx) + }, + ValidateCatalogIdentity: func(identity provider.ProviderIdentity) error { + return nil + }, + } +} diff --git a/internal/provider/minimax/provider.go b/internal/provider/minimax/provider.go new file mode 100644 index 00000000..60253ce8 --- /dev/null +++ b/internal/provider/minimax/provider.go @@ -0,0 +1,148 @@ +package minimax + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "regexp" + "strings" + + "neo-code/internal/provider" + "neo-code/internal/provider/openaicompat" + "neo-code/internal/provider/openaicompat/chatcompletions" + providertypes "neo-code/internal/provider/types" +) + +const errorPrefix = "minimax provider: " + +var thinkTagRe = regexp.MustCompile(`([\s\S]*?)`) + +type Provider struct { + cfg provider.RuntimeConfig + generateClient *http.Client + discoveryClient *http.Client +} + +func New(cfg provider.RuntimeConfig) (*Provider, error) { + if strings.TrimSpace(cfg.BaseURL) == "" { + return nil, errors.New(errorPrefix + "base url is empty") + } + if strings.TrimSpace(cfg.APIKeyEnv) == "" { + return nil, errors.New(errorPrefix + "api_key_env is empty") + } + return &Provider{ + cfg: cfg, + generateClient: &http.Client{ + Transport: http.DefaultTransport, + }, + discoveryClient: &http.Client{ + Timeout: provider.DefaultSDKRequestTimeout, + Transport: http.DefaultTransport, + }, + }, nil +} + +func (p *Provider) EstimateInputTokens(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + tokens, err := provider.EstimateSerializedPayloadTokens(payload) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil +} + +func (p *Provider) DiscoverModels(ctx context.Context) ([]providertypes.ModelDescriptor, error) { + requestCfg, err := openaicompat.RequestConfigFromRuntime(p.cfg) + if err != nil { + return nil, err + } + return openaicompat.DiscoverModelDescriptors(ctx, p.discoveryClient, requestCfg) +} + +func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) + if err != nil { + return err + } + + return provider.RunGenerateWithRetry(ctx, p.cfg, events, func( + attemptCtx context.Context, + attemptEvents chan<- providertypes.StreamEvent, + ) error { + return p.generateOnce(attemptCtx, payload, attemptEvents) + }) +} + +func (p *Provider) generateOnce(ctx context.Context, payload chatcompletions.Request, events chan<- providertypes.StreamEvent) error { + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("%smarshal request: %w", errorPrefix, err) + } + + // MiniMax 始终发送 reasoning_split:true 和 enable_thinking:true(ThinkingForceEnabled) + body, err = injectMiniMaxParams(body) + if err != nil { + return fmt.Errorf("%sinject params: %w", errorPrefix, err) + } + + apiKey, err := p.cfg.ResolveAPIKeyValue() + if err != nil { + return err + } + + endpoint := strings.TrimRight(strings.TrimSpace(p.cfg.BaseURL), "/") + if !strings.Contains(endpoint, "/") { + endpoint += "/chat/completions" + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("%screate request: %w", errorPrefix, err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err := p.generateClient.Do(req) + if err != nil { + return fmt.Errorf("%ssend request: %w", errorPrefix, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return provider.WrapIfThinkingNotSupported(chatcompletions.ParseError(resp)) + } + + return ConsumeMiniMaxStream(ctx, resp.Body, events) +} + +func injectMiniMaxParams(body []byte) ([]byte, error) { + var raw map[string]any + if err := json.Unmarshal(body, &raw); err != nil { + return nil, err + } + raw["reasoning_split"] = true + raw["enable_thinking"] = true + return json.Marshal(raw) +} + +// ExtractThinkContent 从 MiniMax 的 标签或 reasoning_details 中提取思考文本。 +func ExtractThinkContent(content string) string { + matches := thinkTagRe.FindAllStringSubmatch(content, -1) + var builder strings.Builder + for _, match := range matches { + if len(match) > 1 { + builder.WriteString(strings.TrimSpace(match[1])) + builder.WriteString("\n") + } + } + return strings.TrimSpace(builder.String()) +} diff --git a/internal/provider/minimax/provider_more_test.go b/internal/provider/minimax/provider_more_test.go new file mode 100644 index 00000000..8490e858 --- /dev/null +++ b/internal/provider/minimax/provider_more_test.go @@ -0,0 +1,199 @@ +package minimax + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "neo-code/internal/provider" + "neo-code/internal/provider/openaicompat/chatcompletions" + providertypes "neo-code/internal/provider/types" +) + +func TestDriverBuildAndDiscover(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{{"id": "minimax-m2"}}, + }) + })) + defer server.Close() + + cfg := provider.RuntimeConfig{ + BaseURL: server.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + Driver: DriverName, + } + driver := Driver() + if _, err := driver.Build(context.Background(), cfg); err != nil { + t.Fatalf("Build() error = %v", err) + } + models, err := driver.Discover(context.Background(), cfg) + if err != nil { + t.Fatalf("Discover() error = %v", err) + } + if len(models) != 1 || models[0].ID != "minimax-m2" { + t.Fatalf("unexpected models: %+v", models) + } + if err := driver.ValidateCatalogIdentity(provider.ProviderIdentity{}); err != nil { + t.Fatalf("ValidateCatalogIdentity() error = %v", err) + } + if _, err := driver.Discover(context.Background(), provider.RuntimeConfig{}); err == nil { + t.Fatal("expected invalid config discover error") + } +} + +func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { + t.Parallel() + + var requestBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/models": + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{{"id": "minimax-m2"}}, + }) + case "/chat/completions", "/": + var err error + requestBody, err = io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _, _ = w.Write([]byte(strings.Join([]string{ + `data: {"choices":[{"delta":{"reasoning_details":"plan","content":"answer"},"finish_reason":"stop"}],"usage":{"total_tokens":5}}`, + `data: [DONE]`, + "", + }, "\n"))) + default: + t.Fatalf("unexpected path %q", r.URL.Path) + } + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + BaseURL: server.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + DefaultModel: "minimax-m2", + Driver: DriverName, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + req := providertypes.GenerateRequest{ + Model: "minimax-m2", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + }}, + } + if _, err := p.EstimateInputTokens(context.Background(), req); err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + if _, err := p.DiscoverModels(context.Background()); err != nil { + t.Fatalf("DiscoverModels() error = %v", err) + } + p, err = New(provider.RuntimeConfig{ + BaseURL: server.URL + "/chat/completions", + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + DefaultModel: "minimax-m2", + Driver: DriverName, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), req, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + drained := drainMiniMaxProviderEvents(events) + if len(drained) != 3 { + t.Fatalf("expected 3 events, got %d", len(drained)) + } + if !strings.Contains(string(requestBody), `"reasoning_split":true`) || + !strings.Contains(string(requestBody), `"enable_thinking":true`) { + t.Fatalf("request body missing minimax params: %s", string(requestBody)) + } + + errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{"message": "thinking unsupported"}, + }) + })) + defer errorServer.Close() + + p, err = New(provider.RuntimeConfig{ + BaseURL: errorServer.URL + "/chat/completions", + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + DefaultModel: "minimax-m2", + Driver: DriverName, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + err = p.Generate(context.Background(), req, make(chan providertypes.StreamEvent, 1)) + if !provider.IsThinkingNotSupportedError(err) { + t.Fatalf("expected thinking-not-supported error, got %v", err) + } + + if _, err := New(provider.RuntimeConfig{APIKeyEnv: "KEY"}); err == nil { + t.Fatal("expected base url validation error") + } + if _, err := New(provider.RuntimeConfig{BaseURL: "https://example.com"}); err == nil { + t.Fatal("expected api key env validation error") + } + p.generateClient = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("network down") + })} + if err := p.Generate(context.Background(), req, make(chan providertypes.StreamEvent, 1)); err == nil || !strings.Contains(err.Error(), "send request") { + t.Fatalf("expected send request error, got %v", err) + } + invalidReq := providertypes.GenerateRequest{ + Model: "minimax-m2", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{{Kind: "invalid"}}, + }}, + } + if _, err := p.EstimateInputTokens(context.Background(), invalidReq); err == nil { + t.Fatal("expected invalid estimate request error") + } + if err := p.Generate(context.Background(), invalidReq, make(chan providertypes.StreamEvent, 1)); err == nil { + t.Fatal("expected invalid generate request error") + } + p.cfg.APIKeyResolver = provider.StaticAPIKeyResolver("") + if _, err := p.DiscoverModels(context.Background()); err == nil { + t.Fatal("expected discovery api key error") + } + if err := p.generateOnce(context.Background(), chatcompletions.Request{}, make(chan providertypes.StreamEvent, 1)); err == nil { + t.Fatal("expected api key resolve error") + } +} + +func drainMiniMaxProviderEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { + var drained []providertypes.StreamEvent + for { + select { + case event := <-events: + drained = append(drained, event) + default: + return drained + } + } +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/internal/provider/minimax/provider_test.go b/internal/provider/minimax/provider_test.go new file mode 100644 index 00000000..631aea5a --- /dev/null +++ b/internal/provider/minimax/provider_test.go @@ -0,0 +1,136 @@ +package minimax + +import ( + "context" + "encoding/json" + "strings" + "testing" + + providertypes "neo-code/internal/provider/types" +) + +func TestInjectMiniMaxParams(t *testing.T) { + t.Parallel() + + body := []byte(`{"model":"test","messages":[],"stream":true}`) + result, err := injectMiniMaxParams(body) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var raw map[string]any + if err := json.Unmarshal(result, &raw); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if raw["reasoning_split"] != true { + t.Fatalf("expected reasoning_split=true, got %v", raw["reasoning_split"]) + } + if raw["enable_thinking"] != true { + t.Fatalf("expected enable_thinking=true, got %v", raw["enable_thinking"]) + } +} + +func TestExtractThinkContent_WithTags(t *testing.T) { + t.Parallel() + + content := "Some text internal reasoning here final answer" + result := ExtractThinkContent(content) + if result != "internal reasoning here" { + t.Fatalf("expected 'internal reasoning here', got %q", result) + } +} + +func TestExtractThinkContent_MultipleTags(t *testing.T) { + t.Parallel() + + content := "first thought action second thought done" + result := ExtractThinkContent(content) + if result != "first thought\nsecond thought" { + t.Fatalf("expected 'first thought\\nsecond thought', got %q", result) + } +} + +func TestExtractThinkContent_NoTags(t *testing.T) { + t.Parallel() + + result := ExtractThinkContent("plain text without tags") + if result != "" { + t.Fatalf("expected empty, got %q", result) + } +} + +func TestConsumeMiniMaxStream(t *testing.T) { + t.Parallel() + + body := strings.Join([]string{ + `data: {"choices":[{"delta":{"reasoning_details":"internal plan","content":"visible answer"}}],"usage":{"total_tokens":9}}`, + `data: [DONE]`, + "", + }, "\n") + + events := make(chan providertypes.StreamEvent, 4) + if err := ConsumeMiniMaxStream(context.Background(), strings.NewReader(body), events); err != nil { + t.Fatalf("ConsumeMiniMaxStream() error = %v", err) + } + + drained := drainMiniMaxEvents(events) + if len(drained) != 3 { + t.Fatalf("expected 3 events, got %d (%+v)", len(drained), drained) + } + thinking, err := drained[0].ThinkingDeltaValue() + if err != nil || thinking.Text != "internal plan" { + t.Fatalf("expected thinking delta, got err=%v event=%+v", err, drained[0]) + } + text, err := drained[1].TextDeltaValue() + if err != nil || text.Text != "visible answer" { + t.Fatalf("expected text delta, got err=%v event=%+v", err, drained[1]) + } + done, err := drained[2].MessageDoneValue() + if err != nil { + t.Fatalf("expected message done, got err=%v", err) + } + if done.Usage == nil || done.Usage.TotalTokens != 9 { + t.Fatalf("unexpected usage payload: %+v", done.Usage) + } +} + +func TestConsumeMiniMaxStreamExtractsThinkTagsFromContent(t *testing.T) { + t.Parallel() + + body := strings.Join([]string{ + `data: {"choices":[{"delta":{"content":"internal planvisible answer"},"finish_reason":"stop"}],"usage":{"total_tokens":5}}`, + `data: [DONE]`, + "", + }, "\n") + + events := make(chan providertypes.StreamEvent, 4) + if err := ConsumeMiniMaxStream(context.Background(), strings.NewReader(body), events); err != nil { + t.Fatalf("ConsumeMiniMaxStream() error = %v", err) + } + + drained := drainMiniMaxEvents(events) + if len(drained) != 3 { + t.Fatalf("expected 3 events, got %d (%+v)", len(drained), drained) + } + thinking, err := drained[0].ThinkingDeltaValue() + if err != nil || thinking.Text != "internal plan" { + t.Fatalf("expected extracted think tag, got err=%v event=%+v", err, drained[0]) + } + text, err := drained[1].TextDeltaValue() + if err != nil || text.Text != "visible answer" { + t.Fatalf("expected think tags removed from text, got err=%v event=%+v", err, drained[1]) + } +} + +func drainMiniMaxEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { + drained := make([]providertypes.StreamEvent, 0, len(events)) + for { + select { + case event := <-events: + drained = append(drained, event) + default: + return drained + } + } +} diff --git a/internal/provider/openaicompat/chatcompletions/adapter.go b/internal/provider/openaicompat/chatcompletions/adapter.go index ba3ffc5c..986eb1a4 100644 --- a/internal/provider/openaicompat/chatcompletions/adapter.go +++ b/internal/provider/openaicompat/chatcompletions/adapter.go @@ -86,8 +86,10 @@ const ( type streamChunk struct { Choices []struct { Delta struct { - Content string `json:"content,omitempty"` - ToolCalls []openai.ChatCompletionChunkChoiceDeltaToolCall `json:"tool_calls,omitempty"` + Content string `json:"content,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + Reasoning string `json:"reasoning,omitempty"` + ToolCalls []openai.ChatCompletionChunkChoiceDeltaToolCall `json:"tool_calls,omitempty"` } `json:"delta"` FinishReason string `json:"finish_reason"` } `json:"choices"` @@ -139,6 +141,13 @@ func ConsumeStream( if strings.TrimSpace(choice.FinishReason) != "" { finishReason = strings.TrimSpace(choice.FinishReason) } + reasoningText := choice.Delta.ReasoningContent + if reasoningText == "" { + reasoningText = choice.Delta.Reasoning + } + if err := provider.EmitThinkingDelta(ctx, events, reasoningText); err != nil { + return err + } if err := provider.EmitTextDelta(ctx, events, choice.Delta.Content); err != nil { return err } diff --git a/internal/provider/openaicompat/chatcompletions/adapter_test.go b/internal/provider/openaicompat/chatcompletions/adapter_test.go index d1eb195c..dd70b61d 100644 --- a/internal/provider/openaicompat/chatcompletions/adapter_test.go +++ b/internal/provider/openaicompat/chatcompletions/adapter_test.go @@ -104,6 +104,39 @@ func TestConsumeStreamParsesMultilineDataEvent(t *testing.T) { } } +func TestConsumeStreamEmitsThinkingDeltaFromReasoningFields(t *testing.T) { + t.Parallel() + + body := strings.Join([]string{ + `data: {"choices":[{"delta":{"reasoning_content":"plan","content":"answer"}}]}`, + `data: {"choices":[{"delta":{"reasoning":"fallback reasoning"}}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`, + `data: [DONE]`, + "", + }, "\n") + + events := make(chan providertypes.StreamEvent, 8) + if err := ConsumeStream(context.Background(), strings.NewReader(body), events); err != nil { + t.Fatalf("ConsumeStream() error = %v", err) + } + + drained := drainEvents(events) + if len(drained) != 4 { + t.Fatalf("expected 4 events, got %d (%+v)", len(drained), drained) + } + first, err := drained[0].ThinkingDeltaValue() + if err != nil || first.Text != "plan" { + t.Fatalf("expected reasoning_content thinking delta, got err=%v event=%+v", err, drained[0]) + } + text, err := drained[1].TextDeltaValue() + if err != nil || text.Text != "answer" { + t.Fatalf("expected text delta, got err=%v event=%+v", err, drained[1]) + } + second, err := drained[2].ThinkingDeltaValue() + if err != nil || second.Text != "fallback reasoning" { + t.Fatalf("expected reasoning fallback thinking delta, got err=%v event=%+v", err, drained[2]) + } +} + func TestConsumeStreamEOFWithoutDoneAndWithoutFinishReason(t *testing.T) { t.Parallel() diff --git a/internal/provider/openaicompat/chatcompletions/request.go b/internal/provider/openaicompat/chatcompletions/request.go index 9efa2391..27b483a9 100644 --- a/internal/provider/openaicompat/chatcompletions/request.go +++ b/internal/provider/openaicompat/chatcompletions/request.go @@ -17,7 +17,6 @@ import ( const errorPrefix = "openaicompat provider: " -const maxSessionAssetReadBytes = session.MaxSessionAssetBytes const maxSessionAssetsTotalBytes = provider.MaxSessionAssetsTotalBytes const htmlErrorSnippetMaxRunes = 320 @@ -82,6 +81,14 @@ func BuildRequest(ctx context.Context, cfg provider.RuntimeConfig, req providert } } + if tc := req.ThinkingConfig; tc != nil { + if tc.Enabled && tc.Effort != "" { + payload.ReasoningEffort = tc.Effort + } else if !tc.Enabled { + payload.ReasoningEffort = "none" + } + } + return payload, nil } @@ -222,6 +229,20 @@ func toOpenAIMessageWithBudget( } } + if len(message.ThinkingMetadata) > 0 { + var meta struct { + ReasoningContent string `json:"reasoning_content"` + Reasoning string `json:"reasoning"` + } + if err := json.Unmarshal(message.ThinkingMetadata, &meta); err == nil { + if meta.ReasoningContent != "" { + out.ReasoningContent = meta.ReasoningContent + } else if meta.Reasoning != "" { + out.ReasoningContent = meta.Reasoning + } + } + } + return out, usedAssetBytes, nil } diff --git a/internal/provider/openaicompat/chatcompletions/request_test.go b/internal/provider/openaicompat/chatcompletions/request_test.go index 57a7905d..8930df86 100644 --- a/internal/provider/openaicompat/chatcompletions/request_test.go +++ b/internal/provider/openaicompat/chatcompletions/request_test.go @@ -76,6 +76,46 @@ func TestBuildRequestUsesDefaultModelAndNormalizesTools(t *testing.T) { } } +func TestBuildRequestThinkingConfigAndContinuity(t *testing.T) { + t.Parallel() + + payload, err := BuildRequest(context.Background(), provider.RuntimeConfig{DefaultModel: "gpt-default"}, providertypes.GenerateRequest{ + Messages: []providertypes.Message{ + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + ThinkingMetadata: []byte(`{"reasoning":"step by step"}`), + }, + }, + ThinkingConfig: &providertypes.ThinkingConfig{ + Enabled: true, + Effort: "high", + }, + }) + if err != nil { + t.Fatalf("BuildRequest() error = %v", err) + } + if payload.ReasoningEffort != "high" { + t.Fatalf("expected reasoning effort to be preserved, got %q", payload.ReasoningEffort) + } + if len(payload.Messages) != 1 || payload.Messages[0].ReasoningContent != "step by step" { + t.Fatalf("expected reasoning continuity content, got %+v", payload.Messages) + } + + disabled, err := BuildRequest(context.Background(), provider.RuntimeConfig{DefaultModel: "gpt-default"}, providertypes.GenerateRequest{ + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}, + }, + ThinkingConfig: &providertypes.ThinkingConfig{Enabled: false}, + }) + if err != nil { + t.Fatalf("BuildRequest() disabled error = %v", err) + } + if disabled.ReasoningEffort != "none" { + t.Fatalf("expected disabled reasoning effort marker, got %q", disabled.ReasoningEffort) + } +} + func TestBuildRequestAndToOpenAIMessageErrors(t *testing.T) { t.Parallel() diff --git a/internal/provider/openaicompat/chatcompletions/types.go b/internal/provider/openaicompat/chatcompletions/types.go index bc592f16..18ce33d5 100644 --- a/internal/provider/openaicompat/chatcompletions/types.go +++ b/internal/provider/openaicompat/chatcompletions/types.go @@ -5,19 +5,21 @@ package chatcompletions // Request 表示 /chat/completions 端点的请求体。 type Request struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Tools []ToolDefinition `json:"tools,omitempty"` - ToolChoice string `json:"tool_choice,omitempty"` - Stream bool `json:"stream"` + Model string `json:"model"` + Messages []Message `json:"messages"` + Tools []ToolDefinition `json:"tools,omitempty"` + ToolChoice string `json:"tool_choice,omitempty"` + Stream bool `json:"stream"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` } // Message 表示 OpenAI 协议中的消息格式。 type Message struct { - Role string `json:"role"` - Content any `json:"content,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Role string `json:"role"` + Content any `json:"content,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` } // MessageContentPart 表示多模态消息的单个部分。 diff --git a/internal/provider/openaicompat/generate_sdk.go b/internal/provider/openaicompat/generate_sdk.go index 6b72214e..3bd0a671 100644 --- a/internal/provider/openaicompat/generate_sdk.go +++ b/internal/provider/openaicompat/generate_sdk.go @@ -319,7 +319,7 @@ func (p *Provider) generateChatCompletionsWithCompatibleStream( defer resp.Body.Close() if resp.StatusCode >= http.StatusBadRequest { - return ParseError(resp) + return provider.WrapIfThinkingNotSupported(ParseError(resp)) } return chatcompletions.ConsumeStream(ctx, resp.Body, events) @@ -373,7 +373,7 @@ func (p *Provider) generateSDKResponses( defer resp.Body.Close() if resp.StatusCode >= http.StatusBadRequest { - return ParseError(resp) + return provider.WrapIfThinkingNotSupported(ParseError(resp)) } return responses.EmitFromStream(ctx, resp.Body, events) diff --git a/internal/provider/openaicompat/glm/driver.go b/internal/provider/openaicompat/glm/driver.go new file mode 100644 index 00000000..c2d01cce --- /dev/null +++ b/internal/provider/openaicompat/glm/driver.go @@ -0,0 +1,29 @@ +package glm + +import ( + "context" + + "neo-code/internal/provider" + providertypes "neo-code/internal/provider/types" +) + +const DriverName = provider.DriverGLM + +func Driver() provider.DriverDefinition { + return provider.DriverDefinition{ + Name: DriverName, + Build: func(ctx context.Context, cfg provider.RuntimeConfig) (provider.Provider, error) { + return New(cfg) + }, + Discover: func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { + p, err := New(cfg) + if err != nil { + return nil, err + } + return p.DiscoverModels(ctx) + }, + ValidateCatalogIdentity: func(identity provider.ProviderIdentity) error { + return nil + }, + } +} diff --git a/internal/provider/openaicompat/glm/provider.go b/internal/provider/openaicompat/glm/provider.go new file mode 100644 index 00000000..7bc9e151 --- /dev/null +++ b/internal/provider/openaicompat/glm/provider.go @@ -0,0 +1,135 @@ +package glm + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "neo-code/internal/provider" + "neo-code/internal/provider/openaicompat" + "neo-code/internal/provider/openaicompat/chatcompletions" + providertypes "neo-code/internal/provider/types" +) + +const errorPrefix = "glm provider: " + +type Provider struct { + cfg provider.RuntimeConfig + generateClient *http.Client + discoveryClient *http.Client +} + +func New(cfg provider.RuntimeConfig) (*Provider, error) { + if strings.TrimSpace(cfg.BaseURL) == "" { + return nil, errors.New(errorPrefix + "base url is empty") + } + if strings.TrimSpace(cfg.APIKeyEnv) == "" { + return nil, errors.New(errorPrefix + "api_key_env is empty") + } + return &Provider{ + cfg: cfg, + generateClient: &http.Client{ + Transport: http.DefaultTransport, + }, + discoveryClient: &http.Client{ + Timeout: provider.DefaultSDKRequestTimeout, + Transport: http.DefaultTransport, + }, + }, nil +} + +func (p *Provider) EstimateInputTokens(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + tokens, err := provider.EstimateSerializedPayloadTokens(payload) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil +} + +func (p *Provider) DiscoverModels(ctx context.Context) ([]providertypes.ModelDescriptor, error) { + requestCfg, err := openaicompat.RequestConfigFromRuntime(p.cfg) + if err != nil { + return nil, err + } + return openaicompat.DiscoverModelDescriptors(ctx, p.discoveryClient, requestCfg) +} + +func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) + if err != nil { + return err + } + tc := req.ThinkingConfig + + return provider.RunGenerateWithRetry(ctx, p.cfg, events, func( + attemptCtx context.Context, + attemptEvents chan<- providertypes.StreamEvent, + ) error { + return p.generateOnce(attemptCtx, payload, tc, attemptEvents) + }) +} + +func (p *Provider) generateOnce(ctx context.Context, payload chatcompletions.Request, tc *providertypes.ThinkingConfig, events chan<- providertypes.StreamEvent) error { + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("%smarshal request: %w", errorPrefix, err) + } + if tc != nil { + body, err = injectGLMParams(body, *tc) + if err != nil { + return fmt.Errorf("%sinject params: %w", errorPrefix, err) + } + } + + apiKey, err := p.cfg.ResolveAPIKeyValue() + if err != nil { + return err + } + + endpoint := strings.TrimRight(strings.TrimSpace(p.cfg.BaseURL), "/") + "/chat/completions" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("%screate request: %w", errorPrefix, err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err := p.generateClient.Do(req) + if err != nil { + return fmt.Errorf("%ssend request: %w", errorPrefix, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return provider.WrapIfThinkingNotSupported(chatcompletions.ParseError(resp)) + } + + return chatcompletions.ConsumeStream(ctx, resp.Body, events) +} + +// injectGLMParams 注入 GLM 特有的 enable_thinking 和 chat_template_kwargs 参数。 +func injectGLMParams(body []byte, tc providertypes.ThinkingConfig) ([]byte, error) { + var raw map[string]any + if err := json.Unmarshal(body, &raw); err != nil { + return nil, err + } + + raw["enable_thinking"] = tc.Enabled + raw["chat_template_kwargs"] = map[string]any{ + "enable_thinking": tc.Enabled, + "clear_thinking": !tc.Enabled, + } + return json.Marshal(raw) +} diff --git a/internal/provider/openaicompat/glm/provider_more_test.go b/internal/provider/openaicompat/glm/provider_more_test.go new file mode 100644 index 00000000..b0df128e --- /dev/null +++ b/internal/provider/openaicompat/glm/provider_more_test.go @@ -0,0 +1,190 @@ +package glm + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "neo-code/internal/provider" + "neo-code/internal/provider/openaicompat/chatcompletions" + providertypes "neo-code/internal/provider/types" +) + +func TestDriverBuildAndDiscover(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{{"id": "glm-5.1"}}, + }) + })) + defer server.Close() + + cfg := provider.RuntimeConfig{ + BaseURL: server.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + Driver: DriverName, + } + driver := Driver() + if _, err := driver.Build(context.Background(), cfg); err != nil { + t.Fatalf("Build() error = %v", err) + } + models, err := driver.Discover(context.Background(), cfg) + if err != nil { + t.Fatalf("Discover() error = %v", err) + } + if len(models) != 1 || models[0].ID != "glm-5.1" { + t.Fatalf("unexpected models: %+v", models) + } + if err := driver.ValidateCatalogIdentity(provider.ProviderIdentity{}); err != nil { + t.Fatalf("ValidateCatalogIdentity() error = %v", err) + } + if _, err := driver.Discover(context.Background(), provider.RuntimeConfig{}); err == nil { + t.Fatal("expected invalid config discover error") + } +} + +func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { + t.Parallel() + + var requestBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/models": + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{{"id": "glm-5.1"}}, + }) + case "/chat/completions": + var err error + requestBody, err = io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _, _ = w.Write([]byte(strings.Join([]string{ + `data: {"choices":[{"delta":{"reasoning_content":"plan","content":"answer"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`, + `data: [DONE]`, + "", + }, "\n"))) + default: + t.Fatalf("unexpected path %q", r.URL.Path) + } + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + BaseURL: server.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + DefaultModel: "glm-5.1", + Driver: DriverName, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + req := providertypes.GenerateRequest{ + Model: "glm-5.1", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + }}, + ThinkingConfig: &providertypes.ThinkingConfig{Enabled: false}, + } + if _, err := p.EstimateInputTokens(context.Background(), req); err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + if _, err := p.DiscoverModels(context.Background()); err != nil { + t.Fatalf("DiscoverModels() error = %v", err) + } + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), req, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + drained := drainGLMEvents(events) + if len(drained) != 3 { + t.Fatalf("expected 3 events, got %d", len(drained)) + } + if !strings.Contains(string(requestBody), `"enable_thinking":false`) || + !strings.Contains(string(requestBody), `"clear_thinking":true`) { + t.Fatalf("request body missing glm params: %s", string(requestBody)) + } + + errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{"message": "reasoning unsupported"}, + }) + })) + defer errorServer.Close() + + p, err = New(provider.RuntimeConfig{ + BaseURL: errorServer.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + DefaultModel: "glm-5.1", + Driver: DriverName, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + err = p.Generate(context.Background(), req, make(chan providertypes.StreamEvent, 1)) + if !provider.IsThinkingNotSupportedError(err) { + t.Fatalf("expected thinking-not-supported error, got %v", err) + } + + if _, err := New(provider.RuntimeConfig{APIKeyEnv: "KEY"}); err == nil { + t.Fatal("expected base url validation error") + } + if _, err := New(provider.RuntimeConfig{BaseURL: "https://example.com"}); err == nil { + t.Fatal("expected api key env validation error") + } + p.generateClient = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("network down") + })} + if err := p.Generate(context.Background(), req, make(chan providertypes.StreamEvent, 1)); err == nil || !strings.Contains(err.Error(), "send request") { + t.Fatalf("expected send request error, got %v", err) + } + invalidReq := providertypes.GenerateRequest{ + Model: "glm-5.1", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{{Kind: "invalid"}}, + }}, + } + if _, err := p.EstimateInputTokens(context.Background(), invalidReq); err == nil { + t.Fatal("expected invalid estimate request error") + } + if err := p.Generate(context.Background(), invalidReq, make(chan providertypes.StreamEvent, 1)); err == nil { + t.Fatal("expected invalid generate request error") + } + p.cfg.APIKeyResolver = provider.StaticAPIKeyResolver("") + if _, err := p.DiscoverModels(context.Background()); err == nil { + t.Fatal("expected discovery api key error") + } + if err := p.generateOnce(context.Background(), chatcompletions.Request{}, nil, make(chan providertypes.StreamEvent, 1)); err == nil { + t.Fatal("expected api key resolve error") + } +} + +func drainGLMEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { + var drained []providertypes.StreamEvent + for { + select { + case event := <-events: + drained = append(drained, event) + default: + return drained + } + } +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/internal/provider/openaicompat/glm/provider_test.go b/internal/provider/openaicompat/glm/provider_test.go new file mode 100644 index 00000000..68beba9c --- /dev/null +++ b/internal/provider/openaicompat/glm/provider_test.go @@ -0,0 +1,65 @@ +package glm + +import ( + "encoding/json" + "testing" + + providertypes "neo-code/internal/provider/types" +) + +func TestInjectGLMParams_Enabled(t *testing.T) { + t.Parallel() + + body := []byte(`{"model":"test","messages":[],"stream":true}`) + result, err := injectGLMParams(body, providertypes.ThinkingConfig{Enabled: true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var raw map[string]any + if err := json.Unmarshal(result, &raw); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if raw["enable_thinking"] != true { + t.Fatalf("expected enable_thinking=true, got %v", raw["enable_thinking"]) + } + + kwargs, ok := raw["chat_template_kwargs"].(map[string]any) + if !ok { + t.Fatalf("chat_template_kwargs not found") + } + if kwargs["enable_thinking"] != true { + t.Fatalf("expected kwargs.enable_thinking=true, got %v", kwargs["enable_thinking"]) + } + if kwargs["clear_thinking"] != false { + t.Fatalf("expected kwargs.clear_thinking=false, got %v", kwargs["clear_thinking"]) + } +} + +func TestInjectGLMParams_Disabled(t *testing.T) { + t.Parallel() + + body := []byte(`{"model":"test","messages":[],"stream":true}`) + result, err := injectGLMParams(body, providertypes.ThinkingConfig{Enabled: false}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var raw map[string]any + if err := json.Unmarshal(result, &raw); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if raw["enable_thinking"] != false { + t.Fatalf("expected enable_thinking=false, got %v", raw["enable_thinking"]) + } + + kwargs, ok := raw["chat_template_kwargs"].(map[string]any) + if !ok { + t.Fatalf("chat_template_kwargs not found") + } + if kwargs["clear_thinking"] != true { + t.Fatalf("expected kwargs.clear_thinking=true, got %v", kwargs["clear_thinking"]) + } +} diff --git a/internal/provider/openaicompat/qwen/driver.go b/internal/provider/openaicompat/qwen/driver.go new file mode 100644 index 00000000..28dcc141 --- /dev/null +++ b/internal/provider/openaicompat/qwen/driver.go @@ -0,0 +1,29 @@ +package qwen + +import ( + "context" + + "neo-code/internal/provider" + providertypes "neo-code/internal/provider/types" +) + +const DriverName = provider.DriverQwen + +func Driver() provider.DriverDefinition { + return provider.DriverDefinition{ + Name: DriverName, + Build: func(ctx context.Context, cfg provider.RuntimeConfig) (provider.Provider, error) { + return New(cfg) + }, + Discover: func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { + p, err := New(cfg) + if err != nil { + return nil, err + } + return p.DiscoverModels(ctx) + }, + ValidateCatalogIdentity: func(identity provider.ProviderIdentity) error { + return nil + }, + } +} diff --git a/internal/provider/openaicompat/qwen/provider.go b/internal/provider/openaicompat/qwen/provider.go new file mode 100644 index 00000000..3105618e --- /dev/null +++ b/internal/provider/openaicompat/qwen/provider.go @@ -0,0 +1,148 @@ +package qwen + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "neo-code/internal/provider" + "neo-code/internal/provider/openaicompat" + "neo-code/internal/provider/openaicompat/chatcompletions" + providertypes "neo-code/internal/provider/types" +) + +const errorPrefix = "qwen provider: " + +type Provider struct { + cfg provider.RuntimeConfig + generateClient *http.Client + discoveryClient *http.Client +} + +func New(cfg provider.RuntimeConfig) (*Provider, error) { + if strings.TrimSpace(cfg.BaseURL) == "" { + return nil, errors.New(errorPrefix + "base url is empty") + } + if strings.TrimSpace(cfg.APIKeyEnv) == "" { + return nil, errors.New(errorPrefix + "api_key_env is empty") + } + return &Provider{ + cfg: cfg, + generateClient: &http.Client{ + Transport: http.DefaultTransport, + }, + discoveryClient: &http.Client{ + Timeout: provider.DefaultSDKRequestTimeout, + Transport: http.DefaultTransport, + }, + }, nil +} + +func (p *Provider) EstimateInputTokens(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + tokens, err := provider.EstimateSerializedPayloadTokens(payload) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil +} + +func (p *Provider) DiscoverModels(ctx context.Context) ([]providertypes.ModelDescriptor, error) { + requestCfg, err := openaicompat.RequestConfigFromRuntime(p.cfg) + if err != nil { + return nil, err + } + return openaicompat.DiscoverModelDescriptors(ctx, p.discoveryClient, requestCfg) +} + +func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) + if err != nil { + return err + } + tc := req.ThinkingConfig + + return provider.RunGenerateWithRetry(ctx, p.cfg, events, func( + attemptCtx context.Context, + attemptEvents chan<- providertypes.StreamEvent, + ) error { + return p.generateOnce(attemptCtx, payload, tc, attemptEvents) + }) +} + +func (p *Provider) generateOnce(ctx context.Context, payload chatcompletions.Request, tc *providertypes.ThinkingConfig, events chan<- providertypes.StreamEvent) error { + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("%smarshal request: %w", errorPrefix, err) + } + if tc != nil { + body, err = injectQwenParams(body, *tc) + if err != nil { + return fmt.Errorf("%sinject params: %w", errorPrefix, err) + } + } + + apiKey, err := p.cfg.ResolveAPIKeyValue() + if err != nil { + return err + } + + endpoint := strings.TrimRight(strings.TrimSpace(p.cfg.BaseURL), "/") + "/chat/completions" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("%screate request: %w", errorPrefix, err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err := p.generateClient.Do(req) + if err != nil { + return fmt.Errorf("%ssend request: %w", errorPrefix, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return provider.WrapIfThinkingNotSupported(chatcompletions.ParseError(resp)) + } + + return chatcompletions.ConsumeStream(ctx, resp.Body, events) +} + +// injectQwenParams 注入 Qwen 特有的 enable_thinking 平级布尔参数及推荐采样参数。 +func injectQwenParams(body []byte, tc providertypes.ThinkingConfig) ([]byte, error) { + var raw map[string]any + if err := json.Unmarshal(body, &raw); err != nil { + return nil, err + } + + if tc.Enabled { + raw["enable_thinking"] = true + // thinking 模式推荐采样参数 + if _, ok := raw["temperature"]; !ok { + raw["temperature"] = 0.6 + } + if _, ok := raw["top_p"]; !ok { + raw["top_p"] = 0.95 + } + } else { + raw["enable_thinking"] = false + if _, ok := raw["temperature"]; !ok { + raw["temperature"] = 0.7 + } + if _, ok := raw["top_p"]; !ok { + raw["top_p"] = 0.8 + } + } + return json.Marshal(raw) +} diff --git a/internal/provider/openaicompat/qwen/provider_more_test.go b/internal/provider/openaicompat/qwen/provider_more_test.go new file mode 100644 index 00000000..408ad83c --- /dev/null +++ b/internal/provider/openaicompat/qwen/provider_more_test.go @@ -0,0 +1,190 @@ +package qwen + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "neo-code/internal/provider" + "neo-code/internal/provider/openaicompat/chatcompletions" + providertypes "neo-code/internal/provider/types" +) + +func TestDriverBuildAndDiscover(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{{"id": "qwen-max"}}, + }) + })) + defer server.Close() + + cfg := provider.RuntimeConfig{ + BaseURL: server.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + Driver: DriverName, + } + driver := Driver() + if _, err := driver.Build(context.Background(), cfg); err != nil { + t.Fatalf("Build() error = %v", err) + } + models, err := driver.Discover(context.Background(), cfg) + if err != nil { + t.Fatalf("Discover() error = %v", err) + } + if len(models) != 1 || models[0].ID != "qwen-max" { + t.Fatalf("unexpected models: %+v", models) + } + if err := driver.ValidateCatalogIdentity(provider.ProviderIdentity{}); err != nil { + t.Fatalf("ValidateCatalogIdentity() error = %v", err) + } + if _, err := driver.Discover(context.Background(), provider.RuntimeConfig{}); err == nil { + t.Fatal("expected invalid config discover error") + } +} + +func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { + t.Parallel() + + var requestBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/models": + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{{"id": "qwen-max"}}, + }) + case "/chat/completions": + var err error + requestBody, err = io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _, _ = w.Write([]byte(strings.Join([]string{ + `data: {"choices":[{"delta":{"reasoning_content":"plan","content":"answer"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`, + `data: [DONE]`, + "", + }, "\n"))) + default: + t.Fatalf("unexpected path %q", r.URL.Path) + } + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + BaseURL: server.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + DefaultModel: "qwen-max", + Driver: DriverName, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + req := providertypes.GenerateRequest{ + Model: "qwen-max", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + }}, + ThinkingConfig: &providertypes.ThinkingConfig{Enabled: true}, + } + if _, err := p.EstimateInputTokens(context.Background(), req); err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + if _, err := p.DiscoverModels(context.Background()); err != nil { + t.Fatalf("DiscoverModels() error = %v", err) + } + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), req, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + drained := drainQwenEvents(events) + if len(drained) != 3 { + t.Fatalf("expected 3 events, got %d", len(drained)) + } + if !strings.Contains(string(requestBody), `"enable_thinking":true`) || + !strings.Contains(string(requestBody), `"temperature":0.6`) { + t.Fatalf("request body missing qwen thinking params: %s", string(requestBody)) + } + + errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{"message": "thinking unsupported"}, + }) + })) + defer errorServer.Close() + + p, err = New(provider.RuntimeConfig{ + BaseURL: errorServer.URL, + APIKeyEnv: "TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("secret"), + DefaultModel: "qwen-max", + Driver: DriverName, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + err = p.Generate(context.Background(), req, make(chan providertypes.StreamEvent, 1)) + if !provider.IsThinkingNotSupportedError(err) { + t.Fatalf("expected thinking-not-supported error, got %v", err) + } + + if _, err := New(provider.RuntimeConfig{APIKeyEnv: "KEY"}); err == nil { + t.Fatal("expected base url validation error") + } + if _, err := New(provider.RuntimeConfig{BaseURL: "https://example.com"}); err == nil { + t.Fatal("expected api key env validation error") + } + p.generateClient = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("network down") + })} + if err := p.Generate(context.Background(), req, make(chan providertypes.StreamEvent, 1)); err == nil || !strings.Contains(err.Error(), "send request") { + t.Fatalf("expected send request error, got %v", err) + } + invalidReq := providertypes.GenerateRequest{ + Model: "qwen-max", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{{Kind: "invalid"}}, + }}, + } + if _, err := p.EstimateInputTokens(context.Background(), invalidReq); err == nil { + t.Fatal("expected invalid estimate request error") + } + if err := p.Generate(context.Background(), invalidReq, make(chan providertypes.StreamEvent, 1)); err == nil { + t.Fatal("expected invalid generate request error") + } + p.cfg.APIKeyResolver = provider.StaticAPIKeyResolver("") + if _, err := p.DiscoverModels(context.Background()); err == nil { + t.Fatal("expected discovery api key error") + } + if err := p.generateOnce(context.Background(), chatcompletions.Request{}, nil, make(chan providertypes.StreamEvent, 1)); err == nil { + t.Fatal("expected api key resolve error") + } +} + +func drainQwenEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { + var drained []providertypes.StreamEvent + for { + select { + case event := <-events: + drained = append(drained, event) + default: + return drained + } + } +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/internal/provider/openaicompat/qwen/provider_test.go b/internal/provider/openaicompat/qwen/provider_test.go new file mode 100644 index 00000000..6671dbdd --- /dev/null +++ b/internal/provider/openaicompat/qwen/provider_test.go @@ -0,0 +1,77 @@ +package qwen + +import ( + "encoding/json" + "testing" + + providertypes "neo-code/internal/provider/types" +) + +func TestInjectQwenParams_Enabled(t *testing.T) { + t.Parallel() + + body := []byte(`{"model":"test","messages":[],"stream":true}`) + result, err := injectQwenParams(body, providertypes.ThinkingConfig{Enabled: true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var raw map[string]any + if err := json.Unmarshal(result, &raw); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if raw["enable_thinking"] != true { + t.Fatalf("expected enable_thinking=true, got %v", raw["enable_thinking"]) + } + if raw["temperature"] != 0.6 { + t.Fatalf("expected temperature=0.6, got %v", raw["temperature"]) + } + if raw["top_p"] != 0.95 { + t.Fatalf("expected top_p=0.95, got %v", raw["top_p"]) + } +} + +func TestInjectQwenParams_Disabled(t *testing.T) { + t.Parallel() + + body := []byte(`{"model":"test","messages":[],"stream":true}`) + result, err := injectQwenParams(body, providertypes.ThinkingConfig{Enabled: false}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var raw map[string]any + if err := json.Unmarshal(result, &raw); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if raw["enable_thinking"] != false { + t.Fatalf("expected enable_thinking=false, got %v", raw["enable_thinking"]) + } + if raw["temperature"] != 0.7 { + t.Fatalf("expected temperature=0.7, got %v", raw["temperature"]) + } + if raw["top_p"] != 0.8 { + t.Fatalf("expected top_p=0.8, got %v", raw["top_p"]) + } +} + +func TestInjectQwenParams_PreservesExistingTemp(t *testing.T) { + t.Parallel() + + body := []byte(`{"model":"test","messages":[],"stream":true,"temperature":0.3}`) + result, err := injectQwenParams(body, providertypes.ThinkingConfig{Enabled: true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var raw map[string]any + if err := json.Unmarshal(result, &raw); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if raw["temperature"] != 0.3 { + t.Fatalf("expected temperature=0.3 (preserved), got %v", raw["temperature"]) + } +} diff --git a/internal/provider/stream_events.go b/internal/provider/stream_events.go index 40e606aa..5419649a 100644 --- a/internal/provider/stream_events.go +++ b/internal/provider/stream_events.go @@ -30,6 +30,14 @@ func EmitToolCallDelta(ctx context.Context, events chan<- providertypes.StreamEv return emitStreamEvent(ctx, events, providertypes.NewToolCallDeltaStreamEvent(index, id, argumentsDelta)) } +// EmitThinkingDelta 发送思考增量事件,空文本直接忽略。 +func EmitThinkingDelta(ctx context.Context, events chan<- providertypes.StreamEvent, text string) error { + if text == "" { + return nil + } + return emitStreamEvent(ctx, events, providertypes.NewThinkingDeltaStreamEvent(text)) +} + // EmitMessageDone 发送消息完成事件,并在上下文取消时做非阻塞兜底。 func EmitMessageDone(ctx context.Context, events chan<- providertypes.StreamEvent, finishReason string, usage *providertypes.Usage) error { event := providertypes.NewMessageDoneStreamEvent(finishReason, usage) @@ -48,16 +56,6 @@ func EmitMessageDone(ctx context.Context, events chan<- providertypes.StreamEven } } -// FlushDataLines 逐行处理 SSE data 缓冲区。 -func FlushDataLines(dataLines []string, processChunk func(string) error) error { - for _, line := range dataLines { - if err := processChunk(line); err != nil { - return err - } - } - return nil -} - // emitStreamEvent 安全发送流式事件,并支持上下文取消。 func emitStreamEvent(ctx context.Context, events chan<- providertypes.StreamEvent, event providertypes.StreamEvent) error { if events == nil { diff --git a/internal/provider/types/event.go b/internal/provider/types/event.go index 72df4bc1..77fe5a0f 100644 --- a/internal/provider/types/event.go +++ b/internal/provider/types/event.go @@ -14,6 +14,8 @@ const ( StreamEventToolCallDelta StreamEventType = "tool_call_delta" // StreamEventMessageDone 表示本轮消息完成,并携带最终统计信息。 StreamEventMessageDone StreamEventType = "message_done" + // StreamEventThinkingDelta 表示模型思考/推理内容的增量片段。 + StreamEventThinkingDelta StreamEventType = "thinking_delta" ) // StreamEvent 表示 provider 向 runtime 推送的流式事件。 @@ -23,6 +25,12 @@ type StreamEvent struct { ToolCallStart *ToolCallStartPayload `json:"tool_call_start,omitempty"` ToolCallDelta *ToolCallDeltaPayload `json:"tool_call_delta,omitempty"` MessageDone *MessageDonePayload `json:"message_done,omitempty"` + ThinkingDelta *ThinkingDeltaPayload `json:"thinking_delta,omitempty"` +} + +// ThinkingDeltaPayload 表示思考内容增量事件的载荷。 +type ThinkingDeltaPayload struct { + Text string `json:"text"` } // TextDeltaPayload 表示文本增量事件的载荷。 @@ -82,6 +90,14 @@ func NewMessageDoneStreamEvent(finishReason string, usage *Usage) StreamEvent { } } +// NewThinkingDeltaStreamEvent 创建思考增量流事件。 +func NewThinkingDeltaStreamEvent(text string) StreamEvent { + return StreamEvent{ + Type: StreamEventThinkingDelta, + ThinkingDelta: &ThinkingDeltaPayload{Text: text}, + } +} + // TextDeltaValue 返回 text_delta 事件的载荷,并在结构缺失时显式报错。 func (e StreamEvent) TextDeltaValue() (TextDeltaPayload, error) { if e.Type != StreamEventTextDelta { @@ -125,3 +141,14 @@ func (e StreamEvent) MessageDoneValue() (MessageDonePayload, error) { } return *e.MessageDone, nil } + +// ThinkingDeltaValue 返回 thinking_delta 事件的载荷,并在结构缺失时显式报错。 +func (e StreamEvent) ThinkingDeltaValue() (ThinkingDeltaPayload, error) { + if e.Type != StreamEventThinkingDelta { + return ThinkingDeltaPayload{}, fmt.Errorf("provider: stream event type %q is not thinking_delta", e.Type) + } + if e.ThinkingDelta == nil { + return ThinkingDeltaPayload{}, fmt.Errorf("provider: thinking_delta event payload is nil") + } + return *e.ThinkingDelta, nil +} diff --git a/internal/provider/types/message.go b/internal/provider/types/message.go index a68f2669..eeb82b94 100644 --- a/internal/provider/types/message.go +++ b/internal/provider/types/message.go @@ -1,6 +1,9 @@ package types -import "strings" +import ( + "encoding/json" + "strings" +) // RoleSystem 标识系统消息。 const RoleSystem = "system" @@ -16,12 +19,13 @@ const RoleTool = "tool" // Message 表示对话中的单条消息。 type Message struct { - Role string `json:"role"` - Parts []ContentPart `json:"parts,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` - IsError bool `json:"is_error,omitempty"` - ToolMetadata map[string]string `json:"tool_metadata,omitempty"` + Role string `json:"role"` + Parts []ContentPart `json:"parts,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + IsError bool `json:"is_error,omitempty"` + ToolMetadata map[string]string `json:"tool_metadata,omitempty"` + ThinkingMetadata json.RawMessage `json:"thinking_metadata,omitempty"` } // IsEmpty checks if the message has no content parts and no tool calls. diff --git a/internal/provider/types/model.go b/internal/provider/types/model.go index 9f96aecc..2ca9e0ca 100644 --- a/internal/provider/types/model.go +++ b/internal/provider/types/model.go @@ -16,8 +16,12 @@ const ( // ModelCapabilityHints 描述 discovery/catalog 链路共享的模型能力提示。 type ModelCapabilityHints struct { - ToolCalling ModelCapabilityState `json:"tool_calling,omitempty"` - ImageInput ModelCapabilityState `json:"image_input,omitempty"` + ToolCalling ModelCapabilityState `json:"tool_calling,omitempty"` + ImageInput ModelCapabilityState `json:"image_input,omitempty"` + Thinking ModelCapabilityState `json:"thinking,omitempty"` + ThinkingEfforts []string `json:"thinking_efforts,omitempty"` + ThinkingDefaultEffort string `json:"thinking_default_effort,omitempty"` + ThinkingForceEnabled bool `json:"thinking_force_enabled,omitempty"` } // ModelDescriptor 表示 discovery/catalog 链路共享的模型元数据描述符。 @@ -170,6 +174,18 @@ func mergeModelCapabilityHints(primary ModelCapabilityHints, secondary ModelCapa if primary.ImageInput == "" { primary.ImageInput = secondary.ImageInput } + if primary.Thinking == "" { + primary.Thinking = secondary.Thinking + } + if len(primary.ThinkingEfforts) == 0 { + primary.ThinkingEfforts = secondary.ThinkingEfforts + } + if primary.ThinkingDefaultEffort == "" { + primary.ThinkingDefaultEffort = secondary.ThinkingDefaultEffort + } + if !primary.ThinkingForceEnabled { + primary.ThinkingForceEnabled = secondary.ThinkingForceEnabled + } return normalizeModelCapabilityHints(primary) } @@ -177,6 +193,7 @@ func mergeModelCapabilityHints(primary ModelCapabilityHints, secondary ModelCapa func normalizeModelCapabilityHints(hints ModelCapabilityHints) ModelCapabilityHints { hints.ToolCalling = normalizeModelCapabilityState(string(hints.ToolCalling)) hints.ImageInput = normalizeModelCapabilityState(string(hints.ImageInput)) + hints.Thinking = normalizeModelCapabilityState(string(hints.Thinking)) return hints } @@ -189,16 +206,19 @@ func modelCapabilityHintsFromValue(value any) ModelCapabilityHints { hints := ModelCapabilityHints{} for key, item := range raw { - boolValue, ok := item.(bool) - if !ok { - continue - } - switch normalizeKey(key) { case "tool_calling", "tool_call": - hints.ToolCalling = modelCapabilityStateFromBool(boolValue) + if boolValue, ok := item.(bool); ok { + hints.ToolCalling = modelCapabilityStateFromBool(boolValue) + } case "image_input": - hints.ImageInput = modelCapabilityStateFromBool(boolValue) + if boolValue, ok := item.(bool); ok { + hints.ImageInput = modelCapabilityStateFromBool(boolValue) + } + case "thinking": + if boolValue, ok := item.(bool); ok { + hints.Thinking = modelCapabilityStateFromBool(boolValue) + } } } return normalizeModelCapabilityHints(hints) diff --git a/internal/provider/types/model_test.go b/internal/provider/types/model_test.go index 7115896a..8cc0931d 100644 --- a/internal/provider/types/model_test.go +++ b/internal/provider/types/model_test.go @@ -1,6 +1,9 @@ package types -import "testing" +import ( + "reflect" + "testing" +) func TestDescriptorFromRawModel(t *testing.T) { t.Parallel() @@ -151,7 +154,7 @@ func TestDescriptorFromRawModel(t *testing.T) { if !tt.wantOK { return } - if got != tt.want { + if !reflect.DeepEqual(got, tt.want) { t.Fatalf("expected descriptor %+v, got %+v", tt.want, got) } }) @@ -247,7 +250,7 @@ func TestModelCapabilityHintsFromValue(t *testing.T) { if result.ImageInput != ModelCapabilityStateUnsupported { t.Fatalf("expected image input unsupported, got %+v", result) } - if result := modelCapabilityHintsFromValue("not a map"); result != (ModelCapabilityHints{}) { + if result := modelCapabilityHintsFromValue("not a map"); !reflect.DeepEqual(result, ModelCapabilityHints{}) { t.Fatalf("expected empty hints for non-map, got %+v", result) } } diff --git a/internal/provider/types/request.go b/internal/provider/types/request.go index c46730fb..e7a21e4d 100644 --- a/internal/provider/types/request.go +++ b/internal/provider/types/request.go @@ -10,11 +10,19 @@ type SessionAssetReader interface { Open(ctx context.Context, assetID string) (io.ReadCloser, string, error) } +// ThinkingConfig 表示 runtime 向 provider 传递的抽象 thinking 控制指令,由各 adapter 翻译为厂商特定参数。 +type ThinkingConfig struct { + Enabled bool `json:"enabled"` + BudgetTokens int `json:"budget_tokens,omitempty"` + Effort string `json:"effort,omitempty"` +} + // GenerateRequest 是 provider.Generate() 的请求参数。 type GenerateRequest struct { Model string `json:"model"` SystemPrompt string `json:"system_prompt"` Messages []Message `json:"messages"` Tools []ToolSpec `json:"tools,omitempty"` + ThinkingConfig *ThinkingConfig `json:"thinking_config,omitempty"` SessionAssetReader SessionAssetReader `json:"-"` } diff --git a/internal/runtime/events.go b/internal/runtime/events.go index 619d1956..9f3773b4 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -317,6 +317,8 @@ const ( EventUserMessage EventType = "user_message" // EventAgentChunk 表示 assistant 流式文本分片。 EventAgentChunk EventType = "agent_chunk" + // EventThinkingDelta 表示模型思考/推理内容的流式分片。 + EventThinkingDelta EventType = "thinking_delta" // EventAgentDone 表示 assistant 正常结束。 EventAgentDone EventType = "agent_done" // EventToolStart 表示工具开始执行。 @@ -502,6 +504,3 @@ type BashSideEffectPayload struct { PreemptivelyCapturedPaths []string `json:"preemptively_captured_paths,omitempty"` UncoveredPaths []string `json:"uncovered_paths,omitempty"` } - - - diff --git a/internal/runtime/planning.go b/internal/runtime/planning.go index 3a788a99..a906951a 100644 --- a/internal/runtime/planning.go +++ b/internal/runtime/planning.go @@ -301,6 +301,10 @@ func applyCurrentPlanRevision(session *agentsession.Session, plan *agentsession. if session == nil || plan == nil { return false } + // 新 revision 覆盖时,仅取消旧 plan 明确引用的非终态 todo + if oldPlan := session.CurrentPlan; oldPlan != nil && oldPlan.Revision < plan.Revision { + agentsession.CancelTodosByIDs(session.Todos, oldPlan.Summary.ActiveTodoIDs) + } session.CurrentPlan = plan session.PlanApprovalPendingFullAlign = false session.PlanCompletionPendingFullReview = false diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 2e35795a..201c34d6 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -364,7 +364,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { if err := s.setBaseRunState(ctx, &state, controlplane.RunStateVerify); err != nil { return s.handleRunError(err) } - s.updateResumeCheckpoint(ctx, &state, "verify", "completed") + s.updateResumeCheckpoint(ctx, &state, "verify", "completed") acceptanceDecision, err := s.runBeforeCompletionDecisionAcceptance( ctx, &state, @@ -595,11 +595,21 @@ func (s *Service) prepareTurnBudgetSnapshot(ctx context.Context, state *runState promptBudget, budgetSource := s.resolvePromptBudget(ctx, cfg) model := strings.TrimSpace(cfg.CurrentModel) requestMessages := append([]providertypes.Message(nil), builtContext.Messages...) + thinkingCfg, thinkingErr := resolveThinkingConfig( + modelCapabilityHintsForRequest(model, resolvedProvider.Models), + nil, // ThinkingOverride 暂未从 TUI 透传 + s.IsThinkingEnabled(), + ) + if thinkingErr != nil { + return TurnBudgetSnapshot{}, false, thinkingErr + } + request := providertypes.GenerateRequest{ Model: model, SystemPrompt: systemPrompt, Messages: requestMessages, Tools: toolSpecs, + ThinkingConfig: thinkingCfg, SessionAssetReader: s.buildSessionAssetReader(ctx, state.session.ID), } attemptSeq := state.nextAttemptSeq @@ -658,12 +668,33 @@ func (s *Service) callProvider( OnTextDelta: func(text string) { s.emitRunScoped(ctx, EventAgentChunk, state, text) }, + OnThinkingDelta: func(text string) { + s.emitRunScoped(ctx, EventThinkingDelta, state, text) + }, OnToolCallStart: func(payload providertypes.ToolCallStartPayload) { s.emitRunScoped(ctx, EventToolCallThinking, state, payload.Name) }, }) if streamOutcome.err != nil { - return turnProviderOutput{}, streamOutcome.err + // unknown 模型 + ErrThinkingNotSupported → 重试不带 ThinkingConfig + if provider.IsThinkingNotSupportedError(streamOutcome.err) && snapshot.Request.ThinkingConfig != nil { + retryRequest := snapshot.Request + retryRequest.ThinkingConfig = nil + retryOutcome := generateStreamingMessage(ctx, modelProvider, retryRequest, streaming.Hooks{ + OnTextDelta: func(text string) { + s.emitRunScoped(ctx, EventAgentChunk, state, text) + }, + OnToolCallStart: func(payload providertypes.ToolCallStartPayload) { + s.emitRunScoped(ctx, EventToolCallThinking, state, payload.Name) + }, + }) + if retryOutcome.err != nil { + return turnProviderOutput{}, retryOutcome.err + } + streamOutcome = retryOutcome + } else { + return turnProviderOutput{}, streamOutcome.err + } } return turnProviderOutput{ diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index aa726ccc..93593a4c 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -74,12 +74,19 @@ type UserImageInput struct { // PrepareInput 表示进入 runtime 归一化前的领域输入(仅包含文本/图片/会话上下文)。 type PrepareInput struct { - SessionID string - RunID string - Workdir string - Mode string - Text string - Images []UserImageInput + SessionID string + RunID string + Workdir string + Mode string + Text string + Images []UserImageInput + ThinkingOverride *ThinkingOverride `json:"thinking_override,omitempty"` +} + +// ThinkingOverride 表示用户对 thinking 能力的运行时偏好。 +type ThinkingOverride struct { + Enabled *bool `json:"enabled,omitempty"` + Effort string `json:"effort,omitempty"` } // SystemToolInput 描述一次由系统入口触发的确定性工具执行请求。 @@ -168,6 +175,8 @@ type Service struct { activeRunStates map[uint64]*runState permissionAskMapMu sync.Mutex permissionAskLocks map[string]*permissionAskLockEntry + + thinkingEnabled bool } // sessionLockEntry 维护单个会话读写锁及其当前引用计数,用于在无引用时回收 map 项。 @@ -223,6 +232,7 @@ func NewWithFactory( activeRunByID: make(map[string]uint64), activeRunTokenIDs: make(map[uint64]string), activeRunStates: make(map[uint64]*runState), + thinkingEnabled: true, } baseHookExecutor := runtimehooks.NewExecutor(runtimehooks.NewRegistry(), newHookRuntimeEventEmitter(service), runtimehooks.DefaultHookTimeout) baseHookExecutor.SetAsyncResultSink(newHookAsyncResultSink(service)) @@ -250,6 +260,20 @@ func (s *Service) SetSkillsRegistry(registry skills.Registry) { s.skillsRegistry = registry } +// SetThinkingEnabled 设置进程级 thinking 全局开关。 +func (s *Service) SetThinkingEnabled(enabled bool) { + s.runMu.Lock() + s.thinkingEnabled = enabled + s.runMu.Unlock() +} + +// IsThinkingEnabled 返回当前 thinking 全局开关状态。 +func (s *Service) IsThinkingEnabled() bool { + s.runMu.Lock() + defer s.runMu.Unlock() + return s.thinkingEnabled +} + // CancelActiveRun 尝试取消最近一次仍在执行的 Run。 func (s *Service) CancelActiveRun() bool { s.runMu.Lock() diff --git a/internal/runtime/state.go b/internal/runtime/state.go index 22528511..6d54c81f 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -55,6 +55,7 @@ type runState struct { hookNotificationSeen map[string]time.Time hookNotificationOmitted int reportedMissingSkills map[string]struct{} + thinkingOverride *ThinkingOverride } // newRunState 基于持久化会话创建一次运行的内存状态镜像。 diff --git a/internal/runtime/streaming/handler.go b/internal/runtime/streaming/handler.go index 9280b98f..8e7aae62 100644 --- a/internal/runtime/streaming/handler.go +++ b/internal/runtime/streaming/handler.go @@ -9,6 +9,7 @@ import ( // Hooks 描述 runtime 在消费 provider 流时可选的回调挂点。 type Hooks struct { OnTextDelta func(string) + OnThinkingDelta func(string) OnToolCallStart func(providertypes.ToolCallStartPayload) OnMessageDone func(providertypes.MessageDonePayload) } @@ -46,6 +47,15 @@ func HandleEvent(event providertypes.StreamEvent, acc *Accumulator, hooks Hooks) if acc != nil { acc.AccumulateToolCallDelta(payload.Index, payload.ID, payload.ArgumentsDelta) } + case providertypes.StreamEventThinkingDelta: + payload, err := event.ThinkingDeltaValue() + if err != nil { + return err + } + if hooks.OnThinkingDelta != nil { + hooks.OnThinkingDelta(payload.Text) + } + // thinking 不进入 accumulator(不混入 assistant 正文) case providertypes.StreamEventMessageDone: payload, err := event.MessageDoneValue() if err != nil { diff --git a/internal/runtime/thinking.go b/internal/runtime/thinking.go new file mode 100644 index 00000000..d48020bf --- /dev/null +++ b/internal/runtime/thinking.go @@ -0,0 +1,75 @@ +package runtime + +import ( + "fmt" + + providertypes "neo-code/internal/provider/types" +) + +// resolveThinkingConfig 根据用户覆盖、全局开关和模型能力构建 ThinkingConfig。 +// 返回 nil 表示不传递 thinking 控制参数(unsupported 模型)。 +func resolveThinkingConfig( + caps providertypes.ModelCapabilityHints, + override *ThinkingOverride, + globalEnabled bool, +) (*providertypes.ThinkingConfig, error) { + thinkingState := caps.Thinking + if thinkingState == "" { + thinkingState = providertypes.ModelCapabilityStateUnknown + } + + switch thinkingState { + case providertypes.ModelCapabilityStateUnsupported: + return nil, nil + case providertypes.ModelCapabilityStateSupported, providertypes.ModelCapabilityStateUnknown: + // 继续处理 + default: + return nil, nil + } + + enabled := globalEnabled + if override != nil && override.Enabled != nil { + enabled = *override.Enabled + } + // ThinkingForceEnabled 模型强制开启 + if caps.ThinkingForceEnabled { + enabled = true + } + + effort := caps.ThinkingDefaultEffort + if override != nil && override.Effort != "" { + effort = override.Effort + // 校验 effort 在列表内 + if len(caps.ThinkingEfforts) > 0 && !containsEffort(caps.ThinkingEfforts, effort) { + return nil, fmt.Errorf("runtime: thinking effort %q not in supported list %v", effort, caps.ThinkingEfforts) + } + } + // 空列表时不为空 effort + if len(caps.ThinkingEfforts) == 0 && effort != "" { + effort = "" + } + + return &providertypes.ThinkingConfig{ + Enabled: enabled, + Effort: effort, + }, nil +} + +func containsEffort(list []string, target string) bool { + for _, v := range list { + if v == target { + return true + } + } + return false +} + +// modelCapabilityHintsForRequest 从 provider 配置的静态模型列表中查找能力提示。 +func modelCapabilityHintsForRequest(model string, models []providertypes.ModelDescriptor) providertypes.ModelCapabilityHints { + for _, m := range models { + if m.ID == model { + return m.CapabilityHints + } + } + return providertypes.ModelCapabilityHints{} +} diff --git a/internal/runtime/thinking_callprovider_test.go b/internal/runtime/thinking_callprovider_test.go new file mode 100644 index 00000000..ee10276f --- /dev/null +++ b/internal/runtime/thinking_callprovider_test.go @@ -0,0 +1,104 @@ +package runtime + +import ( + "context" + "errors" + "testing" + + "neo-code/internal/provider" + providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" +) + +func TestCallProviderRetriesWithoutThinkingConfig(t *testing.T) { + t.Parallel() + + scripted := &scriptedProvider{} + scripted.chatFn = func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + if len(scripted.requests) == 1 { + return errors.Join(provider.ErrThinkingNotSupported, errors.New("upstream rejected thinking")) + } + events <- providertypes.NewTextDeltaStreamEvent("answer") + events <- providertypes.NewMessageDoneStreamEvent("stop", nil) + return nil + } + service := &Service{events: make(chan RuntimeEvent, 8)} + state := newRunState("run-thinking-retry", agentsession.Session{ID: "session-thinking-retry"}) + snapshot := TurnBudgetSnapshot{ + Request: providertypes.GenerateRequest{ + Model: "test-model", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + }}, + ThinkingConfig: &providertypes.ThinkingConfig{Enabled: true, Effort: "high"}, + }, + } + + output, err := service.callProvider(context.Background(), &state, snapshot, scripted) + if err != nil { + t.Fatalf("callProvider() error = %v", err) + } + if scripted.callCount != 2 { + t.Fatalf("provider calls = %d, want 2", scripted.callCount) + } + if scripted.requests[0].ThinkingConfig == nil { + t.Fatal("first request should include thinking config") + } + if scripted.requests[1].ThinkingConfig != nil { + t.Fatalf("second request should clear thinking config, got %+v", scripted.requests[1].ThinkingConfig) + } + if renderPartsForTest(output.assistant.Parts) != "answer" { + t.Fatalf("unexpected assistant output: %+v", output.assistant) + } +} + +func TestCallProviderEmitsThinkingDeltaEvent(t *testing.T) { + t.Parallel() + + scripted := &scriptedProvider{ + chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + events <- providertypes.NewThinkingDeltaStreamEvent("plan") + events <- providertypes.NewTextDeltaStreamEvent("answer") + events <- providertypes.NewMessageDoneStreamEvent("stop", nil) + return nil + }, + } + service := &Service{events: make(chan RuntimeEvent, 8)} + state := newRunState("run-thinking-event", agentsession.Session{ID: "session-thinking-event"}) + + if _, err := service.callProvider( + context.Background(), + &state, + TurnBudgetSnapshot{Request: providertypes.GenerateRequest{Model: "test-model"}}, + scripted, + ); err != nil { + t.Fatalf("callProvider() error = %v", err) + } + + events := collectThinkingRuntimeEvents(service.events) + if !hasRuntimeEvent(events, EventThinkingDelta, "plan") { + t.Fatalf("expected thinking_delta event, got %+v", events) + } +} + +func collectThinkingRuntimeEvents(ch <-chan RuntimeEvent) []RuntimeEvent { + var events []RuntimeEvent + for { + select { + case event := <-ch: + events = append(events, event) + default: + return events + } + } +} + +func hasRuntimeEvent(events []RuntimeEvent, eventType EventType, payload string) bool { + for _, event := range events { + if event.Type == eventType && event.Payload == payload { + return true + } + } + return false +} diff --git a/internal/runtime/thinking_test.go b/internal/runtime/thinking_test.go new file mode 100644 index 00000000..ec41f053 --- /dev/null +++ b/internal/runtime/thinking_test.go @@ -0,0 +1,141 @@ +package runtime + +import ( + "reflect" + "testing" + + providertypes "neo-code/internal/provider/types" +) + +func TestResolveThinkingConfig(t *testing.T) { + t.Parallel() + + t.Run("unsupported returns nil", func(t *testing.T) { + t.Parallel() + + cfg, err := resolveThinkingConfig(providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateUnsupported, + }, nil, true) + if err != nil { + t.Fatalf("resolveThinkingConfig() error = %v", err) + } + if cfg != nil { + t.Fatalf("expected nil config, got %+v", cfg) + } + }) + + t.Run("unknown follows global toggle", func(t *testing.T) { + t.Parallel() + + cfg, err := resolveThinkingConfig(providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateUnknown, + }, nil, false) + if err != nil { + t.Fatalf("resolveThinkingConfig() error = %v", err) + } + if cfg == nil || cfg.Enabled { + t.Fatalf("expected disabled config, got %+v", cfg) + } + }) + + t.Run("override and effort validation", func(t *testing.T) { + t.Parallel() + + enabled := false + cfg, err := resolveThinkingConfig(providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateSupported, + ThinkingEfforts: []string{"low", "high"}, + ThinkingDefaultEffort: "low", + }, &ThinkingOverride{Enabled: &enabled, Effort: "high"}, true) + if err != nil { + t.Fatalf("resolveThinkingConfig() error = %v", err) + } + if cfg == nil || cfg.Enabled || cfg.Effort != "high" { + t.Fatalf("expected disabled high-effort config, got %+v", cfg) + } + }) + + t.Run("force enabled wins over override", func(t *testing.T) { + t.Parallel() + + enabled := false + cfg, err := resolveThinkingConfig(providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateSupported, + ThinkingForceEnabled: true, + }, &ThinkingOverride{Enabled: &enabled}, false) + if err != nil { + t.Fatalf("resolveThinkingConfig() error = %v", err) + } + if cfg == nil || !cfg.Enabled { + t.Fatalf("expected forced enabled config, got %+v", cfg) + } + }) + + t.Run("unsupported effort returns error", func(t *testing.T) { + t.Parallel() + + _, err := resolveThinkingConfig(providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateSupported, + ThinkingEfforts: []string{"low"}, + }, &ThinkingOverride{Effort: "high"}, true) + if err == nil { + t.Fatal("expected effort validation error") + } + }) + + t.Run("empty effort list clears default effort", func(t *testing.T) { + t.Parallel() + + cfg, err := resolveThinkingConfig(providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateSupported, + ThinkingDefaultEffort: "medium", + }, nil, true) + if err != nil { + t.Fatalf("resolveThinkingConfig() error = %v", err) + } + if cfg == nil || cfg.Effort != "" { + t.Fatalf("expected empty effort, got %+v", cfg) + } + }) +} + +func TestContainsEffortAndHintsLookup(t *testing.T) { + t.Parallel() + + if !containsEffort([]string{"low", "high"}, "high") { + t.Fatal("expected effort to be found") + } + if containsEffort([]string{"low"}, "max") { + t.Fatal("did not expect unknown effort to be found") + } + + models := []providertypes.ModelDescriptor{ + { + ID: "model-a", + CapabilityHints: providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateSupported, + }, + }, + } + hints := modelCapabilityHintsForRequest("model-a", models) + if hints.Thinking != providertypes.ModelCapabilityStateSupported { + t.Fatalf("unexpected hints: %+v", hints) + } + if got := modelCapabilityHintsForRequest("missing", models); !reflect.DeepEqual(got, providertypes.ModelCapabilityHints{}) { + t.Fatalf("expected zero hints for missing model, got %+v", got) + } +} + +func TestServiceThinkingToggle(t *testing.T) { + t.Parallel() + + svc := NewWithFactory(nil, nil, nil, nil, nil) + if !svc.IsThinkingEnabled() { + t.Fatal("expected default thinking toggle to be enabled") + } + + svc.SetThinkingEnabled(false) + if svc.IsThinkingEnabled() { + t.Fatal("expected thinking toggle to be disabled") + } +} diff --git a/internal/session/sqlite_store.go b/internal/session/sqlite_store.go index 3b79789c..72c47b8f 100644 --- a/internal/session/sqlite_store.go +++ b/internal/session/sqlite_store.go @@ -43,12 +43,13 @@ type sqliteSessionRow struct { } type sqliteMessageRow struct { - Role string - PartsJSON string - ToolCallsJSON string - ToolCallID string - IsError bool - ToolMetadataJSON string + Role string + PartsJSON string + ToolCallsJSON string + ToolCallID string + IsError bool + ToolMetadataJSON string + ThinkingMetadataJSON string } const maxSessionDeleteBatchSize = 900 @@ -908,6 +909,9 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { return err } + if err := migrateSQLiteSchemaV6ToV7(ctx, db); err != nil { + return err + } case 2: if err := migrateSQLiteSchemaV2ToV3(ctx, db); err != nil { return err @@ -921,6 +925,9 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { return err } + if err := migrateSQLiteSchemaV6ToV7(ctx, db); err != nil { + return err + } case 3: if err := migrateSQLiteSchemaV3ToV4(ctx, db); err != nil { return err @@ -931,6 +938,9 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { return err } + if err := migrateSQLiteSchemaV6ToV7(ctx, db); err != nil { + return err + } case 4: if err := migrateSQLiteSchemaV4ToV5(ctx, db); err != nil { return err @@ -938,10 +948,16 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { return err } + if err := migrateSQLiteSchemaV6ToV7(ctx, db); err != nil { + return err + } case 5: if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { return err } + if err := migrateSQLiteSchemaV6ToV7(ctx, db); err != nil { + return err + } default: return fmt.Errorf("session: unsupported sqlite schema version %d", userVersion) } @@ -986,6 +1002,7 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { tool_call_id TEXT NOT NULL DEFAULT '', is_error INTEGER NOT NULL DEFAULT 0, tool_metadata_json TEXT NOT NULL DEFAULT '', + thinking_metadata_json TEXT NOT NULL DEFAULT '', created_at_ms INTEGER NOT NULL, PRIMARY KEY(session_id, seq), FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE @@ -1263,6 +1280,24 @@ func migrateSQLiteSchemaV4ToV5(ctx context.Context, db *sql.DB) error { return nil } +// migrateSQLiteSchemaV6ToV7 将 v6 会话库升级到 v7 schema,新增 thinking_metadata_json 列。 +func migrateSQLiteSchemaV6ToV7(ctx context.Context, db *sql.DB) error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("session: begin v6->v7 migration tx: %w", err) + } + defer rollbackTx(tx) + + if _, err := tx.ExecContext(ctx, `ALTER TABLE messages ADD COLUMN thinking_metadata_json TEXT NOT NULL DEFAULT ''`); err != nil { + return fmt.Errorf("session: add thinking_metadata_json column: %w", err) + } + + if _, err := tx.ExecContext(ctx, fmt.Sprintf(`PRAGMA user_version=%d`, sqliteSchemaVersion)); err != nil { + return fmt.Errorf("session: set v7 schema version: %w", err) + } + return tx.Commit() +} + func sqliteTableHasColumn(ctx context.Context, tx *sql.Tx, table string, column string) (bool, error) { rows, err := tx.QueryContext(ctx, `PRAGMA table_info(`+table+`)`) if err != nil { @@ -1444,7 +1479,7 @@ WHERE id = ? // loadMessages 查询指定会话的全部消息并按顺序返回。 func loadMessages(ctx context.Context, tx *sql.Tx, sessionID string) ([]sqliteMessageRow, error) { rows, err := tx.QueryContext(ctx, ` -SELECT role, parts_json, tool_calls_json, tool_call_id, is_error, tool_metadata_json +SELECT role, parts_json, tool_calls_json, tool_call_id, is_error, tool_metadata_json, thinking_metadata_json FROM messages WHERE session_id = ? ORDER BY seq ASC @@ -1466,6 +1501,7 @@ ORDER BY seq ASC &row.ToolCallID, &row.IsError, &row.ToolMetadataJSON, + &row.ThinkingMetadataJSON, ); err != nil { return nil, fmt.Errorf("session: scan message row: %w", err) } @@ -1558,14 +1594,18 @@ func buildMessageFromRow(row sqliteMessageRow) (providertypes.Message, error) { return providertypes.Message{}, fmt.Errorf("session: decode tool metadata: %w", err) } } - return providertypes.Message{ + msg := providertypes.Message{ Role: row.Role, Parts: parts, ToolCalls: toolCalls, ToolCallID: row.ToolCallID, IsError: row.IsError, ToolMetadata: metadata, - }, nil + } + if row.ThinkingMetadataJSON != "" { + msg.ThinkingMetadata = json.RawMessage(row.ThinkingMetadataJSON) + } + return msg, nil } // currentLastSeq 读取当前会话的最后消息序号。 @@ -1727,8 +1767,8 @@ func insertMessage( ) error { result, err := tx.ExecContext(ctx, ` INSERT INTO messages ( - session_id, seq, role, parts_json, tool_calls_json, tool_call_id, is_error, tool_metadata_json, created_at_ms -) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + session_id, seq, role, parts_json, tool_calls_json, tool_call_id, is_error, tool_metadata_json, thinking_metadata_json, created_at_ms +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) `, sessionID, seq, @@ -1738,6 +1778,7 @@ INSERT INTO messages ( message.ToolCallID, boolToInt(message.IsError), mustJSONString(message.ToolMetadata), + string(message.ThinkingMetadata), toUnixMillis(createdAt), ) if err != nil { diff --git a/internal/session/sqlite_store_thinking_test.go b/internal/session/sqlite_store_thinking_test.go new file mode 100644 index 00000000..2fb22733 --- /dev/null +++ b/internal/session/sqlite_store_thinking_test.go @@ -0,0 +1,101 @@ +package session + +import ( + "context" + "database/sql" + "encoding/json" + "path/filepath" + "testing" + + _ "modernc.org/sqlite" + + providertypes "neo-code/internal/provider/types" +) + +func TestSQLiteStoreRoundTripsThinkingMetadata(t *testing.T) { + t.Parallel() + + store := newTestStore(t) + ctx := context.Background() + session, err := store.CreateSession(ctx, CreateSessionInput{ID: "thinking_meta", Title: "thinking"}) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + + thinking := json.RawMessage(`{"reasoning_content":"plan"}`) + if err := store.AppendMessages(ctx, AppendMessagesInput{ + SessionID: session.ID, + Messages: []providertypes.Message{{ + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("answer")}, + ThinkingMetadata: thinking, + }}, + }); err != nil { + t.Fatalf("AppendMessages() error = %v", err) + } + + loaded, err := store.LoadSession(ctx, session.ID) + if err != nil { + t.Fatalf("LoadSession() error = %v", err) + } + if len(loaded.Messages) != 1 { + t.Fatalf("expected one message, got %d", len(loaded.Messages)) + } + if string(loaded.Messages[0].ThinkingMetadata) != string(thinking) { + t.Fatalf("thinking metadata = %s, want %s", loaded.Messages[0].ThinkingMetadata, thinking) + } +} + +func TestMigrateSQLiteSchemaV6ToV7AddsThinkingMetadataColumn(t *testing.T) { + t.Parallel() + + db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), "schema-v6.db")) + if err != nil { + t.Fatalf("sql.Open() error = %v", err) + } + defer db.Close() + + statements := []string{ + `CREATE TABLE messages ( + session_id TEXT NOT NULL, + seq INTEGER NOT NULL, + role TEXT NOT NULL, + parts_json TEXT NOT NULL, + tool_calls_json TEXT NOT NULL DEFAULT '', + tool_call_id TEXT NOT NULL DEFAULT '', + is_error INTEGER NOT NULL DEFAULT 0, + tool_metadata_json TEXT NOT NULL DEFAULT '', + created_at_ms INTEGER NOT NULL, + PRIMARY KEY(session_id, seq) + )`, + `PRAGMA user_version=6`, + } + for _, statement := range statements { + if _, err := db.Exec(statement); err != nil { + t.Fatalf("Exec(%q) error = %v", statement, err) + } + } + + if err := migrateSQLiteSchemaV6ToV7(context.Background(), db); err != nil { + t.Fatalf("migrateSQLiteSchemaV6ToV7() error = %v", err) + } + hasColumn, err := sqliteTableHasColumn(context.Background(), mustBeginTx(t, db), "messages", "thinking_metadata_json") + if err != nil { + t.Fatalf("sqliteTableHasColumn() error = %v", err) + } + if !hasColumn { + t.Fatal("expected thinking_metadata_json column after migration") + } +} + +func mustBeginTx(t *testing.T, db *sql.DB) *sql.Tx { + t.Helper() + tx, err := db.BeginTx(context.Background(), nil) + if err != nil { + t.Fatalf("BeginTx() error = %v", err) + } + t.Cleanup(func() { + _ = tx.Rollback() + }) + return tx +} diff --git a/internal/session/store.go b/internal/session/store.go index 6ec1ac51..deb27de6 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -14,7 +14,7 @@ import ( const ( sessionDatabaseFileName = "session.db" assetsDirName = "assets" - sqliteSchemaVersion = 6 + sqliteSchemaVersion = 7 // MaxSessionMessages 定义单个会话允许持久化的最大消息数,超出时自动裁剪最旧消息。 MaxSessionMessages = 8192 diff --git a/internal/session/todo.go b/internal/session/todo.go index 879988a4..1c1088c2 100644 --- a/internal/session/todo.go +++ b/internal/session/todo.go @@ -856,3 +856,38 @@ type TodoContentCheck struct { Artifact string `json:"artifact,omitempty"` Contains []string `json:"contains,omitempty"` } + +// CancelNonTerminalTodos 将列表中所有非终态(pending/in_progress/blocked)的 todo 标记为 canceled。 +func CancelNonTerminalTodos(todos []TodoItem) { + now := time.Now().UTC() + for i := range todos { + switch todos[i].Status { + case TodoStatusPending, TodoStatusInProgress, TodoStatusBlocked: + todos[i].Status = TodoStatusCanceled + todos[i].UpdatedAt = now + } + } +} + +// CancelTodosByIDs 将指定 ID 列表中非终态的 todo 标记为 canceled。 +// 仅取消 plan 明确引用的 todo,不影响用户手动创建或其他来源的 todo。 +func CancelTodosByIDs(todos []TodoItem, ids []string) { + if len(ids) == 0 { + return + } + target := make(map[string]struct{}, len(ids)) + for _, id := range ids { + target[strings.TrimSpace(id)] = struct{}{} + } + now := time.Now().UTC() + for i := range todos { + if _, ok := target[todos[i].ID]; !ok { + continue + } + switch todos[i].Status { + case TodoStatusPending, TodoStatusInProgress, TodoStatusBlocked: + todos[i].Status = TodoStatusCanceled + todos[i].UpdatedAt = now + } + } +} diff --git a/internal/session/todo_cancel_test.go b/internal/session/todo_cancel_test.go new file mode 100644 index 00000000..eae6ab6a --- /dev/null +++ b/internal/session/todo_cancel_test.go @@ -0,0 +1,67 @@ +package session + +import ( + "testing" + "time" +) + +func TestCancelNonTerminalTodos(t *testing.T) { + t.Parallel() + + todos := []TodoItem{ + {ID: "p", Status: TodoStatusPending}, + {ID: "i", Status: TodoStatusInProgress}, + {ID: "b", Status: TodoStatusBlocked}, + {ID: "c", Status: TodoStatusCompleted}, + } + + CancelNonTerminalTodos(todos) + + for _, id := range []string{"p", "i", "b"} { + item := todos[indexTodoByID(t, todos, id)] + if item.Status != TodoStatusCanceled { + t.Fatalf("todo %q status = %q, want canceled", id, item.Status) + } + if item.UpdatedAt.IsZero() || time.Since(item.UpdatedAt) > time.Minute { + t.Fatalf("todo %q missing updated_at: %+v", id, item) + } + } + if todos[indexTodoByID(t, todos, "c")].Status != TodoStatusCompleted { + t.Fatalf("terminal todo should stay completed: %+v", todos) + } +} + +func TestCancelTodosByIDs(t *testing.T) { + t.Parallel() + + todos := []TodoItem{ + {ID: "keep", Status: TodoStatusPending}, + {ID: "cancel", Status: TodoStatusBlocked}, + {ID: "done", Status: TodoStatusCompleted}, + } + + CancelTodosByIDs(todos, []string{" cancel ", "done"}) + + if todos[indexTodoByID(t, todos, "cancel")].Status != TodoStatusCanceled { + t.Fatalf("expected selected non-terminal todo to be canceled: %+v", todos) + } + if todos[indexTodoByID(t, todos, "keep")].Status != TodoStatusPending { + t.Fatalf("expected unmatched todo to stay pending: %+v", todos) + } + if todos[indexTodoByID(t, todos, "done")].Status != TodoStatusCompleted { + t.Fatalf("expected terminal todo to stay completed: %+v", todos) + } + + CancelTodosByIDs(todos, nil) +} + +func indexTodoByID(t *testing.T, todos []TodoItem, id string) int { + t.Helper() + for i := range todos { + if todos[i].ID == id { + return i + } + } + t.Fatalf("todo %q not found", id) + return -1 +} diff --git a/internal/tools/mode_filter.go b/internal/tools/mode_filter.go index 1c331fdb..99d1cdec 100644 --- a/internal/tools/mode_filter.go +++ b/internal/tools/mode_filter.go @@ -14,7 +14,8 @@ func isReadOnlyVisibleTool(name string) bool { ToolNameFilesystemGlob, ToolNameWebFetch, ToolNameMemoRecall, - ToolNameMemoList: + ToolNameMemoList, + ToolNameTodoWrite: return true default: return false