diff --git a/internal/cmd/ask/command.go b/internal/cmd/ask/command.go new file mode 100644 index 0000000..ec54b25 --- /dev/null +++ b/internal/cmd/ask/command.go @@ -0,0 +1,21 @@ +package askcmd + +import ( + "context" + "io" + + "llm/internal/ask" + "llm/internal/providers" +) + +const ( + Name = "ask" + Usage = "ask " + Description = "Ask a question" +) + +var RunFunc = ask.Run + +func Run(ctx context.Context, provider providers.Provider, stdout, stderr io.Writer, args []string) error { + return RunFunc(ctx, provider, stdout, stderr, args) +} diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go new file mode 100644 index 0000000..2973655 --- /dev/null +++ b/internal/cmd/cmd.go @@ -0,0 +1,190 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "io" + + askcmd "llm/internal/cmd/ask" + commitcmd "llm/internal/cmd/commit" + ghcmd "llm/internal/cmd/gh" + prcmd "llm/internal/cmd/gh/pr" + "llm/internal/gh" + "llm/internal/git" + "llm/internal/providers" +) + +var ErrUnknownCommand = errors.New("unknown command") + +type Dependencies struct { + Provider providers.Provider + Stdout io.Writer + Stderr io.Writer + Git git.Client + GH gh.Client +} + +type Handler func(ctx context.Context, deps Dependencies, args []string) error + +type Command struct { + Name string + Usage string + Description string + Run Handler + Subcommands []*Command +} + +type Registry struct { + commands []*Command +} + +var defaultRegistry = NewRegistry() + +func NewRegistry() *Registry { + return &Registry{ + commands: []*Command{ + { + Name: askcmd.Name, + Usage: askcmd.Usage, + Description: askcmd.Description, + Run: func(ctx context.Context, deps Dependencies, args []string) error { + return askcmd.Run(ctx, deps.Provider, deps.Stdout, deps.Stderr, args) + }, + }, + { + Name: commitcmd.Name, + Usage: commitcmd.Usage, + Description: commitcmd.Description, + Run: func(ctx context.Context, deps Dependencies, args []string) error { + return commitcmd.Run(ctx, deps.Provider, deps.Git, deps.Stderr, args) + }, + }, + { + Name: ghcmd.Name, + Usage: ghcmd.Usage, + Description: ghcmd.Description, + Subcommands: []*Command{ + { + Name: prcmd.Name, + Usage: prcmd.Usage, + Description: prcmd.Description, + Run: func(ctx context.Context, deps Dependencies, args []string) error { + return prcmd.Run(ctx, deps.Provider, deps.GH, deps.Stdout, deps.Stderr, args) + }, + }, + }, + }, + }, + } +} + +func UsageTo(w io.Writer) { + defaultRegistry.UsageTo(w) +} + +func Run(ctx context.Context, provider providers.Provider, stdout, stderr io.Writer, args []string) error { + deps := Dependencies{ + Provider: provider, + Stdout: stdout, + Stderr: stderr, + Git: &git.RealClient{}, + GH: &gh.RealClient{}, + } + + return defaultRegistry.Run(ctx, deps, args) +} + +func (r *Registry) UsageTo(w io.Writer) { + fmt.Fprintf(w, "Usage: llm [options]\n\n") + fmt.Fprintf(w, "Commands:\n") + r.writeCommands(w, r.commands, 2) +} + +func (r *Registry) Run(ctx context.Context, deps Dependencies, args []string) error { + if len(args) == 0 { + return ErrUnknownCommand + } + + deps = normalizeDependencies(deps) + + cmd := findCommand(r.commands, args[0]) + if cmd == nil { + return ErrUnknownCommand + } + + return runCommand(ctx, deps, cmd, cmd.Name, args[1:]) +} + +func runCommand(ctx context.Context, deps Dependencies, cmd *Command, path string, args []string) error { + if len(cmd.Subcommands) > 0 { + if len(args) == 0 { + return fmt.Errorf("usage: llm %s ", path) + } + + sub := findCommand(cmd.Subcommands, args[0]) + if sub == nil { + return fmt.Errorf("unknown %s subcommand %q (usage: llm %s )", cmd.Name, args[0], path) + } + + return runCommand(ctx, deps, sub, path+" "+sub.Name, args[1:]) + } + + if cmd.Run == nil { + return nil + } + + return cmd.Run(ctx, deps, args) +} + +func findCommand(commands []*Command, name string) *Command { + for _, cmd := range commands { + if cmd.Name == name { + return cmd + } + } + + return nil +} + +func (r *Registry) writeCommands(w io.Writer, commands []*Command, indent int) { + for _, cmd := range commands { + if indent <= 2 { + fmt.Fprintf(w, "%s%-19s %s\n", spaces(indent), cmd.Usage, cmd.Description) + } else { + fmt.Fprintf(w, "%s%-17s %s\n", spaces(indent), cmd.Usage, cmd.Description) + } + + if len(cmd.Subcommands) > 0 { + r.writeCommands(w, cmd.Subcommands, indent+2) + } + } +} + +func spaces(n int) string { + if n <= 0 { + return "" + } + + return fmt.Sprintf("%*s", n, "") +} + +func normalizeDependencies(deps Dependencies) Dependencies { + if deps.Stdout == nil { + deps.Stdout = io.Discard + } + + if deps.Stderr == nil { + deps.Stderr = io.Discard + } + + if deps.Git == nil { + deps.Git = &git.RealClient{} + } + + if deps.GH == nil { + deps.GH = &gh.RealClient{} + } + + return deps +} diff --git a/internal/cmd/cmd_test.go b/internal/cmd/cmd_test.go new file mode 100644 index 0000000..f45a32c --- /dev/null +++ b/internal/cmd/cmd_test.go @@ -0,0 +1,276 @@ +package cmd + +import ( + "bytes" + "context" + "errors" + "io" + "reflect" + "strings" + "testing" + + askcmd "llm/internal/cmd/ask" + commitcmd "llm/internal/cmd/commit" + prcmd "llm/internal/cmd/gh/pr" + "llm/internal/gh" + "llm/internal/git" + "llm/internal/providers" +) + +type stubProvider struct{} + +func (s *stubProvider) Complete(ctx context.Context, system, userMsg string) (string, error) { + return "", nil +} + +type stubGitClient struct{} + +func (s *stubGitClient) GetStagedDiff() (string, error) { + return "", nil +} + +func (s *stubGitClient) GetDiffFromRevision(revision string) (string, error) { + return "", nil +} + +func (s *stubGitClient) GetCurrentBranch() (string, error) { + return "", nil +} + +func (s *stubGitClient) HasParentCommit() (bool, error) { + return false, nil +} + +func (s *stubGitClient) Commit(msg string, amend bool) error { + return nil +} + +type stubGHClient struct{} + +func (s *stubGHClient) GetCurrentBranch() (string, error) { + return "", nil +} + +func (s *stubGHClient) GetDefaultBranch() (string, error) { + return "", nil +} + +func (s *stubGHClient) GetMergeBase(base, head string) (string, error) { + return "", nil +} + +func (s *stubGHClient) GetDiffRange(from, to string) (string, error) { + return "", nil +} + +func TestUsageIncludesNestedGHSubcommands(t *testing.T) { + var out bytes.Buffer + UsageTo(&out) + + if !strings.Contains(out.String(), "gh ") { + t.Fatalf("usage output = %q, want to include %q", out.String(), "gh ") + } + + if !strings.Contains(out.String(), "pr Create a GitHub pull request") { + t.Fatalf("usage output = %q, want to include gh pr subcommand help", out.String()) + } +} + +func TestRunRoutesAskAndCommit(t *testing.T) { + originalAskRun := askcmd.RunFunc + originalCommitRun := commitcmd.RunFunc + t.Cleanup(func() { + askcmd.RunFunc = originalAskRun + commitcmd.RunFunc = originalCommitRun + }) + + provider := &stubProvider{} + deps := Dependencies{ + Provider: provider, + Stdout: &bytes.Buffer{}, + Stderr: &bytes.Buffer{}, + Git: &stubGitClient{}, + GH: &stubGHClient{}, + } + + t.Run("routes ask", func(t *testing.T) { + var gotArgs []string + askcmd.RunFunc = func(ctx context.Context, gotProvider providers.Provider, output io.Writer, stderr io.Writer, args []string) error { + if gotProvider != provider { + t.Fatalf("ask provider = %#v, want %#v", gotProvider, provider) + } + gotArgs = args + return nil + } + + err := defaultRegistry.Run(context.Background(), deps, []string{"ask", "what", "is", "go"}) + if err != nil { + t.Fatalf("Run() error = %v, want nil", err) + } + + if !reflect.DeepEqual(gotArgs, []string{"what", "is", "go"}) { + t.Fatalf("ask args = %v, want %v", gotArgs, []string{"what", "is", "go"}) + } + }) + + t.Run("routes commit", func(t *testing.T) { + var gotArgs []string + commitcmd.RunFunc = func(ctx context.Context, gotProvider providers.Provider, gitClient git.Client, stderr io.Writer, args []string) error { + if gotProvider != provider { + t.Fatalf("commit provider = %#v, want %#v", gotProvider, provider) + } + gotArgs = args + return nil + } + + err := defaultRegistry.Run(context.Background(), deps, []string{"commit", "--amend"}) + if err != nil { + t.Fatalf("Run() error = %v, want nil", err) + } + + if !reflect.DeepEqual(gotArgs, []string{"--amend"}) { + t.Fatalf("commit args = %v, want %v", gotArgs, []string{"--amend"}) + } + }) +} + +func TestRunGHSubcommands(t *testing.T) { + originalGHRun := prcmd.RunFunc + originalGHCreatePullRequest := prcmd.CreatePullRequestFunc + originalEnsureBranchPushed := prcmd.EnsureBranchPushedFunc + t.Cleanup(func() { + prcmd.RunFunc = originalGHRun + prcmd.CreatePullRequestFunc = originalGHCreatePullRequest + prcmd.EnsureBranchPushedFunc = originalEnsureBranchPushed + }) + + deps := Dependencies{ + Provider: &stubProvider{}, + Stdout: &bytes.Buffer{}, + Stderr: &bytes.Buffer{}, + Git: &stubGitClient{}, + GH: &stubGHClient{}, + } + + tests := []struct { + name string + args []string + runErr error + createErr error + pushErr error + createOutput string + wantErr bool + wantErrSubstr string + wantOutput string + wantRunArgs []string + }{ + { + name: "runs gh pr and prints command output", + args: []string{"gh", "pr"}, + createOutput: "https://github.com/example/repo/pull/123", + wantOutput: "https://github.com/example/repo/pull/123\n", + wantRunArgs: []string{}, + }, + { + name: "fails when gh subcommand is missing", + args: []string{"gh"}, + wantErr: true, + wantErrSubstr: "usage: llm gh ", + }, + { + name: "fails when gh subcommand is unknown", + args: []string{"gh", "issue"}, + wantErr: true, + wantErrSubstr: "unknown gh subcommand", + }, + { + name: "fails when pr receives trailing args", + args: []string{"gh", "pr", "--extra"}, + wantErr: true, + wantErrSubstr: "usage: llm gh pr", + }, + { + name: "returns generation error", + args: []string{"gh", "pr"}, + runErr: errors.New("generation failed"), + wantErr: true, + wantErrSubstr: "generation failed", + }, + { + name: "returns push error", + args: []string{"gh", "pr"}, + pushErr: errors.New("push failed"), + wantErr: true, + wantErrSubstr: "push failed", + }, + { + name: "returns create error", + args: []string{"gh", "pr"}, + createErr: errors.New("gh pr create failed"), + wantErr: true, + wantErrSubstr: "gh pr create failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var gotRunArgs []string + var gotCreatedPR *gh.PullRequest + var stdout bytes.Buffer + deps.Stdout = &stdout + + prcmd.RunFunc = func(ctx context.Context, provider providers.Provider, client gh.Client, stderr io.Writer, args []string) (*gh.PullRequest, error) { + gotRunArgs = args + if tt.runErr != nil { + return nil, tt.runErr + } + return &gh.PullRequest{Title: "Add gh pr command", Body: "## Summary\n- Add command wiring"}, nil + } + + prcmd.CreatePullRequestFunc = func(pr *gh.PullRequest) (string, error) { + gotCreatedPR = pr + if tt.createErr != nil { + return "", tt.createErr + } + return tt.createOutput, nil + } + + prcmd.EnsureBranchPushedFunc = func(client gh.Client) error { + return tt.pushErr + } + + err := defaultRegistry.Run(context.Background(), deps, tt.args) + + if (err != nil) != tt.wantErr { + t.Fatalf("Run() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + if tt.wantErrSubstr != "" && (err == nil || !strings.Contains(err.Error(), tt.wantErrSubstr)) { + t.Errorf("Run() error = %v, want substring %q", err, tt.wantErrSubstr) + } + return + } + + if !reflect.DeepEqual(gotRunArgs, tt.wantRunArgs) { + t.Errorf("Run() args passed to gh.Run = %v, want %v", gotRunArgs, tt.wantRunArgs) + } + + wantCreatedPR := &gh.PullRequest{Title: "Add gh pr command", Body: "## Summary\n- Add command wiring"} + if !reflect.DeepEqual(gotCreatedPR, wantCreatedPR) { + t.Errorf("Run() PR passed to CreatePullRequest = %#v, want %#v", gotCreatedPR, wantCreatedPR) + } + + if stdout.String() != tt.wantOutput { + t.Errorf("Run() stdout = %q, want %q", stdout.String(), tt.wantOutput) + } + }) + } +} + +func TestRunReturnsUnknownCommandForUnknownTopLevel(t *testing.T) { + err := defaultRegistry.Run(context.Background(), Dependencies{}, []string{"unknown"}) + if !errors.Is(err, ErrUnknownCommand) { + t.Fatalf("Run() error = %v, want ErrUnknownCommand", err) + } +} diff --git a/internal/cmd/commit/command.go b/internal/cmd/commit/command.go new file mode 100644 index 0000000..447c1ac --- /dev/null +++ b/internal/cmd/commit/command.go @@ -0,0 +1,22 @@ +package commitcmd + +import ( + "context" + "io" + + "llm/internal/commit" + "llm/internal/git" + "llm/internal/providers" +) + +const ( + Name = "commit" + Usage = "commit [-a|--amend]" + Description = "Draft a commit message" +) + +var RunFunc = commit.Run + +func Run(ctx context.Context, provider providers.Provider, gitClient git.Client, stderr io.Writer, args []string) error { + return RunFunc(ctx, provider, gitClient, stderr, args) +} diff --git a/internal/cmd/gh/command.go b/internal/cmd/gh/command.go new file mode 100644 index 0000000..cb5917f --- /dev/null +++ b/internal/cmd/gh/command.go @@ -0,0 +1,7 @@ +package ghcmd + +const ( + Name = "gh" + Usage = "gh " + Description = "GitHub-related commands" +) diff --git a/internal/cmd/gh/pr/command.go b/internal/cmd/gh/pr/command.go new file mode 100644 index 0000000..6db8617 --- /dev/null +++ b/internal/cmd/gh/pr/command.go @@ -0,0 +1,45 @@ +package prcmd + +import ( + "context" + "fmt" + "io" + + "llm/internal/gh" + "llm/internal/providers" +) + +const ( + Name = "pr" + Usage = "pr" + Description = "Create a GitHub pull request" +) + +var ( + RunFunc = gh.Run + CreatePullRequestFunc = gh.CreatePullRequest + EnsureBranchPushedFunc = gh.EnsureBranchPushed +) + +func Run(ctx context.Context, provider providers.Provider, client gh.Client, stdout, stderr io.Writer, args []string) error { + if len(args) > 0 { + return fmt.Errorf("usage: llm gh pr") + } + + pr, err := RunFunc(ctx, provider, client, stderr, args) + if err != nil { + return err + } + + if err := EnsureBranchPushedFunc(client); err != nil { + return err + } + + output, err := CreatePullRequestFunc(pr) + if err != nil { + return err + } + + _, err = fmt.Fprintln(stdout, output) + return err +} diff --git a/internal/gh/client.go b/internal/gh/client.go new file mode 100644 index 0000000..44a9ae7 --- /dev/null +++ b/internal/gh/client.go @@ -0,0 +1,49 @@ +package gh + +import ( + "bytes" + "fmt" + "os/exec" + "strings" +) + +var lookPathExec = exec.LookPath +var execCommand = exec.Command + +type Client interface { + GetCurrentBranch() (string, error) + GetDefaultBranch() (string, error) + GetMergeBase(base, head string) (string, error) + GetDiffRange(from, to string) (string, error) +} + +type RealClient struct{} + +func (r *RealClient) GetCurrentBranch() (string, error) { + return r.exec("branch", "--show-current") +} + +func (r *RealClient) GetDefaultBranch() (string, error) { + return r.exec("rev-parse", "--abbrev-ref", "--symbolic-full-name", "refs/remotes/origin/HEAD") +} + +func (r *RealClient) GetMergeBase(base, head string) (string, error) { + return r.exec("merge-base", base, head) +} + +func (r *RealClient) GetDiffRange(from, to string) (string, error) { + return r.exec("diff", fmt.Sprintf("%s..%s", from, to)) +} + +func (r *RealClient) exec(args ...string) (string, error) { + cmd := execCommand("git", args...) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("git %s: %s", strings.Join(args, " "), strings.TrimSpace(stderr.String())) + } + + return strings.TrimSpace(stdout.String()), nil +} diff --git a/internal/gh/gh.go b/internal/gh/gh.go new file mode 100644 index 0000000..91e9cbe --- /dev/null +++ b/internal/gh/gh.go @@ -0,0 +1,215 @@ +package gh + +import ( + "bytes" + "context" + _ "embed" + "fmt" + "io" + "regexp" + "slices" + "strings" + + "llm/internal/loading" + "llm/internal/providers" +) + +//go:embed prompt.md +var systemPrompt string + +var ( + lookPath = lookPathExec + runGHCommand = func(args ...string) (string, error) { + cmd := execCommand("gh", args...) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("gh %s: %s", strings.Join(args, " "), strings.TrimSpace(stderr.String())) + } + + return strings.TrimSpace(stdout.String()), nil + } + runGitCommand = func(args ...string) (string, error) { + cmd := execCommand("git", args...) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("git %s: %s", strings.Join(args, " "), strings.TrimSpace(stderr.String())) + } + + return strings.TrimSpace(stdout.String()), nil + } +) + +type PullRequest struct { + Title string + Body string +} + +func BuildPrompt(diff, branch, base string) string { + if diff == "" { + return "" + } + + return fmt.Sprintf("Branch: %s\nBase: %s\n\n%s", branch, base, diff) +} + +func GeneratePullRequest(ctx context.Context, provider providers.Provider, prompt string, stderr io.Writer) (*PullRequest, error) { + ind := loading.Start(stderr) + raw, err := provider.Complete(ctx, systemPrompt, prompt) + ind.Stop() + + if err != nil { + return nil, err + } + + pr, err := parsePullRequest(raw) + if err != nil { + return nil, fmt.Errorf("parsing generated pull request content: %w", err) + } + + return pr, nil +} + +func Run(ctx context.Context, provider providers.Provider, git Client, stderr io.Writer, args []string) (*PullRequest, error) { + if len(args) > 0 { + return nil, fmt.Errorf("usage: llm gh pr") + } + + if _, err := lookPath("gh"); err != nil { + return nil, fmt.Errorf("gh CLI is required. Install GitHub CLI from https://cli.github.com/") + } + + diff, branch, base, err := getBaseDiffContext(git) + if err != nil { + return nil, err + } + + prompt := BuildPrompt(diff, branch, base) + return GeneratePullRequest(ctx, provider, prompt, stderr) +} + +func CreatePullRequest(pr *PullRequest) (string, error) { + if pr == nil { + return "", fmt.Errorf("pull request content is required") + } + + if strings.TrimSpace(pr.Title) == "" { + return "", fmt.Errorf("pull request title is required") + } + + if strings.TrimSpace(pr.Body) == "" { + return "", fmt.Errorf("pull request body is required") + } + + output, err := runGHCommand("pr", "create", "--title", pr.Title, "--body", pr.Body) + if err != nil { + return "", err + } + + return output, nil +} + +func EnsureBranchPushed(git Client) error { + if _, err := runGitCommand("rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"); err == nil { + if _, err := runGitCommand("push"); err != nil { + return fmt.Errorf("pushing current branch: %w", err) + } + return nil + } + + branch, err := git.GetCurrentBranch() + if err != nil { + return fmt.Errorf("getting current branch for push: %w", err) + } + + if strings.TrimSpace(branch) == "" { + return fmt.Errorf("unable to determine current branch for push") + } + + if _, err := runGitCommand("push", "-u", "origin", branch); err != nil { + return fmt.Errorf("pushing current branch to origin: %w", err) + } + + return nil +} + +func getBaseDiffContext(git Client) (string, string, string, error) { + base, mergeBase, err := resolveBaseBranch(git) + if err != nil { + return "", "", "", err + } + + diff, err := git.GetDiffRange(mergeBase, "HEAD") + if err != nil { + return "", "", "", fmt.Errorf("getting diff context from merge-base: %w", err) + } + + if strings.TrimSpace(diff) == "" { + return "", "", "", fmt.Errorf("no branch changes found against base branch %q", base) + } + + branch, _ := git.GetCurrentBranch() + if branch == "" { + branch = "HEAD" + } + + return diff, branch, base, nil +} + +func resolveBaseBranch(git Client) (string, string, error) { + var candidates []string + + defaultBranch, _ := git.GetDefaultBranch() + if defaultBranch != "" { + candidates = append(candidates, defaultBranch) + } + + fallbacks := []string{"origin/main", "origin/master", "main", "master"} + for _, fallback := range fallbacks { + if slices.Contains(candidates, fallback) { + continue + } + candidates = append(candidates, fallback) + } + + for _, candidate := range candidates { + mergeBase, err := git.GetMergeBase(candidate, "HEAD") + if err != nil || mergeBase == "" { + continue + } + + return candidate, mergeBase, nil + } + + if len(candidates) == 0 { + return "", "", fmt.Errorf("unable to resolve pull request base branch") + } + + return "", "", fmt.Errorf("unable to resolve pull request base branch from candidates: %s", strings.Join(candidates, ", ")) +} + +func parsePullRequest(raw string) (*PullRequest, error) { + re := regexp.MustCompile(`(?s)\s*(.*?)\s*\s*\s*(.*?)\s*`) + m := re.FindStringSubmatch(raw) + if len(m) != 3 { + return nil, fmt.Errorf("expected and <body> tags in provider response") + } + + title := strings.TrimSpace(m[1]) + body := strings.TrimSpace(m[2]) + + if title == "" { + return nil, fmt.Errorf("generated pull request title is empty") + } + + if body == "" { + return nil, fmt.Errorf("generated pull request body is empty") + } + + return &PullRequest{Title: title, Body: body}, nil +} diff --git a/internal/gh/gh_test.go b/internal/gh/gh_test.go new file mode 100644 index 0000000..1a1f83c --- /dev/null +++ b/internal/gh/gh_test.go @@ -0,0 +1,484 @@ +package gh + +import ( + "bytes" + "context" + "errors" + "reflect" + "strings" + "testing" +) + +type stubProvider struct { + resp string + err error +} + +func (s *stubProvider) Complete(ctx context.Context, system, userMsg string) (string, error) { + return s.resp, s.err +} + +type stubGitClient struct { + branch string + defaultBranch string + mergeBase string + diff string + mergeBaseFn func(base, head string) (string, error) + diffFn func(from, to string) (string, error) + branchErr error + defaultBranchErr error + mergeErr error + diffErr error +} + +func (s *stubGitClient) GetCurrentBranch() (string, error) { + return s.branch, s.branchErr +} + +func (s *stubGitClient) GetDefaultBranch() (string, error) { + return s.defaultBranch, s.defaultBranchErr +} + +func (s *stubGitClient) GetMergeBase(base, head string) (string, error) { + if s.mergeBaseFn != nil { + return s.mergeBaseFn(base, head) + } + return s.mergeBase, s.mergeErr +} + +func (s *stubGitClient) GetDiffRange(from, to string) (string, error) { + if s.diffFn != nil { + return s.diffFn(from, to) + } + return s.diff, s.diffErr +} + +func TestBuildPrompt(t *testing.T) { + got := BuildPrompt("diff content", "feature/test", "origin/main") + want := "Branch: feature/test\nBase: origin/main\n\ndiff content" + + if got != want { + t.Errorf("BuildPrompt() = %q, want %q", got, want) + } +} + +func TestGeneratePullRequest(t *testing.T) { + tests := []struct { + name string + resp string + err error + wantErr bool + want *PullRequest + }{ + { + name: "successfully parses title and body", + resp: "<title>Add GH PR generation\n## Summary\n- Add PR generation workflow", + want: &PullRequest{ + Title: "Add GH PR generation", + Body: "## Summary\n- Add PR generation workflow", + }, + }, + { + name: "provider error", + err: errors.New("provider failed"), + wantErr: true, + }, + { + name: "invalid response format", + resp: "just text", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var stderr bytes.Buffer + provider := &stubProvider{resp: tt.resp, err: tt.err} + + got, err := GeneratePullRequest(context.Background(), provider, "prompt", &stderr) + + if (err != nil) != tt.wantErr { + t.Errorf("GeneratePullRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return + } + + if got.Title != tt.want.Title { + t.Errorf("GeneratePullRequest() title = %q, want %q", got.Title, tt.want.Title) + } + + if got.Body != tt.want.Body { + t.Errorf("GeneratePullRequest() body = %q, want %q", got.Body, tt.want.Body) + } + }) + } +} + +func TestRun(t *testing.T) { + originalLookPath := lookPath + t.Cleanup(func() { + lookPath = originalLookPath + }) + + tests := []struct { + name string + args []string + git *stubGitClient + providerResp string + providerErr error + lookPathErr error + wantErr bool + wantErrSubstr string + want *PullRequest + }{ + { + name: "successful PR generation", + git: &stubGitClient{ + branch: "feature/test", + defaultBranch: "origin/main", + mergeBase: "abc123", + diff: "diff content", + }, + providerResp: "Add PR workflow\n## Summary\n- Add gh PR generation", + want: &PullRequest{ + Title: "Add PR workflow", + Body: "## Summary\n- Add gh PR generation", + }, + }, + { + name: "fails when gh CLI is missing", + lookPathErr: errors.New("not found"), + git: &stubGitClient{}, + wantErr: true, + wantErrSubstr: "gh CLI is required", + }, + { + name: "fails when base branch cannot be resolved", + git: &stubGitClient{ + defaultBranchErr: errors.New("no default branch"), + mergeErr: errors.New("no merge base"), + }, + wantErr: true, + wantErrSubstr: "unable to resolve pull request base branch", + }, + { + name: "uses fallback base branch candidates", + git: &stubGitClient{ + mergeBaseFn: func(base, head string) (string, error) { + if base != "origin/main" || head != "HEAD" { + return "", errors.New("unexpected merge-base range") + } + return "abc123", nil + }, + diff: "diff content", + }, + providerResp: "Add PR workflow\n## Summary\n- Add gh PR generation", + want: &PullRequest{ + Title: "Add PR workflow", + Body: "## Summary\n- Add gh PR generation", + }, + }, + { + name: "fails when merge-base cannot be computed", + git: &stubGitClient{ + defaultBranch: "origin/main", + mergeErr: errors.New("unrelated histories"), + }, + wantErr: true, + wantErrSubstr: "unable to resolve pull request base branch", + }, + { + name: "fails when diff context is empty", + git: &stubGitClient{ + defaultBranch: "origin/main", + mergeBase: "abc123", + diff: "", + }, + wantErr: true, + wantErrSubstr: "no branch changes found against base branch", + }, + { + name: "uses base branch merge-base flow for diff context", + git: func() *stubGitClient { + base := "origin/main" + mergeBase := "abc123" + return &stubGitClient{ + branch: "feature/test", + defaultBranch: base, + mergeBaseFn: func(base, head string) (string, error) { + if base != "origin/main" || head != "HEAD" { + return "", errors.New("unexpected merge-base range") + } + return mergeBase, nil + }, + diffFn: func(from, to string) (string, error) { + if from != mergeBase || to != "HEAD" { + return "", errors.New("unexpected diff range") + } + return "diff content", nil + }, + } + }(), + providerResp: "Add PR workflow\n## Summary\n- Add gh PR generation", + want: &PullRequest{ + Title: "Add PR workflow", + Body: "## Summary\n- Add gh PR generation", + }, + }, + { + name: "succeeds when tracked branch diff is empty but base diff exists", + git: func() *stubGitClient { + return &stubGitClient{ + defaultBranch: "origin/main", + mergeBaseFn: func(base, head string) (string, error) { + if base != "origin/main" || head != "HEAD" { + return "", errors.New("unexpected merge-base range") + } + return "abc123", nil + }, + diffFn: func(from, to string) (string, error) { + if from != "abc123" || to != "HEAD" { + return "", errors.New("unexpected diff range") + } + return "diff content against base", nil + }, + } + }(), + providerResp: "Add PR workflow\n## Summary\n- Add gh PR generation", + want: &PullRequest{ + Title: "Add PR workflow", + Body: "## Summary\n- Add gh PR generation", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lookPath = func(file string) (string, error) { + if tt.lookPathErr != nil { + return "", tt.lookPathErr + } + return "/stub/bin/gh", nil + } + + var stderr bytes.Buffer + provider := &stubProvider{resp: tt.providerResp, err: tt.providerErr} + got, err := Run(context.Background(), provider, tt.git, &stderr, tt.args) + + if (err != nil) != tt.wantErr { + t.Errorf("Run() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + if tt.wantErrSubstr != "" && (err == nil || !strings.Contains(err.Error(), tt.wantErrSubstr)) { + t.Errorf("Run() error = %v, want substring %q", err, tt.wantErrSubstr) + } + return + } + + if got.Title != tt.want.Title { + t.Errorf("Run() title = %q, want %q", got.Title, tt.want.Title) + } + + if got.Body != tt.want.Body { + t.Errorf("Run() body = %q, want %q", got.Body, tt.want.Body) + } + }) + } +} + +func TestCreatePullRequest(t *testing.T) { + originalRunGHCommand := runGHCommand + t.Cleanup(func() { + runGHCommand = originalRunGHCommand + }) + + tests := []struct { + name string + pr *PullRequest + runErr error + runOutput string + wantErr bool + wantErrSubstr string + wantArgs []string + wantOutput string + }{ + { + name: "creates PR with generated title and body", + pr: &PullRequest{ + Title: "Add gh pr command", + Body: "## Summary\n- Add routing and PR creation", + }, + runOutput: "https://github.com/example/repo/pull/123", + wantArgs: []string{"pr", "create", "--title", "Add gh pr command", "--body", "## Summary\n- Add routing and PR creation"}, + wantOutput: "https://github.com/example/repo/pull/123", + }, + { + name: "fails when PR content is nil", + pr: nil, + wantErr: true, + wantErrSubstr: "pull request content is required", + }, + { + name: "fails when title is empty", + pr: &PullRequest{ + Body: "body", + }, + wantErr: true, + wantErrSubstr: "pull request title is required", + }, + { + name: "fails when body is empty", + pr: &PullRequest{ + Title: "title", + }, + wantErr: true, + wantErrSubstr: "pull request body is required", + }, + { + name: "returns gh command error", + pr: &PullRequest{ + Title: "title", + Body: "body", + }, + runErr: errors.New("gh pr create failed"), + wantErr: true, + wantErrSubstr: "gh pr create failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var gotArgs []string + runGHCommand = func(args ...string) (string, error) { + gotArgs = args + if tt.runErr != nil { + return "", tt.runErr + } + return tt.runOutput, nil + } + + output, err := CreatePullRequest(tt.pr) + + if (err != nil) != tt.wantErr { + t.Fatalf("CreatePullRequest() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + if tt.wantErrSubstr != "" && (err == nil || !strings.Contains(err.Error(), tt.wantErrSubstr)) { + t.Errorf("CreatePullRequest() error = %v, want substring %q", err, tt.wantErrSubstr) + } + return + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("CreatePullRequest() args = %v, want %v", gotArgs, tt.wantArgs) + } + + if output != tt.wantOutput { + t.Errorf("CreatePullRequest() output = %q, want %q", output, tt.wantOutput) + } + }) + } +} + +func TestEnsureBranchPushed(t *testing.T) { + originalRunGitCommand := runGitCommand + t.Cleanup(func() { + runGitCommand = originalRunGitCommand + }) + + tests := []struct { + name string + git *stubGitClient + runGit func(args ...string) (string, error) + wantErr bool + wantErrSubstr string + wantCalls [][]string + }{ + { + name: "pushes tracked branch with git push", + git: &stubGitClient{branch: "feature/test"}, + runGit: func(args ...string) (string, error) { + return "", nil + }, + wantCalls: [][]string{ + {"rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"}, + {"push"}, + }, + }, + { + name: "pushes untracked branch to origin with upstream", + git: &stubGitClient{branch: "feature/test"}, + runGit: func(args ...string) (string, error) { + if reflect.DeepEqual(args, []string{"rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"}) { + return "", errors.New("no upstream") + } + return "", nil + }, + wantCalls: [][]string{ + {"rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"}, + {"push", "-u", "origin", "feature/test"}, + }, + }, + { + name: "fails when push errors", + git: &stubGitClient{branch: "feature/test"}, + runGit: func(args ...string) (string, error) { + if reflect.DeepEqual(args, []string{"push"}) { + return "", errors.New("push failed") + } + return "", nil + }, + wantErr: true, + wantErrSubstr: "pushing current branch", + wantCalls: [][]string{ + {"rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"}, + {"push"}, + }, + }, + { + name: "fails when branch is unknown for upstream setup", + git: &stubGitClient{branch: ""}, + runGit: func(args ...string) (string, error) { + if reflect.DeepEqual(args, []string{"rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"}) { + return "", errors.New("no upstream") + } + return "", nil + }, + wantErr: true, + wantErrSubstr: "unable to determine current branch", + wantCalls: [][]string{ + {"rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var gotCalls [][]string + runGitCommand = func(args ...string) (string, error) { + gotCalls = append(gotCalls, append([]string(nil), args...)) + return tt.runGit(args...) + } + + err := EnsureBranchPushed(tt.git) + + if (err != nil) != tt.wantErr { + t.Fatalf("EnsureBranchPushed() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr && tt.wantErrSubstr != "" && (err == nil || !strings.Contains(err.Error(), tt.wantErrSubstr)) { + t.Fatalf("EnsureBranchPushed() error = %v, want substring %q", err, tt.wantErrSubstr) + } + + if !reflect.DeepEqual(gotCalls, tt.wantCalls) { + t.Fatalf("EnsureBranchPushed() calls = %v, want %v", gotCalls, tt.wantCalls) + } + }) + } +} diff --git a/internal/gh/prompt.md b/internal/gh/prompt.md new file mode 100644 index 0000000..32bfb0b --- /dev/null +++ b/internal/gh/prompt.md @@ -0,0 +1,10 @@ +You generate GitHub pull request metadata from a git diff. Respond with ONLY the following XML-like format and no other text: + +concise PR title +markdown PR description + +Rules: +- Keep the title concise and action-oriented. +- Body should summarize intent and key changes in markdown. +- Never include code fences around the response. +- Never include explanations outside the required tags. diff --git a/main.go b/main.go index 2a7ad8a..060ca4f 100644 --- a/main.go +++ b/main.go @@ -2,15 +2,15 @@ package main import ( "context" + "errors" "flag" "fmt" + "io" "os" "os/signal" "syscall" - "llm/internal/ask" - "llm/internal/commit" - "llm/internal/git" + "llm/internal/cmd" "llm/internal/providers" ) @@ -54,12 +54,8 @@ func main() { os.Exit(1) } - switch args[0] { - case "ask": - err = ask.Run(ctx, provider, os.Stdout, os.Stderr, args[1:]) - case "commit": - err = commit.Run(ctx, provider, &git.RealClient{}, os.Stderr, args[1:]) - default: + err = cmd.Run(ctx, provider, os.Stdout, os.Stderr, args) + if errors.Is(err, cmd.ErrUnknownCommand) { usage() os.Exit(1) } @@ -71,10 +67,14 @@ func main() { } func usage() { - fmt.Fprintf(os.Stderr, "Usage: llm [options]\n\n") - fmt.Fprintf(os.Stderr, "Commands:\n") - fmt.Fprintf(os.Stderr, " ask Ask a question\n") - fmt.Fprintf(os.Stderr, " commit [-a|--amend] Draft a commit message\n") - fmt.Fprintf(os.Stderr, "\nOptions:\n") + usageTo(os.Stderr) +} + +func usageTo(w io.Writer) { + cmd.UsageTo(w) + fmt.Fprintf(w, "\nOptions:\n") + oldOutput := flag.CommandLine.Output() + flag.CommandLine.SetOutput(w) + defer flag.CommandLine.SetOutput(oldOutput) flag.PrintDefaults() } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..05384c7 --- /dev/null +++ b/main_test.go @@ -0,0 +1,20 @@ +package main + +import ( + "bytes" + "strings" + "testing" +) + +func TestUsageIncludesGHPR(t *testing.T) { + var out bytes.Buffer + usageTo(&out) + + if !strings.Contains(out.String(), "gh ") { + t.Fatalf("usage output = %q, want to include %q", out.String(), "gh ") + } + + if !strings.Contains(out.String(), "pr Create a GitHub pull request") { + t.Fatalf("usage output = %q, want to include gh pr subcommand help", out.String()) + } +}