Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion internal/provider/anthropic/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,11 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque
case anthropic.MessageStartEvent:
if variant.Message.Usage.InputTokens > 0 {
usage.InputTokens = int(variant.Message.Usage.InputTokens)
usage.InputObserved = true
}
if variant.Message.Usage.OutputTokens > 0 {
usage.OutputTokens = int(variant.Message.Usage.OutputTokens)
usage.OutputObserved = true
}
case anthropic.ContentBlockStartEvent:
switch block := variant.ContentBlock.AsAny().(type) {
Expand Down Expand Up @@ -167,9 +169,11 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque
}
if variant.Usage.OutputTokens > 0 {
usage.OutputTokens = int(variant.Usage.OutputTokens)
usage.OutputObserved = true
}
if variant.Usage.InputTokens > 0 {
usage.InputTokens = int(variant.Usage.InputTokens)
usage.InputObserved = true
}
}
}
Expand All @@ -193,9 +197,12 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque
return fmt.Errorf("%sinvalid tool_use stream at index %d: missing tool name", errorPrefix, index)
}
}
if usage.TotalTokens <= 0 {
if usage.TotalTokens <= 0 && (usage.InputObserved || usage.OutputObserved) {
usage.TotalTokens = usage.InputTokens + usage.OutputTokens
}
if !usage.InputObserved && !usage.OutputObserved {
return provider.EmitMessageDone(ctx, events, finishReason, nil)
}
return provider.EmitMessageDone(ctx, events, finishReason, &usage)
}

Expand Down
59 changes: 59 additions & 0 deletions internal/provider/anthropic/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,72 @@ func TestProviderGenerate(t *testing.T) {
if payload.Usage == nil || payload.Usage.TotalTokens != 10 {
t.Fatalf("expected usage total tokens 10, got %+v", payload.Usage)
}
if !payload.Usage.InputObserved || !payload.Usage.OutputObserved {
t.Fatalf("expected usage observed flags true, got %+v", payload.Usage)
}
}
}
if !foundText || !foundToolStart || !foundToolDelta || !foundDone {
t.Fatalf("expected text/tool_start/tool_delta/done events, got %+v", drained)
}
}

func TestProviderGenerateOmitsUsageWhenProviderDidNotReturnUsage(t *testing.T) {
t.Parallel()

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = fmt.Fprint(w, "event: content_block_start\n")
_, _ = fmt.Fprint(w, "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"Hello\"}}\n\n")
_, _ = fmt.Fprint(w, "event: message_delta\n")
_, _ = fmt.Fprint(w, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"}}\n\n")
_, _ = fmt.Fprint(w, "event: message_stop\n")
_, _ = fmt.Fprint(w, "data: {\"type\":\"message_stop\"}\n\n")
}))
defer server.Close()

p, err := New(provider.RuntimeConfig{
Driver: provider.DriverAnthropic,
BaseURL: server.URL,
DefaultModel: "claude-3-7-sonnet",
APIKeyEnv: "ANTHROPIC_TEST_KEY",
APIKeyResolver: provider.StaticAPIKeyResolver("test-key"),
})
if err != nil {
t.Fatalf("New() error = %v", err)
}

events := make(chan providertypes.StreamEvent, 8)
if err := p.Generate(context.Background(), providertypes.GenerateRequest{
Messages: []providertypes.Message{{
Role: providertypes.RoleUser,
Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")},
}},
}, events); err != nil {
t.Fatalf("Generate() error = %v", err)
}

drained := drainEvents(events)
var done *providertypes.MessageDonePayload
for i := range drained {
if drained[i].Type != providertypes.StreamEventMessageDone {
continue
}
payload, payloadErr := drained[i].MessageDoneValue()
if payloadErr != nil {
t.Fatalf("MessageDoneValue() error = %v", payloadErr)
}
done = &payload
break
}
if done == nil {
t.Fatalf("expected message_done event, got %+v", drained)
}
if done.Usage != nil {
t.Fatalf("expected nil usage when provider does not report usage, got %+v", done.Usage)
}
}

func TestNewAcceptsCustomChatEndpointPath(t *testing.T) {
t.Parallel()

Expand Down
5 changes: 5 additions & 0 deletions internal/provider/gemini/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque
if !hasPayload {
return fmt.Errorf("%w: empty gemini stream payload", provider.ErrStreamInterrupted)
}
if !usage.InputObserved && !usage.OutputObserved {
return provider.EmitMessageDone(ctx, events, finishReason, nil)
}
return provider.EmitMessageDone(ctx, events, finishReason, &usage)
}

Expand Down Expand Up @@ -209,6 +212,8 @@ func extractUsage(usage *providertypes.Usage, raw *genai.GenerateContentResponse
usage.InputTokens = int(raw.PromptTokenCount)
usage.OutputTokens = int(raw.CandidatesTokenCount)
usage.TotalTokens = int(raw.TotalTokenCount)
usage.InputObserved = true
usage.OutputObserved = true
}

