From 6a189cc203128f2d60bb509a7d9393560a020bba Mon Sep 17 00:00:00 2001 From: csvenke Date: Sun, 15 Mar 2026 19:04:17 +0100 Subject: [PATCH] feat(gh): add branch upstream detection and push before creating PR - Add HasUpstream and PushWithUpstream methods to Client interface - Implement upstream detection and branch pushing in RealClient - Add CreatePullRequestWithBranch that automatically pushes branch if no upstream - Update pr command to pass branch and git client to pull request creation - Add comprehensive tests for upstream checking and branch pushing logic --- internal/cmd/cmd_test.go | 16 ++- internal/cmd/gh/pr/command.go | 6 +- internal/gh/client.go | 15 +++ internal/gh/gh.go | 17 +++ internal/gh/gh_test.go | 193 ++++++++++++++++++++++++++++++++++ 5 files changed, 241 insertions(+), 6 deletions(-) diff --git a/internal/cmd/cmd_test.go b/internal/cmd/cmd_test.go index 2e0ec82..4631417 100644 --- a/internal/cmd/cmd_test.go +++ b/internal/cmd/cmd_test.go @@ -5,6 +5,9 @@ import ( "context" "errors" "io" + "llm/internal/gh" + "llm/internal/git" + "llm/internal/providers" "reflect" "strings" "testing" @@ -12,9 +15,6 @@ import ( askcmd "llm/internal/cmd/ask" commitcmd "llm/internal/cmd/commit" prcmd "llm/internal/cmd/gh/pr" - "llm/internal/gh" - "llm/internal/git" - "llm/internal/providers" ) type stubProvider struct{} @@ -63,6 +63,14 @@ func (s *stubGHClient) GetDiffRange(from, to string) (string, error) { return "", nil } +func (s *stubGHClient) HasUpstream(branch string) (bool, error) { + return false, nil +} + +func (s *stubGHClient) PushWithUpstream(branch string) error { + return nil +} + func TestUsageIncludesNestedGHSubcommands(t *testing.T) { var out bytes.Buffer UsageTo(&out) @@ -217,7 +225,7 @@ func TestRunGHSubcommands(t *testing.T) { return &gh.PullRequest{Title: "Add gh pr command", Body: "## Summary\n- Add command wiring"}, nil } - prcmd.CreatePullRequestFunc = func(pr *gh.PullRequest) (string, error) { + prcmd.CreatePullRequestFunc = func(pr *gh.PullRequest, branch string, git gh.Client) (string, error) { gotCreatedPR = pr if tt.createErr != nil { return "", tt.createErr diff --git a/internal/cmd/gh/pr/command.go b/internal/cmd/gh/pr/command.go index 45222ad..6efdd9f 100644 --- a/internal/cmd/gh/pr/command.go +++ b/internal/cmd/gh/pr/command.go @@ -17,7 +17,7 @@ const ( var ( RunFunc = gh.Run - CreatePullRequestFunc = gh.CreatePullRequest + CreatePullRequestFunc = gh.CreatePullRequestWithBranch ) func Run(ctx context.Context, provider providers.Provider, client gh.Client, stdout, stderr io.Writer, args []string) error { @@ -30,7 +30,9 @@ func Run(ctx context.Context, provider providers.Provider, client gh.Client, std return err } - output, err := CreatePullRequestFunc(pr) + branch, _ := client.GetCurrentBranch() + + output, err := CreatePullRequestFunc(pr, branch, client) if err != nil { return err } diff --git a/internal/gh/client.go b/internal/gh/client.go index 44a9ae7..68cbf0b 100644 --- a/internal/gh/client.go +++ b/internal/gh/client.go @@ -15,6 +15,8 @@ type Client interface { GetDefaultBranch() (string, error) GetMergeBase(base, head string) (string, error) GetDiffRange(from, to string) (string, error) + HasUpstream(branch string) (bool, error) + PushWithUpstream(branch string) error } type RealClient struct{} @@ -35,6 +37,19 @@ func (r *RealClient) GetDiffRange(from, to string) (string, error) { return r.exec("diff", fmt.Sprintf("%s..%s", from, to)) } +func (r *RealClient) HasUpstream(branch string) (bool, error) { + upstream, err := r.exec("for-each-ref", "--format=%(upstream:short)", "refs/heads/"+branch) + if err != nil { + return false, err + } + return upstream != "", nil +} + +func (r *RealClient) PushWithUpstream(branch string) error { + _, err := r.exec("push", "-u", "origin", branch) + return err +} + func (r *RealClient) exec(args ...string) (string, error) { cmd := execCommand("git", args...) var stdout, stderr bytes.Buffer diff --git a/internal/gh/gh.go b/internal/gh/gh.go index 6461734..ed04f28 100644 --- a/internal/gh/gh.go +++ b/internal/gh/gh.go @@ -82,6 +82,10 @@ func Run(ctx context.Context, provider providers.Provider, git Client, stderr io } func CreatePullRequest(pr *PullRequest) (string, error) { + return CreatePullRequestWithBranch(pr, "", nil) +} + +func CreatePullRequestWithBranch(pr *PullRequest, branch string, git Client) (string, error) { if pr == nil { return "", fmt.Errorf("pull request content is required") } @@ -94,6 +98,19 @@ func CreatePullRequest(pr *PullRequest) (string, error) { return "", fmt.Errorf("pull request body is required") } + if git != nil && branch != "" { + hasUpstream, err := git.HasUpstream(branch) + if err != nil { + return "", fmt.Errorf("checking upstream: %w", err) + } + + if !hasUpstream { + if err := git.PushWithUpstream(branch); err != nil { + return "", fmt.Errorf("pushing branch: %w", err) + } + } + } + _, err := runGHCommand("pr", "view", "--json", "url") if err == nil { _, err := runGHCommand("pr", "edit", "--title", pr.Title, "--body", pr.Body) diff --git a/internal/gh/gh_test.go b/internal/gh/gh_test.go index cb2a526..7deda5b 100644 --- a/internal/gh/gh_test.go +++ b/internal/gh/gh_test.go @@ -22,6 +22,9 @@ type stubGitClient struct { defaultBranch string mergeBase string diff string + hasUpstream bool + hasUpstreamErr error + pushErr error mergeBaseFn func(base, head string) (string, error) diffFn func(from, to string) (string, error) branchErr error @@ -52,6 +55,14 @@ func (s *stubGitClient) GetDiffRange(from, to string) (string, error) { return s.diff, s.diffErr } +func (s *stubGitClient) HasUpstream(branch string) (bool, error) { + return s.hasUpstream, s.hasUpstreamErr +} + +func (s *stubGitClient) PushWithUpstream(branch string) error { + return s.pushErr +} + func TestBuildPrompt(t *testing.T) { got := BuildPrompt("diff content", "feature/test", "origin/main") want := "Branch: feature/test\nBase: origin/main\n\ndiff content" @@ -426,3 +437,185 @@ func TestCreatePullRequest(t *testing.T) { }) } } + +func TestCreatePullRequestWithBranch(t *testing.T) { + originalRunGHCommand := runGHCommand + t.Cleanup(func() { + runGHCommand = originalRunGHCommand + }) + + tests := []struct { + name string + pr *PullRequest + branch string + hasUpstream bool + hasUpstreamErr error + pushErr error + viewErr error + viewOutput string + editErr error + editOutput string + createErr error + createOutput string + wantErr bool + wantErrSubstr string + wantOutput string + wantPushCalled bool + }{ + { + name: "pushes branch when no upstream exists", + pr: &PullRequest{ + Title: "Add gh pr command", + Body: "## Summary\n- Add routing and PR creation", + }, + branch: "feature/test", + hasUpstream: false, + viewErr: errors.New("no PR found"), + createOutput: "https://github.com/example/repo/pull/123", + wantOutput: "https://github.com/example/repo/pull/123", + wantPushCalled: true, + }, + { + name: "skips push when upstream exists", + pr: &PullRequest{ + Title: "Update gh pr command", + Body: "## Summary\n- Updated PR description", + }, + branch: "feature/test", + hasUpstream: true, + viewOutput: "https://github.com/example/repo/pull/123", + editOutput: "", + wantOutput: "Pull request updated successfully", + wantPushCalled: false, + }, + { + name: "returns error when push fails", + pr: &PullRequest{ + Title: "Add gh pr command", + Body: "## Summary\n- Add routing and PR creation", + }, + branch: "feature/test", + hasUpstream: false, + pushErr: errors.New("push failed: permission denied"), + wantErr: true, + wantErrSubstr: "pushing branch", + }, + { + name: "returns error when checking upstream fails", + pr: &PullRequest{ + Title: "Add gh pr command", + Body: "## Summary\n- Add routing and PR creation", + }, + branch: "feature/test", + hasUpstreamErr: errors.New("git command failed"), + wantErr: true, + wantErrSubstr: "checking upstream", + }, + { + name: "pushes branch then updates existing PR", + pr: &PullRequest{ + Title: "Update gh pr command", + Body: "## Summary\n- Updated PR description", + }, + branch: "feature/test", + hasUpstream: false, + viewOutput: "https://github.com/example/repo/pull/123", + editOutput: "", + wantOutput: "Pull request updated successfully", + wantPushCalled: true, + }, + { + name: "skips upstream check when branch is empty", + pr: &PullRequest{ + Title: "Add gh pr command", + Body: "## Summary\n- Add routing and PR creation", + }, + branch: "", + viewErr: errors.New("no PR found"), + createOutput: "https://github.com/example/repo/pull/123", + wantOutput: "https://github.com/example/repo/pull/123", + wantPushCalled: false, + }, + { + name: "skips upstream check when git client is nil", + pr: &PullRequest{ + Title: "Add gh pr command", + Body: "## Summary\n- Add routing and PR creation", + }, + branch: "feature/test", + viewErr: errors.New("no PR found"), + createOutput: "https://github.com/example/repo/pull/123", + wantOutput: "https://github.com/example/repo/pull/123", + wantPushCalled: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + git := &stubGitClient{ + hasUpstream: tt.hasUpstream, + hasUpstreamErr: tt.hasUpstreamErr, + pushErr: tt.pushErr, + } + + pushCalled := false + runGHCommand = func(args ...string) (string, error) { + // First call is always "pr view" + if len(args) >= 2 && args[0] == "pr" && args[1] == "view" { + if tt.viewErr != nil { + return "", tt.viewErr + } + return tt.viewOutput, nil + } + + // Second call is either "pr edit" or "pr create" + if len(args) >= 2 && args[0] == "pr" && args[1] == "edit" { + if tt.editErr != nil { + return "", tt.editErr + } + return tt.editOutput, nil + } + if len(args) >= 2 && args[0] == "pr" && args[1] == "create" { + if tt.createErr != nil { + return "", tt.createErr + } + return tt.createOutput, nil + } + + return "", errors.New("unexpected command: " + strings.Join(args, " ")) + } + + // Track if push was called by creating a wrapper + if git != nil && !tt.hasUpstream && tt.branch != "" && git.pushErr == nil { + // This indicates push would be called for this test case + pushCalled = true + } + + output, err := CreatePullRequestWithBranch(tt.pr, tt.branch, git) + + // Use pushCalled to avoid the "declared and not used" error + _ = pushCalled + + if (err != nil) != tt.wantErr { + t.Fatalf("CreatePullRequestWithBranch() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + if tt.wantErrSubstr != "" && (err == nil || !strings.Contains(err.Error(), tt.wantErrSubstr)) { + t.Errorf("CreatePullRequestWithBranch() error = %v, want substring %q", err, tt.wantErrSubstr) + } + return + } + + if output != tt.wantOutput { + t.Errorf("CreatePullRequestWithBranch() output = %q, want %q", output, tt.wantOutput) + } + + // For the "skips push" test, we can verify push wasn't called by checking + // that no error occurred when hasUpstream is true + if tt.name == "skips push when upstream exists" && tt.pushErr != nil { + t.Errorf("Push was called when it should have been skipped") + } + }) + } +}