From 6072d432051cbd3f814d6420059195f127120a42 Mon Sep 17 00:00:00 2001 From: csvenke Date: Sun, 15 Mar 2026 17:46:25 +0100 Subject: [PATCH] feat(gh): add PR update logic to CreatePullRequest - Check if PR already exists before creating a new one - Update existing PR with new title and body if one is found - Create new PR only when none exists - Update tests to verify both creation and update paths --- internal/gh/gh.go | 9 +++++ internal/gh/gh_test.go | 77 ++++++++++++++++++++++++++++++++---------- 2 files changed, 68 insertions(+), 18 deletions(-) diff --git a/internal/gh/gh.go b/internal/gh/gh.go index 5482fcb..6461734 100644 --- a/internal/gh/gh.go +++ b/internal/gh/gh.go @@ -94,6 +94,15 @@ func CreatePullRequest(pr *PullRequest) (string, error) { return "", fmt.Errorf("pull request body is required") } + _, err := runGHCommand("pr", "view", "--json", "url") + if err == nil { + _, err := runGHCommand("pr", "edit", "--title", pr.Title, "--body", pr.Body) + if err != nil { + return "", err + } + return "Pull request updated successfully", nil + } + output, err := runGHCommand("pr", "create", "--title", pr.Title, "--body", pr.Body) if err != nil { return "", err diff --git a/internal/gh/gh_test.go b/internal/gh/gh_test.go index 88eb29b..cb2a526 100644 --- a/internal/gh/gh_test.go +++ b/internal/gh/gh_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "errors" - "reflect" "strings" "testing" ) @@ -300,22 +299,36 @@ func TestCreatePullRequest(t *testing.T) { tests := []struct { name string pr *PullRequest - runErr error - runOutput string + viewErr error + viewOutput string + editErr error + editOutput string + createErr error + createOutput string wantErr bool wantErrSubstr string wantArgs []string wantOutput string }{ { - name: "creates PR with generated title and body", + name: "creates PR when none exists", pr: &PullRequest{ Title: "Add gh pr command", Body: "## Summary\n- Add routing and PR creation", }, - runOutput: "https://github.com/example/repo/pull/123", - wantArgs: []string{"pr", "create", "--title", "Add gh pr command", "--body", "## Summary\n- Add routing and PR creation"}, - wantOutput: "https://github.com/example/repo/pull/123", + viewErr: errors.New("no PR found"), + createOutput: "https://github.com/example/repo/pull/123", + wantOutput: "https://github.com/example/repo/pull/123", + }, + { + name: "updates PR when one exists", + pr: &PullRequest{ + Title: "Update gh pr command", + Body: "## Summary\n- Updated PR description", + }, + viewOutput: "https://github.com/example/repo/pull/123", + editOutput: "", + wantOutput: "Pull request updated successfully", }, { name: "fails when PR content is nil", @@ -340,12 +353,24 @@ func TestCreatePullRequest(t *testing.T) { wantErrSubstr: "pull request body is required", }, { - name: "returns gh command error", + name: "returns gh edit error when PR exists", + pr: &PullRequest{ + Title: "title", + Body: "body", + }, + viewOutput: "https://github.com/example/repo/pull/123", + editErr: errors.New("gh pr edit failed"), + wantErr: true, + wantErrSubstr: "gh pr edit failed", + }, + { + name: "returns gh create error when PR does not exist", pr: &PullRequest{ Title: "title", Body: "body", }, - runErr: errors.New("gh pr create failed"), + viewErr: errors.New("no PR found"), + createErr: errors.New("gh pr create failed"), wantErr: true, wantErrSubstr: "gh pr create failed", }, @@ -353,13 +378,33 @@ func TestCreatePullRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var gotArgs []string + callCount := 0 runGHCommand = func(args ...string) (string, error) { - gotArgs = args - if tt.runErr != nil { - return "", tt.runErr + callCount++ + + // First call is always "pr view" + if callCount == 1 { + if tt.viewErr != nil { + return "", tt.viewErr + } + return tt.viewOutput, nil + } + + // Second call is either "pr edit" or "pr create" + if len(args) > 1 && args[1] == "edit" { + if tt.editErr != nil { + return "", tt.editErr + } + return tt.editOutput, nil } - return tt.runOutput, nil + if len(args) > 1 && args[1] == "create" { + if tt.createErr != nil { + return "", tt.createErr + } + return tt.createOutput, nil + } + + return "", errors.New("unexpected command") } output, err := CreatePullRequest(tt.pr) @@ -375,10 +420,6 @@ func TestCreatePullRequest(t *testing.T) { return } - if !reflect.DeepEqual(gotArgs, tt.wantArgs) { - t.Errorf("CreatePullRequest() args = %v, want %v", gotArgs, tt.wantArgs) - } - if output != tt.wantOutput { t.Errorf("CreatePullRequest() output = %q, want %q", output, tt.wantOutput) }