// encodeArguments 将函数参数对象编码为 JSON 字符串,供统一 tool_call_delta 事件复用。
Expand Down
55 changes: 55 additions & 0 deletions internal/provider/gemini/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ func TestProviderGenerate(t *testing.T) {
if payload.Usage == nil || payload.Usage.TotalTokens != 7 {
t.Fatalf("expected usage total tokens 7, got %+v", payload.Usage)
}
if !payload.Usage.InputObserved || !payload.Usage.OutputObserved {
t.Fatalf("expected usage observed flags true, got %+v", payload.Usage)
}
if payload.FinishReason != "stop" {
t.Fatalf("expected finish reason stop, got %q", payload.FinishReason)
}
Expand All @@ -90,6 +93,58 @@ func TestProviderGenerate(t *testing.T) {
}
}

func TestProviderGenerateOmitsUsageWhenProviderDidNotReturnUsage(t *testing.T) {
t.Parallel()

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"content\":{\"parts\":[{\"text\":\"Hello \"}]}}]}\n\n")
_, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"finishReason\":\"STOP\",\"content\":{\"parts\":[{\"text\":\"done\"}]}}]}\n\n")
}))
defer server.Close()

p, err := New(provider.RuntimeConfig{
Driver: provider.DriverGemini,
BaseURL: server.URL,
DefaultModel: "gemini-2.5-flash",
APIKeyEnv: "GEMINI_TEST_KEY",
APIKeyResolver: provider.StaticAPIKeyResolver("test-key"),
})
if err != nil {
t.Fatalf("New() error = %v", err)
}

events := make(chan providertypes.StreamEvent, 8)
if err := p.Generate(context.Background(), providertypes.GenerateRequest{
Messages: []providertypes.Message{{
Role: providertypes.RoleUser,
Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")},
}},
}, events); err != nil {
t.Fatalf("Generate() error = %v", err)
}

drained := drainEvents(events)
var done *providertypes.MessageDonePayload
for i := range drained {
if drained[i].Type != providertypes.StreamEventMessageDone {
continue
}
payload, payloadErr := drained[i].MessageDoneValue()
if payloadErr != nil {
t.Fatalf("MessageDoneValue() error = %v", payloadErr)
}
done = &payload
break
}
if done == nil {
t.Fatalf("expected message_done event, got %+v", drained)
}
if done.Usage != nil {
t.Fatalf("expected nil usage when provider does not report usage, got %+v", done.Usage)
}
}

func TestNewAcceptsCustomChatEndpointPath(t *testing.T) {
t.Parallel()

Expand Down
40 changes: 28 additions & 12 deletions internal/provider/openaicompat/chatcompletions/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ func EmitFromSDKStream(
return fmt.Errorf("SDK stream error: %w", err)
}

if !usage.InputObserved && !usage.OutputObserved {
return provider.EmitMessageDone(ctx, events, finishReason, nil)
}
return provider.EmitMessageDone(ctx, events, finishReason, &usage)
}

Expand Down Expand Up @@ -183,13 +186,13 @@ func ConsumeStream(
if flushErr := flushDataLines(); flushErr != nil {
return flushErr
}
return provider.EmitMessageDone(ctx, events, finishReason, &usage)
return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage))
}
if flushErr := flushDataLines(); flushErr != nil {
return flushErr
}
if strings.TrimSpace(finishReason) != "" {
return provider.EmitMessageDone(ctx, events, finishReason, &usage)
return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage))
}
return fmt.Errorf("%w: %w", provider.ErrStreamInterrupted, err)
}
Expand All @@ -206,7 +209,7 @@ func ConsumeStream(
return flushErr
}
done = true
return provider.EmitMessageDone(ctx, events, finishReason, &usage)
return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage))
} else {
dataLines = append(dataLines, data)
}
Expand All @@ -215,7 +218,7 @@ func ConsumeStream(
return flushErr
}
if done {
return provider.EmitMessageDone(ctx, events, finishReason, &usage)
return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage))
}
default:
if len(dataLines) == 0 {
Expand All @@ -225,7 +228,7 @@ func ConsumeStream(
return flushErr
}
if done {
return provider.EmitMessageDone(ctx, events, finishReason, &usage)
return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage))
}
}

