From d12205d259f8e6a7e2cf46865489d0d6bb1f5fdf Mon Sep 17 00:00:00 2001 From: rprithyani Date: Thu, 12 Mar 2026 20:30:45 +0000 Subject: [PATCH 1/4] feat(provider): implement change provider interface with GitHub integration - Add ChangeProvider interface and GitHub implementation - Use Client wrapper pattern with configurable auth - Add tests for the change provider --- example/server/orchestrator/BUILD.bazel | 2 + example/server/orchestrator/main.go | 60 +++++- extension/changeprovider/change_provider.go | 10 +- extension/changeprovider/github/BUILD.bazel | 42 ++++ extension/changeprovider/github/config.go | 47 ++++ .../changeprovider/github/config_test.go | 97 +++++++++ extension/changeprovider/github/convert.go | 42 ++++ extension/changeprovider/github/graphql.go | 157 ++++++++++++++ .../changeprovider/github/graphql_test.go | 145 +++++++++++++ extension/changeprovider/github/provider.go | 163 ++++++++++++++ .../changeprovider/github/provider_test.go | 202 ++++++++++++++++++ extension/changeprovider/github/validate.go | 46 ++++ .../changeprovider/github/validate_test.go | 99 +++++++++ orchestrator/controller/validate/BUILD.bazel | 3 + orchestrator/controller/validate/validate.go | 77 +++++-- .../controller/validate/validate_test.go | 24 ++- 16 files changed, 1186 insertions(+), 30 deletions(-) create mode 100644 extension/changeprovider/github/BUILD.bazel create mode 100644 extension/changeprovider/github/config.go create mode 100644 extension/changeprovider/github/config_test.go create mode 100644 extension/changeprovider/github/convert.go create mode 100644 extension/changeprovider/github/graphql.go create mode 100644 extension/changeprovider/github/graphql_test.go create mode 100644 extension/changeprovider/github/provider.go create mode 100644 extension/changeprovider/github/provider_test.go create mode 100644 extension/changeprovider/github/validate.go create mode 100644 extension/changeprovider/github/validate_test.go diff --git a/example/server/orchestrator/BUILD.bazel b/example/server/orchestrator/BUILD.bazel index 26c5316d..7d2b09fe 100644 --- a/example/server/orchestrator/BUILD.bazel +++ b/example/server/orchestrator/BUILD.bazel @@ -13,6 +13,8 @@ go_library( deps = [ "//core/consumer", "//entity", + "//extension/changeprovider", + "//extension/changeprovider/github", "//extension/counter", "//extension/counter/mysql", "//extension/mergechecker", diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index d7daae4a..2db5391d 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -31,6 +31,8 @@ import ( "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" "github.com/uber/submitqueue/entity" + "github.com/uber/submitqueue/extension/changeprovider" + githubprovider "github.com/uber/submitqueue/extension/changeprovider/github" "github.com/uber/submitqueue/extension/counter" mysqlcounter "github.com/uber/submitqueue/extension/counter/mysql" "github.com/uber/submitqueue/extension/mergechecker" @@ -190,8 +192,11 @@ func run() error { // Create merge checker mc := newMergeChecker(logger, scope) + // Create change provider + cp := newChangeProvider(logger, scope) + // Register controllers - if err := registerControllers(c, logger.Sugar(), scope, registry, mc, cnt, store); err != nil { + if err := registerControllers(c, logger.Sugar(), scope, registry, mc, cp, cnt, store); err != nil { return err } @@ -376,7 +381,7 @@ func newTopicRegistry(q extqueue.Queue, subscriberName string) (consumer.TopicRe // │ │ │ // └────────┴────────────────────────┘ -func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry, mc mergechecker.MergeChecker, cnt counter.Counter, store storage.Storage) error { +func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry, mc mergechecker.MergeChecker, cp changeprovider.ChangeProvider, cnt counter.Counter, store storage.Storage) error { requestController := start.NewController( logger, scope, @@ -395,6 +400,7 @@ func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope t store, registry, mc, + cp, consumer.TopicKeyValidate, "orchestrator-validate", ) @@ -514,6 +520,36 @@ func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope t return nil } +// getEnv returns environment variable value or default if not set. +func getEnv(key, defaultVal string) string { + if val := os.Getenv(key); val != "" { + return val + } + return defaultVal +} + +// parseTimeout parses a duration from environment variable with fallback to default. +// Returns defaultVal if envVal is empty or cannot be parsed. +func parseTimeout(envVal string, defaultVal time.Duration) time.Duration { + if envVal == "" { + return defaultVal + } + if d, err := time.ParseDuration(envVal); err == nil { + return d + } + return defaultVal +} + +// buildGitHubHTTPClient creates an http.Client configured for GitHub API calls. +// Configures timeout and optional bearer token authentication. +func buildGitHubHTTPClient(token string, timeout time.Duration) *http.Client { + httpClient := &http.Client{Timeout: timeout} + if token != "" { + httpClient.Transport = &bearerTransport{token: token} + } + return httpClient +} + // newMergeChecker creates a MergeChecker for GitHub (github.com). // Configured via GITHUB_TOKEN and GITHUB_GRAPHQL_URL environment variables. func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeChecker { @@ -539,6 +575,26 @@ func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeCh }) } +// newChangeProvider creates a ChangeProvider for GitHub (github.com). +// Configured via GITHUB_BASE_URL, GITHUB_TOKEN, and GITHUB_TIMEOUT environment variables. +// Uses pure dependency injection - creates http.Client with auth configured in Transport. +func newChangeProvider(logger *zap.Logger, scope tally.Scope) changeprovider.ChangeProvider { + // 1. Read configuration from environment + baseURL := getEnv("GITHUB_BASE_URL", "https://api.github.com") + token := os.Getenv("GITHUB_TOKEN") + timeout := parseTimeout(os.Getenv("GITHUB_TIMEOUT"), 30*time.Second) + + // 2. Build HTTP client with caller-controlled config (auth + timeout) + httpClient := buildGitHubHTTPClient(token, timeout) + + // 3. Inject into provider + return githubprovider.NewProvider(githubprovider.Params{ + Client: githubprovider.NewClient(httpClient, baseURL), + Logger: logger.Sugar(), + MetricsScope: scope.SubScope("changeprovider"), + }) +} + // bearerTransport is an http.RoundTripper that adds a Bearer token to requests. type bearerTransport struct { token string diff --git a/extension/changeprovider/change_provider.go b/extension/changeprovider/change_provider.go index 377cbd08..72bdd5af 100644 --- a/extension/changeprovider/change_provider.go +++ b/extension/changeprovider/change_provider.go @@ -44,8 +44,9 @@ type ChangedFile struct { // ChangeInfo contains metadata and file changes for a code change. type ChangeInfo struct { - // ID is the change identifier (e.g., "PR: uber-code/go-code/1" or "diff: uber-code/go-code/D1"). - ID string + // URI is the full change URI for correlation with the input request + // (e.g., "github://uber/repo/98/abc123sha" or "phab://D123/xyz789"). + URI string // User is the author of the change. User User // ChangedFiles is the list of files modified in this change. Order is unspecified. @@ -56,6 +57,7 @@ type ChangeInfo struct { // Each implementation is configured for a specific provider (GitHub, GitLab, Phabricator). type ChangeProvider interface { // Get retrieves change information for the provided Change. - // Returns the change info containing metadata and file changes. - Get(ctx context.Context, change entity.Change) (ChangeInfo, error) + // For a Change with multiple URIs (e.g., stacked PRs), returns one ChangeInfo per URI. + // Returns a slice of ChangeInfo, one for each change in the stack. + Get(ctx context.Context, change entity.Change) ([]ChangeInfo, error) } diff --git a/extension/changeprovider/github/BUILD.bazel b/extension/changeprovider/github/BUILD.bazel new file mode 100644 index 00000000..dd8882b4 --- /dev/null +++ b/extension/changeprovider/github/BUILD.bazel @@ -0,0 +1,42 @@ +load("@rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "github", + srcs = [ + "config.go", + "convert.go", + "graphql.go", + "provider.go", + "validate.go", + ], + importpath = "github.com/uber/submitqueue/extension/changeprovider/github", + visibility = ["//visibility:public"], + deps = [ + "//core/metrics", + "//entity", + "//entity/github", + "//extension/changeprovider", + "@com_github_uber_go_tally_v4//:tally", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "github_test", + srcs = [ + "config_test.go", + "graphql_test.go", + "provider_test.go", + "validate_test.go", + ], + embed = [":github"], + deps = [ + "//entity", + "//entity/github", + "//extension/changeprovider", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@com_github_uber_go_tally_v4//:tally", + "@org_uber_go_zap//zaptest", + ], +) diff --git a/extension/changeprovider/github/config.go b/extension/changeprovider/github/config.go new file mode 100644 index 00000000..9c646cd9 --- /dev/null +++ b/extension/changeprovider/github/config.go @@ -0,0 +1,47 @@ +package github + +import ( + "net/http" +) + +// Client is a GitHub API client that encapsulates connection details and authentication. +// The client is protocol-agnostic - it provides helpers for both GraphQL and REST endpoints. +type Client struct { + httpClient *http.Client + baseURL string +} + +// NewClient creates a new GitHub API client with a pre-configured HTTP client. +// The caller is responsible for configuring authentication in the HTTP client's Transport. +// +// Parameters: +// - httpClient: Configured HTTP client (with auth, timeout, transport, etc.) +// - baseURL: GitHub instance base URL (e.g., "https://api.github.com" or "https://ghe.company.com/api") +func NewClient(httpClient *http.Client, baseURL string) *Client { + return &Client{ + httpClient: httpClient, + baseURL: baseURL, + } +} + +// HTTPClient returns the configured HTTP client. +func (c *Client) HTTPClient() *http.Client { + return c.httpClient +} + +// BaseURL returns the configured GitHub base URL. +func (c *Client) BaseURL() string { + return c.baseURL +} + +// GraphQLURL returns the GitHub GraphQL endpoint URL. +// Constructs the URL by appending "/graphql" to the base URL. +func (c *Client) GraphQLURL() string { + return c.baseURL + "/graphql" +} + +// RESTURL constructs a GitHub REST API endpoint URL. +// The path should start with "/" (e.g., "/repos/uber/submitqueue/pulls/123"). +func (c *Client) RESTURL(path string) string { + return c.baseURL + path +} diff --git a/extension/changeprovider/github/config_test.go b/extension/changeprovider/github/config_test.go new file mode 100644 index 00000000..34d6e81f --- /dev/null +++ b/extension/changeprovider/github/config_test.go @@ -0,0 +1,97 @@ +package github + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewClient(t *testing.T) { + httpClient := &http.Client{Timeout: 30 * time.Second} + baseURL := "https://api.github.com" + + client := NewClient(httpClient, baseURL) + + assert.NotNil(t, client) + assert.Equal(t, httpClient, client.HTTPClient()) + assert.Equal(t, baseURL, client.BaseURL()) +} + +func TestClient_GraphQLURL(t *testing.T) { + tests := []struct { + name string + baseURL string + expected string + }{ + { + name: "standard github", + baseURL: "https://api.github.com", + expected: "https://api.github.com/graphql", + }, + { + name: "github enterprise", + baseURL: "https://ghe.example.com/api", + expected: "https://ghe.example.com/api/graphql", + }, + { + name: "localhost", + baseURL: "http://localhost:8080", + expected: "http://localhost:8080/graphql", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewClient(&http.Client{}, tt.baseURL) + assert.Equal(t, tt.expected, client.GraphQLURL()) + }) + } +} + +func TestClient_BaseURL(t *testing.T) { + tests := []struct { + name string + baseURL string + }{ + {name: "standard github", baseURL: "https://api.github.com"}, + {name: "github enterprise", baseURL: "https://ghe.example.com/api"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewClient(&http.Client{}, tt.baseURL) + assert.Equal(t, tt.baseURL, client.BaseURL()) + }) + } +} + +func TestClient_RESTURL(t *testing.T) { + tests := []struct { + name string + baseURL string + path string + expected string + }{ + { + name: "repos endpoint", + baseURL: "https://api.github.com", + path: "/repos/uber/submitqueue/pulls/123", + expected: "https://api.github.com/repos/uber/submitqueue/pulls/123", + }, + { + name: "enterprise repos endpoint", + baseURL: "https://ghe.example.com/api", + path: "/repos/myorg/myrepo/pulls/456", + expected: "https://ghe.example.com/api/repos/myorg/myrepo/pulls/456", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewClient(&http.Client{}, tt.baseURL) + assert.Equal(t, tt.expected, client.RESTURL(tt.path)) + }) + } +} diff --git a/extension/changeprovider/github/convert.go b/extension/changeprovider/github/convert.go new file mode 100644 index 00000000..60394f83 --- /dev/null +++ b/extension/changeprovider/github/convert.go @@ -0,0 +1,42 @@ +package github + +import ( + entitygithub "github.com/uber/submitqueue/entity/github" + "github.com/uber/submitqueue/extension/changeprovider" +) + +// convertToChangeInfo converts GitHub PR data to ChangeInfo. +func convertToChangeInfo(parsed entitygithub.ChangeID, prData *pullRequestData) changeprovider.ChangeInfo { + changedFiles := convertFiles(prData.Files.Nodes) + + return changeprovider.ChangeInfo{ + URI: parsed.String(), + User: changeprovider.User{ + Name: prData.Author.Name, + Email: prData.Author.Email, + }, + ChangedFiles: changedFiles, + } +} + +// convertFiles converts GitHub file nodes to ChangedFile structs. +func convertFiles(nodes []fileNode) []changeprovider.ChangedFile { + changedFiles := make([]changeprovider.ChangedFile, 0, len(nodes)) + + for _, file := range nodes { + linesModified := 0 + if file.Additions > 0 && file.Deletions > 0 { + linesModified = min(file.Additions, file.Deletions) + } + + changedFiles = append(changedFiles, changeprovider.ChangedFile{ + Path: file.Path, + Patch: file.Patch, + LinesAdded: file.Additions, + LinesDeleted: file.Deletions, + LinesModified: linesModified, + }) + } + + return changedFiles +} diff --git a/extension/changeprovider/github/graphql.go b/extension/changeprovider/github/graphql.go new file mode 100644 index 00000000..08f43206 --- /dev/null +++ b/extension/changeprovider/github/graphql.go @@ -0,0 +1,157 @@ +package github + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" +) + +// pullRequestQuery is the GraphQL query to fetch pull request information including files, author, and head SHA. +const pullRequestQuery = ` +query($owner: String!, $repo: String!, $prNumber: Int!, $filesCursor: String) { + repository(owner: $owner, name: $repo) { + pullRequest(number: $prNumber) { + number + headRefOid + author { + login + ... on User { + name + email + } + } + files(first: 100, after: $filesCursor) { + totalCount + pageInfo { + endCursor + hasNextPage + } + nodes { + path + additions + deletions + changeType + patch + } + } + } + } +} +` + +// graphqlRequest represents a GraphQL request. +type graphqlRequest struct { + Query string `json:"query"` + Variables map[string]any `json:"variables"` +} + +// graphqlResponse represents the top-level GraphQL response. +type graphqlResponse struct { + Data struct { + Repository struct { + PullRequest pullRequestData `json:"pullRequest"` + } `json:"repository"` + } `json:"data"` + Errors []graphqlError `json:"errors,omitempty"` +} + +// graphqlError represents a GraphQL error. +type graphqlError struct { + Message string `json:"message"` + Type string `json:"type"` +} + +// pullRequestData contains the pull request metadata. +type pullRequestData struct { + Number int `json:"number"` + HeadRefOid string `json:"headRefOid"` + Author authorData `json:"author"` + Files filesData `json:"files"` +} + +// authorData contains the author information. +type authorData struct { + Login string `json:"login"` + Name string `json:"name"` + Email string `json:"email"` +} + +// filesData contains the files changed in the pull request. +type filesData struct { + TotalCount int `json:"totalCount"` + PageInfo pageInfo `json:"pageInfo"` + Nodes []fileNode `json:"nodes"` +} + +// pageInfo contains pagination information. +type pageInfo struct { + EndCursor string `json:"endCursor"` + HasNextPage bool `json:"hasNextPage"` +} + +// fileNode represents a single changed file. +type fileNode struct { + Path string `json:"path"` + Additions int `json:"additions"` + Deletions int `json:"deletions"` + ChangeType string `json:"changeType"` + Patch string `json:"patch"` +} + +// buildGraphQLRequest builds a GraphQL request for fetching pull request data. +func buildGraphQLRequest(org, repo string, prNumber int, cursor string) graphqlRequest { + return graphqlRequest{ + Query: pullRequestQuery, + Variables: map[string]any{ + "owner": org, + "repo": repo, + "prNumber": prNumber, + "filesCursor": cursor, + }, + } +} + +// doGraphQLRequest executes a GraphQL HTTP request. +func doGraphQLRequest( + ctx context.Context, + bodyBytes []byte, + client *Client, +) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, client.GraphQLURL(), bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := client.HTTPClient().Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + + return resp, nil +} + +// parseGraphQLResponse parses and validates a GraphQL response. +func parseGraphQLResponse( + resp *http.Response, +) (*pullRequestData, error) { + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("GitHub API returned status %d: %s", resp.StatusCode, string(body)) + } + + var gqlResp graphqlResponse + if err := json.NewDecoder(resp.Body).Decode(&gqlResp); err != nil { + return nil, fmt.Errorf("failed to decode GraphQL response: %w", err) + } + + if len(gqlResp.Errors) > 0 { + return nil, fmt.Errorf("GraphQL errors: %+v", gqlResp.Errors) + } + + return &gqlResp.Data.Repository.PullRequest, nil +} diff --git a/extension/changeprovider/github/graphql_test.go b/extension/changeprovider/github/graphql_test.go new file mode 100644 index 00000000..605f449d --- /dev/null +++ b/extension/changeprovider/github/graphql_test.go @@ -0,0 +1,145 @@ +package github + +import ( + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildGraphQLRequest(t *testing.T) { + tests := []struct { + name string + org string + repo string + prNumber int + cursor string + wantVars map[string]any + }{ + { + name: "no cursor", + org: "uber", + repo: "submitqueue", + prNumber: 123, + cursor: "", + wantVars: map[string]any{ + "owner": "uber", + "repo": "submitqueue", + "prNumber": 123, + "filesCursor": "", + }, + }, + { + name: "with cursor", + org: "myorg", + repo: "myrepo", + prNumber: 456, + cursor: "cursor_token_xyz", + wantVars: map[string]any{ + "owner": "myorg", + "repo": "myrepo", + "prNumber": 456, + "filesCursor": "cursor_token_xyz", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := buildGraphQLRequest(tt.org, tt.repo, tt.prNumber, tt.cursor) + assert.Equal(t, pullRequestQuery, req.Query) + assert.Equal(t, tt.wantVars, req.Variables) + }) + } +} + +func TestParseGraphQLResponse(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + wantErr bool + wantData *pullRequestData + }{ + { + name: "success", + statusCode: http.StatusOK, + body: `{ + "data": { + "repository": { + "pullRequest": { + "number": 42, + "headRefOid": "abc123", + "author": {"login": "octocat", "name": "The Octocat", "email": "octocat@example.com"}, + "files": { + "totalCount": 1, + "pageInfo": {"endCursor": "cur1", "hasNextPage": false}, + "nodes": [{"path": "main.go", "additions": 10, "deletions": 2, "changeType": "MODIFIED", "patch": "diff content"}] + } + } + } + } + }`, + wantData: &pullRequestData{ + Number: 42, + HeadRefOid: "abc123", + Author: authorData{Login: "octocat", Name: "The Octocat", Email: "octocat@example.com"}, + Files: filesData{ + TotalCount: 1, + PageInfo: pageInfo{EndCursor: "cur1", HasNextPage: false}, + Nodes: []fileNode{{Path: "main.go", Additions: 10, Deletions: 2, ChangeType: "MODIFIED", Patch: "diff content"}}, + }, + }, + }, + { + name: "non-200 status", + statusCode: http.StatusInternalServerError, + body: `{"message":"Internal Server Error"}`, + wantErr: true, + }, + { + name: "404 not found", + statusCode: http.StatusNotFound, + body: `{"message":"Not Found"}`, + wantErr: true, + }, + { + name: "invalid JSON", + statusCode: http.StatusOK, + body: `{invalid json`, + wantErr: true, + }, + { + name: "GraphQL errors", + statusCode: http.StatusOK, + body: `{"errors":[{"message":"Field doesn't exist","type":"INVALID_FIELD"}]}`, + wantErr: true, + }, + { + name: "multiple GraphQL errors", + statusCode: http.StatusOK, + body: `{"errors":[{"message":"Not Found","type":"NOT_FOUND"},{"message":"Forbidden","type":"FORBIDDEN"}]}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &http.Response{ + StatusCode: tt.statusCode, + Body: io.NopCloser(strings.NewReader(tt.body)), + } + + got, err := parseGraphQLResponse(resp) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantData, got) + }) + } +} diff --git a/extension/changeprovider/github/provider.go b/extension/changeprovider/github/provider.go new file mode 100644 index 00000000..e8024b6d --- /dev/null +++ b/extension/changeprovider/github/provider.go @@ -0,0 +1,163 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/uber-go/tally/v4" + "go.uber.org/zap" + + coremetrics "github.com/uber/submitqueue/core/metrics" + "github.com/uber/submitqueue/entity" + entitygithub "github.com/uber/submitqueue/entity/github" + "github.com/uber/submitqueue/extension/changeprovider" +) + +// Params holds the dependencies for the GitHub ChangeProvider. +type Params struct { + // Client is a pre-configured GitHub API client (encapsulates HTTP client and GraphQL URL). + // Auth is the caller's responsibility via HTTP transport/round-tripper. + Client *Client + // Logger is the structured logger. + Logger *zap.SugaredLogger + // MetricsScope is the metrics scope for instrumentation. + MetricsScope tally.Scope +} + +// provider implements the ChangeProvider interface for GitHub. +type provider struct { + client *Client + logger *zap.SugaredLogger + metricsScope tally.Scope +} + +// NewProvider creates a new GitHub ChangeProvider. +func NewProvider(params Params) changeprovider.ChangeProvider { + return &provider{ + client: params.Client, + logger: params.Logger.Named("github_changeprovider"), + metricsScope: params.MetricsScope.SubScope("github_changeprovider"), + } +} + +// Get retrieves change information from GitHub for the provided Change. +// Returns one ChangeInfo per URI (one per PR in stacked changes). +func (p *provider) Get(ctx context.Context, change entity.Change) (_ []changeprovider.ChangeInfo, retErr error) { + op := coremetrics.Begin(p.metricsScope, "get") + defer func() { op.Complete(retErr) }() + + // Parse all change IDs + changeIDs := make([]entitygithub.ChangeID, 0, len(change.URIs)) + for _, uri := range change.URIs { + parsed, err := entitygithub.ParseChangeID(uri) + if err != nil { + return nil, fmt.Errorf("failed to parse GitHub change ID %q: %w", uri, err) + } + changeIDs = append(changeIDs, parsed) + } + + p.logger.Debugw("fetching PR data from GitHub", + "pr_count", len(changeIDs), + "uris", change.URIs, + ) + + // Validate stacked changes are consistent (same provider, org, and repo) + if err := validateChangeConsistency(changeIDs); err != nil { + return nil, err + } + + // Fetch each PR and build ChangeInfo for each + changeInfos, err := p.fetchAllPRs(ctx, changeIDs) + if err != nil { + return nil, err + } + + p.logger.Debugw("successfully fetched PR data", + "pr_count", len(changeIDs), + ) + + return changeInfos, nil +} + +// fetchAllPRs fetches and validates all PRs in the stack, returning on the first error. +func (p *provider) fetchAllPRs( + ctx context.Context, + changeIDs []entitygithub.ChangeID, +) ([]changeprovider.ChangeInfo, error) { + changeInfos := make([]changeprovider.ChangeInfo, 0, len(changeIDs)) + + for _, cid := range changeIDs { + prData, err := p.fetchPullRequest(ctx, cid) + if err != nil { + coremetrics.NamedCounter(p.metricsScope, "fetch_pr", "errors", 1, + coremetrics.NewTag("org", cid.Org), + coremetrics.NewTag("repo", cid.Repo), + ) + return nil, fmt.Errorf("failed to fetch PR #%d: %w", cid.PRNumber, err) + } + + if err := validatePRStaleness(cid, prData); err != nil { + return nil, err + } + + changeInfo := convertToChangeInfo(cid, prData) + changeInfos = append(changeInfos, changeInfo) + + p.logger.Debugw("fetched PR data", + "org", cid.Org, + "repo", cid.Repo, + "pr", cid.PRNumber, + "files_count", len(changeInfo.ChangedFiles), + "head_sha", prData.HeadRefOid, + ) + } + + return changeInfos, nil +} + +// fetchPullRequest makes GraphQL request(s) to fetch PR data, handling pagination. +func (p *provider) fetchPullRequest(ctx context.Context, parsed entitygithub.ChangeID) (*pullRequestData, error) { + var allFiles []fileNode + var prData pullRequestData + cursor := "" + + for { + data, err := p.fetchPullRequestPage(ctx, parsed.Org, parsed.Repo, parsed.PRNumber, cursor) + if err != nil { + return nil, err + } + + if cursor == "" { + prData = *data + } + + allFiles = append(allFiles, data.Files.Nodes...) + + if !data.Files.PageInfo.HasNextPage { + break + } + cursor = data.Files.PageInfo.EndCursor + } + + prData.Files.Nodes = allFiles + return &prData, nil +} + +// fetchPullRequestPage fetches a single page of PR data. +func (p *provider) fetchPullRequestPage(ctx context.Context, org, repo string, prNumber int, cursor string) (*pullRequestData, error) { + reqBody := buildGraphQLRequest(org, repo, prNumber, cursor) + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal GraphQL request: %w", err) + } + + resp, err := doGraphQLRequest(ctx, bodyBytes, p.client) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return parseGraphQLResponse(resp) +} diff --git a/extension/changeprovider/github/provider_test.go b/extension/changeprovider/github/provider_test.go new file mode 100644 index 00000000..28e9a7dd --- /dev/null +++ b/extension/changeprovider/github/provider_test.go @@ -0,0 +1,202 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/uber-go/tally/v4" + "go.uber.org/zap/zaptest" + + "github.com/uber/submitqueue/entity" + "github.com/uber/submitqueue/extension/changeprovider" +) + +func newTestProvider(t *testing.T, serverURL string) changeprovider.ChangeProvider { + t.Helper() + return NewProvider(Params{ + Client: NewClient(&http.Client{}, serverURL), + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) +} + +func servePR(t *testing.T, w http.ResponseWriter, data pullRequestData) { + t.Helper() + var resp graphqlResponse + resp.Data.Repository.PullRequest = data + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(resp)) +} + +func TestProvider_Get(t *testing.T) { + tests := []struct { + name string + handler http.HandlerFunc + uris []string + wantErr bool + }{ + { + name: "returns result for valid PR", + handler: func(w http.ResponseWriter, r *http.Request) { + servePR(t, w, pullRequestData{ + Number: 123, + HeadRefOid: "abc123", + Author: authorData{Name: "Test User", Email: "test@example.com"}, + Files: filesData{ + Nodes: []fileNode{ + {Path: "main.go"}, + {Path: "test.go"}, + }, + }, + }) + }, + uris: []string{"github://uber/submitqueue/123/abc123"}, + }, + { + name: "invalid URI returns error", + uris: []string{"invalid://uri"}, + wantErr: true, + }, + { + name: "inconsistent change set returns error", + uris: []string{ + "github://uber/submitqueue/123/abc123", + "github://uber/different-repo/456/def456", + }, + wantErr: true, + }, + { + name: "stale PR returns error", + handler: func(w http.ResponseWriter, r *http.Request) { + servePR(t, w, pullRequestData{ + Number: 123, + HeadRefOid: "newsha", + Files: filesData{Nodes: []fileNode{{Path: "main.go"}}}, + }) + }, + uris: []string{"github://uber/submitqueue/123/oldsha"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + serverURL := "http://localhost" + if tt.handler != nil { + server := httptest.NewServer(tt.handler) + defer server.Close() + serverURL = server.URL + } + + p := newTestProvider(t, serverURL) + infos, err := p.Get(context.Background(), entity.Change{URIs: tt.uris}) + + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Len(t, infos, 1) + assert.Equal(t, tt.uris[0], infos[0].URI) + assert.Len(t, infos[0].ChangedFiles, 2) + }) + } +} + +func TestProvider_Get_Pagination(t *testing.T) { + pages := []pullRequestData{ + { + Number: 456, + HeadRefOid: "xyz789", + Files: filesData{ + PageInfo: pageInfo{EndCursor: "cursor1", HasNextPage: true}, + Nodes: []fileNode{{Path: "file1.go"}}, + }, + }, + { + Number: 456, + HeadRefOid: "xyz789", + Files: filesData{ + PageInfo: pageInfo{HasNextPage: false}, + Nodes: []fileNode{{Path: "file2.go"}}, + }, + }, + } + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + servePR(t, w, pages[callCount]) + callCount++ + })) + defer server.Close() + + p := newTestProvider(t, server.URL) + infos, err := p.Get(context.Background(), entity.Change{ + URIs: []string{"github://uber/submitqueue/456/xyz789"}, + }) + + require.NoError(t, err) + assert.Equal(t, 2, callCount) + require.Len(t, infos, 1) + assert.Len(t, infos[0].ChangedFiles, 2) +} + +func TestProvider_Get_MultiplePRs(t *testing.T) { + prData := []pullRequestData{ + {Number: 123, HeadRefOid: "abc123", Files: filesData{Nodes: []fileNode{{Path: "file1.go"}}}}, + {Number: 456, HeadRefOid: "def456", Files: filesData{Nodes: []fileNode{{Path: "file2.go"}}}}, + } + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + servePR(t, w, prData[callCount]) + callCount++ + })) + defer server.Close() + + p := newTestProvider(t, server.URL) + infos, err := p.Get(context.Background(), entity.Change{ + URIs: []string{ + "github://uber/submitqueue/123/abc123", + "github://uber/submitqueue/456/def456", + }, + }) + + require.NoError(t, err) + assert.Equal(t, 2, callCount) + require.Len(t, infos, 2) + assert.Equal(t, "github://uber/submitqueue/123/abc123", infos[0].URI) + assert.Equal(t, "github://uber/submitqueue/456/def456", infos[1].URI) +} + +func TestProvider_Get_FetchError_StopsOnFirstFailure(t *testing.T) { + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if callCount == 1 { + callCount++ + w.WriteHeader(http.StatusInternalServerError) + return + } + servePR(t, w, pullRequestData{ + Number: 123, + HeadRefOid: "abc123", + Files: filesData{Nodes: []fileNode{{Path: "file1.go"}}}, + }) + callCount++ + })) + defer server.Close() + + p := newTestProvider(t, server.URL) + _, err := p.Get(context.Background(), entity.Change{ + URIs: []string{ + "github://uber/submitqueue/123/abc123", + "github://uber/submitqueue/456/def456", + }, + }) + + require.Error(t, err) + assert.Equal(t, 2, callCount) +} diff --git a/extension/changeprovider/github/validate.go b/extension/changeprovider/github/validate.go new file mode 100644 index 00000000..55bf1fb5 --- /dev/null +++ b/extension/changeprovider/github/validate.go @@ -0,0 +1,46 @@ +package github + +import ( + "fmt" + + entitygithub "github.com/uber/submitqueue/entity/github" +) + +// validateChangeConsistency validates that all changeIDs in the stack are consistent. +// Stacked changes must have the same change provider (scheme), org, and repo. +func validateChangeConsistency( + changeIDs []entitygithub.ChangeID, +) error { + expectedScheme := changeIDs[0].Scheme + expectedOrg := changeIDs[0].Org + expectedRepo := changeIDs[0].Repo + + for _, cid := range changeIDs { + // Validate same change provider (scheme) + if cid.Scheme != expectedScheme { + return fmt.Errorf("stacked changes must use same change provider: expected %s, got %s for PR #%d", + expectedScheme, cid.Scheme, cid.PRNumber) + } + + // Validate same org and repo + if cid.Org != expectedOrg || cid.Repo != expectedRepo { + return fmt.Errorf("stacked changes must be from same org/repository: expected %s/%s, got %s/%s for PR #%d", + expectedOrg, expectedRepo, cid.Org, cid.Repo, cid.PRNumber) + } + } + + return nil +} + +// validatePRStaleness validates that the PR hasn't changed since submission. +// Compares the fetched head SHA with the expected SHA from the change URI. +func validatePRStaleness( + cid entitygithub.ChangeID, + prData *pullRequestData, +) error { + if prData.HeadRefOid != cid.HeadCommitSHA { + return fmt.Errorf("PR #%d head SHA changed: expected %s, got %s", + cid.PRNumber, cid.HeadCommitSHA, prData.HeadRefOid) + } + return nil +} diff --git a/extension/changeprovider/github/validate_test.go b/extension/changeprovider/github/validate_test.go new file mode 100644 index 00000000..93ceb028 --- /dev/null +++ b/extension/changeprovider/github/validate_test.go @@ -0,0 +1,99 @@ +package github + +import ( + "testing" + + "github.com/stretchr/testify/require" + entitygithub "github.com/uber/submitqueue/entity/github" +) + +func TestValidateChangeConsistency(t *testing.T) { + tests := []struct { + name string + changeIDs []entitygithub.ChangeID + wantErr bool + }{ + { + name: "single PR", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + }, + }, + { + name: "consistent stack", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 2}, + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 3}, + }, + }, + { + name: "different repo", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + {Scheme: "github", Org: "uber", Repo: "other-repo", PRNumber: 2}, + }, + wantErr: true, + }, + { + name: "different org", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + {Scheme: "github", Org: "other-org", Repo: "submitqueue", PRNumber: 2}, + }, + wantErr: true, + }, + { + name: "different scheme", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + {Scheme: "ghe", Org: "uber", Repo: "submitqueue", PRNumber: 2}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateChangeConsistency(tt.changeIDs) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidatePRStaleness(t *testing.T) { + tests := []struct { + name string + cid entitygithub.ChangeID + prData pullRequestData + wantErr bool + }{ + { + name: "matching SHA", + cid: entitygithub.ChangeID{PRNumber: 1, HeadCommitSHA: "abc123"}, + prData: pullRequestData{HeadRefOid: "abc123"}, + wantErr: false, + }, + { + name: "mismatched SHA", + cid: entitygithub.ChangeID{PRNumber: 1, HeadCommitSHA: "oldsha"}, + prData: pullRequestData{HeadRefOid: "newsha"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePRStaleness(tt.cid, &tt.prData) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/orchestrator/controller/validate/BUILD.bazel b/orchestrator/controller/validate/BUILD.bazel index 0ea6244a..a7da4c8b 100644 --- a/orchestrator/controller/validate/BUILD.bazel +++ b/orchestrator/controller/validate/BUILD.bazel @@ -8,8 +8,10 @@ go_library( deps = [ "//core/consumer", "//core/errs", + "//core/metrics", "//entity", "//entity/queue", + "//extension/changeprovider", "//extension/mergechecker", "//extension/storage", "@com_github_uber_go_tally_v4//:tally", @@ -26,6 +28,7 @@ go_test( "//core/errs", "//entity", "//entity/queue", + "//extension/changeprovider", "//extension/mergechecker", "//extension/mergechecker/mock", "//extension/queue/mock", diff --git a/orchestrator/controller/validate/validate.go b/orchestrator/controller/validate/validate.go index 2ee8d8b1..d541d16c 100644 --- a/orchestrator/controller/validate/validate.go +++ b/orchestrator/controller/validate/validate.go @@ -21,8 +21,10 @@ import ( "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" "github.com/uber/submitqueue/core/errs" + coremetrics "github.com/uber/submitqueue/core/metrics" "github.com/uber/submitqueue/entity" entityqueue "github.com/uber/submitqueue/entity/queue" + "github.com/uber/submitqueue/extension/changeprovider" "github.com/uber/submitqueue/extension/mergechecker" "github.com/uber/submitqueue/extension/storage" "go.uber.org/zap" @@ -33,13 +35,14 @@ import ( // and publishes to the batch stage. Validation logic is extensible to support additional checks. // Implements consumer.Controller interface for integration with the consumer. type Controller struct { - logger *zap.SugaredLogger - metricsScope tally.Scope - store storage.Storage - registry consumer.TopicRegistry - mergeChecker mergechecker.MergeChecker - topicKey consumer.TopicKey - consumerGroup string + logger *zap.SugaredLogger + metricsScope tally.Scope + store storage.Storage + registry consumer.TopicRegistry + mergeChecker mergechecker.MergeChecker + changeProvider changeprovider.ChangeProvider + topicKey consumer.TopicKey + consumerGroup string } // Verify Controller implements consumer.Controller interface at compile time. @@ -52,39 +55,42 @@ func NewController( store storage.Storage, registry consumer.TopicRegistry, mergeChecker mergechecker.MergeChecker, + changeProvider changeprovider.ChangeProvider, topicKey consumer.TopicKey, consumerGroup string, ) *Controller { return &Controller{ - logger: logger.Named("validate_controller"), - metricsScope: scope.SubScope("validate_controller"), - store: store, - registry: registry, - mergeChecker: mergeChecker, - topicKey: topicKey, - consumerGroup: consumerGroup, + logger: logger.Named("validate_controller"), + metricsScope: scope.SubScope("validate_controller"), + store: store, + registry: registry, + mergeChecker: mergeChecker, + changeProvider: changeProvider, + topicKey: topicKey, + consumerGroup: consumerGroup, } } // Process processes a validate delivery from the queue. // Deserializes the request and publishes to the batch topic. // Returns nil to ack (success), or error to nack (retry). -func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) error { - c.metricsScope.Counter("received").Inc(1) +func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (retErr error) { + op := coremetrics.Begin(c.metricsScope, "process") + defer func() { op.Complete(retErr) }() msg := delivery.Message() // Deserialize request ID from payload rid, err := entity.RequestIDFromBytes(msg.Payload) if err != nil { - c.metricsScope.Counter("deserialize_errors").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "deserialize_errors", 1) return fmt.Errorf("failed to deserialize request ID: %w", err) } // Fetch request from storage request, err := c.store.GetRequestStore().Get(ctx, rid.ID) if err != nil { - c.metricsScope.Counter("storage_errors").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "storage_errors", 1) return fmt.Errorf("failed to get request %s: %w", rid.ID, err) } @@ -100,7 +106,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er // Merge conflict check mergeResult, err := c.mergeChecker.Check(ctx, request.Queue, request.Change) if err != nil { - c.metricsScope.Counter("merge_check_errors").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "merge_check_errors", 1) return fmt.Errorf("merge check failed: %w", err) } if !mergeResult.Mergeable { @@ -109,13 +115,31 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "queue", request.Queue, "reason", mergeResult.Reason, ) - c.metricsScope.Counter("not_mergeable").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "not_mergeable", 1) return errs.NewUserError(fmt.Errorf("request %s is not mergeable: %s", request.ID, mergeResult.Reason)) } + // Fetch change metadata + changeInfos, err := c.changeProvider.Get(ctx, request.Change) + if err != nil { + c.logger.Errorw("failed to fetch change information", + "request_id", request.ID, + "change_uris", request.Change.URIs, + "error", err, + ) + coremetrics.NamedCounter(c.metricsScope, "process", "change_provider_errors", 1) + return fmt.Errorf("failed to fetch change information: %w", err) + } + + c.logger.Infow("fetched change information", + "request_id", request.ID, + "change_count", len(changeInfos), + "total_files", totalFiles(changeInfos), + ) + // Publish to batch topic if err := c.publish(ctx, consumer.TopicKeyBatch, request.ID, request.Queue); err != nil { - c.metricsScope.Counter("publish_errors").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "publish_errors", 1) return fmt.Errorf("failed to publish to batch: %w", err) } @@ -124,8 +148,6 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "topic_key", consumer.TopicKeyBatch, ) - c.metricsScope.Counter("processed").Inc(1) - return nil // Success - message will be acked } @@ -156,6 +178,15 @@ func (c *Controller) publish(ctx context.Context, key consumer.TopicKey, request return nil } +// totalFiles returns the total number of files across all changeInfos. +func totalFiles(infos []changeprovider.ChangeInfo) int { + total := 0 + for _, info := range infos { + total += len(info.ChangedFiles) + } + return total +} + // Name returns the controller name for logging and metrics. func (c *Controller) Name() string { return "validate" diff --git a/orchestrator/controller/validate/validate_test.go b/orchestrator/controller/validate/validate_test.go index cba0500d..0d1b2ca7 100644 --- a/orchestrator/controller/validate/validate_test.go +++ b/orchestrator/controller/validate/validate_test.go @@ -26,6 +26,7 @@ import ( "github.com/uber/submitqueue/core/errs" "github.com/uber/submitqueue/entity" "github.com/uber/submitqueue/entity/queue" + "github.com/uber/submitqueue/extension/changeprovider" "github.com/uber/submitqueue/extension/mergechecker" mergecheckermock "github.com/uber/submitqueue/extension/mergechecker/mock" queuemock "github.com/uber/submitqueue/extension/queue/mock" @@ -41,6 +42,25 @@ func requestIDPayload(t *testing.T, id string) []byte { return payload } +// mockChangeProvider is a simple mock that returns test data. +type mockChangeProvider struct{} + +func (m *mockChangeProvider) Get(ctx context.Context, change entity.Change) ([]changeprovider.ChangeInfo, error) { + // Return simple test data + return []changeprovider.ChangeInfo{ + { + URI: "github://org/repo/123/abc123", + User: changeprovider.User{ + Name: "Test User", + Email: "test@example.com", + }, + ChangedFiles: []changeprovider.ChangedFile{ + {Path: "main.go"}, + }, + }, + }, nil +} + // newMergeableMock returns a mock MergeChecker that always returns mergeable. func newMergeableMock(ctrl *gomock.Controller) *mergecheckermock.MockMergeChecker { mc := mergecheckermock.NewMockMergeChecker(ctrl) @@ -78,7 +98,9 @@ func newTestController(t *testing.T, ctrl *gomock.Controller, store *storagemock ) require.NoError(t, err) - return NewController(logger, scope, store, registry, mc, consumer.TopicKeyValidate, "orchestrator-validate") + cp := &mockChangeProvider{} + + return NewController(logger, scope, store, registry, mc, cp, consumer.TopicKeyValidate, "orchestrator-validate") } func TestNewController(t *testing.T) { From 308671335589001dc7bb660ae2ff7c256757f201 Mon Sep 17 00:00:00 2001 From: rprithyani Date: Fri, 13 Mar 2026 04:19:43 +0000 Subject: [PATCH 2/4] make http client configurable with base url and bearer transport --- core/httpclient/transport.go | 80 +++++++++ core/httpclient/transport_test.go | 169 ++++++++++++++++++ example/server/orchestrator/main.go | 39 ++-- extension/changeprovider/github/config.go | 47 ----- .../changeprovider/github/config_test.go | 97 ---------- extension/changeprovider/github/graphql.go | 7 +- extension/changeprovider/github/provider.go | 13 +- .../changeprovider/github/provider_test.go | 6 +- 8 files changed, 281 insertions(+), 177 deletions(-) create mode 100644 core/httpclient/transport.go create mode 100644 core/httpclient/transport_test.go delete mode 100644 extension/changeprovider/github/config.go delete mode 100644 extension/changeprovider/github/config_test.go diff --git a/core/httpclient/transport.go b/core/httpclient/transport.go new file mode 100644 index 00000000..664f1218 --- /dev/null +++ b/core/httpclient/transport.go @@ -0,0 +1,80 @@ +package httpclient + +import ( + "net/http" + "net/url" + "strings" + "time" +) + +// BaseURLTransport is an http.RoundTripper that rewrites every request URL +// to resolve against a fixed base URL. This allows callers to make requests +// with relative paths (e.g. "/graphql") and have the transport prepend the +// configured base URL transparently. +type BaseURLTransport struct { + // BaseURL is the API base URL (e.g. "https://api.github.com"). + BaseURL *url.URL + // Next is the underlying RoundTripper. Defaults to http.DefaultTransport if nil. + Next http.RoundTripper +} + +// RoundTrip rewrites req.URL to resolve against BaseURL, then delegates to Next. +// The base URL path and request path are joined explicitly so that base URLs +// with a path component (e.g. "https://ghe.example.com/api") are handled +// correctly regardless of whether the request path starts with "/". +func (t *BaseURLTransport) RoundTrip(req *http.Request) (*http.Response, error) { + newReq := req.Clone(req.Context()) + + merged := *t.BaseURL + merged.Path = strings.TrimRight(t.BaseURL.Path, "/") + "/" + strings.TrimLeft(req.URL.Path, "/") + merged.RawQuery = req.URL.RawQuery + newReq.URL = &merged + + next := t.Next + if next == nil { + next = http.DefaultTransport + } + return next.RoundTrip(newReq) +} + +// BearerTransport is an http.RoundTripper that adds a Bearer token +// Authorization header to every request. +type BearerTransport struct { + // Token is the bearer token to include in requests. + Token string + // Next is the underlying RoundTripper. Defaults to http.DefaultTransport if nil. + Next http.RoundTripper +} + +// RoundTrip adds the Authorization header, then delegates to Next. +func (t *BearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + newReq := req.Clone(req.Context()) + newReq.Header.Set("Authorization", "Bearer "+t.Token) + + next := t.Next + if next == nil { + next = http.DefaultTransport + } + return next.RoundTrip(newReq) +} + +// NewClient builds an *http.Client with BaseURLTransport and optionally +// BearerTransport configured. The transport chain is: +// +// BearerTransport (if token provided) → BaseURLTransport → DefaultTransport +func NewClient(rawBaseURL, token string, timeout time.Duration) (*http.Client, error) { + u, err := url.Parse(rawBaseURL) + if err != nil { + return nil, err + } + + var transport http.RoundTripper = &BaseURLTransport{ + BaseURL: u, + Next: http.DefaultTransport, + } + if token != "" { + transport = &BearerTransport{Token: token, Next: transport} + } + + return &http.Client{Transport: transport, Timeout: timeout}, nil +} diff --git a/core/httpclient/transport_test.go b/core/httpclient/transport_test.go new file mode 100644 index 00000000..3c37a821 --- /dev/null +++ b/core/httpclient/transport_test.go @@ -0,0 +1,169 @@ +package httpclient + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// roundTripFunc is a test helper that implements http.RoundTripper via a function. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestBaseURLTransport_RewritesURL(t *testing.T) { + tests := []struct { + name string + baseURL string + requestPath string + wantURL string + }{ + { + name: "relative path resolved against base", + baseURL: "https://api.github.com", + requestPath: "/graphql", + wantURL: "https://api.github.com/graphql", + }, + { + name: "enterprise base URL", + baseURL: "https://ghe.example.com/api", + requestPath: "/graphql", + wantURL: "https://ghe.example.com/api/graphql", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedURL string + transport := &BaseURLTransport{ + BaseURL: mustParseURL(t, tt.baseURL), + Next: roundTripFunc(func(req *http.Request) (*http.Response, error) { + capturedURL = req.URL.String() + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }), + } + + req, err := http.NewRequest(http.MethodGet, tt.requestPath, nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, tt.wantURL, capturedURL) + }) + } +} + +func TestBaseURLTransport_DoesNotMutateOriginalRequest(t *testing.T) { + transport := &BaseURLTransport{ + BaseURL: mustParseURL(t, "https://api.github.com"), + Next: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }), + } + + req, err := http.NewRequest(http.MethodGet, "/graphql", nil) + require.NoError(t, err) + originalURL := req.URL.String() + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, originalURL, req.URL.String()) +} + +func TestBearerTransport_AddsAuthHeader(t *testing.T) { + var capturedHeader string + transport := &BearerTransport{ + Token: "test-token", + Next: roundTripFunc(func(req *http.Request) (*http.Response, error) { + capturedHeader = req.Header.Get("Authorization") + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }), + } + + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, "Bearer test-token", capturedHeader) +} + +func TestBearerTransport_DoesNotMutateOriginalRequest(t *testing.T) { + transport := &BearerTransport{ + Token: "test-token", + Next: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }), + } + + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + assert.Empty(t, req.Header.Get("Authorization")) +} + +func TestNewClient_InvalidURL(t *testing.T) { + _, err := NewClient("://invalid", "", 30*time.Second) + require.Error(t, err) +} + +func TestNewClient_SetsTimeout(t *testing.T) { + client, err := NewClient("https://api.github.com", "", 10*time.Second) + require.NoError(t, err) + assert.Equal(t, 10*time.Second, client.Timeout) +} + +func TestNewClient_AuthHeader(t *testing.T) { + tests := []struct { + name string + token string + wantAuthHeader string + }{ + { + name: "no token, no auth header", + token: "", + wantAuthHeader: "", + }, + { + name: "with token, adds bearer auth header", + token: "my-token", + wantAuthHeader: "Bearer my-token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Captured-Auth", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client, err := NewClient(server.URL, tt.token, 30*time.Second) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, tt.wantAuthHeader, resp.Header.Get("X-Captured-Auth")) + }) + } +} + +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + u, err := url.Parse(raw) + require.NoError(t, err) + return u +} diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index 2db5391d..d05965ee 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -30,6 +30,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" + "github.com/uber/submitqueue/core/httpclient" "github.com/uber/submitqueue/entity" "github.com/uber/submitqueue/extension/changeprovider" githubprovider "github.com/uber/submitqueue/extension/changeprovider/github" @@ -193,7 +194,10 @@ func run() error { mc := newMergeChecker(logger, scope) // Create change provider - cp := newChangeProvider(logger, scope) + cp, err := newChangeProvider(logger, scope) + if err != nil { + return fmt.Errorf("failed to create change provider: %w", err) + } // Register controllers if err := registerControllers(c, logger.Sugar(), scope, registry, mc, cp, cnt, store); err != nil { @@ -540,16 +544,6 @@ func parseTimeout(envVal string, defaultVal time.Duration) time.Duration { return defaultVal } -// buildGitHubHTTPClient creates an http.Client configured for GitHub API calls. -// Configures timeout and optional bearer token authentication. -func buildGitHubHTTPClient(token string, timeout time.Duration) *http.Client { - httpClient := &http.Client{Timeout: timeout} - if token != "" { - httpClient.Transport = &bearerTransport{token: token} - } - return httpClient -} - // newMergeChecker creates a MergeChecker for GitHub (github.com). // Configured via GITHUB_TOKEN and GITHUB_GRAPHQL_URL environment variables. func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeChecker { @@ -577,22 +571,21 @@ func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeCh // newChangeProvider creates a ChangeProvider for GitHub (github.com). // Configured via GITHUB_BASE_URL, GITHUB_TOKEN, and GITHUB_TIMEOUT environment variables. -// Uses pure dependency injection - creates http.Client with auth configured in Transport. -func newChangeProvider(logger *zap.Logger, scope tally.Scope) changeprovider.ChangeProvider { - // 1. Read configuration from environment - baseURL := getEnv("GITHUB_BASE_URL", "https://api.github.com") - token := os.Getenv("GITHUB_TOKEN") - timeout := parseTimeout(os.Getenv("GITHUB_TIMEOUT"), 30*time.Second) - - // 2. Build HTTP client with caller-controlled config (auth + timeout) - httpClient := buildGitHubHTTPClient(token, timeout) +func newChangeProvider(logger *zap.Logger, scope tally.Scope) (changeprovider.ChangeProvider, error) { + client, err := httpclient.NewClient( + getEnv("GITHUB_BASE_URL", "https://api.github.com"), + os.Getenv("GITHUB_TOKEN"), + parseTimeout(os.Getenv("GITHUB_TIMEOUT"), 30*time.Second), + ) + if err != nil { + return nil, fmt.Errorf("failed to build GitHub HTTP client: %w", err) + } - // 3. Inject into provider return githubprovider.NewProvider(githubprovider.Params{ - Client: githubprovider.NewClient(httpClient, baseURL), + HTTPClient: client, Logger: logger.Sugar(), MetricsScope: scope.SubScope("changeprovider"), - }) + }), nil } // bearerTransport is an http.RoundTripper that adds a Bearer token to requests. diff --git a/extension/changeprovider/github/config.go b/extension/changeprovider/github/config.go deleted file mode 100644 index 9c646cd9..00000000 --- a/extension/changeprovider/github/config.go +++ /dev/null @@ -1,47 +0,0 @@ -package github - -import ( - "net/http" -) - -// Client is a GitHub API client that encapsulates connection details and authentication. -// The client is protocol-agnostic - it provides helpers for both GraphQL and REST endpoints. -type Client struct { - httpClient *http.Client - baseURL string -} - -// NewClient creates a new GitHub API client with a pre-configured HTTP client. -// The caller is responsible for configuring authentication in the HTTP client's Transport. -// -// Parameters: -// - httpClient: Configured HTTP client (with auth, timeout, transport, etc.) -// - baseURL: GitHub instance base URL (e.g., "https://api.github.com" or "https://ghe.company.com/api") -func NewClient(httpClient *http.Client, baseURL string) *Client { - return &Client{ - httpClient: httpClient, - baseURL: baseURL, - } -} - -// HTTPClient returns the configured HTTP client. -func (c *Client) HTTPClient() *http.Client { - return c.httpClient -} - -// BaseURL returns the configured GitHub base URL. -func (c *Client) BaseURL() string { - return c.baseURL -} - -// GraphQLURL returns the GitHub GraphQL endpoint URL. -// Constructs the URL by appending "/graphql" to the base URL. -func (c *Client) GraphQLURL() string { - return c.baseURL + "/graphql" -} - -// RESTURL constructs a GitHub REST API endpoint URL. -// The path should start with "/" (e.g., "/repos/uber/submitqueue/pulls/123"). -func (c *Client) RESTURL(path string) string { - return c.baseURL + path -} diff --git a/extension/changeprovider/github/config_test.go b/extension/changeprovider/github/config_test.go deleted file mode 100644 index 34d6e81f..00000000 --- a/extension/changeprovider/github/config_test.go +++ /dev/null @@ -1,97 +0,0 @@ -package github - -import ( - "net/http" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestNewClient(t *testing.T) { - httpClient := &http.Client{Timeout: 30 * time.Second} - baseURL := "https://api.github.com" - - client := NewClient(httpClient, baseURL) - - assert.NotNil(t, client) - assert.Equal(t, httpClient, client.HTTPClient()) - assert.Equal(t, baseURL, client.BaseURL()) -} - -func TestClient_GraphQLURL(t *testing.T) { - tests := []struct { - name string - baseURL string - expected string - }{ - { - name: "standard github", - baseURL: "https://api.github.com", - expected: "https://api.github.com/graphql", - }, - { - name: "github enterprise", - baseURL: "https://ghe.example.com/api", - expected: "https://ghe.example.com/api/graphql", - }, - { - name: "localhost", - baseURL: "http://localhost:8080", - expected: "http://localhost:8080/graphql", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - client := NewClient(&http.Client{}, tt.baseURL) - assert.Equal(t, tt.expected, client.GraphQLURL()) - }) - } -} - -func TestClient_BaseURL(t *testing.T) { - tests := []struct { - name string - baseURL string - }{ - {name: "standard github", baseURL: "https://api.github.com"}, - {name: "github enterprise", baseURL: "https://ghe.example.com/api"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - client := NewClient(&http.Client{}, tt.baseURL) - assert.Equal(t, tt.baseURL, client.BaseURL()) - }) - } -} - -func TestClient_RESTURL(t *testing.T) { - tests := []struct { - name string - baseURL string - path string - expected string - }{ - { - name: "repos endpoint", - baseURL: "https://api.github.com", - path: "/repos/uber/submitqueue/pulls/123", - expected: "https://api.github.com/repos/uber/submitqueue/pulls/123", - }, - { - name: "enterprise repos endpoint", - baseURL: "https://ghe.example.com/api", - path: "/repos/myorg/myrepo/pulls/456", - expected: "https://ghe.example.com/api/repos/myorg/myrepo/pulls/456", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - client := NewClient(&http.Client{}, tt.baseURL) - assert.Equal(t, tt.expected, client.RESTURL(tt.path)) - }) - } -} diff --git a/extension/changeprovider/github/graphql.go b/extension/changeprovider/github/graphql.go index 08f43206..0d527710 100644 --- a/extension/changeprovider/github/graphql.go +++ b/extension/changeprovider/github/graphql.go @@ -115,19 +115,20 @@ func buildGraphQLRequest(org, repo string, prNumber int, cursor string) graphqlR } // doGraphQLRequest executes a GraphQL HTTP request. +// The path "/graphql" is relative — BaseURLTransport on the client resolves it to the full URL. func doGraphQLRequest( ctx context.Context, bodyBytes []byte, - client *Client, + client *http.Client, ) (*http.Response, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, client.GraphQLURL(), bytes.NewReader(bodyBytes)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/graphql", bytes.NewReader(bodyBytes)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %w", err) } req.Header.Set("Content-Type", "application/json") - resp, err := client.HTTPClient().Do(req) + resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("HTTP request failed: %w", err) } diff --git a/extension/changeprovider/github/provider.go b/extension/changeprovider/github/provider.go index e8024b6d..79b36179 100644 --- a/extension/changeprovider/github/provider.go +++ b/extension/changeprovider/github/provider.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "github.com/uber-go/tally/v4" "go.uber.org/zap" @@ -16,9 +17,9 @@ import ( // Params holds the dependencies for the GitHub ChangeProvider. type Params struct { - // Client is a pre-configured GitHub API client (encapsulates HTTP client and GraphQL URL). - // Auth is the caller's responsibility via HTTP transport/round-tripper. - Client *Client + // HTTPClient is a pre-configured HTTP client. The caller is responsible for + // configuring the base URL (via BaseURLTransport) and auth (via a transport layer). + HTTPClient *http.Client // Logger is the structured logger. Logger *zap.SugaredLogger // MetricsScope is the metrics scope for instrumentation. @@ -27,7 +28,7 @@ type Params struct { // provider implements the ChangeProvider interface for GitHub. type provider struct { - client *Client + httpClient *http.Client logger *zap.SugaredLogger metricsScope tally.Scope } @@ -35,7 +36,7 @@ type provider struct { // NewProvider creates a new GitHub ChangeProvider. func NewProvider(params Params) changeprovider.ChangeProvider { return &provider{ - client: params.Client, + httpClient: params.HTTPClient, logger: params.Logger.Named("github_changeprovider"), metricsScope: params.MetricsScope.SubScope("github_changeprovider"), } @@ -153,7 +154,7 @@ func (p *provider) fetchPullRequestPage(ctx context.Context, org, repo string, p return nil, fmt.Errorf("failed to marshal GraphQL request: %w", err) } - resp, err := doGraphQLRequest(ctx, bodyBytes, p.client) + resp, err := doGraphQLRequest(ctx, bodyBytes, p.httpClient) if err != nil { return nil, err } diff --git a/extension/changeprovider/github/provider_test.go b/extension/changeprovider/github/provider_test.go index 28e9a7dd..5ba01671 100644 --- a/extension/changeprovider/github/provider_test.go +++ b/extension/changeprovider/github/provider_test.go @@ -6,20 +6,24 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber-go/tally/v4" "go.uber.org/zap/zaptest" + "github.com/uber/submitqueue/core/httpclient" "github.com/uber/submitqueue/entity" "github.com/uber/submitqueue/extension/changeprovider" ) func newTestProvider(t *testing.T, serverURL string) changeprovider.ChangeProvider { t.Helper() + client, err := httpclient.NewClient(serverURL, "", 30*time.Second) + require.NoError(t, err) return NewProvider(Params{ - Client: NewClient(&http.Client{}, serverURL), + HTTPClient: client, Logger: zaptest.NewLogger(t).Sugar(), MetricsScope: tally.NoopScope, }) From c9eac59c51023d8319b5415142f395abbe8af76c Mon Sep 17 00:00:00 2001 From: rprithyani Date: Fri, 13 Mar 2026 04:32:31 +0000 Subject: [PATCH 3/4] remove lines modified as github api doesnt give that and also replaced bearerTransport with oath --- MODULE.bazel | 1 + core/httpclient/BUILD.bazel | 19 ++++++++++++ core/httpclient/transport.go | 30 ++++-------------- core/httpclient/transport_test.go | 34 --------------------- example/server/orchestrator/BUILD.bazel | 1 + extension/changeprovider/change_provider.go | 2 -- extension/changeprovider/github/BUILD.bazel | 3 +- extension/changeprovider/github/convert.go | 14 +++------ go.mod | 1 + go.sum | 2 ++ 10 files changed, 35 insertions(+), 72 deletions(-) create mode 100644 core/httpclient/BUILD.bazel diff --git a/MODULE.bazel b/MODULE.bazel index 28fb6a43..eabace86 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -38,6 +38,7 @@ use_repo( "com_github_uber_go_tally_v4", "org_golang_google_grpc", "org_golang_google_protobuf", + "org_golang_x_oauth2", "org_uber_go_fx", "org_uber_go_mock", "org_uber_go_yarpc", diff --git a/core/httpclient/BUILD.bazel b/core/httpclient/BUILD.bazel new file mode 100644 index 00000000..8656c5c8 --- /dev/null +++ b/core/httpclient/BUILD.bazel @@ -0,0 +1,19 @@ +load("@rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "httpclient", + srcs = ["transport.go"], + importpath = "github.com/uber/submitqueue/core/httpclient", + visibility = ["//visibility:public"], + deps = ["@org_golang_x_oauth2//:oauth2"], +) + +go_test( + name = "httpclient_test", + srcs = ["transport_test.go"], + embed = [":httpclient"], + deps = [ + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + ], +) diff --git a/core/httpclient/transport.go b/core/httpclient/transport.go index 664f1218..963b52fb 100644 --- a/core/httpclient/transport.go +++ b/core/httpclient/transport.go @@ -5,6 +5,8 @@ import ( "net/url" "strings" "time" + + "golang.org/x/oauth2" ) // BaseURLTransport is an http.RoundTripper that rewrites every request URL @@ -37,31 +39,10 @@ func (t *BaseURLTransport) RoundTrip(req *http.Request) (*http.Response, error) return next.RoundTrip(newReq) } -// BearerTransport is an http.RoundTripper that adds a Bearer token -// Authorization header to every request. -type BearerTransport struct { - // Token is the bearer token to include in requests. - Token string - // Next is the underlying RoundTripper. Defaults to http.DefaultTransport if nil. - Next http.RoundTripper -} - -// RoundTrip adds the Authorization header, then delegates to Next. -func (t *BearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { - newReq := req.Clone(req.Context()) - newReq.Header.Set("Authorization", "Bearer "+t.Token) - - next := t.Next - if next == nil { - next = http.DefaultTransport - } - return next.RoundTrip(newReq) -} - // NewClient builds an *http.Client with BaseURLTransport and optionally -// BearerTransport configured. The transport chain is: +// oauth2 bearer auth configured. The transport chain is: // -// BearerTransport (if token provided) → BaseURLTransport → DefaultTransport +// oauth2.Transport (if token provided) → BaseURLTransport → DefaultTransport func NewClient(rawBaseURL, token string, timeout time.Duration) (*http.Client, error) { u, err := url.Parse(rawBaseURL) if err != nil { @@ -73,7 +54,8 @@ func NewClient(rawBaseURL, token string, timeout time.Duration) (*http.Client, e Next: http.DefaultTransport, } if token != "" { - transport = &BearerTransport{Token: token, Next: transport} + ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token}) + transport = &oauth2.Transport{Source: ts, Base: transport} } return &http.Client{Transport: transport, Timeout: timeout}, nil diff --git a/core/httpclient/transport_test.go b/core/httpclient/transport_test.go index 3c37a821..fcdedf7c 100644 --- a/core/httpclient/transport_test.go +++ b/core/httpclient/transport_test.go @@ -77,40 +77,6 @@ func TestBaseURLTransport_DoesNotMutateOriginalRequest(t *testing.T) { assert.Equal(t, originalURL, req.URL.String()) } -func TestBearerTransport_AddsAuthHeader(t *testing.T) { - var capturedHeader string - transport := &BearerTransport{ - Token: "test-token", - Next: roundTripFunc(func(req *http.Request) (*http.Response, error) { - capturedHeader = req.Header.Get("Authorization") - return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil - }), - } - - req, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) - - _, err = transport.RoundTrip(req) - require.NoError(t, err) - assert.Equal(t, "Bearer test-token", capturedHeader) -} - -func TestBearerTransport_DoesNotMutateOriginalRequest(t *testing.T) { - transport := &BearerTransport{ - Token: "test-token", - Next: roundTripFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil - }), - } - - req, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) - - _, err = transport.RoundTrip(req) - require.NoError(t, err) - assert.Empty(t, req.Header.Get("Authorization")) -} - func TestNewClient_InvalidURL(t *testing.T) { _, err := NewClient("://invalid", "", 30*time.Second) require.Error(t, err) diff --git a/example/server/orchestrator/BUILD.bazel b/example/server/orchestrator/BUILD.bazel index 7d2b09fe..0b3bab01 100644 --- a/example/server/orchestrator/BUILD.bazel +++ b/example/server/orchestrator/BUILD.bazel @@ -12,6 +12,7 @@ go_library( visibility = ["//visibility:private"], deps = [ "//core/consumer", + "//core/httpclient", "//entity", "//extension/changeprovider", "//extension/changeprovider/github", diff --git a/extension/changeprovider/change_provider.go b/extension/changeprovider/change_provider.go index 72bdd5af..379e03b3 100644 --- a/extension/changeprovider/change_provider.go +++ b/extension/changeprovider/change_provider.go @@ -38,8 +38,6 @@ type ChangedFile struct { LinesAdded int // LinesDeleted is the number of lines deleted in this file. LinesDeleted int - // LinesModified is the number of lines modified in this file. - LinesModified int } // ChangeInfo contains metadata and file changes for a code change. diff --git a/extension/changeprovider/github/BUILD.bazel b/extension/changeprovider/github/BUILD.bazel index dd8882b4..90209846 100644 --- a/extension/changeprovider/github/BUILD.bazel +++ b/extension/changeprovider/github/BUILD.bazel @@ -3,7 +3,6 @@ load("@rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "github", srcs = [ - "config.go", "convert.go", "graphql.go", "provider.go", @@ -24,13 +23,13 @@ go_library( go_test( name = "github_test", srcs = [ - "config_test.go", "graphql_test.go", "provider_test.go", "validate_test.go", ], embed = [":github"], deps = [ + "//core/httpclient", "//entity", "//entity/github", "//extension/changeprovider", diff --git a/extension/changeprovider/github/convert.go b/extension/changeprovider/github/convert.go index 60394f83..0c7ce869 100644 --- a/extension/changeprovider/github/convert.go +++ b/extension/changeprovider/github/convert.go @@ -24,17 +24,11 @@ func convertFiles(nodes []fileNode) []changeprovider.ChangedFile { changedFiles := make([]changeprovider.ChangedFile, 0, len(nodes)) for _, file := range nodes { - linesModified := 0 - if file.Additions > 0 && file.Deletions > 0 { - linesModified = min(file.Additions, file.Deletions) - } - changedFiles = append(changedFiles, changeprovider.ChangedFile{ - Path: file.Path, - Patch: file.Patch, - LinesAdded: file.Additions, - LinesDeleted: file.Deletions, - LinesModified: linesModified, + Path: file.Path, + Patch: file.Patch, + LinesAdded: file.Additions, + LinesDeleted: file.Deletions, }) } diff --git a/go.mod b/go.mod index ca09d475..94f1a323 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( go.uber.org/mock v0.6.0 go.uber.org/yarpc v1.81.0 go.uber.org/zap v1.27.1 + golang.org/x/oauth2 v0.23.0 google.golang.org/grpc v1.68.1 google.golang.org/protobuf v1.36.3 ) diff --git a/go.sum b/go.sum index fa003055..62eb297f 100644 --- a/go.sum +++ b/go.sum @@ -234,6 +234,8 @@ golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= +golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= From 008f7d25bc21088f88005d3fee64228b8a3789df Mon Sep 17 00:00:00 2001 From: rprithyani Date: Fri, 13 Mar 2026 06:05:17 +0000 Subject: [PATCH 4/4] make oauth and timeout setup happen outside the http client --- core/httpclient/BUILD.bazel | 1 - core/httpclient/transport.go | 22 +++------ core/httpclient/transport_test.go | 49 +------------------ example/server/orchestrator/BUILD.bazel | 1 + example/server/orchestrator/main.go | 15 ++++-- .../changeprovider/github/provider_test.go | 3 +- 6 files changed, 19 insertions(+), 72 deletions(-) diff --git a/core/httpclient/BUILD.bazel b/core/httpclient/BUILD.bazel index 8656c5c8..c28998ad 100644 --- a/core/httpclient/BUILD.bazel +++ b/core/httpclient/BUILD.bazel @@ -5,7 +5,6 @@ go_library( srcs = ["transport.go"], importpath = "github.com/uber/submitqueue/core/httpclient", visibility = ["//visibility:public"], - deps = ["@org_golang_x_oauth2//:oauth2"], ) go_test( diff --git a/core/httpclient/transport.go b/core/httpclient/transport.go index 963b52fb..a485142e 100644 --- a/core/httpclient/transport.go +++ b/core/httpclient/transport.go @@ -4,9 +4,6 @@ import ( "net/http" "net/url" "strings" - "time" - - "golang.org/x/oauth2" ) // BaseURLTransport is an http.RoundTripper that rewrites every request URL @@ -39,24 +36,17 @@ func (t *BaseURLTransport) RoundTrip(req *http.Request) (*http.Response, error) return next.RoundTrip(newReq) } -// NewClient builds an *http.Client with BaseURLTransport and optionally -// oauth2 bearer auth configured. The transport chain is: -// -// oauth2.Transport (if token provided) → BaseURLTransport → DefaultTransport -func NewClient(rawBaseURL, token string, timeout time.Duration) (*http.Client, error) { +// NewClient builds an *http.Client with BaseURLTransport configured. +// Callers are responsible for layering additional transports (e.g. auth) and +// setting Timeout on the returned client. +func NewClient(rawBaseURL string) (*http.Client, error) { u, err := url.Parse(rawBaseURL) if err != nil { return nil, err } - var transport http.RoundTripper = &BaseURLTransport{ + return &http.Client{Transport: &BaseURLTransport{ BaseURL: u, Next: http.DefaultTransport, - } - if token != "" { - ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token}) - transport = &oauth2.Transport{Source: ts, Base: transport} - } - - return &http.Client{Transport: transport, Timeout: timeout}, nil + }}, nil } diff --git a/core/httpclient/transport_test.go b/core/httpclient/transport_test.go index fcdedf7c..7cbb710b 100644 --- a/core/httpclient/transport_test.go +++ b/core/httpclient/transport_test.go @@ -2,10 +2,8 @@ package httpclient import ( "net/http" - "net/http/httptest" "net/url" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -78,55 +76,10 @@ func TestBaseURLTransport_DoesNotMutateOriginalRequest(t *testing.T) { } func TestNewClient_InvalidURL(t *testing.T) { - _, err := NewClient("://invalid", "", 30*time.Second) + _, err := NewClient("://invalid") require.Error(t, err) } -func TestNewClient_SetsTimeout(t *testing.T) { - client, err := NewClient("https://api.github.com", "", 10*time.Second) - require.NoError(t, err) - assert.Equal(t, 10*time.Second, client.Timeout) -} - -func TestNewClient_AuthHeader(t *testing.T) { - tests := []struct { - name string - token string - wantAuthHeader string - }{ - { - name: "no token, no auth header", - token: "", - wantAuthHeader: "", - }, - { - name: "with token, adds bearer auth header", - token: "my-token", - wantAuthHeader: "Bearer my-token", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Captured-Auth", r.Header.Get("Authorization")) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - client, err := NewClient(server.URL, tt.token, 30*time.Second) - require.NoError(t, err) - - req, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) - - resp, err := client.Do(req) - require.NoError(t, err) - assert.Equal(t, tt.wantAuthHeader, resp.Header.Get("X-Captured-Auth")) - }) - } -} - func mustParseURL(t *testing.T, raw string) *url.URL { t.Helper() u, err := url.Parse(raw) diff --git a/example/server/orchestrator/BUILD.bazel b/example/server/orchestrator/BUILD.bazel index 0b3bab01..bbbb3a8d 100644 --- a/example/server/orchestrator/BUILD.bazel +++ b/example/server/orchestrator/BUILD.bazel @@ -41,6 +41,7 @@ go_library( "@com_github_uber_go_tally_v4//:tally", "@org_golang_google_grpc//:grpc", "@org_golang_google_grpc//reflection", + "@org_golang_x_oauth2//:oauth2", "@org_uber_go_zap//:zap", ], ) diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index d05965ee..5ed6b3ce 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -28,6 +28,8 @@ import ( "time" _ "github.com/go-sql-driver/mysql" + "golang.org/x/oauth2" + "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" "github.com/uber/submitqueue/core/httpclient" @@ -572,15 +574,18 @@ func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeCh // newChangeProvider creates a ChangeProvider for GitHub (github.com). // Configured via GITHUB_BASE_URL, GITHUB_TOKEN, and GITHUB_TIMEOUT environment variables. func newChangeProvider(logger *zap.Logger, scope tally.Scope) (changeprovider.ChangeProvider, error) { - client, err := httpclient.NewClient( - getEnv("GITHUB_BASE_URL", "https://api.github.com"), - os.Getenv("GITHUB_TOKEN"), - parseTimeout(os.Getenv("GITHUB_TIMEOUT"), 30*time.Second), - ) + client, err := httpclient.NewClient(getEnv("GITHUB_BASE_URL", "https://api.github.com")) if err != nil { return nil, fmt.Errorf("failed to build GitHub HTTP client: %w", err) } + if token := os.Getenv("GITHUB_TOKEN"); token != "" { + ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token}) + client.Transport = &oauth2.Transport{Source: ts, Base: client.Transport} + } + + client.Timeout = parseTimeout(os.Getenv("GITHUB_TIMEOUT"), 30*time.Second) + return githubprovider.NewProvider(githubprovider.Params{ HTTPClient: client, Logger: logger.Sugar(), diff --git a/extension/changeprovider/github/provider_test.go b/extension/changeprovider/github/provider_test.go index 5ba01671..cb7e380c 100644 --- a/extension/changeprovider/github/provider_test.go +++ b/extension/changeprovider/github/provider_test.go @@ -6,7 +6,6 @@ import ( "net/http" "net/http/httptest" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,7 +19,7 @@ import ( func newTestProvider(t *testing.T, serverURL string) changeprovider.ChangeProvider { t.Helper() - client, err := httpclient.NewClient(serverURL, "", 30*time.Second) + client, err := httpclient.NewClient(serverURL) require.NoError(t, err) return NewProvider(Params{ HTTPClient: client,