From eb360d310ad91c9e256f7f289a6c4d0fafdeccba Mon Sep 17 00:00:00 2001 From: xgopilot Date: Thu, 23 Apr 2026 04:38:20 +0000 Subject: [PATCH] fix(runtime): reconcile token ledger with partial usage observation Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/provider/anthropic/provider.go | 9 ++- internal/provider/anthropic/provider_test.go | 59 ++++++++++++++++ internal/provider/gemini/provider.go | 5 ++ internal/provider/gemini/provider_test.go | 55 +++++++++++++++ .../openaicompat/chatcompletions/adapter.go | 40 +++++++---- .../chatcompletions/adapter_test.go | 33 +++++++++ .../openaicompat/responses/adapter.go | 19 +++-- .../openaicompat/responses/adapter_test.go | 33 +++++++++ internal/provider/types/usage.go | 8 ++- internal/runtime/budget_models.go | 7 +- internal/runtime/provider_stream.go | 14 ++-- internal/runtime/run.go | 24 ++++--- .../runtime/runtime_internal_helpers_test.go | 69 +++++++++++++++++++ internal/runtime/runtime_test.go | 8 ++- 14 files changed, 343 insertions(+), 40 deletions(-) diff --git a/internal/provider/anthropic/provider.go b/internal/provider/anthropic/provider.go index 1e176a16..4c594286 100644 --- a/internal/provider/anthropic/provider.go +++ b/internal/provider/anthropic/provider.go @@ -98,9 +98,11 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque case anthropic.MessageStartEvent: if variant.Message.Usage.InputTokens > 0 { usage.InputTokens = int(variant.Message.Usage.InputTokens) + usage.InputObserved = true } if variant.Message.Usage.OutputTokens > 0 { usage.OutputTokens = int(variant.Message.Usage.OutputTokens) + usage.OutputObserved = true } case anthropic.ContentBlockStartEvent: switch block := variant.ContentBlock.AsAny().(type) { @@ -167,9 +169,11 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque } if variant.Usage.OutputTokens > 0 { usage.OutputTokens = int(variant.Usage.OutputTokens) + usage.OutputObserved = true } if variant.Usage.InputTokens > 0 { usage.InputTokens = int(variant.Usage.InputTokens) + usage.InputObserved = true } } } @@ -193,9 +197,12 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque return fmt.Errorf("%sinvalid tool_use stream at index %d: missing tool name", errorPrefix, index) } } - if usage.TotalTokens <= 0 { + if usage.TotalTokens <= 0 && (usage.InputObserved || usage.OutputObserved) { usage.TotalTokens = usage.InputTokens + usage.OutputTokens } + if !usage.InputObserved && !usage.OutputObserved { + return provider.EmitMessageDone(ctx, events, finishReason, nil) + } return provider.EmitMessageDone(ctx, events, finishReason, &usage) } diff --git a/internal/provider/anthropic/provider_test.go b/internal/provider/anthropic/provider_test.go index 63a2ecdd..a7219d20 100644 --- a/internal/provider/anthropic/provider_test.go +++ b/internal/provider/anthropic/provider_test.go @@ -92,6 +92,9 @@ func TestProviderGenerate(t *testing.T) { if payload.Usage == nil || payload.Usage.TotalTokens != 10 { t.Fatalf("expected usage total tokens 10, got %+v", payload.Usage) } + if !payload.Usage.InputObserved || !payload.Usage.OutputObserved { + t.Fatalf("expected usage observed flags true, got %+v", payload.Usage) + } } } if !foundText || !foundToolStart || !foundToolDelta || !foundDone { @@ -99,6 +102,62 @@ func TestProviderGenerate(t *testing.T) { } } +func TestProviderGenerateOmitsUsageWhenProviderDidNotReturnUsage(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "event: content_block_start\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"Hello\"}}\n\n") + _, _ = fmt.Fprint(w, "event: message_delta\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"}}\n\n") + _, _ = fmt.Fprint(w, "event: message_stop\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_stop\"}\n\n") + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverAnthropic, + BaseURL: server.URL, + DefaultModel: "claude-3-7-sonnet", + APIKeyEnv: "ANTHROPIC_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + }, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + + drained := drainEvents(events) + var done *providertypes.MessageDonePayload + for i := range drained { + if drained[i].Type != providertypes.StreamEventMessageDone { + continue + } + payload, payloadErr := drained[i].MessageDoneValue() + if payloadErr != nil { + t.Fatalf("MessageDoneValue() error = %v", payloadErr) + } + done = &payload + break + } + if done == nil { + t.Fatalf("expected message_done event, got %+v", drained) + } + if done.Usage != nil { + t.Fatalf("expected nil usage when provider does not report usage, got %+v", done.Usage) + } +} + func TestNewAcceptsCustomChatEndpointPath(t *testing.T) { t.Parallel() diff --git a/internal/provider/gemini/provider.go b/internal/provider/gemini/provider.go index 5a6ea83e..017ce282 100644 --- a/internal/provider/gemini/provider.go +++ b/internal/provider/gemini/provider.go @@ -157,6 +157,9 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque if !hasPayload { return fmt.Errorf("%w: empty gemini stream payload", provider.ErrStreamInterrupted) } + if !usage.InputObserved && !usage.OutputObserved { + return provider.EmitMessageDone(ctx, events, finishReason, nil) + } return provider.EmitMessageDone(ctx, events, finishReason, &usage) } @@ -209,6 +212,8 @@ func extractUsage(usage *providertypes.Usage, raw *genai.GenerateContentResponse usage.InputTokens = int(raw.PromptTokenCount) usage.OutputTokens = int(raw.CandidatesTokenCount) usage.TotalTokens = int(raw.TotalTokenCount) + usage.InputObserved = true + usage.OutputObserved = true } // encodeArguments 将函数参数对象编码为 JSON 字符串,供统一 tool_call_delta 事件复用。 diff --git a/internal/provider/gemini/provider_test.go b/internal/provider/gemini/provider_test.go index 3e8296bb..8109fe8c 100644 --- a/internal/provider/gemini/provider_test.go +++ b/internal/provider/gemini/provider_test.go @@ -80,6 +80,9 @@ func TestProviderGenerate(t *testing.T) { if payload.Usage == nil || payload.Usage.TotalTokens != 7 { t.Fatalf("expected usage total tokens 7, got %+v", payload.Usage) } + if !payload.Usage.InputObserved || !payload.Usage.OutputObserved { + t.Fatalf("expected usage observed flags true, got %+v", payload.Usage) + } if payload.FinishReason != "stop" { t.Fatalf("expected finish reason stop, got %q", payload.FinishReason) } @@ -90,6 +93,58 @@ func TestProviderGenerate(t *testing.T) { } } +func TestProviderGenerateOmitsUsageWhenProviderDidNotReturnUsage(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"content\":{\"parts\":[{\"text\":\"Hello \"}]}}]}\n\n") + _, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"finishReason\":\"STOP\",\"content\":{\"parts\":[{\"text\":\"done\"}]}}]}\n\n") + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + }, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + + drained := drainEvents(events) + var done *providertypes.MessageDonePayload + for i := range drained { + if drained[i].Type != providertypes.StreamEventMessageDone { + continue + } + payload, payloadErr := drained[i].MessageDoneValue() + if payloadErr != nil { + t.Fatalf("MessageDoneValue() error = %v", payloadErr) + } + done = &payload + break + } + if done == nil { + t.Fatalf("expected message_done event, got %+v", drained) + } + if done.Usage != nil { + t.Fatalf("expected nil usage when provider does not report usage, got %+v", done.Usage) + } +} + func TestNewAcceptsCustomChatEndpointPath(t *testing.T) { t.Parallel() diff --git a/internal/provider/openaicompat/chatcompletions/adapter.go b/internal/provider/openaicompat/chatcompletions/adapter.go index e1fe0fe7..ba3ffc5c 100644 --- a/internal/provider/openaicompat/chatcompletions/adapter.go +++ b/internal/provider/openaicompat/chatcompletions/adapter.go @@ -72,6 +72,9 @@ func EmitFromSDKStream( return fmt.Errorf("SDK stream error: %w", err) } + if !usage.InputObserved && !usage.OutputObserved { + return provider.EmitMessageDone(ctx, events, finishReason, nil) + } return provider.EmitMessageDone(ctx, events, finishReason, &usage) } @@ -183,13 +186,13 @@ func ConsumeStream( if flushErr := flushDataLines(); flushErr != nil { return flushErr } - return provider.EmitMessageDone(ctx, events, finishReason, &usage) + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) } if flushErr := flushDataLines(); flushErr != nil { return flushErr } if strings.TrimSpace(finishReason) != "" { - return provider.EmitMessageDone(ctx, events, finishReason, &usage) + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) } return fmt.Errorf("%w: %w", provider.ErrStreamInterrupted, err) } @@ -206,7 +209,7 @@ func ConsumeStream( return flushErr } done = true - return provider.EmitMessageDone(ctx, events, finishReason, &usage) + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) } else { dataLines = append(dataLines, data) } @@ -215,7 +218,7 @@ func ConsumeStream( return flushErr } if done { - return provider.EmitMessageDone(ctx, events, finishReason, &usage) + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) } default: if len(dataLines) == 0 { @@ -225,7 +228,7 @@ func ConsumeStream( return flushErr } if done { - return provider.EmitMessageDone(ctx, events, finishReason, &usage) + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) } } @@ -234,7 +237,7 @@ func ConsumeStream( return flushErr } if done || strings.TrimSpace(finishReason) != "" { - return provider.EmitMessageDone(ctx, events, finishReason, &usage) + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) } return fmt.Errorf("%w: missing [DONE] marker before EOF", provider.ErrStreamInterrupted) } @@ -247,19 +250,32 @@ func extractLegacyStreamUsage(usage *providertypes.Usage, raw *streamUsage) { return } *usage = providertypes.Usage{ - InputTokens: raw.PromptTokens, - OutputTokens: raw.CompletionTokens, - TotalTokens: raw.TotalTokens, + InputTokens: raw.PromptTokens, + OutputTokens: raw.CompletionTokens, + TotalTokens: raw.TotalTokens, + InputObserved: true, + OutputObserved: true, } } // extractStreamUsage 将 OpenAI usage 覆盖到统一 token 统计。 func extractStreamUsage(usage *providertypes.Usage, raw openai.CompletionUsage) { *usage = providertypes.Usage{ - InputTokens: int(raw.PromptTokens), - OutputTokens: int(raw.CompletionTokens), - TotalTokens: int(raw.TotalTokens), + InputTokens: int(raw.PromptTokens), + OutputTokens: int(raw.CompletionTokens), + TotalTokens: int(raw.TotalTokens), + InputObserved: true, + OutputObserved: true, + } +} + +// doneUsagePtr 在 message_done 事件中按 usage 观测状态返回 payload,未观测时返回 nil。 +func doneUsagePtr(usage providertypes.Usage) *providertypes.Usage { + if !usage.InputObserved && !usage.OutputObserved { + return nil } + copy := usage + return © } // mergeToolCallDeltaFromSDK 将单个 SDK tool call 增量合并到累积状态,并在必要时发出起始/增量事件。 diff --git a/internal/provider/openaicompat/chatcompletions/adapter_test.go b/internal/provider/openaicompat/chatcompletions/adapter_test.go index 5d9c2da8..d1eb195c 100644 --- a/internal/provider/openaicompat/chatcompletions/adapter_test.go +++ b/internal/provider/openaicompat/chatcompletions/adapter_test.go @@ -43,6 +43,36 @@ func TestConsumeStreamSupportsWeakSSEFormat(t *testing.T) { if done.Usage == nil || done.Usage.TotalTokens != 3 { t.Fatalf("unexpected usage: %+v", done.Usage) } + if !done.Usage.InputObserved || !done.Usage.OutputObserved { + t.Fatalf("expected usage observed flags to be true, got %+v", done.Usage) + } +} + +func TestConsumeStreamEmitsNilUsageWhenProviderDidNotReturnUsage(t *testing.T) { + t.Parallel() + + body := strings.Join([]string{ + `data: {"choices":[{"delta":{"content":"ok"},"finish_reason":"stop"}]}`, + `data: [DONE]`, + "", + }, "\n") + + events := make(chan providertypes.StreamEvent, 4) + if err := ConsumeStream(context.Background(), strings.NewReader(body), events); err != nil { + t.Fatalf("ConsumeStream() error = %v", err) + } + + drained := drainEvents(events) + if len(drained) != 2 { + t.Fatalf("expected 2 events, got %d", len(drained)) + } + done, err := drained[1].MessageDoneValue() + if err != nil { + t.Fatalf("expected message done, got err=%v", err) + } + if done.Usage != nil { + t.Fatalf("expected nil usage when stream carries no usage, got %+v", done.Usage) + } } func TestConsumeStreamParsesMultilineDataEvent(t *testing.T) { @@ -197,6 +227,9 @@ func TestEmitFromSDKStream(t *testing.T) { if done.Usage == nil || done.Usage.TotalTokens != 3 { t.Fatalf("expected usage total tokens 3, got %+v", done.Usage) } + if !done.Usage.InputObserved || !done.Usage.OutputObserved { + t.Fatalf("expected usage observed flags to be true, got %+v", done.Usage) + } } func TestEmitFromSDKStreamErrors(t *testing.T) { diff --git a/internal/provider/openaicompat/responses/adapter.go b/internal/provider/openaicompat/responses/adapter.go index f146207d..6a2c0cf5 100644 --- a/internal/provider/openaicompat/responses/adapter.go +++ b/internal/provider/openaicompat/responses/adapter.go @@ -36,7 +36,7 @@ func EmitFromStream( if reason == "" { reason = "stop" } - return provider.EmitMessageDone(ctx, events, reason, &usage) + return provider.EmitMessageDone(ctx, events, reason, doneUsagePtr(usage)) } processPayload := func(payload string) error { if strings.TrimSpace(payload) == "[DONE]" { @@ -392,9 +392,11 @@ func extractUsage(usage *providertypes.Usage, response *streamResponse) { return } *usage = providertypes.Usage{ - InputTokens: response.Usage.InputTokens, - OutputTokens: response.Usage.OutputTokens, - TotalTokens: response.Usage.TotalTokens, + InputTokens: response.Usage.InputTokens, + OutputTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.TotalTokens, + InputObserved: true, + OutputObserved: true, } } @@ -435,3 +437,12 @@ func resolveFinishReason(eventType string, response *streamResponse) string { return "" } } + +// doneUsagePtr 在 message_done 事件中按 usage 观测状态返回 payload,未观测时返回 nil。 +func doneUsagePtr(usage providertypes.Usage) *providertypes.Usage { + if !usage.InputObserved && !usage.OutputObserved { + return nil + } + copy := usage + return © +} diff --git a/internal/provider/openaicompat/responses/adapter_test.go b/internal/provider/openaicompat/responses/adapter_test.go index a75cd0b3..f922b752 100644 --- a/internal/provider/openaicompat/responses/adapter_test.go +++ b/internal/provider/openaicompat/responses/adapter_test.go @@ -47,6 +47,39 @@ func TestEmitFromStreamSupportsMultilineSSEData(t *testing.T) { if done.Usage == nil || done.Usage.TotalTokens != 3 { t.Fatalf("unexpected usage in done event: %+v", done.Usage) } + if !done.Usage.InputObserved || !done.Usage.OutputObserved { + t.Fatalf("expected usage observed flags to be true, got %+v", done.Usage) + } +} + +func TestEmitFromStreamEmitsNilUsageWhenProviderDidNotReturnUsage(t *testing.T) { + t.Parallel() + + body := strings.Join([]string{ + `data: {"type":"response.output_text.delta","delta":"hello"}`, + "", + `data: {"type":"response.completed","response":{"status":"completed"}}`, + "", + `data: [DONE]`, + "", + }, "\n") + + events := make(chan providertypes.StreamEvent, 4) + if err := EmitFromStream(context.Background(), strings.NewReader(body), events); err != nil { + t.Fatalf("EmitFromStream() error = %v", err) + } + + drained := drainResponseEvents(events) + if len(drained) != 2 { + t.Fatalf("expected 2 events, got %d (%+v)", len(drained), drained) + } + done, err := drained[1].MessageDoneValue() + if err != nil { + t.Fatalf("expected message done event, got err=%v", err) + } + if done.Usage != nil { + t.Fatalf("expected nil usage when stream carries no usage, got %+v", done.Usage) + } } func TestEmitFromStreamSupportsLongDataLine(t *testing.T) { diff --git a/internal/provider/types/usage.go b/internal/provider/types/usage.go index a605919c..c8c76e26 100644 --- a/internal/provider/types/usage.go +++ b/internal/provider/types/usage.go @@ -2,9 +2,11 @@ package types // Usage 记录本次请求的 token 使用统计。 type Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + InputObserved bool `json:"input_observed"` + OutputObserved bool `json:"output_observed"` } // BudgetEstimate 描述 provider 对冻结请求输入 token 的估算结果。 diff --git a/internal/runtime/budget_models.go b/internal/runtime/budget_models.go index 6283931e..b573eeb2 100644 --- a/internal/runtime/budget_models.go +++ b/internal/runtime/budget_models.go @@ -108,13 +108,14 @@ func newTurnBudgetUsageObservation( id controlplane.TurnBudgetID, inputTokens int, outputTokens int, - observed bool, + inputObserved bool, + outputObserved bool, ) TurnBudgetUsageObservation { return TurnBudgetUsageObservation{ ID: id, InputTokens: inputTokens, OutputTokens: outputTokens, - InputObserved: observed, - OutputObserved: observed, + InputObserved: inputObserved, + OutputObserved: outputObserved, } } diff --git a/internal/runtime/provider_stream.go b/internal/runtime/provider_stream.go index b120e4a4..75d4d7ab 100644 --- a/internal/runtime/provider_stream.go +++ b/internal/runtime/provider_stream.go @@ -11,11 +11,12 @@ import ( // streamGenerateResult 统一承载一次流式生成的消息、用量与消费错误。 type streamGenerateResult struct { - message providertypes.Message - inputTokens int - outputTokens int - usagePresent bool - err error + message providertypes.Message + inputTokens int + outputTokens int + inputObserved bool + outputObserved bool + err error } // generateStreamingMessage 负责执行一次基于流式事件的生成调用,并收敛最终 assistant 消息与 usage。 @@ -41,7 +42,8 @@ func generateStreamingMessage( if payload.Usage != nil { outcome.inputTokens = payload.Usage.InputTokens outcome.outputTokens = payload.Usage.OutputTokens - outcome.usagePresent = true + outcome.inputObserved = payload.Usage.InputObserved + outcome.outputObserved = payload.Usage.OutputObserved } if userOnMessageDone != nil { userOnMessageDone(payload) diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 17197798..55a86ab5 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -445,7 +445,8 @@ func (s *Service) callProviderWithRetry( snapshot.ID, streamOutcome.inputTokens, streamOutcome.outputTokens, - streamOutcome.usagePresent, + streamOutcome.inputObserved, + streamOutcome.outputObserved, ), }, nil } @@ -572,18 +573,23 @@ func (s *Service) reconcileLedger( return ledgerReconcileResult{}, fmt.Errorf("runtime: turn budget id mismatch between decision and usage observation") } reconciled := ledgerReconcileResult{ - inputTokens: observation.InputTokens, - inputSource: usageSourceObserved, - outputTokens: observation.OutputTokens, - outputSource: usageSourceObserved, + inputSource: usageSourceUnknown, + outputSource: usageSourceUnknown, + } + if observation.InputObserved { + reconciled.inputTokens = observation.InputTokens + reconciled.inputSource = usageSourceObserved + } else { + reconciled.inputTokens = decision.EstimatedInputTokens + reconciled.inputSource = usageSourceEstimated + } + if observation.OutputObserved { + reconciled.outputTokens = observation.OutputTokens + reconciled.outputSource = usageSourceObserved } if observation.InputObserved && observation.OutputObserved { return reconciled, nil } - reconciled.inputTokens = decision.EstimatedInputTokens - reconciled.inputSource = usageSourceEstimated - reconciled.outputTokens = 0 - reconciled.outputSource = usageSourceUnknown reconciled.hasUnknownUsage = true if state != nil { state.session.HasUnknownUsage = true diff --git a/internal/runtime/runtime_internal_helpers_test.go b/internal/runtime/runtime_internal_helpers_test.go index 4a31e7c8..375a79d1 100644 --- a/internal/runtime/runtime_internal_helpers_test.go +++ b/internal/runtime/runtime_internal_helpers_test.go @@ -11,6 +11,7 @@ import ( "neo-code/internal/provider" providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/approval" + "neo-code/internal/runtime/controlplane" agentsession "neo-code/internal/session" "neo-code/internal/tools" ) @@ -625,6 +626,74 @@ func TestEmitTokenUsageSkipsZeroUsage(t *testing.T) { } } +func TestReconcileLedgerSupportsPartialObservation(t *testing.T) { + t.Parallel() + + service := &Service{} + state := &runState{session: newRuntimeSession("session-partial-observed")} + id := controlplane.TurnBudgetID{AttemptSeq: 2, RequestHash: "hash-partial-observed"} + decision := controlplane.TurnBudgetDecision{ + ID: id, + EstimatedInputTokens: 37, + } + observation := TurnBudgetUsageObservation{ + ID: id, + InputTokens: 13, + OutputTokens: 0, + InputObserved: true, + OutputObserved: false, + } + + result, err := service.reconcileLedger(state, decision, observation) + if err != nil { + t.Fatalf("reconcileLedger() error = %v", err) + } + if result.inputTokens != 13 || result.inputSource != usageSourceObserved { + t.Fatalf("expected observed input reconciliation, got %+v", result) + } + if result.outputTokens != 0 || result.outputSource != usageSourceUnknown { + t.Fatalf("expected unknown output reconciliation, got %+v", result) + } + if !result.hasUnknownUsage { + t.Fatalf("expected hasUnknownUsage=true for partial observation") + } + if !state.session.HasUnknownUsage || !state.hasUnknownUsage { + t.Fatalf("expected unknown usage flag to propagate to run state") + } +} + +func TestReconcileLedgerUsesEstimateWhenInputNotObserved(t *testing.T) { + t.Parallel() + + service := &Service{} + id := controlplane.TurnBudgetID{AttemptSeq: 3, RequestHash: "hash-no-input-observed"} + decision := controlplane.TurnBudgetDecision{ + ID: id, + EstimatedInputTokens: 41, + } + observation := TurnBudgetUsageObservation{ + ID: id, + InputTokens: 0, + OutputTokens: 7, + InputObserved: false, + OutputObserved: true, + } + + result, err := service.reconcileLedger(nil, decision, observation) + if err != nil { + t.Fatalf("reconcileLedger() error = %v", err) + } + if result.inputTokens != 41 || result.inputSource != usageSourceEstimated { + t.Fatalf("expected estimated input reconciliation, got %+v", result) + } + if result.outputTokens != 7 || result.outputSource != usageSourceObserved { + t.Fatalf("expected observed output reconciliation, got %+v", result) + } + if !result.hasUnknownUsage { + t.Fatalf("expected hasUnknownUsage=true when any side is unobserved") + } +} + func TestExecuteAssistantToolCallsFillsErrorContent(t *testing.T) { t.Parallel() diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 3d6fea52..df63dda9 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -3823,6 +3823,8 @@ func TestServiceRunPersistsAndRestoresTokenUsage(t *testing.T) { usage.InputTokens = 25 usage.OutputTokens = 10 } + usage.InputObserved = true + usage.OutputObserved = true select { case events <- providertypes.NewTextDeltaStreamEvent("assistant reply"): @@ -4917,8 +4919,10 @@ func TestTokenUsageRecordedOnMessageDone(t *testing.T) { // Create a MessageDone stream event with token usage messageDoneEvent := providertypes.NewMessageDoneStreamEvent("stop", &providertypes.Usage{ - InputTokens: 100, - OutputTokens: 50, + InputTokens: 100, + OutputTokens: 50, + InputObserved: true, + OutputObserved: true, }) // 使用与运行时相同的流式事件处理器验证 usage 累积行为。