From 1e48b59e38dc25e6aa564d2e249ec31ea4470f86 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Sun, 26 Apr 2026 08:33:29 +0000 Subject: [PATCH] test(provider): improve generate attempt and timeout normalization coverage Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/provider/constants_test.go | 35 +++- internal/provider/generate_attempt_test.go | 199 +++++++++++++++++++++ 2 files changed, 233 insertions(+), 1 deletion(-) diff --git a/internal/provider/constants_test.go b/internal/provider/constants_test.go index 5d82b4df..41452693 100644 --- a/internal/provider/constants_test.go +++ b/internal/provider/constants_test.go @@ -1,6 +1,9 @@ package provider -import "testing" +import ( + "testing" + "time" +) func TestNormalizeGenerateMaxRetries(t *testing.T) { t.Parallel() @@ -26,3 +29,33 @@ func TestNormalizeGenerateMaxRetries(t *testing.T) { }) } } + +func TestNormalizeGenerateStartTimeout(t *testing.T) { + t.Parallel() + + if got := NormalizeGenerateStartTimeout(0); got != DefaultGenerateStartTimeout { + t.Fatalf("NormalizeGenerateStartTimeout(0) = %s, want %s", got, DefaultGenerateStartTimeout) + } + if got := NormalizeGenerateStartTimeout(-time.Second); got != DefaultGenerateStartTimeout { + t.Fatalf("NormalizeGenerateStartTimeout(-1s) = %s, want %s", got, DefaultGenerateStartTimeout) + } + want := 3 * time.Second + if got := NormalizeGenerateStartTimeout(want); got != want { + t.Fatalf("NormalizeGenerateStartTimeout(3s) = %s, want %s", got, want) + } +} + +func TestNormalizeGenerateIdleTimeout(t *testing.T) { + t.Parallel() + + if got := NormalizeGenerateIdleTimeout(0); got != DefaultGenerateIdleTimeout { + t.Fatalf("NormalizeGenerateIdleTimeout(0) = %s, want %s", got, DefaultGenerateIdleTimeout) + } + if got := NormalizeGenerateIdleTimeout(-time.Second); got != DefaultGenerateIdleTimeout { + t.Fatalf("NormalizeGenerateIdleTimeout(-1s) = %s, want %s", got, DefaultGenerateIdleTimeout) + } + want := 4 * time.Second + if got := NormalizeGenerateIdleTimeout(want); got != want { + t.Fatalf("NormalizeGenerateIdleTimeout(4s) = %s, want %s", got, want) + } +} diff --git a/internal/provider/generate_attempt_test.go b/internal/provider/generate_attempt_test.go index 462abe85..0f386c46 100644 --- a/internal/provider/generate_attempt_test.go +++ b/internal/provider/generate_attempt_test.go @@ -3,6 +3,7 @@ package provider import ( "context" "errors" + "sync/atomic" "testing" "time" @@ -257,6 +258,204 @@ func TestRunGenerateWithRetryUsingTreatsMessageDoneAsCompletedState(t *testing.T } } +func TestRunGenerateWithRetryUsesDefaultRunner(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{ + GenerateMaxRetries: 0, + GenerateStartTimeout: time.Second, + GenerateIdleTimeout: time.Second, + } + events := make(chan providertypes.StreamEvent, 8) + + err := RunGenerateWithRetry( + context.Background(), + cfg, + events, + func(ctx context.Context, attemptEvents chan<- providertypes.StreamEvent) error { + if emitErr := EmitTextDelta(ctx, attemptEvents, "ok"); emitErr != nil { + return emitErr + } + return EmitMessageDone(ctx, attemptEvents, "stop", nil) + }, + ) + if err != nil { + t.Fatalf("RunGenerateWithRetry() error = %v", err) + } + + drained := drainAttemptEvents(events) + if len(drained) != 2 { + t.Fatalf("expected two forwarded events, got %d", len(drained)) + } +} + +func TestRunGenerateWithRetryUsingReturnsRetryWaitError(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{ + GenerateMaxRetries: 1, + GenerateStartTimeout: time.Second, + GenerateIdleTimeout: time.Second, + } + events := make(chan providertypes.StreamEvent, 4) + waitErr := errors.New("wait failed") + attempts := 0 + + err := RunGenerateWithRetryUsing( + context.Background(), + cfg, + events, + func(int) time.Duration { return time.Millisecond }, + func(context.Context, time.Duration) error { return waitErr }, + func(context.Context, chan<- providertypes.StreamEvent) error { + attempts++ + return ErrStreamInterrupted + }, + ) + if !errors.Is(err, waitErr) { + t.Fatalf("expected retry wait error, got %v", err) + } + if attempts != 1 { + t.Fatalf("expected only first attempt before wait failure, got %d", attempts) + } +} + +func TestRunGenerateWithRetryUsingReturnsLastErrorAfterExhaustedRetries(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{ + GenerateMaxRetries: 1, + GenerateStartTimeout: time.Second, + GenerateIdleTimeout: time.Second, + } + events := make(chan providertypes.StreamEvent, 4) + firstErr := NewProviderErrorFromStatus(500, "first") + lastErr := NewProviderErrorFromStatus(500, "last") + attempts := 0 + + err := RunGenerateWithRetryUsing( + context.Background(), + cfg, + events, + func(int) time.Duration { return 0 }, + func(context.Context, time.Duration) error { return nil }, + func(context.Context, chan<- providertypes.StreamEvent) error { + attempts++ + if attempts == 1 { + return firstErr + } + return lastErr + }, + ) + if !errors.Is(err, lastErr) { + t.Fatalf("expected last retryable error, got %v", err) + } + if attempts != 2 { + t.Fatalf("expected two attempts after exhausting retries, got %d", attempts) + } +} + +func TestWaitForRetryHonorsContextCancel(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := waitForRetry(ctx, time.Second) + if !errors.Is(err, context.Canceled) { + t.Fatalf("waitForRetry() error = %v, want context canceled", err) + } +} + +func TestWaitForRetryReturnsNilOnNonPositiveWait(t *testing.T) { + t.Parallel() + + if err := waitForRetry(context.Background(), 0); err != nil { + t.Fatalf("waitForRetry() error = %v", err) + } +} + +func TestStopAndResetTimerHelpers(t *testing.T) { + t.Parallel() + + stopTimer(nil) + resetTimer(nil, time.Millisecond) + + timer := time.NewTimer(time.Millisecond) + time.Sleep(5 * time.Millisecond) + stopTimer(timer) + select { + case <-timer.C: + t.Fatal("expected stopTimer to drain timer channel") + default: + } + + resetTimer(timer, 30*time.Millisecond) + select { + case <-timer.C: + t.Fatal("expected resetTimer to apply new wait") + case <-time.After(10 * time.Millisecond): + } + select { + case <-timer.C: + case <-time.After(100 * time.Millisecond): + t.Fatal("expected reset timer to fire") + } +} + +func TestUpdateGenerateAttemptPhaseTransitions(t *testing.T) { + t.Parallel() + + var phase atomic.Uint32 + if got := updateGenerateAttemptPhase( + providertypes.StreamEvent{Type: providertypes.StreamEventType("unknown")}, + &phase, + ); got != generateAttemptPhaseWaitingFirstPayload { + t.Fatalf("unexpected initial phase = %v", got) + } + if got := updateGenerateAttemptPhase( + providertypes.StreamEvent{Type: providertypes.StreamEventToolCallStart}, + &phase, + ); got != generateAttemptPhaseStreaming { + t.Fatalf("expected streaming phase, got %v", got) + } + if got := updateGenerateAttemptPhase( + providertypes.StreamEvent{Type: providertypes.StreamEventMessageDone}, + &phase, + ); got != generateAttemptPhaseCompleted { + t.Fatalf("expected completed phase, got %v", got) + } + if got := updateGenerateAttemptPhase( + providertypes.StreamEvent{Type: providertypes.StreamEventTextDelta}, + &phase, + ); got != generateAttemptPhaseCompleted { + t.Fatalf("expected completed phase to remain terminal, got %v", got) + } +} + +func TestIsEffectiveGeneratePayloadEvent(t *testing.T) { + t.Parallel() + + cases := []struct { + eventType providertypes.StreamEventType + want bool + }{ + {eventType: providertypes.StreamEventTextDelta, want: true}, + {eventType: providertypes.StreamEventToolCallStart, want: true}, + {eventType: providertypes.StreamEventToolCallDelta, want: true}, + {eventType: providertypes.StreamEventMessageDone, want: false}, + } + for _, tc := range cases { + tc := tc + t.Run(string(tc.eventType), func(t *testing.T) { + t.Parallel() + if got := IsEffectiveGeneratePayloadEvent(providertypes.StreamEvent{Type: tc.eventType}); got != tc.want { + t.Fatalf("IsEffectiveGeneratePayloadEvent(%s) = %v, want %v", tc.eventType, got, tc.want) + } + }) + } +} + func drainAttemptEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { out := make([]providertypes.StreamEvent, 0, len(events)) for {