diff --git a/client.go b/client.go index 0273d4565..57737701e 100644 --- a/client.go +++ b/client.go @@ -3,6 +3,9 @@ package linodego import ( "bytes" "context" + "crypto/tls" + "crypto/x509" + _ "embed" "encoding/json" "fmt" "io" @@ -14,13 +17,13 @@ import ( "path/filepath" "reflect" "regexp" + "runtime" + "sort" "strconv" "strings" "sync" "text/template" "time" - - "github.com/go-resty/resty/v2" ) const ( @@ -50,20 +53,35 @@ const ( APIDefaultCacheExpiration = time.Minute * 15 ) -//nolint:unused +// Embed the log template files +// +//go:embed request_log_template.tmpl +var requestTemplateStr string + +//go:embed response_log_template.tmpl +var responseTemplateStr string + var ( - reqLogTemplate = template.Must(template.New("request").Parse(`Sending request: -Method: {{.Method}} -URL: {{.URL}} -Headers: {{.Headers}} -Body: {{.Body}}`)) - - respLogTemplate = template.Must(template.New("response").Parse(`Received response: -Status: {{.Status}} -Headers: {{.Headers}} -Body: {{.Body}}`)) + reqLogTemplate = template.Must(template.New("request").Parse(requestTemplateStr)) + respLogTemplate = template.Must(template.New("response").Parse(responseTemplateStr)) ) +type RequestLog struct { + Request string + Host string + Headers http.Header + Body string +} + +type ResponseLog struct { + Status string + Proto string + ReceivedAt string + TimeDuration string + Headers http.Header + Body string +} + var envDebug = false // redactHeadersMap is a map of headers that should be redacted in logs, @@ -72,18 +90,19 @@ var redactHeadersMap = map[string]string{ "Authorization": "Bearer *******************************", } -// Client is a wrapper around the Resty client +// Client is a wrapper around the http client type Client struct { - resty *resty.Client - userAgent string - debug bool - retryConditionals []RetryConditional + httpClient *http.Client + userAgent string + debug bool pollInterval time.Duration baseURL string apiVersion string apiProto string + hostURL string + header http.Header selectedProfile string loadedProfile string @@ -94,6 +113,16 @@ type Client struct { cacheExpiration time.Duration cachedEntries map[string]clientCacheEntry cachedEntryLock *sync.RWMutex + logger Logger + requestLog func(*RequestLog) error + onBeforeRequest []func(*http.Request) error + onAfterResponse []func(*http.Response) error + + retryConditionals []RetryConditional + retryMaxWaitTime time.Duration + retryMinWaitTime time.Duration + retryAfter RetryAfter + retryCount int } type EnvDefaults struct { @@ -110,13 +139,11 @@ type clientCacheEntry struct { } type ( - Request = resty.Request - Response = resty.Response - Logger = resty.Logger + Request = http.Request + Response = http.Response ) func init() { - // Whether we will enable Resty debugging output if apiDebug, ok := os.LookupEnv("LINODE_DEBUG"); ok { if parsed, err := strconv.ParseBool(apiDebug); err == nil { envDebug = parsed @@ -128,22 +155,37 @@ func init() { } // NewClient factory to create new Client struct +// nolint:funlen func NewClient(hc *http.Client) (client Client) { if hc != nil { - client.resty = resty.NewWithClient(hc) + client.httpClient = hc } else { - client.resty = resty.New() + client.httpClient = &http.Client{} + } + + // Ensure that the Header map is not nil + if client.httpClient.Transport == nil { + client.httpClient.Transport = &http.Transport{} } client.shouldCache = true client.cacheExpiration = APIDefaultCacheExpiration client.cachedEntries = make(map[string]clientCacheEntry) client.cachedEntryLock = &sync.RWMutex{} + client.configProfiles = make(map[string]ConfigProfile) + + const ( + retryMinWaitDuration = 100 * time.Millisecond + retryMaxWaitDuration = 2 * time.Second + ) + + client.retryMinWaitTime = retryMinWaitDuration + client.retryMaxWaitTime = retryMaxWaitDuration client.SetUserAgent(DefaultUserAgent) + client.SetLogger(createLogger()) baseURL, baseURLExists := os.LookupEnv(APIHostVar) - if baseURLExists { client.SetBaseURL(baseURL) } @@ -156,17 +198,11 @@ func NewClient(hc *http.Client) (client Client) { } certPath, certPathExists := os.LookupEnv(APIHostCert) - - if certPathExists && !hasCustomTransport(hc) { - cert, err := os.ReadFile(filepath.Clean(certPath)) - if err != nil { - log.Fatalf("[ERROR] Error when reading cert at %s: %s\n", certPath, err.Error()) - } - + if certPathExists { client.SetRootCertificate(certPath) if envDebug { - log.Printf("[DEBUG] Set API root certificate to %s with contents %s\n", certPath, cert) + log.Printf("[DEBUG] Set API root certificate to %s\n", certPath) } } @@ -174,6 +210,7 @@ func NewClient(hc *http.Client) (client Client) { SetRetryWaitTime(APISecondsPerPoll * time.Second). SetPollDelay(APISecondsPerPoll * time.Second). SetRetries(). + SetLogger(createLogger()). SetDebug(envDebug). enableLogSanitization() @@ -213,8 +250,8 @@ func NewClientFromEnv(hc *http.Client) (*Client, error) { client.selectedProfile = configProfile // We should only load the config if the config file exists - if _, err = os.Stat(configPath); err != nil { - return nil, fmt.Errorf("error loading config file %s: %w", configPath, err) + if _, statErr := os.Stat(configPath); statErr != nil { + return nil, fmt.Errorf("error loading config file %s: %w", configPath, statErr) } err = client.preLoadConfig(configPath) @@ -225,42 +262,284 @@ func NewClientFromEnv(hc *http.Client) (*Client, error) { // SetUserAgent sets a custom user-agent for HTTP requests func (c *Client) SetUserAgent(ua string) *Client { c.userAgent = ua - c.resty.SetHeader("User-Agent", c.userAgent) + c.SetHeader("User-Agent", c.userAgent) return c } -type RequestParams struct { - Body any +type requestParams struct { + Body *bytes.Reader Response any + // Headers are per-request headers that will be applied only to + // the individual request, not stored on the shared client state. + Headers http.Header +} + +func (c *Client) ErrorAndLogf(format string, args ...any) error { + if c.debug && c.logger != nil { + c.logger.Errorf(format, args...) + } + + return fmt.Errorf(format, args...) +} + +// SetRootCertificate adds a root certificate to the underlying TLS client config +func (c *Client) SetRootCertificate(certPath string) *Client { + config, err := c.tlsConfig() + if err != nil { + log.Println("[WARN] Custom transport is not allowed with a custom root CA") + return c + } + + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + + pem, err := os.ReadFile(filepath.Clean(certPath)) + if err != nil { + log.Printf("[ERROR] Failed to read root certificate at %s: %s\n", certPath, err.Error()) + return c + } + + config.RootCAs.AppendCertsFromPEM(pem) + + return c +} + +// SetToken sets the API token for all requests from this client +// Only necessary if you haven't already provided the http client to NewClient() configured with the token. +func (c *Client) SetToken(token string) *Client { + c.SetHeader("Authorization", fmt.Sprintf("Bearer %s", token)) + return c +} + +// SetRetries adds retry conditions for "Linode Busy." errors and 429s. +func (c *Client) SetRetries() *Client { + c. + AddRetryCondition(LinodeBusyRetryCondition). + AddRetryCondition(TooManyRequestsRetryCondition). + AddRetryCondition(ServiceUnavailableRetryCondition). + AddRetryCondition(RequestTimeoutRetryCondition). + AddRetryCondition(RequestGOAWAYRetryCondition). + AddRetryCondition(RequestNGINXRetryCondition). + SetRetryMaxWaitTime(APIRetryMaxWaitTime) + ConfigureRetries(c) + + return c +} + +// AddRetryCondition adds a RetryConditional function to the Client +func (c *Client) AddRetryCondition(retryCondition RetryConditional) *Client { + c.retryConditionals = append(c.retryConditionals, retryCondition) + + return c +} + +func (c *Client) SetDebug(debug bool) *Client { + c.debug = debug + + return c +} + +func (c *Client) SetLogger(logger Logger) *Client { + c.logger = logger + + return c +} + +func (c *Client) OnBeforeRequest(m func(*http.Request) error) { + c.onBeforeRequest = append(c.onBeforeRequest, m) +} + +func (c *Client) OnAfterResponse(m func(*http.Response) error) { + c.onAfterResponse = append(c.onAfterResponse, m) +} + +// UseURL parses the individual components of the given API URL and configures the client +// accordingly. For example, a valid URL. +// For example: +// +// client.UseURL("https://api.test.linode.com/v4beta") +func (c *Client) UseURL(apiURL string) (*Client, error) { + parsedURL, err := url.Parse(apiURL) + if err != nil { + return nil, fmt.Errorf("failed to parse URL: %w", err) + } + + if parsedURL.Scheme == "" || parsedURL.Host == "" { + return nil, fmt.Errorf("need both scheme and host in API URL, got %q", apiURL) + } + + // Create a new URL excluding the path to use as the base URL + baseURL := &url.URL{ + Host: parsedURL.Host, + Scheme: parsedURL.Scheme, + } + + c.SetBaseURL(baseURL.String()) + + versionMatches := regexp.MustCompile(`/v[a-zA-Z0-9]+`).FindAllString(parsedURL.Path, -1) + + // Only set the version if a version is found in the URL, else use the default + if len(versionMatches) > 0 { + c.SetAPIVersion( + strings.Trim(versionMatches[len(versionMatches)-1], "/"), + ) + } + + return c, nil +} + +func (c *Client) SetBaseURL(baseURL string) *Client { + baseURLPath, _ := url.Parse(baseURL) + + c.baseURL = path.Join(baseURLPath.Host, baseURLPath.Path) + c.apiProto = baseURLPath.Scheme + + c.updateHostURL() + + return c +} + +// SetAPIVersion sets the version of the API to interface with +func (c *Client) SetAPIVersion(apiVersion string) *Client { + c.apiVersion = apiVersion + + c.updateHostURL() + + return c +} + +// InvalidateCache clears all cached responses for all endpoints. +func (c *Client) InvalidateCache() { + c.cachedEntryLock.Lock() + defer c.cachedEntryLock.Unlock() + + // GC will handle the old map + c.cachedEntries = make(map[string]clientCacheEntry) +} + +// InvalidateCacheEndpoint invalidates a single cached endpoint. +func (c *Client) InvalidateCacheEndpoint(endpoint string) error { + u, err := url.Parse(endpoint) + if err != nil { + return fmt.Errorf("failed to parse URL for caching: %w", err) + } + + c.cachedEntryLock.Lock() + defer c.cachedEntryLock.Unlock() + + delete(c.cachedEntries, u.Path) + + return nil +} + +// SetGlobalCacheExpiration sets the desired time for any cached response +// to be valid for. +func (c *Client) SetGlobalCacheExpiration(expiryTime time.Duration) { + c.cacheExpiration = expiryTime +} + +// UseCache sets whether response caching should be used +func (c *Client) UseCache(value bool) { + c.shouldCache = value +} + +// SetRetryMaxWaitTime sets the maximum delay before retrying a request. +func (c *Client) SetRetryMaxWaitTime(maxWaitTime time.Duration) *Client { + c.retryMaxWaitTime = maxWaitTime + return c +} + +// SetRetryWaitTime sets the default (minimum) delay before retrying a request. +func (c *Client) SetRetryWaitTime(minWaitTime time.Duration) *Client { + c.retryMinWaitTime = minWaitTime + return c +} + +// SetRetryAfter sets the callback function to be invoked with a failed request +// to determine wben it should be retried. +func (c *Client) SetRetryAfter(callback RetryAfter) *Client { + c.retryAfter = callback + return c +} + +// SetRetryCount sets the maximum retry attempts before aborting. +func (c *Client) SetRetryCount(count int) *Client { + c.retryCount = count + return c +} + +// SetPollDelay sets the number of milliseconds to wait between events or status polls. +// Affects all WaitFor* functions and retries. +func (c *Client) SetPollDelay(delay time.Duration) *Client { + c.pollInterval = delay + return c +} + +// GetPollDelay gets the number of milliseconds to wait between events or status polls. +// Affects all WaitFor* functions and retries. +func (c *Client) GetPollDelay() time.Duration { + return c.pollInterval +} + +// SetHeader sets a custom header to be used in all API requests made with the current +// client. +// NOTE: Some headers may be overridden by the individual request functions. +func (c *Client) SetHeader(name, value string) { + if c.header == nil { + c.header = make(http.Header) // Initialize header if nil + } + + c.header.Set(name, value) +} + +func (c *Client) Transport() (*http.Transport, error) { + if transport, ok := c.httpClient.Transport.(*http.Transport); ok { + return transport, nil + } + + return nil, fmt.Errorf("current transport is not an *http.Transport instance") } // Generic helper to execute HTTP requests using the net/http package // -// nolint:unused, funlen, gocognit -func (c *httpClient) doRequest(ctx context.Context, method, url string, params RequestParams) error { +// nolint:funlen, gocognit, nestif +func (c *Client) doRequest(ctx context.Context, method, endpoint string, params requestParams, paginationMutator *func(*http.Request) error) error { var ( - req *http.Request - bodyBuffer *bytes.Buffer - resp *http.Response - err error + req *http.Request + resp *http.Response + err error ) - for range httpDefaultRetryCount { - req, bodyBuffer, err = c.createRequest(ctx, method, url, params) + for range c.retryCount { + // Reset the body to the start for each retry if it's not nil + if params.Body != nil { + if _, seekErr := params.Body.Seek(0, io.SeekStart); seekErr != nil { + return c.ErrorAndLogf("failed to seek to the start of the body: %v", seekErr.Error()) + } + } + + req, err = c.createRequest(ctx, method, endpoint, params) if err != nil { return err } + if paginationMutator != nil { + if mutErr := (*paginationMutator)(req); mutErr != nil { + return c.ErrorAndLogf("failed to mutate before request: %v", mutErr.Error()) + } + } + if err = c.applyBeforeRequest(req); err != nil { return err } if c.debug && c.logger != nil { - c.logRequest(req, method, url, bodyBuffer) + req = c.logRequest(req) } - processResponse := func() error { + processResponse := func(start, end time.Time) error { defer func() { closeErr := resp.Body.Close() if closeErr != nil && err == nil { @@ -273,12 +552,7 @@ func (c *httpClient) doRequest(ctx context.Context, method, url string, params R } if c.debug && c.logger != nil { - var logErr error - - resp, logErr = c.logResponse(resp) - if logErr != nil { - return logErr - } + resp = c.logResponse(resp, start, end) } if params.Response != nil { @@ -295,9 +569,12 @@ func (c *httpClient) doRequest(ctx context.Context, method, url string, params R return nil } + startTime := time.Now() resp, err = c.sendRequest(req) + endTime := time.Now() + if err == nil { - if err = processResponse(); err == nil { + if err = processResponse(startTime, endTime); err == nil { return nil } } @@ -311,19 +588,33 @@ func (c *httpClient) doRequest(ctx context.Context, method, url string, params R return retryErr } - // Sleep for the specified duration before retrying. - // If retryAfter is 0 (i.e., Retry-After header is not found), - // no delay is applied. - time.Sleep(retryAfter) + // Determine wait time before retrying. + // If the server provided a Retry-After duration, use it (clamped to bounds). + // Otherwise, fall back to the configured minimum wait time. + waitTime := c.retryMinWaitTime + + if retryAfter > 0 { + waitTime = retryAfter + } + + // Ensure the wait time is within the defined bounds + if waitTime < c.retryMinWaitTime { + waitTime = c.retryMinWaitTime + } else if waitTime > c.retryMaxWaitTime { + waitTime = c.retryMaxWaitTime + } + + // Sleep for the calculated duration before retrying + time.Sleep(waitTime) } return err } -// nolint:unused -func (c *httpClient) shouldRetry(resp *http.Response, err error) bool { +func (c *Client) shouldRetry(resp *http.Response, err error) bool { for _, retryConditional := range c.retryConditionals { if retryConditional(resp, err) { + log.Printf("[INFO] Received error %v - Retrying", err) return true } } @@ -331,35 +622,26 @@ func (c *httpClient) shouldRetry(resp *http.Response, err error) bool { return false } -// nolint:unused -func (c *httpClient) createRequest(ctx context.Context, method, url string, params RequestParams) (*http.Request, *bytes.Buffer, error) { - var ( - bodyReader io.Reader - bodyBuffer *bytes.Buffer - ) +func (c *Client) createRequest(ctx context.Context, method, endpoint string, params requestParams) (*http.Request, error) { + var bodyReader io.Reader if params.Body != nil { - bodyBuffer = new(bytes.Buffer) - if err := json.NewEncoder(bodyBuffer).Encode(params.Body); err != nil { - if c.debug && c.logger != nil { - c.logger.Errorf("failed to encode body: %v", err) - } - - return nil, nil, fmt.Errorf("failed to encode body: %w", err) + // Reset the body position to the start before using it + _, err := params.Body.Seek(0, io.SeekStart) + if err != nil { + return nil, c.ErrorAndLogf("failed to seek to the start of the body: %v", err.Error()) } - bodyReader = bodyBuffer + bodyReader = params.Body } - req, err := http.NewRequestWithContext(ctx, method, url, bodyReader) + req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s/%s", strings.TrimRight(c.hostURL, "/"), + strings.TrimLeft(endpoint, "/")), bodyReader) if err != nil { - if c.debug && c.logger != nil { - c.logger.Errorf("failed to create request: %v", err) - } - - return nil, nil, fmt.Errorf("failed to create request: %w", err) + return nil, c.ErrorAndLogf("failed to create request: %v", err.Error()) } + // Set the default headers req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") @@ -367,40 +649,43 @@ func (c *httpClient) createRequest(ctx context.Context, method, url string, para req.Header.Set("User-Agent", c.userAgent) } - return req, bodyBuffer, nil + // Set additional headers added to the client + for name, values := range c.header { + for _, value := range values { + req.Header.Set(name, value) + } + } + + // Apply per-request headers (these take priority over client headers) + for name, values := range params.Headers { + for _, value := range values { + req.Header.Set(name, value) + } + } + + return req, nil } -// nolint:unused -func (c *httpClient) applyBeforeRequest(req *http.Request) error { +func (c *Client) applyBeforeRequest(req *http.Request) error { for _, mutate := range c.onBeforeRequest { if err := mutate(req); err != nil { - if c.debug && c.logger != nil { - c.logger.Errorf("failed to mutate before request: %v", err) - } - - return fmt.Errorf("failed to mutate before request: %w", err) + return c.ErrorAndLogf("failed to mutate before request: %v", err.Error()) } } return nil } -// nolint:unused -func (c *httpClient) applyAfterResponse(resp *http.Response) error { +func (c *Client) applyAfterResponse(resp *http.Response) error { for _, mutate := range c.onAfterResponse { if err := mutate(resp); err != nil { - if c.debug && c.logger != nil { - c.logger.Errorf("failed to mutate after response: %v", err) - } - - return fmt.Errorf("failed to mutate after response: %w", err) + return c.ErrorAndLogf("failed to mutate after response: %v", err.Error()) } } return nil } -// nolint:unused func redactHeaders(headers http.Header) http.Header { redacted := headers.Clone() @@ -413,333 +698,224 @@ func redactHeaders(headers http.Header) http.Header { return redacted } -// nolint:unused -func (c *httpClient) logRequest(req *http.Request, method, url string, bodyBuffer *bytes.Buffer) { - var reqBody string - if bodyBuffer != nil { - reqBody = bodyBuffer.String() - } else { - reqBody = "nil" - } - - var logBuf bytes.Buffer +func (c *Client) logRequest(req *http.Request) *http.Request { + var reqBody bytes.Buffer + if req.Body != nil { + if _, err := io.Copy(&reqBody, req.Body); err != nil { + c.logger.Errorf("failed to read request body: %v", err) + } - err := reqLogTemplate.Execute(&logBuf, map[string]any{ - "Method": method, - "URL": url, - "Headers": redactHeaders(req.Header), - "Body": reqBody, - }) - if err == nil { - c.logger.Debugf(logBuf.String()) + req.Body = io.NopCloser(bytes.NewReader(reqBody.Bytes())) } -} -// nolint:unused -func (c *httpClient) sendRequest(req *http.Request) (*http.Response, error) { - // #nosec G704 - resp, err := c.httpClient.Do(req) - if err != nil { - if c.debug && c.logger != nil { - c.logger.Errorf("failed to send request: %v", err) - } + reqLog := &RequestLog{ + Request: strings.Join([]string{req.Method, req.URL.Path, req.Proto}, " "), + Host: req.Host, + Headers: redactHeaders(req.Header.Clone()), + Body: reqBody.String(), + } - return nil, fmt.Errorf("failed to send request: %w", err) + e := c.requestLog(reqLog) + if e != nil { + _ = c.ErrorAndLogf("failed to log request: %v", e.Error()) } - return resp, nil -} + sanitizedBody := sanitizeLogValue(reqLog.Body) -// nolint:unused -func (c *httpClient) checkHTTPError(resp *http.Response) error { - _, err := coupleAPIErrorsHTTP(resp, nil) - if err != nil { + body, jsonErr := formatBody(sanitizedBody) + if jsonErr != nil { if c.debug && c.logger != nil { - c.logger.Errorf("received HTTP error: %v", err) + c.logger.Errorf("%v", jsonErr) } - - return err - } - - return nil -} - -// nolint:unused -func (c *httpClient) logResponse(resp *http.Response) (*http.Response, error) { - var respBody bytes.Buffer - if _, err := io.Copy(&respBody, resp.Body); err != nil { - c.logger.Errorf("failed to read response body: %v", err) } var logBuf bytes.Buffer - err := respLogTemplate.Execute(&logBuf, map[string]any{ - "Status": resp.Status, - "Headers": redactHeaders(resp.Header), - "Body": respBody.String(), + err := reqLogTemplate.Execute(&logBuf, map[string]any{ + "Request": reqLog.Request, + "Host": reqLog.Host, + "Headers": formatHeaders(reqLog.Headers), + "Body": body, }) if err == nil { - c.logger.Debugf(logBuf.String()) + c.logger.Debugf(sanitizeLogValue(logBuf.String())) } - resp.Body = io.NopCloser(bytes.NewReader(respBody.Bytes())) - - return resp, nil + return req } -// nolint:unused -func (c *httpClient) decodeResponseBody(resp *http.Response, response any) error { - if err := json.NewDecoder(resp.Body).Decode(response); err != nil { - if c.debug && c.logger != nil { - c.logger.Errorf("failed to decode response: %v", err) - } +func formatHeaders(headers map[string][]string) string { + var builder strings.Builder + builder.WriteString("\n") - return fmt.Errorf("failed to decode response: %w", err) + keys := make([]string, 0, len(headers)) + for key := range headers { + keys = append(keys, key) } - return nil -} + sort.Strings(keys) -// R wraps resty's R method -func (c *Client) R(ctx context.Context) *resty.Request { - return c.resty.R(). - ExpectContentType("application/json"). - SetHeader("Content-Type", "application/json"). - SetContext(ctx). - SetError(APIError{}) -} - -// SetDebug sets the debug on resty's client -func (c *Client) SetDebug(debug bool) *Client { - c.debug = debug - c.resty.SetDebug(debug) + for _, key := range keys { + builder.WriteString(fmt.Sprintf(" %s: %s\n", key, strings.Join(headers[key], ", "))) + } - return c + return strings.TrimSuffix(builder.String(), "\n") } -// SetLogger allows the user to override the output -// logger for debug logs. -func (c *Client) SetLogger(logger Logger) *Client { - c.resty.SetLogger(logger) +// sanitizeLogValue removes or escapes control characters that could +// enable log injection (e.g., \r, \n) from a string before it is written +// to a log entry. Uses strings.ReplaceAll so static-analysis tools +// (e.g., CodeQL) can recognize the sanitization. +func sanitizeLogValue(s string) string { + s = strings.ReplaceAll(s, "\r\n", "\\n") + s = strings.ReplaceAll(s, "\r", "\\n") + s = strings.ReplaceAll(s, "\n", "\\n") - return c + return s } -//nolint:unused -func (c *httpClient) httpSetDebug(debug bool) *httpClient { - c.debug = debug +func formatBody(body string) (string, error) { + body = strings.TrimSpace(body) + if body == "null" || body == "nil" || body == "" { + return "", nil + } - return c -} + var jsonData any -//nolint:unused -func (c *httpClient) httpSetLogger(logger httpLogger) *httpClient { - c.logger = logger + err := json.Unmarshal([]byte(body), &jsonData) + if err != nil { + return "", fmt.Errorf("error unmarshalling JSON: %w", err) + } - return c -} + prettyJSON, err := json.MarshalIndent(jsonData, "", " ") + if err != nil { + return "", fmt.Errorf("error marshalling JSON: %w", err) + } -// OnBeforeRequest adds a handler to the request body to run before the request is sent -func (c *Client) OnBeforeRequest(m func(request *Request) error) { - c.resty.OnBeforeRequest(func(_ *resty.Client, req *resty.Request) error { - return m(req) - }) + return "\n" + string(prettyJSON), nil } -// OnAfterResponse adds a handler to the request body to run before the request is sent -func (c *Client) OnAfterResponse(m func(response *Response) error) { - c.resty.OnAfterResponse(func(_ *resty.Client, req *resty.Response) error { - return m(req) - }) -} +func formatDate(dateStr string) (string, error) { + parsedTime, err := time.Parse(time.RFC1123, dateStr) + if err != nil { + return "", fmt.Errorf("error parsing date: %v", err) + } -// nolint:unused -func (c *httpClient) httpOnBeforeRequest(m func(*http.Request) error) *httpClient { - c.onBeforeRequest = append(c.onBeforeRequest, m) + formattedDate := parsedTime.In(time.Local).Format("2006-01-02T15:04:05-07:00") // nolint:gosmopolitan - return c + return formattedDate, nil } -// nolint:unused -func (c *httpClient) httpOnAfterResponse(m func(*http.Response) error) *httpClient { - c.onAfterResponse = append(c.onAfterResponse, m) +func (c *Client) sendRequest(req *http.Request) (*http.Response, error) { + resp, err := c.httpClient.Do(req) //#nosec G704 // URL is constructed from client-configured base URL + endpoint + if err != nil { + return nil, c.ErrorAndLogf("failed to send request: %w", err) + } - return c + return resp, nil } -// UseURL parses the individual components of the given API URL and configures the client -// accordingly. For example, a valid URL. -// For example: -// -// client.UseURL("https://api.test.linode.com/v4beta") -func (c *Client) UseURL(apiURL string) (*Client, error) { - parsedURL, err := url.Parse(apiURL) +func (c *Client) checkHTTPError(resp *http.Response) error { + _, err := coupleAPIErrors(resp, nil) if err != nil { - return nil, fmt.Errorf("failed to parse URL: %w", err) + _ = c.ErrorAndLogf("received HTTP error: %v", err.Error()) + return err } - if parsedURL.Scheme == "" || parsedURL.Host == "" { - return nil, fmt.Errorf("need both scheme and host in API URL, got %q", apiURL) - } + return nil +} - // Create a new URL excluding the path to use as the base URL - baseURL := &url.URL{ - Host: parsedURL.Host, - Scheme: parsedURL.Scheme, +func (c *Client) logResponse(resp *http.Response, start, end time.Time) *http.Response { + var respBody bytes.Buffer + if _, err := io.Copy(&respBody, resp.Body); err != nil { + c.logger.Errorf("failed to read response body: %v", err) } - c.SetBaseURL(baseURL.String()) - - versionMatches := regexp.MustCompile(`/v[a-zA-Z0-9]+`).FindAllString(parsedURL.Path, -1) - - // Only set the version if a version is found in the URL, else use the default - if len(versionMatches) > 0 { - c.SetAPIVersion( - strings.Trim(versionMatches[len(versionMatches)-1], "/"), - ) + receivedAt, dateErr := formatDate(resp.Header.Get("Date")) + if dateErr != nil { + if c.debug && c.logger != nil { + c.logger.Errorf("failed to format date: %v", dateErr) + } } - return c, nil -} - -// SetBaseURL sets the base URL of the Linode v4 API (https://api.linode.com/v4) -func (c *Client) SetBaseURL(baseURL string) *Client { - baseURLPath, _ := url.Parse(baseURL) - - c.baseURL = path.Join(baseURLPath.Host, baseURLPath.Path) - c.apiProto = baseURLPath.Scheme - - c.updateHostURL() - - return c -} - -// SetAPIVersion sets the version of the API to interface with -func (c *Client) SetAPIVersion(apiVersion string) *Client { - c.apiVersion = apiVersion + duration := end.Sub(start).String() - c.updateHostURL() - - return c -} - -// SetRootCertificate adds a root certificate to the underlying TLS client config -func (c *Client) SetRootCertificate(path string) *Client { - c.resty.SetRootCertificate(path) - return c -} - -// SetToken sets the API token for all requests from this client -// Only necessary if you haven't already provided the http client to NewClient() configured with the token. -func (c *Client) SetToken(token string) *Client { - c.resty.SetHeader("Authorization", fmt.Sprintf("Bearer %s", token)) - return c -} + respLog := &ResponseLog{ + Status: resp.Status, + Proto: resp.Proto, + ReceivedAt: receivedAt, + TimeDuration: duration, + Headers: resp.Header, + Body: respBody.String(), + } -// SetRetries adds retry conditions for "Linode Busy." errors and 429s. -func (c *Client) SetRetries() *Client { - c. - addRetryConditional(linodeBusyRetryCondition). - addRetryConditional(tooManyRequestsRetryCondition). - addRetryConditional(serviceUnavailableRetryCondition). - addRetryConditional(requestTimeoutRetryCondition). - addRetryConditional(requestGOAWAYRetryCondition). - addRetryConditional(requestNGINXRetryCondition). - SetRetryMaxWaitTime(APIRetryMaxWaitTime) - configureRetries(c) + body, jsonErr := formatBody(sanitizeLogValue(respLog.Body)) + if jsonErr != nil { + if c.debug && c.logger != nil { + c.logger.Errorf("%v", jsonErr) + } + } - return c -} + var logBuf bytes.Buffer -// AddRetryCondition adds a RetryConditional function to the Client -func (c *Client) AddRetryCondition(retryCondition RetryConditional) *Client { - c.resty.AddRetryCondition(resty.RetryConditionFunc(retryCondition)) - return c -} + err := respLogTemplate.Execute(&logBuf, map[string]any{ + "Status": respLog.Status, + "Proto": respLog.Proto, + "ReceivedAt": respLog.ReceivedAt, + "TimeDuration": respLog.TimeDuration, + "Headers": formatHeaders(redactHeaders(respLog.Headers)), + "Body": body, + }) + if err == nil { + c.logger.Debugf(sanitizeLogValue(logBuf.String())) + } -// InvalidateCache clears all cached responses for all endpoints. -func (c *Client) InvalidateCache() { - c.cachedEntryLock.Lock() - defer c.cachedEntryLock.Unlock() + resp.Body = io.NopCloser(bytes.NewReader(respBody.Bytes())) - // GC will handle the old map - c.cachedEntries = make(map[string]clientCacheEntry) + return resp } -// InvalidateCacheEndpoint invalidates a single cached endpoint. -func (c *Client) InvalidateCacheEndpoint(endpoint string) error { - u, err := url.Parse(endpoint) - if err != nil { - return fmt.Errorf("failed to parse URL for caching: %w", err) +func (c *Client) decodeResponseBody(resp *http.Response, response any) error { + if err := json.NewDecoder(resp.Body).Decode(response); err != nil { + return c.ErrorAndLogf("failed to decode response: %v", err.Error()) } - c.cachedEntryLock.Lock() - defer c.cachedEntryLock.Unlock() - - delete(c.cachedEntries, u.Path) - return nil } -// SetGlobalCacheExpiration sets the desired time for any cached response -// to be valid for. -func (c *Client) SetGlobalCacheExpiration(expiryTime time.Duration) { - c.cacheExpiration = expiryTime -} - -// UseCache sets whether response caching should be used -func (c *Client) UseCache(value bool) { - c.shouldCache = value -} - -// SetRetryMaxWaitTime sets the maximum delay before retrying a request. -func (c *Client) SetRetryMaxWaitTime(maxWaitTime time.Duration) *Client { - c.resty.SetRetryMaxWaitTime(maxWaitTime) - return c -} +func (c *Client) updateHostURL() { + apiProto := APIProto + baseURL := APIHost + apiVersion := APIVersion -// SetRetryWaitTime sets the default (minimum) delay before retrying a request. -func (c *Client) SetRetryWaitTime(minWaitTime time.Duration) *Client { - c.resty.SetRetryWaitTime(minWaitTime) - return c -} + if c.baseURL != "" { + baseURL = c.baseURL + } -// SetRetryAfter sets the callback function to be invoked with a failed request -// to determine wben it should be retried. -func (c *Client) SetRetryAfter(callback RetryAfter) *Client { - c.resty.SetRetryAfter(resty.RetryAfterFunc(callback)) - return c -} + if c.apiVersion != "" { + apiVersion = c.apiVersion + } -// SetRetryCount sets the maximum retry attempts before aborting. -func (c *Client) SetRetryCount(count int) *Client { - c.resty.SetRetryCount(count) - return c -} + if c.apiProto != "" { + apiProto = c.apiProto + } -// SetPollDelay sets the number of milliseconds to wait between events or status polls. -// Affects all WaitFor* functions and retries. -func (c *Client) SetPollDelay(delay time.Duration) *Client { - c.pollInterval = delay - return c + c.hostURL = strings.TrimRight(fmt.Sprintf("%s://%s/%s", apiProto, baseURL, url.PathEscape(apiVersion)), "/") } -// GetPollDelay gets the number of milliseconds to wait between events or status polls. -// Affects all WaitFor* functions and retries. -func (c *Client) GetPollDelay() time.Duration { - return c.pollInterval -} +func (c *Client) tlsConfig() (*tls.Config, error) { + transport, err := c.Transport() + if err != nil { + return nil, err + } -// SetHeader sets a custom header to be used in all API requests made with the current -// client. -// NOTE: Some headers may be overridden by the individual request functions. -func (c *Client) SetHeader(name, value string) { - c.resty.SetHeader(name, value) -} + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + } + } -func (c *Client) addRetryConditional(retryConditional RetryConditional) *Client { - c.retryConditionals = append(c.retryConditionals, retryConditional) - return c + return transport.TLSClientConfig, nil } func (c *Client) addCachedResponse(endpoint string, response any, expiry *time.Duration) { @@ -819,49 +995,25 @@ func (c *Client) getCachedResponse(endpoint string) any { return c.cachedEntries[endpoint].Data } -func (c *Client) updateHostURL() { - apiProto := APIProto - baseURL := APIHost - apiVersion := APIVersion - - if c.baseURL != "" { - baseURL = c.baseURL +func (c *Client) onRequestLog(rl func(*RequestLog) error) *Client { + if c.requestLog != nil { + c.logger.Warnf("Overwriting an existing on-request-log callback from=%s to=%s", + functionName(c.requestLog), functionName(rl)) } - if c.apiVersion != "" { - apiVersion = c.apiVersion - } + c.requestLog = rl - if c.apiProto != "" { - apiProto = c.apiProto - } - - c.resty.SetBaseURL( - fmt.Sprintf( - "%s://%s/%s", - apiProto, - baseURL, - url.PathEscape(apiVersion), - ), - ) + return c } -func redactLogHeaders(header http.Header) { - for h, redactedValue := range redactHeadersMap { - if header.Get(h) != "" { - header.Set(h, redactedValue) - } - } +func functionName(i any) string { + return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() } func (c *Client) enableLogSanitization() *Client { - c.resty.OnRequestLog(func(r *resty.RequestLog) error { - redactLogHeaders(r.Header) - return nil - }) - - c.resty.OnResponseLog(func(r *resty.ResponseLog) error { - redactLogHeaders(r.Header) + c.onRequestLog(func(r *RequestLog) error { + // masking authorization header + r.Headers.Set("Authorization", "Bearer *******************************") return nil }) @@ -958,16 +1110,3 @@ func generateListCacheURL(endpoint string, opts *ListOptions) (string, error) { return fmt.Sprintf("%s:%s", endpoint, hashedOpts), nil } - -func hasCustomTransport(hc *http.Client) bool { - if hc == nil || hc.Transport == nil { - return false - } - - if _, ok := hc.Transport.(*http.Transport); !ok { - log.Println("[WARN] Custom transport is not allowed with a custom root CA.") - return true - } - - return false -} diff --git a/client_http.go b/client_http.go deleted file mode 100644 index 7f16362c5..000000000 --- a/client_http.go +++ /dev/null @@ -1,56 +0,0 @@ -package linodego - -import ( - "net/http" - "sync" - "time" -) - -// Client is a wrapper around the Resty client -// -//nolint:unused -type httpClient struct { - //nolint:unused - httpClient *http.Client - //nolint:unused - userAgent string - //nolint:unused - debug bool - //nolint:unused - retryConditionals []httpRetryConditional - //nolint:unused - retryAfter httpRetryAfter - - //nolint:unused - pollInterval time.Duration - - //nolint:unused - baseURL string - //nolint:unused - apiVersion string - //nolint:unused - apiProto string - //nolint:unused - selectedProfile string - //nolint:unused - loadedProfile string - - //nolint:unused - configProfiles map[string]ConfigProfile - - // Fields for caching endpoint responses - //nolint:unused - shouldCache bool - //nolint:unused - cacheExpiration time.Duration - //nolint:unused - cachedEntries map[string]clientCacheEntry - //nolint:unused - cachedEntryLock *sync.RWMutex - //nolint:unused - logger httpLogger - //nolint:unused - onBeforeRequest []func(*http.Request) error - //nolint:unused - onAfterResponse []func(*http.Response) error -} diff --git a/client_monitor.go b/client_monitor.go index b54dc8284..3cb959854 100644 --- a/client_monitor.go +++ b/client_monitor.go @@ -2,13 +2,17 @@ package linodego import ( "context" + "crypto/tls" + "crypto/x509" + "encoding/json" "fmt" + "io" "net/http" "net/url" "os" "path" - - "github.com/go-resty/resty/v2" + "path/filepath" + "strings" ) const ( @@ -24,25 +28,36 @@ const ( MonitorAPIEnvVar = "MONITOR_API_TOKEN" ) -// MonitorClient is a wrapper around the Resty client +// MonitorClient is a wrapper around the http client type MonitorClient struct { - resty *resty.Client + httpClient *http.Client debug bool apiBaseURL string apiProtocol string apiVersion string + hostURL string userAgent string + header http.Header + logger Logger } // NewMonitorClient is the entry point for user to create a new MonitorClient // It utilizes default values and looks for environment variables to initialize a MonitorClient. func NewMonitorClient(hc *http.Client) (mClient MonitorClient) { if hc != nil { - mClient.resty = resty.NewWithClient(hc) + mClient.httpClient = hc } else { - mClient.resty = resty.New() + mClient.httpClient = &http.Client{} + } + + // Ensure transport is initialized so SetRootCertificate can configure TLS + if mClient.httpClient.Transport == nil { + mClient.httpClient.Transport = &http.Transport{} } + mClient.header = make(http.Header) + mClient.logger = createLogger() + mClient.SetUserAgent(DefaultUserAgent) baseURL, baseURLExists := os.LookupEnv(MonitorAPIHostVar) @@ -72,24 +87,14 @@ func NewMonitorClient(hc *http.Client) (mClient MonitorClient) { // SetUserAgent sets a custom user-agent for HTTP requests func (mc *MonitorClient) SetUserAgent(ua string) *MonitorClient { mc.userAgent = ua - mc.resty.SetHeader("User-Agent", mc.userAgent) + mc.header.Set("User-Agent", ua) return mc } -// R wraps resty's R method -func (mc *MonitorClient) R(ctx context.Context) *resty.Request { - return mc.resty.R(). - ExpectContentType("application/json"). - SetHeader("Content-Type", "application/json"). - SetContext(ctx). - SetError(APIError{}) -} - -// SetDebug sets the debug on resty's client +// SetDebug sets the debug on the client func (mc *MonitorClient) SetDebug(debug bool) *MonitorClient { mc.debug = debug - mc.resty.SetDebug(debug) return mc } @@ -97,7 +102,7 @@ func (mc *MonitorClient) SetDebug(debug bool) *MonitorClient { // SetLogger allows the user to override the output // logger for debug logs. func (mc *MonitorClient) SetLogger(logger Logger) *MonitorClient { - mc.resty.SetLogger(logger) + mc.logger = logger return mc } @@ -124,21 +129,44 @@ func (mc *MonitorClient) SetAPIVersion(apiVersion string) *MonitorClient { } // SetRootCertificate adds a root certificate to the underlying TLS client config -func (mc *MonitorClient) SetRootCertificate(path string) *MonitorClient { - mc.resty.SetRootCertificate(path) +func (mc *MonitorClient) SetRootCertificate(certPath string) *MonitorClient { + transport, ok := mc.httpClient.Transport.(*http.Transport) + if !ok { + mc.logger.Errorf("current transport is not an *http.Transport instance") + return mc + } + + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + } + } + + if transport.TLSClientConfig.RootCAs == nil { + transport.TLSClientConfig.RootCAs = x509.NewCertPool() + } + + pem, err := os.ReadFile(filepath.Clean(certPath)) + if err != nil { + mc.logger.Errorf("Failed to read root certificate at %s: %s", certPath, err.Error()) + return mc + } + + transport.TLSClientConfig.RootCAs.AppendCertsFromPEM(pem) + return mc } // SetToken sets the API token for all requests from this client func (mc *MonitorClient) SetToken(token string) *MonitorClient { - mc.resty.SetHeader("Authorization", fmt.Sprintf("Bearer %s", token)) + mc.header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) return mc } // SetHeader sets a custom header to be used in all API requests made with the current client. // NOTE: Some headers may be overridden by the individual request functions. func (mc *MonitorClient) SetHeader(name, value string) { - mc.resty.SetHeader(name, value) + mc.header.Set(name, value) } func (mc *MonitorClient) updateMonitorHostURL() { @@ -158,12 +186,66 @@ func (mc *MonitorClient) updateMonitorHostURL() { apiProto = mc.apiProtocol } - mc.resty.SetBaseURL( - fmt.Sprintf( - "%s://%s/%s", - apiProto, - baseURL, - url.PathEscape(apiVersion), - ), + mc.hostURL = fmt.Sprintf( + "%s://%s/%s", + apiProto, + baseURL, + url.PathEscape(apiVersion), ) } + +// doRequest is a generic helper to execute HTTP requests for the MonitorClient +func (mc *MonitorClient) doRequest(ctx context.Context, method, endpoint string, params requestParams) error { + var bodyReader io.Reader + + if params.Body != nil { + if _, err := params.Body.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek body: %w", err) + } + + bodyReader = params.Body + } + + reqURL := fmt.Sprintf("%s/%s", strings.TrimRight(mc.hostURL, "/"), strings.TrimLeft(endpoint, "/")) + + req, err := http.NewRequestWithContext(ctx, method, reqURL, bodyReader) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + for name, values := range mc.header { + for _, value := range values { + req.Header.Set(name, value) + } + } + + if mc.debug && mc.logger != nil { + mc.logger.Debugf("Sending request: %s %s", method, reqURL) + } + + resp, err := mc.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + _, err = coupleAPIErrors(resp, nil) + if err != nil { + return err + } + + if mc.debug && mc.logger != nil { + mc.logger.Debugf("Received response: %s", resp.Status) + } + + if params.Response != nil { + if err := json.NewDecoder(resp.Body).Decode(params.Response); err != nil { + return fmt.Errorf("failed to decode response: %w", err) + } + } + + return nil +} diff --git a/client_test.go b/client_test.go index f925f3678..da4997c17 100644 --- a/client_test.go +++ b/client_test.go @@ -37,39 +37,39 @@ func TestClient_SetAPIVersion(t *testing.T) { client := NewClient(nil) - if client.resty.BaseURL != defaultURL { - t.Fatal(cmp.Diff(client.resty.BaseURL, defaultURL)) + if client.hostURL != defaultURL { + t.Fatal(cmp.Diff(client.hostURL, defaultURL)) } client.SetBaseURL(baseURL) client.SetAPIVersion(apiVersion) - if client.resty.BaseURL != expectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != expectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } // Ensure setting twice does not cause conflicts client.SetBaseURL(updatedBaseURL) client.SetAPIVersion(updatedAPIVersion) - if client.resty.BaseURL != updatedExpectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, updatedExpectedHost)) + if client.hostURL != updatedExpectedHost { + t.Fatal(cmp.Diff(client.hostURL, updatedExpectedHost)) } // Revert client.SetBaseURL(baseURL) client.SetAPIVersion(apiVersion) - if client.resty.BaseURL != expectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != expectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } // Custom protocol client.SetBaseURL(protocolBaseURL) client.SetAPIVersion(protocolAPIVersion) - if client.resty.BaseURL != protocolExpectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != protocolExpectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } } @@ -111,7 +111,7 @@ func TestClient_NewFromEnvToken(t *testing.T) { t.Fatal(err) } - if client.resty.Header.Get("Authorization") != "Bearer blah" { + if client.header.Get("Authorization") != "Bearer blah" { t.Fatal("token not found in auth header: blah") } } @@ -171,8 +171,8 @@ func TestClient_UseURL(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if client.resty.BaseURL != tt.wantBaseURL { - t.Fatalf("mismatched base url: got %s, want %s", client.resty.BaseURL, tt.wantBaseURL) + if client.hostURL != tt.wantBaseURL { + t.Fatalf("mismatched base url: got %s, want %s", client.hostURL, tt.wantBaseURL) } }) } @@ -209,12 +209,12 @@ func TestDebugLogSanitization(t *testing.T) { logger.L.SetOutput(&lgr) mockClient.SetDebug(true) - if !mockClient.resty.Debug { + if !mockClient.debug { t.Fatal("debug should be enabled") } mockClient.SetHeader("Authorization", fmt.Sprintf("Bearer %s", plainTextToken)) - if mockClient.resty.Header.Get("Authorization") != fmt.Sprintf("Bearer %s", plainTextToken) { + if mockClient.header.Get("Authorization") != fmt.Sprintf("Bearer %s", plainTextToken) { t.Fatal("token not found in auth header") } @@ -242,22 +242,25 @@ func TestDebugLogSanitization(t *testing.T) { func TestDoRequest_Success(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"message":"success"}`)) + if r.URL.Path == "/v4/foo/bar" { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"message":"success"}`)) + } else { + http.NotFound(w, r) + } } server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) + client.SetBaseURL(server.URL) - params := RequestParams{ + params := requestParams{ Response: &map[string]string{}, } - err := client.doRequest(context.Background(), http.MethodGet, server.URL, params) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", params, nil) // Pass only the endpoint if err != nil { t.Fatal(cmp.Diff(nil, err)) } @@ -269,31 +272,11 @@ func TestDoRequest_Success(t *testing.T) { } } -func TestDoRequest_FailedEncodeBody(t *testing.T) { - client := &httpClient{ - httpClient: http.DefaultClient, - } - - params := RequestParams{ - Body: map[string]interface{}{ - "invalid": func() {}, - }, - } - - err := client.doRequest(context.Background(), http.MethodPost, "http://example.com", params) - expectedErr := "failed to encode body" - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Fatalf("expected error %q, got: %v", expectedErr, err) - } -} - func TestDoRequest_FailedCreateRequest(t *testing.T) { - client := &httpClient{ - httpClient: http.DefaultClient, - } + client := NewClient(nil) - // Create a request with an invalid URL to simulate a request creation failure - err := client.doRequest(context.Background(), http.MethodGet, "http://invalid url", RequestParams{}) + // Create a request with an invalid method to simulate a request creation failure + err := client.doRequest(context.Background(), "bad method", "/foo/bar", requestParams{}, nil) expectedErr := "failed to create request" if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) @@ -302,26 +285,28 @@ func TestDoRequest_FailedCreateRequest(t *testing.T) { func TestDoRequest_Non2xxStatusCode(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "error", http.StatusInternalServerError) + http.Error(w, "error", http.StatusInternalServerError) // Simulate a 500 Internal Server Error } server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) + client.SetBaseURL(server.URL) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", requestParams{}, nil) if err == nil { t.Fatal("expected error, got nil") } - httpError, ok := err.(Error) + + httpError, ok := err.(*Error) if !ok { t.Fatalf("expected error to be of type Error, got %T", err) } + if httpError.Code != http.StatusInternalServerError { t.Fatalf("expected status code %d, got %d", http.StatusInternalServerError, httpError.Code) } + if !strings.Contains(httpError.Message, "error") { t.Fatalf("expected error message to contain %q, got %v", "error", httpError.Message) } @@ -331,21 +316,21 @@ func TestDoRequest_FailedDecodeResponse(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`invalid json`)) + _, _ = w.Write([]byte(`invalid json`)) // Simulate invalid JSON } server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) + client.SetBaseURL(server.URL) - params := RequestParams{ + params := requestParams{ Response: &map[string]string{}, } - err := client.doRequest(context.Background(), http.MethodGet, server.URL, params) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", params, nil) expectedErr := "failed to decode response" + if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) } @@ -363,24 +348,21 @@ func TestDoRequest_BeforeRequestSuccess(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) + client.SetBaseURL(server.URL) - // Define a mutator that successfully modifies the request mutator := func(req *http.Request) error { req.Header.Set("X-Custom-Header", "CustomValue") return nil } - client.httpOnBeforeRequest(mutator) + client.OnBeforeRequest(mutator) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", requestParams{}, nil) if err != nil { t.Fatalf("expected no error, got: %v", err) } - // Check if the header was successfully added to the captured request if reqHeader := capturedRequest.Header.Get("X-Custom-Header"); reqHeader != "CustomValue" { t.Fatalf("expected X-Custom-Header to be set to CustomValue, got: %v", reqHeader) } @@ -395,18 +377,18 @@ func TestDoRequest_BeforeRequestError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) + client.SetBaseURL(server.URL) mutator := func(req *http.Request) error { return errors.New("mutator error") } - client.httpOnBeforeRequest(mutator) + client.OnBeforeRequest(mutator) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", requestParams{}, nil) expectedErr := "failed to mutate before request" + if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) } @@ -425,23 +407,21 @@ func TestDoRequest_AfterResponseSuccess(t *testing.T) { tr := &testRoundTripper{ Transport: server.Client().Transport, } - client := &httpClient{ - httpClient: &http.Client{Transport: tr}, - } + client := NewClient(&http.Client{Transport: tr}) + client.SetBaseURL(server.URL) mutator := func(resp *http.Response) error { resp.Header.Set("X-Modified-Header", "ModifiedValue") return nil } - client.httpOnAfterResponse(mutator) + client.OnAfterResponse(mutator) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", requestParams{}, nil) if err != nil { t.Fatalf("expected no error, got: %v", err) } - // Check if the header was successfully added to the response if respHeader := tr.Response.Header.Get("X-Modified-Header"); respHeader != "ModifiedValue" { t.Fatalf("expected X-Modified-Header to be set to ModifiedValue, got: %v", respHeader) } @@ -456,17 +436,16 @@ func TestDoRequest_AfterResponseError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - client := &httpClient{ - httpClient: server.Client(), - } + client := NewClient(server.Client()) + client.SetBaseURL(server.URL) mutator := func(resp *http.Response) error { return errors.New("mutator error") } - client.httpOnAfterResponse(mutator) + client.OnAfterResponse(mutator) - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + err := client.doRequest(context.Background(), http.MethodGet, "/foo/bar", requestParams{}, nil) expectedErr := "failed to mutate after response" if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) @@ -478,11 +457,9 @@ func TestDoRequestLogging_Success(t *testing.T) { logger := createLogger() logger.l.SetOutput(&logBuffer) // Redirect log output to buffer - client := &httpClient{ - httpClient: http.DefaultClient, - debug: true, - logger: logger, - } + client := NewClient(nil) + client.SetDebug(true) + client.SetLogger(logger) handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -490,29 +467,46 @@ func TestDoRequestLogging_Success(t *testing.T) { _, _ = w.Write([]byte(`{"message":"success"}`)) } server := httptest.NewServer(http.HandlerFunc(handler)) + client.SetBaseURL(server.URL) defer server.Close() - params := RequestParams{ + params := requestParams{ Response: &map[string]string{}, } + endpoint := "/foo/bar" - err := client.doRequest(context.Background(), http.MethodGet, server.URL, params) + err := client.doRequest(context.Background(), http.MethodGet, endpoint, params, nil) if err != nil { t.Fatal(cmp.Diff(nil, err)) } logInfo := logBuffer.String() - logInfoWithoutTimestamps := removeTimestamps(logInfo) - // Expected logs with templates filled in - expectedRequestLog := "DEBUG RESTY Sending request:\nMethod: GET\nURL: " + server.URL + "\nHeaders: map[Accept:[application/json] Content-Type:[application/json]]\nBody: " - expectedResponseLog := "DEBUG RESTY Received response:\nStatus: 200 OK\nHeaders: map[Content-Length:[21] Content-Type:[text/plain; charset=utf-8]]\nBody: {\"message\":\"success\"}" + expectedRequestParts := []string{ + "GET /v4/foo/bar HTTP/1.1", + "Accept: application/json", + "Authorization: Bearer *******************************", + "Content-Type: application/json", + "User-Agent: linodego/dev https://github.com/linode/linodego", + } - if !strings.Contains(logInfo, expectedRequestLog) { - t.Fatalf("expected log %q not found in logs", expectedRequestLog) + expectedResponseParts := []string{ + "STATUS: 200 OK", + "PROTO: HTTP/1.1", + "Content-Length: 21", + "Content-Type: application/json", + `"message": "success"`, } - if !strings.Contains(logInfoWithoutTimestamps, expectedResponseLog) { - t.Fatalf("expected log %q not found in logs", expectedResponseLog) + + for _, part := range expectedRequestParts { + if !strings.Contains(logInfo, part) { + t.Fatalf("expected request part %q not found in logs", part) + } + } + for _, part := range expectedResponseParts { + if !strings.Contains(logInfo, part) { + t.Fatalf("expected response part %q not found in logs", part) + } } } @@ -521,26 +515,19 @@ func TestDoRequestLogging_Error(t *testing.T) { logger := createLogger() logger.l.SetOutput(&logBuffer) // Redirect log output to buffer - client := &httpClient{ - httpClient: http.DefaultClient, - debug: true, - logger: logger, - } - - params := RequestParams{ - Body: map[string]interface{}{ - "invalid": func() {}, - }, - } + client := NewClient(nil) + client.SetDebug(true) + client.SetLogger(logger) - err := client.doRequest(context.Background(), http.MethodPost, "http://example.com", params) - expectedErr := "failed to encode body" + // Create a request with an invalid method to simulate a request creation failure + err := client.doRequest(context.Background(), "bad method", "/foo/bar", requestParams{}, nil) + expectedErr := "failed to create request" if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) } logInfo := logBuffer.String() - expectedLog := "ERROR RESTY failed to encode body" + expectedLog := "ERROR failed to create request" if !strings.Contains(logInfo, expectedLog) { t.Fatalf("expected log %q not found in logs", expectedLog) @@ -638,9 +625,9 @@ func TestClient_CustomRootCAWithoutCustomRoundTripper(t *testing.T) { } client := NewClient(test.httpClient) - transport, err := client.resty.Transport() - if err != nil { - t.Fatal(err) + transport, ok := client.httpClient.Transport.(*http.Transport) + if !ok { + t.Fatal("expected *http.Transport") } if test.setCA && (transport.TLSClientConfig == nil || transport.TLSClientConfig.RootCAs == nil) { t.Error("expected root CAs to be set") @@ -669,39 +656,39 @@ func TestMonitorClient_SetAPIBasics(t *testing.T) { client := NewMonitorClient(nil) - if client.resty.BaseURL != defaultURL { - t.Fatal(cmp.Diff(client.resty.BaseURL, defaultURL)) + if client.hostURL != defaultURL { + t.Fatal(cmp.Diff(client.hostURL, defaultURL)) } client.SetBaseURL(baseURL) client.SetAPIVersion(apiVersion) - if client.resty.BaseURL != expectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != expectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } // Ensure setting twice does not cause conflicts client.SetBaseURL(updatedBaseURL) client.SetAPIVersion(updatedAPIVersion) - if client.resty.BaseURL != updatedExpectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, updatedExpectedHost)) + if client.hostURL != updatedExpectedHost { + t.Fatal(cmp.Diff(client.hostURL, updatedExpectedHost)) } // Revert client.SetBaseURL(baseURL) client.SetAPIVersion(apiVersion) - if client.resty.BaseURL != expectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != expectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } // Custom protocol client.SetBaseURL(protocolBaseURL) client.SetAPIVersion(protocolAPIVersion) - if client.resty.BaseURL != protocolExpectedHost { - t.Fatal(cmp.Diff(client.resty.BaseURL, expectedHost)) + if client.hostURL != protocolExpectedHost { + t.Fatal(cmp.Diff(client.hostURL, expectedHost)) } } @@ -787,7 +774,7 @@ func TestEnableLogSanitization(t *testing.T) { "Authorization": []string{"Bearer " + plainTextToken}, })) - _, err := mockClient.resty.R().Get("https://api.linode.com/v4/test") + err := mockClient.doRequest(context.Background(), http.MethodGet, "/test", requestParams{}, nil) require.NoError(t, err) logOutput := logBuf.String() diff --git a/config_test.go b/config_test.go index b4b3db418..628cc1415 100644 --- a/config_test.go +++ b/config_test.go @@ -42,11 +42,11 @@ func TestConfig_LoadWithDefaults(t *testing.T) { expectedURL := "https://api.cool.linode.com/v4beta" - if client.resty.BaseURL != expectedURL { - t.Fatalf("mismatched host url: %s != %s", client.resty.BaseURL, expectedURL) + if client.hostURL != expectedURL { + t.Fatalf("mismatched host url: %s != %s", client.hostURL, expectedURL) } - if client.resty.Header.Get("Authorization") != "Bearer "+p.APIToken { + if client.header.Get("Authorization") != "Bearer "+p.APIToken { t.Fatalf("token not found in auth header: %s", p.APIToken) } } @@ -88,11 +88,11 @@ func TestConfig_OverrideDefaults(t *testing.T) { expectedURL := "https://api.cool.linode.com/v4" - if client.resty.BaseURL != expectedURL { - t.Fatalf("mismatched host url: %s != %s", client.resty.BaseURL, expectedURL) + if client.hostURL != expectedURL { + t.Fatalf("mismatched host url: %s != %s", client.hostURL, expectedURL) } - if client.resty.Header.Get("Authorization") != "Bearer "+p.APIToken { + if client.header.Get("Authorization") != "Bearer "+p.APIToken { t.Fatalf("token not found in auth header: %s", p.APIToken) } } @@ -124,7 +124,7 @@ func TestConfig_NoDefaults(t *testing.T) { t.Fatalf("mismatched api token: %s != %s", p.APIToken, "mytoken") } - if client.resty.Header.Get("Authorization") != "Bearer "+p.APIToken { + if client.header.Get("Authorization") != "Bearer "+p.APIToken { t.Fatalf("token not found in auth header: %s", p.APIToken) } } diff --git a/errors.go b/errors.go index 873b40450..db7fc2a34 100644 --- a/errors.go +++ b/errors.go @@ -1,7 +1,7 @@ package linodego import ( - "encoding/json" + "bytes" "errors" "fmt" "io" @@ -9,8 +9,6 @@ import ( "reflect" "slices" "strings" - - "github.com/go-resty/resty/v2" ) const ( @@ -49,74 +47,43 @@ type APIError struct { Errors []APIErrorReason `json:"errors"` } -// String returns the error reason in a formatted string -func (r APIErrorReason) String() string { - return fmt.Sprintf("[%s] %s", r.Field, r.Reason) -} - -func coupleAPIErrors(r *resty.Response, err error) (*resty.Response, error) { +//nolint:nestif,unparam +func coupleAPIErrors(resp *http.Response, err error) (*http.Response, error) { if err != nil { - // an error was raised in go code, no need to check the resty Response return nil, NewError(err) } - if r.Error() == nil { - // no error in the resty Response - return r, nil - } - - // handle the resty Response errors - - // Check that response is of the correct content-type before unmarshalling - expectedContentType := r.Request.Header.Get("Accept") - responseContentType := r.Header().Get("Content-Type") - - // If the upstream Linode API server being fronted fails to respond to the request, - // the http server will respond with a default "Bad Gateway" page with Content-Type - // "text/html". - if r.StatusCode() == http.StatusBadGateway && responseContentType == "text/html" { //nolint:goconst - return nil, Error{Code: http.StatusBadGateway, Message: http.StatusText(http.StatusBadGateway)} - } - - if responseContentType != expectedContentType { - msg := fmt.Sprintf( - "Unexpected Content-Type: Expected: %v, Received: %v\nResponse body: %s", - expectedContentType, - responseContentType, - string(r.Body()), - ) - - return nil, Error{Code: r.StatusCode(), Message: msg} - } - - apiError, ok := r.Error().(*APIError) - if !ok || (ok && len(apiError.Errors) == 0) { - return r, nil + if resp == nil { + return nil, NewError(fmt.Errorf("response is nil")) } - return nil, NewError(r) -} - -//nolint:unused -func coupleAPIErrorsHTTP(resp *http.Response, err error) (*http.Response, error) { - if err != nil { - // an error was raised in go code, no need to check the http.Response - return nil, NewError(err) - } - - if resp == nil || resp.StatusCode < 200 || resp.StatusCode >= 300 { + if resp.StatusCode < 200 || resp.StatusCode >= 300 { // Check that response is of the correct content-type before unmarshalling - expectedContentType := resp.Request.Header.Get("Accept") + expectedContentType := "" + if resp.Request != nil && resp.Request.Header != nil { + expectedContentType = resp.Request.Header.Get("Accept") + } + responseContentType := resp.Header.Get("Content-Type") // If the upstream server fails to respond to the request, - // the http server will respond with a default error page with Content-Type "text/html". - if resp.StatusCode == http.StatusBadGateway && responseContentType == "text/html" { //nolint:goconst - return nil, Error{Code: http.StatusBadGateway, Message: http.StatusText(http.StatusBadGateway)} + // the HTTP server will respond with a default error page with Content-Type "text/html". + if resp.StatusCode == http.StatusBadGateway && responseContentType == "text/html" { + return nil, &Error{Code: http.StatusBadGateway, Message: http.StatusText(http.StatusBadGateway), Response: resp} } if responseContentType != expectedContentType { - bodyBytes, _ := io.ReadAll(resp.Body) + if resp.Body == nil { + return nil, NewError(fmt.Errorf("response body is nil")) + } + + bodyBytes, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, NewError(fmt.Errorf("failed to read response body: %w", readErr)) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + msg := fmt.Sprintf( "Unexpected Content-Type: Expected: %v, Received: %v\nResponse body: %s", expectedContentType, @@ -124,22 +91,22 @@ func coupleAPIErrorsHTTP(resp *http.Response, err error) (*http.Response, error) string(bodyBytes), ) - return nil, Error{Code: resp.StatusCode, Message: msg} + return nil, &Error{Code: resp.StatusCode, Message: msg, Response: resp} } - var apiError APIError - if err := json.NewDecoder(resp.Body).Decode(&apiError); err != nil { - return nil, NewError(fmt.Errorf("failed to decode response body: %w", err)) + // Must check if there is no list of reasons in the error before making a call to NewError + apiError, ok := getAPIError(resp) + if !ok { + return nil, NewError(fmt.Errorf("failed to decode response body")) } if len(apiError.Errors) == 0 { return resp, nil } - return nil, Error{Code: resp.StatusCode, Message: apiError.Errors[0].String()} + return nil, NewError(resp) } - // no error in the http.Response return resp, nil } @@ -156,7 +123,7 @@ func (e APIError) Error() string { // - ErrorFromString (1) from a string // - ErrorFromError (2) for an error // - ErrorFromStringer (3) for a Stringer -// - HTTP Status Codes (100-600) for a resty.Response object +// - HTTP Status Codes (100-600) for a http.Response object func NewError(err any) *Error { if err == nil { return nil @@ -165,17 +132,17 @@ func NewError(err any) *Error { switch e := err.(type) { case *Error: return e - case *resty.Response: - apiError, ok := e.Error().(*APIError) + case *http.Response: + apiError, ok := getAPIError(e) if !ok { - return &Error{Code: ErrorUnsupported, Message: "Unexpected Resty Error Response, no error"} + return &Error{Code: ErrorUnsupported, Message: "Unexpected HTTP Error Response, no error"} } return &Error{ - Code: e.RawResponse.StatusCode, + Code: e.StatusCode, Message: apiError.Error(), - Response: e.RawResponse, + Response: e, } case error: return &Error{Code: ErrorFromError, Message: e.Error()} diff --git a/errors_test.go b/errors_test.go index 68428fcd6..4a182c36e 100644 --- a/errors_test.go +++ b/errors_test.go @@ -9,9 +9,10 @@ import ( "io" "net/http" "net/http/httptest" + "strconv" + "strings" "testing" - "github.com/go-resty/resty/v2" "github.com/google/go-cmp/cmp" ) @@ -27,10 +28,10 @@ func (e testError) Error() string { return string(e) } -func restyError(reason, field string) *resty.Response { +func httpError(reason, field string) *http.Response { var reasons []APIErrorReason - // allow for an empty reasons + // Allow for an empty reasons if reason != "" && field != "" { reasons = append(reasons, APIErrorReason{ Reason: reason, @@ -38,15 +39,18 @@ func restyError(reason, field string) *resty.Response { }) } - return &resty.Response{ - RawResponse: &http.Response{ - StatusCode: 500, - }, - Request: &resty.Request{ - Error: &APIError{ - Errors: reasons, - }, - }, + apiError := &APIError{ + Errors: reasons, + } + + body, err := json.Marshal(apiError) + if err != nil { + panic("Failed to marshal APIError") + } + + return &http.Response{ + StatusCode: 500, + Body: io.NopCloser(bytes.NewReader(body)), } } @@ -73,12 +77,12 @@ func TestNewError(t *testing.T) { t.Error("Error should be itself") } - if err := NewError(&resty.Response{Request: &resty.Request{}}); err.Message != "Unexpected Resty Error Response, no error" { - t.Error("Unexpected Resty Error Response, no error") + if err := NewError(&http.Response{Request: &http.Request{}}); err.Message != "Unexpected HTTP Error Response, no error" { + t.Error("Unexpected HTTP Error Response, no error") } - if err := NewError(restyError("testreason", "testfield")); err.Message != "[testfield] testreason" { - t.Error("rest response error should should be set") + if err := NewError(httpError("testreason", "testfield")); err.Message != "[testfield] testreason" { + t.Error("http response error should should be set") } if err := NewError("stringerror"); err.Message != "stringerror" || err.Code != ErrorFromString { @@ -119,98 +123,6 @@ func TestCoupleAPIErrors(t *testing.T) { } }) - t.Run("resty 500 response error with reasons", func(t *testing.T) { - if _, err := coupleAPIErrors(restyError("testreason", "testfield"), nil); err.Error() != "[500] [testfield] testreason" { - t.Error("resty error should return with proper format [code] [field] reason") - } - }) - - t.Run("resty 500 response error without reasons", func(t *testing.T) { - if _, err := coupleAPIErrors(restyError("", ""), nil); err != nil { - t.Error("resty error with no reasons should return no error") - } - }) - - t.Run("resty response with nil error", func(t *testing.T) { - emptyErr := &resty.Response{ - RawResponse: &http.Response{ - StatusCode: 500, - }, - Request: &resty.Request{ - Error: nil, - }, - } - if _, err := coupleAPIErrors(emptyErr, nil); err != nil { - t.Error("resty error with no reasons should return no error") - } - }) - - t.Run("generic html error", func(t *testing.T) { - rawResponse := ` -