diff --git a/internal/repository/path.go b/internal/repository/path.go index d519539f..02ba92b4 100644 --- a/internal/repository/path.go +++ b/internal/repository/path.go @@ -1,6 +1,7 @@ package repository import ( + "context" "errors" "fmt" "io/fs" @@ -17,7 +18,9 @@ type fileReader func(path string) ([]byte, error) // normalizeRetrievalQuery 统一校验检索请求并补齐默认值。 func normalizeRetrievalQuery(workdir string, query RetrievalQuery) (string, string, RetrievalQuery, error) { - if strings.TrimSpace(query.Value) == "" { + normalized := query + normalized.Value = strings.TrimSpace(query.Value) + if normalized.Value == "" { return "", "", RetrievalQuery{}, errors.New("repository: query value is empty") } @@ -30,15 +33,13 @@ func normalizeRetrievalQuery(workdir string, query RetrievalQuery) (string, stri return "", "", RetrievalQuery{}, err } - normalized := query - switch query.Mode { + switch normalized.Mode { case RetrievalModePath, RetrievalModeGlob, RetrievalModeText, RetrievalModeSymbol: default: return "", "", RetrievalQuery{}, errInvalidMode } - normalized.Value = strings.TrimSpace(query.Value) - normalized.Limit = normalizeLimit(query.Limit, defaultRetrievalLimit, maxRetrievalLimit) - normalized.ContextLines = normalizeLimit(query.ContextLines, defaultContextLines, maxContextLines) + normalized.Limit = normalizeLimit(normalized.Limit, defaultRetrievalLimit, maxRetrievalLimit) + normalized.ContextLines = normalizeLimit(normalized.ContextLines, defaultContextLines, maxContextLines) return root, scope, normalized, nil } @@ -121,9 +122,20 @@ func snippetAroundLine(content string, lineNumber int, contextLines int) (string return snippet.text, lineNumber } -// walkWorkspaceFiles 遍历工作区文件,同时跳过已约定的噪声目录。 -func walkWorkspaceFiles(root string, scope string, visit func(path string, entry fs.DirEntry) error) error { +// walkWorkspaceFiles 遍历工作区文件,同时跳过已约定的噪声目录,并支持取消信号快速中断。 +func walkWorkspaceFiles( + ctx context.Context, + root string, + scope string, + visit func(path string, entry fs.DirEntry) error, +) error { + if err := ctx.Err(); err != nil { + return err + } return filepath.WalkDir(scope, func(path string, entry fs.DirEntry, err error) error { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } if err != nil { return err } diff --git a/internal/repository/repository_additional_test.go b/internal/repository/repository_additional_test.go index 7b56cfa5..f09b5b89 100644 --- a/internal/repository/repository_additional_test.go +++ b/internal/repository/repository_additional_test.go @@ -334,7 +334,7 @@ func TestPathAndRetrievalHelpers(t *testing.T) { } visited := make([]string, 0, 2) - err := walkWorkspaceFiles(workdir, workdir, func(path string, entry fs.DirEntry) error { + err := walkWorkspaceFiles(context.Background(), workdir, workdir, func(path string, entry fs.DirEntry) error { visited = append(visited, filepath.Base(path)) return nil }) @@ -344,7 +344,7 @@ func TestPathAndRetrievalHelpers(t *testing.T) { if slices.Contains(visited, "ignored.txt") { t.Fatalf("expected node_modules file to be skipped, got %v", visited) } - err = walkWorkspaceFiles(workdir, filepath.Join(workdir, "missing"), func(path string, entry fs.DirEntry) error { + err = walkWorkspaceFiles(context.Background(), workdir, filepath.Join(workdir, "missing"), func(path string, entry fs.DirEntry) error { return nil }) if err == nil { @@ -370,10 +370,7 @@ func TestRetrieveAndServiceEdgeCases(t *testing.T) { mustWriteFile(t, filepath.Join(workdir, "pkg", "defs.go"), "package pkg\n\ntype Widget struct{}\n\nfunc BuildWidget() {}\nconst WidgetName = \"x\"\nvar WidgetVar = 1\n") mustWriteFile(t, filepath.Join(workdir, "pkg", "notes.txt"), "Widget WidgetName") - service := &Service{ - gitRunner: runGitCommand, - readFile: readFile, - } + service := newTestService(runGitCommand) t.Run("retrieve path guards and not exist", func(t *testing.T) { t.Parallel() @@ -602,7 +599,7 @@ func TestRepositoryCoverageExtraBranches(t *testing.T) { root := t.TempDir() mustWriteFile(t, filepath.Join(root, "a.txt"), "a") expectedErr := errors.New("stop") - err := walkWorkspaceFiles(root, root, func(path string, entry fs.DirEntry) error { + err := walkWorkspaceFiles(context.Background(), root, root, func(path string, entry fs.DirEntry) error { return expectedErr }) if !errors.Is(err, expectedErr) { @@ -616,13 +613,22 @@ func TestRepositoryCoverageExtraBranches(t *testing.T) { } linkPath := filepath.Join(root, "escape.txt") if err := os.Symlink(outsideFile, linkPath); err == nil { - err = walkWorkspaceFiles(root, root, func(path string, entry fs.DirEntry) error { + err = walkWorkspaceFiles(context.Background(), root, root, func(path string, entry fs.DirEntry) error { return nil }) if err == nil { t.Fatalf("expected symlink escape error from walkWorkspaceFiles") } } + + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + err = walkWorkspaceFiles(canceledCtx, root, root, func(path string, entry fs.DirEntry) error { + return nil + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("walkWorkspaceFiles(canceled) err = %v", err) + } }) t.Run("retrieve branches and service switches", func(t *testing.T) { @@ -640,10 +646,7 @@ func TestRepositoryCoverageExtraBranches(t *testing.T) { }, "\n")) mustWriteFile(t, filepath.Join(root, "pkg", "match.txt"), "hit\nhit\nhit") - svc := &Service{ - gitRunner: runGitCommand, - readFile: readFile, - } + svc := newTestService(runGitCommand) canceledCtx, cancel := context.WithCancel(context.Background()) cancel() if _, err := svc.retrieveByGlob(canceledCtx, root, root, RetrievalQuery{Mode: RetrievalModeGlob, Value: "*.go", Limit: 1}); !errors.Is(err, context.Canceled) { @@ -762,16 +765,13 @@ func TestRepositoryCoverageExtraBranches(t *testing.T) { t.Run("summary representative limit and changed-files without snippets", func(t *testing.T) { t.Parallel() - service := &Service{ - gitRunner: func(ctx context.Context, workdir string, args ...string) (string, error) { - lines := []string{"## main"} - for i := 0; i < representativeChangedFilesLimit+2; i++ { - lines = append(lines, fmt.Sprintf(" M file%d.go", i)) - } - return strings.Join(lines, "\n"), nil - }, - readFile: readFile, - } + service := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) { + lines := []string{"## main"} + for i := 0; i < representativeChangedFilesLimit+2; i++ { + lines = append(lines, fmt.Sprintf(" M file%d.go", i)) + } + return strings.Join(lines, "\n"), nil + }) summary, err := service.Summary(context.Background(), t.TempDir()) if err != nil { t.Fatalf("Summary() err = %v", err) @@ -796,12 +796,9 @@ func TestRepositoryCoverageExtraBranches(t *testing.T) { t.Fatalf("ChangedFiles(canceled) err = %v", err) } - nonGitService := &Service{ - gitRunner: func(ctx context.Context, workdir string, args ...string) (string, error) { - return "fatal: not a git repository", errors.New("exit status 128") - }, - readFile: readFile, - } + nonGitService := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) { + return "fatal: not a git repository", errors.New("exit status 128") + }) ctxResult, err := nonGitService.ChangedFiles(context.Background(), t.TempDir(), ChangedFilesOptions{}) if err != nil { t.Fatalf("ChangedFiles(non-git) err = %v", err) diff --git a/internal/repository/repository_test.go b/internal/repository/repository_test.go index 3164f865..09b5361c 100644 --- a/internal/repository/repository_test.go +++ b/internal/repository/repository_test.go @@ -13,12 +13,9 @@ import ( func TestSummaryReturnsStableEmptyForNonGitDirectory(t *testing.T) { t.Parallel() - service := &Service{ - gitRunner: func(ctx context.Context, workdir string, args ...string) (string, error) { - return "fatal: not a git repository", errors.New("exit status 128") - }, - readFile: readFile, - } + service := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) { + return "fatal: not a git repository", errors.New("exit status 128") + }) summary, err := service.Summary(context.Background(), t.TempDir()) if err != nil { @@ -32,17 +29,14 @@ func TestSummaryReturnsStableEmptyForNonGitDirectory(t *testing.T) { func TestSummaryParsesBranchDirtyAheadBehindAndRepresentativeFiles(t *testing.T) { t.Parallel() - service := &Service{ - gitRunner: func(ctx context.Context, workdir string, args ...string) (string, error) { - return strings.Join([]string{ - "## feature/repository...origin/feature/repository [ahead 2, behind 1]", - " M internal/context/source_system.go", - "R old/name.go -> new/name.go", - "?? internal/repository/service.go", - }, "\n"), nil - }, - readFile: readFile, - } + service := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) { + return strings.Join([]string{ + "## feature/repository...origin/feature/repository [ahead 2, behind 1]", + " M internal/context/source_system.go", + "R old/name.go -> new/name.go", + "?? internal/repository/service.go", + }, "\n"), nil + }) summary, err := service.Summary(context.Background(), t.TempDir()) if err != nil { @@ -86,32 +80,28 @@ func TestChangedFilesRespectsStatusNormalizationAndSnippetRules(t *testing.T) { t.Fatalf("WriteFile() error = %v", err) } - service := &Service{ - gitRunner: func(ctx context.Context, dir string, args ...string) (string, error) { - command := strings.Join(args, " ") - switch command { - case "status --porcelain=v1 --branch --untracked-files=normal": - return strings.Join([]string{ - "## main...origin/main [ahead 1]", - " M pkg/changed.go", - "A pkg/new.go", - "?? pkg/untracked.go", - "D pkg/deleted.go", - "R pkg/old.go -> pkg/renamed.go", - "UU pkg/conflicted.go", - }, "\n"), nil - case "diff --unified=3 HEAD -- pkg/changed.go": - return "@@ -1,1 +1,1 @@\n-func Old() {}\n+func Changed() {}\n", nil - case "diff --unified=3 HEAD -- pkg/new.go": - return "@@ -0,0 +1,3 @@\n+package pkg\n+\n+func Added() {}\n", nil - case "diff --unified=3 HEAD -- pkg/renamed.go": - return "", nil - default: - return "", nil - } - }, - readFile: readFile, - } + service := newTestService(func(ctx context.Context, dir string, args ...string) (string, error) { + switch strings.Join(args, " ") { + case "status --porcelain=v1 --branch --untracked-files=normal": + return strings.Join([]string{ + "## main...origin/main [ahead 1]", + " M pkg/changed.go", + "A pkg/new.go", + "?? pkg/untracked.go", + "D pkg/deleted.go", + "R pkg/old.go -> pkg/renamed.go", + "UU pkg/conflicted.go", + }, "\n"), nil + case "diff --unified=3 HEAD -- pkg/changed.go": + return "@@ -1,1 +1,1 @@\n-func Old() {}\n+func Changed() {}\n", nil + case "diff --unified=3 HEAD -- pkg/new.go": + return "@@ -0,0 +1,3 @@\n+package pkg\n+\n+func Added() {}\n", nil + case "diff --unified=3 HEAD -- pkg/renamed.go": + return "", nil + default: + return "", nil + } + }) ctx, err := service.ChangedFiles(context.Background(), workdir, ChangedFilesOptions{ IncludeSnippets: true, @@ -133,16 +123,13 @@ func TestChangedFilesRespectsStatusNormalizationAndSnippetRules(t *testing.T) { func TestChangedFilesAppliesLimitAndTruncation(t *testing.T) { t.Parallel() - service := &Service{ - gitRunner: func(ctx context.Context, workdir string, args ...string) (string, error) { - lines := []string{"## main"} - for i := 0; i < 60; i++ { - lines = append(lines, " M file"+strconv.Itoa(i)+".go") - } - return strings.Join(lines, "\n"), nil - }, - readFile: readFile, - } + service := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) { + lines := []string{"## main"} + for i := 0; i < 60; i++ { + lines = append(lines, " M file"+strconv.Itoa(i)+".go") + } + return strings.Join(lines, "\n"), nil + }) result, err := service.ChangedFiles(context.Background(), t.TempDir(), ChangedFilesOptions{}) if err != nil { @@ -159,24 +146,20 @@ func TestChangedFilesAppliesLimitAndTruncation(t *testing.T) { func TestChangedFilesMarksTruncatedWhenSingleSnippetExceedsLineLimit(t *testing.T) { t.Parallel() - service := &Service{ - gitRunner: func(ctx context.Context, workdir string, args ...string) (string, error) { - command := strings.Join(args, " ") - switch command { - case "status --porcelain=v1 --branch --untracked-files=normal": - return "## main\n M pkg/long.go\n", nil - case "diff --unified=3 HEAD -- pkg/long.go": - lines := []string{"@@ -1,1 +1,25 @@"} - for i := 0; i < 25; i++ { - lines = append(lines, "+line "+strconv.Itoa(i)) - } - return strings.Join(lines, "\n"), nil - default: - return "", nil + service := newTestService(func(ctx context.Context, workdir string, args ...string) (string, error) { + switch strings.Join(args, " ") { + case "status --porcelain=v1 --branch --untracked-files=normal": + return "## main\n M pkg/long.go\n", nil + case "diff --unified=3 HEAD -- pkg/long.go": + lines := []string{"@@ -1,1 +1,25 @@"} + for i := 0; i < 25; i++ { + lines = append(lines, "+line "+strconv.Itoa(i)) } - }, - readFile: readFile, - } + return strings.Join(lines, "\n"), nil + default: + return "", nil + } + }) result, err := service.ChangedFiles(context.Background(), t.TempDir(), ChangedFilesOptions{ IncludeSnippets: true, @@ -210,15 +193,12 @@ func TestChangedFilesMarksTruncatedWhenTotalSnippetBudgetExceeded(t *testing.T) statusLines = append(statusLines, "?? "+filepath.ToSlash(fileName)) } - service := &Service{ - gitRunner: func(ctx context.Context, dir string, args ...string) (string, error) { - if strings.Join(args, " ") == "status --porcelain=v1 --branch --untracked-files=normal" { - return strings.Join(statusLines, "\n"), nil - } - return "", nil - }, - readFile: readFile, - } + service := newTestService(func(ctx context.Context, dir string, args ...string) (string, error) { + if strings.Join(args, " ") == "status --porcelain=v1 --branch --untracked-files=normal" { + return strings.Join(statusLines, "\n"), nil + } + return "", nil + }) result, err := service.ChangedFiles(context.Background(), workdir, ChangedFilesOptions{ IncludeSnippets: true, @@ -345,6 +325,58 @@ func TestRetrieveSymbolFallsBackToWholeWordTextSearch(t *testing.T) { } } +func TestRetrieveSkipsSensitiveLargeAndBinaryFiles(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + mustWriteFile(t, filepath.Join(workdir, ".env"), "API_KEY=secret\n") + mustWriteFile(t, filepath.Join(workdir, "pkg", "notes.key"), "private") + mustWriteFile(t, filepath.Join(workdir, "pkg", "bin.dat"), string([]byte{0x00, 0x01, 0x02, 0x03})) + mustWriteFile(t, filepath.Join(workdir, "pkg", "target.txt"), "match line\n") + + largeContent := strings.Repeat("x", maxRetrievalFileBytes+1) + mustWriteFile(t, filepath.Join(workdir, "pkg", "large.txt"), largeContent) + + service := NewService() + + pathHits, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ + Mode: RetrievalModePath, + Value: ".env", + }) + if err != nil { + t.Fatalf("Retrieve(path sensitive) error = %v", err) + } + if len(pathHits) != 0 { + t.Fatalf("expected sensitive path retrieval to be filtered, got %+v", pathHits) + } + + textHits, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ + Mode: RetrievalModeText, + Value: "match", + Limit: 10, + }) + if err != nil { + t.Fatalf("Retrieve(text) error = %v", err) + } + if len(textHits) != 1 || textHits[0].Path != filepath.Clean("pkg/target.txt") { + t.Fatalf("expected only safe text hit, got %+v", textHits) + } + + globHits, err := service.Retrieve(context.Background(), workdir, RetrievalQuery{ + Mode: RetrievalModeGlob, + Value: "pkg/*", + Limit: 10, + }) + if err != nil { + t.Fatalf("Retrieve(glob) error = %v", err) + } + for _, hit := range globHits { + if hit.Path == filepath.Clean("pkg/large.txt") || hit.Path == filepath.Clean("pkg/notes.key") || hit.Path == filepath.Clean("pkg/bin.dat") { + t.Fatalf("expected filtered file to be excluded, got %+v", globHits) + } + } +} + func assertChangedFile(t *testing.T, file ChangedFile, path string, oldPath string, status ChangedFileStatus, snippetContains string) { t.Helper() if file.Path != path || file.OldPath != oldPath || file.Status != status { @@ -370,3 +402,10 @@ func mustWriteFile(t *testing.T, path string, content string) { t.Fatalf("WriteFile() error = %v", err) } } + +func newTestService(gitRunner func(ctx context.Context, workdir string, args ...string) (string, error)) *Service { + return &Service{ + gitRunner: gitRunner, + readFile: readFile, + } +} diff --git a/internal/repository/retrieve.go b/internal/repository/retrieve.go index 3869baf2..ba228fcb 100644 --- a/internal/repository/retrieve.go +++ b/internal/repository/retrieve.go @@ -1,6 +1,7 @@ package repository import ( + "bytes" "context" "io/fs" "os" @@ -16,8 +17,21 @@ const ( defaultContextLines = 3 maxContextLines = 8 maxSnippetLines = 20 + maxRetrievalFileBytes = 256 * 1024 + binaryProbePrefixSize = 1024 ) +var blockedSensitiveExtensions = map[string]struct{}{ + ".key": {}, + ".pem": {}, + ".p12": {}, + ".pfx": {}, + ".jks": {}, + ".der": {}, + ".cer": {}, + ".crt": {}, +} + // retrieveByPath 按路径读取目标文件的受限片段。 func (s *Service) retrieveByPath(ctx context.Context, root string, query RetrievalQuery) ([]RetrievalHit, error) { if err := ctx.Err(); err != nil { @@ -27,6 +41,9 @@ func (s *Service) retrieveByPath(ctx context.Context, root string, query Retriev if err != nil { return nil, err } + if !allowRetrievalReadByPath(target) { + return []RetrievalHit{}, nil + } content, err := s.readFile(target) if err != nil { if os.IsNotExist(err) { @@ -34,19 +51,15 @@ func (s *Service) retrieveByPath(ctx context.Context, root string, query Retriev } return nil, err } + if isBinaryContent(content) { + return []RetrievalHit{}, nil + } - snippet, lineHint := snippetAroundLine(string(content), 1, query.ContextLines) - relativePath, err := filepath.Rel(root, target) + hit, err := buildRetrievalHit(root, target, RetrievalModePath, query.Value, string(content), 1, query.ContextLines) if err != nil { return nil, err } - return []RetrievalHit{{ - Path: filepath.Clean(relativePath), - Kind: string(RetrievalModePath), - SymbolOrQuery: query.Value, - Snippet: snippet, - LineHint: lineHint, - }}, nil + return []RetrievalHit{hit}, nil } // retrieveByGlob 按 glob 模式在工作区内定位候选文件。 @@ -56,7 +69,10 @@ func (s *Service) retrieveByGlob(ctx context.Context, root string, scope string, } hits := make([]RetrievalHit, 0, query.Limit) - err := walkWorkspaceFiles(root, scope, func(path string, entry fs.DirEntry) error { + err := walkWorkspaceFiles(ctx, root, scope, func(path string, entry fs.DirEntry) error { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } if len(hits) >= query.Limit { return nil } @@ -77,23 +93,15 @@ func (s *Service) retrieveByGlob(ctx context.Context, root string, scope string, if !match { return nil } - - content, readErr := s.readFile(path) - if readErr != nil { + content, ok := s.readRetrievalText(path, entry) + if !ok { return nil } - snippet, lineHint := snippetAroundLine(string(content), 1, query.ContextLines) - relative, relErr := filepath.Rel(root, path) - if relErr != nil { - return relErr + hit, hitErr := buildRetrievalHit(root, path, RetrievalModeGlob, query.Value, content, 1, query.ContextLines) + if hitErr != nil { + return hitErr } - hits = append(hits, RetrievalHit{ - Path: filepath.Clean(relative), - Kind: string(RetrievalModeGlob), - SymbolOrQuery: query.Value, - Snippet: snippet, - LineHint: lineHint, - }) + hits = append(hits, hit) return nil }) if err != nil { @@ -118,18 +126,22 @@ func (s *Service) retrieveByText(ctx context.Context, root string, scope string, } hits := make([]RetrievalHit, 0, query.Limit) - err := walkWorkspaceFiles(root, scope, func(path string, entry fs.DirEntry) error { + err := walkWorkspaceFiles(ctx, root, scope, func(path string, entry fs.DirEntry) error { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } if len(hits) >= query.Limit { return nil } - contentBytes, readErr := s.readFile(path) - if readErr != nil { + content, ok := s.readRetrievalText(path, entry) + if !ok { return nil } - - content := string(contentBytes) lines := strings.Split(strings.ReplaceAll(content, "\r\n", "\n"), "\n") for index, line := range lines { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } if len(hits) >= query.Limit { break } @@ -141,18 +153,11 @@ func (s *Service) retrieveByText(ctx context.Context, root string, scope string, continue } - snippet, lineHint := snippetAroundLine(content, index+1, query.ContextLines) - relative, relErr := filepath.Rel(root, path) - if relErr != nil { - return relErr + hit, hitErr := buildRetrievalHit(root, path, RetrievalModeText, query.Value, content, index+1, query.ContextLines) + if hitErr != nil { + return hitErr } - hits = append(hits, RetrievalHit{ - Path: filepath.Clean(relative), - Kind: string(RetrievalModeText), - SymbolOrQuery: query.Value, - Snippet: snippet, - LineHint: lineHint, - }) + hits = append(hits, hit) } return nil }) @@ -171,36 +176,33 @@ func (s *Service) retrieveBySymbol(ctx context.Context, root string, scope strin } hits := make([]RetrievalHit, 0, query.Limit) - err := walkWorkspaceFiles(root, scope, func(path string, entry fs.DirEntry) error { + err := walkWorkspaceFiles(ctx, root, scope, func(path string, entry fs.DirEntry) error { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } if len(hits) >= query.Limit { return nil } if filepath.Ext(path) != ".go" { return nil } - - contentBytes, readErr := s.readFile(path) - if readErr != nil { + content, ok := s.readRetrievalText(path, entry) + if !ok { return nil } - content := string(contentBytes) lineNumbers := findGoSymbolDefinitions(content, query.Value) for _, lineNumber := range lineNumbers { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } if len(hits) >= query.Limit { break } - snippet, lineHint := snippetAroundLine(content, lineNumber, query.ContextLines) - relative, relErr := filepath.Rel(root, path) - if relErr != nil { - return relErr + hit, hitErr := buildRetrievalHit(root, path, RetrievalModeSymbol, query.Value, content, lineNumber, query.ContextLines) + if hitErr != nil { + return hitErr } - hits = append(hits, RetrievalHit{ - Path: filepath.Clean(relative), - Kind: string(RetrievalModeSymbol), - SymbolOrQuery: query.Value, - Snippet: snippet, - LineHint: lineHint, - }) + hits = append(hits, hit) } return nil }) @@ -272,6 +274,98 @@ func sortRetrievalHits(hits []RetrievalHit) { }) } +// readRetrievalText 读取并过滤检索候选文件,失败时按“无命中”处理。 +func (s *Service) readRetrievalText(path string, entry fs.DirEntry) (string, bool) { + if !allowRetrievalReadByEntry(path, entry) { + return "", false + } + content, err := s.readFile(path) + if err != nil || isBinaryContent(content) { + return "", false + } + return string(content), true +} + +// buildRetrievalHit 基于命中文件和行号构造统一格式的检索结果。 +func buildRetrievalHit( + root string, + path string, + mode RetrievalMode, + query string, + content string, + lineNumber int, + contextLines int, +) (RetrievalHit, error) { + relativePath, err := filepath.Rel(root, path) + if err != nil { + return RetrievalHit{}, err + } + snippet, lineHint := snippetAroundLine(content, lineNumber, contextLines) + return RetrievalHit{ + Path: filepath.Clean(relativePath), + Kind: string(mode), + SymbolOrQuery: query, + Snippet: snippet, + LineHint: lineHint, + }, nil +} + func readFile(path string) ([]byte, error) { return os.ReadFile(path) } + +// allowRetrievalReadByPath 校验路径模式下目标文件是否允许读取。 +func allowRetrievalReadByPath(path string) bool { + info, err := os.Stat(path) + if err != nil || info.IsDir() { + return false + } + return allowRetrievalByNameAndSize(filepath.Base(path), info.Size()) +} + +// allowRetrievalReadByEntry 校验遍历模式下命中文件是否允许读取。 +func allowRetrievalReadByEntry(path string, entry fs.DirEntry) bool { + info, err := entry.Info() + if err != nil || info.IsDir() { + return false + } + return allowRetrievalByNameAndSize(filepath.Base(path), info.Size()) +} + +// allowRetrievalByNameAndSize 基于文件名和大小过滤敏感文件与高成本文件。 +func allowRetrievalByNameAndSize(name string, size int64) bool { + if size < 0 || size > maxRetrievalFileBytes { + return false + } + lowerName := strings.ToLower(strings.TrimSpace(name)) + if lowerName == "" { + return false + } + if lowerName == ".env" || strings.HasPrefix(lowerName, ".env.") { + return false + } + if _, blocked := blockedSensitiveExtensions[filepath.Ext(lowerName)]; blocked { + return false + } + return true +} + +// isBinaryContent 通过前缀字节判断文件是否为二进制内容。 +func isBinaryContent(content []byte) bool { + if len(content) == 0 { + return false + } + prefixBytes := content + if len(prefixBytes) > binaryProbePrefixSize { + prefixBytes = prefixBytes[:binaryProbePrefixSize] + } + if bytes.IndexByte(prefixBytes, 0x00) >= 0 { + return true + } + for _, b := range prefixBytes { + if b < 0x09 { + return true + } + } + return false +}