diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..3550a30 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml new file mode 100644 index 0000000..a31d5b5 --- /dev/null +++ b/.github/workflows/pull-request.yml @@ -0,0 +1,43 @@ +name: Pull Request + +on: + pull_request: + branches: [main] + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: DeterminateSystems/nix-installer-action@main + - run: nix run .#lint + + test: + name: Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: DeterminateSystems/nix-installer-action@main + - run: nix run .#test + + build: + name: Build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: DeterminateSystems/nix-installer-action@main + - run: nix run .#build + - uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ + + merge-ready: + runs-on: ubuntu-latest + if: ${{ always() }} + needs: [lint, test, build] + steps: + - if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') }} + run: exit 1 + - run: echo "LGTM" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..4ea36f6 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,22 @@ +name: Release + +on: + push: + branches: [main] + workflow_dispatch: + +permissions: + contents: write + +jobs: + release: + name: Release + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: DeterminateSystems/nix-installer-action@main + - run: nix run .#release + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5a07d1d --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +# Binaries +dist/ +*.exe + +# Test/coverage +*.test +coverage.out + +# IDE +.idea/ +.vscode/ + +# OS +.DS_Store + +# Direnv +!.envrc diff --git a/.goreleaser.yaml b/.goreleaser.yaml new file mode 100644 index 0000000..6a82967 --- /dev/null +++ b/.goreleaser.yaml @@ -0,0 +1,38 @@ +version: 2 + +before: + hooks: + - go mod tidy + +builds: + - env: + - CGO_ENABLED=0 + ldflags: + - -s -w -X main.version={{ .Version }} + goos: + - linux + - windows + - darwin + goarch: + - amd64 + - arm64 + +archives: + - formats: + - tar.gz + name_template: >- + {{ .ProjectName }}_{{ title .Os }}_ + {{- if eq .Arch "amd64" }}x86_64{{ else }}{{ .Arch }}{{ end }} + format_overrides: + - goos: windows + formats: + - zip + +changelog: + sort: asc + groups: + - title: Features + regexp: '^.*?feat(\([[:word:]]+\))??!?:.+$' + - title: Bug Fixes + regexp: '^.*?fix(\([[:word:]]+\))??!?:.+$' + - title: Other changes diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..138643e --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Christian + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..2ab8b37 --- /dev/null +++ b/README.md @@ -0,0 +1,56 @@ +# llm-cli + +A lightweight CLI tool for interacting with LLMs from the terminal. + +## Features + +- **Generate commit messages** — AI-powered conventional commit messages from your staged changes +- **Ask questions** — Get answers from an LLM directly in your terminal + +## Installation + +### Nix + +```bash +nix profile install github:csvenke/llm-cli +``` + +### Binary releases + +Download pre-built binaries from [GitHub Releases](https://github.com/csvenke/llm-cli/releases). + +## Usage + +### Ask a question + +```bash +llm ask "How do I reverse a string in Go?" +``` + +### Generate a commit message + +Stage your changes and run: + +```bash +llm commit +``` + +To amend the previous commit: + +```bash +llm commit -a +``` + +## Configuration + +Set one of the following environment variables (checked in order): + +| Variable | Description | +|---|---| +| `OPENCODE_ZEN_API_KEY` | OpenCode Zen API key | +| `ANTHROPIC_API_KEY` | Anthropic API key | +| `OPENAI_API_KEY` | OpenAI API key | + +## License + +[MIT](LICENSE) diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..a24d615 --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-parts": { + "inputs": { + "nixpkgs-lib": "nixpkgs-lib" + }, + "locked": { + "lastModified": 1769996383, + "narHash": "sha256-AnYjnFWgS49RlqX7LrC4uA+sCCDBj0Ry/WOJ5XWAsa0=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "57928607ea566b5db3ad13af0e57e921e6b12381", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1771008912, + "narHash": "sha256-gf2AmWVTs8lEq7z/3ZAsgnZDhWIckkb+ZnAo5RzSxJg=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "a82ccc39b39b621151d6732718e3e250109076fa", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-lib": { + "locked": { + "lastModified": 1769909678, + "narHash": "sha256-cBEymOf4/o3FD5AZnzC3J9hLbiZ+QDT/KDuyHXVJOpM=", + "owner": "nix-community", + "repo": "nixpkgs.lib", + "rev": "72716169fe93074c333e8d0173151350670b824c", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "nixpkgs.lib", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-parts": "flake-parts", + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..e3dcfba --- /dev/null +++ b/flake.nix @@ -0,0 +1,45 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-parts.url = "github:hercules-ci/flake-parts"; + }; + + outputs = + inputs@{ + self, + flake-parts, + nixpkgs, + ... + }: + flake-parts.lib.mkFlake { inherit inputs; } { + systems = nixpkgs.lib.systems.flakeExposed; + perSystem = + { system, ... }: + let + pkgs = import nixpkgs { + inherit system; + overlays = [ + (final: prev: { go = prev.go_1_24; }) + ]; + }; + inherit (pkgs) lib callPackage; + version = self.shortRev or self.dirtyShortRev or "snapshot"; + scripts = lib.packagesFromDirectoryRecursive { + inherit callPackage; + directory = ./nix/scripts; + }; + llm-cli = callPackage ./nix/package.nix { + inherit version; + }; + shell = callPackage ./nix/shell.nix { }; + in + { + packages = scripts // { + default = llm-cli; + }; + devShells = { + default = shell; + }; + }; + }; +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c7c966f --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module llm + +go 1.24.4 diff --git a/internal/ask/ask.go b/internal/ask/ask.go new file mode 100644 index 0000000..d817cd6 --- /dev/null +++ b/internal/ask/ask.go @@ -0,0 +1,31 @@ +package ask + +import ( + "context" + "fmt" + "io" + "strings" + + "llm/internal/providers" +) + +// Run executes the ask command with the given arguments. +// Returns an error if no arguments are provided or if the provider fails. +func Run(ctx context.Context, provider providers.Provider, output io.Writer, args []string) error { + if output == nil { + output = io.Discard + } + + if len(args) == 0 { + return fmt.Errorf("usage: llm ask ") + } + + question := strings.Join(args, " ") + response, err := provider.Complete(ctx, "", question) + if err != nil { + return err + } + + _, err = fmt.Fprintln(output, response) + return err +} diff --git a/internal/ask/ask_test.go b/internal/ask/ask_test.go new file mode 100644 index 0000000..bc09fcc --- /dev/null +++ b/internal/ask/ask_test.go @@ -0,0 +1,95 @@ +package ask + +import ( + "bytes" + "context" + "errors" + "testing" +) + +type mockProvider struct { + resp string + err error +} + +func (m *mockProvider) Complete(ctx context.Context, system, userMsg string) (string, error) { + return m.resp, m.err +} + +func TestRun(t *testing.T) { + tests := []struct { + name string + args []string + resp string + err error + want string + wantErr bool + }{ + { + name: "successfully asks a question and outputs response", + args: []string{"what", "is", "Go?"}, + resp: "Go is a programming language.", + want: "Go is a programming language.\n", + wantErr: false, + }, + { + name: "returns error when no arguments provided", + args: []string{}, + wantErr: true, + }, + { + name: "returns error when provider fails", + args: []string{"hello"}, + err: errors.New("network error"), + wantErr: true, + }, + { + name: "handles single word question", + args: []string{"help"}, + resp: "How can I help you?", + want: "How can I help you?\n", + wantErr: false, + }, + { + name: "handles long question with many arguments", + args: []string{"explain", "the", "difference", "between", "interfaces", "and", "structs", "in", "Go"}, + resp: "Interfaces define behavior, structs define data.", + want: "Interfaces define behavior, structs define data.\n", + wantErr: false, + }, + { + name: "handles empty response from provider", + args: []string{"test"}, + resp: "", + want: "\n", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var output bytes.Buffer + provider := &mockProvider{resp: tt.resp, err: tt.err} + + err := Run(context.Background(), provider, &output, tt.args) + + if (err != nil) != tt.wantErr { + t.Errorf("Run() error = %v, wantErr %v", err, tt.wantErr) + return + } + + gotOutput := output.String() + if gotOutput != tt.want { + t.Errorf("output = %q, want %q", gotOutput, tt.want) + } + }) + } +} + +func TestRun_NilOutputDefaultsToDiscard(t *testing.T) { + provider := &mockProvider{resp: "test response"} + err := Run(context.Background(), provider, nil, []string{"test"}) + if err != nil { + t.Errorf("Run() with nil output error = %v, want nil", err) + } +} diff --git a/internal/commit/commit.go b/internal/commit/commit.go new file mode 100644 index 0000000..adb9b10 --- /dev/null +++ b/internal/commit/commit.go @@ -0,0 +1,128 @@ +package commit + +import ( + "context" + _ "embed" + "fmt" + "io" + "regexp" + "time" + + "llm/internal/git" + "llm/internal/providers" +) + +//go:embed prompt.md +var systemPrompt string + +type Config struct { + Amend bool +} + +func ParseConfig(args []string) (*Config, []string, error) { + cfg := &Config{} + var remaining []string + + for i := 0; i < len(args); i++ { + switch args[i] { + case "-a", "--amend": + cfg.Amend = true + default: + remaining = append(remaining, args[i]) + } + } + + return cfg, remaining, nil +} + +func BuildPrompt(diff, branch string) string { + if diff == "" { + return "" + } + + if issue := extractIssue(branch); issue != "" { + return fmt.Sprintf("Branch: %s (Issue: %s)\n\n%s", branch, issue, diff) + } + + return diff +} + +func GenerateCommitMessage(ctx context.Context, provider providers.Provider, prompt string, stderr io.Writer, tickerFunc func(time.Duration) *time.Ticker) (string, error) { + if stderr == nil { + stderr = io.Discard + } + if tickerFunc == nil { + tickerFunc = func(d time.Duration) *time.Ticker { return time.NewTicker(d) } + } + + done := make(chan struct{}) + go func() { + ticker := tickerFunc(500 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-done: + return + case <-ticker.C: + _, _ = fmt.Fprint(stderr, ".") + } + } + }() + + msg, err := provider.Complete(ctx, systemPrompt, prompt) + close(done) + _, _ = fmt.Fprintln(stderr) + + if err != nil { + return "", err + } + + return msg, nil +} + +func Run(ctx context.Context, provider providers.Provider, git git.Client, stderr io.Writer, args []string) error { + cfg, _, err := ParseConfig(args) + if err != nil { + return err + } + + diff, err := getDiffForCommit(git, cfg.Amend) + if err != nil { + return err + } + + if diff == "" { + if cfg.Amend { + return fmt.Errorf("no changes found to amend") + } + return fmt.Errorf("no staged changes found. Stage your changes with 'git add' first") + } + + branch, _ := git.GetCurrentBranch() + prompt := BuildPrompt(diff, branch) + + msg, err := GenerateCommitMessage(ctx, provider, prompt, stderr, nil) + if err != nil { + return err + } + + return git.Commit(msg, cfg.Amend) +} + +func extractIssue(branch string) string { + re := regexp.MustCompile(`[A-Z]+-\d+`) + return re.FindString(branch) +} + +func getDiffForCommit(git git.Client, amend bool) (string, error) { + if !amend { + return git.GetStagedDiff() + } + + hasParent, _ := git.HasParentCommit() + if !hasParent { + return git.GetStagedDiff() + } + + return git.GetDiffFromRevision("HEAD~1") +} diff --git a/internal/commit/commit_test.go b/internal/commit/commit_test.go new file mode 100644 index 0000000..4998954 --- /dev/null +++ b/internal/commit/commit_test.go @@ -0,0 +1,249 @@ +package commit + +import ( + "bytes" + "context" + "errors" + "testing" + "time" +) + +type mockProvider struct { + resp string + err error +} + +func (m *mockProvider) Complete(ctx context.Context, system, userMsg string) (string, error) { + return m.resp, m.err +} + +type mockGitClient struct { + diff string + branch string + hasParent bool + commitErr error + diffErr error + branchErr error +} + +func (m *mockGitClient) GetStagedDiff() (string, error) { + return m.diff, m.diffErr +} + +func (m *mockGitClient) GetDiffFromRevision(revision string) (string, error) { + return m.diff, m.diffErr +} + +func (m *mockGitClient) GetCurrentBranch() (string, error) { + return m.branch, m.branchErr +} + +func (m *mockGitClient) HasParentCommit() (bool, error) { + return m.hasParent, nil +} + +func (m *mockGitClient) Commit(msg string, amend bool) error { + return m.commitErr +} + +func TestParseConfig(t *testing.T) { + tests := []struct { + name string + args []string + wantAmend bool + }{ + { + name: "no flags", + args: []string{}, + wantAmend: false, + }, + { + name: "with -a flag", + args: []string{"-a"}, + wantAmend: true, + }, + { + name: "with --amend flag", + args: []string{"--amend"}, + wantAmend: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg, _, err := ParseConfig(tt.args) + if err != nil { + t.Errorf("ParseConfig() unexpected error: %v", err) + return + } + if cfg.Amend != tt.wantAmend { + t.Errorf("ParseConfig() cfg.Amend = %v, want %v", cfg.Amend, tt.wantAmend) + } + }) + } +} + +func TestBuildPrompt(t *testing.T) { + tests := []struct { + name string + diff string + branch string + want string + }{ + { + name: "diff with JIRA issue in branch", + diff: "some diff content", + branch: "feature/PROJ-123-add-feature", + want: "Branch: feature/PROJ-123-add-feature (Issue: PROJ-123)\n\nsome diff content", + }, + { + name: "diff without issue in branch", + diff: "some diff content", + branch: "feature/add-feature", + want: "some diff content", + }, + { + name: "empty diff returns empty", + diff: "", + branch: "main", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := BuildPrompt(tt.diff, tt.branch) + if got != tt.want { + t.Errorf("BuildPrompt() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestGenerateCommitMessage(t *testing.T) { + tests := []struct { + name string + prompt string + resp string + err error + want string + wantErr bool + }{ + { + name: "successful message generation", + prompt: "test prompt", + resp: "feat: add new feature", + want: "feat: add new feature", + wantErr: false, + }, + { + name: "provider returns error", + prompt: "test prompt", + err: errors.New("API error"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var stderr bytes.Buffer + provider := &mockProvider{resp: tt.resp, err: tt.err} + tickerFactory := func(d time.Duration) *time.Ticker { + return time.NewTicker(d) + } + + msg, err := GenerateCommitMessage(context.Background(), provider, tt.prompt, &stderr, tickerFactory) + + if (err != nil) != tt.wantErr { + t.Errorf("GenerateCommitMessage() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && msg != tt.want { + t.Errorf("GenerateCommitMessage() = %q, want %q", msg, tt.want) + } + }) + } +} + +func TestRun(t *testing.T) { + tests := []struct { + name string + args []string + git *mockGitClient + resp string + provErr error + wantErr bool + }{ + { + name: "successful commit flow", + args: []string{}, + git: &mockGitClient{ + diff: "some changes", + branch: "feature/test", + }, + resp: "feat: add feature", + wantErr: false, + }, + { + name: "successful commit with issue in branch", + args: []string{}, + git: &mockGitClient{ + diff: "some changes", + branch: "feature/PROJ-123-fix", + }, + resp: "fix: resolve bug", + wantErr: false, + }, + { + name: "no staged changes", + args: []string{}, + git: &mockGitClient{ + diff: "", + }, + wantErr: true, + }, + { + name: "no changes to amend", + args: []string{"-a"}, + git: &mockGitClient{ + diff: "", + hasParent: true, // HEAD~1 exists + }, + wantErr: true, + }, + { + name: "provider fails", + args: []string{}, + git: &mockGitClient{ + diff: "some changes", + branch: "main", + }, + provErr: errors.New("API error"), + wantErr: true, + }, + { + name: "git commit fails", + args: []string{}, + git: &mockGitClient{ + diff: "some changes", + branch: "main", + commitErr: errors.New("commit failed"), + }, + resp: "test commit", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var stderr bytes.Buffer + provider := &mockProvider{resp: tt.resp, err: tt.provErr} + err := Run(context.Background(), provider, tt.git, &stderr, tt.args) + + if (err != nil) != tt.wantErr { + t.Errorf("Run() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/commit/prompt.md b/internal/commit/prompt.md new file mode 100644 index 0000000..ba57604 --- /dev/null +++ b/internal/commit/prompt.md @@ -0,0 +1,65 @@ +You generate git commit messages following the conventional commits format for semantic versioning. + +## Commit Type Selection + +The commit type determines version bumping: +- feat: MINOR version bump (new user-facing functionality) +- fix: PATCH version bump (bug fixes to existing functionality) +- BREAKING CHANGE footer: MAJOR version bump (incompatible API changes) + +Non-versioning types (NO version bump): +- chore: maintenance, dependency updates, config changes, tooling +- refactor: code restructuring without behavior change +- docs: documentation-only changes +- ci: CI/CD pipeline changes +- build: build system/dependency changes +- test: test-only changes +- perf: performance improvements +- style: formatting, whitespace, semicolons + +When in doubt, prefer non-versioning types over feat/fix. + +## Format + +```gitcommit +type(scope): concise description + +optional body with details + +optional footer (Closes: #issue, BREAKING CHANGE, etc.) +``` + +## Examples + +```gitcommit +fix: prevent null pointer exception in user validation +``` + +```gitcommit +feat(api): add pagination to search results endpoint +``` + +```gitcommit +refactor: extract database connection logic into separate module + +* move connection pooling to db/pool.py +* update imports in affected services +``` + +```gitcommit +chore: upgrade pytest from 7.1.0 to 7.4.2 +``` + +```gitcommit +feat!: change user ID format from integer to UUID + +BREAKING CHANGE: user IDs are now UUIDs instead of integers +``` + +Never include back-ticks in final commit + +## Rules +- Match the style and tone of the repository's recent commits +- Infer the scope from the changed file paths if appropriate +- Reference issues in the footer if an issue number is available +- Return ONLY the commit message text, no markdown formatting or explanation diff --git a/internal/git/git.go b/internal/git/git.go new file mode 100644 index 0000000..ad4098e --- /dev/null +++ b/internal/git/git.go @@ -0,0 +1,77 @@ +package git + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "strings" +) + +type Client interface { + GetStagedDiff() (string, error) + GetDiffFromRevision(revision string) (string, error) + GetCurrentBranch() (string, error) + HasParentCommit() (bool, error) + Commit(msg string, amend bool) error +} + +type RealClient struct{} + +func (r *RealClient) GetStagedDiff() (string, error) { + diff, err := r.exec("diff", "--staged") + if err != nil { + return "", fmt.Errorf("getting staged diff: %w", err) + } + return diff, nil +} + +func (r *RealClient) GetDiffFromRevision(revision string) (string, error) { + diff, err := r.exec("diff", "--staged", revision) + if err != nil { + return "", fmt.Errorf("getting diff from %s: %w", revision, err) + } + return diff, nil +} + +func (r *RealClient) GetCurrentBranch() (string, error) { + return r.exec("branch", "--show-current") +} + +func (r *RealClient) HasParentCommit() (bool, error) { + _, err := r.exec("rev-parse", "--verify", "HEAD~1") + if err != nil { + return false, nil + } + return true, nil +} + +func (r *RealClient) Commit(msg string, amend bool) error { + commitArgs := []string{"commit", "-m", msg, "-e"} + if amend { + commitArgs = []string{"commit", "--amend", "-m", msg, "-e"} + } + + return r.execInteractive(commitArgs...) +} + +func (r *RealClient) exec(args ...string) (string, error) { + cmd := exec.Command("git", args...) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("git %s: %s", strings.Join(args, " "), strings.TrimSpace(stderr.String())) + } + + return strings.TrimSpace(stdout.String()), nil +} + +func (r *RealClient) execInteractive(args ...string) error { + cmd := exec.Command("git", args...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() +} diff --git a/internal/providers/anthropic.go b/internal/providers/anthropic.go new file mode 100644 index 0000000..f02cace --- /dev/null +++ b/internal/providers/anthropic.go @@ -0,0 +1,93 @@ +package providers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// Anthropic Messages API types +type anthropicRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + System string `json:"system,omitempty"` + Messages []Message `json:"messages"` +} + +type anthropicResponse struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + Error *struct { + Message string `json:"message"` + } `json:"error,omitempty"` +} + +// AnthropicProvider implements Provider for Anthropic Messages API. +type AnthropicProvider struct { + endpoint string + model string + apiKey string +} + +// NewAnthropicProvider creates a new Anthropic provider with the given configuration. +func NewAnthropicProvider(endpoint, model, apiKey string) Provider { + return &AnthropicProvider{ + endpoint: endpoint, + model: model, + apiKey: apiKey, + } +} + +func (a *AnthropicProvider) Complete(ctx context.Context, system, userMsg string) (string, error) { + jsonData, err := json.Marshal(anthropicRequest{ + Model: a.model, + MaxTokens: 4096, + System: system, + Messages: []Message{{Role: "user", Content: userMsg}}, + }) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", a.endpoint, bytes.NewReader(jsonData)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", a.apiKey) + req.Header.Set("anthropic-version", "2023-06-01") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + var r anthropicResponse + if err := json.Unmarshal(body, &r); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("%d: %s", resp.StatusCode, string(body)) + } + if r.Error != nil { + return "", fmt.Errorf("%s", r.Error.Message) + } + + if len(r.Content) > 0 { + return r.Content[0].Text, nil + } + return "", nil +} diff --git a/internal/providers/openai.go b/internal/providers/openai.go new file mode 100644 index 0000000..d1ee54b --- /dev/null +++ b/internal/providers/openai.go @@ -0,0 +1,96 @@ +package providers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// OpenAI Chat Completions API types +type openaiRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` +} + +type openaiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + Error *struct { + Message string `json:"message"` + } `json:"error,omitempty"` +} + +// OpenAIProvider implements Provider for OpenAI Chat Completions API. +type OpenAIProvider struct { + endpoint string + model string + apiKey string +} + +// NewOpenAIProvider creates a new OpenAI provider with the given configuration. +func NewOpenAIProvider(endpoint, model, apiKey string) Provider { + return &OpenAIProvider{ + endpoint: endpoint, + model: model, + apiKey: apiKey, + } +} + +func (o *OpenAIProvider) Complete(ctx context.Context, system, userMsg string) (string, error) { + msgs := []Message{} + if system != "" { + msgs = append(msgs, Message{Role: "system", Content: system}) + } + msgs = append(msgs, Message{Role: "user", Content: userMsg}) + + jsonData, err := json.Marshal(openaiRequest{ + Model: o.model, + Messages: msgs, + }) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", o.endpoint, bytes.NewReader(jsonData)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+o.apiKey) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + var r openaiResponse + if err := json.Unmarshal(body, &r); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("%d: %s", resp.StatusCode, string(body)) + } + if r.Error != nil { + return "", fmt.Errorf("%s", r.Error.Message) + } + + if len(r.Choices) > 0 { + return r.Choices[0].Message.Content, nil + } + return "", nil +} diff --git a/internal/providers/provider.go b/internal/providers/provider.go new file mode 100644 index 0000000..4acd42e --- /dev/null +++ b/internal/providers/provider.go @@ -0,0 +1,16 @@ +package providers + +import "context" + +// Provider defines the strategy interface for LLM chat completions. +// Each provider implementation encapsulates its own configuration +// (endpoint, model, API key) and handles the complete request lifecycle. +type Provider interface { + Complete(ctx context.Context, system, userMsg string) (string, error) +} + +// Message represents a chat message. +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} diff --git a/internal/providers/resolve.go b/internal/providers/resolve.go new file mode 100644 index 0000000..1c82c7a --- /dev/null +++ b/internal/providers/resolve.go @@ -0,0 +1,40 @@ +package providers + +import ( + "fmt" + "os" +) + +// ResolveByAPIKey checks environment variables in order of precedence and returns +// a fully configured provider ready to make API calls. +// Priority: OPENCODE_ZEN_API_KEY > ANTHROPIC_API_KEY > OPENAI_API_KEY +func ResolveByAPIKey() (Provider, error) { + // Priority 1: OpenCode Zen (uses Anthropic Messages format) + if apiKey := os.Getenv("OPENCODE_ZEN_API_KEY"); apiKey != "" { + return NewAnthropicProvider( + "https://opencode.ai/zen/v1/messages", + "claude-3-5-haiku", + apiKey, + ), nil + } + + // Priority 2: Anthropic Direct + if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { + return NewAnthropicProvider( + "https://api.anthropic.com/v1/messages", + "claude-3-5-haiku", + apiKey, + ), nil + } + + // Priority 3: OpenAI + if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { + return NewOpenAIProvider( + "https://api.openai.com/v1/chat/completions", + "gpt-4o-mini", + apiKey, + ), nil + } + + return nil, fmt.Errorf("no API key found. Set OPENCODE_ZEN_API_KEY, ANTHROPIC_API_KEY, or OPENAI_API_KEY") +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..316fa84 --- /dev/null +++ b/main.go @@ -0,0 +1,80 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "os/signal" + "syscall" + + "llm/internal/ask" + "llm/internal/commit" + "llm/internal/git" + "llm/internal/providers" +) + +var version string + +func main() { + var printVersion bool + flag.BoolVar(&printVersion, "v", false, "print version") + flag.BoolVar(&printVersion, "version", false, "print version") + flag.Usage = usage + flag.Parse() + + if printVersion { + if version == "" { + version = "snapshot" + } + fmt.Println(version) + os.Exit(0) + } + + args := flag.Args() + if len(args) == 0 { + usage() + os.Exit(1) + } + + // Create a context that can be cancelled via signals (Ctrl+C, SIGTERM) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + go func() { + <-sigChan + cancel() + }() + + provider, err := providers.ResolveByAPIKey() + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + switch args[0] { + case "ask": + err = ask.Run(ctx, provider, os.Stdout, args[1:]) + case "commit": + err = commit.Run(ctx, provider, &git.RealClient{}, os.Stderr, args[1:]) + default: + usage() + os.Exit(1) + } + + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} + +func usage() { + fmt.Fprintf(os.Stderr, "Usage: llm [options]\n\n") + fmt.Fprintf(os.Stderr, "Commands:\n") + fmt.Fprintf(os.Stderr, " ask Ask a question\n") + fmt.Fprintf(os.Stderr, " commit [-a|--amend] Draft a commit message\n") + fmt.Fprintf(os.Stderr, "\nOptions:\n") + flag.PrintDefaults() +} diff --git a/nix/package.nix b/nix/package.nix new file mode 100644 index 0000000..c8f5a6d --- /dev/null +++ b/nix/package.nix @@ -0,0 +1,16 @@ +{ version, buildGoModule }: + +buildGoModule { + pname = "llm"; + version = version; + src = ../.; + vendorHash = null; + ldflags = [ + "-s" + "-w" + "-X main.version=${version}" + ]; + meta = { + mainProgram = "llm"; + }; +} diff --git a/nix/scripts/build.nix b/nix/scripts/build.nix new file mode 100644 index 0000000..8c757d6 --- /dev/null +++ b/nix/scripts/build.nix @@ -0,0 +1,16 @@ +{ + writeShellApplication, + goreleaser, + go, +}: + +writeShellApplication { + name = "build"; + runtimeInputs = [ + goreleaser + go + ]; + text = '' + goreleaser release --snapshot --clean + ''; +} diff --git a/nix/scripts/clean.nix b/nix/scripts/clean.nix new file mode 100644 index 0000000..f95e2bb --- /dev/null +++ b/nix/scripts/clean.nix @@ -0,0 +1,9 @@ +{ writeShellApplication, coreutils }: + +writeShellApplication { + name = "clean"; + runtimeInputs = [ coreutils ]; + text = '' + rm -rf dist/ + ''; +} diff --git a/nix/scripts/lint.nix b/nix/scripts/lint.nix new file mode 100644 index 0000000..29482ac --- /dev/null +++ b/nix/scripts/lint.nix @@ -0,0 +1,9 @@ +{ writeShellApplication, golangci-lint }: + +writeShellApplication { + name = "lint"; + runtimeInputs = [ golangci-lint ]; + text = '' + golangci-lint run ./... + ''; +} diff --git a/nix/scripts/release.nix b/nix/scripts/release.nix new file mode 100644 index 0000000..812dcee --- /dev/null +++ b/nix/scripts/release.nix @@ -0,0 +1,22 @@ +{ + writeShellApplication, + goreleaser, + go, + svu, + git, +}: + +writeShellApplication { + name = "release"; + runtimeInputs = [ + goreleaser + go + svu + git + ]; + text = '' + VERSION=$(svu next) + git tag "$VERSION" + goreleaser release --clean + ''; +} diff --git a/nix/scripts/test.nix b/nix/scripts/test.nix new file mode 100644 index 0000000..6f9a115 --- /dev/null +++ b/nix/scripts/test.nix @@ -0,0 +1,9 @@ +{ writeShellApplication, go }: + +writeShellApplication { + name = "test"; + runtimeInputs = [ go ]; + text = '' + go test -race ./... + ''; +} diff --git a/nix/shell.nix b/nix/shell.nix new file mode 100644 index 0000000..dd01300 --- /dev/null +++ b/nix/shell.nix @@ -0,0 +1,19 @@ +{ + mkShell, + go, + gopls, + golangci-lint, + nixd, +}: + +mkShell { + packages = [ + go + gopls + golangci-lint + nixd + ]; + shellHook = '' + export GOFLAGS="-buildvcs=false" + ''; +}