diff --git a/cmd/aifr/cmd_hook.go b/cmd/aifr/cmd_hook.go new file mode 100644 index 0000000..b00a43c --- /dev/null +++ b/cmd/aifr/cmd_hook.go @@ -0,0 +1,35 @@ +// Copyright 2026 — see LICENSE file for terms. +package main + +import "github.com/spf13/cobra" + +var hookCmd = &cobra.Command{ + Use: "hook", + Short: "Hooks for AI coding agent integration", + Long: `Commands designed for use as hooks in AI coding agents such as Claude Code. + +These sub-commands read hook payloads from stdin and write hook responses +to stdout, following the agent's hook protocol. + +Example Claude Code configuration: + + { + "hooks": { + "PreToolUse": [ + { + "matcher": "Bash", + "hooks": [ + { + "type": "command", + "command": "aifr hook check-command" + } + ] + } + ] + } + }`, +} + +func init() { + rootCmd.AddCommand(hookCmd) +} diff --git a/cmd/aifr/cmd_hook_checkcommand.go b/cmd/aifr/cmd_hook_checkcommand.go new file mode 100644 index 0000000..d42c2a5 --- /dev/null +++ b/cmd/aifr/cmd_hook_checkcommand.go @@ -0,0 +1,77 @@ +// Copyright 2026 — see LICENSE file for terms. +package main + +import ( + "encoding/json" + "io" + "os" + + "github.com/spf13/cobra" + + "go.pennock.tech/aifr/internal/hookcmd" +) + +var checkCommandMCP bool + +var checkCommandCmd = &cobra.Command{ + Use: "check-command", + Short: "Suggest aifr alternatives for Bash tool calls", + Long: `Reads a Claude Code PreToolUse hook payload from stdin, analyzes the +shell command, and if aifr can handle it, outputs a hook response denying +the Bash call and suggesting the aifr alternative. + +If the command is not something aifr handles, exits silently (exit 0, +no output) so the Bash call continues through normal permission evaluation. + +Pipelines ending in | head -n N or | tail -n N are recognized and mapped +to the appropriate aifr limit parameter (--max-count, --limit, --lines, etc.). + +When --mcp is set, or when an aifr MCP server is detected in .mcp.json, +suggestions reference MCP tool calls instead of CLI sub-commands. + +Recognized commands: cat, head, tail, grep/rg, find, ls, wc, stat, +diff, sed -n, sha256sum/md5sum, hexdump/xxd, git log, git diff. + +Usage in Claude Code settings: + + { + "hooks": { + "PreToolUse": [ + { + "matcher": "Bash", + "hooks": [ + { + "type": "command", + "command": "aifr hook check-command" + } + ] + } + ] + } + }`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + input, err := io.ReadAll(os.Stdin) + if err != nil { + return err + } + + result, err := hookcmd.CheckCommand(input, checkCommandMCP) + if err != nil { + return err + } + if result == nil { + return nil + } + + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + return enc.Encode(result) + }, +} + +func init() { + checkCommandCmd.Flags().BoolVar(&checkCommandMCP, "mcp", false, + "suggest MCP tool calls (auto-detected from .mcp.json and $AIFR_MCP if not set)") + hookCmd.AddCommand(checkCommandCmd) +} diff --git a/go.mod b/go.mod index 6797d6a..74aafec 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/pelletier/go-toml/v2 v2.3.0 github.com/spf13/cobra v1.10.2 golang.org/x/crypto v0.49.0 + mvdan.cc/sh/v3 v3.13.1 ) require ( diff --git a/go.sum b/go.sum index 30b1655..9dfd4d7 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMj github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= github.com/go-git/go-git/v5 v5.17.2 h1:B+nkdlxdYrvyFK4GPXVU8w1U+YkbsgciIR7f2sZJ104= github.com/go-git/go-git/v5 v5.17.2/go.mod h1:pW/VmeqkanRFqR6AljLcs7EA7FbZaN5MQqO7oZADXpo= +github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= +github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= @@ -182,3 +184,5 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +mvdan.cc/sh/v3 v3.13.1 h1:DP3TfgZhDkT7lerUdnp6PTGKyxxzz6T+cOlY/xEvfWk= +mvdan.cc/sh/v3 v3.13.1/go.mod h1:lXJ8SexMvEVcHCoDvAGLZgFJ9Wsm2sulmoNEXGhYZD0= diff --git a/internal/hookcmd/hookcmd.go b/internal/hookcmd/hookcmd.go new file mode 100644 index 0000000..b185964 --- /dev/null +++ b/internal/hookcmd/hookcmd.go @@ -0,0 +1,90 @@ +// Copyright 2026 — see LICENSE file for terms. +package hookcmd + +import ( + "encoding/json" + "fmt" +) + +// HookInput is the JSON payload received from a Claude Code hook on stdin. +type HookInput struct { + SessionID string `json:"session_id"` + CWD string `json:"cwd"` + ToolName string `json:"tool_name"` + ToolInput json.RawMessage `json:"tool_input"` + HookEventName string `json:"hook_event_name"` +} + +// BashInput is the tool_input for a Bash tool call. +type BashInput struct { + Command string `json:"command"` +} + +// HookOutput is the JSON response for a Claude Code hook. +type HookOutput struct { + HookSpecificOutput *HookDecision `json:"hookSpecificOutput"` +} + +// HookDecision describes the hook's permission decision. +type HookDecision struct { + HookEventName string `json:"hookEventName"` + Decision string `json:"permissionDecision"` + Reason string `json:"permissionDecisionReason,omitempty"` +} + +// CheckCommand parses a PreToolUse hook payload and returns a hook output +// denying the command with an aifr suggestion, or nil if no suggestion applies. +// +// When forceMCP is true, suggestions always reference MCP tool calls. +// Otherwise, MCP availability is auto-detected from the working directory's +// .mcp.json and the AIFR_MCP environment variable. +func CheckCommand(input []byte, forceMCP bool) (*HookOutput, error) { + var hi HookInput + if err := json.Unmarshal(input, &hi); err != nil { + return nil, err + } + + if hi.ToolName != "Bash" { + return nil, nil + } + + var bi BashInput + if err := json.Unmarshal(hi.ToolInput, &bi); err != nil { + return nil, err + } + + suggestion := AnalyzeCommand(bi.Command) + if suggestion == nil { + return nil, nil + } + + mcpMode := forceMCP || detectMCPAvailable(hi.CWD) + + var reason string + if mcpMode { + reason = formatMCPReason(suggestion) + } else { + reason = formatCLIReason(suggestion) + } + + return &HookOutput{ + HookSpecificOutput: &HookDecision{ + HookEventName: "PreToolUse", + Decision: "deny", + Reason: reason, + }, + }, nil +} + +func formatCLIReason(s *Suggestion) string { + return "This " + s.Original + + " invocation can be handled by aifr with access controls. Use: " + + s.AifrCommand +} + +func formatMCPReason(s *Suggestion) string { + argsJSON, _ := json.Marshal(s.ToolArgs) + return fmt.Sprintf( + "This %s invocation can be handled by aifr with access controls. Use the %s tool: %s", + s.Original, s.ToolName, string(argsJSON)) +} diff --git a/internal/hookcmd/hookcmd_test.go b/internal/hookcmd/hookcmd_test.go new file mode 100644 index 0000000..03f4b8a --- /dev/null +++ b/internal/hookcmd/hookcmd_test.go @@ -0,0 +1,195 @@ +// Copyright 2026 — see LICENSE file for terms. +package hookcmd + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestCheckCommand_BashWithSuggestion(t *testing.T) { + input := `{ + "session_id": "test-session", + "cwd": "/tmp/nonexistent", + "tool_name": "Bash", + "tool_input": {"command": "cat main.go"}, + "hook_event_name": "PreToolUse" + }` + + result, err := CheckCommand([]byte(input), false) + if err != nil { + t.Fatal(err) + } + if result == nil { + t.Fatal("expected result, got nil") + } + if result.HookSpecificOutput == nil { + t.Fatal("expected HookSpecificOutput, got nil") + } + if result.HookSpecificOutput.Decision != "deny" { + t.Errorf("expected deny, got %q", result.HookSpecificOutput.Decision) + } + if result.HookSpecificOutput.HookEventName != "PreToolUse" { + t.Errorf("expected PreToolUse, got %q", result.HookSpecificOutput.HookEventName) + } +} + +func TestCheckCommand_BashNoSuggestion(t *testing.T) { + input := `{ + "session_id": "test-session", + "cwd": "/tmp/nonexistent", + "tool_name": "Bash", + "tool_input": {"command": "go test ./..."}, + "hook_event_name": "PreToolUse" + }` + + result, err := CheckCommand([]byte(input), false) + if err != nil { + t.Fatal(err) + } + if result != nil { + t.Errorf("expected nil, got result with decision %q", result.HookSpecificOutput.Decision) + } +} + +func TestCheckCommand_NonBashTool(t *testing.T) { + input := `{ + "session_id": "test-session", + "cwd": "/tmp/nonexistent", + "tool_name": "Read", + "tool_input": {"file_path": "/tmp/test.go"}, + "hook_event_name": "PreToolUse" + }` + + result, err := CheckCommand([]byte(input), false) + if err != nil { + t.Fatal(err) + } + if result != nil { + t.Error("expected nil for non-Bash tool") + } +} + +func TestCheckCommand_InvalidJSON(t *testing.T) { + _, err := CheckCommand([]byte("not json"), false) + if err == nil { + t.Error("expected error for invalid JSON") + } +} + +// TestCheckCommand_PipelineSuggestion is an end-to-end wiring test verifying +// that a command pipeline with a recognized | head tail produces a suggestion +// with the appropriate per-command limit parameter. The full scope of pipeline +// and complex command analysis is covered in suggest_test.go. +func TestCheckCommand_PipelineSuggestion(t *testing.T) { + input := `{ + "session_id": "test-session", + "cwd": "/tmp/nonexistent", + "tool_name": "Bash", + "tool_input": {"command": "git log --oneline | head -n 10"}, + "hook_event_name": "PreToolUse" + }` + + result, err := CheckCommand([]byte(input), false) + if err != nil { + t.Fatal(err) + } + if result == nil { + t.Fatal("expected suggestion for pipeline command, got nil") + } + if result.HookSpecificOutput.Decision != "deny" { + t.Errorf("expected deny, got %q", result.HookSpecificOutput.Decision) + } + reason := result.HookSpecificOutput.Reason + if !strings.Contains(reason, "aifr log") && !strings.Contains(reason, "aifr_log") { + t.Errorf("reason should mention aifr log, got %q", reason) + } + if !strings.Contains(reason, "max-count") && !strings.Contains(reason, "max_count") { + t.Errorf("reason should mention max-count/max_count, got %q", reason) + } +} + +func TestCheckCommand_OutputFormat(t *testing.T) { + input := `{ + "session_id": "s1", + "cwd": "/tmp/nonexistent", + "tool_name": "Bash", + "tool_input": {"command": "head -50 README.md"}, + "hook_event_name": "PreToolUse" + }` + + result, err := CheckCommand([]byte(input), false) + if err != nil { + t.Fatal(err) + } + if result == nil { + t.Fatal("expected result") + } + + // Verify it marshals to valid JSON with expected structure. + data, err := json.Marshal(result) + if err != nil { + t.Fatal(err) + } + + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatal(err) + } + + hso, ok := decoded["hookSpecificOutput"].(map[string]any) + if !ok { + t.Fatal("missing hookSpecificOutput") + } + if hso["hookEventName"] != "PreToolUse" { + t.Errorf("hookEventName: %v", hso["hookEventName"]) + } + if hso["permissionDecision"] != "deny" { + t.Errorf("permissionDecision: %v", hso["permissionDecision"]) + } + reason, _ := hso["permissionDecisionReason"].(string) + if reason == "" { + t.Error("expected non-empty reason") + } +} + +func TestCheckCommand_MCPMode(t *testing.T) { + input := `{ + "session_id": "test-session", + "cwd": "/tmp/nonexistent", + "tool_name": "Bash", + "tool_input": {"command": "cat main.go"}, + "hook_event_name": "PreToolUse" + }` + + // CLI mode (forceMCP=false, no .mcp.json in /tmp/nonexistent) + cliResult, err := CheckCommand([]byte(input), false) + if err != nil { + t.Fatal(err) + } + if cliResult == nil { + t.Fatal("expected result") + } + cliReason := cliResult.HookSpecificOutput.Reason + if cliReason == "" { + t.Fatal("expected non-empty CLI reason") + } + + // MCP mode (forceMCP=true) + mcpResult, err := CheckCommand([]byte(input), true) + if err != nil { + t.Fatal(err) + } + if mcpResult == nil { + t.Fatal("expected result") + } + mcpReason := mcpResult.HookSpecificOutput.Reason + if mcpReason == "" { + t.Fatal("expected non-empty MCP reason") + } + + // CLI reason should reference the CLI command. + if cliReason == mcpReason { + t.Error("CLI and MCP reasons should differ") + } +} diff --git a/internal/hookcmd/mcpdetect.go b/internal/hookcmd/mcpdetect.go new file mode 100644 index 0000000..6462c06 --- /dev/null +++ b/internal/hookcmd/mcpdetect.go @@ -0,0 +1,52 @@ +// Copyright 2026 — see LICENSE file for terms. +package hookcmd + +import ( + "encoding/json" + "os" + "path/filepath" +) + +// detectMCPAvailable checks whether an aifr MCP server is likely available +// in the current Claude Code session. +// +// Detection order: +// 1. AIFR_MCP environment variable (any non-empty value → true) +// 2. .mcp.json in the given working directory +func detectMCPAvailable(cwd string) bool { + if os.Getenv("AIFR_MCP") != "" { + return true + } + if cwd != "" { + if checkMCPConfig(filepath.Join(cwd, ".mcp.json")) { + return true + } + } + return false +} + +// checkMCPConfig reads a .mcp.json file and returns true if it contains +// an aifr MCP server entry (matched by server name or command basename). +func checkMCPConfig(path string) bool { + data, err := os.ReadFile(path) + if err != nil { + return false + } + var config struct { + MCPServers map[string]struct { + Command string `json:"command"` + } `json:"mcpServers"` + } + if err := json.Unmarshal(data, &config); err != nil { + return false + } + for name, server := range config.MCPServers { + if name == "aifr" { + return true + } + if filepath.Base(server.Command) == "aifr" { + return true + } + } + return false +} diff --git a/internal/hookcmd/shellparse.go b/internal/hookcmd/shellparse.go new file mode 100644 index 0000000..8b8e065 --- /dev/null +++ b/internal/hookcmd/shellparse.go @@ -0,0 +1,271 @@ +// Copyright 2026 — see LICENSE file for terms. + +// Package hookcmd implements command analysis for AI coding agent hooks. +// It parses shell commands from hook payloads and suggests aifr alternatives +// when the command can be safely handled by aifr. +package hookcmd + +import ( + "path/filepath" + "strconv" + "strings" + + "mvdan.cc/sh/v3/syntax" +) + +// parsedCommand represents a shell command extracted from a parsed AST. +type parsedCommand struct { + Name string // base command name (e.g., "cat", "grep") + Args []string // argument values (unquoted where possible) +} + +// parseShellCommand parses a shell command string into a primary command and +// an optional pipeline modifier. Returns nil if the command is too complex +// to analyze (multiple statements, subshells, control operators, etc.). +// +// Two-stage pipelines where the second stage is head, tail, or sed -n are +// recognized and returned as a modifier on the first stage. +func parseShellCommand(command string) (*parsedCommand, PipelineModifier) { + parser := syntax.NewParser(syntax.Variant(syntax.LangBash)) + file, err := parser.Parse(strings.NewReader(command), "") + if err != nil { + return nil, PipelineModifier{} + } + + if len(file.Stmts) != 1 { + return nil, PipelineModifier{} + } + + stmt := file.Stmts[0] + if stmt.Background || stmt.Negated || stmt.Coprocess { + return nil, PipelineModifier{} + } + + switch cmd := stmt.Cmd.(type) { + case *syntax.CallExpr: + parsed := extractCall(cmd) + return parsed, PipelineModifier{} + + case *syntax.BinaryCmd: + if cmd.Op != syntax.Pipe { + return nil, PipelineModifier{} // &&, || + } + return parsePipelineCmd(cmd) + + default: + // if, for, while, case, subshell, function decl, etc. + return nil, PipelineModifier{} + } +} + +// parsePipelineCmd handles a two-stage pipeline (cmd | head/tail). +// Returns nil for 3+ stage pipelines or when the right side isn't head/tail. +func parsePipelineCmd(bc *syntax.BinaryCmd) (*parsedCommand, PipelineModifier) { + // Both sides must be simple commands (not nested pipelines or control structures). + leftCall, ok := bc.X.Cmd.(*syntax.CallExpr) + if !ok { + return nil, PipelineModifier{} + } + rightCall, ok := bc.Y.Cmd.(*syntax.CallExpr) + if !ok { + return nil, PipelineModifier{} + } + + right := extractCall(rightCall) + if right == nil { + return nil, PipelineModifier{} + } + + mod := pipeTailModifier(right) + if !mod.IsSet() { + return nil, PipelineModifier{} + } + + left := extractCall(leftCall) + return left, mod +} + +// pipeTailModifier checks if a parsed command is head, tail, or sed -n and +// extracts the line count or range as a PipelineModifier. +func pipeTailModifier(cmd *parsedCommand) PipelineModifier { + switch cmd.Name { + case "head": + return PipelineModifier{HeadLines: parseHeadTailN(cmd.Args, 10)} + case "tail": + if hasFlag(cmd.Args, "-f", "--follow", "-F") { + return PipelineModifier{} + } + return PipelineModifier{TailLines: parseHeadTailN(cmd.Args, 10)} + case "sed": + return parseSedModifier(cmd.Args) + default: + return PipelineModifier{} + } +} + +// parseSedModifier parses sed -n 'Np' or 'N,Mp' as a pipeline modifier. +// When the range starts at line 1, it normalizes to HeadLines for +// compatibility with commands that support head-style limits. +func parseSedModifier(args []string) PipelineModifier { + if !hasFlag(args, "-n") { + return PipelineModifier{} + } + + // Find the script argument (first non-flag after -n). + var script string + sawN := false + for _, a := range args { + if a == "-n" { + sawN = true + continue + } + if strings.HasPrefix(a, "-") { + continue + } + if sawN { + script = a + break + } + } + if script == "" { + return PipelineModifier{} + } + + script = strings.TrimSuffix(script, "p") + if script == "" { + return PipelineModifier{} + } + + parts := strings.SplitN(script, ",", 2) + if len(parts) == 1 { + n, err := strconv.Atoi(parts[0]) + if err != nil || n <= 0 { + return PipelineModifier{} + } + if n == 1 { + return PipelineModifier{HeadLines: 1} + } + return PipelineModifier{StartLine: n, EndLine: n} + } + + start, err := strconv.Atoi(parts[0]) + if err != nil || start <= 0 { + return PipelineModifier{} + } + end, err := strconv.Atoi(parts[1]) + if err != nil || end <= 0 { + return PipelineModifier{} + } + // Normalize: start=1 is equivalent to head. + if start == 1 { + return PipelineModifier{HeadLines: end} + } + return PipelineModifier{StartLine: start, EndLine: end} +} + +// extractCall extracts a parsedCommand from a CallExpr AST node. +// Variable assignments (LANG=C cmd) are automatically excluded since the +// parser places them in CallExpr.Assigns, not Args. +func extractCall(ce *syntax.CallExpr) *parsedCommand { + if len(ce.Args) == 0 { + return nil + } + + name := wordValue(ce.Args[0]) + if name == "" { + return nil + } + + args := make([]string, len(ce.Args)-1) + for i, w := range ce.Args[1:] { + args[i] = wordValue(w) + } + + return &parsedCommand{ + Name: filepath.Base(name), + Args: args, + } +} + +// wordValue extracts the effective string value from a shell Word, +// stripping quotes where possible. For words containing parameter expansions +// or command substitutions, falls back to the printed shell representation. +func wordValue(w *syntax.Word) string { + if s := w.Lit(); s != "" { + return s + } + + if len(w.Parts) == 1 { + switch p := w.Parts[0].(type) { + case *syntax.Lit: + return p.Value + case *syntax.SglQuoted: + return p.Value + case *syntax.DblQuoted: + return dblQuotedLiteral(p) + } + } + + var buf strings.Builder + syntax.NewPrinter().Print(&buf, w) + return buf.String() +} + +// dblQuotedLiteral extracts the literal content of a double-quoted string. +// If the string contains expansions, falls back to printer output. +func dblQuotedLiteral(dq *syntax.DblQuoted) string { + var sb strings.Builder + for _, p := range dq.Parts { + lit, ok := p.(*syntax.Lit) + if !ok { + var buf strings.Builder + syntax.NewPrinter().Print(&buf, dq) + return buf.String() + } + sb.WriteString(lit.Value) + } + return sb.String() +} + +// nonFlags returns elements of args that don't start with '-'. +// Redirections are not present in args (the parser handles them separately). +func nonFlags(args []string) []string { + var out []string + for _, t := range args { + if !strings.HasPrefix(t, "-") { + out = append(out, t) + } + } + return out +} + +// hasFlag reports whether any element of tokens exactly matches one of the given flags. +func hasFlag(tokens []string, flags ...string) bool { + for _, t := range tokens { + for _, f := range flags { + if t == f { + return true + } + } + } + return false +} + +// shellQuote returns s quoted for shell use if it contains special characters. +func shellQuote(s string) string { + if s == "" { + return "''" + } + safe := true + for _, r := range s { + if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || + r == '/' || r == '.' || r == '_' || r == '-' || r == ':' || r == '~' || r == '+' || r == '@') { + safe = false + break + } + } + if safe { + return s + } + return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'" +} diff --git a/internal/hookcmd/shellparse_test.go b/internal/hookcmd/shellparse_test.go new file mode 100644 index 0000000..322e8f3 --- /dev/null +++ b/internal/hookcmd/shellparse_test.go @@ -0,0 +1,212 @@ +// Copyright 2026 — see LICENSE file for terms. +package hookcmd + +import ( + "reflect" + "testing" +) + +func TestParseShellCommand_Simple(t *testing.T) { + cases := []struct { + input string + wantName string + wantArgs []string + }{ + {"cat file.go", "cat", []string{"file.go"}}, + {"head -n 50 file.go", "head", []string{"-n", "50", "file.go"}}, + {`grep "hello world" .`, "grep", []string{"hello world", "."}}, + {`cat 'file with spaces.go'`, "cat", []string{"file with spaces.go"}}, + {"/usr/bin/cat file.go", "cat", []string{"file.go"}}, + {"ls -la src/", "ls", []string{"-la", "src/"}}, + } + for _, tc := range cases { + t.Run(tc.input, func(t *testing.T) { + parsed, mod := parseShellCommand(tc.input) + if parsed == nil { + t.Fatal("expected parsed command, got nil") + } + if parsed.Name != tc.wantName { + t.Errorf("Name: got %q, want %q", parsed.Name, tc.wantName) + } + if !reflect.DeepEqual(parsed.Args, tc.wantArgs) { + t.Errorf("Args: got %v, want %v", parsed.Args, tc.wantArgs) + } + if mod.IsSet() { + t.Errorf("expected no modifier, got %+v", mod) + } + }) + } +} + +func TestParseShellCommand_EnvVars(t *testing.T) { + parsed, _ := parseShellCommand("LANG=C cat file.go") + if parsed == nil { + t.Fatal("expected parsed command, got nil") + } + if parsed.Name != "cat" { + t.Errorf("Name: got %q, want %q", parsed.Name, "cat") + } + if !reflect.DeepEqual(parsed.Args, []string{"file.go"}) { + t.Errorf("Args: got %v, want %v", parsed.Args, []string{"file.go"}) + } +} + +func TestParseShellCommand_Redirections(t *testing.T) { + // Redirections should not appear in Args (parser handles them separately). + parsed, _ := parseShellCommand("cat file.go > out.txt") + if parsed == nil { + t.Fatal("expected parsed command, got nil") + } + if parsed.Name != "cat" { + t.Errorf("Name: got %q, want %q", parsed.Name, "cat") + } + if !reflect.DeepEqual(parsed.Args, []string{"file.go"}) { + t.Errorf("Args: got %v, want %v (redirections should be excluded)", parsed.Args, []string{"file.go"}) + } +} + +func TestParseShellCommand_Pipeline(t *testing.T) { + cases := []struct { + input string + wantName string + wantHead int + wantTail int + }{ + {"cat file.go | head -n 50", "cat", 50, 0}, + {"cat file.go | head -10", "cat", 10, 0}, + {"cat file.go | head", "cat", 10, 0}, // default 10 + {"cat file.go | tail -n 20", "cat", 0, 20}, + {"git log --oneline | head -n 10", "git", 10, 0}, + {"grep TODO . | head -5", "grep", 5, 0}, + } + for _, tc := range cases { + t.Run(tc.input, func(t *testing.T) { + parsed, mod := parseShellCommand(tc.input) + if parsed == nil { + t.Fatal("expected parsed command, got nil") + } + if parsed.Name != tc.wantName { + t.Errorf("Name: got %q, want %q", parsed.Name, tc.wantName) + } + if mod.HeadLines != tc.wantHead { + t.Errorf("HeadLines: got %d, want %d", mod.HeadLines, tc.wantHead) + } + if mod.TailLines != tc.wantTail { + t.Errorf("TailLines: got %d, want %d", mod.TailLines, tc.wantTail) + } + }) + } +} + +func TestParseShellCommand_PipelineSed(t *testing.T) { + cases := []struct { + input string + wantName string + wantHead int + wantStart int + wantEnd int + }{ + // sed -n '1,Np' normalizes to HeadLines. + {"cat file.go | sed -n '1,50p'", "cat", 50, 0, 0}, + {"cat file.go | sed -n '1p'", "cat", 1, 0, 0}, + // Arbitrary ranges set StartLine/EndLine. + {"cat file.go | sed -n '5,10p'", "cat", 0, 5, 10}, + {"cat file.go | sed -n '100,200p'", "cat", 0, 100, 200}, + // Single line extraction. + {"cat file.go | sed -n '42p'", "cat", 0, 42, 42}, + } + for _, tc := range cases { + t.Run(tc.input, func(t *testing.T) { + parsed, mod := parseShellCommand(tc.input) + if parsed == nil { + t.Fatal("expected parsed command, got nil") + } + if parsed.Name != tc.wantName { + t.Errorf("Name: got %q, want %q", parsed.Name, tc.wantName) + } + if mod.HeadLines != tc.wantHead { + t.Errorf("HeadLines: got %d, want %d", mod.HeadLines, tc.wantHead) + } + if mod.StartLine != tc.wantStart { + t.Errorf("StartLine: got %d, want %d", mod.StartLine, tc.wantStart) + } + if mod.EndLine != tc.wantEnd { + t.Errorf("EndLine: got %d, want %d", mod.EndLine, tc.wantEnd) + } + }) + } +} + +func TestParseShellCommand_Complex(t *testing.T) { + // All of these should return nil (too complex to analyze). + cases := []struct { + name string + command string + }{ + {"empty", ""}, + {"three-stage pipeline", "cat file | grep pattern | head -5"}, + {"unknown pipe target", "cat file | sort"}, + {"double ampersand", "cd /tmp && ls"}, + {"logical or", "cat file || echo fallback"}, + {"semicolon", "echo hello; echo world"}, + {"background", "cat file &"}, + {"subshell", "(cat file)"}, + {"pipe in quotes OK", `grep "a|b" file`}, // single command, not nil + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + parsed, _ := parseShellCommand(tc.command) + switch tc.name { + case "pipe in quotes OK": + if parsed == nil { + t.Error("grep with | in pattern should parse as simple command") + } + default: + if parsed != nil { + t.Errorf("expected nil for complex command, got %+v", parsed) + } + } + }) + } +} + +func TestShellQuote(t *testing.T) { + cases := []struct { + input string + want string + }{ + {"file.go", "file.go"}, + {"src/main.go", "src/main.go"}, + {"path with spaces", "'path with spaces'"}, + {"it's", "'it'\\''s'"}, + {"", "''"}, + {"*.go", "'*.go'"}, + } + for _, tc := range cases { + t.Run(tc.input, func(t *testing.T) { + got := shellQuote(tc.input) + if got != tc.want { + t.Errorf("shellQuote(%q) = %q, want %q", tc.input, got, tc.want) + } + }) + } +} + +func TestNonFlags(t *testing.T) { + cases := []struct { + input []string + want []string + }{ + {[]string{"-n", "file.go"}, []string{"file.go"}}, + {[]string{"-la", "src/"}, []string{"src/"}}, + {[]string{"-l", "-w", "file.go"}, []string{"file.go"}}, + } + for _, tc := range cases { + t.Run("", func(t *testing.T) { + got := nonFlags(tc.input) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("nonFlags(%v) = %v, want %v", tc.input, got, tc.want) + } + }) + } +} diff --git a/internal/hookcmd/suggest.go b/internal/hookcmd/suggest.go new file mode 100644 index 0000000..39263ed --- /dev/null +++ b/internal/hookcmd/suggest.go @@ -0,0 +1,698 @@ +// Copyright 2026 — see LICENSE file for terms. +package hookcmd + +import ( + "fmt" + "strconv" + "strings" +) + +// Suggestion represents an aifr command that can replace a shell command. +type Suggestion struct { + // Original is the original shell command (or base command name). + Original string + // AifrCommand is the suggested aifr CLI invocation. + AifrCommand string + // ToolName is the MCP tool name (e.g., "aifr_read"). + ToolName string + // ToolArgs are the MCP tool parameters. + ToolArgs map[string]any +} + +// PipelineModifier captures a trailing | head, | tail, or | sed -n in a pipeline. +type PipelineModifier struct { + HeadLines int // > 0 if piped to head -n N + TailLines int // > 0 if piped to tail -n N + StartLine int // > 0 if piped to sed -n extracting from this line (1-based) + EndLine int // > 0 if piped to sed -n extracting to this line (1-based) +} + +// IsSet reports whether any pipeline modifier is active. +func (m PipelineModifier) IsSet() bool { + return m.HeadLines > 0 || m.TailLines > 0 || m.StartLine > 0 +} + +// AnalyzeCommand checks if a shell command can be replaced by an aifr command. +// Returns nil if no suggestion applies. +func AnalyzeCommand(command string) *Suggestion { + command = strings.TrimSpace(command) + if command == "" { + return nil + } + + parsed, mod := parseShellCommand(command) + if parsed == nil { + return nil + } + + // Already an aifr invocation — nothing to suggest. + if parsed.Name == "aifr" { + return nil + } + + switch parsed.Name { + case "cat": + return suggestCat(parsed.Args, mod) + case "head": + return suggestHead(parsed.Args, mod) + case "tail": + return suggestTail(parsed.Args, mod) + case "grep", "egrep", "fgrep", "rg": + return suggestSearch(parsed.Name, parsed.Args, mod) + case "find": + return suggestFind(parsed.Args, mod) + case "ls": + return suggestList(parsed.Args, mod) + case "wc": + return suggestWc(parsed.Args, mod) + case "stat": + return suggestStat(parsed.Args, mod) + case "diff": + return suggestDiff(parsed.Args, mod) + case "sha256sum", "sha1sum", "md5sum", "shasum", "sha384sum", "sha512sum", "b2sum": + return suggestChecksum(parsed.Name, parsed.Args, mod) + case "hexdump", "xxd", "od": + return suggestHexdump(parsed.Name, parsed.Args, mod) + case "sed": + return suggestSed(parsed.Args, mod) + case "git": + return suggestGit(parsed.Args, mod) + default: + return nil + } +} + +// parseHeadTailN extracts the line count from head/tail arguments. +// Handles -n N, -nN, and -N forms. Returns defaultN if not found. +func parseHeadTailN(args []string, defaultN int) int { + for i := 0; i < len(args); i++ { + a := args[i] + switch { + case a == "-n" && i+1 < len(args): + if v, err := strconv.Atoi(args[i+1]); err == nil { + return v + } + i++ + case strings.HasPrefix(a, "-n"): + if v, err := strconv.Atoi(a[2:]); err == nil { + return v + } + case len(a) > 1 && a[0] == '-' && isDigits(a[1:]): + if v, err := strconv.Atoi(a[1:]); err == nil { + return v + } + } + } + return defaultN +} + +func makeSuggestion(original, aifrCmd, toolName string, toolArgs map[string]any) *Suggestion { + return &Suggestion{ + Original: original, + AifrCommand: aifrCmd, + ToolName: toolName, + ToolArgs: toolArgs, + } +} + +func suggestCat(args []string, mod PipelineModifier) *Suggestion { + files := nonFlags(args) + if len(files) == 0 { + return nil // reading from stdin + } + + if mod.HeadLines > 0 { + if len(files) != 1 { + return nil // multi-file cat with head doesn't map cleanly + } + lines := fmt.Sprintf("1:%d", mod.HeadLines) + return makeSuggestion("cat", + fmt.Sprintf("aifr read --lines=%s %s", lines, shellQuote(files[0])), + "aifr_read", + map[string]any{"path": files[0], "lines": lines}) + } + + if mod.TailLines > 0 { + if len(files) != 1 { + return nil + } + lines := fmt.Sprintf("-%d:", mod.TailLines) + return makeSuggestion("cat", + fmt.Sprintf("aifr read --lines=%s %s", lines, shellQuote(files[0])), + "aifr_read", + map[string]any{"path": files[0], "lines": lines}) + } + + if mod.StartLine > 0 { + if len(files) != 1 { + return nil + } + lines := fmt.Sprintf("%d:%d", mod.StartLine, mod.EndLine) + return makeSuggestion("cat", + fmt.Sprintf("aifr read --lines=%s %s", lines, shellQuote(files[0])), + "aifr_read", + map[string]any{"path": files[0], "lines": lines}) + } + + if len(files) == 1 { + return makeSuggestion("cat", + "aifr read "+shellQuote(files[0]), + "aifr_read", + map[string]any{"path": files[0]}) + } + parts := make([]string, len(files)) + for i, f := range files { + parts[i] = shellQuote(f) + } + return makeSuggestion("cat", + "aifr cat "+strings.Join(parts, " "), + "aifr_cat", + map[string]any{"paths": files}) +} + +func suggestHead(args []string, mod PipelineModifier) *Suggestion { + if mod.IsSet() { + return nil // head | head or head | tail is unusual + } + + n := 10 + var file string + for i := 0; i < len(args); i++ { + a := args[i] + switch { + case a == "-n" && i+1 < len(args): + if v, err := strconv.Atoi(args[i+1]); err == nil { + n = v + } + i++ + case strings.HasPrefix(a, "-n"): + if v, err := strconv.Atoi(a[2:]); err == nil { + n = v + } + case len(a) > 1 && a[0] == '-' && isDigits(a[1:]): + if v, err := strconv.Atoi(a[1:]); err == nil { + n = v + } + case !strings.HasPrefix(a, "-"): + if file == "" { + file = a + } + } + } + if file == "" { + return nil + } + lines := fmt.Sprintf("1:%d", n) + return makeSuggestion("head", + fmt.Sprintf("aifr read --lines=%s %s", lines, shellQuote(file)), + "aifr_read", + map[string]any{"path": file, "lines": lines}) +} + +func suggestTail(args []string, mod PipelineModifier) *Suggestion { + if mod.IsSet() { + return nil // tail | head or tail | tail is unusual + } + if hasFlag(args, "-f", "--follow", "-F") { + return nil // tail -f is a live-follow, aifr can't do that + } + + n := 10 + var file string + for i := 0; i < len(args); i++ { + a := args[i] + switch { + case a == "-n" && i+1 < len(args): + if v, err := strconv.Atoi(args[i+1]); err == nil { + n = v + } + i++ + case strings.HasPrefix(a, "-n"): + if v, err := strconv.Atoi(a[2:]); err == nil { + n = v + } + case len(a) > 1 && a[0] == '-' && isDigits(a[1:]): + if v, err := strconv.Atoi(a[1:]); err == nil { + n = v + } + case !strings.HasPrefix(a, "-"): + if file == "" { + file = a + } + } + } + if file == "" { + return nil + } + lines := fmt.Sprintf("-%d:", n) + return makeSuggestion("tail", + fmt.Sprintf("aifr read --lines=%s %s", lines, shellQuote(file)), + "aifr_read", + map[string]any{"path": file, "lines": lines}) +} + +func suggestSearch(baseCmd string, args []string, mod PipelineModifier) *Suggestion { + if mod.TailLines > 0 || mod.StartLine > 0 { + return nil + } + + var pattern string + var path string + recursive := false + + positional := 0 + for i := 0; i < len(args); i++ { + a := args[i] + if strings.HasPrefix(a, "-") { + if a == "-r" || a == "-R" || a == "--recursive" { + recursive = true + continue + } + if flagTakesValue(a) && i+1 < len(args) { + i++ + } + continue + } + switch positional { + case 0: + pattern = a + case 1: + path = a + } + positional++ + } + + if pattern == "" { + return nil + } + if path == "" && !recursive { + return nil // likely reading stdin + } + + effectivePath := path + if effectivePath == "" { + effectivePath = "." + } + + toolArgs := map[string]any{ + "pattern": pattern, + "path": effectivePath, + } + + cmd := "aifr search" + if mod.HeadLines > 0 { + cmd += fmt.Sprintf(" --max-matches=%d", mod.HeadLines) + toolArgs["max_matches"] = mod.HeadLines + } + cmd += " " + shellQuote(pattern) + " " + shellQuote(effectivePath) + + return makeSuggestion(baseCmd, cmd, "aifr_search", toolArgs) +} + +func suggestFind(args []string, mod PipelineModifier) *Suggestion { + if mod.TailLines > 0 || mod.StartLine > 0 { + return nil + } + + var path string + var name string + var ftype string + + for i := 0; i < len(args); i++ { + a := args[i] + switch a { + case "-name", "-iname": + if i+1 < len(args) { + name = args[i+1] + i++ + } + case "-type": + if i+1 < len(args) { + ftype = args[i+1] + i++ + } + default: + if !strings.HasPrefix(a, "-") && path == "" { + path = a + } + } + } + + if path == "" { + path = "." + } + + toolArgs := map[string]any{"path": path} + cmd := "aifr find " + shellQuote(path) + + if name != "" { + cmd += " --name=" + shellQuote(name) + toolArgs["name"] = name + } + if ftype != "" { + cmd += " --type=" + ftype + toolArgs["type"] = ftype + } + if mod.HeadLines > 0 { + cmd += fmt.Sprintf(" --limit=%d", mod.HeadLines) + toolArgs["limit"] = mod.HeadLines + } + + return makeSuggestion("find", cmd, "aifr_find", toolArgs) +} + +func suggestList(args []string, mod PipelineModifier) *Suggestion { + if mod.TailLines > 0 || mod.StartLine > 0 { + return nil + } + + files := nonFlags(args) + path := "." + if len(files) > 0 { + path = files[0] + } + + if hasFlag(args, "-R", "--recursive") { + toolArgs := map[string]any{"path": path} + cmd := "aifr find " + shellQuote(path) + if mod.HeadLines > 0 { + cmd += fmt.Sprintf(" --limit=%d", mod.HeadLines) + toolArgs["limit"] = mod.HeadLines + } + return makeSuggestion("ls", cmd, "aifr_find", toolArgs) + } + + toolArgs := map[string]any{"path": path} + cmd := "aifr list " + shellQuote(path) + if mod.HeadLines > 0 { + cmd += fmt.Sprintf(" --limit=%d", mod.HeadLines) + toolArgs["limit"] = mod.HeadLines + } + return makeSuggestion("ls", cmd, "aifr_list", toolArgs) +} + +func suggestWc(args []string, mod PipelineModifier) *Suggestion { + if mod.IsSet() { + return nil // wc output with head/tail doesn't map usefully + } + + files := nonFlags(args) + if len(files) == 0 { + return nil // reading from stdin + } + + toolArgs := map[string]any{"paths": files} + + parts := make([]string, len(files)) + for i, f := range files { + parts[i] = shellQuote(f) + } + cmd := "aifr wc" + if hasFlag(args, "-l") { + cmd += " -l" + toolArgs["lines"] = true + } + if hasFlag(args, "-w") { + cmd += " -w" + toolArgs["words"] = true + } + if hasFlag(args, "-c") { + cmd += " -c" + toolArgs["bytes"] = true + } + if hasFlag(args, "-m") { + cmd += " -m" + toolArgs["chars"] = true + } + cmd += " " + strings.Join(parts, " ") + return makeSuggestion("wc", cmd, "aifr_wc", toolArgs) +} + +func suggestStat(args []string, mod PipelineModifier) *Suggestion { + if mod.IsSet() { + return nil + } + files := nonFlags(args) + if len(files) == 0 { + return nil + } + return makeSuggestion("stat", + "aifr stat "+shellQuote(files[0]), + "aifr_stat", + map[string]any{"path": files[0]}) +} + +func suggestDiff(args []string, mod PipelineModifier) *Suggestion { + if mod.IsSet() { + return nil + } + files := nonFlags(args) + if len(files) != 2 { + return nil + } + return makeSuggestion("diff", + "aifr diff "+shellQuote(files[0])+" "+shellQuote(files[1]), + "aifr_diff", + map[string]any{"path_a": files[0], "path_b": files[1]}) +} + +func suggestChecksum(baseCmd string, args []string, mod PipelineModifier) *Suggestion { + if mod.IsSet() { + return nil + } + files := nonFlags(args) + if len(files) == 0 { + return nil + } + + algo := "sha256" + switch baseCmd { + case "sha1sum": + algo = "sha1" + case "sha256sum": + algo = "sha256" + case "sha384sum": + algo = "sha384" + case "sha512sum": + algo = "sha512" + case "md5sum": + algo = "md5" + case "b2sum": + return nil // aifr may not support blake2 + case "shasum": + algo = "sha1" + for i := 0; i < len(args); i++ { + if args[i] == "-a" && i+1 < len(args) { + algo = "sha" + args[i+1] + i++ + } + } + } + + parts := make([]string, len(files)) + for i, f := range files { + parts[i] = shellQuote(f) + } + return makeSuggestion(baseCmd, + fmt.Sprintf("aifr checksum -a %s %s", algo, strings.Join(parts, " ")), + "aifr_checksum", + map[string]any{"paths": files, "algorithm": algo}) +} + +func suggestHexdump(_ string, args []string, mod PipelineModifier) *Suggestion { + if mod.IsSet() { + return nil + } + files := nonFlags(args) + if len(files) == 0 { + return nil + } + return makeSuggestion("hexdump", + "aifr hexdump "+shellQuote(files[0]), + "aifr_hexdump", + map[string]any{"path": files[0]}) +} + +func suggestSed(args []string, mod PipelineModifier) *Suggestion { + if mod.IsSet() { + return nil + } + if !hasFlag(args, "-n") { + return nil // without -n, sed may be doing transformations + } + + var script string + var file string + sawN := false + + for i := 0; i < len(args); i++ { + a := args[i] + if a == "-n" { + sawN = true + continue + } + if strings.HasPrefix(a, "-") { + continue + } + if !sawN { + continue + } + if script == "" { + script = a + } else if file == "" { + file = a + } + } + + if script == "" || file == "" { + return nil + } + + script = strings.TrimSuffix(script, "p") + if script == "" { + return nil + } + + parts := strings.SplitN(script, ",", 2) + if len(parts) == 1 { + if _, err := strconv.Atoi(parts[0]); err != nil { + return nil + } + lines := parts[0] + ":" + parts[0] + return makeSuggestion("sed", + fmt.Sprintf("aifr read --lines=%s %s", lines, shellQuote(file)), + "aifr_read", + map[string]any{"path": file, "lines": lines}) + } + + if _, err := strconv.Atoi(parts[0]); err != nil { + return nil + } + if _, err := strconv.Atoi(parts[1]); err != nil { + return nil + } + lines := parts[0] + ":" + parts[1] + return makeSuggestion("sed", + fmt.Sprintf("aifr read --lines=%s %s", lines, shellQuote(file)), + "aifr_read", + map[string]any{"path": file, "lines": lines}) +} + +func suggestGit(args []string, mod PipelineModifier) *Suggestion { + if len(args) == 0 { + return nil + } + switch args[0] { + case "log": + return suggestGitLog(args[1:], mod) + case "diff": + return suggestGitDiff(args[1:], mod) + default: + return nil + } +} + +func suggestGitLog(args []string, mod PipelineModifier) *Suggestion { + if mod.TailLines > 0 || mod.StartLine > 0 { + return nil + } + + oneline := false + var maxCount int + var ref string + + for i := 0; i < len(args); i++ { + a := args[i] + switch { + case a == "--oneline": + oneline = true + case a == "-n" && i+1 < len(args): + maxCount, _ = strconv.Atoi(args[i+1]) + i++ + case strings.HasPrefix(a, "-n") && len(a) > 2 && isDigits(a[2:]): + maxCount, _ = strconv.Atoi(a[2:]) + case strings.HasPrefix(a, "--max-count="): + maxCount, _ = strconv.Atoi(strings.TrimPrefix(a, "--max-count=")) + case !strings.HasPrefix(a, "-"): + if ref == "" { + ref = a + } + } + } + + // Pipeline head overrides/sets max-count. + if mod.HeadLines > 0 { + maxCount = mod.HeadLines + } + + toolArgs := map[string]any{} + cmd := "aifr log" + + if oneline { + cmd += " --oneline" + toolArgs["format"] = "oneline" + } + if maxCount > 0 { + cmd += fmt.Sprintf(" --max-count=%d", maxCount) + toolArgs["max_count"] = maxCount + } + if ref != "" { + cmd += " " + shellQuote(ref) + toolArgs["ref"] = ref + } + + return makeSuggestion("git log", cmd, "aifr_log", toolArgs) +} + +func suggestGitDiff(args []string, mod PipelineModifier) *Suggestion { + if mod.IsSet() { + return nil // diff output with head/tail doesn't map cleanly + } + + refs := nonFlags(args) + var clean []string + for _, r := range refs { + if r != "--" { + clean = append(clean, r) + } + } + if len(clean) == 2 { + return makeSuggestion("git diff", + fmt.Sprintf("aifr diff %s:%s %s:%s", clean[0], ".", clean[1], "."), + "aifr_diff", + map[string]any{"path_a": clean[0] + ":.", "path_b": clean[1] + ":."}) + } + return makeSuggestion("git diff", "aifr diff", "aifr_diff", map[string]any{}) +} + +// isDigits reports whether s is non-empty and contains only ASCII digits. +func isDigits(s string) bool { + if s == "" { + return false + } + for _, r := range s { + if r < '0' || r > '9' { + return false + } + } + return true +} + +// flagTakesValue reports whether a grep/rg flag expects a following argument. +func flagTakesValue(flag string) bool { + switch flag { + case "-e", "-f", "--regexp", "--file", + "-m", "--max-count", + "-A", "--after-context", + "-B", "--before-context", + "-C", "--context", + "--include", "--exclude", "--exclude-dir", + "--color", "--colour", + "-d", "--directories", + "-D", "--devices", + "--label", + "--binary-files": + return true + } + return false +} diff --git a/internal/hookcmd/suggest_test.go b/internal/hookcmd/suggest_test.go new file mode 100644 index 0000000..ac45837 --- /dev/null +++ b/internal/hookcmd/suggest_test.go @@ -0,0 +1,531 @@ +// Copyright 2026 — see LICENSE file for terms. +package hookcmd + +import "testing" + +func TestAnalyzeCommand_NoSuggestion(t *testing.T) { + cases := []struct { + name string + command string + }{ + {"empty", ""}, + {"aifr invocation", "aifr read file.go"}, + {"aifr bare", "aifr"}, + {"unrecognized command", "go build ./..."}, + {"make", "make test"}, + {"npm", "npm install"}, + {"3-stage pipeline", "cat file | grep pattern | head -5"}, + {"unknown pipe target", "cat file | sort"}, + {"double ampersand", "cd /tmp && ls"}, + {"semicolon", "echo hello; echo world"}, + {"subshell", "$(cat file.go)"}, + {"cat from stdin", "cat"}, + {"cat from stdin with flags", "cat -v"}, + {"head from stdin", "head -n 5"}, + {"tail -f", "tail -f server.log"}, + {"tail --follow", "tail --follow server.log"}, + {"tail -f pipe", "tail -f server.log | head -5"}, + {"grep from stdin", "grep pattern"}, + {"wc from stdin", "wc -l"}, + {"stat no args", "stat"}, + {"sed without -n", "sed 's/foo/bar/' file.go"}, + {"git status", "git status"}, + {"git push", "git push origin main"}, + {"wc with head", "wc -l file.go | head -5"}, + {"stat with head", "stat file.go | head -5"}, + {"diff with head", "diff a.go b.go | head -5"}, + {"tail with head", "tail -20 file.go | head -5"}, + {"head with tail", "head -20 file.go | tail -5"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s != nil { + t.Errorf("expected nil, got suggestion: %s", s.AifrCommand) + } + }) + } +} + +func TestAnalyzeCommand_Cat(t *testing.T) { + cases := []struct { + command string + want string + wantTool string + }{ + {"cat file.go", "aifr read file.go", "aifr_read"}, + {"cat src/main.go", "aifr read src/main.go", "aifr_read"}, + {"/usr/bin/cat file.go", "aifr read file.go", "aifr_read"}, + {"cat file1.go file2.go", "aifr cat file1.go file2.go", "aifr_cat"}, + {"cat -n file.go", "aifr read file.go", "aifr_read"}, + } + for _, tc := range cases { + t.Run(tc.command, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != tc.want { + t.Errorf("AifrCommand: got %q, want %q", s.AifrCommand, tc.want) + } + if s.ToolName != tc.wantTool { + t.Errorf("ToolName: got %q, want %q", s.ToolName, tc.wantTool) + } + if s.Original != "cat" { + t.Errorf("Original: got %q, want %q", s.Original, "cat") + } + }) + } +} + +func TestAnalyzeCommand_Head(t *testing.T) { + cases := []struct { + command string + want string + }{ + {"head file.go", "aifr read --lines=1:10 file.go"}, + {"head -n 50 file.go", "aifr read --lines=1:50 file.go"}, + {"head -n50 file.go", "aifr read --lines=1:50 file.go"}, + {"head -20 file.go", "aifr read --lines=1:20 file.go"}, + {"head -n 100 src/main.go", "aifr read --lines=1:100 src/main.go"}, + } + for _, tc := range cases { + t.Run(tc.command, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != tc.want { + t.Errorf("got %q, want %q", s.AifrCommand, tc.want) + } + }) + } +} + +func TestAnalyzeCommand_Tail(t *testing.T) { + cases := []struct { + command string + want string + }{ + {"tail file.go", "aifr read --lines=-10: file.go"}, + {"tail -n 20 file.go", "aifr read --lines=-20: file.go"}, + {"tail -n20 file.go", "aifr read --lines=-20: file.go"}, + {"tail -5 file.go", "aifr read --lines=-5: file.go"}, + } + for _, tc := range cases { + t.Run(tc.command, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != tc.want { + t.Errorf("got %q, want %q", s.AifrCommand, tc.want) + } + }) + } +} + +func TestAnalyzeCommand_Grep(t *testing.T) { + cases := []struct { + command string + want string + }{ + {"grep TODO src/", "aifr search TODO src/"}, + {"grep -r pattern .", "aifr search pattern ."}, + {"grep -rn 'func main' .", "aifr search 'func main' ."}, + {"rg pattern src/", "aifr search pattern src/"}, + {"egrep 'foo|bar' dir/", "aifr search 'foo|bar' dir/"}, + } + for _, tc := range cases { + t.Run(tc.command, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != tc.want { + t.Errorf("got %q, want %q", s.AifrCommand, tc.want) + } + }) + } +} + +func TestAnalyzeCommand_Find(t *testing.T) { + cases := []struct { + command string + want string + }{ + {"find .", "aifr find ."}, + {"find . -name '*.go'", "aifr find . --name='*.go'"}, + {"find . -name '*.go' -type f", "aifr find . --name='*.go' --type=f"}, + {"find src/ -type d", "aifr find src/ --type=d"}, + } + for _, tc := range cases { + t.Run(tc.command, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != tc.want { + t.Errorf("got %q, want %q", s.AifrCommand, tc.want) + } + }) + } +} + +func TestAnalyzeCommand_Ls(t *testing.T) { + cases := []struct { + command string + want string + }{ + {"ls", "aifr list ."}, + {"ls src/", "aifr list src/"}, + {"ls -la src/", "aifr list src/"}, + {"ls -R src/", "aifr find src/"}, + } + for _, tc := range cases { + t.Run(tc.command, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != tc.want { + t.Errorf("got %q, want %q", s.AifrCommand, tc.want) + } + }) + } +} + +func TestAnalyzeCommand_Wc(t *testing.T) { + cases := []struct { + command string + want string + }{ + {"wc file.go", "aifr wc file.go"}, + {"wc -l file.go", "aifr wc -l file.go"}, + {"wc -l -w file.go", "aifr wc -l -w file.go"}, + {"wc file1.go file2.go", "aifr wc file1.go file2.go"}, + } + for _, tc := range cases { + t.Run(tc.command, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != tc.want { + t.Errorf("got %q, want %q", s.AifrCommand, tc.want) + } + }) + } +} + +func TestAnalyzeCommand_Stat(t *testing.T) { + s := AnalyzeCommand("stat file.go") + if s == nil { + t.Fatal("expected suggestion") + } + if s.AifrCommand != "aifr stat file.go" { + t.Errorf("got %q", s.AifrCommand) + } +} + +func TestAnalyzeCommand_Diff(t *testing.T) { + s := AnalyzeCommand("diff file1.go file2.go") + if s == nil { + t.Fatal("expected suggestion") + } + if s.AifrCommand != "aifr diff file1.go file2.go" { + t.Errorf("got %q", s.AifrCommand) + } +} + +func TestAnalyzeCommand_Checksum(t *testing.T) { + cases := []struct { + command string + want string + }{ + {"sha256sum file.go", "aifr checksum -a sha256 file.go"}, + {"md5sum file.go", "aifr checksum -a md5 file.go"}, + {"sha1sum file.go", "aifr checksum -a sha1 file.go"}, + } + for _, tc := range cases { + t.Run(tc.command, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != tc.want { + t.Errorf("got %q, want %q", s.AifrCommand, tc.want) + } + }) + } +} + +func TestAnalyzeCommand_Hexdump(t *testing.T) { + cases := []struct { + command string + want string + }{ + {"hexdump file.bin", "aifr hexdump file.bin"}, + {"xxd file.bin", "aifr hexdump file.bin"}, + } + for _, tc := range cases { + t.Run(tc.command, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != tc.want { + t.Errorf("got %q, want %q", s.AifrCommand, tc.want) + } + }) + } +} + +func TestAnalyzeCommand_Sed(t *testing.T) { + cases := []struct { + command string + want string + }{ + {"sed -n '5p' file.go", "aifr read --lines=5:5 file.go"}, + {"sed -n '10,20p' file.go", "aifr read --lines=10:20 file.go"}, + } + for _, tc := range cases { + t.Run(tc.command, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != tc.want { + t.Errorf("got %q, want %q", s.AifrCommand, tc.want) + } + }) + } +} + +func TestAnalyzeCommand_GitLog(t *testing.T) { + s := AnalyzeCommand("git log") + if s == nil { + t.Fatal("expected suggestion") + } + if s.AifrCommand != "aifr log" { + t.Errorf("got %q", s.AifrCommand) + } +} + +func TestAnalyzeCommand_EnvPrefix(t *testing.T) { + s := AnalyzeCommand("LANG=C cat file.go") + if s == nil { + t.Fatal("expected suggestion") + } + if s.AifrCommand != "aifr read file.go" { + t.Errorf("got %q", s.AifrCommand) + } +} + +func TestAnalyzeCommand_QuotedPaths(t *testing.T) { + s := AnalyzeCommand(`cat "path with spaces/file.go"`) + if s == nil { + t.Fatal("expected suggestion") + } + if s.AifrCommand != "aifr read 'path with spaces/file.go'" { + t.Errorf("got %q", s.AifrCommand) + } +} + +func TestAnalyzeCommand_AbsoluteCommandPath(t *testing.T) { + s := AnalyzeCommand("/usr/bin/head -n 5 file.go") + if s == nil { + t.Fatal("expected suggestion") + } + if s.AifrCommand != "aifr read --lines=1:5 file.go" { + t.Errorf("got %q", s.AifrCommand) + } +} + +// --- Pipeline tests --- + +func TestAnalyzeCommand_PipelineHeadCat(t *testing.T) { + cases := []struct { + command string + want string + }{ + {"cat file.go | head -n 50", "aifr read --lines=1:50 file.go"}, + {"cat file.go | head -10", "aifr read --lines=1:10 file.go"}, + {"cat file.go | head", "aifr read --lines=1:10 file.go"}, // default 10 + } + for _, tc := range cases { + t.Run(tc.command, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != tc.want { + t.Errorf("got %q, want %q", s.AifrCommand, tc.want) + } + if s.ToolName != "aifr_read" { + t.Errorf("ToolName: got %q, want %q", s.ToolName, "aifr_read") + } + }) + } +} + +func TestAnalyzeCommand_PipelineTailCat(t *testing.T) { + s := AnalyzeCommand("cat file.go | tail -n 20") + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != "aifr read --lines=-20: file.go" { + t.Errorf("got %q", s.AifrCommand) + } +} + +func TestAnalyzeCommand_PipelineHeadGitLog(t *testing.T) { + cases := []struct { + command string + want string + }{ + {"git log --oneline | head -n 10", "aifr log --oneline --max-count=10"}, + {"git log | head -n 5", "aifr log --max-count=5"}, + {"git log --oneline | head -20", "aifr log --oneline --max-count=20"}, + } + for _, tc := range cases { + t.Run(tc.command, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != tc.want { + t.Errorf("got %q, want %q", s.AifrCommand, tc.want) + } + if s.ToolName != "aifr_log" { + t.Errorf("ToolName: got %q, want %q", s.ToolName, "aifr_log") + } + }) + } +} + +func TestAnalyzeCommand_PipelineHeadGrep(t *testing.T) { + s := AnalyzeCommand("grep -rn TODO . | head -n 20") + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != "aifr search --max-matches=20 TODO ." { + t.Errorf("got %q", s.AifrCommand) + } +} + +func TestAnalyzeCommand_PipelineHeadFind(t *testing.T) { + s := AnalyzeCommand("find . -name '*.go' | head -n 30") + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != "aifr find . --name='*.go' --limit=30" { + t.Errorf("got %q", s.AifrCommand) + } +} + +func TestAnalyzeCommand_PipelineHeadLs(t *testing.T) { + s := AnalyzeCommand("ls -la src/ | head -n 20") + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != "aifr list src/ --limit=20" { + t.Errorf("got %q", s.AifrCommand) + } +} + +func TestAnalyzeCommand_PipelineSedCat(t *testing.T) { + cases := []struct { + command string + want string + }{ + {"cat file.go | sed -n '5,10p'", "aifr read --lines=5:10 file.go"}, + {"cat file.go | sed -n '42p'", "aifr read --lines=42:42 file.go"}, + {"cat file.go | sed -n '100,200p'", "aifr read --lines=100:200 file.go"}, + // start=1 normalizes to HeadLines, still produces the right suggestion. + {"cat file.go | sed -n '1,50p'", "aifr read --lines=1:50 file.go"}, + } + for _, tc := range cases { + t.Run(tc.command, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.AifrCommand != tc.want { + t.Errorf("got %q, want %q", s.AifrCommand, tc.want) + } + if s.ToolName != "aifr_read" { + t.Errorf("ToolName: got %q, want %q", s.ToolName, "aifr_read") + } + }) + } +} + +func TestAnalyzeCommand_PipelineSedNoSuggestion(t *testing.T) { + // sed line ranges on non-cat commands don't map to aifr parameters. + cases := []struct { + name string + command string + }{ + {"grep with sed range", "grep TODO . | sed -n '5,10p'"}, + {"find with sed range", "find . -name '*.go' | sed -n '5,20p'"}, + {"git log with sed range", "git log | sed -n '5,10p'"}, + {"sed without -n", "cat file.go | sed '5,10p'"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s != nil { + t.Errorf("expected nil, got suggestion: %s", s.AifrCommand) + } + }) + } +} + +// --- MCP tool info tests --- + +func TestAnalyzeCommand_MCPToolArgs(t *testing.T) { + cases := []struct { + name string + command string + wantTool string + wantArg string + wantVal any + }{ + {"cat path", "cat main.go", "aifr_read", "path", "main.go"}, + {"grep pattern", "grep TODO src/", "aifr_search", "pattern", "TODO"}, + {"find name", "find . -name '*.go'", "aifr_find", "name", "*.go"}, + {"git log oneline", "git log --oneline", "aifr_log", "format", "oneline"}, + {"pipeline max_count", "git log | head -5", "aifr_log", "max_count", 5}, + {"pipeline limit", "find . | head -30", "aifr_find", "limit", 30}, + {"diff paths", "diff a.go b.go", "aifr_diff", "path_a", "a.go"}, + {"checksum algo", "sha256sum f.go", "aifr_checksum", "algorithm", "sha256"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + s := AnalyzeCommand(tc.command) + if s == nil { + t.Fatal("expected suggestion, got nil") + } + if s.ToolName != tc.wantTool { + t.Errorf("ToolName: got %q, want %q", s.ToolName, tc.wantTool) + } + got, ok := s.ToolArgs[tc.wantArg] + if !ok { + t.Errorf("ToolArgs missing key %q; args=%v", tc.wantArg, s.ToolArgs) + return + } + // Compare with type awareness: int vs float64 from JSON. + switch want := tc.wantVal.(type) { + case int: + if got != want { + t.Errorf("ToolArgs[%q]: got %v (%T), want %v", tc.wantArg, got, got, want) + } + default: + if got != want { + t.Errorf("ToolArgs[%q]: got %v, want %v", tc.wantArg, got, want) + } + } + }) + } +}