From 37c75f98e86dac5b43d07095b1ee7e3cc5b3ea69 Mon Sep 17 00:00:00 2001 From: ezilber-akamai Date: Tue, 21 Apr 2026 14:29:24 -0400 Subject: [PATCH 1/9] Ported resty removal changes to new branch from up-to-date main and made additional changes to monitor client --- client.go | 717 +++++++++--------- client_http.go | 56 -- client_monitor.go | 124 ++- client_test.go | 207 +++-- config_test.go | 14 +- errors.go | 105 +-- errors_test.go | 193 ++--- go.mod | 1 - go.sum | 4 - images.go | 48 +- internal/testutil/mock.go | 12 +- k8s/go.mod | 1 - k8s/go.sum | 2 - logger.go | 32 +- monitor_alert_definitions.go | 22 +- monitor_api_services.go | 14 +- pagination.go | 51 +- request_helpers.go | 110 +-- request_helpers_test.go | 41 +- retries.go | 138 ++-- retries_http.go | 132 ---- retries_http_test.go | 81 -- retries_test.go | 65 +- test/go.mod | 1 - test/go.sum | 2 - test/integration/cache_test.go | 18 +- .../TestMaintenancePolicies_List.yaml | 28 +- test/integration/maintenance_test.go | 18 +- test/unit/images_test.go | 8 + 29 files changed, 968 insertions(+), 1277 deletions(-) delete mode 100644 client_http.go delete mode 100644 retries_http.go delete mode 100644 retries_http_test.go diff --git a/client.go b/client.go index 0273d4565..575156fb5 100644 --- a/client.go +++ b/client.go @@ -3,6 +3,8 @@ package linodego import ( "bytes" "context" + "crypto/tls" + "crypto/x509" "encoding/json" "fmt" "io" @@ -14,13 +16,12 @@ import ( "path/filepath" "reflect" "regexp" + "runtime" "strconv" "strings" "sync" "text/template" "time" - - "github.com/go-resty/resty/v2" ) const ( @@ -50,7 +51,6 @@ const ( APIDefaultCacheExpiration = time.Minute * 15 ) -//nolint:unused var ( reqLogTemplate = template.Must(template.New("request").Parse(`Sending request: Method: {{.Method}} @@ -64,6 +64,20 @@ Headers: {{.Headers}} Body: {{.Body}}`)) ) +type RequestLog struct { + Method string + URL string + Headers http.Header + Body string +} + +type ResponseLog struct { + Method string + URL string + Headers http.Header + Body string +} + var envDebug = false // redactHeadersMap is a map of headers that should be redacted in logs, @@ -72,18 +86,19 @@ var redactHeadersMap = map[string]string{ "Authorization": "Bearer *******************************", } -// Client is a wrapper around the Resty client +// Client is a wrapper around the http client type Client struct { - resty *resty.Client - userAgent string - debug bool - retryConditionals []RetryConditional + httpClient *http.Client + userAgent string + debug bool pollInterval time.Duration baseURL string apiVersion string apiProto string + hostURL string + header http.Header selectedProfile string loadedProfile string @@ -94,6 +109,16 @@ type Client struct { cacheExpiration time.Duration cachedEntries map[string]clientCacheEntry cachedEntryLock *sync.RWMutex + logger Logger + requestLog func(*RequestLog) error + onBeforeRequest []func(*http.Request) error + onAfterResponse []func(*http.Response) error + + retryConditionals []RetryConditional + retryMaxWaitTime time.Duration + retryMinWaitTime time.Duration + retryAfter RetryAfter + retryCount int } type EnvDefaults struct { @@ -110,13 +135,11 @@ type clientCacheEntry struct { } type ( - Request = resty.Request - Response = resty.Response - Logger = resty.Logger + Request = http.Request + Response = http.Response ) func init() { - // Whether we will enable Resty debugging output if apiDebug, ok := os.LookupEnv("LINODE_DEBUG"); ok { if parsed, err := strconv.ParseBool(apiDebug); err == nil { envDebug = parsed @@ -127,118 +150,23 @@ func init() { } } -// NewClient factory to create new Client struct -func NewClient(hc *http.Client) (client Client) { - if hc != nil { - client.resty = resty.NewWithClient(hc) - } else { - client.resty = resty.New() - } - - client.shouldCache = true - client.cacheExpiration = APIDefaultCacheExpiration - client.cachedEntries = make(map[string]clientCacheEntry) - client.cachedEntryLock = &sync.RWMutex{} - - client.SetUserAgent(DefaultUserAgent) - - baseURL, baseURLExists := os.LookupEnv(APIHostVar) - - if baseURLExists { - client.SetBaseURL(baseURL) - } - - apiVersion, apiVersionExists := os.LookupEnv(APIVersionVar) - if apiVersionExists { - client.SetAPIVersion(apiVersion) - } else { - client.SetAPIVersion(APIVersion) - } - - certPath, certPathExists := os.LookupEnv(APIHostCert) - - if certPathExists && !hasCustomTransport(hc) { - cert, err := os.ReadFile(filepath.Clean(certPath)) - if err != nil { - log.Fatalf("[ERROR] Error when reading cert at %s: %s\n", certPath, err.Error()) - } - - client.SetRootCertificate(certPath) - - if envDebug { - log.Printf("[DEBUG] Set API root certificate to %s with contents %s\n", certPath, cert) - } - } - - client. - SetRetryWaitTime(APISecondsPerPoll * time.Second). - SetPollDelay(APISecondsPerPoll * time.Second). - SetRetries(). - SetDebug(envDebug). - enableLogSanitization() - - return client -} - -// NewClientFromEnv creates a Client and initializes it with values -// from the LINODE_CONFIG file and the LINODE_TOKEN environment variable. -func NewClientFromEnv(hc *http.Client) (*Client, error) { - client := NewClient(hc) - - // Users are expected to chain NewClient(...) and LoadConfig(...) to customize these options - configPath, err := resolveValidConfigPath() - if err != nil { - return nil, err - } - - // Populate the token from the environment. - // Tokens should be first priority to maintain backwards compatibility - if token, ok := os.LookupEnv(APIEnvVar); ok && token != "" { - client.SetToken(token) - return &client, nil - } - - if p, ok := os.LookupEnv(APIConfigEnvVar); ok { - configPath = p - } else if !ok && configPath == "" { - return nil, fmt.Errorf("no linode config file or token found") - } - - configProfile := DefaultConfigProfile - - if p, ok := os.LookupEnv(APIConfigProfileEnvVar); ok { - configProfile = p - } - - client.selectedProfile = configProfile - - // We should only load the config if the config file exists - if _, err = os.Stat(configPath); err != nil { - return nil, fmt.Errorf("error loading config file %s: %w", configPath, err) - } - - err = client.preLoadConfig(configPath) - - return &client, err -} - // SetUserAgent sets a custom user-agent for HTTP requests func (c *Client) SetUserAgent(ua string) *Client { c.userAgent = ua - c.resty.SetHeader("User-Agent", c.userAgent) + c.SetHeader("User-Agent", c.userAgent) return c } -type RequestParams struct { - Body any +type requestParams struct { + Body *bytes.Reader Response any } // Generic helper to execute HTTP requests using the net/http package // -// nolint:unused, funlen, gocognit -func (c *httpClient) doRequest(ctx context.Context, method, url string, params RequestParams) error { +// nolint:funlen, gocognit, nestif +func (c *Client) doRequest(ctx context.Context, method, endpoint string, params requestParams, paginationMutator *func(*http.Request) error) error { var ( req *http.Request bodyBuffer *bytes.Buffer @@ -246,18 +174,32 @@ func (c *httpClient) doRequest(ctx context.Context, method, url string, params R err error ) - for range httpDefaultRetryCount { - req, bodyBuffer, err = c.createRequest(ctx, method, url, params) + for range c.retryCount { + // Reset the body to the start for each retry if it's not nil + if params.Body != nil { + _, err := params.Body.Seek(0, io.SeekStart) + if err != nil { + return c.ErrorAndLogf("failed to seek to the start of the body: %v", err.Error()) + } + } + + req, bodyBuffer, err = c.createRequest(ctx, method, endpoint, params) if err != nil { return err } + if paginationMutator != nil { + if err := (*paginationMutator)(req); err != nil { + return c.ErrorAndLogf("failed to mutate before request: %v", err.Error()) + } + } + if err = c.applyBeforeRequest(req); err != nil { return err } if c.debug && c.logger != nil { - c.logRequest(req, method, url, bodyBuffer) + c.logRequest(req, method, endpoint, bodyBuffer) } processResponse := func() error { @@ -312,18 +254,28 @@ func (c *httpClient) doRequest(ctx context.Context, method, url string, params R } // Sleep for the specified duration before retrying. - // If retryAfter is 0 (i.e., Retry-After header is not found), - // no delay is applied. - time.Sleep(retryAfter) + if retryAfter > 0 { + waitTime := retryAfter + + // Ensure the wait time is within the defined bounds + if waitTime < c.retryMinWaitTime { + waitTime = c.retryMinWaitTime + } else if waitTime > c.retryMaxWaitTime { + waitTime = c.retryMaxWaitTime + } + + // Sleep for the calculated duration before retrying + time.Sleep(waitTime) + } } return err } -// nolint:unused -func (c *httpClient) shouldRetry(resp *http.Response, err error) bool { +func (c *Client) shouldRetry(resp *http.Response, err error) bool { for _, retryConditional := range c.retryConditionals { if retryConditional(resp, err) { + log.Printf("[INFO] Received error %v - Retrying", err) return true } } @@ -331,35 +283,27 @@ func (c *httpClient) shouldRetry(resp *http.Response, err error) bool { return false } -// nolint:unused -func (c *httpClient) createRequest(ctx context.Context, method, url string, params RequestParams) (*http.Request, *bytes.Buffer, error) { - var ( - bodyReader io.Reader - bodyBuffer *bytes.Buffer - ) +func (c *Client) createRequest(ctx context.Context, method, endpoint string, params requestParams) (*http.Request, *bytes.Buffer, error) { + var bodyReader io.Reader + var bodyBuffer *bytes.Buffer if params.Body != nil { - bodyBuffer = new(bytes.Buffer) - if err := json.NewEncoder(bodyBuffer).Encode(params.Body); err != nil { - if c.debug && c.logger != nil { - c.logger.Errorf("failed to encode body: %v", err) - } - - return nil, nil, fmt.Errorf("failed to encode body: %w", err) + // Reset the body position to the start before using it + _, err := params.Body.Seek(0, io.SeekStart) + if err != nil { + return nil, nil, c.ErrorAndLogf("failed to seek to the start of the body: %v", err.Error()) } - bodyReader = bodyBuffer + bodyReader = params.Body } - req, err := http.NewRequestWithContext(ctx, method, url, bodyReader) + req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s/%s", strings.TrimRight(c.hostURL, "/"), + strings.TrimLeft(endpoint, "/")), bodyReader) if err != nil { - if c.debug && c.logger != nil { - c.logger.Errorf("failed to create request: %v", err) - } - - return nil, nil, fmt.Errorf("failed to create request: %w", err) + return nil, nil, c.ErrorAndLogf("failed to create request: %v", err.Error()) } + // Set the default headers req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") @@ -367,40 +311,36 @@ func (c *httpClient) createRequest(ctx context.Context, method, url string, para req.Header.Set("User-Agent", c.userAgent) } + // Set additional headers added to the client + for name, values := range c.header { + for _, value := range values { + req.Header.Set(name, value) + } + } + return req, bodyBuffer, nil } -// nolint:unused -func (c *httpClient) applyBeforeRequest(req *http.Request) error { +func (c *Client) applyBeforeRequest(req *http.Request) error { for _, mutate := range c.onBeforeRequest { if err := mutate(req); err != nil { - if c.debug && c.logger != nil { - c.logger.Errorf("failed to mutate before request: %v", err) - } - - return fmt.Errorf("failed to mutate before request: %w", err) + return c.ErrorAndLogf("failed to mutate before request: %v", err.Error()) } } return nil } -// nolint:unused -func (c *httpClient) applyAfterResponse(resp *http.Response) error { +func (c *Client) applyAfterResponse(resp *http.Response) error { for _, mutate := range c.onAfterResponse { if err := mutate(resp); err != nil { - if c.debug && c.logger != nil { - c.logger.Errorf("failed to mutate after response: %v", err) - } - - return fmt.Errorf("failed to mutate after response: %w", err) + return c.ErrorAndLogf("failed to mutate after response: %v", err.Error()) } } return nil } -// nolint:unused func redactHeaders(headers http.Header) http.Header { redacted := headers.Clone() @@ -413,8 +353,7 @@ func redactHeaders(headers http.Header) http.Header { return redacted } -// nolint:unused -func (c *httpClient) logRequest(req *http.Request, method, url string, bodyBuffer *bytes.Buffer) { +func (c *Client) logRequest(req *http.Request, method, url string, bodyBuffer *bytes.Buffer) { var reqBody string if bodyBuffer != nil { reqBody = bodyBuffer.String() @@ -422,58 +361,57 @@ func (c *httpClient) logRequest(req *http.Request, method, url string, bodyBuffe reqBody = "nil" } - var logBuf bytes.Buffer + reqLog := &RequestLog{ + Method: method, + URL: url, + Headers: req.Header, + Body: reqBody, + } + + e := c.requestLog(reqLog) + if e != nil { + _ = c.ErrorAndLogf("failed to mutate after response: %v", e.Error()) + } - err := reqLogTemplate.Execute(&logBuf, map[string]any{ - "Method": method, - "URL": url, - "Headers": redactHeaders(req.Header), - "Body": reqBody, + var logBuf bytes.Buffer + err := reqLogTemplate.Execute(&logBuf, map[string]interface{}{ + "Method": reqLog.Method, + "URL": reqLog.URL, + "Headers": reqLog.Headers, + "Body": reqLog.Body, }) if err == nil { c.logger.Debugf(logBuf.String()) } } -// nolint:unused -func (c *httpClient) sendRequest(req *http.Request) (*http.Response, error) { - // #nosec G704 +func (c *Client) sendRequest(req *http.Request) (*http.Response, error) { resp, err := c.httpClient.Do(req) if err != nil { - if c.debug && c.logger != nil { - c.logger.Errorf("failed to send request: %v", err) - } - - return nil, fmt.Errorf("failed to send request: %w", err) + return nil, c.ErrorAndLogf("failed to send request: %w", err) } return resp, nil } -// nolint:unused -func (c *httpClient) checkHTTPError(resp *http.Response) error { - _, err := coupleAPIErrorsHTTP(resp, nil) +func (c *Client) checkHTTPError(resp *http.Response) error { + _, err := coupleAPIErrors(resp, nil) if err != nil { - if c.debug && c.logger != nil { - c.logger.Errorf("received HTTP error: %v", err) - } - + _ = c.ErrorAndLogf("received HTTP error: %v", err.Error()) return err } return nil } -// nolint:unused -func (c *httpClient) logResponse(resp *http.Response) (*http.Response, error) { +func (c *Client) logResponse(resp *http.Response) (*http.Response, error) { var respBody bytes.Buffer if _, err := io.Copy(&respBody, resp.Body); err != nil { c.logger.Errorf("failed to read response body: %v", err) } var logBuf bytes.Buffer - - err := respLogTemplate.Execute(&logBuf, map[string]any{ + err := respLogTemplate.Execute(&logBuf, map[string]interface{}{ "Status": resp.Status, "Headers": redactHeaders(resp.Header), "Body": respBody.String(), @@ -487,84 +425,32 @@ func (c *httpClient) logResponse(resp *http.Response) (*http.Response, error) { return resp, nil } -// nolint:unused -func (c *httpClient) decodeResponseBody(resp *http.Response, response any) error { +func (c *Client) decodeResponseBody(resp *http.Response, response interface{}) error { if err := json.NewDecoder(resp.Body).Decode(response); err != nil { - if c.debug && c.logger != nil { - c.logger.Errorf("failed to decode response: %v", err) - } - - return fmt.Errorf("failed to decode response: %w", err) + return c.ErrorAndLogf("failed to decode response: %v", err.Error()) } return nil } -// R wraps resty's R method -func (c *Client) R(ctx context.Context) *resty.Request { - return c.resty.R(). - ExpectContentType("application/json"). - SetHeader("Content-Type", "application/json"). - SetContext(ctx). - SetError(APIError{}) -} - -// SetDebug sets the debug on resty's client func (c *Client) SetDebug(debug bool) *Client { c.debug = debug - c.resty.SetDebug(debug) return c } -// SetLogger allows the user to override the output -// logger for debug logs. func (c *Client) SetLogger(logger Logger) *Client { - c.resty.SetLogger(logger) - - return c -} - -//nolint:unused -func (c *httpClient) httpSetDebug(debug bool) *httpClient { - c.debug = debug - - return c -} - -//nolint:unused -func (c *httpClient) httpSetLogger(logger httpLogger) *httpClient { c.logger = logger return c } -// OnBeforeRequest adds a handler to the request body to run before the request is sent -func (c *Client) OnBeforeRequest(m func(request *Request) error) { - c.resty.OnBeforeRequest(func(_ *resty.Client, req *resty.Request) error { - return m(req) - }) -} - -// OnAfterResponse adds a handler to the request body to run before the request is sent -func (c *Client) OnAfterResponse(m func(response *Response) error) { - c.resty.OnAfterResponse(func(_ *resty.Client, req *resty.Response) error { - return m(req) - }) -} - -// nolint:unused -func (c *httpClient) httpOnBeforeRequest(m func(*http.Request) error) *httpClient { +func (c *Client) OnBeforeRequest(m func(*http.Request) error) { c.onBeforeRequest = append(c.onBeforeRequest, m) - - return c } -// nolint:unused -func (c *httpClient) httpOnAfterResponse(m func(*http.Response) error) *httpClient { +func (c *Client) OnAfterResponse(m func(*http.Response) error) { c.onAfterResponse = append(c.onAfterResponse, m) - - return c } // UseURL parses the individual components of the given API URL and configures the client @@ -602,7 +488,6 @@ func (c *Client) UseURL(apiURL string) (*Client, error) { return c, nil } -// SetBaseURL sets the base URL of the Linode v4 API (https://api.linode.com/v4) func (c *Client) SetBaseURL(baseURL string) *Client { baseURLPath, _ := url.Parse(baseURL) @@ -623,40 +508,162 @@ func (c *Client) SetAPIVersion(apiVersion string) *Client { return c } +func (c *Client) updateHostURL() { + apiProto := APIProto + baseURL := APIHost + apiVersion := APIVersion + + if c.baseURL != "" { + baseURL = c.baseURL + } + + if c.apiVersion != "" { + apiVersion = c.apiVersion + } + + if c.apiProto != "" { + apiProto = c.apiProto + } + + c.hostURL = strings.TrimRight(fmt.Sprintf("%s://%s/%s", apiProto, baseURL, url.PathEscape(apiVersion)), "/") +} + +func (c *Client) Transport() (*http.Transport, error) { + if transport, ok := c.httpClient.Transport.(*http.Transport); ok { + return transport, nil + } + return nil, fmt.Errorf("current transport is not an *http.Transport instance") +} + +func (c *Client) tlsConfig() (*tls.Config, error) { + transport, err := c.Transport() + if err != nil { + return nil, err + } + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + } + } + return transport.TLSClientConfig, nil +} + // SetRootCertificate adds a root certificate to the underlying TLS client config func (c *Client) SetRootCertificate(path string) *Client { - c.resty.SetRootCertificate(path) + config, err := c.tlsConfig() + if err != nil { + log.Println("[WARN] Custom transport is not allowed with a custom root CA") + return c + } + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + + config.RootCAs.AppendCertsFromPEM([]byte(path)) return c } // SetToken sets the API token for all requests from this client // Only necessary if you haven't already provided the http client to NewClient() configured with the token. func (c *Client) SetToken(token string) *Client { - c.resty.SetHeader("Authorization", fmt.Sprintf("Bearer %s", token)) + c.SetHeader("Authorization", fmt.Sprintf("Bearer %s", token)) return c } // SetRetries adds retry conditions for "Linode Busy." errors and 429s. func (c *Client) SetRetries() *Client { c. - addRetryConditional(linodeBusyRetryCondition). - addRetryConditional(tooManyRequestsRetryCondition). - addRetryConditional(serviceUnavailableRetryCondition). - addRetryConditional(requestTimeoutRetryCondition). - addRetryConditional(requestGOAWAYRetryCondition). - addRetryConditional(requestNGINXRetryCondition). + AddRetryCondition(LinodeBusyRetryCondition). + AddRetryCondition(TooManyRequestsRetryCondition). + AddRetryCondition(ServiceUnavailableRetryCondition). + AddRetryCondition(RequestTimeoutRetryCondition). + AddRetryCondition(RequestGOAWAYRetryCondition). + AddRetryCondition(RequestNGINXRetryCondition). SetRetryMaxWaitTime(APIRetryMaxWaitTime) - configureRetries(c) - + ConfigureRetries(c) return c } // AddRetryCondition adds a RetryConditional function to the Client func (c *Client) AddRetryCondition(retryCondition RetryConditional) *Client { - c.resty.AddRetryCondition(resty.RetryConditionFunc(retryCondition)) + c.retryConditionals = append(c.retryConditionals, retryCondition) return c } +func (c *Client) addCachedResponse(endpoint string, response any, expiry *time.Duration) { + if !c.shouldCache { + return + } + + responseValue := reflect.ValueOf(response) + + entry := clientCacheEntry{ + Created: time.Now(), + ExpiryOverride: expiry, + } + + switch responseValue.Kind() { + case reflect.Ptr: + // We want to automatically deref pointers to + // avoid caching mutable data. + entry.Data = responseValue.Elem().Interface() + default: + entry.Data = response + } + + c.cachedEntryLock.Lock() + defer c.cachedEntryLock.Unlock() + + c.cachedEntries[endpoint] = entry +} + +func (c *Client) getCachedResponse(endpoint string) any { + if !c.shouldCache { + return nil + } + + c.cachedEntryLock.RLock() + + // Hacky logic to dynamically RUnlock + // only if it is still locked by the + // end of the function. + // This is necessary as we take write + // access if the entry has expired. + rLocked := true + defer func() { + if rLocked { + c.cachedEntryLock.RUnlock() + } + }() + + entry, ok := c.cachedEntries[endpoint] + if !ok { + return nil + } + + // Handle expired entries + elapsedTime := time.Since(entry.Created) + + hasExpired := elapsedTime > c.cacheExpiration + if entry.ExpiryOverride != nil { + hasExpired = elapsedTime > *entry.ExpiryOverride + } + + if hasExpired { + // We need to give up our read access and request read-write access + c.cachedEntryLock.RUnlock() + rLocked = false + + c.cachedEntryLock.Lock() + defer c.cachedEntryLock.Unlock() + + delete(c.cachedEntries, endpoint) + return nil + } + + return c.cachedEntries[endpoint].Data +} + // InvalidateCache clears all cached responses for all endpoints. func (c *Client) InvalidateCache() { c.cachedEntryLock.Lock() @@ -694,26 +701,26 @@ func (c *Client) UseCache(value bool) { // SetRetryMaxWaitTime sets the maximum delay before retrying a request. func (c *Client) SetRetryMaxWaitTime(maxWaitTime time.Duration) *Client { - c.resty.SetRetryMaxWaitTime(maxWaitTime) + c.retryMaxWaitTime = maxWaitTime return c } // SetRetryWaitTime sets the default (minimum) delay before retrying a request. func (c *Client) SetRetryWaitTime(minWaitTime time.Duration) *Client { - c.resty.SetRetryWaitTime(minWaitTime) + c.retryMinWaitTime = minWaitTime return c } // SetRetryAfter sets the callback function to be invoked with a failed request // to determine wben it should be retried. func (c *Client) SetRetryAfter(callback RetryAfter) *Client { - c.resty.SetRetryAfter(resty.RetryAfterFunc(callback)) + c.retryAfter = callback return c } // SetRetryCount sets the maximum retry attempts before aborting. func (c *Client) SetRetryCount(count int) *Client { - c.resty.SetRetryCount(count) + c.retryCount = count return c } @@ -734,138 +741,141 @@ func (c *Client) GetPollDelay() time.Duration { // client. // NOTE: Some headers may be overridden by the individual request functions. func (c *Client) SetHeader(name, value string) { - c.resty.SetHeader(name, value) + if c.header == nil { + c.header = make(http.Header) // Initialize header if nil + } + c.header.Set(name, value) } -func (c *Client) addRetryConditional(retryConditional RetryConditional) *Client { - c.retryConditionals = append(c.retryConditionals, retryConditional) +func (c *Client) onRequestLog(rl func(*RequestLog) error) *Client { + if c.requestLog != nil { + c.logger.Warnf("Overwriting an existing on-request-log callback from=%s to=%s", + functionName(c.requestLog), functionName(rl)) + } + c.requestLog = rl return c } -func (c *Client) addCachedResponse(endpoint string, response any, expiry *time.Duration) { - if !c.shouldCache { - return - } - - responseValue := reflect.ValueOf(response) - - entry := clientCacheEntry{ - Created: time.Now(), - ExpiryOverride: expiry, - } - - switch responseValue.Kind() { - case reflect.Ptr: - // We want to automatically deref pointers to - // avoid caching mutable data. - entry.Data = responseValue.Elem().Interface() - default: - entry.Data = response - } +func functionName(i interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() +} - c.cachedEntryLock.Lock() - defer c.cachedEntryLock.Unlock() +func (c *Client) enableLogSanitization() *Client { + c.onRequestLog(func(r *RequestLog) error { + // masking authorization header + r.Headers.Set("Authorization", "Bearer *******************************") + return nil + }) - c.cachedEntries[endpoint] = entry + return c } -func (c *Client) getCachedResponse(endpoint string) any { - if !c.shouldCache { - return nil +// NewClient factory to create new Client struct +// nolint:funlen +func NewClient(hc *http.Client) (client Client) { + if hc != nil { + client.httpClient = hc + } else { + client.httpClient = &http.Client{} } - c.cachedEntryLock.RLock() + // Ensure that the Header map is not nil + if client.httpClient.Transport == nil { + client.httpClient.Transport = &http.Transport{} + } - // Hacky logic to dynamically RUnlock - // only if it is still locked by the - // end of the function. - // This is necessary as we take write - // access if the entry has expired. - rLocked := true + client.shouldCache = true + client.cacheExpiration = APIDefaultCacheExpiration + client.cachedEntries = make(map[string]clientCacheEntry) + client.cachedEntryLock = &sync.RWMutex{} + client.configProfiles = make(map[string]ConfigProfile) - defer func() { - if rLocked { - c.cachedEntryLock.RUnlock() - } - }() + const ( + retryMinWaitDuration = 100 * time.Millisecond + retryMaxWaitDuration = 2 * time.Second + ) - entry, ok := c.cachedEntries[endpoint] - if !ok { - return nil - } + client.retryMinWaitTime = retryMinWaitDuration + client.retryMaxWaitTime = retryMaxWaitDuration - // Handle expired entries - elapsedTime := time.Since(entry.Created) + client.SetUserAgent(DefaultUserAgent) + client.SetLogger(createLogger()) - hasExpired := elapsedTime > c.cacheExpiration - if entry.ExpiryOverride != nil { - hasExpired = elapsedTime > *entry.ExpiryOverride + baseURL, baseURLExists := os.LookupEnv(APIHostVar) + if baseURLExists { + client.SetBaseURL(baseURL) + } + apiVersion, apiVersionExists := os.LookupEnv(APIVersionVar) + if apiVersionExists { + client.SetAPIVersion(apiVersion) + } else { + client.SetAPIVersion(APIVersion) } - if hasExpired { - // We need to give up our read access and request read-write access - c.cachedEntryLock.RUnlock() - - rLocked = false - - c.cachedEntryLock.Lock() - defer c.cachedEntryLock.Unlock() + certPath, certPathExists := os.LookupEnv(APIHostCert) + if certPathExists { + cert, err := os.ReadFile(filepath.Clean(certPath)) + if err != nil { + log.Fatalf("[ERROR] Error when reading cert at %s: %s\n", certPath, err.Error()) + } - delete(c.cachedEntries, endpoint) + client.SetRootCertificate(certPath) - return nil + if envDebug { + log.Printf("[DEBUG] Set API root certificate to %s with contents %s\n", certPath, cert) + } } - return c.cachedEntries[endpoint].Data + client. + SetRetryWaitTime(APISecondsPerPoll * time.Second). + SetPollDelay(APISecondsPerPoll * time.Second). + SetRetries(). + SetLogger(createLogger()). + SetDebug(envDebug). + enableLogSanitization() + + return } -func (c *Client) updateHostURL() { - apiProto := APIProto - baseURL := APIHost - apiVersion := APIVersion +// NewClientFromEnv creates a Client and initializes it with values +// from the LINODE_CONFIG file and the LINODE_TOKEN environment variable. +func NewClientFromEnv(hc *http.Client) (*Client, error) { + client := NewClient(hc) - if c.baseURL != "" { - baseURL = c.baseURL + // Users are expected to chain NewClient(...) and LoadConfig(...) to customize these options + configPath, err := resolveValidConfigPath() + if err != nil { + return nil, err } - if c.apiVersion != "" { - apiVersion = c.apiVersion + // Populate the token from the environment. + // Tokens should be first priority to maintain backwards compatibility + if token, ok := os.LookupEnv(APIEnvVar); ok && token != "" { + client.SetToken(token) + return &client, nil } - if c.apiProto != "" { - apiProto = c.apiProto + if p, ok := os.LookupEnv(APIConfigEnvVar); ok { + configPath = p + } else if !ok && configPath == "" { + return nil, fmt.Errorf("no linode config file or token found") } - c.resty.SetBaseURL( - fmt.Sprintf( - "%s://%s/%s", - apiProto, - baseURL, - url.PathEscape(apiVersion), - ), - ) -} + configProfile := DefaultConfigProfile -func redactLogHeaders(header http.Header) { - for h, redactedValue := range redactHeadersMap { - if header.Get(h) != "" { - header.Set(h, redactedValue) - } + if p, ok := os.LookupEnv(APIConfigProfileEnvVar); ok { + configProfile = p } -} -func (c *Client) enableLogSanitization() *Client { - c.resty.OnRequestLog(func(r *resty.RequestLog) error { - redactLogHeaders(r.Header) - return nil - }) + client.selectedProfile = configProfile - c.resty.OnResponseLog(func(r *resty.ResponseLog) error { - redactLogHeaders(r.Header) - return nil - }) + // We should only load the config if the config file exists + if _, err := os.Stat(configPath); err != nil { + return nil, fmt.Errorf("error loading config file %s: %w", configPath, err) + } - return c + err = client.preLoadConfig(configPath) + return &client, err } func (c *Client) preLoadConfig(configPath string) error { @@ -959,6 +969,13 @@ func generateListCacheURL(endpoint string, opts *ListOptions) (string, error) { return fmt.Sprintf("%s:%s", endpoint, hashedOpts), nil } +func (c *Client) ErrorAndLogf(format string, args ...interface{}) error { + if c.debug && c.logger != nil { + c.logger.Errorf(format, args...) + } + return fmt.Errorf(format, args...) +} + func hasCustomTransport(hc *http.Client) bool { if hc == nil || hc.Transport == nil { return false diff --git a/client_http.go b/client_http.go deleted file mode 100644 index 7f16362c5..000000000 --- a/client_http.go +++ /dev/null @@ -1,56 +0,0 @@ -package linodego - -import ( - "net/http" - "sync" - "time" -) - -// Client is a wrapper around the Resty client -// -//nolint:unused -type httpClient struct { - //nolint:unused - httpClient *http.Client - //nolint:unused - userAgent string - //nolint:unused - debug bool - //nolint:unused - retryConditionals []httpRetryConditional - //nolint:unused - retryAfter httpRetryAfter - - //nolint:unused - pollInterval time.Duration - - //nolint:unused - baseURL string - //nolint:unused - apiVersion string - //nolint:unused - apiProto string - //nolint:unused - selectedProfile string - //nolint:unused - loadedProfile string - - //nolint:unused - configProfiles map[string]ConfigProfile - - // Fields for caching endpoint responses - //nolint:unused - shouldCache bool - //nolint:unused - cacheExpiration time.Duration - //nolint:unused - cachedEntries map[string]clientCacheEntry - //nolint:unused - cachedEntryLock *sync.RWMutex - //nolint:unused - logger httpLogger - //nolint:unused - onBeforeRequest []func(*http.Request) error - //nolint:unused - onAfterResponse []func(*http.Response) error -} diff --git a/client_monitor.go b/client_monitor.go index b54dc8284..0cdfe4fa7 100644 --- a/client_monitor.go +++ b/client_monitor.go @@ -2,13 +2,16 @@ package linodego import ( "context" + "crypto/tls" + "crypto/x509" + "encoding/json" "fmt" + "io" "net/http" "net/url" "os" "path" - - "github.com/go-resty/resty/v2" + "strings" ) const ( @@ -24,25 +27,31 @@ const ( MonitorAPIEnvVar = "MONITOR_API_TOKEN" ) -// MonitorClient is a wrapper around the Resty client +// MonitorClient is a wrapper around the http client type MonitorClient struct { - resty *resty.Client + httpClient *http.Client debug bool apiBaseURL string apiProtocol string apiVersion string + hostURL string userAgent string + header http.Header + logger Logger } // NewMonitorClient is the entry point for user to create a new MonitorClient // It utilizes default values and looks for environment variables to initialize a MonitorClient. func NewMonitorClient(hc *http.Client) (mClient MonitorClient) { if hc != nil { - mClient.resty = resty.NewWithClient(hc) + mClient.httpClient = hc } else { - mClient.resty = resty.New() + mClient.httpClient = &http.Client{} } + mClient.header = make(http.Header) + mClient.logger = createLogger() + mClient.SetUserAgent(DefaultUserAgent) baseURL, baseURLExists := os.LookupEnv(MonitorAPIHostVar) @@ -72,24 +81,14 @@ func NewMonitorClient(hc *http.Client) (mClient MonitorClient) { // SetUserAgent sets a custom user-agent for HTTP requests func (mc *MonitorClient) SetUserAgent(ua string) *MonitorClient { mc.userAgent = ua - mc.resty.SetHeader("User-Agent", mc.userAgent) + mc.header.Set("User-Agent", ua) return mc } -// R wraps resty's R method -func (mc *MonitorClient) R(ctx context.Context) *resty.Request { - return mc.resty.R(). - ExpectContentType("application/json"). - SetHeader("Content-Type", "application/json"). - SetContext(ctx). - SetError(APIError{}) -} - -// SetDebug sets the debug on resty's client +// SetDebug sets the debug on the client func (mc *MonitorClient) SetDebug(debug bool) *MonitorClient { mc.debug = debug - mc.resty.SetDebug(debug) return mc } @@ -97,7 +96,7 @@ func (mc *MonitorClient) SetDebug(debug bool) *MonitorClient { // SetLogger allows the user to override the output // logger for debug logs. func (mc *MonitorClient) SetLogger(logger Logger) *MonitorClient { - mc.resty.SetLogger(logger) + mc.logger = logger return mc } @@ -124,21 +123,34 @@ func (mc *MonitorClient) SetAPIVersion(apiVersion string) *MonitorClient { } // SetRootCertificate adds a root certificate to the underlying TLS client config -func (mc *MonitorClient) SetRootCertificate(path string) *MonitorClient { - mc.resty.SetRootCertificate(path) +func (mc *MonitorClient) SetRootCertificate(certPath string) *MonitorClient { + transport, ok := mc.httpClient.Transport.(*http.Transport) + if !ok { + mc.logger.Errorf("current transport is not an *http.Transport instance") + return mc + } + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + } + } + if transport.TLSClientConfig.RootCAs == nil { + transport.TLSClientConfig.RootCAs = x509.NewCertPool() + } + transport.TLSClientConfig.RootCAs.AppendCertsFromPEM([]byte(certPath)) return mc } // SetToken sets the API token for all requests from this client func (mc *MonitorClient) SetToken(token string) *MonitorClient { - mc.resty.SetHeader("Authorization", fmt.Sprintf("Bearer %s", token)) + mc.header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) return mc } // SetHeader sets a custom header to be used in all API requests made with the current client. // NOTE: Some headers may be overridden by the individual request functions. func (mc *MonitorClient) SetHeader(name, value string) { - mc.resty.SetHeader(name, value) + mc.header.Set(name, value) } func (mc *MonitorClient) updateMonitorHostURL() { @@ -158,12 +170,64 @@ func (mc *MonitorClient) updateMonitorHostURL() { apiProto = mc.apiProtocol } - mc.resty.SetBaseURL( - fmt.Sprintf( - "%s://%s/%s", - apiProto, - baseURL, - url.PathEscape(apiVersion), - ), + mc.hostURL = fmt.Sprintf( + "%s://%s/%s", + apiProto, + baseURL, + url.PathEscape(apiVersion), ) } + +// doRequest is a generic helper to execute HTTP requests for the MonitorClient +func (mc *MonitorClient) doRequest(ctx context.Context, method, endpoint string, params requestParams) error { + var bodyReader io.Reader + if params.Body != nil { + if _, err := params.Body.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek body: %w", err) + } + bodyReader = params.Body + } + + reqURL := fmt.Sprintf("%s/%s", strings.TrimRight(mc.hostURL, "/"), strings.TrimLeft(endpoint, "/")) + + req, err := http.NewRequestWithContext(ctx, method, reqURL, bodyReader) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + for name, values := range mc.header { + for _, value := range values { + req.Header.Set(name, value) + } + } + + if mc.debug && mc.logger != nil { + mc.logger.Debugf("Sending request: %s %s", method, reqURL) + } + + resp, err := mc.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + _, err = coupleAPIErrors(resp, nil) + if err != nil { + return err + } + + if mc.debug && mc.logger != nil { + mc.logger.Debugf("Received response: %s", resp.Status) + } + + if params.Response != nil { + if err := json.NewDecoder(resp.Body).Decode(params.Response); err != nil { + return fmt.Errorf("failed to decode response: %w", err) + } + } + + return nil +} diff --git a/client_test.go b/client_test.go index f925f3678..cb4a32aa2 100644 --- a/client_test.go +++ b/client_test.go @@ -37,39 +37,39 @@ func TestClient_SetAPIVersion(t *testing.T) { client := NewClient(nil) - if client.resty.BaseURL != defaultURL { - t.Fatal(cmp.Diff(client.resty.BaseURL, defaultURL)) + if client.hostURL != defaultURL { + t.Fatal(cmp.Diff(client.hostURL, defaultURL)) } client.SetBaseURL(baseURL) client.SetAPIVersion(apiVersion) - if client.resty.BaseURL != expectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != expectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } // Ensure setting twice does not cause conflicts client.SetBaseURL(updatedBaseURL) client.SetAPIVersion(updatedAPIVersion) - if client.resty.BaseURL != updatedExpectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, updatedExpectedHost)) + if client.hostURL != updatedExpectedHost { + t.Fatal(cmp.Diff(client.hostURL, updatedExpectedHost)) } // Revert client.SetBaseURL(baseURL) client.SetAPIVersion(apiVersion) - if client.resty.BaseURL != expectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != expectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } // Custom protocol client.SetBaseURL(protocolBaseURL) client.SetAPIVersion(protocolAPIVersion) - if client.resty.BaseURL != protocolExpectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != protocolExpectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } } @@ -111,7 +111,7 @@ func TestClient_NewFromEnvToken(t *testing.T) { t.Fatal(err) } - if client.resty.Header.Get("Authorization") != "Bearer blah" { + if client.header.Get("Authorization") != "Bearer blah" { t.Fatal("token not found in auth header: blah") } } @@ -171,8 +171,8 @@ func TestClient_UseURL(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if client.resty.BaseURL != tt.wantBaseURL { - t.Fatalf("mismatched base url: got %s, want %s", client.resty.BaseURL, tt.wantBaseURL) + if client.hostURL != tt.wantBaseURL { + t.Fatalf("mismatched base url: got %s, want %s", client.hostURL, tt.wantBaseURL) } }) } @@ -209,12 +209,12 @@ func TestDebugLogSanitization(t *testing.T) { logger.L.SetOutput(&lgr) mockClient.SetDebug(true) - if !mockClient.resty.Debug { + if !mockClient.debug { t.Fatal("debug should be enabled") } mockClient.SetHeader("Authorization", fmt.Sprintf("Bearer %s", plainTextToken)) - if mockClient.resty.Header.Get("Authorization") != fmt.Sprintf("Bearer %s", plainTextToken) { + if mockClient.header.Get("Authorization") != fmt.Sprintf("Bearer %s", plainTextToken) { t.Fatal("token not found in auth header") } @@ -242,22 +242,25 @@ func TestDebugLogSanitization(t *testing.T) { func TestDoRequest_Success(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"message":"success"}`)) + if r.URL.Path == "/v4/foo/bar" { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"message":"success"}`)) + } else { + http.NotFound(w, r) + } } server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) + client.SetBaseURL(server.URL) - params := RequestParams{ + params := requestParams{ Response: &map[string]string{}, } - err := client.doRequest(context.Background(), http.MethodGet, server.URL, params) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", params, nil) // Pass only the endpoint if err != nil { t.Fatal(cmp.Diff(nil, err)) } @@ -269,31 +272,11 @@ func TestDoRequest_Success(t *testing.T) { } } -func TestDoRequest_FailedEncodeBody(t *testing.T) { - client := &httpClient{ - httpClient: http.DefaultClient, - } - - params := RequestParams{ - Body: map[string]interface{}{ - "invalid": func() {}, - }, - } - - err := client.doRequest(context.Background(), http.MethodPost, "http://example.com", params) - expectedErr := "failed to encode body" - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Fatalf("expected error %q, got: %v", expectedErr, err) - } -} - func TestDoRequest_FailedCreateRequest(t *testing.T) { - client := &httpClient{ - httpClient: http.DefaultClient, - } + client := NewClient(nil) - // Create a request with an invalid URL to simulate a request creation failure - err := client.doRequest(context.Background(), http.MethodGet, "http://invalid url", RequestParams{}) + // Create a request with an invalid method to simulate a request creation failure + err := client.doRequest(context.Background(), "bad method", "/foo/bar", requestParams{}, nil) expectedErr := "failed to create request" if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) @@ -302,26 +285,28 @@ func TestDoRequest_FailedCreateRequest(t *testing.T) { func TestDoRequest_Non2xxStatusCode(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "error", http.StatusInternalServerError) + http.Error(w, "error", http.StatusInternalServerError) // Simulate a 500 Internal Server Error } server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) + client.SetBaseURL(server.URL) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", requestParams{}, nil) if err == nil { t.Fatal("expected error, got nil") } - httpError, ok := err.(Error) + + httpError, ok := err.(*Error) if !ok { t.Fatalf("expected error to be of type Error, got %T", err) } + if httpError.Code != http.StatusInternalServerError { t.Fatalf("expected status code %d, got %d", http.StatusInternalServerError, httpError.Code) } + if !strings.Contains(httpError.Message, "error") { t.Fatalf("expected error message to contain %q, got %v", "error", httpError.Message) } @@ -331,21 +316,21 @@ func TestDoRequest_FailedDecodeResponse(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`invalid json`)) + _, _ = w.Write([]byte(`invalid json`)) // Simulate invalid JSON } server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) + client.SetBaseURL(server.URL) - params := RequestParams{ + params := requestParams{ Response: &map[string]string{}, } - err := client.doRequest(context.Background(), http.MethodGet, server.URL, params) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", params, nil) expectedErr := "failed to decode response" + if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) } @@ -363,24 +348,21 @@ func TestDoRequest_BeforeRequestSuccess(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) + client.SetBaseURL(server.URL) - // Define a mutator that successfully modifies the request mutator := func(req *http.Request) error { req.Header.Set("X-Custom-Header", "CustomValue") return nil } - client.httpOnBeforeRequest(mutator) + client.OnBeforeRequest(mutator) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", requestParams{}, nil) if err != nil { t.Fatalf("expected no error, got: %v", err) } - // Check if the header was successfully added to the captured request if reqHeader := capturedRequest.Header.Get("X-Custom-Header"); reqHeader != "CustomValue" { t.Fatalf("expected X-Custom-Header to be set to CustomValue, got: %v", reqHeader) } @@ -395,18 +377,18 @@ func TestDoRequest_BeforeRequestError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) + client.SetBaseURL(server.URL) mutator := func(req *http.Request) error { return errors.New("mutator error") } - client.httpOnBeforeRequest(mutator) + client.OnBeforeRequest(mutator) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", requestParams{}, nil) expectedErr := "failed to mutate before request" + if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) } @@ -425,23 +407,21 @@ func TestDoRequest_AfterResponseSuccess(t *testing.T) { tr := &testRoundTripper{ Transport: server.Client().Transport, } - client := &httpClient{ - httpClient: &http.Client{Transport: tr}, - } + client := NewClient(&http.Client{Transport: tr}) + client.SetBaseURL(server.URL) mutator := func(resp *http.Response) error { resp.Header.Set("X-Modified-Header", "ModifiedValue") return nil } - client.httpOnAfterResponse(mutator) + client.OnAfterResponse(mutator) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", requestParams{}, nil) if err != nil { t.Fatalf("expected no error, got: %v", err) } - // Check if the header was successfully added to the response if respHeader := tr.Response.Header.Get("X-Modified-Header"); respHeader != "ModifiedValue" { t.Fatalf("expected X-Modified-Header to be set to ModifiedValue, got: %v", respHeader) } @@ -456,17 +436,16 @@ func TestDoRequest_AfterResponseError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) + client.SetBaseURL(server.URL) mutator := func(resp *http.Response) error { return errors.New("mutator error") } - client.httpOnAfterResponse(mutator) + client.OnAfterResponse(mutator) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", requestParams{}, nil) expectedErr := "failed to mutate after response" if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) @@ -478,11 +457,9 @@ func TestDoRequestLogging_Success(t *testing.T) { logger := createLogger() logger.l.SetOutput(&logBuffer) // Redirect log output to buffer - client := &httpClient{ - httpClient: http.DefaultClient, - debug: true, - logger: logger, - } + client := NewClient(nil) + client.SetDebug(true) + client.SetLogger(logger) handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -490,13 +467,14 @@ func TestDoRequestLogging_Success(t *testing.T) { _, _ = w.Write([]byte(`{"message":"success"}`)) } server := httptest.NewServer(http.HandlerFunc(handler)) + client.SetBaseURL(server.URL) defer server.Close() - params := RequestParams{ + params := requestParams{ Response: &map[string]string{}, } - err := client.doRequest(context.Background(), http.MethodGet, server.URL, params) + err := client.doRequest(context.Background(), http.MethodGet, server.URL, params, nil) if err != nil { t.Fatal(cmp.Diff(nil, err)) } @@ -505,8 +483,8 @@ func TestDoRequestLogging_Success(t *testing.T) { logInfoWithoutTimestamps := removeTimestamps(logInfo) // Expected logs with templates filled in - expectedRequestLog := "DEBUG RESTY Sending request:\nMethod: GET\nURL: " + server.URL + "\nHeaders: map[Accept:[application/json] Content-Type:[application/json]]\nBody: " - expectedResponseLog := "DEBUG RESTY Received response:\nStatus: 200 OK\nHeaders: map[Content-Length:[21] Content-Type:[text/plain; charset=utf-8]]\nBody: {\"message\":\"success\"}" + expectedRequestLog := "DEBUG Sending request:\nMethod: GET\nURL: " + server.URL + "\nHeaders: map[Accept:[application/json] Authorization:[Bearer *******************************] Content-Type:[application/json] User-Agent:[linodego/dev https://github.com/linode/linodego]]\nBody: " + expectedResponseLog := "DEBUG Received response:\nStatus: 200 OK\nHeaders: map[Content-Length:[21] Content-Type:[text/plain; charset=utf-8]]\nBody: {\"message\":\"success\"}" if !strings.Contains(logInfo, expectedRequestLog) { t.Fatalf("expected log %q not found in logs", expectedRequestLog) @@ -521,26 +499,19 @@ func TestDoRequestLogging_Error(t *testing.T) { logger := createLogger() logger.l.SetOutput(&logBuffer) // Redirect log output to buffer - client := &httpClient{ - httpClient: http.DefaultClient, - debug: true, - logger: logger, - } - - params := RequestParams{ - Body: map[string]interface{}{ - "invalid": func() {}, - }, - } + client := NewClient(nil) + client.SetDebug(true) + client.SetLogger(logger) - err := client.doRequest(context.Background(), http.MethodPost, "http://example.com", params) - expectedErr := "failed to encode body" + // Create a request with an invalid method to simulate a request creation failure + err := client.doRequest(context.Background(), "bad method", "/foo/bar", requestParams{}, nil) + expectedErr := "failed to create request" if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) } logInfo := logBuffer.String() - expectedLog := "ERROR RESTY failed to encode body" + expectedLog := "ERROR failed to create request" if !strings.Contains(logInfo, expectedLog) { t.Fatalf("expected log %q not found in logs", expectedLog) @@ -638,9 +609,9 @@ func TestClient_CustomRootCAWithoutCustomRoundTripper(t *testing.T) { } client := NewClient(test.httpClient) - transport, err := client.resty.Transport() - if err != nil { - t.Fatal(err) + transport, ok := client.httpClient.Transport.(*http.Transport) + if !ok { + t.Fatal("expected *http.Transport") } if test.setCA && (transport.TLSClientConfig == nil || transport.TLSClientConfig.RootCAs == nil) { t.Error("expected root CAs to be set") @@ -669,39 +640,39 @@ func TestMonitorClient_SetAPIBasics(t *testing.T) { client := NewMonitorClient(nil) - if client.resty.BaseURL != defaultURL { - t.Fatal(cmp.Diff(client.resty.BaseURL, defaultURL)) + if client.hostURL != defaultURL { + t.Fatal(cmp.Diff(client.hostURL, defaultURL)) } client.SetBaseURL(baseURL) client.SetAPIVersion(apiVersion) - if client.resty.BaseURL != expectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != expectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } // Ensure setting twice does not cause conflicts client.SetBaseURL(updatedBaseURL) client.SetAPIVersion(updatedAPIVersion) - if client.resty.BaseURL != updatedExpectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, updatedExpectedHost)) + if client.hostURL != updatedExpectedHost { + t.Fatal(cmp.Diff(client.hostURL, updatedExpectedHost)) } // Revert client.SetBaseURL(baseURL) client.SetAPIVersion(apiVersion) - if client.resty.BaseURL != expectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != expectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } // Custom protocol client.SetBaseURL(protocolBaseURL) client.SetAPIVersion(protocolAPIVersion) - if client.resty.BaseURL != protocolExpectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != protocolExpectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } } @@ -787,7 +758,7 @@ func TestEnableLogSanitization(t *testing.T) { "Authorization": []string{"Bearer " + plainTextToken}, })) - _, err := mockClient.resty.R().Get("https://api.linode.com/v4/test") + err := mockClient.doRequest(context.Background(), http.MethodGet, "/test", requestParams{}, nil) require.NoError(t, err) logOutput := logBuf.String() diff --git a/config_test.go b/config_test.go index b4b3db418..628cc1415 100644 --- a/config_test.go +++ b/config_test.go @@ -42,11 +42,11 @@ func TestConfig_LoadWithDefaults(t *testing.T) { expectedURL := "https://api.cool.linode.com/v4beta" - if client.resty.BaseURL != expectedURL { - t.Fatalf("mismatched host url: %s != %s", client.resty.BaseURL, expectedURL) + if client.hostURL != expectedURL { + t.Fatalf("mismatched host url: %s != %s", client.hostURL, expectedURL) } - if client.resty.Header.Get("Authorization") != "Bearer "+p.APIToken { + if client.header.Get("Authorization") != "Bearer "+p.APIToken { t.Fatalf("token not found in auth header: %s", p.APIToken) } } @@ -88,11 +88,11 @@ func TestConfig_OverrideDefaults(t *testing.T) { expectedURL := "https://api.cool.linode.com/v4" - if client.resty.BaseURL != expectedURL { - t.Fatalf("mismatched host url: %s != %s", client.resty.BaseURL, expectedURL) + if client.hostURL != expectedURL { + t.Fatalf("mismatched host url: %s != %s", client.hostURL, expectedURL) } - if client.resty.Header.Get("Authorization") != "Bearer "+p.APIToken { + if client.header.Get("Authorization") != "Bearer "+p.APIToken { t.Fatalf("token not found in auth header: %s", p.APIToken) } } @@ -124,7 +124,7 @@ func TestConfig_NoDefaults(t *testing.T) { t.Fatalf("mismatched api token: %s != %s", p.APIToken, "mytoken") } - if client.resty.Header.Get("Authorization") != "Bearer "+p.APIToken { + if client.header.Get("Authorization") != "Bearer "+p.APIToken { t.Fatalf("token not found in auth header: %s", p.APIToken) } } diff --git a/errors.go b/errors.go index 873b40450..62586a477 100644 --- a/errors.go +++ b/errors.go @@ -1,7 +1,7 @@ package linodego import ( - "encoding/json" + "bytes" "errors" "fmt" "io" @@ -9,8 +9,6 @@ import ( "reflect" "slices" "strings" - - "github.com/go-resty/resty/v2" ) const ( @@ -49,74 +47,43 @@ type APIError struct { Errors []APIErrorReason `json:"errors"` } -// String returns the error reason in a formatted string -func (r APIErrorReason) String() string { - return fmt.Sprintf("[%s] %s", r.Field, r.Reason) -} - -func coupleAPIErrors(r *resty.Response, err error) (*resty.Response, error) { +//nolint:nestif +func coupleAPIErrors(resp *http.Response, err error) (*http.Response, error) { if err != nil { - // an error was raised in go code, no need to check the resty Response return nil, NewError(err) } - if r.Error() == nil { - // no error in the resty Response - return r, nil - } - - // handle the resty Response errors - - // Check that response is of the correct content-type before unmarshalling - expectedContentType := r.Request.Header.Get("Accept") - responseContentType := r.Header().Get("Content-Type") - - // If the upstream Linode API server being fronted fails to respond to the request, - // the http server will respond with a default "Bad Gateway" page with Content-Type - // "text/html". - if r.StatusCode() == http.StatusBadGateway && responseContentType == "text/html" { //nolint:goconst - return nil, Error{Code: http.StatusBadGateway, Message: http.StatusText(http.StatusBadGateway)} - } - - if responseContentType != expectedContentType { - msg := fmt.Sprintf( - "Unexpected Content-Type: Expected: %v, Received: %v\nResponse body: %s", - expectedContentType, - responseContentType, - string(r.Body()), - ) - - return nil, Error{Code: r.StatusCode(), Message: msg} - } - - apiError, ok := r.Error().(*APIError) - if !ok || (ok && len(apiError.Errors) == 0) { - return r, nil + if resp == nil { + return nil, NewError(fmt.Errorf("response is nil")) } - return nil, NewError(r) -} - -//nolint:unused -func coupleAPIErrorsHTTP(resp *http.Response, err error) (*http.Response, error) { - if err != nil { - // an error was raised in go code, no need to check the http.Response - return nil, NewError(err) - } - - if resp == nil || resp.StatusCode < 200 || resp.StatusCode >= 300 { + if resp.StatusCode < 200 || resp.StatusCode >= 300 { // Check that response is of the correct content-type before unmarshalling - expectedContentType := resp.Request.Header.Get("Accept") + expectedContentType := "" + if resp.Request != nil && resp.Request.Header != nil { + expectedContentType = resp.Request.Header.Get("Accept") + } + responseContentType := resp.Header.Get("Content-Type") // If the upstream server fails to respond to the request, - // the http server will respond with a default error page with Content-Type "text/html". - if resp.StatusCode == http.StatusBadGateway && responseContentType == "text/html" { //nolint:goconst - return nil, Error{Code: http.StatusBadGateway, Message: http.StatusText(http.StatusBadGateway)} + // the HTTP server will respond with a default error page with Content-Type "text/html". + if resp.StatusCode == http.StatusBadGateway && responseContentType == "text/html" { + return nil, &Error{Code: http.StatusBadGateway, Message: http.StatusText(http.StatusBadGateway), Response: resp} } if responseContentType != expectedContentType { - bodyBytes, _ := io.ReadAll(resp.Body) + if resp.Body == nil { + return nil, NewError(fmt.Errorf("response body is nil")) + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, NewError(fmt.Errorf("failed to read response body: %w", err)) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + msg := fmt.Sprintf( "Unexpected Content-Type: Expected: %v, Received: %v\nResponse body: %s", expectedContentType, @@ -124,11 +91,12 @@ func coupleAPIErrorsHTTP(resp *http.Response, err error) (*http.Response, error) string(bodyBytes), ) - return nil, Error{Code: resp.StatusCode, Message: msg} + return nil, &Error{Code: resp.StatusCode, Message: msg} } - var apiError APIError - if err := json.NewDecoder(resp.Body).Decode(&apiError); err != nil { + // Must check if there is no list of reasons in the error before making a call to NewError + apiError, ok := getAPIError(resp) + if !ok { return nil, NewError(fmt.Errorf("failed to decode response body: %w", err)) } @@ -136,10 +104,9 @@ func coupleAPIErrorsHTTP(resp *http.Response, err error) (*http.Response, error) return resp, nil } - return nil, Error{Code: resp.StatusCode, Message: apiError.Errors[0].String()} + return nil, NewError(resp) } - // no error in the http.Response return resp, nil } @@ -156,7 +123,7 @@ func (e APIError) Error() string { // - ErrorFromString (1) from a string // - ErrorFromError (2) for an error // - ErrorFromStringer (3) for a Stringer -// - HTTP Status Codes (100-600) for a resty.Response object +// - HTTP Status Codes (100-600) for a http.Response object func NewError(err any) *Error { if err == nil { return nil @@ -165,17 +132,17 @@ func NewError(err any) *Error { switch e := err.(type) { case *Error: return e - case *resty.Response: - apiError, ok := e.Error().(*APIError) + case *http.Response: + apiError, ok := getAPIError(e) if !ok { - return &Error{Code: ErrorUnsupported, Message: "Unexpected Resty Error Response, no error"} + return &Error{Code: ErrorUnsupported, Message: "Unexpected HTTP Error Response, no error"} } return &Error{ - Code: e.RawResponse.StatusCode, + Code: e.StatusCode, Message: apiError.Error(), - Response: e.RawResponse, + Response: e, } case error: return &Error{Code: ErrorFromError, Message: e.Error()} diff --git a/errors_test.go b/errors_test.go index 68428fcd6..45ed16b7e 100644 --- a/errors_test.go +++ b/errors_test.go @@ -7,11 +7,13 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net/http" "net/http/httptest" + "strconv" + "strings" "testing" - "github.com/go-resty/resty/v2" "github.com/google/go-cmp/cmp" ) @@ -27,10 +29,10 @@ func (e testError) Error() string { return string(e) } -func restyError(reason, field string) *resty.Response { +func httpError(reason, field string) *http.Response { var reasons []APIErrorReason - // allow for an empty reasons + // Allow for an empty reasons if reason != "" && field != "" { reasons = append(reasons, APIErrorReason{ Reason: reason, @@ -38,15 +40,18 @@ func restyError(reason, field string) *resty.Response { }) } - return &resty.Response{ - RawResponse: &http.Response{ - StatusCode: 500, - }, - Request: &resty.Request{ - Error: &APIError{ - Errors: reasons, - }, - }, + apiError := &APIError{ + Errors: reasons, + } + + body, err := json.Marshal(apiError) + if err != nil { + panic("Failed to marshal APIError") + } + + return &http.Response{ + StatusCode: 500, + Body: ioutil.NopCloser(bytes.NewReader(body)), } } @@ -73,12 +78,12 @@ func TestNewError(t *testing.T) { t.Error("Error should be itself") } - if err := NewError(&resty.Response{Request: &resty.Request{}}); err.Message != "Unexpected Resty Error Response, no error" { - t.Error("Unexpected Resty Error Response, no error") + if err := NewError(&http.Response{Request: &http.Request{}}); err.Message != "Unexpected HTTP Error Response, no error" { + t.Error("Unexpected HTTP Error Response, no error") } - if err := NewError(restyError("testreason", "testfield")); err.Message != "[testfield] testreason" { - t.Error("rest response error should should be set") + if err := NewError(httpError("testreason", "testfield")); err.Message != "[testfield] testreason" { + t.Error("http response error should should be set") } if err := NewError("stringerror"); err.Message != "stringerror" || err.Code != ErrorFromString { @@ -119,98 +124,6 @@ func TestCoupleAPIErrors(t *testing.T) { } }) - t.Run("resty 500 response error with reasons", func(t *testing.T) { - if _, err := coupleAPIErrors(restyError("testreason", "testfield"), nil); err.Error() != "[500] [testfield] testreason" { - t.Error("resty error should return with proper format [code] [field] reason") - } - }) - - t.Run("resty 500 response error without reasons", func(t *testing.T) { - if _, err := coupleAPIErrors(restyError("", ""), nil); err != nil { - t.Error("resty error with no reasons should return no error") - } - }) - - t.Run("resty response with nil error", func(t *testing.T) { - emptyErr := &resty.Response{ - RawResponse: &http.Response{ - StatusCode: 500, - }, - Request: &resty.Request{ - Error: nil, - }, - } - if _, err := coupleAPIErrors(emptyErr, nil); err != nil { - t.Error("resty error with no reasons should return no error") - } - }) - - t.Run("generic html error", func(t *testing.T) { - rawResponse := ` -500 Internal Server Error - -

500 Internal Server Error

-
nginx
- -` - route := "/v4/linode/instances/123" - ts, client := createTestServer(http.MethodGet, route, "text/html", rawResponse, http.StatusInternalServerError) - // client.SetDebug(true) - defer ts.Close() - - expectedError := Error{ - Code: http.StatusInternalServerError, - Message: "Unexpected Content-Type: Expected: application/json, Received: text/html\nResponse body: " + rawResponse, - } - - _, err := coupleAPIErrors(client.R(context.Background()).SetResult(&Instance{}).Get(ts.URL + route)) - if diff := cmp.Diff(expectedError, err); diff != "" { - t.Errorf("expected error to match but got diff:\n%s", diff) - } - }) - - t.Run("bad gateway error", func(t *testing.T) { - rawResponse := []byte(` -502 Bad Gateway - -

502 Bad Gateway

-
nginx
- -`) - buf := io.NopCloser(bytes.NewBuffer(rawResponse)) - - resp := &resty.Response{ - Request: &resty.Request{ - Error: errors.New("Bad Gateway"), - }, - RawResponse: &http.Response{ - Header: http.Header{ - "Content-Type": []string{"text/html"}, - }, - StatusCode: http.StatusBadGateway, - Body: buf, - }, - } - - expectedError := Error{ - Code: http.StatusBadGateway, - Message: http.StatusText(http.StatusBadGateway), - } - - if _, err := coupleAPIErrors(resp, nil); !cmp.Equal(err, expectedError) { - t.Errorf("expected error %#v to match error %#v", err, expectedError) - } - }) -} - -func TestCoupleAPIErrorsHTTP(t *testing.T) { - t.Run("not nil error generates error", func(t *testing.T) { - err := errors.New("test") - if _, err := coupleAPIErrorsHTTP(nil, err); !cmp.Equal(err, NewError(err)) { - t.Errorf("expect a not nil error to be returned as an Error") - } - }) - t.Run("http 500 response error with reasons", func(t *testing.T) { // Create the simulated HTTP response with a 500 status and a JSON body containing the error details apiError := APIError{ @@ -228,7 +141,7 @@ func TestCoupleAPIErrorsHTTP(t *testing.T) { Request: &http.Request{Header: http.Header{"Accept": []string{"application/json"}}}, } - _, err := coupleAPIErrorsHTTP(resp, nil) + _, err := coupleAPIErrors(resp, nil) expectedMessage := "[500] [testfield] testreason" if err == nil || err.Error() != expectedMessage { t.Errorf("expected error message %q, got: %v", expectedMessage, err) @@ -250,7 +163,7 @@ func TestCoupleAPIErrorsHTTP(t *testing.T) { Request: &http.Request{Header: http.Header{"Accept": []string{"application/json"}}}, } - _, err := coupleAPIErrorsHTTP(resp, nil) + _, err := coupleAPIErrors(resp, nil) if err != nil { t.Error("http error with no reasons should return no error") } @@ -265,7 +178,7 @@ func TestCoupleAPIErrorsHTTP(t *testing.T) { Request: &http.Request{Header: http.Header{"Accept": []string{"application/json"}}}, } - _, err := coupleAPIErrorsHTTP(resp, nil) + _, err := coupleAPIErrors(resp, nil) if err != nil { t.Error("http error with no reasons should return no error") } @@ -288,15 +201,10 @@ func TestCoupleAPIErrorsHTTP(t *testing.T) { })) defer ts.Close() - client := &httpClient{ + client := &Client{ httpClient: ts.Client(), } - expectedError := Error{ - Code: http.StatusInternalServerError, - Message: "Unexpected Content-Type: Expected: application/json, Received: text/html\nResponse body: " + rawResponse, - } - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, ts.URL+route, nil) if err != nil { t.Fatalf("failed to create request: %v", err) @@ -308,11 +216,23 @@ func TestCoupleAPIErrorsHTTP(t *testing.T) { if err != nil { t.Fatalf("failed to send request: %v", err) } + + expectedError := &Error{ + Response: resp, + Code: http.StatusInternalServerError, + Message: "Unexpected Content-Type: Expected: application/json, Received: text/html\nResponse body: " + rawResponse, + } + defer resp.Body.Close() - _, err = coupleAPIErrorsHTTP(resp, nil) - if diff := cmp.Diff(expectedError, err); diff != "" { - t.Errorf("expected error to match but got diff:\n%s", diff) + _, err = coupleAPIErrors(resp, nil) + + if !strings.Contains(err.Error(), strconv.Itoa(expectedError.Code)) { + t.Errorf("expected error code %d, got %d", expectedError.Code, resp.StatusCode) + } + + if !strings.Contains(err.Error(), expectedError.Message) { + t.Errorf("expected error message %s, got %s", expectedError.Message, err.Error()) } }) @@ -337,14 +257,19 @@ func TestCoupleAPIErrorsHTTP(t *testing.T) { }, } - expectedError := Error{ - Code: http.StatusBadGateway, - Message: http.StatusText(http.StatusBadGateway), + expectedError := &Error{ + Response: resp, + Code: http.StatusBadGateway, + Message: http.StatusText(http.StatusBadGateway), + } + + _, err := coupleAPIErrors(resp, nil) + if !strings.Contains(err.Error(), strconv.Itoa(expectedError.Code)) { + t.Errorf("expected error code %d, got %d", expectedError.Code, resp.StatusCode) } - _, err := coupleAPIErrorsHTTP(resp, nil) - if !cmp.Equal(err, expectedError) { - t.Errorf("expected error %#v to match error %#v", err, expectedError) + if !strings.Contains(err.Error(), expectedError.Message) { + t.Errorf("expected error message %s, got %s", expectedError.Message, err.Error()) } }) } @@ -382,26 +307,26 @@ func TestErrorIs(t *testing.T) { expectedResult: true, }, { - testName: "default and Error from empty resty error", - err1: NewError(restyError("", "")), + testName: "default and Error from empty http error", + err1: NewError(httpError("", "")), err2: defaultError, expectedResult: true, }, { - testName: "default and Error from resty error with field", - err1: NewError(restyError("", "test field")), + testName: "default and Error from http error with field", + err1: NewError(httpError("", "test field")), err2: defaultError, expectedResult: true, }, { - testName: "default and Error from resty error with field and reason", - err1: NewError(restyError("test reason", "test field")), + testName: "default and Error from http error with field and reason", + err1: NewError(httpError("test reason", "test field")), err2: defaultError, expectedResult: true, }, { - testName: "default and Error from resty error with reason", - err1: NewError(restyError("test reason", "")), + testName: "default and Error from http error with reason", + err1: NewError(httpError("test reason", "")), err2: defaultError, expectedResult: true, }, diff --git a/go.mod b/go.mod index b00d38c4c..6cc016482 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,6 @@ module github.com/linode/linodego require ( - github.com/go-resty/resty/v2 v2.17.2 github.com/google/go-cmp v0.7.0 github.com/google/go-querystring v1.2.0 github.com/jarcoal/httpmock v1.4.1 diff --git a/go.sum b/go.sum index fcb368761..5b27625ab 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk= -github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= @@ -29,8 +27,6 @@ golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= -golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= -golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/ini.v1 v1.67.1 h1:tVBILHy0R6e4wkYOn3XmiITt/hEVH4TFMYvAX2Ytz6k= diff --git a/images.go b/images.go index 5fdf644fb..88dfad5cb 100644 --- a/images.go +++ b/images.go @@ -4,9 +4,9 @@ import ( "context" "encoding/json" "io" + "net/http" "time" - "github.com/go-resty/resty/v2" "github.com/linode/linodego/internal/parseabletime" ) @@ -297,17 +297,45 @@ func (c *Client) CreateImageUpload(ctx context.Context, opts ImageCreateUploadOp // UploadImageToURL uploads the given image to the given upload URL. func (c *Client) UploadImageToURL(ctx context.Context, uploadURL string, image io.Reader) error { - // Linode-specific headers do not need to be sent to this endpoint - req := resty.New().SetDebug(c.resty.Debug).R(). - SetContext(ctx). - SetContentLength(true). - SetHeader("Content-Type", "application/octet-stream"). - SetBody(image) + clonedClient := *c.httpClient + clonedClient.Transport = http.DefaultTransport - _, err := coupleAPIErrors(req. - Put(uploadURL)) + var contentLength int64 = -1 - return err + if seeker, ok := image.(io.Seeker); ok { + size, err := seeker.Seek(0, io.SeekEnd) + if err != nil { + return err + } + + _, err = seeker.Seek(0, io.SeekStart) + if err != nil { + return err + } + + contentLength = size + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPut, uploadURL, image) + if err != nil { + return err + } + + if contentLength >= 0 { + req.ContentLength = contentLength + } + + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("User-Agent", c.userAgent) + + resp, err := clonedClient.Do(req) + + _, err = coupleAPIErrors(resp, err) + if err != nil { + return err + } + + return nil } // UploadImage creates and uploads an image. diff --git a/internal/testutil/mock.go b/internal/testutil/mock.go index c8235ce65..c32bcae42 100644 --- a/internal/testutil/mock.go +++ b/internal/testutil/mock.go @@ -115,16 +115,16 @@ type TestLogger struct { L *log.Logger } -func (l *TestLogger) Errorf(format string, v ...any) { - l.outputf("ERROR RESTY "+format, v...) +func (l *TestLogger) Errorf(format string, v ...interface{}) { + l.outputf("ERROR "+format, v...) } -func (l *TestLogger) Warnf(format string, v ...any) { - l.outputf("WARN RESTY "+format, v...) +func (l *TestLogger) Warnf(format string, v ...interface{}) { + l.outputf("WARN "+format, v...) } -func (l *TestLogger) Debugf(format string, v ...any) { - l.outputf("DEBUG RESTY "+format, v...) +func (l *TestLogger) Debugf(format string, v ...interface{}) { + l.outputf("DEBUG "+format, v...) } func (l *TestLogger) outputf(format string, v ...any) { diff --git a/k8s/go.mod b/k8s/go.mod index 4e148b8c5..fe876ba3c 100644 --- a/k8s/go.mod +++ b/k8s/go.mod @@ -14,7 +14,6 @@ require ( github.com/go-openapi/jsonpointer v0.19.6 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/swag v0.22.3 // indirect - github.com/go-resty/resty/v2 v2.17.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect diff --git a/k8s/go.sum b/k8s/go.sum index 4e571690a..f93e11a91 100644 --- a/k8s/go.sum +++ b/k8s/go.sum @@ -12,8 +12,6 @@ github.com/go-openapi/jsonreference v0.20.2 h1:3sVjiK66+uXK/6oQ8xgcRKcFgQ5KXa2Kv github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k= github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g= github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= -github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk= -github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= diff --git a/logger.go b/logger.go index 890327e2b..07e2f99f7 100644 --- a/logger.go +++ b/logger.go @@ -5,44 +5,36 @@ import ( "os" ) -//nolint:unused -type httpLogger interface { - Errorf(format string, v ...any) - Warnf(format string, v ...any) - Debugf(format string, v ...any) +type Logger interface { + Errorf(format string, v ...interface{}) + Warnf(format string, v ...interface{}) + Debugf(format string, v ...interface{}) } -//nolint:unused type logger struct { l *log.Logger } -//nolint:unused func createLogger() *logger { l := &logger{l: log.New(os.Stderr, "", log.Ldate|log.Lmicroseconds)} return l } -//nolint:unused -var _ httpLogger = (*logger)(nil) +var _ Logger = (*logger)(nil) -//nolint:unused -func (l *logger) Errorf(format string, v ...any) { - l.output("ERROR RESTY "+format, v...) +func (l *logger) Errorf(format string, v ...interface{}) { + l.output("ERROR "+format, v...) } -//nolint:unused -func (l *logger) Warnf(format string, v ...any) { - l.output("WARN RESTY "+format, v...) +func (l *logger) Warnf(format string, v ...interface{}) { + l.output("WARN "+format, v...) } -//nolint:unused -func (l *logger) Debugf(format string, v ...any) { - l.output("DEBUG RESTY "+format, v...) +func (l *logger) Debugf(format string, v ...interface{}) { + l.output("DEBUG "+format, v...) } -//nolint:unused -func (l *logger) output(format string, v ...any) { //nolint:goprintffuncname +func (l *logger) output(format string, v ...interface{}) { //nolint:goprintffuncname if len(v) == 0 { l.l.Print(format) return diff --git a/monitor_alert_definitions.go b/monitor_alert_definitions.go index 730665655..3bb09ec9d 100644 --- a/monitor_alert_definitions.go +++ b/monitor_alert_definitions.go @@ -1,8 +1,10 @@ package linodego import ( + "bytes" "context" "encoding/json" + "net/http" "time" "github.com/linode/linodego/internal/parseabletime" @@ -263,25 +265,27 @@ func (c *Client) CreateMonitorAlertDefinitionWithIdempotency( var result AlertDefinition - req := c.R(ctx).SetResult(&result) - - if idempotencyKey != "" { - req.SetHeader("Idempotency-Key", idempotencyKey) - } - body, err := json.Marshal(opts) if err != nil { return nil, err } - req.SetBody(string(body)) + params := requestParams{ + Response: &result, + Body: bytes.NewReader(body), + } + + if idempotencyKey != "" { + c.SetHeader("Idempotency-Key", idempotencyKey) + defer c.header.Del("Idempotency-Key") + } - r, err := coupleAPIErrors(req.Post(e)) + err = c.doRequest(ctx, http.MethodPost, e, params, nil) if err != nil { return nil, err } - return r.Result().(*AlertDefinition), nil + return &result, nil } // UpdateMonitorAlertDefinition updates an ACLP Monitor Alert Definition. diff --git a/monitor_api_services.go b/monitor_api_services.go index 9d379353e..a7ee4ca6d 100644 --- a/monitor_api_services.go +++ b/monitor_api_services.go @@ -4,8 +4,10 @@ package linodego import ( + "bytes" "context" "encoding/json" + "net/http" "time" ) @@ -101,7 +103,11 @@ type MetricAbsoluteTimeDuration struct { func (mc *MonitorClient) FetchEntityMetrics(ctx context.Context, serviceType string, opts *EntityMetricsFetchOptions) (*EntityMetrics, error) { endpoint := formatAPIPath("monitor/services/%s/metrics", serviceType) - req := mc.R(ctx).SetResult(&EntityMetrics{}) + var result EntityMetrics + + params := requestParams{ + Response: &result, + } if opts != nil { body, err := json.Marshal(opts) @@ -109,13 +115,13 @@ func (mc *MonitorClient) FetchEntityMetrics(ctx context.Context, serviceType str return nil, err } - req.SetBody(string(body)) + params.Body = bytes.NewReader(body) } - r, err := coupleAPIErrors(req.Post(endpoint)) + err := mc.doRequest(ctx, http.MethodPost, endpoint, params) if err != nil { return nil, err } - return r.Result().(*EntityMetrics), nil + return &result, nil } diff --git a/pagination.go b/pagination.go index 074720e82..2fdcb86fc 100644 --- a/pagination.go +++ b/pagination.go @@ -9,10 +9,9 @@ import ( "encoding/hex" "encoding/json" "fmt" + "net/http" "reflect" "strconv" - - "github.com/go-resty/resty/v2" ) // PageOptions are the pagination parameters for List endpoints @@ -56,38 +55,48 @@ func (l ListOptions) Hash() (string, error) { return hex.EncodeToString(h.Sum(nil)), nil } -func applyListOptionsToRequest(opts *ListOptions, req *resty.Request) error { +func createListOptionsToRequestMutator(opts *ListOptions) func(*http.Request) error { if opts == nil { return nil } - if opts.QueryParams != nil { - params, err := flattenQueryStruct(opts.QueryParams) - if err != nil { - return fmt.Errorf("failed to apply list options: %w", err) + // Return a mutator to apply query parameters and headers + return func(req *http.Request) error { + query := req.URL.Query() + + // Apply QueryParams from ListOptions if present + if opts.QueryParams != nil { + params, err := flattenQueryStruct(opts.QueryParams) + if err != nil { + return fmt.Errorf("failed to apply list options: %w", err) + } + for key, value := range params { + query.Set(key, value) + } } - req.SetQueryParams(params) - } - - if opts.PageOptions != nil && opts.Page > 0 { - req.SetQueryParam("page", strconv.Itoa(opts.Page)) - } + // Apply pagination options + if opts.PageOptions != nil && opts.Page > 0 { + query.Set("page", strconv.Itoa(opts.Page)) + } + if opts.PageSize > 0 { + query.Set("page_size", strconv.Itoa(opts.PageSize)) + } - if opts.PageSize > 0 { - req.SetQueryParam("page_size", strconv.Itoa(opts.PageSize)) - } + // Apply filters as headers + if len(opts.Filter) > 0 { + req.Header.Set("X-Filter", opts.Filter) + } - if len(opts.Filter) > 0 { - req.SetHeader("X-Filter", opts.Filter) + // Assign the updated query back to the request URL + req.URL.RawQuery = query.Encode() + return nil } - - return nil } type PagedResponse interface { endpoint(...any) string - castResult(*resty.Request, string) (int, int, error) + castResult(*http.Request, string) (int, int, error) } // flattenQueryStruct flattens a structure into a Resty-compatible query param map. diff --git a/request_helpers.go b/request_helpers.go index d70e515ea..d172521e3 100644 --- a/request_helpers.go +++ b/request_helpers.go @@ -1,9 +1,11 @@ package linodego import ( + "bytes" "context" "encoding/json" "fmt" + "net/http" "net/url" "reflect" ) @@ -60,56 +62,33 @@ func handlePaginatedResults[T any, O any]( handlePage := func(page int) error { var resultType paginatedResponse[T] - // Override the page to be applied in applyListOptionsToRequest(...) + // Override the page to be applied in createListOptionsToRequestMutator(...) opts.Page = page - // This request object cannot be reused for each page request - // because it can lead to possible data corruption - req := client.R(ctx).SetResult(&resultType) - - // Apply all user-provided list options to the request - if err := applyListOptionsToRequest(opts, req); err != nil { - return err + params := requestParams{ + Response: &resultType, } - // Set request body if provided if reqBody != "" { - req.SetBody(reqBody) + params.Body = bytes.NewReader([]byte(reqBody)) } - var response *paginatedResponse[T] - // Execute the appropriate HTTP method - switch method { - case "GET": - res, err := coupleAPIErrors(req.Get(endpoint)) - if err != nil { - return err - } - - response = res.Result().(*paginatedResponse[T]) - case "PUT": - res, err := coupleAPIErrors(req.Put(endpoint)) - if err != nil { - return err - } - - response = res.Result().(*paginatedResponse[T]) - case "POST": - res, err := coupleAPIErrors(req.Post(endpoint)) - if err != nil { - return err - } - - response = res.Result().(*paginatedResponse[T]) - default: - return fmt.Errorf("unsupported HTTP method: %s", method) + // Create a mutator to apply all user-provided list options to the request + mutator := createListOptionsToRequestMutator(opts) + + // Make the request using doRequest + err := client.doRequest(ctx, method, endpoint, params, &mutator) + if err != nil { + return err } // Update pagination metadata opts.Page = page - opts.Pages = response.Pages - opts.Results = response.Results - result = append(result, response.Data...) + opts.Pages = resultType.Pages + opts.Results = resultType.Results + + // Append the data to the result slice + result = append(result, resultType.Data...) return nil } @@ -127,13 +106,12 @@ func handlePaginatedResults[T any, O any]( return nil, err } - // If the user has explicitly specified a page, we don't - // need to get any other pages. + // If a specific page is defined, return the result if pageDefined { return result, nil } - // Get the rest of the pages + // Get the remaining pages for page := 2; page <= opts.Pages; page++ { if err := handlePage(page); err != nil { return nil, err @@ -186,15 +164,16 @@ func doGETRequest[T any]( endpoint string, ) (*T, error) { var resultType T + params := requestParams{ + Response: &resultType, + } - req := client.R(ctx).SetResult(&resultType) - - r, err := coupleAPIErrors(req.Get(endpoint)) + err := client.doRequest(ctx, http.MethodGet, endpoint, params, nil) if err != nil { return nil, err } - return r.Result().(*T), nil + return &resultType, nil } // doPOSTRequest runs a PUT request using the given client, API endpoint, @@ -206,30 +185,27 @@ func doPOSTRequest[T, O any]( options ...O, ) (*T, error) { var resultType T - numOpts := len(options) - if numOpts > 1 { - return nil, fmt.Errorf("invalid number of options: %d", len(options)) + return nil, fmt.Errorf("invalid number of options: %d", numOpts) } - req := client.R(ctx).SetResult(&resultType) - + params := requestParams{ + Response: &resultType, + } if numOpts > 0 && !isNil(options[0]) { body, err := json.Marshal(options[0]) if err != nil { return nil, err } - - req.SetBody(string(body)) + params.Body = bytes.NewReader(body) } - r, err := coupleAPIErrors(req.Post(endpoint)) + err := client.doRequest(ctx, http.MethodPost, endpoint, params, nil) if err != nil { return nil, err } - - return r.Result().(*T), nil + return &resultType, nil } // doPOSTRequestNoResponseBody runs a POST request using the given client, API endpoint, @@ -263,30 +239,27 @@ func doPUTRequest[T, O any]( options ...O, ) (*T, error) { var resultType T - numOpts := len(options) - if numOpts > 1 { - return nil, fmt.Errorf("invalid number of options: %d", len(options)) + return nil, fmt.Errorf("invalid number of options: %d", numOpts) } - req := client.R(ctx).SetResult(&resultType) - + params := requestParams{ + Response: &resultType, + } if numOpts > 0 && !isNil(options[0]) { body, err := json.Marshal(options[0]) if err != nil { return nil, err } - - req.SetBody(string(body)) + params.Body = bytes.NewReader(body) } - r, err := coupleAPIErrors(req.Put(endpoint)) + err := client.doRequest(ctx, http.MethodPut, endpoint, params, nil) if err != nil { return nil, err } - - return r.Result().(*T), nil + return &resultType, nil } // doDELETERequest runs a DELETE request using the given client @@ -296,9 +269,8 @@ func doDELETERequest( client *Client, endpoint string, ) error { - req := client.R(ctx) - _, err := coupleAPIErrors(req.Delete(endpoint)) - + params := requestParams{} + err := client.doRequest(ctx, http.MethodDelete, endpoint, params, nil) return err } diff --git a/request_helpers_test.go b/request_helpers_test.go index ce4bf80b5..033c86628 100644 --- a/request_helpers_test.go +++ b/request_helpers_test.go @@ -80,11 +80,8 @@ func TestRequestHelpers_post(t *testing.T) { func TestRequestHelpers_postNoOptions(t *testing.T) { client := testutil.CreateMockClient(t, NewClient) - httpmock.RegisterRegexpResponder( - "POST", - testutil.MockRequestURL("/foo/bar"), - testutil.MockRequestBodyValidateNoBody(t, testResponse), - ) + httpmock.RegisterRegexpResponder("POST", testutil.MockRequestURL("/foo/bar"), + testutil.MockRequestBodyValidateNoBody(t, testResponse)) result, err := doPOSTRequest[testResultType, any]( context.Background(), @@ -124,11 +121,8 @@ func TestRequestHelpers_put(t *testing.T) { func TestRequestHelpers_putNoOptions(t *testing.T) { client := testutil.CreateMockClient(t, NewClient) - httpmock.RegisterRegexpResponder( - "PUT", - testutil.MockRequestURL("/foo/bar"), - testutil.MockRequestBodyValidateNoBody(t, testResponse), - ) + httpmock.RegisterRegexpResponder("PUT", testutil.MockRequestURL("/foo/bar"), + testutil.MockRequestBodyValidateNoBody(t, testResponse)) result, err := doPUTRequest[testResultType, any]( context.Background(), @@ -147,11 +141,8 @@ func TestRequestHelpers_putNoOptions(t *testing.T) { func TestRequestHelpers_delete(t *testing.T) { client := testutil.CreateMockClient(t, NewClient) - httpmock.RegisterRegexpResponder( - "DELETE", - testutil.MockRequestURL("/foo/bar/foo%20bar"), - httpmock.NewStringResponder(200, "{}"), - ) + httpmock.RegisterRegexpResponder("DELETE", testutil.MockRequestURL("/foo/bar/foo%20bar"), + httpmock.NewStringResponder(200, "{}")) if err := doDELETERequest( context.Background(), @@ -169,14 +160,8 @@ func TestRequestHelpers_paginateAll(t *testing.T) { numRequests := 0 - httpmock.RegisterRegexpResponder( - "GET", - testutil.MockRequestURL("/foo/bar"), - mockPaginatedResponse( - buildPaginatedEntries(totalResults), - &numRequests, - ), - ) + httpmock.RegisterRegexpResponder("GET", testutil.MockRequestURL("/foo/bar"), + mockPaginatedResponse(buildPaginatedEntries(totalResults), &numRequests)) response, err := getPaginatedResults[testResultType]( context.Background(), @@ -205,14 +190,8 @@ func TestRequestHelpers_paginateSingle(t *testing.T) { numRequests := 0 - httpmock.RegisterRegexpResponder( - "GET", - testutil.MockRequestURL("/foo/bar"), - mockPaginatedResponse( - buildPaginatedEntries(12), - &numRequests, - ), - ) + httpmock.RegisterRegexpResponder("GET", testutil.MockRequestURL("/foo/bar"), + mockPaginatedResponse(buildPaginatedEntries(12), &numRequests)) response, err := getPaginatedResults[testResultType]( context.Background(), diff --git a/retries.go b/retries.go index 047fed84e..371ee03d5 100644 --- a/retries.go +++ b/retries.go @@ -1,74 +1,89 @@ package linodego import ( + "bytes" + "encoding/json" "errors" + "io" "log" "net/http" "strconv" "time" - "github.com/go-resty/resty/v2" "golang.org/x/net/http2" ) const ( - retryAfterHeaderName = "Retry-After" - maintenanceModeHeaderName = "X-Maintenance-Mode" - - defaultRetryCount = 1000 + RetryAfterHeaderName = "Retry-After" + MaintenanceModeHeaderName = "X-Maintenance-Mode" + DefaultRetryCount = 1000 ) -// RetryConditional func(r *resty.Response) (shouldRetry bool) -type RetryConditional resty.RetryConditionFunc +// RetryConditional is a type alias for a function that determines if a request should be retried based on the response and error. +type RetryConditional func(*http.Response, error) bool -// RetryAfter func(c *resty.Client, r *resty.Response) (time.Duration, error) -type RetryAfter resty.RetryAfterFunc +// RetryAfter is a type alias for a function that determines the duration to wait before retrying based on the response. +type RetryAfter func(*http.Response) (time.Duration, error) -// Configures resty to -// lock until enough time has passed to retry the request as determined by the Retry-After response header. -// If the Retry-After header is not set, we fall back to value of SetPollDelay. -func configureRetries(c *Client) { - c.resty. - SetRetryCount(defaultRetryCount). - AddRetryCondition(checkRetryConditionals(c)). - SetRetryAfter(respectRetryAfter) +// Configures http.Client to lock until enough time has passed to retry the request as determined by the Retry-After response header. +// If the Retry-After header is not set, we fall back to the value of SetPollDelay. +func ConfigureRetries(c *Client) { + c.SetRetryAfter(RespectRetryAfter) + c.SetRetryCount(DefaultRetryCount) } -func checkRetryConditionals(c *Client) func(*resty.Response, error) bool { - return func(r *resty.Response, err error) bool { - for _, retryConditional := range c.retryConditionals { - retry := retryConditional(r, err) - if retry { - log.Printf("[INFO] Received error %s - Retrying", r.Error()) - return true - } - } +func RespectRetryAfter(resp *http.Response) (time.Duration, error) { + if resp == nil { + return 0, nil + } - return false + retryAfterStr := resp.Header.Get(RetryAfterHeaderName) + if retryAfterStr == "" { + return 0, nil + } + + retryAfter, err := strconv.Atoi(retryAfterStr) + if err != nil { + return 0, err } + + duration := time.Duration(retryAfter) * time.Second + log.Printf("[INFO] Respecting Retry-After Header of %d (%s)", retryAfter, duration) + return duration, nil } -// SetLinodeBusyRetry configures resty to retry specifically on "Linode busy." errors -// The retry wait time is configured in SetPollDelay -func linodeBusyRetryCondition(r *resty.Response, _ error) bool { - apiError, ok := r.Error().(*APIError) - linodeBusy := ok && apiError.Error() == "Linode busy." - retry := r.StatusCode() == http.StatusBadRequest && linodeBusy +// Retry conditions +func LinodeBusyRetryCondition(resp *http.Response, _ error) bool { + if resp == nil { + return false + } + + apiError, ok := getAPIError(resp) + linodeBusy := ok && apiError.Error() == "Linode busy." + retry := resp.StatusCode == http.StatusBadRequest && linodeBusy return retry } -func tooManyRequestsRetryCondition(r *resty.Response, _ error) bool { - return r.StatusCode() == http.StatusTooManyRequests +func TooManyRequestsRetryCondition(resp *http.Response, _ error) bool { + if resp == nil { + return false + } + + return resp.StatusCode == http.StatusTooManyRequests } -func serviceUnavailableRetryCondition(r *resty.Response, _ error) bool { - serviceUnavailable := r.StatusCode() == http.StatusServiceUnavailable +func ServiceUnavailableRetryCondition(resp *http.Response, _ error) bool { + if resp == nil { + return false + } + + serviceUnavailable := resp.StatusCode == http.StatusServiceUnavailable // During maintenance events, the API will return a 503 and add // an `X-MAINTENANCE-MODE` header. Don't retry during maintenance // events, only for legitimate 503s. - if serviceUnavailable && r.Header().Get(maintenanceModeHeaderName) != "" { + if serviceUnavailable && resp.Header.Get(MaintenanceModeHeaderName) != "" { log.Printf("[INFO] Linode API is under maintenance, request will not be retried - please see status.linode.com for more information") return false } @@ -76,33 +91,46 @@ func serviceUnavailableRetryCondition(r *resty.Response, _ error) bool { return serviceUnavailable } -func requestTimeoutRetryCondition(r *resty.Response, _ error) bool { - return r.StatusCode() == http.StatusRequestTimeout +func RequestTimeoutRetryCondition(resp *http.Response, _ error) bool { + if resp == nil { + return false + } + + return resp.StatusCode == http.StatusRequestTimeout } -func requestGOAWAYRetryCondition(_ *resty.Response, e error) bool { - return errors.As(e, &http2.GoAwayError{}) +func RequestGOAWAYRetryCondition(_ *http.Response, err error) bool { + return errors.As(err, &http2.GoAwayError{}) } -func requestNGINXRetryCondition(r *resty.Response, _ error) bool { - return r.StatusCode() == http.StatusBadRequest && - r.Header().Get("Server") == "nginx" && - r.Header().Get("Content-Type") == "text/html" +func RequestNGINXRetryCondition(resp *http.Response, _ error) bool { + if resp == nil { + return false + } + + return resp.StatusCode == http.StatusBadRequest && + resp.Header.Get("Server") == "nginx" && + resp.Header.Get("Content-Type") == "text/html" } -func respectRetryAfter(client *resty.Client, resp *resty.Response) (time.Duration, error) { - retryAfterStr := resp.Header().Get(retryAfterHeaderName) - if retryAfterStr == "" { - return 0, nil +// Helper function to extract APIError from response +func getAPIError(resp *http.Response) (*APIError, bool) { + if resp.Body == nil { + return nil, false } - retryAfter, err := strconv.Atoi(retryAfterStr) + body, err := io.ReadAll(resp.Body) if err != nil { - return 0, err + return nil, false } - duration := time.Duration(retryAfter) * time.Second - log.Printf("[INFO] Respecting Retry-After Header of %d (%s) (max %s)", retryAfter, duration, client.RetryMaxWaitTime) + resp.Body = io.NopCloser(bytes.NewReader(body)) - return duration, nil + var apiError APIError + err = json.Unmarshal(body, &apiError) + if err != nil { + return nil, false + } + + return &apiError, true } diff --git a/retries_http.go b/retries_http.go deleted file mode 100644 index 46b986e8d..000000000 --- a/retries_http.go +++ /dev/null @@ -1,132 +0,0 @@ -package linodego - -import ( - "encoding/json" - "errors" - "log" - "net/http" - "strconv" - "time" - - "golang.org/x/net/http2" -) - -const ( - // nolint:unused - httpRetryAfterHeaderName = "Retry-After" - // nolint:unused - httpMaintenanceModeHeaderName = "X-Maintenance-Mode" - - // nolint:unused - httpDefaultRetryCount = 1000 -) - -// RetryConditional is a type alias for a function that determines if a request should be retried based on the response and error. -// nolint:unused -type httpRetryConditional func(*http.Response, error) bool - -// RetryAfter is a type alias for a function that determines the duration to wait before retrying based on the response. -// nolint:unused -type httpRetryAfter func(*http.Response) (time.Duration, error) - -// Configures http.Client to lock until enough time has passed to retry the request as determined by the Retry-After response header. -// If the Retry-After header is not set, we fall back to the value of SetPollDelay. -// nolint:unused -func httpConfigureRetries(c *httpClient) { - c.retryConditionals = append(c.retryConditionals, httpcheckRetryConditionals(c)) - c.retryAfter = httpRespectRetryAfter -} - -// nolint:unused -func httpcheckRetryConditionals(c *httpClient) httpRetryConditional { - return func(resp *http.Response, err error) bool { - for _, retryConditional := range c.retryConditionals { - retry := retryConditional(resp, err) - if retry { - log.Printf("[INFO] Received error %v - Retrying", err) - return true - } - } - - return false - } -} - -// nolint:unused -func httpRespectRetryAfter(resp *http.Response) (time.Duration, error) { - retryAfterStr := resp.Header.Get(retryAfterHeaderName) - if retryAfterStr == "" { - return 0, nil - } - - retryAfter, err := strconv.Atoi(retryAfterStr) - if err != nil { - return 0, err - } - - duration := time.Duration(retryAfter) * time.Second - log.Printf("[INFO] Respecting Retry-After Header of %d (%s)", retryAfter, duration) - - return duration, nil -} - -// Retry conditions - -// nolint:unused -func httpLinodeBusyRetryCondition(resp *http.Response, _ error) bool { - apiError, ok := getAPIError(resp) - linodeBusy := ok && apiError.Error() == "Linode busy." - retry := resp.StatusCode == http.StatusBadRequest && linodeBusy - - return retry -} - -// nolint:unused -func httpTooManyRequestsRetryCondition(resp *http.Response, _ error) bool { - return resp.StatusCode == http.StatusTooManyRequests -} - -// nolint:unused -func httpServiceUnavailableRetryCondition(resp *http.Response, _ error) bool { - serviceUnavailable := resp.StatusCode == http.StatusServiceUnavailable - - // During maintenance events, the API will return a 503 and add - // an `X-MAINTENANCE-MODE` header. Don't retry during maintenance - // events, only for legitimate 503s. - if serviceUnavailable && resp.Header.Get(maintenanceModeHeaderName) != "" { - log.Printf("[INFO] Linode API is under maintenance, request will not be retried - please see status.linode.com for more information") - return false - } - - return serviceUnavailable -} - -// nolint:unused -func httpRequestTimeoutRetryCondition(resp *http.Response, _ error) bool { - return resp.StatusCode == http.StatusRequestTimeout -} - -// nolint:unused -func httpRequestGOAWAYRetryCondition(_ *http.Response, err error) bool { - return errors.As(err, &http2.GoAwayError{}) -} - -// nolint:unused -func httpRequestNGINXRetryCondition(resp *http.Response, _ error) bool { - return resp.StatusCode == http.StatusBadRequest && - resp.Header.Get("Server") == "nginx" && - resp.Header.Get("Content-Type") == "text/html" -} - -// Helper function to extract APIError from response -// nolint:unused -func getAPIError(resp *http.Response) (*APIError, bool) { - var apiError APIError - - err := json.NewDecoder(resp.Body).Decode(&apiError) - if err != nil { - return nil, false - } - - return &apiError, true -} diff --git a/retries_http_test.go b/retries_http_test.go deleted file mode 100644 index 35eea5fc9..000000000 --- a/retries_http_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package linodego - -import ( - "bytes" - "encoding/json" - "io" - "net/http" - "testing" - "time" -) - -func TestHTTPLinodeBusyRetryCondition(t *testing.T) { - var retry bool - - // Initialize response body - rawResponse := &http.Response{ - StatusCode: http.StatusBadRequest, - Body: io.NopCloser(bytes.NewBuffer(nil)), - } - - retry = httpLinodeBusyRetryCondition(rawResponse, nil) - - if retry { - t.Errorf("Should not have retried") - } - - apiError := APIError{ - Errors: []APIErrorReason{ - {Reason: "Linode busy."}, - }, - } - rawResponse.Body = createResponseBody(apiError) - - retry = httpLinodeBusyRetryCondition(rawResponse, nil) - - if !retry { - t.Errorf("Should have retried") - } -} - -func TestHTTPServiceUnavailableRetryCondition(t *testing.T) { - rawResponse := &http.Response{ - StatusCode: http.StatusServiceUnavailable, - Header: http.Header{httpRetryAfterHeaderName: []string{"20"}}, - Body: io.NopCloser(bytes.NewBuffer(nil)), // Initialize response body - } - - if retry := httpServiceUnavailableRetryCondition(rawResponse, nil); !retry { - t.Error("expected request to be retried") - } - - if retryAfter, err := httpRespectRetryAfter(rawResponse); err != nil { - t.Errorf("expected error to be nil but got %s", err) - } else if retryAfter != time.Second*20 { - t.Errorf("expected retryAfter to be 20 but got %d", retryAfter) - } -} - -func TestHTTPServiceMaintenanceModeRetryCondition(t *testing.T) { - rawResponse := &http.Response{ - StatusCode: http.StatusServiceUnavailable, - Header: http.Header{ - httpRetryAfterHeaderName: []string{"20"}, - httpMaintenanceModeHeaderName: []string{"Currently in maintenance mode."}, - }, - Body: io.NopCloser(bytes.NewBuffer(nil)), // Initialize response body - } - - if retry := httpServiceUnavailableRetryCondition(rawResponse, nil); retry { - t.Error("expected retry to be skipped due to maintenance mode header") - } -} - -// Helper function to create a response body from an object -func createResponseBody(obj interface{}) io.ReadCloser { - body, err := json.Marshal(obj) - if err != nil { - panic(err) - } - return io.NopCloser(bytes.NewBuffer(body)) -} diff --git a/retries_test.go b/retries_test.go index 4f0029388..45b5fc4d3 100644 --- a/retries_test.go +++ b/retries_test.go @@ -1,24 +1,24 @@ package linodego import ( + "bytes" + "encoding/json" + "io" "net/http" "testing" "time" - - "github.com/go-resty/resty/v2" ) func TestLinodeBusyRetryCondition(t *testing.T) { var retry bool - request := resty.Request{} - rawResponse := http.Response{StatusCode: http.StatusBadRequest} - response := resty.Response{ - Request: &request, - RawResponse: &rawResponse, + // Initialize response body + rawResponse := &http.Response{ + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(bytes.NewBuffer(nil)), } - retry = linodeBusyRetryCondition(&response, nil) + retry = LinodeBusyRetryCondition(rawResponse, nil) if retry { t.Errorf("Should not have retried") @@ -29,48 +29,53 @@ func TestLinodeBusyRetryCondition(t *testing.T) { {Reason: "Linode busy."}, }, } - request.SetError(&apiError) + rawResponse.Body = createResponseBody(apiError) - retry = linodeBusyRetryCondition(&response, nil) + retry = LinodeBusyRetryCondition(rawResponse, nil) if !retry { t.Errorf("Should have retried") } } -func TestLinodeServiceUnavailableRetryCondition(t *testing.T) { - request := resty.Request{} - rawResponse := http.Response{StatusCode: http.StatusServiceUnavailable, Header: http.Header{ - retryAfterHeaderName: []string{"20"}, - }} - response := resty.Response{ - Request: &request, - RawResponse: &rawResponse, +func TestServiceUnavailableRetryCondition(t *testing.T) { + rawResponse := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{RetryAfterHeaderName: []string{"20"}}, + Body: io.NopCloser(bytes.NewBuffer(nil)), // Initialize response body } - if retry := serviceUnavailableRetryCondition(&response, nil); !retry { + if retry := ServiceUnavailableRetryCondition(rawResponse, nil); !retry { t.Error("expected request to be retried") } - if retryAfter, err := respectRetryAfter(NewClient(nil).resty, &response); err != nil { + if retryAfter, err := RespectRetryAfter(rawResponse); err != nil { t.Errorf("expected error to be nil but got %s", err) } else if retryAfter != time.Second*20 { t.Errorf("expected retryAfter to be 20 but got %d", retryAfter) } } -func TestLinodeServiceMaintenanceModeRetryCondition(t *testing.T) { - request := resty.Request{} - rawResponse := http.Response{StatusCode: http.StatusServiceUnavailable, Header: http.Header{ - retryAfterHeaderName: []string{"20"}, - maintenanceModeHeaderName: []string{"Currently in maintenance mode."}, - }} - response := resty.Response{ - Request: &request, - RawResponse: &rawResponse, +func TestServiceMaintenanceModeRetryCondition(t *testing.T) { + rawResponse := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{ + RetryAfterHeaderName: []string{"20"}, + MaintenanceModeHeaderName: []string{"Currently in maintenance mode."}, + }, + Body: io.NopCloser(bytes.NewBuffer(nil)), // Initialize response body } - if retry := serviceUnavailableRetryCondition(&response, nil); retry { + if retry := ServiceUnavailableRetryCondition(rawResponse, nil); retry { t.Error("expected retry to be skipped due to maintenance mode header") } } + +// Helper function to create a response body from an object +func createResponseBody(obj interface{}) io.ReadCloser { + body, err := json.Marshal(obj) + if err != nil { + panic(err) + } + return io.NopCloser(bytes.NewBuffer(body)) +} diff --git a/test/go.mod b/test/go.mod index c1f5dd3a2..a5a9b6ec2 100644 --- a/test/go.mod +++ b/test/go.mod @@ -20,7 +20,6 @@ require ( github.com/go-openapi/jsonpointer v0.19.6 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/swag v0.22.3 // indirect - github.com/go-resty/resty/v2 v2.17.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect diff --git a/test/go.sum b/test/go.sum index 9689a3a72..c74c0080a 100644 --- a/test/go.sum +++ b/test/go.sum @@ -14,8 +14,6 @@ github.com/go-openapi/jsonreference v0.20.2 h1:3sVjiK66+uXK/6oQ8xgcRKcFgQ5KXa2Kv github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k= github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g= github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= -github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk= -github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= diff --git a/test/integration/cache_test.go b/test/integration/cache_test.go index fd4207de2..6f8385997 100644 --- a/test/integration/cache_test.go +++ b/test/integration/cache_test.go @@ -27,9 +27,19 @@ func TestCache_RegionList(t *testing.T) { // Collect request number totalRequests := int64(0) + //client.OnBeforeRequest(func(request *linodego.Request) error { + // fmt.Printf("Request URL: %s\n", request.URL.String()) // Log the URL + // fmt.Printf("Page: %s\n", request.URL.Query().Get("page")) // Log the page query parameter + // if !strings.Contains(request.URL.String(), "regions") || request.URL.Query().Get("page") != "1" { + // return nil + // } + // + // atomic.AddInt64(&totalRequests, 1) + // return nil + //}) client.OnBeforeRequest(func(request *linodego.Request) error { - page := request.QueryParam.Get("page") - if !strings.Contains(request.URL, "regions") || page != "1" { + page := request.URL.Query().Get("page") + if !strings.Contains(request.URL.String(), "regions") || page != "1" { return nil } @@ -91,8 +101,8 @@ func TestCache_Expiration(t *testing.T) { totalRequests := int64(0) client.OnBeforeRequest(func(request *linodego.Request) error { - page := request.QueryParam.Get("page") - if !strings.Contains(request.URL, "kernels") || page != "1" { + page := request.URL.Query().Get("page") + if !strings.Contains(request.URL.String(), "kernels") || page != "1" { return nil } diff --git a/test/integration/fixtures/TestMaintenancePolicies_List.yaml b/test/integration/fixtures/TestMaintenancePolicies_List.yaml index 54b36a96c..e700c5a77 100644 --- a/test/integration/fixtures/TestMaintenancePolicies_List.yaml +++ b/test/integration/fixtures/TestMaintenancePolicies_List.yaml @@ -11,17 +11,17 @@ interactions: - application/json User-Agent: - linodego/dev https://github.com/linode/linodego - url: https://api.linode.com/v4beta/maintenance/policies + url: https://api.linode.com/v4beta/maintenance/policies?page=1 method: GET response: body: '{"data": [{"slug": "linode/migrate", "label": "Migrate", "description": "Migrates the Linode to a new host while it remains fully operational. Recommended for maximizing availability.", "type": "migrate", "notification_period_sec": - 300, "is_default": true}, {"slug": "linode/power_off_on", "label": "Power-off/on", - "description": "Powers off the Linode at the start of the maintenance event - and reboots it once the maintenance finishes. Recommended for maximizing performance.", - "type": "power_off_on", "notification_period_sec": 1800, "is_default": false}], - "page": 1, "pages": 1, "results": 2}' + 10800, "is_default": true}, {"slug": "linode/power_off_on", "label": "Power + Off / Power On", "description": "Powers off the Linode at the start of the maintenance + event and reboots it once the maintenance finishes. Recommended for maximizing + performance.", "type": "power_off_on", "notification_period_sec": 604800, "is_default": + false}], "page": 1, "pages": 1, "results": 2}' headers: Access-Control-Allow-Credentials: - "true" @@ -33,22 +33,24 @@ interactions: - '*' Access-Control-Expose-Headers: - X-OAuth-Scopes, X-Accepted-OAuth-Scopes, X-Status + Akamai-Internal-Account: + - '*' Cache-Control: - - private, max-age=0, s-maxage=0, no-cache, no-store - - private, max-age=60, s-maxage=60 + - max-age=0, no-cache, no-store Connection: - keep-alive Content-Length: - - "595" + - "607" Content-Security-Policy: - default-src 'none' Content-Type: - application/json - Server: - - nginx/1.18.0 + Expires: + - Tue, 21 Apr 2026 18:20:32 GMT + Pragma: + - no-cache Strict-Transport-Security: - max-age=31536000 - - max-age=31536000 Vary: - Authorization, X-Filter - Authorization, X-Filter @@ -62,7 +64,7 @@ interactions: X-Oauth-Scopes: - unknown X-Ratelimit-Limit: - - "400" + - "1840" X-Xss-Protection: - 1; mode=block status: 200 OK diff --git a/test/integration/maintenance_test.go b/test/integration/maintenance_test.go index 8ce207aad..b8d22650c 100644 --- a/test/integration/maintenance_test.go +++ b/test/integration/maintenance_test.go @@ -2,10 +2,8 @@ package integration import ( "context" - "encoding/json" "testing" - "github.com/linode/linodego" "github.com/stretchr/testify/require" ) @@ -13,21 +11,7 @@ func TestMaintenancePolicies_List(t *testing.T) { client, fixtureTeardown := createTestClient(t, "fixtures/TestMaintenancePolicies_List") defer fixtureTeardown() - resp, err := client.R(context.Background()).Get("maintenance/policies") - require.NoError(t, err) - - var result map[string]any - err = json.Unmarshal(resp.Body(), &result) - require.NoError(t, err) - - dataRaw, ok := result["data"] - require.True(t, ok, "Expected 'data' key in response") - - dataJSON, err := json.Marshal(dataRaw) - require.NoError(t, err) - - var policies []linodego.MaintenancePolicy - err = json.Unmarshal(dataJSON, &policies) + policies, err := client.ListMaintenancePolicies(context.Background(), nil) require.NoError(t, err) if len(policies) == 0 { diff --git a/test/unit/images_test.go b/test/unit/images_test.go index dc9aa9ba7..e89cec8ef 100644 --- a/test/unit/images_test.go +++ b/test/unit/images_test.go @@ -271,6 +271,14 @@ func TestImage_Upload(t *testing.T) { base.MockPost("images/upload", fixtureData) + // Mock the PUT request to the upload URL returned in the fixture. + // UploadImageToURL uses http.DefaultTransport, so we need to + // activate httpmock on the default transport as well. + httpmock.Activate() + defer httpmock.DeactivateAndReset() + httpmock.RegisterResponder("PUT", "https://example.com/upload-endpoint", + httpmock.NewStringResponder(200, "{}")) + image, err := base.Client.UploadImage(context.Background(), requestData) assert.NoError(t, err) From 1ae5e506c4634034d4d9e9fec38658203268a299 Mon Sep 17 00:00:00 2001 From: ezilber-akamai Date: Tue, 21 Apr 2026 16:05:47 -0400 Subject: [PATCH 2/9] Fixed lint --- client.go | 721 +++++++++++++++++++------------------- client_monitor.go | 6 + errors.go | 8 +- internal/testutil/mock.go | 6 +- logger.go | 14 +- pagination.go | 3 + request_helpers.go | 11 + retries.go | 5 +- 8 files changed, 399 insertions(+), 375 deletions(-) diff --git a/client.go b/client.go index 575156fb5..a8e60740e 100644 --- a/client.go +++ b/client.go @@ -150,6 +150,116 @@ func init() { } } +// NewClient factory to create new Client struct +// nolint:funlen +func NewClient(hc *http.Client) (client Client) { + if hc != nil { + client.httpClient = hc + } else { + client.httpClient = &http.Client{} + } + + // Ensure that the Header map is not nil + if client.httpClient.Transport == nil { + client.httpClient.Transport = &http.Transport{} + } + + client.shouldCache = true + client.cacheExpiration = APIDefaultCacheExpiration + client.cachedEntries = make(map[string]clientCacheEntry) + client.cachedEntryLock = &sync.RWMutex{} + client.configProfiles = make(map[string]ConfigProfile) + + const ( + retryMinWaitDuration = 100 * time.Millisecond + retryMaxWaitDuration = 2 * time.Second + ) + + client.retryMinWaitTime = retryMinWaitDuration + client.retryMaxWaitTime = retryMaxWaitDuration + + client.SetUserAgent(DefaultUserAgent) + client.SetLogger(createLogger()) + + baseURL, baseURLExists := os.LookupEnv(APIHostVar) + if baseURLExists { + client.SetBaseURL(baseURL) + } + + apiVersion, apiVersionExists := os.LookupEnv(APIVersionVar) + if apiVersionExists { + client.SetAPIVersion(apiVersion) + } else { + client.SetAPIVersion(APIVersion) + } + + certPath, certPathExists := os.LookupEnv(APIHostCert) + if certPathExists { + cert, err := os.ReadFile(filepath.Clean(certPath)) + if err != nil { + log.Fatalf("[ERROR] Error when reading cert at %s: %s\n", certPath, err.Error()) + } + + client.SetRootCertificate(certPath) + + if envDebug { + log.Printf("[DEBUG] Set API root certificate to %s with contents %s\n", certPath, cert) + } + } + + client. + SetRetryWaitTime(APISecondsPerPoll * time.Second). + SetPollDelay(APISecondsPerPoll * time.Second). + SetRetries(). + SetLogger(createLogger()). + SetDebug(envDebug). + enableLogSanitization() + + return client +} + +// NewClientFromEnv creates a Client and initializes it with values +// from the LINODE_CONFIG file and the LINODE_TOKEN environment variable. +func NewClientFromEnv(hc *http.Client) (*Client, error) { + client := NewClient(hc) + + // Users are expected to chain NewClient(...) and LoadConfig(...) to customize these options + configPath, err := resolveValidConfigPath() + if err != nil { + return nil, err + } + + // Populate the token from the environment. + // Tokens should be first priority to maintain backwards compatibility + if token, ok := os.LookupEnv(APIEnvVar); ok && token != "" { + client.SetToken(token) + return &client, nil + } + + if p, ok := os.LookupEnv(APIConfigEnvVar); ok { + configPath = p + } else if !ok && configPath == "" { + return nil, fmt.Errorf("no linode config file or token found") + } + + configProfile := DefaultConfigProfile + + if p, ok := os.LookupEnv(APIConfigProfileEnvVar); ok { + configProfile = p + } + + client.selectedProfile = configProfile + + // We should only load the config if the config file exists + if _, statErr := os.Stat(configPath); statErr != nil { + return nil, fmt.Errorf("error loading config file %s: %w", configPath, statErr) + } + + err = client.preLoadConfig(configPath) + + return &client, err +} + // SetUserAgent sets a custom user-agent for HTTP requests func (c *Client) SetUserAgent(ua string) *Client { c.userAgent = ua @@ -163,6 +273,227 @@ type requestParams struct { Response any } +func (c *Client) ErrorAndLogf(format string, args ...any) error { + if c.debug && c.logger != nil { + c.logger.Errorf(format, args...) + } + + return fmt.Errorf(format, args...) +} + +// SetRootCertificate adds a root certificate to the underlying TLS client config +func (c *Client) SetRootCertificate(path string) *Client { + config, err := c.tlsConfig() + if err != nil { + log.Println("[WARN] Custom transport is not allowed with a custom root CA") + return c + } + + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + + config.RootCAs.AppendCertsFromPEM([]byte(path)) + + return c +} + +// SetToken sets the API token for all requests from this client +// Only necessary if you haven't already provided the http client to NewClient() configured with the token. +func (c *Client) SetToken(token string) *Client { + c.SetHeader("Authorization", fmt.Sprintf("Bearer %s", token)) + return c +} + +// SetRetries adds retry conditions for "Linode Busy." errors and 429s. +func (c *Client) SetRetries() *Client { + c. + AddRetryCondition(LinodeBusyRetryCondition). + AddRetryCondition(TooManyRequestsRetryCondition). + AddRetryCondition(ServiceUnavailableRetryCondition). + AddRetryCondition(RequestTimeoutRetryCondition). + AddRetryCondition(RequestGOAWAYRetryCondition). + AddRetryCondition(RequestNGINXRetryCondition). + SetRetryMaxWaitTime(APIRetryMaxWaitTime) + ConfigureRetries(c) + + return c +} + +// AddRetryCondition adds a RetryConditional function to the Client +func (c *Client) AddRetryCondition(retryCondition RetryConditional) *Client { + c.retryConditionals = append(c.retryConditionals, retryCondition) + + return c +} + +func (c *Client) SetDebug(debug bool) *Client { + c.debug = debug + + return c +} + +func (c *Client) SetLogger(logger Logger) *Client { + c.logger = logger + + return c +} + +func (c *Client) OnBeforeRequest(m func(*http.Request) error) { + c.onBeforeRequest = append(c.onBeforeRequest, m) +} + +func (c *Client) OnAfterResponse(m func(*http.Response) error) { + c.onAfterResponse = append(c.onAfterResponse, m) +} + +// UseURL parses the individual components of the given API URL and configures the client +// accordingly. For example, a valid URL. +// For example: +// +// client.UseURL("https://api.test.linode.com/v4beta") +func (c *Client) UseURL(apiURL string) (*Client, error) { + parsedURL, err := url.Parse(apiURL) + if err != nil { + return nil, fmt.Errorf("failed to parse URL: %w", err) + } + + if parsedURL.Scheme == "" || parsedURL.Host == "" { + return nil, fmt.Errorf("need both scheme and host in API URL, got %q", apiURL) + } + + // Create a new URL excluding the path to use as the base URL + baseURL := &url.URL{ + Host: parsedURL.Host, + Scheme: parsedURL.Scheme, + } + + c.SetBaseURL(baseURL.String()) + + versionMatches := regexp.MustCompile(`/v[a-zA-Z0-9]+`).FindAllString(parsedURL.Path, -1) + + // Only set the version if a version is found in the URL, else use the default + if len(versionMatches) > 0 { + c.SetAPIVersion( + strings.Trim(versionMatches[len(versionMatches)-1], "/"), + ) + } + + return c, nil +} + +func (c *Client) SetBaseURL(baseURL string) *Client { + baseURLPath, _ := url.Parse(baseURL) + + c.baseURL = path.Join(baseURLPath.Host, baseURLPath.Path) + c.apiProto = baseURLPath.Scheme + + c.updateHostURL() + + return c +} + +// SetAPIVersion sets the version of the API to interface with +func (c *Client) SetAPIVersion(apiVersion string) *Client { + c.apiVersion = apiVersion + + c.updateHostURL() + + return c +} + +// InvalidateCache clears all cached responses for all endpoints. +func (c *Client) InvalidateCache() { + c.cachedEntryLock.Lock() + defer c.cachedEntryLock.Unlock() + + // GC will handle the old map + c.cachedEntries = make(map[string]clientCacheEntry) +} + +// InvalidateCacheEndpoint invalidates a single cached endpoint. +func (c *Client) InvalidateCacheEndpoint(endpoint string) error { + u, err := url.Parse(endpoint) + if err != nil { + return fmt.Errorf("failed to parse URL for caching: %w", err) + } + + c.cachedEntryLock.Lock() + defer c.cachedEntryLock.Unlock() + + delete(c.cachedEntries, u.Path) + + return nil +} + +// SetGlobalCacheExpiration sets the desired time for any cached response +// to be valid for. +func (c *Client) SetGlobalCacheExpiration(expiryTime time.Duration) { + c.cacheExpiration = expiryTime +} + +// UseCache sets whether response caching should be used +func (c *Client) UseCache(value bool) { + c.shouldCache = value +} + +// SetRetryMaxWaitTime sets the maximum delay before retrying a request. +func (c *Client) SetRetryMaxWaitTime(maxWaitTime time.Duration) *Client { + c.retryMaxWaitTime = maxWaitTime + return c +} + +// SetRetryWaitTime sets the default (minimum) delay before retrying a request. +func (c *Client) SetRetryWaitTime(minWaitTime time.Duration) *Client { + c.retryMinWaitTime = minWaitTime + return c +} + +// SetRetryAfter sets the callback function to be invoked with a failed request +// to determine wben it should be retried. +func (c *Client) SetRetryAfter(callback RetryAfter) *Client { + c.retryAfter = callback + return c +} + +// SetRetryCount sets the maximum retry attempts before aborting. +func (c *Client) SetRetryCount(count int) *Client { + c.retryCount = count + return c +} + +// SetPollDelay sets the number of milliseconds to wait between events or status polls. +// Affects all WaitFor* functions and retries. +func (c *Client) SetPollDelay(delay time.Duration) *Client { + c.pollInterval = delay + return c +} + +// GetPollDelay gets the number of milliseconds to wait between events or status polls. +// Affects all WaitFor* functions and retries. +func (c *Client) GetPollDelay() time.Duration { + return c.pollInterval +} + +// SetHeader sets a custom header to be used in all API requests made with the current +// client. +// NOTE: Some headers may be overridden by the individual request functions. +func (c *Client) SetHeader(name, value string) { + if c.header == nil { + c.header = make(http.Header) // Initialize header if nil + } + + c.header.Set(name, value) +} + +func (c *Client) Transport() (*http.Transport, error) { + if transport, ok := c.httpClient.Transport.(*http.Transport); ok { + return transport, nil + } + + return nil, fmt.Errorf("current transport is not an *http.Transport instance") +} + // Generic helper to execute HTTP requests using the net/http package // // nolint:funlen, gocognit, nestif @@ -177,9 +508,8 @@ func (c *Client) doRequest(ctx context.Context, method, endpoint string, params for range c.retryCount { // Reset the body to the start for each retry if it's not nil if params.Body != nil { - _, err := params.Body.Seek(0, io.SeekStart) - if err != nil { - return c.ErrorAndLogf("failed to seek to the start of the body: %v", err.Error()) + if _, seekErr := params.Body.Seek(0, io.SeekStart); seekErr != nil { + return c.ErrorAndLogf("failed to seek to the start of the body: %v", seekErr.Error()) } } @@ -189,8 +519,8 @@ func (c *Client) doRequest(ctx context.Context, method, endpoint string, params } if paginationMutator != nil { - if err := (*paginationMutator)(req); err != nil { - return c.ErrorAndLogf("failed to mutate before request: %v", err.Error()) + if mutErr := (*paginationMutator)(req); mutErr != nil { + return c.ErrorAndLogf("failed to mutate before request: %v", mutErr.Error()) } } @@ -215,12 +545,7 @@ func (c *Client) doRequest(ctx context.Context, method, endpoint string, params } if c.debug && c.logger != nil { - var logErr error - - resp, logErr = c.logResponse(resp) - if logErr != nil { - return logErr - } + resp = c.logResponse(resp) } if params.Response != nil { @@ -284,8 +609,10 @@ func (c *Client) shouldRetry(resp *http.Response, err error) bool { } func (c *Client) createRequest(ctx context.Context, method, endpoint string, params requestParams) (*http.Request, *bytes.Buffer, error) { - var bodyReader io.Reader - var bodyBuffer *bytes.Buffer + var ( + bodyReader io.Reader + bodyBuffer *bytes.Buffer + ) if params.Body != nil { // Reset the body position to the start before using it @@ -374,7 +701,8 @@ func (c *Client) logRequest(req *http.Request, method, url string, bodyBuffer *b } var logBuf bytes.Buffer - err := reqLogTemplate.Execute(&logBuf, map[string]interface{}{ + + err := reqLogTemplate.Execute(&logBuf, map[string]any{ "Method": reqLog.Method, "URL": reqLog.URL, "Headers": reqLog.Headers, @@ -386,7 +714,7 @@ func (c *Client) logRequest(req *http.Request, method, url string, bodyBuffer *b } func (c *Client) sendRequest(req *http.Request) (*http.Response, error) { - resp, err := c.httpClient.Do(req) + resp, err := c.httpClient.Do(req) //#nosec G704 // URL is constructed from client-configured base URL + endpoint if err != nil { return nil, c.ErrorAndLogf("failed to send request: %w", err) } @@ -404,14 +732,15 @@ func (c *Client) checkHTTPError(resp *http.Response) error { return nil } -func (c *Client) logResponse(resp *http.Response) (*http.Response, error) { +func (c *Client) logResponse(resp *http.Response) *http.Response { var respBody bytes.Buffer if _, err := io.Copy(&respBody, resp.Body); err != nil { c.logger.Errorf("failed to read response body: %v", err) } var logBuf bytes.Buffer - err := respLogTemplate.Execute(&logBuf, map[string]interface{}{ + + err := respLogTemplate.Execute(&logBuf, map[string]any{ "Status": resp.Status, "Headers": redactHeaders(resp.Header), "Body": respBody.String(), @@ -422,10 +751,10 @@ func (c *Client) logResponse(resp *http.Response) (*http.Response, error) { resp.Body = io.NopCloser(bytes.NewReader(respBody.Bytes())) - return resp, nil + return resp } -func (c *Client) decodeResponseBody(resp *http.Response, response interface{}) error { +func (c *Client) decodeResponseBody(resp *http.Response, response any) error { if err := json.NewDecoder(resp.Body).Decode(response); err != nil { return c.ErrorAndLogf("failed to decode response: %v", err.Error()) } @@ -433,81 +762,6 @@ func (c *Client) decodeResponseBody(resp *http.Response, response interface{}) e return nil } -func (c *Client) SetDebug(debug bool) *Client { - c.debug = debug - - return c -} - -func (c *Client) SetLogger(logger Logger) *Client { - c.logger = logger - - return c -} - -func (c *Client) OnBeforeRequest(m func(*http.Request) error) { - c.onBeforeRequest = append(c.onBeforeRequest, m) -} - -func (c *Client) OnAfterResponse(m func(*http.Response) error) { - c.onAfterResponse = append(c.onAfterResponse, m) -} - -// UseURL parses the individual components of the given API URL and configures the client -// accordingly. For example, a valid URL. -// For example: -// -// client.UseURL("https://api.test.linode.com/v4beta") -func (c *Client) UseURL(apiURL string) (*Client, error) { - parsedURL, err := url.Parse(apiURL) - if err != nil { - return nil, fmt.Errorf("failed to parse URL: %w", err) - } - - if parsedURL.Scheme == "" || parsedURL.Host == "" { - return nil, fmt.Errorf("need both scheme and host in API URL, got %q", apiURL) - } - - // Create a new URL excluding the path to use as the base URL - baseURL := &url.URL{ - Host: parsedURL.Host, - Scheme: parsedURL.Scheme, - } - - c.SetBaseURL(baseURL.String()) - - versionMatches := regexp.MustCompile(`/v[a-zA-Z0-9]+`).FindAllString(parsedURL.Path, -1) - - // Only set the version if a version is found in the URL, else use the default - if len(versionMatches) > 0 { - c.SetAPIVersion( - strings.Trim(versionMatches[len(versionMatches)-1], "/"), - ) - } - - return c, nil -} - -func (c *Client) SetBaseURL(baseURL string) *Client { - baseURLPath, _ := url.Parse(baseURL) - - c.baseURL = path.Join(baseURLPath.Host, baseURLPath.Path) - c.apiProto = baseURLPath.Scheme - - c.updateHostURL() - - return c -} - -// SetAPIVersion sets the version of the API to interface with -func (c *Client) SetAPIVersion(apiVersion string) *Client { - c.apiVersion = apiVersion - - c.updateHostURL() - - return c -} - func (c *Client) updateHostURL() { apiProto := APIProto baseURL := APIHost @@ -520,19 +774,12 @@ func (c *Client) updateHostURL() { if c.apiVersion != "" { apiVersion = c.apiVersion } - - if c.apiProto != "" { - apiProto = c.apiProto - } - - c.hostURL = strings.TrimRight(fmt.Sprintf("%s://%s/%s", apiProto, baseURL, url.PathEscape(apiVersion)), "/") -} - -func (c *Client) Transport() (*http.Transport, error) { - if transport, ok := c.httpClient.Transport.(*http.Transport); ok { - return transport, nil + + if c.apiProto != "" { + apiProto = c.apiProto } - return nil, fmt.Errorf("current transport is not an *http.Transport instance") + + c.hostURL = strings.TrimRight(fmt.Sprintf("%s://%s/%s", apiProto, baseURL, url.PathEscape(apiVersion)), "/") } func (c *Client) tlsConfig() (*tls.Config, error) { @@ -540,54 +787,14 @@ func (c *Client) tlsConfig() (*tls.Config, error) { if err != nil { return nil, err } + if transport.TLSClientConfig == nil { transport.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, } } - return transport.TLSClientConfig, nil -} - -// SetRootCertificate adds a root certificate to the underlying TLS client config -func (c *Client) SetRootCertificate(path string) *Client { - config, err := c.tlsConfig() - if err != nil { - log.Println("[WARN] Custom transport is not allowed with a custom root CA") - return c - } - if config.RootCAs == nil { - config.RootCAs = x509.NewCertPool() - } - - config.RootCAs.AppendCertsFromPEM([]byte(path)) - return c -} - -// SetToken sets the API token for all requests from this client -// Only necessary if you haven't already provided the http client to NewClient() configured with the token. -func (c *Client) SetToken(token string) *Client { - c.SetHeader("Authorization", fmt.Sprintf("Bearer %s", token)) - return c -} - -// SetRetries adds retry conditions for "Linode Busy." errors and 429s. -func (c *Client) SetRetries() *Client { - c. - AddRetryCondition(LinodeBusyRetryCondition). - AddRetryCondition(TooManyRequestsRetryCondition). - AddRetryCondition(ServiceUnavailableRetryCondition). - AddRetryCondition(RequestTimeoutRetryCondition). - AddRetryCondition(RequestGOAWAYRetryCondition). - AddRetryCondition(RequestNGINXRetryCondition). - SetRetryMaxWaitTime(APIRetryMaxWaitTime) - ConfigureRetries(c) - return c -} -// AddRetryCondition adds a RetryConditional function to the Client -func (c *Client) AddRetryCondition(retryCondition RetryConditional) *Client { - c.retryConditionals = append(c.retryConditionals, retryCondition) - return c + return transport.TLSClientConfig, nil } func (c *Client) addCachedResponse(endpoint string, response any, expiry *time.Duration) { @@ -630,6 +837,7 @@ func (c *Client) getCachedResponse(endpoint string) any { // This is necessary as we take write // access if the entry has expired. rLocked := true + defer func() { if rLocked { c.cachedEntryLock.RUnlock() @@ -652,111 +860,32 @@ func (c *Client) getCachedResponse(endpoint string) any { if hasExpired { // We need to give up our read access and request read-write access c.cachedEntryLock.RUnlock() + rLocked = false c.cachedEntryLock.Lock() defer c.cachedEntryLock.Unlock() delete(c.cachedEntries, endpoint) + return nil } return c.cachedEntries[endpoint].Data } -// InvalidateCache clears all cached responses for all endpoints. -func (c *Client) InvalidateCache() { - c.cachedEntryLock.Lock() - defer c.cachedEntryLock.Unlock() - - // GC will handle the old map - c.cachedEntries = make(map[string]clientCacheEntry) -} - -// InvalidateCacheEndpoint invalidates a single cached endpoint. -func (c *Client) InvalidateCacheEndpoint(endpoint string) error { - u, err := url.Parse(endpoint) - if err != nil { - return fmt.Errorf("failed to parse URL for caching: %w", err) - } - - c.cachedEntryLock.Lock() - defer c.cachedEntryLock.Unlock() - - delete(c.cachedEntries, u.Path) - - return nil -} - -// SetGlobalCacheExpiration sets the desired time for any cached response -// to be valid for. -func (c *Client) SetGlobalCacheExpiration(expiryTime time.Duration) { - c.cacheExpiration = expiryTime -} - -// UseCache sets whether response caching should be used -func (c *Client) UseCache(value bool) { - c.shouldCache = value -} - -// SetRetryMaxWaitTime sets the maximum delay before retrying a request. -func (c *Client) SetRetryMaxWaitTime(maxWaitTime time.Duration) *Client { - c.retryMaxWaitTime = maxWaitTime - return c -} - -// SetRetryWaitTime sets the default (minimum) delay before retrying a request. -func (c *Client) SetRetryWaitTime(minWaitTime time.Duration) *Client { - c.retryMinWaitTime = minWaitTime - return c -} - -// SetRetryAfter sets the callback function to be invoked with a failed request -// to determine wben it should be retried. -func (c *Client) SetRetryAfter(callback RetryAfter) *Client { - c.retryAfter = callback - return c -} - -// SetRetryCount sets the maximum retry attempts before aborting. -func (c *Client) SetRetryCount(count int) *Client { - c.retryCount = count - return c -} - -// SetPollDelay sets the number of milliseconds to wait between events or status polls. -// Affects all WaitFor* functions and retries. -func (c *Client) SetPollDelay(delay time.Duration) *Client { - c.pollInterval = delay - return c -} - -// GetPollDelay gets the number of milliseconds to wait between events or status polls. -// Affects all WaitFor* functions and retries. -func (c *Client) GetPollDelay() time.Duration { - return c.pollInterval -} - -// SetHeader sets a custom header to be used in all API requests made with the current -// client. -// NOTE: Some headers may be overridden by the individual request functions. -func (c *Client) SetHeader(name, value string) { - if c.header == nil { - c.header = make(http.Header) // Initialize header if nil - } - c.header.Set(name, value) -} - func (c *Client) onRequestLog(rl func(*RequestLog) error) *Client { if c.requestLog != nil { c.logger.Warnf("Overwriting an existing on-request-log callback from=%s to=%s", functionName(c.requestLog), functionName(rl)) } + c.requestLog = rl + return c } -func functionName(i interface{}) string { +func functionName(i any) string { return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() } @@ -770,114 +899,6 @@ func (c *Client) enableLogSanitization() *Client { return c } -// NewClient factory to create new Client struct -// nolint:funlen -func NewClient(hc *http.Client) (client Client) { - if hc != nil { - client.httpClient = hc - } else { - client.httpClient = &http.Client{} - } - - // Ensure that the Header map is not nil - if client.httpClient.Transport == nil { - client.httpClient.Transport = &http.Transport{} - } - - client.shouldCache = true - client.cacheExpiration = APIDefaultCacheExpiration - client.cachedEntries = make(map[string]clientCacheEntry) - client.cachedEntryLock = &sync.RWMutex{} - client.configProfiles = make(map[string]ConfigProfile) - - const ( - retryMinWaitDuration = 100 * time.Millisecond - retryMaxWaitDuration = 2 * time.Second - ) - - client.retryMinWaitTime = retryMinWaitDuration - client.retryMaxWaitTime = retryMaxWaitDuration - - client.SetUserAgent(DefaultUserAgent) - client.SetLogger(createLogger()) - - baseURL, baseURLExists := os.LookupEnv(APIHostVar) - if baseURLExists { - client.SetBaseURL(baseURL) - } - apiVersion, apiVersionExists := os.LookupEnv(APIVersionVar) - if apiVersionExists { - client.SetAPIVersion(apiVersion) - } else { - client.SetAPIVersion(APIVersion) - } - - certPath, certPathExists := os.LookupEnv(APIHostCert) - if certPathExists { - cert, err := os.ReadFile(filepath.Clean(certPath)) - if err != nil { - log.Fatalf("[ERROR] Error when reading cert at %s: %s\n", certPath, err.Error()) - } - - client.SetRootCertificate(certPath) - - if envDebug { - log.Printf("[DEBUG] Set API root certificate to %s with contents %s\n", certPath, cert) - } - } - - client. - SetRetryWaitTime(APISecondsPerPoll * time.Second). - SetPollDelay(APISecondsPerPoll * time.Second). - SetRetries(). - SetLogger(createLogger()). - SetDebug(envDebug). - enableLogSanitization() - - return -} - -// NewClientFromEnv creates a Client and initializes it with values -// from the LINODE_CONFIG file and the LINODE_TOKEN environment variable. -func NewClientFromEnv(hc *http.Client) (*Client, error) { - client := NewClient(hc) - - // Users are expected to chain NewClient(...) and LoadConfig(...) to customize these options - configPath, err := resolveValidConfigPath() - if err != nil { - return nil, err - } - - // Populate the token from the environment. - // Tokens should be first priority to maintain backwards compatibility - if token, ok := os.LookupEnv(APIEnvVar); ok && token != "" { - client.SetToken(token) - return &client, nil - } - - if p, ok := os.LookupEnv(APIConfigEnvVar); ok { - configPath = p - } else if !ok && configPath == "" { - return nil, fmt.Errorf("no linode config file or token found") - } - - configProfile := DefaultConfigProfile - - if p, ok := os.LookupEnv(APIConfigProfileEnvVar); ok { - configProfile = p - } - - client.selectedProfile = configProfile - - // We should only load the config if the config file exists - if _, err := os.Stat(configPath); err != nil { - return nil, fmt.Errorf("error loading config file %s: %w", configPath, err) - } - - err = client.preLoadConfig(configPath) - return &client, err -} - func (c *Client) preLoadConfig(configPath string) error { if envDebug { log.Printf("[INFO] Loading profile from %s\n", configPath) @@ -968,23 +989,3 @@ func generateListCacheURL(endpoint string, opts *ListOptions) (string, error) { return fmt.Sprintf("%s:%s", endpoint, hashedOpts), nil } - -func (c *Client) ErrorAndLogf(format string, args ...interface{}) error { - if c.debug && c.logger != nil { - c.logger.Errorf(format, args...) - } - return fmt.Errorf(format, args...) -} - -func hasCustomTransport(hc *http.Client) bool { - if hc == nil || hc.Transport == nil { - return false - } - - if _, ok := hc.Transport.(*http.Transport); !ok { - log.Println("[WARN] Custom transport is not allowed with a custom root CA.") - return true - } - - return false -} diff --git a/client_monitor.go b/client_monitor.go index 0cdfe4fa7..866e6a4dd 100644 --- a/client_monitor.go +++ b/client_monitor.go @@ -129,15 +129,19 @@ func (mc *MonitorClient) SetRootCertificate(certPath string) *MonitorClient { mc.logger.Errorf("current transport is not an *http.Transport instance") return mc } + if transport.TLSClientConfig == nil { transport.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, } } + if transport.TLSClientConfig.RootCAs == nil { transport.TLSClientConfig.RootCAs = x509.NewCertPool() } + transport.TLSClientConfig.RootCAs.AppendCertsFromPEM([]byte(certPath)) + return mc } @@ -181,10 +185,12 @@ func (mc *MonitorClient) updateMonitorHostURL() { // doRequest is a generic helper to execute HTTP requests for the MonitorClient func (mc *MonitorClient) doRequest(ctx context.Context, method, endpoint string, params requestParams) error { var bodyReader io.Reader + if params.Body != nil { if _, err := params.Body.Seek(0, io.SeekStart); err != nil { return fmt.Errorf("failed to seek body: %w", err) } + bodyReader = params.Body } diff --git a/errors.go b/errors.go index 62586a477..513ea1db8 100644 --- a/errors.go +++ b/errors.go @@ -47,7 +47,7 @@ type APIError struct { Errors []APIErrorReason `json:"errors"` } -//nolint:nestif +//nolint:nestif,unparam func coupleAPIErrors(resp *http.Response, err error) (*http.Response, error) { if err != nil { return nil, NewError(err) @@ -77,9 +77,9 @@ func coupleAPIErrors(resp *http.Response, err error) (*http.Response, error) { return nil, NewError(fmt.Errorf("response body is nil")) } - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return nil, NewError(fmt.Errorf("failed to read response body: %w", err)) + bodyBytes, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, NewError(fmt.Errorf("failed to read response body: %w", readErr)) } resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) diff --git a/internal/testutil/mock.go b/internal/testutil/mock.go index c32bcae42..752b27f9f 100644 --- a/internal/testutil/mock.go +++ b/internal/testutil/mock.go @@ -115,15 +115,15 @@ type TestLogger struct { L *log.Logger } -func (l *TestLogger) Errorf(format string, v ...interface{}) { +func (l *TestLogger) Errorf(format string, v ...any) { l.outputf("ERROR "+format, v...) } -func (l *TestLogger) Warnf(format string, v ...interface{}) { +func (l *TestLogger) Warnf(format string, v ...any) { l.outputf("WARN "+format, v...) } -func (l *TestLogger) Debugf(format string, v ...interface{}) { +func (l *TestLogger) Debugf(format string, v ...any) { l.outputf("DEBUG "+format, v...) } diff --git a/logger.go b/logger.go index 07e2f99f7..59b947381 100644 --- a/logger.go +++ b/logger.go @@ -6,9 +6,9 @@ import ( ) type Logger interface { - Errorf(format string, v ...interface{}) - Warnf(format string, v ...interface{}) - Debugf(format string, v ...interface{}) + Errorf(format string, v ...any) + Warnf(format string, v ...any) + Debugf(format string, v ...any) } type logger struct { @@ -22,19 +22,19 @@ func createLogger() *logger { var _ Logger = (*logger)(nil) -func (l *logger) Errorf(format string, v ...interface{}) { +func (l *logger) Errorf(format string, v ...any) { l.output("ERROR "+format, v...) } -func (l *logger) Warnf(format string, v ...interface{}) { +func (l *logger) Warnf(format string, v ...any) { l.output("WARN "+format, v...) } -func (l *logger) Debugf(format string, v ...interface{}) { +func (l *logger) Debugf(format string, v ...any) { l.output("DEBUG "+format, v...) } -func (l *logger) output(format string, v ...interface{}) { //nolint:goprintffuncname +func (l *logger) output(format string, v ...any) { //nolint:goprintffuncname if len(v) == 0 { l.l.Print(format) return diff --git a/pagination.go b/pagination.go index 2fdcb86fc..b0fec7040 100644 --- a/pagination.go +++ b/pagination.go @@ -70,6 +70,7 @@ func createListOptionsToRequestMutator(opts *ListOptions) func(*http.Request) er if err != nil { return fmt.Errorf("failed to apply list options: %w", err) } + for key, value := range params { query.Set(key, value) } @@ -79,6 +80,7 @@ func createListOptionsToRequestMutator(opts *ListOptions) func(*http.Request) er if opts.PageOptions != nil && opts.Page > 0 { query.Set("page", strconv.Itoa(opts.Page)) } + if opts.PageSize > 0 { query.Set("page_size", strconv.Itoa(opts.PageSize)) } @@ -90,6 +92,7 @@ func createListOptionsToRequestMutator(opts *ListOptions) func(*http.Request) er // Assign the updated query back to the request URL req.URL.RawQuery = query.Encode() + return nil } } diff --git a/request_helpers.go b/request_helpers.go index d172521e3..e50e4ec6d 100644 --- a/request_helpers.go +++ b/request_helpers.go @@ -164,6 +164,7 @@ func doGETRequest[T any]( endpoint string, ) (*T, error) { var resultType T + params := requestParams{ Response: &resultType, } @@ -185,6 +186,7 @@ func doPOSTRequest[T, O any]( options ...O, ) (*T, error) { var resultType T + numOpts := len(options) if numOpts > 1 { return nil, fmt.Errorf("invalid number of options: %d", numOpts) @@ -193,11 +195,13 @@ func doPOSTRequest[T, O any]( params := requestParams{ Response: &resultType, } + if numOpts > 0 && !isNil(options[0]) { body, err := json.Marshal(options[0]) if err != nil { return nil, err } + params.Body = bytes.NewReader(body) } @@ -205,6 +209,7 @@ func doPOSTRequest[T, O any]( if err != nil { return nil, err } + return &resultType, nil } @@ -217,6 +222,7 @@ func doPOSTRequestNoResponseBody[T any]( options ...T, ) error { _, err := doPOSTRequest[any, T](ctx, client, endpoint, options...) + return err } @@ -239,6 +245,7 @@ func doPUTRequest[T, O any]( options ...O, ) (*T, error) { var resultType T + numOpts := len(options) if numOpts > 1 { return nil, fmt.Errorf("invalid number of options: %d", numOpts) @@ -247,11 +254,13 @@ func doPUTRequest[T, O any]( params := requestParams{ Response: &resultType, } + if numOpts > 0 && !isNil(options[0]) { body, err := json.Marshal(options[0]) if err != nil { return nil, err } + params.Body = bytes.NewReader(body) } @@ -259,6 +268,7 @@ func doPUTRequest[T, O any]( if err != nil { return nil, err } + return &resultType, nil } @@ -271,6 +281,7 @@ func doDELETERequest( ) error { params := requestParams{} err := client.doRequest(ctx, http.MethodDelete, endpoint, params, nil) + return err } diff --git a/retries.go b/retries.go index 371ee03d5..24cbf2454 100644 --- a/retries.go +++ b/retries.go @@ -25,7 +25,7 @@ type RetryConditional func(*http.Response, error) bool // RetryAfter is a type alias for a function that determines the duration to wait before retrying based on the response. type RetryAfter func(*http.Response) (time.Duration, error) -// Configures http.Client to lock until enough time has passed to retry the request as determined by the Retry-After response header. +// ConfigureRetries configures http.Client to lock until enough time has passed to retry the request as determined by the Retry-After response header. // If the Retry-After header is not set, we fall back to the value of SetPollDelay. func ConfigureRetries(c *Client) { c.SetRetryAfter(RespectRetryAfter) @@ -49,6 +49,7 @@ func RespectRetryAfter(resp *http.Response) (time.Duration, error) { duration := time.Duration(retryAfter) * time.Second log.Printf("[INFO] Respecting Retry-After Header of %d (%s)", retryAfter, duration) + return duration, nil } @@ -62,6 +63,7 @@ func LinodeBusyRetryCondition(resp *http.Response, _ error) bool { apiError, ok := getAPIError(resp) linodeBusy := ok && apiError.Error() == "Linode busy." retry := resp.StatusCode == http.StatusBadRequest && linodeBusy + return retry } @@ -127,6 +129,7 @@ func getAPIError(resp *http.Response) (*APIError, bool) { resp.Body = io.NopCloser(bytes.NewReader(body)) var apiError APIError + err = json.Unmarshal(body, &apiError) if err != nil { return nil, false From 14522d79b49bc6e7211fab05d6c7ece6e2e142e9 Mon Sep 17 00:00:00 2001 From: Erik Zilber Date: Fri, 8 Nov 2024 14:05:35 -0500 Subject: [PATCH 3/9] Updated log formatting to more closely resemble Resty's --- client.go | 190 +++++++++++++++++++++++++++---------- client_test.go | 31 ++++-- request_log_template.tmpl | 8 ++ response_log_template.tmpl | 10 ++ 4 files changed, 182 insertions(+), 57 deletions(-) create mode 100644 request_log_template.tmpl create mode 100644 response_log_template.tmpl diff --git a/client.go b/client.go index a8e60740e..0520d3215 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "crypto/x509" + _ "embed" "encoding/json" "fmt" "io" @@ -17,6 +18,7 @@ import ( "reflect" "regexp" "runtime" + "sort" "strconv" "strings" "sync" @@ -51,31 +53,33 @@ const ( APIDefaultCacheExpiration = time.Minute * 15 ) +// Embed the log template files +// +//go:embed request_log_template.tmpl +var requestTemplateStr string + +//go:embed response_log_template.tmpl +var responseTemplateStr string + var ( - reqLogTemplate = template.Must(template.New("request").Parse(`Sending request: -Method: {{.Method}} -URL: {{.URL}} -Headers: {{.Headers}} -Body: {{.Body}}`)) - - respLogTemplate = template.Must(template.New("response").Parse(`Received response: -Status: {{.Status}} -Headers: {{.Headers}} -Body: {{.Body}}`)) + reqLogTemplate = template.Must(template.New("request").Parse(requestTemplateStr)) + respLogTemplate = template.Must(template.New("response").Parse(responseTemplateStr)) ) type RequestLog struct { - Method string - URL string + Request string + Host string Headers http.Header Body string } type ResponseLog struct { - Method string - URL string - Headers http.Header - Body string + Status string + Proto string + ReceivedAt string + TimeDuration string + Headers http.Header + Body string } var envDebug = false @@ -499,10 +503,9 @@ func (c *Client) Transport() (*http.Transport, error) { // nolint:funlen, gocognit, nestif func (c *Client) doRequest(ctx context.Context, method, endpoint string, params requestParams, paginationMutator *func(*http.Request) error) error { var ( - req *http.Request - bodyBuffer *bytes.Buffer - resp *http.Response - err error + req *http.Request + resp *http.Response + err error ) for range c.retryCount { @@ -513,7 +516,7 @@ func (c *Client) doRequest(ctx context.Context, method, endpoint string, params } } - req, bodyBuffer, err = c.createRequest(ctx, method, endpoint, params) + req, err = c.createRequest(ctx, method, endpoint, params) if err != nil { return err } @@ -529,10 +532,15 @@ func (c *Client) doRequest(ctx context.Context, method, endpoint string, params } if c.debug && c.logger != nil { - c.logRequest(req, method, endpoint, bodyBuffer) + loggedReq, logErr := c.logRequest(req) + if logErr != nil { + return logErr + } + + req = loggedReq } - processResponse := func() error { + processResponse := func(start, end time.Time) error { defer func() { closeErr := resp.Body.Close() if closeErr != nil && err == nil { @@ -545,7 +553,7 @@ func (c *Client) doRequest(ctx context.Context, method, endpoint string, params } if c.debug && c.logger != nil { - resp = c.logResponse(resp) + resp = c.logResponse(resp, start, end) } if params.Response != nil { @@ -562,9 +570,11 @@ func (c *Client) doRequest(ctx context.Context, method, endpoint string, params return nil } + startTime := time.Now() resp, err = c.sendRequest(req) + endTime := time.Now() if err == nil { - if err = processResponse(); err == nil { + if err = processResponse(startTime, endTime); err == nil { return nil } } @@ -608,17 +618,16 @@ func (c *Client) shouldRetry(resp *http.Response, err error) bool { return false } -func (c *Client) createRequest(ctx context.Context, method, endpoint string, params requestParams) (*http.Request, *bytes.Buffer, error) { +func (c *Client) createRequest(ctx context.Context, method, endpoint string, params requestParams) (*http.Request, error) { var ( bodyReader io.Reader - bodyBuffer *bytes.Buffer ) if params.Body != nil { // Reset the body position to the start before using it _, err := params.Body.Seek(0, io.SeekStart) if err != nil { - return nil, nil, c.ErrorAndLogf("failed to seek to the start of the body: %v", err.Error()) + return nil, c.ErrorAndLogf("failed to seek to the start of the body: %v", err.Error()) } bodyReader = params.Body @@ -627,7 +636,7 @@ func (c *Client) createRequest(ctx context.Context, method, endpoint string, par req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s/%s", strings.TrimRight(c.hostURL, "/"), strings.TrimLeft(endpoint, "/")), bodyReader) if err != nil { - return nil, nil, c.ErrorAndLogf("failed to create request: %v", err.Error()) + return nil, c.ErrorAndLogf("failed to create request: %v", err.Error()) } // Set the default headers @@ -645,7 +654,7 @@ func (c *Client) createRequest(ctx context.Context, method, endpoint string, par } } - return req, bodyBuffer, nil + return req, nil } func (c *Client) applyBeforeRequest(req *http.Request) error { @@ -680,37 +689,92 @@ func redactHeaders(headers http.Header) http.Header { return redacted } -func (c *Client) logRequest(req *http.Request, method, url string, bodyBuffer *bytes.Buffer) { - var reqBody string - if bodyBuffer != nil { - reqBody = bodyBuffer.String() - } else { - reqBody = "nil" +func (c *Client) logRequest(req *http.Request) (*http.Request, error) { + var reqBody bytes.Buffer + if req.Body != nil { + if _, err := io.Copy(&reqBody, req.Body); err != nil { + c.logger.Errorf("failed to read request body: %v", err) + } + req.Body = io.NopCloser(bytes.NewReader(reqBody.Bytes())) } reqLog := &RequestLog{ - Method: method, - URL: url, - Headers: req.Header, - Body: reqBody, + Request: strings.Join([]string{req.Method, req.URL.Path, req.Proto}, " "), + Host: req.Host, + Headers: req.Header.Clone(), + Body: reqBody.String(), } e := c.requestLog(reqLog) if e != nil { - _ = c.ErrorAndLogf("failed to mutate after response: %v", e.Error()) + _ = c.ErrorAndLogf("failed to log request: %v", e.Error()) + } + + body, jsonErr := formatBody(reqLog.Body) + if jsonErr != nil { + if c.debug && c.logger != nil { + c.logger.Errorf("%v", jsonErr) + } } var logBuf bytes.Buffer err := reqLogTemplate.Execute(&logBuf, map[string]any{ - "Method": reqLog.Method, - "URL": reqLog.URL, - "Headers": reqLog.Headers, - "Body": reqLog.Body, + "Request": reqLog.Request, + "Host": reqLog.Host, + "Headers": formatHeaders(reqLog.Headers), + "Body": body, }) if err == nil { c.logger.Debugf(logBuf.String()) } + + return req, nil +} + +func formatHeaders(headers map[string][]string) string { + var builder strings.Builder + builder.WriteString("\n") + + keys := make([]string, 0, len(headers)) + for key := range headers { + keys = append(keys, key) + } + sort.Strings(keys) + + for _, key := range keys { + builder.WriteString(fmt.Sprintf(" %s: %s\n", key, strings.Join(headers[key], ", "))) + } + return strings.TrimSuffix(builder.String(), "\n") +} + +func formatBody(body string) (string, error) { + body = strings.TrimSpace(body) + if body == "null" || body == "nil" || body == "" { + return "", nil + } + + var jsonData map[string]interface{} + err := json.Unmarshal([]byte(body), &jsonData) + if err != nil { + return "", fmt.Errorf("error unmarshalling JSON: %w", err) + } + + prettyJSON, err := json.MarshalIndent(jsonData, "", " ") + if err != nil { + return "", fmt.Errorf("error marshalling JSON: %w", err) + } + + return "\n" + string(prettyJSON), nil +} + +func formatDate(dateStr string) (string, error) { + parsedTime, err := time.Parse(time.RFC1123, dateStr) + if err != nil { + return "", fmt.Errorf("error parsing date: %v", err) + } + formattedDate := parsedTime.In(time.Local).Format("2006-01-02T15:04:05-07:00") // nolint:gosmopolitan + return formattedDate, nil } func (c *Client) sendRequest(req *http.Request) (*http.Response, error) { @@ -732,18 +796,46 @@ func (c *Client) checkHTTPError(resp *http.Response) error { return nil } -func (c *Client) logResponse(resp *http.Response) *http.Response { +func (c *Client) logResponse(resp *http.Response, start, end time.Time) *http.Response { var respBody bytes.Buffer if _, err := io.Copy(&respBody, resp.Body); err != nil { c.logger.Errorf("failed to read response body: %v", err) } + receivedAt, dateErr := formatDate(resp.Header.Get("Date")) + if dateErr != nil { + if c.debug && c.logger != nil { + c.logger.Errorf("failed to format date: %v", dateErr) + } + } + + duration := end.Sub(start).String() + + respLog := &ResponseLog{ + Status: resp.Status, + Proto: resp.Proto, + ReceivedAt: receivedAt, + TimeDuration: duration, + Headers: resp.Header, + Body: respBody.String(), + } + + body, jsonErr := formatBody(respLog.Body) + if jsonErr != nil { + if c.debug && c.logger != nil { + c.logger.Errorf("%v", jsonErr) + } + } + var logBuf bytes.Buffer err := respLogTemplate.Execute(&logBuf, map[string]any{ - "Status": resp.Status, - "Headers": redactHeaders(resp.Header), - "Body": respBody.String(), + "Status": respLog.Status, + "Proto": respLog.Proto, + "ReceivedAt": respLog.ReceivedAt, + "TimeDuration": respLog.TimeDuration, + "Headers": formatHeaders(redactHeaders(respLog.Headers)), + "Body": body, }) if err == nil { c.logger.Debugf(logBuf.String()) diff --git a/client_test.go b/client_test.go index cb4a32aa2..504e3ca20 100644 --- a/client_test.go +++ b/client_test.go @@ -480,17 +480,32 @@ func TestDoRequestLogging_Success(t *testing.T) { } logInfo := logBuffer.String() - logInfoWithoutTimestamps := removeTimestamps(logInfo) - // Expected logs with templates filled in - expectedRequestLog := "DEBUG Sending request:\nMethod: GET\nURL: " + server.URL + "\nHeaders: map[Accept:[application/json] Authorization:[Bearer *******************************] Content-Type:[application/json] User-Agent:[linodego/dev https://github.com/linode/linodego]]\nBody: " - expectedResponseLog := "DEBUG Received response:\nStatus: 200 OK\nHeaders: map[Content-Length:[21] Content-Type:[text/plain; charset=utf-8]]\nBody: {\"message\":\"success\"}" + expectedRequestParts := []string{ + "GET /v4/" + server.URL + " " + "HTTP/1.1", + "Accept: application/json", + "Authorization: Bearer *******************************", + "Content-Type: application/json", + "User-Agent: linodego/dev https://github.com/linode/linodego", + } - if !strings.Contains(logInfo, expectedRequestLog) { - t.Fatalf("expected log %q not found in logs", expectedRequestLog) + expectedResponseParts := []string{ + "STATUS: 200 OK", + "PROTO: HTTP/1.1", + "Content-Length: 21", + "Content-Type: application/json", + `"message": "success"`, } - if !strings.Contains(logInfoWithoutTimestamps, expectedResponseLog) { - t.Fatalf("expected log %q not found in logs", expectedResponseLog) + + for _, part := range expectedRequestParts { + if !strings.Contains(logInfo, part) { + t.Fatalf("expected request part %q not found in logs", part) + } + } + for _, part := range expectedResponseParts { + if !strings.Contains(logInfo, part) { + t.Fatalf("expected response part %q not found in logs", part) + } } } diff --git a/request_log_template.tmpl b/request_log_template.tmpl new file mode 100644 index 000000000..250547bd8 --- /dev/null +++ b/request_log_template.tmpl @@ -0,0 +1,8 @@ + +============================================================================================ +~~~ REQUEST ~~~ +{{.Request}} +HOST: {{.Host}} +HEADERS: {{.Headers}} +BODY: {{.Body}} +-------------------------------------------------------------------------------------------- \ No newline at end of file diff --git a/response_log_template.tmpl b/response_log_template.tmpl new file mode 100644 index 000000000..7bd5f38d8 --- /dev/null +++ b/response_log_template.tmpl @@ -0,0 +1,10 @@ + +============================================================================================ +~~~ RESPONSE ~~~ +STATUS: {{.Status}} +PROTO: {{.Proto}} +RECEIVED AT: {{.ReceivedAt}} +TIME DURATION: {{.TimeDuration}} +HEADERS: {{.Headers}} +BODY: {{.Body}} +-------------------------------------------------------------------------------------------- \ No newline at end of file From 026d0b62f1a5097eff454095377748fae97ed4bc Mon Sep 17 00:00:00 2001 From: ezilber-akamai Date: Tue, 21 Apr 2026 16:47:05 -0400 Subject: [PATCH 4/9] Fixed lint --- client.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index 0520d3215..17f13c5f0 100644 --- a/client.go +++ b/client.go @@ -532,12 +532,7 @@ func (c *Client) doRequest(ctx context.Context, method, endpoint string, params } if c.debug && c.logger != nil { - loggedReq, logErr := c.logRequest(req) - if logErr != nil { - return logErr - } - - req = loggedReq + req = c.logRequest(req) } processResponse := func(start, end time.Time) error { @@ -573,6 +568,7 @@ func (c *Client) doRequest(ctx context.Context, method, endpoint string, params startTime := time.Now() resp, err = c.sendRequest(req) endTime := time.Now() + if err == nil { if err = processResponse(startTime, endTime); err == nil { return nil @@ -619,9 +615,7 @@ func (c *Client) shouldRetry(resp *http.Response, err error) bool { } func (c *Client) createRequest(ctx context.Context, method, endpoint string, params requestParams) (*http.Request, error) { - var ( - bodyReader io.Reader - ) + var bodyReader io.Reader if params.Body != nil { // Reset the body position to the start before using it @@ -689,12 +683,13 @@ func redactHeaders(headers http.Header) http.Header { return redacted } -func (c *Client) logRequest(req *http.Request) (*http.Request, error) { +func (c *Client) logRequest(req *http.Request) *http.Request { var reqBody bytes.Buffer if req.Body != nil { if _, err := io.Copy(&reqBody, req.Body); err != nil { c.logger.Errorf("failed to read request body: %v", err) } + req.Body = io.NopCloser(bytes.NewReader(reqBody.Bytes())) } @@ -729,7 +724,7 @@ func (c *Client) logRequest(req *http.Request) (*http.Request, error) { c.logger.Debugf(logBuf.String()) } - return req, nil + return req } func formatHeaders(headers map[string][]string) string { @@ -740,11 +735,13 @@ func formatHeaders(headers map[string][]string) string { for key := range headers { keys = append(keys, key) } + sort.Strings(keys) for _, key := range keys { builder.WriteString(fmt.Sprintf(" %s: %s\n", key, strings.Join(headers[key], ", "))) } + return strings.TrimSuffix(builder.String(), "\n") } @@ -754,7 +751,8 @@ func formatBody(body string) (string, error) { return "", nil } - var jsonData map[string]interface{} + var jsonData map[string]any + err := json.Unmarshal([]byte(body), &jsonData) if err != nil { return "", fmt.Errorf("error unmarshalling JSON: %w", err) @@ -773,7 +771,9 @@ func formatDate(dateStr string) (string, error) { if err != nil { return "", fmt.Errorf("error parsing date: %v", err) } + formattedDate := parsedTime.In(time.Local).Format("2006-01-02T15:04:05-07:00") // nolint:gosmopolitan + return formattedDate, nil } From 208f1492b53e5f1a1d0c4647a4803fb240d9a18f Mon Sep 17 00:00:00 2001 From: Zhiwei Liang Date: Tue, 21 Apr 2026 20:11:27 -0400 Subject: [PATCH 5/9] redact request header log --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 17f13c5f0..8691ba4bd 100644 --- a/client.go +++ b/client.go @@ -696,7 +696,7 @@ func (c *Client) logRequest(req *http.Request) *http.Request { reqLog := &RequestLog{ Request: strings.Join([]string{req.Method, req.URL.Path, req.Proto}, " "), Host: req.Host, - Headers: req.Header.Clone(), + Headers: redactHeaders(req.Header.Clone()), Body: reqBody.String(), } From c45b475f9699e910a78f33306590f0b864a9511b Mon Sep 17 00:00:00 2001 From: ezilber-akamai Date: Wed, 22 Apr 2026 10:23:18 -0400 Subject: [PATCH 6/9] Fixes --- client.go | 53 +++++++++++++++++++++++------------- client_monitor.go | 14 +++++++++- images.go | 3 ++ monitor_alert_definitions.go | 5 ++-- 4 files changed, 53 insertions(+), 22 deletions(-) diff --git a/client.go b/client.go index 8691ba4bd..9e289a803 100644 --- a/client.go +++ b/client.go @@ -199,15 +199,10 @@ func NewClient(hc *http.Client) (client Client) { certPath, certPathExists := os.LookupEnv(APIHostCert) if certPathExists { - cert, err := os.ReadFile(filepath.Clean(certPath)) - if err != nil { - log.Fatalf("[ERROR] Error when reading cert at %s: %s\n", certPath, err.Error()) - } - client.SetRootCertificate(certPath) if envDebug { - log.Printf("[DEBUG] Set API root certificate to %s with contents %s\n", certPath, cert) + log.Printf("[DEBUG] Set API root certificate to %s\n", certPath) } } @@ -275,6 +270,9 @@ func (c *Client) SetUserAgent(ua string) *Client { type requestParams struct { Body *bytes.Reader Response any + // Headers are per-request headers that will be applied only to + // the individual request, not stored on the shared client state. + Headers http.Header } func (c *Client) ErrorAndLogf(format string, args ...any) error { @@ -286,7 +284,7 @@ func (c *Client) ErrorAndLogf(format string, args ...any) error { } // SetRootCertificate adds a root certificate to the underlying TLS client config -func (c *Client) SetRootCertificate(path string) *Client { +func (c *Client) SetRootCertificate(certPath string) *Client { config, err := c.tlsConfig() if err != nil { log.Println("[WARN] Custom transport is not allowed with a custom root CA") @@ -297,7 +295,13 @@ func (c *Client) SetRootCertificate(path string) *Client { config.RootCAs = x509.NewCertPool() } - config.RootCAs.AppendCertsFromPEM([]byte(path)) + pem, err := os.ReadFile(filepath.Clean(certPath)) + if err != nil { + log.Printf("[ERROR] Failed to read root certificate at %s: %s\n", certPath, err.Error()) + return c + } + + config.RootCAs.AppendCertsFromPEM(pem) return c } @@ -584,20 +588,24 @@ func (c *Client) doRequest(ctx context.Context, method, endpoint string, params return retryErr } - // Sleep for the specified duration before retrying. - if retryAfter > 0 { - waitTime := retryAfter + // Determine wait time before retrying. + // If the server provided a Retry-After duration, use it (clamped to bounds). + // Otherwise, fall back to the configured minimum wait time. + waitTime := c.retryMinWaitTime - // Ensure the wait time is within the defined bounds - if waitTime < c.retryMinWaitTime { - waitTime = c.retryMinWaitTime - } else if waitTime > c.retryMaxWaitTime { - waitTime = c.retryMaxWaitTime - } + if retryAfter > 0 { + waitTime = retryAfter + } - // Sleep for the calculated duration before retrying - time.Sleep(waitTime) + // Ensure the wait time is within the defined bounds + if waitTime < c.retryMinWaitTime { + waitTime = c.retryMinWaitTime + } else if waitTime > c.retryMaxWaitTime { + waitTime = c.retryMaxWaitTime } + + // Sleep for the calculated duration before retrying + time.Sleep(waitTime) } return err @@ -648,6 +656,13 @@ func (c *Client) createRequest(ctx context.Context, method, endpoint string, par } } + // Apply per-request headers (these take priority over client headers) + for name, values := range params.Headers { + for _, value := range values { + req.Header.Set(name, value) + } + } + return req, nil } diff --git a/client_monitor.go b/client_monitor.go index 866e6a4dd..3cb959854 100644 --- a/client_monitor.go +++ b/client_monitor.go @@ -11,6 +11,7 @@ import ( "net/url" "os" "path" + "path/filepath" "strings" ) @@ -49,6 +50,11 @@ func NewMonitorClient(hc *http.Client) (mClient MonitorClient) { mClient.httpClient = &http.Client{} } + // Ensure transport is initialized so SetRootCertificate can configure TLS + if mClient.httpClient.Transport == nil { + mClient.httpClient.Transport = &http.Transport{} + } + mClient.header = make(http.Header) mClient.logger = createLogger() @@ -140,7 +146,13 @@ func (mc *MonitorClient) SetRootCertificate(certPath string) *MonitorClient { transport.TLSClientConfig.RootCAs = x509.NewCertPool() } - transport.TLSClientConfig.RootCAs.AppendCertsFromPEM([]byte(certPath)) + pem, err := os.ReadFile(filepath.Clean(certPath)) + if err != nil { + mc.logger.Errorf("Failed to read root certificate at %s: %s", certPath, err.Error()) + return mc + } + + transport.TLSClientConfig.RootCAs.AppendCertsFromPEM(pem) return mc } diff --git a/images.go b/images.go index 88dfad5cb..da2287047 100644 --- a/images.go +++ b/images.go @@ -329,6 +329,9 @@ func (c *Client) UploadImageToURL(ctx context.Context, uploadURL string, image i req.Header.Set("User-Agent", c.userAgent) resp, err := clonedClient.Do(req) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } _, err = coupleAPIErrors(resp, err) if err != nil { diff --git a/monitor_alert_definitions.go b/monitor_alert_definitions.go index 3bb09ec9d..7988ab673 100644 --- a/monitor_alert_definitions.go +++ b/monitor_alert_definitions.go @@ -276,8 +276,9 @@ func (c *Client) CreateMonitorAlertDefinitionWithIdempotency( } if idempotencyKey != "" { - c.SetHeader("Idempotency-Key", idempotencyKey) - defer c.header.Del("Idempotency-Key") + params.Headers = http.Header{ + "Idempotency-Key": {idempotencyKey}, + } } err = c.doRequest(ctx, http.MethodPost, e, params, nil) From 1748b19533beac8b09299dc3f04ce70226d8a61d Mon Sep 17 00:00:00 2001 From: ezilber-akamai Date: Wed, 22 Apr 2026 12:15:34 -0400 Subject: [PATCH 7/9] Fixed CoPilot suggestions --- client_test.go | 5 +++-- errors.go | 4 ++-- errors_test.go | 3 +-- test/integration/cache_test.go | 10 ---------- 4 files changed, 6 insertions(+), 16 deletions(-) diff --git a/client_test.go b/client_test.go index 504e3ca20..da4997c17 100644 --- a/client_test.go +++ b/client_test.go @@ -473,8 +473,9 @@ func TestDoRequestLogging_Success(t *testing.T) { params := requestParams{ Response: &map[string]string{}, } + endpoint := "/foo/bar" - err := client.doRequest(context.Background(), http.MethodGet, server.URL, params, nil) + err := client.doRequest(context.Background(), http.MethodGet, endpoint, params, nil) if err != nil { t.Fatal(cmp.Diff(nil, err)) } @@ -482,7 +483,7 @@ func TestDoRequestLogging_Success(t *testing.T) { logInfo := logBuffer.String() expectedRequestParts := []string{ - "GET /v4/" + server.URL + " " + "HTTP/1.1", + "GET /v4/foo/bar HTTP/1.1", "Accept: application/json", "Authorization: Bearer *******************************", "Content-Type: application/json", diff --git a/errors.go b/errors.go index 513ea1db8..db7fc2a34 100644 --- a/errors.go +++ b/errors.go @@ -91,13 +91,13 @@ func coupleAPIErrors(resp *http.Response, err error) (*http.Response, error) { string(bodyBytes), ) - return nil, &Error{Code: resp.StatusCode, Message: msg} + return nil, &Error{Code: resp.StatusCode, Message: msg, Response: resp} } // Must check if there is no list of reasons in the error before making a call to NewError apiError, ok := getAPIError(resp) if !ok { - return nil, NewError(fmt.Errorf("failed to decode response body: %w", err)) + return nil, NewError(fmt.Errorf("failed to decode response body")) } if len(apiError.Errors) == 0 { diff --git a/errors_test.go b/errors_test.go index 45ed16b7e..4a182c36e 100644 --- a/errors_test.go +++ b/errors_test.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" "strconv" @@ -51,7 +50,7 @@ func httpError(reason, field string) *http.Response { return &http.Response{ StatusCode: 500, - Body: ioutil.NopCloser(bytes.NewReader(body)), + Body: io.NopCloser(bytes.NewReader(body)), } } diff --git a/test/integration/cache_test.go b/test/integration/cache_test.go index 6f8385997..9468e870c 100644 --- a/test/integration/cache_test.go +++ b/test/integration/cache_test.go @@ -27,16 +27,6 @@ func TestCache_RegionList(t *testing.T) { // Collect request number totalRequests := int64(0) - //client.OnBeforeRequest(func(request *linodego.Request) error { - // fmt.Printf("Request URL: %s\n", request.URL.String()) // Log the URL - // fmt.Printf("Page: %s\n", request.URL.Query().Get("page")) // Log the page query parameter - // if !strings.Contains(request.URL.String(), "regions") || request.URL.Query().Get("page") != "1" { - // return nil - // } - // - // atomic.AddInt64(&totalRequests, 1) - // return nil - //}) client.OnBeforeRequest(func(request *linodego.Request) error { page := request.URL.Query().Get("page") if !strings.Contains(request.URL.String(), "regions") || page != "1" { From 0cebc5892fb86dedd8b819c0deab7803ab88106b Mon Sep 17 00:00:00 2001 From: ezilber-akamai Date: Thu, 23 Apr 2026 10:45:12 -0400 Subject: [PATCH 8/9] Sanitize log output --- client.go | 22 ++++++++++++++++++---- logger.go | 6 ++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index 9e289a803..bedfcf41c 100644 --- a/client.go +++ b/client.go @@ -720,7 +720,9 @@ func (c *Client) logRequest(req *http.Request) *http.Request { _ = c.ErrorAndLogf("failed to log request: %v", e.Error()) } - body, jsonErr := formatBody(reqLog.Body) + sanitizedBody := sanitizeLogValue(reqLog.Body) + + body, jsonErr := formatBody(sanitizedBody) if jsonErr != nil { if c.debug && c.logger != nil { c.logger.Errorf("%v", jsonErr) @@ -736,7 +738,7 @@ func (c *Client) logRequest(req *http.Request) *http.Request { "Body": body, }) if err == nil { - c.logger.Debugf(logBuf.String()) + c.logger.Debugf(sanitizeLogValue(logBuf.String())) } return req @@ -760,6 +762,18 @@ func formatHeaders(headers map[string][]string) string { return strings.TrimSuffix(builder.String(), "\n") } +// sanitizeLogValue removes or escapes control characters that could +// enable log injection (e.g., \r, \n) from a string before it is written +// to a log entry. Uses strings.ReplaceAll so static-analysis tools +// (e.g., CodeQL) can recognize the sanitization. +func sanitizeLogValue(s string) string { + s = strings.ReplaceAll(s, "\r\n", "\\n") + s = strings.ReplaceAll(s, "\r", "\\n") + s = strings.ReplaceAll(s, "\n", "\\n") + + return s +} + func formatBody(body string) (string, error) { body = strings.TrimSpace(body) if body == "null" || body == "nil" || body == "" { @@ -835,7 +849,7 @@ func (c *Client) logResponse(resp *http.Response, start, end time.Time) *http.Re Body: respBody.String(), } - body, jsonErr := formatBody(respLog.Body) + body, jsonErr := formatBody(sanitizeLogValue(respLog.Body)) if jsonErr != nil { if c.debug && c.logger != nil { c.logger.Errorf("%v", jsonErr) @@ -853,7 +867,7 @@ func (c *Client) logResponse(resp *http.Response, start, end time.Time) *http.Re "Body": body, }) if err == nil { - c.logger.Debugf(logBuf.String()) + c.logger.Debugf(sanitizeLogValue(logBuf.String())) } resp.Body = io.NopCloser(bytes.NewReader(respBody.Bytes())) diff --git a/logger.go b/logger.go index 59b947381..b667dd0f2 100644 --- a/logger.go +++ b/logger.go @@ -3,6 +3,7 @@ package linodego import ( "log" "os" + "strings" ) type Logger interface { @@ -35,6 +36,11 @@ func (l *logger) Debugf(format string, v ...any) { } func (l *logger) output(format string, v ...any) { //nolint:goprintffuncname + // Sanitize to prevent log injection via user-controlled values + format = strings.ReplaceAll(format, "\r\n", "\\n") + format = strings.ReplaceAll(format, "\r", "\\n") + format = strings.ReplaceAll(format, "\n", "\\n") + if len(v) == 0 { l.l.Print(format) return From a06bd9056d7fadcb069c937fa0ef5bd78d63c763 Mon Sep 17 00:00:00 2001 From: ezilber-akamai Date: Thu, 23 Apr 2026 16:10:37 -0400 Subject: [PATCH 9/9] Addressed CoPilot comments --- client.go | 2 +- logger.go | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/client.go b/client.go index bedfcf41c..57737701e 100644 --- a/client.go +++ b/client.go @@ -780,7 +780,7 @@ func formatBody(body string) (string, error) { return "", nil } - var jsonData map[string]any + var jsonData any err := json.Unmarshal([]byte(body), &jsonData) if err != nil { diff --git a/logger.go b/logger.go index b667dd0f2..fc9504b50 100644 --- a/logger.go +++ b/logger.go @@ -1,6 +1,7 @@ package linodego import ( + "fmt" "log" "os" "strings" @@ -36,15 +37,18 @@ func (l *logger) Debugf(format string, v ...any) { } func (l *logger) output(format string, v ...any) { //nolint:goprintffuncname - // Sanitize to prevent log injection via user-controlled values - format = strings.ReplaceAll(format, "\r\n", "\\n") - format = strings.ReplaceAll(format, "\r", "\\n") - format = strings.ReplaceAll(format, "\n", "\\n") - + // Render the final message first, then sanitize control characters + // to prevent log injection via both the format string and variadic args. + var msg string if len(v) == 0 { - l.l.Print(format) - return + msg = format + } else { + msg = fmt.Sprintf(format, v...) } - l.l.Printf(format, v...) + msg = strings.ReplaceAll(msg, "\r\n", "\\n") + msg = strings.ReplaceAll(msg, "\r", "\\n") + msg = strings.ReplaceAll(msg, "\n", "\\n") + + l.l.Print(msg) }