From 62e569bcccd37545547313d7b97899b377ce581d Mon Sep 17 00:00:00 2001 From: xgopilot Date: Thu, 23 Apr 2026 05:08:55 +0000 Subject: [PATCH] fix(runtime,anthropic): tighten estimate failure gate and preserve zero usage observation Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/provider/anthropic/provider.go | 8 +-- internal/provider/anthropic/provider_test.go | 64 ++++++++++++++++++++ internal/runtime/run.go | 9 +++ internal/runtime/runtime_test.go | 46 +++++++++++++- 4 files changed, 122 insertions(+), 5 deletions(-) diff --git a/internal/provider/anthropic/provider.go b/internal/provider/anthropic/provider.go index 4c594286..777fcba6 100644 --- a/internal/provider/anthropic/provider.go +++ b/internal/provider/anthropic/provider.go @@ -96,11 +96,11 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque event := streamReader.Current() switch variant := event.AsAny().(type) { case anthropic.MessageStartEvent: - if variant.Message.Usage.InputTokens > 0 { + if variant.Message.Usage.JSON.InputTokens.Valid() { usage.InputTokens = int(variant.Message.Usage.InputTokens) usage.InputObserved = true } - if variant.Message.Usage.OutputTokens > 0 { + if variant.Message.Usage.JSON.OutputTokens.Valid() { usage.OutputTokens = int(variant.Message.Usage.OutputTokens) usage.OutputObserved = true } @@ -167,11 +167,11 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque if reason := strings.TrimSpace(string(variant.Delta.StopReason)); reason != "" { finishReason = reason } - if variant.Usage.OutputTokens > 0 { + if variant.Usage.JSON.OutputTokens.Valid() { usage.OutputTokens = int(variant.Usage.OutputTokens) usage.OutputObserved = true } - if variant.Usage.InputTokens > 0 { + if variant.Usage.JSON.InputTokens.Valid() { usage.InputTokens = int(variant.Usage.InputTokens) usage.InputObserved = true } diff --git a/internal/provider/anthropic/provider_test.go b/internal/provider/anthropic/provider_test.go index a7219d20..5e8ed159 100644 --- a/internal/provider/anthropic/provider_test.go +++ b/internal/provider/anthropic/provider_test.go @@ -102,6 +102,70 @@ func TestProviderGenerate(t *testing.T) { } } +func TestProviderGenerateMarksZeroUsageAsObserved(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "event: message_start\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n") + _, _ = fmt.Fprint(w, "event: content_block_start\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"ok\"}}\n\n") + _, _ = fmt.Fprint(w, "event: message_delta\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}\n\n") + _, _ = fmt.Fprint(w, "event: message_stop\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_stop\"}\n\n") + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverAnthropic, + BaseURL: server.URL, + DefaultModel: "claude-3-7-sonnet", + APIKeyEnv: "ANTHROPIC_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + }, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + + drained := drainEvents(events) + var done *providertypes.MessageDonePayload + for i := range drained { + if drained[i].Type != providertypes.StreamEventMessageDone { + continue + } + payload, payloadErr := drained[i].MessageDoneValue() + if payloadErr != nil { + t.Fatalf("MessageDoneValue() error = %v", payloadErr) + } + done = &payload + break + } + if done == nil { + t.Fatalf("expected message_done event, got %+v", drained) + } + if done.Usage == nil { + t.Fatalf("expected usage to be present when zero usage is observed") + } + if !done.Usage.InputObserved || !done.Usage.OutputObserved { + t.Fatalf("expected observed flags true, got %+v", done.Usage) + } + if done.Usage.InputTokens != 0 || done.Usage.OutputTokens != 0 || done.Usage.TotalTokens != 0 { + t.Fatalf("expected zero usage, got %+v", done.Usage) + } +} + func TestProviderGenerateOmitsUsageWhenProviderDidNotReturnUsage(t *testing.T) { t.Parallel() diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 7c05dee5..55bfe45d 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -542,6 +542,9 @@ func (s *Service) evaluateTurnBudget( if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return controlplane.TurnBudgetDecision{}, err } + if !shouldBypassEstimateFailure(err) { + return controlplane.TurnBudgetDecision{}, fmt.Errorf("runtime: estimate input tokens: %w", err) + } s.emitRunScoped(ctx, EventBudgetEstimateFailed, state, newBudgetEstimateFailedPayload(snapshot.ID, err)) decision := controlplane.TurnBudgetDecision{ ID: snapshot.ID, @@ -563,6 +566,12 @@ func (s *Service) evaluateTurnBudget( return decision, nil } +// shouldBypassEstimateFailure 判断估算失败是否允许降级放行,仅对可恢复 provider 错误放行。 +func shouldBypassEstimateFailure(err error) bool { + var providerErr *provider.ProviderError + return errors.As(err, &providerErr) && providerErr.Retryable +} + // reconcileLedger 根据 observed usage 或发送前 estimate 生成本轮账本写入结果。 func (s *Service) reconcileLedger( state *runState, diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index df63dda9..2c6b8902 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -4727,7 +4727,12 @@ func TestServiceRunBypassesBudgetGateWhenEstimateFails(t *testing.T) { estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { _ = ctx _ = req - return providertypes.BudgetEstimate{}, errors.New("estimate unavailable") + return providertypes.BudgetEstimate{}, &provider.ProviderError{ + StatusCode: 503, + Code: provider.ErrorCodeServer, + Message: "estimate unavailable", + Retryable: true, + } }, responses: []scriptedResponse{ { @@ -4799,6 +4804,45 @@ func TestServiceRunBypassesBudgetGateWhenEstimateFails(t *testing.T) { assertNoEventType(t, events, EventError) } +func TestServiceRunFailsWhenEstimateFailsWithDeterministicError(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Context.Budget.PromptBudget = 10 + cfg.Context.Budget.FallbackPromptBudget = 10 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + store := newMemoryStore() + registry := tools.NewRegistry() + scripted := &scriptedProvider{ + estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + _ = ctx + _ = req + return providertypes.BudgetEstimate{}, errors.New("invalid provider config") + }, + } + + service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + err := service.Run(context.Background(), UserInput{ + RunID: "run-budget-estimate-failed-hard-stop", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + }) + if err == nil || !containsError(err, "estimate input tokens") { + t.Fatalf("expected estimate input tokens error, got %v", err) + } + if scripted.callCount != 0 { + t.Fatalf("expected provider Generate not to be called, got %d calls", scripted.callCount) + } + + events := collectRuntimeEvents(service.Events()) + assertNoEventType(t, events, EventBudgetEstimateFailed) + assertNoEventType(t, events, EventBudgetChecked) +} + func TestServiceRunFailsWhenEstimateContextCanceled(t *testing.T) { t.Parallel()