Skip to content

Commit 22e4e64

Browse files
make oauth and timeout setup happen outside the http client
1 parent c9eac59 commit 22e4e64

4 files changed

Lines changed: 18 additions & 71 deletions

File tree

core/httpclient/transport.go

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@ import (
44
"net/http"
55
"net/url"
66
"strings"
7-
"time"
8-
9-
"golang.org/x/oauth2"
107
)
118

129
// BaseURLTransport is an http.RoundTripper that rewrites every request URL
@@ -39,24 +36,17 @@ func (t *BaseURLTransport) RoundTrip(req *http.Request) (*http.Response, error)
3936
return next.RoundTrip(newReq)
4037
}
4138

42-
// NewClient builds an *http.Client with BaseURLTransport and optionally
43-
// oauth2 bearer auth configured. The transport chain is:
44-
//
45-
// oauth2.Transport (if token provided) → BaseURLTransport → DefaultTransport
46-
func NewClient(rawBaseURL, token string, timeout time.Duration) (*http.Client, error) {
39+
// NewClient builds an *http.Client with BaseURLTransport configured.
40+
// Callers are responsible for layering additional transports (e.g. auth) and
41+
// setting Timeout on the returned client.
42+
func NewClient(rawBaseURL string) (*http.Client, error) {
4743
u, err := url.Parse(rawBaseURL)
4844
if err != nil {
4945
return nil, err
5046
}
5147

52-
var transport http.RoundTripper = &BaseURLTransport{
48+
return &http.Client{Transport: &BaseURLTransport{
5349
BaseURL: u,
5450
Next: http.DefaultTransport,
55-
}
56-
if token != "" {
57-
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})
58-
transport = &oauth2.Transport{Source: ts, Base: transport}
59-
}
60-
61-
return &http.Client{Transport: transport, Timeout: timeout}, nil
51+
}}, nil
6252
}

core/httpclient/transport_test.go

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@ package httpclient
22

33
import (
44
"net/http"
5-
"net/http/httptest"
65
"net/url"
76
"testing"
8-
"time"
97

108
"github.com/stretchr/testify/assert"
119
"github.com/stretchr/testify/require"
@@ -78,55 +76,10 @@ func TestBaseURLTransport_DoesNotMutateOriginalRequest(t *testing.T) {
7876
}
7977

8078
func TestNewClient_InvalidURL(t *testing.T) {
81-
_, err := NewClient("://invalid", "", 30*time.Second)
79+
_, err := NewClient("://invalid")
8280
require.Error(t, err)
8381
}
8482

85-
func TestNewClient_SetsTimeout(t *testing.T) {
86-
client, err := NewClient("https://api.github.com", "", 10*time.Second)
87-
require.NoError(t, err)
88-
assert.Equal(t, 10*time.Second, client.Timeout)
89-
}
90-
91-
func TestNewClient_AuthHeader(t *testing.T) {
92-
tests := []struct {
93-
name string
94-
token string
95-
wantAuthHeader string
96-
}{
97-
{
98-
name: "no token, no auth header",
99-
token: "",
100-
wantAuthHeader: "",
101-
},
102-
{
103-
name: "with token, adds bearer auth header",
104-
token: "my-token",
105-
wantAuthHeader: "Bearer my-token",
106-
},
107-
}
108-
109-
for _, tt := range tests {
110-
t.Run(tt.name, func(t *testing.T) {
111-
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
112-
w.Header().Set("X-Captured-Auth", r.Header.Get("Authorization"))
113-
w.WriteHeader(http.StatusOK)
114-
}))
115-
defer server.Close()
116-
117-
client, err := NewClient(server.URL, tt.token, 30*time.Second)
118-
require.NoError(t, err)
119-
120-
req, err := http.NewRequest(http.MethodGet, "/", nil)
121-
require.NoError(t, err)
122-
123-
resp, err := client.Do(req)
124-
require.NoError(t, err)
125-
assert.Equal(t, tt.wantAuthHeader, resp.Header.Get("X-Captured-Auth"))
126-
})
127-
}
128-
}
129-
13083
func mustParseURL(t *testing.T, raw string) *url.URL {
13184
t.Helper()
13285
u, err := url.Parse(raw)

example/server/orchestrator/main.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ import (
2828
"time"
2929

3030
_ "github.com/go-sql-driver/mysql"
31+
"golang.org/x/oauth2"
32+
3133
"github.com/uber-go/tally/v4"
3234
"github.com/uber/submitqueue/core/consumer"
3335
"github.com/uber/submitqueue/core/httpclient"
@@ -572,15 +574,18 @@ func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeCh
572574
// newChangeProvider creates a ChangeProvider for GitHub (github.com).
573575
// Configured via GITHUB_BASE_URL, GITHUB_TOKEN, and GITHUB_TIMEOUT environment variables.
574576
func newChangeProvider(logger *zap.Logger, scope tally.Scope) (changeprovider.ChangeProvider, error) {
575-
client, err := httpclient.NewClient(
576-
getEnv("GITHUB_BASE_URL", "https://api.github.com"),
577-
os.Getenv("GITHUB_TOKEN"),
578-
parseTimeout(os.Getenv("GITHUB_TIMEOUT"), 30*time.Second),
579-
)
577+
client, err := httpclient.NewClient(getEnv("GITHUB_BASE_URL", "https://api.github.com"))
580578
if err != nil {
581579
return nil, fmt.Errorf("failed to build GitHub HTTP client: %w", err)
582580
}
583581

582+
if token := os.Getenv("GITHUB_TOKEN"); token != "" {
583+
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})
584+
client.Transport = &oauth2.Transport{Source: ts, Base: client.Transport}
585+
}
586+
587+
client.Timeout = parseTimeout(os.Getenv("GITHUB_TIMEOUT"), 30*time.Second)
588+
584589
return githubprovider.NewProvider(githubprovider.Params{
585590
HTTPClient: client,
586591
Logger: logger.Sugar(),

extension/changeprovider/github/provider_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"net/http"
77
"net/http/httptest"
88
"testing"
9-
"time"
109

1110
"github.com/stretchr/testify/assert"
1211
"github.com/stretchr/testify/require"
@@ -20,7 +19,7 @@ import (
2019

2120
func newTestProvider(t *testing.T, serverURL string) changeprovider.ChangeProvider {
2221
t.Helper()
23-
client, err := httpclient.NewClient(serverURL, "", 30*time.Second)
22+
client, err := httpclient.NewClient(serverURL)
2423
require.NoError(t, err)
2524
return NewProvider(Params{
2625
HTTPClient: client,

0 commit comments

Comments
 (0)