diff --git a/pkg/provider/github/parse_payload.go b/pkg/provider/github/parse_payload.go index 7300e8615..709e15056 100644 --- a/pkg/provider/github/parse_payload.go +++ b/pkg/provider/github/parse_payload.go @@ -314,13 +314,55 @@ func selectSingleOpenPullRequest(prs []*github.PullRequest) (*github.PullRequest } } +func (v *Provider) findOpenPullRequestBySHA(ctx context.Context, org, repo, sha string) (*github.PullRequest, error) { + const maxPages = 10 + opts := &github.PullRequestListOptions{ + State: "open", + Sort: "updated", + ListOptions: github.ListOptions{PerPage: 100}, + } + var matches []*github.PullRequest + + for page := 0; page < maxPages; page++ { + prs, resp, err := wrapAPI(v, "list_pull_requests", func() ([]*github.PullRequest, *github.Response, error) { + return v.Client().PullRequests.List(ctx, org, repo, opts) + }) + if err != nil { + return nil, fmt.Errorf("failed to list open pull requests in %s/%s: %w", org, repo, err) + } + + for _, pr := range prs { + if pr.GetHead().GetSHA() == sha { + matches = append(matches, pr) + } + } + + if resp.NextPage == 0 { + break + } + opts.Page = resp.NextPage + } + + return selectSingleOpenPullRequest(matches) +} + func (v *Provider) resolveReRequestPullRequest(ctx context.Context, runevent *info.Event) (*github.PullRequest, error) { prs, err := v.getPullRequestsWithCommit(ctx, runevent.SHA, runevent.Organization, runevent.Repository, false) if err != nil { return nil, err } - return selectSingleOpenPullRequest(prs) + pr, err := selectSingleOpenPullRequest(prs) + if err != nil { + return nil, err + } + if pr != nil { + return pr, nil + } + + // ListPullRequestsWithCommit may return no matches for fork PR commits. + v.Logger.Infof("No PR found via commits API for SHA %s, falling back to open PR listing", runevent.SHA) + return v.findOpenPullRequestBySHA(ctx, runevent.Organization, runevent.Repository, runevent.SHA) } func (v *Provider) processEvent(ctx context.Context, event *info.Event, eventInt any) (*info.Event, error) { @@ -491,44 +533,58 @@ func (v *Provider) handleReRequestEvent(ctx context.Context, event *github.Check if event.GetRepo() == nil { return nil, errors.New("error parsing payload the repository should not be nil") } + checkRun := event.GetCheckRun() + checkSuite := checkRun.GetCheckSuite() + runevent.Organization = event.GetRepo().GetOwner().GetLogin() runevent.Repository = event.GetRepo().GetName() runevent.URL = event.GetRepo().GetHTMLURL() runevent.DefaultBranch = event.GetRepo().GetDefaultBranch() - runevent.SHA = event.GetCheckRun().GetCheckSuite().GetHeadSHA() - runevent.HeadBranch = event.GetCheckRun().GetCheckSuite().GetHeadBranch() - runevent.HeadURL = event.GetCheckRun().GetCheckSuite().GetRepository().GetHTMLURL() - // If we don't have a pull_request in this it probably mean a push - if len(event.GetCheckRun().GetCheckSuite().PullRequests) == 0 { - // If head_branch is null, try to find a PR by SHA before assuming push - if runevent.HeadBranch == "" && runevent.SHA != "" { - pr, err := v.resolveReRequestPullRequest(ctx, runevent) - if err != nil { - return nil, fmt.Errorf("cannot determine pull request for check_run rerequest and SHA %s: %w", runevent.SHA, err) - } - if pr != nil { - runevent.PullRequestNumber = pr.GetNumber() - runevent.TriggerTarget = triggertype.PullRequest - v.Logger.Infof("Recheck of PR %s/%s#%d has been requested (resolved from SHA)", runevent.Organization, runevent.Repository, runevent.PullRequestNumber) - return v.populateRunEventFromPullRequest(runevent, pr), nil - } + runevent.SHA = checkSuite.GetHeadSHA() + runevent.HeadBranch = checkSuite.GetHeadBranch() + runevent.HeadURL = checkSuite.GetRepository().GetHTMLURL() + + if len(checkSuite.PullRequests) > 0 { + runevent.PullRequestNumber = checkSuite.PullRequests[0].GetNumber() + runevent.TriggerTarget = triggertype.PullRequest + v.Logger.Infof("Recheck of PR %s/%s#%d has been requested", runevent.Organization, runevent.Repository, runevent.PullRequestNumber) + return v.getPullRequest(ctx, runevent) + } + + if len(checkRun.PullRequests) > 1 { + return nil, fmt.Errorf("cannot determine pull request for check_run rerequest: found %d associated pull requests in webhook payload", len(checkRun.PullRequests)) + } + if len(checkRun.PullRequests) == 1 { + runevent.PullRequestNumber = checkRun.PullRequests[0].GetNumber() + runevent.TriggerTarget = triggertype.PullRequest + v.Logger.Infof("Recheck of PR %s/%s#%d has been requested (from check_run)", runevent.Organization, runevent.Repository, runevent.PullRequestNumber) + return v.getPullRequest(ctx, runevent) + } + + // If head_branch is null, try to find a PR by SHA before assuming push. + if runevent.HeadBranch == "" && runevent.SHA != "" { + pr, err := v.resolveReRequestPullRequest(ctx, runevent) + if err != nil { + return nil, fmt.Errorf("cannot determine pull request for check_run rerequest and SHA %s: %w", runevent.SHA, err) } - if runevent.HeadBranch == "" { - return nil, fmt.Errorf("cannot determine branch for check_run rerequest: head_branch is null and no associated PR found for SHA %s", runevent.SHA) + if pr != nil { + runevent.PullRequestNumber = pr.GetNumber() + runevent.TriggerTarget = triggertype.PullRequest + v.Logger.Infof("Recheck of PR %s/%s#%d has been requested (resolved from SHA)", runevent.Organization, runevent.Repository, runevent.PullRequestNumber) + return v.populateRunEventFromPullRequest(runevent, pr), nil } - runevent.BaseBranch = runevent.HeadBranch - runevent.BaseURL = runevent.HeadURL - runevent.EventType = "push" - // we allow the rerequest user here, not the push user, i guess it's - // fine because you can't do a rereq without being a github owner? - runevent.Sender = event.GetSender().GetLogin() - v.userType = event.GetSender().GetType() - return runevent, nil } - runevent.PullRequestNumber = event.GetCheckRun().GetCheckSuite().PullRequests[0].GetNumber() - runevent.TriggerTarget = triggertype.PullRequest - v.Logger.Infof("Recheck of PR %s/%s#%d has been requested", runevent.Organization, runevent.Repository, runevent.PullRequestNumber) - return v.getPullRequest(ctx, runevent) + if runevent.HeadBranch == "" { + return nil, fmt.Errorf("cannot determine branch for check_run rerequest: head_branch is null and no associated PR found for SHA %s", runevent.SHA) + } + runevent.BaseBranch = runevent.HeadBranch + runevent.BaseURL = runevent.HeadURL + runevent.EventType = "push" + // we allow the rerequest user here, not the push user, i guess it's + // fine because you can't do a rereq without being a github owner? + runevent.Sender = event.GetSender().GetLogin() + v.userType = event.GetSender().GetType() + return runevent, nil } func (v *Provider) handleCheckSuites(ctx context.Context, event *github.CheckSuiteEvent) (*info.Event, error) { diff --git a/pkg/provider/github/parse_payload_test.go b/pkg/provider/github/parse_payload_test.go index 386f30758..f195622e4 100644 --- a/pkg/provider/github/parse_payload_test.go +++ b/pkg/provider/github/parse_payload_test.go @@ -336,6 +336,76 @@ func TestGetPullRequestsWithCommit(t *testing.T) { } } +func TestFindOpenPullRequestBySHA(t *testing.T) { + tests := []struct { + name string + sha string + setup func(t *testing.T, mux *http.ServeMux) + wantPRNumber int + wantErr string + }{ + { + name: "single matching open PR", + sha: "forkPRsha", + setup: func(t *testing.T, mux *http.ServeMux) { + t.Helper() + mux.HandleFunc("/repos/owner/reponame/pulls", func(rw http.ResponseWriter, _ *http.Request) { + fmt.Fprint(rw, `[ + {"number": 101, "state": "open", "head": {"sha": "otherSHA"}}, + {"number": 202, "state": "open", "head": {"sha": "forkPRsha"}} + ]`) + }) + }, + wantPRNumber: 202, + }, + { + name: "multiple matching open PRs across pages returns ambiguity error", + sha: "ambiguousSHA", + setup: func(t *testing.T, mux *http.ServeMux) { + t.Helper() + mux.HandleFunc("/repos/owner/reponame/pulls", func(rw http.ResponseWriter, r *http.Request) { + switch r.URL.Query().Get("page") { + case "", "1": + rw.Header().Set("Link", `; rel="next"`) + fmt.Fprint(rw, `[{"number": 101, "state": "open", "head": {"sha": "ambiguousSHA"}}]`) + case "2": + fmt.Fprint(rw, `[{"number": 202, "state": "open", "head": {"sha": "ambiguousSHA"}}]`) + default: + t.Fatalf("unexpected page %q", r.URL.Query().Get("page")) + } + }) + }, + wantErr: "found 2 open pull requests associated with the commit", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, _ := rtesting.SetupFakeContext(t) + fakeclient, mux, _, teardown := ghtesthelper.SetupGH() + defer teardown() + + tt.setup(t, mux) + + logger, _ := logger.GetLogger() + provider := &Provider{ + ghClient: fakeclient, + Logger: logger, + } + + pr, err := provider.findOpenPullRequestBySHA(ctx, "owner", "reponame", tt.sha) + if tt.wantErr != "" { + assert.ErrorContains(t, err, tt.wantErr) + return + } + + assert.NilError(t, err) + assert.Assert(t, pr != nil) + assert.Equal(t, tt.wantPRNumber, pr.GetNumber()) + }) + } +} + func TestIsCommitPartOfPullRequest(t *testing.T) { tests := []struct { name string @@ -644,6 +714,153 @@ func TestParsePayLoad(t *testing.T) { shaRet: "samplePRsha", wantedPullRequestNumber: 54321, }, + { + name: "good/rerequest check_run resolves PR from check_run pull requests", + eventType: "check_run", + githubClient: true, + triggerTarget: string(triggertype.PullRequest), + payloadEventStruct: github.CheckRunEvent{ + Action: github.Ptr("rerequested"), + Repo: sampleRepo, + CheckRun: &github.CheckRun{ + PullRequests: []*github.PullRequest{&samplePR}, + CheckSuite: &github.CheckSuite{ + HeadSHA: github.Ptr("samplePRsha"), + }, + }, + }, + muxReplies: map[string]any{ + "/repos/owner/reponame/pulls/54321": samplePR, + }, + shaRet: "samplePRsha", + wantedPullRequestNumber: 54321, + }, + { + name: "bad/rerequest check_run with multiple pull requests in payload", + eventType: "check_run", + githubClient: true, + wantErrString: "cannot determine pull request for check_run rerequest: found 2 associated pull requests in webhook payload", + payloadEventStruct: github.CheckRunEvent{ + Action: github.Ptr("rerequested"), + Repo: sampleRepo, + CheckRun: &github.CheckRun{ + PullRequests: []*github.PullRequest{ + &samplePR, + &samplePRAnother, + }, + CheckSuite: &github.CheckSuite{ + HeadSHA: github.Ptr("samplePRsha"), + }, + }, + }, + shaRet: "samplePRsha", + }, + { + name: "good/rerequest check_run null head_branch resolves fork PR from open PR list", + eventType: "check_run", + githubClient: true, + triggerTarget: string(triggertype.PullRequest), + payloadEventStruct: github.CheckRunEvent{ + Action: github.Ptr("rerequested"), + Repo: sampleRepo, + CheckRun: &github.CheckRun{ + CheckSuite: &github.CheckSuite{ + HeadSHA: github.Ptr("forkPRsha"), + }, + }, + }, + muxReplies: map[string]any{ + "/repos/owner/reponame/commits/forkPRsha/pulls": []*github.PullRequest{}, + "/repos/owner/reponame/pulls": []*github.PullRequest{ + { + Number: github.Ptr(987), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/reponame/pull/987"), + Head: &github.PullRequestBranch{ + SHA: github.Ptr("forkPRsha"), + Ref: github.Ptr("fork-feature"), + Repo: &github.Repository{ + Owner: &github.User{ + Login: github.Ptr("fork-owner"), + }, + Name: github.Ptr("reponame"), + HTMLURL: github.Ptr("https://github.com/fork-owner/reponame"), + }, + }, + Base: &github.PullRequestBranch{ + Ref: github.Ptr("main"), + SHA: github.Ptr("basesha"), + Repo: sampleRepo, + }, + User: &github.User{ + Login: github.Ptr("fork-contributor"), + }, + Title: github.Ptr("fork PR"), + }, + }, + "/repos/owner/reponame/pulls/987": github.PullRequest{ + Number: github.Ptr(987), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/reponame/pull/987"), + Head: &github.PullRequestBranch{ + SHA: github.Ptr("forkPRsha"), + Ref: github.Ptr("fork-feature"), + Repo: &github.Repository{ + Owner: &github.User{ + Login: github.Ptr("fork-owner"), + }, + Name: github.Ptr("reponame"), + HTMLURL: github.Ptr("https://github.com/fork-owner/reponame"), + }, + }, + Base: &github.PullRequestBranch{ + Ref: github.Ptr("main"), + SHA: github.Ptr("basesha"), + Repo: sampleRepo, + }, + User: &github.User{ + Login: github.Ptr("fork-contributor"), + }, + Title: github.Ptr("fork PR"), + }, + }, + shaRet: "forkPRsha", + wantedPullRequestNumber: 987, + }, + { + name: "bad/rerequest check_run null head_branch ambiguous fallback PRs found", + eventType: "check_run", + githubClient: true, + wantErrString: "found 2 open pull requests associated with the commit", + payloadEventStruct: github.CheckRunEvent{ + Action: github.Ptr("rerequested"), + Repo: sampleRepo, + CheckRun: &github.CheckRun{ + CheckSuite: &github.CheckSuite{ + HeadSHA: github.Ptr("ambiguousFallbackSHA"), + }, + }, + }, + muxReplies: map[string]any{ + "/repos/owner/reponame/commits/ambiguousFallbackSHA/pulls": []*github.PullRequest{}, + "/repos/owner/reponame/pulls": []*github.PullRequest{ + { + Number: github.Ptr(301), + State: github.Ptr("open"), + Head: &github.PullRequestBranch{ + SHA: github.Ptr("ambiguousFallbackSHA"), + }, + }, + { + Number: github.Ptr(302), + State: github.Ptr("open"), + Head: &github.PullRequestBranch{ + SHA: github.Ptr("ambiguousFallbackSHA"), + }, + }, + }, + }, + }, { name: "good/rerequest check_suite null head_branch resolves PR from SHA", eventType: "check_suite", @@ -678,6 +895,7 @@ func TestParsePayLoad(t *testing.T) { }, muxReplies: map[string]any{ "/repos/owner/reponame/commits/orphanSHA/pulls": []*github.PullRequest{}, + "/repos/owner/reponame/pulls": []*github.PullRequest{}, }, }, { @@ -694,6 +912,7 @@ func TestParsePayLoad(t *testing.T) { }, muxReplies: map[string]any{ "/repos/owner/reponame/commits/orphanSHA/pulls": []*github.PullRequest{}, + "/repos/owner/reponame/pulls": []*github.PullRequest{}, }, }, { @@ -717,6 +936,7 @@ func TestParsePayLoad(t *testing.T) { State: github.Ptr("closed"), }, }, + "/repos/owner/reponame/pulls": []*github.PullRequest{}, }, }, {