From d850e7b8bf332b44af7e070683e3f67eb62918e0 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Tue, 5 May 2026 15:06:22 +0000 Subject: [PATCH] test: increase thinking coverage Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- .../provider/conformance/conformance_test.go | 418 +++++++++++++++++- internal/provider/minimax/provider_test.go | 79 ++++ .../chatcompletions/adapter_test.go | 33 ++ .../chatcompletions/request_test.go | 40 ++ internal/runtime/thinking_test.go | 141 ++++++ 5 files changed, 699 insertions(+), 12 deletions(-) create mode 100644 internal/runtime/thinking_test.go diff --git a/internal/provider/conformance/conformance_test.go b/internal/provider/conformance/conformance_test.go index bf4235d9..976eb9c8 100644 --- a/internal/provider/conformance/conformance_test.go +++ b/internal/provider/conformance/conformance_test.go @@ -12,8 +12,13 @@ import ( "neo-code/internal/provider" "neo-code/internal/provider/anthropic" + "neo-code/internal/provider/deepseek" "neo-code/internal/provider/gemini" + "neo-code/internal/provider/mimo" + "neo-code/internal/provider/minimax" "neo-code/internal/provider/openaicompat" + "neo-code/internal/provider/openaicompat/glm" + "neo-code/internal/provider/openaicompat/qwen" providertypes "neo-code/internal/provider/types" ) @@ -27,6 +32,7 @@ func TestGenerateContractAcrossDrivers(t *testing.T) { streamBody string expectReason string expectTokens int + expectedOrder []providertypes.StreamEventType }{ { name: "openaicompat_chat_completions", @@ -50,6 +56,12 @@ func TestGenerateContractAcrossDrivers(t *testing.T) { "data: [DONE]\n\n", expectReason: "stop", expectTokens: 7, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, }, { name: "gemini_native", @@ -71,6 +83,12 @@ func TestGenerateContractAcrossDrivers(t *testing.T) { "data: {\"candidates\":[{\"index\":0,\"finishReason\":\"STOP\",\"content\":{\"parts\":[{\"functionCall\":{\"name\":\"filesystem_read_file\",\"args\":{\"path\":\"README.md\"}}}]}}]}\n\n", expectReason: "stop", expectTokens: 7, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, }, { name: "anthropic_messages", @@ -100,6 +118,128 @@ func TestGenerateContractAcrossDrivers(t *testing.T) { "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":6}}\n\n", expectReason: "tool_use", expectTokens: 10, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, + }, + { + name: "deepseek_chat_completions_with_reasoning", + driver: deepseek.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: deepseek.DriverName, + Driver: provider.DriverDeepSeek, + BaseURL: baseURL, + DefaultModel: "deepseek-v4", + APIKeyEnv: "DEEPSEEK_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + expectedPath: "/v1/chat/completions", + expectedHeader: "Authorization", + streamBody: "data: {\"choices\":[{\"delta\":{\"reasoning_content\":\"plan\",\"content\":\"Hello \"}}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n" + + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"filesystem_read_file\"}}]}}]}\n" + + "data: {\"choices\":[{\"finish_reason\":\"stop\",\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"arguments\":\"{\\\"path\\\":\\\"README.md\\\"}\"}}]}}]}\n" + + "data: [DONE]\n\n", + expectReason: "stop", + expectTokens: 7, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventThinkingDelta, + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, + }, + { + name: "mimo_chat_completions_with_reasoning", + driver: mimo.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: mimo.DriverName, + Driver: provider.DriverMiMo, + BaseURL: baseURL, + DefaultModel: "mimo-v2.5", + APIKeyEnv: "MIMO_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + expectedPath: "/chat/completions", + expectedHeader: "Authorization", + streamBody: "data: {\"choices\":[{\"delta\":{\"reasoning_content\":\"plan\",\"content\":\"Hello \"}}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n" + + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"filesystem_read_file\"}}]}}]}\n" + + "data: {\"choices\":[{\"finish_reason\":\"stop\",\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"arguments\":\"{\\\"path\\\":\\\"README.md\\\"}\"}}]}}]}\n" + + "data: [DONE]\n\n", + expectReason: "stop", + expectTokens: 7, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventThinkingDelta, + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, + }, + { + name: "qwen_chat_completions_with_reasoning", + driver: qwen.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: qwen.DriverName, + Driver: provider.DriverQwen, + BaseURL: baseURL, + DefaultModel: "qwen3", + APIKeyEnv: "QWEN_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + expectedPath: "/chat/completions", + expectedHeader: "Authorization", + streamBody: "data: {\"choices\":[{\"delta\":{\"reasoning_content\":\"plan\",\"content\":\"Hello \"}}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n" + + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"filesystem_read_file\"}}]}}]}\n" + + "data: {\"choices\":[{\"finish_reason\":\"stop\",\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"arguments\":\"{\\\"path\\\":\\\"README.md\\\"}\"}}]}}]}\n" + + "data: [DONE]\n\n", + expectReason: "stop", + expectTokens: 7, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventThinkingDelta, + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, + }, + { + name: "glm_chat_completions_with_reasoning", + driver: glm.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: glm.DriverName, + Driver: provider.DriverGLM, + BaseURL: baseURL, + DefaultModel: "glm-5.1", + APIKeyEnv: "GLM_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + expectedPath: "/chat/completions", + expectedHeader: "Authorization", + streamBody: "data: {\"choices\":[{\"delta\":{\"reasoning_content\":\"plan\",\"content\":\"Hello \"}}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n" + + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"filesystem_read_file\"}}]}}]}\n" + + "data: {\"choices\":[{\"finish_reason\":\"stop\",\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"arguments\":\"{\\\"path\\\":\\\"README.md\\\"}\"}}]}}]}\n" + + "data: [DONE]\n\n", + expectReason: "stop", + expectTokens: 7, + expectedOrder: []providertypes.StreamEventType{ + providertypes.StreamEventThinkingDelta, + providertypes.StreamEventTextDelta, + providertypes.StreamEventToolCallStart, + providertypes.StreamEventToolCallDelta, + providertypes.StreamEventMessageDone, + }, }, } @@ -131,22 +271,16 @@ func TestGenerateContractAcrossDrivers(t *testing.T) { } drained := drainEvents(events) - if len(drained) != 4 { - t.Fatalf("expected 4 events, got %d (%+v)", len(drained), drained) - } - expectedOrder := []providertypes.StreamEventType{ - providertypes.StreamEventTextDelta, - providertypes.StreamEventToolCallStart, - providertypes.StreamEventToolCallDelta, - providertypes.StreamEventMessageDone, + if len(drained) != len(tt.expectedOrder) { + t.Fatalf("expected %d events, got %d (%+v)", len(tt.expectedOrder), len(drained), drained) } - for i := range expectedOrder { - if drained[i].Type != expectedOrder[i] { - t.Fatalf("unexpected event order at index %d, expected %q got %q", i, expectedOrder[i], drained[i].Type) + for i := range tt.expectedOrder { + if drained[i].Type != tt.expectedOrder[i] { + t.Fatalf("unexpected event order at index %d, expected %q got %q", i, tt.expectedOrder[i], drained[i].Type) } } - done, doneErr := drained[3].MessageDoneValue() + done, doneErr := drained[len(drained)-1].MessageDoneValue() if doneErr != nil { t.Fatalf("MessageDoneValue() error = %v", doneErr) } @@ -186,6 +320,23 @@ func TestDiscoverContractAcrossDrivers(t *testing.T) { expectedHeader: "Authorization", responseBody: `{"data":[{"id":"gpt-4.1","name":"GPT 4.1"}]}`, }, + { + name: "deepseek_discover", + driver: deepseek.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: deepseek.DriverName, + Driver: provider.DriverDeepSeek, + BaseURL: baseURL, + APIKeyEnv: "DEEPSEEK_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + DiscoveryEndpointPath: "/models", + } + }, + expectedPath: "/models", + expectedHeader: "Authorization", + responseBody: `{"data":[{"id":"deepseek-v4","name":"DeepSeek V4"}]}`, + }, { name: "gemini_discover", driver: gemini.Driver(), @@ -220,6 +371,74 @@ func TestDiscoverContractAcrossDrivers(t *testing.T) { expectedHeader: "x-api-key", responseBody: `{"data":[{"id":"claude-3-7-sonnet","display_name":"Claude 3.7 Sonnet"}],"has_more":false}`, }, + { + name: "mimo_discover", + driver: mimo.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: mimo.DriverName, + Driver: provider.DriverMiMo, + BaseURL: baseURL, + APIKeyEnv: "MIMO_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + DiscoveryEndpointPath: "/models", + } + }, + expectedPath: "/models", + expectedHeader: "Authorization", + responseBody: `{"data":[{"id":"mimo-v2.5","name":"MiMo V2.5"}]}`, + }, + { + name: "minimax_discover", + driver: minimax.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: minimax.DriverName, + Driver: provider.DriverMiniMax, + BaseURL: baseURL, + APIKeyEnv: "MINIMAX_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + DiscoveryEndpointPath: "/models", + } + }, + expectedPath: "/models", + expectedHeader: "Authorization", + responseBody: `{"data":[{"id":"minimax-m2.7","name":"MiniMax M2.7"}]}`, + }, + { + name: "qwen_discover", + driver: qwen.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: qwen.DriverName, + Driver: provider.DriverQwen, + BaseURL: baseURL, + APIKeyEnv: "QWEN_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + DiscoveryEndpointPath: "/models", + } + }, + expectedPath: "/models", + expectedHeader: "Authorization", + responseBody: `{"data":[{"id":"qwen3","name":"Qwen 3"}]}`, + }, + { + name: "glm_discover", + driver: glm.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: glm.DriverName, + Driver: provider.DriverGLM, + BaseURL: baseURL, + APIKeyEnv: "GLM_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + DiscoveryEndpointPath: "/models", + } + }, + expectedPath: "/models", + expectedHeader: "Authorization", + responseBody: `{"data":[{"id":"glm-5.1","name":"GLM 5.1"}]}`, + }, } for _, tt := range testCases { @@ -272,6 +491,21 @@ func TestGenerateErrorClassificationAcrossDrivers(t *testing.T) { path: "/chat/completions", body: `{"error":{"message":"invalid api key"}}`, }, + { + name: "deepseek_auth_error", + driver: deepseek.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Driver: provider.DriverDeepSeek, + BaseURL: baseURL, + DefaultModel: "deepseek-v4", + APIKeyEnv: "DEEPSEEK_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + path: "/v1/chat/completions", + body: `{"error":{"message":"invalid api key"}}`, + }, { name: "gemini_auth_error", driver: gemini.Driver(), @@ -304,6 +538,66 @@ func TestGenerateErrorClassificationAcrossDrivers(t *testing.T) { path: "/v1/messages", body: `{"error":{"message":"invalid x-api-key"}}`, }, + { + name: "mimo_auth_error", + driver: mimo.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Driver: provider.DriverMiMo, + BaseURL: baseURL, + DefaultModel: "mimo-v2.5", + APIKeyEnv: "MIMO_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + path: "/chat/completions", + body: `{"error":{"message":"invalid api key"}}`, + }, + { + name: "minimax_auth_error", + driver: minimax.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Driver: provider.DriverMiniMax, + BaseURL: baseURL + "/chat/completions", + DefaultModel: "minimax-m2.7", + APIKeyEnv: "MINIMAX_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + path: "/chat/completions", + body: `{"error":{"message":"invalid api key"}}`, + }, + { + name: "qwen_auth_error", + driver: qwen.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Driver: provider.DriverQwen, + BaseURL: baseURL, + DefaultModel: "qwen3", + APIKeyEnv: "QWEN_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + path: "/chat/completions", + body: `{"error":{"message":"invalid api key"}}`, + }, + { + name: "glm_auth_error", + driver: glm.Driver(), + buildConfig: func(baseURL string) provider.RuntimeConfig { + return provider.RuntimeConfig{ + Driver: provider.DriverGLM, + BaseURL: baseURL, + DefaultModel: "glm-5.1", + APIKeyEnv: "GLM_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + path: "/chat/completions", + body: `{"error":{"message":"invalid api key"}}`, + }, } for _, tt := range testCases { @@ -346,6 +640,106 @@ func TestGenerateErrorClassificationAcrossDrivers(t *testing.T) { } } +func TestEstimateInputTokensContractAcrossDrivers(t *testing.T) { + testCases := []struct { + name string + driver provider.DriverDefinition + buildConfig func() provider.RuntimeConfig + }{ + { + name: "deepseek_estimate", + driver: deepseek.Driver(), + buildConfig: func() provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: deepseek.DriverName, + Driver: provider.DriverDeepSeek, + BaseURL: "https://api.deepseek.example", + DefaultModel: "deepseek-v4", + APIKeyEnv: "DEEPSEEK_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + }, + { + name: "mimo_estimate", + driver: mimo.Driver(), + buildConfig: func() provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: mimo.DriverName, + Driver: provider.DriverMiMo, + BaseURL: "https://api.mimo.example", + DefaultModel: "mimo-v2.5", + APIKeyEnv: "MIMO_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + }, + { + name: "minimax_estimate", + driver: minimax.Driver(), + buildConfig: func() provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: minimax.DriverName, + Driver: provider.DriverMiniMax, + BaseURL: "https://api.minimax.example/chat/completions", + DefaultModel: "minimax-m2.7", + APIKeyEnv: "MINIMAX_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + }, + { + name: "qwen_estimate", + driver: qwen.Driver(), + buildConfig: func() provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: qwen.DriverName, + Driver: provider.DriverQwen, + BaseURL: "https://api.qwen.example", + DefaultModel: "qwen3", + APIKeyEnv: "QWEN_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + }, + { + name: "glm_estimate", + driver: glm.Driver(), + buildConfig: func() provider.RuntimeConfig { + return provider.RuntimeConfig{ + Name: glm.DriverName, + Driver: provider.DriverGLM, + BaseURL: "https://api.glm.example", + DefaultModel: "glm-5.1", + APIKeyEnv: "GLM_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + } + }, + }, + } + + for _, tt := range testCases { + tt := tt + t.Run(tt.name, func(t *testing.T) { + p, err := tt.driver.Build(context.Background(), tt.buildConfig()) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + estimate, err := p.EstimateInputTokens(context.Background(), generateRequestWithAssets()) + if err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + if estimate.EstimatedInputTokens <= 0 { + t.Fatalf("expected positive token estimate, got %+v", estimate) + } + if estimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("expected local estimate source, got %+v", estimate) + } + }) + } +} + func generateRequestWithAssets() providertypes.GenerateRequest { return providertypes.GenerateRequest{ Messages: []providertypes.Message{ diff --git a/internal/provider/minimax/provider_test.go b/internal/provider/minimax/provider_test.go index 7a90fa40..631aea5a 100644 --- a/internal/provider/minimax/provider_test.go +++ b/internal/provider/minimax/provider_test.go @@ -1,8 +1,12 @@ package minimax import ( + "context" "encoding/json" + "strings" "testing" + + providertypes "neo-code/internal/provider/types" ) func TestInjectMiniMaxParams(t *testing.T) { @@ -55,3 +59,78 @@ func TestExtractThinkContent_NoTags(t *testing.T) { t.Fatalf("expected empty, got %q", result) } } + +func TestConsumeMiniMaxStream(t *testing.T) { + t.Parallel() + + body := strings.Join([]string{ + `data: {"choices":[{"delta":{"reasoning_details":"internal plan","content":"visible answer"}}],"usage":{"total_tokens":9}}`, + `data: [DONE]`, + "", + }, "\n") + + events := make(chan providertypes.StreamEvent, 4) + if err := ConsumeMiniMaxStream(context.Background(), strings.NewReader(body), events); err != nil { + t.Fatalf("ConsumeMiniMaxStream() error = %v", err) + } + + drained := drainMiniMaxEvents(events) + if len(drained) != 3 { + t.Fatalf("expected 3 events, got %d (%+v)", len(drained), drained) + } + thinking, err := drained[0].ThinkingDeltaValue() + if err != nil || thinking.Text != "internal plan" { + t.Fatalf("expected thinking delta, got err=%v event=%+v", err, drained[0]) + } + text, err := drained[1].TextDeltaValue() + if err != nil || text.Text != "visible answer" { + t.Fatalf("expected text delta, got err=%v event=%+v", err, drained[1]) + } + done, err := drained[2].MessageDoneValue() + if err != nil { + t.Fatalf("expected message done, got err=%v", err) + } + if done.Usage == nil || done.Usage.TotalTokens != 9 { + t.Fatalf("unexpected usage payload: %+v", done.Usage) + } +} + +func TestConsumeMiniMaxStreamExtractsThinkTagsFromContent(t *testing.T) { + t.Parallel() + + body := strings.Join([]string{ + `data: {"choices":[{"delta":{"content":"internal planvisible answer"},"finish_reason":"stop"}],"usage":{"total_tokens":5}}`, + `data: [DONE]`, + "", + }, "\n") + + events := make(chan providertypes.StreamEvent, 4) + if err := ConsumeMiniMaxStream(context.Background(), strings.NewReader(body), events); err != nil { + t.Fatalf("ConsumeMiniMaxStream() error = %v", err) + } + + drained := drainMiniMaxEvents(events) + if len(drained) != 3 { + t.Fatalf("expected 3 events, got %d (%+v)", len(drained), drained) + } + thinking, err := drained[0].ThinkingDeltaValue() + if err != nil || thinking.Text != "internal plan" { + t.Fatalf("expected extracted think tag, got err=%v event=%+v", err, drained[0]) + } + text, err := drained[1].TextDeltaValue() + if err != nil || text.Text != "visible answer" { + t.Fatalf("expected think tags removed from text, got err=%v event=%+v", err, drained[1]) + } +} + +func drainMiniMaxEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { + drained := make([]providertypes.StreamEvent, 0, len(events)) + for { + select { + case event := <-events: + drained = append(drained, event) + default: + return drained + } + } +} diff --git a/internal/provider/openaicompat/chatcompletions/adapter_test.go b/internal/provider/openaicompat/chatcompletions/adapter_test.go index d1eb195c..dd70b61d 100644 --- a/internal/provider/openaicompat/chatcompletions/adapter_test.go +++ b/internal/provider/openaicompat/chatcompletions/adapter_test.go @@ -104,6 +104,39 @@ func TestConsumeStreamParsesMultilineDataEvent(t *testing.T) { } } +func TestConsumeStreamEmitsThinkingDeltaFromReasoningFields(t *testing.T) { + t.Parallel() + + body := strings.Join([]string{ + `data: {"choices":[{"delta":{"reasoning_content":"plan","content":"answer"}}]}`, + `data: {"choices":[{"delta":{"reasoning":"fallback reasoning"}}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`, + `data: [DONE]`, + "", + }, "\n") + + events := make(chan providertypes.StreamEvent, 8) + if err := ConsumeStream(context.Background(), strings.NewReader(body), events); err != nil { + t.Fatalf("ConsumeStream() error = %v", err) + } + + drained := drainEvents(events) + if len(drained) != 4 { + t.Fatalf("expected 4 events, got %d (%+v)", len(drained), drained) + } + first, err := drained[0].ThinkingDeltaValue() + if err != nil || first.Text != "plan" { + t.Fatalf("expected reasoning_content thinking delta, got err=%v event=%+v", err, drained[0]) + } + text, err := drained[1].TextDeltaValue() + if err != nil || text.Text != "answer" { + t.Fatalf("expected text delta, got err=%v event=%+v", err, drained[1]) + } + second, err := drained[2].ThinkingDeltaValue() + if err != nil || second.Text != "fallback reasoning" { + t.Fatalf("expected reasoning fallback thinking delta, got err=%v event=%+v", err, drained[2]) + } +} + func TestConsumeStreamEOFWithoutDoneAndWithoutFinishReason(t *testing.T) { t.Parallel() diff --git a/internal/provider/openaicompat/chatcompletions/request_test.go b/internal/provider/openaicompat/chatcompletions/request_test.go index 57a7905d..8930df86 100644 --- a/internal/provider/openaicompat/chatcompletions/request_test.go +++ b/internal/provider/openaicompat/chatcompletions/request_test.go @@ -76,6 +76,46 @@ func TestBuildRequestUsesDefaultModelAndNormalizesTools(t *testing.T) { } } +func TestBuildRequestThinkingConfigAndContinuity(t *testing.T) { + t.Parallel() + + payload, err := BuildRequest(context.Background(), provider.RuntimeConfig{DefaultModel: "gpt-default"}, providertypes.GenerateRequest{ + Messages: []providertypes.Message{ + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + ThinkingMetadata: []byte(`{"reasoning":"step by step"}`), + }, + }, + ThinkingConfig: &providertypes.ThinkingConfig{ + Enabled: true, + Effort: "high", + }, + }) + if err != nil { + t.Fatalf("BuildRequest() error = %v", err) + } + if payload.ReasoningEffort != "high" { + t.Fatalf("expected reasoning effort to be preserved, got %q", payload.ReasoningEffort) + } + if len(payload.Messages) != 1 || payload.Messages[0].ReasoningContent != "step by step" { + t.Fatalf("expected reasoning continuity content, got %+v", payload.Messages) + } + + disabled, err := BuildRequest(context.Background(), provider.RuntimeConfig{DefaultModel: "gpt-default"}, providertypes.GenerateRequest{ + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}, + }, + ThinkingConfig: &providertypes.ThinkingConfig{Enabled: false}, + }) + if err != nil { + t.Fatalf("BuildRequest() disabled error = %v", err) + } + if disabled.ReasoningEffort != "none" { + t.Fatalf("expected disabled reasoning effort marker, got %q", disabled.ReasoningEffort) + } +} + func TestBuildRequestAndToOpenAIMessageErrors(t *testing.T) { t.Parallel() diff --git a/internal/runtime/thinking_test.go b/internal/runtime/thinking_test.go new file mode 100644 index 00000000..ec41f053 --- /dev/null +++ b/internal/runtime/thinking_test.go @@ -0,0 +1,141 @@ +package runtime + +import ( + "reflect" + "testing" + + providertypes "neo-code/internal/provider/types" +) + +func TestResolveThinkingConfig(t *testing.T) { + t.Parallel() + + t.Run("unsupported returns nil", func(t *testing.T) { + t.Parallel() + + cfg, err := resolveThinkingConfig(providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateUnsupported, + }, nil, true) + if err != nil { + t.Fatalf("resolveThinkingConfig() error = %v", err) + } + if cfg != nil { + t.Fatalf("expected nil config, got %+v", cfg) + } + }) + + t.Run("unknown follows global toggle", func(t *testing.T) { + t.Parallel() + + cfg, err := resolveThinkingConfig(providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateUnknown, + }, nil, false) + if err != nil { + t.Fatalf("resolveThinkingConfig() error = %v", err) + } + if cfg == nil || cfg.Enabled { + t.Fatalf("expected disabled config, got %+v", cfg) + } + }) + + t.Run("override and effort validation", func(t *testing.T) { + t.Parallel() + + enabled := false + cfg, err := resolveThinkingConfig(providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateSupported, + ThinkingEfforts: []string{"low", "high"}, + ThinkingDefaultEffort: "low", + }, &ThinkingOverride{Enabled: &enabled, Effort: "high"}, true) + if err != nil { + t.Fatalf("resolveThinkingConfig() error = %v", err) + } + if cfg == nil || cfg.Enabled || cfg.Effort != "high" { + t.Fatalf("expected disabled high-effort config, got %+v", cfg) + } + }) + + t.Run("force enabled wins over override", func(t *testing.T) { + t.Parallel() + + enabled := false + cfg, err := resolveThinkingConfig(providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateSupported, + ThinkingForceEnabled: true, + }, &ThinkingOverride{Enabled: &enabled}, false) + if err != nil { + t.Fatalf("resolveThinkingConfig() error = %v", err) + } + if cfg == nil || !cfg.Enabled { + t.Fatalf("expected forced enabled config, got %+v", cfg) + } + }) + + t.Run("unsupported effort returns error", func(t *testing.T) { + t.Parallel() + + _, err := resolveThinkingConfig(providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateSupported, + ThinkingEfforts: []string{"low"}, + }, &ThinkingOverride{Effort: "high"}, true) + if err == nil { + t.Fatal("expected effort validation error") + } + }) + + t.Run("empty effort list clears default effort", func(t *testing.T) { + t.Parallel() + + cfg, err := resolveThinkingConfig(providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateSupported, + ThinkingDefaultEffort: "medium", + }, nil, true) + if err != nil { + t.Fatalf("resolveThinkingConfig() error = %v", err) + } + if cfg == nil || cfg.Effort != "" { + t.Fatalf("expected empty effort, got %+v", cfg) + } + }) +} + +func TestContainsEffortAndHintsLookup(t *testing.T) { + t.Parallel() + + if !containsEffort([]string{"low", "high"}, "high") { + t.Fatal("expected effort to be found") + } + if containsEffort([]string{"low"}, "max") { + t.Fatal("did not expect unknown effort to be found") + } + + models := []providertypes.ModelDescriptor{ + { + ID: "model-a", + CapabilityHints: providertypes.ModelCapabilityHints{ + Thinking: providertypes.ModelCapabilityStateSupported, + }, + }, + } + hints := modelCapabilityHintsForRequest("model-a", models) + if hints.Thinking != providertypes.ModelCapabilityStateSupported { + t.Fatalf("unexpected hints: %+v", hints) + } + if got := modelCapabilityHintsForRequest("missing", models); !reflect.DeepEqual(got, providertypes.ModelCapabilityHints{}) { + t.Fatalf("expected zero hints for missing model, got %+v", got) + } +} + +func TestServiceThinkingToggle(t *testing.T) { + t.Parallel() + + svc := NewWithFactory(nil, nil, nil, nil, nil) + if !svc.IsThinkingEnabled() { + t.Fatal("expected default thinking toggle to be enabled") + } + + svc.SetThinkingEnabled(false) + if svc.IsThinkingEnabled() { + t.Fatal("expected thinking toggle to be disabled") + } +}