Expand All @@ -234,7 +237,7 @@ func ConsumeStream(
return flushErr
}
if done || strings.TrimSpace(finishReason) != "" {
return provider.EmitMessageDone(ctx, events, finishReason, &usage)
return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage))
}
return fmt.Errorf("%w: missing [DONE] marker before EOF", provider.ErrStreamInterrupted)
}
Expand All @@ -247,19 +250,32 @@ func extractLegacyStreamUsage(usage *providertypes.Usage, raw *streamUsage) {
return
}
*usage = providertypes.Usage{
InputTokens: raw.PromptTokens,
OutputTokens: raw.CompletionTokens,
TotalTokens: raw.TotalTokens,
InputTokens: raw.PromptTokens,
OutputTokens: raw.CompletionTokens,
TotalTokens: raw.TotalTokens,
InputObserved: true,
OutputObserved: true,
}
}

// extractStreamUsage 将 OpenAI usage 覆盖到统一 token 统计。
func extractStreamUsage(usage *providertypes.Usage, raw openai.CompletionUsage) {
*usage = providertypes.Usage{
InputTokens: int(raw.PromptTokens),
OutputTokens: int(raw.CompletionTokens),
TotalTokens: int(raw.TotalTokens),
InputTokens: int(raw.PromptTokens),
OutputTokens: int(raw.CompletionTokens),
TotalTokens: int(raw.TotalTokens),
InputObserved: true,
OutputObserved: true,
}
}

// doneUsagePtr 在 message_done 事件中按 usage 观测状态返回 payload,未观测时返回 nil。
func doneUsagePtr(usage providertypes.Usage) *providertypes.Usage {
if !usage.InputObserved && !usage.OutputObserved {
return nil
}
copy := usage
return &copy
}

// mergeToolCallDeltaFromSDK 将单个 SDK tool call 增量合并到累积状态,并在必要时发出起始/增量事件。
Expand Down
33 changes: 33 additions & 0 deletions internal/provider/openaicompat/chatcompletions/adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,36 @@ func TestConsumeStreamSupportsWeakSSEFormat(t *testing.T) {
if done.Usage == nil || done.Usage.TotalTokens != 3 {
t.Fatalf("unexpected usage: %+v", done.Usage)
}
if !done.Usage.InputObserved || !done.Usage.OutputObserved {
t.Fatalf("expected usage observed flags to be true, got %+v", done.Usage)
}
}

func TestConsumeStreamEmitsNilUsageWhenProviderDidNotReturnUsage(t *testing.T) {
t.Parallel()

body := strings.Join([]string{
`data: {"choices":[{"delta":{"content":"ok"},"finish_reason":"stop"}]}`,
`data: [DONE]`,
"",
}, "\n")

events := make(chan providertypes.StreamEvent, 4)
if err := ConsumeStream(context.Background(), strings.NewReader(body), events); err != nil {
t.Fatalf("ConsumeStream() error = %v", err)
}

drained := drainEvents(events)
if len(drained) != 2 {
t.Fatalf("expected 2 events, got %d", len(drained))
}
done, err := drained[1].MessageDoneValue()
if err != nil {
t.Fatalf("expected message done, got err=%v", err)
}
if done.Usage != nil {
t.Fatalf("expected nil usage when stream carries no usage, got %+v", done.Usage)
}
}

func TestConsumeStreamParsesMultilineDataEvent(t *testing.T) {
Expand Down Expand Up @@ -197,6 +227,9 @@ func TestEmitFromSDKStream(t *testing.T) {
if done.Usage == nil || done.Usage.TotalTokens != 3 {
t.Fatalf("expected usage total tokens 3, got %+v", done.Usage)
}
if !done.Usage.InputObserved || !done.Usage.OutputObserved {
t.Fatalf("expected usage observed flags to be true, got %+v", done.Usage)
}
}

func TestEmitFromSDKStreamErrors(t *testing.T) {
Expand Down
19 changes: 15 additions & 4 deletions internal/provider/openaicompat/responses/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func EmitFromStream(
if reason == "" {
reason = "stop"
}
return provider.EmitMessageDone(ctx, events, reason, &usage)
return provider.EmitMessageDone(ctx, events, reason, doneUsagePtr(usage))
}
processPayload := func(payload string) error {
if strings.TrimSpace(payload) == "[DONE]" {
Expand Down Expand Up @@ -392,9 +392,11 @@ func extractUsage(usage *providertypes.Usage, response *streamResponse) {
return
}
*usage = providertypes.Usage{
InputTokens: response.Usage.InputTokens,
OutputTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.TotalTokens,
InputTokens: response.Usage.InputTokens,
OutputTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.TotalTokens,
InputObserved: true,
OutputObserved: true,
}
}

Expand Down Expand Up @@ -435,3 +437,12 @@ func resolveFinishReason(eventType string, response *streamResponse) string {
return ""
}
}

// doneUsagePtr 在 message_done 事件中按 usage 观测状态返回 payload,未观测时返回 nil。
func doneUsagePtr(usage providertypes.Usage) *providertypes.Usage {
if !usage.InputObserved && !usage.OutputObserved {
return nil
}
copy := usage
return &copy
}
Loading
Loading