diff --git a/CHANGELOG.md b/CHANGELOG.md index e93745c..726a79e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,23 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.11.0] - 2026-01-11 + +### Features + +- **Checkpoint and rollback**: Automatically create checkpoints before session execution, enabling rollback to restore files to their pre-execution state + - Original file content stored as blob files in `~/.local/share/shannot/sessions/{id}/checkpoint/` + - Conflict detection via post-exec hash comparison (bypass with `--force`) + - Support for both local and remote (SSH) rollback + - Large directory handling with limits (100 files / 50MB) and partial checkpoint warnings + - New session status: `rolled_back` + +### CLI + +- Add `shannot rollback ` command with `--force` and `--dry-run` options +- Add `shannot checkpoint list` to list sessions with available checkpoints +- Add `shannot checkpoint show ` to display checkpoint details + ## [0.10.3] - 2026-01-07 ### Bug Fixes diff --git a/README.md b/README.md index c7d56cd..791e94f 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,11 @@ shannot approve - Always-deny destructive patterns (rm -rf /) - Everything else requires human review +**Checkpoint and Rollback** +- Automatic checkpoint before execution +- Restore files to pre-execution state with `shannot rollback` +- Conflict detection prevents accidental overwrites + ## Installation ```bash @@ -112,6 +117,11 @@ shannot setup mcp install # Claude Desktop integration # Status shannot status # Runtime, config, pending sessions + +# Rollback +shannot rollback # Restore files to pre-execution state +shannot checkpoint list # List sessions with checkpoints +shannot checkpoint show # Show checkpoint details ``` ## Configuration diff --git a/docs/reference/cli.md b/docs/reference/cli.md index dac6cee..af8371d 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -96,6 +96,53 @@ shannot setup mcp install # Install for Claude Desktop shannot setup mcp install --client claude-code # Install for Claude Code ``` +### shannot rollback + +Rollback session changes to pre-execution state. + +```bash +shannot rollback # Rollback with conflict check +shannot rollback --force # Skip conflict detection +shannot rollback --dry-run # Preview without making changes +``` + +**Arguments:** + +| Argument | Description | +|----------|-------------| +| `session_id` | Session ID to rollback | + +**Options:** + +| Option | Description | +|--------|-------------| +| `--force`, `-f` | Skip conflict detection | +| `--dry-run`, `-n` | Preview without making changes | + +**Exit Codes:** + +| Code | Meaning | +|------|---------| +| 0 | Rollback successful | +| 1 | Conflict detected (use --force to override) | +| 2 | Session not found or no checkpoint | + +### shannot checkpoint + +Manage session checkpoints. + +```bash +shannot checkpoint list # List sessions with checkpoints +shannot checkpoint show # Show checkpoint details +``` + +**Subcommands:** + +| Subcommand | Description | +|------------|-------------| +| `list` | List sessions with available checkpoints | +| `show ` | Show checkpoint details for a session | + ## See Also - [Usage Guide](../usage.md) - Comprehensive examples diff --git a/docs/reference/execution.md b/docs/reference/execution.md index f6f99b0..3327e7a 100644 --- a/docs/reference/execution.md +++ b/docs/reference/execution.md @@ -12,7 +12,9 @@ Shannot v0.4.0+ uses a session-based approval workflow instead of direct executo |-------|-------------| | Dry-run | Script runs in sandbox, operations captured | | Review | User reviews captured operations via TUI | +| Checkpoint | Original file content saved before changes | | Execute | Approved operations run on host system | +| Rollback | (Optional) Restore files to pre-execution state | ## Session Workflow @@ -47,9 +49,38 @@ The remote receives the session data and executes in its own PyPy sandbox. |--------|---------| | `run_session.py` | Session execution orchestration | | `session.py` | Session data structures | +| `checkpoint.py` | Checkpoint and rollback logic | | `deploy.py` | Remote deployment | | `ssh.py` | Zero-dependency SSH client | +## Checkpoint Creation + +Before committing writes, Shannot creates a checkpoint: + +1. **Blob storage**: Original file content saved as `{hash[:8]}.blob` +2. **Metadata**: Path mappings stored in `session.checkpoint` +3. **Post-exec hashes**: Recorded after writes for conflict detection + +Directory structure: + +``` +~/.local/share/shannot/sessions/{session_id}/ + session.json + checkpoint/ + a1b2c3d4.blob + e5f6g7h8.blob +``` + +## Session Statuses + +| Status | Description | +|--------|-------------| +| `pending` | Awaiting approval | +| `approved` | Ready for execution | +| `executed` | Completed successfully | +| `rolled_back` | Restored to pre-execution state | +| `expired` | TTL exceeded | + ## See Also - [Usage Guide](../usage.md) - Session workflow details diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 9b8429d..4577b09 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -399,6 +399,74 @@ shannot setup runtime shannot status ``` +## Checkpoint and Rollback Issues + +### "Conflict detected" during rollback + +**Symptoms:** +``` +Error: Conflict detected - file was modified since execution +``` + +**Cause:** A file was modified after the session was executed, and the current content differs from what Shannot wrote. + +**Solutions:** + +1. Review the conflict: + ```bash + shannot checkpoint show SESSION_ID + ``` + +2. Force rollback (overwrites current changes): + ```bash + shannot rollback SESSION_ID --force + ``` + +3. Manually restore the file from the checkpoint blob + +### "Partial checkpoint" warning + +**Symptoms:** +``` +Warning: Partial checkpoint - directory too large to fully checkpoint +``` + +**Cause:** Large directories (>100 files or >50MB) cannot be fully checkpointed. + +**Impact:** These directories cannot be restored via rollback. + +**Solutions:** +1. Accept the limitation for large directories +2. Manually backup large directories before execution +3. Use smaller, more targeted scripts + +### "No checkpoint" error + +**Symptoms:** +``` +Error: Session has no checkpoint +``` + +**Causes:** +1. Session was executed before v0.11.0 +2. Checkpoint creation failed +3. Session was created in dry-run only mode + +**Solutions:** +1. Create a new session and execute it +2. Check session details: + ```bash + shannot checkpoint show SESSION_ID + ``` + +### Checkpoint not created + +**Symptoms:** Session executed but no checkpoint is available. + +**Cause:** The session may have been created before the checkpoint feature was added. + +**Solution:** Re-run the script to create a new session with checkpoint support. + ## Getting Help If you're still stuck: diff --git a/docs/usage.md b/docs/usage.md index 28d5a34..d29b473 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -160,6 +160,47 @@ shannot approve show SESSION_ID # Review details # Then press 'x' in TUI ``` +### 4. Checkpoint and Rollback + +Shannot automatically creates checkpoints before executing approved changes, enabling you to restore files to their pre-execution state. + +#### How Checkpoints Work + +1. Before execution, original file content is saved to blob files +2. After execution, file hashes are recorded for conflict detection +3. Use `shannot rollback` to restore files if needed + +#### Rollback Command + +```bash +# Rollback a session (with conflict detection) +shannot rollback abc123 + +# Force rollback (skip conflict check) +shannot rollback abc123 --force + +# Preview what would be restored +shannot rollback abc123 --dry-run +``` + +**Conflict Detection:** If a file was modified after session execution, rollback will fail unless `--force` is used. + +#### Managing Checkpoints + +```bash +# List all sessions with checkpoints +shannot checkpoint list + +# Show checkpoint details for a session +shannot checkpoint show abc123 +``` + +#### Limitations + +- Large directory deletions (>100 files or >50MB) create partial checkpoints +- Partial checkpoints cannot be fully restored +- Checkpoints are tied to session lifecycle + ## Approval Profiles Profiles control which commands execute automatically vs. require approval. diff --git a/pyproject.toml b/pyproject.toml index 7443869..1acea70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "shannot" -version = "0.10.3" +version = "0.11.0" description = "Sandboxed system administration for LLM agents" readme = "README.md" license = {text = "Apache-2.0"} diff --git a/shannot/checkpoint.py b/shannot/checkpoint.py new file mode 100644 index 0000000..b261570 --- /dev/null +++ b/shannot/checkpoint.py @@ -0,0 +1,485 @@ +"""Checkpoint and rollback functionality for session execution.""" + +from __future__ import annotations + +import base64 +import hashlib +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .session import Session + +# Limits for directory deletion checkpointing +CHECKPOINT_MAX_FILES = 100 +CHECKPOINT_MAX_SIZE = 50 * 1024 * 1024 # 50MB + + +def _hash_content(content: bytes) -> str: + """Return SHA256 hash of content.""" + return hashlib.sha256(content).hexdigest() + + +def _blob_name(content_hash: str) -> str: + """Return blob filename from content hash.""" + return f"{content_hash[:8]}.blob" + + +def create_checkpoint(session: Session) -> dict: + """ + Create checkpoint from session's pending writes and deletions. + + Captures original file content before execution so files can be + restored via rollback. Content is stored as blob files in the + session's checkpoint directory. + + Parameters + ---------- + session + Session with pending_writes and pending_deletions to checkpoint. + + Returns + ------- + dict + Checkpoint data: {path: {blob, size, original_hash, was_created}} + """ + checkpoint_dir = session.checkpoint_dir + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + checkpoint: dict[str, dict] = {} + + # Checkpoint original content from pending writes + for write_data in session.pending_writes: + path = write_data.get("path", "") + original_b64 = write_data.get("original_b64") + original_hash = write_data.get("original_hash") + + if original_b64: + # File existed - checkpoint original content + original = base64.b64decode(original_b64) + blob_name = _blob_name(original_hash or _hash_content(original)) + blob_path = checkpoint_dir / blob_name + blob_path.write_bytes(original) + + checkpoint[path] = { + "blob": blob_name, + "size": len(original), + "original_hash": original_hash, + "was_created": False, + } + else: + # New file - mark as created (will be deleted on rollback) + checkpoint[path] = { + "blob": None, + "size": 0, + "original_hash": None, + "was_created": True, + } + + # Checkpoint files being deleted by reading from real filesystem + for del_data in session.pending_deletions: + path = del_data.get("path", "") + target_type = del_data.get("target_type", "file") + + if target_type == "directory": + # Checkpoint directory contents with limits + _checkpoint_directory(path, checkpoint_dir, checkpoint) + else: + # Checkpoint single file + _checkpoint_file(path, checkpoint_dir, checkpoint) + + session.checkpoint = checkpoint + session.checkpoint_created_at = datetime.now().isoformat() + + return checkpoint + + +def _checkpoint_file(path: str, checkpoint_dir: Path, checkpoint: dict) -> bool: + """ + Checkpoint a single file. + + Returns True if file was checkpointed, False if skipped. + """ + real_path = Path(path) + if not real_path.exists() or not real_path.is_file(): + return False + + try: + content = real_path.read_bytes() + content_hash = _hash_content(content) + blob_name = _blob_name(content_hash) + blob_path = checkpoint_dir / blob_name + + # Don't re-write blob if already exists (same content) + if not blob_path.exists(): + blob_path.write_bytes(content) + + checkpoint[path] = { + "blob": blob_name, + "size": len(content), + "original_hash": content_hash, + "was_deleted": True, + } + return True + except (OSError, PermissionError): + return False + + +def _checkpoint_directory(path: str, checkpoint_dir: Path, checkpoint: dict) -> dict: + """ + Checkpoint a directory's contents with limits. + + Returns dict with 'partial' and 'warning' if limits exceeded. + """ + real_path = Path(path) + if not real_path.exists() or not real_path.is_dir(): + return {} + + files = [] + total_size = 0 + + try: + for f in real_path.rglob("*"): + if f.is_file(): + try: + size = f.stat().st_size + files.append((f, size)) + total_size += size + except (OSError, PermissionError): + continue + except (OSError, PermissionError): + return {"partial": True, "warning": f"Cannot read directory: {path}"} + + # Check limits + if len(files) > CHECKPOINT_MAX_FILES or total_size > CHECKPOINT_MAX_SIZE: + # Mark directory as partial checkpoint only + checkpoint[path] = { + "blob": None, + "size": 0, + "original_hash": None, + "was_deleted": True, + "partial": True, + "file_count": len(files), + "total_size": total_size, + } + return { + "partial": True, + "warning": f"Directory {path} too large ({len(files)} files, {total_size} bytes)", + } + + # Checkpoint all files + for f, _size in files: + _checkpoint_file(str(f), checkpoint_dir, checkpoint) + + # Mark directory itself + checkpoint[path] = { + "blob": None, + "size": 0, + "original_hash": None, + "was_deleted": True, + "is_directory": True, + } + + return {} + + +def update_post_exec_hashes(session: Session) -> None: + """ + Record post-execution hashes for conflict detection during rollback. + + Called after commit_writes() to capture the new file state. + """ + if not session.checkpoint: + return + + for path, entry in session.checkpoint.items(): + if entry.get("was_created") or not entry.get("was_deleted"): + # For created/modified files, hash the new content + real_path = Path(path) + if real_path.exists() and real_path.is_file(): + try: + content = real_path.read_bytes() + entry["post_exec_hash"] = _hash_content(content) + except (OSError, PermissionError): + # Skip unreadable files - conflict detection is best-effort + pass + + +def rollback_local(session: Session, *, force: bool = False) -> list[dict]: + """ + Rollback session changes on local filesystem. + + Parameters + ---------- + session + Executed session to rollback. + force + If True, skip conflict detection and restore anyway. + + Returns + ------- + list[dict] + Results: [{path, action, success, error?}] + """ + if not session.checkpoint: + return [{"path": "", "action": "rollback", "success": False, "error": "No checkpoint"}] + + results = [] + conflicts = [] + + # First pass: detect conflicts + if not force: + for path, entry in session.checkpoint.items(): + post_exec_hash = entry.get("post_exec_hash") + if post_exec_hash: + real_path = Path(path) + if real_path.exists() and real_path.is_file(): + try: + current_hash = _hash_content(real_path.read_bytes()) + if current_hash != post_exec_hash: + conflicts.append(path) + except (OSError, PermissionError): + # Skip unreadable files - don't block rollback on read errors + pass + + if conflicts: + return [ + { + "path": p, + "action": "conflict", + "success": False, + "error": "File modified since execution", + } + for p in conflicts + ] + + checkpoint_dir = session.checkpoint_dir + + # Restore files + for path, entry in session.checkpoint.items(): + blob_name = entry.get("blob") + was_created = entry.get("was_created", False) + was_deleted = entry.get("was_deleted", False) + is_directory = entry.get("is_directory", False) + partial = entry.get("partial", False) + + try: + real_path = Path(path) + + if was_created: + # File was created by session - delete it + if real_path.exists(): + real_path.unlink() + results.append({"path": path, "action": "deleted", "success": True}) + + elif was_deleted: + if partial: + # Partial checkpoint - skip with warning + results.append( + { + "path": path, + "action": "skipped", + "success": False, + "error": "Partial checkpoint - directory too large", + } + ) + elif is_directory: + # Recreate directory (files inside restored separately) + real_path.mkdir(parents=True, exist_ok=True) + results.append({"path": path, "action": "recreated_dir", "success": True}) + elif blob_name: + # Recreate deleted file from blob + blob_path = checkpoint_dir / blob_name + if blob_path.exists(): + real_path.parent.mkdir(parents=True, exist_ok=True) + real_path.write_bytes(blob_path.read_bytes()) + results.append({"path": path, "action": "recreated", "success": True}) + else: + results.append( + { + "path": path, + "action": "recreate", + "success": False, + "error": "Blob not found", + } + ) + + elif blob_name: + # File was modified - restore original content + blob_path = checkpoint_dir / blob_name + if blob_path.exists(): + real_path.parent.mkdir(parents=True, exist_ok=True) + real_path.write_bytes(blob_path.read_bytes()) + results.append({"path": path, "action": "restored", "success": True}) + else: + results.append( + { + "path": path, + "action": "restore", + "success": False, + "error": "Blob not found", + } + ) + + except Exception as e: + results.append({"path": path, "action": "rollback", "success": False, "error": str(e)}) + + return results + + +def rollback_remote(session: Session, ssh: object, *, force: bool = False) -> list[dict]: + """ + Rollback session changes on remote filesystem via SSH. + + Parameters + ---------- + session + Executed session to rollback. + ssh + SSHConnection instance. + force + If True, skip conflict detection. + + Returns + ------- + list[dict] + Results: [{path, action, success, error?}] + """ + import shlex + + if not session.checkpoint: + return [{"path": "", "action": "rollback", "success": False, "error": "No checkpoint"}] + + results = [] + conflicts = [] + + # First pass: detect conflicts via SSH + if not force: + for path, entry in session.checkpoint.items(): + post_exec_hash = entry.get("post_exec_hash") + if post_exec_hash: + try: + result = ssh.run( # type: ignore[union-attr] + f"sha256sum {shlex.quote(path)} 2>/dev/null || echo NOTFOUND" + ) + stdout_str = result.stdout.decode("utf-8", errors="replace") + if "NOTFOUND" not in stdout_str: + current_hash = stdout_str.split()[0] + if current_hash != post_exec_hash: + conflicts.append(path) + except Exception: + # Skip on SSH/network errors - don't block rollback on transient failures + pass + + if conflicts: + return [ + { + "path": p, + "action": "conflict", + "success": False, + "error": "File modified since execution", + } + for p in conflicts + ] + + checkpoint_dir = session.checkpoint_dir + + # Restore files via SSH + for path, entry in session.checkpoint.items(): + blob_name = entry.get("blob") + was_created = entry.get("was_created", False) + was_deleted = entry.get("was_deleted", False) + is_directory = entry.get("is_directory", False) + partial = entry.get("partial", False) + + try: + if was_created: + # File was created - delete it via SSH + ssh.run(f"rm -f {shlex.quote(path)}") # type: ignore[union-attr] + results.append({"path": path, "action": "deleted", "success": True}) + + elif was_deleted: + if partial: + results.append( + { + "path": path, + "action": "skipped", + "success": False, + "error": "Partial checkpoint - directory too large", + } + ) + elif is_directory: + ssh.run(f"mkdir -p {shlex.quote(path)}") # type: ignore[union-attr] + results.append({"path": path, "action": "recreated_dir", "success": True}) + elif blob_name: + blob_path = checkpoint_dir / blob_name + if blob_path.exists(): + content = blob_path.read_bytes() + parent = str(Path(path).parent) + if parent != "/": + ssh.run(f"mkdir -p {shlex.quote(parent)}") # type: ignore[union-attr] + ssh.write_file(path, content) # type: ignore[union-attr] + results.append({"path": path, "action": "recreated", "success": True}) + else: + results.append( + { + "path": path, + "action": "recreate", + "success": False, + "error": "Blob not found", + } + ) + + elif blob_name: + blob_path = checkpoint_dir / blob_name + if blob_path.exists(): + content = blob_path.read_bytes() + parent = str(Path(path).parent) + if parent != "/": + ssh.run(f"mkdir -p {shlex.quote(parent)}") # type: ignore[union-attr] + ssh.write_file(path, content) # type: ignore[union-attr] + results.append({"path": path, "action": "restored", "success": True}) + else: + results.append( + { + "path": path, + "action": "restore", + "success": False, + "error": "Blob not found", + } + ) + + except Exception as e: + results.append({"path": path, "action": "rollback", "success": False, "error": str(e)}) + + return results + + +def list_checkpoints() -> list[tuple]: + """ + List all sessions with checkpoints. + + Returns + ------- + list[tuple] + List of (Session, checkpoint_info) tuples. + """ + from .session import Session + + result = [] + for session in Session.list_all(): + if session.checkpoint_created_at and session.checkpoint: + file_count = len(session.checkpoint) + total_size = sum(e.get("size", 0) for e in session.checkpoint.values()) + result.append( + ( + session, + { + "file_count": file_count, + "total_size": total_size, + "created_at": session.checkpoint_created_at, + }, + ) + ) + return result diff --git a/shannot/cli.py b/shannot/cli.py index 7f668d1..5a64454 100644 --- a/shannot/cli.py +++ b/shannot/cli.py @@ -731,6 +731,197 @@ def cmd_mcp_install(args: argparse.Namespace) -> int: return 0 +def cmd_rollback(args: argparse.Namespace) -> int: + """Handle 'shannot rollback' command.""" + from .checkpoint import rollback_local, rollback_remote + from .session import Session + + session_id = args.session_id + + try: + session = Session.load(session_id) + except FileNotFoundError: + print(f"Error: Session not found: {session_id}", file=sys.stderr) + return 1 + + # Verify session has a checkpoint + if not session.checkpoint_created_at: + print(f"Error: No checkpoint for session {session_id}", file=sys.stderr) + print("Only executed sessions have checkpoints.", file=sys.stderr) + return 1 + + if session.status == "rolled_back": + print(f"Error: Session {session_id} already rolled back", file=sys.stderr) + return 1 + + # Show what would be rolled back + if session.checkpoint: + print(f"Session: {session_id}") + print(f"Checkpoint created: {session.checkpoint_created_at}") + print(f"Files to restore: {len(session.checkpoint)}") + print() + + if args.dry_run: + print("Dry run - would restore:") + for path, entry in session.checkpoint.items(): + was_created = entry.get("was_created", False) + was_deleted = entry.get("was_deleted", False) + partial = entry.get("partial", False) + + if was_created: + print(f" DELETE {path} (was created)") + elif was_deleted: + if partial: + print(f" SKIP {path} (partial checkpoint)") + else: + print(f" RECREATE {path}") + else: + print(f" RESTORE {path}") + return 0 + + # Perform rollback + if session.is_remote(): + from .config import resolve_target + from .ssh import SSHConfig, SSHConnection + + user, host, port = resolve_target(session.target or "") + config = SSHConfig(target=f"{user}@{host}", port=port) + + with SSHConnection(config) as ssh: + if not ssh.connect(): + print("Error: Failed to connect to remote", file=sys.stderr) + return 1 + results = rollback_remote(session, ssh, force=args.force) + else: + results = rollback_local(session, force=args.force) + + # Check for conflicts + conflicts = [r for r in results if r.get("action") == "conflict"] + if conflicts: + print(f"Error: {len(conflicts)} file(s) modified since execution:", file=sys.stderr) + for r in conflicts: + print(f" {r['path']}", file=sys.stderr) + print("\nUse --force to restore anyway.", file=sys.stderr) + return 1 + + # Update session status + session.status = "rolled_back" + session.save() + + # Display results + success_count = sum(1 for r in results if r.get("success")) + error_count = sum(1 for r in results if not r.get("success")) + + print(f"Rollback complete: {success_count} succeeded, {error_count} failed") + for r in results: + path = r.get("path", "") + action = r.get("action", "") + success = r.get("success", False) + error = r.get("error", "") + + mark = "✓" if success else "✗" + if success: + print(f" {mark} {action}: {path}") + else: + print(f" {mark} {action}: {path} ({error})") + + return 0 if error_count == 0 else 1 + + +def cmd_checkpoint(args: argparse.Namespace) -> int: + """Handle 'shannot checkpoint' command.""" + if args.checkpoint_cmd == "list": + return cmd_checkpoint_list(args) + elif args.checkpoint_cmd == "show": + return cmd_checkpoint_show(args) + else: + print("Usage: shannot checkpoint {list,show}", file=sys.stderr) + return 1 + + +def cmd_checkpoint_list(args: argparse.Namespace) -> int: + """Handle 'shannot checkpoint list' command.""" + from .checkpoint import list_checkpoints + + checkpoints = list_checkpoints() + + if not checkpoints: + print("No checkpoints available.") + print("Checkpoints are created when sessions are executed.") + return 0 + + print(f"{'SESSION ID':<30} {'STATUS':<12} {'FILES':<6} {'SIZE':<10} {'CREATED'}") + print("-" * 80) + + for session, info in checkpoints: + size_str = _format_checkpoint_size(info["total_size"]) + # Truncate timestamp to date only + created = info["created_at"][:10] if info["created_at"] else "" + print( + f"{session.id:<30} {session.status:<12} {info['file_count']:<6} " + f"{size_str:<10} {created}" + ) + + return 0 + + +def cmd_checkpoint_show(args: argparse.Namespace) -> int: + """Handle 'shannot checkpoint show' command.""" + from .session import Session + + session_id = args.session_id + + try: + session = Session.load(session_id, audit=False) + except FileNotFoundError: + print(f"Error: Session not found: {session_id}", file=sys.stderr) + return 1 + + if not session.checkpoint: + print(f"No checkpoint for session {session_id}", file=sys.stderr) + return 1 + + print(f"Session: {session_id}") + print(f"Status: {session.status}") + print(f"Checkpoint created: {session.checkpoint_created_at}") + print(f"Checkpoint directory: {session.checkpoint_dir}") + print() + print("Files:") + + for path, entry in sorted(session.checkpoint.items()): + blob = entry.get("blob", "") + size = entry.get("size", 0) + was_created = entry.get("was_created", False) + was_deleted = entry.get("was_deleted", False) + partial = entry.get("partial", False) + + if was_created: + tag = "[created]" + elif was_deleted: + if partial: + tag = "[deleted, partial]" + else: + tag = "[deleted]" + else: + tag = "[modified]" + + size_str = _format_checkpoint_size(size) if size else "" + blob_str = f" ({blob})" if blob else "" + print(f" {path} {tag} {size_str}{blob_str}") + + return 0 + + +def _format_checkpoint_size(size: int) -> str: + """Format size for checkpoint display.""" + if size < 1024: + return f"{size} B" + elif size < 1024 * 1024: + return f"{size / 1024:.1f} KB" + else: + return f"{size / (1024 * 1024):.1f} MB" + + def cmd_status(args: argparse.Namespace) -> int: """Handle 'shannot status' command.""" # Determine what to show @@ -870,7 +1061,7 @@ def main() -> int: subparsers = parser.add_subparsers( dest="command", help="Commands", - metavar="{run,approve,status,setup}", + metavar="{run,approve,status,setup,rollback,checkpoint}", ) # ===== setup subcommand (with sub-subcommands) ===== @@ -1114,6 +1305,60 @@ def main() -> int: ) status_parser.set_defaults(func=cmd_status) + # ===== rollback subcommand ===== + rollback_parser = subparsers.add_parser( + "rollback", + help="Rollback session to pre-execution state", + description="Restore files to their state before session was executed", + ) + rollback_parser.add_argument( + "session_id", + help="Session ID to rollback", + ) + rollback_parser.add_argument( + "--force", + "-f", + action="store_true", + help="Skip conflict detection and restore anyway", + ) + rollback_parser.add_argument( + "--dry-run", + "-n", + action="store_true", + help="Show what would be restored without making changes", + ) + rollback_parser.set_defaults(func=cmd_rollback) + + # ===== checkpoint subcommand ===== + checkpoint_parser = subparsers.add_parser( + "checkpoint", + help="Manage checkpoints", + description="List and inspect session checkpoints", + ) + checkpoint_subparsers = checkpoint_parser.add_subparsers( + dest="checkpoint_cmd", + help="Checkpoint commands", + ) + + # checkpoint list + checkpoint_subparsers.add_parser( + "list", + help="List sessions with checkpoints", + description="Show all sessions that have checkpoints available for rollback", + ) + + # checkpoint show + checkpoint_show_parser = checkpoint_subparsers.add_parser( + "show", + help="Show checkpoint details", + description="Display files included in a session checkpoint", + ) + checkpoint_show_parser.add_argument( + "session_id", + help="Session ID to show checkpoint for", + ) + checkpoint_parser.set_defaults(func=cmd_checkpoint) + # Parse and execute args = parser.parse_args() diff --git a/shannot/run_session.py b/shannot/run_session.py index 6e661e9..e19bc11 100644 --- a/shannot/run_session.py +++ b/shannot/run_session.py @@ -209,9 +209,17 @@ def execute_session_direct(session) -> int: exit_code = result or 0 executed_commands = [] + # Create checkpoint before committing changes + from .checkpoint import create_checkpoint, update_post_exec_hashes + + create_checkpoint(session) + # Commit pending writes to filesystem completed_writes = session.commit_writes() + # Record post-execution hashes for rollback conflict detection + update_post_exec_hashes(session) + # Commit pending deletions to filesystem completed_deletions = session.commit_deletions() diff --git a/shannot/session.py b/shannot/session.py index 0db6877..0cd7180 100644 --- a/shannot/session.py +++ b/shannot/session.py @@ -13,7 +13,14 @@ from .config import SESSIONS_DIR SessionStatus = Literal[ - "pending", "approved", "rejected", "executed", "failed", "cancelled", "expired" + "pending", + "approved", + "rejected", + "executed", + "failed", + "cancelled", + "expired", + "rolled_back", ] # Session TTL - pending sessions expire after this duration @@ -50,6 +57,10 @@ class Session: target: str | None = None # SSH target (user@host) if remote remote_session_id: str | None = None # Session ID on remote + # Checkpoint/rollback fields + checkpoint_created_at: str | None = None # ISO timestamp when checkpoint was created + checkpoint: dict | None = None # path → {blob, size, mtime, post_exec_hash} + def is_remote(self) -> bool: """Check if this is a remote session.""" return self.target is not None @@ -312,6 +323,11 @@ def session_dir(self) -> Path: """Directory storing this session's data.""" return SESSIONS_DIR / self.id + @property + def checkpoint_dir(self) -> Path: + """Directory storing checkpoint blob files.""" + return self.session_dir / "checkpoint" + def save(self) -> None: """Persist session to disk.""" self.session_dir.mkdir(parents=True, exist_ok=True) diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py new file mode 100644 index 0000000..aabe007 --- /dev/null +++ b/test/test_checkpoint.py @@ -0,0 +1,433 @@ +"""Tests for checkpoint and rollback functionality.""" + +from __future__ import annotations + +import base64 +import hashlib + +import pytest + + +@pytest.fixture +def temp_session_dir(tmp_path): + """Create a temporary session directory structure.""" + session_dir = tmp_path / "sessions" / "test-session-1234" + session_dir.mkdir(parents=True) + return session_dir + + +@pytest.fixture +def mock_session(temp_session_dir, monkeypatch): + """Create a mock session for testing.""" + from shannot.session import Session + + # Patch SESSIONS_DIR to use temp directory + sessions_dir = temp_session_dir.parent + monkeypatch.setattr("shannot.config.SESSIONS_DIR", sessions_dir) + monkeypatch.setattr("shannot.session.SESSIONS_DIR", sessions_dir) + + session = Session( + id="test-session-1234", + name="Test Session", + script_path="/tmp/test.py", + pending_writes=[], + pending_deletions=[], + ) + return session + + +class TestCreateCheckpoint: + """Tests for create_checkpoint function.""" + + def test_checkpoint_modified_file(self, mock_session, tmp_path): + """Test checkpointing a modified file.""" + from shannot.checkpoint import create_checkpoint + + # Set up: file that will be modified + original_content = b"original content" + new_content = b"new content" + original_hash = hashlib.sha256(original_content).hexdigest() + + mock_session.pending_writes = [ + { + "path": "/tmp/test.txt", + "content_b64": base64.b64encode(new_content).decode(), + "original_b64": base64.b64encode(original_content).decode(), + "original_hash": original_hash, + } + ] + + # Create checkpoint + checkpoint = create_checkpoint(mock_session) + + # Verify checkpoint was created + assert "/tmp/test.txt" in checkpoint + entry = checkpoint["/tmp/test.txt"] + assert entry["blob"] is not None + assert entry["size"] == len(original_content) + assert entry["was_created"] is False + + # Verify blob file exists + blob_path = mock_session.checkpoint_dir / entry["blob"] + assert blob_path.exists() + assert blob_path.read_bytes() == original_content + + def test_checkpoint_new_file(self, mock_session): + """Test checkpointing a newly created file.""" + from shannot.checkpoint import create_checkpoint + + new_content = b"new file content" + + mock_session.pending_writes = [ + { + "path": "/tmp/newfile.txt", + "content_b64": base64.b64encode(new_content).decode(), + "original_b64": None, + "original_hash": None, + } + ] + + checkpoint = create_checkpoint(mock_session) + + assert "/tmp/newfile.txt" in checkpoint + entry = checkpoint["/tmp/newfile.txt"] + assert entry["blob"] is None + assert entry["was_created"] is True + + def test_checkpoint_deleted_file(self, mock_session, tmp_path): + """Test checkpointing a file marked for deletion.""" + from shannot.checkpoint import create_checkpoint + + # Create real file to be deleted + test_file = tmp_path / "to_delete.txt" + test_file.write_bytes(b"content to backup") + + mock_session.pending_deletions = [ + { + "path": str(test_file), + "target_type": "file", + "size": 17, + } + ] + + checkpoint = create_checkpoint(mock_session) + + assert str(test_file) in checkpoint + entry = checkpoint[str(test_file)] + assert entry["blob"] is not None + assert entry["was_deleted"] is True + + # Verify blob contains original content + blob_path = mock_session.checkpoint_dir / entry["blob"] + assert blob_path.read_bytes() == b"content to backup" + + def test_checkpoint_sets_timestamps(self, mock_session): + """Test that checkpoint creation sets timestamps.""" + from shannot.checkpoint import create_checkpoint + + mock_session.pending_writes = [] + mock_session.pending_deletions = [] + + create_checkpoint(mock_session) + + assert mock_session.checkpoint_created_at is not None + assert mock_session.checkpoint is not None + + +class TestUpdatePostExecHashes: + """Tests for update_post_exec_hashes function.""" + + def test_update_hashes_for_modified_files(self, mock_session, tmp_path): + """Test post-exec hash is recorded for modified files.""" + from shannot.checkpoint import update_post_exec_hashes + + # Create file that was "written" during execution + test_file = tmp_path / "modified.txt" + test_file.write_bytes(b"new content after execution") + + mock_session.checkpoint = { + str(test_file): { + "blob": "abc12345.blob", + "size": 10, + "was_created": False, + } + } + + update_post_exec_hashes(mock_session) + + entry = mock_session.checkpoint[str(test_file)] + assert "post_exec_hash" in entry + expected_hash = hashlib.sha256(b"new content after execution").hexdigest() + assert entry["post_exec_hash"] == expected_hash + + +class TestRollbackLocal: + """Tests for rollback_local function.""" + + def test_rollback_modified_file(self, mock_session, tmp_path): + """Test rolling back a modified file to original content.""" + from shannot.checkpoint import rollback_local + + # Set up: file exists with "new" content + test_file = tmp_path / "modified.txt" + test_file.write_bytes(b"new content") + + # Create checkpoint with original content + checkpoint_dir = mock_session.checkpoint_dir + checkpoint_dir.mkdir(parents=True, exist_ok=True) + blob_name = "abc12345.blob" + (checkpoint_dir / blob_name).write_bytes(b"original content") + + mock_session.checkpoint = { + str(test_file): { + "blob": blob_name, + "size": 16, + "was_created": False, + "post_exec_hash": hashlib.sha256(b"new content").hexdigest(), + } + } + + results = rollback_local(mock_session) + + assert len(results) == 1 + assert results[0]["success"] is True + assert results[0]["action"] == "restored" + assert test_file.read_bytes() == b"original content" + + def test_rollback_created_file(self, mock_session, tmp_path): + """Test rolling back a created file by deleting it.""" + from shannot.checkpoint import rollback_local + + # Set up: file was created during execution + test_file = tmp_path / "created.txt" + test_file.write_bytes(b"created content") + + mock_session.checkpoint = { + str(test_file): { + "blob": None, + "size": 0, + "was_created": True, + "post_exec_hash": hashlib.sha256(b"created content").hexdigest(), + } + } + + results = rollback_local(mock_session) + + assert len(results) == 1 + assert results[0]["success"] is True + assert results[0]["action"] == "deleted" + assert not test_file.exists() + + def test_rollback_deleted_file(self, mock_session, tmp_path): + """Test rolling back a deleted file by recreating it.""" + from shannot.checkpoint import rollback_local + + # Set up: file doesn't exist (was deleted) + test_file = tmp_path / "deleted.txt" + assert not test_file.exists() + + # Create checkpoint with original content + checkpoint_dir = mock_session.checkpoint_dir + checkpoint_dir.mkdir(parents=True, exist_ok=True) + blob_name = "def67890.blob" + (checkpoint_dir / blob_name).write_bytes(b"deleted file content") + + mock_session.checkpoint = { + str(test_file): { + "blob": blob_name, + "size": 20, + "was_deleted": True, + } + } + + results = rollback_local(mock_session) + + assert len(results) == 1 + assert results[0]["success"] is True + assert results[0]["action"] == "recreated" + assert test_file.exists() + assert test_file.read_bytes() == b"deleted file content" + + def test_rollback_detects_conflict(self, mock_session, tmp_path): + """Test that rollback detects file modifications since execution.""" + from shannot.checkpoint import rollback_local + + # Set up: file was modified again after execution + test_file = tmp_path / "conflict.txt" + test_file.write_bytes(b"modified again after execution") + + checkpoint_dir = mock_session.checkpoint_dir + checkpoint_dir.mkdir(parents=True, exist_ok=True) + blob_name = "conflict.blob" + (checkpoint_dir / blob_name).write_bytes(b"original content") + + mock_session.checkpoint = { + str(test_file): { + "blob": blob_name, + "size": 16, + "was_created": False, + "post_exec_hash": hashlib.sha256(b"content at execution time").hexdigest(), + } + } + + results = rollback_local(mock_session) + + assert len(results) == 1 + assert results[0]["success"] is False + assert results[0]["action"] == "conflict" + # File should NOT be modified + assert test_file.read_bytes() == b"modified again after execution" + + def test_rollback_force_ignores_conflict(self, mock_session, tmp_path): + """Test that --force bypasses conflict detection.""" + from shannot.checkpoint import rollback_local + + test_file = tmp_path / "conflict.txt" + test_file.write_bytes(b"modified again after execution") + + checkpoint_dir = mock_session.checkpoint_dir + checkpoint_dir.mkdir(parents=True, exist_ok=True) + blob_name = "conflict.blob" + (checkpoint_dir / blob_name).write_bytes(b"original content") + + mock_session.checkpoint = { + str(test_file): { + "blob": blob_name, + "size": 16, + "was_created": False, + "post_exec_hash": hashlib.sha256(b"content at execution time").hexdigest(), + } + } + + results = rollback_local(mock_session, force=True) + + assert len(results) == 1 + assert results[0]["success"] is True + assert results[0]["action"] == "restored" + assert test_file.read_bytes() == b"original content" + + def test_rollback_partial_checkpoint_skipped(self, mock_session, tmp_path): + """Test that partial checkpoints are skipped with warning.""" + from shannot.checkpoint import rollback_local + + mock_session.checkpoint = { + "/tmp/large_dir": { + "blob": None, + "size": 0, + "was_deleted": True, + "partial": True, + "file_count": 1000, + "total_size": 100_000_000, + } + } + + results = rollback_local(mock_session) + + assert len(results) == 1 + assert results[0]["success"] is False + assert results[0]["action"] == "skipped" + assert "partial" in results[0]["error"].lower() + + +class TestListCheckpoints: + """Tests for list_checkpoints function.""" + + def test_list_empty(self, monkeypatch): + """Test listing when no checkpoints exist.""" + from shannot.checkpoint import list_checkpoints + from shannot.session import Session + + monkeypatch.setattr(Session, "list_all", lambda: []) + + result = list_checkpoints() + assert result == [] + + def test_list_with_checkpoints(self, mock_session, monkeypatch): + """Test listing sessions with checkpoints.""" + from shannot.checkpoint import list_checkpoints + from shannot.session import Session + + mock_session.checkpoint_created_at = "2026-01-11T10:00:00" + mock_session.checkpoint = { + "/tmp/file1.txt": {"blob": "abc.blob", "size": 100}, + "/tmp/file2.txt": {"blob": "def.blob", "size": 200}, + } + + monkeypatch.setattr(Session, "list_all", lambda: [mock_session]) + + result = list_checkpoints() + + assert len(result) == 1 + session, info = result[0] + assert session.id == "test-session-1234" + assert info["file_count"] == 2 + assert info["total_size"] == 300 + assert info["created_at"] == "2026-01-11T10:00:00" + + +class TestCheckpointDirectory: + """Tests for directory checkpoint with size limits.""" + + def test_large_directory_creates_partial_checkpoint(self, mock_session, tmp_path): + """Test that large directories create partial checkpoints.""" + from shannot.checkpoint import CHECKPOINT_MAX_FILES, create_checkpoint + + # Create directory with more than CHECKPOINT_MAX_FILES files + large_dir = tmp_path / "large_dir" + large_dir.mkdir() + for i in range(CHECKPOINT_MAX_FILES + 10): + (large_dir / f"file_{i}.txt").write_bytes(b"x") + + mock_session.pending_deletions = [ + { + "path": str(large_dir), + "target_type": "directory", + "size": 0, + } + ] + + checkpoint = create_checkpoint(mock_session) + + # Should have partial checkpoint entry + entry = checkpoint[str(large_dir)] + assert entry["partial"] is True + assert entry["file_count"] > CHECKPOINT_MAX_FILES + + +class TestSessionIntegration: + """Integration tests with actual session execution flow.""" + + def test_checkpoint_survives_session_save_load(self, mock_session): + """Test that checkpoint data survives session save/load cycle.""" + from shannot.checkpoint import create_checkpoint + + mock_session.pending_writes = [ + { + "path": "/tmp/test.txt", + "content_b64": base64.b64encode(b"new").decode(), + "original_b64": base64.b64encode(b"old").decode(), + "original_hash": hashlib.sha256(b"old").hexdigest(), + } + ] + + create_checkpoint(mock_session) + mock_session.save() + + # Load session + from shannot.session import Session + + loaded = Session.load("test-session-1234", audit=False) + + assert loaded.checkpoint_created_at is not None + assert loaded.checkpoint is not None + assert "/tmp/test.txt" in loaded.checkpoint + + def test_rolled_back_status(self, mock_session): + """Test that rolled_back status is valid.""" + mock_session.status = "rolled_back" + mock_session.save() + + from shannot.session import Session + + loaded = Session.load("test-session-1234", audit=False) + assert loaded.status == "rolled_back"