Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions cmd/aifr/args_coverage_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
// Copyright 2026 — see LICENSE file for terms.
package main

import (
"go/ast"
"go/parser"
"go/token"
"os"
"path/filepath"
"strings"
"testing"
)

// TestVariadicCommandsUseFullArgs uses Go AST analysis to verify that every
// cobra command declared with MinimumNArgs(1) passes the full `args` slice
// to functions, rather than subscripting `args[0]` and silently dropping
// extra positional arguments.
//
// This prevents a class of bug where shell glob expansion (e.g. nats-pi-*)
// produces multiple arguments but only the first is processed.
func TestVariadicCommandsUseFullArgs(t *testing.T) {
cmdFiles, err := filepath.Glob(filepath.Join(".", "cmd_*.go"))
if err != nil {
t.Fatal(err)
}

// If running from the repo root, try the cmd/aifr directory.
if len(cmdFiles) == 0 {
cmdFiles, err = filepath.Glob(filepath.Join("cmd", "aifr", "cmd_*.go"))
if err != nil {
t.Fatal(err)
}
}

if len(cmdFiles) == 0 {
t.Fatal("found no cmd_*.go files")
}

fset := token.NewFileSet()

for _, path := range cmdFiles {
src, err := os.ReadFile(path)
if err != nil {
t.Fatalf("reading %s: %v", path, err)
}

f, err := parser.ParseFile(fset, path, src, 0)
if err != nil {
t.Fatalf("parsing %s: %v", path, err)
}

checkVariadicCommands(t, fset, f, path)
}
}

// checkVariadicCommands finds cobra.Command declarations in the file that
// use MinimumNArgs(1), then verifies their RunE bodies do not subscript
// args[0] (which would silently discard extra positional arguments).
func checkVariadicCommands(t *testing.T, fset *token.FileSet, f *ast.File, filename string) {
t.Helper()

// Walk all top-level variable declarations to find cobra commands.
for _, decl := range f.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.VAR {
continue
}
for _, spec := range genDecl.Specs {
vs, ok := spec.(*ast.ValueSpec)
if !ok || len(vs.Values) != 1 {
continue
}

// Look for &cobra.Command{...} composite literals.
unary, ok := vs.Values[0].(*ast.UnaryExpr)
if !ok || unary.Op != token.AND {
continue
}
compLit, ok := unary.X.(*ast.CompositeLit)
if !ok {
continue
}

if !isCobraCommandType(compLit.Type) {
continue
}

varName := vs.Names[0].Name

if !hasMinimumNArgs1(compLit) {
continue
}

// Found a variadic command. Check its RunE body.
runE := findRunEBody(compLit)
if runE == nil {
continue
}

if containsArgsSubscript(runE) {
t.Errorf("%s: command %q uses MinimumNArgs(1) but subscripts args[N] "+
"in its RunE — this silently drops extra positional arguments. "+
"Pass the full args slice instead.",
filename, varName)
}
}
}
}

// isCobraCommandType checks if a type expression refers to cobra.Command.
func isCobraCommandType(expr ast.Expr) bool {
sel, ok := expr.(*ast.SelectorExpr)
if !ok {
return false
}
ident, ok := sel.X.(*ast.Ident)
if !ok {
return false
}
return ident.Name == "cobra" && sel.Sel.Name == "Command"
}

// hasMinimumNArgs1 checks if a cobra.Command composite literal contains
// Args: cobra.MinimumNArgs(1).
func hasMinimumNArgs1(lit *ast.CompositeLit) bool {
for _, elt := range lit.Elts {
kv, ok := elt.(*ast.KeyValueExpr)
if !ok {
continue
}
keyIdent, ok := kv.Key.(*ast.Ident)
if !ok || keyIdent.Name != "Args" {
continue
}

// Check for cobra.MinimumNArgs(1).
call, ok := kv.Value.(*ast.CallExpr)
if !ok {
continue
}
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
continue
}
pkg, ok := sel.X.(*ast.Ident)
if !ok {
continue
}
if pkg.Name != "cobra" || sel.Sel.Name != "MinimumNArgs" {
continue
}
if len(call.Args) == 1 {
bl, ok := call.Args[0].(*ast.BasicLit)
if ok && bl.Value == "1" {
return true
}
}
}
return false
}

// findRunEBody finds the RunE field's function literal body in a cobra.Command
// composite literal.
func findRunEBody(lit *ast.CompositeLit) *ast.BlockStmt {
for _, elt := range lit.Elts {
kv, ok := elt.(*ast.KeyValueExpr)
if !ok {
continue
}
keyIdent, ok := kv.Key.(*ast.Ident)
if !ok || keyIdent.Name != "RunE" {
continue
}
funcLit, ok := kv.Value.(*ast.FuncLit)
if !ok {
continue
}
return funcLit.Body
}
return nil
}

