Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions internal/provider/anthropic/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"strconv"
"strings"
"sync"

anthropic "github.com/anthropics/anthropic-sdk-go"

Expand All @@ -25,6 +26,14 @@ type toolCallState struct {
// Provider 封装 Anthropic messages 协议的请求发送与流式解析。
type Provider struct {
cfg provider.RuntimeConfig

mu sync.Mutex
prepared *preparedRequest
}

type preparedRequest struct {
signature string
params anthropic.MessageNewParams
}

// EstimateInputTokens 基于 Anthropic 最终请求结构做本地输入 token 估算。
Expand All @@ -40,10 +49,11 @@ func (p *Provider) EstimateInputTokens(
if err != nil {
return providertypes.BudgetEstimate{}, err
}
p.storePreparedRequest(provider.BuildGenerateRequestSignature(req), params)
return providertypes.BudgetEstimate{
EstimatedInputTokens: tokens,
EstimateSource: provider.EstimateSourceLocal,
GatePolicy: provider.EstimateGateGateable,
GatePolicy: provider.EstimateGateAdvisory,
}, nil
}

Expand All @@ -57,9 +67,13 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) {

// Generate 发起 Anthropic 流式请求,并将 typed stream 转为统一事件。
func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error {
params, err := BuildRequest(ctx, p.cfg, req)
if err != nil {
return err
params, ok := p.takePreparedRequest(provider.BuildGenerateRequestSignature(req))
if !ok {
var err error
params, err = BuildRequest(ctx, p.cfg, req)
if err != nil {
return err
}
}

client, err := newSDKClient(p.cfg)
Expand Down Expand Up @@ -185,6 +199,31 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque
return provider.EmitMessageDone(ctx, events, finishReason, &usage)
}

// storePreparedRequest 缓存估算阶段已构建的 Anthropic 请求,供同轮发送复用。
func (p *Provider) storePreparedRequest(signature string, params anthropic.MessageNewParams) {
p.mu.Lock()
defer p.mu.Unlock()
p.prepared = &preparedRequest{
signature: strings.TrimSpace(signature),
params: params,
}
}

// takePreparedRequest 读取并消费匹配签名的预构建请求,避免跨请求误复用。
func (p *Provider) takePreparedRequest(signature string) (anthropic.MessageNewParams, bool) {
p.mu.Lock()
defer p.mu.Unlock()
if p.prepared == nil {
return anthropic.MessageNewParams{}, false
}
current := p.prepared
p.prepared = nil
if strings.TrimSpace(signature) == "" || current.signature != strings.TrimSpace(signature) {
return anthropic.MessageNewParams{}, false
}
return current.params, true
}

// mapAnthropicSDKError 统一映射 SDK 错误为 provider 领域错误。
func mapAnthropicSDKError(err error) error {
var apiErr *anthropic.Error
Expand Down
71 changes: 65 additions & 6 deletions internal/provider/anthropic/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func TestBuildRequestSupportsImageParts(t *testing.T) {
},
},
},
SessionAssetReader: stubSessionAssetReader{
SessionAssetReader: &stubSessionAssetReader{
assets: map[string]stubSessionAsset{
"asset-1": {data: []byte("image-bytes"), mime: "image/png"},
},
Expand Down Expand Up @@ -199,7 +199,7 @@ func TestBuildRequestRejectsSessionAssetWithoutReader(t *testing.T) {
}
}

func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) {
func TestEstimateInputTokensReturnsAdvisoryLocalEstimate(t *testing.T) {
t.Parallel()

p, err := New(provider.RuntimeConfig{
Expand All @@ -225,14 +225,67 @@ func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) {
if estimate.EstimateSource != provider.EstimateSourceLocal {
t.Fatalf("estimate source = %q, want %q", estimate.EstimateSource, provider.EstimateSourceLocal)
}
if estimate.GatePolicy != provider.EstimateGateGateable {
t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateGateable)
if estimate.GatePolicy != provider.EstimateGateAdvisory {
t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateAdvisory)
}
if estimate.EstimatedInputTokens <= 0 {
t.Fatalf("expected positive estimate tokens, got %d", estimate.EstimatedInputTokens)
}
}

func TestEstimateThenGenerateReusesPreparedRequest(t *testing.T) {
t.Parallel()

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = fmt.Fprint(w, "event: message_start\n")
_, _ = fmt.Fprint(w, "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":4}}}\n\n")
_, _ = fmt.Fprint(w, "event: content_block_start\n")
_, _ = fmt.Fprint(w, "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"ok\"}}\n\n")
_, _ = fmt.Fprint(w, "event: message_delta\n")
_, _ = fmt.Fprint(w, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n")
_, _ = fmt.Fprint(w, "event: message_stop\n")
_, _ = fmt.Fprint(w, "data: {\"type\":\"message_stop\"}\n\n")
}))
defer server.Close()

p, err := New(provider.RuntimeConfig{
Driver: provider.DriverAnthropic,
BaseURL: server.URL,
DefaultModel: "claude-3-7-sonnet",
APIKeyEnv: "ANTHROPIC_TEST_KEY",
APIKeyResolver: provider.StaticAPIKeyResolver("test-key"),
})
if err != nil {
t.Fatalf("New() error = %v", err)
}

