From 162ef2ab59c2036473baa8f2b5dd57c393abaa30 Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Sun, 3 May 2026 16:22:42 +0800 Subject: [PATCH 1/8] =?UTF-8?q?feat(checkpoint):=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E4=B8=8E=E4=B8=8A=E4=B8=8B=E6=96=87=E5=BF=AB=E7=85=A7=E5=AE=9A?= =?UTF-8?q?=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/app/bootstrap.go | 15 + internal/checkpoint/checkpoint_manager.go | 364 ++++++++++++++++++++++ internal/checkpoint/shadow_repo.go | 191 ++++++++++++ internal/runtime/checkpoint_gate.go | 133 ++++++++ internal/runtime/events.go | 20 ++ internal/runtime/run.go | 8 + internal/runtime/runtime.go | 9 + internal/session/checkpoint_types.go | 64 ++++ internal/session/sqlite_store.go | 119 +++++++ internal/session/store.go | 4 +- internal/session/store_test.go | 6 +- internal/session/workspace.go | 4 +- 12 files changed, 930 insertions(+), 7 deletions(-) create mode 100644 internal/checkpoint/checkpoint_manager.go create mode 100644 internal/checkpoint/shadow_repo.go create mode 100644 internal/runtime/checkpoint_gate.go create mode 100644 internal/session/checkpoint_types.go diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index 6fb8e9e6..87290a7b 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -10,6 +10,7 @@ import ( tea "github.com/charmbracelet/bubbletea" + "neo-code/internal/checkpoint" "neo-code/internal/config" configstate "neo-code/internal/config/state" agentcontext "neo-code/internal/context" @@ -230,6 +231,20 @@ func BuildGatewayServerDeps(ctx context.Context, opts BootstrapOptions) (Runtime )) } + // Checkpoint 基础设施:影子仓库 + SQLite checkpoint 存储 + if gitAvail, _ := checkpoint.CheckGitAvailability(ctx); gitAvail { + projectDir := agentsession.HashWorkspaceRoot(cfg.Workdir) + shadowDir := filepath.Join(sharedDeps.ConfigManager.BaseDir(), "projects", projectDir) + shadowRepo := checkpoint.NewShadowRepo(shadowDir, cfg.Workdir) + if err := shadowRepo.Init(ctx); err != nil { + log.Printf("checkpoint shadow repo init warning: %v", err) + } else { + dbPath := agentsession.DatabasePath(sharedDeps.ConfigManager.BaseDir(), cfg.Workdir) + checkpointStore := checkpoint.NewSQLiteCheckpointStore(dbPath) + runtimeSvc.SetCheckpointDependencies(checkpointStore, shadowRepo) + } + } + runtimeImpl := agentruntime.Runtime(runtimeSvc) closeFns := []func() error{toolsCleanup, sessionStore.Close} diff --git a/internal/checkpoint/checkpoint_manager.go b/internal/checkpoint/checkpoint_manager.go new file mode 100644 index 00000000..97860416 --- /dev/null +++ b/internal/checkpoint/checkpoint_manager.go @@ -0,0 +1,364 @@ +package checkpoint + +import ( + "context" + "database/sql" + "errors" + "fmt" + "sync" + "time" + + "neo-code/internal/session" +) + +// CheckpointStore 定义 checkpoint 持久化的意图型接口。 +type CheckpointStore interface { + CreateCheckpoint(ctx context.Context, input CreateCheckpointInput) (session.CheckpointRecord, error) + ListCheckpoints(ctx context.Context, sessionID string, opts ListCheckpointOpts) ([]session.CheckpointRecord, error) + GetCheckpoint(ctx context.Context, checkpointID string) (session.CheckpointRecord, *session.SessionCheckpoint, error) + UpdateCheckpointStatus(ctx context.Context, checkpointID string, status session.CheckpointStatus) error + GetLatestResumeCheckpoint(ctx context.Context, sessionID string) (*session.ResumeCheckpoint, error) +} + +// CreateCheckpointInput 描述一次 checkpoint 创建的完整输入。 +type CreateCheckpointInput struct { + Record session.CheckpointRecord + SessionCP session.SessionCheckpoint +} + +// ListCheckpointOpts 描述 checkpoint 列表查询选项。 +type ListCheckpointOpts struct { + Limit int + RestorableOnly bool +} + +// SQLiteCheckpointStore 基于 SQLite 实现 checkpoint 持久化。 +type SQLiteCheckpointStore struct { + dbPath string + initMu sync.Mutex + db *sql.DB +} + +// NewSQLiteCheckpointStore 创建 checkpoint 存储实例。 +// dbPath 为 session.db 文件路径,可通过 session.DatabasePath 获取。 +func NewSQLiteCheckpointStore(dbPath string) *SQLiteCheckpointStore { + return &SQLiteCheckpointStore{ + dbPath: dbPath, + } +} + +// Close 释放数据库连接。 +func (s *SQLiteCheckpointStore) Close() error { + if s == nil || s.db == nil { + return nil + } + return s.db.Close() +} + +func (s *SQLiteCheckpointStore) ensureDB(ctx context.Context) (*sql.DB, error) { + s.initMu.Lock() + defer s.initMu.Unlock() + if s.db != nil { + return s.db, nil + } + db, err := sql.Open("sqlite", s.dbPath) + if err != nil { + return nil, fmt.Errorf("checkpoint: open sqlite db: %w", err) + } + db.SetMaxOpenConns(2) + db.SetMaxIdleConns(2) + + pragmas := []string{ + `PRAGMA journal_mode=WAL`, + `PRAGMA synchronous=NORMAL`, + `PRAGMA foreign_keys=ON`, + `PRAGMA busy_timeout=5000`, + } + for _, pragma := range pragmas { + if _, err := db.ExecContext(ctx, pragma); err != nil { + _ = db.Close() + return nil, fmt.Errorf("checkpoint: apply pragma %q: %w", pragma, err) + } + } + s.db = db + return db, nil +} + +// CreateCheckpoint 在单一事务内写入 checkpoint record + session checkpoint。 +// 事务内完成 record INSERT → session_cp INSERT → record UPDATE(设置 session_checkpoint_ref + status=available)。 +func (s *SQLiteCheckpointStore) CreateCheckpoint(ctx context.Context, input CreateCheckpointInput) (session.CheckpointRecord, error) { + if err := ctx.Err(); err != nil { + return session.CheckpointRecord{}, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return session.CheckpointRecord{}, err + } + + record := input.Record + sessionCP := input.SessionCP + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return session.CheckpointRecord{}, fmt.Errorf("checkpoint: begin create tx: %w", err) + } + defer rollbackTx(tx) + + // INSERT checkpoint_record (status = creating) + _, err = tx.ExecContext(ctx, ` +INSERT INTO checkpoint_records ( + id, workspace_key, session_id, run_id, workdir, created_at_ms, + reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref, + transcript_revision, restorable, status +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +`, + record.CheckpointID, + record.WorkspaceKey, + record.SessionID, + record.RunID, + record.Workdir, + toUnixMillis(record.CreatedAt), + string(record.Reason), + record.CodeCheckpointRef, + "", // session_checkpoint_ref filled below + record.ResumeCheckpointRef, + record.TranscriptRevision, + boolToInt(record.Restorable), + string(session.CheckpointStatusCreating), + ) + if err != nil { + return session.CheckpointRecord{}, fmt.Errorf("checkpoint: insert record %s: %w", record.CheckpointID, err) + } + + // INSERT session_checkpoint + _, err = tx.ExecContext(ctx, ` +INSERT INTO session_checkpoints (id, session_id, head_json, messages_json, created_at_ms) +VALUES (?, ?, ?, ?, ?) +`, + sessionCP.ID, + sessionCP.SessionID, + sessionCP.HeadJSON, + sessionCP.MessagesJSON, + toUnixMillis(sessionCP.CreatedAt), + ) + if err != nil { + return session.CheckpointRecord{}, fmt.Errorf("checkpoint: insert session cp %s: %w", sessionCP.ID, err) + } + + // UPDATE checkpoint_record: set session_checkpoint_ref + status = available + _, err = tx.ExecContext(ctx, ` +UPDATE checkpoint_records +SET session_checkpoint_ref = ?, status = ? +WHERE id = ? +`, + sessionCP.ID, + string(session.CheckpointStatusAvailable), + record.CheckpointID, + ) + if err != nil { + return session.CheckpointRecord{}, fmt.Errorf("checkpoint: update record %s: %w", record.CheckpointID, err) + } + + if err := tx.Commit(); err != nil { + return session.CheckpointRecord{}, fmt.Errorf("checkpoint: commit create %s: %w", record.CheckpointID, err) + } + + record.SessionCheckpointRef = sessionCP.ID + record.Status = session.CheckpointStatusAvailable + return record, nil +} + +// ListCheckpoints 查询指定会话的 checkpoint 记录列表。 +func (s *SQLiteCheckpointStore) ListCheckpoints(ctx context.Context, sessionID string, opts ListCheckpointOpts) ([]session.CheckpointRecord, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return nil, err + } + + query := ` +SELECT id, workspace_key, session_id, run_id, workdir, created_at_ms, + reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref, + transcript_revision, restorable, status +FROM checkpoint_records +WHERE session_id = ? +` + args := []any{sessionID} + if opts.RestorableOnly { + query += ` AND restorable = 1 AND status = ?` + args = append(args, string(session.CheckpointStatusAvailable)) + } + query += ` ORDER BY created_at_ms DESC` + if opts.Limit > 0 { + query += ` LIMIT ?` + args = append(args, opts.Limit) + } + + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("checkpoint: list checkpoints for %s: %w", sessionID, err) + } + defer rows.Close() + + var records []session.CheckpointRecord + for rows.Next() { + var r session.CheckpointRecord + var createdAtMS int64 + var reason, status string + if err := rows.Scan( + &r.CheckpointID, &r.WorkspaceKey, &r.SessionID, &r.RunID, &r.Workdir, &createdAtMS, + &reason, &r.CodeCheckpointRef, &r.SessionCheckpointRef, &r.ResumeCheckpointRef, + &r.TranscriptRevision, &r.Restorable, &status, + ); err != nil { + return nil, fmt.Errorf("checkpoint: scan record: %w", err) + } + r.CreatedAt = fromUnixMillis(createdAtMS) + r.Reason = session.CheckpointReason(reason) + r.Status = session.CheckpointStatus(status) + records = append(records, r) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("checkpoint: iterate records: %w", err) + } + return records, nil +} + +// GetCheckpoint 查询单条 checkpoint record 及其关联的 session checkpoint。 +func (s *SQLiteCheckpointStore) GetCheckpoint(ctx context.Context, checkpointID string) (session.CheckpointRecord, *session.SessionCheckpoint, error) { + if err := ctx.Err(); err != nil { + return session.CheckpointRecord{}, nil, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return session.CheckpointRecord{}, nil, err + } + + var r session.CheckpointRecord + var createdAtMS int64 + var reason, status string + err = db.QueryRowContext(ctx, ` +SELECT id, workspace_key, session_id, run_id, workdir, created_at_ms, + reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref, + transcript_revision, restorable, status +FROM checkpoint_records +WHERE id = ? +`, checkpointID).Scan( + &r.CheckpointID, &r.WorkspaceKey, &r.SessionID, &r.RunID, &r.Workdir, &createdAtMS, + &reason, &r.CodeCheckpointRef, &r.SessionCheckpointRef, &r.ResumeCheckpointRef, + &r.TranscriptRevision, &r.Restorable, &status, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return session.CheckpointRecord{}, nil, fmt.Errorf("checkpoint: record %s not found", checkpointID) + } + return session.CheckpointRecord{}, nil, fmt.Errorf("checkpoint: query record %s: %w", checkpointID, err) + } + r.CreatedAt = fromUnixMillis(createdAtMS) + r.Reason = session.CheckpointReason(reason) + r.Status = session.CheckpointStatus(status) + + if r.SessionCheckpointRef == "" { + return r, nil, nil + } + + var sc session.SessionCheckpoint + var scCreatedAtMS int64 + err = db.QueryRowContext(ctx, ` +SELECT id, session_id, head_json, messages_json, created_at_ms +FROM session_checkpoints +WHERE id = ? +`, r.SessionCheckpointRef).Scan( + &sc.ID, &sc.SessionID, &sc.HeadJSON, &sc.MessagesJSON, &scCreatedAtMS, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return r, nil, nil + } + return session.CheckpointRecord{}, nil, fmt.Errorf("checkpoint: query session cp %s: %w", r.SessionCheckpointRef, err) + } + sc.CreatedAt = fromUnixMillis(scCreatedAtMS) + return r, &sc, nil +} + +// UpdateCheckpointStatus 更新 checkpoint 的生命周期状态。 +func (s *SQLiteCheckpointStore) UpdateCheckpointStatus(ctx context.Context, checkpointID string, status session.CheckpointStatus) error { + if err := ctx.Err(); err != nil { + return err + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + result, err := db.ExecContext(ctx, `UPDATE checkpoint_records SET status = ? WHERE id = ?`, string(status), checkpointID) + if err != nil { + return fmt.Errorf("checkpoint: update status %s: %w", checkpointID, err) + } + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("checkpoint: inspect rows affected for %s: %w", checkpointID, err) + } + if affected == 0 { + return fmt.Errorf("checkpoint: record %s not found", checkpointID) + } + return nil +} + +// GetLatestResumeCheckpoint 查询指定会话最新的 resume checkpoint。 +func (s *SQLiteCheckpointStore) GetLatestResumeCheckpoint(ctx context.Context, sessionID string) (*session.ResumeCheckpoint, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return nil, err + } + + var rc session.ResumeCheckpoint + var updatedAtMS int64 + err = db.QueryRowContext(ctx, ` +SELECT id, workspace_key, run_id, session_id, turn, phase, completion_state, transcript_revision, updated_at_ms +FROM resume_checkpoints +WHERE session_id = ? +ORDER BY updated_at_ms DESC +LIMIT 1 +`, sessionID).Scan( + &rc.ID, &rc.WorkspaceKey, &rc.RunID, &rc.SessionID, + &rc.Turn, &rc.Phase, &rc.CompletionState, + &rc.TranscriptRevision, &updatedAtMS, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("checkpoint: query resume checkpoint for %s: %w", sessionID, err) + } + rc.UpdatedAt = fromUnixMillis(updatedAtMS) + return &rc, nil +} + +func toUnixMillis(value time.Time) int64 { + return value.UTC().UnixMilli() +} + +func fromUnixMillis(value int64) time.Time { + if value == 0 { + return time.Time{} + } + return time.UnixMilli(value).UTC() +} + +func boolToInt(value bool) int { + if value { + return 1 + } + return 0 +} + +func rollbackTx(tx *sql.Tx) { + if tx != nil { + _ = tx.Rollback() + } +} diff --git a/internal/checkpoint/shadow_repo.go b/internal/checkpoint/shadow_repo.go new file mode 100644 index 00000000..13e1c4a3 --- /dev/null +++ b/internal/checkpoint/shadow_repo.go @@ -0,0 +1,191 @@ +package checkpoint + +import ( + "context" + "fmt" + "os/exec" + "path/filepath" + "strings" + "sync" + "time" +) + +const ( + gitCommandTimeout = 5 * time.Second +) + +// ShadowRepo 封装 bare git 仓库,用于对用户工作区做代码快照与恢复。 +type ShadowRepo struct { + shadowDir string + workdir string + gitAvailable bool + mu sync.Mutex +} + +// NewShadowRepo 创建影子仓库实例,workdir 为用户工作区根目录。 +func NewShadowRepo(projectDir string, workdir string) *ShadowRepo { + return &ShadowRepo{ + shadowDir: filepath.Join(projectDir, ".shadow"), + workdir: workdir, + } +} + +// CheckGitAvailability 检查系统是否可用 git 命令。 +func CheckGitAvailability(ctx context.Context) (available bool, version string) { + ctx, cancel := context.WithTimeout(ctx, gitCommandTimeout) + defer cancel() + + out, err := exec.CommandContext(ctx, "git", "version").CombinedOutput() + if err != nil { + return false, "" + } + return true, strings.TrimSpace(string(out)) +} + +// Init 初始化 bare 仓库,设置 core.worktree 指向用户工作区。 +func (r *ShadowRepo) Init(ctx context.Context) error { + r.mu.Lock() + defer r.mu.Unlock() + + ctx, cancel := context.WithTimeout(ctx, gitCommandTimeout) + defer cancel() + + if err := exec.CommandContext(ctx, "git", "init", "--bare", r.shadowDir).Run(); err != nil { + return fmt.Errorf("checkpoint: init bare repo at %s: %w", r.shadowDir, err) + } + + // 设置 core.worktree 使后续操作无需重复指定 --work-tree + ctx2, cancel2 := context.WithTimeout(context.Background(), gitCommandTimeout) + defer cancel2() + if err := r.gitExec(ctx2, "config", "core.worktree", r.workdir); err != nil { + return fmt.Errorf("checkpoint: set core.worktree: %w", err) + } + + r.gitAvailable = true + return nil +} + +// IsAvailable 返回影子仓库是否已初始化且 git 可用。 +func (r *ShadowRepo) IsAvailable() bool { + r.mu.Lock() + defer r.mu.Unlock() + return r.gitAvailable +} + +// Snapshot 对工作区做快照,返回 commit hash。 +// ref 为完整 ref 路径(如 refs/neocode/sessions//checkpoints/)。 +func (r *ShadowRepo) Snapshot(ctx context.Context, ref string, message string) (string, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.gitAvailable { + return "", fmt.Errorf("checkpoint: shadow repo not available") + } + + ctx, cancel := context.WithTimeout(ctx, gitCommandTimeout) + defer cancel() + + if err := r.gitExec(ctx, "add", "-A"); err != nil { + return "", fmt.Errorf("checkpoint: git add: %w", err) + } + + commitMsg := message + if commitMsg == "" { + commitMsg = "checkpoint snapshot" + } + if err := r.gitExec(ctx, "commit", "--allow-empty", "-m", commitMsg); err != nil { + return "", fmt.Errorf("checkpoint: git commit: %w", err) + } + + hash, err := r.gitOutput(ctx, "rev-parse", "HEAD") + if err != nil { + return "", fmt.Errorf("checkpoint: git rev-parse HEAD: %w", err) + } + hash = strings.TrimSpace(hash) + + if ref != "" { + if err := r.gitExec(ctx, "update-ref", ref, hash); err != nil { + return "", fmt.Errorf("checkpoint: git update-ref %s: %w", ref, err) + } + } + + return hash, nil +} + +// Restore 将工作区恢复到指定 commit 状态。 +func (r *ShadowRepo) Restore(ctx context.Context, commitHash string) error { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.gitAvailable { + return fmt.Errorf("checkpoint: shadow repo not available") + } + + ctx, cancel := context.WithTimeout(ctx, gitCommandTimeout) + defer cancel() + + if err := r.gitExec(ctx, "checkout", commitHash, "--", "."); err != nil { + return fmt.Errorf("checkpoint: git checkout %s: %w", commitHash, err) + } + return nil +} + +// DeleteRef 删除指定的 ref 引用,用于补偿失败的 checkpoint 创建。 +func (r *ShadowRepo) DeleteRef(ctx context.Context, ref string) error { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.gitAvailable { + return nil + } + + ctx, cancel := context.WithTimeout(ctx, gitCommandTimeout) + defer cancel() + + if err := r.gitExec(ctx, "update-ref", "-d", ref); err != nil { + return fmt.Errorf("checkpoint: git update-ref -d %s: %w", ref, err) + } + return nil +} + +// HealthCheck 验证 bare 仓库存在且可执行 git 操作。 +func (r *ShadowRepo) HealthCheck(ctx context.Context) error { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.gitAvailable { + return fmt.Errorf("checkpoint: shadow repo not available") + } + + ctx, cancel := context.WithTimeout(ctx, gitCommandTimeout) + defer cancel() + + if err := r.gitExec(ctx, "rev-parse", "--git-dir"); err != nil { + return fmt.Errorf("checkpoint: health check failed: %w", err) + } + return nil +} + +func (r *ShadowRepo) gitExec(ctx context.Context, args ...string) error { + cmd := r.buildGitCommand(ctx, args...) + return cmd.Run() +} + +func (r *ShadowRepo) gitOutput(ctx context.Context, args ...string) (string, error) { + cmd := r.buildGitCommand(ctx, args...) + out, err := cmd.Output() + return string(out), err +} + +func (r *ShadowRepo) buildGitCommand(ctx context.Context, args ...string) *exec.Cmd { + fullArgs := make([]string, 0, 4+len(args)) + fullArgs = append(fullArgs, "--git-dir="+r.shadowDir) + fullArgs = append(fullArgs, "--work-tree="+r.workdir) + fullArgs = append(fullArgs, args...) + return exec.CommandContext(ctx, "git", fullArgs...) +} + +// RefForCheckpoint 构造 checkpoint 的 git ref 路径。 +func RefForCheckpoint(sessionID string, checkpointID string) string { + return fmt.Sprintf("refs/neocode/sessions/%s/checkpoints/%s", sessionID, checkpointID) +} diff --git a/internal/runtime/checkpoint_gate.go b/internal/runtime/checkpoint_gate.go new file mode 100644 index 00000000..1f5e83e9 --- /dev/null +++ b/internal/runtime/checkpoint_gate.go @@ -0,0 +1,133 @@ +package runtime + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "neo-code/internal/checkpoint" + providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" + "neo-code/internal/tools" +) + +// createPreWriteCheckpoint 在工具执行前创建 checkpoint,采用两阶段提交。 +// 失败时不阻塞工具执行,仅返回 error 由调用方发 warning event。 +func (s *Service) createPreWriteCheckpoint(ctx context.Context, state *runState) error { + if s.shadowRepo == nil || !s.shadowRepo.IsAvailable() { + return nil + } + if s.checkpointStore == nil { + return nil + } + + state.mu.Lock() + session := state.session + runID := state.runID + state.mu.Unlock() + + checkpointID := agentsession.NewID("checkpoint") + ref := checkpoint.RefForCheckpoint(session.ID, checkpointID) + commitMsg := fmt.Sprintf("pre_write checkpoint for session %s", session.ID) + + // Phase 1: shadow snapshot + commitHash, err := s.shadowRepo.Snapshot(ctx, ref, commitMsg) + if err != nil { + return fmt.Errorf("checkpoint: shadow snapshot: %w", err) + } + + // Phase 2: DB write + head := session.HeadSnapshot() + headJSON, err := json.Marshal(head) + if err != nil { + _ = s.shadowRepo.DeleteRef(ctx, ref) + return fmt.Errorf("checkpoint: marshal head: %w", err) + } + messagesJSON, err := json.Marshal(session.Messages) + if err != nil { + _ = s.shadowRepo.DeleteRef(ctx, ref) + return fmt.Errorf("checkpoint: marshal messages: %w", err) + } + + effectiveWorkdir := strings.TrimSpace(session.Workdir) + now := time.Now() + + record := agentsession.CheckpointRecord{ + CheckpointID: checkpointID, + WorkspaceKey: agentsession.WorkspacePathKey(effectiveWorkdir), + SessionID: session.ID, + RunID: runID, + Workdir: effectiveWorkdir, + CreatedAt: now, + Reason: agentsession.CheckpointReasonPreWrite, + CodeCheckpointRef: ref, + Restorable: true, + Status: agentsession.CheckpointStatusCreating, + } + + sessionCP := agentsession.SessionCheckpoint{ + ID: agentsession.NewID("sc"), + SessionID: session.ID, + HeadJSON: string(headJSON), + MessagesJSON: string(messagesJSON), + CreatedAt: now, + } + + input := checkpoint.CreateCheckpointInput{ + Record: record, + SessionCP: sessionCP, + } + + saved, err := s.checkpointStore.CreateCheckpoint(ctx, input) + if err != nil { + _ = s.shadowRepo.DeleteRef(ctx, ref) + return fmt.Errorf("checkpoint: db write: %w", err) + } + + s.emitRunScoped(ctx, EventCheckpointCreated, state, CheckpointCreatedPayload{ + CheckpointID: saved.CheckpointID, + CodeCheckpointRef: saved.CodeCheckpointRef, + SessionCheckpointRef: saved.SessionCheckpointRef, + CommitHash: commitHash, + Reason: string(saved.Reason), + }) + return nil +} + +// toolCallsContainWorkspaceWrite 检查工具调用列表中是否包含会修改工作区的调用。 +func toolCallsContainWorkspaceWrite(calls []providertypes.ToolCall) bool { + for _, call := range calls { + if isWorkspaceWriteToolCall(call) { + return true + } + } + return false +} + +func isWorkspaceWriteToolCall(call providertypes.ToolCall) bool { + switch call.Name { + case tools.ToolNameFilesystemWriteFile, tools.ToolNameFilesystemEdit: + return true + case tools.ToolNameBash: + return isBashWriteCommand(call.Arguments) + } + return false +} + +func isBashWriteCommand(argumentsJSON string) bool { + trimmed := strings.TrimSpace(argumentsJSON) + if trimmed == "" { + return false + } + var args struct { + Command string `json:"command"` + } + if err := json.Unmarshal([]byte(trimmed), &args); err != nil { + return false + } + intent := tools.AnalyzeBashCommand(args.Command) + return intent.Classification == tools.BashIntentClassificationLocalMutation || + intent.Classification == tools.BashIntentClassificationDestructive +} diff --git a/internal/runtime/events.go b/internal/runtime/events.go index e0725430..3549ee93 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -400,6 +400,11 @@ const ( EventSubAgentSnapshotUpdated EventType = "subagent_snapshot_updated" // EventTodoSnapshotUpdated 表示 todo 快照已更新。 EventTodoSnapshotUpdated EventType = "todo_snapshot_updated" + + // EventCheckpointCreated 表示 pre-write checkpoint 已创建。 + EventCheckpointCreated EventType = "checkpoint_created" + // EventCheckpointWarning 表示 checkpoint 创建过程中出现非致命告警。 + EventCheckpointWarning EventType = "checkpoint_warning" ) // TokenUsagePayload 承载单轮 token 用量统计。 @@ -412,3 +417,18 @@ type TokenUsagePayload struct { SessionInputTokens int `json:"session_input_tokens"` SessionOutputTokens int `json:"session_output_tokens"` } + +// CheckpointCreatedPayload 描述 checkpoint 创建成功事件。 +type CheckpointCreatedPayload struct { + CheckpointID string `json:"checkpoint_id"` + CodeCheckpointRef string `json:"code_checkpoint_ref"` + SessionCheckpointRef string `json:"session_checkpoint_ref"` + CommitHash string `json:"commit_hash"` + Reason string `json:"reason"` +} + +// CheckpointWarningPayload 描述 checkpoint 创建过程中非致命告警。 +type CheckpointWarningPayload struct { + Error string `json:"error"` + Phase string `json:"phase"` +} diff --git a/internal/runtime/run.go b/internal/runtime/run.go index c9faf87d..5958ef29 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -472,6 +472,14 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { beforeTask := state.session.TaskState.Clone() beforeTodos := cloneTodosForPersistence(state.session.Todos) + if s.checkpointStore != nil && toolCallsContainWorkspaceWrite(turnOutput.assistant.ToolCalls) { + if cpErr := s.createPreWriteCheckpoint(ctx, &state); cpErr != nil { + s.emitRunScoped(ctx, EventCheckpointWarning, &state, CheckpointWarningPayload{ + Error: cpErr.Error(), + Phase: "pre_write", + }) + } + } if err := s.setBaseRunState(ctx, &state, controlplane.RunStateExecute); err != nil { return s.handleRunError(err) } diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index e5efc69e..0932437f 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "neo-code/internal/checkpoint" "neo-code/internal/config" agentcontext "neo-code/internal/context" contextcompact "neo-code/internal/context/compact" @@ -150,6 +151,8 @@ type Service struct { skillsRegistry skills.Registry budgetResolver BudgetResolver hookExecutor HookExecutor + checkpointStore checkpoint.CheckpointStore + shadowRepo *checkpoint.ShadowRepo events chan RuntimeEvent runtimeSnapshotMu sync.Mutex @@ -445,3 +448,9 @@ func (s *Service) SetBudgetResolver(resolver BudgetResolver) { func (s *Service) SetHookExecutor(executor HookExecutor) { s.hookExecutor = executor } + +// SetCheckpointDependencies 注入 checkpoint 存储与影子仓库,用于 pre-write checkpoint gate。 +func (s *Service) SetCheckpointDependencies(store checkpoint.CheckpointStore, repo *checkpoint.ShadowRepo) { + s.checkpointStore = store + s.shadowRepo = repo +} diff --git a/internal/session/checkpoint_types.go b/internal/session/checkpoint_types.go new file mode 100644 index 00000000..1cfd148d --- /dev/null +++ b/internal/session/checkpoint_types.go @@ -0,0 +1,64 @@ +package session + +import "time" + +// CheckpointReason 描述 checkpoint 的创建原因。 +type CheckpointReason string + +const ( + CheckpointReasonPreWrite CheckpointReason = "pre_write" + CheckpointReasonCompact CheckpointReason = "compact" + CheckpointReasonPlanMode CheckpointReason = "plan_mode" + CheckpointReasonManual CheckpointReason = "manual" + CheckpointReasonGuard CheckpointReason = "pre_restore_guard" +) + +// CheckpointStatus 描述 checkpoint 的生命周期状态。 +type CheckpointStatus string + +const ( + CheckpointStatusCreating CheckpointStatus = "creating" + CheckpointStatusAvailable CheckpointStatus = "available" + CheckpointStatusBroken CheckpointStatus = "broken" + CheckpointStatusRestored CheckpointStatus = "restored" + CheckpointStatusPruned CheckpointStatus = "pruned" +) + +// CheckpointRecord 遵循 RFC 6.2 定义,包含所有恢复相关引用。 +type CheckpointRecord struct { + CheckpointID string + WorkspaceKey string + SessionID string + RunID string + Workdir string + CreatedAt time.Time + Reason CheckpointReason + CodeCheckpointRef string + SessionCheckpointRef string + ResumeCheckpointRef string + TranscriptRevision int64 + Restorable bool + Status CheckpointStatus +} + +// SessionCheckpoint 保存完整 durable 会话上下文快照。 +type SessionCheckpoint struct { + ID string + SessionID string + HeadJSON string + MessagesJSON string + CreatedAt time.Time +} + +// ResumeCheckpoint 记录运行恢复所需的最小闭环信息。 +type ResumeCheckpoint struct { + ID string + WorkspaceKey string + RunID string + SessionID string + Turn int + Phase string + CompletionState string + TranscriptRevision int64 + UpdatedAt time.Time +} diff --git a/internal/session/sqlite_store.go b/internal/session/sqlite_store.go index d35e18dc..6a94e645 100644 --- a/internal/session/sqlite_store.go +++ b/internal/session/sqlite_store.go @@ -891,6 +891,9 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { if err := migrateSQLiteSchemaV4ToV5(ctx, db); err != nil { return err } + if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { + return err + } case 2: if err := migrateSQLiteSchemaV2ToV3(ctx, db); err != nil { return err @@ -901,6 +904,9 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { if err := migrateSQLiteSchemaV4ToV5(ctx, db); err != nil { return err } + if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { + return err + } case 3: if err := migrateSQLiteSchemaV3ToV4(ctx, db); err != nil { return err @@ -908,10 +914,20 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { if err := migrateSQLiteSchemaV4ToV5(ctx, db); err != nil { return err } + if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { + return err + } case 4: if err := migrateSQLiteSchemaV4ToV5(ctx, db); err != nil { return err } + if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { + return err + } + case 5: + if err := migrateSQLiteSchemaV5ToV6(ctx, db); err != nil { + return err + } default: return fmt.Errorf("session: unsupported sqlite schema version %d", userVersion) } @@ -972,6 +988,45 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { `CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at_ms DESC)`, `CREATE INDEX IF NOT EXISTS idx_messages_session_seq_desc ON messages(session_id, seq DESC)`, `CREATE INDEX IF NOT EXISTS idx_assets_session_id ON session_assets(session_id)`, + `CREATE TABLE IF NOT EXISTS checkpoint_records ( + id TEXT PRIMARY KEY, + workspace_key TEXT NOT NULL, + session_id TEXT NOT NULL, + run_id TEXT NOT NULL DEFAULT '', + workdir TEXT NOT NULL DEFAULT '', + created_at_ms INTEGER NOT NULL, + reason TEXT NOT NULL DEFAULT 'pre_write', + code_checkpoint_ref TEXT NOT NULL DEFAULT '', + session_checkpoint_ref TEXT NOT NULL DEFAULT '', + resume_checkpoint_ref TEXT NOT NULL DEFAULT '', + transcript_revision INTEGER NOT NULL DEFAULT 0, + restorable INTEGER NOT NULL DEFAULT 1, + status TEXT NOT NULL DEFAULT 'creating', + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE TABLE IF NOT EXISTS session_checkpoints ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + head_json TEXT NOT NULL DEFAULT '', + messages_json TEXT NOT NULL DEFAULT '', + created_at_ms INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE TABLE IF NOT EXISTS resume_checkpoints ( + id TEXT PRIMARY KEY, + workspace_key TEXT NOT NULL, + run_id TEXT NOT NULL DEFAULT '', + session_id TEXT NOT NULL, + turn INTEGER NOT NULL DEFAULT 0, + phase TEXT NOT NULL DEFAULT '', + completion_state TEXT NOT NULL DEFAULT '', + transcript_revision INTEGER NOT NULL DEFAULT 0, + updated_at_ms INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE INDEX IF NOT EXISTS idx_checkpoint_session_created ON checkpoint_records(session_id, created_at_ms DESC)`, + `CREATE INDEX IF NOT EXISTS idx_session_checkpoints_session ON session_checkpoints(session_id, created_at_ms DESC)`, + `CREATE INDEX IF NOT EXISTS idx_resume_checkpoints_session ON resume_checkpoints(session_id, updated_at_ms DESC)`, fmt.Sprintf(`PRAGMA user_version=%d`, sqliteSchemaVersion), } for _, statement := range statements { @@ -1014,6 +1069,70 @@ func migrateSQLiteSchemaV1ToV2(ctx context.Context, db *sql.DB) error { return nil } +// migrateSQLiteSchemaV5ToV6 将 v5 会话库升级到 v6 schema,新增 checkpoint 相关表。 +func migrateSQLiteSchemaV5ToV6(ctx context.Context, db *sql.DB) error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("session: begin schema migration tx: %w", err) + } + defer rollbackTx(tx) + + stmts := []string{ + `CREATE TABLE IF NOT EXISTS checkpoint_records ( + id TEXT PRIMARY KEY, + workspace_key TEXT NOT NULL, + session_id TEXT NOT NULL, + run_id TEXT NOT NULL DEFAULT '', + workdir TEXT NOT NULL DEFAULT '', + created_at_ms INTEGER NOT NULL, + reason TEXT NOT NULL DEFAULT 'pre_write', + code_checkpoint_ref TEXT NOT NULL DEFAULT '', + session_checkpoint_ref TEXT NOT NULL DEFAULT '', + resume_checkpoint_ref TEXT NOT NULL DEFAULT '', + transcript_revision INTEGER NOT NULL DEFAULT 0, + restorable INTEGER NOT NULL DEFAULT 1, + status TEXT NOT NULL DEFAULT 'creating', + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE TABLE IF NOT EXISTS session_checkpoints ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + head_json TEXT NOT NULL DEFAULT '', + messages_json TEXT NOT NULL DEFAULT '', + created_at_ms INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE TABLE IF NOT EXISTS resume_checkpoints ( + id TEXT PRIMARY KEY, + workspace_key TEXT NOT NULL, + run_id TEXT NOT NULL DEFAULT '', + session_id TEXT NOT NULL, + turn INTEGER NOT NULL DEFAULT 0, + phase TEXT NOT NULL DEFAULT '', + completion_state TEXT NOT NULL DEFAULT '', + transcript_revision INTEGER NOT NULL DEFAULT 0, + updated_at_ms INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE INDEX IF NOT EXISTS idx_checkpoint_session_created ON checkpoint_records(session_id, created_at_ms DESC)`, + `CREATE INDEX IF NOT EXISTS idx_session_checkpoints_session ON session_checkpoints(session_id, created_at_ms DESC)`, + `CREATE INDEX IF NOT EXISTS idx_resume_checkpoints_session ON resume_checkpoints(session_id, updated_at_ms DESC)`, + } + for _, stmt := range stmts { + if _, err := tx.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("session: migrate sqlite schema v5 to v6: %w", err) + } + } + + if _, err := tx.ExecContext(ctx, fmt.Sprintf(`PRAGMA user_version=%d`, sqliteSchemaVersion)); err != nil { + return fmt.Errorf("session: set migrated sqlite schema version: %w", err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("session: commit schema migration tx: %w", err) + } + return nil +} + // sqliteTableHasColumn 检查指定表是否包含字段,供明确版本迁移保持幂等。 // migrateSQLiteSchemaV2ToV3 将 v2 会话库升级到 v3 schema,补齐 plan/build 所需字段。 func migrateSQLiteSchemaV2ToV3(ctx context.Context, db *sql.DB) error { diff --git a/internal/session/store.go b/internal/session/store.go index b6e5155d..6ec1ac51 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -14,7 +14,7 @@ import ( const ( sessionDatabaseFileName = "session.db" assetsDirName = "assets" - sqliteSchemaVersion = 5 + sqliteSchemaVersion = 6 // MaxSessionMessages 定义单个会话允许持久化的最大消息数,超出时自动裁剪最旧消息。 MaxSessionMessages = 8192 @@ -155,7 +155,7 @@ func NewSQLiteStore(baseDir string, workspaceRoot string) *SQLiteStore { return &SQLiteStore{ projectDir: projectDirectory(baseDir, workspaceRoot), assetsDir: assetsDirectory(baseDir, workspaceRoot), - dbPath: databasePath(baseDir, workspaceRoot), + dbPath: DatabasePath(baseDir, workspaceRoot), assetPolicy: DefaultAssetPolicy(), } } diff --git a/internal/session/store_test.go b/internal/session/store_test.go index dd634a25..00228be2 100644 --- a/internal/session/store_test.go +++ b/internal/session/store_test.go @@ -476,7 +476,7 @@ func TestSQLiteStoreInitializationRejectsUnsupportedSchemaVersion(t *testing.T) if err := os.MkdirAll(projectDir, 0o755); err != nil { t.Fatalf("MkdirAll(projectDir) error = %v", err) } - db, err := sql.Open("sqlite", databasePath(baseDir, workspaceRoot)) + db, err := sql.Open("sqlite", DatabasePath(baseDir, workspaceRoot)) if err != nil { t.Fatalf("sql.Open() error = %v", err) } @@ -715,7 +715,7 @@ func createLegacyV1SessionDB( if err := os.MkdirAll(projectDir, 0o755); err != nil { t.Fatalf("MkdirAll(projectDir) error = %v", err) } - db, err := sql.Open("sqlite", databasePath(baseDir, workspaceRoot)) + db, err := sql.Open("sqlite", DatabasePath(baseDir, workspaceRoot)) if err != nil { t.Fatalf("sql.Open() error = %v", err) } @@ -792,7 +792,7 @@ func createLegacyV2SessionDB(t *testing.T, ctx context.Context, baseDir string, if err := os.MkdirAll(projectDir, 0o755); err != nil { t.Fatalf("MkdirAll(projectDir) error = %v", err) } - db, err := sql.Open("sqlite", databasePath(baseDir, workspaceRoot)) + db, err := sql.Open("sqlite", DatabasePath(baseDir, workspaceRoot)) if err != nil { t.Fatalf("sql.Open() error = %v", err) } diff --git a/internal/session/workspace.go b/internal/session/workspace.go index d16e0278..cfe99170 100644 --- a/internal/session/workspace.go +++ b/internal/session/workspace.go @@ -17,8 +17,8 @@ func projectDirectory(baseDir string, workspaceRoot string) string { return filepath.Join(baseDir, projectsDirName, HashWorkspaceRoot(workspaceRoot)) } -// databasePath 返回当前工作区级 SQLite 数据库文件路径。 -func databasePath(baseDir string, workspaceRoot string) string { +// DatabasePath 返回当前工作区级 SQLite 数据库文件路径。 +func DatabasePath(baseDir string, workspaceRoot string) string { return filepath.Join(projectDirectory(baseDir, workspaceRoot), sessionDatabaseFileName) } From 29078d09d8daeaa5cf405a7c6a647b63c594ad41 Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Sun, 3 May 2026 20:25:10 +0800 Subject: [PATCH 2/8] =?UTF-8?q?feat(checkpoint):=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E5=9B=9E=E9=80=80=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/app/bootstrap.go | 16 +- internal/checkpoint/checkpoint_manager.go | 270 ++++++++++++++++++++++ internal/checkpoint/shadow_repo.go | 103 +++++++++ internal/cli/gateway_runtime_bridge.go | 21 ++ internal/gateway/bootstrap.go | 121 ++++++++++ internal/gateway/bootstrap_test.go | 143 ++++++++---- internal/gateway/contracts.go | 55 +++++ internal/gateway/contracts_test.go | 12 + internal/gateway/registry.go | 3 + internal/gateway/rpc_dispatch_test.go | 24 ++ internal/gateway/server_test.go | 12 + internal/gateway/types.go | 6 + internal/runtime/checkpoint_gate.go | 61 ++++- internal/runtime/checkpoint_restore.go | 265 +++++++++++++++++++++ internal/runtime/checkpoint_resume.go | 38 +++ internal/runtime/compact.go | 60 +++++ internal/runtime/events.go | 17 ++ internal/runtime/run.go | 7 + internal/session/checkpoint_types.go | 11 +- 19 files changed, 1189 insertions(+), 56 deletions(-) create mode 100644 internal/runtime/checkpoint_restore.go create mode 100644 internal/runtime/checkpoint_resume.go diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index 87290a7b..5aae3f5e 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -232,6 +232,7 @@ func BuildGatewayServerDeps(ctx context.Context, opts BootstrapOptions) (Runtime } // Checkpoint 基础设施:影子仓库 + SQLite checkpoint 存储 + dbPath := agentsession.DatabasePath(sharedDeps.ConfigManager.BaseDir(), cfg.Workdir) if gitAvail, _ := checkpoint.CheckGitAvailability(ctx); gitAvail { projectDir := agentsession.HashWorkspaceRoot(cfg.Workdir) shadowDir := filepath.Join(sharedDeps.ConfigManager.BaseDir(), "projects", projectDir) @@ -239,10 +240,23 @@ func BuildGatewayServerDeps(ctx context.Context, opts BootstrapOptions) (Runtime if err := shadowRepo.Init(ctx); err != nil { log.Printf("checkpoint shadow repo init warning: %v", err) } else { - dbPath := agentsession.DatabasePath(sharedDeps.ConfigManager.BaseDir(), cfg.Workdir) checkpointStore := checkpoint.NewSQLiteCheckpointStore(dbPath) runtimeSvc.SetCheckpointDependencies(checkpointStore, shadowRepo) } + } else { + // 降级模式:git 不可用时仍可创建 session-only checkpoint + checkpointStore := checkpoint.NewSQLiteCheckpointStore(dbPath) + runtimeSvc.SetCheckpointDependencies(checkpointStore, nil) + } + // 启动时修复残留的 creating 状态 checkpoint + { + repairStore := checkpoint.NewSQLiteCheckpointStore(dbPath) + if repaired, err := repairStore.RepairCreatingCheckpoints(ctx); err != nil { + log.Printf("checkpoint repair warning: %v", err) + } else if repaired > 0 { + log.Printf("checkpoint repair: fixed %d stale checkpoints", repaired) + } + _ = repairStore.Close() } runtimeImpl := agentruntime.Runtime(runtimeSvc) diff --git a/internal/checkpoint/checkpoint_manager.go b/internal/checkpoint/checkpoint_manager.go index 97860416..d9137620 100644 --- a/internal/checkpoint/checkpoint_manager.go +++ b/internal/checkpoint/checkpoint_manager.go @@ -3,11 +3,13 @@ package checkpoint import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "sync" "time" + providertypes "neo-code/internal/provider/types" "neo-code/internal/session" ) @@ -18,6 +20,10 @@ type CheckpointStore interface { GetCheckpoint(ctx context.Context, checkpointID string) (session.CheckpointRecord, *session.SessionCheckpoint, error) UpdateCheckpointStatus(ctx context.Context, checkpointID string, status session.CheckpointStatus) error GetLatestResumeCheckpoint(ctx context.Context, sessionID string) (*session.ResumeCheckpoint, error) + RestoreCheckpoint(ctx context.Context, input RestoreCheckpointInput) error + SetResumeCheckpoint(ctx context.Context, rc session.ResumeCheckpoint) error + PruneExpiredCheckpoints(ctx context.Context, sessionID string, maxAutoKeep int) (int, error) + RepairCreatingCheckpoints(ctx context.Context) (int, error) } // CreateCheckpointInput 描述一次 checkpoint 创建的完整输入。 @@ -32,6 +38,16 @@ type ListCheckpointOpts struct { RestorableOnly bool } +// RestoreCheckpointInput 描述一次 restore 操作的完整输入。 +type RestoreCheckpointInput struct { + SessionID string + Head session.SessionHead + Messages []providertypes.Message + UpdatedAt time.Time + MarkAvailableIDs []string + MarkRestoredIDs []string +} + // SQLiteCheckpointStore 基于 SQLite 实现 checkpoint 持久化。 type SQLiteCheckpointStore struct { dbPath string @@ -339,6 +355,260 @@ LIMIT 1 return &rc, nil } +// RestoreCheckpoint 在单一事务内恢复会话消息和头状态,并批量更新 checkpoint 状态。 +func (s *SQLiteCheckpointStore) RestoreCheckpoint(ctx context.Context, input RestoreCheckpointInput) error { + if err := ctx.Err(); err != nil { + return err + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("checkpoint: begin restore tx: %w", err) + } + defer rollbackTx(tx) + + // DELETE existing messages + if _, err := tx.ExecContext(ctx, `DELETE FROM messages WHERE session_id = ?`, input.SessionID); err != nil { + return fmt.Errorf("checkpoint: delete messages %s: %w", input.SessionID, err) + } + + // Re-insert messages + now := input.UpdatedAt + for i, msg := range input.Messages { + seq := i + 1 + toolCallsJSON := "[]" + if len(msg.ToolCalls) > 0 { + if data, err := json.Marshal(msg.ToolCalls); err == nil { + toolCallsJSON = string(data) + } + } + toolMetadataJSON := "{}" + if msg.ToolMetadata != nil { + if data, err := json.Marshal(msg.ToolMetadata); err == nil { + toolMetadataJSON = string(data) + } + } + partsJSON := "[]" + if len(msg.Parts) > 0 { + if data, err := json.Marshal(msg.Parts); err == nil { + partsJSON = string(data) + } + } + if _, err := tx.ExecContext(ctx, `INSERT INTO messages (session_id, seq, role, parts_json, tool_calls_json, tool_call_id, is_error, tool_metadata_json, created_at_ms) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + input.SessionID, seq, msg.Role, partsJSON, toolCallsJSON, msg.ToolCallID, boolToInt(msg.IsError), toolMetadataJSON, toUnixMillis(now), + ); err != nil { + return fmt.Errorf("checkpoint: insert message %s/%d: %w", input.SessionID, seq, err) + } + } + + // UPDATE session head + h := input.Head + result, err := tx.ExecContext(ctx, `UPDATE sessions SET updated_at_ms=?, provider=?, model=?, workdir=?, task_state_json=?, todos_json=?, activated_skills_json=?, token_input_total=?, token_output_total=?, has_unknown_usage=?, agent_mode=?, current_plan_json=?, last_full_plan_revision=?, plan_approval_pending_full_align=?, plan_completion_pending_full_review=?, plan_context_dirty=?, plan_restore_pending_align=?, last_seq=?, message_count=? WHERE id=?`, + toUnixMillis(input.UpdatedAt), h.Provider, h.Model, h.Workdir, + marshalHeadField(h.TaskState), marshalHeadField(h.Todos), marshalHeadField(h.ActivatedSkills), + h.TokenInputTotal, h.TokenOutputTotal, boolToInt(h.HasUnknownUsage), h.AgentMode, + marshalHeadField(h.CurrentPlan), h.LastFullPlanRevision, + boolToInt(h.PlanApprovalPendingFullAlign), boolToInt(h.PlanCompletionPendingFullReview), + boolToInt(h.PlanContextDirty), boolToInt(h.PlanRestorePendingAlign), + len(input.Messages), len(input.Messages), input.SessionID, + ) + if err != nil { + return fmt.Errorf("checkpoint: update session %s: %w", input.SessionID, err) + } + if affected, _ := result.RowsAffected(); affected == 0 { + return fmt.Errorf("checkpoint: session %s not found", input.SessionID) + } + + // Mark available + for _, id := range input.MarkAvailableIDs { + if _, err := tx.ExecContext(ctx, `UPDATE checkpoint_records SET status=? WHERE id=?`, string(session.CheckpointStatusAvailable), id); err != nil { + return fmt.Errorf("checkpoint: mark available %s: %w", id, err) + } + } + + // Mark restored + for _, id := range input.MarkRestoredIDs { + if _, err := tx.ExecContext(ctx, `UPDATE checkpoint_records SET status=? WHERE id=?`, string(session.CheckpointStatusRestored), id); err != nil { + return fmt.Errorf("checkpoint: mark restored %s: %w", id, err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("checkpoint: commit restore %s: %w", input.SessionID, err) + } + return nil +} + +func marshalHeadField(value any) string { + if value == nil { + return "{}" + } + data, err := json.Marshal(value) + if err != nil { + return "{}" + } + return string(data) +} + +// SetResumeCheckpoint 写入或更新 ResumeCheckpoint(一个 session 只保留一条)。 +func (s *SQLiteCheckpointStore) SetResumeCheckpoint(ctx context.Context, rc session.ResumeCheckpoint) error { + if err := ctx.Err(); err != nil { + return err + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("checkpoint: begin set resume tx: %w", err) + } + defer rollbackTx(tx) + + if _, err := tx.ExecContext(ctx, `DELETE FROM resume_checkpoints WHERE session_id=?`, rc.SessionID); err != nil { + return fmt.Errorf("checkpoint: delete old resume cp %s: %w", rc.SessionID, err) + } + + if _, err := tx.ExecContext(ctx, `INSERT INTO resume_checkpoints (id, workspace_key, run_id, session_id, turn, phase, completion_state, transcript_revision, updated_at_ms) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + rc.ID, rc.WorkspaceKey, rc.RunID, rc.SessionID, rc.Turn, rc.Phase, rc.CompletionState, rc.TranscriptRevision, toUnixMillis(rc.UpdatedAt), + ); err != nil { + return fmt.Errorf("checkpoint: insert resume cp %s: %w", rc.SessionID, err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("checkpoint: commit set resume cp %s: %w", rc.SessionID, err) + } + return nil +} + +// PruneExpiredCheckpoints 窗口裁剪,将超出 maxAutoKeep 的旧自动 checkpoint 标记为 pruned。 +func (s *SQLiteCheckpointStore) PruneExpiredCheckpoints(ctx context.Context, sessionID string, maxAutoKeep int) (int, error) { + if err := ctx.Err(); err != nil { + return 0, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return 0, err + } + + rows, err := db.QueryContext(ctx, `SELECT id, session_checkpoint_ref FROM checkpoint_records WHERE session_id=? AND restorable=1 AND status=? AND reason NOT IN (?, ?) ORDER BY created_at_ms DESC`, + sessionID, string(session.CheckpointStatusAvailable), string(session.CheckpointReasonManual), string(session.CheckpointReasonGuard), + ) + if err != nil { + return 0, fmt.Errorf("checkpoint: query prune candidates %s: %w", sessionID, err) + } + defer rows.Close() + + type pruneTarget struct { + ID string + SessionCPRef string + } + var targets []pruneTarget + idx := 0 + for rows.Next() { + var t pruneTarget + if err := rows.Scan(&t.ID, &t.SessionCPRef); err != nil { + return 0, fmt.Errorf("checkpoint: scan prune candidate: %w", err) + } + if idx >= maxAutoKeep { + targets = append(targets, t) + } + idx++ + } + if err := rows.Err(); err != nil { + return 0, fmt.Errorf("checkpoint: iterate prune candidates: %w", err) + } + if len(targets) == 0 { + return 0, nil + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return 0, fmt.Errorf("checkpoint: begin prune tx: %w", err) + } + defer rollbackTx(tx) + + for _, t := range targets { + if _, err := tx.ExecContext(ctx, `UPDATE checkpoint_records SET restorable=0, status=? WHERE id=?`, string(session.CheckpointStatusPruned), t.ID); err != nil { + return 0, fmt.Errorf("checkpoint: prune record %s: %w", t.ID, err) + } + if t.SessionCPRef != "" { + if _, err := tx.ExecContext(ctx, `DELETE FROM session_checkpoints WHERE id=?`, t.SessionCPRef); err != nil { + return 0, fmt.Errorf("checkpoint: delete session cp %s: %w", t.SessionCPRef, err) + } + } + } + + if err := tx.Commit(); err != nil { + return 0, fmt.Errorf("checkpoint: commit prune %s: %w", sessionID, err) + } + return len(targets), nil +} + +// RepairCreatingCheckpoints 修复残留的 creating 状态 checkpoint。 +func (s *SQLiteCheckpointStore) RepairCreatingCheckpoints(ctx context.Context) (int, error) { + if err := ctx.Err(); err != nil { + return 0, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return 0, err + } + + rows, err := db.QueryContext(ctx, `SELECT id, session_checkpoint_ref FROM checkpoint_records WHERE status=?`, string(session.CheckpointStatusCreating)) + if err != nil { + return 0, fmt.Errorf("checkpoint: query creating records: %w", err) + } + defer rows.Close() + + type repairTarget struct { + ID string + SessionCPRef string + } + var targets []repairTarget + for rows.Next() { + var t repairTarget + if err := rows.Scan(&t.ID, &t.SessionCPRef); err != nil { + return 0, fmt.Errorf("checkpoint: scan creating record: %w", err) + } + targets = append(targets, t) + } + if err := rows.Err(); err != nil { + return 0, fmt.Errorf("checkpoint: iterate creating records: %w", err) + } + if len(targets) == 0 { + return 0, nil + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return 0, fmt.Errorf("checkpoint: begin repair tx: %w", err) + } + defer rollbackTx(tx) + + for _, t := range targets { + if t.SessionCPRef != "" { + if _, err := tx.ExecContext(ctx, `UPDATE checkpoint_records SET status=? WHERE id=?`, string(session.CheckpointStatusAvailable), t.ID); err != nil { + return 0, fmt.Errorf("checkpoint: repair available %s: %w", t.ID, err) + } + } else { + if _, err := tx.ExecContext(ctx, `DELETE FROM checkpoint_records WHERE id=?`, t.ID); err != nil { + return 0, fmt.Errorf("checkpoint: delete orphan %s: %w", t.ID, err) + } + } + } + + if err := tx.Commit(); err != nil { + return 0, fmt.Errorf("checkpoint: commit repair: %w", err) + } + return len(targets), nil +} + func toUnixMillis(value time.Time) int64 { return value.UTC().UnixMilli() } diff --git a/internal/checkpoint/shadow_repo.go b/internal/checkpoint/shadow_repo.go index 13e1c4a3..33c2eab4 100644 --- a/internal/checkpoint/shadow_repo.go +++ b/internal/checkpoint/shadow_repo.go @@ -3,6 +3,7 @@ package checkpoint import ( "context" "fmt" + "os" "os/exec" "path/filepath" "strings" @@ -30,6 +31,14 @@ func NewShadowRepo(projectDir string, workdir string) *ShadowRepo { } } +// ConflictResult 描述目标 commit 与当前工作区之间的差异。 +type ConflictResult struct { + HasConflict bool + AddedFiles []string + DeletedFiles []string + ModifiedFiles []string +} + // CheckGitAvailability 检查系统是否可用 git 命令。 func CheckGitAvailability(ctx context.Context) (available bool, version string) { ctx, cancel := context.WithTimeout(ctx, gitCommandTimeout) @@ -43,6 +52,7 @@ func CheckGitAvailability(ctx context.Context) (available bool, version string) } // Init 初始化 bare 仓库,设置 core.worktree 指向用户工作区。 +// 如果仓库已存在但损坏,会先 Rebuild 再初始化。 func (r *ShadowRepo) Init(ctx context.Context) error { r.mu.Lock() defer r.mu.Unlock() @@ -50,6 +60,19 @@ func (r *ShadowRepo) Init(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, gitCommandTimeout) defer cancel() + // 如果目录已存在,先做健康检查 + if _, err := os.Stat(r.shadowDir); err == nil { + healthCtx, healthCancel := context.WithTimeout(context.Background(), gitCommandTimeout) + defer healthCancel() + checkCmd := r.buildGitCommand(healthCtx, "rev-parse", "--git-dir") + if err := checkCmd.Run(); err != nil { + // 损坏的仓库,尝试重建 + if rebuildErr := r.rebuildLocked(context.Background()); rebuildErr != nil { + return fmt.Errorf("checkpoint: rebuild damaged repo: %w", rebuildErr) + } + } + } + if err := exec.CommandContext(ctx, "git", "init", "--bare", r.shadowDir).Run(); err != nil { return fmt.Errorf("checkpoint: init bare repo at %s: %w", r.shadowDir, err) } @@ -130,6 +153,25 @@ func (r *ShadowRepo) Restore(ctx context.Context, commitHash string) error { return nil } +// ResolveRef 解析 ref 对应的 commit hash。 +func (r *ShadowRepo) ResolveRef(ctx context.Context, ref string) (string, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.gitAvailable { + return "", fmt.Errorf("checkpoint: shadow repo not available") + } + + ctx, cancel := context.WithTimeout(ctx, gitCommandTimeout) + defer cancel() + + hash, err := r.gitOutput(ctx, "rev-parse", ref) + if err != nil { + return "", fmt.Errorf("checkpoint: resolve ref %s: %w", ref, err) + } + return strings.TrimSpace(hash), nil +} + // DeleteRef 删除指定的 ref 引用,用于补偿失败的 checkpoint 创建。 func (r *ShadowRepo) DeleteRef(ctx context.Context, ref string) error { r.mu.Lock() @@ -185,6 +227,67 @@ func (r *ShadowRepo) buildGitCommand(ctx context.Context, args ...string) *exec. return exec.CommandContext(ctx, "git", fullArgs...) } +// DetectConflicts 比较目标 commit 与当前工作区差异。 +func (r *ShadowRepo) DetectConflicts(ctx context.Context, commitHash string) (ConflictResult, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.gitAvailable { + return ConflictResult{}, fmt.Errorf("checkpoint: shadow repo not available") + } + + ctx, cancel := context.WithTimeout(ctx, gitCommandTimeout) + defer cancel() + + out, err := r.gitOutput(ctx, "diff", "--name-status", commitHash, "--", ".") + if err != nil { + return ConflictResult{}, fmt.Errorf("checkpoint: git diff: %w", err) + } + + var result ConflictResult + lines := strings.Split(strings.TrimSpace(out), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + parts := strings.SplitN(line, "\t", 2) + if len(parts) != 2 { + continue + } + status := parts[0] + file := parts[1] + switch { + case status == "A": + result.AddedFiles = append(result.AddedFiles, file) + result.HasConflict = true + case status == "D": + result.DeletedFiles = append(result.DeletedFiles, file) + result.HasConflict = true + case strings.HasPrefix(status, "M"): + result.ModifiedFiles = append(result.ModifiedFiles, file) + result.HasConflict = true + } + } + return result, nil +} + +// Rebuild 重建损坏的影子仓库。 +func (r *ShadowRepo) Rebuild(ctx context.Context) error { + r.mu.Lock() + defer r.mu.Unlock() + return r.rebuildLocked(ctx) +} + +// rebuildLocked 在持有锁的情况下重建影子仓库。 +func (r *ShadowRepo) rebuildLocked(ctx context.Context) error { + backupDir := r.shadowDir + ".bak." + fmt.Sprintf("%d", time.Now().UnixNano()) + if err := os.Rename(r.shadowDir, backupDir); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("checkpoint: rename old shadow dir: %w", err) + } + return nil +} + // RefForCheckpoint 构造 checkpoint 的 git ref 路径。 func RefForCheckpoint(sessionID string, checkpointID string) string { return fmt.Sprintf("refs/neocode/sessions/%s/checkpoints/%s", sessionID, checkpointID) diff --git a/internal/cli/gateway_runtime_bridge.go b/internal/cli/gateway_runtime_bridge.go index 9f296bf6..1dc4a7ea 100644 --- a/internal/cli/gateway_runtime_bridge.go +++ b/internal/cli/gateway_runtime_bridge.go @@ -1363,3 +1363,24 @@ type manualModelPayload struct { } var _ gateway.RuntimePort = (*gatewayRuntimePortBridge)(nil) + +func (b *gatewayRuntimePortBridge) ListCheckpoints(ctx context.Context, input gateway.ListCheckpointsInput) ([]gateway.CheckpointEntry, error) { + if b == nil { + return nil, fmt.Errorf(bridgeRuntimeUnavailableErrMsg) + } + return nil, fmt.Errorf("checkpoint list: not yet implemented in CLI bridge") +} + +func (b *gatewayRuntimePortBridge) RestoreCheckpoint(ctx context.Context, input gateway.CheckpointRestoreInput) (gateway.CheckpointRestoreResult, error) { + if b == nil { + return gateway.CheckpointRestoreResult{}, fmt.Errorf(bridgeRuntimeUnavailableErrMsg) + } + return gateway.CheckpointRestoreResult{}, fmt.Errorf("checkpoint restore: not yet implemented in CLI bridge") +} + +func (b *gatewayRuntimePortBridge) UndoRestore(ctx context.Context, input gateway.UndoRestoreInput) (gateway.CheckpointRestoreResult, error) { + if b == nil { + return gateway.CheckpointRestoreResult{}, fmt.Errorf(bridgeRuntimeUnavailableErrMsg) + } + return gateway.CheckpointRestoreResult{}, fmt.Errorf("checkpoint undo restore: not yet implemented in CLI bridge") +} diff --git a/internal/gateway/bootstrap.go b/internal/gateway/bootstrap.go index 0d97e67d..2cf5dc4e 100644 --- a/internal/gateway/bootstrap.go +++ b/internal/gateway/bootstrap.go @@ -1792,6 +1792,18 @@ func readStringValue(payload map[string]any, key string) string { return strings.TrimSpace(stringValue) } +func readBoolValue(payload map[string]any, key string) bool { + rawValue, exists := payload[key] + if !exists { + return false + } + boolValue, ok := rawValue.(bool) + if !ok { + return false + } + return boolValue +} + // decodeWakeIntent 将任意 payload 解码为 WakeIntent。 func decodeWakeIntent(payload any) (protocol.WakeIntent, error) { if payload == nil { @@ -1844,3 +1856,112 @@ func toFrameError(err *handlers.WakeError) *FrameError { } return NewFrameError(ErrorCodeInternalError, err.Message) } + +func handleListCheckpointsFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame { + if runtimePort == nil { + return runtimePortUnavailableFrame(frame) + } + subjectID, subjectErr := requireAuthenticatedSubjectID(ctx) + if subjectErr != nil { + return errorFrame(frame, subjectErr) + } + + callCtx, cancel := withRuntimeOperationTimeout(ctx) + defer cancel() + + entries, err := runtimePort.ListCheckpoints(callCtx, ListCheckpointsInput{ + SubjectID: subjectID, + SessionID: strings.TrimSpace(frame.SessionID), + }) + if err != nil { + return runtimeCallFailedFrame(callCtx, frame, err, "checkpoint_list") + } + + return MessageFrame{ + Type: FrameTypeAck, + Action: FrameActionListCheckpoints, + RequestID: frame.RequestID, + SessionID: strings.TrimSpace(frame.SessionID), + Payload: entries, + } +} + +func handleRestoreCheckpointFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame { + if runtimePort == nil { + return runtimePortUnavailableFrame(frame) + } + subjectID, subjectErr := requireAuthenticatedSubjectID(ctx) + if subjectErr != nil { + return errorFrame(frame, subjectErr) + } + + input := decodeCheckpointRestorePayload(frame.Payload) + input.SubjectID = subjectID + if input.SessionID == "" { + input.SessionID = strings.TrimSpace(frame.SessionID) + } + + callCtx, cancel := withRuntimeOperationTimeout(ctx) + defer cancel() + + result, err := runtimePort.RestoreCheckpoint(callCtx, input) + if err != nil { + return runtimeCallFailedFrame(callCtx, frame, err, "checkpoint_restore") + } + + return MessageFrame{ + Type: FrameTypeAck, + Action: FrameActionRestoreCheckpoint, + RequestID: frame.RequestID, + SessionID: input.SessionID, + Payload: result, + } +} + +func handleUndoRestoreFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame { + if runtimePort == nil { + return runtimePortUnavailableFrame(frame) + } + subjectID, subjectErr := requireAuthenticatedSubjectID(ctx) + if subjectErr != nil { + return errorFrame(frame, subjectErr) + } + + callCtx, cancel := withRuntimeOperationTimeout(ctx) + defer cancel() + + result, err := runtimePort.UndoRestore(callCtx, UndoRestoreInput{ + SubjectID: subjectID, + SessionID: strings.TrimSpace(frame.SessionID), + }) + if err != nil { + return runtimeCallFailedFrame(callCtx, frame, err, "checkpoint_undo_restore") + } + + return MessageFrame{ + Type: FrameTypeAck, + Action: FrameActionUndoRestore, + RequestID: frame.RequestID, + SessionID: strings.TrimSpace(frame.SessionID), + Payload: result, + } +} + +func decodeCheckpointRestorePayload(payload any) CheckpointRestoreInput { + switch typed := payload.(type) { + case map[string]any: + return CheckpointRestoreInput{ + SessionID: readStringValue(typed, "session_id"), + CheckpointID: readStringValue(typed, "checkpoint_id"), + Force: readBoolValue(typed, "force"), + } + default: + raw, marshalErr := json.Marshal(payload) + if marshalErr != nil { + return CheckpointRestoreInput{} + } + var decoded CheckpointRestoreInput + _ = json.Unmarshal(raw, &decoded) + return decoded + } +} diff --git a/internal/gateway/bootstrap_test.go b/internal/gateway/bootstrap_test.go index 26428156..2a508986 100644 --- a/internal/gateway/bootstrap_test.go +++ b/internal/gateway/bootstrap_test.go @@ -16,35 +16,35 @@ import ( ) type bootstrapRuntimeStub struct { - runFn func(ctx context.Context, input RunInput) error - createSessionFn func(ctx context.Context, input CreateSessionInput) (string, error) - compactFn func(ctx context.Context, input CompactInput) (CompactResult, error) - executeSystemToolFn func(ctx context.Context, input ExecuteSystemToolInput) (tools.ToolResult, error) - activateSkillFn func(ctx context.Context, input SessionSkillMutationInput) error - deactivateSkillFn func(ctx context.Context, input SessionSkillMutationInput) error - listSessionSkillsFn func(ctx context.Context, input ListSessionSkillsInput) ([]SessionSkillState, error) - listAvailableFn func(ctx context.Context, input ListAvailableSkillsInput) ([]AvailableSkillState, error) - resolvePermissionFn func(ctx context.Context, input PermissionResolutionInput) error - cancelRunFn func(ctx context.Context, input CancelInput) (bool, error) - events <-chan RuntimeEvent - listSessionsFn func(ctx context.Context) ([]SessionSummary, error) - loadSessionFn func(ctx context.Context, input LoadSessionInput) (Session, error) - listSessionTodosFn func(ctx context.Context, input ListSessionTodosInput) (TodoSnapshot, error) + runFn func(ctx context.Context, input RunInput) error + createSessionFn func(ctx context.Context, input CreateSessionInput) (string, error) + compactFn func(ctx context.Context, input CompactInput) (CompactResult, error) + executeSystemToolFn func(ctx context.Context, input ExecuteSystemToolInput) (tools.ToolResult, error) + activateSkillFn func(ctx context.Context, input SessionSkillMutationInput) error + deactivateSkillFn func(ctx context.Context, input SessionSkillMutationInput) error + listSessionSkillsFn func(ctx context.Context, input ListSessionSkillsInput) ([]SessionSkillState, error) + listAvailableFn func(ctx context.Context, input ListAvailableSkillsInput) ([]AvailableSkillState, error) + resolvePermissionFn func(ctx context.Context, input PermissionResolutionInput) error + cancelRunFn func(ctx context.Context, input CancelInput) (bool, error) + events <-chan RuntimeEvent + listSessionsFn func(ctx context.Context) ([]SessionSummary, error) + loadSessionFn func(ctx context.Context, input LoadSessionInput) (Session, error) + listSessionTodosFn func(ctx context.Context, input ListSessionTodosInput) (TodoSnapshot, error) getRuntimeSnapshotFn func(ctx context.Context, input GetRuntimeSnapshotInput) (RuntimeSnapshot, error) - deleteSessionFn func(ctx context.Context, input DeleteSessionInput) (bool, error) - renameSessionFn func(ctx context.Context, input RenameSessionInput) error - listFilesFn func(ctx context.Context, input ListFilesInput) ([]FileEntry, error) - listModelsFn func(ctx context.Context, input ListModelsInput) ([]ModelEntry, error) - setSessionModelFn func(ctx context.Context, input SetSessionModelInput) error - getSessionModelFn func(ctx context.Context, input GetSessionModelInput) (SessionModelResult, error) - listProvidersFn func(ctx context.Context, input ListProvidersInput) ([]ProviderOption, error) - createProviderFn func(ctx context.Context, input CreateProviderInput) (ProviderSelectionResult, error) - deleteProviderFn func(ctx context.Context, input DeleteProviderInput) error - selectProviderFn func(ctx context.Context, input SelectProviderModelInput) (ProviderSelectionResult, error) - listMCPServersFn func(ctx context.Context, input ListMCPServersInput) ([]MCPServerEntry, error) - upsertMCPServerFn func(ctx context.Context, input UpsertMCPServerInput) error - setMCPEnabledFn func(ctx context.Context, input SetMCPServerEnabledInput) error - deleteMCPServerFn func(ctx context.Context, input DeleteMCPServerInput) error + deleteSessionFn func(ctx context.Context, input DeleteSessionInput) (bool, error) + renameSessionFn func(ctx context.Context, input RenameSessionInput) error + listFilesFn func(ctx context.Context, input ListFilesInput) ([]FileEntry, error) + listModelsFn func(ctx context.Context, input ListModelsInput) ([]ModelEntry, error) + setSessionModelFn func(ctx context.Context, input SetSessionModelInput) error + getSessionModelFn func(ctx context.Context, input GetSessionModelInput) (SessionModelResult, error) + listProvidersFn func(ctx context.Context, input ListProvidersInput) ([]ProviderOption, error) + createProviderFn func(ctx context.Context, input CreateProviderInput) (ProviderSelectionResult, error) + deleteProviderFn func(ctx context.Context, input DeleteProviderInput) error + selectProviderFn func(ctx context.Context, input SelectProviderModelInput) (ProviderSelectionResult, error) + listMCPServersFn func(ctx context.Context, input ListMCPServersInput) ([]MCPServerEntry, error) + upsertMCPServerFn func(ctx context.Context, input UpsertMCPServerInput) error + setMCPEnabledFn func(ctx context.Context, input SetMCPServerEnabledInput) error + deleteMCPServerFn func(ctx context.Context, input DeleteMCPServerInput) error } func (s *bootstrapRuntimeStub) Run(ctx context.Context, input RunInput) error { @@ -249,6 +249,18 @@ func (s *bootstrapRuntimeStub) CreateSession(ctx context.Context, input CreateSe return strings.TrimSpace(input.SessionID), nil } +func (s *bootstrapRuntimeStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { + return nil, nil +} + +func (s *bootstrapRuntimeStub) RestoreCheckpoint(_ context.Context, _ CheckpointRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + +func (s *bootstrapRuntimeStub) UndoRestore(_ context.Context, _ UndoRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + func TestDispatchRequestFramePing(t *testing.T) { response := dispatchRequestFrame(context.Background(), MessageFrame{ Type: FrameTypeRequest, @@ -3799,7 +3811,9 @@ func TestDecodeRenameSessionPayloadBranches(t *testing.T) { }) t.Run("marshal error", func(t *testing.T) { - _, err := decodeRenameSessionPayload(struct{ Bad chan int `json:"bad"` }{Bad: make(chan int)}) + _, err := decodeRenameSessionPayload(struct { + Bad chan int `json:"bad"` + }{Bad: make(chan int)}) if err == nil || err.Code != ErrorCodeInvalidFrame.String() { t.Fatalf("expected invalid frame error, got %#v", err) } @@ -3845,7 +3859,9 @@ func TestDecodeListFilesPayloadBranches(t *testing.T) { }) t.Run("marshal error", func(t *testing.T) { - _, err := decodeListFilesPayload(struct{ Bad chan int `json:"bad"` }{Bad: make(chan int)}) + _, err := decodeListFilesPayload(struct { + Bad chan int `json:"bad"` + }{Bad: make(chan int)}) if err == nil || err.Code != ErrorCodeInvalidFrame.String() { t.Fatalf("expected invalid frame error, got %#v", err) } @@ -3888,7 +3904,9 @@ func TestDecodeSetSessionModelPayloadBranches(t *testing.T) { }) t.Run("marshal error", func(t *testing.T) { - _, err := decodeSetSessionModelPayload(struct{ Bad chan int `json:"bad"` }{Bad: make(chan int)}) + _, err := decodeSetSessionModelPayload(struct { + Bad chan int `json:"bad"` + }{Bad: make(chan int)}) if err == nil || err.Code != ErrorCodeInvalidFrame.String() { t.Fatalf("expected invalid frame error, got %#v", err) } @@ -3956,7 +3974,7 @@ func TestHandleAuthenticateFrameAdditionalBranches(t *testing.T) { type emptySubjectAuthenticator struct{} -func (emptySubjectAuthenticator) ValidateToken(token string) bool { return true } +func (emptySubjectAuthenticator) ValidateToken(token string) bool { return true } func (emptySubjectAuthenticator) ResolveSubjectID(token string) (string, bool) { return "", true } func TestRuntimeCallFailedFrameErrorCodes(t *testing.T) { @@ -3996,39 +4014,70 @@ func TestRuntimeCallFailedFrameErrorCodes(t *testing.T) { // runtimeOnlyStub implements RuntimePort but NOT ManagementRuntimePort. type runtimeOnlyStub struct{} -func (runtimeOnlyStub) Run(ctx context.Context, input RunInput) error { return nil } -func (runtimeOnlyStub) Compact(ctx context.Context, input CompactInput) (CompactResult, error) { return CompactResult{}, nil } +func (runtimeOnlyStub) Run(ctx context.Context, input RunInput) error { return nil } +func (runtimeOnlyStub) Compact(ctx context.Context, input CompactInput) (CompactResult, error) { + return CompactResult{}, nil +} func (runtimeOnlyStub) ExecuteSystemTool(ctx context.Context, input ExecuteSystemToolInput) (tools.ToolResult, error) { return tools.ToolResult{}, nil } -func (runtimeOnlyStub) ActivateSessionSkill(ctx context.Context, input SessionSkillMutationInput) error { return nil } -func (runtimeOnlyStub) DeactivateSessionSkill(ctx context.Context, input SessionSkillMutationInput) error { return nil } +func (runtimeOnlyStub) ActivateSessionSkill(ctx context.Context, input SessionSkillMutationInput) error { + return nil +} +func (runtimeOnlyStub) DeactivateSessionSkill(ctx context.Context, input SessionSkillMutationInput) error { + return nil +} func (runtimeOnlyStub) ListSessionSkills(ctx context.Context, input ListSessionSkillsInput) ([]SessionSkillState, error) { return nil, nil } func (runtimeOnlyStub) ListAvailableSkills(ctx context.Context, input ListAvailableSkillsInput) ([]AvailableSkillState, error) { return nil, nil } -func (runtimeOnlyStub) ResolvePermission(ctx context.Context, input PermissionResolutionInput) error { return nil } -func (runtimeOnlyStub) CancelRun(ctx context.Context, input CancelInput) (bool, error) { return false, nil } -func (runtimeOnlyStub) Events() <-chan RuntimeEvent { return nil } -func (runtimeOnlyStub) ListSessions(ctx context.Context) ([]SessionSummary, error) { return nil, nil } -func (runtimeOnlyStub) LoadSession(ctx context.Context, input LoadSessionInput) (Session, error) { return Session{}, nil } +func (runtimeOnlyStub) ResolvePermission(ctx context.Context, input PermissionResolutionInput) error { + return nil +} +func (runtimeOnlyStub) CancelRun(ctx context.Context, input CancelInput) (bool, error) { + return false, nil +} +func (runtimeOnlyStub) Events() <-chan RuntimeEvent { return nil } +func (runtimeOnlyStub) ListSessions(ctx context.Context) ([]SessionSummary, error) { return nil, nil } +func (runtimeOnlyStub) LoadSession(ctx context.Context, input LoadSessionInput) (Session, error) { + return Session{}, nil +} func (runtimeOnlyStub) ListSessionTodos(ctx context.Context, input ListSessionTodosInput) (TodoSnapshot, error) { return TodoSnapshot{}, nil } func (runtimeOnlyStub) GetRuntimeSnapshot(ctx context.Context, input GetRuntimeSnapshotInput) (RuntimeSnapshot, error) { return RuntimeSnapshot{}, nil } -func (runtimeOnlyStub) CreateSession(ctx context.Context, input CreateSessionInput) (string, error) { return "", nil } -func (runtimeOnlyStub) DeleteSession(ctx context.Context, input DeleteSessionInput) (bool, error) { return false, nil } -func (runtimeOnlyStub) RenameSession(ctx context.Context, input RenameSessionInput) error { return nil } -func (runtimeOnlyStub) ListFiles(ctx context.Context, input ListFilesInput) ([]FileEntry, error) { return nil, nil } -func (runtimeOnlyStub) ListModels(ctx context.Context, input ListModelsInput) ([]ModelEntry, error) { return nil, nil } -func (runtimeOnlyStub) SetSessionModel(ctx context.Context, input SetSessionModelInput) error { return nil } +func (runtimeOnlyStub) CreateSession(ctx context.Context, input CreateSessionInput) (string, error) { + return "", nil +} +func (runtimeOnlyStub) DeleteSession(ctx context.Context, input DeleteSessionInput) (bool, error) { + return false, nil +} +func (runtimeOnlyStub) RenameSession(ctx context.Context, input RenameSessionInput) error { return nil } +func (runtimeOnlyStub) ListFiles(ctx context.Context, input ListFilesInput) ([]FileEntry, error) { + return nil, nil +} +func (runtimeOnlyStub) ListModels(ctx context.Context, input ListModelsInput) ([]ModelEntry, error) { + return nil, nil +} +func (runtimeOnlyStub) SetSessionModel(ctx context.Context, input SetSessionModelInput) error { + return nil +} func (runtimeOnlyStub) GetSessionModel(ctx context.Context, input GetSessionModelInput) (SessionModelResult, error) { return SessionModelResult{}, nil } +func (runtimeOnlyStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { + return nil, nil +} +func (runtimeOnlyStub) RestoreCheckpoint(_ context.Context, _ CheckpointRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} +func (runtimeOnlyStub) UndoRestore(_ context.Context, _ UndoRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} type managementRuntimeStub struct { bootstrapRuntimeStub diff --git a/internal/gateway/contracts.go b/internal/gateway/contracts.go index 3a1da391..959fe9a3 100644 --- a/internal/gateway/contracts.go +++ b/internal/gateway/contracts.go @@ -265,6 +265,55 @@ type SessionModelResult struct { Provider string `json:"provider,omitempty"` } +// ListCheckpointsInput 描述查询 checkpoint 列表的输入。 +type ListCheckpointsInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是目标会话标识。 + SessionID string + // Limit 限制返回数量,0 表示不限制。 + Limit int + // RestorableOnly 仅返回可恢复的 checkpoint。 + RestorableOnly bool +} + +// CheckpointEntry 描述单个 checkpoint 的列表视图。 +type CheckpointEntry struct { + CheckpointID string `json:"checkpoint_id"` + SessionID string `json:"session_id"` + Reason string `json:"reason"` + Status string `json:"status"` + Restorable bool `json:"restorable"` + CreatedAt int64 `json:"created_at_ms"` +} + +// CheckpointRestoreInput 描述恢复 checkpoint 的输入。 +type CheckpointRestoreInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是目标会话标识。 + SessionID string + // CheckpointID 是要恢复的 checkpoint 标识。 + CheckpointID string + // Force 强制恢复,忽略冲突检测。 + Force bool +} + +// UndoRestoreInput 描述撤销 restore 的输入。 +type UndoRestoreInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是目标会话标识。 + SessionID string +} + +// CheckpointRestoreResult 描述 checkpoint 恢复操作的结果。 +type CheckpointRestoreResult struct { + CheckpointID string `json:"checkpoint_id"` + SessionID string `json:"session_id"` + HasConflict bool `json:"has_conflict,omitempty"` +} + // ProviderOption 表示前端管理面可见的 provider 及模型候选。 type ProviderOption struct { // ID 是 provider 标识。 @@ -596,6 +645,12 @@ type RuntimePort interface { SetSessionModel(ctx context.Context, input SetSessionModelInput) error // GetSessionModel 获取当前会话模型。 GetSessionModel(ctx context.Context, input GetSessionModelInput) (SessionModelResult, error) + // ListCheckpoints 查询指定会话的 checkpoint 列表。 + ListCheckpoints(ctx context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error) + // RestoreCheckpoint 恢复到指定 checkpoint。 + RestoreCheckpoint(ctx context.Context, input CheckpointRestoreInput) (CheckpointRestoreResult, error) + // UndoRestore 撤销最近一次 checkpoint 恢复。 + UndoRestore(ctx context.Context, input UndoRestoreInput) (CheckpointRestoreResult, error) } // ManagementRuntimePort 定义前端管理面访问配置能力的可选下游端口。 diff --git a/internal/gateway/contracts_test.go b/internal/gateway/contracts_test.go index 2b4f498d..105e3600 100644 --- a/internal/gateway/contracts_test.go +++ b/internal/gateway/contracts_test.go @@ -126,6 +126,18 @@ func (s *runtimePortCompileStub) CreateSession(_ context.Context, _ CreateSessio return "", nil } +func (s *runtimePortCompileStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { + return nil, nil +} + +func (s *runtimePortCompileStub) RestoreCheckpoint(_ context.Context, _ CheckpointRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + +func (s *runtimePortCompileStub) UndoRestore(_ context.Context, _ UndoRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + var _ RuntimePort = (*runtimePortCompileStub)(nil) var _ TransportAdapter = (*Server)(nil) var _ TransportAdapter = (*NetworkServer)(nil) diff --git a/internal/gateway/registry.go b/internal/gateway/registry.go index 90d66cb1..0a38f588 100644 --- a/internal/gateway/registry.go +++ b/internal/gateway/registry.go @@ -71,6 +71,9 @@ func (r *ActionRegistry) initCore() { r.core[FrameActionUpsertMCPServer] = handleUpsertMCPServerFrame r.core[FrameActionSetMCPServerEnabled] = handleSetMCPServerEnabledFrame r.core[FrameActionDeleteMCPServer] = handleDeleteMCPServerFrame + r.core[FrameActionListCheckpoints] = handleListCheckpointsFrame + r.core[FrameActionRestoreCheckpoint] = handleRestoreCheckpointFrame + r.core[FrameActionUndoRestore] = handleUndoRestoreFrame } // Lookup returns the handler for an action. diff --git a/internal/gateway/rpc_dispatch_test.go b/internal/gateway/rpc_dispatch_test.go index 3c1dac8d..3100369c 100644 --- a/internal/gateway/rpc_dispatch_test.go +++ b/internal/gateway/rpc_dispatch_test.go @@ -198,6 +198,18 @@ func (s *rpcRunCaptureRuntimeStub) GetRuntimeSnapshot( return RuntimeSnapshot{}, nil } +func (s *rpcRunCaptureRuntimeStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { + return nil, nil +} + +func (s *rpcRunCaptureRuntimeStub) RestoreCheckpoint(_ context.Context, _ CheckpointRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + +func (s *rpcRunCaptureRuntimeStub) UndoRestore(_ context.Context, _ UndoRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + func TestDispatchRPCRequestResultEncodeError(t *testing.T) { installHandlerRegistryForTest(t, map[FrameAction]requestFrameHandler{ FrameActionPing: func(_ context.Context, frame MessageFrame, _ RuntimePort) MessageFrame { @@ -957,6 +969,18 @@ func (s *runtimePortOnlyStub) GetSessionModel(_ context.Context, _ GetSessionMod return SessionModelResult{}, nil } +func (s *runtimePortOnlyStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { + return nil, nil +} + +func (s *runtimePortOnlyStub) RestoreCheckpoint(_ context.Context, _ CheckpointRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + +func (s *runtimePortOnlyStub) UndoRestore(_ context.Context, _ UndoRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + func TestDispatchRPCRequestProviderMethodsManagementPortUnavailable(t *testing.T) { ctx := WithRequestSource(context.Background(), RequestSourceIPC) ctx = WithRequestACL(ctx, NewStrictControlPlaneACL()) diff --git a/internal/gateway/server_test.go b/internal/gateway/server_test.go index a4e27b04..e7f815ac 100644 --- a/internal/gateway/server_test.go +++ b/internal/gateway/server_test.go @@ -457,6 +457,18 @@ func (s *runtimePortEventStub) GetRuntimeSnapshot( return RuntimeSnapshot{}, nil } +func (s *runtimePortEventStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { + return nil, nil +} + +func (s *runtimePortEventStub) RestoreCheckpoint(_ context.Context, _ CheckpointRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + +func (s *runtimePortEventStub) UndoRestore(_ context.Context, _ UndoRestoreInput) (CheckpointRestoreResult, error) { + return CheckpointRestoreResult{}, nil +} + func decodeJSONRPCResultFrame(response protocol.JSONRPCResponse) (MessageFrame, error) { if response.Result == nil { return MessageFrame{}, errors.New("rpc result is nil") diff --git a/internal/gateway/types.go b/internal/gateway/types.go index eca4b8e0..85eca473 100644 --- a/internal/gateway/types.go +++ b/internal/gateway/types.go @@ -80,6 +80,12 @@ const ( FrameActionDeleteMCPServer FrameAction = "delete_mcp_server" // FrameActionWakeOpenURL 表示处理 URL Scheme 唤醒请求。 FrameActionWakeOpenURL FrameAction = "wake.openUrl" + // FrameActionListCheckpoints 表示查询会话 checkpoint 列表。 + FrameActionListCheckpoints FrameAction = "checkpoint_list" + // FrameActionRestoreCheckpoint 表示恢复到指定 checkpoint。 + FrameActionRestoreCheckpoint FrameAction = "checkpoint_restore" + // FrameActionUndoRestore 表示撤销最近一次 checkpoint 恢复。 + FrameActionUndoRestore FrameAction = "checkpoint_undo_restore" ) // InputPartType 表示多模态输入分片类型。 diff --git a/internal/runtime/checkpoint_gate.go b/internal/runtime/checkpoint_gate.go index 1f5e83e9..0d3ae0ed 100644 --- a/internal/runtime/checkpoint_gate.go +++ b/internal/runtime/checkpoint_gate.go @@ -14,11 +14,9 @@ import ( ) // createPreWriteCheckpoint 在工具执行前创建 checkpoint,采用两阶段提交。 +// shadowRepo 可用时做完整快照,不可用时降级为 session-only checkpoint。 // 失败时不阻塞工具执行,仅返回 error 由调用方发 warning event。 func (s *Service) createPreWriteCheckpoint(ctx context.Context, state *runState) error { - if s.shadowRepo == nil || !s.shadowRepo.IsAvailable() { - return nil - } if s.checkpointStore == nil { return nil } @@ -28,6 +26,11 @@ func (s *Service) createPreWriteCheckpoint(ctx context.Context, state *runState) runID := state.runID state.mu.Unlock() + // 降级模式:shadowRepo 不可用时创建 session-only checkpoint + if s.shadowRepo == nil || !s.shadowRepo.IsAvailable() { + return s.createDegradedCheckpoint(ctx, session, runID) + } + checkpointID := agentsession.NewID("checkpoint") ref := checkpoint.RefForCheckpoint(session.ID, checkpointID) commitMsg := fmt.Sprintf("pre_write checkpoint for session %s", session.ID) @@ -131,3 +134,55 @@ func isBashWriteCommand(argumentsJSON string) bool { return intent.Classification == tools.BashIntentClassificationLocalMutation || intent.Classification == tools.BashIntentClassificationDestructive } + +// createDegradedCheckpoint 创建 session-only checkpoint(无代码快照),用于 shadowRepo 不可用时。 +func (s *Service) createDegradedCheckpoint(ctx context.Context, session agentsession.Session, runID string) error { + checkpointID := agentsession.NewID("checkpoint") + now := time.Now() + + head := session.HeadSnapshot() + headJSON, err := json.Marshal(head) + if err != nil { + return fmt.Errorf("checkpoint: marshal degraded head: %w", err) + } + messagesJSON, err := json.Marshal(session.Messages) + if err != nil { + return fmt.Errorf("checkpoint: marshal degraded messages: %w", err) + } + + record := agentsession.CheckpointRecord{ + CheckpointID: checkpointID, + WorkspaceKey: agentsession.WorkspacePathKey(session.Workdir), + SessionID: session.ID, + RunID: runID, + Workdir: session.Workdir, + CreatedAt: now, + Reason: agentsession.CheckpointReasonPreWriteDegraded, + Restorable: true, + Status: agentsession.CheckpointStatusCreating, + } + sessionCP := agentsession.SessionCheckpoint{ + ID: agentsession.NewID("sc"), + SessionID: session.ID, + HeadJSON: string(headJSON), + MessagesJSON: string(messagesJSON), + CreatedAt: now, + } + + saved, err := s.checkpointStore.CreateCheckpoint(ctx, checkpoint.CreateCheckpointInput{ + Record: record, + SessionCP: sessionCP, + }) + if err != nil { + return fmt.Errorf("checkpoint: degraded create: %w", err) + } + + s.emitRunScoped(ctx, EventCheckpointCreated, nil, CheckpointCreatedPayload{ + CheckpointID: saved.CheckpointID, + CodeCheckpointRef: "", + SessionCheckpointRef: saved.SessionCheckpointRef, + CommitHash: "", + Reason: string(saved.Reason), + }) + return nil +} diff --git a/internal/runtime/checkpoint_restore.go b/internal/runtime/checkpoint_restore.go new file mode 100644 index 00000000..81613b45 --- /dev/null +++ b/internal/runtime/checkpoint_restore.go @@ -0,0 +1,265 @@ +package runtime + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "neo-code/internal/checkpoint" + providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" +) + +// GatewayRestoreInput 描述来自 Gateway 的 checkpoint 恢复请求。 +type GatewayRestoreInput struct { + SessionID string `json:"session_id"` + CheckpointID string `json:"checkpoint_id"` + Force bool `json:"force,omitempty"` +} + +// RestoreResult 描述 restore/undo 操作的结果。 +type RestoreResult struct { + CheckpointID string `json:"checkpoint_id"` + SessionID string `json:"session_id"` + Conflict *checkpoint.ConflictResult `json:"conflict,omitempty"` +} + +// RestoreCheckpoint 恢复指定 checkpoint 的会话和工作区状态。 +func (s *Service) RestoreCheckpoint(ctx context.Context, input GatewayRestoreInput) (RestoreResult, error) { + if s.checkpointStore == nil { + return RestoreResult{}, fmt.Errorf("checkpoint: store not available") + } + + sessionID := strings.TrimSpace(input.SessionID) + checkpointID := strings.TrimSpace(input.CheckpointID) + if sessionID == "" || checkpointID == "" { + return RestoreResult{}, fmt.Errorf("checkpoint: session_id and checkpoint_id required") + } + + // 1. Load checkpoint record + record, sessionCP, err := s.checkpointStore.GetCheckpoint(ctx, checkpointID) + if err != nil { + return RestoreResult{}, err + } + if record.SessionID != sessionID { + return RestoreResult{}, fmt.Errorf("checkpoint: session mismatch") + } + if record.Status != agentsession.CheckpointStatusAvailable { + return RestoreResult{}, fmt.Errorf("checkpoint: status is %s, expected available", record.Status) + } + if !record.Restorable { + return RestoreResult{}, fmt.Errorf("checkpoint: not restorable") + } + + // 2. Conflict detection + if s.shadowRepo != nil && record.CodeCheckpointRef != "" && !input.Force { + commitHash, resolveErr := s.resolveCommitHashForRef(ctx, record.CodeCheckpointRef) + if resolveErr == nil && commitHash != "" { + conflict, conflictErr := s.shadowRepo.DetectConflicts(ctx, commitHash) + if conflictErr == nil && conflict.HasConflict { + return RestoreResult{ + CheckpointID: checkpointID, + SessionID: sessionID, + Conflict: &conflict, + }, fmt.Errorf("checkpoint: conflicts detected, use force to override") + } + } + } + + // 3. Create pre-restore guard checkpoint (current state snapshot) + guardID := agentsession.NewID("checkpoint") + guardRef := checkpoint.RefForCheckpoint(sessionID, guardID) + guardCommitHash := "" + + if s.shadowRepo != nil && s.shadowRepo.IsAvailable() { + guardCommitHash, _ = s.shadowRepo.Snapshot(ctx, guardRef, fmt.Sprintf("pre_restore_guard for session %s", sessionID)) + } + + guardRecord, guardErr := s.createGuardCheckpoint(ctx, sessionID, record.RunID, guardID, guardRef, guardCommitHash) + if guardErr != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: create guard: %w", guardErr) + } + + // 4. Restore code (git checkout) + if s.shadowRepo != nil && record.CodeCheckpointRef != "" { + restoreCommitHash, resolveErr := s.resolveCommitHashForRef(ctx, record.CodeCheckpointRef) + if resolveErr == nil && restoreCommitHash != "" { + if err := s.shadowRepo.Restore(ctx, restoreCommitHash); err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: restore code: %w", err) + } + } + } + + // 5. Unmarshal session checkpoint + if sessionCP == nil { + return RestoreResult{}, fmt.Errorf("checkpoint: no session checkpoint data") + } + var head agentsession.SessionHead + if err := json.Unmarshal([]byte(sessionCP.HeadJSON), &head); err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: unmarshal head: %w", err) + } + var messages []providertypes.Message + if err := json.Unmarshal([]byte(sessionCP.MessagesJSON), &messages); err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: unmarshal messages: %w", err) + } + + // 6. Determine checkpoint IDs to mark + markAvailableIDs := []string{guardRecord.CheckpointID} + var markRestoredIDs []string + allRecords, listErr := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{}) + if listErr == nil { + for _, r := range allRecords { + if r.CreatedAt.After(record.CreatedAt) && r.Status == agentsession.CheckpointStatusAvailable { + markRestoredIDs = append(markRestoredIDs, r.CheckpointID) + } + } + } + + // 7. Restore session + update checkpoint statuses (single transaction) + restoreInput := checkpoint.RestoreCheckpointInput{ + SessionID: sessionID, + Head: head, + Messages: messages, + UpdatedAt: time.Now(), + MarkAvailableIDs: markAvailableIDs, + MarkRestoredIDs: markRestoredIDs, + } + if err := s.checkpointStore.RestoreCheckpoint(ctx, restoreInput); err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: restore: %w", err) + } + + // 8. Update runtime session if it's the current session + s.updateRuntimeSessionAfterRestore(sessionID, head, messages) + + s.emitRunScoped(ctx, EventCheckpointRestored, nil, CheckpointRestoredPayload{ + CheckpointID: checkpointID, + SessionID: sessionID, + GuardCheckpointID: guardRecord.CheckpointID, + }) + return RestoreResult{ + CheckpointID: checkpointID, + SessionID: sessionID, + }, nil +} + +// UndoRestoreCheckpoint 撤销最近一次 restore,通过 pre_restore_guard 恢复到 restore 前的状态。 +func (s *Service) UndoRestoreCheckpoint(ctx context.Context, sessionID string) (RestoreResult, error) { + if s.checkpointStore == nil { + return RestoreResult{}, fmt.Errorf("checkpoint: store not available") + } + + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + return RestoreResult{}, fmt.Errorf("checkpoint: session_id required") + } + + // Find latest guard checkpoint + records, err := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{ + Limit: 1, + RestorableOnly: true, + }) + if err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: list for undo: %w", err) + } + + var guardRecord *agentsession.CheckpointRecord + for _, r := range records { + if r.Reason == agentsession.CheckpointReasonGuard { + guardRecord = &r + break + } + } + if guardRecord == nil { + return RestoreResult{}, fmt.Errorf("checkpoint: no guard checkpoint found for undo") + } + + // Recursively call RestoreCheckpoint with force + result, err := s.RestoreCheckpoint(ctx, GatewayRestoreInput{ + SessionID: sessionID, + CheckpointID: guardRecord.CheckpointID, + Force: true, + }) + if err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: undo restore: %w", err) + } + + s.emitRunScoped(ctx, EventCheckpointUndoRestore, nil, CheckpointUndoRestorePayload{ + GuardCheckpointID: guardRecord.CheckpointID, + SessionID: sessionID, + }) + return result, nil +} + +// createGuardCheckpoint 创建 pre_restore_guard 类型的 checkpoint。 +func (s *Service) createGuardCheckpoint(ctx context.Context, sessionID, runID, checkpointID, ref, commitHash string) (agentsession.CheckpointRecord, error) { + session, err := s.sessionStore.LoadSession(ctx, sessionID) + if err != nil { + return agentsession.CheckpointRecord{}, fmt.Errorf("checkpoint: load session for guard: %w", err) + } + + head := session.HeadSnapshot() + headJSON, err := json.Marshal(head) + if err != nil { + return agentsession.CheckpointRecord{}, fmt.Errorf("checkpoint: marshal guard head: %w", err) + } + messagesJSON, err := json.Marshal(session.Messages) + if err != nil { + return agentsession.CheckpointRecord{}, fmt.Errorf("checkpoint: marshal guard messages: %w", err) + } + + now := time.Now() + record := agentsession.CheckpointRecord{ + CheckpointID: checkpointID, + WorkspaceKey: agentsession.WorkspacePathKey(session.Workdir), + SessionID: sessionID, + RunID: runID, + Workdir: session.Workdir, + CreatedAt: now, + Reason: agentsession.CheckpointReasonGuard, + CodeCheckpointRef: ref, + Restorable: true, + Status: agentsession.CheckpointStatusCreating, + } + sessionCP := agentsession.SessionCheckpoint{ + ID: agentsession.NewID("sc"), + SessionID: sessionID, + HeadJSON: string(headJSON), + MessagesJSON: string(messagesJSON), + CreatedAt: now, + } + + saved, err := s.checkpointStore.CreateCheckpoint(ctx, checkpoint.CreateCheckpointInput{ + Record: record, + SessionCP: sessionCP, + }) + if err != nil { + return agentsession.CheckpointRecord{}, err + } + + if commitHash != "" { + s.emitRunScoped(ctx, EventCheckpointCreated, nil, CheckpointCreatedPayload{ + CheckpointID: saved.CheckpointID, + CodeCheckpointRef: saved.CodeCheckpointRef, + SessionCheckpointRef: saved.SessionCheckpointRef, + CommitHash: commitHash, + Reason: string(saved.Reason), + }) + } + return saved, nil +} + +// resolveCommitHashForRef 通过 git rev-parse 解析 ref 对应的 commit hash。 +func (s *Service) resolveCommitHashForRef(ctx context.Context, ref string) (string, error) { + if s.shadowRepo == nil || ref == "" { + return "", fmt.Errorf("shadow repo not available") + } + // Use the shadow repo's existing git infrastructure + return s.shadowRepo.ResolveRef(ctx, ref) +} + +// updateRuntimeSessionAfterRestore 在 restore 后更新运行时会话状态。 +// 当前实现不做运行时状态直接修改,依赖下次 session 加载时从 DB 读取恢复后的状态。 +func (s *Service) updateRuntimeSessionAfterRestore(sessionID string, head agentsession.SessionHead, messages []providertypes.Message) { +} diff --git a/internal/runtime/checkpoint_resume.go b/internal/runtime/checkpoint_resume.go new file mode 100644 index 00000000..1c3b027e --- /dev/null +++ b/internal/runtime/checkpoint_resume.go @@ -0,0 +1,38 @@ +package runtime + +import ( + "context" + "log" + "time" + + agentsession "neo-code/internal/session" +) + +// updateResumeCheckpoint 在 phase 转换时写入或更新 ResumeCheckpoint。 +// 失败仅 log,不阻塞主流程。 +func (s *Service) updateResumeCheckpoint(ctx context.Context, state *runState, phase string, completionState string) { + if s.checkpointStore == nil { + return + } + + state.mu.Lock() + session := state.session + runID := state.runID + turn := state.turn + state.mu.Unlock() + + rc := agentsession.ResumeCheckpoint{ + ID: agentsession.NewID("rc"), + WorkspaceKey: agentsession.WorkspacePathKey(session.Workdir), + RunID: runID, + SessionID: session.ID, + Turn: turn, + Phase: phase, + CompletionState: completionState, + UpdatedAt: time.Now(), + } + + if err := s.checkpointStore.SetResumeCheckpoint(ctx, rc); err != nil { + log.Printf("checkpoint: set resume checkpoint for %s: %v", session.ID, err) + } +} diff --git a/internal/runtime/compact.go b/internal/runtime/compact.go index e39cfb59..f8624e1f 100644 --- a/internal/runtime/compact.go +++ b/internal/runtime/compact.go @@ -2,10 +2,12 @@ package runtime import ( "context" + "encoding/json" "errors" "strings" "time" + "neo-code/internal/checkpoint" "neo-code/internal/config" contextcompact "neo-code/internal/context/compact" providertypes "neo-code/internal/provider/types" @@ -155,6 +157,8 @@ func (s *Service) runCompactForSession( return failCompact(errors.New(reason)) } + s.createCompactCheckpoint(ctx, runID, session) + s.emit(ctx, EventCompactStart, runID, session.ID, string(mode)) result, err := runner.Run(ctx, contextcompact.Input{ @@ -248,3 +252,59 @@ func resolveCompactProviderSelection(session agentsession.Session, cfg config.Co } return resolved, strings.TrimSpace(cfg.CurrentModel), nil } + +// createCompactCheckpoint 为 compact 操作创建 session-only checkpoint。 +func (s *Service) createCompactCheckpoint(ctx context.Context, runID string, session agentsession.Session) { + if s.checkpointStore == nil { + return + } + + checkpointID := agentsession.NewID("checkpoint") + now := time.Now() + + head := session.HeadSnapshot() + headJSON, err := json.Marshal(head) + if err != nil { + return + } + messagesJSON, err := json.Marshal(session.Messages) + if err != nil { + return + } + + record := agentsession.CheckpointRecord{ + CheckpointID: checkpointID, + WorkspaceKey: agentsession.WorkspacePathKey(session.Workdir), + SessionID: session.ID, + RunID: runID, + Workdir: session.Workdir, + CreatedAt: now, + Reason: agentsession.CheckpointReasonCompact, + Restorable: true, + Status: agentsession.CheckpointStatusCreating, + } + + // Shadow snapshot if available + if s.shadowRepo != nil && s.shadowRepo.IsAvailable() { + ref := checkpoint.RefForCheckpoint(session.ID, checkpointID) + if commitHash, err := s.shadowRepo.Snapshot(ctx, ref, "compact checkpoint"); err == nil { + record.CodeCheckpointRef = ref + _ = commitHash + } + } + + sessionCP := agentsession.SessionCheckpoint{ + ID: agentsession.NewID("sc"), + SessionID: session.ID, + HeadJSON: string(headJSON), + MessagesJSON: string(messagesJSON), + CreatedAt: now, + } + + if _, err := s.checkpointStore.CreateCheckpoint(ctx, checkpoint.CreateCheckpointInput{ + Record: record, + SessionCP: sessionCP, + }); err != nil { + return + } +} diff --git a/internal/runtime/events.go b/internal/runtime/events.go index 351789e6..8a95620e 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -420,6 +420,10 @@ const ( EventCheckpointCreated EventType = "checkpoint_created" // EventCheckpointWarning 表示 checkpoint 创建过程中出现非致命告警。 EventCheckpointWarning EventType = "checkpoint_warning" + // EventCheckpointRestored 表示 checkpoint 已成功恢复。 + EventCheckpointRestored EventType = "checkpoint_restored" + // EventCheckpointUndoRestore 表示 restore 已撤销。 + EventCheckpointUndoRestore EventType = "checkpoint_undo_restore" ) // TokenUsagePayload 承载单轮 token 用量统计。 @@ -447,3 +451,16 @@ type CheckpointWarningPayload struct { Error string `json:"error"` Phase string `json:"phase"` } + +// CheckpointRestoredPayload 描述 checkpoint 恢复成功事件。 +type CheckpointRestoredPayload struct { + CheckpointID string `json:"checkpoint_id"` + SessionID string `json:"session_id"` + GuardCheckpointID string `json:"guard_checkpoint_id"` +} + +// CheckpointUndoRestorePayload 描述 restore 撤销事件。 +type CheckpointUndoRestorePayload struct { + GuardCheckpointID string `json:"guard_checkpoint_id"` + SessionID string `json:"session_id"` +} diff --git a/internal/runtime/run.go b/internal/runtime/run.go index aa1524c5..bfd6919c 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -90,6 +90,9 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { s.finishRun(runToken) }() defer func() { + if statePtr != nil { + s.updateResumeCheckpoint(runCtx, statePtr, "stopped", "completed") + } s.emitRunTermination(runCtx, input, statePtr, err) }() ctx = runCtx @@ -170,6 +173,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { return s.handleRunError(err) } s.emitRuntimeSnapshotUpdated(ctx, &state, "session_start") + s.updateResumeCheckpoint(ctx, &state, "plan", "") maxTurns := resolveRuntimeMaxTurns(initialCfg.Runtime) for turn := 0; ; turn++ { @@ -344,6 +348,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { if err := s.setBaseRunState(ctx, &state, controlplane.RunStateVerify); err != nil { return s.handleRunError(err) } + s.updateResumeCheckpoint(ctx, &state, "verify", "completed") completionHookOutput := s.runHookPoint( ctx, &state, @@ -485,6 +490,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { if err := s.setBaseRunState(ctx, &state, controlplane.RunStateExecute); err != nil { return s.handleRunError(err) } + s.updateResumeCheckpoint(ctx, &state, "execute", "") summary, err := s.executeAssistantToolCalls(ctx, &state, snapshot, turnOutput.assistant) if err != nil { return s.handleRunError(err) @@ -524,6 +530,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { if err := s.setBaseRunState(ctx, &state, controlplane.RunStateVerify); err != nil { return s.handleRunError(err) } + s.updateResumeCheckpoint(ctx, &state, "verify", "completed") break } } diff --git a/internal/session/checkpoint_types.go b/internal/session/checkpoint_types.go index 1cfd148d..a44ed771 100644 --- a/internal/session/checkpoint_types.go +++ b/internal/session/checkpoint_types.go @@ -6,11 +6,12 @@ import "time" type CheckpointReason string const ( - CheckpointReasonPreWrite CheckpointReason = "pre_write" - CheckpointReasonCompact CheckpointReason = "compact" - CheckpointReasonPlanMode CheckpointReason = "plan_mode" - CheckpointReasonManual CheckpointReason = "manual" - CheckpointReasonGuard CheckpointReason = "pre_restore_guard" + CheckpointReasonPreWrite CheckpointReason = "pre_write" + CheckpointReasonCompact CheckpointReason = "compact" + CheckpointReasonPlanMode CheckpointReason = "plan_mode" + CheckpointReasonManual CheckpointReason = "manual" + CheckpointReasonGuard CheckpointReason = "pre_restore_guard" + CheckpointReasonPreWriteDegraded CheckpointReason = "pre_write_degraded" ) // CheckpointStatus 描述 checkpoint 的生命周期状态。 From 9c2e4b8b3de5e0a8b4045c48b3fa5b7a3c4266ce Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Sun, 3 May 2026 22:56:12 +0800 Subject: [PATCH 3/8] =?UTF-8?q?feat(checkpoint):=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E6=A3=80=E6=9F=A5=E7=82=B9=E5=BB=BA=E7=AB=8B=E8=8A=82=E7=82=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/app/bootstrap.go | 21 +++++++++++---------- internal/checkpoint/checkpoint_manager.go | 16 ++++++++++++++-- internal/cli/root_test.go | 12 ++++++++++++ internal/session/sqlite_store.go | 9 +++++++++ 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index 5aae3f5e..94669476 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -232,35 +232,36 @@ func BuildGatewayServerDeps(ctx context.Context, opts BootstrapOptions) (Runtime } // Checkpoint 基础设施:影子仓库 + SQLite checkpoint 存储 - dbPath := agentsession.DatabasePath(sharedDeps.ConfigManager.BaseDir(), cfg.Workdir) + // 复用 sessionStore 的 *sql.DB 连接,避免 Windows 上多连接文件锁定。 + sessionDB := sessionStore.DB() + var checkpointStore *checkpoint.SQLiteCheckpointStore + if sessionDB != nil { + checkpointStore = checkpoint.NewSQLiteCheckpointStoreWithDB(sessionDB) + } if gitAvail, _ := checkpoint.CheckGitAvailability(ctx); gitAvail { projectDir := agentsession.HashWorkspaceRoot(cfg.Workdir) shadowDir := filepath.Join(sharedDeps.ConfigManager.BaseDir(), "projects", projectDir) shadowRepo := checkpoint.NewShadowRepo(shadowDir, cfg.Workdir) if err := shadowRepo.Init(ctx); err != nil { log.Printf("checkpoint shadow repo init warning: %v", err) - } else { - checkpointStore := checkpoint.NewSQLiteCheckpointStore(dbPath) + } else if checkpointStore != nil { runtimeSvc.SetCheckpointDependencies(checkpointStore, shadowRepo) } - } else { + } else if checkpointStore != nil { // 降级模式:git 不可用时仍可创建 session-only checkpoint - checkpointStore := checkpoint.NewSQLiteCheckpointStore(dbPath) runtimeSvc.SetCheckpointDependencies(checkpointStore, nil) } // 启动时修复残留的 creating 状态 checkpoint - { - repairStore := checkpoint.NewSQLiteCheckpointStore(dbPath) - if repaired, err := repairStore.RepairCreatingCheckpoints(ctx); err != nil { + if checkpointStore != nil { + if repaired, err := checkpointStore.RepairCreatingCheckpoints(ctx); err != nil { log.Printf("checkpoint repair warning: %v", err) } else if repaired > 0 { log.Printf("checkpoint repair: fixed %d stale checkpoints", repaired) } - _ = repairStore.Close() } runtimeImpl := agentruntime.Runtime(runtimeSvc) - closeFns := []func() error{toolsCleanup, sessionStore.Close} + closeFns := []func() error{toolsCleanup, checkpointStore.Close, sessionStore.Close} needCleanup = false diff --git a/internal/checkpoint/checkpoint_manager.go b/internal/checkpoint/checkpoint_manager.go index d9137620..6dae9662 100644 --- a/internal/checkpoint/checkpoint_manager.go +++ b/internal/checkpoint/checkpoint_manager.go @@ -53,6 +53,7 @@ type SQLiteCheckpointStore struct { dbPath string initMu sync.Mutex db *sql.DB + ownsDB bool // true 表示本实例打开的连接,Close 时需释放 } // NewSQLiteCheckpointStore 创建 checkpoint 存储实例。 @@ -63,9 +64,19 @@ func NewSQLiteCheckpointStore(dbPath string) *SQLiteCheckpointStore { } } -// Close 释放数据库连接。 +// NewSQLiteCheckpointStoreWithDB 创建 checkpoint 存储实例,复用已有的 *sql.DB 连接。 +// 适用于与 session store 共享同一数据库文件的场景,避免 Windows 上多连接文件锁定。 +// Close 不会关闭传入的连接,由调用方管理连接生命周期。 +func NewSQLiteCheckpointStoreWithDB(db *sql.DB) *SQLiteCheckpointStore { + return &SQLiteCheckpointStore{ + db: db, + ownsDB: false, + } +} + +// Close 释放数据库连接。仅当本实例拥有连接时(ownsDB=true)才实际关闭。 func (s *SQLiteCheckpointStore) Close() error { - if s == nil || s.db == nil { + if s == nil || s.db == nil || !s.ownsDB { return nil } return s.db.Close() @@ -97,6 +108,7 @@ func (s *SQLiteCheckpointStore) ensureDB(ctx context.Context) (*sql.DB, error) { } } s.db = db + s.ownsDB = true return db, nil } diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go index 9576962f..6cab64fb 100644 --- a/internal/cli/root_test.go +++ b/internal/cli/root_test.go @@ -1092,6 +1092,18 @@ func (stubRuntimePort) GetRuntimeSnapshot( return gateway.RuntimeSnapshot{}, nil } +func (stubRuntimePort) ListCheckpoints(context.Context, gateway.ListCheckpointsInput) ([]gateway.CheckpointEntry, error) { + return nil, nil +} + +func (stubRuntimePort) RestoreCheckpoint(context.Context, gateway.CheckpointRestoreInput) (gateway.CheckpointRestoreResult, error) { + return gateway.CheckpointRestoreResult{}, nil +} + +func (stubRuntimePort) UndoRestore(context.Context, gateway.UndoRestoreInput) (gateway.CheckpointRestoreResult, error) { + return gateway.CheckpointRestoreResult{}, nil +} + func (s *stubGatewayServer) ListenAddress() string { return s.listenAddress } diff --git a/internal/session/sqlite_store.go b/internal/session/sqlite_store.go index 6a94e645..63a46786 100644 --- a/internal/session/sqlite_store.go +++ b/internal/session/sqlite_store.go @@ -94,6 +94,15 @@ func (s *SQLiteStore) Close() error { return s.db.Close() } +// DB 返回底层 *sql.DB 连接,供需要共享同一数据库连接的组件使用。 +// 调用前必须已触发过 ensureDB(如通过任何读写操作)。 +func (s *SQLiteStore) DB() *sql.DB { + if s == nil { + return nil + } + return s.db +} + // CleanupExpiredSessions 删除超过指定时长未更新的会话及其附件,返回删除数量。 func (s *SQLiteStore) CleanupExpiredSessions(ctx context.Context, maxAge time.Duration) (int, error) { if err := ctx.Err(); err != nil { From 8c7083684300dbc1700e468052279bbb09d1ed3a Mon Sep 17 00:00:00 2001 From: xgopilot Date: Sun, 3 May 2026 15:07:46 +0000 Subject: [PATCH 4/8] =?UTF-8?q?test(gateway):=20=E8=A1=A5=E9=BD=90=20urlsc?= =?UTF-8?q?heme=20runtime=20stub?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- .../dispatcher_integration_unix_test.go | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go b/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go index caadece3..6c3c0bfe 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go +++ b/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go @@ -212,6 +212,27 @@ func (s *urlschemeIntegrationRuntimeStub) GetSessionModel(context.Context, gatew return gateway.SessionModelResult{}, nil } +func (s *urlschemeIntegrationRuntimeStub) ListCheckpoints( + context.Context, + gateway.ListCheckpointsInput, +) ([]gateway.CheckpointEntry, error) { + return nil, nil +} + +func (s *urlschemeIntegrationRuntimeStub) RestoreCheckpoint( + context.Context, + gateway.CheckpointRestoreInput, +) (gateway.CheckpointRestoreResult, error) { + return gateway.CheckpointRestoreResult{}, nil +} + +func (s *urlschemeIntegrationRuntimeStub) UndoRestore( + context.Context, + gateway.UndoRestoreInput, +) (gateway.CheckpointRestoreResult, error) { + return gateway.CheckpointRestoreResult{}, nil +} + func waitGatewayReady(address string, timeout time.Duration) error { deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { From c1ab89ce6f8ae83851033db59386734b33dee082 Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Sun, 3 May 2026 23:41:02 +0800 Subject: [PATCH 5/8] =?UTF-8?q?feat(checkpoint):=E4=BF=AEcheckpoint?= =?UTF-8?q?=E5=88=9B=E5=BB=BA=E8=8A=82=E7=82=B9=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/checkpoint/shadow_repo.go | 36 ++++++++++++++ internal/cli/gateway_runtime_bridge.go | 66 ++++++++++++++++++++++---- internal/runtime/checkpoint_gate.go | 58 ++++++---------------- internal/runtime/checkpoint_restore.go | 12 ++++- internal/runtime/run.go | 17 +++---- 5 files changed, 127 insertions(+), 62 deletions(-) diff --git a/internal/checkpoint/shadow_repo.go b/internal/checkpoint/shadow_repo.go index 33c2eab4..65862b72 100644 --- a/internal/checkpoint/shadow_repo.go +++ b/internal/checkpoint/shadow_repo.go @@ -84,6 +84,18 @@ func (r *ShadowRepo) Init(ctx context.Context) error { return fmt.Errorf("checkpoint: set core.worktree: %w", err) } + // 设置提交者身份,避免机器上无全局 git 配置时 commit 失败 + ctx3, cancel3 := context.WithTimeout(context.Background(), gitCommandTimeout) + defer cancel3() + if err := r.gitExec(ctx3, "config", "user.name", "neocode-checkpoint"); err != nil { + return fmt.Errorf("checkpoint: set user.name: %w", err) + } + ctx4, cancel4 := context.WithTimeout(context.Background(), gitCommandTimeout) + defer cancel4() + if err := r.gitExec(ctx4, "config", "user.email", "checkpoint@neocode.local"); err != nil { + return fmt.Errorf("checkpoint: set user.email: %w", err) + } + r.gitAvailable = true return nil } @@ -190,6 +202,30 @@ func (r *ShadowRepo) DeleteRef(ctx context.Context, ref string) error { return nil } +// HasCodeChanges 检查工作区是否有未提交的代码变更。 +// 使用 git diff --quiet HEAD -- . 检测,退出码 1 表示有变更。 +// git 不可用时返回 true(保守策略)。 +func (r *ShadowRepo) HasCodeChanges(ctx context.Context) bool { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.gitAvailable { + return true + } + + ctx, cancel := context.WithTimeout(ctx, gitCommandTimeout) + defer cancel() + + cmd := r.buildGitCommand(ctx, "diff", "--quiet", "HEAD", "--", ".") + if err := cmd.Run(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + return exitErr.ExitCode() == 1 + } + return true + } + return false +} + // HealthCheck 验证 bare 仓库存在且可执行 git 操作。 func (r *ShadowRepo) HealthCheck(ctx context.Context) error { r.mu.Lock() diff --git a/internal/cli/gateway_runtime_bridge.go b/internal/cli/gateway_runtime_bridge.go index 1dc4a7ea..716605d0 100644 --- a/internal/cli/gateway_runtime_bridge.go +++ b/internal/cli/gateway_runtime_bridge.go @@ -13,6 +13,7 @@ import ( "time" "neo-code/internal/app" + "neo-code/internal/checkpoint" "neo-code/internal/config" configstate "neo-code/internal/config/state" "neo-code/internal/gateway" @@ -42,6 +43,12 @@ type runtimeSnapshotGetter interface { GetRuntimeSnapshot(ctx context.Context, sessionID string) (agentruntime.RuntimeSnapshot, error) } +type runtimeCheckpointer interface { + ListCheckpoints(ctx context.Context, sessionID string, opts checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) + RestoreCheckpoint(ctx context.Context, input agentruntime.GatewayRestoreInput) (agentruntime.RestoreResult, error) + UndoRestoreCheckpoint(ctx context.Context, sessionID string) (agentruntime.RestoreResult, error) +} + // bridgeSessionStore 定义桥接层对会话存储的最低需求。 type bridgeSessionStore interface { DeleteSession(ctx context.Context, sessionID string) error @@ -1365,22 +1372,63 @@ type manualModelPayload struct { var _ gateway.RuntimePort = (*gatewayRuntimePortBridge)(nil) func (b *gatewayRuntimePortBridge) ListCheckpoints(ctx context.Context, input gateway.ListCheckpointsInput) ([]gateway.CheckpointEntry, error) { - if b == nil { - return nil, fmt.Errorf(bridgeRuntimeUnavailableErrMsg) + cp, ok := b.runtime.(runtimeCheckpointer) + if !ok { + return nil, fmt.Errorf("gateway runtime bridge: runtime does not support checkpoint operations") } - return nil, fmt.Errorf("checkpoint list: not yet implemented in CLI bridge") + records, err := cp.ListCheckpoints(ctx, strings.TrimSpace(input.SessionID), checkpoint.ListCheckpointOpts{ + Limit: input.Limit, + RestorableOnly: input.RestorableOnly, + }) + if err != nil { + return nil, err + } + entries := make([]gateway.CheckpointEntry, 0, len(records)) + for _, r := range records { + entries = append(entries, gateway.CheckpointEntry{ + CheckpointID: r.CheckpointID, + SessionID: r.SessionID, + Reason: string(r.Reason), + Status: string(r.Status), + Restorable: r.Restorable, + CreatedAt: r.CreatedAt.UnixMilli(), + }) + } + return entries, nil } func (b *gatewayRuntimePortBridge) RestoreCheckpoint(ctx context.Context, input gateway.CheckpointRestoreInput) (gateway.CheckpointRestoreResult, error) { - if b == nil { - return gateway.CheckpointRestoreResult{}, fmt.Errorf(bridgeRuntimeUnavailableErrMsg) + cp, ok := b.runtime.(runtimeCheckpointer) + if !ok { + return gateway.CheckpointRestoreResult{}, fmt.Errorf("gateway runtime bridge: runtime does not support checkpoint operations") } - return gateway.CheckpointRestoreResult{}, fmt.Errorf("checkpoint restore: not yet implemented in CLI bridge") + result, err := cp.RestoreCheckpoint(ctx, agentruntime.GatewayRestoreInput{ + SessionID: strings.TrimSpace(input.SessionID), + CheckpointID: strings.TrimSpace(input.CheckpointID), + Force: input.Force, + }) + if err != nil { + return gateway.CheckpointRestoreResult{}, err + } + return gateway.CheckpointRestoreResult{ + CheckpointID: result.CheckpointID, + SessionID: result.SessionID, + HasConflict: result.Conflict != nil && result.Conflict.HasConflict, + }, nil } func (b *gatewayRuntimePortBridge) UndoRestore(ctx context.Context, input gateway.UndoRestoreInput) (gateway.CheckpointRestoreResult, error) { - if b == nil { - return gateway.CheckpointRestoreResult{}, fmt.Errorf(bridgeRuntimeUnavailableErrMsg) + cp, ok := b.runtime.(runtimeCheckpointer) + if !ok { + return gateway.CheckpointRestoreResult{}, fmt.Errorf("gateway runtime bridge: runtime does not support checkpoint operations") } - return gateway.CheckpointRestoreResult{}, fmt.Errorf("checkpoint undo restore: not yet implemented in CLI bridge") + result, err := cp.UndoRestoreCheckpoint(ctx, strings.TrimSpace(input.SessionID)) + if err != nil { + return gateway.CheckpointRestoreResult{}, err + } + return gateway.CheckpointRestoreResult{ + CheckpointID: result.CheckpointID, + SessionID: result.SessionID, + HasConflict: result.Conflict != nil && result.Conflict.HasConflict, + }, nil } diff --git a/internal/runtime/checkpoint_gate.go b/internal/runtime/checkpoint_gate.go index 0d3ae0ed..5c4c132b 100644 --- a/internal/runtime/checkpoint_gate.go +++ b/internal/runtime/checkpoint_gate.go @@ -8,15 +8,13 @@ import ( "time" "neo-code/internal/checkpoint" - providertypes "neo-code/internal/provider/types" agentsession "neo-code/internal/session" - "neo-code/internal/tools" ) -// createPreWriteCheckpoint 在工具执行前创建 checkpoint,采用两阶段提交。 -// shadowRepo 可用时做完整快照,不可用时降级为 session-only checkpoint。 -// 失败时不阻塞工具执行,仅返回 error 由调用方发 warning event。 -func (s *Service) createPreWriteCheckpoint(ctx context.Context, state *runState) error { +// createPerTurnCheckpoint 在每轮 turn 开始时创建 checkpoint。 +// shadowRepo 可用且有代码变更时做完整快照,否则仅做 session-only 快照。 +// 失败时不阻塞执行,仅返回 error 由调用方发 warning event。 +func (s *Service) createPerTurnCheckpoint(ctx context.Context, state *runState) error { if s.checkpointStore == nil { return nil } @@ -31,9 +29,19 @@ func (s *Service) createPreWriteCheckpoint(ctx context.Context, state *runState) return s.createDegradedCheckpoint(ctx, session, runID) } + // 无代码变更时跳过代码快照,仅做 session-only checkpoint + if !s.shadowRepo.HasCodeChanges(ctx) { + return s.createDegradedCheckpoint(ctx, session, runID) + } + + return s.createFullCheckpoint(ctx, session, runID, state) +} + +// createFullCheckpoint 创建完整 checkpoint(代码快照 + 会话快照)。 +func (s *Service) createFullCheckpoint(ctx context.Context, session agentsession.Session, runID string, state *runState) error { checkpointID := agentsession.NewID("checkpoint") ref := checkpoint.RefForCheckpoint(session.ID, checkpointID) - commitMsg := fmt.Sprintf("pre_write checkpoint for session %s", session.ID) + commitMsg := fmt.Sprintf("per-turn checkpoint for session %s", session.ID) // Phase 1: shadow snapshot commitHash, err := s.shadowRepo.Snapshot(ctx, ref, commitMsg) @@ -99,42 +107,6 @@ func (s *Service) createPreWriteCheckpoint(ctx context.Context, state *runState) return nil } -// toolCallsContainWorkspaceWrite 检查工具调用列表中是否包含会修改工作区的调用。 -func toolCallsContainWorkspaceWrite(calls []providertypes.ToolCall) bool { - for _, call := range calls { - if isWorkspaceWriteToolCall(call) { - return true - } - } - return false -} - -func isWorkspaceWriteToolCall(call providertypes.ToolCall) bool { - switch call.Name { - case tools.ToolNameFilesystemWriteFile, tools.ToolNameFilesystemEdit: - return true - case tools.ToolNameBash: - return isBashWriteCommand(call.Arguments) - } - return false -} - -func isBashWriteCommand(argumentsJSON string) bool { - trimmed := strings.TrimSpace(argumentsJSON) - if trimmed == "" { - return false - } - var args struct { - Command string `json:"command"` - } - if err := json.Unmarshal([]byte(trimmed), &args); err != nil { - return false - } - intent := tools.AnalyzeBashCommand(args.Command) - return intent.Classification == tools.BashIntentClassificationLocalMutation || - intent.Classification == tools.BashIntentClassificationDestructive -} - // createDegradedCheckpoint 创建 session-only checkpoint(无代码快照),用于 shadowRepo 不可用时。 func (s *Service) createDegradedCheckpoint(ctx context.Context, session agentsession.Session, runID string) error { checkpointID := agentsession.NewID("checkpoint") diff --git a/internal/runtime/checkpoint_restore.go b/internal/runtime/checkpoint_restore.go index 81613b45..a3451fdf 100644 --- a/internal/runtime/checkpoint_restore.go +++ b/internal/runtime/checkpoint_restore.go @@ -111,7 +111,7 @@ func (s *Service) RestoreCheckpoint(ctx context.Context, input GatewayRestoreInp allRecords, listErr := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{}) if listErr == nil { for _, r := range allRecords { - if r.CreatedAt.After(record.CreatedAt) && r.Status == agentsession.CheckpointStatusAvailable { + if r.CreatedAt.After(record.CreatedAt) && r.Status == agentsession.CheckpointStatusAvailable && r.Reason != agentsession.CheckpointReasonGuard { markRestoredIDs = append(markRestoredIDs, r.CheckpointID) } } @@ -157,7 +157,7 @@ func (s *Service) UndoRestoreCheckpoint(ctx context.Context, sessionID string) ( // Find latest guard checkpoint records, err := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{ - Limit: 1, + Limit: 20, RestorableOnly: true, }) if err != nil { @@ -259,6 +259,14 @@ func (s *Service) resolveCommitHashForRef(ctx context.Context, ref string) (stri return s.shadowRepo.ResolveRef(ctx, ref) } +// ListCheckpoints 查询指定会话的 checkpoint 列表。 +func (s *Service) ListCheckpoints(ctx context.Context, sessionID string, opts checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) { + if s.checkpointStore == nil { + return nil, fmt.Errorf("checkpoint: store not available") + } + return s.checkpointStore.ListCheckpoints(ctx, sessionID, opts) +} + // updateRuntimeSessionAfterRestore 在 restore 后更新运行时会话状态。 // 当前实现不做运行时状态直接修改,依赖下次 session 加载时从 DB 读取恢复后的状态。 func (s *Service) updateRuntimeSessionAfterRestore(sessionID string, head agentsession.SessionHead, messages []providertypes.Message) { diff --git a/internal/runtime/run.go b/internal/runtime/run.go index bfd6919c..35883f34 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -189,6 +189,14 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { if err := s.setBaseRunState(ctx, &state, baseRunStateForPlanningStage(stage)); err != nil { return s.handleRunError(err) } + if s.checkpointStore != nil { + if cpErr := s.createPerTurnCheckpoint(ctx, &state); cpErr != nil { + s.emitRunScoped(ctx, EventCheckpointWarning, &state, CheckpointWarningPayload{ + Error: cpErr.Error(), + Phase: "per_turn", + }) + } + } turnAttempt: for { @@ -479,14 +487,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { beforeTask := state.session.TaskState.Clone() beforeTodos := cloneTodosForPersistence(state.session.Todos) - if s.checkpointStore != nil && toolCallsContainWorkspaceWrite(turnOutput.assistant.ToolCalls) { - if cpErr := s.createPreWriteCheckpoint(ctx, &state); cpErr != nil { - s.emitRunScoped(ctx, EventCheckpointWarning, &state, CheckpointWarningPayload{ - Error: cpErr.Error(), - Phase: "pre_write", - }) - } - } + if err := s.setBaseRunState(ctx, &state, controlplane.RunStateExecute); err != nil { return s.handleRunError(err) } From ee6464f236d15812f8a23d03136c3247497cca6f Mon Sep 17 00:00:00 2001 From: xgopilot Date: Sun, 3 May 2026 15:55:21 +0000 Subject: [PATCH 6/8] =?UTF-8?q?test(checkpoint):=20=E8=A1=A5=E5=85=85?= =?UTF-8?q?=E5=9B=9E=E6=BB=9A=E9=93=BE=E8=B7=AF=E8=A6=86=E7=9B=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/checkpoint/checkpoint_manager.go | 20 +- .../checkpoint/checkpoint_manager_test.go | 417 ++++++++++++++++++ internal/checkpoint/shadow_repo_test.go | 147 ++++++ internal/cli/gateway_runtime_bridge_test.go | 107 ++++- internal/runtime/checkpoint_flow_test.go | 404 +++++++++++++++++ 5 files changed, 1088 insertions(+), 7 deletions(-) create mode 100644 internal/checkpoint/checkpoint_manager_test.go create mode 100644 internal/checkpoint/shadow_repo_test.go create mode 100644 internal/runtime/checkpoint_flow_test.go diff --git a/internal/checkpoint/checkpoint_manager.go b/internal/checkpoint/checkpoint_manager.go index 6dae9662..d8d4a950 100644 --- a/internal/checkpoint/checkpoint_manager.go +++ b/internal/checkpoint/checkpoint_manager.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "reflect" "sync" "time" @@ -423,7 +424,7 @@ func (s *SQLiteCheckpointStore) RestoreCheckpoint(ctx context.Context, input Res toUnixMillis(input.UpdatedAt), h.Provider, h.Model, h.Workdir, marshalHeadField(h.TaskState), marshalHeadField(h.Todos), marshalHeadField(h.ActivatedSkills), h.TokenInputTotal, h.TokenOutputTotal, boolToInt(h.HasUnknownUsage), h.AgentMode, - marshalHeadField(h.CurrentPlan), h.LastFullPlanRevision, + marshalPlanField(h.CurrentPlan), h.LastFullPlanRevision, boolToInt(h.PlanApprovalPendingFullAlign), boolToInt(h.PlanCompletionPendingFullReview), boolToInt(h.PlanContextDirty), boolToInt(h.PlanRestorePendingAlign), len(input.Messages), len(input.Messages), input.SessionID, @@ -456,12 +457,25 @@ func (s *SQLiteCheckpointStore) RestoreCheckpoint(ctx context.Context, input Res } func marshalHeadField(value any) string { + data, err := json.Marshal(value) + if err != nil { + return "null" + } + return string(data) +} + +// marshalPlanField 将可选计划字段编码为 session 兼容的持久化格式,nil 计划统一写为空串。 +func marshalPlanField(value any) string { if value == nil { - return "{}" + return "" + } + rv := reflect.ValueOf(value) + if rv.Kind() == reflect.Pointer && rv.IsNil() { + return "" } data, err := json.Marshal(value) if err != nil { - return "{}" + return "" } return string(data) } diff --git a/internal/checkpoint/checkpoint_manager_test.go b/internal/checkpoint/checkpoint_manager_test.go new file mode 100644 index 00000000..b8986d2f --- /dev/null +++ b/internal/checkpoint/checkpoint_manager_test.go @@ -0,0 +1,417 @@ +package checkpoint + +import ( + "context" + "encoding/json" + "path/filepath" + "testing" + "time" + + providertypes "neo-code/internal/provider/types" + "neo-code/internal/session" +) + +type checkpointStoreFixture struct { + sessionStore *session.SQLiteStore + checkpointStore *SQLiteCheckpointStore + baseDir string + workspaceRoot string +} + +func newCheckpointStoreFixture(t *testing.T) checkpointStoreFixture { + t.Helper() + + baseDir := t.TempDir() + workspaceRoot := t.TempDir() + + sessionStore := session.NewSQLiteStore(baseDir, workspaceRoot) + t.Cleanup(func() { + _ = sessionStore.Close() + }) + + checkpointStore := NewSQLiteCheckpointStore(session.DatabasePath(baseDir, workspaceRoot)) + t.Cleanup(func() { + _ = checkpointStore.Close() + }) + + return checkpointStoreFixture{ + sessionStore: sessionStore, + checkpointStore: checkpointStore, + baseDir: baseDir, + workspaceRoot: workspaceRoot, + } +} + +func createCheckpointTestSession(t *testing.T, store *session.SQLiteStore, id string, workdir string) session.Session { + t.Helper() + + created, err := store.CreateSession(context.Background(), session.CreateSessionInput{ + ID: id, + Title: "checkpoint test", + Head: session.SessionHead{ + Provider: "openai", + Model: "gpt-test", + Workdir: workdir, + TaskState: session.TaskState{ + Goal: "before restore", + VerificationProfile: session.VerificationProfileTaskOnly, + }, + Todos: []session.TodoItem{ + {ID: "todo-1", Content: "before restore"}, + }, + }, + }) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + + messages := []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("before restore"), + }, + }, + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("tool planned"), + }, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-1", Name: "bash", Arguments: `{"cmd":"pwd"}`}, + }, + ToolMetadata: map[string]string{"source": "test"}, + }, + } + if err := store.AppendMessages(context.Background(), session.AppendMessagesInput{ + SessionID: created.ID, + Messages: messages, + UpdatedAt: time.Now(), + Provider: "openai", + Model: "gpt-test", + Workdir: workdir, + }); err != nil { + t.Fatalf("AppendMessages() error = %v", err) + } + + loaded, err := store.LoadSession(context.Background(), created.ID) + if err != nil { + t.Fatalf("LoadSession() error = %v", err) + } + return loaded +} + +func checkpointInputFromSession(t *testing.T, loaded session.Session, checkpointID string, reason session.CheckpointReason, createdAt time.Time) CreateCheckpointInput { + t.Helper() + + headJSON, err := json.Marshal(loaded.HeadSnapshot()) + if err != nil { + t.Fatalf("Marshal(head) error = %v", err) + } + messagesJSON, err := json.Marshal(loaded.Messages) + if err != nil { + t.Fatalf("Marshal(messages) error = %v", err) + } + + return CreateCheckpointInput{ + Record: session.CheckpointRecord{ + CheckpointID: checkpointID, + WorkspaceKey: session.WorkspacePathKey(loaded.Workdir), + SessionID: loaded.ID, + RunID: "run-" + checkpointID, + Workdir: loaded.Workdir, + CreatedAt: createdAt, + Reason: reason, + CodeCheckpointRef: RefForCheckpoint(loaded.ID, checkpointID), + Restorable: true, + Status: session.CheckpointStatusCreating, + }, + SessionCP: session.SessionCheckpoint{ + ID: "sc-" + checkpointID, + SessionID: loaded.ID, + HeadJSON: string(headJSON), + MessagesJSON: string(messagesJSON), + CreatedAt: createdAt, + }, + } +} + +func TestSQLiteCheckpointStoreCreateRestoreAndResume(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_restore", fixture.workspaceRoot) + checkpointCreatedAt := time.Now().Add(-time.Minute) + + input := checkpointInputFromSession(t, loaded, "cp-restore", session.CheckpointReasonPreWrite, checkpointCreatedAt) + saved, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), input) + if err != nil { + t.Fatalf("CreateCheckpoint() error = %v", err) + } + if saved.SessionCheckpointRef == "" || saved.Status != session.CheckpointStatusAvailable { + t.Fatalf("CreateCheckpoint() = %#v, want available checkpoint with session ref", saved) + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), loaded.ID, ListCheckpointOpts{ + Limit: 10, + RestorableOnly: true, + }) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 || records[0].CheckpointID != saved.CheckpointID { + t.Fatalf("ListCheckpoints() = %#v, want only %q", records, saved.CheckpointID) + } + + record, sessionCP, err := fixture.checkpointStore.GetCheckpoint(context.Background(), saved.CheckpointID) + if err != nil { + t.Fatalf("GetCheckpoint() error = %v", err) + } + if record.CheckpointID != saved.CheckpointID || sessionCP == nil || sessionCP.ID != saved.SessionCheckpointRef { + t.Fatalf("GetCheckpoint() = (%#v, %#v), want saved record and session snapshot", record, sessionCP) + } + + if err := fixture.sessionStore.UpdateSessionState(context.Background(), session.UpdateSessionStateInput{ + SessionID: loaded.ID, + UpdatedAt: time.Now(), + Title: "mutated", + Head: session.SessionHead{ + Provider: "openai", + Model: "gpt-test", + Workdir: loaded.Workdir, + TaskState: session.TaskState{ + Goal: "after restore", + VerificationProfile: session.VerificationProfileTaskOnly, + }, + Todos: []session.TodoItem{ + {ID: "todo-2", Content: "after restore"}, + }, + }, + }); err != nil { + t.Fatalf("UpdateSessionState() error = %v", err) + } + if err := fixture.sessionStore.AppendMessages(context.Background(), session.AppendMessagesInput{ + SessionID: loaded.ID, + Messages: []providertypes.Message{ + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("after restore"), + }, + }, + }, + UpdatedAt: time.Now(), + Provider: "openai", + Model: "gpt-test", + Workdir: loaded.Workdir, + }); err != nil { + t.Fatalf("AppendMessages(after) error = %v", err) + } + + if err := fixture.checkpointStore.RestoreCheckpoint(context.Background(), RestoreCheckpointInput{ + SessionID: loaded.ID, + Head: loaded.HeadSnapshot(), + Messages: loaded.Messages, + UpdatedAt: time.Now(), + }); err != nil { + t.Fatalf("RestoreCheckpoint() error = %v", err) + } + + restored, err := fixture.sessionStore.LoadSession(context.Background(), loaded.ID) + if err != nil { + t.Fatalf("LoadSession(restored) error = %v", err) + } + if restored.TaskState.Goal != loaded.TaskState.Goal { + t.Fatalf("restored goal = %q, want %q", restored.TaskState.Goal, loaded.TaskState.Goal) + } + if len(restored.Messages) != len(loaded.Messages) { + t.Fatalf("restored message count = %d, want %d", len(restored.Messages), len(loaded.Messages)) + } + if restored.Messages[1].ToolMetadata["source"] != "test" { + t.Fatalf("restored tool metadata = %#v, want preserved metadata", restored.Messages[1].ToolMetadata) + } + + if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), saved.CheckpointID, session.CheckpointStatusRestored); err != nil { + t.Fatalf("UpdateCheckpointStatus() error = %v", err) + } + filtered, err := fixture.checkpointStore.ListCheckpoints(context.Background(), loaded.ID, ListCheckpointOpts{ + RestorableOnly: true, + }) + if err != nil { + t.Fatalf("ListCheckpoints(filtered) error = %v", err) + } + if len(filtered) != 0 { + t.Fatalf("expected no restorable checkpoints after status change, got %#v", filtered) + } + + firstResume := session.ResumeCheckpoint{ + ID: "rc-1", + WorkspaceKey: session.WorkspacePathKey(loaded.Workdir), + RunID: "run-1", + SessionID: loaded.ID, + Turn: 1, + Phase: "plan", + CompletionState: "running", + TranscriptRevision: 3, + UpdatedAt: time.Now().Add(-time.Minute), + } + secondResume := firstResume + secondResume.ID = "rc-2" + secondResume.RunID = "run-2" + secondResume.Turn = 2 + secondResume.Phase = "execute" + secondResume.UpdatedAt = time.Now() + + if err := fixture.checkpointStore.SetResumeCheckpoint(context.Background(), firstResume); err != nil { + t.Fatalf("SetResumeCheckpoint(first) error = %v", err) + } + if err := fixture.checkpointStore.SetResumeCheckpoint(context.Background(), secondResume); err != nil { + t.Fatalf("SetResumeCheckpoint(second) error = %v", err) + } + gotResume, err := fixture.checkpointStore.GetLatestResumeCheckpoint(context.Background(), loaded.ID) + if err != nil { + t.Fatalf("GetLatestResumeCheckpoint() error = %v", err) + } + if gotResume == nil || gotResume.ID != secondResume.ID || gotResume.Turn != secondResume.Turn { + t.Fatalf("GetLatestResumeCheckpoint() = %#v, want %#v", gotResume, secondResume) + } +} + +func TestSQLiteCheckpointStorePruneAndRepair(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_prune", fixture.workspaceRoot) + + createdAt := time.Now().Add(-10 * time.Minute) + for i := 0; i < 4; i++ { + checkpointID := "cp-auto-" + string(rune('a'+i)) + input := checkpointInputFromSession(t, loaded, checkpointID, session.CheckpointReasonPreWrite, createdAt.Add(time.Duration(i)*time.Minute)) + if _, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), input); err != nil { + t.Fatalf("CreateCheckpoint(%s) error = %v", checkpointID, err) + } + } + if _, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), checkpointInputFromSession(t, loaded, "cp-manual", session.CheckpointReasonManual, time.Now())); err != nil { + t.Fatalf("CreateCheckpoint(manual) error = %v", err) + } + if _, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), checkpointInputFromSession(t, loaded, "cp-guard", session.CheckpointReasonGuard, time.Now().Add(time.Minute))); err != nil { + t.Fatalf("CreateCheckpoint(guard) error = %v", err) + } + + pruned, err := fixture.checkpointStore.PruneExpiredCheckpoints(context.Background(), loaded.ID, 2) + if err != nil { + t.Fatalf("PruneExpiredCheckpoints() error = %v", err) + } + if pruned != 2 { + t.Fatalf("PruneExpiredCheckpoints() = %d, want 2", pruned) + } + + prunedRecord, prunedSessionCP, err := fixture.checkpointStore.GetCheckpoint(context.Background(), "cp-auto-a") + if err != nil { + t.Fatalf("GetCheckpoint(pruned) error = %v", err) + } + if prunedRecord.Status != session.CheckpointStatusPruned || prunedRecord.Restorable { + t.Fatalf("pruned record = %#v, want pruned and not restorable", prunedRecord) + } + if prunedSessionCP != nil { + t.Fatalf("expected pruned session snapshot to be deleted, got %#v", prunedSessionCP) + } + + manualRecord, _, err := fixture.checkpointStore.GetCheckpoint(context.Background(), "cp-manual") + if err != nil { + t.Fatalf("GetCheckpoint(manual) error = %v", err) + } + if manualRecord.Status != session.CheckpointStatusAvailable || !manualRecord.Restorable { + t.Fatalf("manual record = %#v, want still available", manualRecord) + } + + db, err := fixture.checkpointStore.ensureDB(context.Background()) + if err != nil { + t.Fatalf("ensureDB() error = %v", err) + } + withSessionCPID := "cp-creating-with-session" + if _, err := db.ExecContext(context.Background(), ` +INSERT INTO session_checkpoints (id, session_id, head_json, messages_json, created_at_ms) +VALUES (?, ?, ?, ?, ?) +`, "sc-creating", loaded.ID, `{}`, `[]`, time.Now().UnixMilli()); err != nil { + t.Fatalf("insert session_checkpoint error = %v", err) + } + if _, err := db.ExecContext(context.Background(), ` +INSERT INTO checkpoint_records ( + id, workspace_key, session_id, run_id, workdir, created_at_ms, + reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref, + transcript_revision, restorable, status +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +`, + withSessionCPID, + session.WorkspacePathKey(loaded.Workdir), + loaded.ID, + "run-repair", + loaded.Workdir, + time.Now().UnixMilli(), + string(session.CheckpointReasonPreWrite), + "", + "sc-creating", + "", + 0, + 1, + string(session.CheckpointStatusCreating), + ); err != nil { + t.Fatalf("insert creating checkpoint with session ref error = %v", err) + } + + orphanID := "cp-creating-orphan" + if _, err := db.ExecContext(context.Background(), ` +INSERT INTO checkpoint_records ( + id, workspace_key, session_id, run_id, workdir, created_at_ms, + reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref, + transcript_revision, restorable, status +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +`, + orphanID, + session.WorkspacePathKey(loaded.Workdir), + loaded.ID, + "run-repair", + loaded.Workdir, + time.Now().UnixMilli(), + string(session.CheckpointReasonPreWrite), + "", + "", + "", + 0, + 1, + string(session.CheckpointStatusCreating), + ); err != nil { + t.Fatalf("insert orphan checkpoint error = %v", err) + } + + repaired, err := fixture.checkpointStore.RepairCreatingCheckpoints(context.Background()) + if err != nil { + t.Fatalf("RepairCreatingCheckpoints() error = %v", err) + } + if repaired != 2 { + t.Fatalf("RepairCreatingCheckpoints() = %d, want 2", repaired) + } + + repairedRecord, _, err := fixture.checkpointStore.GetCheckpoint(context.Background(), withSessionCPID) + if err != nil { + t.Fatalf("GetCheckpoint(repaired) error = %v", err) + } + if repairedRecord.Status != session.CheckpointStatusAvailable { + t.Fatalf("repaired record status = %q, want available", repairedRecord.Status) + } + + if _, _, err := fixture.checkpointStore.GetCheckpoint(context.Background(), orphanID); err == nil { + t.Fatalf("expected orphan creating checkpoint to be deleted") + } +} + +func TestSQLiteCheckpointStoreUsesSessionDatabasePath(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + expected := filepath.Clean(session.DatabasePath(fixture.baseDir, fixture.workspaceRoot)) + if filepath.Clean(fixture.checkpointStore.dbPath) != expected { + t.Fatalf("dbPath = %q, want %q", fixture.checkpointStore.dbPath, expected) + } +} diff --git a/internal/checkpoint/shadow_repo_test.go b/internal/checkpoint/shadow_repo_test.go new file mode 100644 index 00000000..109ddce5 --- /dev/null +++ b/internal/checkpoint/shadow_repo_test.go @@ -0,0 +1,147 @@ +package checkpoint + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestShadowRepoSnapshotRestoreAndConflictDetection(t *testing.T) { + t.Parallel() + + if available, _ := CheckGitAvailability(context.Background()); !available { + t.Skip("git is not available in test environment") + } + + projectDir := t.TempDir() + workdir := t.TempDir() + repo := NewShadowRepo(projectDir, workdir) + if err := repo.Init(context.Background()); err != nil { + t.Fatalf("Init() error = %v", err) + } + + targetFile := filepath.Join(workdir, "main.go") + if err := os.WriteFile(targetFile, []byte("package main\nconst version = 1\n"), 0o644); err != nil { + t.Fatalf("WriteFile(version1) error = %v", err) + } + + refOne := RefForCheckpoint("session-1", "cp-1") + hashOne, err := repo.Snapshot(context.Background(), refOne, "snapshot one") + if err != nil { + t.Fatalf("Snapshot(first) error = %v", err) + } + if strings.TrimSpace(hashOne) == "" { + t.Fatalf("Snapshot(first) returned empty hash") + } + + if repo.HasCodeChanges(context.Background()) { + t.Fatalf("expected clean worktree after first snapshot") + } + + if err := os.WriteFile(targetFile, []byte("package main\nconst version = 2\n"), 0o644); err != nil { + t.Fatalf("WriteFile(version2) error = %v", err) + } + if !repo.HasCodeChanges(context.Background()) { + t.Fatalf("expected HasCodeChanges() to detect modified file") + } + + refTwo := RefForCheckpoint("session-1", "cp-2") + if _, err := repo.Snapshot(context.Background(), refTwo, "snapshot two"); err != nil { + t.Fatalf("Snapshot(second) error = %v", err) + } + + resolved, err := repo.ResolveRef(context.Background(), refOne) + if err != nil { + t.Fatalf("ResolveRef() error = %v", err) + } + if resolved != hashOne { + t.Fatalf("ResolveRef() = %q, want %q", resolved, hashOne) + } + + conflict, err := repo.DetectConflicts(context.Background(), hashOne) + if err != nil { + t.Fatalf("DetectConflicts() error = %v", err) + } + if !conflict.HasConflict || len(conflict.ModifiedFiles) != 1 || conflict.ModifiedFiles[0] != "main.go" { + t.Fatalf("DetectConflicts() = %#v, want modified main.go", conflict) + } + + if err := repo.Restore(context.Background(), hashOne); err != nil { + t.Fatalf("Restore() error = %v", err) + } + content, err := os.ReadFile(targetFile) + if err != nil { + t.Fatalf("ReadFile(restored) error = %v", err) + } + if !strings.Contains(string(content), "version = 1") { + t.Fatalf("restored content = %q, want version 1", string(content)) + } + + if err := repo.HealthCheck(context.Background()); err != nil { + t.Fatalf("HealthCheck() error = %v", err) + } +} + +func TestShadowRepoInitRebuildsDamagedRepository(t *testing.T) { + t.Parallel() + + if available, _ := CheckGitAvailability(context.Background()); !available { + t.Skip("git is not available in test environment") + } + + projectDir := t.TempDir() + workdir := t.TempDir() + shadowDir := filepath.Join(projectDir, ".shadow") + if err := os.MkdirAll(shadowDir, 0o755); err != nil { + t.Fatalf("MkdirAll(shadowDir) error = %v", err) + } + if err := os.WriteFile(filepath.Join(shadowDir, "corrupted"), []byte("not a git dir"), 0o644); err != nil { + t.Fatalf("WriteFile(corrupted) error = %v", err) + } + + repo := NewShadowRepo(projectDir, workdir) + if err := repo.Init(context.Background()); err != nil { + t.Fatalf("Init() error = %v", err) + } + if !repo.IsAvailable() { + t.Fatalf("expected repo to be available after rebuild") + } + + backups, err := filepath.Glob(shadowDir + ".bak.*") + if err != nil { + t.Fatalf("Glob() error = %v", err) + } + if len(backups) == 0 { + t.Fatalf("expected damaged shadow repo backup to be created") + } + + if err := repo.Rebuild(context.Background()); err != nil { + t.Fatalf("Rebuild() error = %v", err) + } + backups, err = filepath.Glob(shadowDir + ".bak.*") + if err != nil { + t.Fatalf("Glob(after rebuild) error = %v", err) + } + if len(backups) < 2 { + t.Fatalf("expected rebuild to create another backup, got %v", backups) + } +} + +func TestShadowRepoHelpers(t *testing.T) { + t.Parallel() + + ref := RefForCheckpoint("session-a", "checkpoint-b") + if ref != "refs/neocode/sessions/session-a/checkpoints/checkpoint-b" { + t.Fatalf("RefForCheckpoint() = %q", ref) + } + + repo := NewShadowRepo(t.TempDir(), t.TempDir()) + if repo.HasCodeChanges(context.Background()) != true { + t.Fatalf("expected unavailable shadow repo to conservatively report code changes") + } + if err := repo.DeleteRef(context.Background(), "refs/unused"); err != nil { + t.Fatalf("DeleteRef() on unavailable repo error = %v", err) + } +} diff --git a/internal/cli/gateway_runtime_bridge_test.go b/internal/cli/gateway_runtime_bridge_test.go index b936544d..893279bc 100644 --- a/internal/cli/gateway_runtime_bridge_test.go +++ b/internal/cli/gateway_runtime_bridge_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "neo-code/internal/checkpoint" "neo-code/internal/config" configstate "neo-code/internal/config/state" "neo-code/internal/gateway" @@ -63,6 +64,16 @@ type runtimeStub struct { getSnapshotSessionID string getSnapshotResult agentruntime.RuntimeSnapshot getSnapshotErr error + listCheckpointsID string + listCheckpointsOpts checkpoint.ListCheckpointOpts + listCheckpointsResult []agentsession.CheckpointRecord + listCheckpointsErr error + restoreCheckpointIn agentruntime.GatewayRestoreInput + restoreCheckpointOut agentruntime.RestoreResult + restoreCheckpointErr error + undoRestoreSessionID string + undoRestoreOut agentruntime.RestoreResult + undoRestoreErr error } const testBridgeSubjectID = bridgeLocalSubjectID @@ -150,6 +161,19 @@ func (s *runtimeStub) GetRuntimeSnapshot(_ context.Context, sessionID string) (a s.getSnapshotSessionID = sessionID return s.getSnapshotResult, s.getSnapshotErr } +func (s *runtimeStub) ListCheckpoints(_ context.Context, sessionID string, opts checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) { + s.listCheckpointsID = sessionID + s.listCheckpointsOpts = opts + return s.listCheckpointsResult, s.listCheckpointsErr +} +func (s *runtimeStub) RestoreCheckpoint(_ context.Context, input agentruntime.GatewayRestoreInput) (agentruntime.RestoreResult, error) { + s.restoreCheckpointIn = input + return s.restoreCheckpointOut, s.restoreCheckpointErr +} +func (s *runtimeStub) UndoRestoreCheckpoint(_ context.Context, sessionID string) (agentruntime.RestoreResult, error) { + s.undoRestoreSessionID = sessionID + return s.undoRestoreOut, s.undoRestoreErr +} func (s *runtimeStub) DeleteSession(_ context.Context, _ string) error { return nil } @@ -207,6 +231,15 @@ func (r *runtimeWithoutCreator) ListAvailableSkills( ) ([]agentruntime.AvailableSkillState, error) { return r.base.ListAvailableSkills(ctx, sessionID) } +func (r *runtimeWithoutCreator) ListCheckpoints(ctx context.Context, sessionID string, opts checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) { + return r.base.ListCheckpoints(ctx, sessionID, opts) +} +func (r *runtimeWithoutCreator) RestoreCheckpoint(ctx context.Context, input agentruntime.GatewayRestoreInput) (agentruntime.RestoreResult, error) { + return r.base.RestoreCheckpoint(ctx, input) +} +func (r *runtimeWithoutCreator) UndoRestoreCheckpoint(ctx context.Context, sessionID string) (agentruntime.RestoreResult, error) { + return r.base.UndoRestoreCheckpoint(ctx, sessionID) +} type bridgeSessionStoreStub struct { deleteFn func(ctx context.Context, id string) error @@ -226,6 +259,72 @@ func (s *bridgeSessionStoreStub) UpdateSessionState(ctx context.Context, input a return nil } +func TestGatewayRuntimePortBridgeCheckpointOperations(t *testing.T) { + stub := &runtimeStub{ + listCheckpointsResult: []agentsession.CheckpointRecord{ + { + CheckpointID: "cp-1", + SessionID: "session-1", + Reason: agentsession.CheckpointReasonCompact, + Status: agentsession.CheckpointStatusAvailable, + Restorable: true, + CreatedAt: time.UnixMilli(1234), + }, + }, + restoreCheckpointOut: agentruntime.RestoreResult{ + CheckpointID: "cp-1", + SessionID: "session-1", + }, + undoRestoreOut: agentruntime.RestoreResult{ + CheckpointID: "guard-1", + SessionID: "session-1", + }, + } + + bridge := &gatewayRuntimePortBridge{runtime: stub} + + entries, err := bridge.ListCheckpoints(context.Background(), gateway.ListCheckpointsInput{ + SessionID: " session-1 ", + Limit: 5, + RestorableOnly: true, + }) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if stub.listCheckpointsID != "session-1" || stub.listCheckpointsOpts.Limit != 5 || !stub.listCheckpointsOpts.RestorableOnly { + t.Fatalf("ListCheckpoints() forwarded (%q, %#v)", stub.listCheckpointsID, stub.listCheckpointsOpts) + } + if len(entries) != 1 || entries[0].CheckpointID != "cp-1" || entries[0].Reason != string(agentsession.CheckpointReasonCompact) { + t.Fatalf("ListCheckpoints() = %#v", entries) + } + + restoreResult, err := bridge.RestoreCheckpoint(context.Background(), gateway.CheckpointRestoreInput{ + SessionID: " session-1 ", + CheckpointID: " cp-1 ", + Force: true, + }) + if err != nil { + t.Fatalf("RestoreCheckpoint() error = %v", err) + } + if stub.restoreCheckpointIn.SessionID != "session-1" || stub.restoreCheckpointIn.CheckpointID != "cp-1" || !stub.restoreCheckpointIn.Force { + t.Fatalf("RestoreCheckpoint() forwarded %#v", stub.restoreCheckpointIn) + } + if restoreResult.CheckpointID != "cp-1" || restoreResult.SessionID != "session-1" || restoreResult.HasConflict { + t.Fatalf("RestoreCheckpoint() = %#v", restoreResult) + } + + undoResult, err := bridge.UndoRestore(context.Background(), gateway.UndoRestoreInput{SessionID: " session-1 "}) + if err != nil { + t.Fatalf("UndoRestore() error = %v", err) + } + if stub.undoRestoreSessionID != "session-1" { + t.Fatalf("UndoRestore() forwarded session %q", stub.undoRestoreSessionID) + } + if undoResult.CheckpointID != "guard-1" || undoResult.SessionID != "session-1" { + t.Fatalf("UndoRestore() = %#v", undoResult) + } +} + var testSessionStore bridgeSessionStore = &bridgeSessionStoreStub{} func TestNewGatewayRuntimePortBridgeRuntimeUnavailable(t *testing.T) { @@ -1351,7 +1450,7 @@ func TestResolveListFilesRootPriorities(t *testing.T) { // priority 2: session workdir (store implements bridgeSessionLoader) loaderStore := &bridgeSessionStoreWithLoader{ bridgeSessionStoreStub: bridgeSessionStoreStub{}, - session: agentsession.Session{Workdir: subDir}, + session: agentsession.Session{Workdir: subDir}, } bridge2, _ := newGatewayRuntimePortBridge(context.Background(), &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, loaderStore) defer bridge2.Close() @@ -1958,9 +2057,9 @@ func TestModelDisplayName(t *testing.T) { func TestGatewayRuntimePortBridgeCancelRunAndSnapshots(t *testing.T) { stub := &runtimeStub{ - eventsCh: make(chan agentruntime.RuntimeEvent, 1), - cancelReturn: true, - listTodosErr: errors.New("todo failed"), + eventsCh: make(chan agentruntime.RuntimeEvent, 1), + cancelReturn: true, + listTodosErr: errors.New("todo failed"), getSnapshotErr: errors.New("snapshot failed"), } bridge, err := newGatewayRuntimePortBridge(context.Background(), stub, testSessionStore) diff --git a/internal/runtime/checkpoint_flow_test.go b/internal/runtime/checkpoint_flow_test.go new file mode 100644 index 00000000..d9e81a57 --- /dev/null +++ b/internal/runtime/checkpoint_flow_test.go @@ -0,0 +1,404 @@ +package runtime + +import ( + "context" + "encoding/json" + "os" + "testing" + "time" + + "neo-code/internal/checkpoint" + providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" +) + +type checkpointStoreSpy struct { + lastResume agentsession.ResumeCheckpoint +} + +func (s *checkpointStoreSpy) CreateCheckpoint(context.Context, checkpoint.CreateCheckpointInput) (agentsession.CheckpointRecord, error) { + return agentsession.CheckpointRecord{}, nil +} + +func (s *checkpointStoreSpy) ListCheckpoints(context.Context, string, checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) { + return nil, nil +} + +func (s *checkpointStoreSpy) GetCheckpoint(context.Context, string) (agentsession.CheckpointRecord, *agentsession.SessionCheckpoint, error) { + return agentsession.CheckpointRecord{}, nil, nil +} + +func (s *checkpointStoreSpy) UpdateCheckpointStatus(context.Context, string, agentsession.CheckpointStatus) error { + return nil +} + +func (s *checkpointStoreSpy) GetLatestResumeCheckpoint(context.Context, string) (*agentsession.ResumeCheckpoint, error) { + return nil, nil +} + +func (s *checkpointStoreSpy) RestoreCheckpoint(context.Context, checkpoint.RestoreCheckpointInput) error { + return nil +} + +func (s *checkpointStoreSpy) SetResumeCheckpoint(_ context.Context, rc agentsession.ResumeCheckpoint) error { + s.lastResume = rc + return nil +} + +func (s *checkpointStoreSpy) PruneExpiredCheckpoints(context.Context, string, int) (int, error) { + return 0, nil +} + +func (s *checkpointStoreSpy) RepairCreatingCheckpoints(context.Context) (int, error) { + return 0, nil +} + +type runtimeCheckpointFixture struct { + service *Service + sessionStore *agentsession.SQLiteStore + checkpointStore *checkpoint.SQLiteCheckpointStore + shadowRepo *checkpoint.ShadowRepo + workdir string + projectDir string + session agentsession.Session +} + +func newRuntimeCheckpointFixture(t *testing.T, withShadow bool) runtimeCheckpointFixture { + t.Helper() + + baseDir := t.TempDir() + workdir := t.TempDir() + projectDir := t.TempDir() + + sessionStore := agentsession.NewSQLiteStore(baseDir, workdir) + t.Cleanup(func() { + _ = sessionStore.Close() + }) + + checkpointStore := checkpoint.NewSQLiteCheckpointStore(agentsession.DatabasePath(baseDir, workdir)) + t.Cleanup(func() { + _ = checkpointStore.Close() + }) + + created, err := sessionStore.CreateSession(context.Background(), agentsession.CreateSessionInput{ + ID: "runtime-checkpoint-session", + Title: "runtime checkpoint", + Head: agentsession.SessionHead{ + Provider: "openai", + Model: "gpt-test", + Workdir: workdir, + TaskState: agentsession.TaskState{ + Goal: "initial goal", + VerificationProfile: agentsession.VerificationProfileTaskOnly, + }, + }, + }) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + if err := sessionStore.AppendMessages(context.Background(), agentsession.AppendMessagesInput{ + SessionID: created.ID, + Messages: []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("before restore"), + }, + }, + }, + UpdatedAt: time.Now(), + Provider: "openai", + Model: "gpt-test", + Workdir: workdir, + }); err != nil { + t.Fatalf("AppendMessages() error = %v", err) + } + loaded, err := sessionStore.LoadSession(context.Background(), created.ID) + if err != nil { + t.Fatalf("LoadSession() error = %v", err) + } + + var shadowRepo *checkpoint.ShadowRepo + if withShadow { + shadowRepo = checkpoint.NewShadowRepo(projectDir, workdir) + if err := shadowRepo.Init(context.Background()); err != nil { + t.Fatalf("Init shadow repo error = %v", err) + } + } + + return runtimeCheckpointFixture{ + service: &Service{ + sessionStore: sessionStore, + checkpointStore: checkpointStore, + shadowRepo: shadowRepo, + events: make(chan RuntimeEvent, 32), + }, + sessionStore: sessionStore, + checkpointStore: checkpointStore, + shadowRepo: shadowRepo, + workdir: workdir, + projectDir: projectDir, + session: loaded, + } +} + +func createStoredCheckpointFromSession( + t *testing.T, + cpStore *checkpoint.SQLiteCheckpointStore, + shadowRepo *checkpoint.ShadowRepo, + loaded agentsession.Session, + checkpointID string, +) agentsession.CheckpointRecord { + t.Helper() + + headJSON, err := json.Marshal(loaded.HeadSnapshot()) + if err != nil { + t.Fatalf("Marshal(head) error = %v", err) + } + messagesJSON, err := json.Marshal(loaded.Messages) + if err != nil { + t.Fatalf("Marshal(messages) error = %v", err) + } + + ref := checkpoint.RefForCheckpoint(loaded.ID, checkpointID) + if _, err := shadowRepo.Snapshot(context.Background(), ref, checkpointID); err != nil { + t.Fatalf("Snapshot(%s) error = %v", checkpointID, err) + } + + record, err := cpStore.CreateCheckpoint(context.Background(), checkpoint.CreateCheckpointInput{ + Record: agentsession.CheckpointRecord{ + CheckpointID: checkpointID, + WorkspaceKey: agentsession.WorkspacePathKey(loaded.Workdir), + SessionID: loaded.ID, + RunID: "run-" + checkpointID, + Workdir: loaded.Workdir, + CreatedAt: time.Now().Add(-time.Minute), + Reason: agentsession.CheckpointReasonPreWrite, + CodeCheckpointRef: ref, + Restorable: true, + Status: agentsession.CheckpointStatusCreating, + }, + SessionCP: agentsession.SessionCheckpoint{ + ID: "sc-" + checkpointID, + SessionID: loaded.ID, + HeadJSON: string(headJSON), + MessagesJSON: string(messagesJSON), + CreatedAt: time.Now().Add(-time.Minute), + }, + }) + if err != nil { + t.Fatalf("CreateCheckpoint(%s) error = %v", checkpointID, err) + } + return record +} + +func TestCreatePerTurnCheckpointVariants(t *testing.T) { + t.Run("full checkpoint when code changed", func(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t, true) + target := fixture.workdir + "/main.go" + if err := os.WriteFile(target, []byte("package main\nconst value = 1\n"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + if _, err := fixture.shadowRepo.Snapshot(context.Background(), "refs/heads/base", "baseline"); err != nil { + t.Fatalf("Snapshot(baseline) error = %v", err) + } + if err := os.WriteFile(target, []byte("package main\nconst value = 2\n"), 0o644); err != nil { + t.Fatalf("WriteFile(modified) error = %v", err) + } + + state := newRunState("run-full", fixture.session) + if err := fixture.service.createPerTurnCheckpoint(context.Background(), &state); err != nil { + t.Fatalf("createPerTurnCheckpoint() error = %v", err) + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 || records[0].Reason != agentsession.CheckpointReasonPreWrite || records[0].CodeCheckpointRef == "" { + t.Fatalf("records = %#v, want one full checkpoint", records) + } + }) + + t.Run("degraded checkpoint when repo unavailable", func(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t, false) + state := newRunState("run-degraded", fixture.session) + if err := fixture.service.createPerTurnCheckpoint(context.Background(), &state); err != nil { + t.Fatalf("createPerTurnCheckpoint() error = %v", err) + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 || records[0].Reason != agentsession.CheckpointReasonPreWriteDegraded || records[0].CodeCheckpointRef != "" { + t.Fatalf("records = %#v, want one degraded checkpoint", records) + } + }) + + t.Run("degraded checkpoint when no code changes", func(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t, true) + target := fixture.workdir + "/main.go" + if err := os.WriteFile(target, []byte("package main\nconst value = 1\n"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + if _, err := fixture.shadowRepo.Snapshot(context.Background(), "refs/heads/base", "baseline"); err != nil { + t.Fatalf("Snapshot(baseline) error = %v", err) + } + + state := newRunState("run-noop", fixture.session) + if err := fixture.service.createPerTurnCheckpoint(context.Background(), &state); err != nil { + t.Fatalf("createPerTurnCheckpoint() error = %v", err) + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 || records[0].Reason != agentsession.CheckpointReasonPreWriteDegraded || records[0].CodeCheckpointRef != "" { + t.Fatalf("records = %#v, want session-only checkpoint for no-op turn", records) + } + }) +} + +func TestCreateCompactCheckpointAndResumeCheckpoint(t *testing.T) { + t.Parallel() + + fixture := newRuntimeCheckpointFixture(t, true) + if err := os.WriteFile(fixture.workdir+"/compact.txt", []byte("compact"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + fixture.service.createCompactCheckpoint(context.Background(), "run-compact", fixture.session) + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 || records[0].Reason != agentsession.CheckpointReasonCompact { + t.Fatalf("records = %#v, want compact checkpoint", records) + } + + state := newRunState("run-resume", fixture.session) + state.turn = 3 + spy := &checkpointStoreSpy{} + service := &Service{checkpointStore: spy} + service.updateResumeCheckpoint(context.Background(), &state, "verify", "running") + + if spy.lastResume.SessionID != fixture.session.ID || spy.lastResume.RunID != "run-resume" || spy.lastResume.Turn != 3 || spy.lastResume.Phase != "verify" { + t.Fatalf("SetResumeCheckpoint() captured %#v", spy.lastResume) + } +} + +func TestRestoreCheckpointAndUndoRestore(t *testing.T) { + t.Parallel() + + fixture := newRuntimeCheckpointFixture(t, true) + target := fixture.workdir + "/restore.txt" + if err := os.WriteFile(target, []byte("version one"), 0o644); err != nil { + t.Fatalf("WriteFile(version one) error = %v", err) + } + + originalSession, err := fixture.sessionStore.LoadSession(context.Background(), fixture.session.ID) + if err != nil { + t.Fatalf("LoadSession(original) error = %v", err) + } + record := createStoredCheckpointFromSession(t, fixture.checkpointStore, fixture.shadowRepo, originalSession, "cp-restore") + + if err := os.WriteFile(target, []byte("version two"), 0o644); err != nil { + t.Fatalf("WriteFile(version two) error = %v", err) + } + if err := fixture.sessionStore.UpdateSessionState(context.Background(), agentsession.UpdateSessionStateInput{ + SessionID: originalSession.ID, + UpdatedAt: time.Now(), + Title: "mutated", + Head: agentsession.SessionHead{ + Provider: "openai", + Model: "gpt-test", + Workdir: fixture.workdir, + TaskState: agentsession.TaskState{ + Goal: "mutated goal", + VerificationProfile: agentsession.VerificationProfileTaskOnly, + }, + }, + }); err != nil { + t.Fatalf("UpdateSessionState() error = %v", err) + } + if err := fixture.sessionStore.AppendMessages(context.Background(), agentsession.AppendMessagesInput{ + SessionID: originalSession.ID, + Messages: []providertypes.Message{ + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("after snapshot"), + }, + }, + }, + UpdatedAt: time.Now(), + Provider: "openai", + Model: "gpt-test", + Workdir: fixture.workdir, + }); err != nil { + t.Fatalf("AppendMessages(mutated) error = %v", err) + } + + conflictResult, err := fixture.service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{ + SessionID: originalSession.ID, + CheckpointID: record.CheckpointID, + }) + if err == nil || conflictResult.Conflict == nil || !conflictResult.Conflict.HasConflict { + t.Fatalf("RestoreCheckpoint(conflict) = (%#v, %v), want conflict error", conflictResult, err) + } + + restoreResult, err := fixture.service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{ + SessionID: originalSession.ID, + CheckpointID: record.CheckpointID, + Force: true, + }) + if err != nil { + t.Fatalf("RestoreCheckpoint(force) error = %v", err) + } + if restoreResult.CheckpointID != record.CheckpointID { + t.Fatalf("RestoreCheckpoint(force) = %#v", restoreResult) + } + + restoredContent, err := os.ReadFile(target) + if err != nil { + t.Fatalf("ReadFile(restored) error = %v", err) + } + if string(restoredContent) != "version one" { + t.Fatalf("restored content = %q, want version one", string(restoredContent)) + } + + restoredSession, err := fixture.sessionStore.LoadSession(context.Background(), originalSession.ID) + if err != nil { + t.Fatalf("LoadSession(restored) error = %v", err) + } + if restoredSession.TaskState.Goal != originalSession.TaskState.Goal || len(restoredSession.Messages) != len(originalSession.Messages) { + t.Fatalf("restored session = %#v, want original goal/messages", restoredSession) + } + + undoResult, err := fixture.service.UndoRestoreCheckpoint(context.Background(), originalSession.ID) + if err != nil { + t.Fatalf("UndoRestoreCheckpoint() error = %v", err) + } + if undoResult.SessionID != originalSession.ID { + t.Fatalf("UndoRestoreCheckpoint() = %#v", undoResult) + } + + undoneContent, err := os.ReadFile(target) + if err != nil { + t.Fatalf("ReadFile(undone) error = %v", err) + } + if string(undoneContent) != "version two" { + t.Fatalf("undone content = %q, want version two", string(undoneContent)) + } + + undoneSession, err := fixture.sessionStore.LoadSession(context.Background(), originalSession.ID) + if err != nil { + t.Fatalf("LoadSession(undone) error = %v", err) + } + if undoneSession.TaskState.Goal != "mutated goal" || len(undoneSession.Messages) != len(originalSession.Messages)+1 { + t.Fatalf("undone session = %#v, want mutated session restored", undoneSession) + } +} From bc5070864a05583da3cdc042992027bb6dc3d007 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Sun, 3 May 2026 16:07:54 +0000 Subject: [PATCH 7/8] =?UTF-8?q?test(checkpoint):=20=E8=A1=A5=E5=85=85=20ch?= =?UTF-8?q?eckpoint=20=E8=A6=86=E7=9B=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- .../checkpoint/checkpoint_manager_test.go | 80 ++++++++++ internal/gateway/bootstrap_test.go | 138 +++++++++++++++++- internal/runtime/checkpoint_flow_test.go | 58 +++++++- 3 files changed, 270 insertions(+), 6 deletions(-) diff --git a/internal/checkpoint/checkpoint_manager_test.go b/internal/checkpoint/checkpoint_manager_test.go index b8986d2f..5d2110bb 100644 --- a/internal/checkpoint/checkpoint_manager_test.go +++ b/internal/checkpoint/checkpoint_manager_test.go @@ -2,8 +2,10 @@ package checkpoint import ( "context" + "database/sql" "encoding/json" "path/filepath" + "strings" "testing" "time" @@ -415,3 +417,81 @@ func TestSQLiteCheckpointStoreUsesSessionDatabasePath(t *testing.T) { t.Fatalf("dbPath = %q, want %q", fixture.checkpointStore.dbPath, expected) } } + +func TestSQLiteCheckpointStoreSharedDBAndHelpers(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_shared_db", fixture.workspaceRoot) + db, err := fixture.checkpointStore.ensureDB(context.Background()) + if err != nil { + t.Fatalf("ensureDB() error = %v", err) + } + + shared := NewSQLiteCheckpointStoreWithDB(db) + if shared.ownsDB { + t.Fatal("shared checkpoint store should not own injected db") + } + if err := shared.Close(); err != nil { + t.Fatalf("Close(shared) error = %v", err) + } + if err := db.PingContext(context.Background()); err != nil { + t.Fatalf("db should remain open after shared Close(), got %v", err) + } + if _, err := shared.ListCheckpoints(context.Background(), loaded.ID, ListCheckpointOpts{}); err != nil { + t.Fatalf("shared ListCheckpoints() error = %v", err) + } + + if got := marshalPlanField(nil); got != "" { + t.Fatalf("marshalPlanField(nil) = %q, want empty", got) + } + var nilPlan *session.PlanArtifact + if got := marshalPlanField(nilPlan); got != "" { + t.Fatalf("marshalPlanField(nil pointer) = %q, want empty", got) + } + if got := marshalPlanField(map[string]any{"step": "verify"}); !strings.Contains(got, `"step":"verify"`) { + t.Fatalf("marshalPlanField(map) = %q", got) + } + if got := marshalPlanField(func() {}); got != "" { + t.Fatalf("marshalPlanField(unmarshalable) = %q, want empty", got) + } + if got := marshalHeadField(func() {}); got != "null" { + t.Fatalf("marshalHeadField(unmarshalable) = %q, want null", got) + } +} + +func TestSQLiteCheckpointStoreErrorsAndEmptyResults(t *testing.T) { + t.Parallel() + + fixture := newCheckpointStoreFixture(t) + loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_empty_resume", fixture.workspaceRoot) + if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), "missing", session.CheckpointStatusAvailable); err == nil { + t.Fatal("expected UpdateCheckpointStatus() to fail for missing checkpoint") + } + + rc, err := fixture.checkpointStore.GetLatestResumeCheckpoint(context.Background(), loaded.ID) + if err != nil { + t.Fatalf("GetLatestResumeCheckpoint(missing) error = %v", err) + } + if rc != nil { + t.Fatalf("GetLatestResumeCheckpoint(missing) = %#v, want nil", rc) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := fixture.checkpointStore.ensureDB(context.Background()); err != nil { + t.Fatalf("ensureDB() error = %v", err) + } + if _, err := fixture.checkpointStore.CreateCheckpoint(ctx, CreateCheckpointInput{}); err == nil { + t.Fatal("expected CreateCheckpoint() to honor canceled context") + } +} + +func TestNewSQLiteCheckpointStoreWithNilDBClose(t *testing.T) { + t.Parallel() + + store := NewSQLiteCheckpointStoreWithDB((*sql.DB)(nil)) + if err := store.Close(); err != nil { + t.Fatalf("Close(nil db) error = %v", err) + } +} diff --git a/internal/gateway/bootstrap_test.go b/internal/gateway/bootstrap_test.go index 2a508986..620924b7 100644 --- a/internal/gateway/bootstrap_test.go +++ b/internal/gateway/bootstrap_test.go @@ -45,6 +45,9 @@ type bootstrapRuntimeStub struct { upsertMCPServerFn func(ctx context.Context, input UpsertMCPServerInput) error setMCPEnabledFn func(ctx context.Context, input SetMCPServerEnabledInput) error deleteMCPServerFn func(ctx context.Context, input DeleteMCPServerInput) error + listCheckpointsFn func(ctx context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error) + restoreCheckpointFn func(ctx context.Context, input CheckpointRestoreInput) (CheckpointRestoreResult, error) + undoRestoreFn func(ctx context.Context, input UndoRestoreInput) (CheckpointRestoreResult, error) } func (s *bootstrapRuntimeStub) Run(ctx context.Context, input RunInput) error { @@ -249,15 +252,24 @@ func (s *bootstrapRuntimeStub) CreateSession(ctx context.Context, input CreateSe return strings.TrimSpace(input.SessionID), nil } -func (s *bootstrapRuntimeStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { +func (s *bootstrapRuntimeStub) ListCheckpoints(ctx context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error) { + if s != nil && s.listCheckpointsFn != nil { + return s.listCheckpointsFn(ctx, input) + } return nil, nil } -func (s *bootstrapRuntimeStub) RestoreCheckpoint(_ context.Context, _ CheckpointRestoreInput) (CheckpointRestoreResult, error) { +func (s *bootstrapRuntimeStub) RestoreCheckpoint(ctx context.Context, input CheckpointRestoreInput) (CheckpointRestoreResult, error) { + if s != nil && s.restoreCheckpointFn != nil { + return s.restoreCheckpointFn(ctx, input) + } return CheckpointRestoreResult{}, nil } -func (s *bootstrapRuntimeStub) UndoRestore(_ context.Context, _ UndoRestoreInput) (CheckpointRestoreResult, error) { +func (s *bootstrapRuntimeStub) UndoRestore(ctx context.Context, input UndoRestoreInput) (CheckpointRestoreResult, error) { + if s != nil && s.undoRestoreFn != nil { + return s.undoRestoreFn(ctx, input) + } return CheckpointRestoreResult{}, nil } @@ -431,6 +443,126 @@ func TestDecodeSessionSkillAndSnapshotPayloadBranches(t *testing.T) { } } +func TestCheckpointFrameHandlers(t *testing.T) { + t.Run("list checkpoints success", func(t *testing.T) { + runtime := &bootstrapRuntimeStub{ + listCheckpointsFn: func(_ context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error) { + if input.SubjectID != "subject-1" || input.SessionID != "session-1" { + t.Fatalf("input = %#v", input) + } + return []CheckpointEntry{{CheckpointID: "cp-1", SessionID: "session-1"}}, nil + }, + } + authState := NewConnectionAuthState() + authState.MarkAuthenticated("subject-1") + ctx := WithConnectionAuthState(context.Background(), authState) + + response := handleListCheckpointsFrame(ctx, MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionListCheckpoints, + RequestID: "req-checkpoint-list", + SessionID: " session-1 ", + }, runtime) + + if response.Type != FrameTypeAck || response.Action != FrameActionListCheckpoints { + t.Fatalf("response = %#v", response) + } + entries, ok := response.Payload.([]CheckpointEntry) + if !ok || len(entries) != 1 || entries[0].CheckpointID != "cp-1" { + t.Fatalf("payload = %#v", response.Payload) + } + }) + + t.Run("restore checkpoint success", func(t *testing.T) { + runtime := &bootstrapRuntimeStub{ + restoreCheckpointFn: func(_ context.Context, input CheckpointRestoreInput) (CheckpointRestoreResult, error) { + if input.SubjectID != "subject-1" || input.SessionID != "session-1" || input.CheckpointID != "cp-1" || !input.Force { + t.Fatalf("input = %#v", input) + } + return CheckpointRestoreResult{CheckpointID: input.CheckpointID, SessionID: input.SessionID}, nil + }, + } + authState := NewConnectionAuthState() + authState.MarkAuthenticated("subject-1") + ctx := WithConnectionAuthState(context.Background(), authState) + + response := handleRestoreCheckpointFrame(ctx, MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionRestoreCheckpoint, + RequestID: "req-checkpoint-restore", + SessionID: " session-1 ", + Payload: map[string]any{ + "checkpoint_id": " cp-1 ", + "force": true, + }, + }, runtime) + + if response.Type != FrameTypeAck || response.Action != FrameActionRestoreCheckpoint || response.SessionID != "session-1" { + t.Fatalf("response = %#v", response) + } + result, ok := response.Payload.(CheckpointRestoreResult) + if !ok || result.CheckpointID != "cp-1" { + t.Fatalf("payload = %#v", response.Payload) + } + }) + + t.Run("undo restore success", func(t *testing.T) { + runtime := &bootstrapRuntimeStub{ + undoRestoreFn: func(_ context.Context, input UndoRestoreInput) (CheckpointRestoreResult, error) { + if input.SubjectID != "subject-1" || input.SessionID != "session-1" { + t.Fatalf("input = %#v", input) + } + return CheckpointRestoreResult{CheckpointID: "cp-guard", SessionID: input.SessionID}, nil + }, + } + authState := NewConnectionAuthState() + authState.MarkAuthenticated("subject-1") + ctx := WithConnectionAuthState(context.Background(), authState) + + response := handleUndoRestoreFrame(ctx, MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionUndoRestore, + RequestID: "req-checkpoint-undo", + SessionID: " session-1 ", + }, runtime) + + if response.Type != FrameTypeAck || response.Action != FrameActionUndoRestore || response.SessionID != "session-1" { + t.Fatalf("response = %#v", response) + } + result, ok := response.Payload.(CheckpointRestoreResult) + if !ok || result.CheckpointID != "cp-guard" { + t.Fatalf("payload = %#v", response.Payload) + } + }) +} + +func TestDecodeCheckpointRestorePayloadBranches(t *testing.T) { + t.Parallel() + + params := decodeCheckpointRestorePayload(map[string]any{ + "session_id": " session-1 ", + "checkpoint_id": " cp-1 ", + "force": true, + }) + if params.SessionID != "session-1" || params.CheckpointID != "cp-1" || !params.Force { + t.Fatalf("decode map payload = %#v", params) + } + + params = decodeCheckpointRestorePayload(CheckpointRestoreInput{ + SessionID: "session-2", + CheckpointID: "cp-2", + Force: true, + }) + if params.SessionID != "session-2" || params.CheckpointID != "cp-2" || !params.Force { + t.Fatalf("decode struct payload = %#v", params) + } + + params = decodeCheckpointRestorePayload(invalidJSONMarshaler{}) + if params != (CheckpointRestoreInput{}) { + t.Fatalf("marshal failure should return zero input, got %#v", params) + } +} + func TestDispatchRequestFrameWakeOpenURLReviewSuccess(t *testing.T) { createInputs := make(chan CreateSessionInput, 1) stub := &bootstrapRuntimeStub{ diff --git a/internal/runtime/checkpoint_flow_test.go b/internal/runtime/checkpoint_flow_test.go index d9e81a57..4ebe81f3 100644 --- a/internal/runtime/checkpoint_flow_test.go +++ b/internal/runtime/checkpoint_flow_test.go @@ -13,15 +13,21 @@ import ( ) type checkpointStoreSpy struct { - lastResume agentsession.ResumeCheckpoint + lastResume agentsession.ResumeCheckpoint + listRecords []agentsession.CheckpointRecord + listSessionID string + listOpts checkpoint.ListCheckpointOpts + listErr error } func (s *checkpointStoreSpy) CreateCheckpoint(context.Context, checkpoint.CreateCheckpointInput) (agentsession.CheckpointRecord, error) { return agentsession.CheckpointRecord{}, nil } -func (s *checkpointStoreSpy) ListCheckpoints(context.Context, string, checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) { - return nil, nil +func (s *checkpointStoreSpy) ListCheckpoints(_ context.Context, sessionID string, opts checkpoint.ListCheckpointOpts) ([]agentsession.CheckpointRecord, error) { + s.listSessionID = sessionID + s.listOpts = opts + return s.listRecords, s.listErr } func (s *checkpointStoreSpy) GetCheckpoint(context.Context, string) (agentsession.CheckpointRecord, *agentsession.SessionCheckpoint, error) { @@ -290,6 +296,52 @@ func TestCreateCompactCheckpointAndResumeCheckpoint(t *testing.T) { } } +func TestRuntimeCheckpointFacadeMethods(t *testing.T) { + t.Run("list checkpoints delegates to store", func(t *testing.T) { + spy := &checkpointStoreSpy{ + listRecords: []agentsession.CheckpointRecord{{CheckpointID: "cp-1"}}, + } + service := &Service{checkpointStore: spy} + + records, err := service.ListCheckpoints(context.Background(), "session-1", checkpoint.ListCheckpointOpts{ + Limit: 5, + RestorableOnly: true, + }) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if spy.listSessionID != "session-1" || spy.listOpts.Limit != 5 || !spy.listOpts.RestorableOnly { + t.Fatalf("spy captured session=%q opts=%#v", spy.listSessionID, spy.listOpts) + } + if len(records) != 1 || records[0].CheckpointID != "cp-1" { + t.Fatalf("records = %#v", records) + } + }) + + t.Run("list checkpoints reports unavailable store", func(t *testing.T) { + service := &Service{} + if _, err := service.ListCheckpoints(context.Background(), "session-1", checkpoint.ListCheckpointOpts{}); err == nil { + t.Fatal("expected error when checkpoint store is unavailable") + } + }) + + t.Run("set checkpoint dependencies stores references", func(t *testing.T) { + service := &Service{} + store := &checkpointStoreSpy{} + repo := checkpoint.NewShadowRepo(t.TempDir(), t.TempDir()) + + service.SetCheckpointDependencies(store, repo) + if service.checkpointStore != store || service.shadowRepo != repo { + t.Fatalf("service checkpoint dependencies not set correctly") + } + }) + + t.Run("update runtime session after restore is no-op", func(t *testing.T) { + service := &Service{} + service.updateRuntimeSessionAfterRestore("session-1", agentsession.SessionHead{}, nil) + }) +} + func TestRestoreCheckpointAndUndoRestore(t *testing.T) { t.Parallel() From 95573f5baa52ea0c3cf19e6b56e1bf3704364296 Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Mon, 4 May 2026 00:56:09 +0800 Subject: [PATCH 8/8] =?UTF-8?q?pref(checkpoint):=E4=BF=AE=E5=A4=8Duntracke?= =?UTF-8?q?d=E6=96=87=E4=BB=B6=E6=9C=AA=E8=A2=AB=E8=BF=BD=E8=B8=AA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/checkpoint/shadow_repo.go | 26 ++++++++++++++++-------- internal/runtime/checkpoint_flow_test.go | 11 ++++++++-- internal/runtime/checkpoint_restore.go | 11 ++++++++-- internal/runtime/run.go | 10 ++++++++- 4 files changed, 45 insertions(+), 13 deletions(-) diff --git a/internal/checkpoint/shadow_repo.go b/internal/checkpoint/shadow_repo.go index 65862b72..45eb3d3d 100644 --- a/internal/checkpoint/shadow_repo.go +++ b/internal/checkpoint/shadow_repo.go @@ -202,8 +202,8 @@ func (r *ShadowRepo) DeleteRef(ctx context.Context, ref string) error { return nil } -// HasCodeChanges 检查工作区是否有未提交的代码变更。 -// 使用 git diff --quiet HEAD -- . 检测,退出码 1 表示有变更。 +// HasCodeChanges 检查工作区是否有未提交的代码变更,包括未跟踪文件。 +// 先用 git diff --quiet HEAD -- . 检测已跟踪文件变更,再用 ls-files 检测未跟踪文件。 // git 不可用时返回 true(保守策略)。 func (r *ShadowRepo) HasCodeChanges(ctx context.Context) bool { r.mu.Lock() @@ -213,17 +213,27 @@ func (r *ShadowRepo) HasCodeChanges(ctx context.Context) bool { return true } - ctx, cancel := context.WithTimeout(ctx, gitCommandTimeout) - defer cancel() - - cmd := r.buildGitCommand(ctx, "diff", "--quiet", "HEAD", "--", ".") + // 检测已跟踪文件变更 + diffCtx, diffCancel := context.WithTimeout(ctx, gitCommandTimeout) + defer diffCancel() + cmd := r.buildGitCommand(diffCtx, "diff", "--quiet", "HEAD", "--", ".") if err := cmd.Run(); err != nil { if exitErr, ok := err.(*exec.ExitError); ok { - return exitErr.ExitCode() == 1 + if exitErr.ExitCode() == 1 { + return true + } } return true } - return false + + // 检测未跟踪文件(排除 .gitignore 忽略的文件) + lsCtx, lsCancel := context.WithTimeout(ctx, gitCommandTimeout) + defer lsCancel() + out, err := r.gitOutput(lsCtx, "ls-files", "--others", "--exclude-standard") + if err != nil { + return true + } + return strings.TrimSpace(out) != "" } // HealthCheck 验证 bare 仓库存在且可执行 git 操作。 diff --git a/internal/runtime/checkpoint_flow_test.go b/internal/runtime/checkpoint_flow_test.go index 4ebe81f3..f8c95933 100644 --- a/internal/runtime/checkpoint_flow_test.go +++ b/internal/runtime/checkpoint_flow_test.go @@ -336,9 +336,16 @@ func TestRuntimeCheckpointFacadeMethods(t *testing.T) { } }) - t.Run("update runtime session after restore is no-op", func(t *testing.T) { - service := &Service{} + t.Run("update runtime session after restore invalidates cache", func(t *testing.T) { + service := &Service{ + runtimeSnapshots: map[string]RuntimeSnapshot{ + "session-1": {SessionID: "session-1", Phase: "execute"}, + }, + } service.updateRuntimeSessionAfterRestore("session-1", agentsession.SessionHead{}, nil) + if _, ok := service.runtimeSnapshots["session-1"]; ok { + t.Fatal("expected cached snapshot to be deleted after restore") + } }) } diff --git a/internal/runtime/checkpoint_restore.go b/internal/runtime/checkpoint_restore.go index a3451fdf..df1971e9 100644 --- a/internal/runtime/checkpoint_restore.go +++ b/internal/runtime/checkpoint_restore.go @@ -267,7 +267,14 @@ func (s *Service) ListCheckpoints(ctx context.Context, sessionID string, opts ch return s.checkpointStore.ListCheckpoints(ctx, sessionID, opts) } -// updateRuntimeSessionAfterRestore 在 restore 后更新运行时会话状态。 -// 当前实现不做运行时状态直接修改,依赖下次 session 加载时从 DB 读取恢复后的状态。 +// updateRuntimeSessionAfterRestore 使运行时快照缓存失效。 +// GetRuntimeSnapshot 会从 DB 重新加载恢复后的状态,而非返回旧缓存。 func (s *Service) updateRuntimeSessionAfterRestore(sessionID string, head agentsession.SessionHead, messages []providertypes.Message) { + normalized := strings.TrimSpace(sessionID) + if normalized == "" { + return + } + s.runtimeSnapshotMu.Lock() + delete(s.runtimeSnapshots, normalized) + s.runtimeSnapshotMu.Unlock() } diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 35883f34..a724bf9f 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -91,7 +91,15 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { }() defer func() { if statePtr != nil { - s.updateResumeCheckpoint(runCtx, statePtr, "stopped", "completed") + completion := "completed" + if err != nil { + if errors.Is(err, context.Canceled) { + completion = "cancelled" + } else { + completion = "error" + } + } + s.updateResumeCheckpoint(runCtx, statePtr, "stopped", completion) } s.emitRunTermination(runCtx, input, statePtr, err) }()