diff --git a/go.mod b/go.mod index 5d84bc5d..65b3f6ef 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ replace ( require ( cuelang.org/go v0.15.1 github.com/Masterminds/semver/v3 v3.4.0 + github.com/cenkalti/backoff/v5 v5.0.3 github.com/dlclark/regexp2 v1.11.0 github.com/docker/cli v27.5.1+incompatible github.com/fluxcd/pkg/oci v0.43.1 @@ -100,7 +101,6 @@ require ( github.com/aws/smithy-go v1.24.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect - github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chai2010/gettext-go v1.0.3 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect diff --git a/retry/error.go b/retry/error.go new file mode 100644 index 00000000..cd598f00 --- /dev/null +++ b/retry/error.go @@ -0,0 +1,94 @@ +package retry + +import ( + "errors" + + "github.com/cenkalti/backoff/v5" + meshkiterrors "github.com/meshery/meshkit/errors" +) + +// ErrInvalidConfig is returned when retry configuration validation fails. +// Use errors.Is(err, ErrInvalidConfig) to distinguish config errors from +// operation failures. +var ErrInvalidConfig = errors.New("retry: invalid config") + +const ( + ErrRetryCode = "meshkit-10001" + ErrContextCode = "meshkit-10002" + ErrConfigCode = "meshkit-10003" +) + +type retryError struct { + inner error + meshkit *meshkiterrors.Error +} + +func (e *retryError) Error() string { + return e.meshkit.Error() +} + +func (e *retryError) Unwrap() []error { + return []error{e.inner, e.meshkit} +} + +func ErrRetry(err error) error { + return &retryError{ + inner: err, + meshkit: meshkiterrors.New(ErrRetryCode, meshkiterrors.Alert, + []string{"Retry operation failed"}, + []string{err.Error()}, + []string{"Operation did not succeed within retry limits"}, + []string{"Check the underlying operation and retry configuration"}), + } +} + +func ErrContext(err error) error { + return &retryError{ + inner: err, + meshkit: meshkiterrors.New(ErrContextCode, meshkiterrors.Alert, + []string{"Context canceled or deadline exceeded"}, + []string{err.Error()}, + []string{"Operation timed out or context was canceled"}, + []string{"Check context timeout and ensure the operation completes in time"}), + } +} + +func ErrConfig(err error) error { + return &retryError{ + inner: err, + meshkit: meshkiterrors.New(ErrConfigCode, meshkiterrors.Alert, + []string{"Invalid retry configuration"}, + []string{err.Error()}, + []string{"One or more config values are invalid"}, + []string{"Ensure all retry configuration values are correct"}), + } +} + +// ErrorDecision controls retry behaviour for a single error. +type ErrorDecision int + +const ( + DecisionRetry ErrorDecision = iota + DecisionStop +) + +// ErrorClassifier returns the retry decision for a given error. +// Return DecisionStop for errors that should not be retried (e.g. HTTP 4xx, +// validation failures, auth errors). Return DecisionRetry for transient +// errors (timeouts, 5xx, rate limits). +// +// Ignored when the operation explicitly returns Permanent(err). +type ErrorClassifier func(err error) ErrorDecision + +// Permanent wraps err to signal no further retries should be attempted. +// Use for non-transient errors (HTTP 4xx, auth failures, validation errors). +// Do NOT use for context-cancellation; return ctx.Err() directly. +func Permanent(err error) error { + return backoff.Permanent(err) +} + +// IsPermanent reports whether err is (or wraps) a PermanentError. +func IsPermanent(err error) bool { + var pErr *backoff.PermanentError + return errors.As(err, &pErr) +} diff --git a/retry/options.go b/retry/options.go new file mode 100644 index 00000000..48f6bca6 --- /dev/null +++ b/retry/options.go @@ -0,0 +1,101 @@ +package retry + +import ( + "time" + + "github.com/meshery/meshkit/logger" +) + +const ( + DefaultInitialInterval = 500 * time.Millisecond + DefaultMaxInterval = 30 * time.Second + DefaultMaxElapsedTime = 2 * time.Minute + DefaultMultiplier = 1.5 + DefaultRandomizationFactor = 0.3 // Never set to 0 in production +) + +type Config struct { + MaxAttempts uint + InitialInterval time.Duration + MaxInterval time.Duration + MaxElapsedTime time.Duration + Multiplier float64 + RandomizationFactor float64 + Notifier func(err error, wait time.Duration) + ErrorClassifier ErrorClassifier +} + +func defaultConfig() Config { + return Config{ + InitialInterval: DefaultInitialInterval, + MaxInterval: DefaultMaxInterval, + MaxElapsedTime: DefaultMaxElapsedTime, + Multiplier: DefaultMultiplier, + RandomizationFactor: DefaultRandomizationFactor, + } +} + +type Option func(*Config) + +// WithMaxAttempts sets a hard cap on total calls (includes first attempt). +func WithMaxAttempts(n uint) Option { + return func(c *Config) { c.MaxAttempts = n } +} + +func WithInitialInterval(d time.Duration) Option { + return func(c *Config) { c.InitialInterval = d } +} + +func WithMaxInterval(d time.Duration) Option { + return func(c *Config) { c.MaxInterval = d } +} + +// WithMaxElapsedTime sets wall-clock deadline. Pass 0 to disable. +func WithMaxElapsedTime(d time.Duration) Option { + return func(c *Config) { c.MaxElapsedTime = d } +} + +func WithMultiplier(m float64) Option { + return func(c *Config) { c.Multiplier = m } +} + +// WithJitter overrides randomization factor (range: 0.0-1.0). Do not set to 0.0 in production. +func WithJitter(f float64) Option { + return func(c *Config) { c.RandomizationFactor = f } +} + +// WithErrorClassifier provides a decision function for classifying errors as +// retryable (DecisionRetry) or terminal (DecisionStop). When set, every error +// returned by the operation (except those explicitly wrapped with Permanent) +// is passed to this function. If it returns DecisionStop, the error is treated +// as permanent and the retry loop stops immediately. +// +// Example: +// +// retry.Do(ctx, op, +// retry.WithErrorClassifier(func(err error) retry.ErrorDecision { +// var status *myHTTPError +// if errors.As(err, &status) { +// if status.Code >= 500 { +// return retry.DecisionRetry +// } +// return retry.DecisionStop +// } +// return retry.DecisionRetry +// }), +// ) +func WithErrorClassifier(classifier ErrorClassifier) Option { + return func(c *Config) { c.ErrorClassifier = classifier } +} + +func WithNotifier(n func(err error, wait time.Duration)) Option { + return func(c *Config) { c.Notifier = n } +} + +// WithLogNotifier emits a Warn log entry on each retry via MeshKit's logger.Handler. +func WithLogNotifier(log logger.Handler) Option { + return WithNotifier(func(err error, wait time.Duration) { + log.Infof("retry: transient error; retrying in %s", wait.Round(time.Millisecond)) + log.Warn(err) + }) +} diff --git a/retry/retry.go b/retry/retry.go new file mode 100644 index 00000000..60b247fb --- /dev/null +++ b/retry/retry.go @@ -0,0 +1,101 @@ +package retry + +import ( + "context" + "errors" + "fmt" + "math" + + "github.com/cenkalti/backoff/v5" +) + +type Operation func(ctx context.Context) error + +// Do executes op with exponential backoff until success, permanent error, +// context cancellation, or budget exhaustion. Config via opts (default: +// 500ms initial, 1.5x growth, 30% jitter, 2min cap). +// +// When a ErrorClassifier is configured via WithErrorClassifier, every non-nil +// error from op (except those explicitly wrapped with Permanent) is passed to +// the classifier before the retry decision is made. +func Do(ctx context.Context, op Operation, opts ...Option) error { + if err := ctx.Err(); err != nil { + return ErrContext(err) + } + cfg := defaultConfig() + for _, o := range opts { + o(&cfg) + } + + if err := validateConfig(cfg); err != nil { + return ErrConfig(err) + } + + apply := op + if cfg.ErrorClassifier != nil { + apply = func(ctx context.Context) error { + err := op(ctx) + if err == nil { + return nil + } + var pErr *backoff.PermanentError + if errors.As(err, &pErr) { + return err + } + if cfg.ErrorClassifier(err) == DecisionStop { + return backoff.Permanent(err) + } + return err + } + } + + retryOpts := []backoff.RetryOption{ + backoff.WithBackOff(buildBackOff(cfg)), + backoff.WithMaxElapsedTime(cfg.MaxElapsedTime), + backoff.WithNotify(cfg.Notifier), + } + if cfg.MaxAttempts > 0 { + retryOpts = append(retryOpts, backoff.WithMaxTries(cfg.MaxAttempts)) + } + + _, err := backoff.Retry(ctx, func() (struct{}, error) { + return struct{}{}, apply(ctx) + }, retryOpts...) + if err != nil { + return ErrRetry(err) + } + return nil +} + +func validateConfig(cfg Config) error { + if cfg.InitialInterval <= 0 { + return fmt.Errorf("%w: InitialInterval must be > 0, got %v", ErrInvalidConfig, cfg.InitialInterval) + } + if cfg.MaxInterval <= 0 { + return fmt.Errorf("%w: MaxInterval must be > 0, got %v", ErrInvalidConfig, cfg.MaxInterval) + } + if cfg.MaxInterval < cfg.InitialInterval { + return fmt.Errorf("%w: MaxInterval (%v) must be >= InitialInterval (%v)", ErrInvalidConfig, cfg.MaxInterval, cfg.InitialInterval) + } + if cfg.MaxElapsedTime < 0 { + return fmt.Errorf("%w: MaxElapsedTime must be >= 0, got %v", ErrInvalidConfig, cfg.MaxElapsedTime) + } + if math.IsNaN(cfg.Multiplier) || math.IsInf(cfg.Multiplier, 0) || cfg.Multiplier < 1 { + return fmt.Errorf("%w: Multiplier must be finite and >= 1, got %v", ErrInvalidConfig, cfg.Multiplier) + } + if math.IsNaN(cfg.RandomizationFactor) || cfg.RandomizationFactor < 0 || cfg.RandomizationFactor > 1 { + return fmt.Errorf("%w: RandomizationFactor must be finite and in [0,1], got %v", ErrInvalidConfig, cfg.RandomizationFactor) + } + return nil +} + +// buildBackOff constructs a backoff policy from Config. +func buildBackOff(cfg Config) backoff.BackOff { + b := backoff.NewExponentialBackOff() + b.InitialInterval = cfg.InitialInterval + b.MaxInterval = cfg.MaxInterval + b.Multiplier = cfg.Multiplier + b.RandomizationFactor = cfg.RandomizationFactor + + return b +} diff --git a/retry/retry_test.go b/retry/retry_test.go new file mode 100644 index 00000000..2686158d --- /dev/null +++ b/retry/retry_test.go @@ -0,0 +1,736 @@ +package retry_test + +import ( + "context" + "errors" + "fmt" + "math" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/meshery/meshkit/retry" +) + +func alwaysFail(err error) retry.Operation { + return func(ctx context.Context) error { return err } +} + +func countingOp(count *atomic.Int64, err error) retry.Operation { + return func(ctx context.Context) error { + count.Add(1) + return err + } +} + +func TestRetrySucceedsFirstAttempt(t *testing.T) { + t.Parallel() + + calls := 0 + err := retry.Do(context.Background(), func(ctx context.Context) error { + calls++ + return nil + }, retry.WithMaxAttempts(5)) + + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if calls != 1 { + t.Fatalf("expected op called once, got %d", calls) + } +} + +func TestRetrySucceedsAfterTransientErrors(t *testing.T) { + t.Parallel() + + transient := errors.New("transient") + var calls atomic.Int64 + + err := retry.Do(context.Background(), + func(ctx context.Context) error { + n := calls.Add(1) + if n < 4 { + return transient + } + return nil + }, + retry.WithMaxAttempts(10), + retry.WithInitialInterval(1*time.Millisecond), + retry.WithMaxInterval(5*time.Millisecond), + retry.WithMaxElapsedTime(5*time.Second), + ) + + if err != nil { + t.Fatalf("expected success after retries, got %v", err) + } + if calls.Load() != 4 { + t.Fatalf("expected 4 calls (3 failures + 1 success), got %d", calls.Load()) + } +} + +func TestRetryPermanentErrorStopsImmediately(t *testing.T) { + t.Parallel() + + permanent := errors.New("permanent failure") + calls := 0 + + err := retry.Do(context.Background(), + func(ctx context.Context) error { + calls++ + return retry.Permanent(permanent) + }, + retry.WithMaxAttempts(10), + retry.WithInitialInterval(1*time.Millisecond), + ) + + if err == nil { + t.Fatal("expected non-nil error for permanent failure") + } + if !errors.Is(err, permanent) { + t.Fatalf("expected permanent sentinel unwrapped, got %v", err) + } + if calls != 1 { + t.Fatalf("expected exactly 1 call, got %d", calls) + } +} + +func TestIsPermanentReturnsFalseForTransient(t *testing.T) { + t.Parallel() + + err := errors.New("transient") + if retry.IsPermanent(err) { + t.Fatal("plain error should not be permanent") + } +} + +func TestIsPermanentReturnsTrueForPermanentWrapped(t *testing.T) { + t.Parallel() + + inner := errors.New("the cause") + wrapped := retry.Permanent(inner) + if !retry.IsPermanent(wrapped) { + t.Fatal("Permanent(err) should satisfy IsPermanent") + } +} + +func TestIsPermanentHandlesDoublyWrappedErrors(t *testing.T) { + t.Parallel() + + inner := errors.New("the cause") + wrapped := fmt.Errorf("outer layer: %w", retry.Permanent(inner)) + if !retry.IsPermanent(wrapped) { + t.Fatal("IsPermanent should unwrap error chains successfully") + } +} + +func TestRetryContextCancellationStopsLoop(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + + var calls atomic.Int64 + transient := errors.New("transient") + + go func() { + time.Sleep(5 * time.Millisecond) + cancel() + }() + + err := retry.Do(ctx, + func(ctx context.Context) error { + calls.Add(1) + return transient + }, + retry.WithInitialInterval(50*time.Millisecond), // longer than the cancel delay + retry.WithMaxElapsedTime(10*time.Second), + ) + + if err == nil { + t.Fatal("expected error after context cancellation") + } + if calls.Load() == 0 { + t.Fatal("expected at least one call before cancellation") + } +} + +func TestRetryContextAlreadyCancelledBeforeFirstAttempt(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + var calls atomic.Int64 + err := retry.Do(ctx, + func(ctx context.Context) error { + calls.Add(1) + return errors.New("should not reach") + }, + retry.WithMaxAttempts(5), + retry.WithInitialInterval(1*time.Millisecond), + ) + + if err == nil { + t.Fatal("expected error for pre-cancelled context") + } + if calls.Load() > 1 { + t.Fatalf("expected at most 1 call for pre-cancelled context, got %d", calls.Load()) + } +} + +func TestRetryMaxAttemptsEnforced(t *testing.T) { + t.Parallel() + + const maxAttempts = 4 + var count atomic.Int64 + + err := retry.Do(context.Background(), + countingOp(&count, errors.New("always fails")), + retry.WithMaxAttempts(maxAttempts), + retry.WithInitialInterval(1*time.Millisecond), + retry.WithMaxInterval(2*time.Millisecond), + retry.WithMaxElapsedTime(0), // disable elapsed-time cap + ) + + if err == nil { + t.Fatal("expected error when max attempts exhausted") + } + if count.Load() != maxAttempts { + t.Fatalf("expected exactly %d calls, got %d", maxAttempts, count.Load()) + } +} + +func TestRetryMaxElapsedTimeEnforced(t *testing.T) { + t.Parallel() + + start := time.Now() + const budget = 80 * time.Millisecond + + err := retry.Do(context.Background(), + alwaysFail(errors.New("always fails")), + retry.WithMaxElapsedTime(budget), + retry.WithInitialInterval(5*time.Millisecond), + retry.WithMaxInterval(10*time.Millisecond), + retry.WithJitter(0), // deterministic for timing assertions + ) + + elapsed := time.Since(start) + if err == nil { + t.Fatal("expected error when elapsed time exceeded") + } + if elapsed > 3*budget { + t.Fatalf("loop ran for %s, expected <= %s", elapsed, 3*budget) + } +} + +func TestRetryNotifierCalledOnEachRetry(t *testing.T) { + t.Parallel() + + const failures = 3 + transient := errors.New("transient") + var notifyCount atomic.Int64 + + notifier := func(err error, wait time.Duration) { + notifyCount.Add(1) + if !errors.Is(err, transient) { + t.Errorf("notifier: unexpected error %v", err) + } + } + + var calls atomic.Int64 + _ = retry.Do(context.Background(), + func(ctx context.Context) error { + if calls.Add(1) <= failures { + return transient + } + return nil + }, + retry.WithMaxAttempts(10), + retry.WithInitialInterval(1*time.Millisecond), + retry.WithMaxInterval(2*time.Millisecond), + retry.WithNotifier(notifier), + ) + + if notifyCount.Load() != failures { + t.Fatalf("expected notifier called %d times, got %d", failures, notifyCount.Load()) + } +} + +func TestRetryNotifierNotCalledOnImmediateSuccess(t *testing.T) { + t.Parallel() + + var notifyCount atomic.Int64 + _ = retry.Do(context.Background(), + func(ctx context.Context) error { return nil }, + retry.WithNotifier(func(err error, wait time.Duration) { + notifyCount.Add(1) + }), + ) + if notifyCount.Load() != 0 { + t.Fatalf("notifier should not be called on immediate success, called %d time(s)", notifyCount.Load()) + } +} + +func TestRetryNotifierNotCalledOnPermanentError(t *testing.T) { + t.Parallel() + + var notifyCount atomic.Int64 + _ = retry.Do(context.Background(), + func(ctx context.Context) error { return retry.Permanent(errors.New("perm")) }, + retry.WithMaxAttempts(5), + retry.WithInitialInterval(1*time.Millisecond), + retry.WithNotifier(func(err error, wait time.Duration) { + notifyCount.Add(1) + }), + ) + if notifyCount.Load() != 0 { + t.Fatalf("notifier called %d times for permanent error, expected 0", notifyCount.Load()) + } +} + +func TestRetryZeroMaxAttemptsMeansUnlimited(t *testing.T) { + t.Parallel() + + err := retry.Do(context.Background(), + alwaysFail(errors.New("always fails")), + retry.WithMaxAttempts(0), + retry.WithMaxElapsedTime(50*time.Millisecond), + retry.WithInitialInterval(5*time.Millisecond), + retry.WithMaxInterval(10*time.Millisecond), + ) + if err == nil { + t.Fatal("expected error when elapsed time runs out with unlimited attempts") + } +} + +func TestRetryWithMaxAttemptsOneNoRetry(t *testing.T) { + t.Parallel() + + var calls atomic.Int64 + err := retry.Do(context.Background(), + countingOp(&calls, errors.New("fail")), + retry.WithMaxAttempts(1), + retry.WithInitialInterval(1*time.Millisecond), + retry.WithMaxElapsedTime(0), + ) + if err == nil { + t.Fatal("expected error") + } + if calls.Load() != 1 { + t.Fatalf("WithMaxAttempts(1) should allow exactly 1 call, got %d", calls.Load()) + } +} + +func TestRetryDefaultsAreApplied(t *testing.T) { + t.Parallel() + + transient := errors.New("transient") + var calls atomic.Int64 + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _ = retry.Do(ctx, + func(ctx context.Context) error { + if calls.Add(1) >= 2 { + return nil + } + return transient + }, + ) + + if calls.Load() < 2 { + t.Fatalf("expected at least 2 calls with default config, got %d", calls.Load()) + } +} + +func TestRetryClassifierStopsOnDecisionStop(t *testing.T) { + t.Parallel() + + classifyErr := errors.New("not a chance") + var calls atomic.Int64 + + err := retry.Do(context.Background(), + func(ctx context.Context) error { + calls.Add(1) + return classifyErr + }, + retry.WithErrorClassifier(func(err error) retry.ErrorDecision { + if errors.Is(err, classifyErr) { + return retry.DecisionStop + } + return retry.DecisionRetry + }), + retry.WithMaxAttempts(10), + retry.WithInitialInterval(1*time.Millisecond), + ) + + if err == nil { + t.Fatal("expected non-nil error when classifier stops") + } + if !errors.Is(err, classifyErr) { + t.Fatalf("expected classifier error unwrapped, got %v", err) + } + if calls.Load() != 1 { + t.Fatalf("expected exactly 1 call when classifier stops, got %d", calls.Load()) + } +} + +func TestRetryClassifierRetriesOnDecisionRetry(t *testing.T) { + t.Parallel() + + classifyErr := errors.New("transient per classifier") + var calls atomic.Int64 + + err := retry.Do(context.Background(), + func(ctx context.Context) error { + n := calls.Add(1) + if n < 3 { + return classifyErr + } + return nil + }, + retry.WithErrorClassifier(func(err error) retry.ErrorDecision { + return retry.DecisionRetry + }), + retry.WithMaxAttempts(10), + retry.WithInitialInterval(1*time.Millisecond), + retry.WithMaxInterval(5*time.Millisecond), + ) + + if err != nil { + t.Fatalf("expected success after classifier retries, got %v", err) + } + if calls.Load() != 3 { + t.Fatalf("expected 3 calls (2 classified retries + success), got %d", calls.Load()) + } +} + +func TestRetryClassifierDoesNotOverrideExplicitPermanent(t *testing.T) { + t.Parallel() + + permErr := errors.New("explicitly permanent") + var calls atomic.Int64 + + err := retry.Do(context.Background(), + func(ctx context.Context) error { + calls.Add(1) + return retry.Permanent(permErr) + }, + // Classifier says retry everything — but Permanent should still win. + retry.WithErrorClassifier(func(err error) retry.ErrorDecision { + return retry.DecisionRetry + }), + retry.WithMaxAttempts(10), + retry.WithInitialInterval(1*time.Millisecond), + ) + + if err == nil { + t.Fatal("expected non-nil error") + } + if !errors.Is(err, permErr) { + t.Fatalf("expected permanent error unwrapped, got %v", err) + } + if calls.Load() != 1 { + t.Fatalf("expected exactly 1 call for explicit Permanent, got %d", calls.Load()) + } +} + +func TestRetryClassifierCanMixWithPermanent(t *testing.T) { + t.Parallel() + + permErr := errors.New("permanent") + transientErr := errors.New("transient") + var calls atomic.Int64 + + err := retry.Do(context.Background(), + func(ctx context.Context) error { + n := calls.Add(1) + if n == 1 { + return transientErr + } + return retry.Permanent(permErr) + }, + retry.WithErrorClassifier(func(err error) retry.ErrorDecision { + if errors.Is(err, transientErr) { + return retry.DecisionRetry + } + return retry.DecisionStop + }), + retry.WithMaxAttempts(5), + retry.WithInitialInterval(1*time.Millisecond), + retry.WithMaxInterval(2*time.Millisecond), + ) + + if err == nil { + t.Fatal("expected non-nil error") + } + if !errors.Is(err, permErr) { + t.Fatalf("expected permanent error unwrapped, got %v", err) + } + if calls.Load() != 2 { + t.Fatalf("expected 2 calls (transient + permanent), got %d", calls.Load()) + } +} + +func TestRetryConfigValidationInitialIntervalZero(t *testing.T) { + t.Parallel() + + err := retry.Do(context.Background(), + func(ctx context.Context) error { return errors.New("err") }, + retry.WithInitialInterval(0), + ) + if err == nil { + t.Fatal("expected validation error") + } + if !strings.Contains(err.Error(), "InitialInterval") { + t.Fatalf("expected InitialInterval validation error, got %v", err) + } +} + +func TestRetryConfigValidationMaxIntervalZero(t *testing.T) { + t.Parallel() + + err := retry.Do(context.Background(), + func(ctx context.Context) error { return errors.New("err") }, + retry.WithMaxInterval(0), + ) + if err == nil { + t.Fatal("expected validation error") + } + if !strings.Contains(err.Error(), "MaxInterval") { + t.Fatalf("expected MaxInterval validation error, got %v", err) + } +} + +func TestRetryConfigValidationMaxIntervalLessThanInitial(t *testing.T) { + t.Parallel() + + err := retry.Do(context.Background(), + func(ctx context.Context) error { return errors.New("err") }, + retry.WithInitialInterval(5*time.Second), + retry.WithMaxInterval(1*time.Second), + ) + if err == nil { + t.Fatal("expected validation error") + } + if !strings.Contains(err.Error(), "MaxInterval") || !strings.Contains(err.Error(), "InitialInterval") { + t.Fatalf("expected MaxInterval/InitialInterval mismatch error, got %v", err) + } +} + +func TestRetryConfigValidationMultiplierNaN(t *testing.T) { + t.Parallel() + + err := retry.Do(context.Background(), + func(ctx context.Context) error { return errors.New("err") }, + retry.WithMultiplier(float64(math.NaN())), + ) + if err == nil { + t.Fatal("expected validation error") + } + if !strings.Contains(err.Error(), "Multiplier") { + t.Fatalf("expected Multiplier validation error, got %v", err) + } +} + +func TestRetryConfigValidationMultiplierInf(t *testing.T) { + t.Parallel() + + err := retry.Do(context.Background(), + func(ctx context.Context) error { return errors.New("err") }, + retry.WithMultiplier(float64(math.Inf(1))), + ) + if err == nil { + t.Fatal("expected validation error") + } + if !strings.Contains(err.Error(), "Multiplier") { + t.Fatalf("expected Multiplier validation error, got %v", err) + } +} + +func TestRetryConfigValidationMultiplierLessThanOne(t *testing.T) { + t.Parallel() + + err := retry.Do(context.Background(), + func(ctx context.Context) error { return errors.New("err") }, + retry.WithMultiplier(0.5), + ) + if err == nil { + t.Fatal("expected validation error") + } + if !strings.Contains(err.Error(), "Multiplier") { + t.Fatalf("expected Multiplier validation error, got %v", err) + } +} + +func TestRetryConfigValidationJitterNaN(t *testing.T) { + t.Parallel() + + err := retry.Do(context.Background(), + func(ctx context.Context) error { return errors.New("err") }, + retry.WithJitter(float64(math.NaN())), + ) + if err == nil { + t.Fatal("expected validation error") + } + if !strings.Contains(err.Error(), "RandomizationFactor") { + t.Fatalf("expected RandomizationFactor validation error, got %v", err) + } +} + +func TestRetryConfigValidationJitterOutOfRange(t *testing.T) { + t.Parallel() + + err := retry.Do(context.Background(), + func(ctx context.Context) error { return errors.New("err") }, + retry.WithJitter(1.5), + ) + if err == nil { + t.Fatal("expected validation error") + } + if !strings.Contains(err.Error(), "RandomizationFactor") { + t.Fatalf("expected RandomizationFactor validation error, got %v", err) + } +} + +func TestRetryConfigValidationJitterNegative(t *testing.T) { + t.Parallel() + + err := retry.Do(context.Background(), + func(ctx context.Context) error { return errors.New("err") }, + retry.WithJitter(-0.1), + ) + if err == nil { + t.Fatal("expected validation error") + } + if !strings.Contains(err.Error(), "RandomizationFactor") { + t.Fatalf("expected RandomizationFactor validation error, got %v", err) + } +} + +func TestRetryConfigValidationZeroMaxElapsedTimeIsValid(t *testing.T) { + t.Parallel() + + // 0 for MaxElapsedTime means "no wall-clock limit". Should be valid. + err := retry.Do(context.Background(), + func(ctx context.Context) error { return nil }, + retry.WithMaxElapsedTime(0), + ) + if err != nil { + t.Fatalf("expected success (0 MaxElapsedTime is valid), got %v", err) + } +} + +func TestRetryConfigValidationNegativeMaxElapsedTime(t *testing.T) { + t.Parallel() + + err := retry.Do(context.Background(), + func(ctx context.Context) error { return errors.New("err") }, + retry.WithMaxElapsedTime(-1), + ) + if err == nil { + t.Fatal("expected validation error") + } + if !strings.Contains(err.Error(), "MaxElapsedTime") { + t.Fatalf("expected MaxElapsedTime validation error, got %v", err) + } +} + +func TestRetryInvalidConfigSentinelIsReachable(t *testing.T) { + t.Parallel() + + err := retry.Do(context.Background(), + func(ctx context.Context) error { return errors.New("fail") }, + retry.WithInitialInterval(0), + ) + if err == nil { + t.Fatal("expected validation error") + } + if !errors.Is(err, retry.ErrInvalidConfig) { + t.Fatalf("errors.Is(err, ErrInvalidConfig) should be true, got %v", err) + } +} + +// ExampleDo demonstrates idiomatic HTTP usage with retry budget and per-attempt timeout. +// +// MaxElapsedTime limits the retry loop but does NOT interrupt an in-flight HTTP +// request. Always pair it with http.Client.Timeout (or NewRequestWithContext) so +// each attempt has its own deadline. +func ExampleDo() { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + client := &http.Client{Timeout: 3 * time.Second} + + err := retry.Do(context.Background(), func(ctx context.Context) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + if err != nil { + return retry.Permanent(err) + } + + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + switch { + case resp.StatusCode == http.StatusOK: + return nil + case resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500: + return fmt.Errorf("transient response: %s", resp.Status) + default: + return retry.Permanent(fmt.Errorf("non-retryable response: %s", resp.Status)) + } + }, + retry.WithMaxAttempts(3), + retry.WithInitialInterval(time.Second), + retry.WithMaxElapsedTime(10*time.Second), + ) + + if err != nil { + fmt.Printf("request failed: %v\n", err) + } + // Output: + // request failed: non-retryable response: 404 Not Found | Short Description: Retry operation failed | Probable Cause: Operation did not succeed within retry limits | Suggested Remediation: Check the underlying operation and retry configuration +} + +// ExampleWithErrorClassifier shows how to classify errors as retryable or +// terminal using WithErrorClassifier. DecisionRetry keeps retrying; +// DecisionStop ends the loop immediately. +func ExampleWithErrorClassifier() { + var ( + ErrTimeout = errors.New("request timeout") + ErrAuth = errors.New("authentication failed") + ) + + var attempts int + err := retry.Do(context.Background(), func(ctx context.Context) error { + attempts++ + if attempts < 3 { + return ErrTimeout + } + return ErrAuth + }, + retry.WithErrorClassifier(func(err error) retry.ErrorDecision { + if errors.Is(err, ErrTimeout) { + return retry.DecisionRetry + } + return retry.DecisionStop + }), + retry.WithMaxAttempts(5), + retry.WithInitialInterval(10*time.Millisecond), + retry.WithMaxInterval(20*time.Millisecond), + ) + + fmt.Println("attempts:", attempts) + fmt.Println("auth error:", errors.Is(err, ErrAuth)) + // Output: + // attempts: 3 + // auth error: true +}