diff --git a/internal/provider/anthropic/provider.go b/internal/provider/anthropic/provider.go index f7db2c17..1e176a16 100644 --- a/internal/provider/anthropic/provider.go +++ b/internal/provider/anthropic/provider.go @@ -6,6 +6,7 @@ import ( "fmt" "strconv" "strings" + "sync" anthropic "github.com/anthropics/anthropic-sdk-go" @@ -25,6 +26,14 @@ type toolCallState struct { // Provider 封装 Anthropic messages 协议的请求发送与流式解析。 type Provider struct { cfg provider.RuntimeConfig + + mu sync.Mutex + prepared *preparedRequest +} + +type preparedRequest struct { + signature string + params anthropic.MessageNewParams } // EstimateInputTokens 基于 Anthropic 最终请求结构做本地输入 token 估算。 @@ -40,10 +49,11 @@ func (p *Provider) EstimateInputTokens( if err != nil { return providertypes.BudgetEstimate{}, err } + p.storePreparedRequest(provider.BuildGenerateRequestSignature(req), params) return providertypes.BudgetEstimate{ EstimatedInputTokens: tokens, EstimateSource: provider.EstimateSourceLocal, - GatePolicy: provider.EstimateGateGateable, + GatePolicy: provider.EstimateGateAdvisory, }, nil } @@ -57,9 +67,13 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { // Generate 发起 Anthropic 流式请求,并将 typed stream 转为统一事件。 func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { - params, err := BuildRequest(ctx, p.cfg, req) - if err != nil { - return err + params, ok := p.takePreparedRequest(provider.BuildGenerateRequestSignature(req)) + if !ok { + var err error + params, err = BuildRequest(ctx, p.cfg, req) + if err != nil { + return err + } } client, err := newSDKClient(p.cfg) @@ -185,6 +199,31 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque return provider.EmitMessageDone(ctx, events, finishReason, &usage) } +// storePreparedRequest 缓存估算阶段已构建的 Anthropic 请求,供同轮发送复用。 +func (p *Provider) storePreparedRequest(signature string, params anthropic.MessageNewParams) { + p.mu.Lock() + defer p.mu.Unlock() + p.prepared = &preparedRequest{ + signature: strings.TrimSpace(signature), + params: params, + } +} + +// takePreparedRequest 读取并消费匹配签名的预构建请求,避免跨请求误复用。 +func (p *Provider) takePreparedRequest(signature string) (anthropic.MessageNewParams, bool) { + p.mu.Lock() + defer p.mu.Unlock() + if p.prepared == nil { + return anthropic.MessageNewParams{}, false + } + current := p.prepared + p.prepared = nil + if strings.TrimSpace(signature) == "" || current.signature != strings.TrimSpace(signature) { + return anthropic.MessageNewParams{}, false + } + return current.params, true +} + // mapAnthropicSDKError 统一映射 SDK 错误为 provider 领域错误。 func mapAnthropicSDKError(err error) error { var apiErr *anthropic.Error diff --git a/internal/provider/anthropic/provider_test.go b/internal/provider/anthropic/provider_test.go index 32b3e43f..63a2ecdd 100644 --- a/internal/provider/anthropic/provider_test.go +++ b/internal/provider/anthropic/provider_test.go @@ -154,7 +154,7 @@ func TestBuildRequestSupportsImageParts(t *testing.T) { }, }, }, - SessionAssetReader: stubSessionAssetReader{ + SessionAssetReader: &stubSessionAssetReader{ assets: map[string]stubSessionAsset{ "asset-1": {data: []byte("image-bytes"), mime: "image/png"}, }, @@ -199,7 +199,7 @@ func TestBuildRequestRejectsSessionAssetWithoutReader(t *testing.T) { } } -func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { +func TestEstimateInputTokensReturnsAdvisoryLocalEstimate(t *testing.T) { t.Parallel() p, err := New(provider.RuntimeConfig{ @@ -225,14 +225,67 @@ func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { if estimate.EstimateSource != provider.EstimateSourceLocal { t.Fatalf("estimate source = %q, want %q", estimate.EstimateSource, provider.EstimateSourceLocal) } - if estimate.GatePolicy != provider.EstimateGateGateable { - t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateGateable) + if estimate.GatePolicy != provider.EstimateGateAdvisory { + t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateAdvisory) } if estimate.EstimatedInputTokens <= 0 { t.Fatalf("expected positive estimate tokens, got %d", estimate.EstimatedInputTokens) } } +func TestEstimateThenGenerateReusesPreparedRequest(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "event: message_start\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":4}}}\n\n") + _, _ = fmt.Fprint(w, "event: content_block_start\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"ok\"}}\n\n") + _, _ = fmt.Fprint(w, "event: message_delta\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n") + _, _ = fmt.Fprint(w, "event: message_stop\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_stop\"}\n\n") + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverAnthropic, + BaseURL: server.URL, + DefaultModel: "claude-3-7-sonnet", + APIKeyEnv: "ANTHROPIC_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + reader := &stubSessionAssetReader{ + maxOpen: 1, + assets: map[string]stubSessionAsset{ + "asset-1": {data: []byte("image-bytes"), mime: "image/png"}, + }, + } + request := providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + SessionAssetReader: reader, + } + if _, err := p.EstimateInputTokens(context.Background(), request); err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), request, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + if reader.openCount != 1 { + t.Fatalf("expected session asset to be opened once, got %d", reader.openCount) + } +} + func drainEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { var drained []providertypes.StreamEvent for { @@ -252,10 +305,16 @@ type stubSessionAsset struct { } type stubSessionAssetReader struct { - assets map[string]stubSessionAsset + assets map[string]stubSessionAsset + openCount int + maxOpen int } -func (r stubSessionAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) { +func (r *stubSessionAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) { + if r.maxOpen > 0 && r.openCount >= r.maxOpen { + return nil, "", fmt.Errorf("open limit exceeded for asset: %s", assetID) + } + r.openCount++ asset, ok := r.assets[assetID] if !ok { return nil, "", fmt.Errorf("asset not found: %s", assetID) diff --git a/internal/provider/estimate.go b/internal/provider/estimate.go index 8467a62b..2f0499e8 100644 --- a/internal/provider/estimate.go +++ b/internal/provider/estimate.go @@ -1,8 +1,12 @@ package provider import ( + "crypto/sha256" + "encoding/hex" "encoding/json" "math" + + providertypes "neo-code/internal/provider/types" ) const ( @@ -29,3 +33,13 @@ func EstimateTextTokens(text string) int { } return int(math.Ceil(float64(len([]byte(text))) / 4.0 * localEstimateSlack)) } + +// BuildGenerateRequestSignature 生成 GenerateRequest 的稳定签名,用于估算与发送阶段的请求复用匹配。 +func BuildGenerateRequestSignature(req providertypes.GenerateRequest) string { + encoded, err := json.Marshal(req) + if err != nil { + return "" + } + hash := sha256.Sum256(encoded) + return hex.EncodeToString(hash[:]) +} diff --git a/internal/provider/gemini/provider.go b/internal/provider/gemini/provider.go index af1a9b5d..5a6ea83e 100644 --- a/internal/provider/gemini/provider.go +++ b/internal/provider/gemini/provider.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "strings" + "sync" "google.golang.org/genai" @@ -19,6 +20,16 @@ const errorPrefix = "gemini provider: " // Provider 封装 Gemini native 协议的请求发送与流式响应解析。 type Provider struct { cfg provider.RuntimeConfig + + mu sync.Mutex + prepared *preparedRequest +} + +type preparedRequest struct { + signature string + model string + contents []*genai.Content + config *genai.GenerateContentConfig } // EstimateInputTokens 基于 Gemini 最终请求结构做本地输入 token 估算。 @@ -43,10 +54,11 @@ func (p *Provider) EstimateInputTokens( if err != nil { return providertypes.BudgetEstimate{}, err } + p.storePreparedRequest(provider.BuildGenerateRequestSignature(req), model, contents, genConfig) return providertypes.BudgetEstimate{ EstimatedInputTokens: tokens, EstimateSource: provider.EstimateSourceLocal, - GatePolicy: provider.EstimateGateGateable, + GatePolicy: provider.EstimateGateAdvisory, }, nil } @@ -60,9 +72,13 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { // Generate 发起 Gemini 流式请求,并将 SDK chunk 转为统一流式事件。 func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { - model, contents, config, err := BuildRequest(ctx, p.cfg, req) - if err != nil { - return err + model, contents, config, ok := p.takePreparedRequest(provider.BuildGenerateRequestSignature(req)) + if !ok { + var err error + model, contents, config, err = BuildRequest(ctx, p.cfg, req) + if err != nil { + return err + } } normalizedModel := normalizeGeminiModelName(model) if normalizedModel == "" { @@ -144,6 +160,38 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque return provider.EmitMessageDone(ctx, events, finishReason, &usage) } +// storePreparedRequest 缓存估算阶段的 Gemini 构建结果,供同轮发送直接复用。 +func (p *Provider) storePreparedRequest( + signature string, + model string, + contents []*genai.Content, + config *genai.GenerateContentConfig, +) { + p.mu.Lock() + defer p.mu.Unlock() + p.prepared = &preparedRequest{ + signature: strings.TrimSpace(signature), + model: model, + contents: contents, + config: config, + } +} + +// takePreparedRequest 读取并消费签名匹配的预构建请求,避免跨请求误复用。 +func (p *Provider) takePreparedRequest(signature string) (string, []*genai.Content, *genai.GenerateContentConfig, bool) { + p.mu.Lock() + defer p.mu.Unlock() + if p.prepared == nil { + return "", nil, nil, false + } + current := p.prepared + p.prepared = nil + if strings.TrimSpace(signature) == "" || current.signature != strings.TrimSpace(signature) { + return "", nil, nil, false + } + return current.model, current.contents, current.config, true +} + // normalizeGeminiModelName 统一清洗 Gemini 模型名,兼容 discover 返回的 "models/{id}" 形式。 func normalizeGeminiModelName(model string) string { trimmed := strings.TrimSpace(model) diff --git a/internal/provider/gemini/provider_test.go b/internal/provider/gemini/provider_test.go index e9e866fc..3e8296bb 100644 --- a/internal/provider/gemini/provider_test.go +++ b/internal/provider/gemini/provider_test.go @@ -135,7 +135,7 @@ func TestBuildRequestSupportsImageParts(t *testing.T) { }, }, }, - SessionAssetReader: stubSessionAssetReader{ + SessionAssetReader: &stubSessionAssetReader{ assets: map[string]stubSessionAsset{ "asset-1": {data: []byte("image-bytes"), mime: "image/png"}, }, @@ -186,7 +186,7 @@ func TestBuildRequestRejectsSessionAssetWithoutReader(t *testing.T) { } } -func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { +func TestEstimateInputTokensReturnsAdvisoryLocalEstimate(t *testing.T) { t.Parallel() p, err := New(provider.RuntimeConfig{ @@ -212,14 +212,61 @@ func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { if estimate.EstimateSource != provider.EstimateSourceLocal { t.Fatalf("estimate source = %q, want %q", estimate.EstimateSource, provider.EstimateSourceLocal) } - if estimate.GatePolicy != provider.EstimateGateGateable { - t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateGateable) + if estimate.GatePolicy != provider.EstimateGateAdvisory { + t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateAdvisory) } if estimate.EstimatedInputTokens <= 0 { t.Fatalf("expected positive estimate tokens, got %d", estimate.EstimatedInputTokens) } } +func TestEstimateThenGenerateReusesPreparedRequest(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"content\":{\"parts\":[{\"text\":\"ok\"}]}}],\"usageMetadata\":{\"promptTokenCount\":5,\"candidatesTokenCount\":2,\"totalTokenCount\":7}}\n\n") + _, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"finishReason\":\"STOP\",\"content\":{\"parts\":[]}}]}\n\n") + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + reader := &stubSessionAssetReader{ + maxOpen: 1, + assets: map[string]stubSessionAsset{ + "asset-1": {data: []byte("image-bytes"), mime: "image/png"}, + }, + } + request := providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + SessionAssetReader: reader, + } + if _, err := p.EstimateInputTokens(context.Background(), request); err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), request, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + if reader.openCount != 1 { + t.Fatalf("expected session asset to be opened once, got %d", reader.openCount) + } +} + func drainEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { var drained []providertypes.StreamEvent for { @@ -239,10 +286,16 @@ type stubSessionAsset struct { } type stubSessionAssetReader struct { - assets map[string]stubSessionAsset + assets map[string]stubSessionAsset + openCount int + maxOpen int } -func (r stubSessionAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) { +func (r *stubSessionAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) { + if r.maxOpen > 0 && r.openCount >= r.maxOpen { + return nil, "", fmt.Errorf("open limit exceeded for asset: %s", assetID) + } + r.openCount++ asset, ok := r.assets[assetID] if !ok { return nil, "", fmt.Errorf("asset not found: %s", assetID) diff --git a/internal/provider/openaicompat/generate_sdk.go b/internal/provider/openaicompat/generate_sdk.go index 890997d6..4bcc78de 100644 --- a/internal/provider/openaicompat/generate_sdk.go +++ b/internal/provider/openaicompat/generate_sdk.go @@ -21,14 +21,9 @@ import ( // generateSDKChatCompletions 走 SDK chat/completions 发送请求 func (p *Provider) generateSDKChatCompletions( ctx context.Context, - req providertypes.GenerateRequest, + payload chatcompletions.Request, events chan<- providertypes.StreamEvent, ) error { - payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) - if err != nil { - return err - } - client, err := p.newSDKClient() if err != nil { return err @@ -280,13 +275,9 @@ func (p *Provider) generateChatCompletionsWithCompatibleStream( // generateSDKResponses 走 SDK responses 发送请求,复用本地流事件映射。 func (p *Provider) generateSDKResponses( ctx context.Context, - req providertypes.GenerateRequest, + payload responses.Request, events chan<- providertypes.StreamEvent, ) error { - payload, err := responses.BuildRequest(ctx, p.cfg, req) - if err != nil { - return err - } endpoint, err := resolveChatEndpoint(p.cfg) if err != nil { return err diff --git a/internal/provider/openaicompat/openaicompat_test.go b/internal/provider/openaicompat/openaicompat_test.go index 3cb176be..a45196c9 100644 --- a/internal/provider/openaicompat/openaicompat_test.go +++ b/internal/provider/openaicompat/openaicompat_test.go @@ -3,6 +3,7 @@ package openaicompat import ( "context" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" @@ -245,7 +246,7 @@ func TestDiscoverModelsParsesNestedContainerAndAliasFields(t *testing.T) { } } -func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { +func TestEstimateInputTokensReturnsAdvisoryLocalEstimate(t *testing.T) { t.Parallel() p, err := New(resolvedConfig("", "")) @@ -264,14 +265,61 @@ func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { if estimate.EstimateSource != provider.EstimateSourceLocal { t.Fatalf("estimate source = %q, want %q", estimate.EstimateSource, provider.EstimateSourceLocal) } - if estimate.GatePolicy != provider.EstimateGateGateable { - t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateGateable) + if estimate.GatePolicy != provider.EstimateGateAdvisory { + t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateAdvisory) } if estimate.EstimatedInputTokens <= 0 { t.Fatalf("expected positive estimate tokens, got %d", estimate.EstimatedInputTokens) } } +func TestEstimateThenGenerateReusesPreparedRequest(t *testing.T) { + t.Setenv(config.OpenAIDefaultAPIKeyEnv, "test-key") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte(`data: {"choices":[{"delta":{"content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}} +data: [DONE] + +`)) + })) + defer server.Close() + + p, err := New(resolvedConfig(server.URL, "gpt-4.1")) + if err != nil { + t.Fatalf("New() error = %v", err) + } + p.client = server.Client() + + reader := &singleUseSessionAssetReader{ + maxOpen: 1, + assets: map[string]sessionAsset{ + "asset-1": {data: []byte("image-bytes"), mime: "image/png"}, + }, + } + request := providertypes.GenerateRequest{ + Model: "gpt-4.1", + Messages: []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }, + }, + SessionAssetReader: reader, + } + if _, err := p.EstimateInputTokens(context.Background(), request); err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), request, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + if reader.openCount != 1 { + t.Fatalf("expected session asset to be opened once, got %d", reader.openCount) + } +} + func TestDiscoverModelsOpenAIProfileFallsBackToGenericListKeys(t *testing.T) { t.Parallel() @@ -733,3 +781,30 @@ func (r *cancelAfterDoneReader) Read(p []byte) (int, error) { r.cancel() return 0, r.err } + +type sessionAsset struct { + data []byte + mime string + err error +} + +type singleUseSessionAssetReader struct { + assets map[string]sessionAsset + openCount int + maxOpen int +} + +func (r *singleUseSessionAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) { + if r.maxOpen > 0 && r.openCount >= r.maxOpen { + return nil, "", fmt.Errorf("open limit exceeded for asset: %s", assetID) + } + r.openCount++ + asset, ok := r.assets[assetID] + if !ok { + return nil, "", fmt.Errorf("asset not found: %s", assetID) + } + if asset.err != nil { + return nil, "", asset.err + } + return io.NopCloser(strings.NewReader(string(asset.data))), asset.mime, nil +} diff --git a/internal/provider/openaicompat/provider.go b/internal/provider/openaicompat/provider.go index 6227f9a2..db841b81 100644 --- a/internal/provider/openaicompat/provider.go +++ b/internal/provider/openaicompat/provider.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "strings" + "sync" "time" "neo-code/internal/provider" @@ -38,6 +39,15 @@ func validateRuntimeConfig(cfg provider.RuntimeConfig) error { type Provider struct { cfg provider.RuntimeConfig client *http.Client + + mu sync.Mutex + prepared *preparedRequest +} + +type preparedRequest struct { + mode string + signature string + payload any } // EstimateInputTokens 基于 OpenAI-compatible 最终请求结构做本地输入 token 估算。 @@ -58,12 +68,18 @@ func (p *Provider) EstimateInputTokens( return providertypes.BudgetEstimate{}, buildErr } tokens, err = provider.EstimateSerializedPayloadTokens(payload) + if err == nil { + p.storePreparedRequest(mode, provider.BuildGenerateRequestSignature(req), payload) + } case executionModeResponses: payload, buildErr := responses.BuildRequest(ctx, p.cfg, req) if buildErr != nil { return providertypes.BudgetEstimate{}, buildErr } tokens, err = provider.EstimateSerializedPayloadTokens(payload) + if err == nil { + p.storePreparedRequest(mode, provider.BuildGenerateRequestSignature(req), payload) + } default: return providertypes.BudgetEstimate{}, provider.NewDiscoveryConfigError( fmt.Sprintf("openaicompat provider: driver %q resolved unsupported execution mode %q", p.cfg.Driver, mode), @@ -75,7 +91,7 @@ func (p *Provider) EstimateInputTokens( return providertypes.BudgetEstimate{ EstimatedInputTokens: tokens, EstimateSource: provider.EstimateSourceLocal, - GatePolicy: provider.EstimateGateGateable, + GatePolicy: provider.EstimateGateAdvisory, }, nil } @@ -139,9 +155,25 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque switch mode { case executionModeCompletions: - return p.generateSDKChatCompletions(ctx, req, events) + signature := provider.BuildGenerateRequestSignature(req) + if payload, ok := p.takePreparedChatCompletionsRequest(mode, signature); ok { + return p.generateSDKChatCompletions(ctx, payload, events) + } + payload, buildErr := chatcompletions.BuildRequest(ctx, p.cfg, req) + if buildErr != nil { + return buildErr + } + return p.generateSDKChatCompletions(ctx, payload, events) case executionModeResponses: - return p.generateSDKResponses(ctx, req, events) + signature := provider.BuildGenerateRequestSignature(req) + if payload, ok := p.takePreparedResponsesRequest(mode, signature); ok { + return p.generateSDKResponses(ctx, payload, events) + } + payload, buildErr := responses.BuildRequest(ctx, p.cfg, req) + if buildErr != nil { + return buildErr + } + return p.generateSDKResponses(ctx, payload, events) default: return provider.NewDiscoveryConfigError( fmt.Sprintf("openaicompat provider: driver %q resolved unsupported execution mode %q", p.cfg.Driver, mode), @@ -149,6 +181,61 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque } } +// storePreparedRequest 缓存估算阶段已构建请求,供同轮发送复用以避免重复构建。 +func (p *Provider) storePreparedRequest(mode string, signature string, payload any) { + p.mu.Lock() + defer p.mu.Unlock() + p.prepared = &preparedRequest{ + mode: mode, + signature: strings.TrimSpace(signature), + payload: payload, + } +} + +// takePreparedChatCompletionsRequest 读取并消费 chat/completions 预构建请求,仅在签名匹配时命中。 +func (p *Provider) takePreparedChatCompletionsRequest(mode string, signature string) (chatcompletions.Request, bool) { + raw, ok := p.takePreparedRequest(mode, signature) + if !ok { + return chatcompletions.Request{}, false + } + payload, ok := raw.(chatcompletions.Request) + if !ok { + return chatcompletions.Request{}, false + } + return payload, true +} + +// takePreparedResponsesRequest 读取并消费 responses 预构建请求,仅在签名匹配时命中。 +func (p *Provider) takePreparedResponsesRequest(mode string, signature string) (responses.Request, bool) { + raw, ok := p.takePreparedRequest(mode, signature) + if !ok { + return responses.Request{}, false + } + payload, ok := raw.(responses.Request) + if !ok { + return responses.Request{}, false + } + return payload, true +} + +// takePreparedRequest 读取并消费缓存请求,避免跨请求误复用。 +func (p *Provider) takePreparedRequest(mode string, signature string) (any, bool) { + p.mu.Lock() + defer p.mu.Unlock() + if p.prepared == nil { + return nil, false + } + current := p.prepared + p.prepared = nil + if current.mode != mode { + return nil, false + } + if strings.TrimSpace(signature) == "" || current.signature != strings.TrimSpace(signature) { + return nil, false + } + return current.payload, true +} + // resolveExecutionMode 解析当前配置对应的 OpenAI-compatible 执行模式。 func resolveExecutionMode(cfg provider.RuntimeConfig) (string, error) { if provider.NormalizeProviderDriver(cfg.Driver) != DriverName { diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 681d3d23..76316f24 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -147,7 +147,12 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { continue } - decision, err := s.evaluateTurnBudget(ctx, &state, snapshot) + modelProvider, err := s.providerFactory.Build(ctx, snapshot.ProviderConfig) + if err != nil { + return s.handleRunError(ctx, state.runID, state.session.ID, err) + } + + decision, err := s.evaluateTurnBudget(ctx, &state, snapshot, modelProvider) if err != nil { return s.handleRunError(ctx, state.runID, state.session.ID, err) } @@ -168,7 +173,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { return nil } - turnOutput, err := s.callProviderWithRetry(ctx, &state, snapshot) + turnOutput, err := s.callProviderWithRetry(ctx, &state, snapshot, modelProvider) if err != nil { if provider.IsContextTooLong(err) && state.reactiveCompactAttempts < snapshot.Config.Context.Budget.MaxReactiveCompacts { @@ -388,6 +393,7 @@ func (s *Service) callProviderWithRetry( ctx context.Context, state *runState, snapshot TurnBudgetSnapshot, + initialProvider provider.Provider, ) (turnProviderOutput, error) { var lastErr error @@ -405,9 +411,13 @@ func (s *Service) callProviderWithRetry( } } - modelProvider, err := s.providerFactory.Build(ctx, snapshot.ProviderConfig) - if err != nil { - return turnProviderOutput{}, err + modelProvider := initialProvider + if retryAttempt > 0 { + var err error + modelProvider, err = s.providerFactory.Build(ctx, snapshot.ProviderConfig) + if err != nil { + return turnProviderOutput{}, err + } } streamOutcome := generateStreamingMessage(ctx, modelProvider, snapshot.Request, streaming.Hooks{ @@ -524,11 +534,8 @@ func (s *Service) evaluateTurnBudget( ctx context.Context, state *runState, snapshot TurnBudgetSnapshot, + modelProvider provider.Provider, ) (controlplane.TurnBudgetDecision, error) { - modelProvider, err := s.providerFactory.Build(ctx, snapshot.ProviderConfig) - if err != nil { - return controlplane.TurnBudgetDecision{}, err - } providerEstimate, err := modelProvider.EstimateInputTokens(ctx, snapshot.Request) if err != nil { return controlplane.TurnBudgetDecision{}, err diff --git a/internal/runtime/runtime_remaining_branches_test.go b/internal/runtime/runtime_remaining_branches_test.go index dd1c3aab..2eeabc2c 100644 --- a/internal/runtime/runtime_remaining_branches_test.go +++ b/internal/runtime/runtime_remaining_branches_test.go @@ -606,7 +606,12 @@ func TestRunAndProviderRetryRemainingBranches(t *testing.T) { }() service := &Service{providerFactory: &scriptedProviderFactory{provider: providerRetry}, events: make(chan RuntimeEvent, 8)} state := newRunState("run-retry-backoff", newRuntimeSession("session-retry-backoff")) - _, err := service.callProviderWithRetry(ctx, &state, TurnBudgetSnapshot{ProviderConfig: provider.RuntimeConfig{Name: "x"}}) + _, err := service.callProviderWithRetry( + ctx, + &state, + TurnBudgetSnapshot{ProviderConfig: provider.RuntimeConfig{Name: "x"}}, + providerRetry, + ) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } @@ -621,7 +626,12 @@ func TestRunAndProviderRetryRemainingBranches(t *testing.T) { }} service := &Service{providerFactory: &scriptedProviderFactory{provider: providerRetry}, events: make(chan RuntimeEvent, 8)} state := newRunState("run-retry-ctx-check", newRuntimeSession("session-retry-ctx-check")) - _, err := service.callProviderWithRetry(ctx, &state, TurnBudgetSnapshot{ProviderConfig: provider.RuntimeConfig{Name: "x"}}) + _, err := service.callProviderWithRetry( + ctx, + &state, + TurnBudgetSnapshot{ProviderConfig: provider.RuntimeConfig{Name: "x"}}, + providerRetry, + ) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index c06059b1..091ed67a 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -1022,7 +1022,7 @@ func TestServiceRun(t *testing.T) { t.Fatalf("Run() error = %v", err) } - expectedProviderBuilds := tt.expectProviderCalls * 2 + expectedProviderBuilds := tt.expectProviderCalls if factory.calls != expectedProviderBuilds { t.Fatalf("expected %d provider builds, got %d", expectedProviderBuilds, factory.calls) } @@ -3797,6 +3797,7 @@ func TestCallProviderWithRetryReturnsCombinedForwardError(t *testing.T) { context.Background(), &state, snapshot, + scripted, ) if err == nil || !containsError(err, "provider stream handling failed after provider error") { t.Fatalf("expected combined forward/provider error, got %v", err) diff --git a/internal/tui/services/gateway_stream_client.go b/internal/tui/services/gateway_stream_client.go index 34219c1d..72d137ba 100644 --- a/internal/tui/services/gateway_stream_client.go +++ b/internal/tui/services/gateway_stream_client.go @@ -71,15 +71,19 @@ func (c *GatewayStreamClient) run() { event, err := decodeRuntimeEventFromGatewayNotification(notification) if err != nil { + errMessage := fmt.Sprintf("gateway stream decode error: %v", err) select { case <-c.closeCh: return case c.events <- RuntimeEvent{ Type: EventError, Timestamp: time.Now().UTC(), - Payload: fmt.Sprintf("gateway stream decode error: %v", err), + Payload: errMessage, }: } + if isRuntimePayloadVersionMismatch(errMessage) { + return + } continue } @@ -92,6 +96,13 @@ func (c *GatewayStreamClient) run() { } } +// isRuntimePayloadVersionMismatch 判断错误是否由 runtime 事件版本不匹配触发,用于快速停止消费避免噪声洪泛。 +func isRuntimePayloadVersionMismatch(errMessage string) bool { + normalized := strings.ToLower(strings.TrimSpace(errMessage)) + return strings.Contains(normalized, "payload_version") && + strings.Contains(normalized, "unsupported") +} + // decodeRuntimeEventFromGatewayNotification 将 gateway.event 通知还原为事件。 func decodeRuntimeEventFromGatewayNotification(notification gatewayRPCNotification) (RuntimeEvent, error) { var frame gateway.MessageFrame diff --git a/internal/tui/services/gateway_stream_client_additional_test.go b/internal/tui/services/gateway_stream_client_additional_test.go index 25a6a660..da167c07 100644 --- a/internal/tui/services/gateway_stream_client_additional_test.go +++ b/internal/tui/services/gateway_stream_client_additional_test.go @@ -317,6 +317,57 @@ func TestGatewayStreamClientRunSkipsNonGatewayEventsAndStopsOnClose(t *testing.T } } +func TestGatewayStreamClientRunStopsOnPayloadVersionMismatch(t *testing.T) { + t.Parallel() + + source := make(chan gatewayRPCNotification, 3) + client := NewGatewayStreamClient(source) + + source <- buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "runtime_event_type": string(EventAgentChunk), + "payload_version": runtimeEventPayloadVersion - 1, + "payload": "legacy", + }, + }) + source <- buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "runtime_event_type": string(EventAgentChunk), + "payload_version": runtimeEventPayloadVersion, + "payload": "ok", + }, + }) + + select { + case event, ok := <-client.Events(): + if !ok { + t.Fatalf("events channel closed before decode error event") + } + if event.Type != EventError { + t.Fatalf("event.Type = %q, want %q", event.Type, EventError) + } + payload, payloadOK := event.Payload.(string) + if !payloadOK || !containsAll(payload, "payload_version", "want") { + t.Fatalf("event.Payload = %#v", event.Payload) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for decode error event") + } + + select { + case _, ok := <-client.Events(): + if ok { + t.Fatalf("expected stream to stop after payload version mismatch") + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for events channel close") + } +} + func containsAll(input string, subs ...string) bool { for _, sub := range subs { if !strings.Contains(input, sub) {