reader := &stubSessionAssetReader{
maxOpen: 1,
assets: map[string]stubSessionAsset{
"asset-1": {data: []byte("image-bytes"), mime: "image/png"},
},
}
request := providertypes.GenerateRequest{
Messages: []providertypes.Message{{
Role: providertypes.RoleUser,
Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")},
}},
SessionAssetReader: reader,
}
if _, err := p.EstimateInputTokens(context.Background(), request); err != nil {
t.Fatalf("EstimateInputTokens() error = %v", err)
}

events := make(chan providertypes.StreamEvent, 8)
if err := p.Generate(context.Background(), request, events); err != nil {
t.Fatalf("Generate() error = %v", err)
}
if reader.openCount != 1 {
t.Fatalf("expected session asset to be opened once, got %d", reader.openCount)
}
}

func drainEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent {
var drained []providertypes.StreamEvent
for {
Expand All @@ -252,10 +305,16 @@ type stubSessionAsset struct {
}

type stubSessionAssetReader struct {
assets map[string]stubSessionAsset
assets map[string]stubSessionAsset
openCount int
maxOpen int
}

func (r stubSessionAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) {
func (r *stubSessionAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) {
if r.maxOpen > 0 && r.openCount >= r.maxOpen {
return nil, "", fmt.Errorf("open limit exceeded for asset: %s", assetID)
}
r.openCount++
asset, ok := r.assets[assetID]
if !ok {
return nil, "", fmt.Errorf("asset not found: %s", assetID)
Expand Down
14 changes: 14 additions & 0 deletions internal/provider/estimate.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package provider

import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"math"

providertypes "neo-code/internal/provider/types"
)

const (
Expand All @@ -29,3 +33,13 @@ func EstimateTextTokens(text string) int {
}
return int(math.Ceil(float64(len([]byte(text))) / 4.0 * localEstimateSlack))
}

// BuildGenerateRequestSignature 生成 GenerateRequest 的稳定签名,用于估算与发送阶段的请求复用匹配。
func BuildGenerateRequestSignature(req providertypes.GenerateRequest) string {
encoded, err := json.Marshal(req)
if err != nil {
return ""
}
hash := sha256.Sum256(encoded)
return hex.EncodeToString(hash[:])
}
56 changes: 52 additions & 4 deletions internal/provider/gemini/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"strings"
"sync"

"google.golang.org/genai"

Expand All @@ -19,6 +20,16 @@ const errorPrefix = "gemini provider: "
// Provider 封装 Gemini native 协议的请求发送与流式响应解析。
type Provider struct {
cfg provider.RuntimeConfig

mu sync.Mutex
prepared *preparedRequest
}

type preparedRequest struct {
signature string
model string
contents []*genai.Content
config *genai.GenerateContentConfig
}

// EstimateInputTokens 基于 Gemini 最终请求结构做本地输入 token 估算。
Expand All @@ -43,10 +54,11 @@ func (p *Provider) EstimateInputTokens(
if err != nil {
return providertypes.BudgetEstimate{}, err
}
p.storePreparedRequest(provider.BuildGenerateRequestSignature(req), model, contents, genConfig)
return providertypes.BudgetEstimate{
EstimatedInputTokens: tokens,
EstimateSource: provider.EstimateSourceLocal,
GatePolicy: provider.EstimateGateGateable,
GatePolicy: provider.EstimateGateAdvisory,
}, nil
}

Expand All @@ -60,9 +72,13 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) {

// Generate 发起 Gemini 流式请求,并将 SDK chunk 转为统一流式事件。
func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error {
model, contents, config, err := BuildRequest(ctx, p.cfg, req)
if err != nil {
return err
model, contents, config, ok := p.takePreparedRequest(provider.BuildGenerateRequestSignature(req))
if !ok {
var err error
model, contents, config, err = BuildRequest(ctx, p.cfg, req)
if err != nil {
return err
}
}
normalizedModel := normalizeGeminiModelName(model)
if normalizedModel == "" {
Expand Down Expand Up @@ -144,6 +160,38 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque
return provider.EmitMessageDone(ctx, events, finishReason, &usage)
}

// storePreparedRequest 缓存估算阶段的 Gemini 构建结果,供同轮发送直接复用。
func (p *Provider) storePreparedRequest(
signature string,
model string,
contents []*genai.Content,
config *genai.GenerateContentConfig,
) {
p.mu.Lock()
defer p.mu.Unlock()
p.prepared = &preparedRequest{
signature: strings.TrimSpace(signature),
model: model,
contents: contents,
config: config,
}
}

// takePreparedRequest 读取并消费签名匹配的预构建请求,避免跨请求误复用。
func (p *Provider) takePreparedRequest(signature string) (string, []*genai.Content, *genai.GenerateContentConfig, bool) {
p.mu.Lock()
defer p.mu.Unlock()
if p.prepared == nil {
return "", nil, nil, false
}
current := p.prepared
p.prepared = nil
if strings.TrimSpace(signature) == "" || current.signature != strings.TrimSpace(signature) {
return "", nil, nil, false
}
return current.model, current.contents, current.config, true
}

// normalizeGeminiModelName 统一清洗 Gemini 模型名,兼容 discover 返回的 "models/{id}" 形式。
func normalizeGeminiModelName(model string) string {
trimmed := strings.TrimSpace(model)
Expand Down
Loading
Loading