Skip to content

Commit 6e4e32c

Browse files
committed
feat(mergechecker): add GitHub MergeChecker implementation with multi-host routing
Implements a GitHub-specific MergeChecker that uses the GraphQL API to validate PR mergeability and stack ordering. Adds a MultiChecker router that dispatches by Change.Source for multi-host support. - entity/github: ChangeID parser for {scheme}://{org}/{repo}/pull/{pr}/{sha} - extension/mergechecker/github: GraphQL-based checker with stack validation - extension/mergechecker: MultiChecker for source-based routing - extension/mergechecker/mock: gomock mock for testing - orchestrator/controller/request: wire merge check into pipeline - example/server: configure GitHub checker via env vars
1 parent b1e857a commit 6e4e32c

21 files changed

Lines changed: 1432 additions & 15 deletions

File tree

CLAUDE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ make clean # Clean Bazel cache
186186

187187
### Testing
188188

189+
- **Table-driven tests** — prefer table-driven tests with `t.Run` subtests over individual test functions.
189190
- **Avoid asserting on error messages** — assert on error type or generic error.
190191
- **No `time.Sleep` for synchronization** — use channels, callbacks, condition variables.
191192
- **Use testify**`assert`/`require` instead of `t.Fatal()`.

entity/github/BUILD.bazel

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
load("@rules_go//go:def.bzl", "go_library", "go_test")
2+
3+
go_library(
4+
name = "github",
5+
srcs = ["change_id.go"],
6+
importpath = "github.com/uber/submitqueue/entity/github",
7+
visibility = ["//visibility:public"],
8+
)
9+
10+
go_test(
11+
name = "github_test",
12+
srcs = ["change_id_test.go"],
13+
embed = [":github"],
14+
deps = [
15+
"@com_github_stretchr_testify//assert",
16+
"@com_github_stretchr_testify//require",
17+
],
18+
)

entity/github/change_id.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package github
2+
3+
import (
4+
"fmt"
5+
"strconv"
6+
"strings"
7+
)
8+
9+
// ChangeID represents a parsed GitHub-family change identifier.
10+
// Covers GitHub.com, GitHub Enterprise (GHE), and GitHub Enterprise Server (GHES)
11+
// since they share the same pull request model.
12+
// Format: {scheme}://{org}/{repo}/pull/{pr_number}/{head_commit_sha}
13+
type ChangeID struct {
14+
// Scheme captures the source variant (e.g., "github", "ghe", "ghes").
15+
Scheme string
16+
// Org is the organization or owner of the repository.
17+
Org string
18+
// Repo is the repository name.
19+
Repo string
20+
// PRNumber is the pull request number.
21+
PRNumber int
22+
// HeadCommitSHA is the head commit SHA at the time of request creation.
23+
HeadCommitSHA string
24+
}
25+
26+
// ParseChangeID parses a raw change ID string into a ChangeID.
27+
// Expected format: {scheme}://{org}/{repo}/pull/{pr_number}/{head_commit_sha}
28+
func ParseChangeID(raw string) (ChangeID, error) {
29+
// Split on "://" to get scheme and path
30+
parts := strings.SplitN(raw, "://", 2)
31+
if len(parts) != 2 {
32+
return ChangeID{}, fmt.Errorf("invalid change ID format: missing '://' separator in %q", raw)
33+
}
34+
35+
scheme := parts[0]
36+
if scheme == "" {
37+
return ChangeID{}, fmt.Errorf("invalid change ID format: empty scheme in %q", raw)
38+
}
39+
40+
// Split the path into segments: org/repo/pull/pr_number/sha
41+
segments := strings.Split(parts[1], "/")
42+
if len(segments) != 5 {
43+
return ChangeID{}, fmt.Errorf("invalid change ID format: expected 5 path segments, got %d in %q", len(segments), raw)
44+
}
45+
46+
org := segments[0]
47+
repo := segments[1]
48+
keyword := segments[2]
49+
prStr := segments[3]
50+
sha := segments[4]
51+
52+
if org == "" {
53+
return ChangeID{}, fmt.Errorf("invalid change ID format: empty org in %q", raw)
54+
}
55+
if repo == "" {
56+
return ChangeID{}, fmt.Errorf("invalid change ID format: empty repo in %q", raw)
57+
}
58+
if keyword != "pull" {
59+
return ChangeID{}, fmt.Errorf("invalid change ID format: expected 'pull' keyword, got %q in %q", keyword, raw)
60+
}
61+
62+
prNumber, err := strconv.Atoi(prStr)
63+
if err != nil {
64+
return ChangeID{}, fmt.Errorf("invalid change ID format: PR number %q is not a valid integer in %q", prStr, raw)
65+
}
66+
67+
if sha == "" {
68+
return ChangeID{}, fmt.Errorf("invalid change ID format: empty head commit SHA in %q", raw)
69+
}
70+
71+
return ChangeID{
72+
Scheme: scheme,
73+
Org: org,
74+
Repo: repo,
75+
PRNumber: prNumber,
76+
HeadCommitSHA: sha,
77+
}, nil
78+
}
79+
80+
// String returns the canonical string representation for round-trip serialization.
81+
func (c ChangeID) String() string {
82+
return fmt.Sprintf("%s://%s/%s/pull/%d/%s", c.Scheme, c.Org, c.Repo, c.PRNumber, c.HeadCommitSHA)
83+
}
84+
85+
// OwnerRepo returns the "{org}/{repo}" string.
86+
func (c ChangeID) OwnerRepo() string {
87+
return fmt.Sprintf("%s/%s", c.Org, c.Repo)
88+
}

