diff --git a/example/server/orchestrator/BUILD.bazel b/example/server/orchestrator/BUILD.bazel index 4a89170f..65ad5408 100644 --- a/example/server/orchestrator/BUILD.bazel +++ b/example/server/orchestrator/BUILD.bazel @@ -12,6 +12,8 @@ go_library( visibility = ["//visibility:private"], deps = [ "//core/consumer", + "//extension/changeprovider", + "//extension/changeprovider/github", "//extension/counter", "//extension/counter/mysql", "//extension/mergechecker", diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index 25f13364..ae4526ad 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -30,6 +30,8 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" + "github.com/uber/submitqueue/extension/changeprovider" + githubprovider "github.com/uber/submitqueue/extension/changeprovider/github" "github.com/uber/submitqueue/extension/counter" mysqlcounter "github.com/uber/submitqueue/extension/counter/mysql" "github.com/uber/submitqueue/extension/mergechecker" @@ -188,8 +190,11 @@ func run() error { // Create merge checker mc := newMergeChecker(logger, scope) + // Create change provider + cp := newChangeProvider(logger, scope) + // Register controllers - if err := registerControllers(c, logger.Sugar(), scope, registry, mc, cnt, store); err != nil { + if err := registerControllers(c, logger.Sugar(), scope, registry, mc, cp, cnt, store); err != nil { return err } @@ -374,7 +379,7 @@ func newTopicRegistry(q extqueue.Queue, subscriberName string) (consumer.TopicRe // │ │ │ // └────────┴────────────────────────┘ -func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry, mc mergechecker.MergeChecker, cnt counter.Counter, store storage.Storage) error { +func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry, mc mergechecker.MergeChecker, cp changeprovider.ChangeProvider, cnt counter.Counter, store storage.Storage) error { requestController := start.NewController( logger, scope, @@ -393,6 +398,7 @@ func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope t store, registry, mc, + cp, consumer.TopicKeyValidate, "orchestrator-validate", ) @@ -499,6 +505,36 @@ func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope t return nil } +// getEnv returns environment variable value or default if not set. +func getEnv(key, defaultVal string) string { + if val := os.Getenv(key); val != "" { + return val + } + return defaultVal +} + +// parseTimeout parses a duration from environment variable with fallback to default. +// Returns defaultVal if envVal is empty or cannot be parsed. +func parseTimeout(envVal string, defaultVal time.Duration) time.Duration { + if envVal == "" { + return defaultVal + } + if d, err := time.ParseDuration(envVal); err == nil { + return d + } + return defaultVal +} + +// buildGitHubHTTPClient creates an http.Client configured for GitHub API calls. +// Configures timeout and optional bearer token authentication. +func buildGitHubHTTPClient(token string, timeout time.Duration) *http.Client { + httpClient := &http.Client{Timeout: timeout} + if token != "" { + httpClient.Transport = &bearerTransport{token: token} + } + return httpClient +} + // newMergeChecker creates a MergeChecker for GitHub (github.com). // Configured via GITHUB_TOKEN and GITHUB_GRAPHQL_URL environment variables. func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeChecker { @@ -524,6 +560,26 @@ func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeCh }) } +// newChangeProvider creates a ChangeProvider for GitHub (github.com). +// Configured via GITHUB_BASE_URL, GITHUB_TOKEN, and GITHUB_TIMEOUT environment variables. +// Uses pure dependency injection - creates http.Client with auth configured in Transport. +func newChangeProvider(logger *zap.Logger, scope tally.Scope) changeprovider.ChangeProvider { + // 1. Read configuration from environment + baseURL := getEnv("GITHUB_BASE_URL", "https://api.github.com") + token := os.Getenv("GITHUB_TOKEN") + timeout := parseTimeout(os.Getenv("GITHUB_TIMEOUT"), 30*time.Second) + + // 2. Build HTTP client with caller-controlled config (auth + timeout) + httpClient := buildGitHubHTTPClient(token, timeout) + + // 3. Inject into provider + return githubprovider.NewProvider(githubprovider.Params{ + Client: githubprovider.NewClient(httpClient, baseURL), + Logger: logger.Sugar(), + MetricsScope: scope.SubScope("changeprovider"), + }) +} + // bearerTransport is an http.RoundTripper that adds a Bearer token to requests. type bearerTransport struct { token string diff --git a/extension/changeprovider/change_provider.go b/extension/changeprovider/change_provider.go index 377cbd08..72bdd5af 100644 --- a/extension/changeprovider/change_provider.go +++ b/extension/changeprovider/change_provider.go @@ -44,8 +44,9 @@ type ChangedFile struct { // ChangeInfo contains metadata and file changes for a code change. type ChangeInfo struct { - // ID is the change identifier (e.g., "PR: uber-code/go-code/1" or "diff: uber-code/go-code/D1"). - ID string + // URI is the full change URI for correlation with the input request + // (e.g., "github://uber/repo/98/abc123sha" or "phab://D123/xyz789"). + URI string // User is the author of the change. User User // ChangedFiles is the list of files modified in this change. Order is unspecified. @@ -56,6 +57,7 @@ type ChangeInfo struct { // Each implementation is configured for a specific provider (GitHub, GitLab, Phabricator). type ChangeProvider interface { // Get retrieves change information for the provided Change. - // Returns the change info containing metadata and file changes. - Get(ctx context.Context, change entity.Change) (ChangeInfo, error) + // For a Change with multiple URIs (e.g., stacked PRs), returns one ChangeInfo per URI. + // Returns a slice of ChangeInfo, one for each change in the stack. + Get(ctx context.Context, change entity.Change) ([]ChangeInfo, error) } diff --git a/extension/changeprovider/github/BUILD.bazel b/extension/changeprovider/github/BUILD.bazel new file mode 100644 index 00000000..2c1d28f4 --- /dev/null +++ b/extension/changeprovider/github/BUILD.bazel @@ -0,0 +1,37 @@ +load("@rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "github", + srcs = [ + "config.go", + "convert.go", + "graphql.go", + "provider.go", + "validate.go", + ], + importpath = "github.com/uber/submitqueue/extension/changeprovider/github", + visibility = ["//visibility:public"], + deps = [ + "//entity", + "//entity/github", + "//extension/changeprovider", + "@com_github_uber_go_tally_v4//:tally", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "github_test", + srcs = [ + "config_test.go", + "provider_test.go", + ], + embed = [":github"], + deps = [ + "//entity", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@com_github_uber_go_tally_v4//:tally", + "@org_uber_go_zap//zaptest", + ], +) diff --git a/extension/changeprovider/github/config.go b/extension/changeprovider/github/config.go new file mode 100644 index 00000000..9c646cd9 --- /dev/null +++ b/extension/changeprovider/github/config.go @@ -0,0 +1,47 @@ +package github + +import ( + "net/http" +) + +// Client is a GitHub API client that encapsulates connection details and authentication. +// The client is protocol-agnostic - it provides helpers for both GraphQL and REST endpoints. +type Client struct { + httpClient *http.Client + baseURL string +} + +// NewClient creates a new GitHub API client with a pre-configured HTTP client. +// The caller is responsible for configuring authentication in the HTTP client's Transport. +// +// Parameters: +// - httpClient: Configured HTTP client (with auth, timeout, transport, etc.) +// - baseURL: GitHub instance base URL (e.g., "https://api.github.com" or "https://ghe.company.com/api") +func NewClient(httpClient *http.Client, baseURL string) *Client { + return &Client{ + httpClient: httpClient, + baseURL: baseURL, + } +} + +// HTTPClient returns the configured HTTP client. +func (c *Client) HTTPClient() *http.Client { + return c.httpClient +} + +// BaseURL returns the configured GitHub base URL. +func (c *Client) BaseURL() string { + return c.baseURL +} + +// GraphQLURL returns the GitHub GraphQL endpoint URL. +// Constructs the URL by appending "/graphql" to the base URL. +func (c *Client) GraphQLURL() string { + return c.baseURL + "/graphql" +} + +// RESTURL constructs a GitHub REST API endpoint URL. +// The path should start with "/" (e.g., "/repos/uber/submitqueue/pulls/123"). +func (c *Client) RESTURL(path string) string { + return c.baseURL + path +} diff --git a/extension/changeprovider/github/config_test.go b/extension/changeprovider/github/config_test.go new file mode 100644 index 00000000..34d6e81f --- /dev/null +++ b/extension/changeprovider/github/config_test.go @@ -0,0 +1,97 @@ +package github + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewClient(t *testing.T) { + httpClient := &http.Client{Timeout: 30 * time.Second} + baseURL := "https://api.github.com" + + client := NewClient(httpClient, baseURL) + + assert.NotNil(t, client) + assert.Equal(t, httpClient, client.HTTPClient()) + assert.Equal(t, baseURL, client.BaseURL()) +} + +func TestClient_GraphQLURL(t *testing.T) { + tests := []struct { + name string + baseURL string + expected string + }{ + { + name: "standard github", + baseURL: "https://api.github.com", + expected: "https://api.github.com/graphql", + }, + { + name: "github enterprise", + baseURL: "https://ghe.example.com/api", + expected: "https://ghe.example.com/api/graphql", + }, + { + name: "localhost", + baseURL: "http://localhost:8080", + expected: "http://localhost:8080/graphql", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewClient(&http.Client{}, tt.baseURL) + assert.Equal(t, tt.expected, client.GraphQLURL()) + }) + } +} + +func TestClient_BaseURL(t *testing.T) { + tests := []struct { + name string + baseURL string + }{ + {name: "standard github", baseURL: "https://api.github.com"}, + {name: "github enterprise", baseURL: "https://ghe.example.com/api"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewClient(&http.Client{}, tt.baseURL) + assert.Equal(t, tt.baseURL, client.BaseURL()) + }) + } +} + +func TestClient_RESTURL(t *testing.T) { + tests := []struct { + name string + baseURL string + path string + expected string + }{ + { + name: "repos endpoint", + baseURL: "https://api.github.com", + path: "/repos/uber/submitqueue/pulls/123", + expected: "https://api.github.com/repos/uber/submitqueue/pulls/123", + }, + { + name: "enterprise repos endpoint", + baseURL: "https://ghe.example.com/api", + path: "/repos/myorg/myrepo/pulls/456", + expected: "https://ghe.example.com/api/repos/myorg/myrepo/pulls/456", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewClient(&http.Client{}, tt.baseURL) + assert.Equal(t, tt.expected, client.RESTURL(tt.path)) + }) + } +} diff --git a/extension/changeprovider/github/convert.go b/extension/changeprovider/github/convert.go new file mode 100644 index 00000000..60394f83 --- /dev/null +++ b/extension/changeprovider/github/convert.go @@ -0,0 +1,42 @@ +package github + +import ( + entitygithub "github.com/uber/submitqueue/entity/github" + "github.com/uber/submitqueue/extension/changeprovider" +) + +// convertToChangeInfo converts GitHub PR data to ChangeInfo. +func convertToChangeInfo(parsed entitygithub.ChangeID, prData *pullRequestData) changeprovider.ChangeInfo { + changedFiles := convertFiles(prData.Files.Nodes) + + return changeprovider.ChangeInfo{ + URI: parsed.String(), + User: changeprovider.User{ + Name: prData.Author.Name, + Email: prData.Author.Email, + }, + ChangedFiles: changedFiles, + } +} + +// convertFiles converts GitHub file nodes to ChangedFile structs. +func convertFiles(nodes []fileNode) []changeprovider.ChangedFile { + changedFiles := make([]changeprovider.ChangedFile, 0, len(nodes)) + + for _, file := range nodes { + linesModified := 0 + if file.Additions > 0 && file.Deletions > 0 { + linesModified = min(file.Additions, file.Deletions) + } + + changedFiles = append(changedFiles, changeprovider.ChangedFile{ + Path: file.Path, + Patch: file.Patch, + LinesAdded: file.Additions, + LinesDeleted: file.Deletions, + LinesModified: linesModified, + }) + } + + return changedFiles +} diff --git a/extension/changeprovider/github/graphql.go b/extension/changeprovider/github/graphql.go new file mode 100644 index 00000000..08f43206 --- /dev/null +++ b/extension/changeprovider/github/graphql.go @@ -0,0 +1,157 @@ +package github + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" +) + +// pullRequestQuery is the GraphQL query to fetch pull request information including files, author, and head SHA. +const pullRequestQuery = ` +query($owner: String!, $repo: String!, $prNumber: Int!, $filesCursor: String) { + repository(owner: $owner, name: $repo) { + pullRequest(number: $prNumber) { + number + headRefOid + author { + login + ... on User { + name + email + } + } + files(first: 100, after: $filesCursor) { + totalCount + pageInfo { + endCursor + hasNextPage + } + nodes { + path + additions + deletions + changeType + patch + } + } + } + } +} +` + +// graphqlRequest represents a GraphQL request. +type graphqlRequest struct { + Query string `json:"query"` + Variables map[string]any `json:"variables"` +} + +// graphqlResponse represents the top-level GraphQL response. +type graphqlResponse struct { + Data struct { + Repository struct { + PullRequest pullRequestData `json:"pullRequest"` + } `json:"repository"` + } `json:"data"` + Errors []graphqlError `json:"errors,omitempty"` +} + +// graphqlError represents a GraphQL error. +type graphqlError struct { + Message string `json:"message"` + Type string `json:"type"` +} + +// pullRequestData contains the pull request metadata. +type pullRequestData struct { + Number int `json:"number"` + HeadRefOid string `json:"headRefOid"` + Author authorData `json:"author"` + Files filesData `json:"files"` +} + +// authorData contains the author information. +type authorData struct { + Login string `json:"login"` + Name string `json:"name"` + Email string `json:"email"` +} + +// filesData contains the files changed in the pull request. +type filesData struct { + TotalCount int `json:"totalCount"` + PageInfo pageInfo `json:"pageInfo"` + Nodes []fileNode `json:"nodes"` +} + +// pageInfo contains pagination information. +type pageInfo struct { + EndCursor string `json:"endCursor"` + HasNextPage bool `json:"hasNextPage"` +} + +// fileNode represents a single changed file. +type fileNode struct { + Path string `json:"path"` + Additions int `json:"additions"` + Deletions int `json:"deletions"` + ChangeType string `json:"changeType"` + Patch string `json:"patch"` +} + +// buildGraphQLRequest builds a GraphQL request for fetching pull request data. +func buildGraphQLRequest(org, repo string, prNumber int, cursor string) graphqlRequest { + return graphqlRequest{ + Query: pullRequestQuery, + Variables: map[string]any{ + "owner": org, + "repo": repo, + "prNumber": prNumber, + "filesCursor": cursor, + }, + } +} + +// doGraphQLRequest executes a GraphQL HTTP request. +func doGraphQLRequest( + ctx context.Context, + bodyBytes []byte, + client *Client, +) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, client.GraphQLURL(), bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := client.HTTPClient().Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + + return resp, nil +} + +// parseGraphQLResponse parses and validates a GraphQL response. +func parseGraphQLResponse( + resp *http.Response, +) (*pullRequestData, error) { + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("GitHub API returned status %d: %s", resp.StatusCode, string(body)) + } + + var gqlResp graphqlResponse + if err := json.NewDecoder(resp.Body).Decode(&gqlResp); err != nil { + return nil, fmt.Errorf("failed to decode GraphQL response: %w", err) + } + + if len(gqlResp.Errors) > 0 { + return nil, fmt.Errorf("GraphQL errors: %+v", gqlResp.Errors) + } + + return &gqlResp.Data.Repository.PullRequest, nil +} diff --git a/extension/changeprovider/github/graphql_test.go b/extension/changeprovider/github/graphql_test.go new file mode 100644 index 00000000..81e4b591 --- /dev/null +++ b/extension/changeprovider/github/graphql_test.go @@ -0,0 +1,145 @@ +package github + +import ( + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildGraphQLRequest(t *testing.T) { + tests := []struct { + name string + org string + repo string + prNumber int + cursor string + wantVars map[string]any + }{ + { + name: "no cursor", + org: "uber", + repo: "submitqueue", + prNumber: 123, + cursor: "", + wantVars: map[string]any{ + "owner": "uber", + "repo": "submitqueue", + "prNumber": 123, + "filesCursor": "", + }, + }, + { + name: "with cursor", + org: "myorg", + repo: "myrepo", + prNumber: 456, + cursor: "cursor_token_xyz", + wantVars: map[string]any{ + "owner": "myorg", + "repo": "myrepo", + "prNumber": 456, + "filesCursor": "cursor_token_xyz", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := buildGraphQLRequest(tt.org, tt.repo, tt.prNumber, tt.cursor) + assert.Equal(t, pullRequestQuery, req.Query) + assert.Equal(t, tt.wantVars, req.Variables) + }) + } +} + +func TestParseGraphQLResponse(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + wantErr bool + wantData *pullRequestData + }{ + { + name: "success", + statusCode: http.StatusOK, + body: `{ + "data": { + "repository": { + "pullRequest": { + "number": 42, + "headRefOid": "abc123", + "author": {"login": "octocat", "name": "The Octocat", "email": "octocat@example.com"}, + "files": { + "totalCount": 1, + "pageInfo": {"endCursor": "cur1", "hasNextPage": false}, + "nodes": [{"path": "main.go", "additions": 10, "deletions": 2, "changeType": "MODIFIED", "patch": "diff content"}] + } + } + } + } + }`, + wantData: &pullRequestData{ + Number: 42, + HeadRefOid: "abc123", + Author: authorData{Login: "octocat", Name: "The Octocat", Email: "octocat@example.com"}, + Files: filesData{ + TotalCount: 1, + PageInfo: pageInfo{EndCursor: "cur1", HasNextPage: false}, + Nodes: []fileNode{{Path: "main.go", Additions: 10, Deletions: 2, ChangeType: "MODIFIED", Patch: "diff content"}}, + }, + }, + }, + { + name: "non-200 status", + statusCode: http.StatusInternalServerError, + body: `{"message":"Internal Server Error"}`, + wantErr: true, + }, + { + name: "404 not found", + statusCode: http.StatusNotFound, + body: `{"message":"Not Found"}`, + wantErr: true, + }, + { + name: "invalid JSON", + statusCode: http.StatusOK, + body: `{invalid json`, + wantErr: true, + }, + { + name: "GraphQL errors", + statusCode: http.StatusOK, + body: `{"errors":[{"message":"Field doesn't exist","type":"INVALID_FIELD"}]}`, + wantErr: true, + }, + { + name: "multiple GraphQL errors", + statusCode: http.StatusOK, + body: `{"errors":[{"message":"Not Found","type":"NOT_FOUND"},{"message":"Forbidden","type":"FORBIDDEN"}]}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &http.Response{ + StatusCode: tt.statusCode, + Body: io.NopCloser(strings.NewReader(tt.body)), + } + + got, err := parseGraphQLResponse(resp) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantData, got) + }) + } +} diff --git a/extension/changeprovider/github/provider.go b/extension/changeprovider/github/provider.go new file mode 100644 index 00000000..e8024b6d --- /dev/null +++ b/extension/changeprovider/github/provider.go @@ -0,0 +1,163 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/uber-go/tally/v4" + "go.uber.org/zap" + + coremetrics "github.com/uber/submitqueue/core/metrics" + "github.com/uber/submitqueue/entity" + entitygithub "github.com/uber/submitqueue/entity/github" + "github.com/uber/submitqueue/extension/changeprovider" +) + +// Params holds the dependencies for the GitHub ChangeProvider. +type Params struct { + // Client is a pre-configured GitHub API client (encapsulates HTTP client and GraphQL URL). + // Auth is the caller's responsibility via HTTP transport/round-tripper. + Client *Client + // Logger is the structured logger. + Logger *zap.SugaredLogger + // MetricsScope is the metrics scope for instrumentation. + MetricsScope tally.Scope +} + +// provider implements the ChangeProvider interface for GitHub. +type provider struct { + client *Client + logger *zap.SugaredLogger + metricsScope tally.Scope +} + +// NewProvider creates a new GitHub ChangeProvider. +func NewProvider(params Params) changeprovider.ChangeProvider { + return &provider{ + client: params.Client, + logger: params.Logger.Named("github_changeprovider"), + metricsScope: params.MetricsScope.SubScope("github_changeprovider"), + } +} + +// Get retrieves change information from GitHub for the provided Change. +// Returns one ChangeInfo per URI (one per PR in stacked changes). +func (p *provider) Get(ctx context.Context, change entity.Change) (_ []changeprovider.ChangeInfo, retErr error) { + op := coremetrics.Begin(p.metricsScope, "get") + defer func() { op.Complete(retErr) }() + + // Parse all change IDs + changeIDs := make([]entitygithub.ChangeID, 0, len(change.URIs)) + for _, uri := range change.URIs { + parsed, err := entitygithub.ParseChangeID(uri) + if err != nil { + return nil, fmt.Errorf("failed to parse GitHub change ID %q: %w", uri, err) + } + changeIDs = append(changeIDs, parsed) + } + + p.logger.Debugw("fetching PR data from GitHub", + "pr_count", len(changeIDs), + "uris", change.URIs, + ) + + // Validate stacked changes are consistent (same provider, org, and repo) + if err := validateChangeConsistency(changeIDs); err != nil { + return nil, err + } + + // Fetch each PR and build ChangeInfo for each + changeInfos, err := p.fetchAllPRs(ctx, changeIDs) + if err != nil { + return nil, err + } + + p.logger.Debugw("successfully fetched PR data", + "pr_count", len(changeIDs), + ) + + return changeInfos, nil +} + +// fetchAllPRs fetches and validates all PRs in the stack, returning on the first error. +func (p *provider) fetchAllPRs( + ctx context.Context, + changeIDs []entitygithub.ChangeID, +) ([]changeprovider.ChangeInfo, error) { + changeInfos := make([]changeprovider.ChangeInfo, 0, len(changeIDs)) + + for _, cid := range changeIDs { + prData, err := p.fetchPullRequest(ctx, cid) + if err != nil { + coremetrics.NamedCounter(p.metricsScope, "fetch_pr", "errors", 1, + coremetrics.NewTag("org", cid.Org), + coremetrics.NewTag("repo", cid.Repo), + ) + return nil, fmt.Errorf("failed to fetch PR #%d: %w", cid.PRNumber, err) + } + + if err := validatePRStaleness(cid, prData); err != nil { + return nil, err + } + + changeInfo := convertToChangeInfo(cid, prData) + changeInfos = append(changeInfos, changeInfo) + + p.logger.Debugw("fetched PR data", + "org", cid.Org, + "repo", cid.Repo, + "pr", cid.PRNumber, + "files_count", len(changeInfo.ChangedFiles), + "head_sha", prData.HeadRefOid, + ) + } + + return changeInfos, nil +} + +// fetchPullRequest makes GraphQL request(s) to fetch PR data, handling pagination. +func (p *provider) fetchPullRequest(ctx context.Context, parsed entitygithub.ChangeID) (*pullRequestData, error) { + var allFiles []fileNode + var prData pullRequestData + cursor := "" + + for { + data, err := p.fetchPullRequestPage(ctx, parsed.Org, parsed.Repo, parsed.PRNumber, cursor) + if err != nil { + return nil, err + } + + if cursor == "" { + prData = *data + } + + allFiles = append(allFiles, data.Files.Nodes...) + + if !data.Files.PageInfo.HasNextPage { + break + } + cursor = data.Files.PageInfo.EndCursor + } + + prData.Files.Nodes = allFiles + return &prData, nil +} + +// fetchPullRequestPage fetches a single page of PR data. +func (p *provider) fetchPullRequestPage(ctx context.Context, org, repo string, prNumber int, cursor string) (*pullRequestData, error) { + reqBody := buildGraphQLRequest(org, repo, prNumber, cursor) + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal GraphQL request: %w", err) + } + + resp, err := doGraphQLRequest(ctx, bodyBytes, p.client) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return parseGraphQLResponse(resp) +} diff --git a/extension/changeprovider/github/provider_test.go b/extension/changeprovider/github/provider_test.go new file mode 100644 index 00000000..28e9a7dd --- /dev/null +++ b/extension/changeprovider/github/provider_test.go @@ -0,0 +1,202 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/uber-go/tally/v4" + "go.uber.org/zap/zaptest" + + "github.com/uber/submitqueue/entity" + "github.com/uber/submitqueue/extension/changeprovider" +) + +func newTestProvider(t *testing.T, serverURL string) changeprovider.ChangeProvider { + t.Helper() + return NewProvider(Params{ + Client: NewClient(&http.Client{}, serverURL), + Logger: zaptest.NewLogger(t).Sugar(), + MetricsScope: tally.NoopScope, + }) +} + +func servePR(t *testing.T, w http.ResponseWriter, data pullRequestData) { + t.Helper() + var resp graphqlResponse + resp.Data.Repository.PullRequest = data + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(resp)) +} + +func TestProvider_Get(t *testing.T) { + tests := []struct { + name string + handler http.HandlerFunc + uris []string + wantErr bool + }{ + { + name: "returns result for valid PR", + handler: func(w http.ResponseWriter, r *http.Request) { + servePR(t, w, pullRequestData{ + Number: 123, + HeadRefOid: "abc123", + Author: authorData{Name: "Test User", Email: "test@example.com"}, + Files: filesData{ + Nodes: []fileNode{ + {Path: "main.go"}, + {Path: "test.go"}, + }, + }, + }) + }, + uris: []string{"github://uber/submitqueue/123/abc123"}, + }, + { + name: "invalid URI returns error", + uris: []string{"invalid://uri"}, + wantErr: true, + }, + { + name: "inconsistent change set returns error", + uris: []string{ + "github://uber/submitqueue/123/abc123", + "github://uber/different-repo/456/def456", + }, + wantErr: true, + }, + { + name: "stale PR returns error", + handler: func(w http.ResponseWriter, r *http.Request) { + servePR(t, w, pullRequestData{ + Number: 123, + HeadRefOid: "newsha", + Files: filesData{Nodes: []fileNode{{Path: "main.go"}}}, + }) + }, + uris: []string{"github://uber/submitqueue/123/oldsha"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + serverURL := "http://localhost" + if tt.handler != nil { + server := httptest.NewServer(tt.handler) + defer server.Close() + serverURL = server.URL + } + + p := newTestProvider(t, serverURL) + infos, err := p.Get(context.Background(), entity.Change{URIs: tt.uris}) + + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Len(t, infos, 1) + assert.Equal(t, tt.uris[0], infos[0].URI) + assert.Len(t, infos[0].ChangedFiles, 2) + }) + } +} + +func TestProvider_Get_Pagination(t *testing.T) { + pages := []pullRequestData{ + { + Number: 456, + HeadRefOid: "xyz789", + Files: filesData{ + PageInfo: pageInfo{EndCursor: "cursor1", HasNextPage: true}, + Nodes: []fileNode{{Path: "file1.go"}}, + }, + }, + { + Number: 456, + HeadRefOid: "xyz789", + Files: filesData{ + PageInfo: pageInfo{HasNextPage: false}, + Nodes: []fileNode{{Path: "file2.go"}}, + }, + }, + } + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + servePR(t, w, pages[callCount]) + callCount++ + })) + defer server.Close() + + p := newTestProvider(t, server.URL) + infos, err := p.Get(context.Background(), entity.Change{ + URIs: []string{"github://uber/submitqueue/456/xyz789"}, + }) + + require.NoError(t, err) + assert.Equal(t, 2, callCount) + require.Len(t, infos, 1) + assert.Len(t, infos[0].ChangedFiles, 2) +} + +func TestProvider_Get_MultiplePRs(t *testing.T) { + prData := []pullRequestData{ + {Number: 123, HeadRefOid: "abc123", Files: filesData{Nodes: []fileNode{{Path: "file1.go"}}}}, + {Number: 456, HeadRefOid: "def456", Files: filesData{Nodes: []fileNode{{Path: "file2.go"}}}}, + } + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + servePR(t, w, prData[callCount]) + callCount++ + })) + defer server.Close() + + p := newTestProvider(t, server.URL) + infos, err := p.Get(context.Background(), entity.Change{ + URIs: []string{ + "github://uber/submitqueue/123/abc123", + "github://uber/submitqueue/456/def456", + }, + }) + + require.NoError(t, err) + assert.Equal(t, 2, callCount) + require.Len(t, infos, 2) + assert.Equal(t, "github://uber/submitqueue/123/abc123", infos[0].URI) + assert.Equal(t, "github://uber/submitqueue/456/def456", infos[1].URI) +} + +func TestProvider_Get_FetchError_StopsOnFirstFailure(t *testing.T) { + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if callCount == 1 { + callCount++ + w.WriteHeader(http.StatusInternalServerError) + return + } + servePR(t, w, pullRequestData{ + Number: 123, + HeadRefOid: "abc123", + Files: filesData{Nodes: []fileNode{{Path: "file1.go"}}}, + }) + callCount++ + })) + defer server.Close() + + p := newTestProvider(t, server.URL) + _, err := p.Get(context.Background(), entity.Change{ + URIs: []string{ + "github://uber/submitqueue/123/abc123", + "github://uber/submitqueue/456/def456", + }, + }) + + require.Error(t, err) + assert.Equal(t, 2, callCount) +} diff --git a/extension/changeprovider/github/validate.go b/extension/changeprovider/github/validate.go new file mode 100644 index 00000000..55bf1fb5 --- /dev/null +++ b/extension/changeprovider/github/validate.go @@ -0,0 +1,46 @@ +package github + +import ( + "fmt" + + entitygithub "github.com/uber/submitqueue/entity/github" +) + +// validateChangeConsistency validates that all changeIDs in the stack are consistent. +// Stacked changes must have the same change provider (scheme), org, and repo. +func validateChangeConsistency( + changeIDs []entitygithub.ChangeID, +) error { + expectedScheme := changeIDs[0].Scheme + expectedOrg := changeIDs[0].Org + expectedRepo := changeIDs[0].Repo + + for _, cid := range changeIDs { + // Validate same change provider (scheme) + if cid.Scheme != expectedScheme { + return fmt.Errorf("stacked changes must use same change provider: expected %s, got %s for PR #%d", + expectedScheme, cid.Scheme, cid.PRNumber) + } + + // Validate same org and repo + if cid.Org != expectedOrg || cid.Repo != expectedRepo { + return fmt.Errorf("stacked changes must be from same org/repository: expected %s/%s, got %s/%s for PR #%d", + expectedOrg, expectedRepo, cid.Org, cid.Repo, cid.PRNumber) + } + } + + return nil +} + +// validatePRStaleness validates that the PR hasn't changed since submission. +// Compares the fetched head SHA with the expected SHA from the change URI. +func validatePRStaleness( + cid entitygithub.ChangeID, + prData *pullRequestData, +) error { + if prData.HeadRefOid != cid.HeadCommitSHA { + return fmt.Errorf("PR #%d head SHA changed: expected %s, got %s", + cid.PRNumber, cid.HeadCommitSHA, prData.HeadRefOid) + } + return nil +} diff --git a/extension/changeprovider/github/validate_test.go b/extension/changeprovider/github/validate_test.go new file mode 100644 index 00000000..0612c1ea --- /dev/null +++ b/extension/changeprovider/github/validate_test.go @@ -0,0 +1,99 @@ +package github + +import ( + "testing" + + "github.com/stretchr/testify/require" + entitygithub "github.com/uber/submitqueue/entity/github" +) + +func TestValidateChangeConsistency(t *testing.T) { + tests := []struct { + name string + changeIDs []entitygithub.ChangeID + wantErr bool + }{ + { + name: "single PR", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + }, + }, + { + name: "consistent stack", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 2}, + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 3}, + }, + }, + { + name: "different repo", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + {Scheme: "github", Org: "uber", Repo: "other-repo", PRNumber: 2}, + }, + wantErr: true, + }, + { + name: "different org", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + {Scheme: "github", Org: "other-org", Repo: "submitqueue", PRNumber: 2}, + }, + wantErr: true, + }, + { + name: "different scheme", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 1}, + {Scheme: "ghe", Org: "uber", Repo: "submitqueue", PRNumber: 2}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateChangeConsistency(tt.changeIDs) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidatePRStaleness(t *testing.T) { + tests := []struct { + name string + cid entitygithub.ChangeID + prData pullRequestData + wantErr bool + }{ + { + name: "matching SHA", + cid: entitygithub.ChangeID{PRNumber: 1, HeadCommitSHA: "abc123"}, + prData: pullRequestData{HeadRefOid: "abc123"}, + wantErr: false, + }, + { + name: "mismatched SHA", + cid: entitygithub.ChangeID{PRNumber: 1, HeadCommitSHA: "oldsha"}, + prData: pullRequestData{HeadRefOid: "newsha"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePRStaleness(tt.cid, &tt.prData) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/orchestrator/controller/validate/BUILD.bazel b/orchestrator/controller/validate/BUILD.bazel index 0ea6244a..a4a2f264 100644 --- a/orchestrator/controller/validate/BUILD.bazel +++ b/orchestrator/controller/validate/BUILD.bazel @@ -10,6 +10,7 @@ go_library( "//core/errs", "//entity", "//entity/queue", + "//extension/changeprovider", "//extension/mergechecker", "//extension/storage", "@com_github_uber_go_tally_v4//:tally", @@ -26,6 +27,7 @@ go_test( "//core/errs", "//entity", "//entity/queue", + "//extension/changeprovider", "//extension/mergechecker", "//extension/mergechecker/mock", "//extension/queue/mock", diff --git a/orchestrator/controller/validate/validate.go b/orchestrator/controller/validate/validate.go index 2ee8d8b1..d541d16c 100644 --- a/orchestrator/controller/validate/validate.go +++ b/orchestrator/controller/validate/validate.go @@ -21,8 +21,10 @@ import ( "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" "github.com/uber/submitqueue/core/errs" + coremetrics "github.com/uber/submitqueue/core/metrics" "github.com/uber/submitqueue/entity" entityqueue "github.com/uber/submitqueue/entity/queue" + "github.com/uber/submitqueue/extension/changeprovider" "github.com/uber/submitqueue/extension/mergechecker" "github.com/uber/submitqueue/extension/storage" "go.uber.org/zap" @@ -33,13 +35,14 @@ import ( // and publishes to the batch stage. Validation logic is extensible to support additional checks. // Implements consumer.Controller interface for integration with the consumer. type Controller struct { - logger *zap.SugaredLogger - metricsScope tally.Scope - store storage.Storage - registry consumer.TopicRegistry - mergeChecker mergechecker.MergeChecker - topicKey consumer.TopicKey - consumerGroup string + logger *zap.SugaredLogger + metricsScope tally.Scope + store storage.Storage + registry consumer.TopicRegistry + mergeChecker mergechecker.MergeChecker + changeProvider changeprovider.ChangeProvider + topicKey consumer.TopicKey + consumerGroup string } // Verify Controller implements consumer.Controller interface at compile time. @@ -52,39 +55,42 @@ func NewController( store storage.Storage, registry consumer.TopicRegistry, mergeChecker mergechecker.MergeChecker, + changeProvider changeprovider.ChangeProvider, topicKey consumer.TopicKey, consumerGroup string, ) *Controller { return &Controller{ - logger: logger.Named("validate_controller"), - metricsScope: scope.SubScope("validate_controller"), - store: store, - registry: registry, - mergeChecker: mergeChecker, - topicKey: topicKey, - consumerGroup: consumerGroup, + logger: logger.Named("validate_controller"), + metricsScope: scope.SubScope("validate_controller"), + store: store, + registry: registry, + mergeChecker: mergeChecker, + changeProvider: changeProvider, + topicKey: topicKey, + consumerGroup: consumerGroup, } } // Process processes a validate delivery from the queue. // Deserializes the request and publishes to the batch topic. // Returns nil to ack (success), or error to nack (retry). -func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) error { - c.metricsScope.Counter("received").Inc(1) +func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (retErr error) { + op := coremetrics.Begin(c.metricsScope, "process") + defer func() { op.Complete(retErr) }() msg := delivery.Message() // Deserialize request ID from payload rid, err := entity.RequestIDFromBytes(msg.Payload) if err != nil { - c.metricsScope.Counter("deserialize_errors").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "deserialize_errors", 1) return fmt.Errorf("failed to deserialize request ID: %w", err) } // Fetch request from storage request, err := c.store.GetRequestStore().Get(ctx, rid.ID) if err != nil { - c.metricsScope.Counter("storage_errors").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "storage_errors", 1) return fmt.Errorf("failed to get request %s: %w", rid.ID, err) } @@ -100,7 +106,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er // Merge conflict check mergeResult, err := c.mergeChecker.Check(ctx, request.Queue, request.Change) if err != nil { - c.metricsScope.Counter("merge_check_errors").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "merge_check_errors", 1) return fmt.Errorf("merge check failed: %w", err) } if !mergeResult.Mergeable { @@ -109,13 +115,31 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "queue", request.Queue, "reason", mergeResult.Reason, ) - c.metricsScope.Counter("not_mergeable").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "not_mergeable", 1) return errs.NewUserError(fmt.Errorf("request %s is not mergeable: %s", request.ID, mergeResult.Reason)) } + // Fetch change metadata + changeInfos, err := c.changeProvider.Get(ctx, request.Change) + if err != nil { + c.logger.Errorw("failed to fetch change information", + "request_id", request.ID, + "change_uris", request.Change.URIs, + "error", err, + ) + coremetrics.NamedCounter(c.metricsScope, "process", "change_provider_errors", 1) + return fmt.Errorf("failed to fetch change information: %w", err) + } + + c.logger.Infow("fetched change information", + "request_id", request.ID, + "change_count", len(changeInfos), + "total_files", totalFiles(changeInfos), + ) + // Publish to batch topic if err := c.publish(ctx, consumer.TopicKeyBatch, request.ID, request.Queue); err != nil { - c.metricsScope.Counter("publish_errors").Inc(1) + coremetrics.NamedCounter(c.metricsScope, "process", "publish_errors", 1) return fmt.Errorf("failed to publish to batch: %w", err) } @@ -124,8 +148,6 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "topic_key", consumer.TopicKeyBatch, ) - c.metricsScope.Counter("processed").Inc(1) - return nil // Success - message will be acked } @@ -156,6 +178,15 @@ func (c *Controller) publish(ctx context.Context, key consumer.TopicKey, request return nil } +// totalFiles returns the total number of files across all changeInfos. +func totalFiles(infos []changeprovider.ChangeInfo) int { + total := 0 + for _, info := range infos { + total += len(info.ChangedFiles) + } + return total +} + // Name returns the controller name for logging and metrics. func (c *Controller) Name() string { return "validate" diff --git a/orchestrator/controller/validate/validate_test.go b/orchestrator/controller/validate/validate_test.go index cba0500d..0d1b2ca7 100644 --- a/orchestrator/controller/validate/validate_test.go +++ b/orchestrator/controller/validate/validate_test.go @@ -26,6 +26,7 @@ import ( "github.com/uber/submitqueue/core/errs" "github.com/uber/submitqueue/entity" "github.com/uber/submitqueue/entity/queue" + "github.com/uber/submitqueue/extension/changeprovider" "github.com/uber/submitqueue/extension/mergechecker" mergecheckermock "github.com/uber/submitqueue/extension/mergechecker/mock" queuemock "github.com/uber/submitqueue/extension/queue/mock" @@ -41,6 +42,25 @@ func requestIDPayload(t *testing.T, id string) []byte { return payload } +// mockChangeProvider is a simple mock that returns test data. +type mockChangeProvider struct{} + +func (m *mockChangeProvider) Get(ctx context.Context, change entity.Change) ([]changeprovider.ChangeInfo, error) { + // Return simple test data + return []changeprovider.ChangeInfo{ + { + URI: "github://org/repo/123/abc123", + User: changeprovider.User{ + Name: "Test User", + Email: "test@example.com", + }, + ChangedFiles: []changeprovider.ChangedFile{ + {Path: "main.go"}, + }, + }, + }, nil +} + // newMergeableMock returns a mock MergeChecker that always returns mergeable. func newMergeableMock(ctrl *gomock.Controller) *mergecheckermock.MockMergeChecker { mc := mergecheckermock.NewMockMergeChecker(ctrl) @@ -78,7 +98,9 @@ func newTestController(t *testing.T, ctrl *gomock.Controller, store *storagemock ) require.NoError(t, err) - return NewController(logger, scope, store, registry, mc, consumer.TopicKeyValidate, "orchestrator-validate") + cp := &mockChangeProvider{} + + return NewController(logger, scope, store, registry, mc, cp, consumer.TopicKeyValidate, "orchestrator-validate") } func TestNewController(t *testing.T) {