From 8a0bc4f9353385bc167716fd8d9c89a1b663aafb Mon Sep 17 00:00:00 2001 From: magodo Date: Fri, 11 Jul 2025 16:07:39 +1000 Subject: [PATCH 1/4] Connection initializes Client with `AuthProvider` instead `AuthString` This change is to allow AAD based auth token to refresh when exipred. Previously, if the token encoded in the auth string is expired, the client will always fail. With this change, the Client is initialized with the `AuthProvider`, which is built on top of `azidentity` and MSAL-Go. Everytime the client is gonna make a request, the token will be retrieved from the underlying MSAL-Go library, either from its cache (if not expired) or a fresh new one retrieved via API. This means if the token got expired, a new token will be retireved and used by the client. However, there is one edge case: Since there is no "token expiration buffer" in the MSAL-Go right now, if the token returned from cache expires right after returning, the client will then use this invalid token for an API call, hence fail. There is no "retry" mechanism in the current client implementation to mitigate this. --- azuredevops/v7/auth.go | 13 +++++++++++++ azuredevops/v7/auth_aad.go | 30 ++++++++++++++++++++++++++++++ azuredevops/v7/auth_pat.go | 19 +++++++++++++++++++ azuredevops/v7/client.go | 14 ++++++++++---- azuredevops/v7/connection.go | 6 +++--- azuredevops/v7/go.mod | 15 +++++++++++++-- azuredevops/v7/go.sum | 20 ++++++++++++++++++-- 7 files changed, 106 insertions(+), 11 deletions(-) create mode 100644 azuredevops/v7/auth.go create mode 100644 azuredevops/v7/auth_aad.go create mode 100644 azuredevops/v7/auth_pat.go diff --git a/azuredevops/v7/auth.go b/azuredevops/v7/auth.go new file mode 100644 index 00000000..966a39b4 --- /dev/null +++ b/azuredevops/v7/auth.go @@ -0,0 +1,13 @@ +package azuredevops + +import ( + "context" +) + +type Auth struct { + AuthString string +} + +type AuthProvider interface { + GetAuth(ctx context.Context) (string, error) +} diff --git a/azuredevops/v7/auth_aad.go b/azuredevops/v7/auth_aad.go new file mode 100644 index 00000000..28f66e93 --- /dev/null +++ b/azuredevops/v7/auth_aad.go @@ -0,0 +1,30 @@ +package azuredevops + +import ( + "context" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +type AADCred interface { + GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) +} + +type AuthProviderAAD struct { + cred AADCred + opts policy.TokenRequestOptions +} + +func NewAuthProviderAAD(cred AADCred, opts policy.TokenRequestOptions) AuthProvider { + return AuthProviderAAD{cred, opts} +} + +func (p AuthProviderAAD) GetAuth(ctx context.Context) (string, error) { + token, err := p.cred.GetToken(ctx, p.opts) + if err != nil { + return "", fmt.Errorf("failed to get AAD token: %v", err) + } + return "Bearer " + token.Token, nil +} diff --git a/azuredevops/v7/auth_pat.go b/azuredevops/v7/auth_pat.go new file mode 100644 index 00000000..a0149d15 --- /dev/null +++ b/azuredevops/v7/auth_pat.go @@ -0,0 +1,19 @@ +package azuredevops + +import ( + "context" + "encoding/base64" +) + +type AuthProviderPAT struct { + pat string +} + +func NewAuthProviderPAT(pat string) AuthProvider { + return AuthProviderPAT{pat} +} + +func (p AuthProviderPAT) GetAuth(_ context.Context) (string, error) { + auth := "_:" + p.pat + return "Basic " + base64.StdEncoding.EncodeToString([]byte(auth)), nil +} diff --git a/azuredevops/v7/client.go b/azuredevops/v7/client.go index e0d0d4ce..766d0321 100644 --- a/azuredevops/v7/client.go +++ b/azuredevops/v7/client.go @@ -8,6 +8,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "io/ioutil" "net/http" @@ -70,7 +71,7 @@ func NewClientWithOptions(connection *Connection, baseUrl string, options ...Cli client := &Client{ baseUrl: baseUrl, client: httpClient, - authorization: connection.AuthorizationString, + authProvider: connection.AuthProvider, suppressFedAuthRedirect: connection.SuppressFedAuthRedirect, forceMsaPassThrough: connection.ForceMsaPassThrough, userAgent: connection.UserAgent, @@ -84,7 +85,7 @@ func NewClientWithOptions(connection *Connection, baseUrl string, options ...Cli type Client struct { baseUrl string client *http.Client - authorization string + authProvider AuthProvider suppressFedAuthRedirect bool forceMsaPassThrough bool userAgent string @@ -169,9 +170,14 @@ func (client *Client) CreateRequestMessage(ctx context.Context, req = req.WithContext(ctx) } - if client.authorization != "" { - req.Header.Add(headerKeyAuthorization, client.authorization) + if client.authProvider != nil { + auth, err := client.authProvider.GetAuth(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get auth from auth cache: %v", err) + } + req.Header.Add(headerKeyAuthorization, auth) } + accept := acceptMediaType if apiVersion != "" { accept += ";api-version=" + apiVersion diff --git a/azuredevops/v7/connection.go b/azuredevops/v7/connection.go index eb76f53c..e954b5d4 100644 --- a/azuredevops/v7/connection.go +++ b/azuredevops/v7/connection.go @@ -16,10 +16,10 @@ import ( // Creates a new Azure DevOps connection instance using a personal access token. func NewPatConnection(organizationUrl string, personalAccessToken string) *Connection { - authorizationString := CreateBasicAuthHeaderValue("", personalAccessToken) organizationUrl = normalizeUrl(organizationUrl) + authProvider := NewAuthProviderPAT(personalAccessToken) return &Connection{ - AuthorizationString: authorizationString, + AuthProvider: authProvider, BaseUrl: organizationUrl, SuppressFedAuthRedirect: true, } @@ -34,7 +34,7 @@ func NewAnonymousConnection(organizationUrl string) *Connection { } type Connection struct { - AuthorizationString string + AuthProvider AuthProvider BaseUrl string UserAgent string SuppressFedAuthRedirect bool diff --git a/azuredevops/v7/go.mod b/azuredevops/v7/go.mod index 0c3ebc6c..5c909f45 100644 --- a/azuredevops/v7/go.mod +++ b/azuredevops/v7/go.mod @@ -1,5 +1,16 @@ module github.com/microsoft/azure-devops-go-api/azuredevops/v7 -go 1.12 +go 1.23.0 -require github.com/google/uuid v1.1.1 +toolchain go1.24.1 + +require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 + github.com/google/uuid v1.6.0 +) + +require ( + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.0 // indirect + golang.org/x/net v0.38.0 // indirect + golang.org/x/text v0.23.0 // indirect +) diff --git a/azuredevops/v7/go.sum b/azuredevops/v7/go.sum index b864886e..39e35ef0 100644 --- a/azuredevops/v7/go.sum +++ b/azuredevops/v7/go.sum @@ -1,2 +1,18 @@ -github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= -github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.0 h1:Bg8m3nq/X1DeePkAbCfb6ml6F3F0IunEhE8TMh+lY48= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.0/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 7fc85ff92766f7882d2a0e0a31fa81ab02a73aca Mon Sep 17 00:00:00 2001 From: magodo Date: Mon, 19 Jan 2026 03:00:56 +0000 Subject: [PATCH 2/4] Patch `core` models --- azuredevops/v7/core/models.go | 6 +++--- azuredevops/v7/core/models_ext.go | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) create mode 100644 azuredevops/v7/core/models_ext.go diff --git a/azuredevops/v7/core/models.go b/azuredevops/v7/core/models.go index 7977022d..efb2901e 100644 --- a/azuredevops/v7/core/models.go +++ b/azuredevops/v7/core/models.go @@ -229,8 +229,8 @@ type sourceControlTypesValuesType struct { } var SourceControlTypesValues = sourceControlTypesValuesType{ - Tfvc: "tfvc", - Git: "git", + Tfvc: "Tfvc", + Git: "Git", } // The Team Context for an operation. @@ -270,7 +270,7 @@ type TeamProject struct { // The links to other objects related to this object. Links interface{} `json:"_links,omitempty"` // Set of capabilities this project has (such as process template & version control). - Capabilities *map[string]map[string]string `json:"capabilities,omitempty"` + Capabilities *TeamProjectCapabilities `json:"capabilities,omitempty"` // The shallow ref to the default team. DefaultTeam *WebApiTeamRef `json:"defaultTeam,omitempty"` } diff --git a/azuredevops/v7/core/models_ext.go b/azuredevops/v7/core/models_ext.go new file mode 100644 index 00000000..e9e3b621 --- /dev/null +++ b/azuredevops/v7/core/models_ext.go @@ -0,0 +1,14 @@ +package core + +type TeamProjectCapabilities struct { + Versioncontrol *TeamProjectCapabilitiesVersionControl `json:"versioncontrol,omitempty"` + ProcessTemplate *TeamProjectCapabilitiesProcessTemplate `json:"processTemplate,omitempty"` +} + +type TeamProjectCapabilitiesVersionControl struct { + SourceControlType *SourceControlTypes `json:"sourceControlType,omitempty"` +} + +type TeamProjectCapabilitiesProcessTemplate struct { + TemplateId *string `json:"templateTypeId,omitempty"` +} From 5ba47e0e551c2b0b585495d92092cbe8ad6f5e76 Mon Sep 17 00:00:00 2001 From: magodo Date: Thu, 29 Jan 2026 16:07:19 +1100 Subject: [PATCH 3/4] UnwrapError() only returns WrappedError (instead of its pointer sometimes) --- azuredevops/v7/client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/azuredevops/v7/client.go b/azuredevops/v7/client.go index 766d0321..526871cd 100644 --- a/azuredevops/v7/client.go +++ b/azuredevops/v7/client.go @@ -378,7 +378,7 @@ func trimByteOrderMark(body []byte) []byte { func (client *Client) UnwrapError(response *http.Response) (err error) { if response.ContentLength == 0 { message := "Request returned status: " + response.Status - return &WrappedError{ + return WrappedError{ Message: &message, StatusCode: &response.StatusCode, } @@ -415,7 +415,7 @@ func (client *Client) UnwrapError(response *http.Response) (err error) { var wrappedImproperError WrappedImproperError err = json.Unmarshal(body, &wrappedImproperError) if err == nil && wrappedImproperError.Value != nil && wrappedImproperError.Value.Message != nil { - return &WrappedError{ + return WrappedError{ Message: wrappedImproperError.Value.Message, StatusCode: &response.StatusCode, } From f7730ffa2eabb4ea34be00f1d1de9991f4c0942d Mon Sep 17 00:00:00 2001 From: magodo Date: Wed, 15 Apr 2026 16:48:13 +1000 Subject: [PATCH 4/4] Add configurable retry with exponential backoff for transient errors in v7 client ADO servers occasionally fail with transient connection errors such as "peer connection closed" or "connection reset by peer", where a simple retry succeeds. Add retry support to the v7 Client.SendRequest method with exponential backoff that is configurable via WithRetryOptions. RetryOptions allows callers to set MaxRetries (default 3), initial Delay (default 1s, doubled each attempt), and an optional IsRetryable predicate. The default predicate (DefaultIsRetryable) covers common transient network failures: EOF, connection resets, closed connections, TLS handshake timeouts, and I/O timeouts, while never retrying context cancellation or deadline exceeded errors. Request bodies are buffered before the first attempt so they can be replayed on retries, and backoff sleeps respect context cancellation. --- azuredevops/v7/client.go | 107 +++++++++++- azuredevops/v7/client_options.go | 27 +++ azuredevops/v7/client_test.go | 276 ++++++++++++++++++++++++++++++- 3 files changed, 402 insertions(+), 8 deletions(-) diff --git a/azuredevops/v7/client.go b/azuredevops/v7/client.go index 526871cd..5b756b51 100644 --- a/azuredevops/v7/client.go +++ b/azuredevops/v7/client.go @@ -18,6 +18,7 @@ import ( "runtime" "strings" "sync" + "time" "github.com/google/uuid" ) @@ -62,11 +63,11 @@ func NewClient(connection *Connection, baseUrl string) *Client { httpClient.Timeout = *connection.Timeout } - return NewClientWithOptions(connection, baseUrl, WithHTTPClient(httpClient)) + return newClientWithOptions(connection, baseUrl, WithHTTPClient(httpClient)) } -// NewClientWithOptions returns an Azure DevOps client modified by the options -func NewClientWithOptions(connection *Connection, baseUrl string, options ...ClientOptionFunc) *Client { +// newClientWithOptions returns an Azure DevOps client modified by the options +func newClientWithOptions(connection *Connection, baseUrl string, options ...ClientOptionFunc) *Client { httpClient := &http.Client{} client := &Client{ baseUrl: baseUrl, @@ -89,14 +90,106 @@ type Client struct { suppressFedAuthRedirect bool forceMsaPassThrough bool userAgent string + retryOptions *RetryOptions } func (client *Client) SendRequest(request *http.Request) (response *http.Response, err error) { - resp, err := client.client.Do(request) // todo: add retry logic - if resp != nil && (resp.StatusCode < 200 || resp.StatusCode >= 300) { - err = client.UnwrapError(resp) + var ( + maxRetries = 3 + retryDelay = time.Second + isRetryable = DefaultIsRetryable + ) + + if opt := client.retryOptions; opt != nil { + maxRetries = opt.MaxRetries + if opt.Delay > 0 { + retryDelay = opt.Delay + } + if opt.IsRetryable != nil { + isRetryable = opt.IsRetryable + } } - return resp, err + + // Buffer the request body so it can be replayed on retries. + if maxRetries > 0 && request.Body != nil && request.GetBody == nil { + bodyBytes, readErr := io.ReadAll(request.Body) + request.Body.Close() + if readErr != nil { + return nil, readErr + } + request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + request.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(bodyBytes)), nil + } + } + + for attempt := 0; ; attempt++ { + resp, doErr := client.client.Do(request) + + if doErr != nil && attempt < maxRetries && isRetryable(resp, doErr) { + // Drain and close response body if present. + if resp != nil && resp.Body != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + // Reset body for retry. + if request.GetBody != nil { + newBody, bodyErr := request.GetBody() + if bodyErr != nil { + return nil, bodyErr + } + request.Body = newBody + } + // Exponential backoff: delay * 2^attempt, respecting context cancellation. + delay := retryDelay * time.Duration(1<= 300) { + doErr = client.UnwrapError(resp) + } + return resp, doErr + } +} + +// DefaultIsRetryable returns true for transient connection errors that are +// safe to retry: connection resets, unexpected EOFs, closed connections, and +// similar network-level failures. It returns false for context cancellation +// and deadline exceeded errors. +func DefaultIsRetryable(resp *http.Response, err error) bool { + if err == nil { + return false + } + + // Retry on EOF / unexpected EOF. + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return true + } + + // Check the error message for well-known transient patterns. + msg := strings.ToLower(err.Error()) + transientPatterns := []string{ + "connection reset by peer", + "connection was forcibly closed", + "peer connection closed", + "tls handshake timeout", + "i/o timeout", + "unexpected eof", + "use of closed network connection", + } + for _, pattern := range transientPatterns { + if strings.Contains(msg, pattern) { + return true + } + } + + return false } func (client *Client) Send(ctx context.Context, diff --git a/azuredevops/v7/client_options.go b/azuredevops/v7/client_options.go index 315288f6..4faad33f 100644 --- a/azuredevops/v7/client_options.go +++ b/azuredevops/v7/client_options.go @@ -2,6 +2,7 @@ package azuredevops import ( "net/http" + "time" ) // ClientOptionFunc can be used customize a new AzureDevops API client. @@ -13,3 +14,29 @@ func WithHTTPClient(httpClient *http.Client) ClientOptionFunc { c.client = httpClient } } + +// WithRetryOptions configures retry behavior for transient errors. +// When set, the client will retry failed requests that match the IsRetryable +// predicate, up to MaxRetries times with exponential backoff. +func WithRetryOptions(options RetryOptions) ClientOptionFunc { + return func(c *Client) { + c.retryOptions = &options + } +} + +// RetryOptions configures retry behavior for the client. +type RetryOptions struct { + // MaxRetries is the maximum number of retry attempts. + // A value of 0 means no retries. + // Defaults to 3. + MaxRetries int + + // Delay is the initial delay between retries. Subsequent retries use + // exponential backoff (delay * 2^attempt). Default: 1 second. + Delay time.Duration + + // IsRetryable determines whether a failed request should be retried. + // It receives the HTTP response (may be nil for transport-level errors) + // and the error. If nil, DefaultIsRetryable is used. + IsRetryable func(resp *http.Response, err error) bool +} diff --git a/azuredevops/v7/client_test.go b/azuredevops/v7/client_test.go index 2df26bf6..89c750b7 100644 --- a/azuredevops/v7/client_test.go +++ b/azuredevops/v7/client_test.go @@ -1,8 +1,15 @@ package azuredevops import ( + "bytes" + "context" "crypto/tls" + "errors" + "fmt" + "io" "net/http" + "strings" + "sync/atomic" "testing" "time" ) @@ -38,7 +45,7 @@ func TestClient_NewClientWithOptions_WithHTTPClient(t *testing.T) { httpClient := &http.Client{Timeout: httpTimeout} baseURL := "localhost" - client := NewClientWithOptions(conn, baseURL, WithHTTPClient(httpClient)) + client := newClientWithOptions(conn, baseURL, WithHTTPClient(httpClient)) if client.baseUrl != baseURL { t.Errorf("Expected baseURL: %v Actual baseURL: %v", baseURL, client.baseUrl) } @@ -46,3 +53,270 @@ func TestClient_NewClientWithOptions_WithHTTPClient(t *testing.T) { t.Errorf("Expected httpClient.Timeout: %#v Actual httpClient.Timeout: %#v", httpClient.Timeout, actualHTTPClient.Timeout) } } + +func TestClient_NewClientWithOptions_WithRetryOptions(t *testing.T) { + conn := &Connection{} + opts := RetryOptions{MaxRetries: 5, Delay: 2 * time.Second} + client := newClientWithOptions(conn, "localhost", WithRetryOptions(opts)) + if client.retryOptions == nil { + t.Fatal("Expected retryOptions to be set") + } + if client.retryOptions.MaxRetries != 5 { + t.Errorf("Expected MaxRetries=5, got %d", client.retryOptions.MaxRetries) + } + if client.retryOptions.Delay != 2*time.Second { + t.Errorf("Expected Delay=2s, got %v", client.retryOptions.Delay) + } +} + +// roundTripFunc adapts a function to http.RoundTripper for testing. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestSendRequest_NoRetry_Success(t *testing.T) { + client := &Client{ + client: &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil + }), + }, + } + req, _ := http.NewRequest("GET", "http://example.com", nil) + resp, err := client.SendRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } +} + +func TestSendRequest_NoRetryOptions_TransientError(t *testing.T) { + // Without retry options, transient errors are returned immediately. + client := &Client{ + client: &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("connection reset by peer") + }), + }, + } + req, _ := http.NewRequest("GET", "http://example.com", nil) + _, err := client.SendRequest(req) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "connection reset by peer") { + t.Errorf("expected 'connection reset by peer' error, got: %v", err) + } +} + +func TestSendRequest_RetryOnTransientError(t *testing.T) { + var attempts int32 + client := &Client{ + client: &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + n := atomic.AddInt32(&attempts, 1) + if n <= 2 { + return nil, fmt.Errorf("read tcp: connection reset by peer") + } + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil + }), + }, + retryOptions: &RetryOptions{ + MaxRetries: 3, + Delay: time.Millisecond, // fast for tests + }, + } + req, _ := http.NewRequest("GET", "http://example.com", nil) + resp, err := client.SendRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if atomic.LoadInt32(&attempts) != 3 { + t.Errorf("expected 3 attempts, got %d", atomic.LoadInt32(&attempts)) + } +} + +func TestSendRequest_RetryExhausted(t *testing.T) { + var attempts int32 + client := &Client{ + client: &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + atomic.AddInt32(&attempts, 1) + return nil, fmt.Errorf("peer connection closed") + }), + }, + retryOptions: &RetryOptions{ + MaxRetries: 2, + Delay: time.Millisecond, + }, + } + req, _ := http.NewRequest("GET", "http://example.com", nil) + _, err := client.SendRequest(req) + if err == nil { + t.Fatal("expected error after retries exhausted") + } + // 1 initial + 2 retries = 3 total attempts + if atomic.LoadInt32(&attempts) != 3 { + t.Errorf("expected 3 attempts, got %d", atomic.LoadInt32(&attempts)) + } +} + +func TestSendRequest_RetryWithBody(t *testing.T) { + var bodies []string + var attempts int32 + client := &Client{ + client: &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + b, _ := io.ReadAll(req.Body) + bodies = append(bodies, string(b)) + n := atomic.AddInt32(&attempts, 1) + if n == 1 { + return nil, fmt.Errorf("connection reset by peer") + } + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil + }), + }, + retryOptions: &RetryOptions{ + MaxRetries: 2, + Delay: time.Millisecond, + }, + } + body := "request body content" + req, _ := http.NewRequest("POST", "http://example.com", bytes.NewBufferString(body)) + resp, err := client.SendRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if len(bodies) != 2 { + t.Fatalf("expected 2 attempts, got %d", len(bodies)) + } + for i, b := range bodies { + if b != body { + t.Errorf("attempt %d: expected body %q, got %q", i, body, b) + } + } +} + +func TestSendRequest_NoRetryOnContextCanceled(t *testing.T) { + var attempts int32 + client := &Client{ + client: &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + atomic.AddInt32(&attempts, 1) + return nil, context.Canceled + }), + }, + retryOptions: &RetryOptions{ + MaxRetries: 3, + Delay: time.Millisecond, + }, + } + req, _ := http.NewRequest("GET", "http://example.com", nil) + _, err := client.SendRequest(req) + if !errors.Is(err, context.Canceled) { + t.Errorf("expected context.Canceled, got: %v", err) + } + if atomic.LoadInt32(&attempts) != 1 { + t.Errorf("expected 1 attempt (no retry), got %d", atomic.LoadInt32(&attempts)) + } +} + +func TestSendRequest_NoRetryOnNonRetryableError(t *testing.T) { + var attempts int32 + client := &Client{ + client: &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + atomic.AddInt32(&attempts, 1) + return nil, fmt.Errorf("some permanent error") + }), + }, + retryOptions: &RetryOptions{ + MaxRetries: 3, + Delay: time.Millisecond, + }, + } + req, _ := http.NewRequest("GET", "http://example.com", nil) + _, err := client.SendRequest(req) + if err == nil { + t.Fatal("expected error") + } + if atomic.LoadInt32(&attempts) != 1 { + t.Errorf("expected 1 attempt (no retry for non-retryable), got %d", atomic.LoadInt32(&attempts)) + } +} + +func TestSendRequest_CustomIsRetryable(t *testing.T) { + var attempts int32 + client := &Client{ + client: &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + n := atomic.AddInt32(&attempts, 1) + if n == 1 { + return nil, fmt.Errorf("custom transient error") + } + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil + }), + }, + retryOptions: &RetryOptions{ + MaxRetries: 2, + Delay: time.Millisecond, + IsRetryable: func(resp *http.Response, err error) bool { + return err != nil && strings.Contains(err.Error(), "custom transient") + }, + }, + } + req, _ := http.NewRequest("GET", "http://example.com", nil) + resp, err := client.SendRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if atomic.LoadInt32(&attempts) != 2 { + t.Errorf("expected 2 attempts, got %d", atomic.LoadInt32(&attempts)) + } +} + +func TestSendRequest_ContextCancelDuringSleep(t *testing.T) { + var attempts int32 + client := &Client{ + client: &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + atomic.AddInt32(&attempts, 1) + return nil, fmt.Errorf("connection reset by peer") + }), + }, + retryOptions: &RetryOptions{ + MaxRetries: 3, + Delay: 5 * time.Second, // long delay + }, + } + ctx, cancel := context.WithCancel(context.Background()) + req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil) + + done := make(chan error, 1) + go func() { + _, err := client.SendRequest(req) + done <- err + }() + + // Cancel context shortly after the first failed attempt triggers a retry sleep. + time.Sleep(50 * time.Millisecond) + cancel() + + err := <-done + if !errors.Is(err, context.Canceled) { + t.Errorf("expected context.Canceled, got: %v", err) + } +}