diff --git a/internal/checkpoint/checkpoint_manager.go b/internal/checkpoint/checkpoint_manager.go
index 347c5582..01f8a185 100644
--- a/internal/checkpoint/checkpoint_manager.go
+++ b/internal/checkpoint/checkpoint_manager.go
@@ -58,6 +58,29 @@ type SQLiteCheckpointStore struct {
ownsDB bool // true 表示本实例打开的连接,Close 时需释放
}
+type workspaceFingerprintRow struct {
+ WorkspaceKey string
+ FingerprintPayload string
+ UpdatedAtMS int64
+}
+
+// WorkspaceCheckpointState 保存工作区当前权威代码基线与文件指纹。
+type WorkspaceCheckpointState struct {
+ WorkspaceKey string
+ CurrentCheckpointID string
+ FingerprintPayload string
+ UpdatedAt time.Time
+}
+
+// RunCheckpointBaseline 保存单次 run 的权威回退基线,供异步 diff 查询跨 run 结束后读取。
+type RunCheckpointBaseline struct {
+ SessionID string
+ RunID string
+ CheckpointID string
+ Drifted bool
+ UpdatedAt time.Time
+}
+
// NewSQLiteCheckpointStore 创建 checkpoint 存储实例。
// dbPath 为 session.db 文件路径,可通过 session.DatabasePath 获取。
func NewSQLiteCheckpointStore(dbPath string) *SQLiteCheckpointStore {
@@ -114,6 +137,184 @@ func (s *SQLiteCheckpointStore) ensureDB(ctx context.Context) (*sql.DB, error) {
return db, nil
}
+// SaveWorkspaceFingerprint 把 workspace 最新指纹保存到 SQLite,用于跨会话/重启后的 drift 检测。
+func (s *SQLiteCheckpointStore) SaveWorkspaceFingerprint(
+ ctx context.Context,
+ workspaceKey string,
+ fingerprintPayload string,
+ updatedAt time.Time,
+) error {
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+ db, err := s.ensureDB(ctx)
+ if err != nil {
+ return err
+ }
+ if err := ensureWorkspaceFingerprintTable(ctx, db); err != nil {
+ return err
+ }
+ _, err = db.ExecContext(ctx, `
+INSERT INTO workspace_fingerprints(workspace_key, fingerprint_payload, updated_at_ms)
+VALUES(?, ?, ?)
+ON CONFLICT(workspace_key) DO UPDATE SET
+ fingerprint_payload=excluded.fingerprint_payload,
+ updated_at_ms=excluded.updated_at_ms
+`, workspaceKey, fingerprintPayload, toUnixMillis(updatedAt))
+ if err != nil {
+ return fmt.Errorf("checkpoint: save workspace fingerprint %s: %w", workspaceKey, err)
+ }
+ return nil
+}
+
+// LoadWorkspaceFingerprint 从 SQLite 读取 workspace 最近指纹,不存在时 ok=false。
+func (s *SQLiteCheckpointStore) LoadWorkspaceFingerprint(
+ ctx context.Context,
+ workspaceKey string,
+) (string, bool, error) {
+ if err := ctx.Err(); err != nil {
+ return "", false, err
+ }
+ db, err := s.ensureDB(ctx)
+ if err != nil {
+ return "", false, err
+ }
+ if err := ensureWorkspaceFingerprintTable(ctx, db); err != nil {
+ return "", false, err
+ }
+ var row workspaceFingerprintRow
+ err = db.QueryRowContext(ctx, `
+SELECT workspace_key, fingerprint_payload, updated_at_ms
+FROM workspace_fingerprints
+WHERE workspace_key = ?
+`, workspaceKey).Scan(&row.WorkspaceKey, &row.FingerprintPayload, &row.UpdatedAtMS)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return "", false, nil
+ }
+ return "", false, fmt.Errorf("checkpoint: load workspace fingerprint %s: %w", workspaceKey, err)
+ }
+ return row.FingerprintPayload, true, nil
+}
+
+// SaveWorkspaceCheckpointState 持久化工作区当前代码基线与文件指纹。
+func (s *SQLiteCheckpointStore) SaveWorkspaceCheckpointState(ctx context.Context, state WorkspaceCheckpointState) error {
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+ if state.WorkspaceKey == "" {
+ return fmt.Errorf("checkpoint: workspace key required")
+ }
+ db, err := s.ensureDB(ctx)
+ if err != nil {
+ return err
+ }
+ if err := ensureWorkspaceCheckpointStateTable(ctx, db); err != nil {
+ return err
+ }
+ _, err = db.ExecContext(ctx, `
+INSERT INTO workspace_checkpoint_states(workspace_key, current_checkpoint_id, fingerprint_payload, updated_at_ms)
+VALUES(?, ?, ?, ?)
+ON CONFLICT(workspace_key) DO UPDATE SET
+ current_checkpoint_id=excluded.current_checkpoint_id,
+ fingerprint_payload=excluded.fingerprint_payload,
+ updated_at_ms=excluded.updated_at_ms
+`, state.WorkspaceKey, state.CurrentCheckpointID, state.FingerprintPayload, toUnixMillis(state.UpdatedAt))
+ if err != nil {
+ return fmt.Errorf("checkpoint: save workspace checkpoint state %s: %w", state.WorkspaceKey, err)
+ }
+ return nil
+}
+
+// LoadWorkspaceCheckpointState 读取工作区当前代码基线与文件指纹,不存在时 ok=false。
+func (s *SQLiteCheckpointStore) LoadWorkspaceCheckpointState(ctx context.Context, workspaceKey string) (WorkspaceCheckpointState, bool, error) {
+ if err := ctx.Err(); err != nil {
+ return WorkspaceCheckpointState{}, false, err
+ }
+ db, err := s.ensureDB(ctx)
+ if err != nil {
+ return WorkspaceCheckpointState{}, false, err
+ }
+ if err := ensureWorkspaceCheckpointStateTable(ctx, db); err != nil {
+ return WorkspaceCheckpointState{}, false, err
+ }
+ var state WorkspaceCheckpointState
+ var updatedAtMS int64
+ err = db.QueryRowContext(ctx, `
+SELECT workspace_key, current_checkpoint_id, fingerprint_payload, updated_at_ms
+FROM workspace_checkpoint_states
+WHERE workspace_key = ?
+`, workspaceKey).Scan(&state.WorkspaceKey, &state.CurrentCheckpointID, &state.FingerprintPayload, &updatedAtMS)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return WorkspaceCheckpointState{}, false, nil
+ }
+ return WorkspaceCheckpointState{}, false, fmt.Errorf("checkpoint: load workspace checkpoint state %s: %w", workspaceKey, err)
+ }
+ state.UpdatedAt = fromUnixMillis(updatedAtMS)
+ return state, true, nil
+}
+
+// SaveRunCheckpointBaseline 持久化单次 run 的权威回退基线。
+func (s *SQLiteCheckpointStore) SaveRunCheckpointBaseline(ctx context.Context, baseline RunCheckpointBaseline) error {
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+ if baseline.SessionID == "" || baseline.RunID == "" {
+ return fmt.Errorf("checkpoint: session_id and run_id required for run baseline")
+ }
+ db, err := s.ensureDB(ctx)
+ if err != nil {
+ return err
+ }
+ if err := ensureRunCheckpointBaselineTable(ctx, db); err != nil {
+ return err
+ }
+ _, err = db.ExecContext(ctx, `
+INSERT INTO run_checkpoint_baselines(session_id, run_id, checkpoint_id, drifted, updated_at_ms)
+VALUES(?, ?, ?, ?, ?)
+ON CONFLICT(session_id, run_id) DO UPDATE SET
+ checkpoint_id=excluded.checkpoint_id,
+ drifted=excluded.drifted,
+ updated_at_ms=excluded.updated_at_ms
+`, baseline.SessionID, baseline.RunID, baseline.CheckpointID, boolToInt(baseline.Drifted), toUnixMillis(baseline.UpdatedAt))
+ if err != nil {
+ return fmt.Errorf("checkpoint: save run baseline %s/%s: %w", baseline.SessionID, baseline.RunID, err)
+ }
+ return nil
+}
+
+// LoadRunCheckpointBaseline 读取单次 run 的权威回退基线,不存在时 ok=false。
+func (s *SQLiteCheckpointStore) LoadRunCheckpointBaseline(ctx context.Context, sessionID, runID string) (RunCheckpointBaseline, bool, error) {
+ if err := ctx.Err(); err != nil {
+ return RunCheckpointBaseline{}, false, err
+ }
+ db, err := s.ensureDB(ctx)
+ if err != nil {
+ return RunCheckpointBaseline{}, false, err
+ }
+ if err := ensureRunCheckpointBaselineTable(ctx, db); err != nil {
+ return RunCheckpointBaseline{}, false, err
+ }
+ var baseline RunCheckpointBaseline
+ var drifted int
+ var updatedAtMS int64
+ err = db.QueryRowContext(ctx, `
+SELECT session_id, run_id, checkpoint_id, drifted, updated_at_ms
+FROM run_checkpoint_baselines
+WHERE session_id = ? AND run_id = ?
+`, sessionID, runID).Scan(&baseline.SessionID, &baseline.RunID, &baseline.CheckpointID, &drifted, &updatedAtMS)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return RunCheckpointBaseline{}, false, nil
+ }
+ return RunCheckpointBaseline{}, false, fmt.Errorf("checkpoint: load run baseline %s/%s: %w", sessionID, runID, err)
+ }
+ baseline.Drifted = drifted != 0
+ baseline.UpdatedAt = fromUnixMillis(updatedAtMS)
+ return baseline, true, nil
+}
+
// CreateCheckpoint 在单一事务内写入 checkpoint record + session checkpoint。
// 事务内完成 record INSERT → session_cp INSERT → record UPDATE(设置 session_checkpoint_ref + status=available)。
func (s *SQLiteCheckpointStore) CreateCheckpoint(ctx context.Context, input CreateCheckpointInput) (session.CheckpointRecord, error) {
@@ -663,3 +864,52 @@ func rollbackTx(tx *sql.Tx) {
_ = tx.Rollback()
}
}
+
+func ensureWorkspaceFingerprintTable(ctx context.Context, db *sql.DB) error {
+ if db == nil {
+ return fmt.Errorf("checkpoint: workspace fingerprint db is nil")
+ }
+ if _, err := db.ExecContext(ctx, `
+CREATE TABLE IF NOT EXISTS workspace_fingerprints (
+ workspace_key TEXT PRIMARY KEY,
+ fingerprint_payload TEXT NOT NULL,
+ updated_at_ms INTEGER NOT NULL
+)`); err != nil {
+ return fmt.Errorf("checkpoint: ensure workspace_fingerprints table: %w", err)
+ }
+ return nil
+}
+
+func ensureWorkspaceCheckpointStateTable(ctx context.Context, db *sql.DB) error {
+ if db == nil {
+ return fmt.Errorf("checkpoint: workspace checkpoint state db is nil")
+ }
+ if _, err := db.ExecContext(ctx, `
+CREATE TABLE IF NOT EXISTS workspace_checkpoint_states (
+ workspace_key TEXT PRIMARY KEY,
+ current_checkpoint_id TEXT NOT NULL,
+ fingerprint_payload TEXT NOT NULL,
+ updated_at_ms INTEGER NOT NULL
+)`); err != nil {
+ return fmt.Errorf("checkpoint: ensure workspace_checkpoint_states table: %w", err)
+ }
+ return nil
+}
+
+func ensureRunCheckpointBaselineTable(ctx context.Context, db *sql.DB) error {
+ if db == nil {
+ return fmt.Errorf("checkpoint: run checkpoint baseline db is nil")
+ }
+ if _, err := db.ExecContext(ctx, `
+CREATE TABLE IF NOT EXISTS run_checkpoint_baselines (
+ session_id TEXT NOT NULL,
+ run_id TEXT NOT NULL,
+ checkpoint_id TEXT NOT NULL,
+ drifted INTEGER NOT NULL,
+ updated_at_ms INTEGER NOT NULL,
+ PRIMARY KEY(session_id, run_id)
+)`); err != nil {
+ return fmt.Errorf("checkpoint: ensure run_checkpoint_baselines table: %w", err)
+ }
+ return nil
+}
diff --git a/internal/checkpoint/checkpoint_manager_test.go b/internal/checkpoint/checkpoint_manager_test.go
index 6afd96a1..67d3f58a 100644
--- a/internal/checkpoint/checkpoint_manager_test.go
+++ b/internal/checkpoint/checkpoint_manager_test.go
@@ -588,3 +588,81 @@ func TestNewSQLiteCheckpointStoreWithNilDBClose(t *testing.T) {
t.Fatalf("Close(nil db) error = %v", err)
}
}
+
+func TestSQLiteCheckpointStoreWorkspaceFingerprintPersistence(t *testing.T) {
+ fixture := newCheckpointStoreFixture(t)
+ db, err := fixture.sessionStore.InitDB(context.Background())
+ if err != nil {
+ t.Fatalf("InitDB() error = %v", err)
+ }
+ shared := NewSQLiteCheckpointStoreWithDB(db)
+
+ workspaceKey := session.WorkspacePathKey(fixture.workspaceRoot)
+ payload := `{"a.txt":{"size":1,"mod_time":"2026-01-01T00:00:00Z","head_hash":"abc"}}`
+ if err := shared.SaveWorkspaceFingerprint(context.Background(), workspaceKey, payload, time.Now()); err != nil {
+ t.Fatalf("SaveWorkspaceFingerprint() error = %v", err)
+ }
+
+ got, ok, err := shared.LoadWorkspaceFingerprint(context.Background(), workspaceKey)
+ if err != nil {
+ t.Fatalf("LoadWorkspaceFingerprint() error = %v", err)
+ }
+ if !ok {
+ t.Fatal("expected persisted fingerprint to exist")
+ }
+ if got != payload {
+ t.Fatalf("fingerprint payload = %q, want %q", got, payload)
+ }
+
+ missing, ok, err := shared.LoadWorkspaceFingerprint(context.Background(), "workspace/missing")
+ if err != nil {
+ t.Fatalf("LoadWorkspaceFingerprint(missing) error = %v", err)
+ }
+ if ok || missing != "" {
+ t.Fatalf("expected missing fingerprint, got ok=%v payload=%q", ok, missing)
+ }
+}
+
+func TestSQLiteCheckpointStoreWorkspaceStateAndRunBaselinePersistence(t *testing.T) {
+ fixture := newCheckpointStoreFixture(t)
+ db, err := fixture.sessionStore.InitDB(context.Background())
+ if err != nil {
+ t.Fatalf("InitDB() error = %v", err)
+ }
+ shared := NewSQLiteCheckpointStoreWithDB(db)
+
+ workspaceKey := session.WorkspacePathKey(fixture.workspaceRoot)
+ payload := `{"tracked.txt":{"size":7}}`
+ if err := shared.SaveWorkspaceCheckpointState(context.Background(), WorkspaceCheckpointState{
+ WorkspaceKey: workspaceKey,
+ CurrentCheckpointID: "cp-current",
+ FingerprintPayload: payload,
+ UpdatedAt: time.Now(),
+ }); err != nil {
+ t.Fatalf("SaveWorkspaceCheckpointState() error = %v", err)
+ }
+ state, ok, err := shared.LoadWorkspaceCheckpointState(context.Background(), workspaceKey)
+ if err != nil {
+ t.Fatalf("LoadWorkspaceCheckpointState() error = %v", err)
+ }
+ if !ok || state.CurrentCheckpointID != "cp-current" || state.FingerprintPayload != payload {
+ t.Fatalf("workspace state = %#v ok=%v", state, ok)
+ }
+
+ if err := shared.SaveRunCheckpointBaseline(context.Background(), RunCheckpointBaseline{
+ SessionID: "session-1",
+ RunID: "run-1",
+ CheckpointID: "cp-current",
+ Drifted: true,
+ UpdatedAt: time.Now(),
+ }); err != nil {
+ t.Fatalf("SaveRunCheckpointBaseline() error = %v", err)
+ }
+ baseline, ok, err := shared.LoadRunCheckpointBaseline(context.Background(), "session-1", "run-1")
+ if err != nil {
+ t.Fatalf("LoadRunCheckpointBaseline() error = %v", err)
+ }
+ if !ok || baseline.CheckpointID != "cp-current" || !baseline.Drifted {
+ t.Fatalf("run baseline = %#v ok=%v", baseline, ok)
+ }
+}
diff --git a/internal/checkpoint/per_edit_snapshot.go b/internal/checkpoint/per_edit_snapshot.go
index 9efa78d2..0580bc5c 100644
--- a/internal/checkpoint/per_edit_snapshot.go
+++ b/internal/checkpoint/per_edit_snapshot.go
@@ -367,9 +367,9 @@ func (s *PerEditSnapshotStore) Reset() {
// guardID 为 pre-restore 固化的快照(restoreCheckpointCore 中的 guard checkpoint),
// 用于对比确定每个文件的目标操作;guardID 为空时仅处理 target checkpoint 内的文件。
//
-// 对比逻辑:对 target 与 guard 中出现的每个文件,分别计算"目标状态"与"当前状态",
-// 据此执行写回 / 删除 / 跳过,覆盖文件创建、修改、删除三种变更方向。
-func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID string) error {
+// 对比逻辑:存在 relatedCheckpointIDs 时先收敛到 target 与后续 checkpoint 的净变化路径,
+// 再分别计算"目标状态"与"当前状态"并执行写回 / 删除 / 跳过。
+func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID string, relatedCheckpointIDs ...string) error {
targetCP, err := s.readCheckpointMeta(targetID)
if err != nil {
return err
@@ -378,11 +378,6 @@ func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID st
s.indexMu.Lock()
defer s.indexMu.Unlock()
- hashSet := make(map[string]struct{}, len(targetCP.FileVersions))
- for h := range targetCP.FileVersions {
- hashSet[h] = struct{}{}
- }
-
var guardCP CheckpointMeta
hasGuard := guardID != ""
if hasGuard {
@@ -390,15 +385,29 @@ func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID st
if err != nil {
return err
}
- for h := range guardCP.FileVersions {
- hashSet[h] = struct{}{}
+ }
+ relatedCPs := make([]CheckpointMeta, 0, len(relatedCheckpointIDs))
+ for _, checkpointID := range relatedCheckpointIDs {
+ if strings.TrimSpace(checkpointID) == "" || checkpointID == targetID || checkpointID == guardID {
+ continue
+ }
+ relatedCP, err := s.readCheckpointMeta(checkpointID)
+ if err != nil {
+ return err
}
+ relatedCPs = append(relatedCPs, relatedCP)
}
- // 无论有无 guard,都必须合并全量 pathToVersions。
- // guard 是 pending-only 的,不包含此前创建的、本 turn 未触碰的新文件;
- // 不合并则这些文件在 restore 后仍会残留。
- for h := range s.pathToVersions {
- hashSet[h] = struct{}{}
+
+ hashSet := make(map[string]struct{})
+ if len(relatedCPs) > 0 {
+ if err := s.collectRestoreChangedHashesLocked(hashSet, targetCP, relatedCPs); err != nil {
+ return err
+ }
+ } else {
+ addCheckpointHashes(hashSet, targetCP)
+ }
+ if hasGuard {
+ addCheckpointHashes(hashSet, guardCP)
}
for hash := range hashSet {
@@ -407,7 +416,7 @@ func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID st
}
// 目标状态:target checkpoint 时刻文件应如何。
- toContent, toIsDir, toExists, toMode, toDisplay, err := s.contentAtCheckpointLocked(hash, targetCP.FileVersions, false)
+ toContent, toIsDir, toExists, toMode, toDisplay, err := s.contentAtCheckpointStateLocked(hash, targetCP, false)
if err != nil {
return err
}
@@ -418,7 +427,7 @@ func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID st
var fromMode os.FileMode
var fromDisplay string
if hasGuard {
- fromContent, fromIsDir, fromExists, fromMode, fromDisplay, err = s.contentAtCheckpointLocked(hash, guardCP.FileVersions, true)
+ fromContent, fromIsDir, fromExists, fromMode, fromDisplay, err = s.contentAtCheckpointStateLocked(hash, guardCP, true)
if err != nil {
return err
}
@@ -492,10 +501,26 @@ func (s *PerEditSnapshotStore) RestoreExact(ctx context.Context, checkpointID st
s.indexMu.Lock()
defer s.indexMu.Unlock()
- for hash, vAt := range cp.FileVersions {
+ hashSet := make(map[string]struct{}, len(cp.FileVersions)+len(cp.ExactFileVersions))
+ for hash := range cp.FileVersions {
+ hashSet[hash] = struct{}{}
+ }
+ for hash := range cp.ExactFileVersions {
+ hashSet[hash] = struct{}{}
+ }
+ hashes := make([]string, 0, len(hashSet))
+ for hash := range hashSet {
+ hashes = append(hashes, hash)
+ }
+ sort.Strings(hashes)
+ for _, hash := range hashes {
if err := ctx.Err(); err != nil {
return err
}
+ vAt, ok := cp.ExactFileVersions[hash]
+ if !ok {
+ vAt = cp.FileVersions[hash]
+ }
meta, err := s.readVersionMeta(hash, vAt)
if err != nil {
return fmt.Errorf("per-edit: read meta v%d: %w", vAt, err)
@@ -1311,6 +1336,82 @@ func readWorkdirMode(absPath string) os.FileMode {
return info.Mode()
}
+// addCheckpointHashes 把 checkpoint 中显式记录的文件 hash 加入 restore 候选集合。
+func addCheckpointHashes(hashSet map[string]struct{}, cp CheckpointMeta) {
+ for hash := range cp.FileVersions {
+ hashSet[hash] = struct{}{}
+ }
+ for hash := range cp.ExactFileVersions {
+ hashSet[hash] = struct{}{}
+ }
+}
+
+// collectRestoreChangedHashesLocked 只收集 target 到后续 checkpoint 之间净变化的路径。
+// target 可作为全工作区重锚点;真正 restore 时不应因为 target 是全量快照而误碰无关文件。
+func (s *PerEditSnapshotStore) collectRestoreChangedHashesLocked(
+ hashSet map[string]struct{},
+ targetCP CheckpointMeta,
+ relatedCPs []CheckpointMeta,
+) error {
+ if len(relatedCPs) == 0 {
+ return nil
+ }
+ latestCP := relatedCPs[0]
+ candidateHashes := make(map[string]struct{})
+ addCheckpointHashes(candidateHashes, targetCP)
+ for _, cp := range relatedCPs {
+ addCheckpointHashes(candidateHashes, cp)
+ if cp.CreatedAt.After(latestCP.CreatedAt) {
+ latestCP = cp
+ }
+ }
+ for hash := range candidateHashes {
+ targetContent, targetIsDir, targetExists, targetMode, _, err := s.contentAtCheckpointStateLocked(hash, targetCP, false)
+ if err != nil {
+ return err
+ }
+ latestContent, latestIsDir, latestExists, latestMode, _, err := s.contentAtCheckpointStateLocked(hash, latestCP, false)
+ if err != nil {
+ return err
+ }
+ if targetExists != latestExists ||
+ targetIsDir != latestIsDir ||
+ targetMode != latestMode ||
+ !bytes.Equal(targetContent, latestContent) {
+ hashSet[hash] = struct{}{}
+ }
+ }
+ return nil
+}
+
+// contentAtCheckpointStateLocked 读取 checkpoint 结束态;新 checkpoint 优先使用 ExactFileVersions。
+// 旧 checkpoint 没有 exact 版本时,继续使用 v_next 兼容语义。
+func (s *PerEditSnapshotStore) contentAtCheckpointStateLocked(
+ hash string,
+ cp CheckpointMeta,
+ fallbackIfMissing bool,
+) ([]byte, bool, bool, os.FileMode, string, error) {
+ if version, ok := cp.ExactFileVersions[hash]; ok {
+ meta, err := s.readVersionMeta(hash, version)
+ if err != nil {
+ return nil, false, false, 0, "", fmt.Errorf("per-edit: read exact meta v%d for %s: %w", version, hash, err)
+ }
+ display := s.resolveDisplayPathLocked(hash, meta.DisplayPath)
+ if !meta.Existed {
+ return nil, false, false, meta.Mode, display, nil
+ }
+ if meta.IsDir {
+ return nil, true, true, meta.Mode, display, nil
+ }
+ content, err := s.readVersionBin(hash, version)
+ if err != nil {
+ return nil, false, false, 0, display, fmt.Errorf("per-edit: read exact bin v%d for %s: %w", version, hash, err)
+ }
+ return content, false, true, meta.Mode, display, nil
+ }
+ return s.contentAtCheckpointLocked(hash, cp.FileVersions, fallbackIfMissing)
+}
+
// contentAtCheckpointLocked 计算 hash 在某个 checkpoint 时刻的 workdir 内容。
// 在 cp.FileVersions 中:找下一版本读 .bin(或 Existed=false 时返回 nil);
// 没有下一版本时:以当前 workdir 实际内容为准。
diff --git a/internal/checkpoint/per_edit_snapshot_test.go b/internal/checkpoint/per_edit_snapshot_test.go
index 27b7d94b..1ff35005 100644
--- a/internal/checkpoint/per_edit_snapshot_test.go
+++ b/internal/checkpoint/per_edit_snapshot_test.go
@@ -1733,3 +1733,93 @@ func TestRunAggregateDiff_HistoricalFileFilteredByVersion(t *testing.T) {
t.Fatalf("patch should NOT contain a.txt (filtered by version):\n%s", patch)
}
}
+
+func TestRestore_DoesNotUseGlobalHistoryWhenNoRelatedScope(t *testing.T) {
+ store, workdir := newTestStore(t)
+
+ target := writeWorkdirFile(t, workdir, "target.txt", "old\n")
+ if _, err := store.CapturePreWrite(target); err != nil {
+ t.Fatalf("capture target: %v", err)
+ }
+ if err := os.WriteFile(target, []byte("new\n"), 0o644); err != nil {
+ t.Fatalf("write target: %v", err)
+ }
+ if _, err := store.FinalizeWithExactState("cp-target"); err != nil {
+ t.Fatalf("finalize target: %v", err)
+ }
+ store.Reset()
+
+ unrelated := writeWorkdirFile(t, workdir, "unrelated.txt", "must stay\n")
+ if _, err := store.CapturePreWrite(unrelated); err != nil {
+ t.Fatalf("capture unrelated: %v", err)
+ }
+ if err := store.Restore(context.Background(), "cp-target", ""); err != nil {
+ t.Fatalf("restore cp-target: %v", err)
+ }
+ if got := mustReadFile(t, unrelated); got != "must stay\n" {
+ t.Fatalf("unrelated file should not be touched, got %q", got)
+ }
+}
+
+func TestRestore_WithRelatedScopeSkipsUnchangedFilesFromFullRebase(t *testing.T) {
+ store, workdir := newTestStore(t)
+
+ agentPath := writeWorkdirFile(t, workdir, "agent.txt", "baseline\n")
+ unrelatedPath := writeWorkdirFile(t, workdir, "unrelated.txt", "keep\n")
+ if _, err := store.CaptureBatch([]string{agentPath, unrelatedPath}); err != nil {
+ t.Fatalf("CaptureBatch baseline: %v", err)
+ }
+ if _, err := store.FinalizeWithExactState("cp-rebase"); err != nil {
+ t.Fatalf("FinalizeWithExactState(cp-rebase): %v", err)
+ }
+ store.Reset()
+
+ if _, err := store.CapturePreWrite(agentPath); err != nil {
+ t.Fatalf("CapturePreWrite agent: %v", err)
+ }
+ if err := os.WriteFile(agentPath, []byte("agent edit\n"), 0o644); err != nil {
+ t.Fatalf("write agent edit: %v", err)
+ }
+ if _, err := store.FinalizeWithExactState("cp-agent"); err != nil {
+ t.Fatalf("FinalizeWithExactState(cp-agent): %v", err)
+ }
+ store.Reset()
+
+ if err := os.WriteFile(unrelatedPath, []byte("user post-run edit\n"), 0o644); err != nil {
+ t.Fatalf("write unrelated post-run edit: %v", err)
+ }
+ if err := store.Restore(context.Background(), "cp-rebase", "", "cp-agent"); err != nil {
+ t.Fatalf("Restore(cp-rebase): %v", err)
+ }
+ if got := mustReadFile(t, agentPath); got != "baseline\n" {
+ t.Fatalf("agent file = %q, want baseline", got)
+ }
+ if got := mustReadFile(t, unrelatedPath); got != "user post-run edit\n" {
+ t.Fatalf("unrelated file should not be restored, got %q", got)
+ }
+}
+
+func TestRestoreExact_PrefersExactCheckpointState(t *testing.T) {
+ store, workdir := newTestStore(t)
+ target := writeWorkdirFile(t, workdir, "exact.txt", "before\n")
+ if _, err := store.CapturePreWrite(target); err != nil {
+ t.Fatalf("capture: %v", err)
+ }
+ if err := os.WriteFile(target, []byte("checkpoint state\n"), 0o644); err != nil {
+ t.Fatalf("write checkpoint state: %v", err)
+ }
+ if _, err := store.FinalizeWithExactState("cp-exact"); err != nil {
+ t.Fatalf("FinalizeWithExactState: %v", err)
+ }
+ store.Reset()
+
+ if err := os.WriteFile(target, []byte("drift\n"), 0o644); err != nil {
+ t.Fatalf("write drift: %v", err)
+ }
+ if err := store.RestoreExact(context.Background(), "cp-exact"); err != nil {
+ t.Fatalf("RestoreExact: %v", err)
+ }
+ if got := mustReadFile(t, target); got != "checkpoint state\n" {
+ t.Fatalf("RestoreExact should use exact state, got %q", got)
+ }
+}
diff --git a/internal/cli/gateway_runtime_bridge.go b/internal/cli/gateway_runtime_bridge.go
index a4ec44f5..d8db9ca5 100644
--- a/internal/cli/gateway_runtime_bridge.go
+++ b/internal/cli/gateway_runtime_bridge.go
@@ -2456,6 +2456,8 @@ func (b *gatewayRuntimePortBridge) CheckpointDiff(ctx context.Context, input gat
Deleted: result.Files.Deleted,
Modified: result.Files.Modified,
},
- Patch: result.Patch,
+ Patch: result.Patch,
+ WorkspaceDrifted: result.WorkspaceDrifted,
+ Warning: result.Warning,
}, nil
}
diff --git a/internal/gateway/contracts.go b/internal/gateway/contracts.go
index 69f6692e..533cac1e 100644
--- a/internal/gateway/contracts.go
+++ b/internal/gateway/contracts.go
@@ -476,6 +476,8 @@ type CheckpointDiffResult struct {
PrevCommitHash string `json:"prev_commit_hash,omitempty"`
Files FileDiffs `json:"files"`
Patch string `json:"patch,omitempty"`
+ WorkspaceDrifted bool `json:"workspace_drifted,omitempty"`
+ Warning string `json:"warning,omitempty"`
}
// FileDiffs 描述 diff 中的文件变更列表。
diff --git a/internal/runtime/checkpoint_flow_test.go b/internal/runtime/checkpoint_flow_test.go
index 3075ac75..44cd0a6d 100644
--- a/internal/runtime/checkpoint_flow_test.go
+++ b/internal/runtime/checkpoint_flow_test.go
@@ -11,6 +11,7 @@ import (
"neo-code/internal/checkpoint"
providertypes "neo-code/internal/provider/types"
+ "neo-code/internal/repository"
agentsession "neo-code/internal/session"
)
@@ -207,6 +208,370 @@ func TestCreateStartOfTurnCheckpoint_NoPending_SessionOnly(t *testing.T) {
}
}
+func TestCreatePreRunDriftRebaseCheckpoint_UsesDriftReasonAndPerEditRef(t *testing.T) {
+ fixture := newRuntimeCheckpointFixture(t)
+
+ absKeep := filepath.Join(fixture.workdir, "keep.txt")
+ if err := os.WriteFile(absKeep, []byte("keep"), 0o644); err != nil {
+ t.Fatalf("WriteFile(keep) error = %v", err)
+ }
+ absDeleted := filepath.Join(fixture.workdir, "deleted.txt")
+ if err := os.WriteFile(absDeleted, []byte("deleted"), 0o644); err != nil {
+ t.Fatalf("WriteFile(deleted) error = %v", err)
+ }
+ if err := os.Remove(absDeleted); err != nil {
+ t.Fatalf("Remove(deleted) error = %v", err)
+ }
+
+ state := newRunState("run-drift", fixture.session)
+ checkpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint(context.Background(), &state, repository.FingerprintDiff{
+ Deleted: []string{"deleted.txt"},
+ })
+ if err != nil {
+ t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err)
+ }
+ if checkpointID == "" {
+ t.Fatal("expected non-empty drift baseline checkpoint id")
+ }
+
+ records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{})
+ if err != nil {
+ t.Fatalf("ListCheckpoints() error = %v", err)
+ }
+ if len(records) != 1 {
+ t.Fatalf("expected 1 checkpoint, got %d", len(records))
+ }
+ if records[0].Reason != agentsession.CheckpointReasonPreRunDriftRebase {
+ t.Fatalf("reason = %s, want %s", records[0].Reason, agentsession.CheckpointReasonPreRunDriftRebase)
+ }
+ if !checkpoint.IsPerEditRef(records[0].CodeCheckpointRef) {
+ t.Fatalf("code ref = %q, want peredit ref", records[0].CodeCheckpointRef)
+ }
+}
+
+func TestCreatePreRunDriftRebaseCheckpoint_UsesEffectiveWorkdirWhenSessionWorkdirEmpty(t *testing.T) {
+ fixture := newRuntimeCheckpointFixture(t)
+
+ absPath := filepath.Join(fixture.workdir, "effective.txt")
+ if err := os.WriteFile(absPath, []byte("effective"), 0o644); err != nil {
+ t.Fatalf("WriteFile(effective) error = %v", err)
+ }
+
+ session := fixture.session
+ session.Workdir = ""
+ state := newRunState("run-effective-workdir", session)
+ state.effectiveWorkdir = fixture.workdir
+ checkpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint(context.Background(), &state, repository.FingerprintDiff{
+ Added: []string{"effective.txt"},
+ })
+ if err != nil {
+ t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err)
+ }
+ if checkpointID == "" {
+ t.Fatal("expected non-empty drift baseline checkpoint id")
+ }
+
+ records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{})
+ if err != nil {
+ t.Fatalf("ListCheckpoints() error = %v", err)
+ }
+ if len(records) != 1 {
+ t.Fatalf("expected 1 checkpoint, got %d", len(records))
+ }
+ if records[0].Workdir != fixture.workdir {
+ t.Fatalf("record workdir = %q, want %q", records[0].Workdir, fixture.workdir)
+ }
+ if records[0].WorkspaceKey != agentsession.WorkspacePathKey(fixture.workdir) {
+ t.Fatalf("workspace key = %q, want %q", records[0].WorkspaceKey, agentsession.WorkspacePathKey(fixture.workdir))
+ }
+}
+
+func TestRecordRunEndWorkspaceStateUsesEffectiveWorkdirWhenSessionWorkdirEmpty(t *testing.T) {
+ fixture := newRuntimeCheckpointFixture(t)
+
+ absPath := filepath.Join(fixture.workdir, "run-end.txt")
+ if err := os.WriteFile(absPath, []byte("run end"), 0o644); err != nil {
+ t.Fatalf("WriteFile(run-end) error = %v", err)
+ }
+
+ session := fixture.session
+ session.Workdir = ""
+ state := newRunState("run-end-effective-workdir", session)
+ state.effectiveWorkdir = fixture.workdir
+ fixture.service.recordRunEndWorkspaceState(
+ context.Background(),
+ session.ID,
+ effectiveWorkdirForCheckpointState(&state, session),
+ "cp-current-effective",
+ )
+
+ workspaceKey := agentsession.WorkspacePathKey(fixture.workdir)
+ loaded, ok, err := fixture.checkpointStore.LoadWorkspaceCheckpointState(context.Background(), workspaceKey)
+ if err != nil {
+ t.Fatalf("LoadWorkspaceCheckpointState() error = %v", err)
+ }
+ if !ok {
+ t.Fatal("expected workspace checkpoint state to be persisted")
+ }
+ if loaded.CurrentCheckpointID != "cp-current-effective" {
+ t.Fatalf("current checkpoint id = %q, want cp-current-effective", loaded.CurrentCheckpointID)
+ }
+ if loaded.WorkspaceKey != workspaceKey {
+ t.Fatalf("workspace key = %q, want %q", loaded.WorkspaceKey, workspaceKey)
+ }
+}
+
+func TestCheckpointDiff_ScopeRun_RecreatedFileAfterDriftBaselineIsAdded(t *testing.T) {
+ fixture := newRuntimeCheckpointFixture(t)
+ state := newRunState("run-drift-add", fixture.session)
+
+ absPath := filepath.Join(fixture.workdir, "recreate.txt")
+ if err := os.WriteFile(absPath, []byte("legacy"), 0o644); err != nil {
+ t.Fatalf("WriteFile(legacy) error = %v", err)
+ }
+ if err := os.Remove(absPath); err != nil {
+ t.Fatalf("Remove(recreate.txt) error = %v", err)
+ }
+
+ rebasedCheckpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint(
+ context.Background(),
+ &state,
+ repository.FingerprintDiff{Deleted: []string{"recreate.txt"}},
+ )
+ if err != nil {
+ t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err)
+ }
+ if rebasedCheckpointID == "" {
+ t.Fatal("expected drift rebase checkpoint id")
+ }
+
+ if _, err := fixture.perEditStore.CapturePreWrite(absPath); err != nil {
+ t.Fatalf("CapturePreWrite(recreate) error = %v", err)
+ }
+ if err := os.WriteFile(absPath, []byte("new"), 0o644); err != nil {
+ t.Fatalf("WriteFile(new) error = %v", err)
+ }
+ if _, err := fixture.perEditStore.FinalizeWithExactState("cp-end"); err != nil {
+ t.Fatalf("FinalizeWithExactState(cp-end) error = %v", err)
+ }
+ fixture.perEditStore.Reset()
+ if err := fixture.service.createCheckpointRecord(
+ context.Background(),
+ fixture.session,
+ state.runID,
+ &state,
+ "cp-end",
+ agentsession.CheckpointReasonEndOfTurn,
+ ); err != nil {
+ t.Fatalf("createCheckpointRecord(cp-end) error = %v", err)
+ }
+
+ fixture.service.setRunWorkspaceDrift(fixture.session.ID, state.runID, true)
+ fixture.service.setRunRollbackBaseline(fixture.session.ID, state.runID, rebasedCheckpointID)
+ result, err := fixture.service.CheckpointDiff(context.Background(), CheckpointDiffInput{
+ SessionID: fixture.session.ID,
+ Scope: "run",
+ RunID: state.runID,
+ CheckpointID: "cp-end",
+ })
+ if err != nil {
+ t.Fatalf("CheckpointDiff(scope=run) error = %v", err)
+ }
+ if result.PrevCheckpointID != rebasedCheckpointID {
+ t.Fatalf("PrevCheckpointID = %q, want %q", result.PrevCheckpointID, rebasedCheckpointID)
+ }
+ if len(result.Files.Added) != 1 || result.Files.Added[0] != "recreate.txt" {
+ t.Fatalf("added files = %+v, want recreate.txt", result.Files.Added)
+ }
+ if len(result.Files.Modified) != 0 {
+ t.Fatalf("modified files = %+v, want empty", result.Files.Modified)
+ }
+}
+
+func TestCheckpointDiff_ScopeRun_LoadsPersistedBaselineAfterCacheClear(t *testing.T) {
+ fixture := newRuntimeCheckpointFixture(t)
+ state := newRunState("run-persisted-baseline", fixture.session)
+
+ absPath := filepath.Join(fixture.workdir, "persisted-baseline.txt")
+ if err := os.WriteFile(absPath, []byte("manual\n"), 0o644); err != nil {
+ t.Fatalf("WriteFile(manual) error = %v", err)
+ }
+ rebasedCheckpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint(
+ context.Background(),
+ &state,
+ repository.FingerprintDiff{Added: []string{"persisted-baseline.txt"}},
+ )
+ if err != nil {
+ t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err)
+ }
+ fixture.service.persistRunRollbackBaseline(context.Background(), fixture.session.ID, state.runID, rebasedCheckpointID, true)
+ fixture.service.clearRunCheckpointCaches(fixture.session.ID, state.runID)
+
+ if _, err := fixture.perEditStore.CapturePreWrite(absPath); err != nil {
+ t.Fatalf("CapturePreWrite(agent) error = %v", err)
+ }
+ if err := os.WriteFile(absPath, []byte("agent\n"), 0o644); err != nil {
+ t.Fatalf("WriteFile(agent) error = %v", err)
+ }
+ if _, err := fixture.perEditStore.FinalizeWithExactState("cp-end-persisted"); err != nil {
+ t.Fatalf("FinalizeWithExactState(cp-end-persisted) error = %v", err)
+ }
+ fixture.perEditStore.Reset()
+ if err := fixture.service.createCheckpointRecord(
+ context.Background(),
+ fixture.session,
+ state.runID,
+ &state,
+ "cp-end-persisted",
+ agentsession.CheckpointReasonEndOfTurn,
+ ); err != nil {
+ t.Fatalf("createCheckpointRecord(cp-end-persisted) error = %v", err)
+ }
+
+ result, err := fixture.service.CheckpointDiff(context.Background(), CheckpointDiffInput{
+ SessionID: fixture.session.ID,
+ Scope: "run",
+ RunID: state.runID,
+ CheckpointID: "cp-end-persisted",
+ })
+ if err != nil {
+ t.Fatalf("CheckpointDiff(scope=run) error = %v", err)
+ }
+ if result.PrevCheckpointID != rebasedCheckpointID {
+ t.Fatalf("PrevCheckpointID = %q, want %q", result.PrevCheckpointID, rebasedCheckpointID)
+ }
+ if !result.WorkspaceDrifted {
+ t.Fatal("expected persisted drift flag")
+ }
+}
+
+func TestRestoreCheckpoint_AfterExternalModifyRestoresToRebasedState(t *testing.T) {
+ fixture := newRuntimeCheckpointFixture(t)
+ state := newRunState("run-restore-rebased", fixture.session)
+
+ absPath := filepath.Join(fixture.workdir, "external.txt")
+ if err := os.WriteFile(absPath, []byte("manual\n"), 0o644); err != nil {
+ t.Fatalf("WriteFile(manual) error = %v", err)
+ }
+ unrelatedPath := filepath.Join(fixture.workdir, "unrelated.txt")
+ if err := os.WriteFile(unrelatedPath, []byte("keep\n"), 0o644); err != nil {
+ t.Fatalf("WriteFile(unrelated) error = %v", err)
+ }
+ rebasedCheckpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint(
+ context.Background(),
+ &state,
+ repository.FingerprintDiff{Modified: []string{"external.txt"}},
+ )
+ if err != nil {
+ t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err)
+ }
+
+ if _, err := fixture.perEditStore.CapturePreWrite(absPath); err != nil {
+ t.Fatalf("CapturePreWrite(agent) error = %v", err)
+ }
+ if err := os.WriteFile(absPath, []byte("agent\n"), 0o644); err != nil {
+ t.Fatalf("WriteFile(agent) error = %v", err)
+ }
+ if _, err := fixture.perEditStore.FinalizeWithExactState("cp-agent-end"); err != nil {
+ t.Fatalf("FinalizeWithExactState(cp-agent-end) error = %v", err)
+ }
+ fixture.perEditStore.Reset()
+ if err := fixture.service.createCheckpointRecord(
+ context.Background(),
+ fixture.session,
+ state.runID,
+ &state,
+ "cp-agent-end",
+ agentsession.CheckpointReasonEndOfTurn,
+ ); err != nil {
+ t.Fatalf("createCheckpointRecord(cp-agent-end) error = %v", err)
+ }
+ if err := os.WriteFile(unrelatedPath, []byte("post-run user edit\n"), 0o644); err != nil {
+ t.Fatalf("WriteFile(unrelated post-run) error = %v", err)
+ }
+
+ if _, err := fixture.service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{
+ SessionID: fixture.session.ID,
+ CheckpointID: rebasedCheckpointID,
+ }); err != nil {
+ t.Fatalf("RestoreCheckpoint() error = %v", err)
+ }
+ data, err := os.ReadFile(absPath)
+ if err != nil {
+ t.Fatalf("ReadFile(restored) error = %v", err)
+ }
+ if got := string(data); got != "manual\n" {
+ t.Fatalf("restored content = %q, want manual", got)
+ }
+ unrelatedData, err := os.ReadFile(unrelatedPath)
+ if err != nil {
+ t.Fatalf("ReadFile(unrelated) error = %v", err)
+ }
+ if got := string(unrelatedData); got != "post-run user edit\n" {
+ t.Fatalf("unrelated content = %q, want post-run user edit", got)
+ }
+}
+
+func TestRecordRunStartFingerprint_DetectsDriftAcrossSessionsByWorkspace(t *testing.T) {
+ fixture := newRuntimeCheckpointFixture(t)
+ service := fixture.service
+ ctx := context.Background()
+
+ absPath := filepath.Join(fixture.workdir, "cross-session.txt")
+ if err := os.WriteFile(absPath, []byte("before"), 0o644); err != nil {
+ t.Fatalf("WriteFile(before) error = %v", err)
+ }
+ service.recordRunEndFingerprint(ctx, "sess-a", fixture.workdir)
+
+ if err := os.Remove(absPath); err != nil {
+ t.Fatalf("Remove(cross-session.txt) error = %v", err)
+ }
+
+ drifted, driftDiff := service.recordRunStartFingerprint(ctx, "sess-b", "run-b", fixture.workdir)
+ if !drifted {
+ t.Fatal("expected drift across sessions on the same workspace")
+ }
+ if !containsString(driftDiff.Deleted, "cross-session.txt") {
+ t.Fatalf("deleted diff = %#v, want cross-session.txt", driftDiff.Deleted)
+ }
+}
+
+func TestRecordRunStartFingerprint_LoadsWorkspaceFingerprintFromPersistence(t *testing.T) {
+ fixture := newRuntimeCheckpointFixture(t)
+ ctx := context.Background()
+
+ absPath := filepath.Join(fixture.workdir, "persisted.txt")
+ if err := os.WriteFile(absPath, []byte("before"), 0o644); err != nil {
+ t.Fatalf("WriteFile(before) error = %v", err)
+ }
+ fixture.service.recordRunEndFingerprint(ctx, "sess-a", fixture.workdir)
+
+ if err := os.Remove(absPath); err != nil {
+ t.Fatalf("Remove(persisted.txt) error = %v", err)
+ }
+
+ // 模拟进程重启:新 Service 仅复用 checkpointStore,不复用内存指纹缓存。
+ restarted := &Service{
+ checkpointStore: fixture.checkpointStore,
+ }
+ drifted, driftDiff := restarted.recordRunStartFingerprint(ctx, "sess-b", "run-b", fixture.workdir)
+ if !drifted {
+ t.Fatal("expected drift when loading workspace fingerprint from persistence")
+ }
+ if !containsString(driftDiff.Deleted, "persisted.txt") {
+ t.Fatalf("deleted diff = %#v, want persisted.txt", driftDiff.Deleted)
+ }
+}
+
+func containsString(items []string, target string) bool {
+ for _, item := range items {
+ if item == target {
+ return true
+ }
+ }
+ return false
+}
+
func TestCreateEndOfTurnCheckpoint_NoWriteSkipped(t *testing.T) {
fixture := newRuntimeCheckpointFixture(t)
fixture.captureFile(t, "main.go", []byte("package main\n"))
@@ -924,6 +1289,9 @@ func TestCheckpointDiffRunScopeAggregatesCurrentRun(t *testing.T) {
if result.CheckpointID != "cp-2" {
t.Fatalf("CheckpointID = %q, want cp-2", result.CheckpointID)
}
+ if result.PrevCheckpointID != "" {
+ t.Fatalf("PrevCheckpointID = %q, want empty", result.PrevCheckpointID)
+ }
if len(result.Files.Modified) != 1 || result.Files.Modified[0] != "tracked.txt" {
t.Fatalf("modified files = %+v, want tracked.txt", result.Files.Modified)
}
@@ -941,7 +1309,7 @@ func mustReadRuntimeFile(t *testing.T, path string) []byte {
return data
}
-// ──────── scope=run diff tests ────────
+// scope=run diff tests
func TestCheckpointDiff_ScopeRun_ReturnsAggregateDiff(t *testing.T) {
workdir := t.TempDir()
@@ -988,6 +1356,15 @@ func TestCheckpointDiff_ScopeRun_ReturnsAggregateDiff(t *testing.T) {
CreatedAt: now,
CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-1"),
},
+ {
+ CheckpointID: "cp-0",
+ SessionID: "session-1",
+ RunID: "run-prev",
+ CreatedAt: now.Add(-time.Second),
+ Reason: agentsession.CheckpointReasonEndOfTurn,
+ Status: agentsession.CheckpointStatusAvailable,
+ CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-0"),
+ },
},
}
service := &Service{
@@ -1026,10 +1403,12 @@ func TestCheckpointDiff_ScopeRun_ReturnsAggregateDiff(t *testing.T) {
if len(modifiedPaths) != 1 || modifiedPaths[0] != "a.txt" {
t.Fatalf("expected a.txt modified, got added=%v modified=%v", addedPaths, modifiedPaths)
}
- // 当前 run-scope diff 默认返回目标 checkpoint(未显式指定时为最新 checkpoint)。
if result.CheckpointID != "cp-2" {
t.Fatalf("CheckpointID = %q, want cp-2", result.CheckpointID)
}
+ if result.PrevCheckpointID != "cp-0" {
+ t.Fatalf("PrevCheckpointID = %q, want cp-0", result.PrevCheckpointID)
+ }
}
func TestCheckpointDiff_ScopeRun_RejectsMissingRunID(t *testing.T) {
@@ -1077,6 +1456,161 @@ func TestCheckpointDiff_ScopeRun_NoCheckpointsForRunID(t *testing.T) {
}
}
+func TestCheckpointDiff_ScopeRun_RejectsTargetCheckpointFromAnotherRun(t *testing.T) {
+ spy := &checkpointStoreSpy{
+ listRecords: []agentsession.CheckpointRecord{
+ {
+ CheckpointID: "cp-other-run",
+ SessionID: "session-1",
+ RunID: "run-other",
+ CreatedAt: time.Now().UTC(),
+ CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-other-run"),
+ },
+ },
+ }
+ service := &Service{
+ checkpointStore: spy,
+ perEditStore: checkpoint.NewPerEditSnapshotStore(t.TempDir(), t.TempDir()),
+ }
+ _, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{
+ SessionID: "session-1",
+ Scope: "run",
+ RunID: "run-target",
+ CheckpointID: "cp-other-run",
+ })
+ if err == nil {
+ t.Fatal("expected error for target checkpoint from another run")
+ }
+ if !strings.Contains(err.Error(), "does not belong to run") {
+ t.Fatalf("error = %v, want run mismatch", err)
+ }
+}
+
+func TestCheckpointDiff_ScopeRun_WarnsWhenBaselineMissing(t *testing.T) {
+ workdir := t.TempDir()
+ projectDir := t.TempDir()
+ store := checkpoint.NewPerEditSnapshotStore(projectDir, workdir)
+ now := time.Now().UTC()
+
+ absA := filepath.Join(workdir, "a.txt")
+ _ = os.WriteFile(absA, []byte("old a\n"), 0o644)
+ if _, err := store.CapturePreWrite(absA); err != nil {
+ t.Fatalf("CapturePreWrite a: %v", err)
+ }
+ _ = os.WriteFile(absA, []byte("new a\n"), 0o644)
+ if _, err := store.Finalize("cp-1"); err != nil {
+ t.Fatalf("Finalize cp-1: %v", err)
+ }
+ store.Reset()
+
+ spy := &checkpointStoreSpy{
+ listRecords: []agentsession.CheckpointRecord{
+ {
+ CheckpointID: "cp-1",
+ SessionID: "session-1",
+ RunID: "run-target",
+ CreatedAt: now,
+ CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-1"),
+ },
+ },
+ }
+ service := &Service{
+ checkpointStore: spy,
+ perEditStore: store,
+ }
+
+ result, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{
+ SessionID: "session-1",
+ Scope: "run",
+ RunID: "run-target",
+ })
+ if err != nil {
+ t.Fatalf("CheckpointDiff(scope=run) error = %v", err)
+ }
+ if result.PrevCheckpointID != "" {
+ t.Fatalf("PrevCheckpointID = %q, want empty", result.PrevCheckpointID)
+ }
+ if !strings.Contains(result.Warning, "run baseline checkpoint is missing") {
+ t.Fatalf("Warning = %q, want missing-baseline warning", result.Warning)
+ }
+}
+
+func TestCheckpointDiff_ScopeRun_PrefersRebasedBaselineWhenDrifted(t *testing.T) {
+ workdir := t.TempDir()
+ projectDir := t.TempDir()
+ store := checkpoint.NewPerEditSnapshotStore(projectDir, workdir)
+ now := time.Now().UTC()
+
+ absA := filepath.Join(workdir, "a.txt")
+ _ = os.WriteFile(absA, []byte("one\n"), 0o644)
+ if _, err := store.CapturePreWrite(absA); err != nil {
+ t.Fatalf("CapturePreWrite cp-drift: %v", err)
+ }
+ _ = os.WriteFile(absA, []byte("two\n"), 0o644)
+ if _, err := store.Finalize("cp-drift"); err != nil {
+ t.Fatalf("Finalize cp-drift: %v", err)
+ }
+ store.Reset()
+
+ if _, err := store.CapturePreWrite(absA); err != nil {
+ t.Fatalf("CapturePreWrite cp-target: %v", err)
+ }
+ _ = os.WriteFile(absA, []byte("three\n"), 0o644)
+ if _, err := store.Finalize("cp-target"); err != nil {
+ t.Fatalf("Finalize cp-target: %v", err)
+ }
+ store.Reset()
+
+ spy := &checkpointStoreSpy{
+ listRecords: []agentsession.CheckpointRecord{
+ {
+ CheckpointID: "cp-target",
+ SessionID: "session-1",
+ RunID: "run-target",
+ CreatedAt: now.Add(time.Second),
+ CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-target"),
+ },
+ {
+ CheckpointID: "cp-drift",
+ SessionID: "session-1",
+ RunID: "run-target",
+ CreatedAt: now,
+ CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-drift"),
+ },
+ {
+ CheckpointID: "cp-old",
+ SessionID: "session-1",
+ RunID: "run-prev",
+ CreatedAt: now.Add(-time.Second),
+ Reason: agentsession.CheckpointReasonEndOfTurn,
+ Status: agentsession.CheckpointStatusAvailable,
+ CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-old"),
+ },
+ },
+ }
+ service := &Service{
+ checkpointStore: spy,
+ perEditStore: store,
+ }
+ service.setRunWorkspaceDrift("session-1", "run-target", true)
+ service.setRunRollbackBaseline("session-1", "run-target", "cp-drift")
+
+ result, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{
+ SessionID: "session-1",
+ Scope: "run",
+ RunID: "run-target",
+ })
+ if err != nil {
+ t.Fatalf("CheckpointDiff(scope=run) error = %v", err)
+ }
+ if result.PrevCheckpointID != "cp-drift" {
+ t.Fatalf("PrevCheckpointID = %q, want cp-drift", result.PrevCheckpointID)
+ }
+ if !strings.Contains(result.Warning, "baseline re-anchored") {
+ t.Fatalf("Warning = %q, want re-anchored hint", result.Warning)
+ }
+}
+
func TestCheckpointDiff_DefaultScopePreservesExistingBehavior(t *testing.T) {
// Verify empty scope still uses checkpoint-to-checkpoint comparison.
workdir := t.TempDir()
diff --git a/internal/runtime/checkpoint_gate.go b/internal/runtime/checkpoint_gate.go
index bae5f42f..db4dfd5c 100644
--- a/internal/runtime/checkpoint_gate.go
+++ b/internal/runtime/checkpoint_gate.go
@@ -5,10 +5,13 @@ import (
"encoding/json"
"fmt"
"log"
+ "path/filepath"
+ "sort"
"strings"
"time"
"neo-code/internal/checkpoint"
+ "neo-code/internal/repository"
agentsession "neo-code/internal/session"
)
@@ -38,6 +41,73 @@ func (s *Service) createStartOfTurnCheckpoint(ctx context.Context, state *runSta
return s.createCheckpointRecord(ctx, session, runID, state, checkpointID, agentsession.CheckpointReasonPreWrite)
}
+// createPreRunDriftRebaseCheckpoint 在 run 开始前检测到外部漂移时,固化当前工作区为权威基线。
+func (s *Service) createPreRunDriftRebaseCheckpoint(
+ ctx context.Context,
+ state *runState,
+ diff repository.FingerprintDiff,
+) (string, error) {
+ if s.checkpointStore == nil || s.perEditStore == nil {
+ return "", nil
+ }
+
+ state.mu.Lock()
+ session := state.session
+ runID := state.runID
+ workdir := effectiveWorkdirForCheckpointState(state, session)
+ state.mu.Unlock()
+
+ if workdir == "" {
+ return "", fmt.Errorf("checkpoint: workdir is empty when creating drift rebase checkpoint")
+ }
+
+ currentFingerprint, _, err := repository.ScanWorkdir(ctx, workdir, repository.DefaultFingerprintOptions())
+ if err != nil {
+ return "", fmt.Errorf("checkpoint: scan workdir for drift rebase: %w", err)
+ }
+ absPaths := make([]string, 0, len(currentFingerprint))
+ for relPath := range currentFingerprint {
+ absPaths = append(absPaths, filepath.Join(workdir, filepath.FromSlash(relPath)))
+ }
+ sort.Strings(absPaths)
+ if len(absPaths) > 0 {
+ if _, err := s.perEditStore.CaptureBatch(absPaths); err != nil {
+ return "", fmt.Errorf("checkpoint: capture current files for drift rebase: %w", err)
+ }
+ }
+
+ deleted := append([]string(nil), diff.Deleted...)
+ sort.Strings(deleted)
+ for _, relPath := range deleted {
+ absPath := filepath.Join(workdir, filepath.FromSlash(relPath))
+ if _, err := s.perEditStore.CapturePreWrite(absPath); err != nil {
+ return "", fmt.Errorf("checkpoint: capture deleted file for drift rebase (%s): %w", relPath, err)
+ }
+ }
+
+ checkpointID := agentsession.NewID("checkpoint")
+ written, err := s.perEditStore.FinalizeWithExactState(checkpointID)
+ if err != nil {
+ return "", fmt.Errorf("checkpoint: finalize drift rebase checkpoint: %w", err)
+ }
+ if !written {
+ return "", fmt.Errorf("checkpoint: drift rebase checkpoint has no captured files")
+ }
+ defer s.perEditStore.Reset()
+
+ if err := s.createCheckpointRecord(
+ ctx,
+ session,
+ runID,
+ state,
+ checkpointID,
+ agentsession.CheckpointReasonPreRunDriftRebase,
+ ); err != nil {
+ return "", err
+ }
+ return checkpointID, nil
+}
+
// createEndOfTurnCheckpoint 在工具执行完成后创建代码检查点。
// hasWorkspaceWrite=false 时不创建(避免空 checkpoint);为 true 时 Finalize 当前 pending。
// 失败仅 log,不阻塞主流程。
@@ -95,7 +165,7 @@ func (s *Service) createCheckpointRecord(
return fmt.Errorf("checkpoint: marshal messages: %w", err)
}
- effectiveWorkdir := strings.TrimSpace(session.Workdir)
+ effectiveWorkdir := effectiveWorkdirForCheckpointState(state, session)
now := time.Now()
ref := checkpoint.RefForPerEditCheckpoint(checkpointID)
@@ -159,12 +229,13 @@ func (s *Service) createSessionOnlyCheckpoint(
return fmt.Errorf("checkpoint: marshal session-only messages: %w", err)
}
+ effectiveWorkdir := effectiveWorkdirForCheckpointState(state, session)
record := agentsession.CheckpointRecord{
CheckpointID: checkpointID,
- WorkspaceKey: agentsession.WorkspacePathKey(session.Workdir),
+ WorkspaceKey: agentsession.WorkspacePathKey(effectiveWorkdir),
SessionID: session.ID,
RunID: runID,
- Workdir: session.Workdir,
+ Workdir: effectiveWorkdir,
CreatedAt: now,
Reason: reason,
Restorable: true,
@@ -224,3 +295,13 @@ func (s *Service) findPreviousEndOfTurnCheckpoint(ctx context.Context, sessionID
}
return ""
}
+
+// effectiveWorkdirForCheckpointState 返回 checkpoint 流程应使用的工作目录,优先使用本次 run 已归一化的目录。
+func effectiveWorkdirForCheckpointState(state *runState, session agentsession.Session) string {
+ if state != nil {
+ if workdir := strings.TrimSpace(state.effectiveWorkdir); workdir != "" {
+ return workdir
+ }
+ }
+ return strings.TrimSpace(session.Workdir)
+}
diff --git a/internal/runtime/checkpoint_restore.go b/internal/runtime/checkpoint_restore.go
index 4ba43124..ccc8366a 100644
--- a/internal/runtime/checkpoint_restore.go
+++ b/internal/runtime/checkpoint_restore.go
@@ -86,7 +86,9 @@ func (s *Service) restoreCheckpointCore(ctx context.Context, sessionID, checkpoi
return RestoreResult{}, agentsession.CheckpointRecord{}, fmt.Errorf("checkpoint: create guard: %w", guardErr)
}
- // 3. Restore code via per-edit store(不在 cp.FileVersions 中的文件保持不变)。
+ relatedPerEditIDs := s.restoreRelatedPerEditIDs(ctx, sessionID, record.CreatedAt)
+
+ // 3. Restore code via per-edit store(仅处理目标、guard 与后续相关 checkpoint 涉及的文件)。
// Guard checkpoint 恢复时使用 RestoreExact:guard 中存储的 version 就是 restore 前的 pre-write 状态,
// 而 Restore 的 v_next 语义在 guard 上通常是 no-op(guard 之后没有新的 capture)。
isGuardRestore := record.Reason == agentsession.CheckpointReasonGuard
@@ -102,7 +104,7 @@ func (s *Service) restoreCheckpointCore(ctx context.Context, sessionID, checkpoi
if guardWritten {
guardCheckpointID = guardID
}
- if err := s.perEditStore.Restore(ctx, perEditID, guardCheckpointID); err != nil {
+ if err := s.perEditStore.Restore(ctx, perEditID, guardCheckpointID, relatedPerEditIDs...); err != nil {
return RestoreResult{}, agentsession.CheckpointRecord{}, fmt.Errorf("checkpoint: restore code: %w", err)
}
}
@@ -149,6 +151,7 @@ func (s *Service) restoreCheckpointCore(ctx context.Context, sessionID, checkpoi
// 7. Update runtime session if it's the current session
s.updateRuntimeSessionAfterRestore(sessionID, head, messages)
+ s.recordRunEndWorkspaceState(ctx, sessionID, head.Workdir, checkpointID)
return RestoreResult{
CheckpointID: checkpointID,
@@ -156,6 +159,34 @@ func (s *Service) restoreCheckpointCore(ctx context.Context, sessionID, checkpoi
}, guardRecord, nil
}
+// restoreRelatedPerEditIDs 收集目标 checkpoint 之后仍 available 的代码 checkpoint,用于限定 restore 清理范围。
+func (s *Service) restoreRelatedPerEditIDs(ctx context.Context, sessionID string, targetCreatedAt time.Time) []string {
+ if s.checkpointStore == nil {
+ return nil
+ }
+ records, err := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{})
+ if err != nil {
+ return nil
+ }
+ out := make([]string, 0)
+ for _, record := range records {
+ if !record.CreatedAt.After(targetCreatedAt) {
+ continue
+ }
+ if record.Status != "" && record.Status != agentsession.CheckpointStatusAvailable {
+ continue
+ }
+ if record.Reason == agentsession.CheckpointReasonGuard {
+ continue
+ }
+ perEditID := checkpoint.PerEditCheckpointIDFromRef(record.CodeCheckpointRef)
+ if perEditID != "" {
+ out = append(out, perEditID)
+ }
+ }
+ return out
+}
+
// RestoreCheckpoint 恢复指定 checkpoint 的会话和工作区状态,并发出 checkpoint_restored 事件。
func (s *Service) RestoreCheckpoint(ctx context.Context, input GatewayRestoreInput) (RestoreResult, error) {
result, guardRecord, err := s.restoreCheckpointCore(ctx, input.SessionID, input.CheckpointID)
@@ -315,6 +346,8 @@ type CheckpointDiffResult struct {
PrevCommitHash string `json:"prev_commit_hash,omitempty"`
Files FileDiffs `json:"files"`
Patch string `json:"patch,omitempty"`
+ WorkspaceDrifted bool `json:"workspace_drifted,omitempty"`
+ Warning string `json:"warning,omitempty"`
}
// FileDiffs 描述 diff 中的文件变更列表。
@@ -452,6 +485,9 @@ func (s *Service) checkpointDiffForRun(ctx context.Context, input CheckpointDiff
if runID == "" {
return CheckpointDiffResult{}, fmt.Errorf("checkpoint: run_id required for run scope diff")
}
+ if targetRecord != nil && strings.TrimSpace(targetRecord.RunID) != runID {
+ return CheckpointDiffResult{}, fmt.Errorf("checkpoint: target checkpoint %s does not belong to run %s", targetRecord.CheckpointID, runID)
+ }
codeRecords := make([]agentsession.CheckpointRecord, 0)
for _, record := range records {
@@ -479,6 +515,11 @@ func (s *Service) checkpointDiffForRun(ctx context.Context, input CheckpointDiff
targetRecord = &codeRecords[len(codeRecords)-1]
}
+ prevCheckpointID, workspaceDrifted := s.getPersistentRunRollbackBaseline(ctx, sessionID, runID)
+ if prevCheckpointID == "" {
+ prevCheckpointID = resolveRunPrevCheckpointID(records, codeRecords[0].CreatedAt, runID)
+ }
+
perEditIDs := make([]string, 0, len(codeRecords))
for _, record := range codeRecords {
perEditID := checkpoint.PerEditCheckpointIDFromRef(record.CodeCheckpointRef)
@@ -492,7 +533,30 @@ func (s *Service) checkpointDiffForRun(ctx context.Context, input CheckpointDiff
return CheckpointDiffResult{}, fmt.Errorf("checkpoint: per-edit run diff: %w", err)
}
- result := CheckpointDiffResult{CheckpointID: targetRecord.CheckpointID, Patch: patch}
+ result := CheckpointDiffResult{
+ CheckpointID: targetRecord.CheckpointID,
+ PrevCheckpointID: prevCheckpointID,
+ Patch: patch,
+ WorkspaceDrifted: workspaceDrifted,
+ }
+ if result.WorkspaceDrifted {
+ if prevCheckpointID != "" {
+ result.Warning = "workspace drift detected before run start; baseline re-anchored"
+ } else {
+ result.Warning = "workspace drift detected before run start"
+ }
+ }
+ if result.PrevCheckpointID == "" {
+ if result.Warning != "" {
+ result.Warning += "; run baseline checkpoint is missing"
+ } else {
+ result.Warning = "run baseline checkpoint is missing"
+ }
+ }
+ if result.PrevCheckpointID != "" {
+ s.persistRunRollbackBaseline(ctx, sessionID, runID, result.PrevCheckpointID, result.WorkspaceDrifted)
+ }
+
for _, c := range changes {
switch c.Kind {
case checkpoint.FileChangeAdded:
@@ -505,3 +569,45 @@ func (s *Service) checkpointDiffForRun(ctx context.Context, input CheckpointDiff
}
return result, nil
}
+
+// resolveRunPrevCheckpointID 解析 run 级 diff 的权威 baseline checkpoint_id。
+// 优先选择 run 开始前最近的 end_of_turn 代码 checkpoint;若不存在,退化为最近的可用代码 checkpoint。
+func resolveRunPrevCheckpointID(records []agentsession.CheckpointRecord, runStartAt time.Time, runID string) string {
+ var preferred *agentsession.CheckpointRecord
+ var fallback *agentsession.CheckpointRecord
+ for i := range records {
+ record := records[i]
+ if strings.TrimSpace(record.RunID) == strings.TrimSpace(runID) {
+ continue
+ }
+ if record.Status != "" && record.Status != agentsession.CheckpointStatusAvailable {
+ continue
+ }
+ if !checkpoint.IsPerEditRef(record.CodeCheckpointRef) {
+ continue
+ }
+ if record.Reason == agentsession.CheckpointReasonGuard {
+ continue
+ }
+ if !record.CreatedAt.Before(runStartAt) {
+ continue
+ }
+ if fallback == nil || record.CreatedAt.After(fallback.CreatedAt) {
+ candidate := record
+ fallback = &candidate
+ }
+ if record.Reason == agentsession.CheckpointReasonEndOfTurn {
+ if preferred == nil || record.CreatedAt.After(preferred.CreatedAt) {
+ candidate := record
+ preferred = &candidate
+ }
+ }
+ }
+ if preferred != nil {
+ return preferred.CheckpointID
+ }
+ if fallback != nil {
+ return fallback.CheckpointID
+ }
+ return ""
+}
diff --git a/internal/runtime/checkpoint_run_baseline.go b/internal/runtime/checkpoint_run_baseline.go
new file mode 100644
index 00000000..a299d440
--- /dev/null
+++ b/internal/runtime/checkpoint_run_baseline.go
@@ -0,0 +1,277 @@
+package runtime
+
+import (
+ "context"
+ "encoding/json"
+ "strings"
+ "time"
+
+ "neo-code/internal/checkpoint"
+ "neo-code/internal/repository"
+ agentsession "neo-code/internal/session"
+)
+
+// buildRunBaselineKey 生成 run 级缓存 key,保证同一会话不同 run 的基线隔离。
+func buildRunBaselineKey(sessionID, runID string) string {
+ return strings.TrimSpace(sessionID) + "\n" + strings.TrimSpace(runID)
+}
+
+// ensureRunBaselineMapsLocked 懒初始化 run 级基线与漂移缓存,允许测试中用字面量 Service 直接调用。
+func (s *Service) ensureRunBaselineMapsLocked() {
+ if s.runRollbackBaselineByRunKey == nil {
+ s.runRollbackBaselineByRunKey = make(map[string]string)
+ }
+ if s.runWorkspaceDriftByRunKey == nil {
+ s.runWorkspaceDriftByRunKey = make(map[string]bool)
+ }
+ if s.lastRunFingerprintByWorkspaceKey == nil {
+ s.lastRunFingerprintByWorkspaceKey = make(map[string]repository.WorkdirFingerprint)
+ }
+}
+
+type workspaceFingerprintPersistenceStore interface {
+ SaveWorkspaceFingerprint(ctx context.Context, workspaceKey string, fingerprintPayload string, updatedAt time.Time) error
+ LoadWorkspaceFingerprint(ctx context.Context, workspaceKey string) (string, bool, error)
+}
+
+type workspaceCheckpointStatePersistenceStore interface {
+ SaveWorkspaceCheckpointState(ctx context.Context, state checkpoint.WorkspaceCheckpointState) error
+ LoadWorkspaceCheckpointState(ctx context.Context, workspaceKey string) (checkpoint.WorkspaceCheckpointState, bool, error)
+ SaveRunCheckpointBaseline(ctx context.Context, baseline checkpoint.RunCheckpointBaseline) error
+ LoadRunCheckpointBaseline(ctx context.Context, sessionID, runID string) (checkpoint.RunCheckpointBaseline, bool, error)
+}
+
+// workspaceKeyFromWorkdir 统一把工作目录映射为跨会话稳定 key。
+func workspaceKeyFromWorkdir(workdir string) string {
+ return strings.TrimSpace(agentsession.WorkspacePathKey(strings.TrimSpace(workdir)))
+}
+
+// setRunRollbackBaseline 记录当前 run 的权威回退基线 checkpoint_id(由后端计算,前端仅消费)。
+func (s *Service) setRunRollbackBaseline(sessionID, runID, checkpointID string) {
+ key := buildRunBaselineKey(sessionID, runID)
+ if key == "\n" {
+ return
+ }
+ s.rollbackBaselineMu.Lock()
+ defer s.rollbackBaselineMu.Unlock()
+ s.ensureRunBaselineMapsLocked()
+ if strings.TrimSpace(checkpointID) == "" {
+ delete(s.runRollbackBaselineByRunKey, key)
+ return
+ }
+ s.runRollbackBaselineByRunKey[key] = strings.TrimSpace(checkpointID)
+}
+
+// getRunRollbackBaseline 返回 run 的权威回退基线 checkpoint_id。
+func (s *Service) getRunRollbackBaseline(sessionID, runID string) string {
+ key := buildRunBaselineKey(sessionID, runID)
+ if key == "\n" {
+ return ""
+ }
+ s.rollbackBaselineMu.Lock()
+ defer s.rollbackBaselineMu.Unlock()
+ s.ensureRunBaselineMapsLocked()
+ return s.runRollbackBaselineByRunKey[key]
+}
+
+// getPersistentRunRollbackBaseline 从内存或 SQLite 读取 run 的权威回退基线。
+func (s *Service) getPersistentRunRollbackBaseline(ctx context.Context, sessionID, runID string) (string, bool) {
+ if baseline := s.getRunRollbackBaseline(sessionID, runID); baseline != "" {
+ return baseline, s.getRunWorkspaceDrift(sessionID, runID)
+ }
+ store, ok := s.checkpointStore.(workspaceCheckpointStatePersistenceStore)
+ if !ok {
+ return "", false
+ }
+ loaded, found, err := store.LoadRunCheckpointBaseline(ctx, strings.TrimSpace(sessionID), strings.TrimSpace(runID))
+ if err != nil || !found {
+ return "", false
+ }
+ s.setRunRollbackBaseline(sessionID, runID, loaded.CheckpointID)
+ s.setRunWorkspaceDrift(sessionID, runID, loaded.Drifted)
+ return strings.TrimSpace(loaded.CheckpointID), loaded.Drifted
+}
+
+// persistRunRollbackBaseline 双写 run 级回退基线,保证异步 diff 在 run 结束后仍可读取。
+func (s *Service) persistRunRollbackBaseline(ctx context.Context, sessionID, runID, checkpointID string, drifted bool) {
+ sessionID = strings.TrimSpace(sessionID)
+ runID = strings.TrimSpace(runID)
+ checkpointID = strings.TrimSpace(checkpointID)
+ if sessionID == "" || runID == "" || checkpointID == "" {
+ return
+ }
+ s.setRunRollbackBaseline(sessionID, runID, checkpointID)
+ s.setRunWorkspaceDrift(sessionID, runID, drifted)
+ store, ok := s.checkpointStore.(workspaceCheckpointStatePersistenceStore)
+ if !ok {
+ return
+ }
+ _ = store.SaveRunCheckpointBaseline(ctx, checkpoint.RunCheckpointBaseline{
+ SessionID: sessionID,
+ RunID: runID,
+ CheckpointID: checkpointID,
+ Drifted: drifted,
+ UpdatedAt: time.Now(),
+ })
+}
+
+// setRunWorkspaceDrift 标记当前 run 是否检测到空闲期工作区漂移。
+func (s *Service) setRunWorkspaceDrift(sessionID, runID string, drifted bool) {
+ key := buildRunBaselineKey(sessionID, runID)
+ if key == "\n" {
+ return
+ }
+ s.rollbackBaselineMu.Lock()
+ defer s.rollbackBaselineMu.Unlock()
+ s.ensureRunBaselineMapsLocked()
+ if drifted {
+ s.runWorkspaceDriftByRunKey[key] = true
+ return
+ }
+ delete(s.runWorkspaceDriftByRunKey, key)
+}
+
+// getRunWorkspaceDrift 返回 run 是否检测到空闲期工作区漂移。
+func (s *Service) getRunWorkspaceDrift(sessionID, runID string) bool {
+ key := buildRunBaselineKey(sessionID, runID)
+ if key == "\n" {
+ return false
+ }
+ s.rollbackBaselineMu.Lock()
+ defer s.rollbackBaselineMu.Unlock()
+ s.ensureRunBaselineMapsLocked()
+ return s.runWorkspaceDriftByRunKey[key]
+}
+
+// clearRunCheckpointCaches 清理 run 级 checkpoint 缓存,避免内存与跨 run 污染。
+func (s *Service) clearRunCheckpointCaches(sessionID, runID string) {
+ key := buildRunBaselineKey(sessionID, runID)
+ if key == "\n" {
+ return
+ }
+ s.rollbackBaselineMu.Lock()
+ defer s.rollbackBaselineMu.Unlock()
+ s.ensureRunBaselineMapsLocked()
+ delete(s.runRollbackBaselineByRunKey, key)
+ delete(s.runWorkspaceDriftByRunKey, key)
+}
+
+// recordRunStartFingerprint 在 run 开始时比对上次 run 结束指纹,返回是否发生空闲期漂移及详细差异。
+func (s *Service) recordRunStartFingerprint(
+ ctx context.Context,
+ sessionID, runID, workdir string,
+) (bool, repository.FingerprintDiff) {
+ emptyDiff := repository.FingerprintDiff{}
+ normalizedRunID := strings.TrimSpace(runID)
+ normalizedWorkdir := strings.TrimSpace(workdir)
+ workspaceKey := workspaceKeyFromWorkdir(normalizedWorkdir)
+ if strings.TrimSpace(sessionID) == "" || normalizedRunID == "" || normalizedWorkdir == "" || workspaceKey == "" {
+ return false, emptyDiff
+ }
+ current, _, err := repository.ScanWorkdir(ctx, normalizedWorkdir, repository.DefaultFingerprintOptions())
+ if err != nil {
+ return false, emptyDiff
+ }
+
+ s.rollbackBaselineMu.Lock()
+ s.ensureRunBaselineMapsLocked()
+ previous, ok := s.lastRunFingerprintByWorkspaceKey[workspaceKey]
+ s.rollbackBaselineMu.Unlock()
+ var currentCheckpointID string
+ if !ok {
+ if store, okStore := s.checkpointStore.(workspaceCheckpointStatePersistenceStore); okStore {
+ state, found, loadErr := store.LoadWorkspaceCheckpointState(ctx, workspaceKey)
+ if loadErr == nil && found {
+ currentCheckpointID = strings.TrimSpace(state.CurrentCheckpointID)
+ var restored repository.WorkdirFingerprint
+ if err := json.Unmarshal([]byte(state.FingerprintPayload), &restored); err == nil {
+ s.rollbackBaselineMu.Lock()
+ s.ensureRunBaselineMapsLocked()
+ s.lastRunFingerprintByWorkspaceKey[workspaceKey] = restored
+ s.rollbackBaselineMu.Unlock()
+ previous = restored
+ ok = true
+ }
+ }
+ }
+ if !ok {
+ if store, okStore := s.checkpointStore.(workspaceFingerprintPersistenceStore); okStore {
+ payload, found, loadErr := store.LoadWorkspaceFingerprint(ctx, workspaceKey)
+ if loadErr == nil && found {
+ var restored repository.WorkdirFingerprint
+ if err := json.Unmarshal([]byte(payload), &restored); err == nil {
+ s.rollbackBaselineMu.Lock()
+ s.ensureRunBaselineMapsLocked()
+ s.lastRunFingerprintByWorkspaceKey[workspaceKey] = restored
+ s.rollbackBaselineMu.Unlock()
+ previous = restored
+ ok = true
+ }
+ }
+ }
+ }
+ if !ok {
+ return false, emptyDiff
+ }
+ }
+ diff := repository.DiffFingerprints(previous, current)
+ drifted := len(diff.Added) > 0 || len(diff.Modified) > 0 || len(diff.Deleted) > 0
+ if drifted {
+ s.setRunWorkspaceDrift(sessionID, normalizedRunID, true)
+ return true, diff
+ }
+ if currentCheckpointID != "" {
+ s.persistRunRollbackBaseline(ctx, sessionID, normalizedRunID, currentCheckpointID, false)
+ }
+ return false, diff
+}
+
+// recordRunEndFingerprint 在 run 结束后保存最新指纹,供下次 run 开始时进行漂移检测。
+func (s *Service) recordRunEndFingerprint(ctx context.Context, sessionID, workdir string) {
+ s.recordRunEndWorkspaceState(ctx, sessionID, workdir, "")
+}
+
+// recordRunEndWorkspaceState 保存最新指纹和当前 checkpoint;checkpointID 为空时保留已有基线。
+func (s *Service) recordRunEndWorkspaceState(ctx context.Context, sessionID, workdir, checkpointID string) {
+ normalizedWorkdir := strings.TrimSpace(workdir)
+ workspaceKey := workspaceKeyFromWorkdir(normalizedWorkdir)
+ if strings.TrimSpace(sessionID) == "" || normalizedWorkdir == "" || workspaceKey == "" {
+ return
+ }
+ current, _, err := repository.ScanWorkdir(ctx, normalizedWorkdir, repository.DefaultFingerprintOptions())
+ if err != nil {
+ return
+ }
+ s.rollbackBaselineMu.Lock()
+ s.ensureRunBaselineMapsLocked()
+ s.lastRunFingerprintByWorkspaceKey[workspaceKey] = current
+ s.rollbackBaselineMu.Unlock()
+
+ store, ok := s.checkpointStore.(workspaceFingerprintPersistenceStore)
+ if !ok {
+ return
+ }
+ payload, err := json.Marshal(current)
+ if err != nil {
+ return
+ }
+ now := time.Now()
+ _ = store.SaveWorkspaceFingerprint(ctx, workspaceKey, string(payload), now)
+
+ stateStore, ok := s.checkpointStore.(workspaceCheckpointStatePersistenceStore)
+ if !ok {
+ return
+ }
+ currentCheckpointID := strings.TrimSpace(checkpointID)
+ if currentCheckpointID == "" {
+ if existing, found, err := stateStore.LoadWorkspaceCheckpointState(ctx, workspaceKey); err == nil && found {
+ currentCheckpointID = strings.TrimSpace(existing.CurrentCheckpointID)
+ }
+ }
+ _ = stateStore.SaveWorkspaceCheckpointState(ctx, checkpoint.WorkspaceCheckpointState{
+ WorkspaceKey: workspaceKey,
+ CurrentCheckpointID: currentCheckpointID,
+ FingerprintPayload: string(payload),
+ UpdatedAt: now,
+ })
+}
diff --git a/internal/runtime/run.go b/internal/runtime/run.go
index c4bdb542..6887ee55 100644
--- a/internal/runtime/run.go
+++ b/internal/runtime/run.go
@@ -135,6 +135,11 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) {
ChangedFiles: changedFiles,
})
}
+ if statePtr != nil {
+ runEndCtx := context.Background()
+ s.recordRunEndWorkspaceState(runEndCtx, statePtr.session.ID, effectiveWorkdirForCheckpointState(statePtr, statePtr.session), statePtr.lastEndOfTurnCheckpointID)
+ s.clearRunCheckpointCaches(statePtr.session.ID, statePtr.runID)
+ }
s.emitRunTermination(runCtx, input, statePtr, err)
}()
ctx = runCtx
@@ -170,6 +175,8 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) {
}
state := newRunState(input.RunID, session)
+ effectiveWorkdir := agentsession.EffectiveWorkdir(state.session.Workdir, initialCfg.Workdir)
+ state.effectiveWorkdir = effectiveWorkdir
state.runToken = runToken
state.thinkingOverride = cloneThinkingOverride(input.ThinkingOverride)
state.disableTools = input.DisableTools
@@ -190,7 +197,28 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) {
return s.handleRunError(err)
}
- effectiveWorkdir := agentsession.EffectiveWorkdir(state.session.Workdir, initialCfg.Workdir)
+ runBaselineCheckpointID := ""
+ if drifted, driftDiff := s.recordRunStartFingerprint(ctx, state.session.ID, state.runID, effectiveWorkdir); drifted {
+ if checkpointID, cpErr := s.createPreRunDriftRebaseCheckpoint(ctx, &state, driftDiff); cpErr != nil {
+ s.emitRunScopedOptional(EventCheckpointWarning, &state, CheckpointWarningPayload{
+ Error: fmt.Sprintf("workspace drift detected but baseline rebase failed: %v", cpErr),
+ Phase: "run_start_drift_rebase",
+ })
+ } else if checkpointID != "" {
+ runBaselineCheckpointID = checkpointID
+ s.persistRunRollbackBaseline(ctx, state.session.ID, state.runID, checkpointID, true)
+ s.recordRunEndWorkspaceState(ctx, state.session.ID, effectiveWorkdir, checkpointID)
+ }
+ s.emitRunScopedOptional(EventCheckpointWarning, &state, CheckpointWarningPayload{
+ Error: "workspace drift detected before run start",
+ Phase: "run_start_drift",
+ })
+ }
+ if runBaselineCheckpointID == "" {
+ if baseline, _ := s.getPersistentRunRollbackBaseline(ctx, state.session.ID, state.runID); baseline != "" {
+ runBaselineCheckpointID = baseline
+ }
+ }
_ = s.runHookPoint(ctx, &state, runtimehooks.HookPointSessionStart, runtimehooks.HookContext{
Metadata: map[string]any{
"session_id": state.session.ID,
@@ -223,7 +251,11 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) {
s.updateResumeCheckpoint(ctx, &state, "plan", "")
maxTurns := resolveRuntimeMaxTurns(initialCfg.Runtime)
- state.baselineCheckpointID = s.findPreviousEndOfTurnCheckpoint(ctx, sessionID, input.RunID)
+ if runBaselineCheckpointID != "" {
+ state.baselineCheckpointID = runBaselineCheckpointID
+ } else {
+ state.baselineCheckpointID = s.findPreviousEndOfTurnCheckpoint(ctx, sessionID, input.RunID)
+ }
for turn := 0; ; turn++ {
if turn >= maxTurns {
state.maxTurnsReached = true
diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go
index 0da88617..f1db6601 100644
--- a/internal/runtime/runtime.go
+++ b/internal/runtime/runtime.go
@@ -193,19 +193,26 @@ type Service struct {
events chan RuntimeEvent
runtimeSnapshotMu sync.Mutex
runtimeSnapshots map[string]RuntimeSnapshot
- sessionMu sync.Mutex
- sessionLocks map[string]*sessionLockEntry
- runMu sync.Mutex
- activeRunToken uint64
- nextRunToken uint64
- activeRunCancels map[uint64]context.CancelFunc
- activeRunByID map[string]uint64
- activeRunTokenIDs map[uint64]string
- activeRunStates map[uint64]*runState
- permissionAskMapMu sync.Mutex
- permissionAskLocks map[string]*permissionAskLockEntry
- askStore AskSessionStore
- askSequence uint64
+ rollbackBaselineMu sync.Mutex
+ // runRollbackBaselineByRunKey 记录一次 run 的权威回退基线 checkpoint_id,key=session_id+"\n"+run_id。
+ runRollbackBaselineByRunKey map[string]string
+ // runWorkspaceDriftByRunKey 标记 run 开始前工作区是否发生空闲期漂移,key=session_id+"\n"+run_id。
+ runWorkspaceDriftByRunKey map[string]bool
+ // lastRunFingerprintByWorkspaceKey 保存每个工作区上一次 run 结束时的指纹,用于跨会话漂移检测。
+ lastRunFingerprintByWorkspaceKey map[string]repository.WorkdirFingerprint
+ sessionMu sync.Mutex
+ sessionLocks map[string]*sessionLockEntry
+ runMu sync.Mutex
+ activeRunToken uint64
+ nextRunToken uint64
+ activeRunCancels map[uint64]context.CancelFunc
+ activeRunByID map[string]uint64
+ activeRunTokenIDs map[uint64]string
+ activeRunStates map[uint64]*runState
+ permissionAskMapMu sync.Mutex
+ permissionAskLocks map[string]*permissionAskLockEntry
+ askStore AskSessionStore
+ askSequence uint64
thinkingEnabled bool
@@ -261,24 +268,27 @@ func NewWithFactory(
}
service := &Service{
- configManager: configManager,
- sessionStore: sessionStore,
- toolManager: toolManager,
- providerFactory: providerFactory,
- contextBuilder: contextBuilder,
- repositoryService: repository.NewService(),
- approvalBroker: approval.NewBroker(),
- askUserBroker: askuser.NewBroker(),
- events: make(chan RuntimeEvent, 128),
- runtimeSnapshots: make(map[string]RuntimeSnapshot),
- sessionLocks: make(map[string]*sessionLockEntry),
- permissionAskLocks: make(map[string]*permissionAskLockEntry),
- activeRunCancels: make(map[uint64]context.CancelFunc),
- activeRunByID: make(map[string]uint64),
- activeRunTokenIDs: make(map[uint64]string),
- activeRunStates: make(map[uint64]*runState),
- askStore: newInMemoryAskSessionStore(askSessionTTL),
- thinkingEnabled: true,
+ configManager: configManager,
+ sessionStore: sessionStore,
+ toolManager: toolManager,
+ providerFactory: providerFactory,
+ contextBuilder: contextBuilder,
+ repositoryService: repository.NewService(),
+ approvalBroker: approval.NewBroker(),
+ askUserBroker: askuser.NewBroker(),
+ events: make(chan RuntimeEvent, 128),
+ runtimeSnapshots: make(map[string]RuntimeSnapshot),
+ runRollbackBaselineByRunKey: make(map[string]string),
+ runWorkspaceDriftByRunKey: make(map[string]bool),
+ lastRunFingerprintByWorkspaceKey: make(map[string]repository.WorkdirFingerprint),
+ sessionLocks: make(map[string]*sessionLockEntry),
+ permissionAskLocks: make(map[string]*permissionAskLockEntry),
+ activeRunCancels: make(map[uint64]context.CancelFunc),
+ activeRunByID: make(map[string]uint64),
+ activeRunTokenIDs: make(map[uint64]string),
+ activeRunStates: make(map[uint64]*runState),
+ askStore: newInMemoryAskSessionStore(askSessionTTL),
+ thinkingEnabled: true,
}
baseHookExecutor := runtimehooks.NewExecutor(runtimehooks.NewRegistry(), newHookRuntimeEventEmitter(service), runtimehooks.DefaultHookTimeout)
baseHookExecutor.SetAsyncResultSink(newHookAsyncResultSink(service))
diff --git a/internal/runtime/state.go b/internal/runtime/state.go
index 6606f552..2440bddf 100644
--- a/internal/runtime/state.go
+++ b/internal/runtime/state.go
@@ -17,6 +17,7 @@ type runState struct {
runID string
runToken uint64
session agentsession.Session
+ effectiveWorkdir string
compactCount int
reactiveCompactAttempts int
rememberedThisRun bool
diff --git a/internal/session/checkpoint_types.go b/internal/session/checkpoint_types.go
index d21ac79b..c0a4e62c 100644
--- a/internal/session/checkpoint_types.go
+++ b/internal/session/checkpoint_types.go
@@ -6,13 +6,14 @@ import "time"
type CheckpointReason string
const (
- CheckpointReasonPreWrite CheckpointReason = "pre_write"
- CheckpointReasonCompact CheckpointReason = "compact"
- CheckpointReasonPlanMode CheckpointReason = "plan_mode"
- CheckpointReasonManual CheckpointReason = "manual"
- CheckpointReasonGuard CheckpointReason = "pre_restore_guard"
- CheckpointReasonPreWriteDegraded CheckpointReason = "pre_write_degraded"
- CheckpointReasonEndOfTurn CheckpointReason = "end_of_turn"
+ CheckpointReasonPreWrite CheckpointReason = "pre_write"
+ CheckpointReasonCompact CheckpointReason = "compact"
+ CheckpointReasonPlanMode CheckpointReason = "plan_mode"
+ CheckpointReasonManual CheckpointReason = "manual"
+ CheckpointReasonGuard CheckpointReason = "pre_restore_guard"
+ CheckpointReasonPreRunDriftRebase CheckpointReason = "pre_run_drift_rebase"
+ CheckpointReasonPreWriteDegraded CheckpointReason = "pre_write_degraded"
+ CheckpointReasonEndOfTurn CheckpointReason = "end_of_turn"
)
// CheckpointStatus 描述 checkpoint 的生命周期状态。
diff --git a/internal/tui/services/runtime_contract.go b/internal/tui/services/runtime_contract.go
index 7c4d1cf8..8ce28ade 100644
--- a/internal/tui/services/runtime_contract.go
+++ b/internal/tui/services/runtime_contract.go
@@ -245,6 +245,8 @@ type CheckpointDiffResult struct {
PrevCommitHash string `json:"prev_commit_hash,omitempty"`
Files CheckpointDiffFiles `json:"files"`
Patch string `json:"patch,omitempty"`
+ WorkspaceDrifted bool `json:"workspace_drifted,omitempty"`
+ Warning string `json:"warning,omitempty"`
}
// WorkspaceRecord 描述工作区登记信息。
diff --git a/web/src/App.test.tsx b/web/src/App.test.tsx
new file mode 100644
index 00000000..cb65a5eb
--- /dev/null
+++ b/web/src/App.test.tsx
@@ -0,0 +1,57 @@
+import { describe, expect, it, vi } from 'vitest'
+import { fireEvent, render, screen } from '@testing-library/react'
+import App from './App'
+
+const runtimeState = {
+ status: 'connected',
+ mode: 'browser',
+ error: '',
+ loadingMessage: '',
+ retry: vi.fn(),
+}
+
+vi.mock('./context/RuntimeProvider', () => ({
+ useRuntime: () => runtimeState,
+}))
+
+vi.mock('./pages/ChatPage', () => ({
+ default: () =>
chat-page
,
+}))
+
+vi.mock('./pages/ConnectPage', () => ({
+ default: () => connect-page
,
+}))
+
+describe('App routes by runtime status', () => {
+ it('renders loading screen', () => {
+ runtimeState.status = 'loading'
+ runtimeState.loadingMessage = 'booting'
+ render()
+ expect(screen.getByText('booting')).toBeInTheDocument()
+ })
+
+ it('renders connect page for browser needs_config', () => {
+ runtimeState.status = 'needs_config'
+ runtimeState.mode = 'browser'
+ render()
+ expect(screen.getByText('connect-page')).toBeInTheDocument()
+ })
+
+ it('renders electron error screen with retry', () => {
+ runtimeState.status = 'error'
+ runtimeState.mode = 'electron'
+ runtimeState.error = 'boom'
+ render()
+ expect(screen.getByText('Gateway connection failed')).toBeInTheDocument()
+ fireEvent.click(screen.getByRole('button', { name: 'Retry connection' }))
+ expect(runtimeState.retry).toHaveBeenCalled()
+ })
+
+ it('renders chat routes when connected', () => {
+ runtimeState.status = 'connected'
+ runtimeState.mode = 'browser'
+ render()
+ expect(screen.getByText('chat-page')).toBeInTheDocument()
+ })
+})
+
diff --git a/web/src/api/gateway.test.ts b/web/src/api/gateway.test.ts
new file mode 100644
index 00000000..336b3f40
--- /dev/null
+++ b/web/src/api/gateway.test.ts
@@ -0,0 +1,62 @@
+import { describe, it, expect, vi, beforeEach } from 'vitest'
+import { GatewayAPI } from './gateway'
+import { Method } from './protocol'
+
+describe('GatewayAPI', () => {
+ const call = vi.fn()
+ const ws = { call } as any
+ let api: GatewayAPI
+
+ beforeEach(() => {
+ call.mockReset()
+ call.mockResolvedValue({ type: 'ack', payload: {} })
+ api = new GatewayAPI(ws)
+ })
+
+ it('maps authenticate and run methods', async () => {
+ await api.authenticate('tok')
+ await api.run({ input_text: 'hello' })
+
+ expect(call).toHaveBeenNthCalledWith(1, Method.Authenticate, { token: 'tok' })
+ expect(call).toHaveBeenNthCalledWith(2, Method.Run, { input_text: 'hello' })
+ })
+
+ it('maps optional session_id in listModels', async () => {
+ await api.listModels()
+ await api.listModels('s1')
+
+ expect(call).toHaveBeenNthCalledWith(1, Method.ListModels, undefined)
+ expect(call).toHaveBeenNthCalledWith(2, Method.ListModels, { session_id: 's1' })
+ })
+
+ it('maps optional provider_id in setSessionModel', async () => {
+ await api.setSessionModel('s1', 'm1')
+ await api.setSessionModel('s1', 'm1', 'p1')
+
+ expect(call).toHaveBeenNthCalledWith(1, Method.SetSessionModel, { session_id: 's1', model_id: 'm1' })
+ expect(call).toHaveBeenNthCalledWith(2, Method.SetSessionModel, { session_id: 's1', model_id: 'm1', provider_id: 'p1' })
+ })
+
+ it('maps workspace methods and optional remove_data', async () => {
+ await api.listWorkspaces()
+ await api.createWorkspace('/tmp/a', 'A')
+ await api.switchWorkspace('h1')
+ await api.renameWorkspace('h1', 'B')
+ await api.deleteWorkspace('h1', true)
+
+ expect(call).toHaveBeenNthCalledWith(1, Method.ListWorkspaces)
+ expect(call).toHaveBeenNthCalledWith(2, Method.CreateWorkspace, { path: '/tmp/a', name: 'A' })
+ expect(call).toHaveBeenNthCalledWith(3, Method.SwitchWorkspace, { workspace_hash: 'h1' })
+ expect(call).toHaveBeenNthCalledWith(4, Method.RenameWorkspace, { workspace_hash: 'h1', name: 'B' })
+ expect(call).toHaveBeenNthCalledWith(5, Method.DeleteWorkspace, { workspace_hash: 'h1', remove_data: true })
+ })
+
+ it('maps permission and user question resolution', async () => {
+ await api.resolvePermission({ request_id: 'r1', decision: 'allow_once' })
+ await api.resolveUserQuestion({ request_id: 'q1', status: 'answered', message: 'ok' })
+
+ expect(call).toHaveBeenNthCalledWith(1, Method.ResolvePermission, { request_id: 'r1', decision: 'allow_once' })
+ expect(call).toHaveBeenNthCalledWith(2, Method.UserQuestionAnswer, { request_id: 'q1', status: 'answered', message: 'ok' })
+ })
+})
+
diff --git a/web/src/api/protocol.ts b/web/src/api/protocol.ts
index fbe1011c..249b0dc1 100644
--- a/web/src/api/protocol.ts
+++ b/web/src/api/protocol.ts
@@ -81,6 +81,7 @@ export const EventType = {
ToolStart: 'tool_start',
ToolResult: 'tool_result',
ToolDiff: 'tool_diff',
+ BashSideEffect: 'bash_side_effect',
ToolChunk: 'tool_chunk',
ToolCallThinking: 'tool_call_thinking',
ThinkingDelta: 'thinking_delta',
@@ -484,6 +485,8 @@ export interface CheckpointDiffResultPayload {
prev_commit_hash?: string
files: FileDiffs
patch?: string
+ workspace_drifted?: boolean
+ warning?: string
}
export interface CheckpointRestoreResultPayload {
@@ -916,6 +919,7 @@ export interface ToolDiffFileEntry {
path: string
diff?: string
was_new?: boolean
+ kind?: string
}
/** tool_diff 事件载荷:写工具修改了哪些文件 */
@@ -928,3 +932,12 @@ export interface ToolDiffPayload {
files?: ToolDiffFileChange[]
diffs?: ToolDiffFileEntry[]
}
+
+/** bash_side_effect 文件变更条目 */
+export interface BashSideEffectPayload {
+ tool_call_id: string
+ command?: string
+ changes?: ToolDiffFileChange[]
+ preemptively_captured_paths?: string[]
+ uncovered_paths?: string[]
+}
diff --git a/web/src/api/wsClient.test.ts b/web/src/api/wsClient.test.ts
new file mode 100644
index 00000000..f107786c
--- /dev/null
+++ b/web/src/api/wsClient.test.ts
@@ -0,0 +1,164 @@
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
+import { createWSClient } from './wsClient'
+import { JSONRPC_VERSION, Method } from './protocol'
+
+class MockWebSocket {
+ static CONNECTING = 0
+ static OPEN = 1
+ static CLOSING = 2
+ static CLOSED = 3
+
+ static instances: MockWebSocket[] = []
+
+ url: string
+ readyState = MockWebSocket.CONNECTING
+ onopen: (() => void) | null = null
+ onclose: (() => void) | null = null
+ onerror: (() => void) | null = null
+ onmessage: ((event: { data: string }) => void) | null = null
+ sent: string[] = []
+
+ constructor(url: string) {
+ this.url = url
+ MockWebSocket.instances.push(this)
+ }
+
+ send(data: string) {
+ this.sent.push(data)
+ }
+
+ close() {
+ this.readyState = MockWebSocket.CLOSED
+ this.onclose?.()
+ }
+
+ open() {
+ this.readyState = MockWebSocket.OPEN
+ this.onopen?.()
+ }
+
+ emit(data: unknown) {
+ this.onmessage?.({ data: typeof data === 'string' ? data : JSON.stringify(data) })
+ }
+}
+
+function latestWS(): MockWebSocket {
+ const ws = MockWebSocket.instances.at(-1)
+ if (!ws) throw new Error('websocket not created')
+ return ws
+}
+
+describe('createWSClient', () => {
+ beforeEach(() => {
+ vi.useFakeTimers()
+ MockWebSocket.instances = []
+ vi.stubGlobal('WebSocket', MockWebSocket as any)
+ })
+
+ afterEach(() => {
+ vi.clearAllTimers()
+ vi.useRealTimers()
+ vi.unstubAllGlobals()
+ })
+
+ it('builds ws url with endpoint and token', () => {
+ const client = createWSClient({
+ baseURL: 'http://127.0.0.1:8080/',
+ endpoint: 'gateway/ws',
+ token: 'a b',
+ })
+ client.connect()
+ expect(latestWS().url).toBe('ws://127.0.0.1:8080/gateway/ws?token=a%20b')
+ client.disconnect()
+ })
+
+ it('sends rpc request and resolves rpc response', async () => {
+ const client = createWSClient({ baseURL: 'http://127.0.0.1:8080', rpcTimeout: 50 })
+ client.connect()
+ const ws = latestWS()
+ ws.open()
+
+ const promise = client.call('gateway.ping', { x: 1 })
+ const req = JSON.parse(ws.sent[0])
+ expect(req).toMatchObject({
+ jsonrpc: JSONRPC_VERSION,
+ method: 'gateway.ping',
+ params: { x: 1 },
+ })
+
+ ws.emit({ jsonrpc: JSONRPC_VERSION, id: req.id, result: { ok: true } })
+ await expect(promise).resolves.toEqual({ ok: true })
+ client.disconnect()
+ })
+
+ it('rejects rpc call on timeout', async () => {
+ const client = createWSClient({ baseURL: 'http://127.0.0.1:8080', rpcTimeout: 20 })
+ client.connect()
+ latestWS().open()
+ const promise = client.call('x.y')
+ promise.catch(() => {})
+ await vi.advanceTimersByTimeAsync(21)
+ await expect(promise).rejects.toThrow('RPC timeout: x.y')
+ client.disconnect()
+ })
+
+ it('dispatches gateway.event notifications to event handlers', () => {
+ const client = createWSClient({ baseURL: 'http://127.0.0.1:8080' })
+ const handler = vi.fn()
+ client.onEvent(handler)
+
+ client.connect()
+ const ws = latestWS()
+ ws.open()
+ ws.emit({
+ jsonrpc: JSONRPC_VERSION,
+ method: Method.Event,
+ params: { type: 'event', payload: { a: 1 } },
+ })
+ expect(handler).toHaveBeenCalledWith({ type: 'event', payload: { a: 1 } })
+ client.disconnect()
+ })
+
+ it('transitions to error then reconnects and fires onReconnect', async () => {
+ const client = createWSClient({
+ baseURL: 'http://127.0.0.1:8080',
+ reconnectBaseInterval: 10,
+ reconnectMaxInterval: 10,
+ })
+ const states: string[] = []
+ const onReconnect = vi.fn()
+ client.onStateChange((s) => states.push(s))
+ client.onReconnect(onReconnect)
+
+ client.connect()
+ const ws1 = latestWS()
+ ws1.open()
+ ws1.close()
+
+ expect(states).toContain('error')
+ await vi.advanceTimersByTimeAsync(20)
+ const ws2 = latestWS()
+ expect(ws2).not.toBe(ws1)
+ ws2.open()
+ expect(onReconnect).toHaveBeenCalledTimes(1)
+ expect(client.getState()).toBe('connected')
+ client.disconnect()
+ })
+
+ it('closes stale connection by heartbeat timeout', async () => {
+ const client = createWSClient({
+ baseURL: 'http://127.0.0.1:8080',
+ heartbeatTimeout: 10,
+ heartbeatCheckInterval: 5,
+ reconnectBaseInterval: 50,
+ reconnectMaxInterval: 50,
+ })
+ client.connect()
+ const ws = latestWS()
+ const closeSpy = vi.spyOn(ws, 'close')
+ ws.open()
+ await vi.advanceTimersByTimeAsync(20)
+ expect(closeSpy).toHaveBeenCalled()
+ client.disconnect()
+ })
+})
diff --git a/web/src/components/ErrorBoundary.test.tsx b/web/src/components/ErrorBoundary.test.tsx
new file mode 100644
index 00000000..b4d99dbf
--- /dev/null
+++ b/web/src/components/ErrorBoundary.test.tsx
@@ -0,0 +1,34 @@
+import { describe, expect, it, vi } from 'vitest'
+import { fireEvent, render, screen } from '@testing-library/react'
+import type { ReactElement } from 'react'
+import { ErrorBoundary } from './ErrorBoundary'
+
+function Crash(): ReactElement {
+ throw new Error('boom')
+}
+
+describe('ErrorBoundary', () => {
+ it('renders fallback UI and supports retry', () => {
+ const errSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
+ render(
+
+
+ ,
+ )
+ expect(screen.getByText('The application encountered an error')).toBeInTheDocument()
+ expect(screen.getByText('boom')).toBeInTheDocument()
+ fireEvent.click(screen.getByRole('button', { name: 'Retry' }))
+ errSpy.mockRestore()
+ })
+
+ it('uses custom fallback render prop', () => {
+ const errSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
+ render(
+ }>
+
+ ,
+ )
+ expect(screen.getByRole('button', { name: 'custom:boom' })).toBeInTheDocument()
+ errSpy.mockRestore()
+ })
+})
diff --git a/web/src/components/UpdateNotification.test.tsx b/web/src/components/UpdateNotification.test.tsx
new file mode 100644
index 00000000..37a5cf2a
--- /dev/null
+++ b/web/src/components/UpdateNotification.test.tsx
@@ -0,0 +1,55 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
+import { UpdateNotification } from './UpdateNotification'
+
+describe('UpdateNotification', () => {
+ beforeEach(() => {
+ Object.defineProperty(window, 'electronAPI', {
+ value: undefined,
+ configurable: true,
+ writable: true,
+ })
+ })
+
+ it('renders nothing when electron api is missing', () => {
+ const { container } = render()
+ expect(container.firstChild).toBeNull()
+ })
+
+ it('shows available/downloaded updates and handles actions', async () => {
+ let onAvailable: ((info: any) => void) | null = null
+ let onDownloaded: ((info: any) => void) | null = null
+ const quitAndInstall = vi.fn().mockResolvedValue(undefined)
+ Object.defineProperty(window, 'electronAPI', {
+ value: {
+ onUpdateAvailable: (cb: any) => {
+ onAvailable = cb
+ return vi.fn()
+ },
+ onUpdateDownloaded: (cb: any) => {
+ onDownloaded = cb
+ return vi.fn()
+ },
+ quitAndInstall,
+ },
+ configurable: true,
+ })
+ render()
+
+ act(() => {
+ onAvailable?.({ version: '1.2.3' })
+ })
+ await waitFor(() => {
+ expect(screen.getByText('A new version v1.2.3 is available, downloading...')).toBeInTheDocument()
+ })
+
+ act(() => {
+ onDownloaded?.({ version: '1.2.3' })
+ })
+ fireEvent.click(screen.getByRole('button', { name: 'Restart Now' }))
+ expect(quitAndInstall).toHaveBeenCalled()
+
+ fireEvent.click(screen.getByTitle('Dismiss'))
+ expect(screen.queryByText('NeoCode v1.2.3 is ready to install')).not.toBeInTheDocument()
+ })
+})
diff --git a/web/src/components/chat/AcceptanceMessage.test.tsx b/web/src/components/chat/AcceptanceMessage.test.tsx
new file mode 100644
index 00000000..b33e6e65
--- /dev/null
+++ b/web/src/components/chat/AcceptanceMessage.test.tsx
@@ -0,0 +1,31 @@
+import { describe, expect, it } from 'vitest'
+import { fireEvent, render, screen } from '@testing-library/react'
+import AcceptanceMessage from './AcceptanceMessage'
+
+describe('AcceptanceMessage', () => {
+ it('renders accepted summary and expandable details', () => {
+ render(
+ ,
+ )
+ expect(screen.getByText('已接受')).toBeInTheDocument()
+ expect(screen.getByText('all good')).toBeInTheDocument()
+ fireEvent.click(screen.getByRole('button', { name: '展开详情' }))
+ expect(screen.getByText('detail')).toBeInTheDocument()
+ })
+})
+
diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx
index f9de008e..8c610e7b 100644
--- a/web/src/components/chat/ChatInput.test.tsx
+++ b/web/src/components/chat/ChatInput.test.tsx
@@ -1,25 +1,60 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
-import { fireEvent, render, screen } from '@testing-library/react'
+import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import ChatInput from './ChatInput'
import { useChatStore } from '@/stores/useChatStore'
import { useComposerStore } from '@/stores/useComposerStore'
import { useSessionStore } from '@/stores/useSessionStore'
+const mockGatewayAPI = {
+ listAvailableSkills: vi.fn(),
+ listModels: vi.fn(),
+ run: vi.fn(),
+ bindStream: vi.fn(),
+ cancel: vi.fn(),
+ compact: vi.fn(),
+ executeSystemTool: vi.fn(),
+ activateSessionSkill: vi.fn(),
+ deactivateSessionSkill: vi.fn(),
+}
+
vi.mock('@/context/RuntimeProvider', () => ({
- useGatewayAPI: () => null,
+ useGatewayAPI: () => mockGatewayAPI,
+}))
+
+vi.mock('./ModelSelector', () => ({
+ default: () => ,
}))
describe('ChatInput', () => {
beforeEach(() => {
+ vi.clearAllMocks()
+ mockGatewayAPI.listAvailableSkills.mockResolvedValue({
+ payload: {
+ skills: [
+ {
+ descriptor: { id: 'skill.demo', description: 'demo skill' },
+ active: true,
+ },
+ ],
+ },
+ })
+ mockGatewayAPI.listModels.mockResolvedValue({
+ payload: {
+ models: [],
+ selected_provider_id: '',
+ selected_model_id: '',
+ },
+ })
+
useComposerStore.setState({ composerText: '' })
- useSessionStore.setState({ currentSessionId: '' } as any)
+ useSessionStore.setState({ currentSessionId: '' } as never)
useChatStore.setState({
isGenerating: false,
messages: [],
permissionRequests: [],
agentMode: 'build',
permissionMode: 'default',
- } as any)
+ } as never)
})
it('shows the default/bypass selector in build mode', () => {
@@ -40,10 +75,63 @@ describe('ChatInput', () => {
expect(screen.queryByRole('button', { name: 'bypass' })).not.toBeInTheDocument()
})
+ it('opens slash suggestions for bare slash and loads skills immediately', async () => {
+ render()
+
+ const textarea = screen.getByRole('textbox')
+ fireEvent.change(textarea, { target: { value: '/' } })
+
+ await waitFor(() => {
+ expect(screen.getByTestId('slash-command-menu')).toBeInTheDocument()
+ })
+ await waitFor(() => {
+ expect(mockGatewayAPI.listAvailableSkills).toHaveBeenCalledWith(undefined)
+ })
+ await waitFor(() => {
+ expect(screen.getByText('/skill.demo')).toBeInTheDocument()
+ })
+ })
+
+ it('keeps slash menu visible for fuzzy inputs like /w', async () => {
+ render()
+
+ const textarea = screen.getByRole('textbox')
+ fireEvent.change(textarea, { target: { value: '/w' } })
+
+ await waitFor(() => {
+ expect(screen.getByTestId('slash-command-menu')).toBeInTheDocument()
+ })
+ expect(screen.getAllByText((_, el) => Boolean(el?.textContent?.includes('/help'))).length).toBeGreaterThan(0)
+ })
+
+ it('supports keyboard navigation, tab completion, and escape for slash menu', async () => {
+ render()
+
+ const textarea = screen.getByRole('textbox') as HTMLTextAreaElement
+ fireEvent.change(textarea, { target: { value: '/' } })
+
+ await waitFor(() => {
+ expect(screen.getByTestId('slash-command-menu')).toBeInTheDocument()
+ })
+
+ fireEvent.keyDown(textarea, { key: 'ArrowDown' })
+ fireEvent.keyDown(textarea, { key: 'Tab' })
+ expect(textarea.value).toBe('/compact ')
+
+ fireEvent.change(textarea, { target: { value: '/' } })
+ await waitFor(() => {
+ expect(screen.getByTestId('slash-command-menu')).toBeInTheDocument()
+ })
+ fireEvent.keyDown(textarea, { key: 'Escape' })
+ await waitFor(() => {
+ expect(screen.queryByTestId('slash-command-menu')).not.toBeInTheDocument()
+ })
+ })
+
it('does not render the unimplemented attachment and mention buttons', () => {
render()
- expect(screen.queryByTitle('附加文件')).not.toBeInTheDocument()
+ expect(screen.queryByTitle('附件文件')).not.toBeInTheDocument()
expect(screen.queryByTitle('引用上下文')).not.toBeInTheDocument()
})
})
diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx
index f35c6c5d..ff4258e4 100644
--- a/web/src/components/chat/ChatInput.tsx
+++ b/web/src/components/chat/ChatInput.tsx
@@ -1,4 +1,4 @@
-import { useState, useRef, useEffect, useCallback } from 'react'
+import { useState, useRef, useEffect, useCallback, useMemo } from 'react'
import { useChatStore, createUserMessage } from '@/stores/useChatStore'
import { useGatewayStore } from '@/stores/useGatewayStore'
import { useSessionStore, isValidSessionId } from '@/stores/useSessionStore'
@@ -21,52 +21,97 @@ import SkillPicker from './SkillPicker'
import ModelSelector from './ModelSelector'
import { Send, Square } from 'lucide-react'
+const slashMenuAnchorStyle: React.CSSProperties = {
+ position: 'absolute',
+ left: 0,
+ bottom: 'calc(100% + 8px)',
+ zIndex: 100,
+}
+
+/** 将网关返回的技能列表转换成输入框使用的 slash 命令结构。 */
+function buildSkillSlashCommands(
+ skills: Array<{ descriptor: { id: string; description?: string }; active?: boolean }>,
+): SkillSlashCommand[] {
+ return skills.map((skill) => ({
+ id: `skill-${skill.descriptor.id}`,
+ usage: `/${skill.descriptor.id}`,
+ description: skill.descriptor.description || '技能',
+ hasArgument: false,
+ isSkill: true,
+ skillId: skill.descriptor.id,
+ active: Boolean(skill.active),
+ }))
+}
+
+/** 用当前命令定义生成帮助文本,避免菜单与帮助内容漂移。 */
+function buildSlashHelpText(commands: AnySlashCommand[]): string {
+ const helpCommands = commands.length > 0 ? commands : builtinSlashCommands
+ const maxLen = helpCommands.reduce((max, command) => Math.max(max, command.usage.length), 0)
+ const lines = helpCommands.map((command) => {
+ const pad = ' '.repeat(maxLen - command.usage.length)
+ const description = isSkillCommand(command) && command.active
+ ? `${command.description} (已激活)`
+ : command.description
+ return ` ${command.usage}${pad} - ${description}`
+ })
+ return ['可用命令:', ...lines].join('\n')
+}
+
export default function ChatInput() {
const gatewayAPI = useGatewayAPI()
- const text = useComposerStore((s) => s.composerText)
- const setText = useComposerStore((s) => s.setComposerText)
+ const text = useComposerStore((state) => state.composerText)
+ const setText = useComposerStore((state) => state.setComposerText)
const [rows, setRows] = useState(1)
const textareaRef = useRef(null)
const runCancelledRef = useRef(false)
const composingRef = useRef(false)
- const isGenerating = useChatStore((s) => s.isGenerating)
- const addMessage = useChatStore((s) => s.addMessage)
- const addSystemMessage = useChatStore((s) => s.addSystemMessage)
- const setGenerating = useChatStore((s) => s.setGenerating)
- const sessionId = useSessionStore((s) => s.currentSessionId)
- const agentMode = useChatStore((s) => s.agentMode)
- const setAgentMode = useChatStore((s) => s.setAgentMode)
- const permissionMode = useChatStore((s) => s.permissionMode)
- const setPermissionMode = useChatStore((s) => s.setPermissionMode)
+ const isGenerating = useChatStore((state) => state.isGenerating)
+ const addMessage = useChatStore((state) => state.addMessage)
+ const addSystemMessage = useChatStore((state) => state.addSystemMessage)
+ const setGenerating = useChatStore((state) => state.setGenerating)
+ const sessionId = useSessionStore((state) => state.currentSessionId)
+ const agentMode = useChatStore((state) => state.agentMode)
+ const setAgentMode = useChatStore((state) => state.setAgentMode)
+ const permissionMode = useChatStore((state) => state.permissionMode)
+ const setPermissionMode = useChatStore((state) => state.setPermissionMode)
const [showSlashMenu, setShowSlashMenu] = useState(false)
const [selectedIndex, setSelectedIndex] = useState(0)
const [matchedCommands, setMatchedCommands] = useState([])
const [availableSkillCommands, setAvailableSkillCommands] = useState([])
const [showSkillPicker, setShowSkillPicker] = useState(false)
+ const allSlashCommands = useMemo(
+ () => [...builtinSlashCommands, ...availableSkillCommands],
+ [availableSkillCommands],
+ )
useEffect(() => {
- if (!showSlashMenu || !gatewayAPI) return
+ if (!gatewayAPI || !text.trimLeft().startsWith('/')) return
+ let cancelled = false
gatewayAPI.listAvailableSkills(sessionId || undefined).then((result) => {
+ if (cancelled) return
const skills = result.payload?.skills || []
- const skillCommands: SkillSlashCommand[] = skills.map((s) => ({
- id: `skill-${s.descriptor.id}`,
- usage: `/${s.descriptor.id}`,
- description: s.descriptor.description || '技能',
- hasArgument: false, isSkill: true,
- skillId: s.descriptor.id, active: s.active,
- }))
- setAvailableSkillCommands(skillCommands)
- }).catch(() => { setAvailableSkillCommands([]) })
- }, [showSlashMenu, gatewayAPI, sessionId])
+ setAvailableSkillCommands(buildSkillSlashCommands(skills))
+ }).catch(() => {
+ if (!cancelled) setAvailableSkillCommands([])
+ })
+ return () => {
+ cancelled = true
+ }
+ }, [text, gatewayAPI, sessionId])
useEffect(() => {
- if (!isSlashCommand(text)) { setShowSlashMenu(false); return }
- const allCommands: AnySlashCommand[] = [...builtinSlashCommands, ...availableSkillCommands]
- const matched = matchSlashCommands(text, allCommands)
- if (matched.length > 0) { setMatchedCommands(matched); setShowSlashMenu(true); setSelectedIndex(0) }
- else { setShowSlashMenu(false) }
- }, [text, availableSkillCommands])
+ if (!isSlashCommand(text)) {
+ setMatchedCommands([])
+ setShowSlashMenu(false)
+ return
+ }
+
+ const matched = matchSlashCommands(text, allSlashCommands)
+ setMatchedCommands(matched)
+ setShowSlashMenu(matched.length > 0)
+ if (matched.length > 0) setSelectedIndex(0)
+ }, [text, allSlashCommands])
useEffect(() => {
const lines = text.split('\n').length
@@ -76,137 +121,249 @@ export default function ChatInput() {
const executeSlashCommand = useCallback(async (input: string): Promise => {
const parsed = parseSlashCommand(input)
if (!parsed) return false
+
const { command, argument } = parsed
const currentSessionId = sessionId
const api = gatewayAPI
- if (!api) { useUIStore.getState().showToast('Gateway not connected', 'error'); return true }
+ if (!api) {
+ useUIStore.getState().showToast('Gateway not connected', 'error')
+ return true
+ }
switch (command) {
case '/help': {
- const allUsages = [...builtinSlashCommands.map((c) => c.usage), '/']
- const maxLen = Math.max(...allUsages.map((u) => u.length))
- const helpLines = [
- '可用命令:',
- ...builtinSlashCommands.map((cmd) => {
- const pad = ' '.repeat(maxLen - cmd.usage.length)
- return ` ${cmd.usage}${pad} — ${cmd.description}`
- }),
- ` /${''.padEnd(maxLen - 1)} — 激活/停用技能`,
- ]
- addSystemMessage(helpLines.join('\n'))
+ addSystemMessage(buildSlashHelpText(allSlashCommands))
return true
}
case '/compact': {
- if (!isValidSessionId(currentSessionId)) { useUIStore.getState().showToast('Send a message first to start a session', 'error'); return true }
- try { await api.compact(currentSessionId, '') } catch (err) { console.error('Compact failed:', err); useUIStore.getState().showToast('Compaction failed', 'error') }
+ if (!isValidSessionId(currentSessionId)) {
+ useUIStore.getState().showToast('Send a message first to start a session', 'error')
+ return true
+ }
+ try {
+ await api.compact(currentSessionId, '')
+ } catch (err) {
+ console.error('Compact failed:', err)
+ useUIStore.getState().showToast('Compaction failed', 'error')
+ }
return true
}
case '/memo': {
- if (!isValidSessionId(currentSessionId)) { useUIStore.getState().showToast('Send a message first to start a session', 'error'); return true }
+ if (!isValidSessionId(currentSessionId)) {
+ useUIStore.getState().showToast('Send a message first to start a session', 'error')
+ return true
+ }
try {
const result = await api.executeSystemTool(currentSessionId, '', 'memo_list', {})
- addSystemMessage((result as any)?.payload?.content || 'Memo query complete')
- } catch (err) { console.error('Memo list failed:', err); useUIStore.getState().showToast('Failed to query memo', 'error') }
+ addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo query complete')
+ } catch (err) {
+ console.error('Memo list failed:', err)
+ useUIStore.getState().showToast('Failed to query memo', 'error')
+ }
return true
}
case '/remember': {
- if (!argument) { useUIStore.getState().showToast('Usage: /remember ', 'error'); return true }
- if (!isValidSessionId(currentSessionId)) { useUIStore.getState().showToast('Send a message first to start a session', 'error'); return true }
+ if (!argument) {
+ useUIStore.getState().showToast('Usage: /remember ', 'error')
+ return true
+ }
+ if (!isValidSessionId(currentSessionId)) {
+ useUIStore.getState().showToast('Send a message first to start a session', 'error')
+ return true
+ }
try {
- const result = await api.executeSystemTool(currentSessionId, '', 'memo_remember', { type: 'user', title: argument, content: argument })
- addSystemMessage((result as any)?.payload?.content || 'Memo saved')
- } catch (err) { console.error('Remember failed:', err); useUIStore.getState().showToast('Failed to save memo', 'error') }
+ const result = await api.executeSystemTool(currentSessionId, '', 'memo_remember', {
+ type: 'user',
+ title: argument,
+ content: argument,
+ })
+ addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo saved')
+ } catch (err) {
+ console.error('Remember failed:', err)
+ useUIStore.getState().showToast('Failed to save memo', 'error')
+ }
return true
}
case '/forget': {
- if (!argument) { useUIStore.getState().showToast('Usage: /forget ', 'error'); return true }
- if (!isValidSessionId(currentSessionId)) { useUIStore.getState().showToast('Send a message first to start a session', 'error'); return true }
+ if (!argument) {
+ useUIStore.getState().showToast('Usage: /forget ', 'error')
+ return true
+ }
+ if (!isValidSessionId(currentSessionId)) {
+ useUIStore.getState().showToast('Send a message first to start a session', 'error')
+ return true
+ }
try {
- const result = await api.executeSystemTool(currentSessionId, '', 'memo_remove', { keyword: argument, scope: 'all' })
- addSystemMessage((result as any)?.payload?.content || 'Memo deleted')
- } catch (err) { console.error('Forget failed:', err); useUIStore.getState().showToast('Failed to delete memo', 'error') }
+ const result = await api.executeSystemTool(currentSessionId, '', 'memo_remove', {
+ keyword: argument,
+ scope: 'all',
+ })
+ addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo deleted')
+ } catch (err) {
+ console.error('Forget failed:', err)
+ useUIStore.getState().showToast('Failed to delete memo', 'error')
+ }
+ return true
+ }
+ case '/skills': {
+ setShowSkillPicker(true)
return true
}
- case '/skills': { setShowSkillPicker(true); return true }
default: {
- if (isGenerating) { useUIStore.getState().showToast('Cannot toggle skill while generating', 'info'); return true }
- const skillCmd = availableSkillCommands.find((s) => s.usage === command)
- if (skillCmd && isValidSessionId(currentSessionId)) {
+ if (isGenerating) {
+ useUIStore.getState().showToast('Cannot toggle skill while generating', 'info')
+ return true
+ }
+ const skillCommand = availableSkillCommands.find((skill) => skill.usage === command)
+ if (skillCommand && isValidSessionId(currentSessionId)) {
try {
- if (skillCmd.active) await api.deactivateSessionSkill(currentSessionId, skillCmd.skillId)
- else await api.activateSessionSkill(currentSessionId, skillCmd.skillId)
- setAvailableSkillCommands((prev) => prev.map((item) => item.skillId === skillCmd.skillId ? { ...item, active: !item.active } : item))
- } catch (err) { console.error('Skill toggle failed:', err); useUIStore.getState().showToast('Skill operation failed', 'error') }
+ if (skillCommand.active) {
+ await api.deactivateSessionSkill(currentSessionId, skillCommand.skillId)
+ } else {
+ await api.activateSessionSkill(currentSessionId, skillCommand.skillId)
+ }
+ setAvailableSkillCommands((prev) => prev.map((item) => (
+ item.skillId === skillCommand.skillId
+ ? { ...item, active: !item.active }
+ : item
+ )))
+ } catch (err) {
+ console.error('Skill toggle failed:', err)
+ useUIStore.getState().showToast('Skill operation failed', 'error')
+ }
return true
}
return false
}
}
- }, [gatewayAPI, sessionId, addSystemMessage, availableSkillCommands, isGenerating])
+ }, [gatewayAPI, sessionId, addSystemMessage, availableSkillCommands, isGenerating, allSlashCommands])
async function handleSubmit() {
const input = text.trim()
if (!input) return
+
if (isGenerating) {
if (isSlashCommand(input)) useUIStore.getState().showToast('Cannot run commands while generating', 'info')
return
}
+
if (isSlashCommand(input)) {
- setText(''); setShowSlashMenu(false)
+ setText('')
+ setShowSlashMenu(false)
const handled = await executeSlashCommand(input)
if (handled) return
}
+
setText('')
const userMsg = createUserMessage(input)
addMessage(userMsg)
useRuntimeInsightStore.getState().setTodoSnapshot({
- items: [], summary: { total: 0, required_total: 0, required_completed: 0, required_failed: 0, required_open: 0 },
+ items: [],
+ summary: { total: 0, required_total: 0, required_completed: 0, required_failed: 0, required_open: 0 },
})
setGenerating(true)
runCancelledRef.current = false
+
try {
if (!gatewayAPI) return
const isNewSession = !isValidSessionId(sessionId)
- const ack = await gatewayAPI.run({ session_id: isNewSession ? undefined : sessionId, new_session: isNewSession ? true : undefined, input_text: input, mode: agentMode })
+ const ack = await gatewayAPI.run({
+ session_id: isNewSession ? undefined : sessionId,
+ new_session: isNewSession ? true : undefined,
+ input_text: input,
+ mode: agentMode,
+ })
if (!runCancelledRef.current) {
- const gwStore = useGatewayStore.getState()
- const sessStore = useSessionStore.getState()
- if (ack.run_id) gwStore.setCurrentRunId(ack.run_id)
- if (ack.session_id) { sessStore.setCurrentSessionId(ack.session_id); gatewayAPI?.bindStream({ session_id: ack.session_id, channel: 'all' }).catch(() => {}) }
+ const gatewayStore = useGatewayStore.getState()
+ const sessionStore = useSessionStore.getState()
+ if (ack.run_id) gatewayStore.setCurrentRunId(ack.run_id)
+ if (ack.session_id) {
+ sessionStore.setCurrentSessionId(ack.session_id)
+ gatewayAPI.bindStream({ session_id: ack.session_id, channel: 'all' }).catch(() => {})
+ }
}
} catch (err) {
if (!runCancelledRef.current) {
- setGenerating(false); useChatStore.getState().removeMessage(userMsg.id)
- console.error('Run failed:', err); useUIStore.getState().showToast('Failed to send message', 'error')
+ setGenerating(false)
+ useChatStore.getState().removeMessage(userMsg.id)
+ console.error('Run failed:', err)
+ useUIStore.getState().showToast('Failed to send message', 'error')
}
}
}
function handleKeyDown(e: React.KeyboardEvent) {
if (composingRef.current) return
+
if (!showSlashMenu) {
- if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); handleSubmit() }
+ if (e.key === 'Enter' && !e.shiftKey) {
+ e.preventDefault()
+ handleSubmit()
+ }
return
}
+
switch (e.key) {
- case 'ArrowDown': e.preventDefault(); setSelectedIndex((prev) => (prev + 1) % matchedCommands.length); return
- case 'ArrowUp': e.preventDefault(); setSelectedIndex((prev) => (prev - 1 + matchedCommands.length) % matchedCommands.length); return
- case 'Enter': e.preventDefault(); const cmd = matchedCommands[selectedIndex]; if (cmd) handleSelectCommand(cmd); return
- case 'Escape': e.preventDefault(); setShowSlashMenu(false); return
- case 'Tab': e.preventDefault(); const c = matchedCommands[selectedIndex]; if (c) { setText(c.usage + ' '); textareaRef.current?.focus() }; return
+ case 'ArrowDown':
+ e.preventDefault()
+ setSelectedIndex((prev) => (prev + 1) % matchedCommands.length)
+ return
+ case 'ArrowUp':
+ e.preventDefault()
+ setSelectedIndex((prev) => (prev - 1 + matchedCommands.length) % matchedCommands.length)
+ return
+ case 'Enter': {
+ e.preventDefault()
+ const command = matchedCommands[selectedIndex]
+ if (command) handleSelectCommand(command)
+ return
+ }
+ case 'Escape':
+ e.preventDefault()
+ setShowSlashMenu(false)
+ return
+ case 'Tab': {
+ e.preventDefault()
+ const command = matchedCommands[selectedIndex]
+ if (command) {
+ setText(`${command.usage} `)
+ textareaRef.current?.focus()
+ }
+ return
+ }
}
}
function handleSelectCommand(cmd: AnySlashCommand) {
- if (isSkillCommand(cmd)) { setText(cmd.usage); setShowSlashMenu(false); executeSlashCommand(cmd.usage); return }
- if (cmd.hasArgument) { setText(cmd.usage + ' '); setShowSlashMenu(false); textareaRef.current?.focus() }
- else { setText(''); setShowSlashMenu(false); executeSlashCommand(cmd.usage) }
+ if (isSkillCommand(cmd)) {
+ setText(cmd.usage)
+ setShowSlashMenu(false)
+ void executeSlashCommand(cmd.usage)
+ return
+ }
+
+ if (cmd.hasArgument) {
+ setText(`${cmd.usage} `)
+ setShowSlashMenu(false)
+ textareaRef.current?.focus()
+ return
+ }
+
+ setText('')
+ setShowSlashMenu(false)
+ void executeSlashCommand(cmd.usage)
}
async function handleCancel() {
runCancelledRef.current = true
const runId = useGatewayStore.getState().currentRunId
- if (runId && gatewayAPI) { try { await gatewayAPI.cancel({ run_id: runId }) } catch (err) { console.error('Cancel failed:', err) } }
+ if (runId && gatewayAPI) {
+ try {
+ await gatewayAPI.cancel({ run_id: runId })
+ } catch (err) {
+ console.error('Cancel failed:', err)
+ }
+ }
useChatStore.getState().resetGeneratingState()
}
@@ -218,9 +375,9 @@ export default function ChatInput() {
setShowSkillPicker(false)} />
)}
- {showSlashMenu && matchedCommands.length > 0 && (
-
-
+
+ {showSlashMenu && matchedCommands.length > 0 && (
+
-
- )}
-
@@ -309,11 +466,11 @@ export default function ChatInput() {
}
function BudgetTokenStrip() {
- const budgetChecked = useRuntimeInsightStore((s) => s.budgetChecked)
- const budgetUsageRatio = useRuntimeInsightStore((s) => s.budgetUsageRatio)
- const budgetEstimateFailed = useRuntimeInsightStore((s) => s.budgetEstimateFailed)
- const ledgerReconciled = useRuntimeInsightStore((s) => s.ledgerReconciled)
- const tokenUsage = useChatStore((s) => s.tokenUsage)
+ const budgetChecked = useRuntimeInsightStore((state) => state.budgetChecked)
+ const budgetUsageRatio = useRuntimeInsightStore((state) => state.budgetUsageRatio)
+ const budgetEstimateFailed = useRuntimeInsightStore((state) => state.budgetEstimateFailed)
+ const ledgerReconciled = useRuntimeInsightStore((state) => state.ledgerReconciled)
+ const tokenUsage = useChatStore((state) => state.tokenUsage)
const [open, setOpen] = useState(false)
const ref = useRef
(null)
const [popoverStyle, setPopoverStyle] = useState({})
@@ -334,7 +491,7 @@ function BudgetTokenStrip() {
useEffect(() => {
if (!open) return
- /** 根据锚点动态计算弹层位置,避免被容器裁剪或超出视口。 */
+ /** 根据锚点位置动态计算弹层,避免被容器裁剪或超出视口。 */
function updatePopoverPosition() {
const anchor = ref.current
if (!anchor) return
@@ -370,17 +527,36 @@ function BudgetTokenStrip() {
if (!budgetChecked && !totalTokens) return null
return (
- setOpen(true)}
onMouseLeave={() => setOpen(false)}
>
-
+
@@ -410,7 +586,9 @@ function BudgetTokenStrip() {
)}
已用
- {formatTokenCount(budgetChecked.estimated_input_tokens)} ({pct}%)
+
+ {formatTokenCount(budgetChecked.estimated_input_tokens)} ({pct}%)
+
{totalTokens > 0 && (
diff --git a/web/src/components/chat/CheckpointInlineMark.test.tsx b/web/src/components/chat/CheckpointInlineMark.test.tsx
new file mode 100644
index 00000000..de4a5d26
--- /dev/null
+++ b/web/src/components/chat/CheckpointInlineMark.test.tsx
@@ -0,0 +1,39 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { fireEvent, render, screen, waitFor } from '@testing-library/react'
+import CheckpointInlineMark from './CheckpointInlineMark'
+import { useSessionStore } from '@/stores/useSessionStore'
+import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore'
+
+let gatewayAPI: any = null
+vi.mock('@/context/RuntimeProvider', () => ({
+ useGatewayAPI: () => gatewayAPI,
+}))
+
+describe('CheckpointInlineMark', () => {
+ beforeEach(() => {
+ gatewayAPI = {
+ restoreCheckpoint: vi.fn().mockResolvedValue({ payload: {} }),
+ undoRestore: vi.fn().mockResolvedValue({ payload: {} }),
+ checkpointDiff: vi.fn().mockResolvedValue({ payload: { files: { added: [], modified: [], deleted: [] }, patch: '' } }),
+ }
+ useSessionStore.setState({ currentSessionId: 's1' } as any)
+ useRuntimeInsightStore.getState().reset()
+ vi.spyOn(window, 'confirm').mockReturnValue(true)
+ })
+
+ it('restores checkpoint from available state', async () => {
+ render(
)
+ fireEvent.click(screen.getByRole('button', { name: /cp_abcdef/i }))
+ await waitFor(() => expect(gatewayAPI.restoreCheckpoint).toHaveBeenCalledWith({
+ session_id: 's1',
+ checkpoint_id: 'abcdef123456',
+ }))
+ })
+
+ it('renders restored state and can undo restore', async () => {
+ render(
)
+ fireEvent.click(screen.getByRole('button', { name: /已撤回/ }))
+ await waitFor(() => expect(gatewayAPI.undoRestore).toHaveBeenCalledWith('s1'))
+ })
+})
+
diff --git a/web/src/components/chat/CodeBlock.test.tsx b/web/src/components/chat/CodeBlock.test.tsx
new file mode 100644
index 00000000..d8d10a69
--- /dev/null
+++ b/web/src/components/chat/CodeBlock.test.tsx
@@ -0,0 +1,27 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { fireEvent, render, screen } from '@testing-library/react'
+import CodeBlock from './CodeBlock'
+
+describe('CodeBlock', () => {
+ beforeEach(() => {
+ Object.assign(navigator, {
+ clipboard: { writeText: vi.fn() },
+ })
+ })
+
+ it('renders inline code and copies content', () => {
+ render(
)
+ const container = screen.getByText('const a = 1').closest('div') as HTMLElement
+ fireEvent.mouseEnter(container)
+ fireEvent.click(screen.getByTitle('复制'))
+ expect(navigator.clipboard.writeText).toHaveBeenCalledWith('const a = 1')
+ })
+
+ it('renders file code block with line numbers', () => {
+ render(
)
+ expect(screen.getByText('a.ts')).toBeInTheDocument()
+ expect(screen.getByText('line1')).toBeInTheDocument()
+ expect(screen.getByText('line2')).toBeInTheDocument()
+ })
+})
+
diff --git a/web/src/components/chat/MessageItem.test.tsx b/web/src/components/chat/MessageItem.test.tsx
new file mode 100644
index 00000000..cfe3db05
--- /dev/null
+++ b/web/src/components/chat/MessageItem.test.tsx
@@ -0,0 +1,49 @@
+import { describe, expect, it, vi } from 'vitest'
+import { fireEvent, render, screen } from '@testing-library/react'
+import MessageItem from './MessageItem'
+
+vi.mock('./ToolCallCard', () => ({ default: () =>
tool-card
}))
+vi.mock('./VerificationMessage', () => ({ default: () =>
verification-card
}))
+vi.mock('./AcceptanceMessage', () => ({ default: () =>
acceptance-card
}))
+vi.mock('./CodeBlock', () => ({ default: ({ code }: { code: string }) =>
{code} }))
+vi.mock('./MarkdownContent', () => ({ default: ({ content }: { content: string }) =>
{content} }))
+vi.mock('@/context/RuntimeProvider', () => ({ useGatewayAPI: () => null }))
+
+describe('MessageItem', () => {
+ it('renders system message', () => {
+ render(
)
+ expect(screen.getByText('sys')).toBeInTheDocument()
+ })
+
+ it('renders welcome message', () => {
+ render(
)
+ expect(screen.getByText('hello')).toBeInTheDocument()
+ })
+
+ it('renders thinking message and toggles details', () => {
+ render(
+
,
+ )
+ fireEvent.click(screen.getByText('AI 思考过程'))
+ expect(screen.getByText('reasoning')).toBeInTheDocument()
+ })
+
+ it('renders tool/verification/acceptance delegates', () => {
+ const { rerender } = render(
)
+ expect(screen.getByText('tool-card')).toBeInTheDocument()
+ rerender(
)
+ expect(screen.getByText('verification-card')).toBeInTheDocument()
+ rerender(
)
+ expect(screen.getByText('acceptance-card')).toBeInTheDocument()
+ })
+
+ it('renders code and plain assistant messages', () => {
+ const { rerender } = render(
)
+ expect(screen.getByText('const a=1')).toBeInTheDocument()
+ rerender(
)
+ expect(screen.getByText('answer')).toBeInTheDocument()
+ })
+})
+
diff --git a/web/src/components/chat/MessageList.test.tsx b/web/src/components/chat/MessageList.test.tsx
new file mode 100644
index 00000000..a8af748b
--- /dev/null
+++ b/web/src/components/chat/MessageList.test.tsx
@@ -0,0 +1,36 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { render, screen } from '@testing-library/react'
+import MessageList from './MessageList'
+import { useChatStore } from '@/stores/useChatStore'
+
+vi.mock('./MessageItem', () => ({
+ default: ({ message, groupedWithPrev }: any) => (
+
{message.id}:{groupedWithPrev ? 'group' : 'solo'}
+ ),
+}))
+
+describe('MessageList', () => {
+ beforeEach(() => {
+ useChatStore.setState({ messages: [], isGenerating: false } as any)
+ })
+
+ it('renders empty state when no messages', () => {
+ render(
)
+ expect(screen.getByText('开始你的 AI 编程之旅')).toBeInTheDocument()
+ })
+
+ it('reorders process messages before assistant text within AI turn', () => {
+ useChatStore.setState({
+ messages: [
+ { id: 'u1', role: 'user', type: 'text', content: 'q', timestamp: 1 },
+ { id: 'a1', role: 'assistant', type: 'text', content: 'answer', timestamp: 2 },
+ { id: 't1', role: 'tool', type: 'tool_call', content: '', timestamp: 3 },
+ { id: 'a2', role: 'assistant', type: 'thinking', content: 'thinking', timestamp: 4 },
+ ],
+ } as any)
+
+ render(
)
+ const ids = screen.getAllByTestId(/msg-/).map((x) => x.textContent)
+ expect(ids).toEqual(['u1:solo', 't1:solo', 'a2:group', 'a1:group'])
+ })
+})
diff --git a/web/src/components/chat/SkillPicker.test.tsx b/web/src/components/chat/SkillPicker.test.tsx
new file mode 100644
index 00000000..2db59a17
--- /dev/null
+++ b/web/src/components/chat/SkillPicker.test.tsx
@@ -0,0 +1,89 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { fireEvent, render, screen, waitFor } from '@testing-library/react'
+import SkillPicker from './SkillPicker'
+import { useChatStore } from '@/stores/useChatStore'
+import { useUIStore } from '@/stores/useUIStore'
+
+describe('SkillPicker', () => {
+ beforeEach(() => {
+ useChatStore.setState({ isGenerating: false } as any)
+ useUIStore.setState({ showToast: vi.fn() } as any)
+ })
+
+ it('renders empty state when no skills', async () => {
+ const api = {
+ listAvailableSkills: vi.fn().mockResolvedValue({ payload: { skills: [] } }),
+ } as any
+ render(
)
+ await screen.findByText('暂无可用技能')
+ })
+
+ it('toggles skill activation and reloads list', async () => {
+ const api = {
+ listAvailableSkills: vi
+ .fn()
+ .mockResolvedValueOnce({
+ payload: {
+ skills: [{
+ active: false,
+ descriptor: { id: 'sk1', name: 'Skill 1', description: 'desc', scope: 'explicit' },
+ }],
+ },
+ })
+ .mockResolvedValueOnce({
+ payload: {
+ skills: [{
+ active: true,
+ descriptor: { id: 'sk1', name: 'Skill 1', description: 'desc', scope: 'explicit' },
+ }],
+ },
+ }),
+ activateSessionSkill: vi.fn().mockResolvedValue({ payload: {} }),
+ deactivateSessionSkill: vi.fn().mockResolvedValue({ payload: {} }),
+ } as any
+
+ render(
)
+ const activateBtn = await screen.findByRole('button', { name: '激活' })
+ fireEvent.click(activateBtn)
+
+ await waitFor(() => expect(api.activateSessionSkill).toHaveBeenCalledWith('s1', 'sk1'))
+ expect(api.listAvailableSkills).toHaveBeenCalledTimes(2)
+ })
+
+ it('blocks toggle when session is missing', async () => {
+ const showToast = vi.fn()
+ useUIStore.setState({ showToast } as any)
+ const api = {
+ listAvailableSkills: vi.fn().mockResolvedValue({
+ payload: {
+ skills: [{ active: false, descriptor: { id: 'sk1', name: 'Skill 1' } }],
+ },
+ }),
+ activateSessionSkill: vi.fn(),
+ } as any
+
+ render(
)
+ const activateBtn = await screen.findByRole('button', { name: '激活' })
+ fireEvent.click(activateBtn)
+
+ expect(api.activateSessionSkill).not.toHaveBeenCalled()
+ expect(showToast).toHaveBeenCalledWith('Send a message first to start a session', 'error')
+ })
+
+ it('disables operation while generating', async () => {
+ useChatStore.setState({ isGenerating: true } as any)
+ const api = {
+ listAvailableSkills: vi.fn().mockResolvedValue({
+ payload: {
+ skills: [{ active: false, descriptor: { id: 'sk1', name: 'Skill 1' } }],
+ },
+ }),
+ activateSessionSkill: vi.fn(),
+ } as any
+
+ render(
)
+ const activateBtn = await screen.findByRole('button', { name: '激活' })
+ expect(activateBtn).toBeDisabled()
+ })
+})
+
diff --git a/web/src/components/chat/SlashCommandMenu.test.tsx b/web/src/components/chat/SlashCommandMenu.test.tsx
new file mode 100644
index 00000000..9f124a1b
--- /dev/null
+++ b/web/src/components/chat/SlashCommandMenu.test.tsx
@@ -0,0 +1,73 @@
+import { describe, expect, it, vi } from 'vitest'
+import { fireEvent, render, screen } from '@testing-library/react'
+import SlashCommandMenu from './SlashCommandMenu'
+import type { AnySlashCommand } from '@/utils/slashCommands'
+
+const builtin: AnySlashCommand = {
+ id: 'compact',
+ usage: '/compact',
+ description: 'compress context',
+ hasArgument: false,
+}
+
+const skill: AnySlashCommand = {
+ id: 'skill.demo',
+ usage: '/skill.demo',
+ description: 'demo skill',
+ hasArgument: false,
+ isSkill: true,
+ skillId: 'skill.demo',
+ active: true,
+}
+
+describe('SlashCommandMenu', () => {
+ ;(HTMLElement.prototype as { scrollIntoView?: () => void }).scrollIntoView = vi.fn()
+
+ it('returns null when commands is empty', () => {
+ const { container } = render(
+
,
+ )
+ expect(container.firstChild).toBeNull()
+ })
+
+ it('renders builtin and skill sections without owning absolute positioning', () => {
+ render(
+
,
+ )
+
+ const menu = screen.getByTestId('slash-command-menu')
+ expect(menu).toBeInTheDocument()
+ expect(menu).not.toHaveStyle({ position: 'absolute' })
+ expect(screen.getByText('命令')).toBeInTheDocument()
+ expect(screen.getByText('技能')).toBeInTheDocument()
+ expect(screen.getByText('已激活')).toBeInTheDocument()
+ expect(screen.getAllByText((_, el) => Boolean(el?.textContent?.includes('/compact'))).length).toBeGreaterThan(0)
+ })
+
+ it('triggers hover and select callbacks', () => {
+ const onSelect = vi.fn()
+ const onHover = vi.fn()
+
+ render(
+
,
+ )
+
+ fireEvent.mouseEnter(screen.getByText('/compact'))
+ fireEvent.click(screen.getByText('/skill.demo'))
+
+ expect(onHover).toHaveBeenCalledWith(0)
+ expect(onSelect).toHaveBeenCalledWith(skill)
+ })
+})
diff --git a/web/src/components/chat/SlashCommandMenu.tsx b/web/src/components/chat/SlashCommandMenu.tsx
index 71869f08..d2a2c582 100644
--- a/web/src/components/chat/SlashCommandMenu.tsx
+++ b/web/src/components/chat/SlashCommandMenu.tsx
@@ -26,41 +26,50 @@ function getCommandIcon(cmd: AnySlashCommand): React.ReactNode {
}
function highlightMatch(text: string, query: string): React.ReactNode {
- if (!query || query === '/') return text
+ const normalizedQuery = query.trim().toLowerCase()
+ if (!normalizedQuery || normalizedQuery === '/') return text
+
const lowerText = text.toLowerCase()
- const lowerQuery = query.toLowerCase().trim()
- const idx = lowerText.indexOf(lowerQuery)
- if (idx === -1) return text
+ const matchIndex = lowerText.indexOf(normalizedQuery)
+ if (matchIndex < 0) return text
return (
<>
- {text.slice(0, idx)}
-
{text.slice(idx, idx + lowerQuery.length)}
- {text.slice(idx + lowerQuery.length)}
+ {text.slice(0, matchIndex)}
+
+ {text.slice(matchIndex, matchIndex + normalizedQuery.length)}
+
+ {text.slice(matchIndex + normalizedQuery.length)}
>
)
}
-/** Slash 命令浮动菜单 */
-export default function SlashCommandMenu({ commands, selectedIndex, onSelect, onHover, query }: SlashCommandMenuProps) {
+/** Slash 命令菜单只负责渲染内容,不再自行决定浮层定位。 */
+export default function SlashCommandMenu({
+ commands,
+ selectedIndex,
+ onSelect,
+ onHover,
+ query,
+}: SlashCommandMenuProps) {
useEffect(() => {
const el = document.querySelector(`[data-slash-index="${selectedIndex}"]`)
- if (el) {
+ if (el instanceof HTMLElement && typeof el.scrollIntoView === 'function') {
el.scrollIntoView({ block: 'nearest' })
}
}, [selectedIndex])
if (commands.length === 0) return null
- const builtinCmds = commands.filter(isBuiltinCommand)
- const skillCmds = commands.filter(isSkillCommand)
+ const builtinCommands = commands.filter(isBuiltinCommand)
+ const skillCommands = commands.filter(isSkillCommand)
return (
-
- {builtinCmds.length > 0 && (
+
+ {builtinCommands.length > 0 && (
命令
- {builtinCmds.map((cmd) => {
+ {builtinCommands.map((cmd) => {
const globalIndex = commands.indexOf(cmd)
return (
)}
- {skillCmds.length > 0 && (
+ {skillCommands.length > 0 && (
- {builtinCmds.length > 0 &&
}
+ {builtinCommands.length > 0 &&
}
技能
- {skillCmds.map((cmd) => {
+ {skillCommands.map((cmd) => {
const globalIndex = commands.indexOf(cmd)
return (
= {
container: {
- position: 'absolute',
- bottom: '100%',
- left: 0,
- marginBottom: 8,
minWidth: 280,
maxWidth: 360,
maxHeight: 320,
overflowY: 'auto',
- background: 'var(--bg-secondary)',
+ background: 'var(--bg-overlay)',
border: '1px solid var(--border-primary)',
borderRadius: 'var(--radius-lg)',
- boxShadow: '0 4px 24px rgba(0,0,0,0.15)',
- zIndex: 100,
+ boxShadow: 'var(--shadow-elevated)',
padding: '6px 0',
},
sectionLabel: {
diff --git a/web/src/components/chat/TodoStrip.test.tsx b/web/src/components/chat/TodoStrip.test.tsx
new file mode 100644
index 00000000..f5911467
--- /dev/null
+++ b/web/src/components/chat/TodoStrip.test.tsx
@@ -0,0 +1,49 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { fireEvent, render, screen } from '@testing-library/react'
+import TodoStrip from './TodoStrip'
+import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore'
+import { useUIStore } from '@/stores/useUIStore'
+
+describe('TodoStrip', () => {
+ beforeEach(() => {
+ useRuntimeInsightStore.getState().reset()
+ useUIStore.setState({
+ todoStripExpanded: false,
+ setTodoStripExpanded: vi.fn(),
+ } as any)
+ })
+
+ it('renders nothing when no snapshot and no conflict', () => {
+ const { container } = render()
+ expect(container.firstChild).toBeNull()
+ })
+
+ it('renders summary and items from snapshot', () => {
+ useRuntimeInsightStore.setState({
+ todoSnapshot: {
+ items: [
+ { id: '1', content: 'Task 1', status: 'in_progress', required: true, revision: 1 },
+ { id: '2', content: 'Task 2', status: 'completed', required: true, revision: 1 },
+ ],
+ summary: { total: 2, required_total: 2, required_completed: 1, required_failed: 0, required_open: 1 },
+ },
+ todoHistory: {
+ '1': { id: '1', content: 'Task 1', status: 'in_progress', required: true, revision: 1, firstSeenAt: 1, lastSeenAt: 1 },
+ '2': { id: '2', content: 'Task 2', status: 'completed', required: true, revision: 1, firstSeenAt: 1, lastSeenAt: 1 },
+ },
+ } as any)
+ render()
+ expect(screen.getByText('Task 1')).toBeInTheDocument()
+ fireEvent.click(screen.getByRole('button', { expanded: false }))
+ expect(useUIStore.getState().setTodoStripExpanded).toHaveBeenCalled()
+ })
+
+ it('forces expanded conflict state', () => {
+ useRuntimeInsightStore.setState({
+ todoConflict: { action: 'conflict', reason: 'manual check needed' },
+ } as any)
+ render()
+ expect(screen.getByText(/Todo 冲突/)).toBeInTheDocument()
+ })
+})
+
diff --git a/web/src/components/chat/ToolCallCard.test.tsx b/web/src/components/chat/ToolCallCard.test.tsx
new file mode 100644
index 00000000..cb79f8df
--- /dev/null
+++ b/web/src/components/chat/ToolCallCard.test.tsx
@@ -0,0 +1,58 @@
+import { describe, expect, it, vi } from 'vitest'
+import { fireEvent, render, screen } from '@testing-library/react'
+import ToolCallCard from './ToolCallCard'
+
+vi.mock('./CheckpointInlineMark', () => ({
+ default: ({ checkpointId }: { checkpointId: string }) => cp:{checkpointId},
+}))
+
+describe('ToolCallCard', () => {
+ it('shows running state and expands/collapses', () => {
+ render(
+ ,
+ )
+ expect(screen.getByText('bash')).toBeInTheDocument()
+ expect(screen.getByText('$')).toBeInTheDocument()
+ fireEvent.click(screen.getByRole('button', { expanded: true }))
+ })
+
+ it('renders file edit diff detail', () => {
+ render(
+ ,
+ )
+ fireEvent.click(screen.getByRole('button', { expanded: false }))
+ expect(screen.getAllByText('a.ts').length).toBeGreaterThan(0)
+ expect(screen.getByText('old')).toBeInTheDocument()
+ expect(screen.getByText('new')).toBeInTheDocument()
+ expect(screen.getByText('cp:cp1')).toBeInTheDocument()
+ })
+})
diff --git a/web/src/components/chat/VerificationMessage.test.tsx b/web/src/components/chat/VerificationMessage.test.tsx
new file mode 100644
index 00000000..dadfb094
--- /dev/null
+++ b/web/src/components/chat/VerificationMessage.test.tsx
@@ -0,0 +1,32 @@
+import { describe, expect, it } from 'vitest'
+import { fireEvent, render, screen } from '@testing-library/react'
+import VerificationMessage from './VerificationMessage'
+
+describe('VerificationMessage', () => {
+ it('renders running summary and stage details', () => {
+ render(
+ ,
+ )
+ expect(screen.getByText(/Verify running/)).toBeInTheDocument()
+ fireEvent.click(screen.getByRole('button'))
+ expect(screen.getByText('test')).toBeInTheDocument()
+ })
+})
+
diff --git a/web/src/components/layout/AppLayout.test.tsx b/web/src/components/layout/AppLayout.test.tsx
new file mode 100644
index 00000000..e98098b4
--- /dev/null
+++ b/web/src/components/layout/AppLayout.test.tsx
@@ -0,0 +1,60 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { fireEvent, render, screen } from '@testing-library/react'
+import AppLayout from './AppLayout'
+import { useUIStore } from '@/stores/useUIStore'
+import { useSessionStore } from '@/stores/useSessionStore'
+import { useWorkspaceStore } from '@/stores/useWorkspaceStore'
+
+vi.mock('./Sidebar', () => ({ default: ({ collapsed }: { collapsed?: boolean }) => {collapsed ? 'sidebar-collapsed' : 'sidebar-open'}
}))
+vi.mock('@/components/chat/ChatPanel', () => ({ default: () => chat-panel
}))
+vi.mock('@/components/panels/FileChangePanel', () => ({ default: () => changes-panel
}))
+vi.mock('@/components/panels/FileTreePanel', () => ({ default: () => tree-panel
}))
+vi.mock('@/components/status/StatusBar', () => ({ default: () => status-bar
}))
+vi.mock('@/components/ui/ToastContainer', () => ({ default: () => toast-container
}))
+
+describe('AppLayout', () => {
+ beforeEach(() => {
+ useUIStore.setState({
+ sidebarOpen: true,
+ sidebarWidth: 280,
+ setSidebarWidth: vi.fn(),
+ changesPanelOpen: false,
+ changesPanelWidth: 360,
+ setChangesPanelWidth: vi.fn(),
+ fileTreePanelOpen: false,
+ fileTreePanelWidth: 320,
+ setFileTreePanelWidth: vi.fn(),
+ } as any)
+ useSessionStore.setState({
+ prepareNewChat: vi.fn(),
+ } as any)
+ useWorkspaceStore.setState({ currentWorkspaceHash: '' } as any)
+ })
+
+ it('renders main layout with sidebar open', () => {
+ render()
+ expect(screen.getByText('sidebar-open')).toBeInTheDocument()
+ expect(screen.getByText('chat-panel')).toBeInTheDocument()
+ })
+
+ it('renders collapsed sidebar and right panels when toggled', () => {
+ useUIStore.setState({
+ sidebarOpen: false,
+ changesPanelOpen: true,
+ fileTreePanelOpen: true,
+ } as any)
+ render()
+ expect(screen.getByText('sidebar-collapsed')).toBeInTheDocument()
+ expect(screen.getByText('changes-panel')).toBeInTheDocument()
+ expect(screen.getByText('tree-panel')).toBeInTheDocument()
+ })
+
+ it('handles ctrl/cmd+n shortcut', () => {
+ const prepareNewChat = vi.fn()
+ useSessionStore.setState({ prepareNewChat } as any)
+ render()
+ fireEvent.keyDown(window, { key: 'n', ctrlKey: true })
+ expect(prepareNewChat).toHaveBeenCalled()
+ })
+})
+
diff --git a/web/src/components/panels/FileChangePanel.test.tsx b/web/src/components/panels/FileChangePanel.test.tsx
index 5e47ccda..52f38b1f 100644
--- a/web/src/components/panels/FileChangePanel.test.tsx
+++ b/web/src/components/panels/FileChangePanel.test.tsx
@@ -1,11 +1,17 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { fireEvent, render, screen, waitFor, within } from '@testing-library/react'
import FileChangePanel from './FileChangePanel'
+import { useChatStore } from '@/stores/useChatStore'
+import { useSessionStore } from '@/stores/useSessionStore'
import { CHANGES_PREVIEW_TAB_ID, GIT_DIFF_PREVIEW_TAB_ID, useUIStore } from '@/stores/useUIStore'
const mockGatewayAPI = {
listGitDiffFiles: vi.fn(),
readGitDiffFile: vi.fn(),
+ restoreCheckpoint: vi.fn(),
+ loadSession: vi.fn(),
+ listSessionTodos: vi.fn(),
+ listCheckpoints: vi.fn(),
}
vi.mock('@/context/RuntimeProvider', () => ({
@@ -71,6 +77,27 @@ describe('FileChangePanel', () => {
size_modified: 5,
},
})
+ mockGatewayAPI.restoreCheckpoint.mockResolvedValue({
+ payload: {
+ checkpoint_id: 'cp-1',
+ session_id: 'sess-1',
+ },
+ })
+ mockGatewayAPI.loadSession.mockResolvedValue({
+ payload: {
+ id: 'sess-1',
+ agent_mode: 'build',
+ messages: [{ role: 'assistant', content: 'reloaded' }],
+ },
+ })
+ mockGatewayAPI.listSessionTodos.mockResolvedValue({
+ payload: { items: [] },
+ })
+ mockGatewayAPI.listCheckpoints.mockResolvedValue({
+ payload: [],
+ })
+ useChatStore.setState({ isGenerating: false } as never)
+ useSessionStore.setState({ currentSessionId: 'sess-1' } as never)
useUIStore.setState({
fileChanges: [
{
@@ -79,6 +106,7 @@ describe('FileChangePanel', () => {
status: 'modified',
additions: 2,
deletions: 2,
+ checkpoint_id: 'cp-1',
hunks: [
{
header: '@@ -1,3 +1,3 @@',
@@ -128,6 +156,7 @@ describe('FileChangePanel', () => {
changesPanelOpen: true,
changesPanelWidth: 560,
theme: 'dark',
+ isRestoringCheckpoint: false,
} as never)
})
@@ -136,16 +165,113 @@ describe('FileChangePanel', () => {
fireEvent.click(screen.getByText('src/a.txt'))
- expect(screen.getByText('接受')).toBeTruthy()
+ expect(screen.getByTestId('accept-change-fc-1')).toBeTruthy()
expect(screen.getAllByTestId(/diff-hunk-fc-1-/)).toHaveLength(1)
expect(screen.getByText('line 1')).toBeTruthy()
expect(screen.getByText('line 2 new')).toBeTruthy()
- fireEvent.click(screen.getByText('接受'))
+ fireEvent.click(screen.getByTestId('accept-change-fc-1'))
expect(useUIStore.getState().fileChanges[0]?.status).toBe('accepted')
})
+ it('renders restore button per file change and disables it when checkpoint is missing', () => {
+ useUIStore.setState({
+ fileChanges: [
+ {
+ id: 'fc-with-cp',
+ path: 'src/with-cp.txt',
+ status: 'modified',
+ additions: 1,
+ deletions: 0,
+ checkpoint_id: 'cp-1',
+ hunks: [],
+ },
+ {
+ id: 'fc-without-cp',
+ path: 'src/no-cp.txt',
+ status: 'modified',
+ additions: 1,
+ deletions: 0,
+ hunks: [],
+ },
+ ],
+ } as never)
+
+ render()
+
+ fireEvent.click(screen.getByText('src/with-cp.txt'))
+ fireEvent.click(screen.getByText('src/no-cp.txt'))
+
+ expect(screen.getByTestId('restore-change-fc-with-cp')).toBeEnabled()
+ expect(screen.getByTestId('restore-change-fc-without-cp')).toBeDisabled()
+ })
+
+ it('calls restoreCheckpoint after confirming restore', async () => {
+ render()
+
+ fireEvent.click(screen.getByText('src/a.txt'))
+ fireEvent.click(screen.getByTestId('restore-change-fc-1'))
+
+ const confirmButtons = screen.getAllByRole('button', { name: 'Restore' })
+ fireEvent.click(confirmButtons[confirmButtons.length - 1])
+
+ await waitFor(() => {
+ expect(mockGatewayAPI.restoreCheckpoint).toHaveBeenCalledWith({
+ session_id: 'sess-1',
+ checkpoint_id: 'cp-1',
+ })
+ })
+ expect(mockGatewayAPI.loadSession).not.toHaveBeenCalled()
+ })
+
+ it('disables accept and restore actions while the session is generating', () => {
+ useChatStore.setState({ isGenerating: true } as never)
+ render()
+
+ fireEvent.click(screen.getByText('src/a.txt'))
+
+ expect(screen.getByTestId('accept-change-fc-1')).toBeDisabled()
+ expect(screen.getByTestId('restore-change-fc-1')).toBeDisabled()
+ fireEvent.click(screen.getByTestId('restore-change-fc-1'))
+ expect(mockGatewayAPI.restoreCheckpoint).not.toHaveBeenCalled()
+ })
+
+ it('disables accept and restore actions while checkpoint restore is in progress', () => {
+ useUIStore.setState({ isRestoringCheckpoint: true } as never)
+ render()
+
+ fireEvent.click(screen.getByText('src/a.txt'))
+
+ expect(screen.getByTestId('accept-change-fc-1')).toBeDisabled()
+ expect(screen.getByTestId('restore-change-fc-1')).toBeDisabled()
+ })
+
+ it('prevents duplicate restore calls while restore is in-flight', async () => {
+ let resolveRestore: ((value?: unknown) => void) | undefined
+ mockGatewayAPI.restoreCheckpoint.mockImplementation(() => new Promise((resolve) => {
+ resolveRestore = resolve
+ }))
+
+ render()
+
+ fireEvent.click(screen.getByText('src/a.txt'))
+ fireEvent.click(screen.getByTestId('restore-change-fc-1'))
+ const confirmButtons = screen.getAllByRole('button', { name: 'Restore' })
+ fireEvent.click(confirmButtons[confirmButtons.length - 1])
+
+ await waitFor(() => {
+ expect(mockGatewayAPI.restoreCheckpoint).toHaveBeenCalledTimes(1)
+ })
+
+ const restoreButton = screen.getByTestId('restore-change-fc-1')
+ expect(restoreButton).toBeDisabled()
+ fireEvent.click(restoreButton)
+ expect(mockGatewayAPI.restoreCheckpoint).toHaveBeenCalledTimes(1)
+
+ resolveRestore?.()
+ })
+
it('keeps the panel body clipped and uses the content area as the scroll container', () => {
render()
@@ -374,7 +500,7 @@ describe('FileChangePanel', () => {
{
id: CHANGES_PREVIEW_TAB_ID,
kind: 'changes',
- title: '鏂囦欢鍙樻洿',
+ title: '文件变更',
closable: false,
},
{
@@ -449,7 +575,7 @@ describe('FileChangePanel', () => {
{
id: CHANGES_PREVIEW_TAB_ID,
kind: 'changes',
- title: '鏂囦欢鍙樻洿',
+ title: '文件变更',
closable: false,
},
{
diff --git a/web/src/components/panels/FileChangePanel.tsx b/web/src/components/panels/FileChangePanel.tsx
index e323f147..4842c26b 100644
--- a/web/src/components/panels/FileChangePanel.tsx
+++ b/web/src/components/panels/FileChangePanel.tsx
@@ -1,6 +1,9 @@
import { lazy, Suspense, useCallback, useEffect, useMemo, useRef, useState, type CSSProperties, type KeyboardEvent as ReactKeyboardEvent, type RefObject, type WheelEvent as ReactWheelEvent } from 'react'
-import { ChevronRight, Check, FileDiff, Loader2, PanelRightClose, RefreshCw, X } from 'lucide-react'
+import { ChevronRight, Check, FileDiff, Loader2, PanelRightClose, RefreshCw, RotateCcw, X } from 'lucide-react'
import { useGatewayAPI } from '@/context/RuntimeProvider'
+import ConfirmDialog from '@/components/ui/ConfirmDialog'
+import { useChatStore } from '@/stores/useChatStore'
+import { useSessionStore } from '@/stores/useSessionStore'
import { useWorkspaceStore } from '@/stores/useWorkspaceStore'
import {
CHANGES_PREVIEW_TAB_ID,
@@ -19,6 +22,7 @@ const LazyCodePreviewEditor = lazy(() => import('./CodePreviewEditor'))
const LazyGitDiffPreviewEditor = lazy(() => import('./GitDiffPreviewEditor'))
const changeStatusMeta: Record = {
+ pending: { label: '待定', color: 'var(--text-tertiary)', bg: 'var(--bg-active)' },
added: { label: '新增', color: 'var(--success)', bg: 'rgba(22, 163, 74, 0.12)' },
modified: { label: '修改', color: 'var(--warning)', bg: 'rgba(217, 119, 6, 0.14)' },
deleted: { label: '删除', color: 'var(--error)', bg: 'rgba(220, 38, 38, 0.12)' },
@@ -41,6 +45,7 @@ function getChangeStatusMeta(status: string) {
}
function getChangeType(change: { status: string; additions: number; deletions: number }) {
+ if (change.status === 'pending') return 'pending' as const
if (['added', 'modified', 'deleted'].includes(change.status)) return change.status as 'added' | 'modified' | 'deleted'
if (change.additions > 0 && change.deletions > 0) return 'modified'
if (change.additions > 0) return 'added'
@@ -52,12 +57,13 @@ function getChangeCounts(fileChanges: { status: string; additions: number; delet
return fileChanges.reduce(
(counts, change) => {
const type = getChangeType(change)
+ if (type === 'pending') counts.pending += 1
if (type === 'added') counts.added += 1
if (type === 'modified') counts.modified += 1
if (type === 'deleted') counts.deleted += 1
return counts
},
- { added: 0, modified: 0, deleted: 0 },
+ { pending: 0, added: 0, modified: 0, deleted: 0 },
)
}
@@ -99,6 +105,7 @@ function getPreviewBadge(tab: PreviewTab, fileChanges: FileChange[]) {
const matched = fileChanges.find((change) => change.path === tab.path)
if (!matched) return null
const type = getChangeType(matched)
+ if (type === 'pending') return { label: 'P', color: 'var(--text-tertiary)' }
if (type === 'added') return { label: 'A', color: 'var(--success)' }
if (type === 'deleted') return { label: 'D', color: 'var(--error)' }
return { label: 'M', color: 'var(--warning)' }
@@ -146,12 +153,31 @@ function FileChangeItem({
onToggle: () => void
scrollContainerRef: RefObject
}) {
+ const gatewayAPI = useGatewayAPI()
+ const sessionId = useSessionStore((state) => state.currentSessionId)
+ const isGenerating = useChatStore((state) => state.isGenerating)
const acceptFileChange = useUIStore((state) => state.acceptFileChange)
+ const isRestoringCheckpoint = useUIStore((state) => state.isRestoringCheckpoint)
+ const setRestoringCheckpoint = useUIStore((state) => state.setRestoringCheckpoint)
+ const showToast = useUIStore((state) => state.showToast)
const meta = getChangeStatusMeta(change.status)
const reviewed = change.status === 'accepted' || change.status === 'rejected'
const displayHunks = getDisplayHunks(change)
+ const [confirmingRestore, setConfirmingRestore] = useState(false)
+ const disabledByRunning = isGenerating
+ const disabledByRestore = isRestoringCheckpoint
+ const acceptDisabled = reviewed || disabledByRestore || disabledByRunning
+ const restoreDisabled = reviewed || disabledByRestore || disabledByRunning || !change.checkpoint_id
+ const acceptTitle = disabledByRunning || disabledByRestore ? 'Running; action is disabled' : 'Mark as reviewed'
+ const restoreTitle = disabledByRunning
+ ? 'Running; action is disabled'
+ : disabledByRestore
+ ? 'Checkpoint restore in progress'
+ : change.checkpoint_id
+ ? 'Restore to this point (impacts all later changes)'
+ : 'No checkpoint available for this file change'
- // 将 diff 内层收到的纵向滚轮显式转发给外层列表,避免 hover 在 hunk 上时卡住整列滚动。
+ // 横向 diff 区域优先消费纵向滚轮,避免 hover 在 hunk 上时页面滚动。
const handleHunkWheel = (event: ReactWheelEvent) => {
if (Math.abs(event.deltaY) <= Math.abs(event.deltaX)) {
return
@@ -164,6 +190,24 @@ function FileChangeItem({
event.preventDefault()
}
+ // handleConfirmRestore 只负责触发 checkpoint 回退,成功后的状态刷新由事件链路驱动。
+ async function handleConfirmRestore() {
+ setConfirmingRestore(false)
+ if (!gatewayAPI || !sessionId || !change.checkpoint_id || disabledByRestore || disabledByRunning) return
+
+ setRestoringCheckpoint(true)
+ try {
+ await gatewayAPI.restoreCheckpoint({
+ session_id: sessionId,
+ checkpoint_id: change.checkpoint_id,
+ })
+ } catch (err) {
+ const message = err instanceof Error ? err.message : 'Unknown error'
+ showToast(`Restore failed: ${message}`, 'error')
+ setRestoringCheckpoint(false)
+ }
+ }
+
return (