// containsArgsSubscript walks an AST node looking for index expressions of
// the form args[N] where args is an identifier and N is any expression.
func containsArgsSubscript(node ast.Node) bool {
found := false
ast.Inspect(node, func(n ast.Node) bool {
if found {
return false
}
indexExpr, ok := n.(*ast.IndexExpr)
if !ok {
return true
}
ident, ok := indexExpr.X.(*ast.Ident)
if !ok {
return true
}
if strings.EqualFold(ident.Name, "args") {
found = true
return false
}
return true
})
return found
}
6 changes: 3 additions & 3 deletions cmd/aifr/cmd_cat.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ var catCmd = &cobra.Command{
Discovery: aifr cat --name '*.go' --exclude-path '**/vendor/**' ./src/

Discovery mode activates when --name or --exclude-path flags are set.
In discovery mode, exactly one positional arg (the root directory) is expected.
In discovery mode, positional args are root directories to search.

Divider formats for --format text:
plain --- path/to/file ---
Expand Down Expand Up @@ -60,9 +60,9 @@ Divider formats for --format text:

var resp *protocol.CatResponse
if isDiscovery {
resp, err = eng.Cat(nil, args[0], params)
resp, err = eng.Cat(nil, args, params)
} else {
resp, err = eng.Cat(args, "", params)
resp, err = eng.Cat(args, nil, params)
}
if err != nil {
exitWithError(err)
Expand Down
89 changes: 52 additions & 37 deletions internal/engine/cat.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@ type CatParams struct {
MaxFiles int // 0 = use default
}

// catFile pairs an absolute path with its pre-computed relative path.
type catFile struct {
absPath string
relPath string // relative to discovery root; empty in explicit mode
}

// Cat concatenates contents of multiple files.
// If paths is non-empty, reads those files in order (explicit mode).
// If paths is empty and root is set, discovers files under root (discovery mode).
func (e *Engine) Cat(paths []string, root string, params CatParams) (*protocol.CatResponse, error) {
// If paths is empty and roots is non-empty, discovers files under each root (discovery mode).
func (e *Engine) Cat(paths []string, roots []string, params CatParams) (*protocol.CatResponse, error) {
maxTotal := params.MaxTotalSize
if maxTotal <= 0 {
maxTotal = DefaultMaxTotalSize
Expand All @@ -51,31 +57,47 @@ func (e *Engine) Cat(paths []string, root string, params CatParams) (*protocol.C

if len(paths) > 0 {
resp.Mode = "explicit"
e.catReadFiles(paths, "", params, maxTotal, maxFiles, resp)
} else if root != "" {
resp.Mode = "discover"
resolved, err := e.checkAccess(root)
if err != nil {
return nil, err
files := make([]catFile, len(paths))
for i, p := range paths {
files[i] = catFile{absPath: p}
}
e.catReadFiles(files, params, maxTotal, maxFiles, resp)
} else if len(roots) > 0 {
resp.Mode = "discover"

info, err := os.Stat(resolved)
if err != nil {
return nil, protocol.NewPathError(protocol.ErrNotFound, root, "path does not exist")
}
if !info.IsDir() {
return nil, protocol.NewPathError(protocol.ErrIsDirectory, root, "cat discovery mode requires a directory")
}
var allFiles []catFile
for _, root := range roots {
resolved, err := e.checkAccess(root)
if err != nil {
return nil, err
}

resp.Root = resolved
info, err := os.Stat(resolved)
if err != nil {
return nil, protocol.NewPathError(protocol.ErrNotFound, root, "path does not exist")
}
if !info.IsDir() {
return nil, protocol.NewPathError(protocol.ErrIsDirectory, root, "cat discovery mode requires a directory")
}

// Discover files under this root.
discovered := e.catDiscover(resolved, resolved, params, 0)

// Sort within each root for deterministic output.
slices.SortFunc(discovered, func(a, b catFile) int {
return strings.Compare(a.absPath, b.absPath)
})

// Discover files.
discovered := e.catDiscover(resolved, resolved, params, 0)
allFiles = append(allFiles, discovered...)
}

// Sort for deterministic output.
slices.Sort(discovered)
// Set Root when there is exactly one root (backward-compatible).
if len(roots) == 1 {
resolved, _ := e.checkAccess(roots[0])
resp.Root = resolved
}

e.catReadFiles(discovered, resolved, params, maxTotal, maxFiles, resp)
e.catReadFiles(allFiles, params, maxTotal, maxFiles, resp)
} else {
return nil, protocol.NewError("INVALID_ARGS", "cat requires either explicit paths or a root directory")
}
Expand All @@ -85,14 +107,14 @@ func (e *Engine) Cat(paths []string, root string, params CatParams) (*protocol.C
return resp, nil
}

// catDiscover walks a directory tree collecting file paths matching filters.
func (e *Engine) catDiscover(root, dir string, params CatParams, depth int) []string {
// catDiscover walks a directory tree collecting files matching filters.
func (e *Engine) catDiscover(root, dir string, params CatParams, depth int) []catFile {
entries, err := os.ReadDir(dir)
if err != nil {
return nil
}

var result []string
var result []catFile
for _, de := range entries {
fullPath := filepath.Join(dir, de.Name())

Expand Down Expand Up @@ -138,34 +160,27 @@ func (e *Engine) catDiscover(root, dir string, params CatParams, depth int) []st
continue // silently skip inaccessible in discovery
}

result = append(result, fullPath)
rel, _ := filepath.Rel(root, fullPath)
result = append(result, catFile{absPath: fullPath, relPath: rel})
}
return result
}

// catReadFiles reads a list of files and populates the response.
func (e *Engine) catReadFiles(paths []string, root string, params CatParams, maxTotal int64, maxFiles int, resp *protocol.CatResponse) {
func (e *Engine) catReadFiles(files []catFile, params CatParams, maxTotal int64, maxFiles int, resp *protocol.CatResponse) {
var totalBytes int64

for _, path := range paths {
for _, cf := range files {
if len(resp.Files) >= maxFiles {
resp.Truncated = true
resp.Warning = "max_files_limit"
break
}

entry := protocol.CatEntry{Path: path}

// Compute relative path if in discovery mode.
if root != "" {
rel, err := filepath.Rel(root, path)
if err == nil {
entry.RelPath = rel
}
}
entry := protocol.CatEntry{Path: cf.absPath, RelPath: cf.relPath}

// Check access (for explicit mode; discovery already checked).
resolved, err := e.checkAccess(path)
resolved, err := e.checkAccess(cf.absPath)
if err != nil {
if ae, ok := err.(*protocol.AifrError); ok {
entry.Error = ae.Code
Expand Down
Loading