diff --git a/go.mod b/go.mod index 1a60a521..10b3edd5 100644 --- a/go.mod +++ b/go.mod @@ -76,6 +76,7 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index 6fb8e9e6..8e0e5ad8 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -10,6 +10,7 @@ import ( tea "github.com/charmbracelet/bubbletea" + "neo-code/internal/checkpoint" "neo-code/internal/config" configstate "neo-code/internal/config/state" agentcontext "neo-code/internal/context" @@ -230,8 +231,34 @@ func BuildGatewayServerDeps(ctx context.Context, opts BootstrapOptions) (Runtime )) } + // Checkpoint 基础设施:SQLite + per-edit 版本化文件历史(不依赖 git)。 + // 优先复用 sessionStore 已打开的 *sql.DB;冷启动尚未建连时显式初始化, + // 避免 sessionStore.DB() 为 nil 时整条 checkpoint 链路被静默跳过。 + sessionDB := sessionStore.DB() + if sessionDB == nil { + if initDB, initErr := sessionStore.InitDB(ctx); initErr == nil { + sessionDB = initDB + } + } + var checkpointStore *checkpoint.SQLiteCheckpointStore + if sessionDB != nil { + checkpointStore = checkpoint.NewSQLiteCheckpointStoreWithDB(sessionDB) + projectDir := agentsession.HashWorkspaceRoot(cfg.Workdir) + snapshotRoot := filepath.Join(sharedDeps.ConfigManager.BaseDir(), "projects", projectDir) + perEditStore := checkpoint.NewPerEditSnapshotStore(snapshotRoot, cfg.Workdir) + runtimeSvc.SetCheckpointDependencies(checkpointStore, perEditStore) + } + // 启动时修复残留的 creating 状态 checkpoint + if checkpointStore != nil { + if repaired, err := checkpointStore.RepairCreatingCheckpoints(ctx); err != nil { + log.Printf("checkpoint repair warning: %v", err) + } else if repaired > 0 { + log.Printf("checkpoint repair: fixed %d stale checkpoints", repaired) + } + } + runtimeImpl := agentruntime.Runtime(runtimeSvc) - closeFns := []func() error{toolsCleanup, sessionStore.Close} + closeFns := []func() error{toolsCleanup, checkpointStore.Close, sessionStore.Close} needCleanup = false @@ -411,6 +438,11 @@ func buildToolRegistry(cfg config.Config) (*tools.Registry, func() error, error) toolRegistry.Register(filesystem.NewGrep(cfg.Workdir)) toolRegistry.Register(filesystem.NewGlob(cfg.Workdir)) toolRegistry.Register(filesystem.NewEdit(cfg.Workdir)) + toolRegistry.Register(filesystem.NewMove(cfg.Workdir)) + toolRegistry.Register(filesystem.NewCopy(cfg.Workdir)) + toolRegistry.Register(filesystem.NewDelete(cfg.Workdir)) + toolRegistry.Register(filesystem.NewCreateDir(cfg.Workdir)) + toolRegistry.Register(filesystem.NewRemoveDir(cfg.Workdir)) toolRegistry.Register(bash.New(cfg.Workdir, cfg.Shell, time.Duration(cfg.ToolTimeoutSec)*time.Second)) toolRegistry.Register(diagnosetool.New()) toolRegistry.Register(webfetch.New(webfetch.Config{ diff --git a/internal/checkpoint/bash_capture.go b/internal/checkpoint/bash_capture.go new file mode 100644 index 00000000..b64bfe5e --- /dev/null +++ b/internal/checkpoint/bash_capture.go @@ -0,0 +1,208 @@ +package checkpoint + +import ( + "path/filepath" + "regexp" + "strings" +) + +var ( + bashWriteRedirectRE = regexp.MustCompile(`(^|[^&\d])>{1,2}\s*[^\s&>]`) + bashSedInplaceRE = regexp.MustCompile(`\bsed\b[^|;&]*?\s(-i|-i\.[^\s]+|--in-place)`) + bashAwkInplaceRE = regexp.MustCompile(`\bawk\b[^|;&]*?-i\b`) + bashGitWriteRE = regexp.MustCompile(`\bgit\s+(checkout|restore|reset|apply|pull|merge|rebase|am|cherry-pick|revert|commit|add|rm|mv|stash|clean)\b`) + bashPkgManagerRE = regexp.MustCompile(`\b(npm|yarn|pnpm|bower)\s+(install|i|add|remove|uninstall|i)\b`) + bashPipInstallRE = regexp.MustCompile(`\bpip\s*\d*\s+(install|uninstall)\b`) + bashGoInstallRE = regexp.MustCompile(`\bgo\s+(get|install|mod\s+(download|tidy|vendor)|generate)\b`) + bashCargoRE = regexp.MustCompile(`\bcargo\s+(install|add|remove|update|build|fetch|generate)\b`) + bashArchiveRE = regexp.MustCompile(`\b(unzip|gunzip|bunzip2|tar)\b`) + bashFindDeleteRE = regexp.MustCompile(`\bfind\b[^|;&]*?(-delete|-exec\s+rm)`) + bashTeeRE = regexp.MustCompile(`\btee\b`) + bashShellSplitRE = regexp.MustCompile(`[;&|<>()\s{}` + "`" + `]+`) +) + +// bashWriteCommands lists single-word commands that mutate files when invoked. +var bashWriteCommands = []string{ + "mv", "cp", "rm", "touch", "mkdir", "rmdir", "ln", "chmod", "chown", + "dd", "patch", "install", "rsync", "shred", "truncate", "trash", +} + +var bashWriteCommandRE = regexp.MustCompile(`\b(` + strings.Join(bashWriteCommands, "|") + `)\b`) + +// BashLikelyWritesFiles 基于启发式判断 bash 命令是否可能写文件。 +// 设计偏保守:宁可多 capture(返回 true),也不漏(false 时由 fingerprint 兜底)。 +// 仅在能明确判定为只读时返回 false。 +func BashLikelyWritesFiles(command string) bool { + cmd := strings.TrimSpace(command) + if cmd == "" { + return false + } + sanitized := stripHarmlessRedirects(cmd) + if bashWriteRedirectRE.MatchString(sanitized) { + return true + } + lower := strings.ToLower(sanitized) + if bashWriteCommandRE.MatchString(lower) { + return true + } + if bashSedInplaceRE.MatchString(lower) { + return true + } + if bashAwkInplaceRE.MatchString(lower) { + return true + } + if bashGitWriteRE.MatchString(lower) { + return true + } + if bashPkgManagerRE.MatchString(lower) { + return true + } + if bashPipInstallRE.MatchString(lower) { + return true + } + if bashGoInstallRE.MatchString(lower) { + return true + } + if bashCargoRE.MatchString(lower) { + return true + } + if bashTeeRE.MatchString(lower) { + return true + } + if bashFindDeleteRE.MatchString(lower) { + return true + } + if bashArchiveRE.MatchString(lower) && bashHasArchiveExtractFlag(lower) { + return true + } + return false +} + +// SourceFilesInWorkdir 从命令中尝试提取 workdir 内的文件路径(保守估计)。 +// 仅匹配看起来像源代码/配置/文本的扩展名,返回的路径可能不准确(启发式),由 fingerprint 兜底。 +func SourceFilesInWorkdir(command, workdir string) []string { + if strings.TrimSpace(command) == "" { + return nil + } + tokens := tokenizeBashArgs(command) + seen := make(map[string]struct{}) + out := make([]string, 0, len(tokens)) + for _, tok := range tokens { + tok = strings.Trim(tok, `"'`) + if tok == "" { + continue + } + if !hasRecognizedSourceExt(tok) { + continue + } + abs := resolvePathAgainstWorkdir(tok, workdir) + if abs == "" { + continue + } + if _, dup := seen[abs]; dup { + continue + } + seen[abs] = struct{}{} + out = append(out, abs) + } + if len(out) == 0 { + return nil + } + return out +} + +func tokenizeBashArgs(cmd string) []string { + parts := bashShellSplitRE.Split(cmd, -1) + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + p = strings.Trim(p, `"'`) + if p == "" { + continue + } + out = append(out, p) + } + return out +} + +func resolvePathAgainstWorkdir(p, workdir string) string { + if strings.ContainsAny(p, "*?[") { + return "" + } + workdirClean := filepath.Clean(strings.TrimSpace(workdir)) + var abs string + if filepath.IsAbs(p) { + abs = filepath.Clean(p) + } else { + if workdirClean == "" || workdirClean == "." { + return "" + } + abs = filepath.Clean(filepath.Join(workdirClean, p)) + } + if workdirClean == "" || workdirClean == "." { + return abs + } + rel, err := filepath.Rel(workdirClean, abs) + if err != nil { + return "" + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "" + } + return abs +} + +func stripHarmlessRedirects(cmd string) string { + r := strings.NewReplacer( + "2>&1", "", + "1>&2", "", + ">&2", "", + ">&-", "", + "<&-", "", + "&>&-", "", + ) + return r.Replace(cmd) +} + +func bashHasArchiveExtractFlag(lower string) bool { + if strings.Contains(lower, "unzip") || strings.Contains(lower, "gunzip") || strings.Contains(lower, "bunzip2") { + return true + } + if !strings.Contains(lower, "tar") { + return false + } + for _, marker := range []string{" -x", " --extract", "tar x", "-xf", "-xv", "-xz", "-xj", "-xJ", "xvf"} { + if strings.Contains(lower, marker) { + return true + } + } + return false +} + +var bashSourceExts = map[string]struct{}{ + ".go": {}, ".rs": {}, ".py": {}, ".js": {}, ".jsx": {}, ".ts": {}, ".tsx": {}, + ".java": {}, ".c": {}, ".cpp": {}, ".cc": {}, ".cxx": {}, ".h": {}, ".hpp": {}, ".hxx": {}, + ".rb": {}, ".php": {}, ".swift": {}, ".kt": {}, ".scala": {}, ".groovy": {}, + ".md": {}, ".rst": {}, ".txt": {}, + ".json": {}, ".yaml": {}, ".yml": {}, ".toml": {}, ".ini": {}, ".conf": {}, ".cfg": {}, ".properties": {}, + ".html": {}, ".htm": {}, ".xml": {}, ".css": {}, ".scss": {}, ".sass": {}, ".less": {}, + ".vue": {}, ".svelte": {}, ".astro": {}, + ".sh": {}, ".bash": {}, ".zsh": {}, ".fish": {}, ".ps1": {}, + ".sql": {}, ".graphql": {}, ".gql": {}, ".proto": {}, + ".csv": {}, ".tsv": {}, ".log": {}, + ".env": {}, ".lock": {}, +} + +func hasRecognizedSourceExt(p string) bool { + ext := strings.ToLower(filepath.Ext(p)) + if ext == "" { + base := strings.ToLower(filepath.Base(p)) + switch base { + case "dockerfile", "makefile", ".gitignore", ".dockerignore", ".env": + return true + } + return false + } + _, ok := bashSourceExts[ext] + return ok +} diff --git a/internal/checkpoint/bash_capture_test.go b/internal/checkpoint/bash_capture_test.go new file mode 100644 index 00000000..ecfbf34f --- /dev/null +++ b/internal/checkpoint/bash_capture_test.go @@ -0,0 +1,222 @@ +package checkpoint + +import ( + "path/filepath" + "sort" + "testing" +) + +// TestBashLikelyWritesFiles_PositiveCases: commands that mutate files must return true. +func TestBashLikelyWritesFiles_PositiveCases(t *testing.T) { + cases := []string{ + `echo hello > out.txt`, + `echo more >> out.txt`, + `cat src.go > dst.go`, + `mv old.go new.go`, + `cp src.go dst.go`, + `rm stale.txt`, + `rm -rf build/`, + `touch new.go`, + `mkdir -p pkg/foo`, + `rmdir empty`, + `ln -s a b`, + `chmod +x script.sh`, + `chown user:group file`, + `patch -p1 < change.patch`, + `rsync -av src/ dst/`, + `sed -i 's/foo/bar/g' main.go`, + `sed -i.bak 's/foo/bar/' main.go`, + `sed --in-place 's/x/y/' f.txt`, + `awk -i inplace '{print}' f.txt`, + `git checkout main`, + `git restore --staged file.go`, + `git reset --hard HEAD`, + `git apply patch.diff`, + `git pull origin main`, + `git merge feature`, + `git rebase main`, + `git cherry-pick abc123`, + `git revert HEAD`, + `git commit -m "x"`, + `git add .`, + `git rm old.go`, + `git mv a b`, + `git stash`, + `git clean -fd`, + `npm install`, + `npm i lodash`, + `yarn add react`, + `pnpm install`, + `pnpm add foo`, + `pip install requests`, + `pip3 install -r requirements.txt`, + `go get github.com/x/y`, + `go install ./cmd/x`, + `go mod tidy`, + `go mod download`, + `go mod vendor`, + `go generate ./...`, + `cargo install ripgrep`, + `cargo build`, + `cargo update`, + `unzip archive.zip`, + `tar -xzf bundle.tar.gz`, + `tar xvf bundle.tar`, + `gunzip data.gz`, + `bunzip2 data.bz2`, + `find . -name '*.tmp' -delete`, + `find . -type f -exec rm {} \;`, + `echo content | tee out.txt`, + `dd if=/dev/zero of=disk.img bs=1M count=100`, + `truncate -s 0 log.txt`, + } + for _, cmd := range cases { + if !BashLikelyWritesFiles(cmd) { + t.Errorf("expected write=true for %q", cmd) + } + } +} + +// TestBashLikelyWritesFiles_NegativeCases: read-only commands must return false. +func TestBashLikelyWritesFiles_NegativeCases(t *testing.T) { + cases := []string{ + ``, + ` `, + `ls`, + `ls -la`, + `pwd`, + `cat file.txt`, + `cat file.txt 2>&1`, + `head -20 file.go`, + `tail -f log.txt`, + `grep -r foo .`, + `grep -n bar file.go`, + `find . -name '*.go'`, + `find . -type f`, + `git status`, + `git log --oneline`, + `git diff main`, + `git show HEAD`, + `git branch`, + `git remote -v`, + `go version`, + `go env`, + `go test ./...`, + `go vet ./...`, + `go build ./...`, + `echo hello`, + `printf "x"`, + `which bash`, + `whoami`, + `uname -a`, + `ps aux`, + `df -h`, + `du -sh .`, + `wc -l main.go`, + `sort file.txt`, + `uniq file.txt`, + `diff a.txt b.txt`, + `stat file.go`, + `file binary`, + `echo done 2>&1 1>&2`, + // echo with stderr-only redirection should remain read-only after stripHarmlessRedirects + `some_cmd >&2`, + } + for _, cmd := range cases { + if BashLikelyWritesFiles(cmd) { + t.Errorf("expected write=false for %q", cmd) + } + } +} + +// TestSourceFilesInWorkdir_ExtractsFromCommonPatterns: paths inside workdir with recognized +// extensions are returned, paths outside or with unknown extensions are filtered. +func TestSourceFilesInWorkdir_ExtractsFromCommonPatterns(t *testing.T) { + root := t.TempDir() + + // Helper to convert a workdir-relative slash path into the platform abs path. + abs := func(rel string) string { + return filepath.Clean(filepath.Join(root, filepath.FromSlash(rel))) + } + + type tc struct { + name string + command string + want []string + } + + cases := []tc{ + { + name: "simple_relative_paths", + command: `mv pkg/a.go pkg/b.go`, + want: []string{abs("pkg/a.go"), abs("pkg/b.go")}, + }, + { + name: "absolute_paths_inside_workdir", + command: `cp ` + abs("src/main.go") + ` ` + abs("src/main_copy.go"), + want: []string{abs("src/main.go"), abs("src/main_copy.go")}, + }, + { + name: "redirect_target", + command: `echo hello > notes.md`, + want: []string{abs("notes.md")}, + }, + { + name: "deduplicates_repeated_paths", + command: `cp main.go main.go`, + want: []string{abs("main.go")}, + }, + { + name: "filters_unknown_extensions", + command: `mv binary.bin other.exe`, + want: nil, + }, + { + name: "filters_paths_outside_workdir", + command: `cat ../escape.go ../../further.go > out.log`, + want: []string{abs("out.log")}, + }, + { + name: "ignores_glob_arguments", + command: `rm *.go pkg/*.json`, + want: nil, + }, + { + name: "extracts_yaml_and_toml_paths", + command: `sed -i 's/x/y/g' config.yaml settings.toml`, + want: []string{abs("config.yaml"), abs("settings.toml")}, + }, + { + name: "empty_command_returns_nil", + command: ``, + want: nil, + }, + { + name: "whitespace_command_returns_nil", + command: ` `, + want: nil, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := SourceFilesInWorkdir(c.command, root) + gotSorted := append([]string(nil), got...) + sort.Strings(gotSorted) + wantSorted := append([]string(nil), c.want...) + sort.Strings(wantSorted) + if !equalStringSlice(gotSorted, wantSorted) { + t.Fatalf("got %v want %v", got, c.want) + } + }) + } +} + +// TestSourceFilesInWorkdir_HandlesEmptyWorkdir: with empty workdir we cannot compute +// safe relative paths, so the function should return nil for relative inputs. +func TestSourceFilesInWorkdir_HandlesEmptyWorkdir(t *testing.T) { + got := SourceFilesInWorkdir(`mv a.go b.go`, "") + if got != nil { + t.Fatalf("expected nil with empty workdir, got %v", got) + } +} diff --git a/internal/checkpoint/checkpoint_manager.go b/internal/checkpoint/checkpoint_manager.go new file mode 100644 index 00000000..d8d4a950 --- /dev/null +++ b/internal/checkpoint/checkpoint_manager.go @@ -0,0 +1,660 @@ +package checkpoint + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "reflect" + "sync" + "time" + + providertypes "neo-code/internal/provider/types" + "neo-code/internal/session" +) + +// CheckpointStore 定义 checkpoint 持久化的意图型接口。 +type CheckpointStore interface { + CreateCheckpoint(ctx context.Context, input CreateCheckpointInput) (session.CheckpointRecord, error) + ListCheckpoints(ctx context.Context, sessionID string, opts ListCheckpointOpts) ([]session.CheckpointRecord, error) + GetCheckpoint(ctx context.Context, checkpointID string) (session.CheckpointRecord, *session.SessionCheckpoint, error) + UpdateCheckpointStatus(ctx context.Context, checkpointID string, status session.CheckpointStatus) error + GetLatestResumeCheckpoint(ctx context.Context, sessionID string) (*session.ResumeCheckpoint, error) + RestoreCheckpoint(ctx context.Context, input RestoreCheckpointInput) error + SetResumeCheckpoint(ctx context.Context, rc session.ResumeCheckpoint) error + PruneExpiredCheckpoints(ctx context.Context, sessionID string, maxAutoKeep int) (int, error) + RepairCreatingCheckpoints(ctx context.Context) (int, error) +} + +// CreateCheckpointInput 描述一次 checkpoint 创建的完整输入。 +type CreateCheckpointInput struct { + Record session.CheckpointRecord + SessionCP session.SessionCheckpoint +} + +// ListCheckpointOpts 描述 checkpoint 列表查询选项。 +type ListCheckpointOpts struct { + Limit int + RestorableOnly bool +} + +// RestoreCheckpointInput 描述一次 restore 操作的完整输入。 +type RestoreCheckpointInput struct { + SessionID string + Head session.SessionHead + Messages []providertypes.Message + UpdatedAt time.Time + MarkAvailableIDs []string + MarkRestoredIDs []string +} + +// SQLiteCheckpointStore 基于 SQLite 实现 checkpoint 持久化。 +type SQLiteCheckpointStore struct { + dbPath string + initMu sync.Mutex + db *sql.DB + ownsDB bool // true 表示本实例打开的连接,Close 时需释放 +} + +// NewSQLiteCheckpointStore 创建 checkpoint 存储实例。 +// dbPath 为 session.db 文件路径,可通过 session.DatabasePath 获取。 +func NewSQLiteCheckpointStore(dbPath string) *SQLiteCheckpointStore { + return &SQLiteCheckpointStore{ + dbPath: dbPath, + } +} + +// NewSQLiteCheckpointStoreWithDB 创建 checkpoint 存储实例,复用已有的 *sql.DB 连接。 +// 适用于与 session store 共享同一数据库文件的场景,避免 Windows 上多连接文件锁定。 +// Close 不会关闭传入的连接,由调用方管理连接生命周期。 +func NewSQLiteCheckpointStoreWithDB(db *sql.DB) *SQLiteCheckpointStore { + return &SQLiteCheckpointStore{ + db: db, + ownsDB: false, + } +} + +// Close 释放数据库连接。仅当本实例拥有连接时(ownsDB=true)才实际关闭。 +func (s *SQLiteCheckpointStore) Close() error { + if s == nil || s.db == nil || !s.ownsDB { + return nil + } + return s.db.Close() +} + +func (s *SQLiteCheckpointStore) ensureDB(ctx context.Context) (*sql.DB, error) { + s.initMu.Lock() + defer s.initMu.Unlock() + if s.db != nil { + return s.db, nil + } + db, err := sql.Open("sqlite", s.dbPath) + if err != nil { + return nil, fmt.Errorf("checkpoint: open sqlite db: %w", err) + } + db.SetMaxOpenConns(2) + db.SetMaxIdleConns(2) + + pragmas := []string{ + `PRAGMA journal_mode=WAL`, + `PRAGMA synchronous=NORMAL`, + `PRAGMA foreign_keys=ON`, + `PRAGMA busy_timeout=5000`, + } + for _, pragma := range pragmas { + if _, err := db.ExecContext(ctx, pragma); err != nil { + _ = db.Close() + return nil, fmt.Errorf("checkpoint: apply pragma %q: %w", pragma, err) + } + } + s.db = db + s.ownsDB = true + return db, nil +} + +// CreateCheckpoint 在单一事务内写入 checkpoint record + session checkpoint。 +// 事务内完成 record INSERT → session_cp INSERT → record UPDATE(设置 session_checkpoint_ref + status=available)。 +func (s *SQLiteCheckpointStore) CreateCheckpoint(ctx context.Context, input CreateCheckpointInput) (session.CheckpointRecord, error) { + if err := ctx.Err(); err != nil { + return session.CheckpointRecord{}, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return session.CheckpointRecord{}, err + } + + record := input.Record + sessionCP := input.SessionCP + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return session.CheckpointRecord{}, fmt.Errorf("checkpoint: begin create tx: %w", err) + } + defer rollbackTx(tx) + + // INSERT checkpoint_record (status = creating) + _, err = tx.ExecContext(ctx, ` +INSERT INTO checkpoint_records ( + id, workspace_key, session_id, run_id, workdir, created_at_ms, + reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref, + transcript_revision, restorable, status +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +`, + record.CheckpointID, + record.WorkspaceKey, + record.SessionID, + record.RunID, + record.Workdir, + toUnixMillis(record.CreatedAt), + string(record.Reason), + record.CodeCheckpointRef, + "", // session_checkpoint_ref filled below + record.ResumeCheckpointRef, + record.TranscriptRevision, + boolToInt(record.Restorable), + string(session.CheckpointStatusCreating), + ) + if err != nil { + return session.CheckpointRecord{}, fmt.Errorf("checkpoint: insert record %s: %w", record.CheckpointID, err) + } + + // INSERT session_checkpoint + _, err = tx.ExecContext(ctx, ` +INSERT INTO session_checkpoints (id, session_id, head_json, messages_json, created_at_ms) +VALUES (?, ?, ?, ?, ?) +`, + sessionCP.ID, + sessionCP.SessionID, + sessionCP.HeadJSON, + sessionCP.MessagesJSON, + toUnixMillis(sessionCP.CreatedAt), + ) + if err != nil { + return session.CheckpointRecord{}, fmt.Errorf("checkpoint: insert session cp %s: %w", sessionCP.ID, err) + } + + // UPDATE checkpoint_record: set session_checkpoint_ref + status = available + _, err = tx.ExecContext(ctx, ` +UPDATE checkpoint_records +SET session_checkpoint_ref = ?, status = ? +WHERE id = ? +`, + sessionCP.ID, + string(session.CheckpointStatusAvailable), + record.CheckpointID, + ) + if err != nil { + return session.CheckpointRecord{}, fmt.Errorf("checkpoint: update record %s: %w", record.CheckpointID, err) + } + + if err := tx.Commit(); err != nil { + return session.CheckpointRecord{}, fmt.Errorf("checkpoint: commit create %s: %w", record.CheckpointID, err) + } + + record.SessionCheckpointRef = sessionCP.ID + record.Status = session.CheckpointStatusAvailable + return record, nil +} + +// ListCheckpoints 查询指定会话的 checkpoint 记录列表。 +func (s *SQLiteCheckpointStore) ListCheckpoints(ctx context.Context, sessionID string, opts ListCheckpointOpts) ([]session.CheckpointRecord, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return nil, err + } + + query := ` +SELECT id, workspace_key, session_id, run_id, workdir, created_at_ms, + reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref, + transcript_revision, restorable, status +FROM checkpoint_records +WHERE session_id = ? +` + args := []any{sessionID} + if opts.RestorableOnly { + query += ` AND restorable = 1 AND status = ?` + args = append(args, string(session.CheckpointStatusAvailable)) + } + query += ` ORDER BY created_at_ms DESC` + if opts.Limit > 0 { + query += ` LIMIT ?` + args = append(args, opts.Limit) + } + + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("checkpoint: list checkpoints for %s: %w", sessionID, err) + } + defer rows.Close() + + var records []session.CheckpointRecord + for rows.Next() { + var r session.CheckpointRecord + var createdAtMS int64 + var reason, status string + if err := rows.Scan( + &r.CheckpointID, &r.WorkspaceKey, &r.SessionID, &r.RunID, &r.Workdir, &createdAtMS, + &reason, &r.CodeCheckpointRef, &r.SessionCheckpointRef, &r.ResumeCheckpointRef, + &r.TranscriptRevision, &r.Restorable, &status, + ); err != nil { + return nil, fmt.Errorf("checkpoint: scan record: %w", err) + } + r.CreatedAt = fromUnixMillis(createdAtMS) + r.Reason = session.CheckpointReason(reason) + r.Status = session.CheckpointStatus(status) + records = append(records, r) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("checkpoint: iterate records: %w", err) + } + return records, nil +} + +// GetCheckpoint 查询单条 checkpoint record 及其关联的 session checkpoint。 +func (s *SQLiteCheckpointStore) GetCheckpoint(ctx context.Context, checkpointID string) (session.CheckpointRecord, *session.SessionCheckpoint, error) { + if err := ctx.Err(); err != nil { + return session.CheckpointRecord{}, nil, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return session.CheckpointRecord{}, nil, err + } + + var r session.CheckpointRecord + var createdAtMS int64 + var reason, status string + err = db.QueryRowContext(ctx, ` +SELECT id, workspace_key, session_id, run_id, workdir, created_at_ms, + reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref, + transcript_revision, restorable, status +FROM checkpoint_records +WHERE id = ? +`, checkpointID).Scan( + &r.CheckpointID, &r.WorkspaceKey, &r.SessionID, &r.RunID, &r.Workdir, &createdAtMS, + &reason, &r.CodeCheckpointRef, &r.SessionCheckpointRef, &r.ResumeCheckpointRef, + &r.TranscriptRevision, &r.Restorable, &status, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return session.CheckpointRecord{}, nil, fmt.Errorf("checkpoint: record %s not found", checkpointID) + } + return session.CheckpointRecord{}, nil, fmt.Errorf("checkpoint: query record %s: %w", checkpointID, err) + } + r.CreatedAt = fromUnixMillis(createdAtMS) + r.Reason = session.CheckpointReason(reason) + r.Status = session.CheckpointStatus(status) + + if r.SessionCheckpointRef == "" { + return r, nil, nil + } + + var sc session.SessionCheckpoint + var scCreatedAtMS int64 + err = db.QueryRowContext(ctx, ` +SELECT id, session_id, head_json, messages_json, created_at_ms +FROM session_checkpoints +WHERE id = ? +`, r.SessionCheckpointRef).Scan( + &sc.ID, &sc.SessionID, &sc.HeadJSON, &sc.MessagesJSON, &scCreatedAtMS, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return r, nil, nil + } + return session.CheckpointRecord{}, nil, fmt.Errorf("checkpoint: query session cp %s: %w", r.SessionCheckpointRef, err) + } + sc.CreatedAt = fromUnixMillis(scCreatedAtMS) + return r, &sc, nil +} + +// UpdateCheckpointStatus 更新 checkpoint 的生命周期状态。 +func (s *SQLiteCheckpointStore) UpdateCheckpointStatus(ctx context.Context, checkpointID string, status session.CheckpointStatus) error { + if err := ctx.Err(); err != nil { + return err + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + result, err := db.ExecContext(ctx, `UPDATE checkpoint_records SET status = ? WHERE id = ?`, string(status), checkpointID) + if err != nil { + return fmt.Errorf("checkpoint: update status %s: %w", checkpointID, err) + } + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("checkpoint: inspect rows affected for %s: %w", checkpointID, err) + } + if affected == 0 { + return fmt.Errorf("checkpoint: record %s not found", checkpointID) + } + return nil +} + +// GetLatestResumeCheckpoint 查询指定会话最新的 resume checkpoint。 +func (s *SQLiteCheckpointStore) GetLatestResumeCheckpoint(ctx context.Context, sessionID string) (*session.ResumeCheckpoint, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return nil, err + } + + var rc session.ResumeCheckpoint + var updatedAtMS int64 + err = db.QueryRowContext(ctx, ` +SELECT id, workspace_key, run_id, session_id, turn, phase, completion_state, transcript_revision, updated_at_ms +FROM resume_checkpoints +WHERE session_id = ? +ORDER BY updated_at_ms DESC +LIMIT 1 +`, sessionID).Scan( + &rc.ID, &rc.WorkspaceKey, &rc.RunID, &rc.SessionID, + &rc.Turn, &rc.Phase, &rc.CompletionState, + &rc.TranscriptRevision, &updatedAtMS, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("checkpoint: query resume checkpoint for %s: %w", sessionID, err) + } + rc.UpdatedAt = fromUnixMillis(updatedAtMS) + return &rc, nil +} + +// RestoreCheckpoint 在单一事务内恢复会话消息和头状态,并批量更新 checkpoint 状态。 +func (s *SQLiteCheckpointStore) RestoreCheckpoint(ctx context.Context, input RestoreCheckpointInput) error { + if err := ctx.Err(); err != nil { + return err + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("checkpoint: begin restore tx: %w", err) + } + defer rollbackTx(tx) + + // DELETE existing messages + if _, err := tx.ExecContext(ctx, `DELETE FROM messages WHERE session_id = ?`, input.SessionID); err != nil { + return fmt.Errorf("checkpoint: delete messages %s: %w", input.SessionID, err) + } + + // Re-insert messages + now := input.UpdatedAt + for i, msg := range input.Messages { + seq := i + 1 + toolCallsJSON := "[]" + if len(msg.ToolCalls) > 0 { + if data, err := json.Marshal(msg.ToolCalls); err == nil { + toolCallsJSON = string(data) + } + } + toolMetadataJSON := "{}" + if msg.ToolMetadata != nil { + if data, err := json.Marshal(msg.ToolMetadata); err == nil { + toolMetadataJSON = string(data) + } + } + partsJSON := "[]" + if len(msg.Parts) > 0 { + if data, err := json.Marshal(msg.Parts); err == nil { + partsJSON = string(data) + } + } + if _, err := tx.ExecContext(ctx, `INSERT INTO messages (session_id, seq, role, parts_json, tool_calls_json, tool_call_id, is_error, tool_metadata_json, created_at_ms) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + input.SessionID, seq, msg.Role, partsJSON, toolCallsJSON, msg.ToolCallID, boolToInt(msg.IsError), toolMetadataJSON, toUnixMillis(now), + ); err != nil { + return fmt.Errorf("checkpoint: insert message %s/%d: %w", input.SessionID, seq, err) + } + } + + // UPDATE session head + h := input.Head + result, err := tx.ExecContext(ctx, `UPDATE sessions SET updated_at_ms=?, provider=?, model=?, workdir=?, task_state_json=?, todos_json=?, activated_skills_json=?, token_input_total=?, token_output_total=?, has_unknown_usage=?, agent_mode=?, current_plan_json=?, last_full_plan_revision=?, plan_approval_pending_full_align=?, plan_completion_pending_full_review=?, plan_context_dirty=?, plan_restore_pending_align=?, last_seq=?, message_count=? WHERE id=?`, + toUnixMillis(input.UpdatedAt), h.Provider, h.Model, h.Workdir, + marshalHeadField(h.TaskState), marshalHeadField(h.Todos), marshalHeadField(h.ActivatedSkills), + h.TokenInputTotal, h.TokenOutputTotal, boolToInt(h.HasUnknownUsage), h.AgentMode, + marshalPlanField(h.CurrentPlan), h.LastFullPlanRevision, + boolToInt(h.PlanApprovalPendingFullAlign), boolToInt(h.PlanCompletionPendingFullReview), + boolToInt(h.PlanContextDirty), boolToInt(h.PlanRestorePendingAlign), + len(input.Messages), len(input.Messages), input.SessionID, + ) + if err != nil { + return fmt.Errorf("checkpoint: update session %s: %w", input.SessionID, err) + } + if affected, _ := result.RowsAffected(); affected == 0 { + return fmt.Errorf("checkpoint: session %s not found", input.SessionID) + } + + // Mark available + for _, id := range input.MarkAvailableIDs { + if _, err := tx.ExecContext(ctx, `UPDATE checkpoint_records SET status=? WHERE id=?`, string(session.CheckpointStatusAvailable), id); err != nil { + return fmt.Errorf("checkpoint: mark available %s: %w", id, err) + } + } + + // Mark restored + for _, id := range input.MarkRestoredIDs { + if _, err := tx.ExecContext(ctx, `UPDATE checkpoint_records SET status=? WHERE id=?`, string(session.CheckpointStatusRestored), id); err != nil { + return fmt.Errorf("checkpoint: mark restored %s: %w", id, err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("checkpoint: commit restore %s: %w", input.SessionID, err) + } + return nil +} + +func marshalHeadField(value any) string { + data, err := json.Marshal(value) + if err != nil { + return "null" + } + return string(data) +} + +// marshalPlanField 将可选计划字段编码为 session 兼容的持久化格式,nil 计划统一写为空串。 +func marshalPlanField(value any) string { + if value == nil { + return "" + } + rv := reflect.ValueOf(value) + if rv.Kind() == reflect.Pointer && rv.IsNil() { + return "" + } + data, err := json.Marshal(value) + if err != nil { + return "" + } + return string(data) +} + +// SetResumeCheckpoint 写入或更新 ResumeCheckpoint(一个 session 只保留一条)。 +func (s *SQLiteCheckpointStore) SetResumeCheckpoint(ctx context.Context, rc session.ResumeCheckpoint) error { + if err := ctx.Err(); err != nil { + return err + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("checkpoint: begin set resume tx: %w", err) + } + defer rollbackTx(tx) + + if _, err := tx.ExecContext(ctx, `DELETE FROM resume_checkpoints WHERE session_id=?`, rc.SessionID); err != nil { + return fmt.Errorf("checkpoint: delete old resume cp %s: %w", rc.SessionID, err) + } + + if _, err := tx.ExecContext(ctx, `INSERT INTO resume_checkpoints (id, workspace_key, run_id, session_id, turn, phase, completion_state, transcript_revision, updated_at_ms) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + rc.ID, rc.WorkspaceKey, rc.RunID, rc.SessionID, rc.Turn, rc.Phase, rc.CompletionState, rc.TranscriptRevision, toUnixMillis(rc.UpdatedAt), + ); err != nil { + return fmt.Errorf("checkpoint: insert resume cp %s: %w", rc.SessionID, err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("checkpoint: commit set resume cp %s: %w", rc.SessionID, err) + } + return nil +} + +// PruneExpiredCheckpoints 窗口裁剪,将超出 maxAutoKeep 的旧自动 checkpoint 标记为 pruned。 +func (s *SQLiteCheckpointStore) PruneExpiredCheckpoints(ctx context.Context, sessionID string, maxAutoKeep int) (int, error) { + if err := ctx.Err(); err != nil { + return 0, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return 0, err + } + + rows, err := db.QueryContext(ctx, `SELECT id, session_checkpoint_ref FROM checkpoint_records WHERE session_id=? AND restorable=1 AND status=? AND reason NOT IN (?, ?) ORDER BY created_at_ms DESC`, + sessionID, string(session.CheckpointStatusAvailable), string(session.CheckpointReasonManual), string(session.CheckpointReasonGuard), + ) + if err != nil { + return 0, fmt.Errorf("checkpoint: query prune candidates %s: %w", sessionID, err) + } + defer rows.Close() + + type pruneTarget struct { + ID string + SessionCPRef string + } + var targets []pruneTarget + idx := 0 + for rows.Next() { + var t pruneTarget + if err := rows.Scan(&t.ID, &t.SessionCPRef); err != nil { + return 0, fmt.Errorf("checkpoint: scan prune candidate: %w", err) + } + if idx >= maxAutoKeep { + targets = append(targets, t) + } + idx++ + } + if err := rows.Err(); err != nil { + return 0, fmt.Errorf("checkpoint: iterate prune candidates: %w", err) + } + if len(targets) == 0 { + return 0, nil + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return 0, fmt.Errorf("checkpoint: begin prune tx: %w", err) + } + defer rollbackTx(tx) + + for _, t := range targets { + if _, err := tx.ExecContext(ctx, `UPDATE checkpoint_records SET restorable=0, status=? WHERE id=?`, string(session.CheckpointStatusPruned), t.ID); err != nil { + return 0, fmt.Errorf("checkpoint: prune record %s: %w", t.ID, err) + } + if t.SessionCPRef != "" { + if _, err := tx.ExecContext(ctx, `DELETE FROM session_checkpoints WHERE id=?`, t.SessionCPRef); err != nil { + return 0, fmt.Errorf("checkpoint: delete session cp %s: %w", t.SessionCPRef, err) + } + } + } + + if err := tx.Commit(); err != nil { + return 0, fmt.Errorf("checkpoint: commit prune %s: %w", sessionID, err) + } + return len(targets), nil +} + +// RepairCreatingCheckpoints 修复残留的 creating 状态 checkpoint。 +func (s *SQLiteCheckpointStore) RepairCreatingCheckpoints(ctx context.Context) (int, error) { + if err := ctx.Err(); err != nil { + return 0, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return 0, err + } + + rows, err := db.QueryContext(ctx, `SELECT id, session_checkpoint_ref FROM checkpoint_records WHERE status=?`, string(session.CheckpointStatusCreating)) + if err != nil { + return 0, fmt.Errorf("checkpoint: query creating records: %w", err) + } + defer rows.Close() + + type repairTarget struct { + ID string + SessionCPRef string + } + var targets []repairTarget + for rows.Next() { + var t repairTarget + if err := rows.Scan(&t.ID, &t.SessionCPRef); err != nil { + return 0, fmt.Errorf("checkpoint: scan creating record: %w", err) + } + targets = append(targets, t) + } + if err := rows.Err(); err != nil { + return 0, fmt.Errorf("checkpoint: iterate creating records: %w", err) + } + if len(targets) == 0 { + return 0, nil + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return 0, fmt.Errorf("checkpoint: begin repair tx: %w", err) + } + defer rollbackTx(tx) + + for _, t := range targets { + if t.SessionCPRef != "" { + if _, err := tx.ExecContext(ctx, `UPDATE checkpoint_records SET status=? WHERE id=?`, string(session.CheckpointStatusAvailable), t.ID); err != nil { + return 0, fmt.Errorf("checkpoint: repair available %s: %w", t.ID, err) + } + } else { + if _, err := tx.ExecContext(ctx, `DELETE FROM checkpoint_records WHERE id=?`, t.ID); err != nil { + return 0, fmt.Errorf("checkpoint: delete orphan %s: %w", t.ID, err) + } + } + } + + if err := tx.Commit(); err != nil { + return 0, fmt.Errorf("checkpoint: commit repair: %w", err) + } + return len(targets), nil +} + +func toUnixMillis(value time.Time) int64 { + return value.UTC().UnixMilli() +} + +func fromUnixMillis(value int64) time.Time { + if value == 0 { + return time.Time{} + } + return time.UnixMilli(value).UTC() +} + +func boolToInt(value bool) int { + if value { + return 1 + } + return 0 +} + +func rollbackTx(tx *sql.Tx) { + if tx != nil { + _ = tx.Rollback() + } +} diff --git a/internal/checkpoint/checkpoint_manager_test.go b/internal/checkpoint/checkpoint_manager_test.go new file mode 100644 index 00000000..6afd96a1 --- /dev/null +++ b/internal/checkpoint/checkpoint_manager_test.go @@ -0,0 +1,590 @@ +package checkpoint + +import ( + "context" + "database/sql" + "encoding/json" + "path/filepath" + "strings" + "testing" + "time" + + providertypes "neo-code/internal/provider/types" + "neo-code/internal/session" +) + +type checkpointStoreFixture struct { + sessionStore *session.SQLiteStore + checkpointStore *SQLiteCheckpointStore + baseDir string + workspaceRoot string +} + +func newCheckpointStoreFixture(t *testing.T) checkpointStoreFixture { + t.Helper() + + baseDir := t.TempDir() + workspaceRoot := t.TempDir() + + sessionStore := session.NewSQLiteStore(baseDir, workspaceRoot) + t.Cleanup(func() { + _ = sessionStore.Close() + }) + + checkpointStore := NewSQLiteCheckpointStore(session.DatabasePath(baseDir, workspaceRoot)) + t.Cleanup(func() { + _ = checkpointStore.Close() + }) + + return checkpointStoreFixture{ + sessionStore: sessionStore, + checkpointStore: checkpointStore, + baseDir: baseDir, + workspaceRoot: workspaceRoot, + } +} + +func createCheckpointTestSession(t *testing.T, store *session.SQLiteStore, id string, workdir string) session.Session { + t.Helper() + + created, err := store.CreateSession(context.Background(), session.CreateSessionInput{ + ID: id, + Title: "checkpoint test", + Head: session.SessionHead{ + Provider: "openai", + Model: "gpt-test", + Workdir: workdir, + TaskState: session.TaskState{ + Goal: "before restore", + VerificationProfile: session.VerificationProfileTaskOnly, + }, + Todos: []session.TodoItem{ + {ID: "todo-1", Content: "before restore"}, + }, + }, + }) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + + messages := []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("before restore"), + }, + }, + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("tool planned"), + }, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-1", Name: "bash", Arguments: `{"cmd":"pwd"}`}, + }, + ToolMetadata: map[string]string{"source": "test"}, + }, + } + if err := store.AppendMessages(context.Background(), session.AppendMessagesInput{ + SessionID: created.ID, + Messages: messages, + UpdatedAt: time.Now(), + Provider: "openai", + Model: "gpt-test", + Workdir: workdir, + }); err != nil { + t.Fatalf("AppendMessages() error = %v", err) + } + + loaded, err := store.LoadSession(context.Background(), created.ID) + if err != nil { + t.Fatalf("LoadSession() error = %v", err) + } + return loaded +} + +func checkpointInputFromSession(t *testing.T, loaded session.Session, checkpointID string, reason session.CheckpointReason, createdAt time.Time) CreateCheckpointInput { + t.Helper() + + headJSON, err := json.Marshal(loaded.HeadSnapshot()) + if err != nil { + t.Fatalf("Marshal(head) error = %v", err) + } + messagesJSON, err := json.Marshal(loaded.Messages) + if err != nil { + t.Fatalf("Marshal(messages) error = %v", err) + } + + return CreateCheckpointInput{ + Record: session.CheckpointRecord{ + CheckpointID: checkpointID, + WorkspaceKey: session.WorkspacePathKey(loaded.Workdir), + SessionID: loaded.ID, + RunID: "run-" + checkpointID, + Workdir: loaded.Workdir, + CreatedAt: createdAt, + Reason: reason, + CodeCheckpointRef: RefForPerEditCheckpoint(checkpointID), + Restorable: true, + Status: session.CheckpointStatusCreating, + }, + SessionCP: session.SessionCheckpoint{ + ID: "sc-" + checkpointID, + SessionID: loaded.ID, + HeadJSON: string(headJSON), + MessagesJSON: string(messagesJSON), + CreatedAt: createdAt, + }, + } +} + +func TestSQLiteCheckpointStoreCreateRestoreAndResume(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_restore", fixture.workspaceRoot) + checkpointCreatedAt := time.Now().Add(-time.Minute) + + input := checkpointInputFromSession(t, loaded, "cp-restore", session.CheckpointReasonPreWrite, checkpointCreatedAt) + saved, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), input) + if err != nil { + t.Fatalf("CreateCheckpoint() error = %v", err) + } + if saved.SessionCheckpointRef == "" || saved.Status != session.CheckpointStatusAvailable { + t.Fatalf("CreateCheckpoint() = %#v, want available checkpoint with session ref", saved) + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), loaded.ID, ListCheckpointOpts{ + Limit: 10, + RestorableOnly: true, + }) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 || records[0].CheckpointID != saved.CheckpointID { + t.Fatalf("ListCheckpoints() = %#v, want only %q", records, saved.CheckpointID) + } + + record, sessionCP, err := fixture.checkpointStore.GetCheckpoint(context.Background(), saved.CheckpointID) + if err != nil { + t.Fatalf("GetCheckpoint() error = %v", err) + } + if record.CheckpointID != saved.CheckpointID || sessionCP == nil || sessionCP.ID != saved.SessionCheckpointRef { + t.Fatalf("GetCheckpoint() = (%#v, %#v), want saved record and session snapshot", record, sessionCP) + } + + if err := fixture.sessionStore.UpdateSessionState(context.Background(), session.UpdateSessionStateInput{ + SessionID: loaded.ID, + UpdatedAt: time.Now(), + Title: "mutated", + Head: session.SessionHead{ + Provider: "openai", + Model: "gpt-test", + Workdir: loaded.Workdir, + TaskState: session.TaskState{ + Goal: "after restore", + VerificationProfile: session.VerificationProfileTaskOnly, + }, + Todos: []session.TodoItem{ + {ID: "todo-2", Content: "after restore"}, + }, + }, + }); err != nil { + t.Fatalf("UpdateSessionState() error = %v", err) + } + if err := fixture.sessionStore.AppendMessages(context.Background(), session.AppendMessagesInput{ + SessionID: loaded.ID, + Messages: []providertypes.Message{ + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("after restore"), + }, + }, + }, + UpdatedAt: time.Now(), + Provider: "openai", + Model: "gpt-test", + Workdir: loaded.Workdir, + }); err != nil { + t.Fatalf("AppendMessages(after) error = %v", err) + } + + if err := fixture.checkpointStore.RestoreCheckpoint(context.Background(), RestoreCheckpointInput{ + SessionID: loaded.ID, + Head: loaded.HeadSnapshot(), + Messages: loaded.Messages, + UpdatedAt: time.Now(), + }); err != nil { + t.Fatalf("RestoreCheckpoint() error = %v", err) + } + + restored, err := fixture.sessionStore.LoadSession(context.Background(), loaded.ID) + if err != nil { + t.Fatalf("LoadSession(restored) error = %v", err) + } + if restored.TaskState.Goal != loaded.TaskState.Goal { + t.Fatalf("restored goal = %q, want %q", restored.TaskState.Goal, loaded.TaskState.Goal) + } + if len(restored.Messages) != len(loaded.Messages) { + t.Fatalf("restored message count = %d, want %d", len(restored.Messages), len(loaded.Messages)) + } + if restored.Messages[1].ToolMetadata["source"] != "test" { + t.Fatalf("restored tool metadata = %#v, want preserved metadata", restored.Messages[1].ToolMetadata) + } + + if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), saved.CheckpointID, session.CheckpointStatusRestored); err != nil { + t.Fatalf("UpdateCheckpointStatus() error = %v", err) + } + filtered, err := fixture.checkpointStore.ListCheckpoints(context.Background(), loaded.ID, ListCheckpointOpts{ + RestorableOnly: true, + }) + if err != nil { + t.Fatalf("ListCheckpoints(filtered) error = %v", err) + } + if len(filtered) != 0 { + t.Fatalf("expected no restorable checkpoints after status change, got %#v", filtered) + } + + firstResume := session.ResumeCheckpoint{ + ID: "rc-1", + WorkspaceKey: session.WorkspacePathKey(loaded.Workdir), + RunID: "run-1", + SessionID: loaded.ID, + Turn: 1, + Phase: "plan", + CompletionState: "running", + TranscriptRevision: 3, + UpdatedAt: time.Now().Add(-time.Minute), + } + secondResume := firstResume + secondResume.ID = "rc-2" + secondResume.RunID = "run-2" + secondResume.Turn = 2 + secondResume.Phase = "execute" + secondResume.UpdatedAt = time.Now() + + if err := fixture.checkpointStore.SetResumeCheckpoint(context.Background(), firstResume); err != nil { + t.Fatalf("SetResumeCheckpoint(first) error = %v", err) + } + if err := fixture.checkpointStore.SetResumeCheckpoint(context.Background(), secondResume); err != nil { + t.Fatalf("SetResumeCheckpoint(second) error = %v", err) + } + gotResume, err := fixture.checkpointStore.GetLatestResumeCheckpoint(context.Background(), loaded.ID) + if err != nil { + t.Fatalf("GetLatestResumeCheckpoint() error = %v", err) + } + if gotResume == nil || gotResume.ID != secondResume.ID || gotResume.Turn != secondResume.Turn { + t.Fatalf("GetLatestResumeCheckpoint() = %#v, want %#v", gotResume, secondResume) + } +} + +func TestSQLiteCheckpointStoreRestoreCheckpointUpdatesStatuses(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_restore_status", fixture.workspaceRoot) + + recordA, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), checkpointInputFromSession( + t, loaded, "cp-available", session.CheckpointReasonPreWrite, time.Now().Add(-2*time.Minute), + )) + if err != nil { + t.Fatalf("CreateCheckpoint(cp-available) error = %v", err) + } + recordB, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), checkpointInputFromSession( + t, loaded, "cp-restored", session.CheckpointReasonEndOfTurn, time.Now().Add(-time.Minute), + )) + if err != nil { + t.Fatalf("CreateCheckpoint(cp-restored) error = %v", err) + } + + if err := fixture.checkpointStore.RestoreCheckpoint(context.Background(), RestoreCheckpointInput{ + SessionID: loaded.ID, + Head: loaded.HeadSnapshot(), + Messages: loaded.Messages, + UpdatedAt: time.Now(), + MarkAvailableIDs: []string{recordA.CheckpointID}, + MarkRestoredIDs: []string{recordB.CheckpointID}, + }); err != nil { + t.Fatalf("RestoreCheckpoint() error = %v", err) + } + + availableRecord, _, err := fixture.checkpointStore.GetCheckpoint(context.Background(), recordA.CheckpointID) + if err != nil { + t.Fatalf("GetCheckpoint(cp-available) error = %v", err) + } + if availableRecord.Status != session.CheckpointStatusAvailable { + t.Fatalf("available record status = %q, want available", availableRecord.Status) + } + + restoredRecord, _, err := fixture.checkpointStore.GetCheckpoint(context.Background(), recordB.CheckpointID) + if err != nil { + t.Fatalf("GetCheckpoint(cp-restored) error = %v", err) + } + if restoredRecord.Status != session.CheckpointStatusRestored { + t.Fatalf("restored record status = %q, want restored", restoredRecord.Status) + } +} + +func TestSQLiteCheckpointStoreGetCheckpointWithoutSessionSnapshot(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_record_only", fixture.workspaceRoot) + db, err := fixture.checkpointStore.ensureDB(context.Background()) + if err != nil { + t.Fatalf("ensureDB() error = %v", err) + } + + if _, err := db.ExecContext(context.Background(), ` +INSERT INTO checkpoint_records ( + id, workspace_key, session_id, run_id, workdir, created_at_ms, + reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref, + transcript_revision, restorable, status +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +`, + "cp-record-only", + session.WorkspacePathKey(loaded.Workdir), + loaded.ID, + "run-record-only", + loaded.Workdir, + time.Now().UnixMilli(), + string(session.CheckpointReasonPreWrite), + "", + "", + "", + 0, + 1, + string(session.CheckpointStatusAvailable), + ); err != nil { + t.Fatalf("insert checkpoint record error = %v", err) + } + + record, sessionCP, err := fixture.checkpointStore.GetCheckpoint(context.Background(), "cp-record-only") + if err != nil { + t.Fatalf("GetCheckpoint() error = %v", err) + } + if record.CheckpointID != "cp-record-only" { + t.Fatalf("record id = %q, want cp-record-only", record.CheckpointID) + } + if sessionCP != nil { + t.Fatalf("session checkpoint = %#v, want nil", sessionCP) + } +} + +func TestSQLiteCheckpointStorePruneAndRepair(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_prune", fixture.workspaceRoot) + + createdAt := time.Now().Add(-10 * time.Minute) + for i := 0; i < 4; i++ { + checkpointID := "cp-auto-" + string(rune('a'+i)) + input := checkpointInputFromSession(t, loaded, checkpointID, session.CheckpointReasonPreWrite, createdAt.Add(time.Duration(i)*time.Minute)) + if _, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), input); err != nil { + t.Fatalf("CreateCheckpoint(%s) error = %v", checkpointID, err) + } + } + if _, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), checkpointInputFromSession(t, loaded, "cp-manual", session.CheckpointReasonManual, time.Now())); err != nil { + t.Fatalf("CreateCheckpoint(manual) error = %v", err) + } + if _, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), checkpointInputFromSession(t, loaded, "cp-guard", session.CheckpointReasonGuard, time.Now().Add(time.Minute))); err != nil { + t.Fatalf("CreateCheckpoint(guard) error = %v", err) + } + + pruned, err := fixture.checkpointStore.PruneExpiredCheckpoints(context.Background(), loaded.ID, 2) + if err != nil { + t.Fatalf("PruneExpiredCheckpoints() error = %v", err) + } + if pruned != 2 { + t.Fatalf("PruneExpiredCheckpoints() = %d, want 2", pruned) + } + + prunedRecord, prunedSessionCP, err := fixture.checkpointStore.GetCheckpoint(context.Background(), "cp-auto-a") + if err != nil { + t.Fatalf("GetCheckpoint(pruned) error = %v", err) + } + if prunedRecord.Status != session.CheckpointStatusPruned || prunedRecord.Restorable { + t.Fatalf("pruned record = %#v, want pruned and not restorable", prunedRecord) + } + if prunedSessionCP != nil { + t.Fatalf("expected pruned session snapshot to be deleted, got %#v", prunedSessionCP) + } + + manualRecord, _, err := fixture.checkpointStore.GetCheckpoint(context.Background(), "cp-manual") + if err != nil { + t.Fatalf("GetCheckpoint(manual) error = %v", err) + } + if manualRecord.Status != session.CheckpointStatusAvailable || !manualRecord.Restorable { + t.Fatalf("manual record = %#v, want still available", manualRecord) + } + + db, err := fixture.checkpointStore.ensureDB(context.Background()) + if err != nil { + t.Fatalf("ensureDB() error = %v", err) + } + withSessionCPID := "cp-creating-with-session" + if _, err := db.ExecContext(context.Background(), ` +INSERT INTO session_checkpoints (id, session_id, head_json, messages_json, created_at_ms) +VALUES (?, ?, ?, ?, ?) +`, "sc-creating", loaded.ID, `{}`, `[]`, time.Now().UnixMilli()); err != nil { + t.Fatalf("insert session_checkpoint error = %v", err) + } + if _, err := db.ExecContext(context.Background(), ` +INSERT INTO checkpoint_records ( + id, workspace_key, session_id, run_id, workdir, created_at_ms, + reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref, + transcript_revision, restorable, status +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +`, + withSessionCPID, + session.WorkspacePathKey(loaded.Workdir), + loaded.ID, + "run-repair", + loaded.Workdir, + time.Now().UnixMilli(), + string(session.CheckpointReasonPreWrite), + "", + "sc-creating", + "", + 0, + 1, + string(session.CheckpointStatusCreating), + ); err != nil { + t.Fatalf("insert creating checkpoint with session ref error = %v", err) + } + + orphanID := "cp-creating-orphan" + if _, err := db.ExecContext(context.Background(), ` +INSERT INTO checkpoint_records ( + id, workspace_key, session_id, run_id, workdir, created_at_ms, + reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref, + transcript_revision, restorable, status +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +`, + orphanID, + session.WorkspacePathKey(loaded.Workdir), + loaded.ID, + "run-repair", + loaded.Workdir, + time.Now().UnixMilli(), + string(session.CheckpointReasonPreWrite), + "", + "", + "", + 0, + 1, + string(session.CheckpointStatusCreating), + ); err != nil { + t.Fatalf("insert orphan checkpoint error = %v", err) + } + + repaired, err := fixture.checkpointStore.RepairCreatingCheckpoints(context.Background()) + if err != nil { + t.Fatalf("RepairCreatingCheckpoints() error = %v", err) + } + if repaired != 2 { + t.Fatalf("RepairCreatingCheckpoints() = %d, want 2", repaired) + } + + repairedRecord, _, err := fixture.checkpointStore.GetCheckpoint(context.Background(), withSessionCPID) + if err != nil { + t.Fatalf("GetCheckpoint(repaired) error = %v", err) + } + if repairedRecord.Status != session.CheckpointStatusAvailable { + t.Fatalf("repaired record status = %q, want available", repairedRecord.Status) + } + + if _, _, err := fixture.checkpointStore.GetCheckpoint(context.Background(), orphanID); err == nil { + t.Fatalf("expected orphan creating checkpoint to be deleted") + } +} + +func TestSQLiteCheckpointStoreUsesSessionDatabasePath(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + expected := filepath.Clean(session.DatabasePath(fixture.baseDir, fixture.workspaceRoot)) + if filepath.Clean(fixture.checkpointStore.dbPath) != expected { + t.Fatalf("dbPath = %q, want %q", fixture.checkpointStore.dbPath, expected) + } +} + +func TestSQLiteCheckpointStoreSharedDBAndHelpers(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_shared_db", fixture.workspaceRoot) + db, err := fixture.checkpointStore.ensureDB(context.Background()) + if err != nil { + t.Fatalf("ensureDB() error = %v", err) + } + + shared := NewSQLiteCheckpointStoreWithDB(db) + if shared.ownsDB { + t.Fatal("shared checkpoint store should not own injected db") + } + if err := shared.Close(); err != nil { + t.Fatalf("Close(shared) error = %v", err) + } + if err := db.PingContext(context.Background()); err != nil { + t.Fatalf("db should remain open after shared Close(), got %v", err) + } + if _, err := shared.ListCheckpoints(context.Background(), loaded.ID, ListCheckpointOpts{}); err != nil { + t.Fatalf("shared ListCheckpoints() error = %v", err) + } + + if got := marshalPlanField(nil); got != "" { + t.Fatalf("marshalPlanField(nil) = %q, want empty", got) + } + var nilPlan *session.PlanArtifact + if got := marshalPlanField(nilPlan); got != "" { + t.Fatalf("marshalPlanField(nil pointer) = %q, want empty", got) + } + if got := marshalPlanField(map[string]any{"step": "verify"}); !strings.Contains(got, `"step":"verify"`) { + t.Fatalf("marshalPlanField(map) = %q", got) + } + if got := marshalPlanField(func() {}); got != "" { + t.Fatalf("marshalPlanField(unmarshalable) = %q, want empty", got) + } + if got := marshalHeadField(func() {}); got != "null" { + t.Fatalf("marshalHeadField(unmarshalable) = %q, want null", got) + } +} + +func TestSQLiteCheckpointStoreErrorsAndEmptyResults(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_empty_resume", fixture.workspaceRoot) + if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), "missing", session.CheckpointStatusAvailable); err == nil { + t.Fatal("expected UpdateCheckpointStatus() to fail for missing checkpoint") + } + + rc, err := fixture.checkpointStore.GetLatestResumeCheckpoint(context.Background(), loaded.ID) + if err != nil { + t.Fatalf("GetLatestResumeCheckpoint(missing) error = %v", err) + } + if rc != nil { + t.Fatalf("GetLatestResumeCheckpoint(missing) = %#v, want nil", rc) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := fixture.checkpointStore.ensureDB(context.Background()); err != nil { + t.Fatalf("ensureDB() error = %v", err) + } + if _, err := fixture.checkpointStore.CreateCheckpoint(ctx, CreateCheckpointInput{}); err == nil { + t.Fatal("expected CreateCheckpoint() to honor canceled context") + } +} + +func TestNewSQLiteCheckpointStoreWithNilDBClose(t *testing.T) { + t.Parallel() + + store := NewSQLiteCheckpointStoreWithDB((*sql.DB)(nil)) + if err := store.Close(); err != nil { + t.Fatalf("Close(nil db) error = %v", err) + } +} diff --git a/internal/checkpoint/fingerprint.go b/internal/checkpoint/fingerprint.go new file mode 100644 index 00000000..1d335026 --- /dev/null +++ b/internal/checkpoint/fingerprint.go @@ -0,0 +1,187 @@ +package checkpoint + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "io" + "os" + "path/filepath" + "sort" + "strings" + "time" +) + +const ( + fingerprintHeadBytes = 4 * 1024 + fingerprintHashLen = 16 +) + +// FileFingerprint 描述单个文件的廉价指纹。 +type FileFingerprint struct { + Size int64 `json:"size"` + ModTime time.Time `json:"mod_time"` + HeadHash string `json:"head_hash"` +} + +// WorkdirFingerprint 是 workdir 相对路径 → 指纹的快照。 +type WorkdirFingerprint map[string]FileFingerprint + +// FingerprintOptions 控制扫描范围与跳过规则。 +type FingerprintOptions struct { + SkipDirs []string + SkipExts []string + MaxFiles int +} + +// FingerprintDiff 描述两次指纹快照之间的差异。 +type FingerprintDiff struct { + Added []string + Deleted []string + Modified []string +} + +// DefaultFingerprintOptions 返回常用的扫描跳过规则与上限。 +func DefaultFingerprintOptions() FingerprintOptions { + return FingerprintOptions{ + SkipDirs: []string{".git", ".neocode", ".shadow", "node_modules", ".idea", ".vscode", "vendor", "target", "dist", "build"}, + SkipExts: []string{".exe", ".dll", ".so", ".dylib", ".bin", ".zip", ".tar", ".gz", ".7z", ".rar", ".jar", ".class", ".o", ".a", ".obj", ".pyc"}, + MaxFiles: 5000, + } +} + +// ScanWorkdir 扫描 workdir 下所有文件并生成指纹。 +// 第二个返回值为 true 表示因 MaxFiles 截断,结果可能不完整。 +func ScanWorkdir(ctx context.Context, workdir string, opts FingerprintOptions) (WorkdirFingerprint, bool, error) { + result := make(WorkdirFingerprint) + if strings.TrimSpace(workdir) == "" { + return result, false, nil + } + skipDirSet := setOf(opts.SkipDirs) + skipExtSet := setOf(lowerSlice(opts.SkipExts)) + truncated := false + + walkErr := filepath.Walk(workdir, func(path string, info os.FileInfo, err error) error { + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + if info.IsDir() { + if path == workdir { + return nil + } + if _, skip := skipDirSet[info.Name()]; skip { + return filepath.SkipDir + } + return nil + } + if !info.Mode().IsRegular() { + return nil + } + ext := strings.ToLower(filepath.Ext(info.Name())) + if _, skip := skipExtSet[ext]; skip { + return nil + } + if opts.MaxFiles > 0 && len(result) >= opts.MaxFiles { + truncated = true + return errSkipScan + } + rel, relErr := filepath.Rel(workdir, path) + if relErr != nil { + return relErr + } + hash, hashErr := hashHead(path, fingerprintHeadBytes) + if hashErr != nil { + return nil + } + result[filepath.ToSlash(rel)] = FileFingerprint{ + Size: info.Size(), + ModTime: info.ModTime(), + HeadHash: hash, + } + return nil + }) + if walkErr != nil && !errors.Is(walkErr, errSkipScan) { + return result, truncated, walkErr + } + return result, truncated, nil +} + +// DiffFingerprints 对比两个指纹快照,返回新增/删除/修改的相对路径列表(按字典序)。 +func DiffFingerprints(before, after WorkdirFingerprint) FingerprintDiff { + diff := FingerprintDiff{} + for path, fp := range after { + prev, ok := before[path] + if !ok { + diff.Added = append(diff.Added, path) + continue + } + if !fingerprintEqual(prev, fp) { + diff.Modified = append(diff.Modified, path) + } + } + for path := range before { + if _, ok := after[path]; !ok { + diff.Deleted = append(diff.Deleted, path) + } + } + sort.Strings(diff.Added) + sort.Strings(diff.Modified) + sort.Strings(diff.Deleted) + return diff +} + +func fingerprintEqual(a, b FileFingerprint) bool { + if a.Size != b.Size { + return false + } + if !a.ModTime.Equal(b.ModTime) { + return false + } + return a.HeadHash == b.HeadHash +} + +func hashHead(path string, max int) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + + h := sha256.New() + _, err = io.CopyN(h, f, int64(max)) + if err != nil && !errors.Is(err, io.EOF) { + return "", err + } + sum := h.Sum(nil) + return hex.EncodeToString(sum)[:fingerprintHashLen], nil +} + +func setOf(values []string) map[string]struct{} { + out := make(map[string]struct{}, len(values)) + for _, v := range values { + v = strings.TrimSpace(v) + if v == "" { + continue + } + out[v] = struct{}{} + } + return out +} + +func lowerSlice(values []string) []string { + out := make([]string, len(values)) + for i, v := range values { + out[i] = strings.ToLower(strings.TrimSpace(v)) + } + return out +} + +// errSkipScan 用于在 walk 过程中提前终止后续遍历,避免使用 panic/recover。 +var errSkipScan = errors.New("scan-truncated") diff --git a/internal/checkpoint/fingerprint_test.go b/internal/checkpoint/fingerprint_test.go new file mode 100644 index 00000000..7170c358 --- /dev/null +++ b/internal/checkpoint/fingerprint_test.go @@ -0,0 +1,166 @@ +package checkpoint + +import ( + "context" + "os" + "path/filepath" + "testing" +) + +// TestScanWorkdir_SkipsConfiguredDirs: skip dirs in opts are not scanned. +func TestScanWorkdir_SkipsConfiguredDirs(t *testing.T) { + root := t.TempDir() + mustWrite(t, filepath.Join(root, "src", "main.go"), "package main") + mustWrite(t, filepath.Join(root, "node_modules", "lib", "x.js"), "x") + mustWrite(t, filepath.Join(root, ".git", "config"), "[core]") + mustWrite(t, filepath.Join(root, "vendor", "v.go"), "v") + + fp, truncated, err := ScanWorkdir(context.Background(), root, DefaultFingerprintOptions()) + if err != nil { + t.Fatalf("scan: %v", err) + } + if truncated { + t.Fatalf("should not be truncated") + } + if _, ok := fp[filepath.ToSlash(filepath.Join("src", "main.go"))]; !ok { + t.Fatalf("src/main.go missing from fingerprint") + } + for k := range fp { + if filepath.HasPrefix(k, "node_modules/") || filepath.HasPrefix(k, ".git/") || filepath.HasPrefix(k, "vendor/") { + t.Fatalf("skipped dir leaked: %s", k) + } + } +} + +// TestScanWorkdir_SkipsBinaryExtensions: extensions in SkipExts excluded. +func TestScanWorkdir_SkipsBinaryExtensions(t *testing.T) { + root := t.TempDir() + mustWrite(t, filepath.Join(root, "a.go"), "package a") + mustWrite(t, filepath.Join(root, "b.exe"), "binary") + mustWrite(t, filepath.Join(root, "c.zip"), "zip") + + fp, _, err := ScanWorkdir(context.Background(), root, DefaultFingerprintOptions()) + if err != nil { + t.Fatalf("scan: %v", err) + } + if _, ok := fp["a.go"]; !ok { + t.Fatalf("a.go missing") + } + if _, ok := fp["b.exe"]; ok { + t.Fatalf(".exe should be skipped") + } + if _, ok := fp["c.zip"]; ok { + t.Fatalf(".zip should be skipped") + } +} + +// TestScanWorkdir_TruncatesBeyondMaxFiles: MaxFiles enforces an upper bound and sets truncated=true. +func TestScanWorkdir_TruncatesBeyondMaxFiles(t *testing.T) { + root := t.TempDir() + for i := 0; i < 20; i++ { + mustWrite(t, filepath.Join(root, "f", "x"+itoa(i)+".go"), "package x") + } + opts := DefaultFingerprintOptions() + opts.MaxFiles = 5 + + fp, truncated, err := ScanWorkdir(context.Background(), root, opts) + if err != nil { + t.Fatalf("scan: %v", err) + } + if !truncated { + t.Fatalf("expected truncated=true with MaxFiles=%d", opts.MaxFiles) + } + if len(fp) > opts.MaxFiles { + t.Fatalf("got %d entries, MaxFiles=%d", len(fp), opts.MaxFiles) + } +} + +// TestDiffFingerprints_ClassifiesAddDeleteModify: Diff produces three categories correctly. +func TestDiffFingerprints_ClassifiesAddDeleteModify(t *testing.T) { + before := WorkdirFingerprint{ + "keep.go": {Size: 10, HeadHash: "AA"}, + "remove.go": {Size: 20, HeadHash: "BB"}, + "modified.go": {Size: 30, HeadHash: "CC"}, + } + after := WorkdirFingerprint{ + "keep.go": {Size: 10, HeadHash: "AA"}, + "modified.go": {Size: 30, HeadHash: "DD"}, // hash differs + "new.go": {Size: 40, HeadHash: "EE"}, + } + diff := DiffFingerprints(before, after) + wantAdded := []string{"new.go"} + wantDeleted := []string{"remove.go"} + wantModified := []string{"modified.go"} + if !equalStringSlice(diff.Added, wantAdded) { + t.Fatalf("Added: got %v want %v", diff.Added, wantAdded) + } + if !equalStringSlice(diff.Deleted, wantDeleted) { + t.Fatalf("Deleted: got %v want %v", diff.Deleted, wantDeleted) + } + if !equalStringSlice(diff.Modified, wantModified) { + t.Fatalf("Modified: got %v want %v", diff.Modified, wantModified) + } +} + +// TestScanWorkdir_DetectsContentEdit: editing file content updates HeadHash. +func TestScanWorkdir_DetectsContentEdit(t *testing.T) { + root := t.TempDir() + mustWrite(t, filepath.Join(root, "f.go"), "v1") + + before, _, err := ScanWorkdir(context.Background(), root, DefaultFingerprintOptions()) + if err != nil { + t.Fatalf("scan before: %v", err) + } + mustWrite(t, filepath.Join(root, "f.go"), "v2content_different_size") + after, _, err := ScanWorkdir(context.Background(), root, DefaultFingerprintOptions()) + if err != nil { + t.Fatalf("scan after: %v", err) + } + diff := DiffFingerprints(before, after) + if !equalStringSlice(diff.Modified, []string{"f.go"}) { + t.Fatalf("Modified=%v, want [f.go]", diff.Modified) + } +} + +// helpers + +func mustWrite(t *testing.T, path, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("write: %v", err) + } +} + +func itoa(i int) string { + if i == 0 { + return "0" + } + digits := []byte{} + neg := i < 0 + if neg { + i = -i + } + for i > 0 { + digits = append([]byte{byte('0' + i%10)}, digits...) + i /= 10 + } + if neg { + return "-" + string(digits) + } + return string(digits) +} + +func equalStringSlice(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/internal/checkpoint/per_edit_snapshot.go b/internal/checkpoint/per_edit_snapshot.go new file mode 100644 index 00000000..70fbd2b7 --- /dev/null +++ b/internal/checkpoint/per_edit_snapshot.go @@ -0,0 +1,815 @@ +package checkpoint + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "time" + + "github.com/pmezard/go-difflib/difflib" +) + +const ( + perEditPathHashLen = 16 + perEditMaxCaptureBytes = 64 * 1024 * 1024 + perEditIndexFileName = "index.jsonl" +) + +// ConflictResult 是 RestoreResult.Conflict 字段的占位类型,保留以维持 Gateway/CLI 旧契约。 +// per-edit 后端不做冲突检测,HasConflict 始终为 false。 +type ConflictResult struct { + HasConflict bool `json:"has_conflict"` +} + +// FileVersionMeta 描述某次 CapturePreWrite 时刻的元信息,伴随 .bin 内容文件落盘。 +type FileVersionMeta struct { + PathHash string `json:"path_hash"` + DisplayPath string `json:"display_path"` + Version int `json:"version"` + Existed bool `json:"existed"` + IsDir bool `json:"is_dir,omitempty"` + Mode os.FileMode `json:"mode,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// CheckpointMeta 是 cp_.json 的内容。 +type CheckpointMeta struct { + CheckpointID string `json:"checkpoint_id"` + CreatedAt time.Time `json:"created_at"` + FileVersions map[string]int `json:"file_versions"` +} + +// perEditIndexEntry 是 index.jsonl 的单行结构,进程重启时用于重建内存索引。 +type perEditIndexEntry struct { + PathHash string `json:"path_hash"` + DisplayPath string `json:"display_path"` + Version int `json:"version"` +} + +// PerEditSnapshotStore 提供基于"工具触碰"的版本化增量文件历史。 +// 每个版本独立寻址(@v.bin/.meta),checkpoint 仅存 (pathHash → version) 映射。 +// 同一 workdir 下跨 session 共享 file-history 目录,pathHash 已唯一标识 abs path。 +type PerEditSnapshotStore struct { + fileHistoryDir string + checkpointsDir string + workdir string + + indexMu sync.Mutex + pathToVersions map[string][]int + displayPaths map[string]string + + pendingMu sync.Mutex + pending map[string]int +} + +// NewPerEditSnapshotStore 创建文件历史存储实例并从磁盘重建内存索引。 +// projectDir 为 ~/.neocode/projects/,workdir 为实际工作区根目录。 +func NewPerEditSnapshotStore(projectDir, workdir string) *PerEditSnapshotStore { + store := &PerEditSnapshotStore{ + fileHistoryDir: filepath.Join(projectDir, "file-history"), + checkpointsDir: filepath.Join(projectDir, "checkpoints"), + workdir: workdir, + pathToVersions: make(map[string][]int), + displayPaths: make(map[string]string), + pending: make(map[string]int), + } + store.loadIndexFromDisk() + return store +} + +// IsAvailable 永远返回 true,纯文件实现没有外部依赖。 +func (s *PerEditSnapshotStore) IsAvailable() bool { + return s != nil +} + +// CapturePreWrite 在工具修改 absPath 之前为其创建一个新版本(含旧内容)。 +// 同一 path 在同一轮(Reset 之间)内多次调用只保留首次:返回首次分配的版本号。 +// 文件不存在时 .meta.Existed=false、.bin 为空文件。 +func (s *PerEditSnapshotStore) CapturePreWrite(absPath string) (int, error) { + cleanPath := filepath.Clean(absPath) + if cleanPath == "" || cleanPath == "." { + return 0, fmt.Errorf("per-edit: empty path") + } + hash := perEditPathHash(cleanPath) + + s.indexMu.Lock() + defer s.indexMu.Unlock() + + s.pendingMu.Lock() + if v, ok := s.pending[hash]; ok { + s.pendingMu.Unlock() + return v, nil + } + s.pendingMu.Unlock() + + versions := s.pathToVersions[hash] + nextVersion := 1 + if len(versions) > 0 { + nextVersion = versions[len(versions)-1] + 1 + } + + content, existed, isDir, mode, err := readFileForCapture(cleanPath) + if err != nil { + return 0, fmt.Errorf("per-edit: read %s: %w", cleanPath, err) + } + + meta := FileVersionMeta{ + PathHash: hash, + DisplayPath: cleanPath, + Version: nextVersion, + Existed: existed, + IsDir: isDir, + Mode: mode, + CreatedAt: time.Now().UTC(), + } + + if err := s.writeVersionFiles(meta, content); err != nil { + return 0, err + } + if err := s.appendIndex(perEditIndexEntry{ + PathHash: hash, + DisplayPath: cleanPath, + Version: nextVersion, + }); err != nil { + return 0, fmt.Errorf("per-edit: append index: %w", err) + } + + s.pathToVersions[hash] = append(versions, nextVersion) + s.displayPaths[hash] = cleanPath + + s.pendingMu.Lock() + s.pending[hash] = nextVersion + s.pendingMu.Unlock() + + return nextVersion, nil +} + +// CaptureBatch 批量调用 CapturePreWrite,返回成功 capture 的 abs path 列表。 +// 单条失败立即返回,已 capture 的 path 仍在返回切片中。 +func (s *PerEditSnapshotStore) CaptureBatch(absPaths []string) ([]string, error) { + captured := make([]string, 0, len(absPaths)) + for _, p := range absPaths { + if strings.TrimSpace(p) == "" { + continue + } + if _, err := s.CapturePreWrite(p); err != nil { + return captured, err + } + captured = append(captured, filepath.Clean(p)) + } + return captured, nil +} + +// CapturePostDelete 为已删除的路径写入 post-delete 版本(Existed=false)。 +// 这些版本不进入 pending,而是直接追加到索引,供 restore/diff 的 v_next 查询使用。 +func (s *PerEditSnapshotStore) CapturePostDelete(absPaths []string) error { + s.indexMu.Lock() + defer s.indexMu.Unlock() + + for _, p := range absPaths { + cleanPath := filepath.Clean(p) + if cleanPath == "" || cleanPath == "." { + continue + } + hash := perEditPathHash(cleanPath) + + versions := s.pathToVersions[hash] + nextVersion := 1 + if len(versions) > 0 { + nextVersion = versions[len(versions)-1] + 1 + } + + meta := FileVersionMeta{ + PathHash: hash, + DisplayPath: cleanPath, + Version: nextVersion, + Existed: false, + IsDir: false, + Mode: 0, + CreatedAt: time.Now().UTC(), + } + metaPath := s.versionMetaPath(hash, nextVersion) + if err := s.writeVersionMetaOnly(metaPath, meta); err != nil { + return fmt.Errorf("per-edit: post-delete %s: %w", cleanPath, err) + } + if err := s.appendIndex(perEditIndexEntry{ + PathHash: hash, + DisplayPath: cleanPath, + Version: nextVersion, + }); err != nil { + return fmt.Errorf("per-edit: append post-delete index %s: %w", cleanPath, err) + } + + s.pathToVersions[hash] = append(versions, nextVersion) + s.displayPaths[hash] = cleanPath + } + return nil +} + +// Finalize 把当前 pending 的 (pathHash → version) 映射写入 cp_.json。 +// pending 为空时返回 (false, nil),不创建空 checkpoint。调用方在 Finalize 后应调用 Reset。 +func (s *PerEditSnapshotStore) Finalize(checkpointID string) (bool, error) { + if checkpointID == "" { + return false, fmt.Errorf("per-edit: empty checkpointID") + } + s.pendingMu.Lock() + if len(s.pending) == 0 { + s.pendingMu.Unlock() + return false, nil + } + snapshot := make(map[string]int, len(s.pending)) + for k, v := range s.pending { + snapshot[k] = v + } + s.pendingMu.Unlock() + + meta := CheckpointMeta{ + CheckpointID: checkpointID, + CreatedAt: time.Now().UTC(), + FileVersions: snapshot, + } + if err := s.writeCheckpointMeta(meta); err != nil { + return false, err + } + return true, nil +} + +// Reset 清空 pending 映射,每轮 turn 开始时调用,避免跨轮残留。 +func (s *PerEditSnapshotStore) Reset() { + s.pendingMu.Lock() + s.pending = make(map[string]int) + s.pendingMu.Unlock() +} + +// Restore 还原到指定 checkpoint 时刻的工作区文件状态。 +// 算法核心("下一版本即修改后状态"对偶): +// - 对每个 (pathHash, v_A):找 pathToVersions[hash] 中 v_A 之后的下一个版本 v_next。 +// - v_next 存在时把 v_next.bin 写回 displayPath(v_next.meta.Existed=false 时改为删除); +// v_next 内容即"checkpoint A 时刻的状态"。 +// - v_next 不存在时 no-op:当前 workdir 已等于 A 时刻状态。 +// +// 不在 cp.FileVersions 中的其他文件保持不变(per-edit 的关键性质)。 +func (s *PerEditSnapshotStore) Restore(ctx context.Context, checkpointID string) error { + cp, err := s.readCheckpointMeta(checkpointID) + if err != nil { + return err + } + s.indexMu.Lock() + defer s.indexMu.Unlock() + + for hash, vA := range cp.FileVersions { + if err := ctx.Err(); err != nil { + return err + } + nextVersion := s.findNextVersionLocked(hash, vA) + if nextVersion == 0 { + continue + } + nextMeta, err := s.readVersionMeta(hash, nextVersion) + if err != nil { + return fmt.Errorf("per-edit: read meta v%d: %w", nextVersion, err) + } + target := s.resolveDisplayPathLocked(hash, nextMeta.DisplayPath) + if target == "" { + return fmt.Errorf("per-edit: missing display path for hash %s", hash) + } + if !nextMeta.Existed { + if err := os.RemoveAll(target); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("per-edit: restore remove %s: %w", target, err) + } + continue + } + if nextMeta.IsDir { + if err := os.MkdirAll(target, nextMeta.Mode); err != nil { + return fmt.Errorf("per-edit: restore mkdir %s: %w", target, err) + } + continue + } + content, err := s.readVersionBin(hash, nextVersion) + if err != nil { + return fmt.Errorf("per-edit: read bin v%d: %w", nextVersion, err) + } + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + return fmt.Errorf("per-edit: restore mkdir parent %s: %w", target, err) + } + if err := writeFileAtomic(target, content, nextMeta.Mode); err != nil { + return fmt.Errorf("per-edit: write restore %s: %w", target, err) + } + } + return nil +} + +// RestoreExact 直接恢复 checkpoint 中记录的**精确版本**(不查找 v_next)。 +// 用于 UndoRestore:guard checkpoint 保存的就是 restore 前的 pre-write 状态, +// 直接写回即可,无需 v_next 语义。 +func (s *PerEditSnapshotStore) RestoreExact(ctx context.Context, checkpointID string) error { + cp, err := s.readCheckpointMeta(checkpointID) + if err != nil { + return err + } + s.indexMu.Lock() + defer s.indexMu.Unlock() + + for hash, vAt := range cp.FileVersions { + if err := ctx.Err(); err != nil { + return err + } + meta, err := s.readVersionMeta(hash, vAt) + if err != nil { + return fmt.Errorf("per-edit: read meta v%d: %w", vAt, err) + } + target := s.resolveDisplayPathLocked(hash, meta.DisplayPath) + if target == "" { + return fmt.Errorf("per-edit: missing display path for hash %s", hash) + } + if !meta.Existed { + if err := os.RemoveAll(target); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("per-edit: restore-exact remove %s: %w", target, err) + } + continue + } + if meta.IsDir { + if err := os.MkdirAll(target, meta.Mode); err != nil { + return fmt.Errorf("per-edit: restore-exact mkdir %s: %w", target, err) + } + continue + } + content, err := s.readVersionBin(hash, vAt) + if err != nil { + return fmt.Errorf("per-edit: restore-exact read bin v%d: %w", vAt, err) + } + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + return fmt.Errorf("per-edit: restore-exact mkdir parent %s: %w", target, err) + } + if err := writeFileAtomic(target, content, meta.Mode); err != nil { + return fmt.Errorf("per-edit: restore-exact write %s: %w", target, err) + } + } + return nil +} + +// Diff 端到端对比两个 checkpoint 之间的工作区差异,返回 unified diff。 +// 端到端性质保证:unified diff 算法只看输入端点,中间的反复修改若回到原值会自动从 diff 消失。 +func (s *PerEditSnapshotStore) Diff(ctx context.Context, fromID, toID string) (string, error) { + fromMeta, err := s.readCheckpointMeta(fromID) + if err != nil { + return "", err + } + toMeta, err := s.readCheckpointMeta(toID) + if err != nil { + return "", err + } + + s.indexMu.Lock() + defer s.indexMu.Unlock() + + hashSet := make(map[string]struct{}) + for h := range fromMeta.FileVersions { + hashSet[h] = struct{}{} + } + for h := range toMeta.FileVersions { + hashSet[h] = struct{}{} + } + hashes := make([]string, 0, len(hashSet)) + for h := range hashSet { + hashes = append(hashes, h) + } + sort.Strings(hashes) + + var buf bytes.Buffer + for _, hash := range hashes { + if err := ctx.Err(); err != nil { + return "", err + } + fromContent, fromIsDir, fromExists, fromDisplay, err := s.contentAtCheckpointLocked(hash, fromMeta.FileVersions, false) + if err != nil { + return "", err + } + toContent, toIsDir, toExists, toDisplay, err := s.contentAtCheckpointLocked(hash, toMeta.FileVersions, false) + if err != nil { + return "", err + } + if fromIsDir && toIsDir { + continue + } + if bytes.Equal(fromContent, toContent) && fromExists == toExists && fromIsDir == toIsDir { + continue + } + display := toDisplay + if display == "" { + display = fromDisplay + } + rel := s.relativeDisplay(display) + diff := difflib.UnifiedDiff{ + A: difflib.SplitLines(string(fromContent)), + B: difflib.SplitLines(string(toContent)), + FromFile: "a/" + filepath.ToSlash(rel), + ToFile: "b/" + filepath.ToSlash(rel), + Context: 3, + } + out, err := difflib.GetUnifiedDiffString(diff) + if err != nil { + return "", fmt.Errorf("per-edit: diff %s: %w", rel, err) + } + buf.WriteString(out) + } + return strings.TrimRight(buf.String(), "\n"), nil +} + +// DeleteCheckpoint 仅删除 cp_.json 元数据。 +// file-history 下的 .bin/.meta 不删除,因为它们可能被其他 checkpoint 引用,GC 由独立流程负责。 +func (s *PerEditSnapshotStore) DeleteCheckpoint(checkpointID string) error { + if checkpointID == "" { + return nil + } + err := os.Remove(s.checkpointMetaPath(checkpointID)) + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err +} + +// HasPending 返回当前 turn 是否已有 capture 待 Finalize,用于 gate 决定是否创建 checkpoint。 +func (s *PerEditSnapshotStore) HasPending() bool { + s.pendingMu.Lock() + defer s.pendingMu.Unlock() + return len(s.pending) > 0 +} + +// FileChangeKind 表示两个 checkpoint 之间单个 path 的变更类别。 +type FileChangeKind string + +const ( + FileChangeAdded FileChangeKind = "added" + FileChangeDeleted FileChangeKind = "deleted" + FileChangeModified FileChangeKind = "modified" +) + +// FileChangeEntry 描述端到端 diff 中单个 path 的变更。 +type FileChangeEntry struct { + Path string + Kind FileChangeKind +} + +// ChangedFiles 端到端比较两个 checkpoint,返回 path → 变更类别的列表(按 path 字典序)。 +// 不返回内容差异,仅用于 UI 分组(添加/删除/修改)。完整 patch 仍由 Diff 生成。 +func (s *PerEditSnapshotStore) ChangedFiles(ctx context.Context, fromID, toID string) ([]FileChangeEntry, error) { + fromMeta, err := s.readCheckpointMeta(fromID) + if err != nil { + return nil, err + } + toMeta, err := s.readCheckpointMeta(toID) + if err != nil { + return nil, err + } + + s.indexMu.Lock() + defer s.indexMu.Unlock() + + hashSet := make(map[string]struct{}) + for h := range fromMeta.FileVersions { + hashSet[h] = struct{}{} + } + for h := range toMeta.FileVersions { + hashSet[h] = struct{}{} + } + hashes := make([]string, 0, len(hashSet)) + for h := range hashSet { + hashes = append(hashes, h) + } + sort.Strings(hashes) + + out := make([]FileChangeEntry, 0, len(hashes)) + for _, hash := range hashes { + if err := ctx.Err(); err != nil { + return nil, err + } + fromContent, fromIsDir, fromExists, fromDisplay, err := s.contentAtCheckpointLocked(hash, fromMeta.FileVersions, false) + if err != nil { + return nil, err + } + toContent, toIsDir, toExists, toDisplay, err := s.contentAtCheckpointLocked(hash, toMeta.FileVersions, false) + if err != nil { + return nil, err + } + display := toDisplay + if display == "" { + display = fromDisplay + } + rel := filepath.ToSlash(s.relativeDisplay(display)) + switch { + case !fromExists && toExists: + out = append(out, FileChangeEntry{Path: rel, Kind: FileChangeAdded}) + case fromExists && !toExists: + out = append(out, FileChangeEntry{Path: rel, Kind: FileChangeDeleted}) + case fromIsDir != toIsDir || !bytes.Equal(fromContent, toContent): + out = append(out, FileChangeEntry{Path: rel, Kind: FileChangeModified}) + } + } + return out, nil +} + +// PerEditRefPrefix 标识 CheckpointRecord.CodeCheckpointRef 字段中由 per-edit 后端生成的引用。 +const PerEditRefPrefix = "peredit:" + +// RefForPerEditCheckpoint 返回 per-edit 后端用于 CheckpointRecord.CodeCheckpointRef 的字符串引用。 +func RefForPerEditCheckpoint(checkpointID string) string { + return PerEditRefPrefix + checkpointID +} + +// IsPerEditRef 判定一个 CodeCheckpointRef 是否由 per-edit 后端生成。 +func IsPerEditRef(ref string) bool { + return strings.HasPrefix(ref, PerEditRefPrefix) +} + +// PerEditCheckpointIDFromRef 从 CodeCheckpointRef 中提取 checkpoint ID。非 per-edit ref 时返回空字符串。 +func PerEditCheckpointIDFromRef(ref string) string { + if !IsPerEditRef(ref) { + return "" + } + return strings.TrimPrefix(ref, PerEditRefPrefix) +} + +func perEditPathHash(absPath string) string { + sum := sha256.Sum256([]byte(filepath.Clean(absPath))) + return hex.EncodeToString(sum[:])[:perEditPathHashLen] +} + +func readFileForCapture(absPath string) ([]byte, bool, bool, os.FileMode, error) { + info, err := os.Stat(absPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, false, false, 0, nil + } + return nil, false, false, 0, err + } + if info.IsDir() { + return nil, true, true, info.Mode(), nil + } + if info.Size() > perEditMaxCaptureBytes { + return nil, true, false, info.Mode(), fmt.Errorf("file %d bytes exceeds per-edit capture limit", info.Size()) + } + content, err := os.ReadFile(absPath) + if err != nil { + return nil, true, false, info.Mode(), err + } + return content, true, false, info.Mode(), nil +} + +func (s *PerEditSnapshotStore) writeVersionFiles(meta FileVersionMeta, content []byte) error { + if err := os.MkdirAll(s.fileHistoryDir, 0o755); err != nil { + return fmt.Errorf("per-edit: mkdir file-history: %w", err) + } + binPath := s.versionBinPath(meta.PathHash, meta.Version) + metaPath := s.versionMetaPath(meta.PathHash, meta.Version) + + if err := writeFileAtomic(binPath, content, 0o644); err != nil { + return fmt.Errorf("per-edit: write bin: %w", err) + } + if err := s.writeVersionMetaOnly(metaPath, meta); err != nil { + _ = os.Remove(binPath) + return err + } + return nil +} + +func (s *PerEditSnapshotStore) writeVersionMetaOnly(metaPath string, meta FileVersionMeta) error { + metaJSON, err := json.Marshal(meta) + if err != nil { + return fmt.Errorf("per-edit: marshal meta: %w", err) + } + if err := writeFileAtomic(metaPath, metaJSON, 0o644); err != nil { + return fmt.Errorf("per-edit: write meta: %w", err) + } + return nil +} + +func writeFileAtomic(target string, data []byte, mode os.FileMode) error { + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + return err + } + if mode == 0 { + mode = 0o644 + } + tmp, err := os.CreateTemp(filepath.Dir(target), filepath.Base(target)+".tmp-*") + if err != nil { + return err + } + tmpName := tmp.Name() + if _, err := tmp.Write(data); err != nil { + tmp.Close() + os.Remove(tmpName) + return err + } + if err := tmp.Close(); err != nil { + os.Remove(tmpName) + return err + } + if err := os.Chmod(tmpName, mode); err != nil { + os.Remove(tmpName) + return err + } + return os.Rename(tmpName, target) +} + +func (s *PerEditSnapshotStore) appendIndex(entry perEditIndexEntry) error { + if err := os.MkdirAll(s.fileHistoryDir, 0o755); err != nil { + return err + } + line, err := json.Marshal(entry) + if err != nil { + return err + } + line = append(line, '\n') + f, err := os.OpenFile(s.indexPath(), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return err + } + defer f.Close() + _, err = f.Write(line) + return err +} + +func (s *PerEditSnapshotStore) loadIndexFromDisk() { + f, err := os.Open(s.indexPath()) + if err != nil { + return + } + defer f.Close() + + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + var entry perEditIndexEntry + if err := json.Unmarshal(line, &entry); err != nil { + continue + } + s.pathToVersions[entry.PathHash] = append(s.pathToVersions[entry.PathHash], entry.Version) + s.displayPaths[entry.PathHash] = entry.DisplayPath + } + for hash, versions := range s.pathToVersions { + sort.Ints(versions) + s.pathToVersions[hash] = versions + } +} + +func (s *PerEditSnapshotStore) writeCheckpointMeta(meta CheckpointMeta) error { + if err := os.MkdirAll(s.checkpointsDir, 0o755); err != nil { + return fmt.Errorf("per-edit: mkdir checkpoints: %w", err) + } + data, err := json.Marshal(meta) + if err != nil { + return fmt.Errorf("per-edit: marshal cp meta: %w", err) + } + return writeFileAtomic(s.checkpointMetaPath(meta.CheckpointID), data, 0o644) +} + +func (s *PerEditSnapshotStore) readCheckpointMeta(checkpointID string) (CheckpointMeta, error) { + var meta CheckpointMeta + data, err := os.ReadFile(s.checkpointMetaPath(checkpointID)) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return meta, fmt.Errorf("per-edit: checkpoint %s not found", checkpointID) + } + return meta, fmt.Errorf("per-edit: read cp meta %s: %w", checkpointID, err) + } + if err := json.Unmarshal(data, &meta); err != nil { + return meta, fmt.Errorf("per-edit: unmarshal cp meta %s: %w", checkpointID, err) + } + if meta.FileVersions == nil { + meta.FileVersions = map[string]int{} + } + return meta, nil +} + +func (s *PerEditSnapshotStore) readVersionMeta(hash string, version int) (FileVersionMeta, error) { + var meta FileVersionMeta + data, err := os.ReadFile(s.versionMetaPath(hash, version)) + if err != nil { + return meta, err + } + err = json.Unmarshal(data, &meta) + return meta, err +} + +func (s *PerEditSnapshotStore) readVersionBin(hash string, version int) ([]byte, error) { + return os.ReadFile(s.versionBinPath(hash, version)) +} + +// findNextVersionLocked 返回 hash 下大于 vA 的最小版本号,没有则返回 0。indexMu 必须被持有。 +func (s *PerEditSnapshotStore) findNextVersionLocked(hash string, vA int) int { + versions := s.pathToVersions[hash] + for _, v := range versions { + if v > vA { + return v + } + } + return 0 +} + +// resolveDisplayPathLocked 选取 hash 对应的工作区绝对路径。indexMu 必须被持有。 +func (s *PerEditSnapshotStore) resolveDisplayPathLocked(hash, fallback string) string { + if dp, ok := s.displayPaths[hash]; ok && dp != "" { + return dp + } + return fallback +} + +// contentAtCheckpointLocked 计算 hash 在某个 checkpoint 时刻的 workdir 内容。 +// 在 cp.FileVersions 中:找下一版本读 .bin(或 Existed=false 时返回 nil); +// 没有下一版本时:以当前 workdir 实际内容为准。 +// 不在 cp.FileVersions 中且 fallbackIfMissing=false 时:返回 exists=false,避免 diff 侧把工作区当前文件误判为 checkpoint 时刻已存在。 +// indexMu 必须被持有。 +func (s *PerEditSnapshotStore) contentAtCheckpointLocked(hash string, cpVersions map[string]int, fallbackIfMissing bool) ([]byte, bool, bool, string, error) { + display := s.displayPaths[hash] + vAt, ok := cpVersions[hash] + if !ok { + if fallbackIfMissing { + c, isDir, exists := readWorkdirContent(display) + return c, isDir, exists, display, nil + } + return nil, false, false, display, nil + } + nextVersion := s.findNextVersionLocked(hash, vAt) + if nextVersion == 0 { + c, isDir, exists := readWorkdirContent(display) + return c, isDir, exists, display, nil + } + nextMeta, err := s.readVersionMeta(hash, nextVersion) + if err != nil { + return nil, false, false, display, fmt.Errorf("per-edit: read meta v%d for %s: %w", nextVersion, hash, err) + } + if !nextMeta.Existed { + return nil, false, false, display, nil + } + if nextMeta.IsDir { + return nil, true, true, display, nil + } + content, err := s.readVersionBin(hash, nextVersion) + if err != nil { + return nil, false, false, display, fmt.Errorf("per-edit: read bin v%d for %s: %w", nextVersion, hash, err) + } + return content, false, true, display, nil +} + +func readWorkdirContent(absPath string) ([]byte, bool, bool) { + if absPath == "" { + return nil, false, false + } + info, err := os.Stat(absPath) + if err != nil { + return nil, false, false + } + if info.IsDir() { + return nil, true, true + } + data, err := os.ReadFile(absPath) + if err != nil { + return nil, false, false + } + return data, false, true +} + +func (s *PerEditSnapshotStore) relativeDisplay(absPath string) string { + if absPath == "" { + return "" + } + if s.workdir == "" { + return absPath + } + rel, err := filepath.Rel(s.workdir, absPath) + if err != nil { + return absPath + } + return rel +} + +func (s *PerEditSnapshotStore) versionBinPath(hash string, version int) string { + return filepath.Join(s.fileHistoryDir, fmt.Sprintf("%s@v%d.bin", hash, version)) +} + +func (s *PerEditSnapshotStore) versionMetaPath(hash string, version int) string { + return filepath.Join(s.fileHistoryDir, fmt.Sprintf("%s@v%d.meta", hash, version)) +} + +func (s *PerEditSnapshotStore) checkpointMetaPath(checkpointID string) string { + return filepath.Join(s.checkpointsDir, fmt.Sprintf("cp_%s.json", checkpointID)) +} + +func (s *PerEditSnapshotStore) indexPath() string { + return filepath.Join(s.fileHistoryDir, perEditIndexFileName) +} diff --git a/internal/checkpoint/per_edit_snapshot_test.go b/internal/checkpoint/per_edit_snapshot_test.go new file mode 100644 index 00000000..c83dc0cb --- /dev/null +++ b/internal/checkpoint/per_edit_snapshot_test.go @@ -0,0 +1,1125 @@ +package checkpoint + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +// newTestStore returns a PerEditSnapshotStore rooted at t.TempDir() and a workdir under it. +func newTestStore(t *testing.T) (*PerEditSnapshotStore, string) { + t.Helper() + root := t.TempDir() + projectDir := filepath.Join(root, "project") + workdir := filepath.Join(root, "workdir") + if err := os.MkdirAll(workdir, 0o755); err != nil { + t.Fatalf("mkdir workdir: %v", err) + } + return NewPerEditSnapshotStore(projectDir, workdir), workdir +} + +func writeWorkdirFile(t *testing.T, workdir, rel, content string) string { + t.Helper() + abs := filepath.Join(workdir, rel) + if err := os.MkdirAll(filepath.Dir(abs), 0o755); err != nil { + t.Fatalf("mkdir parent: %v", err) + } + if err := os.WriteFile(abs, []byte(content), 0o644); err != nil { + t.Fatalf("write %s: %v", rel, err) + } + return abs +} + +func mustReadFile(t *testing.T, path string) string { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read %s: %v", path, err) + } + return string(data) +} + +// TestCapturePreWrite_AssignsMonotonicVersions: same path captured across turns gets v1, v2, v3... +func TestCapturePreWrite_AssignsMonotonicVersions(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "a.txt", "v0") + + for i := 1; i <= 3; i++ { + v, err := store.CapturePreWrite(abs) + if err != nil { + t.Fatalf("capture %d: %v", i, err) + } + if v != i { + t.Fatalf("capture %d: want version %d, got %d", i, i, v) + } + store.Reset() + } +} + +// TestCapturePreWrite_DedupesWithinTurn: same path within one turn returns first version every time. +func TestCapturePreWrite_DedupesWithinTurn(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "a.txt", "hello") + + v1, err := store.CapturePreWrite(abs) + if err != nil || v1 != 1 { + t.Fatalf("first capture: v=%d err=%v", v1, err) + } + v2, err := store.CapturePreWrite(abs) + if err != nil { + t.Fatalf("second capture: %v", err) + } + if v2 != v1 { + t.Fatalf("dedupe failed: v1=%d v2=%d", v1, v2) + } + v3, err := store.CapturePreWrite(abs) + if err != nil { + t.Fatalf("third capture: %v", err) + } + if v3 != v1 { + t.Fatalf("dedupe failed: v1=%d v3=%d", v1, v3) + } +} + +// TestCapturePreWrite_NewFileMarksExistedFalse: capturing a non-existent path stores Existed=false. +func TestCapturePreWrite_NewFileMarksExistedFalse(t *testing.T) { + store, workdir := newTestStore(t) + abs := filepath.Join(workdir, "ghost.txt") + + v, err := store.CapturePreWrite(abs) + if err != nil { + t.Fatalf("capture missing file: %v", err) + } + + hash := perEditPathHash(abs) + meta, err := store.readVersionMeta(hash, v) + if err != nil { + t.Fatalf("read meta: %v", err) + } + if meta.Existed { + t.Fatalf("Existed should be false for missing file") + } + bin, err := store.readVersionBin(hash, v) + if err != nil { + t.Fatalf("read bin: %v", err) + } + if len(bin) != 0 { + t.Fatalf("bin should be empty, got %d bytes", len(bin)) + } +} + +// TestRestore_UsesNextVersionAsTargetState: capture v1, modify, finalize cp1; capture v2, modify; +// Restore(cp1) should put v2.bin (== state right after v1's edit) on disk. +func TestRestore_UsesNextVersionAsTargetState(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "a.txt", "STATE_INITIAL") + + // Turn 1: capture preX, simulate tool edit to STATE_AFTER_TURN_1, finalize cp1. + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("turn1 capture: %v", err) + } + if err := os.WriteFile(abs, []byte("STATE_AFTER_TURN_1"), 0o644); err != nil { + t.Fatalf("turn1 edit: %v", err) + } + if written, err := store.Finalize("cp1"); err != nil || !written { + t.Fatalf("turn1 finalize: written=%v err=%v", written, err) + } + store.Reset() + + // Turn 2: capture (current=STATE_AFTER_TURN_1), simulate edit to STATE_AFTER_TURN_2, finalize cp2. + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("turn2 capture: %v", err) + } + if err := os.WriteFile(abs, []byte("STATE_AFTER_TURN_2"), 0o644); err != nil { + t.Fatalf("turn2 edit: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("turn2 finalize: %v", err) + } + store.Reset() + + // Workdir is now STATE_AFTER_TURN_2. + if got := mustReadFile(t, abs); got != "STATE_AFTER_TURN_2" { + t.Fatalf("pre-restore: %q", got) + } + + // Restore cp1: should write STATE_AFTER_TURN_1 (== v2.bin == content captured at start of turn 2). + if err := store.Restore(context.Background(), "cp1"); err != nil { + t.Fatalf("restore cp1: %v", err) + } + if got := mustReadFile(t, abs); got != "STATE_AFTER_TURN_1" { + t.Fatalf("after restore cp1 want %q got %q", "STATE_AFTER_TURN_1", got) + } +} + +// TestRestore_NoNextVersionIsNoOp: restoring the latest checkpoint doesn't change workdir. +func TestRestore_NoNextVersionIsNoOp(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "a.txt", "BEFORE") + + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture: %v", err) + } + if err := os.WriteFile(abs, []byte("AFTER"), 0o644); err != nil { + t.Fatalf("edit: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize: %v", err) + } + store.Reset() + + if err := store.Restore(context.Background(), "cp1"); err != nil { + t.Fatalf("restore: %v", err) + } + if got := mustReadFile(t, abs); got != "AFTER" { + t.Fatalf("workdir after restore should be unchanged AFTER, got %q", got) + } +} + +// TestRestore_PreservesUntrackedFiles: files not in cp.FileVersions stay untouched. +func TestRestore_PreservesUntrackedFiles(t *testing.T) { + store, workdir := newTestStore(t) + tracked := writeWorkdirFile(t, workdir, "tracked.txt", "TR_INITIAL") + untracked := writeWorkdirFile(t, workdir, "untracked.txt", "UN_INITIAL") + + // Turn 1: only touch tracked. + if _, err := store.CapturePreWrite(tracked); err != nil { + t.Fatalf("capture tracked: %v", err) + } + if err := os.WriteFile(tracked, []byte("TR_AFTER_T1"), 0o644); err != nil { + t.Fatalf("edit tracked: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize: %v", err) + } + store.Reset() + + // Turn 2: edit tracked again so cp1 has a usable v_next. + if _, err := store.CapturePreWrite(tracked); err != nil { + t.Fatalf("capture tracked t2: %v", err) + } + if err := os.WriteFile(tracked, []byte("TR_AFTER_T2"), 0o644); err != nil { + t.Fatalf("edit tracked t2: %v", err) + } + // External (non-agent) edit to untracked file at any time; should NOT be reverted. + if err := os.WriteFile(untracked, []byte("UN_EXTERNAL"), 0o644); err != nil { + t.Fatalf("edit untracked: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("finalize cp2: %v", err) + } + store.Reset() + + if err := store.Restore(context.Background(), "cp1"); err != nil { + t.Fatalf("restore cp1: %v", err) + } + if got := mustReadFile(t, tracked); got != "TR_AFTER_T1" { + t.Fatalf("tracked after restore want TR_AFTER_T1 got %q", got) + } + if got := mustReadFile(t, untracked); got != "UN_EXTERNAL" { + t.Fatalf("untracked must stay UN_EXTERNAL, got %q", got) + } +} + +// TestDiff_EndToEnd_SameLineMultipleEdits: a→b→a→b→a sequence; Diff(first, last) is empty. +func TestDiff_EndToEnd_SameLineMultipleEdits(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "f.txt", "X\n") + + transitions := []string{"A\n", "B\n", "A\n", "B\n", "A\n"} + for i, target := range transitions { + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture turn %d: %v", i+1, err) + } + if err := os.WriteFile(abs, []byte(target), 0o644); err != nil { + t.Fatalf("edit turn %d: %v", i+1, err) + } + cpID := "cp" + string(rune('0'+i+1)) + if _, err := store.Finalize(cpID); err != nil { + t.Fatalf("finalize %s: %v", cpID, err) + } + store.Reset() + } + + // State at cp1 (== content right after turn 1) should be "A". + // State at cp5 (== current workdir, since v5 has no v_next) should also be "A". + patch, err := store.Diff(context.Background(), "cp1", "cp5") + if err != nil { + t.Fatalf("diff: %v", err) + } + if strings.TrimSpace(patch) != "" { + t.Fatalf("expected empty diff for endpoints both 'A', got:\n%s", patch) + } +} + +// TestDiff_NoNextVersionFallsBackToWorkdir: latest checkpoint uses current workdir for its content. +func TestDiff_NoNextVersionFallsBackToWorkdir(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "f.txt", "X") + + // Turn 1: X → A + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture t1: %v", err) + } + if err := os.WriteFile(abs, []byte("A"), 0o644); err != nil { + t.Fatalf("edit t1: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize cp1: %v", err) + } + store.Reset() + + // Turn 2: A → B + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture t2: %v", err) + } + if err := os.WriteFile(abs, []byte("B"), 0o644); err != nil { + t.Fatalf("edit t2: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("finalize cp2: %v", err) + } + store.Reset() + + // content_at_cp1 = v2.bin = "A" + // content_at_cp2 = current workdir = "B" + patch, err := store.Diff(context.Background(), "cp1", "cp2") + if err != nil { + t.Fatalf("diff: %v", err) + } + if !strings.Contains(patch, "-A") || !strings.Contains(patch, "+B") { + t.Fatalf("expected diff A→B, got:\n%s", patch) + } +} + +// TestIndexReload_SurvivesProcessRestart: reconstruct store from disk, verify pathToVersions/displayPaths. +func TestIndexReload_SurvivesProcessRestart(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "project") + workdir := filepath.Join(root, "workdir") + if err := os.MkdirAll(workdir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + + abs := filepath.Join(workdir, "a.txt") + if err := os.WriteFile(abs, []byte("X"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + { + store := NewPerEditSnapshotStore(projectDir, workdir) + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("first capture: %v", err) + } + if err := os.WriteFile(abs, []byte("Y"), 0o644); err != nil { + t.Fatalf("edit1: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize: %v", err) + } + store.Reset() + + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("second capture: %v", err) + } + } + + // Simulate process restart: build fresh store from same dirs. + revived := NewPerEditSnapshotStore(projectDir, workdir) + hash := perEditPathHash(abs) + versions := revived.pathToVersions[hash] + if len(versions) != 2 || versions[0] != 1 || versions[1] != 2 { + t.Fatalf("revived versions = %v, want [1 2]", versions) + } + if revived.displayPaths[hash] != filepath.Clean(abs) { + t.Fatalf("revived display = %q, want %q", revived.displayPaths[hash], filepath.Clean(abs)) + } + + // Restore on revived store should still work (verifies cp1.json + version files are usable). + // Workdir is "Y" right now (we never edited again post second capture). + // cp1 -> v_next(v1) = v2 -> meta.Existed=true, content="Y" + // So Restore writes "Y" back which is no-op effectively. + if err := revived.Restore(context.Background(), "cp1"); err != nil { + t.Fatalf("revived restore: %v", err) + } + if got := mustReadFile(t, abs); got != "Y" { + t.Fatalf("post-restore want Y got %q", got) + } +} + +// TestFinalize_EmptyPendingReturnsFalse: Finalize with no captures should be a no-op. +func TestFinalize_EmptyPendingReturnsFalse(t *testing.T) { + store, _ := newTestStore(t) + written, err := store.Finalize("cp_empty") + if err != nil { + t.Fatalf("finalize: %v", err) + } + if written { + t.Fatalf("written should be false on empty pending") + } + if _, err := os.Stat(store.checkpointMetaPath("cp_empty")); !os.IsNotExist(err) { + t.Fatalf("checkpoint meta should not exist, stat err=%v", err) + } +} + +// TestRestore_RemovesFileWhenVNextExistedFalse: capture-existing → delete → restore should NOT +// recreate the file because the next captured version has Existed=false. +func TestRestore_RemovesFileWhenVNextExistedFalse(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "doomed.txt", "I_LIVE") + + // Turn 1: edit + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture t1: %v", err) + } + if err := os.WriteFile(abs, []byte("STILL_LIVE"), 0o644); err != nil { + t.Fatalf("edit t1: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize cp1: %v", err) + } + store.Reset() + + // Turn 2: capture existing then delete; v2.bin contains "STILL_LIVE", v2.meta.Existed=true. + // We need a v3 that has Existed=false to model "restore should delete". + // So: turn 2 deletes, capture pre-delete: v2.bin="STILL_LIVE", Existed=true; remove file. + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture t2: %v", err) + } + if err := os.Remove(abs); err != nil { + t.Fatalf("delete t2: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("finalize cp2: %v", err) + } + store.Reset() + + // Turn 3: re-create file; capture pre-create finds Existed=false. + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture t3: %v", err) + } + if err := os.WriteFile(abs, []byte("RECREATED"), 0o644); err != nil { + t.Fatalf("recreate t3: %v", err) + } + if _, err := store.Finalize("cp3"); err != nil { + t.Fatalf("finalize cp3: %v", err) + } + store.Reset() + + // Restore cp2: v2 captured "STILL_LIVE"; v_next(v2)=v3 has Existed=false → delete file. + if err := store.Restore(context.Background(), "cp2"); err != nil { + t.Fatalf("restore cp2: %v", err) + } + if _, err := os.Stat(abs); !os.IsNotExist(err) { + t.Fatalf("file should be deleted, stat err=%v", err) + } +} + +// TestCaptureBatch_DedupesAndCaptures: batch is just sequential CapturePreWrite, dedupe works. +func TestCaptureBatch_DedupesAndCaptures(t *testing.T) { + store, workdir := newTestStore(t) + a := writeWorkdirFile(t, workdir, "a.txt", "A") + b := writeWorkdirFile(t, workdir, "b.txt", "B") + + captured, err := store.CaptureBatch([]string{a, b, a, " ", "", b}) + if err != nil { + t.Fatalf("batch: %v", err) + } + if len(captured) != 4 { + t.Fatalf("captured paths len = %d, want 4 (empty/whitespace skipped)", len(captured)) + } + + // pending should have exactly two unique hashes. + store.pendingMu.Lock() + count := len(store.pending) + store.pendingMu.Unlock() + if count != 2 { + t.Fatalf("pending count = %d, want 2", count) + } +} + +// TestCapturePreWrite_DirectoryMarksExistedTrue: capturing an existing directory stores Existed=true, IsDir=true. +func TestCapturePreWrite_DirectoryMarksExistedTrue(t *testing.T) { + store, workdir := newTestStore(t) + abs := filepath.Join(workdir, "subdir") + if err := os.MkdirAll(abs, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + + v, err := store.CapturePreWrite(abs) + if err != nil { + t.Fatalf("capture dir: %v", err) + } + + hash := perEditPathHash(abs) + meta, err := store.readVersionMeta(hash, v) + if err != nil { + t.Fatalf("read meta: %v", err) + } + if !meta.Existed { + t.Fatalf("Existed should be true for directory") + } + if !meta.IsDir { + t.Fatalf("IsDir should be true for directory") + } + bin, err := store.readVersionBin(hash, v) + if err != nil { + t.Fatalf("read bin: %v", err) + } + if len(bin) != 0 { + t.Fatalf("bin should be empty, got %d bytes", len(bin)) + } +} + +// TestRestore_DirectoryRecreateAndDelete: per-edit restore uses v_next to determine directory state. +func TestRestore_DirectoryRecreateAndDelete(t *testing.T) { + store, workdir := newTestStore(t) + dir := filepath.Join(workdir, "foo") + + // Turn 1: create_dir — capture pre-create (does not exist), then create. + if _, err := store.CapturePreWrite(dir); err != nil { + t.Fatalf("capture pre-create: %v", err) + } + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize cp1: %v", err) + } + store.Reset() + + // Turn 2: remove_dir — capture pre-remove (exists, IsDir=true), then remove. + if _, err := store.CapturePreWrite(dir); err != nil { + t.Fatalf("capture pre-remove: %v", err) + } + if err := os.RemoveAll(dir); err != nil { + t.Fatalf("remove: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("finalize cp2: %v", err) + } + store.Reset() + + // Turn 3: recreate_dir — capture pre-recreate (does not exist), then create. + // This gives cp2 a v_next with Existed=false so restore can delete the directory. + if _, err := store.CapturePreWrite(dir); err != nil { + t.Fatalf("capture pre-recreate: %v", err) + } + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("mkdir recreate: %v", err) + } + if _, err := store.Finalize("cp3"); err != nil { + t.Fatalf("finalize cp3: %v", err) + } + store.Reset() + + // Restore cp1: v_next=v2(Existed=true,IsDir=true) → MkdirAll. Dir should exist. + if err := os.RemoveAll(dir); err != nil { + t.Fatalf("manual remove before restore: %v", err) + } + if err := store.Restore(context.Background(), "cp1"); err != nil { + t.Fatalf("restore cp1: %v", err) + } + info, err := os.Stat(dir) + if err != nil { + t.Fatalf("dir should exist after restore cp1, got %v", err) + } + if !info.IsDir() { + t.Fatalf("restored path should be a directory") + } + + // Restore cp2: v_next=v3(Existed=false) → RemoveAll. Dir should be deleted. + if err := store.Restore(context.Background(), "cp2"); err != nil { + t.Fatalf("restore cp2: %v", err) + } + if _, err := os.Stat(dir); !os.IsNotExist(err) { + t.Fatalf("expected dir absent after restore cp2, stat err=%v", err) + } +} + +// TestRestore_DirectoryWithNestedFile: RemoveAll can delete a directory that later got nested files. +func TestRestore_DirectoryWithNestedFile(t *testing.T) { + store, workdir := newTestStore(t) + dir := filepath.Join(workdir, "foo") + child := filepath.Join(dir, "bar.txt") + + // Turn 1: create_dir. + if _, err := store.CapturePreWrite(dir); err != nil { + t.Fatalf("capture pre-create dir: %v", err) + } + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if _, err := store.Finalize("cp-dir"); err != nil { + t.Fatalf("finalize cp-dir: %v", err) + } + store.Reset() + + // Turn 2: write file inside dir AND re-capture dir (so dir gets a v2 with Existed=true,IsDir=true). + if _, err := store.CapturePreWrite(dir); err != nil { + t.Fatalf("capture pre-touch dir: %v", err) + } + if _, err := store.CapturePreWrite(child); err != nil { + t.Fatalf("capture pre-write child: %v", err) + } + if err := os.WriteFile(child, []byte("hello"), 0o644); err != nil { + t.Fatalf("write child: %v", err) + } + if _, err := store.Finalize("cp-child"); err != nil { + t.Fatalf("finalize cp-child: %v", err) + } + store.Reset() + + // Turn 3: remove_dir — capture pre-remove (dir+child exist), then delete. + if _, err := store.CapturePreWrite(dir); err != nil { + t.Fatalf("capture pre-remove dir: %v", err) + } + if err := os.RemoveAll(dir); err != nil { + t.Fatalf("remove dir: %v", err) + } + if _, err := store.Finalize("cp-remove"); err != nil { + t.Fatalf("finalize cp-remove: %v", err) + } + store.Reset() + + // Turn 4: recreate empty dir — gives cp-remove a v_next with Existed=false. + if _, err := store.CapturePreWrite(dir); err != nil { + t.Fatalf("capture pre-recreate dir: %v", err) + } + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("mkdir recreate: %v", err) + } + if _, err := store.Finalize("cp-recreate"); err != nil { + t.Fatalf("finalize cp-recreate: %v", err) + } + store.Reset() + + // Restore cp-dir: v_next=v2(Existed=true,IsDir=true) → MkdirAll. Dir should exist (child won't be restored because child has its own chain). + if err := os.RemoveAll(dir); err != nil { + t.Fatalf("manual remove before restore: %v", err) + } + if err := store.Restore(context.Background(), "cp-dir"); err != nil { + t.Fatalf("restore cp-dir: %v", err) + } + if _, err := os.Stat(dir); os.IsNotExist(err) { + t.Fatalf("dir should be recreated after restore cp-dir") + } + + // Restore cp-remove: v_next=v4(Existed=false) → RemoveAll. Should delete even if non-empty. + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("mkdir before restore: %v", err) + } + if err := os.WriteFile(child, []byte("new"), 0o644); err != nil { + t.Fatalf("write child before restore: %v", err) + } + if err := store.Restore(context.Background(), "cp-remove"); err != nil { + t.Fatalf("restore cp-remove: %v", err) + } + if _, err := os.Stat(dir); !os.IsNotExist(err) { + t.Fatalf("expected dir absent after restore cp-remove, stat err=%v", err) + } +} + +func TestChangedFiles(t *testing.T) { + store, workdir := newTestStore(t) + + // Setup files for cp1. + writeWorkdirFile(t, workdir, "a.txt", "alpha") + writeWorkdirFile(t, workdir, "b.txt", "beta") + + // Turn 1: capture, finalize cp1. + if _, err := store.CapturePreWrite(filepath.Join(workdir, "a.txt")); err != nil { + t.Fatalf("capture cp1 a: %v", err) + } + if _, err := store.CapturePreWrite(filepath.Join(workdir, "b.txt")); err != nil { + t.Fatalf("capture cp1 b: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize cp1: %v", err) + } + store.Reset() + + // Turn 2: capture all paths (including new c.txt), then edit. + if _, err := store.CapturePreWrite(filepath.Join(workdir, "a.txt")); err != nil { + t.Fatalf("capture cp2 a: %v", err) + } + if _, err := store.CapturePreWrite(filepath.Join(workdir, "b.txt")); err != nil { + t.Fatalf("capture cp2 b: %v", err) + } + if _, err := store.CapturePreWrite(filepath.Join(workdir, "c.txt")); err != nil { + t.Fatalf("capture cp2 c: %v", err) + } + writeWorkdirFile(t, workdir, "a.txt", "alpha-v2") + if err := os.Remove(filepath.Join(workdir, "b.txt")); err != nil { + t.Fatalf("remove b.txt: %v", err) + } + writeWorkdirFile(t, workdir, "c.txt", "gamma") + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("finalize cp2: %v", err) + } + store.Reset() + + // Turn 3: capture all paths again to create v_next for cp2 (needed for correct diff resolution). + if _, err := store.CapturePreWrite(filepath.Join(workdir, "a.txt")); err != nil { + t.Fatalf("capture cp3 a: %v", err) + } + if _, err := store.CapturePreWrite(filepath.Join(workdir, "b.txt")); err != nil { + t.Fatalf("capture cp3 b: %v", err) + } + if _, err := store.CapturePreWrite(filepath.Join(workdir, "c.txt")); err != nil { + t.Fatalf("capture cp3 c: %v", err) + } + if _, err := store.Finalize("cp3"); err != nil { + t.Fatalf("finalize cp3: %v", err) + } + store.Reset() + + // Restore to cp1 so workdir fallback matches cp1 state. + if err := store.Restore(context.Background(), "cp1"); err != nil { + t.Fatalf("restore cp1: %v", err) + } + // c.txt did not exist in cp1; Restore won't remove it because cp1 doesn't know about it. + if err := os.Remove(filepath.Join(workdir, "c.txt")); err != nil && !os.IsNotExist(err) { + t.Fatalf("remove stray c.txt: %v", err) + } + + changes, err := store.ChangedFiles(context.Background(), "cp1", "cp2") + if err != nil { + t.Fatalf("changed files cp1->cp2: %v", err) + } + if len(changes) != 3 { + t.Fatalf("expected 3 changes, got %d: %+v", len(changes), changes) + } + + want := map[string]FileChangeKind{ + "a.txt": FileChangeModified, + "b.txt": FileChangeDeleted, + "c.txt": FileChangeAdded, + } + for _, ch := range changes { + if want[ch.Path] != ch.Kind { + t.Errorf("path %s: expected kind %s, got %s", ch.Path, want[ch.Path], ch.Kind) + } + delete(want, ch.Path) + } + if len(want) > 0 { + t.Errorf("missing expected changes: %+v", want) + } +} + +func TestChangedFiles_NoChange(t *testing.T) { + store, workdir := newTestStore(t) + writeWorkdirFile(t, workdir, "x.txt", "same") + + if _, err := store.CapturePreWrite(filepath.Join(workdir, "x.txt")); err != nil { + t.Fatalf("capture cp1: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize cp1: %v", err) + } + store.Reset() + + if _, err := store.CapturePreWrite(filepath.Join(workdir, "x.txt")); err != nil { + t.Fatalf("capture cp2: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("finalize cp2: %v", err) + } + store.Reset() + + changes, err := store.ChangedFiles(context.Background(), "cp1", "cp2") + if err != nil { + t.Fatalf("changed files cp1->cp2: %v", err) + } + if len(changes) != 0 { + t.Fatalf("expected no changes, got %d: %+v", len(changes), changes) + } +} + +func TestChangedFiles_DirectoryToFile(t *testing.T) { + store, workdir := newTestStore(t) + path := filepath.Join(workdir, "target") + + // Turn 1: target is a directory. + if err := os.MkdirAll(path, 0o755); err != nil { + t.Fatalf("mkdir target: %v", err) + } + if _, err := store.CapturePreWrite(path); err != nil { + t.Fatalf("capture cp1: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize cp1: %v", err) + } + store.Reset() + + // Turn 2: target becomes a file. + if _, err := store.CapturePreWrite(path); err != nil { + t.Fatalf("capture cp2: %v", err) + } + if err := os.RemoveAll(path); err != nil { + t.Fatalf("remove dir: %v", err) + } + if err := os.WriteFile(path, []byte("file"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("finalize cp2: %v", err) + } + store.Reset() + + // Turn 3: capture again to give cp2 a v_next. + if _, err := store.CapturePreWrite(path); err != nil { + t.Fatalf("capture cp3: %v", err) + } + if _, err := store.Finalize("cp3"); err != nil { + t.Fatalf("finalize cp3: %v", err) + } + store.Reset() + + changes, err := store.ChangedFiles(context.Background(), "cp1", "cp2") + if err != nil { + t.Fatalf("changed files: %v", err) + } + if len(changes) != 1 || changes[0].Path != "target" || changes[0].Kind != FileChangeModified { + t.Fatalf("expected target modified, got %+v", changes) + } +} + +func TestCapturePostDelete_CreatesExistedFalseVersion(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "a.txt", "hello") + + // Turn 1: capture pre-write (v1, Existed=true). + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture v1: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize cp1: %v", err) + } + store.Reset() + + // Delete the file, then call CapturePostDelete to create v2(Existed=false). + if err := os.Remove(abs); err != nil { + t.Fatalf("remove: %v", err) + } + if err := store.CapturePostDelete([]string{abs}); err != nil { + t.Fatalf("CapturePostDelete: %v", err) + } + + // Restore cp1: v_next should be v2(Existed=false) → file should be deleted. + if err := store.Restore(context.Background(), "cp1"); err != nil { + t.Fatalf("restore cp1: %v", err) + } + if _, err := os.Stat(abs); !os.IsNotExist(err) { + t.Fatalf("expected file absent after restore, stat err=%v", err) + } +} + +func TestCapturePostDelete_DirectoryTreeRecovery(t *testing.T) { + store, workdir := newTestStore(t) + dir := filepath.Join(workdir, "foo") + child1 := filepath.Join(dir, "a.txt") + child2 := filepath.Join(dir, "sub", "b.txt") + + // Create nested tree. + writeWorkdirFile(t, workdir, "foo/a.txt", "alpha") + writeWorkdirFile(t, workdir, "foo/sub/b.txt", "beta") + + // Turn 1: pre-capture directory and all nested files. + if _, err := store.CapturePreWrite(dir); err != nil { + t.Fatalf("capture dir: %v", err) + } + if _, err := store.CapturePreWrite(child1); err != nil { + t.Fatalf("capture child1: %v", err) + } + if _, err := store.CapturePreWrite(child2); err != nil { + t.Fatalf("capture child2: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize cp1: %v", err) + } + store.Reset() + + // Turn 2: pre-capture, then delete tree, then CapturePostDelete, then finalize cp2. + if _, err := store.CapturePreWrite(dir); err != nil { + t.Fatalf("capture dir t2: %v", err) + } + if _, err := store.CapturePreWrite(child1); err != nil { + t.Fatalf("capture child1 t2: %v", err) + } + if _, err := store.CapturePreWrite(child2); err != nil { + t.Fatalf("capture child2 t2: %v", err) + } + if err := os.RemoveAll(dir); err != nil { + t.Fatalf("removeAll: %v", err) + } + if err := store.CapturePostDelete([]string{dir, child1, child2}); err != nil { + t.Fatalf("CapturePostDelete: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("finalize cp2: %v", err) + } + store.Reset() + + // Restore cp1: v_next is v2(pre-delete, Existed=true) → tree recreated. + if err := store.Restore(context.Background(), "cp1"); err != nil { + t.Fatalf("restore cp1: %v", err) + } + if got := mustReadFile(t, child1); got != "alpha" { + t.Fatalf("child1 want alpha got %q", got) + } + if got := mustReadFile(t, child2); got != "beta" { + t.Fatalf("child2 want beta got %q", got) + } + + // Restore cp2: v_next is v3(post-delete, Existed=false) → tree deleted. + if err := store.Restore(context.Background(), "cp2"); err != nil { + t.Fatalf("restore cp2: %v", err) + } + if _, err := os.Stat(dir); !os.IsNotExist(err) { + t.Fatalf("expected dir absent after restore cp2, stat err=%v", err) + } +} + +func TestRestore_RemoveDirWithNestedFiles(t *testing.T) { + store, workdir := newTestStore(t) + dir := filepath.Join(workdir, "foo") + child := filepath.Join(dir, "bar.txt") + + // Turn 1: create tree. + writeWorkdirFile(t, workdir, "foo/bar.txt", "hello") + if _, err := store.CapturePreWrite(dir); err != nil { + t.Fatalf("capture dir t1: %v", err) + } + if _, err := store.CapturePreWrite(child); err != nil { + t.Fatalf("capture child t1: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize cp1: %v", err) + } + store.Reset() + + // Turn 2: remove tree with recursive pre-capture + post-delete. + if _, err := store.CapturePreWrite(dir); err != nil { + t.Fatalf("capture dir t2: %v", err) + } + if _, err := store.CapturePreWrite(child); err != nil { + t.Fatalf("capture child t2: %v", err) + } + if err := os.RemoveAll(dir); err != nil { + t.Fatalf("removeAll: %v", err) + } + if err := store.CapturePostDelete([]string{dir, child}); err != nil { + t.Fatalf("CapturePostDelete: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("finalize cp2: %v", err) + } + store.Reset() + + // Turn 3: recreate tree with different content. + writeWorkdirFile(t, workdir, "foo/bar.txt", "world") + if _, err := store.CapturePreWrite(dir); err != nil { + t.Fatalf("capture dir t3: %v", err) + } + if _, err := store.CapturePreWrite(child); err != nil { + t.Fatalf("capture child t3: %v", err) + } + if _, err := store.Finalize("cp3"); err != nil { + t.Fatalf("finalize cp3: %v", err) + } + store.Reset() + + // Restore cp2: should delete the tree. + if err := store.Restore(context.Background(), "cp2"); err != nil { + t.Fatalf("restore cp2: %v", err) + } + if _, err := os.Stat(dir); !os.IsNotExist(err) { + t.Fatalf("expected dir absent after restore cp2, stat err=%v", err) + } + + // Restore cp1: should recreate the tree with original content. + if err := store.Restore(context.Background(), "cp1"); err != nil { + t.Fatalf("restore cp1: %v", err) + } + if got := mustReadFile(t, child); got != "hello" { + t.Fatalf("child want hello got %q", got) + } +} + +func TestPerEditStoreHelperMethods(t *testing.T) { + t.Run("availability and pending lifecycle", func(t *testing.T) { + var nilStore *PerEditSnapshotStore + if nilStore.IsAvailable() { + t.Fatal("nil store should report unavailable") + } + + store, workdir := newTestStore(t) + if !store.IsAvailable() { + t.Fatal("store should report available") + } + if store.HasPending() { + t.Fatal("new store should not have pending captures") + } + + abs := writeWorkdirFile(t, workdir, "pending.txt", "hello") + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("CapturePreWrite() error = %v", err) + } + if !store.HasPending() { + t.Fatal("capture should mark store pending") + } + + store.Reset() + if store.HasPending() { + t.Fatal("Reset() should clear pending captures") + } + }) + + t.Run("delete checkpoint and ref helpers", func(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "tracked.txt", "v1") + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("CapturePreWrite() error = %v", err) + } + if written, err := store.Finalize("cp-delete"); err != nil || !written { + t.Fatalf("Finalize() written=%v err=%v", written, err) + } + + cpPath := store.checkpointMetaPath("cp-delete") + if _, err := os.Stat(cpPath); err != nil { + t.Fatalf("checkpoint meta missing before delete: %v", err) + } + if err := store.DeleteCheckpoint("cp-delete"); err != nil { + t.Fatalf("DeleteCheckpoint() error = %v", err) + } + if _, err := os.Stat(cpPath); !os.IsNotExist(err) { + t.Fatalf("checkpoint meta should be removed, err=%v", err) + } + if err := store.DeleteCheckpoint("cp-delete"); err != nil { + t.Fatalf("DeleteCheckpoint() missing should be noop, got %v", err) + } + if err := store.DeleteCheckpoint(""); err != nil { + t.Fatalf("DeleteCheckpoint(\"\") should be noop, got %v", err) + } + + ref := RefForPerEditCheckpoint("cp-delete") + if !IsPerEditRef(ref) { + t.Fatalf("expected per-edit ref: %q", ref) + } + if got := PerEditCheckpointIDFromRef(ref); got != "cp-delete" { + t.Fatalf("PerEditCheckpointIDFromRef() = %q, want cp-delete", got) + } + if IsPerEditRef("git:deadbeef") { + t.Fatal("non per-edit ref should not match") + } + if got := PerEditCheckpointIDFromRef("git:deadbeef"); got != "" { + t.Fatalf("non per-edit ref should return empty id, got %q", got) + } + }) +} + +func TestRestoreExact(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "a.txt", "hello") + + // Turn 1: capture, edit, finalize cp1. + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture: %v", err) + } + if err := os.WriteFile(abs, []byte("world"), 0o644); err != nil { + t.Fatalf("edit: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize: %v", err) + } + store.Reset() + + // Turn 2: capture (v2="world"), edit again. + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture t2: %v", err) + } + if err := os.WriteFile(abs, []byte("third"), 0o644); err != nil { + t.Fatalf("edit t2: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("finalize cp2: %v", err) + } + store.Reset() + + // RestoreExact(cp2) should write back v2="world" (the exact version in cp2). + if err := store.RestoreExact(context.Background(), "cp2"); err != nil { + t.Fatalf("RestoreExact(cp2): %v", err) + } + if got := mustReadFile(t, abs); got != "world" { + t.Fatalf("RestoreExact(cp2) want world got %q", got) + } + + // RestoreExact(cp1) should write back v1="hello". + if err := store.RestoreExact(context.Background(), "cp1"); err != nil { + t.Fatalf("RestoreExact(cp1): %v", err) + } + if got := mustReadFile(t, abs); got != "hello" { + t.Fatalf("RestoreExact(cp1) want hello got %q", got) + } +} + +func TestChangedFiles_NewFileDetectedAsAdded(t *testing.T) { + store, workdir := newTestStore(t) + + // Turn 1: only a.txt exists. + writeWorkdirFile(t, workdir, "a.txt", "alpha") + if _, err := store.CapturePreWrite(filepath.Join(workdir, "a.txt")); err != nil { + t.Fatalf("capture cp1 a: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize cp1: %v", err) + } + store.Reset() + + // Turn 2: create b.txt. + writeWorkdirFile(t, workdir, "b.txt", "beta") + if _, err := store.CapturePreWrite(filepath.Join(workdir, "a.txt")); err != nil { + t.Fatalf("capture cp2 a: %v", err) + } + if _, err := store.CapturePreWrite(filepath.Join(workdir, "b.txt")); err != nil { + t.Fatalf("capture cp2 b: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("finalize cp2: %v", err) + } + store.Reset() + + // Turn 3: capture again to give cp2 a v_next. + if _, err := store.CapturePreWrite(filepath.Join(workdir, "a.txt")); err != nil { + t.Fatalf("capture cp3 a: %v", err) + } + if _, err := store.CapturePreWrite(filepath.Join(workdir, "b.txt")); err != nil { + t.Fatalf("capture cp3 b: %v", err) + } + if _, err := store.Finalize("cp3"); err != nil { + t.Fatalf("finalize cp3: %v", err) + } + store.Reset() + + // Restore to cp1 so workdir fallback matches cp1 state. + if err := store.Restore(context.Background(), "cp1"); err != nil { + t.Fatalf("restore cp1: %v", err) + } + if err := os.Remove(filepath.Join(workdir, "b.txt")); err != nil && !os.IsNotExist(err) { + t.Fatalf("remove stray b.txt: %v", err) + } + + changes, err := store.ChangedFiles(context.Background(), "cp1", "cp2") + if err != nil { + t.Fatalf("changed files cp1->cp2: %v", err) + } + if len(changes) != 1 { + t.Fatalf("expected 1 change, got %d: %+v", len(changes), changes) + } + if changes[0].Path != "b.txt" || changes[0].Kind != FileChangeAdded { + t.Fatalf("expected b.txt added, got %+v", changes[0]) + } +} diff --git a/internal/cli/gateway_runtime_bridge.go b/internal/cli/gateway_runtime_bridge.go index 9f296bf6..79b0ba03 100644 --- a/internal/cli/gateway_runtime_bridge.go +++ b/internal/cli/gateway_runtime_bridge.go @@ -13,6 +13,7 @@ import ( "time" "neo-code/internal/app" + "neo-code/internal/checkpoint" "neo-code/internal/config" configstate "neo-code/internal/config/state" "neo-code/internal/gateway" @@ -42,6 +43,13 @@ type runtimeSnapshotGetter interface { GetRuntimeSnapshot(ctx context.Context, sessionID string) (agentruntime.RuntimeSnapshot, error) } +type runtimeCheckpointer interface { + ListCheckpoints(ctx context.Context, sessionID string, opts checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) + RestoreCheckpoint(ctx context.Context, input agentruntime.GatewayRestoreInput) (agentruntime.RestoreResult, error) + UndoRestoreCheckpoint(ctx context.Context, sessionID string) (agentruntime.RestoreResult, error) + CheckpointDiff(ctx context.Context, input agentruntime.CheckpointDiffInput) (agentruntime.CheckpointDiffResult, error) +} + // bridgeSessionStore 定义桥接层对会话存储的最低需求。 type bridgeSessionStore interface { DeleteSession(ctx context.Context, sessionID string) error @@ -1363,3 +1371,91 @@ type manualModelPayload struct { } var _ gateway.RuntimePort = (*gatewayRuntimePortBridge)(nil) + +func (b *gatewayRuntimePortBridge) ListCheckpoints(ctx context.Context, input gateway.ListCheckpointsInput) ([]gateway.CheckpointEntry, error) { + cp, ok := b.runtime.(runtimeCheckpointer) + if !ok { + return nil, fmt.Errorf("gateway runtime bridge: runtime does not support checkpoint operations") + } + records, err := cp.ListCheckpoints(ctx, strings.TrimSpace(input.SessionID), checkpoint.ListCheckpointOpts{ + Limit: input.Limit, + RestorableOnly: input.RestorableOnly, + }) + if err != nil { + return nil, err + } + entries := make([]gateway.CheckpointEntry, 0, len(records)) + for _, r := range records { + entries = append(entries, gateway.CheckpointEntry{ + CheckpointID: r.CheckpointID, + SessionID: r.SessionID, + Reason: string(r.Reason), + Status: string(r.Status), + Restorable: r.Restorable, + CreatedAt: r.CreatedAt.UnixMilli(), + }) + } + return entries, nil +} + +func (b *gatewayRuntimePortBridge) RestoreCheckpoint(ctx context.Context, input gateway.CheckpointRestoreInput) (gateway.CheckpointRestoreResult, error) { + cp, ok := b.runtime.(runtimeCheckpointer) + if !ok { + return gateway.CheckpointRestoreResult{}, fmt.Errorf("gateway runtime bridge: runtime does not support checkpoint operations") + } + result, err := cp.RestoreCheckpoint(ctx, agentruntime.GatewayRestoreInput{ + SessionID: strings.TrimSpace(input.SessionID), + CheckpointID: strings.TrimSpace(input.CheckpointID), + Force: input.Force, + }) + if err != nil { + return gateway.CheckpointRestoreResult{}, err + } + return gateway.CheckpointRestoreResult{ + CheckpointID: result.CheckpointID, + SessionID: result.SessionID, + HasConflict: result.Conflict != nil && result.Conflict.HasConflict, + }, nil +} + +func (b *gatewayRuntimePortBridge) UndoRestore(ctx context.Context, input gateway.UndoRestoreInput) (gateway.CheckpointRestoreResult, error) { + cp, ok := b.runtime.(runtimeCheckpointer) + if !ok { + return gateway.CheckpointRestoreResult{}, fmt.Errorf("gateway runtime bridge: runtime does not support checkpoint operations") + } + result, err := cp.UndoRestoreCheckpoint(ctx, strings.TrimSpace(input.SessionID)) + if err != nil { + return gateway.CheckpointRestoreResult{}, err + } + return gateway.CheckpointRestoreResult{ + CheckpointID: result.CheckpointID, + SessionID: result.SessionID, + HasConflict: result.Conflict != nil && result.Conflict.HasConflict, + }, nil +} + +func (b *gatewayRuntimePortBridge) CheckpointDiff(ctx context.Context, input gateway.CheckpointDiffInput) (gateway.CheckpointDiffResult, error) { + cp, ok := b.runtime.(runtimeCheckpointer) + if !ok { + return gateway.CheckpointDiffResult{}, fmt.Errorf("gateway runtime bridge: runtime does not support checkpoint operations") + } + result, err := cp.CheckpointDiff(ctx, agentruntime.CheckpointDiffInput{ + SessionID: strings.TrimSpace(input.SessionID), + CheckpointID: strings.TrimSpace(input.CheckpointID), + }) + if err != nil { + return gateway.CheckpointDiffResult{}, err + } + return gateway.CheckpointDiffResult{ + CheckpointID: result.CheckpointID, + PrevCheckpointID: result.PrevCheckpointID, + CommitHash: result.CommitHash, + PrevCommitHash: result.PrevCommitHash, + Files: gateway.FileDiffs{ + Added: result.Files.Added, + Deleted: result.Files.Deleted, + Modified: result.Files.Modified, + }, + Patch: result.Patch, + }, nil +} diff --git a/internal/cli/gateway_runtime_bridge_test.go b/internal/cli/gateway_runtime_bridge_test.go index b936544d..f9b48123 100644 --- a/internal/cli/gateway_runtime_bridge_test.go +++ b/internal/cli/gateway_runtime_bridge_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "neo-code/internal/checkpoint" "neo-code/internal/config" configstate "neo-code/internal/config/state" "neo-code/internal/gateway" @@ -63,6 +64,19 @@ type runtimeStub struct { getSnapshotSessionID string getSnapshotResult agentruntime.RuntimeSnapshot getSnapshotErr error + listCheckpointsID string + listCheckpointsOpts checkpoint.ListCheckpointOpts + listCheckpointsResult []agentsession.CheckpointRecord + listCheckpointsErr error + restoreCheckpointIn agentruntime.GatewayRestoreInput + restoreCheckpointOut agentruntime.RestoreResult + restoreCheckpointErr error + undoRestoreSessionID string + undoRestoreOut agentruntime.RestoreResult + undoRestoreErr error + checkpointDiffIn agentruntime.CheckpointDiffInput + checkpointDiffOut agentruntime.CheckpointDiffResult + checkpointDiffErr error } const testBridgeSubjectID = bridgeLocalSubjectID @@ -150,6 +164,23 @@ func (s *runtimeStub) GetRuntimeSnapshot(_ context.Context, sessionID string) (a s.getSnapshotSessionID = sessionID return s.getSnapshotResult, s.getSnapshotErr } +func (s *runtimeStub) ListCheckpoints(_ context.Context, sessionID string, opts checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) { + s.listCheckpointsID = sessionID + s.listCheckpointsOpts = opts + return s.listCheckpointsResult, s.listCheckpointsErr +} +func (s *runtimeStub) RestoreCheckpoint(_ context.Context, input agentruntime.GatewayRestoreInput) (agentruntime.RestoreResult, error) { + s.restoreCheckpointIn = input + return s.restoreCheckpointOut, s.restoreCheckpointErr +} +func (s *runtimeStub) UndoRestoreCheckpoint(_ context.Context, sessionID string) (agentruntime.RestoreResult, error) { + s.undoRestoreSessionID = sessionID + return s.undoRestoreOut, s.undoRestoreErr +} +func (s *runtimeStub) CheckpointDiff(_ context.Context, input agentruntime.CheckpointDiffInput) (agentruntime.CheckpointDiffResult, error) { + s.checkpointDiffIn = input + return s.checkpointDiffOut, s.checkpointDiffErr +} func (s *runtimeStub) DeleteSession(_ context.Context, _ string) error { return nil } @@ -207,6 +238,65 @@ func (r *runtimeWithoutCreator) ListAvailableSkills( ) ([]agentruntime.AvailableSkillState, error) { return r.base.ListAvailableSkills(ctx, sessionID) } +func (r *runtimeWithoutCreator) ListCheckpoints(ctx context.Context, sessionID string, opts checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) { + return r.base.ListCheckpoints(ctx, sessionID, opts) +} +func (r *runtimeWithoutCreator) RestoreCheckpoint(ctx context.Context, input agentruntime.GatewayRestoreInput) (agentruntime.RestoreResult, error) { + return r.base.RestoreCheckpoint(ctx, input) +} +func (r *runtimeWithoutCreator) UndoRestoreCheckpoint(ctx context.Context, sessionID string) (agentruntime.RestoreResult, error) { + return r.base.UndoRestoreCheckpoint(ctx, sessionID) +} +func (r *runtimeWithoutCreator) CheckpointDiff(ctx context.Context, input agentruntime.CheckpointDiffInput) (agentruntime.CheckpointDiffResult, error) { + return r.base.CheckpointDiff(ctx, input) +} + +type runtimeWithoutCheckpointer struct { + base *runtimeStub +} + +func (r *runtimeWithoutCheckpointer) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + return r.base.Submit(ctx, input) +} +func (r *runtimeWithoutCheckpointer) PrepareUserInput(ctx context.Context, input agentruntime.PrepareInput) (agentruntime.UserInput, error) { + return r.base.PrepareUserInput(ctx, input) +} +func (r *runtimeWithoutCheckpointer) Run(ctx context.Context, input agentruntime.UserInput) error { + return r.base.Run(ctx, input) +} +func (r *runtimeWithoutCheckpointer) Compact(ctx context.Context, input agentruntime.CompactInput) (agentruntime.CompactResult, error) { + return r.base.Compact(ctx, input) +} +func (r *runtimeWithoutCheckpointer) ExecuteSystemTool(ctx context.Context, input agentruntime.SystemToolInput) (tools.ToolResult, error) { + return r.base.ExecuteSystemTool(ctx, input) +} +func (r *runtimeWithoutCheckpointer) ResolvePermission(ctx context.Context, input agentruntime.PermissionResolutionInput) error { + return r.base.ResolvePermission(ctx, input) +} +func (r *runtimeWithoutCheckpointer) CancelActiveRun() bool { + return r.base.CancelActiveRun() +} +func (r *runtimeWithoutCheckpointer) Events() <-chan agentruntime.RuntimeEvent { + return r.base.Events() +} +func (r *runtimeWithoutCheckpointer) ListSessions(ctx context.Context) ([]agentsession.Summary, error) { + return r.base.ListSessions(ctx) +} +func (r *runtimeWithoutCheckpointer) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { + return r.base.LoadSession(ctx, id) +} +func (r *runtimeWithoutCheckpointer) ActivateSessionSkill(ctx context.Context, sessionID string, skillID string) error { + return r.base.ActivateSessionSkill(ctx, sessionID, skillID) +} +func (r *runtimeWithoutCheckpointer) DeactivateSessionSkill(ctx context.Context, sessionID string, skillID string) error { + return r.base.DeactivateSessionSkill(ctx, sessionID, skillID) +} +func (r *runtimeWithoutCheckpointer) ListSessionSkills(ctx context.Context, sessionID string) ([]agentruntime.SessionSkillState, error) { + return r.base.ListSessionSkills(ctx, sessionID) +} +func (r *runtimeWithoutCheckpointer) ListAvailableSkills(ctx context.Context, sessionID string) ([]agentruntime.AvailableSkillState, error) { + return r.base.ListAvailableSkills(ctx, sessionID) +} type bridgeSessionStoreStub struct { deleteFn func(ctx context.Context, id string) error @@ -226,6 +316,193 @@ func (s *bridgeSessionStoreStub) UpdateSessionState(ctx context.Context, input a return nil } +func TestGatewayRuntimePortBridgeCheckpointOperations(t *testing.T) { + stub := &runtimeStub{ + listCheckpointsResult: []agentsession.CheckpointRecord{ + { + CheckpointID: "cp-1", + SessionID: "session-1", + Reason: agentsession.CheckpointReasonCompact, + Status: agentsession.CheckpointStatusAvailable, + Restorable: true, + CreatedAt: time.UnixMilli(1234), + }, + }, + restoreCheckpointOut: agentruntime.RestoreResult{ + CheckpointID: "cp-1", + SessionID: "session-1", + }, + undoRestoreOut: agentruntime.RestoreResult{ + CheckpointID: "guard-1", + SessionID: "session-1", + }, + checkpointDiffOut: agentruntime.CheckpointDiffResult{ + CheckpointID: "cp-2", + PrevCheckpointID: "cp-1", + CommitHash: "commit-2", + PrevCommitHash: "commit-1", + Files: agentruntime.FileDiffs{ + Added: []string{"new.txt"}, + Deleted: []string{"old.txt"}, + Modified: []string{"keep.txt"}, + }, + Patch: "diff --git a/keep.txt b/keep.txt", + }, + } + + bridge := &gatewayRuntimePortBridge{runtime: stub} + + entries, err := bridge.ListCheckpoints(context.Background(), gateway.ListCheckpointsInput{ + SessionID: " session-1 ", + Limit: 5, + RestorableOnly: true, + }) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if stub.listCheckpointsID != "session-1" || stub.listCheckpointsOpts.Limit != 5 || !stub.listCheckpointsOpts.RestorableOnly { + t.Fatalf("ListCheckpoints() forwarded (%q, %#v)", stub.listCheckpointsID, stub.listCheckpointsOpts) + } + if len(entries) != 1 || entries[0].CheckpointID != "cp-1" || entries[0].Reason != string(agentsession.CheckpointReasonCompact) { + t.Fatalf("ListCheckpoints() = %#v", entries) + } + + restoreResult, err := bridge.RestoreCheckpoint(context.Background(), gateway.CheckpointRestoreInput{ + SessionID: " session-1 ", + CheckpointID: " cp-1 ", + Force: true, + }) + if err != nil { + t.Fatalf("RestoreCheckpoint() error = %v", err) + } + if stub.restoreCheckpointIn.SessionID != "session-1" || stub.restoreCheckpointIn.CheckpointID != "cp-1" || !stub.restoreCheckpointIn.Force { + t.Fatalf("RestoreCheckpoint() forwarded %#v", stub.restoreCheckpointIn) + } + if restoreResult.CheckpointID != "cp-1" || restoreResult.SessionID != "session-1" || restoreResult.HasConflict { + t.Fatalf("RestoreCheckpoint() = %#v", restoreResult) + } + + undoResult, err := bridge.UndoRestore(context.Background(), gateway.UndoRestoreInput{SessionID: " session-1 "}) + if err != nil { + t.Fatalf("UndoRestore() error = %v", err) + } + if stub.undoRestoreSessionID != "session-1" { + t.Fatalf("UndoRestore() forwarded session %q", stub.undoRestoreSessionID) + } + if undoResult.CheckpointID != "guard-1" || undoResult.SessionID != "session-1" { + t.Fatalf("UndoRestore() = %#v", undoResult) + } + + diffResult, err := bridge.CheckpointDiff(context.Background(), gateway.CheckpointDiffInput{ + SessionID: " session-1 ", + CheckpointID: " cp-2 ", + }) + if err != nil { + t.Fatalf("CheckpointDiff() error = %v", err) + } + if stub.checkpointDiffIn.SessionID != "session-1" || stub.checkpointDiffIn.CheckpointID != "cp-2" { + t.Fatalf("CheckpointDiff() forwarded %#v", stub.checkpointDiffIn) + } + if diffResult.CheckpointID != "cp-2" || diffResult.PrevCheckpointID != "cp-1" || + diffResult.CommitHash != "commit-2" || diffResult.PrevCommitHash != "commit-1" || + len(diffResult.Files.Added) != 1 || diffResult.Files.Added[0] != "new.txt" || + len(diffResult.Files.Deleted) != 1 || diffResult.Files.Deleted[0] != "old.txt" || + len(diffResult.Files.Modified) != 1 || diffResult.Files.Modified[0] != "keep.txt" || + diffResult.Patch != "diff --git a/keep.txt b/keep.txt" { + t.Fatalf("CheckpointDiff() = %#v", diffResult) + } +} + +func TestGatewayRuntimePortBridgeCheckpointOperations_ReportConflictAndUnsupportedRuntime(t *testing.T) { + t.Run("conflict forwarded", func(t *testing.T) { + stub := &runtimeStub{ + restoreCheckpointOut: agentruntime.RestoreResult{ + CheckpointID: "cp-1", + SessionID: "session-1", + Conflict: &checkpoint.ConflictResult{HasConflict: true}, + }, + undoRestoreOut: agentruntime.RestoreResult{ + CheckpointID: "guard-1", + SessionID: "session-1", + Conflict: &checkpoint.ConflictResult{HasConflict: true}, + }, + } + bridge := &gatewayRuntimePortBridge{runtime: stub} + + restoreResult, err := bridge.RestoreCheckpoint(context.Background(), gateway.CheckpointRestoreInput{ + SessionID: "session-1", + CheckpointID: "cp-1", + }) + if err != nil { + t.Fatalf("RestoreCheckpoint() error = %v", err) + } + if !restoreResult.HasConflict { + t.Fatalf("RestoreCheckpoint() conflict flag = false, want true") + } + + undoResult, err := bridge.UndoRestore(context.Background(), gateway.UndoRestoreInput{SessionID: "session-1"}) + if err != nil { + t.Fatalf("UndoRestore() error = %v", err) + } + if !undoResult.HasConflict { + t.Fatalf("UndoRestore() conflict flag = false, want true") + } + }) + + t.Run("unsupported runtime", func(t *testing.T) { + bridge := &gatewayRuntimePortBridge{runtime: &runtimeWithoutCheckpointer{base: &runtimeStub{}}} + cases := []struct { + name string + call func() error + }{ + { + name: "list", + call: func() error { + _, err := bridge.ListCheckpoints(context.Background(), gateway.ListCheckpointsInput{SessionID: "session-1"}) + return err + }, + }, + { + name: "restore", + call: func() error { + _, err := bridge.RestoreCheckpoint(context.Background(), gateway.CheckpointRestoreInput{ + SessionID: "session-1", + CheckpointID: "cp-1", + }) + return err + }, + }, + { + name: "undo", + call: func() error { + _, err := bridge.UndoRestore(context.Background(), gateway.UndoRestoreInput{SessionID: "session-1"}) + return err + }, + }, + { + name: "diff", + call: func() error { + _, err := bridge.CheckpointDiff(context.Background(), gateway.CheckpointDiffInput{ + SessionID: "session-1", + CheckpointID: "cp-1", + }) + return err + }, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := tc.call() + if err == nil || !strings.Contains(err.Error(), "does not support checkpoint operations") { + t.Fatalf("error = %v, want unsupported checkpoint operations", err) + } + }) + } + }) +} + var testSessionStore bridgeSessionStore = &bridgeSessionStoreStub{} func TestNewGatewayRuntimePortBridgeRuntimeUnavailable(t *testing.T) { @@ -1351,7 +1628,7 @@ func TestResolveListFilesRootPriorities(t *testing.T) { // priority 2: session workdir (store implements bridgeSessionLoader) loaderStore := &bridgeSessionStoreWithLoader{ bridgeSessionStoreStub: bridgeSessionStoreStub{}, - session: agentsession.Session{Workdir: subDir}, + session: agentsession.Session{Workdir: subDir}, } bridge2, _ := newGatewayRuntimePortBridge(context.Background(), &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, loaderStore) defer bridge2.Close() @@ -1958,9 +2235,9 @@ func TestModelDisplayName(t *testing.T) { func TestGatewayRuntimePortBridgeCancelRunAndSnapshots(t *testing.T) { stub := &runtimeStub{ - eventsCh: make(chan agentruntime.RuntimeEvent, 1), - cancelReturn: true, - listTodosErr: errors.New("todo failed"), + eventsCh: make(chan agentruntime.RuntimeEvent, 1), + cancelReturn: true, + listTodosErr: errors.New("todo failed"), getSnapshotErr: errors.New("snapshot failed"), } bridge, err := newGatewayRuntimePortBridge(context.Background(), stub, testSessionStore) diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go index 9576962f..476eec36 100644 --- a/internal/cli/root_test.go +++ b/internal/cli/root_test.go @@ -1092,6 +1092,22 @@ func (stubRuntimePort) GetRuntimeSnapshot( return gateway.RuntimeSnapshot{}, nil } +func (stubRuntimePort) ListCheckpoints(context.Context, gateway.ListCheckpointsInput) ([]gateway.CheckpointEntry, error) { + return nil, nil +} + +func (stubRuntimePort) RestoreCheckpoint(context.Context, gateway.CheckpointRestoreInput) (gateway.CheckpointRestoreResult, error) { + return gateway.CheckpointRestoreResult{}, nil +} + +func (stubRuntimePort) UndoRestore(context.Context, gateway.UndoRestoreInput) (gateway.CheckpointRestoreResult, error) { + return gateway.CheckpointRestoreResult{}, nil +} + +func (stubRuntimePort) CheckpointDiff(context.Context, gateway.CheckpointDiffInput) (gateway.CheckpointDiffResult, error) { + return gateway.CheckpointDiffResult{}, nil +} + func (s *stubGatewayServer) ListenAddress() string { return s.listenAddress } diff --git a/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go b/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go index caadece3..42e55e14 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go +++ b/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go @@ -212,6 +212,34 @@ func (s *urlschemeIntegrationRuntimeStub) GetSessionModel(context.Context, gatew return gateway.SessionModelResult{}, nil } +func (s *urlschemeIntegrationRuntimeStub) ListCheckpoints( + context.Context, + gateway.ListCheckpointsInput, +) ([]gateway.CheckpointEntry, error) { + return nil, nil +} + +func (s *urlschemeIntegrationRuntimeStub) RestoreCheckpoint( + context.Context, + gateway.CheckpointRestoreInput, +) (gateway.CheckpointRestoreResult, error) { + return gateway.CheckpointRestoreResult{}, nil +} + +func (s *urlschemeIntegrationRuntimeStub) UndoRestore( + context.Context, + gateway.UndoRestoreInput, +) (gateway.CheckpointRestoreResult, error) { + return gateway.CheckpointRestoreResult{}, nil +} + +func (s *urlschemeIntegrationRuntimeStub) CheckpointDiff( + context.Context, + gateway.CheckpointDiffInput, +) (gateway.CheckpointDiffResult, error) { + return gateway.CheckpointDiffResult{}, nil +} + func waitGatewayReady(address string, timeout time.Duration) error { deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { diff --git a/internal/gateway/bootstrap.go b/internal/gateway/bootstrap.go index 0d97e67d..8e59d68f 100644 --- a/internal/gateway/bootstrap.go +++ b/internal/gateway/bootstrap.go @@ -1792,6 +1792,18 @@ func readStringValue(payload map[string]any, key string) string { return strings.TrimSpace(stringValue) } +func readBoolValue(payload map[string]any, key string) bool { + rawValue, exists := payload[key] + if !exists { + return false + } + boolValue, ok := rawValue.(bool) + if !ok { + return false + } + return boolValue +} + // decodeWakeIntent 将任意 payload 解码为 WakeIntent。 func decodeWakeIntent(payload any) (protocol.WakeIntent, error) { if payload == nil { @@ -1844,3 +1856,162 @@ func toFrameError(err *handlers.WakeError) *FrameError { } return NewFrameError(ErrorCodeInternalError, err.Message) } + +func handleListCheckpointsFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame { + if runtimePort == nil { + return runtimePortUnavailableFrame(frame) + } + subjectID, subjectErr := requireAuthenticatedSubjectID(ctx) + if subjectErr != nil { + return errorFrame(frame, subjectErr) + } + + callCtx, cancel := withRuntimeOperationTimeout(ctx) + defer cancel() + + entries, err := runtimePort.ListCheckpoints(callCtx, ListCheckpointsInput{ + SubjectID: subjectID, + SessionID: strings.TrimSpace(frame.SessionID), + }) + if err != nil { + return runtimeCallFailedFrame(callCtx, frame, err, "checkpoint_list") + } + + return MessageFrame{ + Type: FrameTypeAck, + Action: FrameActionListCheckpoints, + RequestID: frame.RequestID, + SessionID: strings.TrimSpace(frame.SessionID), + Payload: entries, + } +} + +func handleRestoreCheckpointFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame { + if runtimePort == nil { + return runtimePortUnavailableFrame(frame) + } + subjectID, subjectErr := requireAuthenticatedSubjectID(ctx) + if subjectErr != nil { + return errorFrame(frame, subjectErr) + } + + input := decodeCheckpointRestorePayload(frame.Payload) + input.SubjectID = subjectID + if input.SessionID == "" { + input.SessionID = strings.TrimSpace(frame.SessionID) + } + + callCtx, cancel := withRuntimeOperationTimeout(ctx) + defer cancel() + + result, err := runtimePort.RestoreCheckpoint(callCtx, input) + if err != nil { + return runtimeCallFailedFrame(callCtx, frame, err, "checkpoint_restore") + } + + return MessageFrame{ + Type: FrameTypeAck, + Action: FrameActionRestoreCheckpoint, + RequestID: frame.RequestID, + SessionID: input.SessionID, + Payload: result, + } +} + +func handleUndoRestoreFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame { + if runtimePort == nil { + return runtimePortUnavailableFrame(frame) + } + subjectID, subjectErr := requireAuthenticatedSubjectID(ctx) + if subjectErr != nil { + return errorFrame(frame, subjectErr) + } + + callCtx, cancel := withRuntimeOperationTimeout(ctx) + defer cancel() + + result, err := runtimePort.UndoRestore(callCtx, UndoRestoreInput{ + SubjectID: subjectID, + SessionID: strings.TrimSpace(frame.SessionID), + }) + if err != nil { + return runtimeCallFailedFrame(callCtx, frame, err, "checkpoint_undo_restore") + } + + return MessageFrame{ + Type: FrameTypeAck, + Action: FrameActionUndoRestore, + RequestID: frame.RequestID, + SessionID: strings.TrimSpace(frame.SessionID), + Payload: result, + } +} + +func handleCheckpointDiffFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame { + if runtimePort == nil { + return runtimePortUnavailableFrame(frame) + } + subjectID, subjectErr := requireAuthenticatedSubjectID(ctx) + if subjectErr != nil { + return errorFrame(frame, subjectErr) + } + + input := decodeCheckpointDiffPayload(frame.Payload) + input.SubjectID = subjectID + if input.SessionID == "" { + input.SessionID = strings.TrimSpace(frame.SessionID) + } + + callCtx, cancel := withRuntimeOperationTimeout(ctx) + defer cancel() + + result, err := runtimePort.CheckpointDiff(callCtx, input) + if err != nil { + return runtimeCallFailedFrame(callCtx, frame, err, "checkpoint_diff") + } + + return MessageFrame{ + Type: FrameTypeAck, + Action: FrameActionCheckpointDiff, + RequestID: frame.RequestID, + SessionID: input.SessionID, + Payload: result, + } +} + +func decodeCheckpointDiffPayload(payload any) CheckpointDiffInput { + switch typed := payload.(type) { + case map[string]any: + return CheckpointDiffInput{ + SessionID: readStringValue(typed, "session_id"), + CheckpointID: readStringValue(typed, "checkpoint_id"), + } + default: + raw, marshalErr := json.Marshal(payload) + if marshalErr != nil { + return CheckpointDiffInput{} + } + var input CheckpointDiffInput + _ = json.Unmarshal(raw, &input) + return input + } +} + +func decodeCheckpointRestorePayload(payload any) CheckpointRestoreInput { + switch typed := payload.(type) { + case map[string]any: + return CheckpointRestoreInput{ + SessionID: readStringValue(typed, "session_id"), + CheckpointID: readStringValue(typed, "checkpoint_id"), + Force: readBoolValue(typed, "force"), + } + default: + raw, marshalErr := json.Marshal(payload) + if marshalErr != nil { + return CheckpointRestoreInput{} + } + var decoded CheckpointRestoreInput + _ = json.Unmarshal(raw, &decoded) + return decoded + } +} diff --git a/internal/gateway/bootstrap_test.go b/internal/gateway/bootstrap_test.go index 26428156..817e5340 100644 --- a/internal/gateway/bootstrap_test.go +++ b/internal/gateway/bootstrap_test.go @@ -16,35 +16,39 @@ import ( ) type bootstrapRuntimeStub struct { - runFn func(ctx context.Context, input RunInput) error - createSessionFn func(ctx context.Context, input CreateSessionInput) (string, error) - compactFn func(ctx context.Context, input CompactInput) (CompactResult, error) - executeSystemToolFn func(ctx context.Context, input ExecuteSystemToolInput) (tools.ToolResult, error) - activateSkillFn func(ctx context.Context, input SessionSkillMutationInput) error - deactivateSkillFn func(ctx context.Context, input SessionSkillMutationInput) error - listSessionSkillsFn func(ctx context.Context, input ListSessionSkillsInput) ([]SessionSkillState, error) - listAvailableFn func(ctx context.Context, input ListAvailableSkillsInput) ([]AvailableSkillState, error) - resolvePermissionFn func(ctx context.Context, input PermissionResolutionInput) error - cancelRunFn func(ctx context.Context, input CancelInput) (bool, error) - events <-chan RuntimeEvent - listSessionsFn func(ctx context.Context) ([]SessionSummary, error) - loadSessionFn func(ctx context.Context, input LoadSessionInput) (Session, error) - listSessionTodosFn func(ctx context.Context, input ListSessionTodosInput) (TodoSnapshot, error) + runFn func(ctx context.Context, input RunInput) error + createSessionFn func(ctx context.Context, input CreateSessionInput) (string, error) + compactFn func(ctx context.Context, input CompactInput) (CompactResult, error) + executeSystemToolFn func(ctx context.Context, input ExecuteSystemToolInput) (tools.ToolResult, error) + activateSkillFn func(ctx context.Context, input SessionSkillMutationInput) error + deactivateSkillFn func(ctx context.Context, input SessionSkillMutationInput) error + listSessionSkillsFn func(ctx context.Context, input ListSessionSkillsInput) ([]SessionSkillState, error) + listAvailableFn func(ctx context.Context, input ListAvailableSkillsInput) ([]AvailableSkillState, error) + resolvePermissionFn func(ctx context.Context, input PermissionResolutionInput) error + cancelRunFn func(ctx context.Context, input CancelInput) (bool, error) + events <-chan RuntimeEvent + listSessionsFn func(ctx context.Context) ([]SessionSummary, error) + loadSessionFn func(ctx context.Context, input LoadSessionInput) (Session, error) + listSessionTodosFn func(ctx context.Context, input ListSessionTodosInput) (TodoSnapshot, error) getRuntimeSnapshotFn func(ctx context.Context, input GetRuntimeSnapshotInput) (RuntimeSnapshot, error) - deleteSessionFn func(ctx context.Context, input DeleteSessionInput) (bool, error) - renameSessionFn func(ctx context.Context, input RenameSessionInput) error - listFilesFn func(ctx context.Context, input ListFilesInput) ([]FileEntry, error) - listModelsFn func(ctx context.Context, input ListModelsInput) ([]ModelEntry, error) - setSessionModelFn func(ctx context.Context, input SetSessionModelInput) error - getSessionModelFn func(ctx context.Context, input GetSessionModelInput) (SessionModelResult, error) - listProvidersFn func(ctx context.Context, input ListProvidersInput) ([]ProviderOption, error) - createProviderFn func(ctx context.Context, input CreateProviderInput) (ProviderSelectionResult, error) - deleteProviderFn func(ctx context.Context, input DeleteProviderInput) error - selectProviderFn func(ctx context.Context, input SelectProviderModelInput) (ProviderSelectionResult, error) - listMCPServersFn func(ctx context.Context, input ListMCPServersInput) ([]MCPServerEntry, error) - upsertMCPServerFn func(ctx context.Context, input UpsertMCPServerInput) error - setMCPEnabledFn func(ctx context.Context, input SetMCPServerEnabledInput) error - deleteMCPServerFn func(ctx context.Context, input DeleteMCPServerInput) error + deleteSessionFn func(ctx context.Context, input DeleteSessionInput) (bool, error) + renameSessionFn func(ctx context.Context, input RenameSessionInput) error + listFilesFn func(ctx context.Context, input ListFilesInput) ([]FileEntry, error) + listModelsFn func(ctx context.Context, input ListModelsInput) ([]ModelEntry, error) + setSessionModelFn func(ctx context.Context, input SetSessionModelInput) error + getSessionModelFn func(ctx context.Context, input GetSessionModelInput) (SessionModelResult, error) + listProvidersFn func(ctx context.Context, input ListProvidersInput) ([]ProviderOption, error) + createProviderFn func(ctx context.Context, input CreateProviderInput) (ProviderSelectionResult, error) + deleteProviderFn func(ctx context.Context, input DeleteProviderInput) error + selectProviderFn func(ctx context.Context, input SelectProviderModelInput) (ProviderSelectionResult, error) + listMCPServersFn func(ctx context.Context, input ListMCPServersInput) ([]MCPServerEntry, error) + upsertMCPServerFn func(ctx context.Context, input UpsertMCPServerInput) error + setMCPEnabledFn func(ctx context.Context, input SetMCPServerEnabledInput) error + deleteMCPServerFn func(ctx context.Context, input DeleteMCPServerInput) error + listCheckpointsFn func(ctx context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error) + restoreCheckpointFn func(ctx context.Context, input CheckpointRestoreInput) (CheckpointRestoreResult, error) + undoRestoreFn func(ctx context.Context, input UndoRestoreInput) (CheckpointRestoreResult, error) + checkpointDiffFn func(ctx context.Context, input CheckpointDiffInput) (CheckpointDiffResult, error) } func (s *bootstrapRuntimeStub) Run(ctx context.Context, input RunInput) error { @@ -249,6 +253,34 @@ func (s *bootstrapRuntimeStub) CreateSession(ctx context.Context, input CreateSe return strings.TrimSpace(input.SessionID), nil } +func (s *bootstrapRuntimeStub) ListCheckpoints(ctx context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error) { + if s != nil && s.listCheckpointsFn != nil { + return s.listCheckpointsFn(ctx, input) + } + return nil, nil +} + +func (s *bootstrapRuntimeStub) RestoreCheckpoint(ctx context.Context, input CheckpointRestoreInput) (CheckpointRestoreResult, error) { + if s != nil && s.restoreCheckpointFn != nil { + return s.restoreCheckpointFn(ctx, input) + } + return CheckpointRestoreResult{}, nil +} + +func (s *bootstrapRuntimeStub) UndoRestore(ctx context.Context, input UndoRestoreInput) (CheckpointRestoreResult, error) { + if s != nil && s.undoRestoreFn != nil { + return s.undoRestoreFn(ctx, input) + } + return CheckpointRestoreResult{}, nil +} + +func (s *bootstrapRuntimeStub) CheckpointDiff(ctx context.Context, input CheckpointDiffInput) (CheckpointDiffResult, error) { + if s != nil && s.checkpointDiffFn != nil { + return s.checkpointDiffFn(ctx, input) + } + return CheckpointDiffResult{}, nil +} + func TestDispatchRequestFramePing(t *testing.T) { response := dispatchRequestFrame(context.Background(), MessageFrame{ Type: FrameTypeRequest, @@ -419,6 +451,192 @@ func TestDecodeSessionSkillAndSnapshotPayloadBranches(t *testing.T) { } } +func TestCheckpointFrameHandlers(t *testing.T) { + t.Run("list checkpoints success", func(t *testing.T) { + runtime := &bootstrapRuntimeStub{ + listCheckpointsFn: func(_ context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error) { + if input.SubjectID != "subject-1" || input.SessionID != "session-1" { + t.Fatalf("input = %#v", input) + } + return []CheckpointEntry{{CheckpointID: "cp-1", SessionID: "session-1"}}, nil + }, + } + authState := NewConnectionAuthState() + authState.MarkAuthenticated("subject-1") + ctx := WithConnectionAuthState(context.Background(), authState) + + response := handleListCheckpointsFrame(ctx, MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionListCheckpoints, + RequestID: "req-checkpoint-list", + SessionID: " session-1 ", + }, runtime) + + if response.Type != FrameTypeAck || response.Action != FrameActionListCheckpoints { + t.Fatalf("response = %#v", response) + } + entries, ok := response.Payload.([]CheckpointEntry) + if !ok || len(entries) != 1 || entries[0].CheckpointID != "cp-1" { + t.Fatalf("payload = %#v", response.Payload) + } + }) + + t.Run("restore checkpoint success", func(t *testing.T) { + runtime := &bootstrapRuntimeStub{ + restoreCheckpointFn: func(_ context.Context, input CheckpointRestoreInput) (CheckpointRestoreResult, error) { + if input.SubjectID != "subject-1" || input.SessionID != "session-1" || input.CheckpointID != "cp-1" || !input.Force { + t.Fatalf("input = %#v", input) + } + return CheckpointRestoreResult{CheckpointID: input.CheckpointID, SessionID: input.SessionID}, nil + }, + } + authState := NewConnectionAuthState() + authState.MarkAuthenticated("subject-1") + ctx := WithConnectionAuthState(context.Background(), authState) + + response := handleRestoreCheckpointFrame(ctx, MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionRestoreCheckpoint, + RequestID: "req-checkpoint-restore", + SessionID: " session-1 ", + Payload: map[string]any{ + "checkpoint_id": " cp-1 ", + "force": true, + }, + }, runtime) + + if response.Type != FrameTypeAck || response.Action != FrameActionRestoreCheckpoint || response.SessionID != "session-1" { + t.Fatalf("response = %#v", response) + } + result, ok := response.Payload.(CheckpointRestoreResult) + if !ok || result.CheckpointID != "cp-1" { + t.Fatalf("payload = %#v", response.Payload) + } + }) + + t.Run("undo restore success", func(t *testing.T) { + runtime := &bootstrapRuntimeStub{ + undoRestoreFn: func(_ context.Context, input UndoRestoreInput) (CheckpointRestoreResult, error) { + if input.SubjectID != "subject-1" || input.SessionID != "session-1" { + t.Fatalf("input = %#v", input) + } + return CheckpointRestoreResult{CheckpointID: "cp-guard", SessionID: input.SessionID}, nil + }, + } + authState := NewConnectionAuthState() + authState.MarkAuthenticated("subject-1") + ctx := WithConnectionAuthState(context.Background(), authState) + + response := handleUndoRestoreFrame(ctx, MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionUndoRestore, + RequestID: "req-checkpoint-undo", + SessionID: " session-1 ", + }, runtime) + + if response.Type != FrameTypeAck || response.Action != FrameActionUndoRestore || response.SessionID != "session-1" { + t.Fatalf("response = %#v", response) + } + result, ok := response.Payload.(CheckpointRestoreResult) + if !ok || result.CheckpointID != "cp-guard" { + t.Fatalf("payload = %#v", response.Payload) + } + }) + + t.Run("checkpoint diff success", func(t *testing.T) { + runtime := &bootstrapRuntimeStub{ + checkpointDiffFn: func(_ context.Context, input CheckpointDiffInput) (CheckpointDiffResult, error) { + if input.SubjectID != "subject-1" || input.SessionID != "session-1" || input.CheckpointID != "cp-1" { + t.Fatalf("input = %#v", input) + } + return CheckpointDiffResult{ + CheckpointID: input.CheckpointID, + PrevCheckpointID: "cp-0", + Files: FileDiffs{ + Modified: []string{"README.md"}, + }, + Patch: "diff --git a/README.md b/README.md", + }, nil + }, + } + authState := NewConnectionAuthState() + authState.MarkAuthenticated("subject-1") + ctx := WithConnectionAuthState(context.Background(), authState) + + response := handleCheckpointDiffFrame(ctx, MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionCheckpointDiff, + RequestID: "req-checkpoint-diff", + SessionID: " session-1 ", + Payload: map[string]any{ + "checkpoint_id": " cp-1 ", + }, + }, runtime) + + if response.Type != FrameTypeAck || response.Action != FrameActionCheckpointDiff || response.SessionID != "session-1" { + t.Fatalf("response = %#v", response) + } + result, ok := response.Payload.(CheckpointDiffResult) + if !ok || result.CheckpointID != "cp-1" || result.PrevCheckpointID != "cp-0" || + len(result.Files.Modified) != 1 || result.Files.Modified[0] != "README.md" || + result.Patch != "diff --git a/README.md b/README.md" { + t.Fatalf("payload = %#v", response.Payload) + } + }) +} + +func TestDecodeCheckpointRestorePayloadBranches(t *testing.T) { + t.Parallel() + + params := decodeCheckpointRestorePayload(map[string]any{ + "session_id": " session-1 ", + "checkpoint_id": " cp-1 ", + "force": true, + }) + if params.SessionID != "session-1" || params.CheckpointID != "cp-1" || !params.Force { + t.Fatalf("decode map payload = %#v", params) + } + + params = decodeCheckpointRestorePayload(CheckpointRestoreInput{ + SessionID: "session-2", + CheckpointID: "cp-2", + Force: true, + }) + if params.SessionID != "session-2" || params.CheckpointID != "cp-2" || !params.Force { + t.Fatalf("decode struct payload = %#v", params) + } + + params = decodeCheckpointRestorePayload(invalidJSONMarshaler{}) + if params != (CheckpointRestoreInput{}) { + t.Fatalf("marshal failure should return zero input, got %#v", params) + } +} + +func TestDecodeCheckpointDiffPayloadBranches(t *testing.T) { + t.Parallel() + + params := decodeCheckpointDiffPayload(map[string]any{ + "session_id": " session-1 ", + "checkpoint_id": " cp-1 ", + }) + if params.SessionID != "session-1" || params.CheckpointID != "cp-1" { + t.Fatalf("decode map payload = %#v", params) + } + + params = decodeCheckpointDiffPayload(CheckpointDiffInput{ + SessionID: "session-2", + CheckpointID: "cp-2", + }) + if params.SessionID != "session-2" || params.CheckpointID != "cp-2" { + t.Fatalf("decode struct payload = %#v", params) + } + + params = decodeCheckpointDiffPayload(invalidJSONMarshaler{}) + if params != (CheckpointDiffInput{}) { + t.Fatalf("marshal failure should return zero input, got %#v", params) + } +} + func TestDispatchRequestFrameWakeOpenURLReviewSuccess(t *testing.T) { createInputs := make(chan CreateSessionInput, 1) stub := &bootstrapRuntimeStub{ @@ -3799,7 +4017,9 @@ func TestDecodeRenameSessionPayloadBranches(t *testing.T) { }) t.Run("marshal error", func(t *testing.T) { - _, err := decodeRenameSessionPayload(struct{ Bad chan int `json:"bad"` }{Bad: make(chan int)}) + _, err := decodeRenameSessionPayload(struct { + Bad chan int `json:"bad"` + }{Bad: make(chan int)}) if err == nil || err.Code != ErrorCodeInvalidFrame.String() { t.Fatalf("expected invalid frame error, got %#v", err) } @@ -3845,7 +4065,9 @@ func TestDecodeListFilesPayloadBranches(t *testing.T) { }) t.Run("marshal error", func(t *testing.T) { - _, err := decodeListFilesPayload(struct{ Bad chan int `json:"bad"` }{Bad: make(chan int)}) + _, err := decodeListFilesPayload(struct { + Bad chan int `json:"bad"` + }{Bad: make(chan int)}) if err == nil || err.Code != ErrorCodeInvalidFrame.String() { t.Fatalf("expected invalid frame error, got %#v", err) } @@ -3888,7 +4110,9 @@ func TestDecodeSetSessionModelPayloadBranches(t *testing.T) { }) t.Run("marshal error", func(t *testing.T) { - _, err := decodeSetSessionModelPayload(struct{ Bad chan int `json:"bad"` }{Bad: make(chan int)}) + _, err := decodeSetSessionModelPayload(struct { + Bad chan int `json:"bad"` + }{Bad: make(chan int)}) if err == nil || err.Code != ErrorCodeInvalidFrame.String() { t.Fatalf("expected invalid frame error, got %#v", err) } @@ -3956,7 +4180,7 @@ func TestHandleAuthenticateFrameAdditionalBranches(t *testing.T) { type emptySubjectAuthenticator struct{} -func (emptySubjectAuthenticator) ValidateToken(token string) bool { return true } +func (emptySubjectAuthenticator) ValidateToken(token string) bool { return true } func (emptySubjectAuthenticator) ResolveSubjectID(token string) (string, bool) { return "", true } func TestRuntimeCallFailedFrameErrorCodes(t *testing.T) { @@ -3996,39 +4220,73 @@ func TestRuntimeCallFailedFrameErrorCodes(t *testing.T) { // runtimeOnlyStub implements RuntimePort but NOT ManagementRuntimePort. type runtimeOnlyStub struct{} -func (runtimeOnlyStub) Run(ctx context.Context, input RunInput) error { return nil } -func (runtimeOnlyStub) Compact(ctx context.Context, input CompactInput) (CompactResult, error) { return CompactResult{}, nil } +func (runtimeOnlyStub) Run(ctx context.Context, input RunInput) error { return nil } +func (runtimeOnlyStub) Compact(ctx context.Context, input CompactInput) (CompactResult, error) { + return CompactResult{}, nil +} func (runtimeOnlyStub) ExecuteSystemTool(ctx context.Context, input ExecuteSystemToolInput) (tools.ToolResult, error) { return tools.ToolResult{}, nil } -func (runtimeOnlyStub) ActivateSessionSkill(ctx context.Context, input SessionSkillMutationInput) error { return nil } -func (runtimeOnlyStub) DeactivateSessionSkill(ctx context.Context, input SessionSkillMutationInput) error { return nil } +func (runtimeOnlyStub) ActivateSessionSkill(ctx context.Context, input SessionSkillMutationInput) error { + return nil +} +func (runtimeOnlyStub) DeactivateSessionSkill(ctx context.Context, input SessionSkillMutationInput) error { + return nil +} func (runtimeOnlyStub) ListSessionSkills(ctx context.Context, input ListSessionSkillsInput) ([]SessionSkillState, error) { return nil, nil } func (runtimeOnlyStub) ListAvailableSkills(ctx context.Context, input ListAvailableSkillsInput) ([]AvailableSkillState, error) { return nil, nil } -func (runtimeOnlyStub) ResolvePermission(ctx context.Context, input PermissionResolutionInput) error { return nil } -func (runtimeOnlyStub) CancelRun(ctx context.Context, input CancelInput) (bool, error) { return false, nil } -func (runtimeOnlyStub) Events() <-chan RuntimeEvent { return nil } -func (runtimeOnlyStub) ListSessions(ctx context.Context) ([]SessionSummary, error) { return nil, nil } -func (runtimeOnlyStub) LoadSession(ctx context.Context, input LoadSessionInput) (Session, error) { return Session{}, nil } +func (runtimeOnlyStub) ResolvePermission(ctx context.Context, input PermissionResolutionInput) error { + return nil +} +func (runtimeOnlyStub) CancelRun(ctx context.Context, input CancelInput) (bool, error) { + return false, nil +} +func (runtimeOnlyStub) Events() <-chan RuntimeEvent { return nil } +func (runtimeOnlyStub) ListSessions(ctx context.Context) ([]SessionSummary, error) { return nil, nil } +func (runtimeOnlyStub) LoadSession(ctx context.Context, input LoadSessionInput) (Session, error) { + return Session{}, nil +} func (runtimeOnlyStub) ListSessionTodos(ctx context.Context, input ListSessionTodosInput) (TodoSnapshot, error) { return TodoSnapshot{}, nil } func (runtimeOnlyStub) GetRuntimeSnapshot(ctx context.Context, input GetRuntimeSnapshotInput) (RuntimeSnapshot, error) { return RuntimeSnapshot{}, nil } -func (runtimeOnlyStub) CreateSession(ctx context.Context, input CreateSessionInput) (string, error) { return "", nil } -func (runtimeOnlyStub) DeleteSession(ctx context.Context, input DeleteSessionInput) (bool, error) { return false, nil } -func (runtimeOnlyStub) RenameSession(ctx context.Context, input RenameSessionInput) error { return nil } -func (runtimeOnlyStub) ListFiles(ctx context.Context, input ListFilesInput) ([]FileEntry, error) { return nil, nil } -func (runtimeOnlyStub) ListModels(ctx context.Context, input ListModelsInput) ([]ModelEntry, error) { return nil, nil } -func (runtimeOnlyStub) SetSessionModel(ctx context.Context, input SetSessionModelInput) error { return nil } +func (runtimeOnlyStub) CreateSession(ctx context.Context, input CreateSessionInput) (string, error) { + return "", nil +} +func (runtimeOnlyStub) DeleteSession(ctx context.Context, input DeleteSessionInput) (bool, error) { + return false, nil +} +func (runtimeOnlyStub) RenameSession(ctx context.Context, input RenameSessionInput) error { return nil } +func (runtimeOnlyStub) ListFiles(ctx context.Context, input ListFilesInput) ([]FileEntry, error) { + return nil, nil +} +func (runtimeOnlyStub) ListModels(ctx context.Context, input ListModelsInput) ([]ModelEntry, error) { + return nil, nil +} +func (runtimeOnlyStub) SetSessionModel(ctx context.Context, input SetSessionModelInput) error { + return nil +} func (runtimeOnlyStub) GetSessionModel(ctx context.Context, input GetSessionModelInput) (SessionModelResult, error) { return SessionModelResult{}, nil } +func (runtimeOnlyStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { + return nil, nil +} +func (runtimeOnlyStub) RestoreCheckpoint(_ context.Context, _ CheckpointRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} +func (runtimeOnlyStub) UndoRestore(_ context.Context, _ UndoRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} +func (runtimeOnlyStub) CheckpointDiff(_ context.Context, _ CheckpointDiffInput) (CheckpointDiffResult, error) { + return CheckpointDiffResult{}, nil +} type managementRuntimeStub struct { bootstrapRuntimeStub diff --git a/internal/gateway/contracts.go b/internal/gateway/contracts.go index 3a1da391..f229769e 100644 --- a/internal/gateway/contracts.go +++ b/internal/gateway/contracts.go @@ -265,6 +265,82 @@ type SessionModelResult struct { Provider string `json:"provider,omitempty"` } +// ListCheckpointsInput 描述查询 checkpoint 列表的输入。 +type ListCheckpointsInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是目标会话标识。 + SessionID string + // Limit 限制返回数量,0 表示不限制。 + Limit int + // RestorableOnly 仅返回可恢复的 checkpoint。 + RestorableOnly bool +} + +// CheckpointEntry 描述单个 checkpoint 的列表视图。 +type CheckpointEntry struct { + CheckpointID string `json:"checkpoint_id"` + SessionID string `json:"session_id"` + Reason string `json:"reason"` + Status string `json:"status"` + Restorable bool `json:"restorable"` + CreatedAt int64 `json:"created_at_ms"` +} + +// CheckpointRestoreInput 描述恢复 checkpoint 的输入。 +type CheckpointRestoreInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是目标会话标识。 + SessionID string + // CheckpointID 是要恢复的 checkpoint 标识。 + CheckpointID string + // Force 强制恢复,忽略冲突检测。 + Force bool +} + +// UndoRestoreInput 描述撤销 restore 的输入。 +type UndoRestoreInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是目标会话标识。 + SessionID string +} + +// CheckpointRestoreResult 描述 checkpoint 恢复操作的结果。 +type CheckpointRestoreResult struct { + CheckpointID string `json:"checkpoint_id"` + SessionID string `json:"session_id"` + HasConflict bool `json:"has_conflict,omitempty"` +} + +// CheckpointDiffInput 描述 checkpoint diff 查询输入。 +type CheckpointDiffInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是目标会话标识。 + SessionID string `json:"session_id"` + // CheckpointID 是可选的 checkpoint 标识,为空则查最新代码检查点。 + CheckpointID string `json:"checkpoint_id,omitempty"` +} + +// CheckpointDiffResult 描述两个相邻代码检查点之间的差异。 +type CheckpointDiffResult struct { + CheckpointID string `json:"checkpoint_id"` + PrevCheckpointID string `json:"prev_checkpoint_id,omitempty"` + CommitHash string `json:"commit_hash,omitempty"` + PrevCommitHash string `json:"prev_commit_hash,omitempty"` + Files FileDiffs `json:"files"` + Patch string `json:"patch,omitempty"` +} + +// FileDiffs 描述 diff 中的文件变更列表。 +type FileDiffs struct { + Added []string `json:"added,omitempty"` + Deleted []string `json:"deleted,omitempty"` + Modified []string `json:"modified,omitempty"` +} + // ProviderOption 表示前端管理面可见的 provider 及模型候选。 type ProviderOption struct { // ID 是 provider 标识。 @@ -596,6 +672,14 @@ type RuntimePort interface { SetSessionModel(ctx context.Context, input SetSessionModelInput) error // GetSessionModel 获取当前会话模型。 GetSessionModel(ctx context.Context, input GetSessionModelInput) (SessionModelResult, error) + // ListCheckpoints 查询指定会话的 checkpoint 列表。 + ListCheckpoints(ctx context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error) + // RestoreCheckpoint 恢复到指定 checkpoint。 + RestoreCheckpoint(ctx context.Context, input CheckpointRestoreInput) (CheckpointRestoreResult, error) + // UndoRestore 撤销最近一次 checkpoint 恢复。 + UndoRestore(ctx context.Context, input UndoRestoreInput) (CheckpointRestoreResult, error) + // CheckpointDiff 查询两个相邻代码检查点之间的差异。 + CheckpointDiff(ctx context.Context, input CheckpointDiffInput) (CheckpointDiffResult, error) } // ManagementRuntimePort 定义前端管理面访问配置能力的可选下游端口。 diff --git a/internal/gateway/contracts_test.go b/internal/gateway/contracts_test.go index 2b4f498d..303b3b5c 100644 --- a/internal/gateway/contracts_test.go +++ b/internal/gateway/contracts_test.go @@ -126,6 +126,22 @@ func (s *runtimePortCompileStub) CreateSession(_ context.Context, _ CreateSessio return "", nil } +func (s *runtimePortCompileStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { + return nil, nil +} + +func (s *runtimePortCompileStub) RestoreCheckpoint(_ context.Context, _ CheckpointRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + +func (s *runtimePortCompileStub) UndoRestore(_ context.Context, _ UndoRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + +func (s *runtimePortCompileStub) CheckpointDiff(_ context.Context, _ CheckpointDiffInput) (CheckpointDiffResult, error) { + return CheckpointDiffResult{}, nil +} + var _ RuntimePort = (*runtimePortCompileStub)(nil) var _ TransportAdapter = (*Server)(nil) var _ TransportAdapter = (*NetworkServer)(nil) diff --git a/internal/gateway/registry.go b/internal/gateway/registry.go index 90d66cb1..e1a4522e 100644 --- a/internal/gateway/registry.go +++ b/internal/gateway/registry.go @@ -71,6 +71,10 @@ func (r *ActionRegistry) initCore() { r.core[FrameActionUpsertMCPServer] = handleUpsertMCPServerFrame r.core[FrameActionSetMCPServerEnabled] = handleSetMCPServerEnabledFrame r.core[FrameActionDeleteMCPServer] = handleDeleteMCPServerFrame + r.core[FrameActionListCheckpoints] = handleListCheckpointsFrame + r.core[FrameActionRestoreCheckpoint] = handleRestoreCheckpointFrame + r.core[FrameActionUndoRestore] = handleUndoRestoreFrame + r.core[FrameActionCheckpointDiff] = handleCheckpointDiffFrame } // Lookup returns the handler for an action. diff --git a/internal/gateway/rpc_dispatch_test.go b/internal/gateway/rpc_dispatch_test.go index 3c1dac8d..0edd72a3 100644 --- a/internal/gateway/rpc_dispatch_test.go +++ b/internal/gateway/rpc_dispatch_test.go @@ -198,6 +198,22 @@ func (s *rpcRunCaptureRuntimeStub) GetRuntimeSnapshot( return RuntimeSnapshot{}, nil } +func (s *rpcRunCaptureRuntimeStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { + return nil, nil +} + +func (s *rpcRunCaptureRuntimeStub) RestoreCheckpoint(_ context.Context, _ CheckpointRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + +func (s *rpcRunCaptureRuntimeStub) UndoRestore(_ context.Context, _ UndoRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + +func (s *rpcRunCaptureRuntimeStub) CheckpointDiff(_ context.Context, _ CheckpointDiffInput) (CheckpointDiffResult, error) { + return CheckpointDiffResult{}, nil +} + func TestDispatchRPCRequestResultEncodeError(t *testing.T) { installHandlerRegistryForTest(t, map[FrameAction]requestFrameHandler{ FrameActionPing: func(_ context.Context, frame MessageFrame, _ RuntimePort) MessageFrame { @@ -957,6 +973,22 @@ func (s *runtimePortOnlyStub) GetSessionModel(_ context.Context, _ GetSessionMod return SessionModelResult{}, nil } +func (s *runtimePortOnlyStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { + return nil, nil +} + +func (s *runtimePortOnlyStub) RestoreCheckpoint(_ context.Context, _ CheckpointRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + +func (s *runtimePortOnlyStub) UndoRestore(_ context.Context, _ UndoRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + +func (s *runtimePortOnlyStub) CheckpointDiff(_ context.Context, _ CheckpointDiffInput) (CheckpointDiffResult, error) { + return CheckpointDiffResult{}, nil +} + func TestDispatchRPCRequestProviderMethodsManagementPortUnavailable(t *testing.T) { ctx := WithRequestSource(context.Background(), RequestSourceIPC) ctx = WithRequestACL(ctx, NewStrictControlPlaneACL()) diff --git a/internal/gateway/server_test.go b/internal/gateway/server_test.go index a4e27b04..b39ae3a5 100644 --- a/internal/gateway/server_test.go +++ b/internal/gateway/server_test.go @@ -457,6 +457,22 @@ func (s *runtimePortEventStub) GetRuntimeSnapshot( return RuntimeSnapshot{}, nil } +func (s *runtimePortEventStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { + return nil, nil +} + +func (s *runtimePortEventStub) RestoreCheckpoint(_ context.Context, _ CheckpointRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + +func (s *runtimePortEventStub) UndoRestore(_ context.Context, _ UndoRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + +func (s *runtimePortEventStub) CheckpointDiff(_ context.Context, _ CheckpointDiffInput) (CheckpointDiffResult, error) { + return CheckpointDiffResult{}, nil +} + func decodeJSONRPCResultFrame(response protocol.JSONRPCResponse) (MessageFrame, error) { if response.Result == nil { return MessageFrame{}, errors.New("rpc result is nil") diff --git a/internal/gateway/types.go b/internal/gateway/types.go index eca4b8e0..53e5db8d 100644 --- a/internal/gateway/types.go +++ b/internal/gateway/types.go @@ -80,6 +80,14 @@ const ( FrameActionDeleteMCPServer FrameAction = "delete_mcp_server" // FrameActionWakeOpenURL 表示处理 URL Scheme 唤醒请求。 FrameActionWakeOpenURL FrameAction = "wake.openUrl" + // FrameActionListCheckpoints 表示查询会话 checkpoint 列表。 + FrameActionListCheckpoints FrameAction = "checkpoint_list" + // FrameActionRestoreCheckpoint 表示恢复到指定 checkpoint。 + FrameActionRestoreCheckpoint FrameAction = "checkpoint_restore" + // FrameActionUndoRestore 表示撤销最近一次 checkpoint 恢复。 + FrameActionUndoRestore FrameAction = "checkpoint_undo_restore" + // FrameActionCheckpointDiff 表示查询两个相邻代码检查点之间的差异。 + FrameActionCheckpointDiff FrameAction = "checkpoint_diff" ) // InputPartType 表示多模态输入分片类型。 diff --git a/internal/promptasset/templates/core/tool_usage.md b/internal/promptasset/templates/core/tool_usage.md index a5244b06..e80cf99f 100644 --- a/internal/promptasset/templates/core/tool_usage.md +++ b/internal/promptasset/templates/core/tool_usage.md @@ -14,6 +14,13 @@ - Use `filesystem_write_file` only for new files or full rewrites. - For simple create/overwrite tasks, prefer `filesystem_write_file` with `verify_after_write=true` so one call can emit write + verification facts. - Do not use `bash` to edit files when the filesystem tools can make the change safely. +- For file system structure changes inside the workspace, prefer the dedicated tools over `bash`: + - rename/move: `filesystem_move_file` (not `bash mv`) + - copy: `filesystem_copy_file` (not `bash cp`) + - delete file: `filesystem_delete_file` (not `bash rm`) + - create directory: `filesystem_create_dir` (not `bash mkdir`) + - remove directory: `filesystem_remove_dir` (not `bash rmdir` / `rm -rf`) + These tools record their changes for checkpoint/rollback; equivalent `bash` commands produce reduced rollback coverage. - For multi-step implementation, debugging, refactoring, or long-running work, keep task state explicit via `todo_write` (plan/add/update/set_status/claim/complete/fail) instead of relying on implicit memory. - Create todos that map to real acceptance work, not vague activity. - Required todos are acceptance-relevant and must converge before finalization. @@ -55,6 +62,7 @@ - Do not claim work is done if verification failed, was skipped without reason, could not run, or the needed files and commands did not actually succeed. ## Bash usage +- Whenever a `filesystem_*` tool can express the operation, use it instead of `bash`. The runtime tracks `filesystem_*` operations precisely; `bash` mutations are tracked only via best-effort heuristics + workdir scanning, so undoing them is less reliable. - When using `bash`, avoid interactive or blocking commands and pass non-interactive flags when they are available. - Stay within the current workspace unless the user clearly asks for something else. - Use Git through `bash` with this order: inspect (`git status`/`git diff`/`git log`), then mutate, then verify (`git status`/`git diff`), then summarize. diff --git a/internal/runtime/checkpoint_flow_test.go b/internal/runtime/checkpoint_flow_test.go new file mode 100644 index 00000000..5a6bb735 --- /dev/null +++ b/internal/runtime/checkpoint_flow_test.go @@ -0,0 +1,786 @@ +package runtime + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "neo-code/internal/checkpoint" + providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" +) + +type checkpointStoreSpy struct { + lastResume agentsession.ResumeCheckpoint + listRecords []agentsession.CheckpointRecord + listSessionID string + listOpts checkpoint.ListCheckpointOpts + listErr error + getRecord agentsession.CheckpointRecord + getSessionCP *agentsession.SessionCheckpoint + getErr error +} + +func (s *checkpointStoreSpy) CreateCheckpoint(_ context.Context, in checkpoint.CreateCheckpointInput) (agentsession.CheckpointRecord, error) { + return in.Record, nil +} + +func (s *checkpointStoreSpy) ListCheckpoints(_ context.Context, sessionID string, opts checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) { + s.listSessionID = sessionID + s.listOpts = opts + return s.listRecords, s.listErr +} + +func (s *checkpointStoreSpy) GetCheckpoint(context.Context, string) (agentsession.CheckpointRecord, *agentsession.SessionCheckpoint, error) { + return s.getRecord, s.getSessionCP, s.getErr +} + +func (s *checkpointStoreSpy) UpdateCheckpointStatus(context.Context, string, agentsession.CheckpointStatus) error { + return nil +} + +func (s *checkpointStoreSpy) GetLatestResumeCheckpoint(context.Context, string) (*agentsession.ResumeCheckpoint, error) { + return nil, nil +} + +func (s *checkpointStoreSpy) RestoreCheckpoint(context.Context, checkpoint.RestoreCheckpointInput) error { + return nil +} + +func (s *checkpointStoreSpy) SetResumeCheckpoint(_ context.Context, rc agentsession.ResumeCheckpoint) error { + s.lastResume = rc + return nil +} + +func (s *checkpointStoreSpy) PruneExpiredCheckpoints(context.Context, string, int) (int, error) { + return 0, nil +} + +func (s *checkpointStoreSpy) RepairCreatingCheckpoints(context.Context) (int, error) { + return 0, nil +} + +type runtimeCheckpointFixture struct { + service *Service + sessionStore *agentsession.SQLiteStore + checkpointStore *checkpoint.SQLiteCheckpointStore + perEditStore *checkpoint.PerEditSnapshotStore + workdir string + projectDir string + session agentsession.Session +} + +func newRuntimeCheckpointFixture(t *testing.T) runtimeCheckpointFixture { + t.Helper() + + baseDir := t.TempDir() + workdir := t.TempDir() + projectDir := t.TempDir() + + sessionStore := agentsession.NewSQLiteStore(baseDir, workdir) + t.Cleanup(func() { _ = sessionStore.Close() }) + + checkpointStore := checkpoint.NewSQLiteCheckpointStore(agentsession.DatabasePath(baseDir, workdir)) + t.Cleanup(func() { _ = checkpointStore.Close() }) + + perEditStore := checkpoint.NewPerEditSnapshotStore(projectDir, workdir) + + created, err := sessionStore.CreateSession(context.Background(), agentsession.CreateSessionInput{ + ID: "runtime-checkpoint-session", + Title: "runtime checkpoint", + Head: agentsession.SessionHead{ + Provider: "openai", + Model: "gpt-test", + Workdir: workdir, + TaskState: agentsession.TaskState{ + Goal: "initial goal", + VerificationProfile: agentsession.VerificationProfileTaskOnly, + }, + }, + }) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + if err := sessionStore.AppendMessages(context.Background(), agentsession.AppendMessagesInput{ + SessionID: created.ID, + Messages: []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("before restore"), + }, + }, + }, + UpdatedAt: time.Now(), + Provider: "openai", + Model: "gpt-test", + Workdir: workdir, + }); err != nil { + t.Fatalf("AppendMessages() error = %v", err) + } + loaded, err := sessionStore.LoadSession(context.Background(), created.ID) + if err != nil { + t.Fatalf("LoadSession() error = %v", err) + } + + return runtimeCheckpointFixture{ + service: &Service{ + sessionStore: sessionStore, + checkpointStore: checkpointStore, + perEditStore: perEditStore, + events: make(chan RuntimeEvent, 32), + }, + sessionStore: sessionStore, + checkpointStore: checkpointStore, + perEditStore: perEditStore, + workdir: workdir, + projectDir: projectDir, + session: loaded, + } +} + +// captureFile is a test helper that drops a file at workdir-relative path and asks +// the per-edit store to capture its current content as a pending pre-write version. +func (f runtimeCheckpointFixture) captureFile(t *testing.T, relPath string, content []byte) string { + t.Helper() + abs := filepath.Join(f.workdir, relPath) + if err := os.MkdirAll(filepath.Dir(abs), 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + if err := os.WriteFile(abs, content, 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + if _, err := f.perEditStore.CapturePreWrite(abs); err != nil { + t.Fatalf("CapturePreWrite() error = %v", err) + } + return abs +} + +func TestCreateStartOfTurnCheckpoint_PendingWrite(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + fixture.captureFile(t, "main.go", []byte("package main\nconst v = 1\n")) + + state := newRunState("run-pending", fixture.session) + if err := fixture.service.createStartOfTurnCheckpoint(context.Background(), &state); err != nil { + t.Fatalf("createStartOfTurnCheckpoint() error = %v", err) + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("records count = %d, want 1: %#v", len(records), records) + } + if records[0].Reason != agentsession.CheckpointReasonPreWrite { + t.Fatalf("reason = %s, want pre_write", records[0].Reason) + } + if !checkpoint.IsPerEditRef(records[0].CodeCheckpointRef) { + t.Fatalf("code ref = %q, want peredit ref", records[0].CodeCheckpointRef) + } +} + +func TestCreateStartOfTurnCheckpoint_NoPending_SessionOnly(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + + state := newRunState("run-empty", fixture.session) + if err := fixture.service.createStartOfTurnCheckpoint(context.Background(), &state); err != nil { + t.Fatalf("createStartOfTurnCheckpoint() error = %v", err) + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("records = %#v, want one session-only checkpoint", records) + } + if records[0].Reason != agentsession.CheckpointReasonPreWrite { + t.Fatalf("reason = %s, want pre_write", records[0].Reason) + } + if records[0].CodeCheckpointRef != "" { + t.Fatalf("code ref = %q, want empty (session-only)", records[0].CodeCheckpointRef) + } +} + +func TestCreateEndOfTurnCheckpoint_NoWriteSkipped(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + fixture.captureFile(t, "main.go", []byte("package main\n")) + + state := newRunState("run-no-write", fixture.session) + fixture.service.createEndOfTurnCheckpoint(context.Background(), &state, false) + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 0 { + t.Fatalf("records = %#v, want no checkpoint when hasWorkspaceWrite=false", records) + } +} + +func TestCreateEndOfTurnCheckpoint_PerEditSkipsEmpty(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + + state := newRunState("run-empty", fixture.session) + fixture.service.createEndOfTurnCheckpoint(context.Background(), &state, true) + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 0 { + t.Fatalf("records = %#v, want no checkpoint when no pending writes captured", records) + } +} + +func TestCreateEndOfTurnCheckpoint_WithPending(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + fixture.captureFile(t, "lib.go", []byte("package lib\n")) + + state := newRunState("run-eot", fixture.session) + fixture.service.createEndOfTurnCheckpoint(context.Background(), &state, true) + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("records = %#v, want 1 end-of-turn checkpoint", records) + } + if records[0].Reason != agentsession.CheckpointReasonEndOfTurn { + t.Fatalf("reason = %s, want end_of_turn", records[0].Reason) + } + if !checkpoint.IsPerEditRef(records[0].CodeCheckpointRef) { + t.Fatalf("code ref = %q, want peredit ref", records[0].CodeCheckpointRef) + } +} + +func TestCreateCompactCheckpoint(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + if err := os.WriteFile(filepath.Join(fixture.workdir, "compact.txt"), []byte("compact"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + fixture.service.createCompactCheckpoint(context.Background(), "run-compact", fixture.session) + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 || records[0].Reason != agentsession.CheckpointReasonCompact { + t.Fatalf("records = %#v, want compact checkpoint", records) + } +} + +func TestUpdateResumeCheckpoint(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + state := newRunState("run-resume", fixture.session) + state.turn = 3 + spy := &checkpointStoreSpy{} + service := &Service{checkpointStore: spy} + service.updateResumeCheckpoint(context.Background(), &state, "verify", "running") + + if spy.lastResume.SessionID != fixture.session.ID || spy.lastResume.RunID != "run-resume" || spy.lastResume.Turn != 3 || spy.lastResume.Phase != "verify" { + t.Fatalf("SetResumeCheckpoint() captured %#v", spy.lastResume) + } +} + +func TestRuntimeCheckpointFacadeMethods(t *testing.T) { + t.Run("list checkpoints delegates to store", func(t *testing.T) { + spy := &checkpointStoreSpy{ + listRecords: []agentsession.CheckpointRecord{{CheckpointID: "cp-1"}}, + } + service := &Service{checkpointStore: spy} + + records, err := service.ListCheckpoints(context.Background(), "session-1", checkpoint.ListCheckpointOpts{ + Limit: 5, + RestorableOnly: true, + }) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if spy.listSessionID != "session-1" || spy.listOpts.Limit != 5 || !spy.listOpts.RestorableOnly { + t.Fatalf("spy captured session=%q opts=%#v", spy.listSessionID, spy.listOpts) + } + if len(records) != 1 || records[0].CheckpointID != "cp-1" { + t.Fatalf("records = %#v", records) + } + }) + + t.Run("list checkpoints reports unavailable store", func(t *testing.T) { + service := &Service{} + if _, err := service.ListCheckpoints(context.Background(), "session-1", checkpoint.ListCheckpointOpts{}); err == nil { + t.Fatal("expected error when checkpoint store is unavailable") + } + }) + + t.Run("set checkpoint dependencies stores references", func(t *testing.T) { + service := &Service{} + store := &checkpointStoreSpy{} + perEdit := checkpoint.NewPerEditSnapshotStore(t.TempDir(), t.TempDir()) + + service.SetCheckpointDependencies(store, perEdit) + if service.checkpointStore != store || service.perEditStore != perEdit { + t.Fatalf("service checkpoint dependencies not set correctly") + } + }) + + t.Run("update runtime session after restore invalidates cache", func(t *testing.T) { + service := &Service{ + runtimeSnapshots: map[string]RuntimeSnapshot{ + "session-1": {SessionID: "session-1", Phase: "execute"}, + }, + } + service.updateRuntimeSessionAfterRestore("session-1", agentsession.SessionHead{}, nil) + if _, ok := service.runtimeSnapshots["session-1"]; ok { + t.Fatal("expected cached snapshot to be deleted after restore") + } + }) +} + +func TestRestoreCheckpoint_RecoversCapturedFile(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + target := filepath.Join(fixture.workdir, "restore.txt") + if err := os.WriteFile(target, []byte("version one"), 0o644); err != nil { + t.Fatalf("WriteFile(version one) error = %v", err) + } + if _, err := fixture.perEditStore.CapturePreWrite(target); err != nil { + t.Fatalf("CapturePreWrite() error = %v", err) + } + + state := newRunState("run-restore", fixture.session) + if err := fixture.service.createStartOfTurnCheckpoint(context.Background(), &state); err != nil { + t.Fatalf("createStartOfTurnCheckpoint() error = %v", err) + } + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("records = %#v, want 1", records) + } + cpRecord := records[0] + + // mark checkpoint available so RestoreCheckpoint accepts it + if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), cpRecord.CheckpointID, agentsession.CheckpointStatusAvailable); err != nil { + t.Fatalf("UpdateCheckpointStatus() error = %v", err) + } + + // agent rewrites the file (capture v2 mid-flight) + if _, err := fixture.perEditStore.CapturePreWrite(target); err != nil { + t.Fatalf("CapturePreWrite(v2) error = %v", err) + } + if err := os.WriteFile(target, []byte("version two"), 0o644); err != nil { + t.Fatalf("WriteFile(version two) error = %v", err) + } + + if _, err := fixture.service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{ + SessionID: fixture.session.ID, + CheckpointID: cpRecord.CheckpointID, + }); err != nil { + t.Fatalf("RestoreCheckpoint() error = %v", err) + } + + got, err := os.ReadFile(target) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + if string(got) != "version one" { + t.Fatalf("restored content = %q, want %q", string(got), "version one") + } +} + +func TestUndoRestoreCheckpoint_RestoresGuardState(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + target := filepath.Join(fixture.workdir, "undo.txt") + if err := os.WriteFile(target, []byte("before"), 0o644); err != nil { + t.Fatalf("WriteFile(before) error = %v", err) + } + if _, err := fixture.perEditStore.CapturePreWrite(target); err != nil { + t.Fatalf("CapturePreWrite(before) error = %v", err) + } + + state := newRunState("run-undo", fixture.session) + if err := fixture.service.createStartOfTurnCheckpoint(context.Background(), &state); err != nil { + t.Fatalf("createStartOfTurnCheckpoint() error = %v", err) + } + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + cpRecord := records[0] + if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), cpRecord.CheckpointID, agentsession.CheckpointStatusAvailable); err != nil { + t.Fatalf("UpdateCheckpointStatus() error = %v", err) + } + + if _, err := fixture.perEditStore.CapturePreWrite(target); err != nil { + t.Fatalf("CapturePreWrite(guard) error = %v", err) + } + if err := os.WriteFile(target, []byte("after"), 0o644); err != nil { + t.Fatalf("WriteFile(after) error = %v", err) + } + + if _, err := fixture.service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{ + SessionID: fixture.session.ID, + CheckpointID: cpRecord.CheckpointID, + }); err != nil { + t.Fatalf("RestoreCheckpoint() error = %v", err) + } + if got := string(mustReadRuntimeFile(t, target)); got != "before" { + t.Fatalf("restored content = %q, want before", got) + } + + if _, err := fixture.service.UndoRestoreCheckpoint(context.Background(), fixture.session.ID); err != nil { + t.Fatalf("UndoRestoreCheckpoint() error = %v", err) + } + if got := string(mustReadRuntimeFile(t, target)); got != "before" { + t.Fatalf("undo content = %q, want before", got) + } + + seenUndo := false + for { + select { + case evt := <-fixture.service.events: + if evt.Type == EventCheckpointUndoRestore { + payload, ok := evt.Payload.(CheckpointUndoRestorePayload) + if !ok { + t.Fatalf("undo payload type = %T", evt.Payload) + } + if payload.SessionID != fixture.session.ID { + t.Fatalf("undo payload session = %q, want %q", payload.SessionID, fixture.session.ID) + } + seenUndo = true + } + default: + if !seenUndo { + t.Fatal("expected checkpoint undo event") + } + return + } + } +} + +func TestCheckpointDiff_ReturnsPatchAndClassifiedFiles(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + ctx := context.Background() + + alpha := filepath.Join(fixture.workdir, "alpha.txt") + if err := os.WriteFile(alpha, []byte("zero\n"), 0o644); err != nil { + t.Fatalf("WriteFile(alpha) error = %v", err) + } + + if _, err := fixture.perEditStore.CapturePreWrite(alpha); err != nil { + t.Fatalf("CapturePreWrite(alpha cp1) error = %v", err) + } + if err := os.WriteFile(alpha, []byte("one\n"), 0o644); err != nil { + t.Fatalf("WriteFile(alpha one) error = %v", err) + } + if _, err := fixture.perEditStore.Finalize("cp1"); err != nil { + t.Fatalf("Finalize(cp1) error = %v", err) + } + fixture.perEditStore.Reset() + + if _, err := fixture.perEditStore.CapturePreWrite(alpha); err != nil { + t.Fatalf("CapturePreWrite(alpha cp2) error = %v", err) + } + if err := os.WriteFile(alpha, []byte("two\n"), 0o644); err != nil { + t.Fatalf("WriteFile(alpha two) error = %v", err) + } + if _, err := fixture.perEditStore.Finalize("cp2"); err != nil { + t.Fatalf("Finalize(cp2) error = %v", err) + } + fixture.perEditStore.Reset() + + for _, cpID := range []string{"cp1", "cp2"} { + if _, err := fixture.checkpointStore.CreateCheckpoint(ctx, checkpoint.CreateCheckpointInput{ + Record: agentsession.CheckpointRecord{ + CheckpointID: cpID, + WorkspaceKey: agentsession.WorkspacePathKey(fixture.workdir), + SessionID: fixture.session.ID, + RunID: "run-diff", + Workdir: fixture.workdir, + CreatedAt: time.Now().UTC(), + Reason: agentsession.CheckpointReasonEndOfTurn, + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint(cpID), + Restorable: true, + Status: agentsession.CheckpointStatusAvailable, + }, + SessionCP: agentsession.SessionCheckpoint{ + ID: agentsession.NewID("sc"), + SessionID: fixture.session.ID, + HeadJSON: `{}`, + MessagesJSON: `[]`, + CreatedAt: time.Now().UTC(), + }, + }); err != nil { + t.Fatalf("CreateCheckpoint(%s) error = %v", cpID, err) + } + time.Sleep(time.Millisecond) + } + + result, err := fixture.service.CheckpointDiff(ctx, CheckpointDiffInput{ + SessionID: fixture.session.ID, + CheckpointID: "cp2", + }) + if err != nil { + t.Fatalf("CheckpointDiff() error = %v", err) + } + if result.CheckpointID != "cp2" || result.PrevCheckpointID != "cp1" { + t.Fatalf("unexpected checkpoint ids: %#v", result) + } + if !strings.Contains(result.Patch, "alpha.txt") { + t.Fatalf("patch should mention changed files, got:\n%s", result.Patch) + } + if len(result.Files.Modified) != 1 || result.Files.Modified[0] != "alpha.txt" { + t.Fatalf("modified files = %#v, want alpha.txt", result.Files.Modified) + } + if len(result.Files.Added) != 0 { + t.Fatalf("added files = %#v, want empty", result.Files.Added) + } + if len(result.Files.Deleted) != 0 { + t.Fatalf("deleted files = %#v, want empty", result.Files.Deleted) + } +} + +func TestRestoreCheckpointRejectsInvalidRequestAndMismatchedSession(t *testing.T) { + service := &Service{} + if _, err := service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{}); err == nil { + t.Fatal("expected error when checkpoint store is unavailable") + } + + service = &Service{ + checkpointStore: &checkpointStoreSpy{}, + perEditStore: checkpoint.NewPerEditSnapshotStore(t.TempDir(), t.TempDir()), + } + if _, err := service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{}); err == nil { + t.Fatal("expected validation error for empty identifiers") + } + + service.checkpointStore = &checkpointStoreSpy{ + getRecord: agentsession.CheckpointRecord{ + CheckpointID: "cp-1", + SessionID: "other-session", + Status: agentsession.CheckpointStatusAvailable, + Restorable: true, + }, + getSessionCP: &agentsession.SessionCheckpoint{ + HeadJSON: `{}`, + MessagesJSON: `[]`, + }, + } + if _, err := service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{ + SessionID: "session-1", + CheckpointID: "cp-1", + }); err == nil || !strings.Contains(err.Error(), "session mismatch") { + t.Fatalf("RestoreCheckpoint() error = %v, want session mismatch", err) + } + + for _, tc := range []struct { + name string + record agentsession.CheckpointRecord + sessionCP *agentsession.SessionCheckpoint + wantSubstr string + }{ + { + name: "status must be available", + record: agentsession.CheckpointRecord{ + CheckpointID: "cp-status", + SessionID: "session-1", + Status: agentsession.CheckpointStatusRestored, + Restorable: true, + }, + sessionCP: &agentsession.SessionCheckpoint{HeadJSON: `{}`, MessagesJSON: `[]`}, + wantSubstr: "status is restored", + }, + { + name: "checkpoint must be restorable", + record: agentsession.CheckpointRecord{ + CheckpointID: "cp-restorable", + SessionID: "session-1", + Status: agentsession.CheckpointStatusAvailable, + Restorable: false, + }, + sessionCP: &agentsession.SessionCheckpoint{HeadJSON: `{}`, MessagesJSON: `[]`}, + wantSubstr: "not restorable", + }, + { + name: "session checkpoint data is required", + record: agentsession.CheckpointRecord{ + CheckpointID: "cp-session-data", + SessionID: "session-1", + Status: agentsession.CheckpointStatusAvailable, + Restorable: true, + }, + sessionCP: nil, + wantSubstr: "no session checkpoint data", + }, + { + name: "head json must be valid", + record: agentsession.CheckpointRecord{ + CheckpointID: "cp-head-json", + SessionID: "session-1", + Status: agentsession.CheckpointStatusAvailable, + Restorable: true, + }, + sessionCP: &agentsession.SessionCheckpoint{HeadJSON: `{invalid`, MessagesJSON: `[]`}, + wantSubstr: "unmarshal head", + }, + { + name: "messages json must be valid", + record: agentsession.CheckpointRecord{ + CheckpointID: "cp-messages-json", + SessionID: "session-1", + Status: agentsession.CheckpointStatusAvailable, + Restorable: true, + }, + sessionCP: &agentsession.SessionCheckpoint{HeadJSON: `{}`, MessagesJSON: `{invalid`}, + wantSubstr: "unmarshal messages", + }, + } { + t.Run(tc.name, func(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + tc.record.SessionID = fixture.session.ID + spy := &checkpointStoreSpy{ + getRecord: tc.record, + getSessionCP: tc.sessionCP, + } + service := &Service{ + sessionStore: fixture.sessionStore, + checkpointStore: spy, + perEditStore: checkpoint.NewPerEditSnapshotStore(t.TempDir(), fixture.workdir), + events: make(chan RuntimeEvent, 8), + } + _, err := service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{ + SessionID: fixture.session.ID, + CheckpointID: tc.record.CheckpointID, + }) + if err == nil || !strings.Contains(err.Error(), tc.wantSubstr) { + t.Fatalf("RestoreCheckpoint() error = %v, want substring %q", err, tc.wantSubstr) + } + }) + } +} + +func TestCheckpointDiffSelectsLatestCodeCheckpointAndRejectsSessionOnlyTarget(t *testing.T) { + now := time.Now().UTC() + workdir := t.TempDir() + projectDir := t.TempDir() + perEditStore := checkpoint.NewPerEditSnapshotStore(projectDir, workdir) + target := filepath.Join(workdir, "tracked.txt") + if err := os.WriteFile(target, []byte("one\n"), 0o644); err != nil { + t.Fatalf("WriteFile(cp1 base) error = %v", err) + } + if _, err := perEditStore.CapturePreWrite(target); err != nil { + t.Fatalf("CapturePreWrite(cp1) error = %v", err) + } + if err := os.WriteFile(target, []byte("two\n"), 0o644); err != nil { + t.Fatalf("WriteFile(cp1 next) error = %v", err) + } + if _, err := perEditStore.Finalize("cp-1"); err != nil { + t.Fatalf("Finalize(cp-1) error = %v", err) + } + perEditStore.Reset() + + if _, err := perEditStore.CapturePreWrite(target); err != nil { + t.Fatalf("CapturePreWrite(cp2) error = %v", err) + } + if err := os.WriteFile(target, []byte("three\n"), 0o644); err != nil { + t.Fatalf("WriteFile(cp2 next) error = %v", err) + } + if _, err := perEditStore.Finalize("cp-2"); err != nil { + t.Fatalf("Finalize(cp-2) error = %v", err) + } + perEditStore.Reset() + + spy := &checkpointStoreSpy{ + listRecords: []agentsession.CheckpointRecord{ + { + CheckpointID: "session-only", + SessionID: "session-1", + CreatedAt: now.Add(2 * time.Second), + CodeCheckpointRef: "", + }, + { + CheckpointID: "cp-2", + SessionID: "session-1", + CreatedAt: now.Add(time.Second), + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-2"), + }, + { + CheckpointID: "cp-1", + SessionID: "session-1", + CreatedAt: now, + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-1"), + }, + }, + } + service := &Service{ + checkpointStore: spy, + perEditStore: perEditStore, + } + + result, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{SessionID: "session-1"}) + if err != nil { + t.Fatalf("CheckpointDiff() error = %v", err) + } + if result.CheckpointID != "cp-2" || result.PrevCheckpointID != "cp-1" { + t.Fatalf("CheckpointDiff() = %#v, want latest code checkpoint pair", result) + } + + if _, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{ + SessionID: "session-1", + CheckpointID: "session-only", + }); err == nil || !strings.Contains(err.Error(), "not found or has no code snapshot") { + t.Fatalf("CheckpointDiff() error = %v, want session-only target rejection", err) + } +} + +func TestCheckpointDiffRejectsMissingStateAndReturnsEmptyWhenNoPreviousSnapshot(t *testing.T) { + service := &Service{} + if _, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{}); err == nil { + t.Fatal("expected store availability error") + } + + service = &Service{ + checkpointStore: &checkpointStoreSpy{}, + perEditStore: checkpoint.NewPerEditSnapshotStore(t.TempDir(), t.TempDir()), + } + if _, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{}); err == nil || !strings.Contains(err.Error(), "session_id required") { + t.Fatalf("CheckpointDiff() error = %v, want session_id validation", err) + } + + spy := &checkpointStoreSpy{ + listRecords: []agentsession.CheckpointRecord{ + { + CheckpointID: "cp-only", + SessionID: "session-1", + CreatedAt: time.Now().UTC(), + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-only"), + }, + }, + } + service = &Service{ + checkpointStore: spy, + perEditStore: checkpoint.NewPerEditSnapshotStore(t.TempDir(), t.TempDir()), + } + result, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{SessionID: "session-1"}) + if err != nil { + t.Fatalf("CheckpointDiff() error = %v", err) + } + if result.CheckpointID != "cp-only" || result.PrevCheckpointID != "" || result.Patch != "" { + t.Fatalf("CheckpointDiff() = %#v, want latest checkpoint without previous diff", result) + } +} + +func mustReadRuntimeFile(t *testing.T, path string) []byte { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile(%s) error = %v", path, err) + } + return data +} diff --git a/internal/runtime/checkpoint_gate.go b/internal/runtime/checkpoint_gate.go new file mode 100644 index 00000000..a74bef2b --- /dev/null +++ b/internal/runtime/checkpoint_gate.go @@ -0,0 +1,193 @@ +package runtime + +import ( + "context" + "encoding/json" + "fmt" + "log" + "strings" + "time" + + "neo-code/internal/checkpoint" + agentsession "neo-code/internal/session" +) + +// createStartOfTurnCheckpoint 在每轮 turn 开始时创建检查点。 +// 把上一轮 turn 的 pending capture 固化为 cp_.json;pending 为空时退化为 session-only。 +// 返回 error 由调用方发 warning event;失败不阻塞执行。 +func (s *Service) createStartOfTurnCheckpoint(ctx context.Context, state *runState) error { + if s.checkpointStore == nil || s.perEditStore == nil { + return nil + } + + state.mu.Lock() + session := state.session + runID := state.runID + state.mu.Unlock() + + checkpointID := agentsession.NewID("checkpoint") + written, err := s.perEditStore.Finalize(checkpointID) + if err != nil { + return fmt.Errorf("checkpoint: finalize per-edit: %w", err) + } + + if !written { + return s.createSessionOnlyCheckpoint(ctx, session, runID, state, agentsession.CheckpointReasonPreWrite) + } + defer s.perEditStore.Reset() + return s.createCheckpointRecord(ctx, session, runID, state, checkpointID, agentsession.CheckpointReasonPreWrite) +} + +// createEndOfTurnCheckpoint 在工具执行完成后创建代码检查点。 +// hasWorkspaceWrite=false 时不创建(避免空 checkpoint);为 true 时 Finalize 当前 pending。 +// 失败仅 log,不阻塞主流程。 +func (s *Service) createEndOfTurnCheckpoint(ctx context.Context, state *runState, hasWorkspaceWrite bool) { + if s.checkpointStore == nil || s.perEditStore == nil { + return + } + if !hasWorkspaceWrite { + return + } + + state.mu.Lock() + session := state.session + runID := state.runID + state.mu.Unlock() + + checkpointID := agentsession.NewID("checkpoint") + written, err := s.perEditStore.Finalize(checkpointID) + if err != nil { + log.Printf("checkpoint: end-of-turn finalize: %v", err) + return + } + if !written { + return + } + defer s.perEditStore.Reset() + if err := s.createCheckpointRecord(ctx, session, runID, state, checkpointID, agentsession.CheckpointReasonEndOfTurn); err != nil { + log.Printf("checkpoint: end-of-turn record: %v", err) + } +} + +// createCheckpointRecord 写入 SQLite checkpoint 记录 + session 快照,并发出 EventCheckpointCreated。 +// CodeCheckpointRef 复用为 "peredit:",由 per-edit 后端解释为版本化文件历史的引用。 +func (s *Service) createCheckpointRecord( + ctx context.Context, + session agentsession.Session, + runID string, + state *runState, + checkpointID string, + reason agentsession.CheckpointReason, +) error { + head := session.HeadSnapshot() + headJSON, err := json.Marshal(head) + if err != nil { + _ = s.perEditStore.DeleteCheckpoint(checkpointID) + return fmt.Errorf("checkpoint: marshal head: %w", err) + } + messagesJSON, err := json.Marshal(session.Messages) + if err != nil { + _ = s.perEditStore.DeleteCheckpoint(checkpointID) + return fmt.Errorf("checkpoint: marshal messages: %w", err) + } + + effectiveWorkdir := strings.TrimSpace(session.Workdir) + now := time.Now() + ref := checkpoint.RefForPerEditCheckpoint(checkpointID) + + record := agentsession.CheckpointRecord{ + CheckpointID: checkpointID, + WorkspaceKey: agentsession.WorkspacePathKey(effectiveWorkdir), + SessionID: session.ID, + RunID: runID, + Workdir: effectiveWorkdir, + CreatedAt: now, + Reason: reason, + CodeCheckpointRef: ref, + Restorable: true, + Status: agentsession.CheckpointStatusCreating, + } + sessionCP := agentsession.SessionCheckpoint{ + ID: agentsession.NewID("sc"), + SessionID: session.ID, + HeadJSON: string(headJSON), + MessagesJSON: string(messagesJSON), + CreatedAt: now, + } + + saved, err := s.checkpointStore.CreateCheckpoint(ctx, checkpoint.CreateCheckpointInput{ + Record: record, + SessionCP: sessionCP, + }) + if err != nil { + _ = s.perEditStore.DeleteCheckpoint(checkpointID) + return fmt.Errorf("checkpoint: db write: %w", err) + } + + s.emitRunScoped(ctx, EventCheckpointCreated, state, CheckpointCreatedPayload{ + CheckpointID: saved.CheckpointID, + CodeCheckpointRef: saved.CodeCheckpointRef, + SessionCheckpointRef: saved.SessionCheckpointRef, + CommitHash: "", + Reason: string(saved.Reason), + }) + return nil +} + +// createSessionOnlyCheckpoint 创建仅含 session 状态的 checkpoint(无代码引用),用于无 pending 写入时的边界标记。 +func (s *Service) createSessionOnlyCheckpoint( + ctx context.Context, + session agentsession.Session, + runID string, + state *runState, + reason agentsession.CheckpointReason, +) error { + checkpointID := agentsession.NewID("checkpoint") + now := time.Now() + + head := session.HeadSnapshot() + headJSON, err := json.Marshal(head) + if err != nil { + return fmt.Errorf("checkpoint: marshal session-only head: %w", err) + } + messagesJSON, err := json.Marshal(session.Messages) + if err != nil { + return fmt.Errorf("checkpoint: marshal session-only messages: %w", err) + } + + record := agentsession.CheckpointRecord{ + CheckpointID: checkpointID, + WorkspaceKey: agentsession.WorkspacePathKey(session.Workdir), + SessionID: session.ID, + RunID: runID, + Workdir: session.Workdir, + CreatedAt: now, + Reason: reason, + Restorable: true, + Status: agentsession.CheckpointStatusCreating, + } + sessionCP := agentsession.SessionCheckpoint{ + ID: agentsession.NewID("sc"), + SessionID: session.ID, + HeadJSON: string(headJSON), + MessagesJSON: string(messagesJSON), + CreatedAt: now, + } + + saved, err := s.checkpointStore.CreateCheckpoint(ctx, checkpoint.CreateCheckpointInput{ + Record: record, + SessionCP: sessionCP, + }) + if err != nil { + return fmt.Errorf("checkpoint: session-only create: %w", err) + } + + s.emitRunScoped(ctx, EventCheckpointCreated, state, CheckpointCreatedPayload{ + CheckpointID: saved.CheckpointID, + CodeCheckpointRef: "", + SessionCheckpointRef: saved.SessionCheckpointRef, + CommitHash: "", + Reason: string(saved.Reason), + }) + return nil +} diff --git a/internal/runtime/checkpoint_restore.go b/internal/runtime/checkpoint_restore.go new file mode 100644 index 00000000..09db09fe --- /dev/null +++ b/internal/runtime/checkpoint_restore.go @@ -0,0 +1,390 @@ +package runtime + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "neo-code/internal/checkpoint" + providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" +) + +// GatewayRestoreInput 描述来自 Gateway 的 checkpoint 恢复请求。 +type GatewayRestoreInput struct { + SessionID string `json:"session_id"` + CheckpointID string `json:"checkpoint_id"` + Force bool `json:"force,omitempty"` +} + +// RestoreResult 描述 restore/undo 操作的结果。 +// per-edit 后端只还原本快照覆盖的文件,因此 Conflict 字段恒为空,仅保留以维持网关契约。 +type RestoreResult struct { + CheckpointID string `json:"checkpoint_id"` + SessionID string `json:"session_id"` + Conflict *checkpoint.ConflictResult `json:"conflict,omitempty"` +} + +// RestoreCheckpoint 恢复指定 checkpoint 的会话和工作区状态。 +// per-edit 后端只还原本快照覆盖的文件,不会破坏 agent 未触碰的文件,因此不再做冲突检测。 +// input.Force 字段保留以维持网关 API 契约。 +func (s *Service) RestoreCheckpoint(ctx context.Context, input GatewayRestoreInput) (RestoreResult, error) { + if s.checkpointStore == nil || s.perEditStore == nil { + return RestoreResult{}, fmt.Errorf("checkpoint: store not available") + } + + sessionID := strings.TrimSpace(input.SessionID) + checkpointID := strings.TrimSpace(input.CheckpointID) + if sessionID == "" || checkpointID == "" { + return RestoreResult{}, fmt.Errorf("checkpoint: session_id and checkpoint_id required") + } + + // 1. Load checkpoint record + record, sessionCP, err := s.checkpointStore.GetCheckpoint(ctx, checkpointID) + if err != nil { + return RestoreResult{}, err + } + if record.SessionID != sessionID { + return RestoreResult{}, fmt.Errorf("checkpoint: session mismatch") + } + if record.Status != agentsession.CheckpointStatusAvailable { + return RestoreResult{}, fmt.Errorf("checkpoint: status is %s, expected available", record.Status) + } + if !record.Restorable { + return RestoreResult{}, fmt.Errorf("checkpoint: not restorable") + } + + // 2. Pre-restore guard checkpoint:把当前 pending 固化为 guard cp,以便 undo 回到 restore 之前。 + guardID := agentsession.NewID("checkpoint") + guardWritten, finalizeErr := s.perEditStore.Finalize(guardID) + if finalizeErr != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: finalize guard: %w", finalizeErr) + } + if guardWritten { + s.perEditStore.Reset() + } + guardRecord, guardErr := s.createGuardCheckpoint(ctx, sessionID, record.RunID, guardID, guardWritten) + if guardErr != nil { + if guardWritten { + _ = s.perEditStore.DeleteCheckpoint(guardID) + } + return RestoreResult{}, fmt.Errorf("checkpoint: create guard: %w", guardErr) + } + + // 3. Restore code via per-edit store(不在 cp.FileVersions 中的文件保持不变)。 + // Guard checkpoint 恢复时使用 RestoreExact:guard 中存储的 version 就是 restore 前的 pre-write 状态, + // 而 Restore 的 v_next 语义在 guard 上通常是 no-op(guard 之后没有新的 capture)。 + isGuardRestore := record.Reason == agentsession.CheckpointReasonGuard + if checkpoint.IsPerEditRef(record.CodeCheckpointRef) { + perEditID := checkpoint.PerEditCheckpointIDFromRef(record.CodeCheckpointRef) + if perEditID != "" { + if isGuardRestore { + if err := s.perEditStore.RestoreExact(ctx, perEditID); err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: restore code: %w", err) + } + } else { + if err := s.perEditStore.Restore(ctx, perEditID); err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: restore code: %w", err) + } + } + } + } + + // 4. Unmarshal session checkpoint + if sessionCP == nil { + return RestoreResult{}, fmt.Errorf("checkpoint: no session checkpoint data") + } + var head agentsession.SessionHead + if err := json.Unmarshal([]byte(sessionCP.HeadJSON), &head); err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: unmarshal head: %w", err) + } + var messages []providertypes.Message + if err := json.Unmarshal([]byte(sessionCP.MessagesJSON), &messages); err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: unmarshal messages: %w", err) + } + + // 5. Determine checkpoint IDs to mark + markAvailableIDs := []string{guardRecord.CheckpointID} + var markRestoredIDs []string + allRecords, listErr := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{}) + if listErr == nil { + for _, r := range allRecords { + if r.CreatedAt.After(record.CreatedAt) && r.Status == agentsession.CheckpointStatusAvailable && r.Reason != agentsession.CheckpointReasonGuard { + markRestoredIDs = append(markRestoredIDs, r.CheckpointID) + } + } + } + + // 6. Restore session + update checkpoint statuses (single transaction) + restoreInput := checkpoint.RestoreCheckpointInput{ + SessionID: sessionID, + Head: head, + Messages: messages, + UpdatedAt: time.Now(), + MarkAvailableIDs: markAvailableIDs, + MarkRestoredIDs: markRestoredIDs, + } + if err := s.checkpointStore.RestoreCheckpoint(ctx, restoreInput); err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: restore: %w", err) + } + + // 7. Update runtime session if it's the current session + s.updateRuntimeSessionAfterRestore(sessionID, head, messages) + + s.emitRunScoped(ctx, EventCheckpointRestored, nil, CheckpointRestoredPayload{ + CheckpointID: checkpointID, + SessionID: sessionID, + GuardCheckpointID: guardRecord.CheckpointID, + }) + return RestoreResult{ + CheckpointID: checkpointID, + SessionID: sessionID, + }, nil +} + +// UndoRestoreCheckpoint 撤销最近一次 restore,通过 pre_restore_guard 恢复到 restore 前的状态。 +func (s *Service) UndoRestoreCheckpoint(ctx context.Context, sessionID string) (RestoreResult, error) { + if s.checkpointStore == nil { + return RestoreResult{}, fmt.Errorf("checkpoint: store not available") + } + + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + return RestoreResult{}, fmt.Errorf("checkpoint: session_id required") + } + + records, err := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{ + Limit: 20, + RestorableOnly: true, + }) + if err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: list for undo: %w", err) + } + + var guardRecord *agentsession.CheckpointRecord + for _, r := range records { + if r.Reason == agentsession.CheckpointReasonGuard { + guardRecord = &r + break + } + } + if guardRecord == nil { + return RestoreResult{}, fmt.Errorf("checkpoint: no guard checkpoint found for undo") + } + + result, err := s.RestoreCheckpoint(ctx, GatewayRestoreInput{ + SessionID: sessionID, + CheckpointID: guardRecord.CheckpointID, + Force: true, + }) + if err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: undo restore: %w", err) + } + + s.emitRunScoped(ctx, EventCheckpointUndoRestore, nil, CheckpointUndoRestorePayload{ + GuardCheckpointID: guardRecord.CheckpointID, + SessionID: sessionID, + }) + return result, nil +} + +// createGuardCheckpoint 创建 pre_restore_guard 类型的 checkpoint。 +// guardWritten=true 时 guardID 对应的 per-edit cp_.json 已写入,CodeCheckpointRef 指向它;否则仅记 session 状态。 +func (s *Service) createGuardCheckpoint(ctx context.Context, sessionID, runID, guardID string, guardWritten bool) (agentsession.CheckpointRecord, error) { + session, err := s.sessionStore.LoadSession(ctx, sessionID) + if err != nil { + return agentsession.CheckpointRecord{}, fmt.Errorf("checkpoint: load session for guard: %w", err) + } + + head := session.HeadSnapshot() + headJSON, err := json.Marshal(head) + if err != nil { + return agentsession.CheckpointRecord{}, fmt.Errorf("checkpoint: marshal guard head: %w", err) + } + messagesJSON, err := json.Marshal(session.Messages) + if err != nil { + return agentsession.CheckpointRecord{}, fmt.Errorf("checkpoint: marshal guard messages: %w", err) + } + + var ref string + if guardWritten { + ref = checkpoint.RefForPerEditCheckpoint(guardID) + } + + now := time.Now() + record := agentsession.CheckpointRecord{ + CheckpointID: guardID, + WorkspaceKey: agentsession.WorkspacePathKey(session.Workdir), + SessionID: sessionID, + RunID: runID, + Workdir: session.Workdir, + CreatedAt: now, + Reason: agentsession.CheckpointReasonGuard, + CodeCheckpointRef: ref, + Restorable: true, + Status: agentsession.CheckpointStatusCreating, + } + sessionCP := agentsession.SessionCheckpoint{ + ID: agentsession.NewID("sc"), + SessionID: sessionID, + HeadJSON: string(headJSON), + MessagesJSON: string(messagesJSON), + CreatedAt: now, + } + + saved, err := s.checkpointStore.CreateCheckpoint(ctx, checkpoint.CreateCheckpointInput{ + Record: record, + SessionCP: sessionCP, + }) + if err != nil { + return agentsession.CheckpointRecord{}, err + } + + s.emitRunScoped(ctx, EventCheckpointCreated, nil, CheckpointCreatedPayload{ + CheckpointID: saved.CheckpointID, + CodeCheckpointRef: saved.CodeCheckpointRef, + SessionCheckpointRef: saved.SessionCheckpointRef, + CommitHash: "", + Reason: string(saved.Reason), + }) + return saved, nil +} + +// ListCheckpoints 查询指定会话的 checkpoint 列表。 +func (s *Service) ListCheckpoints(ctx context.Context, sessionID string, opts checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) { + if s.checkpointStore == nil { + return nil, fmt.Errorf("checkpoint: store not available") + } + return s.checkpointStore.ListCheckpoints(ctx, sessionID, opts) +} + +// updateRuntimeSessionAfterRestore 使运行时快照缓存失效。 +// GetRuntimeSnapshot 会从 DB 重新加载恢复后的状态,而非返回旧缓存。 +func (s *Service) updateRuntimeSessionAfterRestore(sessionID string, head agentsession.SessionHead, messages []providertypes.Message) { + normalized := strings.TrimSpace(sessionID) + if normalized == "" { + return + } + s.runtimeSnapshotMu.Lock() + delete(s.runtimeSnapshots, normalized) + s.runtimeSnapshotMu.Unlock() +} + +// CheckpointDiffInput 描述 checkpoint diff 查询请求。 +type CheckpointDiffInput struct { + SessionID string `json:"session_id"` + CheckpointID string `json:"checkpoint_id,omitempty"` // 可选,为空则查最新代码检查点 +} + +// CheckpointDiffResult 描述两个相邻代码检查点之间的差异。 +type CheckpointDiffResult struct { + CheckpointID string `json:"checkpoint_id"` + PrevCheckpointID string `json:"prev_checkpoint_id,omitempty"` + CommitHash string `json:"commit_hash,omitempty"` + PrevCommitHash string `json:"prev_commit_hash,omitempty"` + Files FileDiffs `json:"files"` + Patch string `json:"patch,omitempty"` +} + +// FileDiffs 描述 diff 中的文件变更列表。 +type FileDiffs struct { + Added []string `json:"added,omitempty"` + Deleted []string `json:"deleted,omitempty"` + Modified []string `json:"modified,omitempty"` +} + +// CheckpointDiff 查询两个相邻代码检查点之间的差异,单一 per-edit 后端路径。 +func (s *Service) CheckpointDiff(ctx context.Context, input CheckpointDiffInput) (CheckpointDiffResult, error) { + if s.checkpointStore == nil || s.perEditStore == nil { + return CheckpointDiffResult{}, fmt.Errorf("checkpoint: store not available") + } + + sessionID := strings.TrimSpace(input.SessionID) + if sessionID == "" { + return CheckpointDiffResult{}, fmt.Errorf("checkpoint: session_id required") + } + + records, err := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{Limit: 20}) + if err != nil { + return CheckpointDiffResult{}, fmt.Errorf("checkpoint: list for diff: %w", err) + } + + targetID := strings.TrimSpace(input.CheckpointID) + var targetRecord *agentsession.CheckpointRecord + if targetID != "" { + for i := range records { + if records[i].CheckpointID != targetID { + continue + } + if !checkpoint.IsPerEditRef(records[i].CodeCheckpointRef) { + continue + } + targetRecord = &records[i] + break + } + if targetRecord == nil { + return CheckpointDiffResult{}, fmt.Errorf("checkpoint: %s not found or has no code snapshot", targetID) + } + } else { + for i := range records { + if !checkpoint.IsPerEditRef(records[i].CodeCheckpointRef) { + continue + } + targetRecord = &records[i] + break + } + if targetRecord == nil { + return CheckpointDiffResult{}, fmt.Errorf("checkpoint: no code checkpoint found") + } + } + + var prevRecord *agentsession.CheckpointRecord + for i := range records { + if records[i].CheckpointID == targetRecord.CheckpointID { + continue + } + if !records[i].CreatedAt.Before(targetRecord.CreatedAt) { + continue + } + if !checkpoint.IsPerEditRef(records[i].CodeCheckpointRef) { + continue + } + prevRecord = &records[i] + break + } + + result := CheckpointDiffResult{ + CheckpointID: targetRecord.CheckpointID, + } + if prevRecord == nil { + return result, nil + } + result.PrevCheckpointID = prevRecord.CheckpointID + + fromID := checkpoint.PerEditCheckpointIDFromRef(prevRecord.CodeCheckpointRef) + toID := checkpoint.PerEditCheckpointIDFromRef(targetRecord.CodeCheckpointRef) + patch, err := s.perEditStore.Diff(ctx, fromID, toID) + if err != nil { + return CheckpointDiffResult{}, fmt.Errorf("checkpoint: per-edit diff: %w", err) + } + result.Patch = patch + + changes, err := s.perEditStore.ChangedFiles(ctx, fromID, toID) + if err != nil { + return CheckpointDiffResult{}, fmt.Errorf("checkpoint: per-edit changed files: %w", err) + } + for _, c := range changes { + switch c.Kind { + case checkpoint.FileChangeAdded: + result.Files.Added = append(result.Files.Added, c.Path) + case checkpoint.FileChangeDeleted: + result.Files.Deleted = append(result.Files.Deleted, c.Path) + case checkpoint.FileChangeModified: + result.Files.Modified = append(result.Files.Modified, c.Path) + } + } + + return result, nil +} diff --git a/internal/runtime/checkpoint_resume.go b/internal/runtime/checkpoint_resume.go new file mode 100644 index 00000000..1c3b027e --- /dev/null +++ b/internal/runtime/checkpoint_resume.go @@ -0,0 +1,38 @@ +package runtime + +import ( + "context" + "log" + "time" + + agentsession "neo-code/internal/session" +) + +// updateResumeCheckpoint 在 phase 转换时写入或更新 ResumeCheckpoint。 +// 失败仅 log,不阻塞主流程。 +func (s *Service) updateResumeCheckpoint(ctx context.Context, state *runState, phase string, completionState string) { + if s.checkpointStore == nil { + return + } + + state.mu.Lock() + session := state.session + runID := state.runID + turn := state.turn + state.mu.Unlock() + + rc := agentsession.ResumeCheckpoint{ + ID: agentsession.NewID("rc"), + WorkspaceKey: agentsession.WorkspacePathKey(session.Workdir), + RunID: runID, + SessionID: session.ID, + Turn: turn, + Phase: phase, + CompletionState: completionState, + UpdatedAt: time.Now(), + } + + if err := s.checkpointStore.SetResumeCheckpoint(ctx, rc); err != nil { + log.Printf("checkpoint: set resume checkpoint for %s: %v", session.ID, err) + } +} diff --git a/internal/runtime/compact.go b/internal/runtime/compact.go index e39cfb59..55b001c3 100644 --- a/internal/runtime/compact.go +++ b/internal/runtime/compact.go @@ -2,10 +2,12 @@ package runtime import ( "context" + "encoding/json" "errors" "strings" "time" + "neo-code/internal/checkpoint" "neo-code/internal/config" contextcompact "neo-code/internal/context/compact" providertypes "neo-code/internal/provider/types" @@ -155,6 +157,8 @@ func (s *Service) runCompactForSession( return failCompact(errors.New(reason)) } + s.createCompactCheckpoint(ctx, runID, session) + s.emit(ctx, EventCompactStart, runID, session.ID, string(mode)) result, err := runner.Run(ctx, contextcompact.Input{ @@ -248,3 +252,58 @@ func resolveCompactProviderSelection(session agentsession.Session, cfg config.Co } return resolved, strings.TrimSpace(cfg.CurrentModel), nil } + +// createCompactCheckpoint 为 compact 操作创建 session-only checkpoint。 +func (s *Service) createCompactCheckpoint(ctx context.Context, runID string, session agentsession.Session) { + if s.checkpointStore == nil { + return + } + + checkpointID := agentsession.NewID("checkpoint") + now := time.Now() + + head := session.HeadSnapshot() + headJSON, err := json.Marshal(head) + if err != nil { + return + } + messagesJSON, err := json.Marshal(session.Messages) + if err != nil { + return + } + + record := agentsession.CheckpointRecord{ + CheckpointID: checkpointID, + WorkspaceKey: agentsession.WorkspacePathKey(session.Workdir), + SessionID: session.ID, + RunID: runID, + Workdir: session.Workdir, + CreatedAt: now, + Reason: agentsession.CheckpointReasonCompact, + Restorable: true, + Status: agentsession.CheckpointStatusCreating, + } + + // Per-edit snapshot if pending writes exist this turn. + if s.perEditStore != nil { + if written, err := s.perEditStore.Finalize(checkpointID); err == nil && written { + record.CodeCheckpointRef = checkpoint.RefForPerEditCheckpoint(checkpointID) + s.perEditStore.Reset() + } + } + + sessionCP := agentsession.SessionCheckpoint{ + ID: agentsession.NewID("sc"), + SessionID: session.ID, + HeadJSON: string(headJSON), + MessagesJSON: string(messagesJSON), + CreatedAt: now, + } + + if _, err := s.checkpointStore.CreateCheckpoint(ctx, checkpoint.CreateCheckpointInput{ + Record: record, + SessionCP: sessionCP, + }); err != nil { + return + } +} diff --git a/internal/runtime/events.go b/internal/runtime/events.go index 4dcb0a8c..47d127d1 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -325,6 +325,8 @@ const ( EventToolResult EventType = "tool_result" // EventToolChunk 表示工具流式输出分片。 EventToolChunk EventType = "tool_chunk" + // EventToolDiff 表示写工具修改了某个文件。 + EventToolDiff EventType = "tool_diff" // EventRunCanceled 表示运行被取消。 EventRunCanceled EventType = "run_canceled" // EventError 表示运行出现终止错误。 @@ -415,6 +417,17 @@ const ( EventSubAgentSnapshotUpdated EventType = "subagent_snapshot_updated" // EventTodoSnapshotUpdated 表示 todo 快照已更新。 EventTodoSnapshotUpdated EventType = "todo_snapshot_updated" + + // EventCheckpointCreated 表示 pre-write checkpoint 已创建。 + EventCheckpointCreated EventType = "checkpoint_created" + // EventCheckpointWarning 表示 checkpoint 创建过程中出现非致命告警。 + EventCheckpointWarning EventType = "checkpoint_warning" + // EventCheckpointRestored 表示 checkpoint 已成功恢复。 + EventCheckpointRestored EventType = "checkpoint_restored" + // EventCheckpointUndoRestore 表示 restore 已撤销。 + EventCheckpointUndoRestore EventType = "checkpoint_undo_restore" + // EventBashSideEffect 表示 bash 命令在 workdir 内产生了文件变更。 + EventBashSideEffect EventType = "bash_side_effect" ) // TokenUsagePayload 承载单轮 token 用量统计。 @@ -427,3 +440,65 @@ type TokenUsagePayload struct { SessionInputTokens int `json:"session_input_tokens"` SessionOutputTokens int `json:"session_output_tokens"` } + +// CheckpointCreatedPayload 描述 checkpoint 创建成功事件。 +type CheckpointCreatedPayload struct { + CheckpointID string `json:"checkpoint_id"` + CodeCheckpointRef string `json:"code_checkpoint_ref"` + SessionCheckpointRef string `json:"session_checkpoint_ref"` + CommitHash string `json:"commit_hash"` + Reason string `json:"reason"` +} + +// CheckpointWarningPayload 描述 checkpoint 创建过程中非致命告警。 +type CheckpointWarningPayload struct { + Error string `json:"error"` + Phase string `json:"phase"` +} + +// CheckpointRestoredPayload 描述 checkpoint 恢复成功事件。 +type CheckpointRestoredPayload struct { + CheckpointID string `json:"checkpoint_id"` + SessionID string `json:"session_id"` + GuardCheckpointID string `json:"guard_checkpoint_id"` +} + +// CheckpointUndoRestorePayload 描述 restore 撤销事件。 +type CheckpointUndoRestorePayload struct { + GuardCheckpointID string `json:"guard_checkpoint_id"` + SessionID string `json:"session_id"` +} + +// FileChange 描述一次文件变更的最小信息。 +type FileChange struct { + Path string `json:"path"` + Kind string `json:"kind"` // "added" | "modified" | "deleted" +} + +// FileDiffEntry 描述单个文件的精确 diff(多文件工具下使用)。 +type FileDiffEntry struct { + Path string `json:"path"` + Diff string `json:"diff,omitempty"` + WasNew bool `json:"was_new,omitempty"` +} + +// ToolDiffPayload 描述写工具修改了哪些文件。 +// 单文件兼容字段(FilePath/Diff/WasNew)保留以支持现有消费方;多文件工具填充 Files+Diffs。 +type ToolDiffPayload struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + FilePath string `json:"file_path"` + Diff string `json:"diff,omitempty"` + WasNew bool `json:"was_new,omitempty"` + Files []FileChange `json:"files,omitempty"` + Diffs []FileDiffEntry `json:"diffs,omitempty"` +} + +// BashSideEffectPayload 描述 bash 命令在 workdir 内的文件变更。 +type BashSideEffectPayload struct { + ToolCallID string `json:"tool_call_id"` + Command string `json:"command,omitempty"` + Changes []FileChange `json:"changes"` + PreemptivelyCapturedPaths []string `json:"preemptively_captured_paths,omitempty"` + UncoveredPaths []string `json:"uncovered_paths,omitempty"` +} diff --git a/internal/runtime/file_snapshot.go b/internal/runtime/file_snapshot.go new file mode 100644 index 00000000..b2789868 --- /dev/null +++ b/internal/runtime/file_snapshot.go @@ -0,0 +1,70 @@ +package runtime + +import ( + "os" + "strings" + + "github.com/pmezard/go-difflib/difflib" +) + +// fileSnapshot 工具执行前的文件状态快照,用于在执行后计算精确 diff。 +type fileSnapshot struct { + path string + content []byte + existed bool +} + +// captureFileSnapshot 读取目标文件当前内容并打包成快照。文件不存在时 existed=false。 +func captureFileSnapshot(path string) fileSnapshot { + snap := fileSnapshot{path: path} + content, err := os.ReadFile(path) + if err == nil { + snap.content = content + snap.existed = true + } + return snap +} + +// Diff 对比快照内容和文件当前内容,返回 unified diff。 +// 内容未变化或文件仍不存在时返回空字符串。 +func (s fileSnapshot) Diff() (string, error) { + current, err := os.ReadFile(s.path) + if err != nil { + if os.IsNotExist(err) { + if !s.existed { + return "", nil + } + return computeUnifiedDiff(string(s.content), "", s.path) + } + return "", err + } + if s.existed && string(current) == string(s.content) { + return "", nil + } + oldContent := "" + if s.existed { + oldContent = string(s.content) + } + return computeUnifiedDiff(oldContent, string(current), s.path) +} + +// WasNew 判断该文件在 Capture 时是否不存在(agent 新建了该文件)。 +func (s fileSnapshot) WasNew() bool { + return !s.existed +} + +// computeUnifiedDiff 计算两段文本的 unified diff,使用 go-difflib 生成标准格式。 +func computeUnifiedDiff(oldContent, newContent, label string) (string, error) { + diff := difflib.UnifiedDiff{ + A: difflib.SplitLines(oldContent), + B: difflib.SplitLines(newContent), + FromFile: label, + ToFile: label, + Context: 3, + } + out, err := difflib.GetUnifiedDiffString(diff) + if err != nil { + return "", err + } + return strings.TrimRight(out, "\n"), nil +} diff --git a/internal/runtime/file_snapshot_test.go b/internal/runtime/file_snapshot_test.go new file mode 100644 index 00000000..c78c0e57 --- /dev/null +++ b/internal/runtime/file_snapshot_test.go @@ -0,0 +1,74 @@ +package runtime + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestCaptureFileSnapshotMissingFileMarksAsNew(t *testing.T) { + path := filepath.Join(t.TempDir(), "missing.txt") + + snap := captureFileSnapshot(path) + if !snap.WasNew() { + t.Fatal("expected missing file snapshot to be treated as new") + } + + diff, err := snap.Diff() + if err != nil { + t.Fatalf("Diff() error = %v", err) + } + if diff != "" { + t.Fatalf("Diff() = %q, want empty", diff) + } +} + +func TestFileSnapshotDiffHandlesDeletion(t *testing.T) { + path := filepath.Join(t.TempDir(), "delete.txt") + if err := os.WriteFile(path, []byte("before\n"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + snap := captureFileSnapshot(path) + if err := os.Remove(path); err != nil { + t.Fatalf("Remove() error = %v", err) + } + + diff, err := snap.Diff() + if err != nil { + t.Fatalf("Diff() error = %v", err) + } + if !strings.Contains(diff, "--- "+path) || !strings.Contains(diff, "-before") { + t.Fatalf("Diff() = %q, want deletion patch for %s", diff, path) + } +} + +func TestFileSnapshotDiffIgnoresUnchangedContent(t *testing.T) { + path := filepath.Join(t.TempDir(), "same.txt") + if err := os.WriteFile(path, []byte("same\n"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + snap := captureFileSnapshot(path) + diff, err := snap.Diff() + if err != nil { + t.Fatalf("Diff() error = %v", err) + } + if diff != "" { + t.Fatalf("Diff() = %q, want empty", diff) + } +} + +func TestComputeUnifiedDiffTrimsTrailingNewline(t *testing.T) { + diff, err := computeUnifiedDiff("one\n", "two\n", "sample.txt") + if err != nil { + t.Fatalf("computeUnifiedDiff() error = %v", err) + } + if strings.HasSuffix(diff, "\n") { + t.Fatalf("diff should be trimmed, got %q", diff) + } + if !strings.Contains(diff, "@@") || !strings.Contains(diff, "+two") || !strings.Contains(diff, "-one") { + t.Fatalf("diff = %q, want unified diff body", diff) + } +} diff --git a/internal/runtime/run.go b/internal/runtime/run.go index d55f5795..2e35795a 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -90,6 +90,17 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { s.finishRun(runToken) }() defer func() { + if statePtr != nil { + completion := "completed" + if err != nil { + if errors.Is(err, context.Canceled) { + completion = "cancelled" + } else { + completion = "error" + } + } + s.updateResumeCheckpoint(runCtx, statePtr, "stopped", completion) + } s.emitRunTermination(runCtx, input, statePtr, err) }() ctx = runCtx @@ -170,6 +181,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { return s.handleRunError(err) } s.emitRuntimeSnapshotUpdated(ctx, &state, "session_start") + s.updateResumeCheckpoint(ctx, &state, "plan", "") maxTurns := resolveRuntimeMaxTurns(initialCfg.Runtime) for turn := 0; ; turn++ { @@ -185,6 +197,14 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { if err := s.setBaseRunState(ctx, &state, baseRunStateForPlanningStage(stage)); err != nil { return s.handleRunError(err) } + if s.checkpointStore != nil { + if cpErr := s.createStartOfTurnCheckpoint(ctx, &state); cpErr != nil { + s.emitRunScoped(ctx, EventCheckpointWarning, &state, CheckpointWarningPayload{ + Error: cpErr.Error(), + Phase: "start_of_turn", + }) + } + } turnAttempt: for { @@ -344,6 +364,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { if err := s.setBaseRunState(ctx, &state, controlplane.RunStateVerify); err != nil { return s.handleRunError(err) } + s.updateResumeCheckpoint(ctx, &state, "verify", "completed") acceptanceDecision, err := s.runBeforeCompletionDecisionAcceptance( ctx, &state, @@ -434,14 +455,22 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { beforeTask := state.session.TaskState.Clone() beforeTodos := cloneTodosForPersistence(state.session.Todos) + if err := s.setBaseRunState(ctx, &state, controlplane.RunStateExecute); err != nil { return s.handleRunError(err) } + s.updateResumeCheckpoint(ctx, &state, "execute", "") summary, err := s.executeAssistantToolCalls(ctx, &state, snapshot, turnOutput.assistant) if err != nil { return s.handleRunError(err) } + // 通知 TUI 本轮修改了哪些文件 + s.emitToolDiffs(ctx, &state, summary) + + // 工具执行完成后创建代码检查点,传入 hasWorkspaceWrite 区分 agent 写操作与外界修改 + s.createEndOfTurnCheckpoint(ctx, &state, summary.HasSuccessfulWorkspaceWrite) + state.mu.Lock() state.completion = applyToolExecutionCompletion(state.completion, summary) afterTask := state.session.TaskState.Clone() @@ -476,6 +505,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { if err := s.setBaseRunState(ctx, &state, controlplane.RunStateVerify); err != nil { return s.handleRunError(err) } + s.updateResumeCheckpoint(ctx, &state, "verify", "completed") break } } @@ -664,6 +694,90 @@ func (s *Service) emitTokenUsage(ctx context.Context, state *runState, result le }) } +// emitToolDiffs 遍历本轮写操作结果,逐个 emit EventToolDiff 通知 TUI。 +func (s *Service) emitToolDiffs(ctx context.Context, state *runState, summary toolExecutionSummary) { + for _, result := range summary.Results { + if !result.Facts.WorkspaceWrite || toolResultNoopWrite(result.Metadata) { + continue + } + payload, ok := buildToolDiffPayload(result) + if !ok { + continue + } + s.emitRunScopedOptional(EventToolDiff, state, payload) + } +} + +// buildToolDiffPayload 将工具结果 metadata 中的 diff 信息组装成 ToolDiffPayload。 +// 多文件工具(filesystem_move_file 等)使用 Files+Diffs 多路径字段; +// 其他写工具继续填充兼容字段 FilePath/Diff/WasNew,保持现有消费者不破。 +func buildToolDiffPayload(result tools.ToolResult) (ToolDiffPayload, bool) { + payload := ToolDiffPayload{ + ToolCallID: result.ToolCallID, + ToolName: result.Name, + } + if multi, ok := toolResultMultiDiffs(result.Metadata); ok && len(multi) > 0 { + payload.Diffs = multi + payload.Files = make([]FileChange, 0, len(multi)) + for _, entry := range multi { + kind := "modified" + if entry.WasNew { + kind = "added" + } + payload.Files = append(payload.Files, FileChange{Path: entry.Path, Kind: kind}) + } + first := multi[0] + payload.FilePath = first.Path + payload.Diff = first.Diff + payload.WasNew = first.WasNew + return payload, true + } + filePath := toolResultFilePath(result.Metadata) + if filePath == "" { + return payload, false + } + diff, _ := result.Metadata["tool_diff"].(string) + wasNew, _ := result.Metadata["tool_diff_new"].(bool) + payload.FilePath = filePath + payload.Diff = diff + payload.WasNew = wasNew + return payload, true +} + +// toolResultMultiDiffs 从工具结果 metadata 解析多文件 diff 列表。 +func toolResultMultiDiffs(metadata map[string]any) ([]FileDiffEntry, bool) { + if metadata == nil { + return nil, false + } + raw, ok := metadata["tool_diffs"] + if !ok || raw == nil { + return nil, false + } + entries, ok := raw.([]map[string]any) + if !ok { + return nil, false + } + out := make([]FileDiffEntry, 0, len(entries)) + for _, entry := range entries { + path, _ := entry["path"].(string) + path = strings.TrimSpace(path) + if path == "" { + continue + } + diff, _ := entry["diff"].(string) + wasNew, _ := entry["was_new"].(bool) + out = append(out, FileDiffEntry{ + Path: path, + Diff: diff, + WasNew: wasNew, + }) + } + if len(out) == 0 { + return nil, false + } + return out, true +} + // applyCompactForState 在运行中执行 compact,并把结果同步回 runState。 func (s *Service) applyCompactForState( ctx context.Context, diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 13a98e6d..aa726ccc 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "neo-code/internal/checkpoint" "neo-code/internal/config" agentcontext "neo-code/internal/context" contextcompact "neo-code/internal/context/compact" @@ -150,6 +151,8 @@ type Service struct { skillsRegistry skills.Registry budgetResolver BudgetResolver hookExecutor HookExecutor + checkpointStore checkpoint.CheckpointStore + perEditStore *checkpoint.PerEditSnapshotStore events chan RuntimeEvent runtimeSnapshotMu sync.Mutex @@ -453,3 +456,9 @@ func (s *Service) SetHookExecutor(executor HookExecutor) { } s.hookExecutor = executor } + +// SetCheckpointDependencies 注入 checkpoint 存储与版本化文件历史快照后端,用于 pre-write checkpoint gate。 +func (s *Service) SetCheckpointDependencies(store checkpoint.CheckpointStore, perEdit *checkpoint.PerEditSnapshotStore) { + s.checkpointStore = store + s.perEditStore = perEdit +} diff --git a/internal/runtime/tool_diff_helpers_test.go b/internal/runtime/tool_diff_helpers_test.go new file mode 100644 index 00000000..668a6c4c --- /dev/null +++ b/internal/runtime/tool_diff_helpers_test.go @@ -0,0 +1,243 @@ +package runtime + +import ( + "context" + "testing" + + "neo-code/internal/checkpoint" + providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" + "neo-code/internal/tools" +) + +func TestBuildToolDiffPayload(t *testing.T) { + t.Run("single file payload", func(t *testing.T) { + result := tools.ToolResult{ + Name: tools.ToolNameFilesystemWriteFile, + ToolCallID: "call-1", + Metadata: map[string]any{ + "path": "main.go", + "tool_diff": "@@ -1 +1 @@", + "tool_diff_new": true, + }, + } + + payload, ok := buildToolDiffPayload(result) + if !ok { + t.Fatal("expected payload") + } + if payload.FilePath != "main.go" || payload.Diff != "@@ -1 +1 @@" || !payload.WasNew { + t.Fatalf("unexpected single payload: %#v", payload) + } + }) + + t.Run("multi file payload", func(t *testing.T) { + result := tools.ToolResult{ + Name: tools.ToolNameFilesystemMoveFile, + ToolCallID: "call-2", + Metadata: map[string]any{ + "tool_diffs": []map[string]any{ + {"path": "old.txt", "diff": "@@ -1 +0 @@", "was_new": false}, + {"path": "new.txt", "diff": "@@ -0 +1 @@", "was_new": true}, + {"path": " ", "diff": "ignored", "was_new": true}, + }, + }, + } + + payload, ok := buildToolDiffPayload(result) + if !ok { + t.Fatal("expected payload") + } + if len(payload.Files) != 2 || len(payload.Diffs) != 2 { + t.Fatalf("unexpected multi payload lengths: %#v", payload) + } + if payload.Files[0].Kind != "modified" || payload.Files[1].Kind != "added" { + t.Fatalf("unexpected file kinds: %#v", payload.Files) + } + if payload.FilePath != "old.txt" { + t.Fatalf("first file path = %q, want old.txt", payload.FilePath) + } + }) + + t.Run("missing file path returns false", func(t *testing.T) { + if _, ok := buildToolDiffPayload(tools.ToolResult{Name: tools.ToolNameFilesystemWriteFile}); ok { + t.Fatal("expected no payload when metadata has no path") + } + }) +} + +func TestToolExecutionHelperFunctions(t *testing.T) { + t.Run("toolCallTouchedPaths covers write and move payloads", func(t *testing.T) { + writePaths := toolCallTouchedPaths(providertypes.ToolCall{ + Name: tools.ToolNameFilesystemWriteFile, + Arguments: `{"path":" docs/readme.md "}`, + }, "/repo") + if len(writePaths) != 1 || writePaths[0] != "/repo/docs/readme.md" { + t.Fatalf("write toolCallTouchedPaths() = %#v", writePaths) + } + + movePaths := toolCallTouchedPaths(providertypes.ToolCall{ + Name: tools.ToolNameFilesystemMoveFile, + Arguments: `{"source_path":"src/a.txt","destination_path":" /tmp/b.txt "}`, + }, "/repo") + if len(movePaths) != 2 || movePaths[0] != "/repo/src/a.txt" || movePaths[1] != "/tmp/b.txt" { + t.Fatalf("move toolCallTouchedPaths() = %#v", movePaths) + } + + if got := toolCallTouchedPaths(providertypes.ToolCall{ + Name: tools.ToolNameFilesystemCopyFile, + Arguments: `{invalid`, + }, "/repo"); got != nil { + t.Fatalf("malformed toolCallTouchedPaths() = %#v, want nil", got) + } + }) + + t.Run("toolResultMultiDiffs parses valid entries", func(t *testing.T) { + entries, ok := toolResultMultiDiffs(map[string]any{ + "tool_diffs": []map[string]any{ + {"path": "a.txt", "diff": "a", "was_new": true}, + {"path": " ", "diff": "ignored", "was_new": false}, + }, + }) + if !ok || len(entries) != 1 { + t.Fatalf("entries=%#v ok=%v", entries, ok) + } + if entries[0].Path != "a.txt" || !entries[0].WasNew { + t.Fatalf("unexpected entry: %#v", entries[0]) + } + }) + + t.Run("toolResultFilePath trims metadata", func(t *testing.T) { + if got := toolResultFilePath(map[string]any{"path": " demo.txt "}); got != "demo.txt" { + t.Fatalf("toolResultFilePath() = %q, want demo.txt", got) + } + if got := toolResultFilePath(nil); got != "" { + t.Fatalf("toolResultFilePath(nil) = %q, want empty", got) + } + }) + + t.Run("resolveWorkdirPaths normalizes relative and absolute values", func(t *testing.T) { + paths := resolveWorkdirPaths("/repo", " a.txt ", "/tmp/demo.txt", "") + if len(paths) != 2 || paths[0] != "/repo/a.txt" || paths[1] != "/tmp/demo.txt" { + t.Fatalf("resolveWorkdirPaths() = %#v", paths) + } + }) + + t.Run("bashCommandFromCall prefers command then cmd alias", func(t *testing.T) { + if got := bashCommandFromCall(providertypes.ToolCall{Arguments: `{"command":" echo hi "}`}); got != "echo hi" { + t.Fatalf("command field = %q", got) + } + if got := bashCommandFromCall(providertypes.ToolCall{Arguments: `{"cmd":" pwd "}`}); got != "pwd" { + t.Fatalf("cmd alias = %q", got) + } + if got := bashCommandFromCall(providertypes.ToolCall{Arguments: `{invalid`}); got != "" { + t.Fatalf("invalid json should return empty command, got %q", got) + } + }) + + t.Run("collectUncoveredBashPaths removes covered and duplicate entries", func(t *testing.T) { + diff := checkpoint.FingerprintDiff{ + Added: []string{"new.txt", "new.txt"}, + Modified: []string{"tracked.txt", "covered.txt"}, + } + covered := map[string]struct{}{ + "/repo/covered.txt": {}, + } + got := collectUncoveredBashPaths("/repo", diff, covered) + if len(got) != 2 || got[0] != "tracked.txt" || got[1] != "new.txt" { + t.Fatalf("collectUncoveredBashPaths() = %#v", got) + } + }) +} + +func TestEmitHelpersPublishExpectedEvents(t *testing.T) { + service := &Service{events: make(chan RuntimeEvent, 8)} + state := &runState{ + runID: "run-1", + session: agentsession.Session{ID: "session-1"}, + } + + service.emitBashSideEffectEvent( + context.Background(), + state, + providertypes.ToolCall{ID: "tool-1"}, + "touch x", + checkpoint.FingerprintDiff{ + Added: []string{"new.txt"}, + Modified: []string{"edit.txt"}, + Deleted: []string{"old.txt"}, + }, + []string{"/repo/edit.txt"}, + []string{"new.txt"}, + ) + + evt := <-service.events + if evt.Type != EventBashSideEffect { + t.Fatalf("event type = %q, want %q", evt.Type, EventBashSideEffect) + } + payload, ok := evt.Payload.(BashSideEffectPayload) + if !ok { + t.Fatalf("payload type = %T", evt.Payload) + } + if len(payload.Changes) != 3 || payload.UncoveredPaths[0] != "new.txt" { + t.Fatalf("unexpected bash payload: %#v", payload) + } + + service.emitBashSideEffectEvent( + context.Background(), + state, + providertypes.ToolCall{ID: "tool-2"}, + "touch noop", + checkpoint.FingerprintDiff{}, + nil, + nil, + ) + select { + case extra := <-service.events: + t.Fatalf("unexpected empty bash side effect event: %#v", extra) + default: + } + + service.emitToolDiffs(context.Background(), state, toolExecutionSummary{ + Results: []tools.ToolResult{ + { + Name: tools.ToolNameFilesystemWriteFile, + Facts: tools.ToolExecutionFacts{ + WorkspaceWrite: true, + }, + Metadata: map[string]any{ + "path": "main.go", + "tool_diff": "@@ -1 +1 @@", + "tool_diff_new": true, + }, + }, + { + Name: tools.ToolNameFilesystemWriteFile, + Facts: tools.ToolExecutionFacts{ + WorkspaceWrite: true, + }, + Metadata: map[string]any{ + "path": "noop.go", + "noop_write": true, + }, + }, + }, + }) + + evt = <-service.events + if evt.Type != EventToolDiff { + t.Fatalf("event type = %q, want %q", evt.Type, EventToolDiff) + } + diffPayload, ok := evt.Payload.(ToolDiffPayload) + if !ok { + t.Fatalf("diff payload type = %T", evt.Payload) + } + if diffPayload.FilePath != "main.go" || !diffPayload.WasNew { + t.Fatalf("unexpected tool diff payload: %#v", diffPayload) + } + select { + case extra := <-service.events: + t.Fatalf("unexpected extra event: %#v", extra) + default: + } +} diff --git a/internal/runtime/toolexec.go b/internal/runtime/toolexec.go index e89b5069..af366d46 100644 --- a/internal/runtime/toolexec.go +++ b/internal/runtime/toolexec.go @@ -2,10 +2,14 @@ package runtime import ( "context" + "encoding/json" "errors" + "os" + "path/filepath" "strings" "sync" + "neo-code/internal/checkpoint" providertypes "neo-code/internal/provider/types" runtimefacts "neo-code/internal/runtime/facts" runtimehooks "neo-code/internal/runtime/hooks" @@ -148,6 +152,55 @@ func (s *Service) executeOneToolCall( s.emitRunScoped(ctx, EventToolStart, state, call) + isWrite := isFileWriteTool(call.Name) + isBash := strings.EqualFold(strings.TrimSpace(call.Name), tools.ToolNameBash) + + var preSnaps map[string]fileSnapshot + var preFingerprint checkpoint.WorkdirFingerprint + var bashCapturedPaths []string + var bashCommand string + var touchedPaths []string + var removeDirNestedPaths []string + + if isWrite { + touchedPaths = toolCallTouchedPaths(call, snapshot.Workdir) + if len(touchedPaths) > 0 { + preSnaps = make(map[string]fileSnapshot, len(touchedPaths)) + for _, p := range touchedPaths { + preSnaps[p] = captureFileSnapshot(p) + if s.perEditStore != nil { + _, _ = s.perEditStore.CapturePreWrite(p) + } + // remove_dir: recursively pre-capture all nested files/dirs. + if strings.EqualFold(strings.TrimSpace(call.Name), tools.ToolNameFilesystemRemoveDir) { + if info, err := os.Stat(p); err == nil && info.IsDir() { + _ = filepath.WalkDir(p, func(path string, d os.DirEntry, err error) error { + if err != nil || path == p { + return nil + } + removeDirNestedPaths = append(removeDirNestedPaths, path) + if s.perEditStore != nil { + _, _ = s.perEditStore.CapturePreWrite(path) + } + return nil + }) + } + } + } + } + } else if isBash && s.perEditStore != nil { + bashCommand = bashCommandFromCall(call) + if checkpoint.BashLikelyWritesFiles(bashCommand) { + bashCapturedPaths = checkpoint.SourceFilesInWorkdir(bashCommand, snapshot.Workdir) + if len(bashCapturedPaths) > 0 { + _, _ = s.perEditStore.CaptureBatch(bashCapturedPaths) + } + if fp, _, err := checkpoint.ScanWorkdir(ctx, snapshot.Workdir, checkpoint.DefaultFingerprintOptions()); err == nil { + preFingerprint = fp + } + } + } + result, execErr := s.executeToolCallWithPermission(ctx, permissionExecutionInput{ RunID: state.runID, SessionID: state.session.ID, @@ -160,6 +213,61 @@ func (s *Service) executeOneToolCall( ToolTimeout: snapshot.ToolTimeout, }) + if isWrite && len(preSnaps) > 0 && execErr == nil && !result.IsError { + if result.Metadata == nil { + result.Metadata = map[string]any{} + } + diffs := make([]map[string]any, 0, len(preSnaps)) + for path, snap := range preSnaps { + diff, err := snap.Diff() + if err != nil { + continue + } + diffs = append(diffs, map[string]any{ + "path": path, + "diff": diff, + "was_new": snap.WasNew(), + }) + } + if len(diffs) > 0 { + result.Metadata["tool_diffs"] = diffs + if len(diffs) == 1 { + result.Metadata["tool_diff"] = diffs[0]["diff"] + result.Metadata["tool_diff_new"] = diffs[0]["was_new"] + } + } + } + + if isWrite && execErr == nil && !result.IsError && s.perEditStore != nil { + switch strings.TrimSpace(call.Name) { + case tools.ToolNameFilesystemRemoveDir: + if len(removeDirNestedPaths) > 0 && len(touchedPaths) > 0 { + allPaths := append([]string{touchedPaths[0]}, removeDirNestedPaths...) + _ = s.perEditStore.CapturePostDelete(allPaths) + } else if len(touchedPaths) > 0 { + _ = s.perEditStore.CapturePostDelete(touchedPaths) + } + case tools.ToolNameFilesystemDeleteFile: + if len(touchedPaths) > 0 { + _ = s.perEditStore.CapturePostDelete(touchedPaths) + } + } + } + + if isBash && preFingerprint != nil && execErr == nil && !result.IsError { + if afterFP, _, err := checkpoint.ScanWorkdir(ctx, snapshot.Workdir, checkpoint.DefaultFingerprintOptions()); err == nil { + fpDiff := checkpoint.DiffFingerprints(preFingerprint, afterFP) + if len(fpDiff.Added) > 0 || len(fpDiff.Modified) > 0 || len(fpDiff.Deleted) > 0 { + covered := make(map[string]struct{}, len(bashCapturedPaths)) + for _, p := range bashCapturedPaths { + covered[filepath.Clean(p)] = struct{}{} + } + uncovered := collectUncoveredBashPaths(snapshot.Workdir, fpDiff, covered) + s.emitBashSideEffectEvent(ctx, state, call, bashCommand, fpDiff, bashCapturedPaths, uncovered) + } + } + } + if errors.Is(execErr, context.Canceled) { s.emitAfterToolResultHook(ctx, state, call, result, execErr, snapshot.Workdir) s.emitAfterToolFailureHook(ctx, state, call, result, execErr, snapshot.Workdir) @@ -202,6 +310,10 @@ func (s *Service) executeOneToolCall( state.mu.Unlock() } + if isBash && execErr == nil && !result.IsError && len(bashCapturedPaths) > 0 { + result.Facts.WorkspaceWrite = true + } + if checkContext() { return result, hasSuccessfulWorkspaceWriteFact(result, execErr), ctx.Err() } @@ -363,6 +475,178 @@ func toolResultNoopWrite(metadata map[string]any) bool { } } +// toolResultFilePath 从工具结果 metadata 中取文件路径。 +func toolResultFilePath(metadata map[string]any) string { + if metadata == nil { + return "" + } + p, _ := metadata["path"].(string) + return strings.TrimSpace(p) +} + +// isFileWriteTool 判断工具调用是否为文件写入类工具,需在执行前后做 diff。 +func isFileWriteTool(name string) bool { + switch strings.TrimSpace(name) { + case tools.ToolNameFilesystemWriteFile, + tools.ToolNameFilesystemEdit, + tools.ToolNameFilesystemMoveFile, + tools.ToolNameFilesystemCopyFile, + tools.ToolNameFilesystemDeleteFile, + tools.ToolNameFilesystemCreateDir, + tools.ToolNameFilesystemRemoveDir: + return true + } + return false +} + +// toolCallTouchedPaths 从工具调用参数中提取所有可能被修改的工作区绝对路径。 +// move/copy 同时返回 source 与 destination;其他写工具返回单个 path。 +func toolCallTouchedPaths(call providertypes.ToolCall, workdir string) []string { + args := strings.TrimSpace(call.Arguments) + if args == "" { + return nil + } + switch strings.TrimSpace(call.Name) { + case tools.ToolNameFilesystemMoveFile, tools.ToolNameFilesystemCopyFile: + var parsed struct { + SourcePath string `json:"source_path"` + DestinationPath string `json:"destination_path"` + } + if err := json.Unmarshal([]byte(args), &parsed); err != nil { + return nil + } + return resolveWorkdirPaths(workdir, parsed.SourcePath, parsed.DestinationPath) + default: + var parsed struct { + Path string `json:"path"` + } + if err := json.Unmarshal([]byte(args), &parsed); err != nil { + return nil + } + return resolveWorkdirPaths(workdir, parsed.Path) + } +} + +// resolveWorkdirPaths 将多个相对/绝对路径解析为工作区绝对路径,丢弃空字符串。 +func resolveWorkdirPaths(workdir string, raw ...string) []string { + out := make([]string, 0, len(raw)) + for _, p := range raw { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if filepath.IsAbs(p) { + out = append(out, filepath.Clean(p)) + continue + } + wd := strings.TrimSpace(workdir) + if wd == "" { + out = append(out, filepath.Clean(p)) + continue + } + out = append(out, filepath.Clean(filepath.Join(wd, p))) + } + if len(out) == 0 { + return nil + } + return out +} + +// bashCommandFromCall 从 bash 工具调用参数解析 command 字段,兼容 cmd 别名。 +func bashCommandFromCall(call providertypes.ToolCall) string { + args := strings.TrimSpace(call.Arguments) + if args == "" { + return "" + } + var parsed struct { + Command string `json:"command"` + Cmd string `json:"cmd"` + } + if err := json.Unmarshal([]byte(args), &parsed); err != nil { + return "" + } + if c := strings.TrimSpace(parsed.Command); c != "" { + return c + } + return strings.TrimSpace(parsed.Cmd) +} + +// collectUncoveredBashPaths 把 fingerprint 检测到的变更路径与启发式预捕获集合做差, +// 输出 EventBashSideEffect.UncoveredPaths 用于可观测性提醒。 +func collectUncoveredBashPaths(workdir string, fpDiff checkpoint.FingerprintDiff, covered map[string]struct{}) []string { + if len(fpDiff.Added) == 0 && len(fpDiff.Modified) == 0 { + return nil + } + wd := strings.TrimSpace(workdir) + seen := make(map[string]struct{}) + out := make([]string, 0) + check := func(rel string) { + rel = strings.TrimSpace(rel) + if rel == "" { + return + } + var abs string + if filepath.IsAbs(rel) { + abs = filepath.Clean(rel) + } else if wd != "" { + abs = filepath.Clean(filepath.Join(wd, rel)) + } else { + abs = filepath.Clean(rel) + } + if _, ok := covered[abs]; ok { + return + } + if _, dup := seen[rel]; dup { + return + } + seen[rel] = struct{}{} + out = append(out, rel) + } + for _, p := range fpDiff.Modified { + check(p) + } + for _, p := range fpDiff.Added { + check(p) + } + if len(out) == 0 { + return nil + } + return out +} + +// emitBashSideEffectEvent 派发 EventBashSideEffect,将 fingerprint 变化分类成 added/modified/deleted。 +func (s *Service) emitBashSideEffectEvent( + ctx context.Context, + state *runState, + call providertypes.ToolCall, + command string, + fpDiff checkpoint.FingerprintDiff, + preCaptured []string, + uncovered []string, +) { + changes := make([]FileChange, 0, len(fpDiff.Added)+len(fpDiff.Modified)+len(fpDiff.Deleted)) + for _, p := range fpDiff.Added { + changes = append(changes, FileChange{Path: p, Kind: "added"}) + } + for _, p := range fpDiff.Modified { + changes = append(changes, FileChange{Path: p, Kind: "modified"}) + } + for _, p := range fpDiff.Deleted { + changes = append(changes, FileChange{Path: p, Kind: "deleted"}) + } + if len(changes) == 0 { + return + } + payload := BashSideEffectPayload{ + ToolCallID: strings.TrimSpace(call.ID), + Command: strings.TrimSpace(command), + Changes: changes, + PreemptivelyCapturedPaths: preCaptured, + UncoveredPaths: uncovered, + } + s.emitRunScoped(ctx, EventBashSideEffect, state, payload) +} + func summarizeHookResultContent(content string) string { trimmed := strings.TrimSpace(content) if len(trimmed) <= 256 { diff --git a/internal/session/checkpoint_types.go b/internal/session/checkpoint_types.go new file mode 100644 index 00000000..d21ac79b --- /dev/null +++ b/internal/session/checkpoint_types.go @@ -0,0 +1,66 @@ +package session + +import "time" + +// CheckpointReason 描述 checkpoint 的创建原因。 +type CheckpointReason string + +const ( + CheckpointReasonPreWrite CheckpointReason = "pre_write" + CheckpointReasonCompact CheckpointReason = "compact" + CheckpointReasonPlanMode CheckpointReason = "plan_mode" + CheckpointReasonManual CheckpointReason = "manual" + CheckpointReasonGuard CheckpointReason = "pre_restore_guard" + CheckpointReasonPreWriteDegraded CheckpointReason = "pre_write_degraded" + CheckpointReasonEndOfTurn CheckpointReason = "end_of_turn" +) + +// CheckpointStatus 描述 checkpoint 的生命周期状态。 +type CheckpointStatus string + +const ( + CheckpointStatusCreating CheckpointStatus = "creating" + CheckpointStatusAvailable CheckpointStatus = "available" + CheckpointStatusBroken CheckpointStatus = "broken" + CheckpointStatusRestored CheckpointStatus = "restored" + CheckpointStatusPruned CheckpointStatus = "pruned" +) + +// CheckpointRecord 遵循 RFC 6.2 定义,包含所有恢复相关引用。 +type CheckpointRecord struct { + CheckpointID string + WorkspaceKey string + SessionID string + RunID string + Workdir string + CreatedAt time.Time + Reason CheckpointReason + CodeCheckpointRef string + SessionCheckpointRef string + ResumeCheckpointRef string + TranscriptRevision int64 + Restorable bool + Status CheckpointStatus +} + +// SessionCheckpoint 保存完整 durable 会话上下文快照。 +type SessionCheckpoint struct { + ID string + SessionID string + HeadJSON string + MessagesJSON string + CreatedAt time.Time +} + +// ResumeCheckpoint 记录运行恢复所需的最小闭环信息。 +type ResumeCheckpoint struct { + ID string + WorkspaceKey string + RunID string + SessionID string + Turn int + Phase string + CompletionState string + TranscriptRevision int64 + UpdatedAt time.Time +} diff --git a/internal/session/sqlite_store.go b/internal/session/sqlite_store.go index d35e18dc..3b79789c 100644 --- a/internal/session/sqlite_store.go +++ b/internal/session/sqlite_store.go @@ -94,6 +94,20 @@ func (s *SQLiteStore) Close() error { return s.db.Close() } +// DB 返回底层 *sql.DB 连接,供需要共享同一数据库连接的组件使用。 +// 调用前必须已触发过 ensureDB(如通过任何读写操作),否则返回 nil。 +func (s *SQLiteStore) DB() *sql.DB { + if s == nil { + return nil + } + return s.db +} + +// InitDB 显式触发数据库连接初始化。冷启动时确保 DB 可用,供需要共享连接的组件使用。 +func (s *SQLiteStore) InitDB(ctx context.Context) (*sql.DB, error) { + return s.ensureDB(ctx) +} + // CleanupExpiredSessions 删除超过指定时长未更新的会话及其附件,返回删除数量。 func (s *SQLiteStore) CleanupExpiredSessions(ctx context.Context, maxAge time.Duration) (int, error) { if err := ctx.Err(); err != nil { @@ -891,6 +905,9 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { if err := migrateSQLiteSchemaV4ToV5(ctx, db); err != nil { return err } + if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { + return err + } case 2: if err := migrateSQLiteSchemaV2ToV3(ctx, db); err != nil { return err @@ -901,6 +918,9 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { if err := migrateSQLiteSchemaV4ToV5(ctx, db); err != nil { return err } + if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { + return err + } case 3: if err := migrateSQLiteSchemaV3ToV4(ctx, db); err != nil { return err @@ -908,10 +928,20 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { if err := migrateSQLiteSchemaV4ToV5(ctx, db); err != nil { return err } + if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { + return err + } case 4: if err := migrateSQLiteSchemaV4ToV5(ctx, db); err != nil { return err } + if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { + return err + } + case 5: + if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { + return err + } default: return fmt.Errorf("session: unsupported sqlite schema version %d", userVersion) } @@ -972,6 +1002,45 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { `CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at_ms DESC)`, `CREATE INDEX IF NOT EXISTS idx_messages_session_seq_desc ON messages(session_id, seq DESC)`, `CREATE INDEX IF NOT EXISTS idx_assets_session_id ON session_assets(session_id)`, + `CREATE TABLE IF NOT EXISTS checkpoint_records ( + id TEXT PRIMARY KEY, + workspace_key TEXT NOT NULL, + session_id TEXT NOT NULL, + run_id TEXT NOT NULL DEFAULT '', + workdir TEXT NOT NULL DEFAULT '', + created_at_ms INTEGER NOT NULL, + reason TEXT NOT NULL DEFAULT 'pre_write', + code_checkpoint_ref TEXT NOT NULL DEFAULT '', + session_checkpoint_ref TEXT NOT NULL DEFAULT '', + resume_checkpoint_ref TEXT NOT NULL DEFAULT '', + transcript_revision INTEGER NOT NULL DEFAULT 0, + restorable INTEGER NOT NULL DEFAULT 1, + status TEXT NOT NULL DEFAULT 'creating', + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE TABLE IF NOT EXISTS session_checkpoints ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + head_json TEXT NOT NULL DEFAULT '', + messages_json TEXT NOT NULL DEFAULT '', + created_at_ms INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE TABLE IF NOT EXISTS resume_checkpoints ( + id TEXT PRIMARY KEY, + workspace_key TEXT NOT NULL, + run_id TEXT NOT NULL DEFAULT '', + session_id TEXT NOT NULL, + turn INTEGER NOT NULL DEFAULT 0, + phase TEXT NOT NULL DEFAULT '', + completion_state TEXT NOT NULL DEFAULT '', + transcript_revision INTEGER NOT NULL DEFAULT 0, + updated_at_ms INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE INDEX IF NOT EXISTS idx_checkpoint_session_created ON checkpoint_records(session_id, created_at_ms DESC)`, + `CREATE INDEX IF NOT EXISTS idx_session_checkpoints_session ON session_checkpoints(session_id, created_at_ms DESC)`, + `CREATE INDEX IF NOT EXISTS idx_resume_checkpoints_session ON resume_checkpoints(session_id, updated_at_ms DESC)`, fmt.Sprintf(`PRAGMA user_version=%d`, sqliteSchemaVersion), } for _, statement := range statements { @@ -1014,6 +1083,70 @@ func migrateSQLiteSchemaV1ToV2(ctx context.Context, db *sql.DB) error { return nil } +// migrateSQLiteSchemaV5ToV6 将 v5 会话库升级到 v6 schema,新增 checkpoint 相关表。 +func migrateSQLiteSchemaV5ToV6(ctx context.Context, db *sql.DB) error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("session: begin schema migration tx: %w", err) + } + defer rollbackTx(tx) + + stmts := []string{ + `CREATE TABLE IF NOT EXISTS checkpoint_records ( + id TEXT PRIMARY KEY, + workspace_key TEXT NOT NULL, + session_id TEXT NOT NULL, + run_id TEXT NOT NULL DEFAULT '', + workdir TEXT NOT NULL DEFAULT '', + created_at_ms INTEGER NOT NULL, + reason TEXT NOT NULL DEFAULT 'pre_write', + code_checkpoint_ref TEXT NOT NULL DEFAULT '', + session_checkpoint_ref TEXT NOT NULL DEFAULT '', + resume_checkpoint_ref TEXT NOT NULL DEFAULT '', + transcript_revision INTEGER NOT NULL DEFAULT 0, + restorable INTEGER NOT NULL DEFAULT 1, + status TEXT NOT NULL DEFAULT 'creating', + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE TABLE IF NOT EXISTS session_checkpoints ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + head_json TEXT NOT NULL DEFAULT '', + messages_json TEXT NOT NULL DEFAULT '', + created_at_ms INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE TABLE IF NOT EXISTS resume_checkpoints ( + id TEXT PRIMARY KEY, + workspace_key TEXT NOT NULL, + run_id TEXT NOT NULL DEFAULT '', + session_id TEXT NOT NULL, + turn INTEGER NOT NULL DEFAULT 0, + phase TEXT NOT NULL DEFAULT '', + completion_state TEXT NOT NULL DEFAULT '', + transcript_revision INTEGER NOT NULL DEFAULT 0, + updated_at_ms INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE INDEX IF NOT EXISTS idx_checkpoint_session_created ON checkpoint_records(session_id, created_at_ms DESC)`, + `CREATE INDEX IF NOT EXISTS idx_session_checkpoints_session ON session_checkpoints(session_id, created_at_ms DESC)`, + `CREATE INDEX IF NOT EXISTS idx_resume_checkpoints_session ON resume_checkpoints(session_id, updated_at_ms DESC)`, + } + for _, stmt := range stmts { + if _, err := tx.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("session: migrate sqlite schema v5 to v6: %w", err) + } + } + + if _, err := tx.ExecContext(ctx, fmt.Sprintf(`PRAGMA user_version=%d`, sqliteSchemaVersion)); err != nil { + return fmt.Errorf("session: set migrated sqlite schema version: %w", err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("session: commit schema migration tx: %w", err) + } + return nil +} + // sqliteTableHasColumn 检查指定表是否包含字段,供明确版本迁移保持幂等。 // migrateSQLiteSchemaV2ToV3 将 v2 会话库升级到 v3 schema,补齐 plan/build 所需字段。 func migrateSQLiteSchemaV2ToV3(ctx context.Context, db *sql.DB) error { diff --git a/internal/session/store.go b/internal/session/store.go index b6e5155d..6ec1ac51 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -14,7 +14,7 @@ import ( const ( sessionDatabaseFileName = "session.db" assetsDirName = "assets" - sqliteSchemaVersion = 5 + sqliteSchemaVersion = 6 // MaxSessionMessages 定义单个会话允许持久化的最大消息数,超出时自动裁剪最旧消息。 MaxSessionMessages = 8192 @@ -155,7 +155,7 @@ func NewSQLiteStore(baseDir string, workspaceRoot string) *SQLiteStore { return &SQLiteStore{ projectDir: projectDirectory(baseDir, workspaceRoot), assetsDir: assetsDirectory(baseDir, workspaceRoot), - dbPath: databasePath(baseDir, workspaceRoot), + dbPath: DatabasePath(baseDir, workspaceRoot), assetPolicy: DefaultAssetPolicy(), } } diff --git a/internal/session/store_test.go b/internal/session/store_test.go index dd634a25..00228be2 100644 --- a/internal/session/store_test.go +++ b/internal/session/store_test.go @@ -476,7 +476,7 @@ func TestSQLiteStoreInitializationRejectsUnsupportedSchemaVersion(t *testing.T) if err := os.MkdirAll(projectDir, 0o755); err != nil { t.Fatalf("MkdirAll(projectDir) error = %v", err) } - db, err := sql.Open("sqlite", databasePath(baseDir, workspaceRoot)) + db, err := sql.Open("sqlite", DatabasePath(baseDir, workspaceRoot)) if err != nil { t.Fatalf("sql.Open() error = %v", err) } @@ -715,7 +715,7 @@ func createLegacyV1SessionDB( if err := os.MkdirAll(projectDir, 0o755); err != nil { t.Fatalf("MkdirAll(projectDir) error = %v", err) } - db, err := sql.Open("sqlite", databasePath(baseDir, workspaceRoot)) + db, err := sql.Open("sqlite", DatabasePath(baseDir, workspaceRoot)) if err != nil { t.Fatalf("sql.Open() error = %v", err) } @@ -792,7 +792,7 @@ func createLegacyV2SessionDB(t *testing.T, ctx context.Context, baseDir string, if err := os.MkdirAll(projectDir, 0o755); err != nil { t.Fatalf("MkdirAll(projectDir) error = %v", err) } - db, err := sql.Open("sqlite", databasePath(baseDir, workspaceRoot)) + db, err := sql.Open("sqlite", DatabasePath(baseDir, workspaceRoot)) if err != nil { t.Fatalf("sql.Open() error = %v", err) } diff --git a/internal/session/workspace.go b/internal/session/workspace.go index d16e0278..cfe99170 100644 --- a/internal/session/workspace.go +++ b/internal/session/workspace.go @@ -17,8 +17,8 @@ func projectDirectory(baseDir string, workspaceRoot string) string { return filepath.Join(baseDir, projectsDirName, HashWorkspaceRoot(workspaceRoot)) } -// databasePath 返回当前工作区级 SQLite 数据库文件路径。 -func databasePath(baseDir string, workspaceRoot string) string { +// DatabasePath 返回当前工作区级 SQLite 数据库文件路径。 +func DatabasePath(baseDir string, workspaceRoot string) string { return filepath.Join(projectDirectory(baseDir, workspaceRoot), sessionDatabaseFileName) } diff --git a/internal/tools/filesystem/copy_file.go b/internal/tools/filesystem/copy_file.go new file mode 100644 index 00000000..c03b6111 --- /dev/null +++ b/internal/tools/filesystem/copy_file.go @@ -0,0 +1,126 @@ +package filesystem + +import ( + "context" + "encoding/json" + "errors" + "os" + "path/filepath" + "strings" + + "neo-code/internal/tools" +) + +type CopyFileTool struct { + root string +} + +type copyFileInput struct { + SourcePath string `json:"source_path"` + DestinationPath string `json:"destination_path"` + Overwrite bool `json:"overwrite,omitempty"` +} + +func NewCopy(root string) *CopyFileTool { + return &CopyFileTool{root: root} +} + +func (t *CopyFileTool) Name() string { + return copyFileToolName +} + +func (t *CopyFileTool) Description() string { + return "Copy a file inside the workspace. Both paths must resolve inside the workspace." +} + +func (t *CopyFileTool) Schema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "source_path": map[string]any{ + "type": "string", + "description": "Existing file path to copy, relative to workspace root or absolute inside workspace.", + }, + "destination_path": map[string]any{ + "type": "string", + "description": "Destination file path, relative to workspace root or absolute inside workspace.", + }, + "overwrite": map[string]any{ + "type": "boolean", + "description": "When true, replace destination if it already exists. Defaults to false.", + }, + }, + "required": []string{"source_path", "destination_path"}, + } +} + +func (t *CopyFileTool) MicroCompactPolicy() tools.MicroCompactPolicy { + return tools.MicroCompactPolicyCompact +} + +func (t *CopyFileTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + var args copyFileInput + if err := json.Unmarshal(input.Arguments, &args); err != nil { + return tools.NewErrorResult(t.Name(), "invalid arguments", err.Error(), nil), err + } + if strings.TrimSpace(args.SourcePath) == "" { + err := errors.New(copyFileToolName + ": source_path is required") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + if strings.TrimSpace(args.DestinationPath) == "" { + err := errors.New(copyFileToolName + ": destination_path is required") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + if err := ctx.Err(); err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + base := effectiveRoot(t.root, input.Workdir) + + src, err := resolvePath(base, args.SourcePath) + if err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + dst, err := resolvePath(base, args.DestinationPath) + if err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + srcInfo, statErr := os.Stat(src) + if statErr != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), statErr), "", nil), statErr + } + if srcInfo.IsDir() { + err := errors.New(copyFileToolName + ": source_path must be a file, not a directory") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + if _, err := os.Stat(dst); err == nil { + if !args.Overwrite { + err := errors.New(copyFileToolName + ": destination_path already exists; pass overwrite=true to replace it") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + } else if !os.IsNotExist(err) { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + if err := copyFileContents(src, dst, srcInfo.Mode().Perm()); err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + return tools.ToolResult{ + Name: t.Name(), + Content: "ok", + Metadata: map[string]any{ + "source_path": src, + "destination_path": dst, + "paths": []string{dst}, + "bytes": srcInfo.Size(), + "overwrite": args.Overwrite, + }, + Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, + }, nil +} diff --git a/internal/tools/filesystem/copy_file_test.go b/internal/tools/filesystem/copy_file_test.go new file mode 100644 index 00000000..bcada3b7 --- /dev/null +++ b/internal/tools/filesystem/copy_file_test.go @@ -0,0 +1,199 @@ +package filesystem + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "neo-code/internal/tools" +) + +func TestCopyFileTool_DuplicatesContent(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + src := filepath.Join(workspace, "a.go") + if err := os.WriteFile(src, []byte("package main"), 0o644); err != nil { + t.Fatalf("seed: %v", err) + } + tool := NewCopy(workspace) + args, _ := json.Marshal(map[string]any{ + "source_path": "a.go", + "destination_path": filepath.Join("nested", "b.go"), + }) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err != nil { + t.Fatalf("execute: %v", err) + } + if result.IsError { + t.Fatalf("error result: %s", result.Content) + } + if !result.Facts.WorkspaceWrite { + t.Fatalf("expected WorkspaceWrite=true") + } + srcData, _ := os.ReadFile(src) + if string(srcData) != "package main" { + t.Fatalf("source modified: %q", string(srcData)) + } + dst := filepath.Join(workspace, "nested", "b.go") + dstData, err := os.ReadFile(dst) + if err != nil { + t.Fatalf("read dst: %v", err) + } + if string(dstData) != "package main" { + t.Fatalf("dst content = %q", string(dstData)) + } + paths, ok := result.Metadata["paths"].([]string) + if !ok || len(paths) != 1 { + t.Fatalf("paths metadata = %#v want 1-item slice", result.Metadata["paths"]) + } +} + +func TestCopyFileTool_RefusesOverwriteByDefault(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + src := filepath.Join(workspace, "src.txt") + dst := filepath.Join(workspace, "dst.txt") + if err := os.WriteFile(src, []byte("a"), 0o644); err != nil { + t.Fatalf("seed src: %v", err) + } + if err := os.WriteFile(dst, []byte("b"), 0o644); err != nil { + t.Fatalf("seed dst: %v", err) + } + tool := NewCopy(workspace) + args, _ := json.Marshal(map[string]any{ + "source_path": "src.txt", + "destination_path": "dst.txt", + }) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), "already exists") { + t.Fatalf("expected exists error, got %v", err) + } + if data, _ := os.ReadFile(dst); string(data) != "b" { + t.Fatalf("dst was clobbered: %q", string(data)) + } +} + +func TestCopyFileTool_OverwriteAllowed(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + src := filepath.Join(workspace, "src.txt") + dst := filepath.Join(workspace, "dst.txt") + if err := os.WriteFile(src, []byte("new"), 0o644); err != nil { + t.Fatalf("seed src: %v", err) + } + if err := os.WriteFile(dst, []byte("old"), 0o644); err != nil { + t.Fatalf("seed dst: %v", err) + } + tool := NewCopy(workspace) + args, _ := json.Marshal(map[string]any{ + "source_path": "src.txt", + "destination_path": "dst.txt", + "overwrite": true, + }) + if _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }); err != nil { + t.Fatalf("execute: %v", err) + } + if data, _ := os.ReadFile(dst); string(data) != "new" { + t.Fatalf("dst content = %q want new", string(data)) + } + if data, _ := os.ReadFile(src); string(data) != "new" { + t.Fatalf("src removed unexpectedly: %q", string(data)) + } +} + +func TestCopyFileTool_RejectsTraversal(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + src := filepath.Join(workspace, "src.txt") + if err := os.WriteFile(src, []byte("x"), 0o644); err != nil { + t.Fatalf("seed: %v", err) + } + tool := NewCopy(workspace) + args, _ := json.Marshal(map[string]any{ + "source_path": "src.txt", + "destination_path": filepath.Join("..", "escape.txt"), + }) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), "escapes workspace") { + t.Fatalf("expected escape error, got %v", err) + } +} + +func TestCopyFileTool_InvalidJSON(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewCopy(workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: []byte(`{invalid`), + Workdir: workspace, + }) + if err == nil { + t.Fatalf("expected json error") + } + if !result.IsError { + t.Fatalf("expected error result") + } +} + +func TestCopyFileTool_RejectsDirectorySource(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + sourceDir := filepath.Join(workspace, "srcdir") + if err := os.MkdirAll(sourceDir, 0o755); err != nil { + t.Fatalf("seed dir: %v", err) + } + tool := NewCopy(workspace) + args, _ := json.Marshal(map[string]any{ + "source_path": "srcdir", + "destination_path": "copy.txt", + }) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), "must be a file") { + t.Fatalf("expected directory source error, got %v", err) + } +} + +func TestCopyFileTool_RejectsCanceledContext(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewCopy(workspace) + args, _ := json.Marshal(map[string]any{ + "source_path": "src.txt", + "destination_path": "dst.txt", + }) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := tool.Execute(ctx, tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) { + t.Fatalf("expected canceled error, got %v", err) + } +} diff --git a/internal/tools/filesystem/create_dir.go b/internal/tools/filesystem/create_dir.go new file mode 100644 index 00000000..538ebb03 --- /dev/null +++ b/internal/tools/filesystem/create_dir.go @@ -0,0 +1,121 @@ +package filesystem + +import ( + "context" + "encoding/json" + "errors" + "os" + "strings" + + "neo-code/internal/security" + "neo-code/internal/tools" +) + +type CreateDirTool struct { + root string +} + +type createDirInput struct { + Path string `json:"path"` + Recursive *bool `json:"recursive,omitempty"` +} + +func NewCreateDir(root string) *CreateDirTool { + return &CreateDirTool{root: root} +} + +func (t *CreateDirTool) Name() string { + return createDirToolName +} + +func (t *CreateDirTool) Description() string { + return "Create a directory inside the workspace. Recursive by default." +} + +func (t *CreateDirTool) Schema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Directory path relative to workspace root, or absolute inside the workspace.", + }, + "recursive": map[string]any{ + "type": "boolean", + "description": "When true (default), create parent directories as needed; when false, fail if the parent is missing.", + }, + }, + "required": []string{"path"}, + } +} + +func (t *CreateDirTool) MicroCompactPolicy() tools.MicroCompactPolicy { + return tools.MicroCompactPolicyCompact +} + +func (t *CreateDirTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + var args createDirInput + if err := json.Unmarshal(input.Arguments, &args); err != nil { + return tools.NewErrorResult(t.Name(), "invalid arguments", err.Error(), nil), err + } + if strings.TrimSpace(args.Path) == "" { + err := errors.New(createDirToolName + ": path is required") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + if err := ctx.Err(); err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + recursive := true + if args.Recursive != nil { + recursive = *args.Recursive + } + + base := effectiveRoot(t.root, input.Workdir) + + _, target, err := tools.ResolveWorkspaceTarget(input, security.TargetTypePath, base, args.Path, resolvePath) + if err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + if info, statErr := os.Stat(target); statErr == nil { + if !info.IsDir() { + err := errors.New(createDirToolName + ": path exists and is not a directory") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + return tools.ToolResult{ + Name: t.Name(), + Content: "ok", + Metadata: map[string]any{ + "path": target, + "created": false, + "noop_write": true, + "recursive": recursive, + }, + Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, + }, nil + } else if !os.IsNotExist(statErr) { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), statErr), "", nil), statErr + } + + if recursive { + if err := os.MkdirAll(target, 0o755); err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + } else { + if err := os.Mkdir(target, 0o755); err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + } + + return tools.ToolResult{ + Name: t.Name(), + Content: "ok", + Metadata: map[string]any{ + "path": target, + "created": true, + "recursive": recursive, + }, + Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, + }, nil +} diff --git a/internal/tools/filesystem/create_dir_test.go b/internal/tools/filesystem/create_dir_test.go new file mode 100644 index 00000000..a2434d98 --- /dev/null +++ b/internal/tools/filesystem/create_dir_test.go @@ -0,0 +1,136 @@ +package filesystem + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "neo-code/internal/tools" +) + +func TestCreateDirTool_RecursiveByDefault(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewCreateDir(workspace) + args, _ := json.Marshal(map[string]any{ + "path": filepath.Join("a", "b", "c"), + }) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err != nil { + t.Fatalf("execute: %v", err) + } + if result.IsError { + t.Fatalf("error result: %s", result.Content) + } + if !result.Facts.WorkspaceWrite { + t.Fatalf("expected WorkspaceWrite=true") + } + target := filepath.Join(workspace, "a", "b", "c") + if info, err := os.Stat(target); err != nil || !info.IsDir() { + t.Fatalf("dir not created: info=%v err=%v", info, err) + } + if got, _ := result.Metadata["created"].(bool); !got { + t.Fatalf("created metadata = %v want true", result.Metadata["created"]) + } +} + +func TestCreateDirTool_NonRecursiveFailsForMissingParent(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewCreateDir(workspace) + args, _ := json.Marshal(map[string]any{ + "path": filepath.Join("missing", "child"), + "recursive": false, + }) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil { + t.Fatalf("expected error for missing parent") + } +} + +func TestCreateDirTool_ExistingDirReturnsNoop(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + dir := filepath.Join(workspace, "existing") + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("seed: %v", err) + } + tool := NewCreateDir(workspace) + args, _ := json.Marshal(map[string]any{"path": "existing"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err != nil { + t.Fatalf("execute: %v", err) + } + if got, _ := result.Metadata["noop_write"].(bool); !got { + t.Fatalf("noop_write metadata = %v", result.Metadata["noop_write"]) + } + if got, _ := result.Metadata["created"].(bool); got { + t.Fatalf("created metadata = %v want false", result.Metadata["created"]) + } +} + +func TestCreateDirTool_RejectsExistingFile(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + target := filepath.Join(workspace, "blocker") + if err := os.WriteFile(target, []byte("file"), 0o644); err != nil { + t.Fatalf("seed: %v", err) + } + tool := NewCreateDir(workspace) + args, _ := json.Marshal(map[string]any{"path": "blocker"}) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), "not a directory") { + t.Fatalf("expected file-blocking error, got %v", err) + } +} + +func TestCreateDirTool_RejectsTraversal(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewCreateDir(workspace) + args, _ := json.Marshal(map[string]any{"path": filepath.Join("..", "escape")}) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), "escapes workspace") { + t.Fatalf("expected escape error, got %v", err) + } +} + +func TestCreateDirTool_InvalidJSON(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewCreateDir(workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: []byte(`{invalid`), + Workdir: workspace, + }) + if err == nil { + t.Fatalf("expected json error") + } + if !result.IsError { + t.Fatalf("expected error result") + } +} diff --git a/internal/tools/filesystem/delete_file.go b/internal/tools/filesystem/delete_file.go new file mode 100644 index 00000000..5415c9ca --- /dev/null +++ b/internal/tools/filesystem/delete_file.go @@ -0,0 +1,106 @@ +package filesystem + +import ( + "context" + "encoding/json" + "errors" + "os" + "strings" + + "neo-code/internal/security" + "neo-code/internal/tools" +) + +type DeleteFileTool struct { + root string +} + +type deleteFileInput struct { + Path string `json:"path"` +} + +func NewDelete(root string) *DeleteFileTool { + return &DeleteFileTool{root: root} +} + +func (t *DeleteFileTool) Name() string { + return deleteFileToolName +} + +func (t *DeleteFileTool) Description() string { + return "Delete a single file inside the workspace. Does not remove directories." +} + +func (t *DeleteFileTool) Schema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "File path relative to workspace root, or absolute inside the workspace.", + }, + }, + "required": []string{"path"}, + } +} + +func (t *DeleteFileTool) MicroCompactPolicy() tools.MicroCompactPolicy { + return tools.MicroCompactPolicyCompact +} + +func (t *DeleteFileTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + var args deleteFileInput + if err := json.Unmarshal(input.Arguments, &args); err != nil { + return tools.NewErrorResult(t.Name(), "invalid arguments", err.Error(), nil), err + } + if strings.TrimSpace(args.Path) == "" { + err := errors.New(deleteFileToolName + ": path is required") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + if err := ctx.Err(); err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + base := effectiveRoot(t.root, input.Workdir) + + _, target, err := tools.ResolveWorkspaceTarget(input, security.TargetTypePath, base, args.Path, resolvePath) + if err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + info, statErr := os.Stat(target) + if statErr != nil { + if os.IsNotExist(statErr) { + return tools.ToolResult{ + Name: t.Name(), + Content: "ok", + Metadata: map[string]any{ + "path": target, + "deleted": false, + "noop_write": true, + }, + Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, + }, nil + } + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), statErr), "", nil), statErr + } + if info.IsDir() { + err := errors.New(deleteFileToolName + ": path is a directory; use filesystem_remove_dir") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + if err := os.Remove(target); err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + return tools.ToolResult{ + Name: t.Name(), + Content: "ok", + Metadata: map[string]any{ + "path": target, + "deleted": true, + "bytes": info.Size(), + }, + Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, + }, nil +} diff --git a/internal/tools/filesystem/delete_file_test.go b/internal/tools/filesystem/delete_file_test.go new file mode 100644 index 00000000..6faa0746 --- /dev/null +++ b/internal/tools/filesystem/delete_file_test.go @@ -0,0 +1,133 @@ +package filesystem + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "neo-code/internal/tools" +) + +func TestDeleteFileTool_RemovesExistingFile(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + target := filepath.Join(workspace, "doomed.txt") + if err := os.WriteFile(target, []byte("bye"), 0o644); err != nil { + t.Fatalf("seed: %v", err) + } + tool := NewDelete(workspace) + args, _ := json.Marshal(map[string]any{"path": "doomed.txt"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err != nil { + t.Fatalf("execute: %v", err) + } + if result.IsError { + t.Fatalf("error result: %s", result.Content) + } + if !result.Facts.WorkspaceWrite { + t.Fatalf("expected WorkspaceWrite=true") + } + if _, err := os.Stat(target); !os.IsNotExist(err) { + t.Fatalf("file still exists: %v", err) + } + if got, _ := result.Metadata["deleted"].(bool); !got { + t.Fatalf("deleted metadata = %v want true", result.Metadata["deleted"]) + } +} + +func TestDeleteFileTool_MissingFileReturnsNoop(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewDelete(workspace) + args, _ := json.Marshal(map[string]any{"path": "ghost.txt"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err != nil { + t.Fatalf("execute: %v", err) + } + if result.IsError { + t.Fatalf("error result: %s", result.Content) + } + if got, _ := result.Metadata["noop_write"].(bool); !got { + t.Fatalf("noop_write metadata = %v", result.Metadata["noop_write"]) + } + if got, _ := result.Metadata["deleted"].(bool); got { + t.Fatalf("deleted metadata = %v want false", result.Metadata["deleted"]) + } +} + +func TestDeleteFileTool_RejectsDirectory(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + dir := filepath.Join(workspace, "subdir") + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + tool := NewDelete(workspace) + args, _ := json.Marshal(map[string]any{"path": "subdir"}) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), "is a directory") { + t.Fatalf("expected directory error, got %v", err) + } +} + +func TestDeleteFileTool_RejectsTraversal(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewDelete(workspace) + args, _ := json.Marshal(map[string]any{"path": filepath.Join("..", "escape.txt")}) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), "escapes workspace") { + t.Fatalf("expected escape error, got %v", err) + } +} + +func TestDeleteFileTool_RejectsEmptyPath(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewDelete(workspace) + args, _ := json.Marshal(map[string]any{"path": ""}) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), "path is required") { + t.Fatalf("expected path required, got %v", err) + } +} + +func TestDeleteFileTool_InvalidJSON(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewDelete(workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: []byte(`{invalid`), + Workdir: workspace, + }) + if err == nil { + t.Fatalf("expected json error") + } + if !result.IsError { + t.Fatalf("expected error result") + } +} diff --git a/internal/tools/filesystem/helpers.go b/internal/tools/filesystem/helpers.go index 929d74e1..4606cb27 100644 --- a/internal/tools/filesystem/helpers.go +++ b/internal/tools/filesystem/helpers.go @@ -9,11 +9,16 @@ import ( ) const ( - readFileToolName = tools.ToolNameFilesystemReadFile - writeFileToolName = tools.ToolNameFilesystemWriteFile - grepToolName = tools.ToolNameFilesystemGrep - globToolName = tools.ToolNameFilesystemGlob - editToolName = tools.ToolNameFilesystemEdit + readFileToolName = tools.ToolNameFilesystemReadFile + writeFileToolName = tools.ToolNameFilesystemWriteFile + grepToolName = tools.ToolNameFilesystemGrep + globToolName = tools.ToolNameFilesystemGlob + editToolName = tools.ToolNameFilesystemEdit + moveFileToolName = tools.ToolNameFilesystemMoveFile + copyFileToolName = tools.ToolNameFilesystemCopyFile + deleteFileToolName = tools.ToolNameFilesystemDeleteFile + createDirToolName = tools.ToolNameFilesystemCreateDir + removeDirToolName = tools.ToolNameFilesystemRemoveDir ) func effectiveRoot(defaultRoot string, workdir string) string { diff --git a/internal/tools/filesystem/helpers_test.go b/internal/tools/filesystem/helpers_test.go new file mode 100644 index 00000000..3fb75c34 --- /dev/null +++ b/internal/tools/filesystem/helpers_test.go @@ -0,0 +1,92 @@ +package filesystem + +import ( + "errors" + "os" + "path/filepath" + "testing" +) + +func TestToRelativePath(t *testing.T) { + t.Parallel() + root := t.TempDir() + inside := filepath.Join(root, "nested", "file.txt") + outside := filepath.Join(filepath.Dir(root), "outside.txt") + + if got := toRelativePath(root, inside); got != filepath.Join("nested", "file.txt") { + t.Fatalf("inside path = %q, want nested/file.txt", got) + } + if got := toRelativePath(root, outside); got != filepath.Join("..", "outside.txt") { + t.Fatalf("outside path = %q, want ../outside.txt", got) + } +} + +func TestSkipDirEntry(t *testing.T) { + t.Parallel() + root := t.TempDir() + mustCreateDir(t, filepath.Join(root, ".git")) + mustCreateDir(t, filepath.Join(root, "node_modules")) + mustCreateDir(t, filepath.Join(root, "keep")) + mustWriteTestFile(t, filepath.Join(root, ".vscode"), "not-a-dir") + + entries, err := os.ReadDir(root) + if err != nil { + t.Fatalf("ReadDir() error = %v", err) + } + + got := map[string]bool{} + for _, entry := range entries { + got[entry.Name()] = skipDirEntry(filepath.Join(root, entry.Name()), entry) + } + + if !got[".git"] { + t.Fatalf(".git skip = false, want true") + } + if !got["node_modules"] { + t.Fatalf("node_modules skip = false, want true") + } + if got["keep"] { + t.Fatalf("keep skip = true, want false") + } + if got[".vscode"] { + t.Fatalf(".vscode file skip = true, want false for non-directory") + } +} + +func TestIsCrossDeviceLinkError(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + err error + want bool + }{ + {name: "nil", err: nil, want: false}, + {name: "other", err: errors.New("permission denied"), want: false}, + {name: "cross-device", err: errors.New("invalid cross-device link"), want: true}, + {name: "exdev", err: errors.New("rename failed: EXDEV"), want: true}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + if got := isCrossDeviceLinkError(tc.err); got != tc.want { + t.Fatalf("isCrossDeviceLinkError(%v) = %v, want %v", tc.err, got, tc.want) + } + }) + } +} + +func mustCreateDir(t *testing.T, path string) { + t.Helper() + if err := os.MkdirAll(path, 0o755); err != nil { + t.Fatalf("MkdirAll(%q) error = %v", path, err) + } +} + +func mustWriteTestFile(t *testing.T, path string, content string) { + t.Helper() + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("WriteFile(%q) error = %v", path, err) + } +} diff --git a/internal/tools/filesystem/move_file.go b/internal/tools/filesystem/move_file.go new file mode 100644 index 00000000..c627b85f --- /dev/null +++ b/internal/tools/filesystem/move_file.go @@ -0,0 +1,160 @@ +package filesystem + +import ( + "context" + "encoding/json" + "errors" + "io" + "os" + "path/filepath" + "strings" + + "neo-code/internal/tools" +) + +type MoveFileTool struct { + root string +} + +type moveFileInput struct { + SourcePath string `json:"source_path"` + DestinationPath string `json:"destination_path"` + Overwrite bool `json:"overwrite,omitempty"` +} + +func NewMove(root string) *MoveFileTool { + return &MoveFileTool{root: root} +} + +func (t *MoveFileTool) Name() string { + return moveFileToolName +} + +func (t *MoveFileTool) Description() string { + return "Move or rename a file inside the workspace. Both paths must resolve inside the workspace." +} + +func (t *MoveFileTool) Schema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "source_path": map[string]any{ + "type": "string", + "description": "Existing file path to move, relative to workspace root or absolute inside workspace.", + }, + "destination_path": map[string]any{ + "type": "string", + "description": "New file path, relative to workspace root or absolute inside workspace.", + }, + "overwrite": map[string]any{ + "type": "boolean", + "description": "When true, replace destination if it already exists. Defaults to false.", + }, + }, + "required": []string{"source_path", "destination_path"}, + } +} + +func (t *MoveFileTool) MicroCompactPolicy() tools.MicroCompactPolicy { + return tools.MicroCompactPolicyCompact +} + +func (t *MoveFileTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + var args moveFileInput + if err := json.Unmarshal(input.Arguments, &args); err != nil { + return tools.NewErrorResult(t.Name(), "invalid arguments", err.Error(), nil), err + } + if strings.TrimSpace(args.SourcePath) == "" { + err := errors.New(moveFileToolName + ": source_path is required") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + if strings.TrimSpace(args.DestinationPath) == "" { + err := errors.New(moveFileToolName + ": destination_path is required") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + if err := ctx.Err(); err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + base := effectiveRoot(t.root, input.Workdir) + + src, err := resolvePath(base, args.SourcePath) + if err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + dst, err := resolvePath(base, args.DestinationPath) + if err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + srcInfo, statErr := os.Stat(src) + if statErr != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), statErr), "", nil), statErr + } + if srcInfo.IsDir() { + err := errors.New(moveFileToolName + ": source_path must be a file, not a directory") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + if _, err := os.Stat(dst); err == nil { + if !args.Overwrite { + err := errors.New(moveFileToolName + ": destination_path already exists; pass overwrite=true to replace it") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + } else if !os.IsNotExist(err) { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + if err := os.Rename(src, dst); err != nil { + if !isCrossDeviceLinkError(err) { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + if copyErr := copyFileContents(src, dst, srcInfo.Mode().Perm()); copyErr != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), copyErr), "", nil), copyErr + } + if removeErr := os.Remove(src); removeErr != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), removeErr), "", nil), removeErr + } + } + + return tools.ToolResult{ + Name: t.Name(), + Content: "ok", + Metadata: map[string]any{ + "source_path": src, + "destination_path": dst, + "paths": []string{src, dst}, + "bytes": srcInfo.Size(), + "overwrite": args.Overwrite, + }, + Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, + }, nil +} + +func copyFileContents(src, dst string, mode os.FileMode) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode) + if err != nil { + return err + } + defer out.Close() + if _, err := io.Copy(out, in); err != nil { + return err + } + return out.Sync() +} + +func isCrossDeviceLinkError(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "cross-device") || strings.Contains(msg, "exdev") +} diff --git a/internal/tools/filesystem/move_file_test.go b/internal/tools/filesystem/move_file_test.go new file mode 100644 index 00000000..bd95e86d --- /dev/null +++ b/internal/tools/filesystem/move_file_test.go @@ -0,0 +1,255 @@ +package filesystem + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "neo-code/internal/tools" +) + +func TestMoveFileTool_RenamesWithinWorkspace(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + src := filepath.Join(workspace, "old.go") + if err := os.WriteFile(src, []byte("hello"), 0o644); err != nil { + t.Fatalf("seed: %v", err) + } + tool := NewMove(workspace) + + args, _ := json.Marshal(map[string]any{ + "source_path": "old.go", + "destination_path": "renamed.go", + }) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err != nil { + t.Fatalf("execute: %v", err) + } + if result.IsError { + t.Fatalf("unexpected error result: %s", result.Content) + } + if !result.Facts.WorkspaceWrite { + t.Fatalf("expected WorkspaceWrite=true") + } + if _, err := os.Stat(src); !os.IsNotExist(err) { + t.Fatalf("source still exists: err=%v", err) + } + dst := filepath.Join(workspace, "renamed.go") + if data, err := os.ReadFile(dst); err != nil { + t.Fatalf("read dst: %v", err) + } else if string(data) != "hello" { + t.Fatalf("dst content = %q want hello", string(data)) + } + if got := result.Metadata["destination_path"]; got != dst { + t.Fatalf("destination_path metadata = %v want %v", got, dst) + } + paths, ok := result.Metadata["paths"].([]string) + if !ok || len(paths) != 2 { + t.Fatalf("paths metadata = %#v, want 2-item slice", result.Metadata["paths"]) + } +} + +func TestMoveFileTool_RejectsExistingDestinationWithoutOverwrite(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + src := filepath.Join(workspace, "src.txt") + dst := filepath.Join(workspace, "dst.txt") + if err := os.WriteFile(src, []byte("a"), 0o644); err != nil { + t.Fatalf("seed src: %v", err) + } + if err := os.WriteFile(dst, []byte("b"), 0o644); err != nil { + t.Fatalf("seed dst: %v", err) + } + tool := NewMove(workspace) + args, _ := json.Marshal(map[string]any{ + "source_path": "src.txt", + "destination_path": "dst.txt", + }) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), "already exists") { + t.Fatalf("expected exists error, got %v", err) + } + if data, _ := os.ReadFile(dst); string(data) != "b" { + t.Fatalf("dst content modified, got %q want b", string(data)) + } +} + +func TestMoveFileTool_OverwritesWhenAllowed(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + src := filepath.Join(workspace, "src.txt") + dst := filepath.Join(workspace, "dst.txt") + if err := os.WriteFile(src, []byte("new"), 0o644); err != nil { + t.Fatalf("seed src: %v", err) + } + if err := os.WriteFile(dst, []byte("old"), 0o644); err != nil { + t.Fatalf("seed dst: %v", err) + } + tool := NewMove(workspace) + args, _ := json.Marshal(map[string]any{ + "source_path": "src.txt", + "destination_path": "dst.txt", + "overwrite": true, + }) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err != nil { + t.Fatalf("execute: %v", err) + } + if result.IsError { + t.Fatalf("error result: %s", result.Content) + } + if data, _ := os.ReadFile(dst); string(data) != "new" { + t.Fatalf("dst content = %q want new", string(data)) + } +} + +func TestMoveFileTool_RejectsTraversal(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + src := filepath.Join(workspace, "src.txt") + if err := os.WriteFile(src, []byte("x"), 0o644); err != nil { + t.Fatalf("seed: %v", err) + } + tool := NewMove(workspace) + args, _ := json.Marshal(map[string]any{ + "source_path": "src.txt", + "destination_path": filepath.Join("..", "escape.txt"), + }) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), "escapes workspace") { + t.Fatalf("expected escape error, got %v", err) + } +} + +func TestMoveFileTool_RejectsMissingSource(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewMove(workspace) + args, _ := json.Marshal(map[string]any{ + "source_path": "missing.txt", + "destination_path": "out.txt", + }) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil { + t.Fatalf("expected error for missing source") + } +} + +func TestMoveFileTool_RejectsEmptyPaths(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewMove(workspace) + + for _, tc := range []struct { + name string + args map[string]any + want string + }{ + { + name: "empty source", + args: map[string]any{"source_path": "", "destination_path": "x.txt"}, + want: "source_path is required", + }, + { + name: "empty destination", + args: map[string]any{"source_path": "x.txt", "destination_path": ""}, + want: "destination_path is required", + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + args, _ := json.Marshal(tc.args) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), tc.want) { + t.Fatalf("expected %q, got %v", tc.want, err) + } + }) + } +} + +func TestMoveFileTool_InvalidJSON(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewMove(workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: []byte(`{invalid`), + Workdir: workspace, + }) + if err == nil { + t.Fatalf("expected json error") + } + if !result.IsError { + t.Fatalf("expected error result") + } +} + +func TestMoveFileTool_RejectsDirectorySource(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + sourceDir := filepath.Join(workspace, "srcdir") + if err := os.MkdirAll(sourceDir, 0o755); err != nil { + t.Fatalf("seed dir: %v", err) + } + tool := NewMove(workspace) + args, _ := json.Marshal(map[string]any{ + "source_path": "srcdir", + "destination_path": "moved.txt", + }) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), "must be a file") { + t.Fatalf("expected directory source error, got %v", err) + } +} + +func TestMoveFileTool_RejectsCanceledContext(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewMove(workspace) + args, _ := json.Marshal(map[string]any{ + "source_path": "src.txt", + "destination_path": "dst.txt", + }) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := tool.Execute(ctx, tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) { + t.Fatalf("expected canceled error, got %v", err) + } +} diff --git a/internal/tools/filesystem/remove_dir.go b/internal/tools/filesystem/remove_dir.go new file mode 100644 index 00000000..cc611b6a --- /dev/null +++ b/internal/tools/filesystem/remove_dir.go @@ -0,0 +1,118 @@ +package filesystem + +import ( + "context" + "encoding/json" + "errors" + "os" + "strings" + + "neo-code/internal/security" + "neo-code/internal/tools" +) + +type RemoveDirTool struct { + root string +} + +type removeDirInput struct { + Path string `json:"path"` + Force bool `json:"force,omitempty"` +} + +func NewRemoveDir(root string) *RemoveDirTool { + return &RemoveDirTool{root: root} +} + +func (t *RemoveDirTool) Name() string { + return removeDirToolName +} + +func (t *RemoveDirTool) Description() string { + return "Remove a directory inside the workspace. By default only empty directories are removed; pass force=true to remove recursively." +} + +func (t *RemoveDirTool) Schema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Directory path relative to workspace root, or absolute inside the workspace.", + }, + "force": map[string]any{ + "type": "boolean", + "description": "When true, remove directory and all contents recursively. Defaults to false (empty directory only).", + }, + }, + "required": []string{"path"}, + } +} + +func (t *RemoveDirTool) MicroCompactPolicy() tools.MicroCompactPolicy { + return tools.MicroCompactPolicyCompact +} + +func (t *RemoveDirTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + var args removeDirInput + if err := json.Unmarshal(input.Arguments, &args); err != nil { + return tools.NewErrorResult(t.Name(), "invalid arguments", err.Error(), nil), err + } + if strings.TrimSpace(args.Path) == "" { + err := errors.New(removeDirToolName + ": path is required") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + if err := ctx.Err(); err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + base := effectiveRoot(t.root, input.Workdir) + + _, target, err := tools.ResolveWorkspaceTarget(input, security.TargetTypePath, base, args.Path, resolvePath) + if err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + info, statErr := os.Stat(target) + if statErr != nil { + if os.IsNotExist(statErr) { + return tools.ToolResult{ + Name: t.Name(), + Content: "ok", + Metadata: map[string]any{ + "path": target, + "removed": false, + "noop_write": true, + "force": args.Force, + }, + Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, + }, nil + } + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), statErr), "", nil), statErr + } + if !info.IsDir() { + err := errors.New(removeDirToolName + ": path is not a directory; use filesystem_delete_file") + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + + if args.Force { + if err := os.RemoveAll(target); err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + } else { + if err := os.Remove(target); err != nil { + return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err + } + } + + return tools.ToolResult{ + Name: t.Name(), + Content: "ok", + Metadata: map[string]any{ + "path": target, + "removed": true, + "force": args.Force, + }, + Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, + }, nil +} diff --git a/internal/tools/filesystem/remove_dir_test.go b/internal/tools/filesystem/remove_dir_test.go new file mode 100644 index 00000000..ae9b69ab --- /dev/null +++ b/internal/tools/filesystem/remove_dir_test.go @@ -0,0 +1,166 @@ +package filesystem + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "neo-code/internal/tools" +) + +func TestRemoveDirTool_RemovesEmptyDir(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + dir := filepath.Join(workspace, "empty") + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("seed: %v", err) + } + tool := NewRemoveDir(workspace) + args, _ := json.Marshal(map[string]any{"path": "empty"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err != nil { + t.Fatalf("execute: %v", err) + } + if result.IsError { + t.Fatalf("error result: %s", result.Content) + } + if !result.Facts.WorkspaceWrite { + t.Fatalf("expected WorkspaceWrite=true") + } + if _, err := os.Stat(dir); !os.IsNotExist(err) { + t.Fatalf("dir still exists: %v", err) + } +} + +func TestRemoveDirTool_RefusesNonEmptyWithoutForce(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + dir := filepath.Join(workspace, "full") + child := filepath.Join(dir, "x.txt") + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("seed dir: %v", err) + } + if err := os.WriteFile(child, []byte("x"), 0o644); err != nil { + t.Fatalf("seed file: %v", err) + } + tool := NewRemoveDir(workspace) + args, _ := json.Marshal(map[string]any{"path": "full"}) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil { + t.Fatalf("expected error for non-empty directory") + } + if _, err := os.Stat(child); err != nil { + t.Fatalf("child file destroyed: %v", err) + } +} + +func TestRemoveDirTool_ForceRemovesRecursive(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + dir := filepath.Join(workspace, "tree") + nested := filepath.Join(dir, "a", "b") + if err := os.MkdirAll(nested, 0o755); err != nil { + t.Fatalf("seed: %v", err) + } + if err := os.WriteFile(filepath.Join(nested, "c.txt"), []byte("c"), 0o644); err != nil { + t.Fatalf("seed file: %v", err) + } + tool := NewRemoveDir(workspace) + args, _ := json.Marshal(map[string]any{ + "path": "tree", + "force": true, + }) + if _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }); err != nil { + t.Fatalf("execute: %v", err) + } + if _, err := os.Stat(dir); !os.IsNotExist(err) { + t.Fatalf("dir still exists: %v", err) + } +} + +func TestRemoveDirTool_MissingDirReturnsNoop(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewRemoveDir(workspace) + args, _ := json.Marshal(map[string]any{"path": "phantom"}) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err != nil { + t.Fatalf("execute: %v", err) + } + if got, _ := result.Metadata["noop_write"].(bool); !got { + t.Fatalf("noop_write = %v", result.Metadata["noop_write"]) + } + if got, _ := result.Metadata["removed"].(bool); got { + t.Fatalf("removed = %v want false", result.Metadata["removed"]) + } +} + +func TestRemoveDirTool_RejectsFile(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + target := filepath.Join(workspace, "afile.txt") + if err := os.WriteFile(target, []byte("x"), 0o644); err != nil { + t.Fatalf("seed: %v", err) + } + tool := NewRemoveDir(workspace) + args, _ := json.Marshal(map[string]any{"path": "afile.txt"}) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), "not a directory") { + t.Fatalf("expected directory-required error, got %v", err) + } +} + +func TestRemoveDirTool_RejectsTraversal(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewRemoveDir(workspace) + args, _ := json.Marshal(map[string]any{"path": filepath.Join("..", "escape")}) + _, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err == nil || !strings.Contains(err.Error(), "escapes workspace") { + t.Fatalf("expected escape error, got %v", err) + } +} + +func TestRemoveDirTool_InvalidJSON(t *testing.T) { + t.Parallel() + workspace := t.TempDir() + tool := NewRemoveDir(workspace) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: []byte(`{invalid`), + Workdir: workspace, + }) + if err == nil { + t.Fatalf("expected json error") + } + if !result.IsError { + t.Fatalf("expected error result") + } +} diff --git a/internal/tools/filesystem/tool_metadata_test.go b/internal/tools/filesystem/tool_metadata_test.go new file mode 100644 index 00000000..27a6c873 --- /dev/null +++ b/internal/tools/filesystem/tool_metadata_test.go @@ -0,0 +1,96 @@ +package filesystem + +import ( + "errors" + "testing" + + "neo-code/internal/tools" +) + +func TestFilesystemToolMetadata(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + toolName string + description string + schema map[string]any + policy tools.MicroCompactPolicy + }{ + { + name: "copy", + toolName: NewCopy("/workspace").Name(), + description: NewCopy("/workspace").Description(), + schema: NewCopy("/workspace").Schema(), + policy: NewCopy("/workspace").MicroCompactPolicy(), + }, + { + name: "move", + toolName: NewMove("/workspace").Name(), + description: NewMove("/workspace").Description(), + schema: NewMove("/workspace").Schema(), + policy: NewMove("/workspace").MicroCompactPolicy(), + }, + { + name: "create dir", + toolName: NewCreateDir("/workspace").Name(), + description: NewCreateDir("/workspace").Description(), + schema: NewCreateDir("/workspace").Schema(), + policy: NewCreateDir("/workspace").MicroCompactPolicy(), + }, + { + name: "delete file", + toolName: NewDelete("/workspace").Name(), + description: NewDelete("/workspace").Description(), + schema: NewDelete("/workspace").Schema(), + policy: NewDelete("/workspace").MicroCompactPolicy(), + }, + { + name: "remove dir", + toolName: NewRemoveDir("/workspace").Name(), + description: NewRemoveDir("/workspace").Description(), + schema: NewRemoveDir("/workspace").Schema(), + policy: NewRemoveDir("/workspace").MicroCompactPolicy(), + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if tt.toolName == "" { + t.Fatal("tool name should not be empty") + } + if tt.description == "" { + t.Fatal("description should not be empty") + } + if got, _ := tt.schema["type"].(string); got != "object" { + t.Fatalf("schema type = %q, want object", got) + } + required, ok := tt.schema["required"].([]string) + if !ok || len(required) == 0 { + t.Fatalf("required schema fields missing: %#v", tt.schema["required"]) + } + if tt.policy != tools.MicroCompactPolicyCompact { + t.Fatalf("policy = %q, want compact", tt.policy) + } + }) + } +} + +func TestMoveCrossDeviceHelper(t *testing.T) { + t.Parallel() + + if !isCrossDeviceLinkError(errors.New("rename failed: cross-device link")) { + t.Fatal("cross-device error should be detected") + } + if !isCrossDeviceLinkError(errors.New("EXDEV: invalid cross-device link")) { + t.Fatal("EXDEV error should be detected") + } + if isCrossDeviceLinkError(errors.New("permission denied")) { + t.Fatal("unrelated error should not be detected as cross-device") + } + if isCrossDeviceLinkError(nil) { + t.Fatal("nil error should not be detected as cross-device") + } +} diff --git a/internal/tools/names.go b/internal/tools/names.go index 02d2cd88..625a1dc6 100644 --- a/internal/tools/names.go +++ b/internal/tools/names.go @@ -2,18 +2,23 @@ package tools // Tool name constants are shared across tool implementations, context policies, and tests. const ( - ToolNameBash = "bash" - ToolNameWebFetch = "webfetch" - ToolNameFilesystemReadFile = "filesystem_read_file" - ToolNameFilesystemWriteFile = "filesystem_write_file" - ToolNameFilesystemGrep = "filesystem_grep" - ToolNameFilesystemGlob = "filesystem_glob" - ToolNameFilesystemEdit = "filesystem_edit" - ToolNameTodoWrite = "todo_write" - ToolNameSpawnSubAgent = "spawn_subagent" - ToolNameMemoRemember = "memo_remember" - ToolNameMemoRecall = "memo_recall" - ToolNameMemoList = "memo_list" - ToolNameMemoRemove = "memo_remove" - ToolNameDiagnose = "diagnose" + ToolNameBash = "bash" + ToolNameWebFetch = "webfetch" + ToolNameFilesystemReadFile = "filesystem_read_file" + ToolNameFilesystemWriteFile = "filesystem_write_file" + ToolNameFilesystemGrep = "filesystem_grep" + ToolNameFilesystemGlob = "filesystem_glob" + ToolNameFilesystemEdit = "filesystem_edit" + ToolNameFilesystemMoveFile = "filesystem_move_file" + ToolNameFilesystemCopyFile = "filesystem_copy_file" + ToolNameFilesystemDeleteFile = "filesystem_delete_file" + ToolNameFilesystemCreateDir = "filesystem_create_dir" + ToolNameFilesystemRemoveDir = "filesystem_remove_dir" + ToolNameTodoWrite = "todo_write" + ToolNameSpawnSubAgent = "spawn_subagent" + ToolNameMemoRemember = "memo_remember" + ToolNameMemoRecall = "memo_recall" + ToolNameMemoList = "memo_list" + ToolNameMemoRemove = "memo_remove" + ToolNameDiagnose = "diagnose" ) diff --git a/internal/tools/permission_mapper.go b/internal/tools/permission_mapper.go index c43f196b..ed92a2b0 100644 --- a/internal/tools/permission_mapper.go +++ b/internal/tools/permission_mapper.go @@ -92,6 +92,41 @@ func buildPermissionAction(input ToolCallInput) (security.Action, error) { action.Payload.Target = extractStringArgument(input.Arguments, "path") action.Payload.SandboxTargetType = security.TargetTypePath action.Payload.SandboxTarget = action.Payload.Target + case ToolNameFilesystemMoveFile: + action.Type = security.ActionTypeWrite + action.Payload.Operation = "move_file" + action.Payload.TargetType = security.TargetTypePath + action.Payload.Target = extractStringArgument(input.Arguments, "destination_path") + action.Payload.SandboxTargetType = security.TargetTypePath + action.Payload.SandboxTarget = action.Payload.Target + case ToolNameFilesystemCopyFile: + action.Type = security.ActionTypeWrite + action.Payload.Operation = "copy_file" + action.Payload.TargetType = security.TargetTypePath + action.Payload.Target = extractStringArgument(input.Arguments, "destination_path") + action.Payload.SandboxTargetType = security.TargetTypePath + action.Payload.SandboxTarget = action.Payload.Target + case ToolNameFilesystemDeleteFile: + action.Type = security.ActionTypeWrite + action.Payload.Operation = "delete_file" + action.Payload.TargetType = security.TargetTypePath + action.Payload.Target = extractStringArgument(input.Arguments, "path") + action.Payload.SandboxTargetType = security.TargetTypePath + action.Payload.SandboxTarget = action.Payload.Target + case ToolNameFilesystemCreateDir: + action.Type = security.ActionTypeWrite + action.Payload.Operation = "create_dir" + action.Payload.TargetType = security.TargetTypePath + action.Payload.Target = extractStringArgument(input.Arguments, "path") + action.Payload.SandboxTargetType = security.TargetTypePath + action.Payload.SandboxTarget = action.Payload.Target + case ToolNameFilesystemRemoveDir: + action.Type = security.ActionTypeWrite + action.Payload.Operation = "remove_dir" + action.Payload.TargetType = security.TargetTypePath + action.Payload.Target = extractStringArgument(input.Arguments, "path") + action.Payload.SandboxTargetType = security.TargetTypePath + action.Payload.SandboxTarget = action.Payload.Target case ToolNameTodoWrite: action.Type = security.ActionTypeWrite action.Payload.Operation = "todo_write" diff --git a/internal/tools/permission_mapper_test.go b/internal/tools/permission_mapper_test.go new file mode 100644 index 00000000..b26159fa --- /dev/null +++ b/internal/tools/permission_mapper_test.go @@ -0,0 +1,118 @@ +package tools + +import ( + "testing" + + "neo-code/internal/security" +) + +func TestBuildPermissionActionRejectsEmptyToolName(t *testing.T) { + if _, err := buildPermissionAction(ToolCallInput{}); err == nil { + t.Fatal("expected error for empty tool name") + } +} + +func TestBuildPermissionActionBashGitDefaultsSandboxAndSemanticFields(t *testing.T) { + action, err := buildPermissionAction(ToolCallInput{ + Name: ToolNameBash, + SessionID: "session-1", + TaskID: "task-1", + AgentID: "agent-1", + Arguments: []byte(`{"command":"git status --short"}`), + }) + if err != nil { + t.Fatalf("buildPermissionAction() error = %v", err) + } + + if action.Type != security.ActionTypeBash { + t.Fatalf("action.Type = %q, want %q", action.Type, security.ActionTypeBash) + } + if action.Payload.Operation != "git_status" { + t.Fatalf("operation = %q, want git_status", action.Payload.Operation) + } + if action.Payload.TargetType != security.TargetTypeCommand { + t.Fatalf("target type = %q, want %q", action.Payload.TargetType, security.TargetTypeCommand) + } + if action.Payload.Target != "git status --short" { + t.Fatalf("target = %q", action.Payload.Target) + } + if action.Payload.Resource == "" { + t.Fatal("expected git resource to be set") + } + if action.Payload.SemanticType != "git" { + t.Fatalf("semantic type = %q, want git", action.Payload.SemanticType) + } + if action.Payload.PermissionFingerprint == "" { + t.Fatal("expected permission fingerprint to be populated") + } + if action.Payload.SandboxTargetType != security.TargetTypeDirectory { + t.Fatalf("sandbox target type = %q, want %q", action.Payload.SandboxTargetType, security.TargetTypeDirectory) + } + if action.Payload.SandboxTarget != "." { + t.Fatalf("sandbox target = %q, want .", action.Payload.SandboxTarget) + } +} + +func TestBuildPermissionActionReadFileFallsBackForWindowsPathLikePayload(t *testing.T) { + action, err := buildPermissionAction(ToolCallInput{ + Name: ToolNameFilesystemReadFile, + Arguments: []byte(`{"path":"C:\repo\main.go"}`), + }) + if err != nil { + t.Fatalf("buildPermissionAction() error = %v", err) + } + + if action.Type != security.ActionTypeRead { + t.Fatalf("action.Type = %q, want %q", action.Type, security.ActionTypeRead) + } + if action.Payload.Target != `C:\repo\main.go` { + t.Fatalf("target = %q", action.Payload.Target) + } + if action.Payload.SandboxTarget != `C:\repo\main.go` { + t.Fatalf("sandbox target = %q", action.Payload.SandboxTarget) + } +} + +func TestBuildPermissionActionSpawnSubAgentUsesAllowedPathAndStableTarget(t *testing.T) { + action, err := buildPermissionAction(ToolCallInput{ + Name: ToolNameSpawnSubAgent, + Workdir: "/workspace", + Arguments: []byte(`{ + "id":"fallback", + "items":[{"id":"task-a"},{"id":"task-b"}], + "allowed_paths":["/workspace/pkg"] + }`), + }) + if err != nil { + t.Fatalf("buildPermissionAction() error = %v", err) + } + + if action.Type != security.ActionTypeWrite { + t.Fatalf("action.Type = %q, want %q", action.Type, security.ActionTypeWrite) + } + if action.Payload.Target != "task-a,task-b" { + t.Fatalf("target = %q, want task-a,task-b", action.Payload.Target) + } + if action.Payload.SandboxTarget != "/workspace/pkg" { + t.Fatalf("sandbox target = %q, want /workspace/pkg", action.Payload.SandboxTarget) + } +} + +func TestBuildPermissionActionSupportsMCPIdentity(t *testing.T) { + action, err := buildPermissionAction(ToolCallInput{ + Name: "MCP.GitHub.Create_Issue", + }) + if err != nil { + t.Fatalf("buildPermissionAction() error = %v", err) + } + + if action.Type != security.ActionTypeMCP { + t.Fatalf("action.Type = %q, want %q", action.Type, security.ActionTypeMCP) + } + if action.Payload.Target != "mcp.github.create_issue" { + t.Fatalf("target = %q, want mcp.github.create_issue", action.Payload.Target) + } + if got := mcpServerTarget("MCP.GitHub.Create_Issue"); got != "mcp.github" { + t.Fatalf("mcpServerTarget() = %q, want mcp.github", got) + } +}