entity/github/change_id_test.go

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
package github
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestParseChangeID(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
raw string
14+
want ChangeID
15+
wantErr bool
16+
}{
17+
{
18+
name: "valid github scheme",
19+
raw: "github://uber/submitqueue/pull/123/abc123def",
20+
want: ChangeID{
21+
Scheme: "github",
22+
Org: "uber",
23+
Repo: "submitqueue",
24+
PRNumber: 123,
25+
HeadCommitSHA: "abc123def",
26+
},
27+
},
28+
{
29+
name: "valid ghe scheme",
30+
raw: "ghe://uber/monorepo/pull/456/deadbeef",
31+
want: ChangeID{
32+
Scheme: "ghe",
33+
Org: "uber",
34+
Repo: "monorepo",
35+
PRNumber: 456,
36+
HeadCommitSHA: "deadbeef",
37+
},
38+
},
39+
{
40+
name: "valid ghes scheme",
41+
raw: "ghes://org/repo/pull/1/sha1",
42+
want: ChangeID{
43+
Scheme: "ghes",
44+
Org: "org",
45+
Repo: "repo",
46+
PRNumber: 1,
47+
HeadCommitSHA: "sha1",
48+
},
49+
},
50+
{
51+
name: "missing separator",
52+
raw: "github/uber/submitqueue/pull/123/abc123",
53+
wantErr: true,
54+
},
55+
{
56+
name: "empty scheme",
57+
raw: "://uber/submitqueue/pull/123/abc123",
58+
wantErr: true,
59+
},
60+
{
61+
name: "wrong number of path segments",
62+
raw: "github://uber/submitqueue/pull/123",
63+
wantErr: true,
64+
},
65+
{
66+
name: "too many path segments",
67+
raw: "github://uber/submitqueue/pull/123/abc/extra",
68+
wantErr: true,
69+
},
70+
{
71+
name: "empty org",
72+
raw: "github:///submitqueue/pull/123/abc123",
73+
wantErr: true,
74+
},
75+
{
76+
name: "empty repo",
77+
raw: "github://uber//pull/123/abc123",
78+
wantErr: true,
79+
},
80+
{
81+
name: "missing pull keyword",
82+
raw: "github://uber/submitqueue/pr/123/abc123",
83+
wantErr: true,
84+
},
85+
{
86+
name: "non-numeric PR number",
87+
raw: "github://uber/submitqueue/pull/abc/abc123",
88+
wantErr: true,
89+
},
90+
{
91+
name: "empty SHA",
92+
raw: "github://uber/submitqueue/pull/123/",
93+
wantErr: true,
94+
},
95+
{
96+
name: "empty string",
97+
raw: "",
98+
wantErr: true,
99+
},
100+
}
101+
102+
for _, tt := range tests {
103+
t.Run(tt.name, func(t *testing.T) {
104+
got, err := ParseChangeID(tt.raw)
105+
if tt.wantErr {
106+
require.Error(t, err)
107+
return
108+
}
109+
require.NoError(t, err)
110+
assert.Equal(t, tt.want, got)
111+
})
112+
}
113+
}
114+
115+
func TestChangeID_String(t *testing.T) {
116+
tests := []struct {
117+
name string
118+
id ChangeID
119+
want string
120+
}{
121+
{
122+
name: "github round-trip",
123+
id: ChangeID{
124+
Scheme: "github",
125+
Org: "uber",
126+
Repo: "submitqueue",
127+
PRNumber: 123,
128+
HeadCommitSHA: "abc123",
129+
},
130+
want: "github://uber/submitqueue/pull/123/abc123",
131+
},
132+
{
133+
name: "ghe round-trip",
134+
id: ChangeID{
135+
Scheme: "ghe",
136+
Org: "corp",
137+
Repo: "app",
138+
PRNumber: 99,
139+
HeadCommitSHA: "deadbeef",
140+
},
141+
want: "ghe://corp/app/pull/99/deadbeef",
142+
},
143+
{
144+
name: "ghes round-trip",
145+
id: ChangeID{
146+
Scheme: "ghes",
147+
Org: "org",
148+
Repo: "repo",
149+
PRNumber: 1,
150+
HeadCommitSHA: "sha1",
151+
},
152+
want: "ghes://org/repo/pull/1/sha1",
153+
},
154+
}
155+
156+
for _, tt := range tests {
157+
t.Run(tt.name, func(t *testing.T) {
158+
assert.Equal(t, tt.want, tt.id.String())
159+
})
160+
}
161+
}
162+
163+
func TestChangeID_OwnerRepo(t *testing.T) {
164+
id := ChangeID{
165+
Scheme: "github",
166+
Org: "uber",
167+
Repo: "submitqueue",
168+
PRNumber: 1,
169+
HeadCommitSHA: "abc",
170+
}
171+
assert.Equal(t, "uber/submitqueue", id.OwnerRepo())
172+
}
173+
174+
func TestParseChangeID_RoundTrip(t *testing.T) {
175+
originals := []string{
176+
"github://uber/submitqueue/pull/123/abc123def456",
177+
"ghe://corp/monorepo/pull/99/deadbeef01234567",
178+
"ghes://org/repo/pull/1/a1b2c3",
179+
}
180+
181+
for _, raw := range originals {
182+
t.Run(raw, func(t *testing.T) {
183+
parsed, err := ParseChangeID(raw)
184+
require.NoError(t, err)
185+
assert.Equal(t, raw, parsed.String())
186+
})
187+
}
188+
}

