Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions internal/provider/anthropic/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque
event := streamReader.Current()
switch variant := event.AsAny().(type) {
case anthropic.MessageStartEvent:
if variant.Message.Usage.InputTokens > 0 {
if variant.Message.Usage.JSON.InputTokens.Valid() {
usage.InputTokens = int(variant.Message.Usage.InputTokens)
usage.InputObserved = true
}
if variant.Message.Usage.OutputTokens > 0 {
if variant.Message.Usage.JSON.OutputTokens.Valid() {
usage.OutputTokens = int(variant.Message.Usage.OutputTokens)
usage.OutputObserved = true
}
Expand Down Expand Up @@ -167,11 +167,11 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque
if reason := strings.TrimSpace(string(variant.Delta.StopReason)); reason != "" {
finishReason = reason
}
if variant.Usage.OutputTokens > 0 {
if variant.Usage.JSON.OutputTokens.Valid() {
usage.OutputTokens = int(variant.Usage.OutputTokens)
usage.OutputObserved = true
}
if variant.Usage.InputTokens > 0 {
if variant.Usage.JSON.InputTokens.Valid() {
usage.InputTokens = int(variant.Usage.InputTokens)
usage.InputObserved = true
}
Expand Down
64 changes: 64 additions & 0 deletions internal/provider/anthropic/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,70 @@ func TestProviderGenerate(t *testing.T) {
}
}

func TestProviderGenerateMarksZeroUsageAsObserved(t *testing.T) {
t.Parallel()

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = fmt.Fprint(w, "event: message_start\n")
_, _ = fmt.Fprint(w, "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n")
_, _ = fmt.Fprint(w, "event: content_block_start\n")
_, _ = fmt.Fprint(w, "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"ok\"}}\n\n")
_, _ = fmt.Fprint(w, "event: message_delta\n")
_, _ = fmt.Fprint(w, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}\n\n")
_, _ = fmt.Fprint(w, "event: message_stop\n")
_, _ = fmt.Fprint(w, "data: {\"type\":\"message_stop\"}\n\n")
}))
defer server.Close()

p, err := New(provider.RuntimeConfig{
Driver: provider.DriverAnthropic,
BaseURL: server.URL,
DefaultModel: "claude-3-7-sonnet",
APIKeyEnv: "ANTHROPIC_TEST_KEY",
APIKeyResolver: provider.StaticAPIKeyResolver("test-key"),
})
if err != nil {
t.Fatalf("New() error = %v", err)
}

events := make(chan providertypes.StreamEvent, 8)
if err := p.Generate(context.Background(), providertypes.GenerateRequest{
Messages: []providertypes.Message{{
Role: providertypes.RoleUser,
Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")},
}},
}, events); err != nil {
t.Fatalf("Generate() error = %v", err)
}

drained := drainEvents(events)
var done *providertypes.MessageDonePayload
for i := range drained {
if drained[i].Type != providertypes.StreamEventMessageDone {
continue
}
payload, payloadErr := drained[i].MessageDoneValue()
if payloadErr != nil {
t.Fatalf("MessageDoneValue() error = %v", payloadErr)
}
done = &payload
break
}
if done == nil {
t.Fatalf("expected message_done event, got %+v", drained)
}
if done.Usage == nil {
t.Fatalf("expected usage to be present when zero usage is observed")
}
if !done.Usage.InputObserved || !done.Usage.OutputObserved {
t.Fatalf("expected observed flags true, got %+v", done.Usage)
}
if done.Usage.InputTokens != 0 || done.Usage.OutputTokens != 0 || done.Usage.TotalTokens != 0 {
t.Fatalf("expected zero usage, got %+v", done.Usage)
}
}

func TestProviderGenerateOmitsUsageWhenProviderDidNotReturnUsage(t *testing.T) {
t.Parallel()

Expand Down
9 changes: 9 additions & 0 deletions internal/runtime/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,9 @@ func (s *Service) evaluateTurnBudget(
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return controlplane.TurnBudgetDecision{}, err
}
if !shouldBypassEstimateFailure(err) {
return controlplane.TurnBudgetDecision{}, fmt.Errorf("runtime: estimate input tokens: %w", err)
}
s.emitRunScoped(ctx, EventBudgetEstimateFailed, state, newBudgetEstimateFailedPayload(snapshot.ID, err))
decision := controlplane.TurnBudgetDecision{
ID: snapshot.ID,
Expand All @@ -563,6 +566,12 @@ func (s *Service) evaluateTurnBudget(
return decision, nil
}

// shouldBypassEstimateFailure 判断估算失败是否允许降级放行,仅对可恢复 provider 错误放行。
func shouldBypassEstimateFailure(err error) bool {
var providerErr *provider.ProviderError
return errors.As(err, &providerErr) && providerErr.Retryable
}

// reconcileLedger 根据 observed usage 或发送前 estimate 生成本轮账本写入结果。
func (s *Service) reconcileLedger(
state *runState,
Expand Down
46 changes: 45 additions & 1 deletion internal/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4727,7 +4727,12 @@ func TestServiceRunBypassesBudgetGateWhenEstimateFails(t *testing.T) {
estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) {
_ = ctx
_ = req
return providertypes.BudgetEstimate{}, errors.New("estimate unavailable")
return providertypes.BudgetEstimate{}, &provider.ProviderError{
StatusCode: 503,
Code: provider.ErrorCodeServer,
Message: "estimate unavailable",
Retryable: true,
}
},
responses: []scriptedResponse{
{
Expand Down Expand Up @@ -4799,6 +4804,45 @@ func TestServiceRunBypassesBudgetGateWhenEstimateFails(t *testing.T) {
assertNoEventType(t, events, EventError)
}

func TestServiceRunFailsWhenEstimateFailsWithDeterministicError(t *testing.T) {
t.Parallel()

manager := newRuntimeConfigManager(t)
if err := manager.Update(context.Background(), func(cfg *config.Config) error {
cfg.Context.Budget.PromptBudget = 10
cfg.Context.Budget.FallbackPromptBudget = 10
return nil
}); err != nil {
t.Fatalf("update config: %v", err)
}

store := newMemoryStore()
registry := tools.NewRegistry()
scripted := &scriptedProvider{
estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) {
_ = ctx
_ = req
return providertypes.BudgetEstimate{}, errors.New("invalid provider config")
},
}

service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{})
err := service.Run(context.Background(), UserInput{
RunID: "run-budget-estimate-failed-hard-stop",
Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")},
})
if err == nil || !containsError(err, "estimate input tokens") {
t.Fatalf("expected estimate input tokens error, got %v", err)
}
if scripted.callCount != 0 {
t.Fatalf("expected provider Generate not to be called, got %d calls", scripted.callCount)
}

events := collectRuntimeEvents(service.Events())
assertNoEventType(t, events, EventBudgetEstimateFailed)
assertNoEventType(t, events, EventBudgetChecked)
}

func TestServiceRunFailsWhenEstimateContextCanceled(t *testing.T) {
t.Parallel()

Expand Down
Loading