From 01b118308ac37928afe12254d9991e3f66c736bb Mon Sep 17 00:00:00 2001 From: rprithyani Date: Fri, 27 Feb 2026 05:16:45 +0000 Subject: [PATCH 1/6] implement change provider - github --- example/server/orchestrator/BUILD.bazel | 2 + example/server/orchestrator/main.go | 32 +- extension/changeprovider/change_provider.go | 10 +- extension/changeprovider/github/BUILD.bazel | 33 + extension/changeprovider/github/convert.go | 42 ++ extension/changeprovider/github/graphql.go | 200 +++++ extension/changeprovider/github/provider.go | 213 ++++++ .../changeprovider/github/provider_test.go | 684 ++++++++++++++++++ extension/changeprovider/github/validate.go | 84 +++ orchestrator/controller/validate/BUILD.bazel | 2 + orchestrator/controller/validate/validate.go | 59 +- .../controller/validate/validate_test.go | 24 +- 12 files changed, 1364 insertions(+), 21 deletions(-) create mode 100644 extension/changeprovider/github/BUILD.bazel create mode 100644 extension/changeprovider/github/convert.go create mode 100644 extension/changeprovider/github/graphql.go create mode 100644 extension/changeprovider/github/provider.go create mode 100644 extension/changeprovider/github/provider_test.go create mode 100644 extension/changeprovider/github/validate.go diff --git a/example/server/orchestrator/BUILD.bazel b/example/server/orchestrator/BUILD.bazel index 4a89170f..65ad5408 100644 --- a/example/server/orchestrator/BUILD.bazel +++ b/example/server/orchestrator/BUILD.bazel @@ -12,6 +12,8 @@ go_library( visibility = ["//visibility:private"], deps = [ "//core/consumer", + "//extension/changeprovider", + "//extension/changeprovider/github", "//extension/counter", "//extension/counter/mysql", "//extension/mergechecker", diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index 25f13364..2af761c2 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -30,6 +30,8 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" + "github.com/uber/submitqueue/extension/changeprovider" + githubprovider "github.com/uber/submitqueue/extension/changeprovider/github" "github.com/uber/submitqueue/extension/counter" mysqlcounter "github.com/uber/submitqueue/extension/counter/mysql" "github.com/uber/submitqueue/extension/mergechecker" @@ -188,8 +190,11 @@ func run() error { // Create merge checker mc := newMergeChecker(logger, scope) + // Create change provider + cp := newChangeProvider(logger, scope) + // Register controllers - if err := registerControllers(c, logger.Sugar(), scope, registry, mc, cnt, store); err != nil { + if err := registerControllers(c, logger.Sugar(), scope, registry, mc, cp, cnt, store); err != nil { return err } @@ -374,7 +379,7 @@ func newTopicRegistry(q extqueue.Queue, subscriberName string) (consumer.TopicRe // │ │ │ // └────────┴────────────────────────┘ -func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry, mc mergechecker.MergeChecker, cnt counter.Counter, store storage.Storage) error { +func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry, mc mergechecker.MergeChecker, cp changeprovider.ChangeProvider, cnt counter.Counter, store storage.Storage) error { requestController := start.NewController( logger, scope, @@ -393,6 +398,7 @@ func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope t store, registry, mc, + cp, consumer.TopicKeyValidate, "orchestrator-validate", ) @@ -524,6 +530,28 @@ func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeCh }) } +// newChangeProvider creates a ChangeProvider for GitHub (github.com). +// Configured via GITHUB_TOKEN and GITHUB_GRAPHQL_URL environment variables. +// Reuses the same HTTP client configuration as the mergechecker. +func newChangeProvider(logger *zap.Logger, scope tally.Scope) changeprovider.ChangeProvider { + graphQLURL := os.Getenv("GITHUB_GRAPHQL_URL") + if graphQLURL == "" { + graphQLURL = "https://api.github.com/graphql" + } + + httpClient := &http.Client{} + if token := os.Getenv("GITHUB_TOKEN"); token != "" { + httpClient.Transport = &bearerTransport{token: token} + } + + return githubprovider.NewProvider(githubprovider.Params{ + HTTPClient: httpClient, + GraphQLURL: graphQLURL, + Logger: logger.Sugar(), + MetricsScope: scope.SubScope("changeprovider"), + }) +} + // bearerTransport is an http.RoundTripper that adds a Bearer token to requests. type bearerTransport struct { token string diff --git a/extension/changeprovider/change_provider.go b/extension/changeprovider/change_provider.go index 377cbd08..72bdd5af 100644 --- a/extension/changeprovider/change_provider.go +++ b/extension/changeprovider/change_provider.go @@ -44,8 +44,9 @@ type ChangedFile struct { // ChangeInfo contains metadata and file changes for a code change. type ChangeInfo struct { - // ID is the change identifier (e.g., "PR: uber-code/go-code/1" or "diff: uber-code/go-code/D1"). - ID string + // URI is the full change URI for correlation with the input request + // (e.g., "github://uber/repo/98/abc123sha" or "phab://D123/xyz789"). + URI string // User is the author of the change. User User // ChangedFiles is the list of files modified in this change. Order is unspecified. @@ -56,6 +57,7 @@ type ChangeInfo struct { // Each implementation is configured for a specific provider (GitHub, GitLab, Phabricator). type ChangeProvider interface { // Get retrieves change information for the provided Change. - // Returns the change info containing metadata and file changes. - Get(ctx context.Context, change entity.Change) (ChangeInfo, error) + // For a Change with multiple URIs (e.g., stacked PRs), returns one ChangeInfo per URI. + // Returns a slice of ChangeInfo, one for each change in the stack. + Get(ctx context.Context, change entity.Change) ([]ChangeInfo, error) } diff --git a/extension/changeprovider/github/BUILD.bazel b/extension/changeprovider/github/BUILD.bazel new file mode 100644 index 00000000..1b11821c --- /dev/null +++ b/extension/changeprovider/github/BUILD.bazel @@ -0,0 +1,33 @@ +load("@rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "github", + srcs = [ + "convert.go", + "graphql.go", + "provider.go", + "validate.go", + ], + importpath = "github.com/uber/submitqueue/extension/changeprovider/github", + visibility = ["//visibility:public"], + deps = [ + "//entity", + "//entity/github", + "//extension/changeprovider", + "@com_github_uber_go_tally_v4//:tally", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "github_test", + srcs = ["provider_test.go"], + embed = [":github"], + deps = [ + "//entity", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@com_github_uber_go_tally_v4//:tally", + "@org_uber_go_zap//zaptest", + ], +) diff --git a/extension/changeprovider/github/convert.go b/extension/changeprovider/github/convert.go new file mode 100644 index 00000000..60394f83 --- /dev/null +++ b/extension/changeprovider/github/convert.go @@ -0,0 +1,42 @@ +package github + +import ( + entitygithub "github.com/uber/submitqueue/entity/github" + "github.com/uber/submitqueue/extension/changeprovider" +) + +// convertToChangeInfo converts GitHub PR data to ChangeInfo. +func convertToChangeInfo(parsed entitygithub.ChangeID, prData *pullRequestData) changeprovider.ChangeInfo { + changedFiles := convertFiles(prData.Files.Nodes) + + return changeprovider.ChangeInfo{ + URI: parsed.String(), + User: changeprovider.User{ + Name: prData.Author.Name, + Email: prData.Author.Email, + }, + ChangedFiles: changedFiles, + } +} + +// convertFiles converts GitHub file nodes to ChangedFile structs. +func convertFiles(nodes []fileNode) []changeprovider.ChangedFile { + changedFiles := make([]changeprovider.ChangedFile, 0, len(nodes)) + + for _, file := range nodes { + linesModified := 0 + if file.Additions > 0 && file.Deletions > 0 { + linesModified = min(file.Additions, file.Deletions) + } + + changedFiles = append(changedFiles, changeprovider.ChangedFile{ + Path: file.Path, + Patch: file.Patch, + LinesAdded: file.Additions, + LinesDeleted: file.Deletions, + LinesModified: linesModified, + }) + } + + return changedFiles +} diff --git a/extension/changeprovider/github/graphql.go b/extension/changeprovider/github/graphql.go new file mode 100644 index 00000000..058f2d56 --- /dev/null +++ b/extension/changeprovider/github/graphql.go @@ -0,0 +1,200 @@ +package github + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/uber-go/tally/v4" + "go.uber.org/zap" +) + +// pullRequestQuery is the GraphQL query to fetch pull request information including files, author, and head SHA. +const pullRequestQuery = ` +query($owner: String!, $repo: String!, $prNumber: Int!, $filesCursor: String) { + repository(owner: $owner, name: $repo) { + pullRequest(number: $prNumber) { + number + headRefOid + author { + login + ... on User { + name + email + } + } + files(first: 100, after: $filesCursor) { + totalCount + pageInfo { + endCursor + hasNextPage + } + nodes { + path + additions + deletions + changeType + patch + } + } + } + } +} +` + +// graphqlRequest represents a GraphQL request. +type graphqlRequest struct { + Query string `json:"query"` + Variables map[string]any `json:"variables"` +} + +// graphqlResponse represents the top-level GraphQL response. +type graphqlResponse struct { + Data struct { + Repository struct { + PullRequest pullRequestData `json:"pullRequest"` + } `json:"repository"` + } `json:"data"` + Errors []graphqlError `json:"errors,omitempty"` +} + +// graphqlError represents a GraphQL error. +type graphqlError struct { + Message string `json:"message"` + Type string `json:"type"` +} + +// pullRequestData contains the pull request metadata. +type pullRequestData struct { + Number int `json:"number"` + HeadRefOid string `json:"headRefOid"` + Author authorData `json:"author"` + Files filesData `json:"files"` +} + +// authorData contains the author information. +type authorData struct { + Login string `json:"login"` + Name string `json:"name"` + Email string `json:"email"` +} + +// filesData contains the files changed in the pull request. +type filesData struct { + TotalCount int `json:"totalCount"` + PageInfo pageInfo `json:"pageInfo"` + Nodes []fileNode `json:"nodes"` +} + +// pageInfo contains pagination information. +type pageInfo struct { + EndCursor string `json:"endCursor"` + HasNextPage bool `json:"hasNextPage"` +} + +// fileNode represents a single changed file. +type fileNode struct { + Path string `json:"path"` + Additions int `json:"additions"` + Deletions int `json:"deletions"` + ChangeType string `json:"changeType"` + Patch string `json:"patch"` +} + +// buildGraphQLRequest builds a GraphQL request for fetching pull request data. +func buildGraphQLRequest(org, repo string, prNumber int, cursor string) graphqlRequest { + return graphqlRequest{ + Query: pullRequestQuery, + Variables: map[string]any{ + "owner": org, + "repo": repo, + "prNumber": prNumber, + "filesCursor": cursor, + }, + } +} + +// doGraphQLRequest executes a GraphQL HTTP request. +func doGraphQLRequest( + ctx context.Context, + bodyBytes []byte, + graphQLURL string, + httpClient *http.Client, + org, repo string, + metrics tally.Scope, +) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, graphQLURL, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := httpClient.Do(req) + if err != nil { + metrics.Tagged(map[string]string{ + "org": org, + "repo": repo, + "error_type": "http_error", + }).Counter("get_errors").Inc(1) + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + + return resp, nil +} + +// parseGraphQLResponse parses and validates a GraphQL response. +func parseGraphQLResponse( + resp *http.Response, + org, repo string, + prNumber int, + logger *zap.SugaredLogger, + metrics tally.Scope, +) (*pullRequestData, error) { + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + logger.Errorw("GitHub API error", + "status", resp.StatusCode, + "org", org, + "repo", repo, + "pr", prNumber, + "response", string(body), + ) + metrics.Tagged(map[string]string{ + "org": org, + "repo": repo, + "error_type": "api_error", + }).Counter("get_errors").Inc(1) + return nil, fmt.Errorf("GitHub API returned status %d: %s", resp.StatusCode, string(body)) + } + + var gqlResp graphqlResponse + if err := json.NewDecoder(resp.Body).Decode(&gqlResp); err != nil { + metrics.Tagged(map[string]string{ + "org": org, + "repo": repo, + "error_type": "decode_error", + }).Counter("get_errors").Inc(1) + return nil, fmt.Errorf("failed to decode GraphQL response: %w", err) + } + + if len(gqlResp.Errors) > 0 { + logger.Errorw("GraphQL errors", + "org", org, + "repo", repo, + "pr", prNumber, + "errors", gqlResp.Errors, + ) + metrics.Tagged(map[string]string{ + "org": org, + "repo": repo, + "error_type": "graphql_error", + }).Counter("get_errors").Inc(1) + return nil, fmt.Errorf("GraphQL errors: %+v", gqlResp.Errors) + } + + return &gqlResp.Data.Repository.PullRequest, nil +} diff --git a/extension/changeprovider/github/provider.go b/extension/changeprovider/github/provider.go new file mode 100644 index 00000000..546889ec --- /dev/null +++ b/extension/changeprovider/github/provider.go @@ -0,0 +1,213 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/uber-go/tally/v4" + "go.uber.org/zap" + + "github.com/uber/submitqueue/entity" + entitygithub "github.com/uber/submitqueue/entity/github" + "github.com/uber/submitqueue/extension/changeprovider" +) + +// Params holds the dependencies for the GitHub ChangeProvider. +type Params struct { + // HTTPClient is a pre-configured HTTP client with auth (bearer token, GitHub App JWT, etc.). + // Auth is the caller's responsibility via HTTP transport/round-tripper. + HTTPClient *http.Client + // GraphQLURL is the GitHub GraphQL API endpoint + // (e.g., "https://api.github.com/graphql" or "https://ghe.example.com/api/graphql"). + GraphQLURL string + // Logger is the structured logger. + Logger *zap.SugaredLogger + // MetricsScope is the metrics scope for instrumentation. + MetricsScope tally.Scope +} + +// provider implements the ChangeProvider interface for GitHub. +type provider struct { + httpClient *http.Client + graphQLURL string + logger *zap.SugaredLogger + metrics tally.Scope +} + +// NewProvider creates a new GitHub ChangeProvider. +func NewProvider(params Params) changeprovider.ChangeProvider { + return &provider{ + httpClient: params.HTTPClient, + graphQLURL: params.GraphQLURL, + logger: params.Logger.Named("github_changeprovider"), + metrics: params.MetricsScope.SubScope("github_changeprovider"), + } +} + +// Get retrieves change information from GitHub for the provided Change. +// Returns one ChangeInfo per URI (one per PR in stacked changes). +// TODO add error codes for user errors (non-retryable) vs system errors. +func (p *provider) Get(ctx context.Context, change entity.Change) ([]changeprovider.ChangeInfo, error) { + p.metrics.Counter("get_change_info_started").Inc(1) + startTime := time.Now() + defer func() { + p.metrics.Timer("get_change_info_latency").Record(time.Since(startTime)) + }() + + if len(change.URIs) == 0 { + p.logger.Errorw("no URIs provided in change") + p.metrics.Counter("get_change_info_errors").Inc(1) + return nil, fmt.Errorf("no URIs provided") + } + + // Parse all change IDs + changeIDs := make([]entitygithub.ChangeID, 0, len(change.URIs)) + for _, uri := range change.URIs { + parsed, err := entitygithub.ParseChangeID(uri) + if err != nil { + p.logger.Errorw("failed to parse GitHub change ID", "uri", uri, "error", err) + p.metrics.Counter("get_change_info_errors").Inc(1) + return nil, fmt.Errorf("failed to parse GitHub change ID %q: %w", uri, err) + } + changeIDs = append(changeIDs, parsed) + } + + p.logger.Debugw("fetching PR data from GitHub", + "pr_count", len(changeIDs), + "uris", change.URIs, + ) + + // Validate stacked changes are consistent (same provider, org, and repo) + org, repo, err := validateChangeConsistency(changeIDs, p.logger, p.metrics) + if err != nil { + return nil, err + } + + // Fetch each PR and build ChangeInfo for each + changeInfos, fetchErrors, failedPRs := p.fetchAllPRs(ctx, changeIDs) + + // Return partial results if any PRs failed + if len(fetchErrors) > 0 { + p.logger.Errorw("failed to fetch some PRs", + "total_prs", len(changeIDs), + "failed_count", len(fetchErrors), + "failed_prs", failedPRs, + "succeeded_count", len(changeInfos), + ) + return changeInfos, fmt.Errorf("failed to fetch %d of %d PRs (failed: %v): %v", + len(fetchErrors), len(changeIDs), failedPRs, fetchErrors) + } + + p.logger.Infow("successfully fetched PR data", + "pr_count", len(changeIDs), + ) + + p.metrics.Tagged(map[string]string{ + "org": org, + "repo": repo, + }).Counter("get_success").Inc(1) + + return changeInfos, nil +} + +// fetchAllPRs fetches and validates all PRs in the stack, handling partial failures. +// Returns the successfully fetched ChangeInfos, any errors encountered, and the list of failed PR numbers. +func (p *provider) fetchAllPRs( + ctx context.Context, + changeIDs []entitygithub.ChangeID, +) ([]changeprovider.ChangeInfo, []error, []int) { + changeInfos := make([]changeprovider.ChangeInfo, 0, len(changeIDs)) + var fetchErrors []error + var failedPRs []int + + for _, cid := range changeIDs { + prData, err := p.fetchPullRequest(ctx, cid) + if err != nil { + p.logger.Errorw("failed to fetch PR from GitHub", + "org", cid.Org, + "repo", cid.Repo, + "pr", cid.PRNumber, + "error", err, + ) + p.metrics.Tagged(map[string]string{ + "org": cid.Org, + "repo": cid.Repo, + "error_type": "fetch_pr", + }).Counter("get_errors").Inc(1) + fetchErrors = append(fetchErrors, fmt.Errorf("PR #%d: %w", cid.PRNumber, err)) + failedPRs = append(failedPRs, cid.PRNumber) + continue // Continue to next PR + } + + // Validate PR hasn't changed since submission + if err := validatePRStaleness(cid, prData, p.logger, p.metrics); err != nil { + fetchErrors = append(fetchErrors, err) + failedPRs = append(failedPRs, cid.PRNumber) + continue // Continue to next PR + } + + // Convert to ChangeInfo + changeInfo := convertToChangeInfo(cid, prData) + changeInfos = append(changeInfos, changeInfo) + + p.logger.Debugw("fetched PR data", + "org", cid.Org, + "repo", cid.Repo, + "pr", cid.PRNumber, + "files_count", len(changeInfo.ChangedFiles), + "head_sha", prData.HeadRefOid, + ) + } + + return changeInfos, fetchErrors, failedPRs +} + +// fetchPullRequest makes GraphQL request(s) to fetch PR data, handling pagination. +func (p *provider) fetchPullRequest(ctx context.Context, parsed entitygithub.ChangeID) (*pullRequestData, error) { + var allFiles []fileNode + var prData pullRequestData + cursor := "" + + for { + data, err := p.fetchPullRequestPage(ctx, parsed.Org, parsed.Repo, parsed.PRNumber, cursor) + if err != nil { + return nil, err + } + + if cursor == "" { + prData = *data + } + + allFiles = append(allFiles, data.Files.Nodes...) + + if !data.Files.PageInfo.HasNextPage { + break + } + cursor = data.Files.PageInfo.EndCursor + } + + prData.Files.Nodes = allFiles + return &prData, nil +} + +// fetchPullRequestPage fetches a single page of PR data. +func (p *provider) fetchPullRequestPage(ctx context.Context, org, repo string, prNumber int, cursor string) (*pullRequestData, error) { + reqBody := buildGraphQLRequest(org, repo, prNumber, cursor) + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal GraphQL request: %w", err) + } + + resp, err := doGraphQLRequest(ctx, bodyBytes, p.graphQLURL, p.httpClient, org, repo, p.metrics) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return parseGraphQLResponse(resp, org, repo, prNumber, p.logger, p.metrics) +} + diff --git a/extension/changeprovider/github/provider_test.go b/extension/changeprovider/github/provider_test.go new file mode 100644 index 00000000..ec59db0a --- /dev/null +++ b/extension/changeprovider/github/provider_test.go @@ -0,0 +1,684 @@ +package github + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/uber-go/tally/v4" + "go.uber.org/zap/zaptest" + + "github.com/uber/submitqueue/entity" +) + +// mockRoundTripper is a mock implementation of http.RoundTripper for testing. +type mockRoundTripper struct { + roundTripFunc func(*http.Request) (*http.Response, error) +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return m.roundTripFunc(req) +} + +// newMockClient creates an http.Client with a mock RoundTripper. +func newMockClient(roundTripFunc func(*http.Request) (*http.Response, error)) *http.Client { + return &http.Client{ + Transport: &mockRoundTripper{roundTripFunc: roundTripFunc}, + } +} + +func TestProvider_Get_Success(t *testing.T) { + responseBody := `{ + "data": { + "repository": { + "pullRequest": { + "number": 123, + "headRefOid": "abc123def456", + "author": { + "login": "testuser", + "name": "Test User", + "email": "test@example.com" + }, + "files": { + "totalCount": 2, + "pageInfo": { + "endCursor": "", + "hasNextPage": false + }, + "nodes": [ + { + "path": "main.go", + "additions": 10, + "deletions": 5, + "changeType": "MODIFIED", + "patch": "diff --git a/main.go b/main.go\n..." + }, + { + "path": "test.go", + "additions": 20, + "deletions": 0, + "changeType": "ADDED", + "patch": "diff --git a/test.go b/test.go\n..." + } + ] + } + } + } + } + }` + + mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { + assert.Equal(t, http.MethodPost, req.Method) + assert.Equal(t, "https://api.github.test/graphql", req.URL.String()) + assert.Equal(t, "application/json", req.Header.Get("Content-Type")) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(responseBody)), + Header: make(http.Header), + }, nil + }) + + provider := NewProvider(Params{ + HTTPClient: mockClient, + GraphQLURL: "https://api.github.test/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) + + changeInfo, err := provider.Get(context.Background(), entity.Change{ + URIs: []string{"github://uber/submitqueue/123/abc123def456"}, + }) + + require.NoError(t, err) + require.Len(t, changeInfo, 1, "should return 1 ChangeInfo for 1 PR") + + info := changeInfo[0] + assert.Equal(t, "PR: uber/submitqueue/123", info.ID) + assert.Equal(t, "Test User", info.User.Name) + assert.Equal(t, "test@example.com", info.User.Email) + assert.Len(t, info.ChangedFiles, 2) + + assert.Equal(t, "main.go", info.ChangedFiles[0].Path) + assert.Equal(t, 10, info.ChangedFiles[0].LinesAdded) + assert.Equal(t, 5, info.ChangedFiles[0].LinesDeleted) + assert.Equal(t, 5, info.ChangedFiles[0].LinesModified) + assert.Contains(t, info.ChangedFiles[0].Patch, "diff --git a/main.go") + + assert.Equal(t, "test.go", info.ChangedFiles[1].Path) + assert.Equal(t, 20, info.ChangedFiles[1].LinesAdded) + assert.Equal(t, 0, info.ChangedFiles[1].LinesDeleted) + assert.Equal(t, 0, info.ChangedFiles[1].LinesModified) +} + +func TestProvider_Get_Pagination(t *testing.T) { + callCount := 0 + responses := []string{ + `{ + "data": { + "repository": { + "pullRequest": { + "number": 456, + "headRefOid": "xyz789", + "author": { + "login": "user", + "name": "User", + "email": "user@example.com" + }, + "files": { + "totalCount": 150, + "pageInfo": { + "endCursor": "cursor1", + "hasNextPage": true + }, + "nodes": [ + { + "path": "file1.go", + "additions": 5, + "deletions": 2, + "changeType": "MODIFIED", + "patch": "diff1" + } + ] + } + } + } + } + }`, + `{ + "data": { + "repository": { + "pullRequest": { + "number": 456, + "headRefOid": "xyz789", + "author": { + "login": "user", + "name": "User", + "email": "user@example.com" + }, + "files": { + "totalCount": 150, + "pageInfo": { + "endCursor": "", + "hasNextPage": false + }, + "nodes": [ + { + "path": "file2.go", + "additions": 3, + "deletions": 1, + "changeType": "MODIFIED", + "patch": "diff2" + } + ] + } + } + } + } + }`, + } + + mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { + response := responses[callCount] + callCount++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(response)), + Header: make(http.Header), + }, nil + }) + + provider := NewProvider(Params{ + HTTPClient: mockClient, + GraphQLURL: "https://api.github.test/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) + + changeInfo, err := provider.Get(context.Background(), entity.Change{ + URIs: []string{"github://uber/submitqueue/456/xyz789"}, + }) + + require.NoError(t, err) + assert.Equal(t, 2, callCount, "should make 2 GraphQL requests for pagination") + require.Len(t, changeInfo, 1, "should return 1 ChangeInfo for 1 PR") + + info := changeInfo[0] + assert.Len(t, info.ChangedFiles, 2, "should combine files from both pages") + assert.Equal(t, "file1.go", info.ChangedFiles[0].Path) + assert.Equal(t, "file2.go", info.ChangedFiles[1].Path) +} + +func TestProvider_Get_NoURIs(t *testing.T) { + provider := NewProvider(Params{ + HTTPClient: &http.Client{}, + GraphQLURL: "https://api.github.test/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) + + _, err := provider.Get(context.Background(), entity.Change{ + URIs: []string{}, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "no URIs provided") +} + +func TestProvider_Get_InvalidURI(t *testing.T) { + provider := NewProvider(Params{ + HTTPClient: &http.Client{}, + GraphQLURL: "https://api.github.test/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) + + _, err := provider.Get(context.Background(), entity.Change{ + URIs: []string{"invalid://uri"}, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse GitHub change ID") +} + +func TestProvider_Get_HTTPError(t *testing.T) { + mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { + return nil, assert.AnError + }) + + provider := NewProvider(Params{ + HTTPClient: mockClient, + GraphQLURL: "https://api.github.test/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) + + _, err := provider.Get(context.Background(), entity.Change{ + URIs: []string{"github://uber/submitqueue/pull/123/abc"}, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "HTTP request failed") +} + +func TestProvider_Get_APIError404(t *testing.T) { + mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(bytes.NewBufferString(`{"message":"Not Found"}`)), + Header: make(http.Header), + }, nil + }) + + provider := NewProvider(Params{ + HTTPClient: mockClient, + GraphQLURL: "https://api.github.test/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) + + _, err := provider.Get(context.Background(), entity.Change{ + URIs: []string{"github://uber/submitqueue/pull/999/abc"}, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "GitHub API returned status 404") +} + +func TestProvider_Get_GraphQLError(t *testing.T) { + responseBody := `{ + "errors": [ + { + "message": "Field 'pullRequest' doesn't exist on type 'Repository'", + "type": "INVALID_FIELD" + } + ] + }` + + mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(responseBody)), + Header: make(http.Header), + }, nil + }) + + provider := NewProvider(Params{ + HTTPClient: mockClient, + GraphQLURL: "https://api.github.test/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) + + _, err := provider.Get(context.Background(), entity.Change{ + URIs: []string{"github://uber/submitqueue/pull/123/abc"}, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "GraphQL errors") +} + +func TestProvider_Get_InvalidJSON(t *testing.T) { + mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{invalid json`)), + Header: make(http.Header), + }, nil + }) + + provider := NewProvider(Params{ + HTTPClient: mockClient, + GraphQLURL: "https://api.github.test/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) + + _, err := provider.Get(context.Background(), entity.Change{ + URIs: []string{"github://uber/submitqueue/pull/123/abc"}, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode GraphQL response") +} + +func TestNewProvider_DefaultConfig(t *testing.T) { + httpClient := &http.Client{Timeout: 30 * time.Second} + provider := NewProvider(Params{ + HTTPClient: httpClient, + GraphQLURL: "https://api.github.com/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }).(*provider) + + assert.Equal(t, "https://api.github.com/graphql", provider.graphQLURL) + assert.NotNil(t, provider.httpClient) +} + +func TestProvider_Get_MultiplePRs(t *testing.T) { + callCount := 0 + responses := map[int]string{ + 0: `{ + "data": { + "repository": { + "pullRequest": { + "number": 123, + "headRefOid": "abc123", + "author": { + "login": "user1", + "name": "User One", + "email": "user1@example.com" + }, + "files": { + "totalCount": 1, + "pageInfo": { + "endCursor": "", + "hasNextPage": false + }, + "nodes": [ + { + "path": "file1.go", + "additions": 10, + "deletions": 5, + "changeType": "MODIFIED", + "patch": "diff1" + } + ] + } + } + } + } + }`, + 1: `{ + "data": { + "repository": { + "pullRequest": { + "number": 456, + "headRefOid": "def456", + "author": { + "login": "user1", + "name": "User One", + "email": "user1@example.com" + }, + "files": { + "totalCount": 1, + "pageInfo": { + "endCursor": "", + "hasNextPage": false + }, + "nodes": [ + { + "path": "file2.go", + "additions": 20, + "deletions": 2, + "changeType": "ADDED", + "patch": "diff2" + } + ] + } + } + } + } + }`, + } + + mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { + response := responses[callCount] + callCount++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(response)), + Header: make(http.Header), + }, nil + }) + + provider := NewProvider(Params{ + HTTPClient: mockClient, + GraphQLURL: "https://api.github.test/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) + + changeInfo, err := provider.Get(context.Background(), entity.Change{ + URIs: []string{ + "github://uber/submitqueue/123/abc123", + "github://uber/submitqueue/456/def456", + }, + }) + + require.NoError(t, err) + assert.Equal(t, 2, callCount, "should make 2 GraphQL requests for 2 PRs") + require.Len(t, changeInfo, 2, "should return 2 ChangeInfo for 2 PRs") + + // First PR + assert.Equal(t, "PR: uber/submitqueue/123", changeInfo[0].ID) + assert.Equal(t, "User One", changeInfo[0].User.Name) + assert.Equal(t, "user1@example.com", changeInfo[0].User.Email) + assert.Len(t, changeInfo[0].ChangedFiles, 1) + assert.Equal(t, "file1.go", changeInfo[0].ChangedFiles[0].Path) + + // Second PR + assert.Equal(t, "PR: uber/submitqueue/456", changeInfo[1].ID) + assert.Equal(t, "User One", changeInfo[1].User.Name) + assert.Equal(t, "user1@example.com", changeInfo[1].User.Email) + assert.Len(t, changeInfo[1].ChangedFiles, 1) + assert.Equal(t, "file2.go", changeInfo[1].ChangedFiles[0].Path) +} + +func TestProvider_Get_CrossRepoStack(t *testing.T) { + provider := NewProvider(Params{ + HTTPClient: &http.Client{}, + GraphQLURL: "https://api.github.test/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) + + _, err := provider.Get(context.Background(), entity.Change{ + URIs: []string{ + "github://uber/submitqueue/123/abc123", + "github://uber/different-repo/456/def456", // Different repo! + }, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "stacked changes must be from same repository") + assert.Contains(t, err.Error(), "expected uber/submitqueue") + assert.Contains(t, err.Error(), "got uber/different-repo") +} + +func TestProvider_Get_MixedProviderStack(t *testing.T) { + provider := NewProvider(Params{ + HTTPClient: &http.Client{}, + GraphQLURL: "https://api.github.test/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) + + _, err := provider.Get(context.Background(), entity.Change{ + URIs: []string{ + "github://uber/submitqueue/123/abc123", + "ghe://uber/submitqueue/456/def456", // Different provider (GHE instead of github)! + }, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "stacked changes must use same change provider") + assert.Contains(t, err.Error(), "expected github") + assert.Contains(t, err.Error(), "got ghe") +} + +func TestProvider_Get_StalePR(t *testing.T) { + responseBody := `{ + "data": { + "repository": { + "pullRequest": { + "number": 123, + "headRefOid": "newsha123", + "author": { + "login": "testuser", + "name": "Test User", + "email": "test@example.com" + }, + "files": { + "totalCount": 1, + "pageInfo": { + "endCursor": "", + "hasNextPage": false + }, + "nodes": [ + { + "path": "main.go", + "additions": 10, + "deletions": 5, + "changeType": "MODIFIED", + "patch": "diff --git a/main.go..." + } + ] + } + } + } + } + }` + + mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(responseBody)), + Header: make(http.Header), + }, nil + }) + + provider := NewProvider(Params{ + HTTPClient: mockClient, + GraphQLURL: "https://api.github.test/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) + + _, err := provider.Get(context.Background(), entity.Change{ + URIs: []string{"github://uber/submitqueue/123/oldsha456"}, // Different SHA! + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "PR #123 head SHA changed") + assert.Contains(t, err.Error(), "expected oldsha456") + assert.Contains(t, err.Error(), "got newsha123") +} + +func TestProvider_Get_PartialSuccess(t *testing.T) { + callCount := 0 + responses := map[int]string{ + 0: `{ + "data": { + "repository": { + "pullRequest": { + "number": 123, + "headRefOid": "abc123", + "author": { + "login": "user1", + "name": "User One", + "email": "user1@example.com" + }, + "files": { + "totalCount": 1, + "pageInfo": { + "endCursor": "", + "hasNextPage": false + }, + "nodes": [ + { + "path": "file1.go", + "additions": 10, + "deletions": 5, + "changeType": "MODIFIED", + "patch": "diff1" + } + ] + } + } + } + } + }`, + // Second PR will get an error response (404) + 2: `{ + "data": { + "repository": { + "pullRequest": { + "number": 789, + "headRefOid": "ghi789", + "author": { + "login": "user1", + "name": "User One", + "email": "user1@example.com" + }, + "files": { + "totalCount": 1, + "pageInfo": { + "endCursor": "", + "hasNextPage": false + }, + "nodes": [ + { + "path": "file3.go", + "additions": 15, + "deletions": 3, + "changeType": "MODIFIED", + "patch": "diff3" + } + ] + } + } + } + } + }`, + } + + mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { + if callCount == 1 { + // Fail on second PR + callCount++ + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(bytes.NewBufferString(`{"message":"Not Found"}`)), + Header: make(http.Header), + }, nil + } + response := responses[callCount] + callCount++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(response)), + Header: make(http.Header), + }, nil + }) + + provider := NewProvider(Params{ + HTTPClient: mockClient, + GraphQLURL: "https://api.github.test/graphql", + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) + + changeInfo, err := provider.Get(context.Background(), entity.Change{ + URIs: []string{ + "github://uber/submitqueue/123/abc123", + "github://uber/submitqueue/456/def456", // This will fail + "github://uber/submitqueue/789/ghi789", + }, + }) + + // Should return partial results with error + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to fetch 1 of 3 PRs") + assert.Contains(t, err.Error(), "failed: [456]") + + // Should have 2 successful PRs + require.Len(t, changeInfo, 2, "should return 2 successful ChangeInfo despite 1 failure") + assert.Equal(t, "PR: uber/submitqueue/123", changeInfo[0].ID) + assert.Equal(t, "PR: uber/submitqueue/789", changeInfo[1].ID) +} diff --git a/extension/changeprovider/github/validate.go b/extension/changeprovider/github/validate.go new file mode 100644 index 00000000..d986b8e0 --- /dev/null +++ b/extension/changeprovider/github/validate.go @@ -0,0 +1,84 @@ +package github + +import ( + "fmt" + + "github.com/uber-go/tally/v4" + "go.uber.org/zap" + + entitygithub "github.com/uber/submitqueue/entity/github" +) + +// validateChangeConsistency validates that all changeIDs in the stack are consistent. +// Stacked changes must have the same change provider (scheme), org, and repo. +// Returns the org and repo if valid, or an error if any change is inconsistent. +func validateChangeConsistency( + changeIDs []entitygithub.ChangeID, + logger *zap.SugaredLogger, + metrics tally.Scope, +) (string, string, error) { + if len(changeIDs) == 0 { + return "", "", nil + } + + expectedScheme := changeIDs[0].Scheme + expectedOrg := changeIDs[0].Org + expectedRepo := changeIDs[0].Repo + + for _, cid := range changeIDs { + // Validate same change provider (scheme) + if cid.Scheme != expectedScheme { + logger.Errorw("stacked changes must use same change provider", + "expected_provider", expectedScheme, + "got_provider", cid.Scheme, + "pr", cid.PRNumber, + ) + metrics.Tagged(map[string]string{"error_type": "mixed_provider_stack"}).Counter("get_errors").Inc(1) + return "", "", fmt.Errorf("stacked changes must use same change provider: expected %s, got %s for PR #%d", + expectedScheme, cid.Scheme, cid.PRNumber) + } + + // Validate same org and repo + if cid.Org != expectedOrg || cid.Repo != expectedRepo { + logger.Errorw("stacked changes must be from same repository", + "expected_org", expectedOrg, + "expected_repo", expectedRepo, + "got_org", cid.Org, + "got_repo", cid.Repo, + "pr", cid.PRNumber, + ) + metrics.Tagged(map[string]string{"error_type": "cross_repo_stack"}).Counter("get_errors").Inc(1) + return "", "", fmt.Errorf("stacked changes must be from same repository: expected %s/%s, got %s/%s for PR #%d", + expectedOrg, expectedRepo, cid.Org, cid.Repo, cid.PRNumber) + } + } + + return expectedOrg, expectedRepo, nil +} + +// validatePRStaleness validates that the PR hasn't changed since submission. +// Compares the fetched head SHA with the expected SHA from the change URI. +func validatePRStaleness( + cid entitygithub.ChangeID, + prData *pullRequestData, + logger *zap.SugaredLogger, + metrics tally.Scope, +) error { + if prData.HeadRefOid != cid.HeadCommitSHA { + logger.Errorw("PR head SHA changed since submission", + "org", cid.Org, + "repo", cid.Repo, + "pr", cid.PRNumber, + "expected_sha", cid.HeadCommitSHA, + "current_sha", prData.HeadRefOid, + ) + metrics.Tagged(map[string]string{ + "org": cid.Org, + "repo": cid.Repo, + "error_type": "stale_pr", + }).Counter("get_errors").Inc(1) + return fmt.Errorf("PR #%d head SHA changed: expected %s, got %s", + cid.PRNumber, cid.HeadCommitSHA, prData.HeadRefOid) + } + return nil +} diff --git a/orchestrator/controller/validate/BUILD.bazel b/orchestrator/controller/validate/BUILD.bazel index 0ea6244a..a4a2f264 100644 --- a/orchestrator/controller/validate/BUILD.bazel +++ b/orchestrator/controller/validate/BUILD.bazel @@ -10,6 +10,7 @@ go_library( "//core/errs", "//entity", "//entity/queue", + "//extension/changeprovider", "//extension/mergechecker", "//extension/storage", "@com_github_uber_go_tally_v4//:tally", @@ -26,6 +27,7 @@ go_test( "//core/errs", "//entity", "//entity/queue", + "//extension/changeprovider", "//extension/mergechecker", "//extension/mergechecker/mock", "//extension/queue/mock", diff --git a/orchestrator/controller/validate/validate.go b/orchestrator/controller/validate/validate.go index 2ee8d8b1..771a3c80 100644 --- a/orchestrator/controller/validate/validate.go +++ b/orchestrator/controller/validate/validate.go @@ -23,6 +23,7 @@ import ( "github.com/uber/submitqueue/core/errs" "github.com/uber/submitqueue/entity" entityqueue "github.com/uber/submitqueue/entity/queue" + "github.com/uber/submitqueue/extension/changeprovider" "github.com/uber/submitqueue/extension/mergechecker" "github.com/uber/submitqueue/extension/storage" "go.uber.org/zap" @@ -33,13 +34,14 @@ import ( // and publishes to the batch stage. Validation logic is extensible to support additional checks. // Implements consumer.Controller interface for integration with the consumer. type Controller struct { - logger *zap.SugaredLogger - metricsScope tally.Scope - store storage.Storage - registry consumer.TopicRegistry - mergeChecker mergechecker.MergeChecker - topicKey consumer.TopicKey - consumerGroup string + logger *zap.SugaredLogger + metricsScope tally.Scope + store storage.Storage + registry consumer.TopicRegistry + mergeChecker mergechecker.MergeChecker + changeProvider changeprovider.ChangeProvider + topicKey consumer.TopicKey + consumerGroup string } // Verify Controller implements consumer.Controller interface at compile time. @@ -52,17 +54,19 @@ func NewController( store storage.Storage, registry consumer.TopicRegistry, mergeChecker mergechecker.MergeChecker, + changeProvider changeprovider.ChangeProvider, topicKey consumer.TopicKey, consumerGroup string, ) *Controller { return &Controller{ - logger: logger.Named("validate_controller"), - metricsScope: scope.SubScope("validate_controller"), - store: store, - registry: registry, - mergeChecker: mergeChecker, - topicKey: topicKey, - consumerGroup: consumerGroup, + logger: logger.Named("validate_controller"), + metricsScope: scope.SubScope("validate_controller"), + store: store, + registry: registry, + mergeChecker: mergeChecker, + changeProvider: changeProvider, + topicKey: topicKey, + consumerGroup: consumerGroup, } } @@ -113,6 +117,24 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er return errs.NewUserError(fmt.Errorf("request %s is not mergeable: %s", request.ID, mergeResult.Reason)) } + // Fetch change metadata + changeInfos, err := c.changeProvider.Get(ctx, request.Change) + if err != nil { + c.logger.Errorw("failed to fetch change information", + "request_id", request.ID, + "change_uris", request.Change.URIs, + "error", err, + ) + c.metricsScope.Counter("change_provider_errors").Inc(1) + return fmt.Errorf("failed to fetch change information: %w", err) + } + + c.logger.Infow("fetched change information", + "request_id", request.ID, + "change_count", len(changeInfos), + "total_files", totalFiles(changeInfos), + ) + // Publish to batch topic if err := c.publish(ctx, consumer.TopicKeyBatch, request.ID, request.Queue); err != nil { c.metricsScope.Counter("publish_errors").Inc(1) @@ -156,6 +178,15 @@ func (c *Controller) publish(ctx context.Context, key consumer.TopicKey, request return nil } +// totalFiles returns the total number of files across all changeInfos. +func totalFiles(infos []changeprovider.ChangeInfo) int { + total := 0 + for _, info := range infos { + total += len(info.ChangedFiles) + } + return total +} + // Name returns the controller name for logging and metrics. func (c *Controller) Name() string { return "validate" diff --git a/orchestrator/controller/validate/validate_test.go b/orchestrator/controller/validate/validate_test.go index cba0500d..0d1b2ca7 100644 --- a/orchestrator/controller/validate/validate_test.go +++ b/orchestrator/controller/validate/validate_test.go @@ -26,6 +26,7 @@ import ( "github.com/uber/submitqueue/core/errs" "github.com/uber/submitqueue/entity" "github.com/uber/submitqueue/entity/queue" + "github.com/uber/submitqueue/extension/changeprovider" "github.com/uber/submitqueue/extension/mergechecker" mergecheckermock "github.com/uber/submitqueue/extension/mergechecker/mock" queuemock "github.com/uber/submitqueue/extension/queue/mock" @@ -41,6 +42,25 @@ func requestIDPayload(t *testing.T, id string) []byte { return payload } +// mockChangeProvider is a simple mock that returns test data. +type mockChangeProvider struct{} + +func (m *mockChangeProvider) Get(ctx context.Context, change entity.Change) ([]changeprovider.ChangeInfo, error) { + // Return simple test data + return []changeprovider.ChangeInfo{ + { + URI: "github://org/repo/123/abc123", + User: changeprovider.User{ + Name: "Test User", + Email: "test@example.com", + }, + ChangedFiles: []changeprovider.ChangedFile{ + {Path: "main.go"}, + }, + }, + }, nil +} + // newMergeableMock returns a mock MergeChecker that always returns mergeable. func newMergeableMock(ctrl *gomock.Controller) *mergecheckermock.MockMergeChecker { mc := mergecheckermock.NewMockMergeChecker(ctrl) @@ -78,7 +98,9 @@ func newTestController(t *testing.T, ctrl *gomock.Controller, store *storagemock ) require.NoError(t, err) - return NewController(logger, scope, store, registry, mc, consumer.TopicKeyValidate, "orchestrator-validate") + cp := &mockChangeProvider{} + + return NewController(logger, scope, store, registry, mc, cp, consumer.TopicKeyValidate, "orchestrator-validate") } func TestNewController(t *testing.T) { From ca56f243c7fd584f6f75a638115e6bb0f24ce51c Mon Sep 17 00:00:00 2001 From: rprithyani Date: Tue, 3 Mar 2026 07:35:08 +0000 Subject: [PATCH 2/6] Refactor and address code review suggestions - Rename ID to URI with full scheme://org/repo/pr/sha format - Remove defensive checks and excessive logging - Simplify validation utilities (pure functions) Co-Authored-By: Claude Opus 4.6 --- extension/changeprovider/github/provider.go | 13 ++----- .../changeprovider/github/provider_test.go | 10 ++--- extension/changeprovider/github/validate.go | 37 ------------------- 3 files changed, 8 insertions(+), 52 deletions(-) diff --git a/extension/changeprovider/github/provider.go b/extension/changeprovider/github/provider.go index 546889ec..e6580354 100644 --- a/extension/changeprovider/github/provider.go +++ b/extension/changeprovider/github/provider.go @@ -57,18 +57,11 @@ func (p *provider) Get(ctx context.Context, change entity.Change) ([]changeprovi p.metrics.Timer("get_change_info_latency").Record(time.Since(startTime)) }() - if len(change.URIs) == 0 { - p.logger.Errorw("no URIs provided in change") - p.metrics.Counter("get_change_info_errors").Inc(1) - return nil, fmt.Errorf("no URIs provided") - } - // Parse all change IDs changeIDs := make([]entitygithub.ChangeID, 0, len(change.URIs)) for _, uri := range change.URIs { parsed, err := entitygithub.ParseChangeID(uri) if err != nil { - p.logger.Errorw("failed to parse GitHub change ID", "uri", uri, "error", err) p.metrics.Counter("get_change_info_errors").Inc(1) return nil, fmt.Errorf("failed to parse GitHub change ID %q: %w", uri, err) } @@ -81,7 +74,7 @@ func (p *provider) Get(ctx context.Context, change entity.Change) ([]changeprovi ) // Validate stacked changes are consistent (same provider, org, and repo) - org, repo, err := validateChangeConsistency(changeIDs, p.logger, p.metrics) + org, repo, err := validateChangeConsistency(changeIDs) if err != nil { return nil, err } @@ -101,7 +94,7 @@ func (p *provider) Get(ctx context.Context, change entity.Change) ([]changeprovi len(fetchErrors), len(changeIDs), failedPRs, fetchErrors) } - p.logger.Infow("successfully fetched PR data", + p.logger.Debugw("successfully fetched PR data", "pr_count", len(changeIDs), ) @@ -143,7 +136,7 @@ func (p *provider) fetchAllPRs( } // Validate PR hasn't changed since submission - if err := validatePRStaleness(cid, prData, p.logger, p.metrics); err != nil { + if err := validatePRStaleness(cid, prData); err != nil { fetchErrors = append(fetchErrors, err) failedPRs = append(failedPRs, cid.PRNumber) continue // Continue to next PR diff --git a/extension/changeprovider/github/provider_test.go b/extension/changeprovider/github/provider_test.go index ec59db0a..8292e6ce 100644 --- a/extension/changeprovider/github/provider_test.go +++ b/extension/changeprovider/github/provider_test.go @@ -99,7 +99,7 @@ func TestProvider_Get_Success(t *testing.T) { require.Len(t, changeInfo, 1, "should return 1 ChangeInfo for 1 PR") info := changeInfo[0] - assert.Equal(t, "PR: uber/submitqueue/123", info.ID) + assert.Equal(t, "github://uber/submitqueue/123/abc123def456", info.URI) assert.Equal(t, "Test User", info.User.Name) assert.Equal(t, "test@example.com", info.User.Email) assert.Len(t, info.ChangedFiles, 2) @@ -456,14 +456,14 @@ func TestProvider_Get_MultiplePRs(t *testing.T) { require.Len(t, changeInfo, 2, "should return 2 ChangeInfo for 2 PRs") // First PR - assert.Equal(t, "PR: uber/submitqueue/123", changeInfo[0].ID) + assert.Equal(t, "github://uber/submitqueue/123/abc123", changeInfo[0].URI) assert.Equal(t, "User One", changeInfo[0].User.Name) assert.Equal(t, "user1@example.com", changeInfo[0].User.Email) assert.Len(t, changeInfo[0].ChangedFiles, 1) assert.Equal(t, "file1.go", changeInfo[0].ChangedFiles[0].Path) // Second PR - assert.Equal(t, "PR: uber/submitqueue/456", changeInfo[1].ID) + assert.Equal(t, "github://uber/submitqueue/456/def456", changeInfo[1].URI) assert.Equal(t, "User One", changeInfo[1].User.Name) assert.Equal(t, "user1@example.com", changeInfo[1].User.Email) assert.Len(t, changeInfo[1].ChangedFiles, 1) @@ -679,6 +679,6 @@ func TestProvider_Get_PartialSuccess(t *testing.T) { // Should have 2 successful PRs require.Len(t, changeInfo, 2, "should return 2 successful ChangeInfo despite 1 failure") - assert.Equal(t, "PR: uber/submitqueue/123", changeInfo[0].ID) - assert.Equal(t, "PR: uber/submitqueue/789", changeInfo[1].ID) + assert.Equal(t, "github://uber/submitqueue/123/abc123", changeInfo[0].URI) + assert.Equal(t, "github://uber/submitqueue/789/ghi789", changeInfo[1].URI) } diff --git a/extension/changeprovider/github/validate.go b/extension/changeprovider/github/validate.go index d986b8e0..b885dd7d 100644 --- a/extension/changeprovider/github/validate.go +++ b/extension/changeprovider/github/validate.go @@ -3,9 +3,6 @@ package github import ( "fmt" - "github.com/uber-go/tally/v4" - "go.uber.org/zap" - entitygithub "github.com/uber/submitqueue/entity/github" ) @@ -14,13 +11,7 @@ import ( // Returns the org and repo if valid, or an error if any change is inconsistent. func validateChangeConsistency( changeIDs []entitygithub.ChangeID, - logger *zap.SugaredLogger, - metrics tally.Scope, ) (string, string, error) { - if len(changeIDs) == 0 { - return "", "", nil - } - expectedScheme := changeIDs[0].Scheme expectedOrg := changeIDs[0].Org expectedRepo := changeIDs[0].Repo @@ -28,26 +19,12 @@ func validateChangeConsistency( for _, cid := range changeIDs { // Validate same change provider (scheme) if cid.Scheme != expectedScheme { - logger.Errorw("stacked changes must use same change provider", - "expected_provider", expectedScheme, - "got_provider", cid.Scheme, - "pr", cid.PRNumber, - ) - metrics.Tagged(map[string]string{"error_type": "mixed_provider_stack"}).Counter("get_errors").Inc(1) return "", "", fmt.Errorf("stacked changes must use same change provider: expected %s, got %s for PR #%d", expectedScheme, cid.Scheme, cid.PRNumber) } // Validate same org and repo if cid.Org != expectedOrg || cid.Repo != expectedRepo { - logger.Errorw("stacked changes must be from same repository", - "expected_org", expectedOrg, - "expected_repo", expectedRepo, - "got_org", cid.Org, - "got_repo", cid.Repo, - "pr", cid.PRNumber, - ) - metrics.Tagged(map[string]string{"error_type": "cross_repo_stack"}).Counter("get_errors").Inc(1) return "", "", fmt.Errorf("stacked changes must be from same repository: expected %s/%s, got %s/%s for PR #%d", expectedOrg, expectedRepo, cid.Org, cid.Repo, cid.PRNumber) } @@ -61,22 +38,8 @@ func validateChangeConsistency( func validatePRStaleness( cid entitygithub.ChangeID, prData *pullRequestData, - logger *zap.SugaredLogger, - metrics tally.Scope, ) error { if prData.HeadRefOid != cid.HeadCommitSHA { - logger.Errorw("PR head SHA changed since submission", - "org", cid.Org, - "repo", cid.Repo, - "pr", cid.PRNumber, - "expected_sha", cid.HeadCommitSHA, - "current_sha", prData.HeadRefOid, - ) - metrics.Tagged(map[string]string{ - "org": cid.Org, - "repo": cid.Repo, - "error_type": "stale_pr", - }).Counter("get_errors").Inc(1) return fmt.Errorf("PR #%d head SHA changed: expected %s, got %s", cid.PRNumber, cid.HeadCommitSHA, prData.HeadRefOid) } From 59e5aeceb5fd87c514382cb0cef73aeeccf063fe Mon Sep 17 00:00:00 2001 From: rprithyani Date: Tue, 3 Mar 2026 08:29:26 +0000 Subject: [PATCH 3/6] Refactor GitHub provider to use Config pattern Replace Params with Config to encapsulate BaseURL + Token + Timeout together, addressing review feedback about HTTPClient encapsulation. Changes: - Add Config type with BaseURL, Token, Timeout, and optional HTTPClient - Derive GraphQL URL from BaseURL (no separate parameter needed) - Add configurable timeout (defaults to 30s) - Simplify main.go to use DefaultConfig() - Update all tests to use Config instead of Params Benefits: - Impossible to misconfigure (URL and auth bundled) - Multi-instance support (github.com + GHE) - Flexible timeout per instance - Cleaner API surface Co-Authored-By: Claude Opus 4.6 --- example/server/orchestrator/main.go | 21 +-- extension/changeprovider/github/BUILD.bazel | 1 + extension/changeprovider/github/config.go | 59 ++++++++ extension/changeprovider/github/provider.go | 79 +++++++--- .../changeprovider/github/provider_test.go | 143 ++++++------------ 5 files changed, 171 insertions(+), 132 deletions(-) create mode 100644 extension/changeprovider/github/config.go diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index 2af761c2..4e273ee2 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -531,25 +531,10 @@ func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeCh } // newChangeProvider creates a ChangeProvider for GitHub (github.com). -// Configured via GITHUB_TOKEN and GITHUB_GRAPHQL_URL environment variables. -// Reuses the same HTTP client configuration as the mergechecker. +// Configured via GITHUB_BASE_URL and GITHUB_TOKEN environment variables. func newChangeProvider(logger *zap.Logger, scope tally.Scope) changeprovider.ChangeProvider { - graphQLURL := os.Getenv("GITHUB_GRAPHQL_URL") - if graphQLURL == "" { - graphQLURL = "https://api.github.com/graphql" - } - - httpClient := &http.Client{} - if token := os.Getenv("GITHUB_TOKEN"); token != "" { - httpClient.Transport = &bearerTransport{token: token} - } - - return githubprovider.NewProvider(githubprovider.Params{ - HTTPClient: httpClient, - GraphQLURL: graphQLURL, - Logger: logger.Sugar(), - MetricsScope: scope.SubScope("changeprovider"), - }) + config := githubprovider.DefaultConfig() + return githubprovider.NewProvider(config, logger.Sugar(), scope.SubScope("changeprovider")) } // bearerTransport is an http.RoundTripper that adds a Bearer token to requests. diff --git a/extension/changeprovider/github/BUILD.bazel b/extension/changeprovider/github/BUILD.bazel index 1b11821c..c67ce397 100644 --- a/extension/changeprovider/github/BUILD.bazel +++ b/extension/changeprovider/github/BUILD.bazel @@ -3,6 +3,7 @@ load("@rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "github", srcs = [ + "config.go", "convert.go", "graphql.go", "provider.go", diff --git a/extension/changeprovider/github/config.go b/extension/changeprovider/github/config.go new file mode 100644 index 00000000..5a082474 --- /dev/null +++ b/extension/changeprovider/github/config.go @@ -0,0 +1,59 @@ +package github + +import ( + "fmt" + "net/http" + "os" + "time" +) + +const ( + // DefaultTimeout is the default HTTP request timeout for GitHub API calls. + // This applies to the entire request/response cycle. + DefaultTimeout = 30 * time.Second +) + +// Config holds configuration for connecting to a GitHub backend. +type Config struct { + // BaseURL is the GitHub instance base URL (without /graphql suffix). + // Examples: "https://api.github.com", "https://ghe.company.com" + BaseURL string + + // Token for authenticating to this GitHub instance. + // Can be empty for unauthenticated requests. + Token string + + // Timeout for HTTP requests to this GitHub instance. + // If zero or negative, defaults to DefaultTimeout (30s). + // Set this higher for slow GHE instances or flaky networks. + Timeout time.Duration + + // HTTPClient provides complete control over the HTTP client. + // If set, BaseURL is still used but Token and Timeout are ignored. + // Use this for custom transports, connection pooling, or testing. + HTTPClient *http.Client +} + +// Validate checks if the config is valid. +func (c Config) Validate() error { + if c.BaseURL == "" { + return fmt.Errorf("BaseURL is required") + } + return nil +} + +// DefaultConfig returns a Config for github.com from environment. +func DefaultConfig() Config { + return Config{ + BaseURL: getEnvOrDefault("GITHUB_BASE_URL", "https://api.github.com"), + Token: os.Getenv("GITHUB_TOKEN"), + Timeout: 0, // Will use DefaultTimeout + } +} + +func getEnvOrDefault(key, defaultVal string) string { + if val := os.Getenv(key); val != "" { + return val + } + return defaultVal +} diff --git a/extension/changeprovider/github/provider.go b/extension/changeprovider/github/provider.go index e6580354..3ba07c4c 100644 --- a/extension/changeprovider/github/provider.go +++ b/extension/changeprovider/github/provider.go @@ -15,20 +15,6 @@ import ( "github.com/uber/submitqueue/extension/changeprovider" ) -// Params holds the dependencies for the GitHub ChangeProvider. -type Params struct { - // HTTPClient is a pre-configured HTTP client with auth (bearer token, GitHub App JWT, etc.). - // Auth is the caller's responsibility via HTTP transport/round-tripper. - HTTPClient *http.Client - // GraphQLURL is the GitHub GraphQL API endpoint - // (e.g., "https://api.github.com/graphql" or "https://ghe.example.com/api/graphql"). - GraphQLURL string - // Logger is the structured logger. - Logger *zap.SugaredLogger - // MetricsScope is the metrics scope for instrumentation. - MetricsScope tally.Scope -} - // provider implements the ChangeProvider interface for GitHub. type provider struct { httpClient *http.Client @@ -37,14 +23,67 @@ type provider struct { metrics tally.Scope } -// NewProvider creates a new GitHub ChangeProvider. -func NewProvider(params Params) changeprovider.ChangeProvider { +// NewProvider creates a new GitHub ChangeProvider from configuration. +func NewProvider(config Config, logger *zap.SugaredLogger, metrics tally.Scope) changeprovider.ChangeProvider { + if err := config.Validate(); err != nil { + panic(fmt.Sprintf("invalid GitHub config: %v", err)) + } + + // Derive GraphQL URL from base URL + graphQLURL := config.BaseURL + "/graphql" + + // Use provided client or create default + httpClient := config.HTTPClient + if httpClient == nil { + timeout := config.Timeout + if timeout <= 0 { + timeout = DefaultTimeout + } + httpClient = createDefaultClient(config.Token, timeout) + } + return &provider{ - httpClient: params.HTTPClient, - graphQLURL: params.GraphQLURL, - logger: params.Logger.Named("github_changeprovider"), - metrics: params.MetricsScope.SubScope("github_changeprovider"), + httpClient: httpClient, + graphQLURL: graphQLURL, + logger: logger.Named("github_changeprovider"), + metrics: metrics.SubScope("github_changeprovider"), + } +} + +// createDefaultClient creates an HTTP client with the given token and timeout. +func createDefaultClient(token string, timeout time.Duration) *http.Client { + transport := createTransport(token) + + return &http.Client{ + Timeout: timeout, + Transport: transport, + } +} + +// createTransport creates an HTTP transport with optional bearer token authentication. +func createTransport(token string) http.RoundTripper { + base := http.DefaultTransport + + if token == "" { + return base } + + return &bearerTransport{ + token: token, + base: base, + } +} + +// bearerTransport is an http.RoundTripper that adds a Bearer token to requests. +type bearerTransport struct { + token string + base http.RoundTripper +} + +func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Set("Authorization", "Bearer "+t.token) + return t.base.RoundTrip(req) } // Get retrieves change information from GitHub for the provided Change. diff --git a/extension/changeprovider/github/provider_test.go b/extension/changeprovider/github/provider_test.go index 8292e6ce..285f91c4 100644 --- a/extension/changeprovider/github/provider_test.go +++ b/extension/changeprovider/github/provider_test.go @@ -84,12 +84,10 @@ func TestProvider_Get_Success(t *testing.T) { }, nil }) - provider := NewProvider(Params{ - HTTPClient: mockClient, - GraphQLURL: "https://api.github.test/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }) + provider := NewProvider(Config{ + BaseURL: "https://api.github.test", + HTTPClient: mockClient, + }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) changeInfo, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/123/abc123def456"}, @@ -193,12 +191,10 @@ func TestProvider_Get_Pagination(t *testing.T) { }, nil }) - provider := NewProvider(Params{ - HTTPClient: mockClient, - GraphQLURL: "https://api.github.test/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }) + provider := NewProvider(Config{ + BaseURL: "https://api.github.test", + HTTPClient: mockClient, + }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) changeInfo, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/456/xyz789"}, @@ -214,29 +210,10 @@ func TestProvider_Get_Pagination(t *testing.T) { assert.Equal(t, "file2.go", info.ChangedFiles[1].Path) } -func TestProvider_Get_NoURIs(t *testing.T) { - provider := NewProvider(Params{ - HTTPClient: &http.Client{}, - GraphQLURL: "https://api.github.test/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }) - - _, err := provider.Get(context.Background(), entity.Change{ - URIs: []string{}, - }) - - require.Error(t, err) - assert.Contains(t, err.Error(), "no URIs provided") -} - func TestProvider_Get_InvalidURI(t *testing.T) { - provider := NewProvider(Params{ - HTTPClient: &http.Client{}, - GraphQLURL: "https://api.github.test/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }) + provider := NewProvider(Config{ + BaseURL: "https://api.github.test", + }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"invalid://uri"}, @@ -251,12 +228,10 @@ func TestProvider_Get_HTTPError(t *testing.T) { return nil, assert.AnError }) - provider := NewProvider(Params{ - HTTPClient: mockClient, - GraphQLURL: "https://api.github.test/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }) + provider := NewProvider(Config{ + BaseURL: "https://api.github.test", + HTTPClient: mockClient, + }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/pull/123/abc"}, @@ -275,12 +250,10 @@ func TestProvider_Get_APIError404(t *testing.T) { }, nil }) - provider := NewProvider(Params{ - HTTPClient: mockClient, - GraphQLURL: "https://api.github.test/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }) + provider := NewProvider(Config{ + BaseURL: "https://api.github.test", + HTTPClient: mockClient, + }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/pull/999/abc"}, @@ -308,12 +281,10 @@ func TestProvider_Get_GraphQLError(t *testing.T) { }, nil }) - provider := NewProvider(Params{ - HTTPClient: mockClient, - GraphQLURL: "https://api.github.test/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }) + provider := NewProvider(Config{ + BaseURL: "https://api.github.test", + HTTPClient: mockClient, + }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/pull/123/abc"}, @@ -332,12 +303,10 @@ func TestProvider_Get_InvalidJSON(t *testing.T) { }, nil }) - provider := NewProvider(Params{ - HTTPClient: mockClient, - GraphQLURL: "https://api.github.test/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }) + provider := NewProvider(Config{ + BaseURL: "https://api.github.test", + HTTPClient: mockClient, + }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/pull/123/abc"}, @@ -349,12 +318,10 @@ func TestProvider_Get_InvalidJSON(t *testing.T) { func TestNewProvider_DefaultConfig(t *testing.T) { httpClient := &http.Client{Timeout: 30 * time.Second} - provider := NewProvider(Params{ - HTTPClient: httpClient, - GraphQLURL: "https://api.github.com/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }).(*provider) + provider := NewProvider(Config{ + BaseURL: "https://api.github.com", + HTTPClient: httpClient, + }, zaptest.NewLogger(t).Sugar(), tally.NoopScope).(*provider) assert.Equal(t, "https://api.github.com/graphql", provider.graphQLURL) assert.NotNil(t, provider.httpClient) @@ -437,12 +404,10 @@ func TestProvider_Get_MultiplePRs(t *testing.T) { }, nil }) - provider := NewProvider(Params{ - HTTPClient: mockClient, - GraphQLURL: "https://api.github.test/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }) + provider := NewProvider(Config{ + BaseURL: "https://api.github.test", + HTTPClient: mockClient, + }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) changeInfo, err := provider.Get(context.Background(), entity.Change{ URIs: []string{ @@ -471,12 +436,9 @@ func TestProvider_Get_MultiplePRs(t *testing.T) { } func TestProvider_Get_CrossRepoStack(t *testing.T) { - provider := NewProvider(Params{ - HTTPClient: &http.Client{}, - GraphQLURL: "https://api.github.test/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }) + provider := NewProvider(Config{ + BaseURL: "https://api.github.test", + }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{ @@ -492,12 +454,9 @@ func TestProvider_Get_CrossRepoStack(t *testing.T) { } func TestProvider_Get_MixedProviderStack(t *testing.T) { - provider := NewProvider(Params{ - HTTPClient: &http.Client{}, - GraphQLURL: "https://api.github.test/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }) + provider := NewProvider(Config{ + BaseURL: "https://api.github.test", + }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{ @@ -553,12 +512,10 @@ func TestProvider_Get_StalePR(t *testing.T) { }, nil }) - provider := NewProvider(Params{ - HTTPClient: mockClient, - GraphQLURL: "https://api.github.test/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }) + provider := NewProvider(Config{ + BaseURL: "https://api.github.test", + HTTPClient: mockClient, + }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/123/oldsha456"}, // Different SHA! @@ -657,12 +614,10 @@ func TestProvider_Get_PartialSuccess(t *testing.T) { }, nil }) - provider := NewProvider(Params{ - HTTPClient: mockClient, - GraphQLURL: "https://api.github.test/graphql", - Logger: zaptest.NewLogger(t).Sugar(), - MetricsScope: tally.NoopScope, - }) + provider := NewProvider(Config{ + BaseURL: "https://api.github.test", + HTTPClient: mockClient, + }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) changeInfo, err := provider.Get(context.Background(), entity.Change{ URIs: []string{ From 1faf6c434dd9cf05c465d7268fbecbbe3ec78ffb Mon Sep 17 00:00:00 2001 From: rprithyani Date: Thu, 5 Mar 2026 06:31:57 +0000 Subject: [PATCH 4/6] Refactor GitHub change provider to use Client wrapper pattern Replace Config pattern with Client wrapper that encapsulates HTTP client and GraphQL URL. Simplifies provider by giving caller full control over authentication and endpoint configuration. Co-Authored-By: Claude Opus 4.6 --- example/server/orchestrator/main.go | 11 +- extension/changeprovider/github/BUILD.bazel | 5 +- extension/changeprovider/github/config.go | 112 ++++++++++----- .../changeprovider/github/config_test.go | 115 ++++++++++++++++ extension/changeprovider/github/graphql.go | 7 +- extension/changeprovider/github/provider.go | 85 +++--------- .../changeprovider/github/provider_test.go | 129 +++++++++++------- 7 files changed, 305 insertions(+), 159 deletions(-) create mode 100644 extension/changeprovider/github/config_test.go diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index 4e273ee2..a0861bb9 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -533,8 +533,15 @@ func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeCh // newChangeProvider creates a ChangeProvider for GitHub (github.com). // Configured via GITHUB_BASE_URL and GITHUB_TOKEN environment variables. func newChangeProvider(logger *zap.Logger, scope tally.Scope) changeprovider.ChangeProvider { - config := githubprovider.DefaultConfig() - return githubprovider.NewProvider(config, logger.Sugar(), scope.SubScope("changeprovider")) + baseURL := os.Getenv("GITHUB_BASE_URL") + if baseURL == "" { + baseURL = "https://api.github.com" + } + + token := os.Getenv("GITHUB_TOKEN") + + client := githubprovider.NewAuthenticatedClient(token, baseURL, githubprovider.DefaultTimeout) + return githubprovider.NewProvider(client, logger.Sugar(), scope.SubScope("changeprovider")) } // bearerTransport is an http.RoundTripper that adds a Bearer token to requests. diff --git a/extension/changeprovider/github/BUILD.bazel b/extension/changeprovider/github/BUILD.bazel index c67ce397..2c1d28f4 100644 --- a/extension/changeprovider/github/BUILD.bazel +++ b/extension/changeprovider/github/BUILD.bazel @@ -22,7 +22,10 @@ go_library( go_test( name = "github_test", - srcs = ["provider_test.go"], + srcs = [ + "config_test.go", + "provider_test.go", + ], embed = [":github"], deps = [ "//entity", diff --git a/extension/changeprovider/github/config.go b/extension/changeprovider/github/config.go index 5a082474..0474dfd6 100644 --- a/extension/changeprovider/github/config.go +++ b/extension/changeprovider/github/config.go @@ -1,9 +1,7 @@ package github import ( - "fmt" "net/http" - "os" "time" ) @@ -13,47 +11,91 @@ const ( DefaultTimeout = 30 * time.Second ) -// Config holds configuration for connecting to a GitHub backend. -type Config struct { - // BaseURL is the GitHub instance base URL (without /graphql suffix). - // Examples: "https://api.github.com", "https://ghe.company.com" - BaseURL string +// Client is a GitHub API client that encapsulates connection details and authentication. +type Client struct { + httpClient *http.Client + graphQLURL string +} - // Token for authenticating to this GitHub instance. - // Can be empty for unauthenticated requests. - Token string +// NewClient creates a new GitHub API client with a pre-configured HTTP client. +// The caller is responsible for configuring authentication in the HTTP client. +// +// Parameters: +// - httpClient: Configured HTTP client (with auth, timeout, transport, etc.) +// - graphQLURL: GitHub GraphQL endpoint (e.g., "https://api.github.com/graphql") +// +// Example with custom HTTP client: +// +// tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "ghp_xxx"}) +// httpClient := oauth2.NewClient(ctx, tokenSource) +// client := github.NewClient(httpClient, "https://api.github.com/graphql") +func NewClient(httpClient *http.Client, graphQLURL string) *Client { + return &Client{ + httpClient: httpClient, + graphQLURL: graphQLURL, + } +} - // Timeout for HTTP requests to this GitHub instance. - // If zero or negative, defaults to DefaultTimeout (30s). - // Set this higher for slow GHE instances or flaky networks. - Timeout time.Duration +// NewAuthenticatedClient creates a GitHub API client with bearer token authentication. +// This is a convenience helper for simple token-based auth. +// +// Parameters: +// - token: GitHub personal access token (can be empty for public access) +// - baseURL: GitHub instance base URL (e.g., "https://api.github.com" or "https://ghe.company.com") +// - timeout: HTTP request timeout (use DefaultTimeout if unsure) +// +// The GraphQL URL is derived by appending "/graphql" to baseURL. +// +// Example: +// +// // GitHub.com +// client := github.NewAuthenticatedClient("ghp_xxx", "https://api.github.com", github.DefaultTimeout) +// +// // GitHub Enterprise Server +// client := github.NewAuthenticatedClient("ghp_xxx", "https://ghe.company.com/api", github.DefaultTimeout) +func NewAuthenticatedClient(token string, baseURL string, timeout time.Duration) *Client { + httpClient := &http.Client{ + Timeout: timeout, + Transport: newBearerTransport(token, http.DefaultTransport), + } - // HTTPClient provides complete control over the HTTP client. - // If set, BaseURL is still used but Token and Timeout are ignored. - // Use this for custom transports, connection pooling, or testing. - HTTPClient *http.Client + return &Client{ + httpClient: httpClient, + graphQLURL: baseURL + "/graphql", + } } -// Validate checks if the config is valid. -func (c Config) Validate() error { - if c.BaseURL == "" { - return fmt.Errorf("BaseURL is required") - } - return nil +// bearerTransport is an http.RoundTripper that adds a Bearer token to requests. +type bearerTransport struct { + token string + base http.RoundTripper } -// DefaultConfig returns a Config for github.com from environment. -func DefaultConfig() Config { - return Config{ - BaseURL: getEnvOrDefault("GITHUB_BASE_URL", "https://api.github.com"), - Token: os.Getenv("GITHUB_TOKEN"), - Timeout: 0, // Will use DefaultTimeout +// newBearerTransport creates an HTTP transport with bearer token authentication. +// If token is empty, returns the base transport unchanged. +func newBearerTransport(token string, base http.RoundTripper) http.RoundTripper { + if token == "" { + return base } -} -func getEnvOrDefault(key, defaultVal string) string { - if val := os.Getenv(key); val != "" { - return val + return &bearerTransport{ + token: token, + base: base, } - return defaultVal +} + +func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Set("Authorization", "Bearer "+t.token) + return t.base.RoundTrip(req) +} + +// HTTPClient returns the configured HTTP client. +func (c *Client) HTTPClient() *http.Client { + return c.httpClient +} + +// GraphQLURL returns the configured GraphQL endpoint URL. +func (c *Client) GraphQLURL() string { + return c.graphQLURL } diff --git a/extension/changeprovider/github/config_test.go b/extension/changeprovider/github/config_test.go new file mode 100644 index 00000000..5174be09 --- /dev/null +++ b/extension/changeprovider/github/config_test.go @@ -0,0 +1,115 @@ +package github + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewClient(t *testing.T) { + httpClient := &http.Client{Timeout: 10 * time.Second} + graphQLURL := "https://api.github.com/graphql" + + client := NewClient(httpClient, graphQLURL) + + assert.Equal(t, httpClient, client.HTTPClient()) + assert.Equal(t, graphQLURL, client.GraphQLURL()) +} + +func TestNewAuthenticatedClient(t *testing.T) { + token := "ghp_test123" + baseURL := "https://api.github.com" + timeout := 30 * time.Second + + client := NewAuthenticatedClient(token, baseURL, timeout) + + assert.NotNil(t, client.HTTPClient()) + assert.Equal(t, "https://api.github.com/graphql", client.GraphQLURL()) + assert.Equal(t, timeout, client.HTTPClient().Timeout) + + // Verify bearer transport is configured + transport, ok := client.HTTPClient().Transport.(*bearerTransport) + assert.True(t, ok, "transport should be bearerTransport") + assert.Equal(t, token, transport.token) + assert.Equal(t, http.DefaultTransport, transport.base) +} + +func TestNewAuthenticatedClient_EmptyToken(t *testing.T) { + token := "" + baseURL := "https://api.github.com" + timeout := 30 * time.Second + + client := NewAuthenticatedClient(token, baseURL, timeout) + + assert.NotNil(t, client.HTTPClient()) + assert.Equal(t, "https://api.github.com/graphql", client.GraphQLURL()) + + // Verify transport is NOT bearerTransport when token is empty + assert.Equal(t, http.DefaultTransport, client.HTTPClient().Transport) +} + +func TestNewAuthenticatedClient_GHES(t *testing.T) { + token := "ghp_enterprise" + baseURL := "https://ghe.company.com/api" + timeout := 15 * time.Second + + client := NewAuthenticatedClient(token, baseURL, timeout) + + assert.NotNil(t, client.HTTPClient()) + assert.Equal(t, "https://ghe.company.com/api/graphql", client.GraphQLURL()) + assert.Equal(t, timeout, client.HTTPClient().Timeout) +} + +func TestBearerTransport_AddsAuthHeader(t *testing.T) { + token := "ghp_test_token" + mockBase := &mockRoundTripper{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + // Verify the Authorization header was added + assert.Equal(t, "Bearer ghp_test_token", req.Header.Get("Authorization")) + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }, + } + + transport := &bearerTransport{ + token: token, + base: mockBase, + } + + req, err := http.NewRequest(http.MethodGet, "https://api.github.com/test", nil) + assert.NoError(t, err) + + _, err = transport.RoundTrip(req) + assert.NoError(t, err) +} + +func TestBearerTransport_ClonesRequest(t *testing.T) { + token := "ghp_test_token" + originalReq, err := http.NewRequest(http.MethodGet, "https://api.github.com/test", nil) + assert.NoError(t, err) + + mockBase := &mockRoundTripper{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + // Verify the request was cloned (different pointer) + assert.NotSame(t, originalReq, req, "request should be cloned") + assert.Equal(t, originalReq.URL.String(), req.URL.String()) + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }, + } + + transport := &bearerTransport{ + token: token, + base: mockBase, + } + + _, err = transport.RoundTrip(originalReq) + assert.NoError(t, err) + + // Verify original request is unchanged + assert.Empty(t, originalReq.Header.Get("Authorization"), "original request should not be modified") +} diff --git a/extension/changeprovider/github/graphql.go b/extension/changeprovider/github/graphql.go index 058f2d56..1c141c12 100644 --- a/extension/changeprovider/github/graphql.go +++ b/extension/changeprovider/github/graphql.go @@ -121,19 +121,18 @@ func buildGraphQLRequest(org, repo string, prNumber int, cursor string) graphqlR func doGraphQLRequest( ctx context.Context, bodyBytes []byte, - graphQLURL string, - httpClient *http.Client, + client *Client, org, repo string, metrics tally.Scope, ) (*http.Response, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, graphQLURL, bytes.NewReader(bodyBytes)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, client.GraphQLURL(), bytes.NewReader(bodyBytes)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %w", err) } req.Header.Set("Content-Type", "application/json") - resp, err := httpClient.Do(req) + resp, err := client.HTTPClient().Do(req) if err != nil { metrics.Tagged(map[string]string{ "org": org, diff --git a/extension/changeprovider/github/provider.go b/extension/changeprovider/github/provider.go index 3ba07c4c..11223208 100644 --- a/extension/changeprovider/github/provider.go +++ b/extension/changeprovider/github/provider.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "net/http" "time" "github.com/uber-go/tally/v4" @@ -17,75 +16,31 @@ import ( // provider implements the ChangeProvider interface for GitHub. type provider struct { - httpClient *http.Client - graphQLURL string - logger *zap.SugaredLogger - metrics tally.Scope + client *Client + logger *zap.SugaredLogger + metrics tally.Scope } -// NewProvider creates a new GitHub ChangeProvider from configuration. -func NewProvider(config Config, logger *zap.SugaredLogger, metrics tally.Scope) changeprovider.ChangeProvider { - if err := config.Validate(); err != nil { - panic(fmt.Sprintf("invalid GitHub config: %v", err)) - } - - // Derive GraphQL URL from base URL - graphQLURL := config.BaseURL + "/graphql" - - // Use provided client or create default - httpClient := config.HTTPClient - if httpClient == nil { - timeout := config.Timeout - if timeout <= 0 { - timeout = DefaultTimeout - } - httpClient = createDefaultClient(config.Token, timeout) - } - +// NewProvider creates a new GitHub ChangeProvider. +// The caller is responsible for providing a fully-configured Client with authentication. +// Use NewAuthenticatedClient helper to create a client with bearer token auth. +// +// Parameters: +// - client: Pre-configured GitHub API client (encapsulates HTTP client and GraphQL URL) +// - logger: Structured logger +// - metrics: Metrics scope +func NewProvider( + client *Client, + logger *zap.SugaredLogger, + metrics tally.Scope, +) changeprovider.ChangeProvider { return &provider{ - httpClient: httpClient, - graphQLURL: graphQLURL, - logger: logger.Named("github_changeprovider"), - metrics: metrics.SubScope("github_changeprovider"), - } -} - -// createDefaultClient creates an HTTP client with the given token and timeout. -func createDefaultClient(token string, timeout time.Duration) *http.Client { - transport := createTransport(token) - - return &http.Client{ - Timeout: timeout, - Transport: transport, + client: client, + logger: logger.Named("github_changeprovider"), + metrics: metrics.SubScope("github_changeprovider"), } } -// createTransport creates an HTTP transport with optional bearer token authentication. -func createTransport(token string) http.RoundTripper { - base := http.DefaultTransport - - if token == "" { - return base - } - - return &bearerTransport{ - token: token, - base: base, - } -} - -// bearerTransport is an http.RoundTripper that adds a Bearer token to requests. -type bearerTransport struct { - token string - base http.RoundTripper -} - -func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = req.Clone(req.Context()) - req.Header.Set("Authorization", "Bearer "+t.token) - return t.base.RoundTrip(req) -} - // Get retrieves change information from GitHub for the provided Change. // Returns one ChangeInfo per URI (one per PR in stacked changes). // TODO add error codes for user errors (non-retryable) vs system errors. @@ -234,7 +189,7 @@ func (p *provider) fetchPullRequestPage(ctx context.Context, org, repo string, p return nil, fmt.Errorf("failed to marshal GraphQL request: %w", err) } - resp, err := doGraphQLRequest(ctx, bodyBytes, p.graphQLURL, p.httpClient, org, repo, p.metrics) + resp, err := doGraphQLRequest(ctx, bodyBytes, p.client, org, repo, p.metrics) if err != nil { return nil, err } diff --git a/extension/changeprovider/github/provider_test.go b/extension/changeprovider/github/provider_test.go index 285f91c4..4549aba6 100644 --- a/extension/changeprovider/github/provider_test.go +++ b/extension/changeprovider/github/provider_test.go @@ -84,10 +84,12 @@ func TestProvider_Get_Success(t *testing.T) { }, nil }) - provider := NewProvider(Config{ - BaseURL: "https://api.github.test", - HTTPClient: mockClient, - }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + client := NewClient(mockClient, "https://api.github.test/graphql") + provider := NewProvider( + client, + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + ) changeInfo, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/123/abc123def456"}, @@ -191,10 +193,12 @@ func TestProvider_Get_Pagination(t *testing.T) { }, nil }) - provider := NewProvider(Config{ - BaseURL: "https://api.github.test", - HTTPClient: mockClient, - }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + client := NewClient(mockClient, "https://api.github.test/graphql") + provider := NewProvider( + client, + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + ) changeInfo, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/456/xyz789"}, @@ -211,9 +215,12 @@ func TestProvider_Get_Pagination(t *testing.T) { } func TestProvider_Get_InvalidURI(t *testing.T) { - provider := NewProvider(Config{ - BaseURL: "https://api.github.test", - }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + client := NewClient(&http.Client{}, "https://api.github.test/graphql") + provider := NewProvider( + client, + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + ) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"invalid://uri"}, @@ -228,10 +235,12 @@ func TestProvider_Get_HTTPError(t *testing.T) { return nil, assert.AnError }) - provider := NewProvider(Config{ - BaseURL: "https://api.github.test", - HTTPClient: mockClient, - }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + client := NewClient(mockClient, "https://api.github.test/graphql") + provider := NewProvider( + client, + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + ) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/pull/123/abc"}, @@ -250,10 +259,12 @@ func TestProvider_Get_APIError404(t *testing.T) { }, nil }) - provider := NewProvider(Config{ - BaseURL: "https://api.github.test", - HTTPClient: mockClient, - }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + client := NewClient(mockClient, "https://api.github.test/graphql") + provider := NewProvider( + client, + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + ) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/pull/999/abc"}, @@ -281,10 +292,12 @@ func TestProvider_Get_GraphQLError(t *testing.T) { }, nil }) - provider := NewProvider(Config{ - BaseURL: "https://api.github.test", - HTTPClient: mockClient, - }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + client := NewClient(mockClient, "https://api.github.test/graphql") + provider := NewProvider( + client, + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + ) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/pull/123/abc"}, @@ -303,10 +316,12 @@ func TestProvider_Get_InvalidJSON(t *testing.T) { }, nil }) - provider := NewProvider(Config{ - BaseURL: "https://api.github.test", - HTTPClient: mockClient, - }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + client := NewClient(mockClient, "https://api.github.test/graphql") + provider := NewProvider( + client, + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + ) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/pull/123/abc"}, @@ -316,15 +331,13 @@ func TestProvider_Get_InvalidJSON(t *testing.T) { assert.Contains(t, err.Error(), "failed to decode GraphQL response") } -func TestNewProvider_DefaultConfig(t *testing.T) { +func TestNewProvider(t *testing.T) { httpClient := &http.Client{Timeout: 30 * time.Second} - provider := NewProvider(Config{ - BaseURL: "https://api.github.com", - HTTPClient: httpClient, - }, zaptest.NewLogger(t).Sugar(), tally.NoopScope).(*provider) + client := NewClient(httpClient, "https://api.github.com/graphql") - assert.Equal(t, "https://api.github.com/graphql", provider.graphQLURL) - assert.NotNil(t, provider.httpClient) + provider := NewProvider(client, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + + assert.NotNil(t, provider) } func TestProvider_Get_MultiplePRs(t *testing.T) { @@ -404,10 +417,12 @@ func TestProvider_Get_MultiplePRs(t *testing.T) { }, nil }) - provider := NewProvider(Config{ - BaseURL: "https://api.github.test", - HTTPClient: mockClient, - }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + client := NewClient(mockClient, "https://api.github.test/graphql") + provider := NewProvider( + client, + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + ) changeInfo, err := provider.Get(context.Background(), entity.Change{ URIs: []string{ @@ -436,9 +451,12 @@ func TestProvider_Get_MultiplePRs(t *testing.T) { } func TestProvider_Get_CrossRepoStack(t *testing.T) { - provider := NewProvider(Config{ - BaseURL: "https://api.github.test", - }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + client := NewClient(&http.Client{}, "https://api.github.test/graphql") + provider := NewProvider( + client, + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + ) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{ @@ -454,9 +472,12 @@ func TestProvider_Get_CrossRepoStack(t *testing.T) { } func TestProvider_Get_MixedProviderStack(t *testing.T) { - provider := NewProvider(Config{ - BaseURL: "https://api.github.test", - }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + client := NewClient(&http.Client{}, "https://api.github.test/graphql") + provider := NewProvider( + client, + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + ) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{ @@ -512,10 +533,12 @@ func TestProvider_Get_StalePR(t *testing.T) { }, nil }) - provider := NewProvider(Config{ - BaseURL: "https://api.github.test", - HTTPClient: mockClient, - }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + client := NewClient(mockClient, "https://api.github.test/graphql") + provider := NewProvider( + client, + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + ) _, err := provider.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/123/oldsha456"}, // Different SHA! @@ -614,10 +637,12 @@ func TestProvider_Get_PartialSuccess(t *testing.T) { }, nil }) - provider := NewProvider(Config{ - BaseURL: "https://api.github.test", - HTTPClient: mockClient, - }, zaptest.NewLogger(t).Sugar(), tally.NoopScope) + client := NewClient(mockClient, "https://api.github.test/graphql") + provider := NewProvider( + client, + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + ) changeInfo, err := provider.Get(context.Background(), entity.Change{ URIs: []string{ From 5c812f8daf5a46987d194d169ef08e38cfd647db Mon Sep 17 00:00:00 2001 From: rprithyani Date: Thu, 5 Mar 2026 07:38:20 +0000 Subject: [PATCH 5/6] make auth configurable --- example/server/orchestrator/main.go | 49 +++++- extension/changeprovider/github/config.go | 98 ++++------- .../changeprovider/github/config_test.go | 158 +++++++++--------- .../changeprovider/github/provider_test.go | 26 +-- 4 files changed, 158 insertions(+), 173 deletions(-) diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index a0861bb9..b25ff44b 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -505,6 +505,36 @@ func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope t return nil } +// getEnv returns environment variable value or default if not set. +func getEnv(key, defaultVal string) string { + if val := os.Getenv(key); val != "" { + return val + } + return defaultVal +} + +// parseTimeout parses a duration from environment variable with fallback to default. +// Returns defaultVal if envVal is empty or cannot be parsed. +func parseTimeout(envVal string, defaultVal time.Duration) time.Duration { + if envVal == "" { + return defaultVal + } + if d, err := time.ParseDuration(envVal); err == nil { + return d + } + return defaultVal +} + +// buildGitHubHTTPClient creates an http.Client configured for GitHub API calls. +// Configures timeout and optional bearer token authentication. +func buildGitHubHTTPClient(token string, timeout time.Duration) *http.Client { + httpClient := &http.Client{Timeout: timeout} + if token != "" { + httpClient.Transport = &bearerTransport{token: token} + } + return httpClient +} + // newMergeChecker creates a MergeChecker for GitHub (github.com). // Configured via GITHUB_TOKEN and GITHUB_GRAPHQL_URL environment variables. func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeChecker { @@ -531,16 +561,21 @@ func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeCh } // newChangeProvider creates a ChangeProvider for GitHub (github.com). -// Configured via GITHUB_BASE_URL and GITHUB_TOKEN environment variables. +// Configured via GITHUB_BASE_URL, GITHUB_TOKEN, and GITHUB_TIMEOUT environment variables. +// Uses pure dependency injection - creates http.Client with auth configured in Transport. func newChangeProvider(logger *zap.Logger, scope tally.Scope) changeprovider.ChangeProvider { - baseURL := os.Getenv("GITHUB_BASE_URL") - if baseURL == "" { - baseURL = "https://api.github.com" - } - + // 1. Read configuration from environment + baseURL := getEnv("GITHUB_BASE_URL", "https://api.github.com") token := os.Getenv("GITHUB_TOKEN") + timeout := parseTimeout(os.Getenv("GITHUB_TIMEOUT"), 30*time.Second) + + // 2. Build HTTP client with caller-controlled config (auth + timeout) + httpClient := buildGitHubHTTPClient(token, timeout) + + // 3. Create GitHub client wrapper with baseURL + client := githubprovider.NewClient(httpClient, baseURL) - client := githubprovider.NewAuthenticatedClient(token, baseURL, githubprovider.DefaultTimeout) + // 4. Inject into provider return githubprovider.NewProvider(client, logger.Sugar(), scope.SubScope("changeprovider")) } diff --git a/extension/changeprovider/github/config.go b/extension/changeprovider/github/config.go index 0474dfd6..61e9eb4e 100644 --- a/extension/changeprovider/github/config.go +++ b/extension/changeprovider/github/config.go @@ -2,100 +2,58 @@ package github import ( "net/http" - "time" -) - -const ( - // DefaultTimeout is the default HTTP request timeout for GitHub API calls. - // This applies to the entire request/response cycle. - DefaultTimeout = 30 * time.Second ) // Client is a GitHub API client that encapsulates connection details and authentication. +// The client is protocol-agnostic - it provides helpers for both GraphQL and REST endpoints. type Client struct { httpClient *http.Client - graphQLURL string + baseURL string } // NewClient creates a new GitHub API client with a pre-configured HTTP client. -// The caller is responsible for configuring authentication in the HTTP client. +// The caller is responsible for configuring authentication in the HTTP client's Transport. // // Parameters: // - httpClient: Configured HTTP client (with auth, timeout, transport, etc.) -// - graphQLURL: GitHub GraphQL endpoint (e.g., "https://api.github.com/graphql") +// - baseURL: GitHub instance base URL (e.g., "https://api.github.com" or "https://ghe.company.com/api") // -// Example with custom HTTP client: +// Example with Bearer token auth: // -// tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "ghp_xxx"}) -// httpClient := oauth2.NewClient(ctx, tokenSource) -// client := github.NewClient(httpClient, "https://api.github.com/graphql") -func NewClient(httpClient *http.Client, graphQLURL string) *Client { - return &Client{ - httpClient: httpClient, - graphQLURL: graphQLURL, - } -} - -// NewAuthenticatedClient creates a GitHub API client with bearer token authentication. -// This is a convenience helper for simple token-based auth. -// -// Parameters: -// - token: GitHub personal access token (can be empty for public access) -// - baseURL: GitHub instance base URL (e.g., "https://api.github.com" or "https://ghe.company.com") -// - timeout: HTTP request timeout (use DefaultTimeout if unsure) -// -// The GraphQL URL is derived by appending "/graphql" to baseURL. +// transport := &bearerTransport{token: "ghp_xxx", base: http.DefaultTransport} +// httpClient := &http.Client{Transport: transport, Timeout: 30 * time.Second} +// client := github.NewClient(httpClient, "https://api.github.com") // -// Example: +// Example with OAuth2: // -// // GitHub.com -// client := github.NewAuthenticatedClient("ghp_xxx", "https://api.github.com", github.DefaultTimeout) -// -// // GitHub Enterprise Server -// client := github.NewAuthenticatedClient("ghp_xxx", "https://ghe.company.com/api", github.DefaultTimeout) -func NewAuthenticatedClient(token string, baseURL string, timeout time.Duration) *Client { - httpClient := &http.Client{ - Timeout: timeout, - Transport: newBearerTransport(token, http.DefaultTransport), - } - +// tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "ghp_xxx"}) +// httpClient := oauth2.NewClient(ctx, tokenSource) +// client := github.NewClient(httpClient, "https://api.github.com") +func NewClient(httpClient *http.Client, baseURL string) *Client { return &Client{ httpClient: httpClient, - graphQLURL: baseURL + "/graphql", + baseURL: baseURL, } } -// bearerTransport is an http.RoundTripper that adds a Bearer token to requests. -type bearerTransport struct { - token string - base http.RoundTripper -} - -// newBearerTransport creates an HTTP transport with bearer token authentication. -// If token is empty, returns the base transport unchanged. -func newBearerTransport(token string, base http.RoundTripper) http.RoundTripper { - if token == "" { - return base - } - - return &bearerTransport{ - token: token, - base: base, - } -} - -func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = req.Clone(req.Context()) - req.Header.Set("Authorization", "Bearer "+t.token) - return t.base.RoundTrip(req) -} - // HTTPClient returns the configured HTTP client. func (c *Client) HTTPClient() *http.Client { return c.httpClient } -// GraphQLURL returns the configured GraphQL endpoint URL. +// BaseURL returns the configured GitHub base URL. +func (c *Client) BaseURL() string { + return c.baseURL +} + +// GraphQLURL returns the GitHub GraphQL endpoint URL. +// Constructs the URL by appending "/graphql" to the base URL. func (c *Client) GraphQLURL() string { - return c.graphQLURL + return c.baseURL + "/graphql" +} + +// RESTURL constructs a GitHub REST API endpoint URL. +// The path should start with "/" (e.g., "/repos/uber/submitqueue/pulls/123"). +func (c *Client) RESTURL(path string) string { + return c.baseURL + path } diff --git a/extension/changeprovider/github/config_test.go b/extension/changeprovider/github/config_test.go index 5174be09..55a0e107 100644 --- a/extension/changeprovider/github/config_test.go +++ b/extension/changeprovider/github/config_test.go @@ -10,106 +10,98 @@ import ( func TestNewClient(t *testing.T) { httpClient := &http.Client{Timeout: 10 * time.Second} - graphQLURL := "https://api.github.com/graphql" - - client := NewClient(httpClient, graphQLURL) - - assert.Equal(t, httpClient, client.HTTPClient()) - assert.Equal(t, graphQLURL, client.GraphQLURL()) -} - -func TestNewAuthenticatedClient(t *testing.T) { - token := "ghp_test123" - baseURL := "https://api.github.com" - timeout := 30 * time.Second - - client := NewAuthenticatedClient(token, baseURL, timeout) - - assert.NotNil(t, client.HTTPClient()) - assert.Equal(t, "https://api.github.com/graphql", client.GraphQLURL()) - assert.Equal(t, timeout, client.HTTPClient().Timeout) - - // Verify bearer transport is configured - transport, ok := client.HTTPClient().Transport.(*bearerTransport) - assert.True(t, ok, "transport should be bearerTransport") - assert.Equal(t, token, transport.token) - assert.Equal(t, http.DefaultTransport, transport.base) -} - -func TestNewAuthenticatedClient_EmptyToken(t *testing.T) { - token := "" baseURL := "https://api.github.com" - timeout := 30 * time.Second - client := NewAuthenticatedClient(token, baseURL, timeout) + client := NewClient(httpClient, baseURL) - assert.NotNil(t, client.HTTPClient()) + assert.Equal(t, httpClient, client.HTTPClient()) + assert.Equal(t, baseURL, client.BaseURL()) assert.Equal(t, "https://api.github.com/graphql", client.GraphQLURL()) - - // Verify transport is NOT bearerTransport when token is empty - assert.Equal(t, http.DefaultTransport, client.HTTPClient().Transport) } -func TestNewAuthenticatedClient_GHES(t *testing.T) { - token := "ghp_enterprise" - baseURL := "https://ghe.company.com/api" - timeout := 15 * time.Second - - client := NewAuthenticatedClient(token, baseURL, timeout) +func TestClient_BaseURL(t *testing.T) { + tests := []struct { + name string + baseURL string + want string + }{ + { + name: "GitHub.com", + baseURL: "https://api.github.com", + want: "https://api.github.com", + }, + { + name: "GitHub Enterprise Server", + baseURL: "https://ghe.company.com/api", + want: "https://ghe.company.com/api", + }, + } - assert.NotNil(t, client.HTTPClient()) - assert.Equal(t, "https://ghe.company.com/api/graphql", client.GraphQLURL()) - assert.Equal(t, timeout, client.HTTPClient().Timeout) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewClient(&http.Client{}, tt.baseURL) + assert.Equal(t, tt.want, client.BaseURL()) + }) + } } -func TestBearerTransport_AddsAuthHeader(t *testing.T) { - token := "ghp_test_token" - mockBase := &mockRoundTripper{ - roundTripFunc: func(req *http.Request) (*http.Response, error) { - // Verify the Authorization header was added - assert.Equal(t, "Bearer ghp_test_token", req.Header.Get("Authorization")) - return &http.Response{ - StatusCode: http.StatusOK, - }, nil +func TestClient_GraphQLURL(t *testing.T) { + tests := []struct { + name string + baseURL string + want string + }{ + { + name: "GitHub.com", + baseURL: "https://api.github.com", + want: "https://api.github.com/graphql", + }, + { + name: "GitHub Enterprise Server", + baseURL: "https://ghe.company.com/api", + want: "https://ghe.company.com/api/graphql", }, } - transport := &bearerTransport{ - token: token, - base: mockBase, + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewClient(&http.Client{}, tt.baseURL) + assert.Equal(t, tt.want, client.GraphQLURL()) + }) } - - req, err := http.NewRequest(http.MethodGet, "https://api.github.com/test", nil) - assert.NoError(t, err) - - _, err = transport.RoundTrip(req) - assert.NoError(t, err) } -func TestBearerTransport_ClonesRequest(t *testing.T) { - token := "ghp_test_token" - originalReq, err := http.NewRequest(http.MethodGet, "https://api.github.com/test", nil) - assert.NoError(t, err) - - mockBase := &mockRoundTripper{ - roundTripFunc: func(req *http.Request) (*http.Response, error) { - // Verify the request was cloned (different pointer) - assert.NotSame(t, originalReq, req, "request should be cloned") - assert.Equal(t, originalReq.URL.String(), req.URL.String()) - return &http.Response{ - StatusCode: http.StatusOK, - }, nil +func TestClient_RESTURL(t *testing.T) { + tests := []struct { + name string + baseURL string + path string + want string + }{ + { + name: "GitHub.com - pull request", + baseURL: "https://api.github.com", + path: "/repos/uber/submitqueue/pulls/123", + want: "https://api.github.com/repos/uber/submitqueue/pulls/123", + }, + { + name: "GitHub Enterprise Server - pull request", + baseURL: "https://ghe.company.com/api", + path: "/repos/uber/submitqueue/pulls/456", + want: "https://ghe.company.com/api/repos/uber/submitqueue/pulls/456", + }, + { + name: "repos endpoint", + baseURL: "https://api.github.com", + path: "/repos/uber/submitqueue", + want: "https://api.github.com/repos/uber/submitqueue", }, } - transport := &bearerTransport{ - token: token, - base: mockBase, + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewClient(&http.Client{}, tt.baseURL) + assert.Equal(t, tt.want, client.RESTURL(tt.path)) + }) } - - _, err = transport.RoundTrip(originalReq) - assert.NoError(t, err) - - // Verify original request is unchanged - assert.Empty(t, originalReq.Header.Get("Authorization"), "original request should not be modified") } diff --git a/extension/changeprovider/github/provider_test.go b/extension/changeprovider/github/provider_test.go index 4549aba6..f7f1083d 100644 --- a/extension/changeprovider/github/provider_test.go +++ b/extension/changeprovider/github/provider_test.go @@ -84,7 +84,7 @@ func TestProvider_Get_Success(t *testing.T) { }, nil }) - client := NewClient(mockClient, "https://api.github.test/graphql") + client := NewClient(mockClient, "https://api.github.test") provider := NewProvider( client, zaptest.NewLogger(t).Sugar(), @@ -193,7 +193,7 @@ func TestProvider_Get_Pagination(t *testing.T) { }, nil }) - client := NewClient(mockClient, "https://api.github.test/graphql") + client := NewClient(mockClient, "https://api.github.test") provider := NewProvider( client, zaptest.NewLogger(t).Sugar(), @@ -215,7 +215,7 @@ func TestProvider_Get_Pagination(t *testing.T) { } func TestProvider_Get_InvalidURI(t *testing.T) { - client := NewClient(&http.Client{}, "https://api.github.test/graphql") + client := NewClient(&http.Client{}, "https://api.github.test") provider := NewProvider( client, zaptest.NewLogger(t).Sugar(), @@ -235,7 +235,7 @@ func TestProvider_Get_HTTPError(t *testing.T) { return nil, assert.AnError }) - client := NewClient(mockClient, "https://api.github.test/graphql") + client := NewClient(mockClient, "https://api.github.test") provider := NewProvider( client, zaptest.NewLogger(t).Sugar(), @@ -259,7 +259,7 @@ func TestProvider_Get_APIError404(t *testing.T) { }, nil }) - client := NewClient(mockClient, "https://api.github.test/graphql") + client := NewClient(mockClient, "https://api.github.test") provider := NewProvider( client, zaptest.NewLogger(t).Sugar(), @@ -292,7 +292,7 @@ func TestProvider_Get_GraphQLError(t *testing.T) { }, nil }) - client := NewClient(mockClient, "https://api.github.test/graphql") + client := NewClient(mockClient, "https://api.github.test") provider := NewProvider( client, zaptest.NewLogger(t).Sugar(), @@ -316,7 +316,7 @@ func TestProvider_Get_InvalidJSON(t *testing.T) { }, nil }) - client := NewClient(mockClient, "https://api.github.test/graphql") + client := NewClient(mockClient, "https://api.github.test") provider := NewProvider( client, zaptest.NewLogger(t).Sugar(), @@ -333,7 +333,7 @@ func TestProvider_Get_InvalidJSON(t *testing.T) { func TestNewProvider(t *testing.T) { httpClient := &http.Client{Timeout: 30 * time.Second} - client := NewClient(httpClient, "https://api.github.com/graphql") + client := NewClient(httpClient, "https://api.github.com") provider := NewProvider(client, zaptest.NewLogger(t).Sugar(), tally.NoopScope) @@ -417,7 +417,7 @@ func TestProvider_Get_MultiplePRs(t *testing.T) { }, nil }) - client := NewClient(mockClient, "https://api.github.test/graphql") + client := NewClient(mockClient, "https://api.github.test") provider := NewProvider( client, zaptest.NewLogger(t).Sugar(), @@ -451,7 +451,7 @@ func TestProvider_Get_MultiplePRs(t *testing.T) { } func TestProvider_Get_CrossRepoStack(t *testing.T) { - client := NewClient(&http.Client{}, "https://api.github.test/graphql") + client := NewClient(&http.Client{}, "https://api.github.test") provider := NewProvider( client, zaptest.NewLogger(t).Sugar(), @@ -472,7 +472,7 @@ func TestProvider_Get_CrossRepoStack(t *testing.T) { } func TestProvider_Get_MixedProviderStack(t *testing.T) { - client := NewClient(&http.Client{}, "https://api.github.test/graphql") + client := NewClient(&http.Client{}, "https://api.github.test") provider := NewProvider( client, zaptest.NewLogger(t).Sugar(), @@ -533,7 +533,7 @@ func TestProvider_Get_StalePR(t *testing.T) { }, nil }) - client := NewClient(mockClient, "https://api.github.test/graphql") + client := NewClient(mockClient, "https://api.github.test") provider := NewProvider( client, zaptest.NewLogger(t).Sugar(), @@ -637,7 +637,7 @@ func TestProvider_Get_PartialSuccess(t *testing.T) { }, nil }) - client := NewClient(mockClient, "https://api.github.test/graphql") + client := NewClient(mockClient, "https://api.github.test") provider := NewProvider( client, zaptest.NewLogger(t).Sugar(), From 330aeb746545138e46615e2e860685bd2a0e232c Mon Sep 17 00:00:00 2001 From: rprithyani Date: Thu, 12 Mar 2026 19:02:58 +0000 Subject: [PATCH 6/6] fix conflicts and re-do work on the change provider and add tests --- example/server/orchestrator/main.go | 11 +- extension/changeprovider/github/config.go | 16 +- .../changeprovider/github/config_test.go | 80 +- extension/changeprovider/github/graphql.go | 42 - .../changeprovider/github/graphql_test.go | 145 ++++ extension/changeprovider/github/provider.go | 109 +-- .../changeprovider/github/provider_test.go | 750 ++++-------------- extension/changeprovider/github/validate.go | 9 +- .../changeprovider/github/validate_test.go | 99 +++ orchestrator/controller/validate/validate.go | 20 +- 10 files changed, 481 insertions(+), 800 deletions(-) create mode 100644 extension/changeprovider/github/graphql_test.go create mode 100644 extension/changeprovider/github/validate_test.go diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index b25ff44b..ae4526ad 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -572,11 +572,12 @@ func newChangeProvider(logger *zap.Logger, scope tally.Scope) changeprovider.Cha // 2. Build HTTP client with caller-controlled config (auth + timeout) httpClient := buildGitHubHTTPClient(token, timeout) - // 3. Create GitHub client wrapper with baseURL - client := githubprovider.NewClient(httpClient, baseURL) - - // 4. Inject into provider - return githubprovider.NewProvider(client, logger.Sugar(), scope.SubScope("changeprovider")) + // 3. Inject into provider + return githubprovider.NewProvider(githubprovider.Params{ + Client: githubprovider.NewClient(httpClient, baseURL), + Logger: logger.Sugar(), + MetricsScope: scope.SubScope("changeprovider"), + }) } // bearerTransport is an http.RoundTripper that adds a Bearer token to requests. diff --git a/extension/changeprovider/github/config.go b/extension/changeprovider/github/config.go index 61e9eb4e..9c646cd9 100644 --- a/extension/changeprovider/github/config.go +++ b/extension/changeprovider/github/config.go @@ -15,20 +15,8 @@ type Client struct { // The caller is responsible for configuring authentication in the HTTP client's Transport. // // Parameters: -// - httpClient: Configured HTTP client (with auth, timeout, transport, etc.) -// - baseURL: GitHub instance base URL (e.g., "https://api.github.com" or "https://ghe.company.com/api") -// -// Example with Bearer token auth: -// -// transport := &bearerTransport{token: "ghp_xxx", base: http.DefaultTransport} -// httpClient := &http.Client{Transport: transport, Timeout: 30 * time.Second} -// client := github.NewClient(httpClient, "https://api.github.com") -// -// Example with OAuth2: -// -// tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "ghp_xxx"}) -// httpClient := oauth2.NewClient(ctx, tokenSource) -// client := github.NewClient(httpClient, "https://api.github.com") +// - httpClient: Configured HTTP client (with auth, timeout, transport, etc.) +// - baseURL: GitHub instance base URL (e.g., "https://api.github.com" or "https://ghe.company.com/api") func NewClient(httpClient *http.Client, baseURL string) *Client { return &Client{ httpClient: httpClient, diff --git a/extension/changeprovider/github/config_test.go b/extension/changeprovider/github/config_test.go index 55a0e107..34d6e81f 100644 --- a/extension/changeprovider/github/config_test.go +++ b/extension/changeprovider/github/config_test.go @@ -9,99 +9,89 @@ import ( ) func TestNewClient(t *testing.T) { - httpClient := &http.Client{Timeout: 10 * time.Second} + httpClient := &http.Client{Timeout: 30 * time.Second} baseURL := "https://api.github.com" client := NewClient(httpClient, baseURL) + assert.NotNil(t, client) assert.Equal(t, httpClient, client.HTTPClient()) assert.Equal(t, baseURL, client.BaseURL()) - assert.Equal(t, "https://api.github.com/graphql", client.GraphQLURL()) } -func TestClient_BaseURL(t *testing.T) { +func TestClient_GraphQLURL(t *testing.T) { tests := []struct { - name string - baseURL string - want string + name string + baseURL string + expected string }{ { - name: "GitHub.com", - baseURL: "https://api.github.com", - want: "https://api.github.com", + name: "standard github", + baseURL: "https://api.github.com", + expected: "https://api.github.com/graphql", }, { - name: "GitHub Enterprise Server", - baseURL: "https://ghe.company.com/api", - want: "https://ghe.company.com/api", + name: "github enterprise", + baseURL: "https://ghe.example.com/api", + expected: "https://ghe.example.com/api/graphql", + }, + { + name: "localhost", + baseURL: "http://localhost:8080", + expected: "http://localhost:8080/graphql", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client := NewClient(&http.Client{}, tt.baseURL) - assert.Equal(t, tt.want, client.BaseURL()) + assert.Equal(t, tt.expected, client.GraphQLURL()) }) } } -func TestClient_GraphQLURL(t *testing.T) { +func TestClient_BaseURL(t *testing.T) { tests := []struct { name string baseURL string - want string }{ - { - name: "GitHub.com", - baseURL: "https://api.github.com", - want: "https://api.github.com/graphql", - }, - { - name: "GitHub Enterprise Server", - baseURL: "https://ghe.company.com/api", - want: "https://ghe.company.com/api/graphql", - }, + {name: "standard github", baseURL: "https://api.github.com"}, + {name: "github enterprise", baseURL: "https://ghe.example.com/api"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client := NewClient(&http.Client{}, tt.baseURL) - assert.Equal(t, tt.want, client.GraphQLURL()) + assert.Equal(t, tt.baseURL, client.BaseURL()) }) } } func TestClient_RESTURL(t *testing.T) { tests := []struct { - name string - baseURL string - path string - want string + name string + baseURL string + path string + expected string }{ { - name: "GitHub.com - pull request", - baseURL: "https://api.github.com", - path: "/repos/uber/submitqueue/pulls/123", - want: "https://api.github.com/repos/uber/submitqueue/pulls/123", - }, - { - name: "GitHub Enterprise Server - pull request", - baseURL: "https://ghe.company.com/api", - path: "/repos/uber/submitqueue/pulls/456", - want: "https://ghe.company.com/api/repos/uber/submitqueue/pulls/456", + name: "repos endpoint", + baseURL: "https://api.github.com", + path: "/repos/uber/submitqueue/pulls/123", + expected: "https://api.github.com/repos/uber/submitqueue/pulls/123", }, { - name: "repos endpoint", - baseURL: "https://api.github.com", - path: "/repos/uber/submitqueue", - want: "https://api.github.com/repos/uber/submitqueue", + name: "enterprise repos endpoint", + baseURL: "https://ghe.example.com/api", + path: "/repos/myorg/myrepo/pulls/456", + expected: "https://ghe.example.com/api/repos/myorg/myrepo/pulls/456", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client := NewClient(&http.Client{}, tt.baseURL) - assert.Equal(t, tt.want, client.RESTURL(tt.path)) + assert.Equal(t, tt.expected, client.RESTURL(tt.path)) }) } } diff --git a/extension/changeprovider/github/graphql.go b/extension/changeprovider/github/graphql.go index 1c141c12..08f43206 100644 --- a/extension/changeprovider/github/graphql.go +++ b/extension/changeprovider/github/graphql.go @@ -7,9 +7,6 @@ import ( "fmt" "io" "net/http" - - "github.com/uber-go/tally/v4" - "go.uber.org/zap" ) // pullRequestQuery is the GraphQL query to fetch pull request information including files, author, and head SHA. @@ -122,8 +119,6 @@ func doGraphQLRequest( ctx context.Context, bodyBytes []byte, client *Client, - org, repo string, - metrics tally.Scope, ) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, http.MethodPost, client.GraphQLURL(), bytes.NewReader(bodyBytes)) if err != nil { @@ -134,11 +129,6 @@ func doGraphQLRequest( resp, err := client.HTTPClient().Do(req) if err != nil { - metrics.Tagged(map[string]string{ - "org": org, - "repo": repo, - "error_type": "http_error", - }).Counter("get_errors").Inc(1) return nil, fmt.Errorf("HTTP request failed: %w", err) } @@ -148,50 +138,18 @@ func doGraphQLRequest( // parseGraphQLResponse parses and validates a GraphQL response. func parseGraphQLResponse( resp *http.Response, - org, repo string, - prNumber int, - logger *zap.SugaredLogger, - metrics tally.Scope, ) (*pullRequestData, error) { if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - logger.Errorw("GitHub API error", - "status", resp.StatusCode, - "org", org, - "repo", repo, - "pr", prNumber, - "response", string(body), - ) - metrics.Tagged(map[string]string{ - "org": org, - "repo": repo, - "error_type": "api_error", - }).Counter("get_errors").Inc(1) return nil, fmt.Errorf("GitHub API returned status %d: %s", resp.StatusCode, string(body)) } var gqlResp graphqlResponse if err := json.NewDecoder(resp.Body).Decode(&gqlResp); err != nil { - metrics.Tagged(map[string]string{ - "org": org, - "repo": repo, - "error_type": "decode_error", - }).Counter("get_errors").Inc(1) return nil, fmt.Errorf("failed to decode GraphQL response: %w", err) } if len(gqlResp.Errors) > 0 { - logger.Errorw("GraphQL errors", - "org", org, - "repo", repo, - "pr", prNumber, - "errors", gqlResp.Errors, - ) - metrics.Tagged(map[string]string{ - "org": org, - "repo": repo, - "error_type": "graphql_error", - }).Counter("get_errors").Inc(1) return nil, fmt.Errorf("GraphQL errors: %+v", gqlResp.Errors) } diff --git a/extension/changeprovider/github/graphql_test.go b/extension/changeprovider/github/graphql_test.go new file mode 100644 index 00000000..81e4b591 --- /dev/null +++ b/extension/changeprovider/github/graphql_test.go @@ -0,0 +1,145 @@ +package github + +import ( + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildGraphQLRequest(t *testing.T) { + tests := []struct { + name string + org string + repo string + prNumber int + cursor string + wantVars map[string]any + }{ + { + name: "no cursor", + org: "uber", + repo: "submitqueue", + prNumber: 123, + cursor: "", + wantVars: map[string]any{ + "owner": "uber", + "repo": "submitqueue", + "prNumber": 123, + "filesCursor": "", + }, + }, + { + name: "with cursor", + org: "myorg", + repo: "myrepo", + prNumber: 456, + cursor: "cursor_token_xyz", + wantVars: map[string]any{ + "owner": "myorg", + "repo": "myrepo", + "prNumber": 456, + "filesCursor": "cursor_token_xyz", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := buildGraphQLRequest(tt.org, tt.repo, tt.prNumber, tt.cursor) + assert.Equal(t, pullRequestQuery, req.Query) + assert.Equal(t, tt.wantVars, req.Variables) + }) + } +} + +func TestParseGraphQLResponse(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + wantErr bool + wantData *pullRequestData + }{ + { + name: "success", + statusCode: http.StatusOK, + body: `{ + "data": { + "repository": { + "pullRequest": { + "number": 42, + "headRefOid": "abc123", + "author": {"login": "octocat", "name": "The Octocat", "email": "octocat@example.com"}, + "files": { + "totalCount": 1, + "pageInfo": {"endCursor": "cur1", "hasNextPage": false}, + "nodes": [{"path": "main.go", "additions": 10, "deletions": 2, "changeType": "MODIFIED", "patch": "diff content"}] + } + } + } + } + }`, + wantData: &pullRequestData{ + Number: 42, + HeadRefOid: "abc123", + Author: authorData{Login: "octocat", Name: "The Octocat", Email: "octocat@example.com"}, + Files: filesData{ + TotalCount: 1, + PageInfo: pageInfo{EndCursor: "cur1", HasNextPage: false}, + Nodes: []fileNode{{Path: "main.go", Additions: 10, Deletions: 2, ChangeType: "MODIFIED", Patch: "diff content"}}, + }, + }, + }, + { + name: "non-200 status", + statusCode: http.StatusInternalServerError, + body: `{"message":"Internal Server Error"}`, + wantErr: true, + }, + { + name: "404 not found", + statusCode: http.StatusNotFound, + body: `{"message":"Not Found"}`, + wantErr: true, + }, + { + name: "invalid JSON", + statusCode: http.StatusOK, + body: `{invalid json`, + wantErr: true, + }, + { + name: "GraphQL errors", + statusCode: http.StatusOK, + body: `{"errors":[{"message":"Field doesn't exist","type":"INVALID_FIELD"}]}`, + wantErr: true, + }, + { + name: "multiple GraphQL errors", + statusCode: http.StatusOK, + body: `{"errors":[{"message":"Not Found","type":"NOT_FOUND"},{"message":"Forbidden","type":"FORBIDDEN"}]}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &http.Response{ + StatusCode: tt.statusCode, + Body: io.NopCloser(strings.NewReader(tt.body)), + } + + got, err := parseGraphQLResponse(resp) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantData, got) + }) + } +} diff --git a/extension/changeprovider/github/provider.go b/extension/changeprovider/github/provider.go index 11223208..e8024b6d 100644 --- a/extension/changeprovider/github/provider.go +++ b/extension/changeprovider/github/provider.go @@ -4,59 +4,54 @@ import ( "context" "encoding/json" "fmt" - "time" "github.com/uber-go/tally/v4" "go.uber.org/zap" + coremetrics "github.com/uber/submitqueue/core/metrics" "github.com/uber/submitqueue/entity" entitygithub "github.com/uber/submitqueue/entity/github" "github.com/uber/submitqueue/extension/changeprovider" ) +// Params holds the dependencies for the GitHub ChangeProvider. +type Params struct { + // Client is a pre-configured GitHub API client (encapsulates HTTP client and GraphQL URL). + // Auth is the caller's responsibility via HTTP transport/round-tripper. + Client *Client + // Logger is the structured logger. + Logger *zap.SugaredLogger + // MetricsScope is the metrics scope for instrumentation. + MetricsScope tally.Scope +} + // provider implements the ChangeProvider interface for GitHub. type provider struct { - client *Client - logger *zap.SugaredLogger - metrics tally.Scope + client *Client + logger *zap.SugaredLogger + metricsScope tally.Scope } // NewProvider creates a new GitHub ChangeProvider. -// The caller is responsible for providing a fully-configured Client with authentication. -// Use NewAuthenticatedClient helper to create a client with bearer token auth. -// -// Parameters: -// - client: Pre-configured GitHub API client (encapsulates HTTP client and GraphQL URL) -// - logger: Structured logger -// - metrics: Metrics scope -func NewProvider( - client *Client, - logger *zap.SugaredLogger, - metrics tally.Scope, -) changeprovider.ChangeProvider { +func NewProvider(params Params) changeprovider.ChangeProvider { return &provider{ - client: client, - logger: logger.Named("github_changeprovider"), - metrics: metrics.SubScope("github_changeprovider"), + client: params.Client, + logger: params.Logger.Named("github_changeprovider"), + metricsScope: params.MetricsScope.SubScope("github_changeprovider"), } } // Get retrieves change information from GitHub for the provided Change. // Returns one ChangeInfo per URI (one per PR in stacked changes). -// TODO add error codes for user errors (non-retryable) vs system errors. -func (p *provider) Get(ctx context.Context, change entity.Change) ([]changeprovider.ChangeInfo, error) { - p.metrics.Counter("get_change_info_started").Inc(1) - startTime := time.Now() - defer func() { - p.metrics.Timer("get_change_info_latency").Record(time.Since(startTime)) - }() +func (p *provider) Get(ctx context.Context, change entity.Change) (_ []changeprovider.ChangeInfo, retErr error) { + op := coremetrics.Begin(p.metricsScope, "get") + defer func() { op.Complete(retErr) }() // Parse all change IDs changeIDs := make([]entitygithub.ChangeID, 0, len(change.URIs)) for _, uri := range change.URIs { parsed, err := entitygithub.ParseChangeID(uri) if err != nil { - p.metrics.Counter("get_change_info_errors").Inc(1) return nil, fmt.Errorf("failed to parse GitHub change ID %q: %w", uri, err) } changeIDs = append(changeIDs, parsed) @@ -68,75 +63,44 @@ func (p *provider) Get(ctx context.Context, change entity.Change) ([]changeprovi ) // Validate stacked changes are consistent (same provider, org, and repo) - org, repo, err := validateChangeConsistency(changeIDs) - if err != nil { + if err := validateChangeConsistency(changeIDs); err != nil { return nil, err } // Fetch each PR and build ChangeInfo for each - changeInfos, fetchErrors, failedPRs := p.fetchAllPRs(ctx, changeIDs) - - // Return partial results if any PRs failed - if len(fetchErrors) > 0 { - p.logger.Errorw("failed to fetch some PRs", - "total_prs", len(changeIDs), - "failed_count", len(fetchErrors), - "failed_prs", failedPRs, - "succeeded_count", len(changeInfos), - ) - return changeInfos, fmt.Errorf("failed to fetch %d of %d PRs (failed: %v): %v", - len(fetchErrors), len(changeIDs), failedPRs, fetchErrors) + changeInfos, err := p.fetchAllPRs(ctx, changeIDs) + if err != nil { + return nil, err } p.logger.Debugw("successfully fetched PR data", "pr_count", len(changeIDs), ) - p.metrics.Tagged(map[string]string{ - "org": org, - "repo": repo, - }).Counter("get_success").Inc(1) - return changeInfos, nil } -// fetchAllPRs fetches and validates all PRs in the stack, handling partial failures. -// Returns the successfully fetched ChangeInfos, any errors encountered, and the list of failed PR numbers. +// fetchAllPRs fetches and validates all PRs in the stack, returning on the first error. func (p *provider) fetchAllPRs( ctx context.Context, changeIDs []entitygithub.ChangeID, -) ([]changeprovider.ChangeInfo, []error, []int) { +) ([]changeprovider.ChangeInfo, error) { changeInfos := make([]changeprovider.ChangeInfo, 0, len(changeIDs)) - var fetchErrors []error - var failedPRs []int for _, cid := range changeIDs { prData, err := p.fetchPullRequest(ctx, cid) if err != nil { - p.logger.Errorw("failed to fetch PR from GitHub", - "org", cid.Org, - "repo", cid.Repo, - "pr", cid.PRNumber, - "error", err, + coremetrics.NamedCounter(p.metricsScope, "fetch_pr", "errors", 1, + coremetrics.NewTag("org", cid.Org), + coremetrics.NewTag("repo", cid.Repo), ) - p.metrics.Tagged(map[string]string{ - "org": cid.Org, - "repo": cid.Repo, - "error_type": "fetch_pr", - }).Counter("get_errors").Inc(1) - fetchErrors = append(fetchErrors, fmt.Errorf("PR #%d: %w", cid.PRNumber, err)) - failedPRs = append(failedPRs, cid.PRNumber) - continue // Continue to next PR + return nil, fmt.Errorf("failed to fetch PR #%d: %w", cid.PRNumber, err) } - // Validate PR hasn't changed since submission if err := validatePRStaleness(cid, prData); err != nil { - fetchErrors = append(fetchErrors, err) - failedPRs = append(failedPRs, cid.PRNumber) - continue // Continue to next PR + return nil, err } - // Convert to ChangeInfo changeInfo := convertToChangeInfo(cid, prData) changeInfos = append(changeInfos, changeInfo) @@ -149,7 +113,7 @@ func (p *provider) fetchAllPRs( ) } - return changeInfos, fetchErrors, failedPRs + return changeInfos, nil } // fetchPullRequest makes GraphQL request(s) to fetch PR data, handling pagination. @@ -189,12 +153,11 @@ func (p *provider) fetchPullRequestPage(ctx context.Context, org, repo string, p return nil, fmt.Errorf("failed to marshal GraphQL request: %w", err) } - resp, err := doGraphQLRequest(ctx, bodyBytes, p.client, org, repo, p.metrics) + resp, err := doGraphQLRequest(ctx, bodyBytes, p.client) if err != nil { return nil, err } defer resp.Body.Close() - return parseGraphQLResponse(resp, org, repo, prNumber, p.logger, p.metrics) + return parseGraphQLResponse(resp) } - diff --git a/extension/changeprovider/github/provider_test.go b/extension/changeprovider/github/provider_test.go index f7f1083d..28e9a7dd 100644 --- a/extension/changeprovider/github/provider_test.go +++ b/extension/changeprovider/github/provider_test.go @@ -1,12 +1,11 @@ package github import ( - "bytes" "context" - "io" + "encoding/json" "net/http" + "net/http/httptest" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -14,417 +13,152 @@ import ( "go.uber.org/zap/zaptest" "github.com/uber/submitqueue/entity" + "github.com/uber/submitqueue/extension/changeprovider" ) -// mockRoundTripper is a mock implementation of http.RoundTripper for testing. -type mockRoundTripper struct { - roundTripFunc func(*http.Request) (*http.Response, error) -} - -func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - return m.roundTripFunc(req) -} - -// newMockClient creates an http.Client with a mock RoundTripper. -func newMockClient(roundTripFunc func(*http.Request) (*http.Response, error)) *http.Client { - return &http.Client{ - Transport: &mockRoundTripper{roundTripFunc: roundTripFunc}, +func newTestProvider(t *testing.T, serverURL string) changeprovider.ChangeProvider { + t.Helper() + return NewProvider(Params{ + Client: NewClient(&http.Client{}, serverURL), + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) +} + +func servePR(t *testing.T, w http.ResponseWriter, data pullRequestData) { + t.Helper() + var resp graphqlResponse + resp.Data.Repository.PullRequest = data + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(resp)) +} + +func TestProvider_Get(t *testing.T) { + tests := []struct { + name string + handler http.HandlerFunc + uris []string + wantErr bool + }{ + { + name: "returns result for valid PR", + handler: func(w http.ResponseWriter, r *http.Request) { + servePR(t, w, pullRequestData{ + Number: 123, + HeadRefOid: "abc123", + Author: authorData{Name: "Test User", Email: "test@example.com"}, + Files: filesData{ + Nodes: []fileNode{ + {Path: "main.go"}, + {Path: "test.go"}, + }, + }, + }) + }, + uris: []string{"github://uber/submitqueue/123/abc123"}, + }, + { + name: "invalid URI returns error", + uris: []string{"invalid://uri"}, + wantErr: true, + }, + { + name: "inconsistent change set returns error", + uris: []string{ + "github://uber/submitqueue/123/abc123", + "github://uber/different-repo/456/def456", + }, + wantErr: true, + }, + { + name: "stale PR returns error", + handler: func(w http.ResponseWriter, r *http.Request) { + servePR(t, w, pullRequestData{ + Number: 123, + HeadRefOid: "newsha", + Files: filesData{Nodes: []fileNode{{Path: "main.go"}}}, + }) + }, + uris: []string{"github://uber/submitqueue/123/oldsha"}, + wantErr: true, + }, } -} -func TestProvider_Get_Success(t *testing.T) { - responseBody := `{ - "data": { - "repository": { - "pullRequest": { - "number": 123, - "headRefOid": "abc123def456", - "author": { - "login": "testuser", - "name": "Test User", - "email": "test@example.com" - }, - "files": { - "totalCount": 2, - "pageInfo": { - "endCursor": "", - "hasNextPage": false - }, - "nodes": [ - { - "path": "main.go", - "additions": 10, - "deletions": 5, - "changeType": "MODIFIED", - "patch": "diff --git a/main.go b/main.go\n..." - }, - { - "path": "test.go", - "additions": 20, - "deletions": 0, - "changeType": "ADDED", - "patch": "diff --git a/test.go b/test.go\n..." - } - ] - } - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + serverURL := "http://localhost" + if tt.handler != nil { + server := httptest.NewServer(tt.handler) + defer server.Close() + serverURL = server.URL } - } - }` - - mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { - assert.Equal(t, http.MethodPost, req.Method) - assert.Equal(t, "https://api.github.test/graphql", req.URL.String()) - assert.Equal(t, "application/json", req.Header.Get("Content-Type")) - - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(responseBody)), - Header: make(http.Header), - }, nil - }) - client := NewClient(mockClient, "https://api.github.test") - provider := NewProvider( - client, - zaptest.NewLogger(t).Sugar(), - tally.NoopScope, - ) + p := newTestProvider(t, serverURL) + infos, err := p.Get(context.Background(), entity.Change{URIs: tt.uris}) - changeInfo, err := provider.Get(context.Background(), entity.Change{ - URIs: []string{"github://uber/submitqueue/123/abc123def456"}, - }) - - require.NoError(t, err) - require.Len(t, changeInfo, 1, "should return 1 ChangeInfo for 1 PR") - - info := changeInfo[0] - assert.Equal(t, "github://uber/submitqueue/123/abc123def456", info.URI) - assert.Equal(t, "Test User", info.User.Name) - assert.Equal(t, "test@example.com", info.User.Email) - assert.Len(t, info.ChangedFiles, 2) - - assert.Equal(t, "main.go", info.ChangedFiles[0].Path) - assert.Equal(t, 10, info.ChangedFiles[0].LinesAdded) - assert.Equal(t, 5, info.ChangedFiles[0].LinesDeleted) - assert.Equal(t, 5, info.ChangedFiles[0].LinesModified) - assert.Contains(t, info.ChangedFiles[0].Patch, "diff --git a/main.go") - - assert.Equal(t, "test.go", info.ChangedFiles[1].Path) - assert.Equal(t, 20, info.ChangedFiles[1].LinesAdded) - assert.Equal(t, 0, info.ChangedFiles[1].LinesDeleted) - assert.Equal(t, 0, info.ChangedFiles[1].LinesModified) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Len(t, infos, 1) + assert.Equal(t, tt.uris[0], infos[0].URI) + assert.Len(t, infos[0].ChangedFiles, 2) + }) + } } func TestProvider_Get_Pagination(t *testing.T) { - callCount := 0 - responses := []string{ - `{ - "data": { - "repository": { - "pullRequest": { - "number": 456, - "headRefOid": "xyz789", - "author": { - "login": "user", - "name": "User", - "email": "user@example.com" - }, - "files": { - "totalCount": 150, - "pageInfo": { - "endCursor": "cursor1", - "hasNextPage": true - }, - "nodes": [ - { - "path": "file1.go", - "additions": 5, - "deletions": 2, - "changeType": "MODIFIED", - "patch": "diff1" - } - ] - } - } - } - } - }`, - `{ - "data": { - "repository": { - "pullRequest": { - "number": 456, - "headRefOid": "xyz789", - "author": { - "login": "user", - "name": "User", - "email": "user@example.com" - }, - "files": { - "totalCount": 150, - "pageInfo": { - "endCursor": "", - "hasNextPage": false - }, - "nodes": [ - { - "path": "file2.go", - "additions": 3, - "deletions": 1, - "changeType": "MODIFIED", - "patch": "diff2" - } - ] - } - } - } - } - }`, + pages := []pullRequestData{ + { + Number: 456, + HeadRefOid: "xyz789", + Files: filesData{ + PageInfo: pageInfo{EndCursor: "cursor1", HasNextPage: true}, + Nodes: []fileNode{{Path: "file1.go"}}, + }, + }, + { + Number: 456, + HeadRefOid: "xyz789", + Files: filesData{ + PageInfo: pageInfo{HasNextPage: false}, + Nodes: []fileNode{{Path: "file2.go"}}, + }, + }, } - - mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { - response := responses[callCount] + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + servePR(t, w, pages[callCount]) callCount++ - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(response)), - Header: make(http.Header), - }, nil - }) - - client := NewClient(mockClient, "https://api.github.test") - provider := NewProvider( - client, - zaptest.NewLogger(t).Sugar(), - tally.NoopScope, - ) + })) + defer server.Close() - changeInfo, err := provider.Get(context.Background(), entity.Change{ + p := newTestProvider(t, server.URL) + infos, err := p.Get(context.Background(), entity.Change{ URIs: []string{"github://uber/submitqueue/456/xyz789"}, }) require.NoError(t, err) - assert.Equal(t, 2, callCount, "should make 2 GraphQL requests for pagination") - require.Len(t, changeInfo, 1, "should return 1 ChangeInfo for 1 PR") - - info := changeInfo[0] - assert.Len(t, info.ChangedFiles, 2, "should combine files from both pages") - assert.Equal(t, "file1.go", info.ChangedFiles[0].Path) - assert.Equal(t, "file2.go", info.ChangedFiles[1].Path) -} - -func TestProvider_Get_InvalidURI(t *testing.T) { - client := NewClient(&http.Client{}, "https://api.github.test") - provider := NewProvider( - client, - zaptest.NewLogger(t).Sugar(), - tally.NoopScope, - ) - - _, err := provider.Get(context.Background(), entity.Change{ - URIs: []string{"invalid://uri"}, - }) - - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to parse GitHub change ID") -} - -func TestProvider_Get_HTTPError(t *testing.T) { - mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { - return nil, assert.AnError - }) - - client := NewClient(mockClient, "https://api.github.test") - provider := NewProvider( - client, - zaptest.NewLogger(t).Sugar(), - tally.NoopScope, - ) - - _, err := provider.Get(context.Background(), entity.Change{ - URIs: []string{"github://uber/submitqueue/pull/123/abc"}, - }) - - require.Error(t, err) - assert.Contains(t, err.Error(), "HTTP request failed") -} - -func TestProvider_Get_APIError404(t *testing.T) { - mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewBufferString(`{"message":"Not Found"}`)), - Header: make(http.Header), - }, nil - }) - - client := NewClient(mockClient, "https://api.github.test") - provider := NewProvider( - client, - zaptest.NewLogger(t).Sugar(), - tally.NoopScope, - ) - - _, err := provider.Get(context.Background(), entity.Change{ - URIs: []string{"github://uber/submitqueue/pull/999/abc"}, - }) - - require.Error(t, err) - assert.Contains(t, err.Error(), "GitHub API returned status 404") -} - -func TestProvider_Get_GraphQLError(t *testing.T) { - responseBody := `{ - "errors": [ - { - "message": "Field 'pullRequest' doesn't exist on type 'Repository'", - "type": "INVALID_FIELD" - } - ] - }` - - mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(responseBody)), - Header: make(http.Header), - }, nil - }) - - client := NewClient(mockClient, "https://api.github.test") - provider := NewProvider( - client, - zaptest.NewLogger(t).Sugar(), - tally.NoopScope, - ) - - _, err := provider.Get(context.Background(), entity.Change{ - URIs: []string{"github://uber/submitqueue/pull/123/abc"}, - }) - - require.Error(t, err) - assert.Contains(t, err.Error(), "GraphQL errors") -} - -func TestProvider_Get_InvalidJSON(t *testing.T) { - mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(`{invalid json`)), - Header: make(http.Header), - }, nil - }) - - client := NewClient(mockClient, "https://api.github.test") - provider := NewProvider( - client, - zaptest.NewLogger(t).Sugar(), - tally.NoopScope, - ) - - _, err := provider.Get(context.Background(), entity.Change{ - URIs: []string{"github://uber/submitqueue/pull/123/abc"}, - }) - - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to decode GraphQL response") -} - -func TestNewProvider(t *testing.T) { - httpClient := &http.Client{Timeout: 30 * time.Second} - client := NewClient(httpClient, "https://api.github.com") - - provider := NewProvider(client, zaptest.NewLogger(t).Sugar(), tally.NoopScope) - - assert.NotNil(t, provider) + assert.Equal(t, 2, callCount) + require.Len(t, infos, 1) + assert.Len(t, infos[0].ChangedFiles, 2) } func TestProvider_Get_MultiplePRs(t *testing.T) { - callCount := 0 - responses := map[int]string{ - 0: `{ - "data": { - "repository": { - "pullRequest": { - "number": 123, - "headRefOid": "abc123", - "author": { - "login": "user1", - "name": "User One", - "email": "user1@example.com" - }, - "files": { - "totalCount": 1, - "pageInfo": { - "endCursor": "", - "hasNextPage": false - }, - "nodes": [ - { - "path": "file1.go", - "additions": 10, - "deletions": 5, - "changeType": "MODIFIED", - "patch": "diff1" - } - ] - } - } - } - } - }`, - 1: `{ - "data": { - "repository": { - "pullRequest": { - "number": 456, - "headRefOid": "def456", - "author": { - "login": "user1", - "name": "User One", - "email": "user1@example.com" - }, - "files": { - "totalCount": 1, - "pageInfo": { - "endCursor": "", - "hasNextPage": false - }, - "nodes": [ - { - "path": "file2.go", - "additions": 20, - "deletions": 2, - "changeType": "ADDED", - "patch": "diff2" - } - ] - } - } - } - } - }`, + prData := []pullRequestData{ + {Number: 123, HeadRefOid: "abc123", Files: filesData{Nodes: []fileNode{{Path: "file1.go"}}}}, + {Number: 456, HeadRefOid: "def456", Files: filesData{Nodes: []fileNode{{Path: "file2.go"}}}}, } - - mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { - response := responses[callCount] + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + servePR(t, w, prData[callCount]) callCount++ - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(response)), - Header: make(http.Header), - }, nil - }) - - client := NewClient(mockClient, "https://api.github.test") - provider := NewProvider( - client, - zaptest.NewLogger(t).Sugar(), - tally.NoopScope, - ) + })) + defer server.Close() - changeInfo, err := provider.Get(context.Background(), entity.Change{ + p := newTestProvider(t, server.URL) + infos, err := p.Get(context.Background(), entity.Change{ URIs: []string{ "github://uber/submitqueue/123/abc123", "github://uber/submitqueue/456/def456", @@ -432,233 +166,37 @@ func TestProvider_Get_MultiplePRs(t *testing.T) { }) require.NoError(t, err) - assert.Equal(t, 2, callCount, "should make 2 GraphQL requests for 2 PRs") - require.Len(t, changeInfo, 2, "should return 2 ChangeInfo for 2 PRs") - - // First PR - assert.Equal(t, "github://uber/submitqueue/123/abc123", changeInfo[0].URI) - assert.Equal(t, "User One", changeInfo[0].User.Name) - assert.Equal(t, "user1@example.com", changeInfo[0].User.Email) - assert.Len(t, changeInfo[0].ChangedFiles, 1) - assert.Equal(t, "file1.go", changeInfo[0].ChangedFiles[0].Path) - - // Second PR - assert.Equal(t, "github://uber/submitqueue/456/def456", changeInfo[1].URI) - assert.Equal(t, "User One", changeInfo[1].User.Name) - assert.Equal(t, "user1@example.com", changeInfo[1].User.Email) - assert.Len(t, changeInfo[1].ChangedFiles, 1) - assert.Equal(t, "file2.go", changeInfo[1].ChangedFiles[0].Path) -} - -func TestProvider_Get_CrossRepoStack(t *testing.T) { - client := NewClient(&http.Client{}, "https://api.github.test") - provider := NewProvider( - client, - zaptest.NewLogger(t).Sugar(), - tally.NoopScope, - ) - - _, err := provider.Get(context.Background(), entity.Change{ - URIs: []string{ - "github://uber/submitqueue/123/abc123", - "github://uber/different-repo/456/def456", // Different repo! - }, - }) - - require.Error(t, err) - assert.Contains(t, err.Error(), "stacked changes must be from same repository") - assert.Contains(t, err.Error(), "expected uber/submitqueue") - assert.Contains(t, err.Error(), "got uber/different-repo") + assert.Equal(t, 2, callCount) + require.Len(t, infos, 2) + assert.Equal(t, "github://uber/submitqueue/123/abc123", infos[0].URI) + assert.Equal(t, "github://uber/submitqueue/456/def456", infos[1].URI) } -func TestProvider_Get_MixedProviderStack(t *testing.T) { - client := NewClient(&http.Client{}, "https://api.github.test") - provider := NewProvider( - client, - zaptest.NewLogger(t).Sugar(), - tally.NoopScope, - ) - - _, err := provider.Get(context.Background(), entity.Change{ - URIs: []string{ - "github://uber/submitqueue/123/abc123", - "ghe://uber/submitqueue/456/def456", // Different provider (GHE instead of github)! - }, - }) - - require.Error(t, err) - assert.Contains(t, err.Error(), "stacked changes must use same change provider") - assert.Contains(t, err.Error(), "expected github") - assert.Contains(t, err.Error(), "got ghe") -} - -func TestProvider_Get_StalePR(t *testing.T) { - responseBody := `{ - "data": { - "repository": { - "pullRequest": { - "number": 123, - "headRefOid": "newsha123", - "author": { - "login": "testuser", - "name": "Test User", - "email": "test@example.com" - }, - "files": { - "totalCount": 1, - "pageInfo": { - "endCursor": "", - "hasNextPage": false - }, - "nodes": [ - { - "path": "main.go", - "additions": 10, - "deletions": 5, - "changeType": "MODIFIED", - "patch": "diff --git a/main.go..." - } - ] - } - } - } - } - }` - - mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(responseBody)), - Header: make(http.Header), - }, nil - }) - - client := NewClient(mockClient, "https://api.github.test") - provider := NewProvider( - client, - zaptest.NewLogger(t).Sugar(), - tally.NoopScope, - ) - - _, err := provider.Get(context.Background(), entity.Change{ - URIs: []string{"github://uber/submitqueue/123/oldsha456"}, // Different SHA! - }) - - require.Error(t, err) - assert.Contains(t, err.Error(), "PR #123 head SHA changed") - assert.Contains(t, err.Error(), "expected oldsha456") - assert.Contains(t, err.Error(), "got newsha123") -} - -func TestProvider_Get_PartialSuccess(t *testing.T) { +func TestProvider_Get_FetchError_StopsOnFirstFailure(t *testing.T) { callCount := 0 - responses := map[int]string{ - 0: `{ - "data": { - "repository": { - "pullRequest": { - "number": 123, - "headRefOid": "abc123", - "author": { - "login": "user1", - "name": "User One", - "email": "user1@example.com" - }, - "files": { - "totalCount": 1, - "pageInfo": { - "endCursor": "", - "hasNextPage": false - }, - "nodes": [ - { - "path": "file1.go", - "additions": 10, - "deletions": 5, - "changeType": "MODIFIED", - "patch": "diff1" - } - ] - } - } - } - } - }`, - // Second PR will get an error response (404) - 2: `{ - "data": { - "repository": { - "pullRequest": { - "number": 789, - "headRefOid": "ghi789", - "author": { - "login": "user1", - "name": "User One", - "email": "user1@example.com" - }, - "files": { - "totalCount": 1, - "pageInfo": { - "endCursor": "", - "hasNextPage": false - }, - "nodes": [ - { - "path": "file3.go", - "additions": 15, - "deletions": 3, - "changeType": "MODIFIED", - "patch": "diff3" - } - ] - } - } - } - } - }`, - } - - mockClient := newMockClient(func(req *http.Request) (*http.Response, error) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if callCount == 1 { - // Fail on second PR callCount++ - return &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewBufferString(`{"message":"Not Found"}`)), - Header: make(http.Header), - }, nil + w.WriteHeader(http.StatusInternalServerError) + return } - response := responses[callCount] + servePR(t, w, pullRequestData{ + Number: 123, + HeadRefOid: "abc123", + Files: filesData{Nodes: []fileNode{{Path: "file1.go"}}}, + }) callCount++ - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(response)), - Header: make(http.Header), - }, nil - }) - - client := NewClient(mockClient, "https://api.github.test") - provider := NewProvider( - client, - zaptest.NewLogger(t).Sugar(), - tally.NoopScope, - ) + })) + defer server.Close() - changeInfo, err := provider.Get(context.Background(), entity.Change{ + p := newTestProvider(t, server.URL) + _, err := p.Get(context.Background(), entity.Change{ URIs: []string{ "github://uber/submitqueue/123/abc123", - "github://uber/submitqueue/456/def456", // This will fail - "github://uber/submitqueue/789/ghi789", + "github://uber/submitqueue/456/def456", }, }) - // Should return partial results with error require.Error(t, err) - assert.Contains(t, err.Error(), "failed to fetch 1 of 3 PRs") - assert.Contains(t, err.Error(), "failed: [456]") - - // Should have 2 successful PRs - require.Len(t, changeInfo, 2, "should return 2 successful ChangeInfo despite 1 failure") - assert.Equal(t, "github://uber/submitqueue/123/abc123", changeInfo[0].URI) - assert.Equal(t, "github://uber/submitqueue/789/ghi789", changeInfo[1].URI) + assert.Equal(t, 2, callCount) } diff --git a/extension/changeprovider/github/validate.go b/extension/changeprovider/github/validate.go index b885dd7d..55bf1fb5 100644 --- a/extension/changeprovider/github/validate.go +++ b/extension/changeprovider/github/validate.go @@ -8,10 +8,9 @@ import ( // validateChangeConsistency validates that all changeIDs in the stack are consistent. // Stacked changes must have the same change provider (scheme), org, and repo. -// Returns the org and repo if valid, or an error if any change is inconsistent. func validateChangeConsistency( changeIDs []entitygithub.ChangeID, -) (string, string, error) { +) error { expectedScheme := changeIDs[0].Scheme expectedOrg := changeIDs[0].Org expectedRepo := changeIDs[0].Repo @@ -19,18 +18,18 @@ func validateChangeConsistency( for _, cid := range changeIDs { // Validate same change provider (scheme) if cid.Scheme != expectedScheme { - return "", "", fmt.Errorf("stacked changes must use same change provider: expected %s, got %s for PR #%d", + return fmt.Errorf("stacked changes must use same change provider: expected %s, got %s for PR #%d", expectedScheme, cid.Scheme, cid.PRNumber) } // Validate same org and repo if cid.Org != expectedOrg || cid.Repo != expectedRepo { - return "", "", fmt.Errorf("stacked changes must be from same repository: expected %s/%s, got %s/%s for PR #%d", + return fmt.Errorf("stacked changes must be from same org/repository: expected %s/%s, got %s/%s for PR #%d", expectedOrg, expectedRepo, cid.Org, cid.Repo, cid.PRNumber) } } - return expectedOrg, expectedRepo, nil + return nil } // validatePRStaleness validates that the PR hasn't changed since submission. diff --git a/extension/changeprovider/github/validate_test.go b/extension/changeprovider/github/validate_test.go new file mode 100644 index 00000000..0612c1ea --- /dev/null +++ b/extension/changeprovider/github/validate_test.go @@ -0,0 +1,99 @@ +package github + +import ( + "testing" + + "github.com/stretchr/testify/require" + entitygithub "github.com/uber/submitqueue/entity/github" +) + +func TestValidateChangeConsistency(t *testing.T) { + tests := []struct { + name string + changeIDs []entitygithub.ChangeID + wantErr bool + }{ + { + name: "single PR", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + }, + }, + { + name: "consistent stack", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 2}, + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 3}, + }, + }, + { + name: "different repo", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + {Scheme: "github", Org: "uber", Repo: "other-repo", PRNumber: 2}, + }, + wantErr: true, + }, + { + name: "different org", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + {Scheme: "github", Org: "other-org", Repo: "submitqueue", PRNumber: 2}, + }, + wantErr: true, + }, + { + name: "different scheme", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + {Scheme: "ghe", Org: "uber", Repo: "submitqueue", PRNumber: 2}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateChangeConsistency(tt.changeIDs) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidatePRStaleness(t *testing.T) { + tests := []struct { + name string + cid entitygithub.ChangeID + prData pullRequestData + wantErr bool + }{ + { + name: "matching SHA", + cid: entitygithub.ChangeID{PRNumber: 1, HeadCommitSHA: "abc123"}, + prData: pullRequestData{HeadRefOid: "abc123"}, + wantErr: false, + }, + { + name: "mismatched SHA", + cid: entitygithub.ChangeID{PRNumber: 1, HeadCommitSHA: "oldsha"}, + prData: pullRequestData{HeadRefOid: "newsha"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePRStaleness(tt.cid, &tt.prData) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/orchestrator/controller/validate/validate.go b/orchestrator/controller/validate/validate.go index 771a3c80..d541d16c 100644 --- a/orchestrator/controller/validate/validate.go +++ b/orchestrator/controller/validate/validate.go @@ -21,6 +21,7 @@ import ( "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" "github.com/uber/submitqueue/core/errs" + coremetrics "github.com/uber/submitqueue/core/metrics" "github.com/uber/submitqueue/entity" entityqueue "github.com/uber/submitqueue/entity/queue" "github.com/uber/submitqueue/extension/changeprovider" @@ -73,22 +74,23 @@ func NewController( // Process processes a validate delivery from the queue. // Deserializes the request and publishes to the batch topic. // Returns nil to ack (success), or error to nack (retry). -func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) error { - c.metricsScope.Counter("received").Inc(1) +func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (retErr error) { + op := coremetrics.Begin(c.metricsScope, "process") + defer func() { op.Complete(retErr) }() msg := delivery.Message() // Deserialize request ID from payload rid, err := entity.RequestIDFromBytes(msg.Payload) if err != nil { - c.metricsScope.Counter("deserialize_errors").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "deserialize_errors", 1) return fmt.Errorf("failed to deserialize request ID: %w", err) } // Fetch request from storage request, err := c.store.GetRequestStore().Get(ctx, rid.ID) if err != nil { - c.metricsScope.Counter("storage_errors").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "storage_errors", 1) return fmt.Errorf("failed to get request %s: %w", rid.ID, err) } @@ -104,7 +106,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er // Merge conflict check mergeResult, err := c.mergeChecker.Check(ctx, request.Queue, request.Change) if err != nil { - c.metricsScope.Counter("merge_check_errors").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "merge_check_errors", 1) return fmt.Errorf("merge check failed: %w", err) } if !mergeResult.Mergeable { @@ -113,7 +115,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "queue", request.Queue, "reason", mergeResult.Reason, ) - c.metricsScope.Counter("not_mergeable").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "not_mergeable", 1) return errs.NewUserError(fmt.Errorf("request %s is not mergeable: %s", request.ID, mergeResult.Reason)) } @@ -125,7 +127,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "change_uris", request.Change.URIs, "error", err, ) - c.metricsScope.Counter("change_provider_errors").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "change_provider_errors", 1) return fmt.Errorf("failed to fetch change information: %w", err) } @@ -137,7 +139,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er // Publish to batch topic if err := c.publish(ctx, consumer.TopicKeyBatch, request.ID, request.Queue); err != nil { - c.metricsScope.Counter("publish_errors").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "publish_errors", 1) return fmt.Errorf("failed to publish to batch: %w", err) } @@ -146,8 +148,6 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "topic_key", consumer.TopicKeyBatch, ) - c.metricsScope.Counter("processed").Inc(1) - return nil // Success - message will be acked }