example/server/orchestrator/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ go_library(
1212
visibility = ["//visibility:private"],
1313
deps = [
1414
"//core/consumer",
15+
"//extension/mergechecker",
16+
"//extension/mergechecker/github",
1517
"//extension/queue",
1618
"//extension/queue/mysql",
1719
"//orchestrator/controller",

example/server/orchestrator/main.go

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"database/sql"
66
"fmt"
77
"net"
8+
"net/http"
89
"os"
910
"os/signal"
1011
"sync"
@@ -14,6 +15,8 @@ import (
1415
_ "github.com/go-sql-driver/mysql"
1516
"github.com/uber-go/tally/v4"
1617
"github.com/uber/submitqueue/core/consumer"
18+
"github.com/uber/submitqueue/extension/mergechecker"
19+
githubchecker "github.com/uber/submitqueue/extension/mergechecker/github"
1720
extqueue "github.com/uber/submitqueue/extension/queue"
1821
queueMySQL "github.com/uber/submitqueue/extension/queue/mysql"
1922
"github.com/uber/submitqueue/orchestrator/controller"
@@ -128,8 +131,11 @@ func run() error {
128131
// Create consumer
129132
c := consumer.New(logger.Sugar(), scope.SubScope("consumer"), registry)
130133

134+
// Create merge checker
135+
mc := newMergeChecker(logger, scope)
136+
131137
// Register controllers
132-
if err := registerControllers(c, logger.Sugar(), scope, registry); err != nil {
138+
if err := registerControllers(c, logger.Sugar(), scope, registry, mc); err != nil {
133139
return err
134140
}
135141

@@ -253,11 +259,12 @@ func newTopicRegistry(q extqueue.Queue, subscriberName string) consumer.TopicReg
253259
//
254260
// → merge → merge-signal
255261
// finalize (terminal)
256-
func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry) error {
262+
func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry, mc mergechecker.MergeChecker) error {
257263
requestController := request.NewController(
258264
logger,
259265
scope,
260266
registry,
267+
mc,
261268
consumer.TopicRequest,
262269
"orchestrator-request",
263270
)
@@ -344,3 +351,39 @@ func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope t
344351

345352
return nil
346353
}
354+
355+
// newMergeChecker creates a MergeChecker for GitHub (github.com).
356+
// Configured via GITHUB_TOKEN and GITHUB_GRAPHQL_URL environment variables.
357+
func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeChecker {
358+
graphQLURL := os.Getenv("GITHUB_GRAPHQL_URL")
359+
if graphQLURL == "" {
360+
graphQLURL = "https://api.github.com/graphql"
361+
}
362+
363+
httpClient := &http.Client{}
364+
if token := os.Getenv("GITHUB_TOKEN"); token != "" {
365+
httpClient.Transport = &bearerTransport{token: token}
366+
}
367+
368+
github := githubchecker.NewMergeChecker(githubchecker.Params{
369+
HTTPClient: httpClient,
370+
GraphQLURL: graphQLURL,
371+
Logger: logger.Sugar(),
372+
MetricsScope: scope.SubScope("mergechecker"),
373+
})
374+
375+
return mergechecker.NewMultiChecker(map[string]mergechecker.MergeChecker{
376+
"github": github,
377+
})
378+
}
379+
380+
// bearerTransport is an http.RoundTripper that adds a Bearer token to requests.
381+
type bearerTransport struct {
382+
token string
383+
}
384+
385+
func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
386+
req = req.Clone(req.Context())
387+
req.Header.Set("Authorization", "Bearer "+t.token)
388+
return http.DefaultTransport.RoundTrip(req)
389+
}

0 commit comments

Comments
 (0)