diff --git a/CLAUDE.md b/CLAUDE.md index 3d231b11..1b270f5d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -186,6 +186,7 @@ make clean # Clean Bazel cache ### Testing +- **Table-driven tests** — prefer table-driven tests with `t.Run` subtests over individual test functions. - **Avoid asserting on error messages** — assert on error type or generic error. - **No change detector tests** — don't assert on default values, internal structure, or implementation details that can change without affecting behavior. Test what the code *does*, not how it's constructed. - **No `time.Sleep` for synchronization** — use channels, callbacks, condition variables. diff --git a/entity/github/BUILD.bazel b/entity/github/BUILD.bazel new file mode 100644 index 00000000..3cabb00b --- /dev/null +++ b/entity/github/BUILD.bazel @@ -0,0 +1,18 @@ +load("@rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "github", + srcs = ["change_id.go"], + importpath = "github.com/uber/submitqueue/entity/github", + visibility = ["//visibility:public"], +) + +go_test( + name = "github_test", + srcs = ["change_id_test.go"], + embed = [":github"], + deps = [ + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + ], +) diff --git a/entity/github/change_id.go b/entity/github/change_id.go new file mode 100644 index 00000000..c3747efd --- /dev/null +++ b/entity/github/change_id.go @@ -0,0 +1,99 @@ +package github + +import ( + "fmt" + "strconv" + "strings" +) + +// changeIDFormat is the expected format for change IDs, included in error messages. +const changeIDFormat = "{scheme}://{owner}/{repo}/{pr_number}/{head_commit_sha}" + +// ChangeID represents a parsed GitHub-family change identifier. +// Covers GitHub.com, GitHub Enterprise (GHE), and GitHub Enterprise Server (GHES) +// since they share the same pull request model. +// Format: {scheme}://{owner}/{repo}/{pr_number}/{head_commit_sha} +type ChangeID struct { + // Scheme captures the source variant (e.g., "github", "ghe", "ghes"). + Scheme string + // Org is the organization or owner of the repository. + Org string + // Repo is the repository name. + Repo string + // PRNumber is the pull request number. + PRNumber int + // HeadCommitSHA is the head commit SHA at the time of request creation. + HeadCommitSHA string +} + +// ParseChangeID parses a raw change ID string into a ChangeID. +// Expected format: {scheme}://{owner}/{repo}/{pr_number}/{head_commit_sha} +// The parser works from the end: SHA (last), PR number (second-to-last), +// and everything before is the repo path (split into owner and repo). +func ParseChangeID(raw string) (ChangeID, error) { + // Split on "://" to get scheme and path + schemeSplit := strings.SplitN(raw, "://", 2) + if len(schemeSplit) != 2 { + return ChangeID{}, fmt.Errorf("invalid change ID %q: missing '://' separator (expected format: %s)", raw, changeIDFormat) + } + + scheme := schemeSplit[0] + if scheme == "" { + return ChangeID{}, fmt.Errorf("invalid change ID %q: empty scheme (expected format: %s)", raw, changeIDFormat) + } + + path := schemeSplit[1] + + // Split the path into segments and parse from the end. + segments := strings.Split(path, "/") + // Need at least 4 segments: {owner}/{repo}/{pr_number}/{sha} + if len(segments) < 4 { + return ChangeID{}, fmt.Errorf("invalid change ID %q: need at least owner/repo/pr/sha, got %d segments (expected format: %s)", raw, len(segments), changeIDFormat) + } + + sha := segments[len(segments)-1] + prStr := segments[len(segments)-2] + repoSegments := segments[:len(segments)-2] + + if sha == "" { + return ChangeID{}, fmt.Errorf("invalid change ID %q: empty head commit SHA (expected format: %s)", raw, changeIDFormat) + } + + prNumber, err := strconv.Atoi(prStr) + if err != nil { + return ChangeID{}, fmt.Errorf("invalid change ID %q: PR number %q is not a valid integer (expected format: %s)", raw, prStr, changeIDFormat) + } + + // Split repo path: last segment is repo name, everything before is the owner. + if len(repoSegments) < 2 { + return ChangeID{}, fmt.Errorf("invalid change ID %q: repo path must have at least owner/repo (expected format: %s)", raw, changeIDFormat) + } + + repo := repoSegments[len(repoSegments)-1] + org := strings.Join(repoSegments[:len(repoSegments)-1], "/") + + if org == "" { + return ChangeID{}, fmt.Errorf("invalid change ID %q: empty owner (expected format: %s)", raw, changeIDFormat) + } + if repo == "" { + return ChangeID{}, fmt.Errorf("invalid change ID %q: empty repo (expected format: %s)", raw, changeIDFormat) + } + + return ChangeID{ + Scheme: scheme, + Org: org, + Repo: repo, + PRNumber: prNumber, + HeadCommitSHA: sha, + }, nil +} + +// String returns the string representation of the change ID. +func (c ChangeID) String() string { + return fmt.Sprintf("%s://%s/%s/%d/%s", c.Scheme, c.Org, c.Repo, c.PRNumber, c.HeadCommitSHA) +} + +// OwnerRepo returns the "{org}/{repo}" string. +func (c ChangeID) OwnerRepo() string { + return fmt.Sprintf("%s/%s", c.Org, c.Repo) +} diff --git a/entity/github/change_id_test.go b/entity/github/change_id_test.go new file mode 100644 index 00000000..ca6a0487 --- /dev/null +++ b/entity/github/change_id_test.go @@ -0,0 +1,194 @@ +package github + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseChangeID(t *testing.T) { + tests := []struct { + name string + raw string + want ChangeID + wantErr bool + }{ + { + name: "valid github scheme", + raw: "github://uber/submitqueue/123/abc123def", + want: ChangeID{ + Scheme: "github", + Org: "uber", + Repo: "submitqueue", + PRNumber: 123, + HeadCommitSHA: "abc123def", + }, + }, + { + name: "valid ghe scheme", + raw: "ghe://uber/monorepo/456/deadbeef", + want: ChangeID{ + Scheme: "ghe", + Org: "uber", + Repo: "monorepo", + PRNumber: 456, + HeadCommitSHA: "deadbeef", + }, + }, + { + name: "valid ghes scheme", + raw: "ghes://org/repo/1/sha1", + want: ChangeID{ + Scheme: "ghes", + Org: "org", + Repo: "repo", + PRNumber: 1, + HeadCommitSHA: "sha1", + }, + }, + { + name: "nested org path", + raw: "github://uber/frontend/webapp/42/abc123", + want: ChangeID{ + Scheme: "github", + Org: "uber/frontend", + Repo: "webapp", + PRNumber: 42, + HeadCommitSHA: "abc123", + }, + }, + { + name: "missing separator", + raw: "github/uber/submitqueue/123/abc123", + wantErr: true, + }, + { + name: "empty scheme", + raw: "://uber/submitqueue/123/abc123", + wantErr: true, + }, + { + name: "too few segments", + raw: "github://uber/123/abc123", + wantErr: true, + }, + { + name: "only one segment", + raw: "github://abc123", + wantErr: true, + }, + { + name: "empty owner", + raw: "github:///submitqueue/123/abc123", + wantErr: true, + }, + { + name: "empty repo", + raw: "github://uber//123/abc123", + wantErr: true, + }, + { + name: "non-numeric PR number", + raw: "github://uber/submitqueue/abc/abc123", + wantErr: true, + }, + { + name: "empty SHA", + raw: "github://uber/submitqueue/123/", + wantErr: true, + }, + { + name: "empty string", + raw: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseChangeID(tt.raw) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestChangeID_String(t *testing.T) { + tests := []struct { + name string + id ChangeID + want string + }{ + { + name: "github", + id: ChangeID{ + Scheme: "github", + Org: "uber", + Repo: "submitqueue", + PRNumber: 123, + HeadCommitSHA: "abc123", + }, + want: "github://uber/submitqueue/123/abc123", + }, + { + name: "ghe", + id: ChangeID{ + Scheme: "ghe", + Org: "corp", + Repo: "app", + PRNumber: 99, + HeadCommitSHA: "deadbeef", + }, + want: "ghe://corp/app/99/deadbeef", + }, + { + name: "ghes", + id: ChangeID{ + Scheme: "ghes", + Org: "org", + Repo: "repo", + PRNumber: 1, + HeadCommitSHA: "sha1", + }, + want: "ghes://org/repo/1/sha1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.id.String()) + }) + } +} + +func TestChangeID_OwnerRepo(t *testing.T) { + id := ChangeID{ + Scheme: "github", + Org: "uber", + Repo: "submitqueue", + PRNumber: 1, + HeadCommitSHA: "abc", + } + assert.Equal(t, "uber/submitqueue", id.OwnerRepo()) +} + +func TestParseChangeID_RoundTrip(t *testing.T) { + originals := []string{ + "github://uber/submitqueue/123/abc123def456", + "ghe://corp/monorepo/99/deadbeef01234567", + "ghes://org/repo/1/a1b2c3", + } + + for _, raw := range originals { + t.Run(raw, func(t *testing.T) { + parsed, err := ParseChangeID(raw) + require.NoError(t, err) + assert.Equal(t, raw, parsed.String()) + }) + } +} diff --git a/example/server/orchestrator/BUILD.bazel b/example/server/orchestrator/BUILD.bazel index 7cf1a272..9db3e45e 100644 --- a/example/server/orchestrator/BUILD.bazel +++ b/example/server/orchestrator/BUILD.bazel @@ -12,6 +12,8 @@ go_library( visibility = ["//visibility:private"], deps = [ "//core/consumer", + "//extension/mergechecker", + "//extension/mergechecker/github", "//extension/queue", "//extension/queue/mysql", "//orchestrator/controller", diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index e838591d..2e2f6afd 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "net" + "net/http" "os" "os/signal" "sync" @@ -14,6 +15,8 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" + "github.com/uber/submitqueue/extension/mergechecker" + githubchecker "github.com/uber/submitqueue/extension/mergechecker/github" extqueue "github.com/uber/submitqueue/extension/queue" queueMySQL "github.com/uber/submitqueue/extension/queue/mysql" "github.com/uber/submitqueue/orchestrator/controller" @@ -128,8 +131,11 @@ func run() error { // Create consumer c := consumer.New(logger.Sugar(), scope.SubScope("consumer"), registry) + // Create merge checker + mc := newMergeChecker(logger, scope) + // Register controllers - if err := registerControllers(c, logger.Sugar(), scope, registry); err != nil { + if err := registerControllers(c, logger.Sugar(), scope, registry, mc); err != nil { return err } @@ -253,11 +259,12 @@ func newTopicRegistry(q extqueue.Queue, subscriberName string) consumer.TopicReg // // → merge → merge-signal // finalize (terminal) -func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry) error { +func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry, mc mergechecker.MergeChecker) error { requestController := request.NewController( logger, scope, registry, + mc, consumer.TopicRequest, "orchestrator-request", ) @@ -344,3 +351,39 @@ func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope t return nil } + +// newMergeChecker creates a MergeChecker for GitHub (github.com). +// Configured via GITHUB_TOKEN and GITHUB_GRAPHQL_URL environment variables. +func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeChecker { + graphQLURL := os.Getenv("GITHUB_GRAPHQL_URL") + if graphQLURL == "" { + graphQLURL = "https://api.github.com/graphql" + } + + httpClient := &http.Client{} + if token := os.Getenv("GITHUB_TOKEN"); token != "" { + httpClient.Transport = &bearerTransport{token: token} + } + + github := githubchecker.NewMergeChecker(githubchecker.Params{ + HTTPClient: httpClient, + GraphQLURL: graphQLURL, + Logger: logger.Sugar(), + MetricsScope: scope.SubScope("mergechecker"), + }) + + return mergechecker.NewMultiChecker(map[string]mergechecker.MergeChecker{ + "github": github, + }) +} + +// bearerTransport is an http.RoundTripper that adds a Bearer token to requests. +type bearerTransport struct { + token string +} + +func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Set("Authorization", "Bearer "+t.token) + return http.DefaultTransport.RoundTrip(req) +} diff --git a/extension/mergechecker/BUILD.bazel b/extension/mergechecker/BUILD.bazel index 39e2c33c..1f797fde 100644 --- a/extension/mergechecker/BUILD.bazel +++ b/extension/mergechecker/BUILD.bazel @@ -1,9 +1,23 @@ -load("@rules_go//go:def.bzl", "go_library") +load("@rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "mergechecker", - srcs = ["mergechecker.go"], + srcs = [ + "mergechecker.go", + "multi.go", + ], importpath = "github.com/uber/submitqueue/extension/mergechecker", visibility = ["//visibility:public"], deps = ["//entity"], ) + +go_test( + name = "mergechecker_test", + srcs = ["multi_test.go"], + embed = [":mergechecker"], + deps = [ + "//entity", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + ], +) diff --git a/extension/mergechecker/github/BUILD.bazel b/extension/mergechecker/github/BUILD.bazel new file mode 100644 index 00000000..df212f17 --- /dev/null +++ b/extension/mergechecker/github/BUILD.bazel @@ -0,0 +1,38 @@ +load("@rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "github", + srcs = [ + "checker.go", + "graphql.go", + "validate.go", + ], + importpath = "github.com/uber/submitqueue/extension/mergechecker/github", + visibility = ["//visibility:public"], + deps = [ + "//entity", + "//entity/github", + "//extension/mergechecker", + "@com_github_uber_go_tally_v4//:tally", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "github_test", + srcs = [ + "checker_test.go", + "graphql_test.go", + "validate_test.go", + ], + embed = [":github"], + deps = [ + "//entity", + "//entity/github", + "//extension/mergechecker", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@com_github_uber_go_tally_v4//:tally", + "@org_uber_go_zap//zaptest", + ], +) diff --git a/extension/mergechecker/github/checker.go b/extension/mergechecker/github/checker.go new file mode 100644 index 00000000..576f9a37 --- /dev/null +++ b/extension/mergechecker/github/checker.go @@ -0,0 +1,131 @@ +package github + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/uber-go/tally/v4" + "github.com/uber/submitqueue/entity" + entitygithub "github.com/uber/submitqueue/entity/github" + "github.com/uber/submitqueue/extension/mergechecker" + "go.uber.org/zap" +) + +// Params holds the dependencies for the GitHub MergeChecker. +type Params struct { + // HTTPClient is a pre-configured HTTP client with auth (bearer token, GitHub App JWT, etc.). + // Auth is the caller's responsibility via HTTP transport/round-tripper. + HTTPClient *http.Client + // GraphQLURL is the GitHub GraphQL API endpoint + // (e.g., "https://api.github.com/graphql" or "https://ghe.example.com/api/graphql"). + GraphQLURL string + // Logger is the structured logger. + Logger *zap.SugaredLogger + // MetricsScope is the metrics scope for instrumentation. + MetricsScope tally.Scope +} + +// mergeChecker implements the mergechecker.MergeChecker interface using the GitHub GraphQL API. +type mergeChecker struct { + httpClient *http.Client + graphQLURL string + logger *zap.SugaredLogger + metricsScope tally.Scope +} + +// Verify mergeChecker implements mergechecker.MergeChecker at compile time. +var _ mergechecker.MergeChecker = (*mergeChecker)(nil) + +// NewMergeChecker creates a new GitHub MergeChecker. +func NewMergeChecker(params Params) mergechecker.MergeChecker { + return &mergeChecker{ + httpClient: params.HTTPClient, + graphQLURL: params.GraphQLURL, + logger: params.Logger.Named("github_mergechecker"), + metricsScope: params.MetricsScope.SubScope("github_mergechecker"), + } +} + +// Check assesses whether a change can merge cleanly using the GitHub GraphQL API. +func (c *mergeChecker) Check(ctx context.Context, queue string, change entity.Change) (mergechecker.Result, error) { + c.metricsScope.Counter("check_started").Inc(1) + + result := mergechecker.Result{} + // Parse all change IDs + // TODO: classify parse errors as user errors (non-retryable) vs system errors. + changeIDs := make([]entitygithub.ChangeID, 0, len(change.URIs)) + for _, rawID := range change.URIs { + cid, err := entitygithub.ParseChangeID(rawID) + if err != nil { + c.metricsScope.Counter("parse_errors").Inc(1) + return result, fmt.Errorf("failed to parse change ID %q: %w", rawID, err) + } + changeIDs = append(changeIDs, cid) + } + + // Fetch PR info from GitHub GraphQL API + prInfoMap, err := c.fetchPRInfo(ctx, changeIDs) + if err != nil { + c.metricsScope.Counter("graphql_errors").Inc(1) + return result, fmt.Errorf("failed to fetch PR info: %w", err) + } + + // Validate PR mergeability + mergeable, reason, err := validatePRs(changeIDs, prInfoMap) + if err != nil { + c.metricsScope.Counter("validation_errors").Inc(1) + return result, err + } + + if !mergeable { + c.metricsScope.Counter("not_mergeable").Inc(1) + c.logger.Infow("change not mergeable", + "queue", queue, + "reason", reason, + "change_uris", change.URIs, + ) + } else { + c.metricsScope.Counter("mergeable").Inc(1) + } + + result.Mergeable = mergeable + result.Reason = reason + return result, nil +} + +// fetchPRInfo executes a batched GraphQL query to fetch PR info for all change IDs. +func (c *mergeChecker) fetchPRInfo(ctx context.Context, changeIDs []entitygithub.ChangeID) (map[int]PRInfo, error) { + query := buildGraphQLQuery(changeIDs) + + reqBody, err := json.Marshal(graphQLRequest{Query: query}) + if err != nil { + return nil, fmt.Errorf("failed to marshal graphql request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.graphQLURL, bytes.NewReader(reqBody)) + if err != nil { + return nil, fmt.Errorf("failed to create http request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("graphql request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read graphql response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("graphql request returned status %d: %s", resp.StatusCode, string(body)) + } + + return parseGraphQLResponse(body, changeIDs) +} diff --git a/extension/mergechecker/github/checker_test.go b/extension/mergechecker/github/checker_test.go new file mode 100644 index 00000000..77043aee --- /dev/null +++ b/extension/mergechecker/github/checker_test.go @@ -0,0 +1,154 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/uber-go/tally/v4" + "github.com/uber/submitqueue/entity" + "github.com/uber/submitqueue/extension/mergechecker" + "go.uber.org/zap/zaptest" +) + +func newTestMergeChecker(t *testing.T, serverURL string) mergechecker.MergeChecker { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + + return NewMergeChecker(Params{ + HTTPClient: &http.Client{}, + GraphQLURL: serverURL, + Logger: logger, + MetricsScope: scope, + }) +} + +func graphQLHandler(t *testing.T, prInfos []PRInfo) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + t.Helper() + + data := make(map[string]json.RawMessage, len(prInfos)) + for i, pr := range prInfos { + alias := fmt.Sprintf("pr%d", i) + prJSON, err := json.Marshal(map[string]any{ + "pullRequest": map[string]any{ + "number": pr.Number, + "mergeable": string(pr.Mergeable), + "baseRefName": pr.BaseRefName, + "headRefName": pr.HeadRefName, + "headRefOid": pr.HeadRefOid, + "state": string(pr.State), + }, + }) + require.NoError(t, err) + data[alias] = json.RawMessage(prJSON) + } + + resp := graphQLResponse{Data: data} + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(resp) + require.NoError(t, err) + } +} + +func TestMergeChecker_Check(t *testing.T) { + tests := []struct { + name string + handler http.HandlerFunc + queue string + change entity.Change + wantMergeable bool + wantReason string + wantErr bool + }{ + { + name: "single PR mergeable", + handler: graphQLHandler(t, []PRInfo{ + {Number: 1, Mergeable: PRMergeableStateMergeable, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "abc123", State: PRStateOpen}, + }), + queue: "test-queue", + change: entity.Change{URIs: []string{"github://uber/repo/1/abc123"}}, + wantMergeable: true, + }, + { + name: "single PR conflicting", + handler: graphQLHandler(t, []PRInfo{ + {Number: 1, Mergeable: PRMergeableStateConflicting, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "abc123", State: PRStateOpen}, + }), + queue: "test-queue", + change: entity.Change{URIs: []string{"github://uber/repo/1/abc123"}}, + wantMergeable: false, + wantReason: "PR #1 has merge conflicts", + }, + { + name: "stack of two PRs mergeable", + handler: graphQLHandler(t, []PRInfo{ + {Number: 1, Mergeable: PRMergeableStateMergeable, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "sha1", State: PRStateOpen}, + {Number: 2, Mergeable: PRMergeableStateMergeable, BaseRefName: "feature-1", HeadRefName: "feature-2", HeadRefOid: "sha2", State: PRStateOpen}, + }), + queue: "test-queue", + change: entity.Change{URIs: []string{"github://uber/repo/1/sha1", "github://uber/repo/2/sha2"}}, + wantMergeable: true, + }, + { + name: "unknown mergeability returns error", + handler: graphQLHandler(t, []PRInfo{ + {Number: 1, Mergeable: PRMergeableStateUnknown, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "abc123", State: PRStateOpen}, + }), + queue: "test-queue", + change: entity.Change{URIs: []string{"github://uber/repo/1/abc123"}}, + wantErr: true, + }, + { + name: "stale SHA not mergeable", + handler: graphQLHandler(t, []PRInfo{ + {Number: 1, Mergeable: PRMergeableStateMergeable, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "new_sha", State: PRStateOpen}, + }), + queue: "test-queue", + change: entity.Change{URIs: []string{"github://uber/repo/1/old_sha"}}, + wantMergeable: false, + wantReason: "PR #1 head SHA changed: expected old_sha, got new_sha", + }, + { + name: "invalid change ID", + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("should not reach server") + }), + queue: "test-queue", + change: entity.Change{URIs: []string{"invalid-change-id"}}, + wantErr: true, + }, + { + name: "server error", + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("internal server error")) + }), + queue: "test-queue", + change: entity.Change{URIs: []string{"github://uber/repo/1/abc123"}}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(tt.handler) + defer server.Close() + + mc := newTestMergeChecker(t, server.URL) + result, err := mc.Check(context.Background(), tt.queue, tt.change) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantMergeable, result.Mergeable) + assert.Equal(t, tt.wantReason, result.Reason) + }) + } +} diff --git a/extension/mergechecker/github/graphql.go b/extension/mergechecker/github/graphql.go new file mode 100644 index 00000000..820edca6 --- /dev/null +++ b/extension/mergechecker/github/graphql.go @@ -0,0 +1,117 @@ +package github + +import ( + "encoding/json" + "fmt" + "strings" + + entitygithub "github.com/uber/submitqueue/entity/github" +) + +// graphQLRequest is the request body for the GitHub GraphQL API. +type graphQLRequest struct { + // Query is the GraphQL query string. + Query string `json:"query"` +} + +// graphQLResponse is the top-level response from the GitHub GraphQL API. +type graphQLResponse struct { + // Data contains the query results keyed by alias. + Data map[string]json.RawMessage `json:"data"` + // Errors contains any GraphQL errors. + Errors []graphQLError `json:"errors"` +} + +// graphQLError represents a single GraphQL error. +type graphQLError struct { + // Message is the error message. + Message string `json:"message"` +} + +// repositoryResponse represents a repository query result. +type repositoryResponse struct { + // PullRequest contains the PR data. + PullRequest prResponse `json:"pullRequest"` +} + +// prResponse represents the fields fetched for a single pull request. +type prResponse struct { + // Number is the PR number. + Number int `json:"number"` + // Mergeable is the mergeability state. + Mergeable string `json:"mergeable"` + // BaseRefName is the base branch name. + BaseRefName string `json:"baseRefName"` + // HeadRefName is the head branch name. + HeadRefName string `json:"headRefName"` + // HeadRefOid is the head commit SHA. + HeadRefOid string `json:"headRefOid"` + // State is the PR state (OPEN, CLOSED, MERGED). + State string `json:"state"` +} + +// buildGraphQLQuery builds a batched GraphQL query for multiple PRs. +// Each PR gets an alias like "pr0", "pr1", etc. +func buildGraphQLQuery(changeIDs []entitygithub.ChangeID) string { + var sb strings.Builder + sb.WriteString("query {") + + for i, cid := range changeIDs { + fmt.Fprintf(&sb, ` + pr%d: repository(owner: %q, name: %q) { + pullRequest(number: %d) { + number + mergeable + baseRefName + headRefName + headRefOid + state + } + }`, i, cid.Org, cid.Repo, cid.PRNumber) + } + + sb.WriteString("\n}") + return sb.String() +} + +// parseGraphQLResponse parses the GraphQL response body and returns a map of PR number to PRInfo. +func parseGraphQLResponse(body []byte, changeIDs []entitygithub.ChangeID) (map[int]PRInfo, error) { + var resp graphQLResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse GraphQL response: %w", err) + } + + if len(resp.Errors) > 0 { + messages := make([]string, len(resp.Errors)) + for i, e := range resp.Errors { + messages[i] = e.Message + } + return nil, fmt.Errorf("GraphQL errors: %s", strings.Join(messages, "; ")) + } + + result := make(map[int]PRInfo, len(changeIDs)) + for i := range changeIDs { + alias := fmt.Sprintf("pr%d", i) + raw, ok := resp.Data[alias] + if !ok { + return nil, fmt.Errorf("missing alias %q in GraphQL response", alias) + } + + var repoResp repositoryResponse + if err := json.Unmarshal(raw, &repoResp); err != nil { + return nil, fmt.Errorf("failed to parse alias %q: %w", alias, err) + } + + pr := repoResp.PullRequest + result[pr.Number] = PRInfo{ + Number: pr.Number, + Mergeable: PRMergeableState(pr.Mergeable), + BaseRefName: pr.BaseRefName, + HeadRefName: pr.HeadRefName, + HeadRefOid: pr.HeadRefOid, + State: PRState(pr.State), + } + } + + return result, nil +} diff --git a/extension/mergechecker/github/graphql_test.go b/extension/mergechecker/github/graphql_test.go new file mode 100644 index 00000000..78a69709 --- /dev/null +++ b/extension/mergechecker/github/graphql_test.go @@ -0,0 +1,127 @@ +package github + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + entitygithub "github.com/uber/submitqueue/entity/github" +) + +func TestBuildGraphQLQuery(t *testing.T) { + tests := []struct { + name string + changeIDs []entitygithub.ChangeID + wantParts []string + }{ + { + name: "single PR", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "submitqueue", PRNumber: 42, HeadCommitSHA: "abc123"}, + }, + wantParts: []string{ + "query {", + `pr0: repository(owner: "uber", name: "submitqueue")`, + "pullRequest(number: 42)", + "number", "mergeable", "baseRefName", "headRefName", "headRefOid", "state", + }, + }, + { + name: "multiple PRs across repos", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 1, HeadCommitSHA: "sha1"}, + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 2, HeadCommitSHA: "sha2"}, + {Scheme: "ghe", Org: "corp", Repo: "app", PRNumber: 99, HeadCommitSHA: "sha99"}, + }, + wantParts: []string{ + `pr0: repository(owner: "uber", name: "repo")`, + "pullRequest(number: 1)", + `pr1: repository(owner: "uber", name: "repo")`, + "pullRequest(number: 2)", + `pr2: repository(owner: "corp", name: "app")`, + "pullRequest(number: 99)", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query := buildGraphQLQuery(tt.changeIDs) + for _, part := range tt.wantParts { + assert.Contains(t, query, part) + } + }) + } +} + +func TestParseGraphQLResponse(t *testing.T) { + tests := []struct { + name string + body string + changeIDs []entitygithub.ChangeID + want map[int]PRInfo + wantErr bool + }{ + { + name: "success with two PRs", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 1, HeadCommitSHA: "sha1"}, + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 2, HeadCommitSHA: "sha2"}, + }, + body: mustMarshalGraphQLResponse(t, map[string]json.RawMessage{ + "pr0": json.RawMessage(`{"pullRequest":{"number":1,"mergeable":"MERGEABLE","baseRefName":"main","headRefName":"feature-1","headRefOid":"sha1","state":"OPEN"}}`), + "pr1": json.RawMessage(`{"pullRequest":{"number":2,"mergeable":"CONFLICTING","baseRefName":"feature-1","headRefName":"feature-2","headRefOid":"sha2","state":"OPEN"}}`), + }), + want: map[int]PRInfo{ + 1: {Number: 1, Mergeable: PRMergeableStateMergeable, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "sha1", State: PRStateOpen}, + 2: {Number: 2, Mergeable: PRMergeableStateConflicting, BaseRefName: "feature-1", HeadRefName: "feature-2", HeadRefOid: "sha2", State: PRStateOpen}, + }, + }, + { + name: "GraphQL errors", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 1, HeadCommitSHA: "sha1"}, + }, + body: `{"data":null,"errors":[{"message":"Not Found"},{"message":"Forbidden"}]}`, + wantErr: true, + }, + { + name: "invalid JSON", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 1, HeadCommitSHA: "sha1"}, + }, + body: `invalid`, + wantErr: true, + }, + { + name: "missing alias", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 1, HeadCommitSHA: "sha1"}, + }, + body: `{"data":{}}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseGraphQLResponse([]byte(tt.body), tt.changeIDs) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, result) + }) + } +} + +// mustMarshalGraphQLResponse is a test helper to build a GraphQL response body. +func mustMarshalGraphQLResponse(t *testing.T, data map[string]json.RawMessage) string { + t.Helper() + resp := graphQLResponse{Data: data} + body, err := json.Marshal(resp) + require.NoError(t, err) + return string(body) +} diff --git a/extension/mergechecker/github/validate.go b/extension/mergechecker/github/validate.go new file mode 100644 index 00000000..7ec3aad7 --- /dev/null +++ b/extension/mergechecker/github/validate.go @@ -0,0 +1,87 @@ +package github + +import ( + "fmt" + + entitygithub "github.com/uber/submitqueue/entity/github" +) + +// PRMergeableState represents the mergeability state of a pull request. +type PRMergeableState string + +const ( + // PRMergeableStateMergeable indicates the PR can be merged cleanly. + PRMergeableStateMergeable PRMergeableState = "MERGEABLE" + // PRMergeableStateConflicting indicates the PR has merge conflicts. + PRMergeableStateConflicting PRMergeableState = "CONFLICTING" + // PRMergeableStateUnknown indicates GitHub hasn't computed mergeability yet. + // GitHub computes mergeability asynchronously after pushes. The GraphQL API + // returns UNKNOWN when the computation hasn't finished, even though the API + // call itself is synchronous. Callers should retry after a short delay. + PRMergeableStateUnknown PRMergeableState = "UNKNOWN" +) + +// PRState represents the state of a pull request. +type PRState string + +const ( + // PRStateOpen indicates the PR is open. + PRStateOpen PRState = "OPEN" + // PRStateClosed indicates the PR is closed. + PRStateClosed PRState = "CLOSED" + // PRStateMerged indicates the PR has been merged. + PRStateMerged PRState = "MERGED" +) + +// PRInfo holds the relevant pull request information fetched from GitHub. +type PRInfo struct { + // Number is the pull request number. + Number int + // Mergeable is the mergeability state of the PR. + Mergeable PRMergeableState + // BaseRefName is the base branch the PR targets. + BaseRefName string + // HeadRefName is the head branch of the PR. + HeadRefName string + // HeadRefOid is the current head commit SHA of the PR. + HeadRefOid string + // State is the current state of the PR (OPEN, CLOSED, MERGED). + State PRState +} + +// validatePRs validates that all PRs are open, individually mergeable, and not stale. +// Returns (true, "", nil) if all PRs pass validation. +// Returns (false, reason, nil) if definitively not mergeable (conflicts, closed, stale SHA). +// Returns (false, "", error) if mergeability is UNKNOWN (retryable — GitHub hasn't computed yet). +func validatePRs(changeIDs []entitygithub.ChangeID, prInfoMap map[int]PRInfo) (bool, string, error) { + for _, cid := range changeIDs { + pr, ok := prInfoMap[cid.PRNumber] + if !ok { + return false, "", fmt.Errorf("PR #%d not found in API response", cid.PRNumber) + } + + // Check PR is open + if pr.State != PRStateOpen { + return false, fmt.Sprintf("PR #%d is %s", cid.PRNumber, pr.State), nil + } + + // Check mergeability + switch pr.Mergeable { + case PRMergeableStateConflicting: + return false, fmt.Sprintf("PR #%d has merge conflicts", cid.PRNumber), nil + case PRMergeableStateUnknown: + return false, "", fmt.Errorf("mergeability unknown for PR #%d, retry later", cid.PRNumber) + case PRMergeableStateMergeable: + // OK, continue + default: + return false, "", fmt.Errorf("unexpected mergeable state %q for PR #%d", pr.Mergeable, cid.PRNumber) + } + + // Check head commit SHA matches (staleness check) + if pr.HeadRefOid != cid.HeadCommitSHA { + return false, fmt.Sprintf("PR #%d head SHA changed: expected %s, got %s", cid.PRNumber, cid.HeadCommitSHA, pr.HeadRefOid), nil + } + } + + return true, "", nil +} diff --git a/extension/mergechecker/github/validate_test.go b/extension/mergechecker/github/validate_test.go new file mode 100644 index 00000000..c8805add --- /dev/null +++ b/extension/mergechecker/github/validate_test.go @@ -0,0 +1,135 @@ +package github + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + entitygithub "github.com/uber/submitqueue/entity/github" +) + +func TestValidatePRs(t *testing.T) { + tests := []struct { + name string + changeIDs []entitygithub.ChangeID + prInfoMap map[int]PRInfo + wantOK bool + wantReason string + wantErr bool + }{ + { + name: "single PR mergeable", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 1, HeadCommitSHA: "abc123"}, + }, + prInfoMap: map[int]PRInfo{ + 1: {Number: 1, Mergeable: PRMergeableStateMergeable, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "abc123", State: PRStateOpen}, + }, + wantOK: true, + }, + { + name: "stack of three PRs all mergeable", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 1, HeadCommitSHA: "sha1"}, + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 2, HeadCommitSHA: "sha2"}, + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 3, HeadCommitSHA: "sha3"}, + }, + prInfoMap: map[int]PRInfo{ + 1: {Number: 1, Mergeable: PRMergeableStateMergeable, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "sha1", State: PRStateOpen}, + 2: {Number: 2, Mergeable: PRMergeableStateMergeable, BaseRefName: "feature-1", HeadRefName: "feature-2", HeadRefOid: "sha2", State: PRStateOpen}, + 3: {Number: 3, Mergeable: PRMergeableStateMergeable, BaseRefName: "feature-2", HeadRefName: "feature-3", HeadRefOid: "sha3", State: PRStateOpen}, + }, + wantOK: true, + }, + { + name: "PR closed", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 1, HeadCommitSHA: "sha1"}, + }, + prInfoMap: map[int]PRInfo{ + 1: {Number: 1, Mergeable: PRMergeableStateMergeable, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "sha1", State: PRStateClosed}, + }, + wantOK: false, + wantReason: "PR #1 is CLOSED", + }, + { + name: "PR already merged", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 1, HeadCommitSHA: "sha1"}, + }, + prInfoMap: map[int]PRInfo{ + 1: {Number: 1, Mergeable: PRMergeableStateMergeable, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "sha1", State: PRStateMerged}, + }, + wantOK: false, + wantReason: "PR #1 is MERGED", + }, + { + name: "PR has conflicts", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 1, HeadCommitSHA: "sha1"}, + }, + prInfoMap: map[int]PRInfo{ + 1: {Number: 1, Mergeable: PRMergeableStateConflicting, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "sha1", State: PRStateOpen}, + }, + wantOK: false, + wantReason: "PR #1 has merge conflicts", + }, + { + name: "unknown mergeability returns error", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 1, HeadCommitSHA: "sha1"}, + }, + prInfoMap: map[int]PRInfo{ + 1: {Number: 1, Mergeable: PRMergeableStateUnknown, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "sha1", State: PRStateOpen}, + }, + wantOK: false, + wantErr: true, + }, + { + name: "stale SHA", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 1, HeadCommitSHA: "old_sha"}, + }, + prInfoMap: map[int]PRInfo{ + 1: {Number: 1, Mergeable: PRMergeableStateMergeable, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "new_sha", State: PRStateOpen}, + }, + wantOK: false, + wantReason: "PR #1 head SHA changed: expected old_sha, got new_sha", + }, + { + name: "PR not found in map", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 999, HeadCommitSHA: "sha1"}, + }, + prInfoMap: map[int]PRInfo{}, + wantOK: false, + wantErr: true, + }, + { + name: "second PR in stack conflicting", + changeIDs: []entitygithub.ChangeID{ + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 1, HeadCommitSHA: "sha1"}, + {Scheme: "github", Org: "uber", Repo: "repo", PRNumber: 2, HeadCommitSHA: "sha2"}, + }, + prInfoMap: map[int]PRInfo{ + 1: {Number: 1, Mergeable: PRMergeableStateMergeable, BaseRefName: "main", HeadRefName: "feature-1", HeadRefOid: "sha1", State: PRStateOpen}, + 2: {Number: 2, Mergeable: PRMergeableStateConflicting, BaseRefName: "feature-1", HeadRefName: "feature-2", HeadRefOid: "sha2", State: PRStateOpen}, + }, + wantOK: false, + wantReason: "PR #2 has merge conflicts", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ok, reason, err := validatePRs(tt.changeIDs, tt.prInfoMap) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.wantOK, ok) + assert.Equal(t, tt.wantReason, reason) + }) + } +} diff --git a/extension/mergechecker/mergechecker.go b/extension/mergechecker/mergechecker.go index ba1690ac..bfeed16d 100644 --- a/extension/mergechecker/mergechecker.go +++ b/extension/mergechecker/mergechecker.go @@ -6,16 +6,19 @@ import ( "github.com/uber/submitqueue/entity" ) -// MergeChecker predicts whether a request's changes can merge cleanly. +// MergeChecker predicts whether a set of changes can merge cleanly. type MergeChecker interface { - // Check is a fail-fast validation that optimistically assesses the - // mergeability of the request. A positive result does not guarantee - // that the changes will apply cleanly at merge finalization time. - Check(ctx context.Context, request entity.Request) (Result, error) + // Check is a fail-fast mergeability check that optimistically assesses + // whether the changes can be merged. A positive result does not + // guarantee that the changes will apply cleanly at merge time. + Check(ctx context.Context, queue string, change entity.Change) (Result, error) } -// Result holds the outcome of a merge check. +// Result holds the outcome of a mergeability check. type Result struct { // Mergeable is true if the request's changes are expected to merge cleanly. Mergeable bool + // Reason is a human-readable explanation when Mergeable is false. + // Empty when Mergeable is true. + Reason string } diff --git a/extension/mergechecker/mock/BUILD.bazel b/extension/mergechecker/mock/BUILD.bazel new file mode 100644 index 00000000..bb83b0fe --- /dev/null +++ b/extension/mergechecker/mock/BUILD.bazel @@ -0,0 +1,13 @@ +load("@rules_go//go:def.bzl", "go_library") + +go_library( + name = "mock", + srcs = ["mergechecker_mock.go"], + importpath = "github.com/uber/submitqueue/extension/mergechecker/mock", + visibility = ["//visibility:public"], + deps = [ + "//entity", + "//extension/mergechecker", + "@org_uber_go_mock//gomock", + ], +) diff --git a/extension/mergechecker/mock/mergechecker_mock.go b/extension/mergechecker/mock/mergechecker_mock.go new file mode 100644 index 00000000..f40e4397 --- /dev/null +++ b/extension/mergechecker/mock/mergechecker_mock.go @@ -0,0 +1,58 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: extension/mergechecker/mergechecker.go +// +// Generated by this command: +// +// mockgen -source=extension/mergechecker/mergechecker.go -destination=extension/mergechecker/mock/mergechecker_mock.go -package=mock +// + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + entity "github.com/uber/submitqueue/entity" + mergechecker "github.com/uber/submitqueue/extension/mergechecker" + gomock "go.uber.org/mock/gomock" +) + +// MockMergeChecker is a mock of MergeChecker interface. +type MockMergeChecker struct { + ctrl *gomock.Controller + recorder *MockMergeCheckerMockRecorder + isgomock struct{} +} + +// MockMergeCheckerMockRecorder is the mock recorder for MockMergeChecker. +type MockMergeCheckerMockRecorder struct { + mock *MockMergeChecker +} + +// NewMockMergeChecker creates a new mock instance. +func NewMockMergeChecker(ctrl *gomock.Controller) *MockMergeChecker { + mock := &MockMergeChecker{ctrl: ctrl} + mock.recorder = &MockMergeCheckerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMergeChecker) EXPECT() *MockMergeCheckerMockRecorder { + return m.recorder +} + +// Check mocks base method. +func (m *MockMergeChecker) Check(ctx context.Context, queue string, change entity.Change) (mergechecker.Result, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Check", ctx, queue, change) + ret0, _ := ret[0].(mergechecker.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Check indicates an expected call of Check. +func (mr *MockMergeCheckerMockRecorder) Check(ctx, queue, change any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Check", reflect.TypeOf((*MockMergeChecker)(nil).Check), ctx, queue, change) +} diff --git a/extension/mergechecker/multi.go b/extension/mergechecker/multi.go new file mode 100644 index 00000000..89bca150 --- /dev/null +++ b/extension/mergechecker/multi.go @@ -0,0 +1,43 @@ +package mergechecker + +import ( + "context" + "fmt" + "strings" + + "github.com/uber/submitqueue/entity" +) + +// multiChecker dispatches mergeability checks to scheme-specific checkers +// based on the URI scheme of the first change URI. Each scheme +// (e.g., "github", "ghe", "ghes") maps to a checker configured for that host. +type multiChecker struct { + // checkers maps URI scheme values to their corresponding MergeChecker. + checkers map[string]MergeChecker +} + +// NewMultiChecker creates a MergeChecker that routes mergeability checks +// to scheme-specific checkers. The map keys correspond to URI schemes +// (e.g., "github", "ghe") extracted from the first change URI. +func NewMultiChecker(checkers map[string]MergeChecker) MergeChecker { + return &multiChecker{checkers: checkers} +} + +// Check dispatches the mergeability check to the checker registered for +// the change URI scheme. +func (m *multiChecker) Check(ctx context.Context, queue string, change entity.Change) (Result, error) { + if len(change.URIs) == 0 { + return Result{}, fmt.Errorf("no change URIs provided") + } + + scheme, _, ok := strings.Cut(change.URIs[0], "://") + if !ok || scheme == "" { + return Result{}, fmt.Errorf("invalid change URI %q: missing scheme", change.URIs[0]) + } + + checker, ok := m.checkers[scheme] + if !ok { + return Result{}, fmt.Errorf("no mergeability checker configured for scheme %q", scheme) + } + return checker.Check(ctx, queue, change) +} diff --git a/extension/mergechecker/multi_test.go b/extension/mergechecker/multi_test.go new file mode 100644 index 00000000..080157cc --- /dev/null +++ b/extension/mergechecker/multi_test.go @@ -0,0 +1,68 @@ +package mergechecker + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/uber/submitqueue/entity" +) + +// stubChecker is a test stub that returns a fixed result. +type stubChecker struct { + result Result + err error +} + +func (s *stubChecker) Check(_ context.Context, _ string, _ entity.Change) (Result, error) { + return s.result, s.err +} + +func TestMultiChecker_RoutesToCorrectChecker(t *testing.T) { + githubChecker := &stubChecker{result: Result{Mergeable: true}} + gheChecker := &stubChecker{result: Result{Mergeable: false}} + + mc := NewMultiChecker(map[string]MergeChecker{ + "github": githubChecker, + "ghe": gheChecker, + }) + + // Route to github checker + result, err := mc.Check(context.Background(), "test-queue", entity.Change{URIs: []string{"github://uber/repo/1/abc123"}}) + require.NoError(t, err) + assert.True(t, result.Mergeable) + + // Route to ghe checker + result, err = mc.Check(context.Background(), "test-queue", entity.Change{URIs: []string{"ghe://uber/repo/1/abc123"}}) + require.NoError(t, err) + assert.False(t, result.Mergeable) +} + +func TestMultiChecker_UnknownScheme(t *testing.T) { + mc := NewMultiChecker(map[string]MergeChecker{ + "github": &stubChecker{result: Result{Mergeable: true}}, + }) + + _, err := mc.Check(context.Background(), "test-queue", entity.Change{URIs: []string{"unknown://uber/repo/1/abc123"}}) + require.Error(t, err) +} + +func TestMultiChecker_PropagatesError(t *testing.T) { + mc := NewMultiChecker(map[string]MergeChecker{ + "github": &stubChecker{err: fmt.Errorf("api failure")}, + }) + + _, err := mc.Check(context.Background(), "test-queue", entity.Change{URIs: []string{"github://uber/repo/1/abc123"}}) + require.Error(t, err) +} + +func TestMultiChecker_EmptyURIs(t *testing.T) { + mc := NewMultiChecker(map[string]MergeChecker{ + "github": &stubChecker{result: Result{Mergeable: true}}, + }) + + _, err := mc.Check(context.Background(), "test-queue", entity.Change{URIs: []string{}}) + require.Error(t, err) +} diff --git a/orchestrator/controller/request/BUILD.bazel b/orchestrator/controller/request/BUILD.bazel index d737d595..f28241a6 100644 --- a/orchestrator/controller/request/BUILD.bazel +++ b/orchestrator/controller/request/BUILD.bazel @@ -9,6 +9,7 @@ go_library( "//core/consumer", "//entity", "//entity/queue", + "//extension/mergechecker", "@com_github_uber_go_tally_v4//:tally", "@org_uber_go_zap//:zap", ], @@ -22,6 +23,8 @@ go_test( "//core/consumer", "//entity", "//entity/queue", + "//extension/mergechecker", + "//extension/mergechecker/mock", "//extension/queue/mock", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", diff --git a/orchestrator/controller/request/request.go b/orchestrator/controller/request/request.go index 71749bae..66d8ee91 100644 --- a/orchestrator/controller/request/request.go +++ b/orchestrator/controller/request/request.go @@ -8,6 +8,7 @@ import ( "github.com/uber/submitqueue/core/consumer" "github.com/uber/submitqueue/entity" entityqueue "github.com/uber/submitqueue/entity/queue" + "github.com/uber/submitqueue/extension/mergechecker" "go.uber.org/zap" ) @@ -18,6 +19,7 @@ type Controller struct { logger *zap.SugaredLogger metricsScope tally.Scope registry consumer.TopicRegistry + mergeChecker mergechecker.MergeChecker topic consumer.Topic consumerGroup string } @@ -30,6 +32,7 @@ func NewController( logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry, + mergeChecker mergechecker.MergeChecker, topic consumer.Topic, consumerGroup string, ) *Controller { @@ -37,6 +40,7 @@ func NewController( logger: logger.Named("request_controller"), metricsScope: scope.SubScope("request_controller"), registry: registry, + mergeChecker: mergeChecker, topic: topic, consumerGroup: consumerGroup, } @@ -76,8 +80,25 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "partition_key", msg.PartitionKey, ) - // TODO: Add validation logic - // - Merge Conflict Check + // Merge conflict check + mergeResult, err := c.mergeChecker.Check(ctx, request.Queue, request.Change) + if err != nil { + c.logger.Errorw("merge check failed", + "request_id", request.ID, + "error", err, + ) + c.metricsScope.Counter("merge_check_errors").Inc(1) + return fmt.Errorf("merge check failed: %w", err) + } + if !mergeResult.Mergeable { + c.logger.Infow("request not mergeable", + "request_id", request.ID, + "queue", request.Queue, + "reason", mergeResult.Reason, + ) + c.metricsScope.Counter("not_mergeable").Inc(1) + return consumer.NewNonRetryableError(fmt.Errorf("request %s is not mergeable: %s", request.ID, mergeResult.Reason)) + } // Publish to batch topic if err := c.publish(ctx, consumer.TopicToBatch, request); err != nil { diff --git a/orchestrator/controller/request/request_test.go b/orchestrator/controller/request/request_test.go index e08f1c77..da5cca69 100644 --- a/orchestrator/controller/request/request_test.go +++ b/orchestrator/controller/request/request_test.go @@ -11,13 +11,22 @@ import ( "github.com/uber/submitqueue/core/consumer" "github.com/uber/submitqueue/entity" "github.com/uber/submitqueue/entity/queue" + "github.com/uber/submitqueue/extension/mergechecker" + mergecheckermock "github.com/uber/submitqueue/extension/mergechecker/mock" queuemock "github.com/uber/submitqueue/extension/queue/mock" "go.uber.org/mock/gomock" "go.uber.org/zap/zaptest" ) +// newMergeableMock returns a mock MergeChecker that always returns mergeable. +func newMergeableMock(ctrl *gomock.Controller) *mergecheckermock.MockMergeChecker { + mc := mergecheckermock.NewMockMergeChecker(ctrl) + mc.EXPECT().Check(gomock.Any(), gomock.Any(), gomock.Any()).Return(mergechecker.Result{Mergeable: true}, nil).AnyTimes() + return mc +} + // newTestController creates a controller with test dependencies. -func newTestController(t *testing.T, ctrl *gomock.Controller, publishErr error) *Controller { +func newTestController(t *testing.T, ctrl *gomock.Controller, mc mergechecker.MergeChecker, publishErr error) *Controller { logger := zaptest.NewLogger(t).Sugar() scope := tally.NoopScope @@ -36,12 +45,13 @@ func newTestController(t *testing.T, ctrl *gomock.Controller, publishErr error) nil, ) - return NewController(logger, scope, registry, consumer.TopicRequest, "orchestrator-request") + return NewController(logger, scope, registry, mc, consumer.TopicRequest, "orchestrator-request") } func TestNewController(t *testing.T) { ctrl := gomock.NewController(t) - controller := newTestController(t, ctrl, nil) + mc := newMergeableMock(ctrl) + controller := newTestController(t, ctrl, mc, nil) require.NotNil(t, controller) assert.Equal(t, consumer.TopicRequest, controller.Topic()) @@ -51,8 +61,9 @@ func TestNewController(t *testing.T) { func TestController_Process_Success(t *testing.T) { ctrl := gomock.NewController(t) + mc := newMergeableMock(ctrl) - controller := newTestController(t, ctrl, nil) + controller := newTestController(t, ctrl, mc, nil) request := entity.Request{ ID: "test-queue/123", @@ -77,8 +88,9 @@ func TestController_Process_Success(t *testing.T) { func TestController_Process_InvalidJSON(t *testing.T) { ctrl := gomock.NewController(t) + mc := newMergeableMock(ctrl) - controller := newTestController(t, ctrl, nil) + controller := newTestController(t, ctrl, mc, nil) invalidPayload := []byte(`{"invalid": json"}`) msg := queue.NewMessage("invalid-msg", invalidPayload, "partition1", nil) @@ -109,8 +121,9 @@ func TestController_Process_AllRequestStates(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctrl := gomock.NewController(t) + mc := newMergeableMock(ctrl) - controller := newTestController(t, ctrl, nil) + controller := newTestController(t, ctrl, mc, nil) request := entity.Request{ ID: fmt.Sprintf("queue/%s", tt.state), @@ -137,8 +150,9 @@ func TestController_Process_AllRequestStates(t *testing.T) { func TestController_Process_MultipleChanges(t *testing.T) { ctrl := gomock.NewController(t) + mc := newMergeableMock(ctrl) - controller := newTestController(t, ctrl, nil) + controller := newTestController(t, ctrl, mc, nil) request := entity.Request{ ID: "queue/999", @@ -169,8 +183,9 @@ func TestController_Process_MultipleChanges(t *testing.T) { func TestController_Process_PublishFailure(t *testing.T) { ctrl := gomock.NewController(t) + mc := newMergeableMock(ctrl) - controller := newTestController(t, ctrl, fmt.Errorf("publish failed")) + controller := newTestController(t, ctrl, mc, fmt.Errorf("publish failed")) request := entity.Request{ ID: "test-queue/123", @@ -195,7 +210,69 @@ func TestController_Process_PublishFailure(t *testing.T) { func TestController_InterfaceImplementation(t *testing.T) { ctrl := gomock.NewController(t) - controller := newTestController(t, ctrl, nil) + mc := newMergeableMock(ctrl) + controller := newTestController(t, ctrl, mc, nil) var _ consumer.Controller = controller } + +func TestController_Process_NotMergeable(t *testing.T) { + ctrl := gomock.NewController(t) + + mc := mergecheckermock.NewMockMergeChecker(ctrl) + mc.EXPECT().Check(gomock.Any(), gomock.Any(), gomock.Any()).Return(mergechecker.Result{Mergeable: false}, nil) + + controller := newTestController(t, ctrl, mc, nil) + + request := entity.Request{ + ID: "test-queue/123", + Queue: "test-queue", + Change: entity.Change{URIs: []string{"github://uber/repo/1/abc123"}}, + LandStrategy: entity.RequestLandStrategyRebase, + State: entity.RequestStateNew, + Version: 1, + } + + payload, err := request.ToBytes() + require.NoError(t, err) + + msg := queue.NewMessage(request.ID, payload, request.Queue, nil) + delivery := queuemock.NewMockDelivery(ctrl) + delivery.EXPECT().Message().Return(msg).AnyTimes() + delivery.EXPECT().Attempt().Return(1).AnyTimes() + + err = controller.Process(context.Background(), delivery) + require.Error(t, err) + assert.True(t, consumer.IsNonRetryable(err)) +} + +func TestController_Process_MergeCheckError(t *testing.T) { + ctrl := gomock.NewController(t) + + mc := mergecheckermock.NewMockMergeChecker(ctrl) + mc.EXPECT().Check(gomock.Any(), gomock.Any(), gomock.Any()).Return(mergechecker.Result{}, fmt.Errorf("merge check failed")) + + controller := newTestController(t, ctrl, mc, nil) + + request := entity.Request{ + ID: "test-queue/123", + Queue: "test-queue", + Change: entity.Change{URIs: []string{"github://uber/repo/1/abc123"}}, + LandStrategy: entity.RequestLandStrategyRebase, + State: entity.RequestStateNew, + Version: 1, + } + + payload, err := request.ToBytes() + require.NoError(t, err) + + msg := queue.NewMessage(request.ID, payload, request.Queue, nil) + delivery := queuemock.NewMockDelivery(ctrl) + delivery.EXPECT().Message().Return(msg).AnyTimes() + delivery.EXPECT().Attempt().Return(1).AnyTimes() + + err = controller.Process(context.Background(), delivery) + require.Error(t, err) + // Merge check errors should be retryable (not NonRetryableError) + assert.False(t, consumer.IsNonRetryable(err)) +}