From 75abe3fd1dd234469ad341dd13904a604df36d78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E7=BA=AA?= <3049035704@qq.com> Date: Sun, 10 May 2026 01:39:42 +0800 Subject: [PATCH 1/3] add rollback tooling --- internal/rollback/store.go | 906 ++++++++++++++++++++ internal/rollback/store_test.go | 190 ++++ internal/tools/apply_patch.go | 148 +++- internal/tools/replace_in_file.go | 23 +- internal/tools/rollback_helpers.go | 77 ++ internal/tools/rollback_integration_test.go | 188 ++++ internal/tools/write_file.go | 26 +- tui/component_command_utils.go | 1 + tui/component_rollback_command.go | 122 +++ tui/component_rollback_command_test.go | 79 ++ tui/component_slash_entry.go | 2 + tui/model.go | 3 +- tui/model_test.go | 2 +- 13 files changed, 1746 insertions(+), 21 deletions(-) create mode 100644 internal/rollback/store.go create mode 100644 internal/rollback/store_test.go create mode 100644 internal/tools/rollback_helpers.go create mode 100644 internal/tools/rollback_integration_test.go create mode 100644 tui/component_rollback_command.go create mode 100644 tui/component_rollback_command_test.go diff --git a/internal/rollback/store.go b/internal/rollback/store.go new file mode 100644 index 00000000..6cb1d62d --- /dev/null +++ b/internal/rollback/store.go @@ -0,0 +1,906 @@ +package rollback + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "time" + + configpkg "github.com/1024XEngineer/bytemind/internal/config" + corepkg "github.com/1024XEngineer/bytemind/internal/core" + storagepkg "github.com/1024XEngineer/bytemind/internal/storage" +) + +const maxSnapshotBytes int64 = 5 * 1024 * 1024 + +type OpType string + +const ( + OpTypeAdd OpType = "add" + OpTypeUpdate OpType = "update" + OpTypeDelete OpType = "delete" + OpTypeMove OpType = "move" +) + +type Status string + +const ( + StatusPending Status = "pending" + StatusCommitted Status = "committed" + StatusRolledBack Status = "rolled_back" + StatusRollbackFailed Status = "rollback_failed" + StatusAborted Status = "aborted" +) + +type BeginOptions struct { + Workspace string + SessionID string + TaskID string + TraceID string + ToolName string + Actor string +} + +type FileTarget struct { + Path string + AbsPath string + NewPath string + NewAbsPath string + OpType OpType +} + +type Operation struct { + OperationID string `json:"operation_id"` + SessionID string `json:"session_id,omitempty"` + TaskID string `json:"task_id,omitempty"` + TraceID string `json:"trace_id,omitempty"` + Workspace string `json:"workspace"` + ToolName string `json:"tool_name"` + Actor string `json:"actor,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Status Status `json:"status"` + AffectedFiles []FileChange `json:"affected_files"` + RollbackAttempts int `json:"rollback_attempts"` + LastError string `json:"last_error,omitempty"` +} + +type FileChange struct { + Path string `json:"path"` + AbsPath string `json:"abs_path,omitempty"` + NewPath string `json:"new_path,omitempty"` + NewAbsPath string `json:"new_abs_path,omitempty"` + OpType OpType `json:"op_type"` + FileAbsentBefore bool `json:"file_absent_before,omitempty"` + FileAbsentAfter bool `json:"file_absent_after,omitempty"` + BeforeHash string `json:"before_hash,omitempty"` + AfterHash string `json:"after_hash,omitempty"` + BeforeSnapshot string `json:"before_snapshot,omitempty"` + SizeBytes int64 `json:"size_bytes,omitempty"` +} + +type Store struct { + root string + entriesDir string + blobsDir string +} + +func NewDefaultStore() (*Store, error) { + home, err := configpkg.ResolveHomeDir() + if err != nil { + return nil, err + } + return NewStore(filepath.Join(home, "rollback")) +} + +func NewStore(root string) (*Store, error) { + root = strings.TrimSpace(root) + if root == "" { + return nil, errors.New("rollback root is required") + } + root, err := filepath.Abs(root) + if err != nil { + return nil, err + } + store := &Store{ + root: root, + entriesDir: filepath.Join(root, "entries"), + blobsDir: filepath.Join(root, "blobs"), + } + if err := os.MkdirAll(store.entriesDir, 0o755); err != nil { + return nil, err + } + if err := os.MkdirAll(store.blobsDir, 0o755); err != nil { + return nil, err + } + return store, nil +} + +func (s *Store) Begin(ctx context.Context, opts BeginOptions, targets []FileTarget) (*Operation, error) { + if s == nil { + return nil, errors.New("rollback store is unavailable") + } + if len(targets) == 0 { + return nil, errors.New("rollback operation requires at least one file") + } + workspace, err := filepath.Abs(strings.TrimSpace(opts.Workspace)) + if err != nil { + return nil, err + } + workspace = filepath.Clean(workspace) + now := time.Now().UTC() + op := &Operation{ + OperationID: newOperationID(now), + SessionID: strings.TrimSpace(opts.SessionID), + TaskID: strings.TrimSpace(opts.TaskID), + TraceID: strings.TrimSpace(opts.TraceID), + Workspace: workspace, + ToolName: strings.TrimSpace(opts.ToolName), + Actor: strings.TrimSpace(opts.Actor), + CreatedAt: now, + UpdatedAt: now, + Status: StatusPending, + AffectedFiles: make([]FileChange, 0, len(targets)), + } + if op.Actor == "" { + op.Actor = "agent" + } + + for i, target := range targets { + change, err := s.captureBeforeState(op, workspace, i, target) + if err != nil { + return nil, err + } + op.AffectedFiles = append(op.AffectedFiles, change) + } + + if err := s.saveOperation(op); err != nil { + return nil, err + } + return op, nil +} + +func (s *Store) Commit(ctx context.Context, op *Operation) error { + if s == nil || op == nil { + return nil + } + for i := range op.AffectedFiles { + change := &op.AffectedFiles[i] + switch change.OpType { + case OpTypeAdd, OpTypeUpdate: + hash, absent, err := hashCurrentFile(resolveChangePath(*op, *change)) + if err != nil { + return err + } + if absent { + return fmt.Errorf("rollback commit failed: %s is absent after %s", change.Path, change.OpType) + } + change.AfterHash = hash + change.FileAbsentAfter = false + case OpTypeDelete: + _, absent, err := hashCurrentFile(resolveChangePath(*op, *change)) + if err != nil { + return err + } + if !absent { + return fmt.Errorf("rollback commit failed: %s still exists after delete", change.Path) + } + change.AfterHash = "" + change.FileAbsentAfter = true + case OpTypeMove: + oldPath := resolveChangePath(*op, *change) + newPath := resolveNewChangePath(*op, *change) + _, oldAbsent, err := hashCurrentFile(oldPath) + if err != nil { + return err + } + if !oldAbsent { + return fmt.Errorf("rollback commit failed: %s still exists after move", change.Path) + } + hash, newAbsent, err := hashCurrentFile(newPath) + if err != nil { + return err + } + if newAbsent { + return fmt.Errorf("rollback commit failed: %s is absent after move", change.NewPath) + } + change.AfterHash = hash + change.FileAbsentAfter = true + default: + return fmt.Errorf("unsupported rollback op type %q", change.OpType) + } + } + op.Status = StatusCommitted + op.UpdatedAt = time.Now().UTC() + op.LastError = "" + if err := s.saveOperation(op); err != nil { + return err + } + s.appendAudit(ctx, op, "rollback_operation_committed", "success", "") + return nil +} + +func (s *Store) Abort(ctx context.Context, op *Operation, reason string) error { + if s == nil || op == nil { + return nil + } + op.Status = StatusAborted + op.UpdatedAt = time.Now().UTC() + op.LastError = strings.TrimSpace(reason) + if err := s.saveOperation(op); err != nil { + return err + } + s.appendAudit(ctx, op, "rollback_operation_aborted", "aborted", reason) + return nil +} + +func (s *Store) AbortAndRestore(ctx context.Context, op *Operation, reason string, writableRoots ...string) error { + if s == nil || op == nil { + return nil + } + restoreErr := s.restoreBefore(op, writableRoots...) + status := StatusAborted + lastError := strings.TrimSpace(reason) + result := "aborted" + if restoreErr != nil { + status = StatusRollbackFailed + result = "restore_failed" + if lastError != "" { + lastError += "; " + } + lastError += "automatic restore failed: " + restoreErr.Error() + } + op.Status = status + op.UpdatedAt = time.Now().UTC() + op.LastError = lastError + if err := s.saveOperation(op); err != nil && restoreErr == nil { + return err + } + s.appendAudit(ctx, op, "rollback_operation_aborted", result, lastError) + return restoreErr +} + +func (s *Store) ListRecent(ctx context.Context, workspace string, limit int) ([]Operation, error) { + if s == nil { + return nil, errors.New("rollback store is unavailable") + } + if limit <= 0 { + limit = 10 + } + ops, err := s.loadAll() + if err != nil { + return nil, err + } + filtered := make([]Operation, 0, len(ops)) + for _, op := range ops { + if ctx != nil && ctx.Err() != nil { + return nil, ctx.Err() + } + if op.Status != StatusCommitted { + continue + } + if !sameWorkspace(op.Workspace, workspace) { + continue + } + filtered = append(filtered, op) + } + sort.Slice(filtered, func(i, j int) bool { + return filtered[i].CreatedAt.After(filtered[j].CreatedAt) + }) + if len(filtered) > limit { + filtered = filtered[:limit] + } + return filtered, nil +} + +func (s *Store) RollbackLast(ctx context.Context, workspace string, writableRoots ...string) (*Operation, error) { + ops, err := s.ListRecent(ctx, workspace, 1) + if err != nil { + return nil, err + } + if len(ops) == 0 { + return nil, errors.New("no committed rollback operation found for this workspace") + } + return s.Rollback(ctx, workspace, ops[0].OperationID, writableRoots...) +} + +func (s *Store) Rollback(ctx context.Context, workspace, operationID string, writableRoots ...string) (*Operation, error) { + if s == nil { + return nil, errors.New("rollback store is unavailable") + } + op, err := s.findOperation(workspace, operationID) + if err != nil { + return nil, err + } + if op.Status != StatusCommitted { + return nil, fmt.Errorf("rollback operation %s is %s, not committed", op.OperationID, op.Status) + } + if err := validateOperationPaths(op, workspace, writableRoots...); err != nil { + return nil, err + } + op.RollbackAttempts++ + op.UpdatedAt = time.Now().UTC() + + if err := checkConflicts(op); err != nil { + op.LastError = err.Error() + _ = s.saveOperation(&op) + s.appendAudit(ctx, &op, "rollback_operation_blocked", "conflict", err.Error()) + return nil, err + } + + current, err := captureCurrentStates(op) + if err != nil { + op.LastError = err.Error() + _ = s.saveOperation(&op) + return nil, err + } + if err := s.applyRollback(&op); err != nil { + restoreErr := restoreCurrentStates(current) + op.Status = StatusRollbackFailed + op.LastError = err.Error() + if restoreErr != nil { + op.LastError += "; failed to restore rollback attempt state: " + restoreErr.Error() + } + _ = s.saveOperation(&op) + s.appendAudit(ctx, &op, "rollback_operation_failed", "failed", op.LastError) + return nil, err + } + op.Status = StatusRolledBack + op.UpdatedAt = time.Now().UTC() + op.LastError = "" + if err := s.saveOperation(&op); err != nil { + return nil, err + } + s.appendAudit(ctx, &op, "rollback_operation_executed", "success", "") + return &op, nil +} + +func (s *Store) captureBeforeState(op *Operation, workspace string, index int, target FileTarget) (FileChange, error) { + opType := target.OpType + if opType == "" { + opType = OpTypeUpdate + } + absPath, err := normalizeTargetPath(workspace, target.AbsPath, target.Path) + if err != nil { + return FileChange{}, err + } + path := strings.TrimSpace(target.Path) + if path == "" { + path = displayPath(workspace, absPath) + } + change := FileChange{ + Path: filepath.ToSlash(path), + AbsPath: absPath, + OpType: opType, + } + if opType == OpTypeMove { + newAbs, err := normalizeTargetPath(workspace, target.NewAbsPath, target.NewPath) + if err != nil { + return FileChange{}, err + } + newPath := strings.TrimSpace(target.NewPath) + if newPath == "" { + newPath = displayPath(workspace, newAbs) + } + change.NewPath = filepath.ToSlash(newPath) + change.NewAbsPath = newAbs + } + + state, err := readSnapshotCandidate(absPath) + if err != nil { + return FileChange{}, err + } + change.FileAbsentBefore = !state.exists + if !state.exists { + return change, nil + } + change.BeforeHash = hashBytes(state.data) + change.SizeBytes = int64(len(state.data)) + blobRel := filepath.ToSlash(filepath.Join(op.OperationID, fmt.Sprintf("%03d.blob", index))) + blobAbs := filepath.Join(s.blobsDir, filepath.FromSlash(blobRel)) + if err := os.MkdirAll(filepath.Dir(blobAbs), 0o755); err != nil { + return FileChange{}, err + } + if err := os.WriteFile(blobAbs, state.data, 0o644); err != nil { + return FileChange{}, err + } + change.BeforeSnapshot = blobRel + return change, nil +} + +func (s *Store) restoreBefore(op *Operation, writableRoots ...string) error { + if op == nil { + return nil + } + if err := validateOperationPaths(*op, op.Workspace, writableRoots...); err != nil { + return err + } + current, err := captureCurrentStates(*op) + if err != nil { + return err + } + if err := s.applyBeforeState(op); err != nil { + if restoreErr := restoreCurrentStates(current); restoreErr != nil { + return fmt.Errorf("%w; also failed to restore current state: %v", err, restoreErr) + } + return err + } + return nil +} + +func (s *Store) applyRollback(op *Operation) error { + return s.applyBeforeState(op) +} + +func (s *Store) applyBeforeState(op *Operation) error { + if op == nil { + return nil + } + for _, change := range op.AffectedFiles { + path := resolveChangePath(*op, change) + switch change.OpType { + case OpTypeAdd: + if err := removeIfExists(path); err != nil { + return err + } + case OpTypeUpdate, OpTypeDelete: + if err := s.restoreSnapshot(path, change); err != nil { + return err + } + case OpTypeMove: + if err := removeIfExists(resolveNewChangePath(*op, change)); err != nil { + return err + } + if err := s.restoreSnapshot(path, change); err != nil { + return err + } + default: + return fmt.Errorf("unsupported rollback op type %q", change.OpType) + } + } + return nil +} + +func (s *Store) restoreSnapshot(path string, change FileChange) error { + if change.FileAbsentBefore { + return removeIfExists(path) + } + if strings.TrimSpace(change.BeforeSnapshot) == "" { + return fmt.Errorf("rollback snapshot missing for %s", change.Path) + } + data, err := os.ReadFile(filepath.Join(s.blobsDir, filepath.FromSlash(change.BeforeSnapshot))) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + return os.WriteFile(path, data, 0o644) +} + +func (s *Store) loadAll() ([]Operation, error) { + entries, err := os.ReadDir(s.entriesDir) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return nil, err + } + ops := make([]Operation, 0, len(entries)) + for _, entry := range entries { + if entry.IsDir() || !strings.EqualFold(filepath.Ext(entry.Name()), ".json") { + continue + } + op, err := s.loadOperationFile(filepath.Join(s.entriesDir, entry.Name())) + if err != nil { + return nil, err + } + ops = append(ops, op) + } + return ops, nil +} + +func (s *Store) findOperation(workspace, operationID string) (Operation, error) { + operationID = strings.TrimSpace(operationID) + if operationID == "" { + return Operation{}, errors.New("rollback operation id is required") + } + ops, err := s.loadAll() + if err != nil { + return Operation{}, err + } + matches := make([]Operation, 0, 1) + for _, op := range ops { + if !sameWorkspace(op.Workspace, workspace) { + continue + } + if op.OperationID == operationID || strings.HasPrefix(op.OperationID, operationID) { + matches = append(matches, op) + } + } + if len(matches) == 0 { + return Operation{}, fmt.Errorf("rollback operation %s was not found for this workspace", operationID) + } + if len(matches) > 1 { + return Operation{}, fmt.Errorf("rollback operation id %s is ambiguous", operationID) + } + return matches[0], nil +} + +func (s *Store) loadOperationFile(path string) (Operation, error) { + data, err := os.ReadFile(path) + if err != nil { + return Operation{}, err + } + var op Operation + if err := json.Unmarshal(data, &op); err != nil { + return Operation{}, fmt.Errorf("read rollback operation %s: %w", path, err) + } + return op, nil +} + +func (s *Store) saveOperation(op *Operation) error { + if s == nil || op == nil { + return nil + } + if strings.TrimSpace(op.OperationID) == "" { + return errors.New("rollback operation id is required") + } + op.UpdatedAt = op.UpdatedAt.UTC() + if op.CreatedAt.IsZero() { + op.CreatedAt = time.Now().UTC() + } else { + op.CreatedAt = op.CreatedAt.UTC() + } + path := filepath.Join(s.entriesDir, op.OperationID+".json") + data, err := json.MarshalIndent(op, "", " ") + if err != nil { + return err + } + tmp := path + ".tmp" + if err := os.WriteFile(tmp, append(data, '\n'), 0o644); err != nil { + return err + } + return os.Rename(tmp, path) +} + +func (s *Store) appendAudit(ctx context.Context, op *Operation, action, result, reason string) { + if op == nil { + return + } + audit, err := storagepkg.NewDefaultAuditStore() + if err != nil { + return + } + metadata := map[string]string{ + "operation_id": op.OperationID, + "tool_name": op.ToolName, + "workspace": op.Workspace, + "file_count": strconv.Itoa(len(op.AffectedFiles)), + } + if strings.TrimSpace(reason) != "" { + metadata["reason"] = strings.TrimSpace(reason) + } + _ = audit.Append(ctx, storagepkg.AuditEvent{ + SessionID: corepkg.SessionID(op.SessionID), + TaskID: corepkg.TaskID(op.TaskID), + TraceID: corepkg.TraceID(op.TraceID), + Actor: op.Actor, + Action: action, + Result: result, + Metadata: metadata, + }) +} + +type snapshotState struct { + exists bool + data []byte + mode os.FileMode +} + +func readSnapshotCandidate(path string) (snapshotState, error) { + info, err := os.Stat(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return snapshotState{}, nil + } + return snapshotState{}, err + } + if info.IsDir() { + return snapshotState{}, fmt.Errorf("rollback snapshot target is a directory: %s", path) + } + if info.Size() > maxSnapshotBytes { + return snapshotState{}, fmt.Errorf("rollback snapshot target exceeds %d bytes: %s", maxSnapshotBytes, path) + } + data, err := os.ReadFile(path) + if err != nil { + return snapshotState{}, err + } + if !isText(data) { + return snapshotState{}, fmt.Errorf("rollback snapshot target is not a text file: %s", path) + } + return snapshotState{exists: true, data: data, mode: info.Mode().Perm()}, nil +} + +func captureCurrentStates(op Operation) (map[string]snapshotState, error) { + paths := operationTouchedPaths(op) + states := make(map[string]snapshotState, len(paths)) + for _, path := range paths { + state, err := readCurrentState(path) + if err != nil { + return nil, err + } + states[path] = state + } + return states, nil +} + +func readCurrentState(path string) (snapshotState, error) { + info, err := os.Stat(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return snapshotState{}, nil + } + return snapshotState{}, err + } + if info.IsDir() { + return snapshotState{}, fmt.Errorf("rollback target is a directory: %s", path) + } + data, err := os.ReadFile(path) + if err != nil { + return snapshotState{}, err + } + return snapshotState{exists: true, data: data, mode: info.Mode().Perm()}, nil +} + +func restoreCurrentStates(states map[string]snapshotState) error { + for path, state := range states { + if !state.exists { + if err := removeIfExists(path); err != nil { + return err + } + continue + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + mode := state.mode + if mode == 0 { + mode = 0o644 + } + if err := os.WriteFile(path, state.data, mode); err != nil { + return err + } + } + return nil +} + +func checkConflicts(op Operation) error { + for _, change := range op.AffectedFiles { + switch change.OpType { + case OpTypeAdd, OpTypeUpdate: + if err := requireCurrentHash(change.Path, resolveChangePath(op, change), change.AfterHash); err != nil { + return err + } + case OpTypeDelete: + _, absent, err := hashCurrentFile(resolveChangePath(op, change)) + if err != nil { + return err + } + if !absent { + return fmt.Errorf("rollback blocked: %s changed after delete", change.Path) + } + case OpTypeMove: + _, oldAbsent, err := hashCurrentFile(resolveChangePath(op, change)) + if err != nil { + return err + } + if !oldAbsent { + return fmt.Errorf("rollback blocked: %s changed after move", change.Path) + } + if err := requireCurrentHash(change.NewPath, resolveNewChangePath(op, change), change.AfterHash); err != nil { + return err + } + default: + return fmt.Errorf("unsupported rollback op type %q", change.OpType) + } + } + return nil +} + +func requireCurrentHash(displayPath, absPath, expected string) error { + hash, absent, err := hashCurrentFile(absPath) + if err != nil { + return err + } + if absent { + return fmt.Errorf("rollback blocked: %s is missing", displayPath) + } + if hash != expected { + return fmt.Errorf("rollback blocked: %s changed after operation", displayPath) + } + return nil +} + +func hashCurrentFile(path string) (hash string, absent bool, err error) { + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return "", true, nil + } + return "", false, err + } + return hashBytes(data), false, nil +} + +func operationTouchedPaths(op Operation) []string { + seen := map[string]struct{}{} + paths := make([]string, 0, len(op.AffectedFiles)) + add := func(path string) { + path = filepath.Clean(strings.TrimSpace(path)) + if path == "" { + return + } + if _, ok := seen[path]; ok { + return + } + seen[path] = struct{}{} + paths = append(paths, path) + } + for _, change := range op.AffectedFiles { + add(resolveChangePath(op, change)) + if change.OpType == OpTypeMove { + add(resolveNewChangePath(op, change)) + } + } + return paths +} + +func validateOperationPaths(op Operation, workspace string, writableRoots ...string) error { + workspace = strings.TrimSpace(workspace) + if workspace == "" { + workspace = op.Workspace + } + allowed, err := allowedRoots(workspace, writableRoots...) + if err != nil { + return err + } + for _, path := range operationTouchedPaths(op) { + if !pathWithinAnyRoot(path, allowed) { + return fmt.Errorf("rollback blocked: recorded path escapes workspace and writable roots: %s", path) + } + } + return nil +} + +func allowedRoots(workspace string, writableRoots ...string) ([]string, error) { + workspace, err := filepath.Abs(strings.TrimSpace(workspace)) + if err != nil { + return nil, err + } + roots := []string{filepath.Clean(workspace)} + for _, root := range writableRoots { + root = strings.TrimSpace(root) + if root == "" { + continue + } + abs, err := filepath.Abs(root) + if err != nil { + return nil, err + } + roots = append(roots, filepath.Clean(abs)) + } + return roots, nil +} + +func pathWithinAnyRoot(path string, roots []string) bool { + for _, root := range roots { + if isPathWithinRoot(root, path) { + return true + } + } + return false +} + +func isPathWithinRoot(root, candidate string) bool { + root = filepath.Clean(strings.TrimSpace(root)) + candidate = filepath.Clean(strings.TrimSpace(candidate)) + rel, err := filepath.Rel(root, candidate) + if err != nil { + return false + } + return rel == "." || (rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator))) +} + +func normalizeTargetPath(workspace, absPath, display string) (string, error) { + candidate := strings.TrimSpace(absPath) + if candidate == "" { + candidate = filepath.FromSlash(strings.TrimSpace(display)) + } + if candidate == "" { + return "", errors.New("rollback file path is required") + } + if !filepath.IsAbs(candidate) { + candidate = filepath.Join(workspace, candidate) + } + abs, err := filepath.Abs(candidate) + if err != nil { + return "", err + } + return filepath.Clean(abs), nil +} + +func resolveChangePath(op Operation, change FileChange) string { + path, err := normalizeTargetPath(op.Workspace, change.AbsPath, change.Path) + if err != nil { + return filepath.Clean(filepath.Join(op.Workspace, filepath.FromSlash(change.Path))) + } + return path +} + +func resolveNewChangePath(op Operation, change FileChange) string { + path, err := normalizeTargetPath(op.Workspace, change.NewAbsPath, change.NewPath) + if err != nil { + return filepath.Clean(filepath.Join(op.Workspace, filepath.FromSlash(change.NewPath))) + } + return path +} + +func displayPath(workspace, absPath string) string { + rel, err := filepath.Rel(workspace, absPath) + if err == nil && rel != "." && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return filepath.ToSlash(rel) + } + return filepath.ToSlash(absPath) +} + +func sameWorkspace(a, b string) bool { + aa, errA := filepath.Abs(strings.TrimSpace(a)) + bb, errB := filepath.Abs(strings.TrimSpace(b)) + if errA == nil { + a = aa + } + if errB == nil { + b = bb + } + return strings.EqualFold(filepath.Clean(a), filepath.Clean(b)) +} + +func removeIfExists(path string) error { + err := os.Remove(path) + if err == nil || errors.Is(err, os.ErrNotExist) { + return nil + } + return err +} + +func hashBytes(data []byte) string { + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + +func isText(data []byte) bool { + for _, b := range data { + if b == 0 { + return false + } + } + return true +} + +func newOperationID(now time.Time) string { + buf := make([]byte, 6) + if _, err := rand.Read(buf); err != nil { + return now.UTC().Format("20060102T150405.000000000Z") + } + return now.UTC().Format("20060102T150405.000000000Z") + "-" + hex.EncodeToString(buf) +} diff --git a/internal/rollback/store_test.go b/internal/rollback/store_test.go new file mode 100644 index 00000000..1712b84d --- /dev/null +++ b/internal/rollback/store_test.go @@ -0,0 +1,190 @@ +package rollback + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestStoreRollsBackCommittedUpdate(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + path := filepath.Join(workspace, "file.txt") + mustWriteRollbackTestFile(t, path, "old\n") + + store, err := NewStore(filepath.Join(t.TempDir(), "rollback")) + if err != nil { + t.Fatal(err) + } + op, err := store.Begin(context.Background(), BeginOptions{ + Workspace: workspace, + ToolName: "write_file", + TraceID: "trace-test", + }, []FileTarget{{ + Path: "file.txt", + AbsPath: path, + OpType: OpTypeUpdate, + }}) + if err != nil { + t.Fatal(err) + } + + mustWriteRollbackTestFile(t, path, "new\n") + if err := store.Commit(context.Background(), op); err != nil { + t.Fatal(err) + } + + rolledBack, err := store.Rollback(context.Background(), workspace, op.OperationID) + if err != nil { + t.Fatal(err) + } + if rolledBack.Status != StatusRolledBack { + t.Fatalf("expected rolled_back status, got %s", rolledBack.Status) + } + if got := mustReadRollbackTestFile(t, path); got != "old\n" { + t.Fatalf("expected old content restored, got %q", got) + } +} + +func TestStoreRollbackBlocksOnConflict(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + path := filepath.Join(workspace, "file.txt") + mustWriteRollbackTestFile(t, path, "old\n") + + store, err := NewStore(filepath.Join(t.TempDir(), "rollback")) + if err != nil { + t.Fatal(err) + } + op, err := store.Begin(context.Background(), BeginOptions{ + Workspace: workspace, + ToolName: "replace_in_file", + TraceID: "trace-test", + }, []FileTarget{{ + Path: "file.txt", + AbsPath: path, + OpType: OpTypeUpdate, + }}) + if err != nil { + t.Fatal(err) + } + mustWriteRollbackTestFile(t, path, "new\n") + if err := store.Commit(context.Background(), op); err != nil { + t.Fatal(err) + } + mustWriteRollbackTestFile(t, path, "user edit\n") + + _, err = store.Rollback(context.Background(), workspace, op.OperationID) + if err == nil || !strings.Contains(err.Error(), "changed after operation") { + t.Fatalf("expected conflict error, got %v", err) + } + if got := mustReadRollbackTestFile(t, path); got != "user edit\n" { + t.Fatalf("expected conflict to leave user edit intact, got %q", got) + } +} + +func TestStoreRollsBackAddDeleteAndMove(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + store, err := NewStore(filepath.Join(t.TempDir(), "rollback")) + if err != nil { + t.Fatal(err) + } + + addPath := filepath.Join(workspace, "added.txt") + addOp, err := store.Begin(context.Background(), BeginOptions{Workspace: workspace, ToolName: "apply_patch", TraceID: "trace-add"}, []FileTarget{{ + Path: "added.txt", + AbsPath: addPath, + OpType: OpTypeAdd, + }}) + if err != nil { + t.Fatal(err) + } + mustWriteRollbackTestFile(t, addPath, "new\n") + if err := store.Commit(context.Background(), addOp); err != nil { + t.Fatal(err) + } + if _, err := store.Rollback(context.Background(), workspace, addOp.OperationID); err != nil { + t.Fatal(err) + } + if _, err := os.Stat(addPath); !os.IsNotExist(err) { + t.Fatalf("expected added file removed, got %v", err) + } + + deletePath := filepath.Join(workspace, "delete.txt") + mustWriteRollbackTestFile(t, deletePath, "before delete\n") + deleteOp, err := store.Begin(context.Background(), BeginOptions{Workspace: workspace, ToolName: "apply_patch", TraceID: "trace-delete"}, []FileTarget{{ + Path: "delete.txt", + AbsPath: deletePath, + OpType: OpTypeDelete, + }}) + if err != nil { + t.Fatal(err) + } + if err := os.Remove(deletePath); err != nil { + t.Fatal(err) + } + if err := store.Commit(context.Background(), deleteOp); err != nil { + t.Fatal(err) + } + if _, err := store.Rollback(context.Background(), workspace, deleteOp.OperationID); err != nil { + t.Fatal(err) + } + if got := mustReadRollbackTestFile(t, deletePath); got != "before delete\n" { + t.Fatalf("expected deleted file restored, got %q", got) + } + + oldPath := filepath.Join(workspace, "old.txt") + newPath := filepath.Join(workspace, "nested", "new.txt") + mustWriteRollbackTestFile(t, oldPath, "before move\n") + moveOp, err := store.Begin(context.Background(), BeginOptions{Workspace: workspace, ToolName: "apply_patch", TraceID: "trace-move"}, []FileTarget{{ + Path: "old.txt", + AbsPath: oldPath, + NewPath: "nested/new.txt", + NewAbsPath: newPath, + OpType: OpTypeMove, + }}) + if err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil { + t.Fatal(err) + } + mustWriteRollbackTestFile(t, newPath, "after move\n") + if err := os.Remove(oldPath); err != nil { + t.Fatal(err) + } + if err := store.Commit(context.Background(), moveOp); err != nil { + t.Fatal(err) + } + if _, err := store.Rollback(context.Background(), workspace, moveOp.OperationID); err != nil { + t.Fatal(err) + } + if got := mustReadRollbackTestFile(t, oldPath); got != "before move\n" { + t.Fatalf("expected moved file restored, got %q", got) + } + if _, err := os.Stat(newPath); !os.IsNotExist(err) { + t.Fatalf("expected moved target removed, got %v", err) + } +} + +func mustWriteRollbackTestFile(t *testing.T, path, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatal(err) + } +} + +func mustReadRollbackTestFile(t *testing.T, path string) string { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + return string(data) +} diff --git a/internal/tools/apply_patch.go b/internal/tools/apply_patch.go index a0ddee17..0d63eb0a 100644 --- a/internal/tools/apply_patch.go +++ b/internal/tools/apply_patch.go @@ -12,6 +12,7 @@ import ( "strings" "github.com/1024XEngineer/bytemind/internal/llm" + rollbackpkg "github.com/1024XEngineer/bytemind/internal/rollback" ) type ApplyPatchTool struct{} @@ -34,6 +35,13 @@ type patchHunk struct { Lines []patchLine } +type plannedPatchOperation struct { + Type string + OldPath string + NewPath string + Content string +} + var unifiedHunkHeaderPattern = regexp.MustCompile(`^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@(?: .*)?$`) func (ApplyPatchTool) Definition() llm.ToolDefinition { @@ -56,7 +64,7 @@ func (ApplyPatchTool) Definition() llm.ToolDefinition { } } -func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *ExecutionContext) (string, error) { +func (ApplyPatchTool) Run(ctx context.Context, raw json.RawMessage, execCtx *ExecutionContext) (string, error) { var args struct { Patch string `json:"patch"` } @@ -78,10 +86,31 @@ func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execu i := 1 operations := make([]map[string]any, 0, 4) diffFiles := make([]DiffFile, 0, 4) + planned := make([]plannedPatchOperation, 0, 4) + targets := make([]rollbackpkg.FileTarget, 0, 4) + touchedPaths := map[string]string{} for i < len(lines) { line := lines[i] if line == "*** End Patch" { - return toJSON(buildPatchResult(operations, diffFiles)) + tracker, err := beginRollbackOperation(ctx, execCtx, "apply_patch", targets) + if err != nil { + return "", err + } + if err := applyPlannedPatchOperations(planned); err != nil { + if tracker != nil { + tracker.abort(ctx, err.Error()) + } + return "", err + } + operationID, err := tracker.commit(ctx) + if err != nil { + return "", err + } + result := buildPatchResult(operations, diffFiles) + if operationID != "" { + result["rollback_operation_id"] = operationID + } + return toJSON(result) } switch { @@ -100,14 +129,26 @@ func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execu if err != nil { return "", err } - if err := os.MkdirAll(filepath.Dir(resolved), 0o755); err != nil { + if _, err := os.Stat(resolved); err == nil { + return "", fmt.Errorf("add file target already exists: %s", path) + } else if err != nil && !errors.Is(err, os.ErrNotExist) { return "", err } - content := joinLines(contentLines, true, patchLineEnding) - if err := os.WriteFile(resolved, []byte(content), 0o644); err != nil { + if err := registerPatchTarget(touchedPaths, resolved, path); err != nil { return "", err } + content := joinLines(contentLines, true, patchLineEnding) relPath := filepath.ToSlash(mustRel(execCtx.Workspace, resolved)) + planned = append(planned, plannedPatchOperation{ + Type: "add", + NewPath: resolved, + Content: content, + }) + targets = append(targets, rollbackpkg.FileTarget{ + Path: relPath, + AbsPath: resolved, + OpType: rollbackpkg.OpTypeAdd, + }) operations = append(operations, map[string]any{"type": "add", "path": relPath}) diffFiles = append(diffFiles, DiffFile{ Path: relPath, @@ -122,10 +163,19 @@ func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execu return "", err } removed := lineCount(resolved) - if err := os.Remove(resolved); err != nil { + relPath := filepath.ToSlash(mustRel(execCtx.Workspace, resolved)) + if err := registerPatchTarget(touchedPaths, resolved, path); err != nil { return "", err } - relPath := filepath.ToSlash(mustRel(execCtx.Workspace, resolved)) + planned = append(planned, plannedPatchOperation{ + Type: "delete", + OldPath: resolved, + }) + targets = append(targets, rollbackpkg.FileTarget{ + Path: relPath, + AbsPath: resolved, + OpType: rollbackpkg.OpTypeDelete, + }) operations = append(operations, map[string]any{"type": "delete", "path": relPath}) diffFiles = append(diffFiles, DiffFile{ Path: relPath, @@ -148,6 +198,21 @@ func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execu } i++ } + if newPath != oldPath { + if _, err := os.Stat(newPath); err == nil { + return "", fmt.Errorf("move target already exists: %s", filepath.ToSlash(mustRel(execCtx.Workspace, newPath))) + } else if err != nil && !errors.Is(err, os.ErrNotExist) { + return "", err + } + } + if err := registerPatchTarget(touchedPaths, oldPath, path); err != nil { + return "", err + } + if newPath != oldPath { + if err := registerPatchTarget(touchedPaths, newPath, filepath.ToSlash(mustRel(execCtx.Workspace, newPath))); err != nil { + return "", err + } + } chunkLines := make([]string, 0, 64) for i < len(lines) && !strings.HasPrefix(lines[i], "*** ") { chunkLines = append(chunkLines, lines[i]) @@ -161,19 +226,28 @@ func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execu if err != nil { return "", err } - if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil { - return "", err + relOld := filepath.ToSlash(mustRel(execCtx.Workspace, oldPath)) + relNew := filepath.ToSlash(mustRel(execCtx.Workspace, newPath)) + opType := rollbackpkg.OpTypeUpdate + if newPath != oldPath { + opType = rollbackpkg.OpTypeMove } - if err := os.WriteFile(newPath, []byte(updated), 0o644); err != nil { - return "", err + planned = append(planned, plannedPatchOperation{ + Type: "update", + OldPath: oldPath, + NewPath: newPath, + Content: updated, + }) + target := rollbackpkg.FileTarget{ + Path: relOld, + AbsPath: oldPath, + OpType: opType, } if newPath != oldPath { - if err := os.Remove(oldPath); err != nil { - return "", err - } + target.NewPath = relNew + target.NewAbsPath = newPath } - relOld := filepath.ToSlash(mustRel(execCtx.Workspace, oldPath)) - relNew := filepath.ToSlash(mustRel(execCtx.Workspace, newPath)) + targets = append(targets, target) operations = append(operations, map[string]any{ "type": "update", "path": relOld, @@ -201,6 +275,48 @@ func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execu return "", errors.New("patch missing *** End Patch") } +func registerPatchTarget(seen map[string]string, absPath, label string) error { + key := filepath.Clean(absPath) + if previous, ok := seen[key]; ok { + return fmt.Errorf("patch touches %s more than once (already used by %s)", label, previous) + } + seen[key] = label + return nil +} + +func applyPlannedPatchOperations(operations []plannedPatchOperation) error { + for _, op := range operations { + switch op.Type { + case "add": + if err := os.MkdirAll(filepath.Dir(op.NewPath), 0o755); err != nil { + return err + } + if err := os.WriteFile(op.NewPath, []byte(op.Content), 0o644); err != nil { + return err + } + case "delete": + if err := os.Remove(op.OldPath); err != nil { + return err + } + case "update": + if err := os.MkdirAll(filepath.Dir(op.NewPath), 0o755); err != nil { + return err + } + if err := os.WriteFile(op.NewPath, []byte(op.Content), 0o644); err != nil { + return err + } + if op.NewPath != op.OldPath { + if err := os.Remove(op.OldPath); err != nil { + return err + } + } + default: + return fmt.Errorf("unsupported planned patch operation: %s", op.Type) + } + } + return nil +} + func applyStructuredPatch(original string, chunkLines []string) (string, error) { hunks, err := parsePatchHunks(chunkLines) if err != nil { diff --git a/internal/tools/replace_in_file.go b/internal/tools/replace_in_file.go index acb49197..5a962338 100644 --- a/internal/tools/replace_in_file.go +++ b/internal/tools/replace_in_file.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/1024XEngineer/bytemind/internal/llm" + rollbackpkg "github.com/1024XEngineer/bytemind/internal/rollback" ) type ReplaceInFileTool struct{} @@ -45,7 +46,7 @@ func (ReplaceInFileTool) Definition() llm.ToolDefinition { } } -func (ReplaceInFileTool) Run(_ context.Context, raw json.RawMessage, execCtx *ExecutionContext) (string, error) { +func (ReplaceInFileTool) Run(ctx context.Context, raw json.RawMessage, execCtx *ExecutionContext) (string, error) { var args struct { Path string `json:"path"` Old string `json:"old"` @@ -78,12 +79,27 @@ func (ReplaceInFileTool) Run(_ context.Context, raw json.RawMessage, execCtx *Ex } else { updated = strings.Replace(content, args.Old, args.New, 1) } + relPath := filepath.ToSlash(mustRel(execCtx.Workspace, path)) + tracker, err := beginRollbackOperation(ctx, execCtx, "replace_in_file", []rollbackpkg.FileTarget{{ + Path: relPath, + AbsPath: path, + OpType: rollbackpkg.OpTypeUpdate, + }}) + if err != nil { + return "", err + } if err := os.WriteFile(path, []byte(updated), 0o644); err != nil { + if tracker != nil { + tracker.abort(ctx, err.Error()) + } + return "", err + } + operationID, err := tracker.commit(ctx) + if err != nil { return "", err } var result map[string]any - relPath := filepath.ToSlash(mustRel(execCtx.Workspace, path)) if diffPreview := buildReplaceDiff(content, args.Old, args.New, args.ReplaceAll, relPath); diffPreview != nil { result = map[string]any{ "ok": true, @@ -100,6 +116,9 @@ func (ReplaceInFileTool) Run(_ context.Context, raw json.RawMessage, execCtx *Ex "old_count": count, } } + if operationID != "" { + result["rollback_operation_id"] = operationID + } return toJSON(result) } diff --git a/internal/tools/rollback_helpers.go b/internal/tools/rollback_helpers.go new file mode 100644 index 00000000..f24a7c4b --- /dev/null +++ b/internal/tools/rollback_helpers.go @@ -0,0 +1,77 @@ +package tools + +import ( + "context" + "strings" + + rollbackpkg "github.com/1024XEngineer/bytemind/internal/rollback" +) + +type rollbackTracker struct { + store *rollbackpkg.Store + op *rollbackpkg.Operation + roots []string +} + +func beginRollbackOperation(ctx context.Context, execCtx *ExecutionContext, toolName string, targets []rollbackpkg.FileTarget) (*rollbackTracker, error) { + if !rollbackEnabled(execCtx) { + return nil, nil + } + if len(targets) == 0 { + return nil, nil + } + store, err := rollbackpkg.NewDefaultStore() + if err != nil { + return nil, err + } + op, err := store.Begin(ctx, rollbackpkg.BeginOptions{ + Workspace: strings.TrimSpace(execCtx.Workspace), + SessionID: rollbackSessionID(execCtx), + TraceID: strings.TrimSpace(execCtx.RunID), + ToolName: toolName, + Actor: "agent", + }, targets) + if err != nil { + return nil, err + } + return &rollbackTracker{ + store: store, + op: op, + roots: writableRootsFromExecContext(execCtx), + }, nil +} + +func rollbackEnabled(execCtx *ExecutionContext) bool { + if execCtx == nil { + return false + } + return strings.TrimSpace(execCtx.RunID) != "" +} + +func rollbackSessionID(execCtx *ExecutionContext) string { + if execCtx == nil || execCtx.Session == nil { + return "" + } + return strings.TrimSpace(execCtx.Session.ID) +} + +func (t *rollbackTracker) commit(ctx context.Context) (string, error) { + if t == nil || t.store == nil || t.op == nil { + return "", nil + } + if err := t.store.Commit(ctx, t.op); err != nil { + restoreErr := t.store.AbortAndRestore(ctx, t.op, err.Error(), t.roots...) + if restoreErr != nil { + return "", restoreErr + } + return "", err + } + return t.op.OperationID, nil +} + +func (t *rollbackTracker) abort(ctx context.Context, reason string) { + if t == nil || t.store == nil || t.op == nil { + return + } + _ = t.store.AbortAndRestore(ctx, t.op, reason, t.roots...) +} diff --git a/internal/tools/rollback_integration_test.go b/internal/tools/rollback_integration_test.go new file mode 100644 index 00000000..f4906124 --- /dev/null +++ b/internal/tools/rollback_integration_test.go @@ -0,0 +1,188 @@ +package tools + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + rollbackpkg "github.com/1024XEngineer/bytemind/internal/rollback" +) + +func TestWriteFileToolRecordsRollbackOperation(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + tool := WriteFileTool{} + payload, _ := json.Marshal(map[string]any{ + "path": "new.txt", + "content": "created\n", + }) + + result, err := tool.Run(context.Background(), payload, &ExecutionContext{ + Workspace: workspace, + RunID: "trace-write", + }) + if err != nil { + t.Fatal(err) + } + operationID := rollbackOperationIDFromToolResult(t, result) + if operationID == "" { + t.Fatal("expected rollback operation id") + } + + store, err := rollbackpkg.NewDefaultStore() + if err != nil { + t.Fatal(err) + } + if _, err := store.Rollback(context.Background(), workspace, operationID); err != nil { + t.Fatal(err) + } + if _, err := os.Stat(filepath.Join(workspace, "new.txt")); !os.IsNotExist(err) { + t.Fatalf("expected rollback to remove created file, got %v", err) + } +} + +func TestReplaceInFileToolRecordsRollbackOperation(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + path := filepath.Join(workspace, "sample.txt") + mustWriteFile(t, path, "alpha beta\n") + tool := ReplaceInFileTool{} + payload, _ := json.Marshal(map[string]any{ + "path": "sample.txt", + "old": "beta", + "new": "gamma", + }) + + result, err := tool.Run(context.Background(), payload, &ExecutionContext{ + Workspace: workspace, + RunID: "trace-replace", + }) + if err != nil { + t.Fatal(err) + } + operationID := rollbackOperationIDFromToolResult(t, result) + + store, err := rollbackpkg.NewDefaultStore() + if err != nil { + t.Fatal(err) + } + if _, err := store.Rollback(context.Background(), workspace, operationID); err != nil { + t.Fatal(err) + } + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + if string(data) != "alpha beta\n" { + t.Fatalf("expected replacement rollback, got %q", string(data)) + } +} + +func TestApplyPatchToolRollbackRestoresMultipleFiles(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + mustWriteFile(t, filepath.Join(workspace, "a.txt"), "alpha\nbeta\n") + tool := ApplyPatchTool{} + payload, _ := json.Marshal(map[string]any{ + "patch": strings.Join([]string{ + "*** Begin Patch", + "*** Update File: a.txt", + "@@", + " alpha", + "-beta", + "+gamma", + "*** Add File: b.txt", + "+created", + "*** End Patch", + }, "\n"), + }) + + result, err := tool.Run(context.Background(), payload, &ExecutionContext{ + Workspace: workspace, + RunID: "trace-patch", + }) + if err != nil { + t.Fatal(err) + } + operationID := rollbackOperationIDFromToolResult(t, result) + + store, err := rollbackpkg.NewDefaultStore() + if err != nil { + t.Fatal(err) + } + if _, err := store.Rollback(context.Background(), workspace, operationID); err != nil { + t.Fatal(err) + } + if got := mustReadToolsTestFile(t, filepath.Join(workspace, "a.txt")); got != "alpha\nbeta\n" { + t.Fatalf("expected a.txt restored, got %q", got) + } + if _, err := os.Stat(filepath.Join(workspace, "b.txt")); !os.IsNotExist(err) { + t.Fatalf("expected b.txt removed, got %v", err) + } +} + +func TestApplyPatchToolFailureAutomaticallyRestoresChangedFiles(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + path := filepath.Join(workspace, "a.txt") + mustWriteFile(t, path, "alpha\nbeta\n") + tool := ApplyPatchTool{} + payload, _ := json.Marshal(map[string]any{ + "patch": strings.Join([]string{ + "*** Begin Patch", + "*** Update File: a.txt", + "@@", + " alpha", + "-beta", + "+gamma", + "*** Delete File: missing.txt", + "*** End Patch", + }, "\n"), + }) + + _, err := tool.Run(context.Background(), payload, &ExecutionContext{ + Workspace: workspace, + RunID: "trace-patch-failure", + }) + if err == nil { + t.Fatal("expected patch failure") + } + if got := mustReadToolsTestFile(t, path); got != "alpha\nbeta\n" { + t.Fatalf("expected failed patch to restore original content, got %q", got) + } + + store, err := rollbackpkg.NewDefaultStore() + if err != nil { + t.Fatal(err) + } + ops, err := store.ListRecent(context.Background(), workspace, 10) + if err != nil { + t.Fatal(err) + } + if len(ops) != 0 { + t.Fatalf("expected failed patch not to leave committed rollback operations, got %#v", ops) + } +} + +func rollbackOperationIDFromToolResult(t *testing.T, result string) string { + t.Helper() + var parsed struct { + OperationID string `json:"rollback_operation_id"` + } + if err := json.Unmarshal([]byte(result), &parsed); err != nil { + t.Fatal(err) + } + return strings.TrimSpace(parsed.OperationID) +} + +func mustReadToolsTestFile(t *testing.T, path string) string { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + return string(data) +} diff --git a/internal/tools/write_file.go b/internal/tools/write_file.go index 4f19fffe..f17915d5 100644 --- a/internal/tools/write_file.go +++ b/internal/tools/write_file.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/1024XEngineer/bytemind/internal/llm" + rollbackpkg "github.com/1024XEngineer/bytemind/internal/rollback" ) type WriteFileTool struct{} @@ -40,7 +41,7 @@ func (WriteFileTool) Definition() llm.ToolDefinition { } } -func (WriteFileTool) Run(_ context.Context, raw json.RawMessage, execCtx *ExecutionContext) (string, error) { +func (WriteFileTool) Run(ctx context.Context, raw json.RawMessage, execCtx *ExecutionContext) (string, error) { var args struct { Path string `json:"path"` Content string `json:"content"` @@ -73,7 +74,27 @@ func (WriteFileTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execut original = string(data) } + opType := rollbackpkg.OpTypeAdd + if exists { + opType = rollbackpkg.OpTypeUpdate + } + tracker, err := beginRollbackOperation(ctx, execCtx, "write_file", []rollbackpkg.FileTarget{{ + Path: relPath, + AbsPath: path, + OpType: opType, + }}) + if err != nil { + return "", err + } + if err := os.WriteFile(path, []byte(args.Content), 0o644); err != nil { + if tracker != nil { + tracker.abort(ctx, err.Error()) + } + return "", err + } + operationID, err := tracker.commit(ctx) + if err != nil { return "", err } @@ -82,6 +103,9 @@ func (WriteFileTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execut "path": relPath, "bytes_written": len(args.Content), } + if operationID != "" { + result["rollback_operation_id"] = operationID + } if dp := buildWriteFileDiff(original, args.Content, exists, relPath); dp != nil { result["diff_preview"] = dp diff --git a/tui/component_command_utils.go b/tui/component_command_utils.go index 8af35da6..df8ce6e6 100644 --- a/tui/component_command_utils.go +++ b/tui/component_command_utils.go @@ -30,6 +30,7 @@ func (m model) helpText() string { "- `/compact`: summarize long history into a compact continuation context.", "- `/commit `: stage all changes and create a local Git commit.", "- `/undo-commit`: undo the last local commit created by `/commit` in this session.", + "- `/rollback [last|]`: list or undo ByteMind file edits recorded by tool snapshots.", "- `/btw `: interject while a run is in progress.", "- `/quit`: exit the TUI.", "- TUI does not expose `/resume`; use `/session` then `Enter` on the selected row.", diff --git a/tui/component_rollback_command.go b/tui/component_rollback_command.go new file mode 100644 index 00000000..a0a57777 --- /dev/null +++ b/tui/component_rollback_command.go @@ -0,0 +1,122 @@ +package tui + +import ( + "context" + "errors" + "fmt" + "strings" + + rollbackpkg "github.com/1024XEngineer/bytemind/internal/rollback" +) + +const rollbackUsage = "Usage: /rollback [last|]\nList or undo ByteMind file edits recorded by write_file, replace_in_file, or apply_patch." + +func (m *model) runRollbackCommand(input string) error { + response, status, err := executeRollbackCommand(context.Background(), m.workspace, m.cfg.WritableRoots, input) + if err != nil { + return m.finishRollbackCommand(input, err.Error(), "Rollback failed.") + } + return m.finishRollbackCommand(input, response, status) +} + +func (m *model) finishRollbackCommand(input, response, status string) error { + m.appendCommandExchange(input, response) + m.statusNote = status + if err := m.recordCommandExchange(input, response); err != nil { + m.statusNote = "Command shown, but session save failed: " + err.Error() + return nil + } + return nil +} + +func executeRollbackCommand(ctx context.Context, workspace string, writableRoots []string, input string) (response string, status string, err error) { + fields := strings.Fields(strings.TrimSpace(input)) + if len(fields) == 0 || fields[0] != "/rollback" { + return "", "", errors.New(rollbackUsage) + } + if len(fields) > 2 { + return "", "", errors.New(rollbackUsage) + } + + store, err := rollbackpkg.NewDefaultStore() + if err != nil { + return "", "", fmt.Errorf("Rollback unavailable: %w", err) + } + + if len(fields) == 1 { + ops, err := store.ListRecent(ctx, workspace, 10) + if err != nil { + return "", "", err + } + return formatRollbackList(ops), "Rollback operations listed.", nil + } + + target := strings.TrimSpace(fields[1]) + var op *rollbackpkg.Operation + if strings.EqualFold(target, "last") { + op, err = store.RollbackLast(ctx, workspace, writableRoots...) + } else { + op, err = store.Rollback(ctx, workspace, target, writableRoots...) + } + if err != nil { + return "", "", err + } + return formatRollbackSuccess(*op), "Rollback completed.", nil +} + +func formatRollbackList(ops []rollbackpkg.Operation) string { + if len(ops) == 0 { + return "No ByteMind rollback operations recorded for this workspace.\n\n`/rollback` is for ByteMind file edits. `/undo-commit` is only for local git commits created by `/commit`." + } + lines := []string{ + "Recent ByteMind rollback operations:", + "", + } + for _, op := range ops { + lines = append(lines, fmt.Sprintf( + "- `%s` %s %s %d file(s) %s", + shortRollbackID(op.OperationID), + op.CreatedAt.Local().Format("2006-01-02 15:04:05"), + op.ToolName, + len(op.AffectedFiles), + rollbackPathSummary(op), + )) + } + lines = append(lines, "", "Use `/rollback last` or `/rollback ` to restore one operation.") + return strings.Join(lines, "\n") +} + +func formatRollbackSuccess(op rollbackpkg.Operation) string { + return fmt.Sprintf( + "Rollback completed.\n\nOperation: `%s`\nTool: %s\nFiles restored: %d\n\nThis restored ByteMind file snapshots and did not modify git history.", + op.OperationID, + op.ToolName, + len(op.AffectedFiles), + ) +} + +func shortRollbackID(id string) string { + id = strings.TrimSpace(id) + if len(id) <= 22 { + return id + } + return id[:22] +} + +func rollbackPathSummary(op rollbackpkg.Operation) string { + paths := make([]string, 0, min(len(op.AffectedFiles), 3)) + for i, file := range op.AffectedFiles { + if i >= 3 { + break + } + path := file.Path + if file.OpType == rollbackpkg.OpTypeMove && strings.TrimSpace(file.NewPath) != "" { + path += " -> " + file.NewPath + } + paths = append(paths, path) + } + if len(op.AffectedFiles) > 3 { + paths = append(paths, fmt.Sprintf("+%d more", len(op.AffectedFiles)-3)) + } + return strings.Join(paths, ", ") +} diff --git a/tui/component_rollback_command_test.go b/tui/component_rollback_command_test.go new file mode 100644 index 00000000..571e4465 --- /dev/null +++ b/tui/component_rollback_command_test.go @@ -0,0 +1,79 @@ +package tui + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + rollbackpkg "github.com/1024XEngineer/bytemind/internal/rollback" +) + +func TestExecuteRollbackCommandListsAndRollsBackLast(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + path := filepath.Join(workspace, "file.txt") + writeRollbackCommandTestFile(t, path, "old\n") + + store, err := rollbackpkg.NewDefaultStore() + if err != nil { + t.Fatal(err) + } + op, err := store.Begin(context.Background(), rollbackpkg.BeginOptions{ + Workspace: workspace, + ToolName: "write_file", + TraceID: "trace-rollback-command", + }, []rollbackpkg.FileTarget{{ + Path: "file.txt", + AbsPath: path, + OpType: rollbackpkg.OpTypeUpdate, + }}) + if err != nil { + t.Fatal(err) + } + writeRollbackCommandTestFile(t, path, "new\n") + if err := store.Commit(context.Background(), op); err != nil { + t.Fatal(err) + } + + response, status, err := executeRollbackCommand(context.Background(), workspace, nil, "/rollback") + if err != nil { + t.Fatal(err) + } + if !strings.Contains(response, shortRollbackID(op.OperationID)) || status != "Rollback operations listed." { + t.Fatalf("expected rollback list response, got %q / %q", response, status) + } + + response, status, err = executeRollbackCommand(context.Background(), workspace, nil, "/rollback last") + if err != nil { + t.Fatal(err) + } + if !strings.Contains(response, "Rollback completed.") || status != "Rollback completed." { + t.Fatalf("expected rollback success response, got %q / %q", response, status) + } + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + if string(data) != "old\n" { + t.Fatalf("expected file restored, got %q", string(data)) + } +} + +func TestExecuteRollbackCommandRejectsInvalidUsage(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + if _, _, err := executeRollbackCommand(context.Background(), t.TempDir(), nil, "/rollback a b"); err == nil || err.Error() != rollbackUsage { + t.Fatalf("expected rollback usage error, got %v", err) + } +} + +func writeRollbackCommandTestFile(t *testing.T, path, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatal(err) + } +} diff --git a/tui/component_slash_entry.go b/tui/component_slash_entry.go index 33beff1c..b462555a 100644 --- a/tui/component_slash_entry.go +++ b/tui/component_slash_entry.go @@ -57,6 +57,8 @@ func (m *model) handleSlashCommand(input string) error { return m.runCommitCommand(input) case "/undo-commit": return m.runUndoCommitCommand(input) + case "/rollback": + return m.runRollbackCommand(input) default: return fmt.Errorf("unknown command: %s", fields[0]) } diff --git a/tui/model.go b/tui/model.go index 4bc94c11..3be86cd9 100644 --- a/tui/model.go +++ b/tui/model.go @@ -357,6 +357,7 @@ var commandItems = []commandItem{ {Name: "/compact", Usage: "/compact", Description: "Compress long session history into a continuation summary.", Kind: "command"}, {Name: "/commit", Usage: "/commit ", Description: "Stage all changes and create a local Git commit.", Kind: "command"}, {Name: "/undo-commit", Usage: "/undo-commit", Description: "Undo the last local commit created by /commit in this session.", Kind: "command"}, + {Name: "/rollback", Usage: "/rollback [last|]", Description: "List or undo ByteMind file edits recorded by tool snapshots.", Kind: "command"}, {Name: "/btw", Usage: "/btw ", Description: "Interject while a run is in progress.", Kind: "command"}, {Name: "/quit", Usage: "/quit", Description: "Exit the current TUI window.", Kind: "command"}, {Name: "/skills", Usage: "/skills", Description: "List available skills and current active skill.", Kind: "command"}, @@ -3066,7 +3067,7 @@ func shouldExecuteFromPalette(item commandItem) bool { return true } switch item.Name { - case "/help", "/session", "/agents", "/skills", "/skill clear", "/mcp list", "/mcp help", "/model", "/new", "/compact", "/undo-commit", "/quit": + case "/help", "/session", "/agents", "/skills", "/skill clear", "/mcp list", "/mcp help", "/model", "/new", "/compact", "/undo-commit", "/rollback", "/quit": return true default: return false diff --git a/tui/model_test.go b/tui/model_test.go index 14682699..9d6d7274 100644 --- a/tui/model_test.go +++ b/tui/model_test.go @@ -3982,7 +3982,7 @@ func TestFilteredCommandsShowsRootSelectorGroups(t *testing.T) { usages = append(usages, item.Usage) } - for _, want := range []string{"/help", "/session", "/skills-select", "/model", "/new", "/compact", "/commit ", "/undo-commit", "/quit"} { + for _, want := range []string{"/help", "/session", "/skills-select", "/model", "/new", "/compact", "/commit ", "/undo-commit", "/rollback [last|]", "/quit"} { if !containsString(usages, want) { t.Fatalf("expected root selector to contain %q, got %v", want, usages) } From f1b0bbbd92846a6daf761eec3ce6a4dd9c17304d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E7=BA=AA?= <3049035704@qq.com> Date: Sun, 10 May 2026 01:51:29 +0800 Subject: [PATCH 2/3] increase rollback test coverage --- internal/rollback/store_test.go | 296 ++++++++++++++++++++++++++++++++ 1 file changed, 296 insertions(+) diff --git a/internal/rollback/store_test.go b/internal/rollback/store_test.go index 1712b84d..896e248f 100644 --- a/internal/rollback/store_test.go +++ b/internal/rollback/store_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "strings" "testing" + "time" ) func TestStoreRollsBackCommittedUpdate(t *testing.T) { @@ -170,6 +171,301 @@ func TestStoreRollsBackAddDeleteAndMove(t *testing.T) { } } +func TestDefaultStoreListsRecentAndRollsBackLast(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + store, err := NewDefaultStore() + if err != nil { + t.Fatal(err) + } + + firstPath := filepath.Join(workspace, "first.txt") + secondPath := filepath.Join(workspace, "second.txt") + ignoredPath := filepath.Join(workspace, "ignored.txt") + mustWriteRollbackTestFile(t, firstPath, "first old\n") + mustWriteRollbackTestFile(t, secondPath, "second old\n") + mustWriteRollbackTestFile(t, ignoredPath, "ignored old\n") + + first := beginUpdateRollbackTestOperation(t, store, workspace, "first.txt", firstPath) + mustWriteRollbackTestFile(t, firstPath, "first new\n") + if err := store.Commit(context.Background(), first); err != nil { + t.Fatal(err) + } + second := beginUpdateRollbackTestOperation(t, store, workspace, "second.txt", secondPath) + mustWriteRollbackTestFile(t, secondPath, "second new\n") + if err := store.Commit(context.Background(), second); err != nil { + t.Fatal(err) + } + ignored := beginUpdateRollbackTestOperation(t, store, workspace, "ignored.txt", ignoredPath) + if err := store.Abort(context.Background(), ignored, "not committed"); err != nil { + t.Fatal(err) + } + + ops, err := store.ListRecent(context.Background(), workspace, 1) + if err != nil { + t.Fatal(err) + } + if len(ops) != 1 || ops[0].OperationID != second.OperationID { + t.Fatalf("expected only newest committed operation, got %#v", ops) + } + + rolledBack, err := store.RollbackLast(context.Background(), workspace) + if err != nil { + t.Fatal(err) + } + if rolledBack.OperationID != second.OperationID { + t.Fatalf("expected latest operation %s, got %s", second.OperationID, rolledBack.OperationID) + } + if got := mustReadRollbackTestFile(t, secondPath); got != "second old\n" { + t.Fatalf("expected second file restored, got %q", got) + } + if got := mustReadRollbackTestFile(t, firstPath); got != "first new\n" { + t.Fatalf("expected first file untouched, got %q", got) + } +} + +func TestAbortAndRestoreRestoresPendingOperation(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + store, err := NewStore(filepath.Join(t.TempDir(), "rollback")) + if err != nil { + t.Fatal(err) + } + path := filepath.Join(workspace, "file.txt") + mustWriteRollbackTestFile(t, path, "before\n") + + op := beginUpdateRollbackTestOperation(t, store, workspace, "file.txt", path) + mustWriteRollbackTestFile(t, path, "partial write\n") + if err := store.AbortAndRestore(context.Background(), op, "write failed"); err != nil { + t.Fatal(err) + } + if op.Status != StatusAborted { + t.Fatalf("expected aborted status, got %s", op.Status) + } + if got := mustReadRollbackTestFile(t, path); got != "before\n" { + t.Fatalf("expected pending change restored, got %q", got) + } +} + +func TestAbortAndRestoreMarksFailureWhenPathEscapesRoots(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + external := filepath.Join(t.TempDir(), "external.txt") + mustWriteRollbackTestFile(t, external, "before\n") + store, err := NewStore(filepath.Join(t.TempDir(), "rollback")) + if err != nil { + t.Fatal(err) + } + op, err := store.Begin(context.Background(), BeginOptions{ + Workspace: workspace, + ToolName: "write_file", + TraceID: "trace-external", + }, []FileTarget{{ + Path: filepath.ToSlash(external), + AbsPath: external, + OpType: OpTypeUpdate, + }}) + if err != nil { + t.Fatal(err) + } + mustWriteRollbackTestFile(t, external, "partial\n") + + err = store.AbortAndRestore(context.Background(), op, "write failed") + if err == nil || !strings.Contains(err.Error(), "escapes workspace") { + t.Fatalf("expected path validation failure, got %v", err) + } + if op.Status != StatusRollbackFailed { + t.Fatalf("expected rollback_failed status, got %s", op.Status) + } + if got := mustReadRollbackTestFile(t, external); got != "partial\n" { + t.Fatalf("expected failed restore to leave external file, got %q", got) + } +} + +func TestListRecentHandlesCanceledContextAndNoOperations(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + store, err := NewStore(filepath.Join(t.TempDir(), "rollback")) + if err != nil { + t.Fatal(err) + } + if ops, err := store.ListRecent(context.Background(), workspace, 0); err != nil || len(ops) != 0 { + t.Fatalf("expected empty list, got %#v / %v", ops, err) + } + if _, err := store.RollbackLast(context.Background(), workspace); err == nil || !strings.Contains(err.Error(), "no committed") { + t.Fatalf("expected no committed operation error, got %v", err) + } + + path := filepath.Join(workspace, "file.txt") + mustWriteRollbackTestFile(t, path, "before\n") + op := beginUpdateRollbackTestOperation(t, store, workspace, "file.txt", path) + mustWriteRollbackTestFile(t, path, "after\n") + if err := store.Commit(context.Background(), op); err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := store.ListRecent(ctx, workspace, 10); err == nil { + t.Fatal("expected canceled context error") + } +} + +func TestStoreRejectsInvalidInputsAndSnapshotTargets(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + if _, err := NewStore(" "); err == nil { + t.Fatal("expected empty rollback root error") + } + store, err := NewStore(filepath.Join(t.TempDir(), "rollback")) + if err != nil { + t.Fatal(err) + } + if _, err := store.Begin(context.Background(), BeginOptions{Workspace: t.TempDir()}, nil); err == nil { + t.Fatal("expected empty target error") + } + if _, err := (*Store)(nil).Begin(context.Background(), BeginOptions{}, []FileTarget{{Path: "x"}}); err == nil { + t.Fatal("expected nil store error") + } + if err := store.saveOperation(&Operation{}); err == nil { + t.Fatal("expected missing operation id error") + } + + workspace := t.TempDir() + dirTarget := filepath.Join(workspace, "dir") + if err := os.MkdirAll(dirTarget, 0o755); err != nil { + t.Fatal(err) + } + if _, err := store.Begin(context.Background(), BeginOptions{Workspace: workspace}, []FileTarget{{Path: "dir", AbsPath: dirTarget}}); err == nil || !strings.Contains(err.Error(), "directory") { + t.Fatalf("expected directory snapshot error, got %v", err) + } + binaryTarget := filepath.Join(workspace, "binary.bin") + if err := os.WriteFile(binaryTarget, []byte{'a', 0, 'b'}, 0o644); err != nil { + t.Fatal(err) + } + if _, err := store.Begin(context.Background(), BeginOptions{Workspace: workspace}, []FileTarget{{Path: "binary.bin", AbsPath: binaryTarget}}); err == nil || !strings.Contains(err.Error(), "not a text") { + t.Fatalf("expected binary snapshot error, got %v", err) + } +} + +func TestFindOperationAndLoadErrors(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + store, err := NewStore(filepath.Join(t.TempDir(), "rollback")) + if err != nil { + t.Fatal(err) + } + if _, err := store.findOperation(workspace, " "); err == nil { + t.Fatal("expected empty operation id error") + } + if _, err := store.findOperation(workspace, "missing"); err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("expected not found error, got %v", err) + } + + first := beginUpdateRollbackTestOperation(t, store, workspace, "a.txt", filepath.Join(workspace, "a.txt")) + second := beginUpdateRollbackTestOperation(t, store, workspace, "b.txt", filepath.Join(workspace, "b.txt")) + first.OperationID = "same-prefix-one" + second.OperationID = "same-prefix-two" + if err := store.saveOperation(first); err != nil { + t.Fatal(err) + } + if err := store.saveOperation(second); err != nil { + t.Fatal(err) + } + if _, err := store.findOperation(workspace, "same-prefix"); err == nil || !strings.Contains(err.Error(), "ambiguous") { + t.Fatalf("expected ambiguous prefix error, got %v", err) + } + + if err := os.WriteFile(filepath.Join(store.entriesDir, "broken.json"), []byte("{"), 0o644); err != nil { + t.Fatal(err) + } + if _, err := store.loadAll(); err == nil || !strings.Contains(err.Error(), "read rollback operation") { + t.Fatalf("expected invalid json load error, got %v", err) + } +} + +func TestRollbackErrorsAndHelperBranches(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + store, err := NewStore(filepath.Join(t.TempDir(), "rollback")) + if err != nil { + t.Fatal(err) + } + + unsupported := &Operation{ + OperationID: "unsupported", + Workspace: workspace, + Status: StatusPending, + CreatedAt: timeNowForRollbackTest(), + UpdatedAt: timeNowForRollbackTest(), + AffectedFiles: []FileChange{{Path: "x.txt", OpType: OpType("bad")}}, + } + if err := store.Commit(context.Background(), unsupported); err == nil || !strings.Contains(err.Error(), "unsupported") { + t.Fatalf("expected unsupported commit error, got %v", err) + } + + committed := beginUpdateRollbackTestOperation(t, store, workspace, "file.txt", filepath.Join(workspace, "file.txt")) + committed.Status = StatusRolledBack + if err := store.saveOperation(committed); err != nil { + t.Fatal(err) + } + if _, err := store.Rollback(context.Background(), workspace, committed.OperationID); err == nil || !strings.Contains(err.Error(), "not committed") { + t.Fatalf("expected non-committed rollback error, got %v", err) + } + + outside := filepath.Join(t.TempDir(), "outside.txt") + mustWriteRollbackTestFile(t, outside, "outside\n") + external := beginUpdateRollbackTestOperation(t, store, workspace, filepath.ToSlash(outside), outside) + external.Status = StatusCommitted + external.AffectedFiles[0].AfterHash = hashBytes([]byte("outside\n")) + if err := store.saveOperation(external); err != nil { + t.Fatal(err) + } + if _, err := store.Rollback(context.Background(), workspace, external.OperationID); err == nil || !strings.Contains(err.Error(), "escapes workspace") { + t.Fatalf("expected path escape rollback error, got %v", err) + } + + path := filepath.Join(workspace, "state.txt") + states := map[string]snapshotState{ + path: {exists: true, data: []byte("restored\n")}, + filepath.Join(workspace, "missing.txt"): {}, + } + if err := restoreCurrentStates(states); err != nil { + t.Fatal(err) + } + if got := mustReadRollbackTestFile(t, path); got != "restored\n" { + t.Fatalf("expected restored current state, got %q", got) + } + + insideDisplay := displayPath(workspace, filepath.Join(workspace, "nested", "file.txt")) + if insideDisplay != "nested/file.txt" { + t.Fatalf("expected relative display path, got %q", insideDisplay) + } + outsideDisplay := displayPath(workspace, outside) + if outsideDisplay != filepath.ToSlash(outside) { + t.Fatalf("expected absolute display path, got %q", outsideDisplay) + } +} + +func beginUpdateRollbackTestOperation(t *testing.T, store *Store, workspace, relPath, absPath string) *Operation { + t.Helper() + op, err := store.Begin(context.Background(), BeginOptions{ + Workspace: workspace, + ToolName: "write_file", + TraceID: "trace-test", + }, []FileTarget{{ + Path: relPath, + AbsPath: absPath, + OpType: OpTypeUpdate, + }}) + if err != nil { + t.Fatal(err) + } + return op +} + +func timeNowForRollbackTest() time.Time { + return time.Now().UTC() +} + func mustWriteRollbackTestFile(t *testing.T, path, content string) { t.Helper() if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { From 20dab74d4d322175f42a35600e9bedf9d9d2bfb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E7=BA=AA?= <3049035704@qq.com> Date: Sun, 10 May 2026 03:00:08 +0800 Subject: [PATCH 3/3] cover rollback edge cases --- internal/rollback/store_test.go | 230 ++++++++++++++++++++ internal/tools/rollback_integration_test.go | 65 ++++++ tui/component_rollback_command_test.go | 106 +++++++++ 3 files changed, 401 insertions(+) diff --git a/internal/rollback/store_test.go b/internal/rollback/store_test.go index 896e248f..323b57fa 100644 --- a/internal/rollback/store_test.go +++ b/internal/rollback/store_test.go @@ -445,6 +445,236 @@ func TestRollbackErrorsAndHelperBranches(t *testing.T) { } } +func TestStoreAdditionalErrorBranches(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + ctx := context.Background() + workspace := t.TempDir() + store, err := NewStore(filepath.Join(t.TempDir(), "rollback")) + if err != nil { + t.Fatal(err) + } + + var nilStore *Store + if err := nilStore.Commit(ctx, nil); err != nil { + t.Fatal(err) + } + if err := nilStore.Abort(ctx, nil, "ignored"); err != nil { + t.Fatal(err) + } + if err := nilStore.AbortAndRestore(ctx, nil, "ignored"); err != nil { + t.Fatal(err) + } + if err := store.applyBeforeState(nil); err != nil { + t.Fatal(err) + } + + if err := store.Commit(ctx, &Operation{ + OperationID: "update-missing", + Workspace: workspace, + AffectedFiles: []FileChange{{Path: "missing.txt", OpType: OpTypeUpdate}}, + }); err == nil || !strings.Contains(err.Error(), "absent after update") { + t.Fatalf("expected update missing commit error, got %v", err) + } + + deletePath := filepath.Join(workspace, "delete-still-exists.txt") + mustWriteRollbackTestFile(t, deletePath, "still here\n") + if err := store.Commit(ctx, &Operation{ + OperationID: "delete-present", + Workspace: workspace, + AffectedFiles: []FileChange{{ + Path: "delete-still-exists.txt", + AbsPath: deletePath, + OpType: OpTypeDelete, + }}, + }); err == nil || !strings.Contains(err.Error(), "still exists after delete") { + t.Fatalf("expected delete present commit error, got %v", err) + } + + oldPath := filepath.Join(workspace, "old-still-exists.txt") + newPath := filepath.Join(workspace, "new-after-move.txt") + mustWriteRollbackTestFile(t, oldPath, "old\n") + mustWriteRollbackTestFile(t, newPath, "new\n") + if err := store.Commit(ctx, &Operation{ + OperationID: "move-old-present", + Workspace: workspace, + AffectedFiles: []FileChange{{ + Path: "old-still-exists.txt", + AbsPath: oldPath, + NewPath: "new-after-move.txt", + NewAbsPath: newPath, + OpType: OpTypeMove, + }}, + }); err == nil || !strings.Contains(err.Error(), "still exists after move") { + t.Fatalf("expected move old present commit error, got %v", err) + } + + if err := os.Remove(oldPath); err != nil { + t.Fatal(err) + } + if err := os.Remove(newPath); err != nil { + t.Fatal(err) + } + if err := store.Commit(ctx, &Operation{ + OperationID: "move-new-missing", + Workspace: workspace, + AffectedFiles: []FileChange{{ + Path: "old-still-exists.txt", + AbsPath: oldPath, + NewPath: "new-after-move.txt", + NewAbsPath: newPath, + OpType: OpTypeMove, + }}, + }); err == nil || !strings.Contains(err.Error(), "absent after move") { + t.Fatalf("expected move new missing commit error, got %v", err) + } + + if err := store.applyBeforeState(&Operation{ + Workspace: workspace, + AffectedFiles: []FileChange{{Path: "bad.txt", OpType: OpType("bad")}}, + }); err == nil || !strings.Contains(err.Error(), "unsupported") { + t.Fatalf("expected unsupported restore error, got %v", err) + } + if err := store.restoreSnapshot(filepath.Join(workspace, "missing-snapshot.txt"), FileChange{Path: "missing-snapshot.txt"}); err == nil || !strings.Contains(err.Error(), "snapshot missing") { + t.Fatalf("expected missing snapshot error, got %v", err) + } + removePath := filepath.Join(workspace, "remove-before-absent.txt") + mustWriteRollbackTestFile(t, removePath, "remove me\n") + if err := store.restoreSnapshot(removePath, FileChange{Path: "remove-before-absent.txt", FileAbsentBefore: true}); err != nil { + t.Fatal(err) + } + if _, err := os.Stat(removePath); !os.IsNotExist(err) { + t.Fatalf("expected before-absent restore to remove file, got %v", err) + } +} + +func TestConflictAndPathHelperBranches(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + ctx := context.Background() + workspace := t.TempDir() + store, err := NewStore(filepath.Join(t.TempDir(), "rollback")) + if err != nil { + t.Fatal(err) + } + + deletePath := filepath.Join(workspace, "deleted.txt") + mustWriteRollbackTestFile(t, deletePath, "changed after delete\n") + deleteOp := &Operation{ + OperationID: "delete-conflict", + Workspace: workspace, + Status: StatusCommitted, + AffectedFiles: []FileChange{{ + Path: "deleted.txt", + AbsPath: deletePath, + OpType: OpTypeDelete, + }}, + } + if err := store.saveOperation(deleteOp); err != nil { + t.Fatal(err) + } + if _, err := store.Rollback(ctx, workspace, deleteOp.OperationID); err == nil || !strings.Contains(err.Error(), "changed after delete") { + t.Fatalf("expected delete conflict, got %v", err) + } + + oldPath := filepath.Join(workspace, "move-old.txt") + newPath := filepath.Join(workspace, "move-new.txt") + mustWriteRollbackTestFile(t, oldPath, "old changed\n") + mustWriteRollbackTestFile(t, newPath, "after move\n") + moveOp := &Operation{ + OperationID: "move-old-conflict", + Workspace: workspace, + Status: StatusCommitted, + AffectedFiles: []FileChange{{ + Path: "move-old.txt", + AbsPath: oldPath, + NewPath: "move-new.txt", + NewAbsPath: newPath, + OpType: OpTypeMove, + AfterHash: hashBytes([]byte("after move\n")), + }}, + } + if err := store.saveOperation(moveOp); err != nil { + t.Fatal(err) + } + if _, err := store.Rollback(ctx, workspace, moveOp.OperationID); err == nil || !strings.Contains(err.Error(), "changed after move") { + t.Fatalf("expected move old path conflict, got %v", err) + } + + if err := os.Remove(oldPath); err != nil { + t.Fatal(err) + } + if err := os.Remove(newPath); err != nil { + t.Fatal(err) + } + moveMissing := *moveOp + moveMissing.OperationID = "move-new-missing-conflict" + if err := store.saveOperation(&moveMissing); err != nil { + t.Fatal(err) + } + if _, err := store.Rollback(ctx, workspace, moveMissing.OperationID); err == nil || !strings.Contains(err.Error(), "is missing") { + t.Fatalf("expected move new path missing conflict, got %v", err) + } + + externalRoot := t.TempDir() + externalPath := filepath.Join(externalRoot, "allowed.txt") + op := Operation{ + Workspace: workspace, + AffectedFiles: []FileChange{{ + Path: filepath.ToSlash(externalPath), + AbsPath: externalPath, + OpType: OpTypeUpdate, + }}, + } + if err := validateOperationPaths(op, workspace, " ", externalRoot); err != nil { + t.Fatalf("expected external writable root to be allowed, got %v", err) + } + if _, err := normalizeTargetPath(workspace, "", " "); err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected empty target path error, got %v", err) + } +} + +func TestSnapshotAndCurrentStateErrorBranches(t *testing.T) { + workspace := t.TempDir() + largePath := filepath.Join(workspace, "large.txt") + largeData := make([]byte, maxSnapshotBytes+1) + if err := os.WriteFile(largePath, largeData, 0o644); err != nil { + t.Fatal(err) + } + if _, err := readSnapshotCandidate(largePath); err == nil || !strings.Contains(err.Error(), "exceeds") { + t.Fatalf("expected large snapshot error, got %v", err) + } + + dirPath := filepath.Join(workspace, "dir") + if err := os.MkdirAll(dirPath, 0o755); err != nil { + t.Fatal(err) + } + if _, err := readCurrentState(dirPath); err == nil || !strings.Contains(err.Error(), "directory") { + t.Fatalf("expected current state directory error, got %v", err) + } + + blockFile := filepath.Join(workspace, "block") + if err := os.WriteFile(blockFile, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + if _, err := NewStore(blockFile); err == nil { + t.Fatal("expected NewStore to fail when root is a file") + } + + blobDir := filepath.Join(workspace, "blobs") + if err := os.MkdirAll(blobDir, 0o755); err != nil { + t.Fatal(err) + } + badEntries := filepath.Join(workspace, "entries-file") + if err := os.WriteFile(badEntries, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + badStore := &Store{entriesDir: badEntries, blobsDir: blobDir} + target := filepath.Join(workspace, "file.txt") + mustWriteRollbackTestFile(t, target, "before\n") + if _, err := badStore.Begin(context.Background(), BeginOptions{Workspace: workspace}, []FileTarget{{Path: "file.txt", AbsPath: target}}); err == nil { + t.Fatal("expected begin to fail when operation entry cannot be saved") + } +} + func beginUpdateRollbackTestOperation(t *testing.T, store *Store, workspace, relPath, absPath string) *Operation { t.Helper() op, err := store.Begin(context.Background(), BeginOptions{ diff --git a/internal/tools/rollback_integration_test.go b/internal/tools/rollback_integration_test.go index f4906124..2ce93295 100644 --- a/internal/tools/rollback_integration_test.go +++ b/internal/tools/rollback_integration_test.go @@ -9,6 +9,7 @@ import ( "testing" rollbackpkg "github.com/1024XEngineer/bytemind/internal/rollback" + "github.com/1024XEngineer/bytemind/internal/session" ) func TestWriteFileToolRecordsRollbackOperation(t *testing.T) { @@ -167,6 +168,70 @@ func TestApplyPatchToolFailureAutomaticallyRestoresChangedFiles(t *testing.T) { } } +func TestRollbackHelperBranches(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + path := filepath.Join(workspace, "file.txt") + mustWriteFile(t, path, "before\n") + + if rollbackEnabled(nil) { + t.Fatal("expected nil execution context to disable rollback") + } + if tracker, err := beginRollbackOperation(context.Background(), &ExecutionContext{Workspace: workspace}, "write_file", []rollbackpkg.FileTarget{{ + Path: "file.txt", + AbsPath: path, + OpType: rollbackpkg.OpTypeUpdate, + }}); err != nil || tracker != nil { + t.Fatalf("expected rollback disabled without run id, got %#v / %v", tracker, err) + } + if tracker, err := beginRollbackOperation(context.Background(), &ExecutionContext{Workspace: workspace, RunID: "trace-empty"}, "write_file", nil); err != nil || tracker != nil { + t.Fatalf("expected empty targets to skip rollback, got %#v / %v", tracker, err) + } + if got := rollbackSessionID(&ExecutionContext{Session: session.New(workspace)}); got == "" { + t.Fatal("expected session id from execution context") + } + if id, err := (*rollbackTracker)(nil).commit(context.Background()); err != nil || id != "" { + t.Fatalf("expected nil tracker commit no-op, got %q / %v", id, err) + } + (*rollbackTracker)(nil).abort(context.Background(), "ignored") +} + +func TestRollbackHelperPropagatesBeginAndCommitFailures(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + binaryPath := filepath.Join(workspace, "binary.bin") + mustWriteFile(t, binaryPath, "a\x00b") + + if _, err := beginRollbackOperation(context.Background(), &ExecutionContext{Workspace: workspace, RunID: "trace-binary"}, "write_file", []rollbackpkg.FileTarget{{ + Path: "binary.bin", + AbsPath: binaryPath, + OpType: rollbackpkg.OpTypeUpdate, + }}); err == nil || !strings.Contains(err.Error(), "not a text") { + t.Fatalf("expected begin rollback snapshot error, got %v", err) + } + + textPath := filepath.Join(workspace, "text.txt") + mustWriteFile(t, textPath, "before\n") + tracker, err := beginRollbackOperation(context.Background(), &ExecutionContext{Workspace: workspace, RunID: "trace-missing"}, "write_file", []rollbackpkg.FileTarget{{ + Path: "text.txt", + AbsPath: textPath, + OpType: rollbackpkg.OpTypeUpdate, + }}) + if err != nil { + t.Fatal(err) + } + if err := os.Remove(textPath); err != nil { + t.Fatal(err) + } + _, err = tracker.commit(context.Background()) + if err == nil || !strings.Contains(err.Error(), "absent after update") { + t.Fatalf("expected commit failure, got %v", err) + } + if got := mustReadToolsTestFile(t, textPath); got != "before\n" { + t.Fatalf("expected failed commit to restore snapshot, got %q", got) + } +} + func rollbackOperationIDFromToolResult(t *testing.T, result string) string { t.Helper() var parsed struct { diff --git a/tui/component_rollback_command_test.go b/tui/component_rollback_command_test.go index 571e4465..c226e2c0 100644 --- a/tui/component_rollback_command_test.go +++ b/tui/component_rollback_command_test.go @@ -8,6 +8,7 @@ import ( "testing" rollbackpkg "github.com/1024XEngineer/bytemind/internal/rollback" + "github.com/1024XEngineer/bytemind/internal/session" ) func TestExecuteRollbackCommandListsAndRollsBackLast(t *testing.T) { @@ -68,6 +69,111 @@ func TestExecuteRollbackCommandRejectsInvalidUsage(t *testing.T) { } } +func TestRunRollbackCommandRecordsExchange(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + store, err := session.NewStore(t.TempDir()) + if err != nil { + t.Fatal(err) + } + sess := session.New(workspace) + m := model{ + workspace: workspace, + store: store, + sess: sess, + } + + if err := m.runRollbackCommand("/rollback"); err != nil { + t.Fatal(err) + } + if m.statusNote != "Rollback operations listed." { + t.Fatalf("expected rollback list status, got %q", m.statusNote) + } + if len(m.chatItems) != 2 || !strings.Contains(m.chatItems[1].Body, "No ByteMind rollback operations") { + t.Fatalf("expected rollback exchange, got %#v", m.chatItems) + } + if len(sess.Messages) != 2 || sess.Messages[0].Text() != "/rollback" { + t.Fatalf("expected rollback exchange recorded in session, got %#v", sess.Messages) + } +} + +func TestRunRollbackCommandHandlesErrorsAndSaveFailure(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + m := model{workspace: workspace, sess: session.New(workspace), store: failingCommitSessionStore{}} + + if err := m.runRollbackCommand("/rollback bad id"); err != nil { + t.Fatal(err) + } + if !strings.Contains(m.chatItems[1].Body, "Usage: /rollback") || !strings.Contains(m.statusNote, "session save failed") { + t.Fatalf("expected error exchange and save failure status, got body=%q status=%q", m.chatItems[1].Body, m.statusNote) + } +} + +func TestHandleSlashCommandRollback(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + store, err := session.NewStore(t.TempDir()) + if err != nil { + t.Fatal(err) + } + m := model{workspace: workspace, store: store, sess: session.New(workspace)} + if err := m.handleSlashCommand("/rollback"); err != nil { + t.Fatal(err) + } + if m.statusNote != "Rollback operations listed." { + t.Fatalf("expected rollback command to run, got %q", m.statusNote) + } +} + +func TestExecuteRollbackCommandAdditionalBranches(t *testing.T) { + t.Setenv("BYTEMIND_HOME", t.TempDir()) + workspace := t.TempDir() + if _, _, err := executeRollbackCommand(context.Background(), workspace, nil, "/not-rollback"); err == nil || err.Error() != rollbackUsage { + t.Fatalf("expected rollback usage error, got %v", err) + } + if _, _, err := executeRollbackCommand(context.Background(), workspace, nil, "/rollback missing"); err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("expected missing rollback operation error, got %v", err) + } + + blockFile := filepath.Join(t.TempDir(), "home-as-file") + if err := os.WriteFile(blockFile, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + t.Setenv("BYTEMIND_HOME", blockFile) + if _, _, err := executeRollbackCommand(context.Background(), workspace, nil, "/rollback"); err == nil || !strings.Contains(err.Error(), "Rollback unavailable") { + t.Fatalf("expected unavailable rollback store, got %v", err) + } +} + +func TestRollbackFormattingBranches(t *testing.T) { + if got := formatRollbackList(nil); !strings.Contains(got, "No ByteMind rollback operations") { + t.Fatalf("expected empty rollback list message, got %q", got) + } + if got := shortRollbackID("short-id"); got != "short-id" { + t.Fatalf("expected short id unchanged, got %q", got) + } + + op := rollbackpkg.Operation{ + OperationID: "20260510T000000.000000000Z-abcdef", + ToolName: "apply_patch", + AffectedFiles: []rollbackpkg.FileChange{ + {Path: "a.txt"}, + {Path: "old.txt", NewPath: "new.txt", OpType: rollbackpkg.OpTypeMove}, + {Path: "c.txt"}, + {Path: "d.txt"}, + }, + } + summary := rollbackPathSummary(op) + if !strings.Contains(summary, "old.txt -> new.txt") || !strings.Contains(summary, "+1 more") { + t.Fatalf("expected move and truncation summary, got %q", summary) + } + success := formatRollbackSuccess(op) + if !strings.Contains(success, "Rollback completed.") || !strings.Contains(success, "Files restored: 4") { + t.Fatalf("expected rollback success details, got %q", success) + } +} + func writeRollbackCommandTestFile(t *testing.T, path, content string) { t.Helper() if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {