diff --git a/build.savant b/build.savant index fb2e64f09..2be40b4c6 100644 --- a/build.savant +++ b/build.savant @@ -282,6 +282,8 @@ target(name: "build-swift", description: "Build the Swift Client Library") { target(name: "build-go", description: "Build the Go Client Library") { clientLibrary.buildClient(template: "src/main/client/go.client.ftl", outputFile: "../go-client/pkg/fusionauth/Client.go") + clientLibrary.buildClient(template: "src/main/client/go.client.test.ftl", + outputFile: "../go-client/pkg/fusionauth/Client_dynamic_test.go") clientLibrary.buildClient(template: "src/main/client/go.domain.ftl", outputFile: "../go-client/pkg/fusionauth/Domain.go") clientLibrary.buildClient(template: "src/main/client/go.domain.test.ftl", diff --git a/src/main/client/go.client.ftl b/src/main/client/go.client.ftl index 6aecd5e9f..0e0f59268 100644 --- a/src/main/client/go.client.ftl +++ b/src/main/client/go.client.ftl @@ -23,9 +23,12 @@ import ( "encoding/json" "fmt" "io" + "math" + "math/rand" "net/http" "net/http/httputil" "net/url" + "os" "path" "strconv" "strings" @@ -49,29 +52,145 @@ func NewClient(httpClient *http.Client, baseURL *url.URL, apiKey string) *Fusion return c } +// NewClientWithRetryConfiguration creates a new FusionAuthClient with the provided retry configuration. +// if httpClient is nil then a DefaultClient is used. +// Use NewBasicRetryConfiguration for sensible retry defaults. +func NewClientWithRetryConfiguration(httpClient *http.Client, baseURL *url.URL, apiKey string, retryConfiguration *RetryConfiguration) *FusionAuthClient { + if httpClient == nil { + httpClient = &http.Client{ + Timeout: 5 * time.Minute, + } + } + c := &FusionAuthClient{ + HTTPClient: httpClient, + BaseURL: baseURL, + APIKey: apiKey, + RetryConfiguration: retryConfiguration, + } + + return c +} + // SetTenantId sets the tenantId on the client -func (c *FusionAuthClient) SetTenantId(tenantId string) { +func (c *FusionAuthClient) SetTenantId(tenantId string) { c.TenantId = tenantId } +// RetryConfiguration configures automatic retry of failed HTTP requests. +// A nil RetryConfiguration (the default) means no retries are performed. +// Use NewBasicRetryConfiguration for sensible retry defaults. +type RetryConfiguration struct { + // AllowNonIdempotentRetries when true, all HTTP methods including POST will be retried. + // Defaults to false, meaning only idempotent methods (GET, PUT, DELETE, PATCH, HEAD) are retried. + AllowNonIdempotentRetries bool + // BackoffMultiplier is the multiplier applied to the delay between each retry attempt. Defaults to 2.0. + BackoffMultiplier float64 + // InitialDelay is the initial delay before the first retry. Defaults to 100ms. + InitialDelay time.Duration + // Jitter is the maximum random jitter multiplier added to every delay, in the range [0.0, 1.0]. + // Actual jitter is randomly chosen between 0.0 and this value. Defaults to 0.20. + Jitter float64 + // MaxDelay is the maximum delay between retry attempts. Defaults to 30s. + MaxDelay time.Duration + // MaxRetries is the number of additional attempts after the initial request. + // 0 effectively disables retries. Defaults to 4. + MaxRetries int + // RetryFunction is an optional function called to determine if a response warrants a retry, + // in addition to the built-in retryable status code checks. Return true to retry. + RetryFunction func(statusCode int, body []byte) bool + // RetryOnNetworkError when true, requests that fail due to network errors will be retried. Defaults to true. + RetryOnNetworkError bool + // RetryableStatusCodes is the set of HTTP status codes that trigger a retry. + // Defaults to {429, 500, 502, 503, 504}. + RetryableStatusCodes map[int]struct{} +} + +// NewBasicRetryConfiguration returns a RetryConfiguration with sensible defaults. +// It retries on status codes 429, 500, 502, 503, 504 and on retryableConflict (409) errors, +// using exponential backoff with 20% jitter. +func NewBasicRetryConfiguration() *RetryConfiguration { + return &RetryConfiguration{ + BackoffMultiplier: 2.0, + InitialDelay: 100 * time.Millisecond, + Jitter: 0.20, + MaxDelay: 30 * time.Second, + MaxRetries: 4, + RetryOnNetworkError: true, + RetryableStatusCodes: map[int]struct{}{ + 429: {}, + 500: {}, + 502: {}, + 503: {}, + 504: {}, + }, + RetryFunction: func(statusCode int, body []byte) bool { + return statusCode == http.StatusConflict && bytes.Contains(body, []byte("[retryableConflict]")) + }, + } +} + +// RetryConfigurationFromEnv returns NewBasicRetryConfiguration if the FUSIONAUTH_ENABLE_RETRY +// environment variable is set to "true", otherwise returns nil (no retries). +// This is useful for Terraform providers and other tools that want to opt-in to retries via +// an environment variable. +func RetryConfigurationFromEnv() *RetryConfiguration { + if os.Getenv("FUSIONAUTH_ENABLE_RETRY") == "true" { + return NewBasicRetryConfiguration() + } + return nil +} + +func (cfg *RetryConfiguration) validate() error { + if cfg.MaxRetries < 0 { + return fmt.Errorf("RetryConfiguration: MaxRetries must be non-negative") + } + if cfg.InitialDelay < 0 { + return fmt.Errorf("RetryConfiguration: InitialDelay must be non-negative") + } + if cfg.MaxDelay < 0 { + return fmt.Errorf("RetryConfiguration: MaxDelay must be non-negative") + } + if cfg.Jitter < 0.0 || cfg.Jitter > 1.0 { + return fmt.Errorf("RetryConfiguration: Jitter must be in the range [0.0, 1.0]") + } + if cfg.BackoffMultiplier < 0 { + return fmt.Errorf("RetryConfiguration: BackoffMultiplier must be non-negative") + } + return nil +} + +func (cfg *RetryConfiguration) calculateDelay(attempt int) time.Duration { + backoff := float64(cfg.InitialDelay) * math.Pow(cfg.BackoffMultiplier, float64(attempt-1)) + if float64(cfg.MaxDelay) > 0 && backoff > float64(cfg.MaxDelay) { + backoff = float64(cfg.MaxDelay) + } + if cfg.Jitter > 0 { + backoff *= 1.0 + rand.Float64()*cfg.Jitter + } + return time.Duration(backoff) +} + // FusionAuthClient describes the Go Client for interacting with FusionAuth's RESTful API type FusionAuthClient struct { - HTTPClient *http.Client - BaseURL *url.URL - APIKey string - Debug bool - TenantId string + HTTPClient *http.Client + BaseURL *url.URL + APIKey string + Debug bool + TenantId string + RetryConfiguration *RetryConfiguration } type restClient struct { - Body io.Reader - Debug bool - ErrorRef interface{} - Headers map[string]string - HTTPClient *http.Client - Method string - ResponseRef interface{} - Uri *url.URL + Body io.Reader + bodyBytes []byte + Debug bool + ErrorRef interface{} + Headers map[string]string + HTTPClient *http.Client + Method string + ResponseRef interface{} + RetryConfiguration *RetryConfiguration + Uri *url.URL } func (c *FusionAuthClient) Start(responseRef interface{}, errorRef interface{}) *restClient { @@ -80,11 +199,12 @@ func (c *FusionAuthClient) Start(responseRef interface{}, errorRef interface{}) func (c *FusionAuthClient) StartAnonymous(responseRef interface{}, errorRef interface{}) *restClient { rc := &restClient{ - Debug: c.Debug, - ErrorRef: errorRef, - Headers: make(map[string]string), - HTTPClient: c.HTTPClient, - ResponseRef: responseRef, + Debug: c.Debug, + ErrorRef: errorRef, + Headers: make(map[string]string), + HTTPClient: c.HTTPClient, + ResponseRef: responseRef, + RetryConfiguration: c.RetryConfiguration, } rc.Uri, _ = url.Parse(c.BaseURL.String()) if c.TenantId != "" { @@ -96,34 +216,118 @@ func (c *FusionAuthClient) StartAnonymous(responseRef interface{}, errorRef inte } func (rc *restClient) Do(ctx context.Context) error { - req, err := http.NewRequestWithContext(ctx, rc.Method, rc.Uri.String(), rc.Body) - if err != nil { - return err - } - for key, val := range rc.Headers { - req.Header.Set(key, val) + if rc.RetryConfiguration != nil { + if err := rc.RetryConfiguration.validate(); err != nil { + return err + } } - resp, err := rc.HTTPClient.Do(req) - if err != nil { - return err + + // Buffer the request body once so it can be replayed on retries. + if rc.Body != nil { + b, err := io.ReadAll(rc.Body) + if err != nil { + return err + } + rc.bodyBytes = b + rc.Body = nil } - defer resp.Body.Close() - if rc.Debug { - responseDump, _ := httputil.DumpResponse(resp, true) - fmt.Println(string(responseDump)) + + maxAttempts := 1 + if rc.RetryConfiguration != nil && rc.RetryConfiguration.MaxRetries > 0 { + maxAttempts = 1 + rc.RetryConfiguration.MaxRetries } - if resp.StatusCode < 200 || resp.StatusCode > 299 { - if err = json.NewDecoder(resp.Body).Decode(rc.ErrorRef); err == io.EOF { - err = nil + + for attempt := 0; attempt < maxAttempts; attempt++ { + if attempt > 0 { + delay := rc.RetryConfiguration.calculateDelay(attempt) + if delay > 0 { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delay): + } + } } - } else { - rc.ErrorRef = nil - if _, ok := rc.ResponseRef.(*BaseHTTPResponse); !ok { - err = json.NewDecoder(resp.Body).Decode(rc.ResponseRef) + + var body io.Reader + if rc.bodyBytes != nil { + body = bytes.NewReader(rc.bodyBytes) + } + req, err := http.NewRequestWithContext(ctx, rc.Method, rc.Uri.String(), body) + if err != nil { + return err + } + for key, val := range rc.Headers { + req.Header.Set(key, val) + } + + resp, err := rc.HTTPClient.Do(req) + if err != nil { + // Retry on network error if configured and method is retryable. + if attempt < maxAttempts-1 && rc.RetryConfiguration != nil && + rc.RetryConfiguration.RetryOnNetworkError && rc.isMethodRetryable() { + continue + } + return err + } + respBody, readErr := io.ReadAll(resp.Body) + resp.Body.Close() + if readErr != nil { + return readErr } + + if rc.Debug { + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + responseDump, _ := httputil.DumpResponse(resp, true) + fmt.Println(string(responseDump)) + } + + // Check whether this response should trigger a retry. + if attempt < maxAttempts-1 && rc.RetryConfiguration != nil && rc.isMethodRetryable() { + if rc.shouldRetry(resp.StatusCode, respBody) { + continue + } + } + + if resp.StatusCode < 200 || resp.StatusCode > 299 { + if err = json.NewDecoder(bytes.NewReader(respBody)).Decode(rc.ErrorRef); err == io.EOF { + err = nil + } + } else { + rc.ErrorRef = nil + if _, ok := rc.ResponseRef.(*BaseHTTPResponse); !ok { + err = json.NewDecoder(bytes.NewReader(respBody)).Decode(rc.ResponseRef) + } + } + rc.ResponseRef.(StatusAble).SetStatus(resp.StatusCode) + return err + } + return nil +} + +// isMethodRetryable returns true if the HTTP method is safe to retry. +// Idempotent methods (GET, PUT, DELETE, PATCH, HEAD) are always retryable. +// POST is retryable only when AllowNonIdempotentRetries is true. +func (rc *restClient) isMethodRetryable() bool { + if rc.RetryConfiguration.AllowNonIdempotentRetries { + return true + } + switch rc.Method { + case http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodPatch, http.MethodHead: + return true + } + return false +} + +// shouldRetry returns true if the status code or body indicates a retryable failure. +func (rc *restClient) shouldRetry(statusCode int, body []byte) bool { + if _, ok := rc.RetryConfiguration.RetryableStatusCodes[statusCode]; ok { + return true + } + if rc.RetryConfiguration.RetryFunction != nil { + return rc.RetryConfiguration.RetryFunction(statusCode, body) } - rc.ResponseRef.(StatusAble).SetStatus(resp.StatusCode) - return err + return false } func (rc *restClient) WithAuthorization(key string) *restClient { diff --git a/src/main/client/go.client.test.ftl b/src/main/client/go.client.test.ftl new file mode 100644 index 000000000..8ca604eda --- /dev/null +++ b/src/main/client/go.client.test.ftl @@ -0,0 +1,640 @@ +/* +* Copyright (c) 2019-${.now?string('yyyy')}, FusionAuth, All Rights Reserved +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +* either express or implied. See the License for the specific +* language governing permissions and limitations under the License. +*/ + +package fusionauth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "os" + "sync" + "sync/atomic" + "testing" + "time" +) + +// roundTripFunc allows creating inline http.RoundTripper implementations in tests. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +// newTestRC creates a restClient pointed at serverURL using the given RetryConfiguration and method. +// ResponseRef and ErrorRef are both set to &BaseHTTPResponse{} which satisfies StatusAble. +func newTestRC(serverURL *url.URL, cfg *RetryConfiguration, method string) *restClient { + return &restClient{ + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + Headers: make(map[string]string), + RetryConfiguration: cfg, + ResponseRef: &BaseHTTPResponse{}, + ErrorRef: &BaseHTTPResponse{}, + Method: method, + Uri: serverURL, + } +} + +// newCountingServer returns a test server that serves responses[i] for the i-th request. +// Once all responses are consumed, the last entry is repeated. +func newCountingServer(t *testing.T, responses []int) (*httptest.Server, *int32) { + t.Helper() + var callCount int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + idx := int(atomic.AddInt32(&callCount, 1)) - 1 + if idx >= len(responses) { + idx = len(responses) - 1 + } + w.WriteHeader(responses[idx]) + })) + t.Cleanup(server.Close) + return server, &callCount +} + +// --------------------------------------------------------------------------- +// RetryConfiguration struct tests +// --------------------------------------------------------------------------- + +func TestNewBasicRetryConfigurationDefaults(t *testing.T) { + rc := NewBasicRetryConfiguration() + + if rc.BackoffMultiplier != 2.0 { + t.Errorf("BackoffMultiplier: got %v, want 2.0", rc.BackoffMultiplier) + } + if rc.InitialDelay != 100*time.Millisecond { + t.Errorf("InitialDelay: got %v, want 100ms", rc.InitialDelay) + } + if rc.Jitter != 0.20 { + t.Errorf("Jitter: got %v, want 0.20", rc.Jitter) + } + if rc.MaxDelay != 30*time.Second { + t.Errorf("MaxDelay: got %v, want 30s", rc.MaxDelay) + } + if rc.MaxRetries != 4 { + t.Errorf("MaxRetries: got %d, want 4", rc.MaxRetries) + } + if !rc.RetryOnNetworkError { + t.Error("RetryOnNetworkError: want true") + } + if rc.AllowNonIdempotentRetries { + t.Error("AllowNonIdempotentRetries: want false") + } + for _, code := range []int{429, 500, 502, 503, 504} { + if _, ok := rc.RetryableStatusCodes[code]; !ok { + t.Errorf("RetryableStatusCodes: missing %d", code) + } + } + if rc.RetryFunction == nil { + t.Error("RetryFunction: want non-nil default") + } +} + +func TestRetryConfigurationValidate(t *testing.T) { + tests := []struct { + name string + cfg RetryConfiguration + wantErr bool + }{ + { + name: "zero value is valid", + cfg: RetryConfiguration{}, + }, + { + name: "valid typical config", + cfg: RetryConfiguration{ + BackoffMultiplier: 2.0, + InitialDelay: 100 * time.Millisecond, + Jitter: 0.20, + MaxDelay: 30 * time.Second, + MaxRetries: 4, + }, + }, + { + name: "jitter boundary 0.0", + cfg: RetryConfiguration{Jitter: 0.0}, + }, + { + name: "jitter boundary 1.0", + cfg: RetryConfiguration{Jitter: 1.0}, + }, + { + name: "negative MaxRetries", + cfg: RetryConfiguration{MaxRetries: -1}, + wantErr: true, + }, + { + name: "negative InitialDelay", + cfg: RetryConfiguration{InitialDelay: -1}, + wantErr: true, + }, + { + name: "negative MaxDelay", + cfg: RetryConfiguration{MaxDelay: -1}, + wantErr: true, + }, + { + name: "negative Jitter", + cfg: RetryConfiguration{Jitter: -0.01}, + wantErr: true, + }, + { + name: "Jitter above 1.0", + cfg: RetryConfiguration{Jitter: 1.01}, + wantErr: true, + }, + { + name: "negative BackoffMultiplier", + cfg: RetryConfiguration{BackoffMultiplier: -0.01}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.cfg.validate() + if (err != nil) != tt.wantErr { + t.Errorf("validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestRetryConfigurationCalculateDelay(t *testing.T) { + cfg := &RetryConfiguration{ + InitialDelay: 100 * time.Millisecond, + BackoffMultiplier: 2.0, + MaxDelay: 10 * time.Second, + Jitter: 0, // no jitter for deterministic checks + } + + cases := []struct { + attempt int + want time.Duration + }{ + {1, 100 * time.Millisecond}, + {2, 200 * time.Millisecond}, + {3, 400 * time.Millisecond}, + {4, 800 * time.Millisecond}, + {5, 1600 * time.Millisecond}, + } + + for _, tc := range cases { + got := cfg.calculateDelay(tc.attempt) + if got != tc.want { + t.Errorf("attempt %d: got %v, want %v", tc.attempt, got, tc.want) + } + } +} + +func TestRetryConfigurationCalculateDelayMaxDelayCapped(t *testing.T) { + cfg := &RetryConfiguration{ + InitialDelay: 1 * time.Second, + BackoffMultiplier: 2.0, + MaxDelay: 3 * time.Second, + Jitter: 0, + } + + for attempt := 1; attempt <= 6; attempt++ { + got := cfg.calculateDelay(attempt) + if got > 3*time.Second { + t.Errorf("attempt %d: delay %v exceeds MaxDelay 3s", attempt, got) + } + } +} + +func TestRetryConfigurationCalculateDelayWithJitter(t *testing.T) { + cfg := &RetryConfiguration{ + InitialDelay: 100 * time.Millisecond, + BackoffMultiplier: 2.0, + MaxDelay: 30 * time.Second, + Jitter: 0.20, + } + + // With 20% jitter the delay must be in [base, base * 1.20]. + base := 100 * time.Millisecond + for attempt := 1; attempt <= 3; attempt++ { + got := cfg.calculateDelay(attempt) + lo := base + hi := time.Duration(float64(base) * 1.20) + if got < lo || got > hi { + t.Errorf("attempt %d: delay %v outside expected range [%v, %v]", attempt, got, lo, hi) + } + base *= 2 + } +} + +// --------------------------------------------------------------------------- +// Retry behaviour integration tests (httptest server) +// --------------------------------------------------------------------------- + +func TestRetryOnRetryableStatusCodes(t *testing.T) { + for _, code := range []int{429, 500, 502, 503, 504} { + code := code + t.Run(fmt.Sprintf("retries_on_%d", code), func(t *testing.T) { + server, callCount := newCountingServer(t, []int{code, code, 200}) + serverURL, _ := url.Parse(server.URL) + cfg := &RetryConfiguration{ + MaxRetries: 2, + RetryableStatusCodes: map[int]struct{}{code: {}}, + } + rc := newTestRC(serverURL, cfg, http.MethodGet) + + if err := rc.Do(context.Background()); err != nil { + t.Fatalf("Do() unexpected error: %v", err) + } + if got := atomic.LoadInt32(callCount); got != 3 { + t.Errorf("expected 3 calls, got %d", got) + } + if status := rc.ResponseRef.(*BaseHTTPResponse).StatusCode; status != 200 { + t.Errorf("expected final status 200, got %d", status) + } + }) + } +} + +func TestNoRetryOnSuccessResponse(t *testing.T) { + server, callCount := newCountingServer(t, []int{200}) + serverURL, _ := url.Parse(server.URL) + cfg := &RetryConfiguration{ + MaxRetries: 3, + RetryableStatusCodes: map[int]struct{}{503: {}}, + } + rc := newTestRC(serverURL, cfg, http.MethodGet) + + if err := rc.Do(context.Background()); err != nil { + t.Fatalf("Do() unexpected error: %v", err) + } + if got := atomic.LoadInt32(callCount); got != 1 { + t.Errorf("expected 1 call (no retry on success), got %d", got) + } +} + +func TestNoRetryOnNonRetryableStatusCode(t *testing.T) { + server, callCount := newCountingServer(t, []int{400}) + serverURL, _ := url.Parse(server.URL) + cfg := &RetryConfiguration{ + MaxRetries: 3, + RetryableStatusCodes: map[int]struct{}{503: {}}, + } + rc := newTestRC(serverURL, cfg, http.MethodGet) + + if err := rc.Do(context.Background()); err != nil { + t.Fatalf("Do() unexpected error: %v", err) + } + if got := atomic.LoadInt32(callCount); got != 1 { + t.Errorf("expected 1 call (400 is not retryable), got %d", got) + } +} + +func TestMaxRetriesRespected(t *testing.T) { + server, callCount := newCountingServer(t, []int{503, 503, 503, 503, 503}) + serverURL, _ := url.Parse(server.URL) + cfg := &RetryConfiguration{ + MaxRetries: 2, + RetryableStatusCodes: map[int]struct{}{503: {}}, + } + rc := newTestRC(serverURL, cfg, http.MethodGet) + + if err := rc.Do(context.Background()); err != nil { + t.Fatalf("Do() unexpected error: %v", err) + } + // 1 initial + 2 retries = 3 total calls + if got := atomic.LoadInt32(callCount); got != 3 { + t.Errorf("expected 3 calls (1 initial + 2 retries), got %d", got) + } + // Final response should be the last 503 + if status := rc.ResponseRef.(*BaseHTTPResponse).StatusCode; status != 503 { + t.Errorf("expected final status 503, got %d", status) + } +} + +func TestRetryOnNetworkError(t *testing.T) { + var callCount int32 + u, _ := url.Parse("http://127.0.0.1:0") + cfg := &RetryConfiguration{ + MaxRetries: 2, + RetryOnNetworkError: true, + RetryableStatusCodes: map[int]struct{}{}, + } + rc := &restClient{ + HTTPClient: &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&callCount, 1) + return nil, fmt.Errorf("simulated network error") + }), + }, + Headers: make(map[string]string), + RetryConfiguration: cfg, + ResponseRef: &BaseHTTPResponse{}, + ErrorRef: &BaseHTTPResponse{}, + Method: http.MethodGet, + Uri: u, + } + + err := rc.Do(context.Background()) + if err == nil { + t.Fatal("expected error from network failure") + } + // 1 initial + 2 retries = 3 total + if got := atomic.LoadInt32(&callCount); got != 3 { + t.Errorf("expected 3 calls, got %d", got) + } +} + +func TestNoRetryOnNetworkErrorWhenDisabled(t *testing.T) { + var callCount int32 + u, _ := url.Parse("http://127.0.0.1:0") + cfg := &RetryConfiguration{ + MaxRetries: 2, + RetryOnNetworkError: false, + RetryableStatusCodes: map[int]struct{}{}, + } + rc := &restClient{ + HTTPClient: &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&callCount, 1) + return nil, fmt.Errorf("simulated network error") + }), + }, + Headers: make(map[string]string), + RetryConfiguration: cfg, + ResponseRef: &BaseHTTPResponse{}, + ErrorRef: &BaseHTTPResponse{}, + Method: http.MethodGet, + Uri: u, + } + + err := rc.Do(context.Background()) + if err == nil { + t.Fatal("expected error from network failure") + } + if got := atomic.LoadInt32(&callCount); got != 1 { + t.Errorf("expected 1 call (no network retry when disabled), got %d", got) + } +} + +func TestRetryFunctionRetryableConflict(t *testing.T) { + var callCount int32 + conflictBody := `{"generalErrors":[{"code":"[retryableConflict]","message":"conflict"}]}` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&callCount, 1) + if n < 3 { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusConflict) + fmt.Fprint(w, conflictBody) + } else { + w.WriteHeader(http.StatusOK) + } + })) + defer server.Close() + + serverURL, _ := url.Parse(server.URL) + cfg := NewBasicRetryConfiguration() + cfg.InitialDelay = 0 // no delay for fast tests + + rc := newTestRC(serverURL, cfg, http.MethodGet) + if err := rc.Do(context.Background()); err != nil { + t.Fatalf("Do() unexpected error: %v", err) + } + if got := atomic.LoadInt32(&callCount); got != 3 { + t.Errorf("expected 3 calls (2 conflict retries + 1 success), got %d", got) + } +} + +func TestRetryFunctionConflictWithoutRetryableCodeIsNotRetried(t *testing.T) { + var callCount int32 + // 409 without [retryableConflict] in body should not be retried + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&callCount, 1) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusConflict) + fmt.Fprint(w, `{"generalErrors":[{"code":"[someOtherConflict]"}]}`) + })) + defer server.Close() + + serverURL, _ := url.Parse(server.URL) + cfg := NewBasicRetryConfiguration() + cfg.InitialDelay = 0 + + rc := newTestRC(serverURL, cfg, http.MethodGet) + if err := rc.Do(context.Background()); err != nil { + t.Fatalf("Do() unexpected error: %v", err) + } + if got := atomic.LoadInt32(&callCount); got != 1 { + t.Errorf("expected 1 call (non-retryable 409), got %d", got) + } +} + +func TestNoRetryForPostByDefault(t *testing.T) { + server, callCount := newCountingServer(t, []int{503, 503, 200}) + serverURL, _ := url.Parse(server.URL) + cfg := &RetryConfiguration{ + MaxRetries: 2, + RetryableStatusCodes: map[int]struct{}{503: {}}, + // AllowNonIdempotentRetries defaults to false + } + rc := newTestRC(serverURL, cfg, http.MethodPost) + + if err := rc.Do(context.Background()); err != nil { + t.Fatalf("Do() unexpected error: %v", err) + } + if got := atomic.LoadInt32(callCount); got != 1 { + t.Errorf("expected 1 call (POST not retried by default), got %d", got) + } +} + +func TestAllowNonIdempotentRetriesEnablesPostRetry(t *testing.T) { + server, callCount := newCountingServer(t, []int{503, 503, 200}) + serverURL, _ := url.Parse(server.URL) + cfg := &RetryConfiguration{ + MaxRetries: 2, + AllowNonIdempotentRetries: true, + RetryableStatusCodes: map[int]struct{}{503: {}}, + } + rc := newTestRC(serverURL, cfg, http.MethodPost) + + if err := rc.Do(context.Background()); err != nil { + t.Fatalf("Do() unexpected error: %v", err) + } + if got := atomic.LoadInt32(callCount); got != 3 { + t.Errorf("expected 3 calls (POST retried with AllowNonIdempotentRetries), got %d", got) + } +} + +func TestIdempotentMethodsAreRetried(t *testing.T) { + for _, method := range []string{ + http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodPatch, http.MethodHead, + } { + method := method + t.Run(method, func(t *testing.T) { + server, callCount := newCountingServer(t, []int{503, 200}) + serverURL, _ := url.Parse(server.URL) + cfg := &RetryConfiguration{ + MaxRetries: 1, + RetryableStatusCodes: map[int]struct{}{503: {}}, + } + rc := newTestRC(serverURL, cfg, method) + + if err := rc.Do(context.Background()); err != nil { + t.Fatalf("Do() [%s] unexpected error: %v", method, err) + } + if got := atomic.LoadInt32(callCount); got != 2 { + t.Errorf("[%s] expected 2 calls, got %d", method, got) + } + }) + } +} + +func TestContextCancelledDuringBackoff(t *testing.T) { + server, _ := newCountingServer(t, []int{503, 200}) + serverURL, _ := url.Parse(server.URL) + cfg := &RetryConfiguration{ + MaxRetries: 1, + RetryableStatusCodes: map[int]struct{}{503: {}}, + InitialDelay: 10 * time.Second, // long enough that context fires first + BackoffMultiplier: 1.0, + } + rc := newTestRC(serverURL, cfg, http.MethodGet) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := rc.Do(ctx) + if err == nil { + t.Fatal("expected context cancellation error") + } +} + +func TestRetryBodyReplayedOnRetry(t *testing.T) { + var mu sync.Mutex + var receivedBodies []string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var m map[string]interface{} + if decErr := json.NewDecoder(r.Body).Decode(&m); decErr == nil { + if v, _ := m["key"].(string); v != "" { + mu.Lock() + receivedBodies = append(receivedBodies, v) + mu.Unlock() + } + } + + mu.Lock() + n := len(receivedBodies) + mu.Unlock() + + if n < 2 { + w.WriteHeader(http.StatusServiceUnavailable) + } else { + w.WriteHeader(http.StatusOK) + } + })) + defer server.Close() + + serverURL, _ := url.Parse(server.URL) + cfg := &RetryConfiguration{ + MaxRetries: 2, + AllowNonIdempotentRetries: true, + RetryableStatusCodes: map[int]struct{}{503: {}}, + } + rc := newTestRC(serverURL, cfg, http.MethodPost) + rc.WithJSONBody(map[string]string{"key": "hello"}) + + if err := rc.Do(context.Background()); err != nil { + t.Fatalf("Do() unexpected error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + if len(receivedBodies) < 2 { + t.Fatalf("expected at least 2 requests with body, got %d", len(receivedBodies)) + } + for i, body := range receivedBodies { + if body != "hello" { + t.Errorf("attempt %d: body key = %q, want %q", i+1, body, "hello") + } + } +} + +func TestInvalidRetryConfigReturnsError(t *testing.T) { + u, _ := url.Parse("http://127.0.0.1:0") + cfg := &RetryConfiguration{MaxRetries: -1} + rc := &restClient{ + HTTPClient: &http.Client{}, + Headers: make(map[string]string), + RetryConfiguration: cfg, + ResponseRef: &BaseHTTPResponse{}, + ErrorRef: &BaseHTTPResponse{}, + Method: http.MethodGet, + Uri: u, + } + if err := rc.Do(context.Background()); err == nil { + t.Error("expected validation error for negative MaxRetries") + } +} + +func TestNilRetryConfigurationNoRetries(t *testing.T) { + server, callCount := newCountingServer(t, []int{503, 503, 200}) + serverURL, _ := url.Parse(server.URL) + rc := newTestRC(serverURL, nil, http.MethodGet) + + if err := rc.Do(context.Background()); err != nil { + t.Fatalf("Do() unexpected error: %v", err) + } + if got := atomic.LoadInt32(callCount); got != 1 { + t.Errorf("expected 1 call (nil config = no retries), got %d", got) + } +} + +// --------------------------------------------------------------------------- +// RetryConfigurationFromEnv tests +// --------------------------------------------------------------------------- + +func TestRetryConfigurationFromEnvNotSet(t *testing.T) { + os.Unsetenv("FUSIONAUTH_ENABLE_RETRY") + cfg := RetryConfigurationFromEnv() + if cfg != nil { + t.Errorf("expected nil when FUSIONAUTH_ENABLE_RETRY not set, got %+v", cfg) + } +} + +func TestRetryConfigurationFromEnvWrongValue(t *testing.T) { + os.Setenv("FUSIONAUTH_ENABLE_RETRY", "1") + defer os.Unsetenv("FUSIONAUTH_ENABLE_RETRY") + cfg := RetryConfigurationFromEnv() + if cfg != nil { + t.Errorf("expected nil when FUSIONAUTH_ENABLE_RETRY=%q (not 'true'), got non-nil", "1") + } +} + +func TestRetryConfigurationFromEnvEnabled(t *testing.T) { + os.Setenv("FUSIONAUTH_ENABLE_RETRY", "true") + defer os.Unsetenv("FUSIONAUTH_ENABLE_RETRY") + cfg := RetryConfigurationFromEnv() + if cfg == nil { + t.Fatal("expected non-nil RetryConfiguration when FUSIONAUTH_ENABLE_RETRY=true") + } + if cfg.MaxRetries != 4 { + t.Errorf("MaxRetries: got %d, want 4", cfg.MaxRetries) + } + if cfg.InitialDelay != 100*time.Millisecond { + t.Errorf("InitialDelay: got %v, want 100ms", cfg.InitialDelay) + } +}