Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions internal/repository/path.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package repository

import (
"context"
"errors"
"fmt"
"io/fs"
Expand All @@ -17,7 +18,9 @@ type fileReader func(path string) ([]byte, error)

// normalizeRetrievalQuery 统一校验检索请求并补齐默认值。
func normalizeRetrievalQuery(workdir string, query RetrievalQuery) (string, string, RetrievalQuery, error) {
if strings.TrimSpace(query.Value) == "" {
normalized := query
normalized.Value = strings.TrimSpace(query.Value)
if normalized.Value == "" {
return "", "", RetrievalQuery{}, errors.New("repository: query value is empty")
}

Expand All @@ -30,15 +33,13 @@ func normalizeRetrievalQuery(workdir string, query RetrievalQuery) (string, stri
return "", "", RetrievalQuery{}, err
}

normalized := query
switch query.Mode {
switch normalized.Mode {
case RetrievalModePath, RetrievalModeGlob, RetrievalModeText, RetrievalModeSymbol:
default:
return "", "", RetrievalQuery{}, errInvalidMode
}
normalized.Value = strings.TrimSpace(query.Value)
normalized.Limit = normalizeLimit(query.Limit, defaultRetrievalLimit, maxRetrievalLimit)
normalized.ContextLines = normalizeLimit(query.ContextLines, defaultContextLines, maxContextLines)
normalized.Limit = normalizeLimit(normalized.Limit, defaultRetrievalLimit, maxRetrievalLimit)
normalized.ContextLines = normalizeLimit(normalized.ContextLines, defaultContextLines, maxContextLines)
return root, scope, normalized, nil
}

Expand Down Expand Up @@ -121,9 +122,20 @@ func snippetAroundLine(content string, lineNumber int, contextLines int) (string
return snippet.text, lineNumber
}

// walkWorkspaceFiles 遍历工作区文件,同时跳过已约定的噪声目录。
func walkWorkspaceFiles(root string, scope string, visit func(path string, entry fs.DirEntry) error) error {
// walkWorkspaceFiles 遍历工作区文件,同时跳过已约定的噪声目录,并支持取消信号快速中断。
func walkWorkspaceFiles(
ctx context.Context,
root string,
scope string,
visit func(path string, entry fs.DirEntry) error,
) error {
if err := ctx.Err(); err != nil {
return err
}
return filepath.WalkDir(scope, func(path string, entry fs.DirEntry, err error) error {
if ctxErr := ctx.Err(); ctxErr != nil {
return ctxErr
}
if err != nil {
return err
}
Expand Down
53 changes: 25 additions & 28 deletions internal/repository/repository_additional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ func TestPathAndRetrievalHelpers(t *testing.T) {
}

visited := make([]string, 0, 2)
err := walkWorkspaceFiles(workdir, workdir, func(path string, entry fs.DirEntry) error {
err := walkWorkspaceFiles(context.Background(), workdir, workdir, func(path string, entry fs.DirEntry) error {
visited = append(visited, filepath.Base(path))
return nil
})
Expand All @@ -344,7 +344,7 @@ func TestPathAndRetrievalHelpers(t *testing.T) {
if slices.Contains(visited, "ignored.txt") {
t.Fatalf("expected node_modules file to be skipped, got %v", visited)
}
err = walkWorkspaceFiles(workdir, filepath.Join(workdir, "missing"), func(path string, entry fs.DirEntry) error {
err = walkWorkspaceFiles(context.Background(), workdir, filepath.Join(workdir, "missing"), func(path string, entry fs.DirEntry) error {
return nil
})
if err == nil {
Expand All @@ -370,10 +370,7 @@ func TestRetrieveAndServiceEdgeCases(t *testing.T) {
mustWriteFile(t, filepath.Join(workdir, "pkg", "defs.go"), "package pkg\n\ntype Widget struct{}\n\nfunc BuildWidget() {}\nconst WidgetName = \"x\"\nvar WidgetVar = 1\n")
mustWriteFile(t, filepath.Join(workdir, "pkg", "notes.txt"), "Widget WidgetName")

service := &Service{
gitRunner: runGitCommand,
readFile: readFile,
}
service := newTestService(runGitCommand)

t.Run("retrieve path guards and not exist", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -602,7 +599,7 @@ func TestRepositoryCoverageExtraBranches(t *testing.T) {
root := t.TempDir()
mustWriteFile(t, filepath.Join(root, "a.txt"), "a")
expectedErr := errors.New("stop")
err := walkWorkspaceFiles(root, root, func(path string, entry fs.DirEntry) error {
err := walkWorkspaceFiles(context.Background(), root, root, func(path string, entry fs.DirEntry) error {
return expectedErr
})
if !errors.Is(err, expectedErr) {
Expand All @@ -616,13 +613,22 @@ func TestRepositoryCoverageExtraBranches(t *testing.T) {
}
linkPath := filepath.Join(root, "escape.txt")
if err := os.Symlink(outsideFile, linkPath); err == nil {
err = walkWorkspaceFiles(root, root, func(path string, entry fs.DirEntry) error {
err = walkWorkspaceFiles(context.Background(), root, root, func(path string, entry fs.DirEntry) error {
return nil
})
if err == nil {
t.Fatalf("expected symlink escape error from walkWorkspaceFiles")
}
}

canceledCtx, cancel := context.WithCancel(context.Background())
cancel()
err = walkWorkspaceFiles(canceledCtx, root, root, func(path string, entry fs.DirEntry) error {
return nil
})
if !errors.Is(err, context.Canceled) {
t.Fatalf("walkWorkspaceFiles(canceled) err = %v", err)
}
})

t.Run("retrieve branches and service switches", func(t *testing.T) {
Expand All @@ -640,10 +646,7 @@ func TestRepositoryCoverageExtraBranches(t *testing.T) {
}, "\n"))
mustWriteFile(t, filepath.Join(root, "pkg", "match.txt"), "hit\nhit\nhit")

svc := &Service{
gitRunner: runGitCommand,
readFile: readFile,
}
svc := newTestService(runGitCommand)
canceledCtx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := svc.retrieveByGlob(canceledCtx, root, root, RetrievalQuery{Mode: RetrievalModeGlob, Value: "*.go", Limit: 1}); !errors.Is(err, context.Canceled) {
Expand Down Expand Up @@ -762,16 +765,13 @@ func TestRepositoryCoverageExtraBranches(t *testing.T) {
t.Run("summary representative limit and changed-files without snippets", func(t *testing.T) {
t.Parallel()

service := &Service{
gitRunner: func(ctx context.Context, workdir string, args ...string) (string, error) {
lines := []string{"## main"}
for i := 0; i < representativeChangedFilesLimit+2; i++ {
lines = append(lines, fmt.Sprintf(" M file%d.go", i))
}
return strings.Join(lines, "\n"), nil
},
readFile: readFile,
}
service := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) {
lines := []string{"## main"}
for i := 0; i < representativeChangedFilesLimit+2; i++ {
lines = append(lines, fmt.Sprintf(" M file%d.go", i))
}
return strings.Join(lines, "\n"), nil
})
summary, err := service.Summary(context.Background(), t.TempDir())
if err != nil {
t.Fatalf("Summary() err = %v", err)
Expand All @@ -796,12 +796,9 @@ func TestRepositoryCoverageExtraBranches(t *testing.T) {
t.Fatalf("ChangedFiles(canceled) err = %v", err)
}

nonGitService := &Service{
gitRunner: func(ctx context.Context, workdir string, args ...string) (string, error) {
return "fatal: not a git repository", errors.New("exit status 128")
},
readFile: readFile,
}
nonGitService := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) {
return "fatal: not a git repository", errors.New("exit status 128")
})
ctxResult, err := nonGitService.ChangedFiles(context.Background(), t.TempDir(), ChangedFilesOptions{})
if err != nil {
t.Fatalf("ChangedFiles(non-git) err = %v", err)
Expand Down
Loading
Loading