From 7be75e28c1ceef4f513526057104de6abf5a1222 Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Sat, 9 May 2026 21:02:48 +0800 Subject: [PATCH 1/4] =?UTF-8?q?feat:web=E7=AB=AF=E6=8E=A5=E5=85=A5?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E5=9B=9E=E9=80=80=E6=8C=89=E9=94=AE=E5=92=8C?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=EF=BC=8C=E4=BF=AE=E5=A4=8D=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E5=9B=9E=E9=80=80=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/checkpoint/checkpoint_manager.go | 250 ++++++++ .../checkpoint/checkpoint_manager_test.go | 78 +++ internal/checkpoint/per_edit_snapshot.go | 137 +++- internal/checkpoint/per_edit_snapshot_test.go | 90 +++ internal/cli/gateway_runtime_bridge.go | 4 +- internal/gateway/contracts.go | 2 + internal/runtime/checkpoint_flow_test.go | 466 +++++++++++++- internal/runtime/checkpoint_gate.go | 70 ++ internal/runtime/checkpoint_restore.go | 112 +++- internal/runtime/checkpoint_run_baseline.go | 277 ++++++++ internal/runtime/run.go | 33 +- internal/runtime/runtime.go | 72 ++- internal/session/checkpoint_types.go | 15 +- internal/tui/services/runtime_contract.go | 2 + web/src/api/protocol.ts | 13 + .../panels/FileChangePanel.test.tsx | 134 +++- web/src/components/panels/FileChangePanel.tsx | 85 ++- web/src/stores/useSessionStore.ts | 37 ++ web/src/stores/useUIStore.ts | 6 +- web/src/utils/eventBridge.test.ts | 603 +++++++++++++++++- web/src/utils/eventBridge.ts | 207 +++++- 21 files changed, 2597 insertions(+), 96 deletions(-) create mode 100644 internal/runtime/checkpoint_run_baseline.go diff --git a/internal/checkpoint/checkpoint_manager.go b/internal/checkpoint/checkpoint_manager.go index 347c5582..01f8a185 100644 --- a/internal/checkpoint/checkpoint_manager.go +++ b/internal/checkpoint/checkpoint_manager.go @@ -58,6 +58,29 @@ type SQLiteCheckpointStore struct { ownsDB bool // true 表示本实例打开的连接,Close 时需释放 } +type workspaceFingerprintRow struct { + WorkspaceKey string + FingerprintPayload string + UpdatedAtMS int64 +} + +// WorkspaceCheckpointState 保存工作区当前权威代码基线与文件指纹。 +type WorkspaceCheckpointState struct { + WorkspaceKey string + CurrentCheckpointID string + FingerprintPayload string + UpdatedAt time.Time +} + +// RunCheckpointBaseline 保存单次 run 的权威回退基线,供异步 diff 查询跨 run 结束后读取。 +type RunCheckpointBaseline struct { + SessionID string + RunID string + CheckpointID string + Drifted bool + UpdatedAt time.Time +} + // NewSQLiteCheckpointStore 创建 checkpoint 存储实例。 // dbPath 为 session.db 文件路径,可通过 session.DatabasePath 获取。 func NewSQLiteCheckpointStore(dbPath string) *SQLiteCheckpointStore { @@ -114,6 +137,184 @@ func (s *SQLiteCheckpointStore) ensureDB(ctx context.Context) (*sql.DB, error) { return db, nil } +// SaveWorkspaceFingerprint 把 workspace 最新指纹保存到 SQLite,用于跨会话/重启后的 drift 检测。 +func (s *SQLiteCheckpointStore) SaveWorkspaceFingerprint( + ctx context.Context, + workspaceKey string, + fingerprintPayload string, + updatedAt time.Time, +) error { + if err := ctx.Err(); err != nil { + return err + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + if err := ensureWorkspaceFingerprintTable(ctx, db); err != nil { + return err + } + _, err = db.ExecContext(ctx, ` +INSERT INTO workspace_fingerprints(workspace_key, fingerprint_payload, updated_at_ms) +VALUES(?, ?, ?) +ON CONFLICT(workspace_key) DO UPDATE SET + fingerprint_payload=excluded.fingerprint_payload, + updated_at_ms=excluded.updated_at_ms +`, workspaceKey, fingerprintPayload, toUnixMillis(updatedAt)) + if err != nil { + return fmt.Errorf("checkpoint: save workspace fingerprint %s: %w", workspaceKey, err) + } + return nil +} + +// LoadWorkspaceFingerprint 从 SQLite 读取 workspace 最近指纹,不存在时 ok=false。 +func (s *SQLiteCheckpointStore) LoadWorkspaceFingerprint( + ctx context.Context, + workspaceKey string, +) (string, bool, error) { + if err := ctx.Err(); err != nil { + return "", false, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return "", false, err + } + if err := ensureWorkspaceFingerprintTable(ctx, db); err != nil { + return "", false, err + } + var row workspaceFingerprintRow + err = db.QueryRowContext(ctx, ` +SELECT workspace_key, fingerprint_payload, updated_at_ms +FROM workspace_fingerprints +WHERE workspace_key = ? +`, workspaceKey).Scan(&row.WorkspaceKey, &row.FingerprintPayload, &row.UpdatedAtMS) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", false, nil + } + return "", false, fmt.Errorf("checkpoint: load workspace fingerprint %s: %w", workspaceKey, err) + } + return row.FingerprintPayload, true, nil +} + +// SaveWorkspaceCheckpointState 持久化工作区当前代码基线与文件指纹。 +func (s *SQLiteCheckpointStore) SaveWorkspaceCheckpointState(ctx context.Context, state WorkspaceCheckpointState) error { + if err := ctx.Err(); err != nil { + return err + } + if state.WorkspaceKey == "" { + return fmt.Errorf("checkpoint: workspace key required") + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + if err := ensureWorkspaceCheckpointStateTable(ctx, db); err != nil { + return err + } + _, err = db.ExecContext(ctx, ` +INSERT INTO workspace_checkpoint_states(workspace_key, current_checkpoint_id, fingerprint_payload, updated_at_ms) +VALUES(?, ?, ?, ?) +ON CONFLICT(workspace_key) DO UPDATE SET + current_checkpoint_id=excluded.current_checkpoint_id, + fingerprint_payload=excluded.fingerprint_payload, + updated_at_ms=excluded.updated_at_ms +`, state.WorkspaceKey, state.CurrentCheckpointID, state.FingerprintPayload, toUnixMillis(state.UpdatedAt)) + if err != nil { + return fmt.Errorf("checkpoint: save workspace checkpoint state %s: %w", state.WorkspaceKey, err) + } + return nil +} + +// LoadWorkspaceCheckpointState 读取工作区当前代码基线与文件指纹,不存在时 ok=false。 +func (s *SQLiteCheckpointStore) LoadWorkspaceCheckpointState(ctx context.Context, workspaceKey string) (WorkspaceCheckpointState, bool, error) { + if err := ctx.Err(); err != nil { + return WorkspaceCheckpointState{}, false, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return WorkspaceCheckpointState{}, false, err + } + if err := ensureWorkspaceCheckpointStateTable(ctx, db); err != nil { + return WorkspaceCheckpointState{}, false, err + } + var state WorkspaceCheckpointState + var updatedAtMS int64 + err = db.QueryRowContext(ctx, ` +SELECT workspace_key, current_checkpoint_id, fingerprint_payload, updated_at_ms +FROM workspace_checkpoint_states +WHERE workspace_key = ? +`, workspaceKey).Scan(&state.WorkspaceKey, &state.CurrentCheckpointID, &state.FingerprintPayload, &updatedAtMS) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return WorkspaceCheckpointState{}, false, nil + } + return WorkspaceCheckpointState{}, false, fmt.Errorf("checkpoint: load workspace checkpoint state %s: %w", workspaceKey, err) + } + state.UpdatedAt = fromUnixMillis(updatedAtMS) + return state, true, nil +} + +// SaveRunCheckpointBaseline 持久化单次 run 的权威回退基线。 +func (s *SQLiteCheckpointStore) SaveRunCheckpointBaseline(ctx context.Context, baseline RunCheckpointBaseline) error { + if err := ctx.Err(); err != nil { + return err + } + if baseline.SessionID == "" || baseline.RunID == "" { + return fmt.Errorf("checkpoint: session_id and run_id required for run baseline") + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + if err := ensureRunCheckpointBaselineTable(ctx, db); err != nil { + return err + } + _, err = db.ExecContext(ctx, ` +INSERT INTO run_checkpoint_baselines(session_id, run_id, checkpoint_id, drifted, updated_at_ms) +VALUES(?, ?, ?, ?, ?) +ON CONFLICT(session_id, run_id) DO UPDATE SET + checkpoint_id=excluded.checkpoint_id, + drifted=excluded.drifted, + updated_at_ms=excluded.updated_at_ms +`, baseline.SessionID, baseline.RunID, baseline.CheckpointID, boolToInt(baseline.Drifted), toUnixMillis(baseline.UpdatedAt)) + if err != nil { + return fmt.Errorf("checkpoint: save run baseline %s/%s: %w", baseline.SessionID, baseline.RunID, err) + } + return nil +} + +// LoadRunCheckpointBaseline 读取单次 run 的权威回退基线,不存在时 ok=false。 +func (s *SQLiteCheckpointStore) LoadRunCheckpointBaseline(ctx context.Context, sessionID, runID string) (RunCheckpointBaseline, bool, error) { + if err := ctx.Err(); err != nil { + return RunCheckpointBaseline{}, false, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return RunCheckpointBaseline{}, false, err + } + if err := ensureRunCheckpointBaselineTable(ctx, db); err != nil { + return RunCheckpointBaseline{}, false, err + } + var baseline RunCheckpointBaseline + var drifted int + var updatedAtMS int64 + err = db.QueryRowContext(ctx, ` +SELECT session_id, run_id, checkpoint_id, drifted, updated_at_ms +FROM run_checkpoint_baselines +WHERE session_id = ? AND run_id = ? +`, sessionID, runID).Scan(&baseline.SessionID, &baseline.RunID, &baseline.CheckpointID, &drifted, &updatedAtMS) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return RunCheckpointBaseline{}, false, nil + } + return RunCheckpointBaseline{}, false, fmt.Errorf("checkpoint: load run baseline %s/%s: %w", sessionID, runID, err) + } + baseline.Drifted = drifted != 0 + baseline.UpdatedAt = fromUnixMillis(updatedAtMS) + return baseline, true, nil +} + // CreateCheckpoint 在单一事务内写入 checkpoint record + session checkpoint。 // 事务内完成 record INSERT → session_cp INSERT → record UPDATE(设置 session_checkpoint_ref + status=available)。 func (s *SQLiteCheckpointStore) CreateCheckpoint(ctx context.Context, input CreateCheckpointInput) (session.CheckpointRecord, error) { @@ -663,3 +864,52 @@ func rollbackTx(tx *sql.Tx) { _ = tx.Rollback() } } + +func ensureWorkspaceFingerprintTable(ctx context.Context, db *sql.DB) error { + if db == nil { + return fmt.Errorf("checkpoint: workspace fingerprint db is nil") + } + if _, err := db.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS workspace_fingerprints ( + workspace_key TEXT PRIMARY KEY, + fingerprint_payload TEXT NOT NULL, + updated_at_ms INTEGER NOT NULL +)`); err != nil { + return fmt.Errorf("checkpoint: ensure workspace_fingerprints table: %w", err) + } + return nil +} + +func ensureWorkspaceCheckpointStateTable(ctx context.Context, db *sql.DB) error { + if db == nil { + return fmt.Errorf("checkpoint: workspace checkpoint state db is nil") + } + if _, err := db.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS workspace_checkpoint_states ( + workspace_key TEXT PRIMARY KEY, + current_checkpoint_id TEXT NOT NULL, + fingerprint_payload TEXT NOT NULL, + updated_at_ms INTEGER NOT NULL +)`); err != nil { + return fmt.Errorf("checkpoint: ensure workspace_checkpoint_states table: %w", err) + } + return nil +} + +func ensureRunCheckpointBaselineTable(ctx context.Context, db *sql.DB) error { + if db == nil { + return fmt.Errorf("checkpoint: run checkpoint baseline db is nil") + } + if _, err := db.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS run_checkpoint_baselines ( + session_id TEXT NOT NULL, + run_id TEXT NOT NULL, + checkpoint_id TEXT NOT NULL, + drifted INTEGER NOT NULL, + updated_at_ms INTEGER NOT NULL, + PRIMARY KEY(session_id, run_id) +)`); err != nil { + return fmt.Errorf("checkpoint: ensure run_checkpoint_baselines table: %w", err) + } + return nil +} diff --git a/internal/checkpoint/checkpoint_manager_test.go b/internal/checkpoint/checkpoint_manager_test.go index 6afd96a1..67d3f58a 100644 --- a/internal/checkpoint/checkpoint_manager_test.go +++ b/internal/checkpoint/checkpoint_manager_test.go @@ -588,3 +588,81 @@ func TestNewSQLiteCheckpointStoreWithNilDBClose(t *testing.T) { t.Fatalf("Close(nil db) error = %v", err) } } + +func TestSQLiteCheckpointStoreWorkspaceFingerprintPersistence(t *testing.T) { + fixture := newCheckpointStoreFixture(t) + db, err := fixture.sessionStore.InitDB(context.Background()) + if err != nil { + t.Fatalf("InitDB() error = %v", err) + } + shared := NewSQLiteCheckpointStoreWithDB(db) + + workspaceKey := session.WorkspacePathKey(fixture.workspaceRoot) + payload := `{"a.txt":{"size":1,"mod_time":"2026-01-01T00:00:00Z","head_hash":"abc"}}` + if err := shared.SaveWorkspaceFingerprint(context.Background(), workspaceKey, payload, time.Now()); err != nil { + t.Fatalf("SaveWorkspaceFingerprint() error = %v", err) + } + + got, ok, err := shared.LoadWorkspaceFingerprint(context.Background(), workspaceKey) + if err != nil { + t.Fatalf("LoadWorkspaceFingerprint() error = %v", err) + } + if !ok { + t.Fatal("expected persisted fingerprint to exist") + } + if got != payload { + t.Fatalf("fingerprint payload = %q, want %q", got, payload) + } + + missing, ok, err := shared.LoadWorkspaceFingerprint(context.Background(), "workspace/missing") + if err != nil { + t.Fatalf("LoadWorkspaceFingerprint(missing) error = %v", err) + } + if ok || missing != "" { + t.Fatalf("expected missing fingerprint, got ok=%v payload=%q", ok, missing) + } +} + +func TestSQLiteCheckpointStoreWorkspaceStateAndRunBaselinePersistence(t *testing.T) { + fixture := newCheckpointStoreFixture(t) + db, err := fixture.sessionStore.InitDB(context.Background()) + if err != nil { + t.Fatalf("InitDB() error = %v", err) + } + shared := NewSQLiteCheckpointStoreWithDB(db) + + workspaceKey := session.WorkspacePathKey(fixture.workspaceRoot) + payload := `{"tracked.txt":{"size":7}}` + if err := shared.SaveWorkspaceCheckpointState(context.Background(), WorkspaceCheckpointState{ + WorkspaceKey: workspaceKey, + CurrentCheckpointID: "cp-current", + FingerprintPayload: payload, + UpdatedAt: time.Now(), + }); err != nil { + t.Fatalf("SaveWorkspaceCheckpointState() error = %v", err) + } + state, ok, err := shared.LoadWorkspaceCheckpointState(context.Background(), workspaceKey) + if err != nil { + t.Fatalf("LoadWorkspaceCheckpointState() error = %v", err) + } + if !ok || state.CurrentCheckpointID != "cp-current" || state.FingerprintPayload != payload { + t.Fatalf("workspace state = %#v ok=%v", state, ok) + } + + if err := shared.SaveRunCheckpointBaseline(context.Background(), RunCheckpointBaseline{ + SessionID: "session-1", + RunID: "run-1", + CheckpointID: "cp-current", + Drifted: true, + UpdatedAt: time.Now(), + }); err != nil { + t.Fatalf("SaveRunCheckpointBaseline() error = %v", err) + } + baseline, ok, err := shared.LoadRunCheckpointBaseline(context.Background(), "session-1", "run-1") + if err != nil { + t.Fatalf("LoadRunCheckpointBaseline() error = %v", err) + } + if !ok || baseline.CheckpointID != "cp-current" || !baseline.Drifted { + t.Fatalf("run baseline = %#v ok=%v", baseline, ok) + } +} diff --git a/internal/checkpoint/per_edit_snapshot.go b/internal/checkpoint/per_edit_snapshot.go index 9efa78d2..0580bc5c 100644 --- a/internal/checkpoint/per_edit_snapshot.go +++ b/internal/checkpoint/per_edit_snapshot.go @@ -367,9 +367,9 @@ func (s *PerEditSnapshotStore) Reset() { // guardID 为 pre-restore 固化的快照(restoreCheckpointCore 中的 guard checkpoint), // 用于对比确定每个文件的目标操作;guardID 为空时仅处理 target checkpoint 内的文件。 // -// 对比逻辑:对 target 与 guard 中出现的每个文件,分别计算"目标状态"与"当前状态", -// 据此执行写回 / 删除 / 跳过,覆盖文件创建、修改、删除三种变更方向。 -func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID string) error { +// 对比逻辑:存在 relatedCheckpointIDs 时先收敛到 target 与后续 checkpoint 的净变化路径, +// 再分别计算"目标状态"与"当前状态"并执行写回 / 删除 / 跳过。 +func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID string, relatedCheckpointIDs ...string) error { targetCP, err := s.readCheckpointMeta(targetID) if err != nil { return err @@ -378,11 +378,6 @@ func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID st s.indexMu.Lock() defer s.indexMu.Unlock() - hashSet := make(map[string]struct{}, len(targetCP.FileVersions)) - for h := range targetCP.FileVersions { - hashSet[h] = struct{}{} - } - var guardCP CheckpointMeta hasGuard := guardID != "" if hasGuard { @@ -390,15 +385,29 @@ func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID st if err != nil { return err } - for h := range guardCP.FileVersions { - hashSet[h] = struct{}{} + } + relatedCPs := make([]CheckpointMeta, 0, len(relatedCheckpointIDs)) + for _, checkpointID := range relatedCheckpointIDs { + if strings.TrimSpace(checkpointID) == "" || checkpointID == targetID || checkpointID == guardID { + continue + } + relatedCP, err := s.readCheckpointMeta(checkpointID) + if err != nil { + return err } + relatedCPs = append(relatedCPs, relatedCP) } - // 无论有无 guard,都必须合并全量 pathToVersions。 - // guard 是 pending-only 的,不包含此前创建的、本 turn 未触碰的新文件; - // 不合并则这些文件在 restore 后仍会残留。 - for h := range s.pathToVersions { - hashSet[h] = struct{}{} + + hashSet := make(map[string]struct{}) + if len(relatedCPs) > 0 { + if err := s.collectRestoreChangedHashesLocked(hashSet, targetCP, relatedCPs); err != nil { + return err + } + } else { + addCheckpointHashes(hashSet, targetCP) + } + if hasGuard { + addCheckpointHashes(hashSet, guardCP) } for hash := range hashSet { @@ -407,7 +416,7 @@ func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID st } // 目标状态:target checkpoint 时刻文件应如何。 - toContent, toIsDir, toExists, toMode, toDisplay, err := s.contentAtCheckpointLocked(hash, targetCP.FileVersions, false) + toContent, toIsDir, toExists, toMode, toDisplay, err := s.contentAtCheckpointStateLocked(hash, targetCP, false) if err != nil { return err } @@ -418,7 +427,7 @@ func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID st var fromMode os.FileMode var fromDisplay string if hasGuard { - fromContent, fromIsDir, fromExists, fromMode, fromDisplay, err = s.contentAtCheckpointLocked(hash, guardCP.FileVersions, true) + fromContent, fromIsDir, fromExists, fromMode, fromDisplay, err = s.contentAtCheckpointStateLocked(hash, guardCP, true) if err != nil { return err } @@ -492,10 +501,26 @@ func (s *PerEditSnapshotStore) RestoreExact(ctx context.Context, checkpointID st s.indexMu.Lock() defer s.indexMu.Unlock() - for hash, vAt := range cp.FileVersions { + hashSet := make(map[string]struct{}, len(cp.FileVersions)+len(cp.ExactFileVersions)) + for hash := range cp.FileVersions { + hashSet[hash] = struct{}{} + } + for hash := range cp.ExactFileVersions { + hashSet[hash] = struct{}{} + } + hashes := make([]string, 0, len(hashSet)) + for hash := range hashSet { + hashes = append(hashes, hash) + } + sort.Strings(hashes) + for _, hash := range hashes { if err := ctx.Err(); err != nil { return err } + vAt, ok := cp.ExactFileVersions[hash] + if !ok { + vAt = cp.FileVersions[hash] + } meta, err := s.readVersionMeta(hash, vAt) if err != nil { return fmt.Errorf("per-edit: read meta v%d: %w", vAt, err) @@ -1311,6 +1336,82 @@ func readWorkdirMode(absPath string) os.FileMode { return info.Mode() } +// addCheckpointHashes 把 checkpoint 中显式记录的文件 hash 加入 restore 候选集合。 +func addCheckpointHashes(hashSet map[string]struct{}, cp CheckpointMeta) { + for hash := range cp.FileVersions { + hashSet[hash] = struct{}{} + } + for hash := range cp.ExactFileVersions { + hashSet[hash] = struct{}{} + } +} + +// collectRestoreChangedHashesLocked 只收集 target 到后续 checkpoint 之间净变化的路径。 +// target 可作为全工作区重锚点;真正 restore 时不应因为 target 是全量快照而误碰无关文件。 +func (s *PerEditSnapshotStore) collectRestoreChangedHashesLocked( + hashSet map[string]struct{}, + targetCP CheckpointMeta, + relatedCPs []CheckpointMeta, +) error { + if len(relatedCPs) == 0 { + return nil + } + latestCP := relatedCPs[0] + candidateHashes := make(map[string]struct{}) + addCheckpointHashes(candidateHashes, targetCP) + for _, cp := range relatedCPs { + addCheckpointHashes(candidateHashes, cp) + if cp.CreatedAt.After(latestCP.CreatedAt) { + latestCP = cp + } + } + for hash := range candidateHashes { + targetContent, targetIsDir, targetExists, targetMode, _, err := s.contentAtCheckpointStateLocked(hash, targetCP, false) + if err != nil { + return err + } + latestContent, latestIsDir, latestExists, latestMode, _, err := s.contentAtCheckpointStateLocked(hash, latestCP, false) + if err != nil { + return err + } + if targetExists != latestExists || + targetIsDir != latestIsDir || + targetMode != latestMode || + !bytes.Equal(targetContent, latestContent) { + hashSet[hash] = struct{}{} + } + } + return nil +} + +// contentAtCheckpointStateLocked 读取 checkpoint 结束态;新 checkpoint 优先使用 ExactFileVersions。 +// 旧 checkpoint 没有 exact 版本时,继续使用 v_next 兼容语义。 +func (s *PerEditSnapshotStore) contentAtCheckpointStateLocked( + hash string, + cp CheckpointMeta, + fallbackIfMissing bool, +) ([]byte, bool, bool, os.FileMode, string, error) { + if version, ok := cp.ExactFileVersions[hash]; ok { + meta, err := s.readVersionMeta(hash, version) + if err != nil { + return nil, false, false, 0, "", fmt.Errorf("per-edit: read exact meta v%d for %s: %w", version, hash, err) + } + display := s.resolveDisplayPathLocked(hash, meta.DisplayPath) + if !meta.Existed { + return nil, false, false, meta.Mode, display, nil + } + if meta.IsDir { + return nil, true, true, meta.Mode, display, nil + } + content, err := s.readVersionBin(hash, version) + if err != nil { + return nil, false, false, 0, display, fmt.Errorf("per-edit: read exact bin v%d for %s: %w", version, hash, err) + } + return content, false, true, meta.Mode, display, nil + } + return s.contentAtCheckpointLocked(hash, cp.FileVersions, fallbackIfMissing) +} + // contentAtCheckpointLocked 计算 hash 在某个 checkpoint 时刻的 workdir 内容。 // 在 cp.FileVersions 中:找下一版本读 .bin(或 Existed=false 时返回 nil); // 没有下一版本时:以当前 workdir 实际内容为准。 diff --git a/internal/checkpoint/per_edit_snapshot_test.go b/internal/checkpoint/per_edit_snapshot_test.go index 27b7d94b..1ff35005 100644 --- a/internal/checkpoint/per_edit_snapshot_test.go +++ b/internal/checkpoint/per_edit_snapshot_test.go @@ -1733,3 +1733,93 @@ func TestRunAggregateDiff_HistoricalFileFilteredByVersion(t *testing.T) { t.Fatalf("patch should NOT contain a.txt (filtered by version):\n%s", patch) } } + +func TestRestore_DoesNotUseGlobalHistoryWhenNoRelatedScope(t *testing.T) { + store, workdir := newTestStore(t) + + target := writeWorkdirFile(t, workdir, "target.txt", "old\n") + if _, err := store.CapturePreWrite(target); err != nil { + t.Fatalf("capture target: %v", err) + } + if err := os.WriteFile(target, []byte("new\n"), 0o644); err != nil { + t.Fatalf("write target: %v", err) + } + if _, err := store.FinalizeWithExactState("cp-target"); err != nil { + t.Fatalf("finalize target: %v", err) + } + store.Reset() + + unrelated := writeWorkdirFile(t, workdir, "unrelated.txt", "must stay\n") + if _, err := store.CapturePreWrite(unrelated); err != nil { + t.Fatalf("capture unrelated: %v", err) + } + if err := store.Restore(context.Background(), "cp-target", ""); err != nil { + t.Fatalf("restore cp-target: %v", err) + } + if got := mustReadFile(t, unrelated); got != "must stay\n" { + t.Fatalf("unrelated file should not be touched, got %q", got) + } +} + +func TestRestore_WithRelatedScopeSkipsUnchangedFilesFromFullRebase(t *testing.T) { + store, workdir := newTestStore(t) + + agentPath := writeWorkdirFile(t, workdir, "agent.txt", "baseline\n") + unrelatedPath := writeWorkdirFile(t, workdir, "unrelated.txt", "keep\n") + if _, err := store.CaptureBatch([]string{agentPath, unrelatedPath}); err != nil { + t.Fatalf("CaptureBatch baseline: %v", err) + } + if _, err := store.FinalizeWithExactState("cp-rebase"); err != nil { + t.Fatalf("FinalizeWithExactState(cp-rebase): %v", err) + } + store.Reset() + + if _, err := store.CapturePreWrite(agentPath); err != nil { + t.Fatalf("CapturePreWrite agent: %v", err) + } + if err := os.WriteFile(agentPath, []byte("agent edit\n"), 0o644); err != nil { + t.Fatalf("write agent edit: %v", err) + } + if _, err := store.FinalizeWithExactState("cp-agent"); err != nil { + t.Fatalf("FinalizeWithExactState(cp-agent): %v", err) + } + store.Reset() + + if err := os.WriteFile(unrelatedPath, []byte("user post-run edit\n"), 0o644); err != nil { + t.Fatalf("write unrelated post-run edit: %v", err) + } + if err := store.Restore(context.Background(), "cp-rebase", "", "cp-agent"); err != nil { + t.Fatalf("Restore(cp-rebase): %v", err) + } + if got := mustReadFile(t, agentPath); got != "baseline\n" { + t.Fatalf("agent file = %q, want baseline", got) + } + if got := mustReadFile(t, unrelatedPath); got != "user post-run edit\n" { + t.Fatalf("unrelated file should not be restored, got %q", got) + } +} + +func TestRestoreExact_PrefersExactCheckpointState(t *testing.T) { + store, workdir := newTestStore(t) + target := writeWorkdirFile(t, workdir, "exact.txt", "before\n") + if _, err := store.CapturePreWrite(target); err != nil { + t.Fatalf("capture: %v", err) + } + if err := os.WriteFile(target, []byte("checkpoint state\n"), 0o644); err != nil { + t.Fatalf("write checkpoint state: %v", err) + } + if _, err := store.FinalizeWithExactState("cp-exact"); err != nil { + t.Fatalf("FinalizeWithExactState: %v", err) + } + store.Reset() + + if err := os.WriteFile(target, []byte("drift\n"), 0o644); err != nil { + t.Fatalf("write drift: %v", err) + } + if err := store.RestoreExact(context.Background(), "cp-exact"); err != nil { + t.Fatalf("RestoreExact: %v", err) + } + if got := mustReadFile(t, target); got != "checkpoint state\n" { + t.Fatalf("RestoreExact should use exact state, got %q", got) + } +} diff --git a/internal/cli/gateway_runtime_bridge.go b/internal/cli/gateway_runtime_bridge.go index a4ec44f5..d8db9ca5 100644 --- a/internal/cli/gateway_runtime_bridge.go +++ b/internal/cli/gateway_runtime_bridge.go @@ -2456,6 +2456,8 @@ func (b *gatewayRuntimePortBridge) CheckpointDiff(ctx context.Context, input gat Deleted: result.Files.Deleted, Modified: result.Files.Modified, }, - Patch: result.Patch, + Patch: result.Patch, + WorkspaceDrifted: result.WorkspaceDrifted, + Warning: result.Warning, }, nil } diff --git a/internal/gateway/contracts.go b/internal/gateway/contracts.go index 69f6692e..533cac1e 100644 --- a/internal/gateway/contracts.go +++ b/internal/gateway/contracts.go @@ -476,6 +476,8 @@ type CheckpointDiffResult struct { PrevCommitHash string `json:"prev_commit_hash,omitempty"` Files FileDiffs `json:"files"` Patch string `json:"patch,omitempty"` + WorkspaceDrifted bool `json:"workspace_drifted,omitempty"` + Warning string `json:"warning,omitempty"` } // FileDiffs 描述 diff 中的文件变更列表。 diff --git a/internal/runtime/checkpoint_flow_test.go b/internal/runtime/checkpoint_flow_test.go index 3075ac75..13361b3a 100644 --- a/internal/runtime/checkpoint_flow_test.go +++ b/internal/runtime/checkpoint_flow_test.go @@ -11,6 +11,7 @@ import ( "neo-code/internal/checkpoint" providertypes "neo-code/internal/provider/types" + "neo-code/internal/repository" agentsession "neo-code/internal/session" ) @@ -207,6 +208,298 @@ func TestCreateStartOfTurnCheckpoint_NoPending_SessionOnly(t *testing.T) { } } +func TestCreatePreRunDriftRebaseCheckpoint_UsesDriftReasonAndPerEditRef(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + + absKeep := filepath.Join(fixture.workdir, "keep.txt") + if err := os.WriteFile(absKeep, []byte("keep"), 0o644); err != nil { + t.Fatalf("WriteFile(keep) error = %v", err) + } + absDeleted := filepath.Join(fixture.workdir, "deleted.txt") + if err := os.WriteFile(absDeleted, []byte("deleted"), 0o644); err != nil { + t.Fatalf("WriteFile(deleted) error = %v", err) + } + if err := os.Remove(absDeleted); err != nil { + t.Fatalf("Remove(deleted) error = %v", err) + } + + state := newRunState("run-drift", fixture.session) + checkpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint(context.Background(), &state, repository.FingerprintDiff{ + Deleted: []string{"deleted.txt"}, + }) + if err != nil { + t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err) + } + if checkpointID == "" { + t.Fatal("expected non-empty drift baseline checkpoint id") + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("expected 1 checkpoint, got %d", len(records)) + } + if records[0].Reason != agentsession.CheckpointReasonPreRunDriftRebase { + t.Fatalf("reason = %s, want %s", records[0].Reason, agentsession.CheckpointReasonPreRunDriftRebase) + } + if !checkpoint.IsPerEditRef(records[0].CodeCheckpointRef) { + t.Fatalf("code ref = %q, want peredit ref", records[0].CodeCheckpointRef) + } +} + +func TestCheckpointDiff_ScopeRun_RecreatedFileAfterDriftBaselineIsAdded(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + state := newRunState("run-drift-add", fixture.session) + + absPath := filepath.Join(fixture.workdir, "recreate.txt") + if err := os.WriteFile(absPath, []byte("legacy"), 0o644); err != nil { + t.Fatalf("WriteFile(legacy) error = %v", err) + } + if err := os.Remove(absPath); err != nil { + t.Fatalf("Remove(recreate.txt) error = %v", err) + } + + rebasedCheckpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint( + context.Background(), + &state, + repository.FingerprintDiff{Deleted: []string{"recreate.txt"}}, + ) + if err != nil { + t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err) + } + if rebasedCheckpointID == "" { + t.Fatal("expected drift rebase checkpoint id") + } + + if _, err := fixture.perEditStore.CapturePreWrite(absPath); err != nil { + t.Fatalf("CapturePreWrite(recreate) error = %v", err) + } + if err := os.WriteFile(absPath, []byte("new"), 0o644); err != nil { + t.Fatalf("WriteFile(new) error = %v", err) + } + if _, err := fixture.perEditStore.FinalizeWithExactState("cp-end"); err != nil { + t.Fatalf("FinalizeWithExactState(cp-end) error = %v", err) + } + fixture.perEditStore.Reset() + if err := fixture.service.createCheckpointRecord( + context.Background(), + fixture.session, + state.runID, + &state, + "cp-end", + agentsession.CheckpointReasonEndOfTurn, + ); err != nil { + t.Fatalf("createCheckpointRecord(cp-end) error = %v", err) + } + + fixture.service.setRunWorkspaceDrift(fixture.session.ID, state.runID, true) + fixture.service.setRunRollbackBaseline(fixture.session.ID, state.runID, rebasedCheckpointID) + result, err := fixture.service.CheckpointDiff(context.Background(), CheckpointDiffInput{ + SessionID: fixture.session.ID, + Scope: "run", + RunID: state.runID, + CheckpointID: "cp-end", + }) + if err != nil { + t.Fatalf("CheckpointDiff(scope=run) error = %v", err) + } + if result.PrevCheckpointID != rebasedCheckpointID { + t.Fatalf("PrevCheckpointID = %q, want %q", result.PrevCheckpointID, rebasedCheckpointID) + } + if len(result.Files.Added) != 1 || result.Files.Added[0] != "recreate.txt" { + t.Fatalf("added files = %+v, want recreate.txt", result.Files.Added) + } + if len(result.Files.Modified) != 0 { + t.Fatalf("modified files = %+v, want empty", result.Files.Modified) + } +} + +func TestCheckpointDiff_ScopeRun_LoadsPersistedBaselineAfterCacheClear(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + state := newRunState("run-persisted-baseline", fixture.session) + + absPath := filepath.Join(fixture.workdir, "persisted-baseline.txt") + if err := os.WriteFile(absPath, []byte("manual\n"), 0o644); err != nil { + t.Fatalf("WriteFile(manual) error = %v", err) + } + rebasedCheckpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint( + context.Background(), + &state, + repository.FingerprintDiff{Added: []string{"persisted-baseline.txt"}}, + ) + if err != nil { + t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err) + } + fixture.service.persistRunRollbackBaseline(context.Background(), fixture.session.ID, state.runID, rebasedCheckpointID, true) + fixture.service.clearRunCheckpointCaches(fixture.session.ID, state.runID) + + if _, err := fixture.perEditStore.CapturePreWrite(absPath); err != nil { + t.Fatalf("CapturePreWrite(agent) error = %v", err) + } + if err := os.WriteFile(absPath, []byte("agent\n"), 0o644); err != nil { + t.Fatalf("WriteFile(agent) error = %v", err) + } + if _, err := fixture.perEditStore.FinalizeWithExactState("cp-end-persisted"); err != nil { + t.Fatalf("FinalizeWithExactState(cp-end-persisted) error = %v", err) + } + fixture.perEditStore.Reset() + if err := fixture.service.createCheckpointRecord( + context.Background(), + fixture.session, + state.runID, + &state, + "cp-end-persisted", + agentsession.CheckpointReasonEndOfTurn, + ); err != nil { + t.Fatalf("createCheckpointRecord(cp-end-persisted) error = %v", err) + } + + result, err := fixture.service.CheckpointDiff(context.Background(), CheckpointDiffInput{ + SessionID: fixture.session.ID, + Scope: "run", + RunID: state.runID, + CheckpointID: "cp-end-persisted", + }) + if err != nil { + t.Fatalf("CheckpointDiff(scope=run) error = %v", err) + } + if result.PrevCheckpointID != rebasedCheckpointID { + t.Fatalf("PrevCheckpointID = %q, want %q", result.PrevCheckpointID, rebasedCheckpointID) + } + if !result.WorkspaceDrifted { + t.Fatal("expected persisted drift flag") + } +} + +func TestRestoreCheckpoint_AfterExternalModifyRestoresToRebasedState(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + state := newRunState("run-restore-rebased", fixture.session) + + absPath := filepath.Join(fixture.workdir, "external.txt") + if err := os.WriteFile(absPath, []byte("manual\n"), 0o644); err != nil { + t.Fatalf("WriteFile(manual) error = %v", err) + } + unrelatedPath := filepath.Join(fixture.workdir, "unrelated.txt") + if err := os.WriteFile(unrelatedPath, []byte("keep\n"), 0o644); err != nil { + t.Fatalf("WriteFile(unrelated) error = %v", err) + } + rebasedCheckpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint( + context.Background(), + &state, + repository.FingerprintDiff{Modified: []string{"external.txt"}}, + ) + if err != nil { + t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err) + } + + if _, err := fixture.perEditStore.CapturePreWrite(absPath); err != nil { + t.Fatalf("CapturePreWrite(agent) error = %v", err) + } + if err := os.WriteFile(absPath, []byte("agent\n"), 0o644); err != nil { + t.Fatalf("WriteFile(agent) error = %v", err) + } + if _, err := fixture.perEditStore.FinalizeWithExactState("cp-agent-end"); err != nil { + t.Fatalf("FinalizeWithExactState(cp-agent-end) error = %v", err) + } + fixture.perEditStore.Reset() + if err := fixture.service.createCheckpointRecord( + context.Background(), + fixture.session, + state.runID, + &state, + "cp-agent-end", + agentsession.CheckpointReasonEndOfTurn, + ); err != nil { + t.Fatalf("createCheckpointRecord(cp-agent-end) error = %v", err) + } + if err := os.WriteFile(unrelatedPath, []byte("post-run user edit\n"), 0o644); err != nil { + t.Fatalf("WriteFile(unrelated post-run) error = %v", err) + } + + if _, err := fixture.service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{ + SessionID: fixture.session.ID, + CheckpointID: rebasedCheckpointID, + }); err != nil { + t.Fatalf("RestoreCheckpoint() error = %v", err) + } + data, err := os.ReadFile(absPath) + if err != nil { + t.Fatalf("ReadFile(restored) error = %v", err) + } + if got := string(data); got != "manual\n" { + t.Fatalf("restored content = %q, want manual", got) + } + unrelatedData, err := os.ReadFile(unrelatedPath) + if err != nil { + t.Fatalf("ReadFile(unrelated) error = %v", err) + } + if got := string(unrelatedData); got != "post-run user edit\n" { + t.Fatalf("unrelated content = %q, want post-run user edit", got) + } +} + +func TestRecordRunStartFingerprint_DetectsDriftAcrossSessionsByWorkspace(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + service := fixture.service + ctx := context.Background() + + absPath := filepath.Join(fixture.workdir, "cross-session.txt") + if err := os.WriteFile(absPath, []byte("before"), 0o644); err != nil { + t.Fatalf("WriteFile(before) error = %v", err) + } + service.recordRunEndFingerprint(ctx, "sess-a", fixture.workdir) + + if err := os.Remove(absPath); err != nil { + t.Fatalf("Remove(cross-session.txt) error = %v", err) + } + + drifted, driftDiff := service.recordRunStartFingerprint(ctx, "sess-b", "run-b", fixture.workdir) + if !drifted { + t.Fatal("expected drift across sessions on the same workspace") + } + if !containsString(driftDiff.Deleted, "cross-session.txt") { + t.Fatalf("deleted diff = %#v, want cross-session.txt", driftDiff.Deleted) + } +} + +func TestRecordRunStartFingerprint_LoadsWorkspaceFingerprintFromPersistence(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + ctx := context.Background() + + absPath := filepath.Join(fixture.workdir, "persisted.txt") + if err := os.WriteFile(absPath, []byte("before"), 0o644); err != nil { + t.Fatalf("WriteFile(before) error = %v", err) + } + fixture.service.recordRunEndFingerprint(ctx, "sess-a", fixture.workdir) + + if err := os.Remove(absPath); err != nil { + t.Fatalf("Remove(persisted.txt) error = %v", err) + } + + // 模拟进程重启:新 Service 仅复用 checkpointStore,不复用内存指纹缓存。 + restarted := &Service{ + checkpointStore: fixture.checkpointStore, + } + drifted, driftDiff := restarted.recordRunStartFingerprint(ctx, "sess-b", "run-b", fixture.workdir) + if !drifted { + t.Fatal("expected drift when loading workspace fingerprint from persistence") + } + if !containsString(driftDiff.Deleted, "persisted.txt") { + t.Fatalf("deleted diff = %#v, want persisted.txt", driftDiff.Deleted) + } +} + +func containsString(items []string, target string) bool { + for _, item := range items { + if item == target { + return true + } + } + return false +} + func TestCreateEndOfTurnCheckpoint_NoWriteSkipped(t *testing.T) { fixture := newRuntimeCheckpointFixture(t) fixture.captureFile(t, "main.go", []byte("package main\n")) @@ -924,6 +1217,9 @@ func TestCheckpointDiffRunScopeAggregatesCurrentRun(t *testing.T) { if result.CheckpointID != "cp-2" { t.Fatalf("CheckpointID = %q, want cp-2", result.CheckpointID) } + if result.PrevCheckpointID != "" { + t.Fatalf("PrevCheckpointID = %q, want empty", result.PrevCheckpointID) + } if len(result.Files.Modified) != 1 || result.Files.Modified[0] != "tracked.txt" { t.Fatalf("modified files = %+v, want tracked.txt", result.Files.Modified) } @@ -941,7 +1237,7 @@ func mustReadRuntimeFile(t *testing.T, path string) []byte { return data } -// ──────── scope=run diff tests ──────── +// scope=run diff tests func TestCheckpointDiff_ScopeRun_ReturnsAggregateDiff(t *testing.T) { workdir := t.TempDir() @@ -988,6 +1284,15 @@ func TestCheckpointDiff_ScopeRun_ReturnsAggregateDiff(t *testing.T) { CreatedAt: now, CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-1"), }, + { + CheckpointID: "cp-0", + SessionID: "session-1", + RunID: "run-prev", + CreatedAt: now.Add(-time.Second), + Reason: agentsession.CheckpointReasonEndOfTurn, + Status: agentsession.CheckpointStatusAvailable, + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-0"), + }, }, } service := &Service{ @@ -1026,10 +1331,12 @@ func TestCheckpointDiff_ScopeRun_ReturnsAggregateDiff(t *testing.T) { if len(modifiedPaths) != 1 || modifiedPaths[0] != "a.txt" { t.Fatalf("expected a.txt modified, got added=%v modified=%v", addedPaths, modifiedPaths) } - // 当前 run-scope diff 默认返回目标 checkpoint(未显式指定时为最新 checkpoint)。 if result.CheckpointID != "cp-2" { t.Fatalf("CheckpointID = %q, want cp-2", result.CheckpointID) } + if result.PrevCheckpointID != "cp-0" { + t.Fatalf("PrevCheckpointID = %q, want cp-0", result.PrevCheckpointID) + } } func TestCheckpointDiff_ScopeRun_RejectsMissingRunID(t *testing.T) { @@ -1077,6 +1384,161 @@ func TestCheckpointDiff_ScopeRun_NoCheckpointsForRunID(t *testing.T) { } } +func TestCheckpointDiff_ScopeRun_RejectsTargetCheckpointFromAnotherRun(t *testing.T) { + spy := &checkpointStoreSpy{ + listRecords: []agentsession.CheckpointRecord{ + { + CheckpointID: "cp-other-run", + SessionID: "session-1", + RunID: "run-other", + CreatedAt: time.Now().UTC(), + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-other-run"), + }, + }, + } + service := &Service{ + checkpointStore: spy, + perEditStore: checkpoint.NewPerEditSnapshotStore(t.TempDir(), t.TempDir()), + } + _, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{ + SessionID: "session-1", + Scope: "run", + RunID: "run-target", + CheckpointID: "cp-other-run", + }) + if err == nil { + t.Fatal("expected error for target checkpoint from another run") + } + if !strings.Contains(err.Error(), "does not belong to run") { + t.Fatalf("error = %v, want run mismatch", err) + } +} + +func TestCheckpointDiff_ScopeRun_WarnsWhenBaselineMissing(t *testing.T) { + workdir := t.TempDir() + projectDir := t.TempDir() + store := checkpoint.NewPerEditSnapshotStore(projectDir, workdir) + now := time.Now().UTC() + + absA := filepath.Join(workdir, "a.txt") + _ = os.WriteFile(absA, []byte("old a\n"), 0o644) + if _, err := store.CapturePreWrite(absA); err != nil { + t.Fatalf("CapturePreWrite a: %v", err) + } + _ = os.WriteFile(absA, []byte("new a\n"), 0o644) + if _, err := store.Finalize("cp-1"); err != nil { + t.Fatalf("Finalize cp-1: %v", err) + } + store.Reset() + + spy := &checkpointStoreSpy{ + listRecords: []agentsession.CheckpointRecord{ + { + CheckpointID: "cp-1", + SessionID: "session-1", + RunID: "run-target", + CreatedAt: now, + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-1"), + }, + }, + } + service := &Service{ + checkpointStore: spy, + perEditStore: store, + } + + result, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{ + SessionID: "session-1", + Scope: "run", + RunID: "run-target", + }) + if err != nil { + t.Fatalf("CheckpointDiff(scope=run) error = %v", err) + } + if result.PrevCheckpointID != "" { + t.Fatalf("PrevCheckpointID = %q, want empty", result.PrevCheckpointID) + } + if !strings.Contains(result.Warning, "run baseline checkpoint is missing") { + t.Fatalf("Warning = %q, want missing-baseline warning", result.Warning) + } +} + +func TestCheckpointDiff_ScopeRun_PrefersRebasedBaselineWhenDrifted(t *testing.T) { + workdir := t.TempDir() + projectDir := t.TempDir() + store := checkpoint.NewPerEditSnapshotStore(projectDir, workdir) + now := time.Now().UTC() + + absA := filepath.Join(workdir, "a.txt") + _ = os.WriteFile(absA, []byte("one\n"), 0o644) + if _, err := store.CapturePreWrite(absA); err != nil { + t.Fatalf("CapturePreWrite cp-drift: %v", err) + } + _ = os.WriteFile(absA, []byte("two\n"), 0o644) + if _, err := store.Finalize("cp-drift"); err != nil { + t.Fatalf("Finalize cp-drift: %v", err) + } + store.Reset() + + if _, err := store.CapturePreWrite(absA); err != nil { + t.Fatalf("CapturePreWrite cp-target: %v", err) + } + _ = os.WriteFile(absA, []byte("three\n"), 0o644) + if _, err := store.Finalize("cp-target"); err != nil { + t.Fatalf("Finalize cp-target: %v", err) + } + store.Reset() + + spy := &checkpointStoreSpy{ + listRecords: []agentsession.CheckpointRecord{ + { + CheckpointID: "cp-target", + SessionID: "session-1", + RunID: "run-target", + CreatedAt: now.Add(time.Second), + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-target"), + }, + { + CheckpointID: "cp-drift", + SessionID: "session-1", + RunID: "run-target", + CreatedAt: now, + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-drift"), + }, + { + CheckpointID: "cp-old", + SessionID: "session-1", + RunID: "run-prev", + CreatedAt: now.Add(-time.Second), + Reason: agentsession.CheckpointReasonEndOfTurn, + Status: agentsession.CheckpointStatusAvailable, + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-old"), + }, + }, + } + service := &Service{ + checkpointStore: spy, + perEditStore: store, + } + service.setRunWorkspaceDrift("session-1", "run-target", true) + service.setRunRollbackBaseline("session-1", "run-target", "cp-drift") + + result, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{ + SessionID: "session-1", + Scope: "run", + RunID: "run-target", + }) + if err != nil { + t.Fatalf("CheckpointDiff(scope=run) error = %v", err) + } + if result.PrevCheckpointID != "cp-drift" { + t.Fatalf("PrevCheckpointID = %q, want cp-drift", result.PrevCheckpointID) + } + if !strings.Contains(result.Warning, "baseline re-anchored") { + t.Fatalf("Warning = %q, want re-anchored hint", result.Warning) + } +} + func TestCheckpointDiff_DefaultScopePreservesExistingBehavior(t *testing.T) { // Verify empty scope still uses checkpoint-to-checkpoint comparison. workdir := t.TempDir() diff --git a/internal/runtime/checkpoint_gate.go b/internal/runtime/checkpoint_gate.go index bae5f42f..2c5ec5f7 100644 --- a/internal/runtime/checkpoint_gate.go +++ b/internal/runtime/checkpoint_gate.go @@ -5,10 +5,13 @@ import ( "encoding/json" "fmt" "log" + "path/filepath" + "sort" "strings" "time" "neo-code/internal/checkpoint" + "neo-code/internal/repository" agentsession "neo-code/internal/session" ) @@ -38,6 +41,73 @@ func (s *Service) createStartOfTurnCheckpoint(ctx context.Context, state *runSta return s.createCheckpointRecord(ctx, session, runID, state, checkpointID, agentsession.CheckpointReasonPreWrite) } +// createPreRunDriftRebaseCheckpoint 在 run 开始前检测到外部漂移时,固化当前工作区为权威基线。 +func (s *Service) createPreRunDriftRebaseCheckpoint( + ctx context.Context, + state *runState, + diff repository.FingerprintDiff, +) (string, error) { + if s.checkpointStore == nil || s.perEditStore == nil { + return "", nil + } + + state.mu.Lock() + session := state.session + runID := state.runID + state.mu.Unlock() + + workdir := strings.TrimSpace(session.Workdir) + if workdir == "" { + return "", fmt.Errorf("checkpoint: workdir is empty when creating drift rebase checkpoint") + } + + currentFingerprint, _, err := repository.ScanWorkdir(ctx, workdir, repository.DefaultFingerprintOptions()) + if err != nil { + return "", fmt.Errorf("checkpoint: scan workdir for drift rebase: %w", err) + } + absPaths := make([]string, 0, len(currentFingerprint)) + for relPath := range currentFingerprint { + absPaths = append(absPaths, filepath.Join(workdir, filepath.FromSlash(relPath))) + } + sort.Strings(absPaths) + if len(absPaths) > 0 { + if _, err := s.perEditStore.CaptureBatch(absPaths); err != nil { + return "", fmt.Errorf("checkpoint: capture current files for drift rebase: %w", err) + } + } + + deleted := append([]string(nil), diff.Deleted...) + sort.Strings(deleted) + for _, relPath := range deleted { + absPath := filepath.Join(workdir, filepath.FromSlash(relPath)) + if _, err := s.perEditStore.CapturePreWrite(absPath); err != nil { + return "", fmt.Errorf("checkpoint: capture deleted file for drift rebase (%s): %w", relPath, err) + } + } + + checkpointID := agentsession.NewID("checkpoint") + written, err := s.perEditStore.FinalizeWithExactState(checkpointID) + if err != nil { + return "", fmt.Errorf("checkpoint: finalize drift rebase checkpoint: %w", err) + } + if !written { + return "", fmt.Errorf("checkpoint: drift rebase checkpoint has no captured files") + } + defer s.perEditStore.Reset() + + if err := s.createCheckpointRecord( + ctx, + session, + runID, + state, + checkpointID, + agentsession.CheckpointReasonPreRunDriftRebase, + ); err != nil { + return "", err + } + return checkpointID, nil +} + // createEndOfTurnCheckpoint 在工具执行完成后创建代码检查点。 // hasWorkspaceWrite=false 时不创建(避免空 checkpoint);为 true 时 Finalize 当前 pending。 // 失败仅 log,不阻塞主流程。 diff --git a/internal/runtime/checkpoint_restore.go b/internal/runtime/checkpoint_restore.go index 4ba43124..ccc8366a 100644 --- a/internal/runtime/checkpoint_restore.go +++ b/internal/runtime/checkpoint_restore.go @@ -86,7 +86,9 @@ func (s *Service) restoreCheckpointCore(ctx context.Context, sessionID, checkpoi return RestoreResult{}, agentsession.CheckpointRecord{}, fmt.Errorf("checkpoint: create guard: %w", guardErr) } - // 3. Restore code via per-edit store(不在 cp.FileVersions 中的文件保持不变)。 + relatedPerEditIDs := s.restoreRelatedPerEditIDs(ctx, sessionID, record.CreatedAt) + + // 3. Restore code via per-edit store(仅处理目标、guard 与后续相关 checkpoint 涉及的文件)。 // Guard checkpoint 恢复时使用 RestoreExact:guard 中存储的 version 就是 restore 前的 pre-write 状态, // 而 Restore 的 v_next 语义在 guard 上通常是 no-op(guard 之后没有新的 capture)。 isGuardRestore := record.Reason == agentsession.CheckpointReasonGuard @@ -102,7 +104,7 @@ func (s *Service) restoreCheckpointCore(ctx context.Context, sessionID, checkpoi if guardWritten { guardCheckpointID = guardID } - if err := s.perEditStore.Restore(ctx, perEditID, guardCheckpointID); err != nil { + if err := s.perEditStore.Restore(ctx, perEditID, guardCheckpointID, relatedPerEditIDs...); err != nil { return RestoreResult{}, agentsession.CheckpointRecord{}, fmt.Errorf("checkpoint: restore code: %w", err) } } @@ -149,6 +151,7 @@ func (s *Service) restoreCheckpointCore(ctx context.Context, sessionID, checkpoi // 7. Update runtime session if it's the current session s.updateRuntimeSessionAfterRestore(sessionID, head, messages) + s.recordRunEndWorkspaceState(ctx, sessionID, head.Workdir, checkpointID) return RestoreResult{ CheckpointID: checkpointID, @@ -156,6 +159,34 @@ func (s *Service) restoreCheckpointCore(ctx context.Context, sessionID, checkpoi }, guardRecord, nil } +// restoreRelatedPerEditIDs 收集目标 checkpoint 之后仍 available 的代码 checkpoint,用于限定 restore 清理范围。 +func (s *Service) restoreRelatedPerEditIDs(ctx context.Context, sessionID string, targetCreatedAt time.Time) []string { + if s.checkpointStore == nil { + return nil + } + records, err := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{}) + if err != nil { + return nil + } + out := make([]string, 0) + for _, record := range records { + if !record.CreatedAt.After(targetCreatedAt) { + continue + } + if record.Status != "" && record.Status != agentsession.CheckpointStatusAvailable { + continue + } + if record.Reason == agentsession.CheckpointReasonGuard { + continue + } + perEditID := checkpoint.PerEditCheckpointIDFromRef(record.CodeCheckpointRef) + if perEditID != "" { + out = append(out, perEditID) + } + } + return out +} + // RestoreCheckpoint 恢复指定 checkpoint 的会话和工作区状态,并发出 checkpoint_restored 事件。 func (s *Service) RestoreCheckpoint(ctx context.Context, input GatewayRestoreInput) (RestoreResult, error) { result, guardRecord, err := s.restoreCheckpointCore(ctx, input.SessionID, input.CheckpointID) @@ -315,6 +346,8 @@ type CheckpointDiffResult struct { PrevCommitHash string `json:"prev_commit_hash,omitempty"` Files FileDiffs `json:"files"` Patch string `json:"patch,omitempty"` + WorkspaceDrifted bool `json:"workspace_drifted,omitempty"` + Warning string `json:"warning,omitempty"` } // FileDiffs 描述 diff 中的文件变更列表。 @@ -452,6 +485,9 @@ func (s *Service) checkpointDiffForRun(ctx context.Context, input CheckpointDiff if runID == "" { return CheckpointDiffResult{}, fmt.Errorf("checkpoint: run_id required for run scope diff") } + if targetRecord != nil && strings.TrimSpace(targetRecord.RunID) != runID { + return CheckpointDiffResult{}, fmt.Errorf("checkpoint: target checkpoint %s does not belong to run %s", targetRecord.CheckpointID, runID) + } codeRecords := make([]agentsession.CheckpointRecord, 0) for _, record := range records { @@ -479,6 +515,11 @@ func (s *Service) checkpointDiffForRun(ctx context.Context, input CheckpointDiff targetRecord = &codeRecords[len(codeRecords)-1] } + prevCheckpointID, workspaceDrifted := s.getPersistentRunRollbackBaseline(ctx, sessionID, runID) + if prevCheckpointID == "" { + prevCheckpointID = resolveRunPrevCheckpointID(records, codeRecords[0].CreatedAt, runID) + } + perEditIDs := make([]string, 0, len(codeRecords)) for _, record := range codeRecords { perEditID := checkpoint.PerEditCheckpointIDFromRef(record.CodeCheckpointRef) @@ -492,7 +533,30 @@ func (s *Service) checkpointDiffForRun(ctx context.Context, input CheckpointDiff return CheckpointDiffResult{}, fmt.Errorf("checkpoint: per-edit run diff: %w", err) } - result := CheckpointDiffResult{CheckpointID: targetRecord.CheckpointID, Patch: patch} + result := CheckpointDiffResult{ + CheckpointID: targetRecord.CheckpointID, + PrevCheckpointID: prevCheckpointID, + Patch: patch, + WorkspaceDrifted: workspaceDrifted, + } + if result.WorkspaceDrifted { + if prevCheckpointID != "" { + result.Warning = "workspace drift detected before run start; baseline re-anchored" + } else { + result.Warning = "workspace drift detected before run start" + } + } + if result.PrevCheckpointID == "" { + if result.Warning != "" { + result.Warning += "; run baseline checkpoint is missing" + } else { + result.Warning = "run baseline checkpoint is missing" + } + } + if result.PrevCheckpointID != "" { + s.persistRunRollbackBaseline(ctx, sessionID, runID, result.PrevCheckpointID, result.WorkspaceDrifted) + } + for _, c := range changes { switch c.Kind { case checkpoint.FileChangeAdded: @@ -505,3 +569,45 @@ func (s *Service) checkpointDiffForRun(ctx context.Context, input CheckpointDiff } return result, nil } + +// resolveRunPrevCheckpointID 解析 run 级 diff 的权威 baseline checkpoint_id。 +// 优先选择 run 开始前最近的 end_of_turn 代码 checkpoint;若不存在,退化为最近的可用代码 checkpoint。 +func resolveRunPrevCheckpointID(records []agentsession.CheckpointRecord, runStartAt time.Time, runID string) string { + var preferred *agentsession.CheckpointRecord + var fallback *agentsession.CheckpointRecord + for i := range records { + record := records[i] + if strings.TrimSpace(record.RunID) == strings.TrimSpace(runID) { + continue + } + if record.Status != "" && record.Status != agentsession.CheckpointStatusAvailable { + continue + } + if !checkpoint.IsPerEditRef(record.CodeCheckpointRef) { + continue + } + if record.Reason == agentsession.CheckpointReasonGuard { + continue + } + if !record.CreatedAt.Before(runStartAt) { + continue + } + if fallback == nil || record.CreatedAt.After(fallback.CreatedAt) { + candidate := record + fallback = &candidate + } + if record.Reason == agentsession.CheckpointReasonEndOfTurn { + if preferred == nil || record.CreatedAt.After(preferred.CreatedAt) { + candidate := record + preferred = &candidate + } + } + } + if preferred != nil { + return preferred.CheckpointID + } + if fallback != nil { + return fallback.CheckpointID + } + return "" +} diff --git a/internal/runtime/checkpoint_run_baseline.go b/internal/runtime/checkpoint_run_baseline.go new file mode 100644 index 00000000..a299d440 --- /dev/null +++ b/internal/runtime/checkpoint_run_baseline.go @@ -0,0 +1,277 @@ +package runtime + +import ( + "context" + "encoding/json" + "strings" + "time" + + "neo-code/internal/checkpoint" + "neo-code/internal/repository" + agentsession "neo-code/internal/session" +) + +// buildRunBaselineKey 生成 run 级缓存 key,保证同一会话不同 run 的基线隔离。 +func buildRunBaselineKey(sessionID, runID string) string { + return strings.TrimSpace(sessionID) + "\n" + strings.TrimSpace(runID) +} + +// ensureRunBaselineMapsLocked 懒初始化 run 级基线与漂移缓存,允许测试中用字面量 Service 直接调用。 +func (s *Service) ensureRunBaselineMapsLocked() { + if s.runRollbackBaselineByRunKey == nil { + s.runRollbackBaselineByRunKey = make(map[string]string) + } + if s.runWorkspaceDriftByRunKey == nil { + s.runWorkspaceDriftByRunKey = make(map[string]bool) + } + if s.lastRunFingerprintByWorkspaceKey == nil { + s.lastRunFingerprintByWorkspaceKey = make(map[string]repository.WorkdirFingerprint) + } +} + +type workspaceFingerprintPersistenceStore interface { + SaveWorkspaceFingerprint(ctx context.Context, workspaceKey string, fingerprintPayload string, updatedAt time.Time) error + LoadWorkspaceFingerprint(ctx context.Context, workspaceKey string) (string, bool, error) +} + +type workspaceCheckpointStatePersistenceStore interface { + SaveWorkspaceCheckpointState(ctx context.Context, state checkpoint.WorkspaceCheckpointState) error + LoadWorkspaceCheckpointState(ctx context.Context, workspaceKey string) (checkpoint.WorkspaceCheckpointState, bool, error) + SaveRunCheckpointBaseline(ctx context.Context, baseline checkpoint.RunCheckpointBaseline) error + LoadRunCheckpointBaseline(ctx context.Context, sessionID, runID string) (checkpoint.RunCheckpointBaseline, bool, error) +} + +// workspaceKeyFromWorkdir 统一把工作目录映射为跨会话稳定 key。 +func workspaceKeyFromWorkdir(workdir string) string { + return strings.TrimSpace(agentsession.WorkspacePathKey(strings.TrimSpace(workdir))) +} + +// setRunRollbackBaseline 记录当前 run 的权威回退基线 checkpoint_id(由后端计算,前端仅消费)。 +func (s *Service) setRunRollbackBaseline(sessionID, runID, checkpointID string) { + key := buildRunBaselineKey(sessionID, runID) + if key == "\n" { + return + } + s.rollbackBaselineMu.Lock() + defer s.rollbackBaselineMu.Unlock() + s.ensureRunBaselineMapsLocked() + if strings.TrimSpace(checkpointID) == "" { + delete(s.runRollbackBaselineByRunKey, key) + return + } + s.runRollbackBaselineByRunKey[key] = strings.TrimSpace(checkpointID) +} + +// getRunRollbackBaseline 返回 run 的权威回退基线 checkpoint_id。 +func (s *Service) getRunRollbackBaseline(sessionID, runID string) string { + key := buildRunBaselineKey(sessionID, runID) + if key == "\n" { + return "" + } + s.rollbackBaselineMu.Lock() + defer s.rollbackBaselineMu.Unlock() + s.ensureRunBaselineMapsLocked() + return s.runRollbackBaselineByRunKey[key] +} + +// getPersistentRunRollbackBaseline 从内存或 SQLite 读取 run 的权威回退基线。 +func (s *Service) getPersistentRunRollbackBaseline(ctx context.Context, sessionID, runID string) (string, bool) { + if baseline := s.getRunRollbackBaseline(sessionID, runID); baseline != "" { + return baseline, s.getRunWorkspaceDrift(sessionID, runID) + } + store, ok := s.checkpointStore.(workspaceCheckpointStatePersistenceStore) + if !ok { + return "", false + } + loaded, found, err := store.LoadRunCheckpointBaseline(ctx, strings.TrimSpace(sessionID), strings.TrimSpace(runID)) + if err != nil || !found { + return "", false + } + s.setRunRollbackBaseline(sessionID, runID, loaded.CheckpointID) + s.setRunWorkspaceDrift(sessionID, runID, loaded.Drifted) + return strings.TrimSpace(loaded.CheckpointID), loaded.Drifted +} + +// persistRunRollbackBaseline 双写 run 级回退基线,保证异步 diff 在 run 结束后仍可读取。 +func (s *Service) persistRunRollbackBaseline(ctx context.Context, sessionID, runID, checkpointID string, drifted bool) { + sessionID = strings.TrimSpace(sessionID) + runID = strings.TrimSpace(runID) + checkpointID = strings.TrimSpace(checkpointID) + if sessionID == "" || runID == "" || checkpointID == "" { + return + } + s.setRunRollbackBaseline(sessionID, runID, checkpointID) + s.setRunWorkspaceDrift(sessionID, runID, drifted) + store, ok := s.checkpointStore.(workspaceCheckpointStatePersistenceStore) + if !ok { + return + } + _ = store.SaveRunCheckpointBaseline(ctx, checkpoint.RunCheckpointBaseline{ + SessionID: sessionID, + RunID: runID, + CheckpointID: checkpointID, + Drifted: drifted, + UpdatedAt: time.Now(), + }) +} + +// setRunWorkspaceDrift 标记当前 run 是否检测到空闲期工作区漂移。 +func (s *Service) setRunWorkspaceDrift(sessionID, runID string, drifted bool) { + key := buildRunBaselineKey(sessionID, runID) + if key == "\n" { + return + } + s.rollbackBaselineMu.Lock() + defer s.rollbackBaselineMu.Unlock() + s.ensureRunBaselineMapsLocked() + if drifted { + s.runWorkspaceDriftByRunKey[key] = true + return + } + delete(s.runWorkspaceDriftByRunKey, key) +} + +// getRunWorkspaceDrift 返回 run 是否检测到空闲期工作区漂移。 +func (s *Service) getRunWorkspaceDrift(sessionID, runID string) bool { + key := buildRunBaselineKey(sessionID, runID) + if key == "\n" { + return false + } + s.rollbackBaselineMu.Lock() + defer s.rollbackBaselineMu.Unlock() + s.ensureRunBaselineMapsLocked() + return s.runWorkspaceDriftByRunKey[key] +} + +// clearRunCheckpointCaches 清理 run 级 checkpoint 缓存,避免内存与跨 run 污染。 +func (s *Service) clearRunCheckpointCaches(sessionID, runID string) { + key := buildRunBaselineKey(sessionID, runID) + if key == "\n" { + return + } + s.rollbackBaselineMu.Lock() + defer s.rollbackBaselineMu.Unlock() + s.ensureRunBaselineMapsLocked() + delete(s.runRollbackBaselineByRunKey, key) + delete(s.runWorkspaceDriftByRunKey, key) +} + +// recordRunStartFingerprint 在 run 开始时比对上次 run 结束指纹,返回是否发生空闲期漂移及详细差异。 +func (s *Service) recordRunStartFingerprint( + ctx context.Context, + sessionID, runID, workdir string, +) (bool, repository.FingerprintDiff) { + emptyDiff := repository.FingerprintDiff{} + normalizedRunID := strings.TrimSpace(runID) + normalizedWorkdir := strings.TrimSpace(workdir) + workspaceKey := workspaceKeyFromWorkdir(normalizedWorkdir) + if strings.TrimSpace(sessionID) == "" || normalizedRunID == "" || normalizedWorkdir == "" || workspaceKey == "" { + return false, emptyDiff + } + current, _, err := repository.ScanWorkdir(ctx, normalizedWorkdir, repository.DefaultFingerprintOptions()) + if err != nil { + return false, emptyDiff + } + + s.rollbackBaselineMu.Lock() + s.ensureRunBaselineMapsLocked() + previous, ok := s.lastRunFingerprintByWorkspaceKey[workspaceKey] + s.rollbackBaselineMu.Unlock() + var currentCheckpointID string + if !ok { + if store, okStore := s.checkpointStore.(workspaceCheckpointStatePersistenceStore); okStore { + state, found, loadErr := store.LoadWorkspaceCheckpointState(ctx, workspaceKey) + if loadErr == nil && found { + currentCheckpointID = strings.TrimSpace(state.CurrentCheckpointID) + var restored repository.WorkdirFingerprint + if err := json.Unmarshal([]byte(state.FingerprintPayload), &restored); err == nil { + s.rollbackBaselineMu.Lock() + s.ensureRunBaselineMapsLocked() + s.lastRunFingerprintByWorkspaceKey[workspaceKey] = restored + s.rollbackBaselineMu.Unlock() + previous = restored + ok = true + } + } + } + if !ok { + if store, okStore := s.checkpointStore.(workspaceFingerprintPersistenceStore); okStore { + payload, found, loadErr := store.LoadWorkspaceFingerprint(ctx, workspaceKey) + if loadErr == nil && found { + var restored repository.WorkdirFingerprint + if err := json.Unmarshal([]byte(payload), &restored); err == nil { + s.rollbackBaselineMu.Lock() + s.ensureRunBaselineMapsLocked() + s.lastRunFingerprintByWorkspaceKey[workspaceKey] = restored + s.rollbackBaselineMu.Unlock() + previous = restored + ok = true + } + } + } + } + if !ok { + return false, emptyDiff + } + } + diff := repository.DiffFingerprints(previous, current) + drifted := len(diff.Added) > 0 || len(diff.Modified) > 0 || len(diff.Deleted) > 0 + if drifted { + s.setRunWorkspaceDrift(sessionID, normalizedRunID, true) + return true, diff + } + if currentCheckpointID != "" { + s.persistRunRollbackBaseline(ctx, sessionID, normalizedRunID, currentCheckpointID, false) + } + return false, diff +} + +// recordRunEndFingerprint 在 run 结束后保存最新指纹,供下次 run 开始时进行漂移检测。 +func (s *Service) recordRunEndFingerprint(ctx context.Context, sessionID, workdir string) { + s.recordRunEndWorkspaceState(ctx, sessionID, workdir, "") +} + +// recordRunEndWorkspaceState 保存最新指纹和当前 checkpoint;checkpointID 为空时保留已有基线。 +func (s *Service) recordRunEndWorkspaceState(ctx context.Context, sessionID, workdir, checkpointID string) { + normalizedWorkdir := strings.TrimSpace(workdir) + workspaceKey := workspaceKeyFromWorkdir(normalizedWorkdir) + if strings.TrimSpace(sessionID) == "" || normalizedWorkdir == "" || workspaceKey == "" { + return + } + current, _, err := repository.ScanWorkdir(ctx, normalizedWorkdir, repository.DefaultFingerprintOptions()) + if err != nil { + return + } + s.rollbackBaselineMu.Lock() + s.ensureRunBaselineMapsLocked() + s.lastRunFingerprintByWorkspaceKey[workspaceKey] = current + s.rollbackBaselineMu.Unlock() + + store, ok := s.checkpointStore.(workspaceFingerprintPersistenceStore) + if !ok { + return + } + payload, err := json.Marshal(current) + if err != nil { + return + } + now := time.Now() + _ = store.SaveWorkspaceFingerprint(ctx, workspaceKey, string(payload), now) + + stateStore, ok := s.checkpointStore.(workspaceCheckpointStatePersistenceStore) + if !ok { + return + } + currentCheckpointID := strings.TrimSpace(checkpointID) + if currentCheckpointID == "" { + if existing, found, err := stateStore.LoadWorkspaceCheckpointState(ctx, workspaceKey); err == nil && found { + currentCheckpointID = strings.TrimSpace(existing.CurrentCheckpointID) + } + } + _ = stateStore.SaveWorkspaceCheckpointState(ctx, checkpoint.WorkspaceCheckpointState{ + WorkspaceKey: workspaceKey, + CurrentCheckpointID: currentCheckpointID, + FingerprintPayload: string(payload), + UpdatedAt: now, + }) +} diff --git a/internal/runtime/run.go b/internal/runtime/run.go index c4bdb542..58bc4cfd 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -135,6 +135,11 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { ChangedFiles: changedFiles, }) } + if statePtr != nil { + runEndCtx := context.Background() + s.recordRunEndWorkspaceState(runEndCtx, statePtr.session.ID, statePtr.session.Workdir, statePtr.lastEndOfTurnCheckpointID) + s.clearRunCheckpointCaches(statePtr.session.ID, statePtr.runID) + } s.emitRunTermination(runCtx, input, statePtr, err) }() ctx = runCtx @@ -191,6 +196,28 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { } effectiveWorkdir := agentsession.EffectiveWorkdir(state.session.Workdir, initialCfg.Workdir) + runBaselineCheckpointID := "" + if drifted, driftDiff := s.recordRunStartFingerprint(ctx, state.session.ID, state.runID, effectiveWorkdir); drifted { + if checkpointID, cpErr := s.createPreRunDriftRebaseCheckpoint(ctx, &state, driftDiff); cpErr != nil { + s.emitRunScopedOptional(EventCheckpointWarning, &state, CheckpointWarningPayload{ + Error: fmt.Sprintf("workspace drift detected but baseline rebase failed: %v", cpErr), + Phase: "run_start_drift_rebase", + }) + } else if checkpointID != "" { + runBaselineCheckpointID = checkpointID + s.persistRunRollbackBaseline(ctx, state.session.ID, state.runID, checkpointID, true) + s.recordRunEndWorkspaceState(ctx, state.session.ID, effectiveWorkdir, checkpointID) + } + s.emitRunScopedOptional(EventCheckpointWarning, &state, CheckpointWarningPayload{ + Error: "workspace drift detected before run start", + Phase: "run_start_drift", + }) + } + if runBaselineCheckpointID == "" { + if baseline, _ := s.getPersistentRunRollbackBaseline(ctx, state.session.ID, state.runID); baseline != "" { + runBaselineCheckpointID = baseline + } + } _ = s.runHookPoint(ctx, &state, runtimehooks.HookPointSessionStart, runtimehooks.HookContext{ Metadata: map[string]any{ "session_id": state.session.ID, @@ -223,7 +250,11 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { s.updateResumeCheckpoint(ctx, &state, "plan", "") maxTurns := resolveRuntimeMaxTurns(initialCfg.Runtime) - state.baselineCheckpointID = s.findPreviousEndOfTurnCheckpoint(ctx, sessionID, input.RunID) + if runBaselineCheckpointID != "" { + state.baselineCheckpointID = runBaselineCheckpointID + } else { + state.baselineCheckpointID = s.findPreviousEndOfTurnCheckpoint(ctx, sessionID, input.RunID) + } for turn := 0; ; turn++ { if turn >= maxTurns { state.maxTurnsReached = true diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 0da88617..f1db6601 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -193,19 +193,26 @@ type Service struct { events chan RuntimeEvent runtimeSnapshotMu sync.Mutex runtimeSnapshots map[string]RuntimeSnapshot - sessionMu sync.Mutex - sessionLocks map[string]*sessionLockEntry - runMu sync.Mutex - activeRunToken uint64 - nextRunToken uint64 - activeRunCancels map[uint64]context.CancelFunc - activeRunByID map[string]uint64 - activeRunTokenIDs map[uint64]string - activeRunStates map[uint64]*runState - permissionAskMapMu sync.Mutex - permissionAskLocks map[string]*permissionAskLockEntry - askStore AskSessionStore - askSequence uint64 + rollbackBaselineMu sync.Mutex + // runRollbackBaselineByRunKey 记录一次 run 的权威回退基线 checkpoint_id,key=session_id+"\n"+run_id。 + runRollbackBaselineByRunKey map[string]string + // runWorkspaceDriftByRunKey 标记 run 开始前工作区是否发生空闲期漂移,key=session_id+"\n"+run_id。 + runWorkspaceDriftByRunKey map[string]bool + // lastRunFingerprintByWorkspaceKey 保存每个工作区上一次 run 结束时的指纹,用于跨会话漂移检测。 + lastRunFingerprintByWorkspaceKey map[string]repository.WorkdirFingerprint + sessionMu sync.Mutex + sessionLocks map[string]*sessionLockEntry + runMu sync.Mutex + activeRunToken uint64 + nextRunToken uint64 + activeRunCancels map[uint64]context.CancelFunc + activeRunByID map[string]uint64 + activeRunTokenIDs map[uint64]string + activeRunStates map[uint64]*runState + permissionAskMapMu sync.Mutex + permissionAskLocks map[string]*permissionAskLockEntry + askStore AskSessionStore + askSequence uint64 thinkingEnabled bool @@ -261,24 +268,27 @@ func NewWithFactory( } service := &Service{ - configManager: configManager, - sessionStore: sessionStore, - toolManager: toolManager, - providerFactory: providerFactory, - contextBuilder: contextBuilder, - repositoryService: repository.NewService(), - approvalBroker: approval.NewBroker(), - askUserBroker: askuser.NewBroker(), - events: make(chan RuntimeEvent, 128), - runtimeSnapshots: make(map[string]RuntimeSnapshot), - sessionLocks: make(map[string]*sessionLockEntry), - permissionAskLocks: make(map[string]*permissionAskLockEntry), - activeRunCancels: make(map[uint64]context.CancelFunc), - activeRunByID: make(map[string]uint64), - activeRunTokenIDs: make(map[uint64]string), - activeRunStates: make(map[uint64]*runState), - askStore: newInMemoryAskSessionStore(askSessionTTL), - thinkingEnabled: true, + configManager: configManager, + sessionStore: sessionStore, + toolManager: toolManager, + providerFactory: providerFactory, + contextBuilder: contextBuilder, + repositoryService: repository.NewService(), + approvalBroker: approval.NewBroker(), + askUserBroker: askuser.NewBroker(), + events: make(chan RuntimeEvent, 128), + runtimeSnapshots: make(map[string]RuntimeSnapshot), + runRollbackBaselineByRunKey: make(map[string]string), + runWorkspaceDriftByRunKey: make(map[string]bool), + lastRunFingerprintByWorkspaceKey: make(map[string]repository.WorkdirFingerprint), + sessionLocks: make(map[string]*sessionLockEntry), + permissionAskLocks: make(map[string]*permissionAskLockEntry), + activeRunCancels: make(map[uint64]context.CancelFunc), + activeRunByID: make(map[string]uint64), + activeRunTokenIDs: make(map[uint64]string), + activeRunStates: make(map[uint64]*runState), + askStore: newInMemoryAskSessionStore(askSessionTTL), + thinkingEnabled: true, } baseHookExecutor := runtimehooks.NewExecutor(runtimehooks.NewRegistry(), newHookRuntimeEventEmitter(service), runtimehooks.DefaultHookTimeout) baseHookExecutor.SetAsyncResultSink(newHookAsyncResultSink(service)) diff --git a/internal/session/checkpoint_types.go b/internal/session/checkpoint_types.go index d21ac79b..c0a4e62c 100644 --- a/internal/session/checkpoint_types.go +++ b/internal/session/checkpoint_types.go @@ -6,13 +6,14 @@ import "time" type CheckpointReason string const ( - CheckpointReasonPreWrite CheckpointReason = "pre_write" - CheckpointReasonCompact CheckpointReason = "compact" - CheckpointReasonPlanMode CheckpointReason = "plan_mode" - CheckpointReasonManual CheckpointReason = "manual" - CheckpointReasonGuard CheckpointReason = "pre_restore_guard" - CheckpointReasonPreWriteDegraded CheckpointReason = "pre_write_degraded" - CheckpointReasonEndOfTurn CheckpointReason = "end_of_turn" + CheckpointReasonPreWrite CheckpointReason = "pre_write" + CheckpointReasonCompact CheckpointReason = "compact" + CheckpointReasonPlanMode CheckpointReason = "plan_mode" + CheckpointReasonManual CheckpointReason = "manual" + CheckpointReasonGuard CheckpointReason = "pre_restore_guard" + CheckpointReasonPreRunDriftRebase CheckpointReason = "pre_run_drift_rebase" + CheckpointReasonPreWriteDegraded CheckpointReason = "pre_write_degraded" + CheckpointReasonEndOfTurn CheckpointReason = "end_of_turn" ) // CheckpointStatus 描述 checkpoint 的生命周期状态。 diff --git a/internal/tui/services/runtime_contract.go b/internal/tui/services/runtime_contract.go index 7c4d1cf8..8ce28ade 100644 --- a/internal/tui/services/runtime_contract.go +++ b/internal/tui/services/runtime_contract.go @@ -245,6 +245,8 @@ type CheckpointDiffResult struct { PrevCommitHash string `json:"prev_commit_hash,omitempty"` Files CheckpointDiffFiles `json:"files"` Patch string `json:"patch,omitempty"` + WorkspaceDrifted bool `json:"workspace_drifted,omitempty"` + Warning string `json:"warning,omitempty"` } // WorkspaceRecord 描述工作区登记信息。 diff --git a/web/src/api/protocol.ts b/web/src/api/protocol.ts index fbe1011c..249b0dc1 100644 --- a/web/src/api/protocol.ts +++ b/web/src/api/protocol.ts @@ -81,6 +81,7 @@ export const EventType = { ToolStart: 'tool_start', ToolResult: 'tool_result', ToolDiff: 'tool_diff', + BashSideEffect: 'bash_side_effect', ToolChunk: 'tool_chunk', ToolCallThinking: 'tool_call_thinking', ThinkingDelta: 'thinking_delta', @@ -484,6 +485,8 @@ export interface CheckpointDiffResultPayload { prev_commit_hash?: string files: FileDiffs patch?: string + workspace_drifted?: boolean + warning?: string } export interface CheckpointRestoreResultPayload { @@ -916,6 +919,7 @@ export interface ToolDiffFileEntry { path: string diff?: string was_new?: boolean + kind?: string } /** tool_diff 事件载荷:写工具修改了哪些文件 */ @@ -928,3 +932,12 @@ export interface ToolDiffPayload { files?: ToolDiffFileChange[] diffs?: ToolDiffFileEntry[] } + +/** bash_side_effect 文件变更条目 */ +export interface BashSideEffectPayload { + tool_call_id: string + command?: string + changes?: ToolDiffFileChange[] + preemptively_captured_paths?: string[] + uncovered_paths?: string[] +} diff --git a/web/src/components/panels/FileChangePanel.test.tsx b/web/src/components/panels/FileChangePanel.test.tsx index 5e47ccda..52f38b1f 100644 --- a/web/src/components/panels/FileChangePanel.test.tsx +++ b/web/src/components/panels/FileChangePanel.test.tsx @@ -1,11 +1,17 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { fireEvent, render, screen, waitFor, within } from '@testing-library/react' import FileChangePanel from './FileChangePanel' +import { useChatStore } from '@/stores/useChatStore' +import { useSessionStore } from '@/stores/useSessionStore' import { CHANGES_PREVIEW_TAB_ID, GIT_DIFF_PREVIEW_TAB_ID, useUIStore } from '@/stores/useUIStore' const mockGatewayAPI = { listGitDiffFiles: vi.fn(), readGitDiffFile: vi.fn(), + restoreCheckpoint: vi.fn(), + loadSession: vi.fn(), + listSessionTodos: vi.fn(), + listCheckpoints: vi.fn(), } vi.mock('@/context/RuntimeProvider', () => ({ @@ -71,6 +77,27 @@ describe('FileChangePanel', () => { size_modified: 5, }, }) + mockGatewayAPI.restoreCheckpoint.mockResolvedValue({ + payload: { + checkpoint_id: 'cp-1', + session_id: 'sess-1', + }, + }) + mockGatewayAPI.loadSession.mockResolvedValue({ + payload: { + id: 'sess-1', + agent_mode: 'build', + messages: [{ role: 'assistant', content: 'reloaded' }], + }, + }) + mockGatewayAPI.listSessionTodos.mockResolvedValue({ + payload: { items: [] }, + }) + mockGatewayAPI.listCheckpoints.mockResolvedValue({ + payload: [], + }) + useChatStore.setState({ isGenerating: false } as never) + useSessionStore.setState({ currentSessionId: 'sess-1' } as never) useUIStore.setState({ fileChanges: [ { @@ -79,6 +106,7 @@ describe('FileChangePanel', () => { status: 'modified', additions: 2, deletions: 2, + checkpoint_id: 'cp-1', hunks: [ { header: '@@ -1,3 +1,3 @@', @@ -128,6 +156,7 @@ describe('FileChangePanel', () => { changesPanelOpen: true, changesPanelWidth: 560, theme: 'dark', + isRestoringCheckpoint: false, } as never) }) @@ -136,16 +165,113 @@ describe('FileChangePanel', () => { fireEvent.click(screen.getByText('src/a.txt')) - expect(screen.getByText('接受')).toBeTruthy() + expect(screen.getByTestId('accept-change-fc-1')).toBeTruthy() expect(screen.getAllByTestId(/diff-hunk-fc-1-/)).toHaveLength(1) expect(screen.getByText('line 1')).toBeTruthy() expect(screen.getByText('line 2 new')).toBeTruthy() - fireEvent.click(screen.getByText('接受')) + fireEvent.click(screen.getByTestId('accept-change-fc-1')) expect(useUIStore.getState().fileChanges[0]?.status).toBe('accepted') }) + it('renders restore button per file change and disables it when checkpoint is missing', () => { + useUIStore.setState({ + fileChanges: [ + { + id: 'fc-with-cp', + path: 'src/with-cp.txt', + status: 'modified', + additions: 1, + deletions: 0, + checkpoint_id: 'cp-1', + hunks: [], + }, + { + id: 'fc-without-cp', + path: 'src/no-cp.txt', + status: 'modified', + additions: 1, + deletions: 0, + hunks: [], + }, + ], + } as never) + + render() + + fireEvent.click(screen.getByText('src/with-cp.txt')) + fireEvent.click(screen.getByText('src/no-cp.txt')) + + expect(screen.getByTestId('restore-change-fc-with-cp')).toBeEnabled() + expect(screen.getByTestId('restore-change-fc-without-cp')).toBeDisabled() + }) + + it('calls restoreCheckpoint after confirming restore', async () => { + render() + + fireEvent.click(screen.getByText('src/a.txt')) + fireEvent.click(screen.getByTestId('restore-change-fc-1')) + + const confirmButtons = screen.getAllByRole('button', { name: 'Restore' }) + fireEvent.click(confirmButtons[confirmButtons.length - 1]) + + await waitFor(() => { + expect(mockGatewayAPI.restoreCheckpoint).toHaveBeenCalledWith({ + session_id: 'sess-1', + checkpoint_id: 'cp-1', + }) + }) + expect(mockGatewayAPI.loadSession).not.toHaveBeenCalled() + }) + + it('disables accept and restore actions while the session is generating', () => { + useChatStore.setState({ isGenerating: true } as never) + render() + + fireEvent.click(screen.getByText('src/a.txt')) + + expect(screen.getByTestId('accept-change-fc-1')).toBeDisabled() + expect(screen.getByTestId('restore-change-fc-1')).toBeDisabled() + fireEvent.click(screen.getByTestId('restore-change-fc-1')) + expect(mockGatewayAPI.restoreCheckpoint).not.toHaveBeenCalled() + }) + + it('disables accept and restore actions while checkpoint restore is in progress', () => { + useUIStore.setState({ isRestoringCheckpoint: true } as never) + render() + + fireEvent.click(screen.getByText('src/a.txt')) + + expect(screen.getByTestId('accept-change-fc-1')).toBeDisabled() + expect(screen.getByTestId('restore-change-fc-1')).toBeDisabled() + }) + + it('prevents duplicate restore calls while restore is in-flight', async () => { + let resolveRestore: ((value?: unknown) => void) | undefined + mockGatewayAPI.restoreCheckpoint.mockImplementation(() => new Promise((resolve) => { + resolveRestore = resolve + })) + + render() + + fireEvent.click(screen.getByText('src/a.txt')) + fireEvent.click(screen.getByTestId('restore-change-fc-1')) + const confirmButtons = screen.getAllByRole('button', { name: 'Restore' }) + fireEvent.click(confirmButtons[confirmButtons.length - 1]) + + await waitFor(() => { + expect(mockGatewayAPI.restoreCheckpoint).toHaveBeenCalledTimes(1) + }) + + const restoreButton = screen.getByTestId('restore-change-fc-1') + expect(restoreButton).toBeDisabled() + fireEvent.click(restoreButton) + expect(mockGatewayAPI.restoreCheckpoint).toHaveBeenCalledTimes(1) + + resolveRestore?.() + }) + it('keeps the panel body clipped and uses the content area as the scroll container', () => { render() @@ -374,7 +500,7 @@ describe('FileChangePanel', () => { { id: CHANGES_PREVIEW_TAB_ID, kind: 'changes', - title: '鏂囦欢鍙樻洿', + title: '文件变更', closable: false, }, { @@ -449,7 +575,7 @@ describe('FileChangePanel', () => { { id: CHANGES_PREVIEW_TAB_ID, kind: 'changes', - title: '鏂囦欢鍙樻洿', + title: '文件变更', closable: false, }, { diff --git a/web/src/components/panels/FileChangePanel.tsx b/web/src/components/panels/FileChangePanel.tsx index e323f147..4842c26b 100644 --- a/web/src/components/panels/FileChangePanel.tsx +++ b/web/src/components/panels/FileChangePanel.tsx @@ -1,6 +1,9 @@ import { lazy, Suspense, useCallback, useEffect, useMemo, useRef, useState, type CSSProperties, type KeyboardEvent as ReactKeyboardEvent, type RefObject, type WheelEvent as ReactWheelEvent } from 'react' -import { ChevronRight, Check, FileDiff, Loader2, PanelRightClose, RefreshCw, X } from 'lucide-react' +import { ChevronRight, Check, FileDiff, Loader2, PanelRightClose, RefreshCw, RotateCcw, X } from 'lucide-react' import { useGatewayAPI } from '@/context/RuntimeProvider' +import ConfirmDialog from '@/components/ui/ConfirmDialog' +import { useChatStore } from '@/stores/useChatStore' +import { useSessionStore } from '@/stores/useSessionStore' import { useWorkspaceStore } from '@/stores/useWorkspaceStore' import { CHANGES_PREVIEW_TAB_ID, @@ -19,6 +22,7 @@ const LazyCodePreviewEditor = lazy(() => import('./CodePreviewEditor')) const LazyGitDiffPreviewEditor = lazy(() => import('./GitDiffPreviewEditor')) const changeStatusMeta: Record = { + pending: { label: '待定', color: 'var(--text-tertiary)', bg: 'var(--bg-active)' }, added: { label: '新增', color: 'var(--success)', bg: 'rgba(22, 163, 74, 0.12)' }, modified: { label: '修改', color: 'var(--warning)', bg: 'rgba(217, 119, 6, 0.14)' }, deleted: { label: '删除', color: 'var(--error)', bg: 'rgba(220, 38, 38, 0.12)' }, @@ -41,6 +45,7 @@ function getChangeStatusMeta(status: string) { } function getChangeType(change: { status: string; additions: number; deletions: number }) { + if (change.status === 'pending') return 'pending' as const if (['added', 'modified', 'deleted'].includes(change.status)) return change.status as 'added' | 'modified' | 'deleted' if (change.additions > 0 && change.deletions > 0) return 'modified' if (change.additions > 0) return 'added' @@ -52,12 +57,13 @@ function getChangeCounts(fileChanges: { status: string; additions: number; delet return fileChanges.reduce( (counts, change) => { const type = getChangeType(change) + if (type === 'pending') counts.pending += 1 if (type === 'added') counts.added += 1 if (type === 'modified') counts.modified += 1 if (type === 'deleted') counts.deleted += 1 return counts }, - { added: 0, modified: 0, deleted: 0 }, + { pending: 0, added: 0, modified: 0, deleted: 0 }, ) } @@ -99,6 +105,7 @@ function getPreviewBadge(tab: PreviewTab, fileChanges: FileChange[]) { const matched = fileChanges.find((change) => change.path === tab.path) if (!matched) return null const type = getChangeType(matched) + if (type === 'pending') return { label: 'P', color: 'var(--text-tertiary)' } if (type === 'added') return { label: 'A', color: 'var(--success)' } if (type === 'deleted') return { label: 'D', color: 'var(--error)' } return { label: 'M', color: 'var(--warning)' } @@ -146,12 +153,31 @@ function FileChangeItem({ onToggle: () => void scrollContainerRef: RefObject }) { + const gatewayAPI = useGatewayAPI() + const sessionId = useSessionStore((state) => state.currentSessionId) + const isGenerating = useChatStore((state) => state.isGenerating) const acceptFileChange = useUIStore((state) => state.acceptFileChange) + const isRestoringCheckpoint = useUIStore((state) => state.isRestoringCheckpoint) + const setRestoringCheckpoint = useUIStore((state) => state.setRestoringCheckpoint) + const showToast = useUIStore((state) => state.showToast) const meta = getChangeStatusMeta(change.status) const reviewed = change.status === 'accepted' || change.status === 'rejected' const displayHunks = getDisplayHunks(change) + const [confirmingRestore, setConfirmingRestore] = useState(false) + const disabledByRunning = isGenerating + const disabledByRestore = isRestoringCheckpoint + const acceptDisabled = reviewed || disabledByRestore || disabledByRunning + const restoreDisabled = reviewed || disabledByRestore || disabledByRunning || !change.checkpoint_id + const acceptTitle = disabledByRunning || disabledByRestore ? 'Running; action is disabled' : 'Mark as reviewed' + const restoreTitle = disabledByRunning + ? 'Running; action is disabled' + : disabledByRestore + ? 'Checkpoint restore in progress' + : change.checkpoint_id + ? 'Restore to this point (impacts all later changes)' + : 'No checkpoint available for this file change' - // 将 diff 内层收到的纵向滚轮显式转发给外层列表,避免 hover 在 hunk 上时卡住整列滚动。 + // 横向 diff 区域优先消费纵向滚轮,避免 hover 在 hunk 上时页面滚动。 const handleHunkWheel = (event: ReactWheelEvent) => { if (Math.abs(event.deltaY) <= Math.abs(event.deltaX)) { return @@ -164,6 +190,24 @@ function FileChangeItem({ event.preventDefault() } + // handleConfirmRestore 只负责触发 checkpoint 回退,成功后的状态刷新由事件链路驱动。 + async function handleConfirmRestore() { + setConfirmingRestore(false) + if (!gatewayAPI || !sessionId || !change.checkpoint_id || disabledByRestore || disabledByRunning) return + + setRestoringCheckpoint(true) + try { + await gatewayAPI.restoreCheckpoint({ + session_id: sessionId, + checkpoint_id: change.checkpoint_id, + }) + } catch (err) { + const message = err instanceof Error ? err.message : 'Unknown error' + showToast(`Restore failed: ${message}`, 'error') + setRestoringCheckpoint(false) + } + } + return (
+
{displayHunks.length === 0 ? ( -
当前文件没有可展示的 diff
+
暂无可展示的 diff
) : ( displayHunks.map((hunk, index) => (
)} + {confirmingRestore && ( + setConfirmingRestore(false)} + /> + )}
) } @@ -247,6 +317,7 @@ function ChangesView() {
{fileChanges.length} 个文件 + {counts.pending} 待定 {counts.added} 新增 {counts.modified} 修改 {counts.deleted} 删除 diff --git a/web/src/stores/useSessionStore.ts b/web/src/stores/useSessionStore.ts index 5f61eecc..10755991 100644 --- a/web/src/stores/useSessionStore.ts +++ b/web/src/stores/useSessionStore.ts @@ -214,6 +214,43 @@ export function mapHistoryMessages(backendMessages: BackendMessage[]): Array { + const normalizedSessionId = sessionId.trim() + if (!normalizedSessionId) return false + + useUIStore.getState().clearFileChanges() + useChatStore.getState().clearMessages() + useRuntimeInsightStore.getState().reset() + + const sessionFrame = await loadSessionWithInsights(gatewayAPI, normalizedSessionId) + if (reloadSeq !== _latestCheckpointRestoreReloadSeq) return false + const sessionData = sessionFrame.payload as { messages?: BackendMessage[]; agent_mode?: string } + + if (sessionData.messages && sessionData.messages.length > 0) { + const mapped = mapHistoryMessages(sessionData.messages) + for (const msg of mapped) { + useChatStore.getState().addMessage(msg) + } + } + + const restoredMode = sessionData.agent_mode === 'plan' ? 'plan' : 'build' + useChatStore.getState().setAgentMode(restoredMode) + return true +} + let _fetchSessionsPromise: Promise | null = null export const useSessionStore = create((set, get) => ({ diff --git a/web/src/stores/useUIStore.ts b/web/src/stores/useUIStore.ts index 3b37e19e..06980768 100644 --- a/web/src/stores/useUIStore.ts +++ b/web/src/stores/useUIStore.ts @@ -11,7 +11,7 @@ export interface Toast { export interface FileChange { id: string path: string - status: 'added' | 'modified' | 'deleted' | 'accepted' | 'rejected' + status: 'pending' | 'added' | 'modified' | 'deleted' | 'accepted' | 'rejected' additions: number deletions: number diff?: DiffLine[] @@ -196,6 +196,7 @@ interface UIState { theme: 'light' | 'dark' searchQuery: string fileChanges: FileChange[] + isRestoringCheckpoint: boolean gitDiffSummary: GitDiffSummary gitDiffLoading: boolean gitDiffError: string @@ -219,6 +220,7 @@ interface UIState { acceptFileChange: (id: string) => void rejectFileChange: (id: string) => void clearFileChanges: () => void + setRestoringCheckpoint: (restoring: boolean) => void openPreviewTab: (path: string) => OpenPreviewTabResult openGitDiffTab: (path: string) => OpenPreviewTabResult activatePreviewTab: (id: string) => void @@ -251,6 +253,7 @@ export const useUIStore = create((set) => ({ theme: (localStorage.getItem('neocode-theme') as 'light' | 'dark') || 'dark', searchQuery: '', fileChanges: [], + isRestoringCheckpoint: false, gitDiffSummary: createEmptyGitDiffSummary(), gitDiffLoading: false, gitDiffError: '', @@ -287,6 +290,7 @@ export const useUIStore = create((set) => ({ fileChanges: state.fileChanges.map((change) => (change.id === id ? { ...change, status: 'rejected' as const } : change)), })), clearFileChanges: () => set({ fileChanges: [] }), + setRestoringCheckpoint: (isRestoringCheckpoint) => set({ isRestoringCheckpoint }), openPreviewTab: (path) => { const normalizedPath = path.trim() const tabID = `file:${normalizedPath}` diff --git a/web/src/utils/eventBridge.test.ts b/web/src/utils/eventBridge.test.ts index 1bffaa8b..b6a7c281 100644 --- a/web/src/utils/eventBridge.test.ts +++ b/web/src/utils/eventBridge.test.ts @@ -1,4 +1,5 @@ import { describe, it, expect, beforeEach, vi } from 'vitest' +import { waitFor } from '@testing-library/react' import { useChatStore } from '@/stores/useChatStore' import { useGatewayStore } from '@/stores/useGatewayStore' import { useSessionStore } from '@/stores/useSessionStore' @@ -36,7 +37,8 @@ beforeEach(() => { authenticated: false, } as any) useRuntimeInsightStore.getState().reset() - useUIStore.setState({ toasts: [], fileChanges: [] } as any) + useSessionStore.setState({ currentSessionId: '' } as any) + useUIStore.setState({ toasts: [], fileChanges: [], isRestoringCheckpoint: false } as any) }) describe('eventBridge', () => { @@ -193,6 +195,19 @@ describe('eventBridge', () => { expect(msgs[0].toolName).toBe('read_file') }) + it('ToolStart file placeholders are pending before tool diff arrives', () => { + const api = createMockGatewayAPI() + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc-pending', arguments: '{"path":"pending.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + const change = useUIStore.getState().fileChanges.find((entry) => entry.path === 'pending.txt') + expect(change?.status).toBe('pending') + }) + it('ToolResult updates an existing tool call message', () => { const api = createMockGatewayAPI() // 先触发 ToolStart 创建工具消息 @@ -472,6 +487,87 @@ describe('eventBridge', () => { expect(useRuntimeInsightStore.getState().checkpointEvents[0]).toMatchObject({ checkpoint_id: 'cp1' }) }) + it('CheckpointRestored reloads state for current session and clears stale file changes first', async () => { + const loadSession = vi.fn(async () => ({ + payload: { + id: 'sess-1', + agent_mode: 'build', + messages: [{ role: 'assistant', content: 'after restore' }], + }, + })) + const api = createMockGatewayAPI({ loadSession }) + useSessionStore.setState({ currentSessionId: 'sess-1' } as any) + useUIStore.setState({ + fileChanges: [ + { id: 'fc-1', path: 'stale.txt', status: 'modified', additions: 1, deletions: 0 }, + ], + } as any) + + handleGatewayEvent({ + type: EventType.CheckpointRestored, + payload: { payload: { runtime_event_type: EventType.CheckpointRestored, payload: { checkpoint_id: 'cp1', session_id: 'sess-1', guard_checkpoint_id: 'guard-1' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + await Promise.resolve() + await Promise.resolve() + + expect(loadSession).toHaveBeenCalledWith('sess-1') + expect(useUIStore.getState().fileChanges).toHaveLength(0) + }) + + it('CheckpointRestored does not reload when event session differs from current session', async () => { + const loadSession = vi.fn(async () => ({ payload: { id: 'sess-other', messages: [] } })) + const api = createMockGatewayAPI({ loadSession }) + useSessionStore.setState({ currentSessionId: 'sess-current' } as any) + useUIStore.setState({ + fileChanges: [ + { id: 'fc-1', path: 'keep.txt', status: 'modified', additions: 1, deletions: 0 }, + ], + } as any) + + handleGatewayEvent({ + type: EventType.CheckpointRestored, + payload: { payload: { runtime_event_type: EventType.CheckpointRestored, payload: { checkpoint_id: 'cp1', session_id: 'sess-other', guard_checkpoint_id: 'guard-1' } } }, + session_id: 'sess-other', + run_id: 'run-1', + }, api) + await Promise.resolve() + await Promise.resolve() + + expect(loadSession).not.toHaveBeenCalled() + expect(useUIStore.getState().fileChanges).toHaveLength(1) + }) + + it('CheckpointUndoRestore reloads current session with the same restore-sync flow', async () => { + const loadSession = vi.fn(async () => ({ + payload: { + id: 'sess-1', + agent_mode: 'build', + messages: [{ role: 'assistant', content: 'after undo restore' }], + }, + })) + const api = createMockGatewayAPI({ loadSession }) + useSessionStore.setState({ currentSessionId: 'sess-1' } as any) + useUIStore.setState({ + fileChanges: [ + { id: 'fc-1', path: 'stale.txt', status: 'modified', additions: 1, deletions: 0 }, + ], + } as any) + + handleGatewayEvent({ + type: EventType.CheckpointUndoRestore, + payload: { payload: { runtime_event_type: EventType.CheckpointUndoRestore, payload: { session_id: 'sess-1', guard_checkpoint_id: 'guard-1' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + await Promise.resolve() + await Promise.resolve() + + expect(loadSession).toHaveBeenCalledWith('sess-1') + expect(useUIStore.getState().fileChanges).toHaveLength(0) + }) + it('VerificationStarted creates a verification ChatMessage', () => { const api = createMockGatewayAPI() handleGatewayEvent({ @@ -571,6 +667,33 @@ describe('eventBridge', () => { expect(toolMsg?.checkpointStatus).toBe('available') }) + it('CheckpointCreated with pre_restore_guard does not override latest rollback baseline', () => { + const api = createMockGatewayAPI() + + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-base', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: 'abc', reason: 'pre_write' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-guard', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: 'abc', reason: 'pre_restore_guard' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc1', arguments: '{"path":"baseline.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + const baselineChange = useUIStore.getState().fileChanges.find((entry) => entry.path === 'baseline.txt') + expect(baselineChange?.checkpoint_id).toBe('cp-base') + }) + it('clearMessages resets eventBridge cursors so new session does not inherit prior checkpoint', () => { const api = createMockGatewayAPI() @@ -698,6 +821,177 @@ describe('eventBridge', () => { ]) }) + it('uses run diff prev_checkpoint_id as rollback baseline authority', async () => { + const checkpointDiff = vi.fn(async () => ({ + payload: { + checkpoint_id: 'cp-latest', + prev_checkpoint_id: 'cp-authoritative', + files: { modified: ['a.txt'] }, + patch: '--- a/a.txt\n+++ b/a.txt\n@@ -1 +1 @@\n-old\n+new\n', + }, + })) + const api = createMockGatewayAPI({ checkpointDiff }) + + handleGatewayEvent({ + type: EventType.InputNormalized, + payload: { payload: { runtime_event_type: EventType.InputNormalized, payload: { session_id: 'sess-1', run_id: 'run-1' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-base', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'pre_write' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc1', arguments: '{"path":"a.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-latest', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'end_of_turn' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + await Promise.resolve() + await Promise.resolve() + + const change = useUIStore.getState().fileChanges.find((entry) => entry.path === 'a.txt') + expect(change?.checkpoint_id).toBe('cp-authoritative') + }) + + it('shows warning toast when run diff carries warning text', async () => { + const checkpointDiff = vi.fn(async () => ({ + payload: { + checkpoint_id: 'cp-latest', + files: { modified: ['a.txt'] }, + patch: '--- a/a.txt\n+++ b/a.txt\n@@ -1 +1 @@\n-old\n+new\n', + warning: 'run baseline checkpoint is missing', + }, + })) + const api = createMockGatewayAPI({ checkpointDiff }) + + handleGatewayEvent({ + type: EventType.InputNormalized, + payload: { payload: { runtime_event_type: EventType.InputNormalized, payload: { session_id: 'sess-1', run_id: 'run-1' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc1', arguments: '{"path":"a.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-latest', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'end_of_turn' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + await Promise.resolve() + await Promise.resolve() + + expect(useUIStore.getState().toasts.some((toast) => toast.message.includes('run baseline checkpoint is missing'))).toBe(true) + }) + + it('keeps first-touch checkpoint for a file after run-scoped diff replacement', async () => { + const checkpointDiff = vi.fn(async () => ({ + payload: { + checkpoint_id: 'cp-latest', + files: { modified: ['a.txt'] }, + patch: '--- a/a.txt\n+++ b/a.txt\n@@ -1 +1 @@\n-old\n+new\n', + }, + })) + const api = createMockGatewayAPI({ checkpointDiff }) + + handleGatewayEvent({ + type: EventType.InputNormalized, + payload: { payload: { runtime_event_type: EventType.InputNormalized, payload: { session_id: 'sess-1', run_id: 'run-1' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-base', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'pre_write' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc1', arguments: '{"path":"a.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.ToolDiff, + payload: { payload: { runtime_event_type: EventType.ToolDiff, payload: { tool_call_id: 'tc1', tool_name: 'filesystem_write_file', file_path: 'a.txt', diff: '--- a/a.txt\n+++ b/a.txt\n@@ -1 +1 @@\n-old\n+new\n' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-latest', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'end_of_turn' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + await Promise.resolve() + await Promise.resolve() + + const change = useUIStore.getState().fileChanges.find((entry) => entry.path === 'a.txt') + expect(change?.checkpoint_id).toBe('cp-base') + }) + + it('does not overwrite first-touch checkpoint when the same file is edited multiple times', () => { + const api = createMockGatewayAPI() + + handleGatewayEvent({ + type: EventType.InputNormalized, + payload: { payload: { runtime_event_type: EventType.InputNormalized, payload: { session_id: 'sess-1', run_id: 'run-1' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-base', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'pre_write' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc1', arguments: '{"path":"a.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.ToolDiff, + payload: { payload: { runtime_event_type: EventType.ToolDiff, payload: { tool_call_id: 'tc1', tool_name: 'filesystem_write_file', file_path: 'a.txt', diff: '--- a/a.txt\n+++ b/a.txt\n@@ -1 +1 @@\n-old\n+mid\n' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-next', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'pre_write' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.ToolDiff, + payload: { payload: { runtime_event_type: EventType.ToolDiff, payload: { tool_call_id: 'tc2', tool_name: 'filesystem_write_file', file_path: 'a.txt', diff: '--- a/a.txt\n+++ b/a.txt\n@@ -1 +1 @@\n-mid\n+new\n' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + const change = useUIStore.getState().fileChanges.find((entry) => entry.path === 'a.txt') + expect(change?.checkpoint_id).toBe('cp-base') + }) + it('stores hunk structure for transient tool diffs before aggregate checkpoint diff arrives', () => { const api = createMockGatewayAPI() @@ -735,6 +1029,61 @@ describe('eventBridge', () => { ]) }) + it('uses backend kind for multi-file tool diffs instead of was_new fallback', () => { + const api = createMockGatewayAPI() + + handleGatewayEvent({ + type: EventType.ToolDiff, + payload: { + payload: { + runtime_event_type: EventType.ToolDiff, + payload: { + tool_call_id: 'tc-move', + tool_name: 'filesystem_move_file', + file_path: 'old.txt', + files: [ + { path: 'old.txt', kind: 'deleted' }, + { path: 'new.txt', kind: 'added' }, + ], + diffs: [ + { path: 'old.txt', kind: 'deleted', was_new: false, diff: '--- a/old.txt\n+++ b/old.txt\n-old\n' }, + { path: 'new.txt', kind: 'added', was_new: false, diff: '--- a/new.txt\n+++ b/new.txt\n+new\n' }, + ], + }, + }, + }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + const changes = useUIStore.getState().fileChanges + expect(changes.find((entry) => entry.path === 'old.txt')?.status).toBe('deleted') + expect(changes.find((entry) => entry.path === 'new.txt')?.status).toBe('added') + }) + + it('tracks bash side-effect changes using backend kind', () => { + const api = createMockGatewayAPI() + + handleGatewayEvent({ + type: EventType.BashSideEffect, + payload: { + payload: { + runtime_event_type: EventType.BashSideEffect, + payload: { + tool_call_id: 'bash-1', + command: 'echo hi > generated.txt', + changes: [{ path: 'generated.txt', kind: 'added' }], + }, + }, + }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + const change = useUIStore.getState().fileChanges.find((entry) => entry.path === 'generated.txt') + expect(change?.status).toBe('added') + }) + it('keeps transient tool diffs visible when backend sends simplified diff without @@ header', () => { const api = createMockGatewayAPI() @@ -828,4 +1177,256 @@ describe('eventBridge', () => { deletions: 0, }) }) + + it('rebinds post-restore first-touch baseline to restored checkpoint', async () => { + const loadSession = vi.fn(async () => ({ + payload: { + id: 'sess-1', + agent_mode: 'build', + messages: [], + }, + })) + const api = createMockGatewayAPI({ loadSession }) + useSessionStore.setState({ currentSessionId: 'sess-1' } as any) + + handleGatewayEvent({ + type: EventType.InputNormalized, + payload: { payload: { runtime_event_type: EventType.InputNormalized, payload: { session_id: 'sess-1', run_id: 'run-1' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-old', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'pre_write' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc1', arguments: '{"path":"a.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + handleGatewayEvent({ + type: EventType.CheckpointRestored, + payload: { payload: { runtime_event_type: EventType.CheckpointRestored, payload: { checkpoint_id: 'cp-restored', session_id: 'sess-1', guard_checkpoint_id: 'guard-1' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + await Promise.resolve() + await Promise.resolve() + + useUIStore.getState().clearFileChanges() + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc2', arguments: '{"path":"fresh.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-2', + }, api) + const fresh = useUIStore.getState().fileChanges.find((entry) => entry.path === 'fresh.txt') + expect(fresh?.checkpoint_id).toBe('cp-restored') + }) + + it('keeps restored pending baseline even when InputNormalized is missing before next run', async () => { + const loadSession = vi.fn(async () => ({ + payload: { + id: 'sess-1', + agent_mode: 'build', + messages: [], + }, + })) + const api = createMockGatewayAPI({ loadSession }) + useSessionStore.setState({ currentSessionId: 'sess-1' } as any) + + handleGatewayEvent({ + type: EventType.CheckpointRestored, + payload: { payload: { runtime_event_type: EventType.CheckpointRestored, payload: { checkpoint_id: 'cp-restored', session_id: 'sess-1', guard_checkpoint_id: 'guard-1' } } }, + session_id: 'sess-1', + run_id: 'run-restore', + }, api) + await Promise.resolve() + await Promise.resolve() + + // 注意: 故意不发送 InputNormalized + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc-next', arguments: '{"path":"no-input-normalized.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-next', + }, api) + + const change = useUIStore.getState().fileChanges.find((entry) => entry.path === 'no-input-normalized.txt') + expect(change?.checkpoint_id).toBe('cp-restored') + }) + + it('rebinds post-undo first-touch baseline to guard checkpoint', async () => { + const loadSession = vi.fn(async () => ({ + payload: { + id: 'sess-1', + agent_mode: 'build', + messages: [], + }, + })) + const api = createMockGatewayAPI({ loadSession }) + useSessionStore.setState({ currentSessionId: 'sess-1' } as any) + + handleGatewayEvent({ + type: EventType.CheckpointUndoRestore, + payload: { payload: { runtime_event_type: EventType.CheckpointUndoRestore, payload: { session_id: 'sess-1', guard_checkpoint_id: 'guard-1' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + await Promise.resolve() + await Promise.resolve() + + useUIStore.getState().clearFileChanges() + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc3', arguments: '{"path":"undo-fresh.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-2', + }, api) + const fresh = useUIStore.getState().fileChanges.find((entry) => entry.path === 'undo-fresh.txt') + expect(fresh?.checkpoint_id).toBe('guard-1') + }) + + it('resets first-touch cache when run_id changes for the same file path', () => { + const api = createMockGatewayAPI() + + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-old', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'pre_write' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc-old', arguments: '{"path":"same.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + let change = useUIStore.getState().fileChanges.find((entry) => entry.path === 'same.txt') + expect(change?.checkpoint_id).toBe('cp-old') + + useUIStore.getState().clearFileChanges() + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-new', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'pre_write' } } }, + session_id: 'sess-1', + run_id: 'run-2', + }, api) + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc-new', arguments: '{"path":"same.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-2', + }, api) + change = useUIStore.getState().fileChanges.find((entry) => entry.path === 'same.txt') + expect(change?.checkpoint_id).toBe('cp-new') + }) + + it('does not let pre_restore_guard overwrite pending restore baseline for the next run', async () => { + const loadSession = vi.fn(async () => ({ + payload: { + id: 'sess-1', + agent_mode: 'build', + messages: [], + }, + })) + const api = createMockGatewayAPI({ loadSession }) + useSessionStore.setState({ currentSessionId: 'sess-1' } as any) + + handleGatewayEvent({ + type: EventType.CheckpointRestored, + payload: { payload: { runtime_event_type: EventType.CheckpointRestored, payload: { checkpoint_id: 'cp-restored', session_id: 'sess-1', guard_checkpoint_id: 'guard-1' } } }, + session_id: 'sess-1', + run_id: 'run-restore', + }, api) + await Promise.resolve() + await Promise.resolve() + + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-guard', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'pre_restore_guard' } } }, + session_id: 'sess-1', + run_id: 'run-restore', + }, api) + + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc-next', arguments: '{"path":"pending-protected.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-next', + }, api) + const change = useUIStore.getState().fileChanges.find((entry) => entry.path === 'pending-protected.txt') + expect(change?.checkpoint_id).toBe('cp-restored') + }) + + it('applies only the latest restore reload when restore events arrive back-to-back', async () => { + let resolveFirst: ((value: unknown) => void) | undefined + let resolveSecond: ((value: unknown) => void) | undefined + const loadSession = vi.fn(() => { + if (!resolveFirst) { + return new Promise((resolve) => { + resolveFirst = resolve + }) + } + return new Promise((resolve) => { + resolveSecond = resolve + }) + }) + const api = createMockGatewayAPI({ loadSession }) + useSessionStore.setState({ currentSessionId: 'sess-1' } as any) + + handleGatewayEvent({ + type: EventType.CheckpointRestored, + payload: { payload: { runtime_event_type: EventType.CheckpointRestored, payload: { checkpoint_id: 'cp-old', session_id: 'sess-1', guard_checkpoint_id: 'guard-1' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.CheckpointRestored, + payload: { payload: { runtime_event_type: EventType.CheckpointRestored, payload: { checkpoint_id: 'cp-new', session_id: 'sess-1', guard_checkpoint_id: 'guard-2' } } }, + session_id: 'sess-1', + run_id: 'run-2', + }, api) + + resolveSecond?.({ + payload: { + id: 'sess-1', + agent_mode: 'build', + messages: [{ role: 'assistant', content: 'newest' }], + }, + }) + await Promise.resolve() + await Promise.resolve() + + resolveFirst?.({ + payload: { + id: 'sess-1', + agent_mode: 'build', + messages: [{ role: 'assistant', content: 'stale' }], + }, + }) + await Promise.resolve() + await Promise.resolve() + + await waitFor(() => { + expect(useUIStore.getState().isRestoringCheckpoint).toBe(false) + }) + + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc-after-restore', arguments: '{"path":"after-double-restore.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-3', + }, api) + const change = useUIStore.getState().fileChanges.find((entry) => entry.path === 'after-double-restore.txt') + expect(change?.checkpoint_id).toBe('cp-new') + + const textMessages = useChatStore.getState().messages.filter((m) => m.type === 'text') + expect(textMessages.map((m) => m.content)).toContain('newest') + expect(textMessages.map((m) => m.content)).not.toContain('stale') + }) }) diff --git a/web/src/utils/eventBridge.ts b/web/src/utils/eventBridge.ts index 91740fcf..8ebe17b0 100644 --- a/web/src/utils/eventBridge.ts +++ b/web/src/utils/eventBridge.ts @@ -1,6 +1,7 @@ import { EventType, type AcceptanceDecidedPayload, + type BashSideEffectPayload, type BudgetCheckedPayload, type BudgetEstimateFailedPayload, type CheckpointCreatedPayload, @@ -25,7 +26,11 @@ import { type GatewayAPI } from '@/api/gateway' import { useChatStore } from '@/stores/useChatStore' import { useUIStore } from '@/stores/useUIStore' import { useGatewayStore } from '@/stores/useGatewayStore' -import { useSessionStore } from '@/stores/useSessionStore' +import { + beginCheckpointRestoreReloadSeq, + reloadSessionAfterCheckpointRestore, + useSessionStore, +} from '@/stores/useSessionStore' import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' import { useWorkspaceStore } from '@/stores/useWorkspaceStore' import { parseSingleFileDiff, parseUnifiedPatch, type ParsedFileDiff } from '@/utils/patchParser' @@ -40,13 +45,48 @@ let _latestDoneToolCallId: string | undefined // 模块级缓存最新的 checkpoint_id,用于工具占位条目关联后续端到端 diff。 let _latestCheckpointId: string | undefined let _latestRunDiffRequestId = 0 +let _latestRestoreSyncRequestId = 0 +// 文件首次触碰时的回退基线 checkpoint:key=标准化路径,value=checkpoint_id。 +let _firstTouchRollbackCheckpointByPath = new Map() +// restore/undo 后“下一轮”回退基线锚点,仅由 restore/undo 事件写入。 +let _pendingNextRunRollbackCheckpointId: string | undefined +// 当前用于回退基线绑定的 run 边界(按 frame.run_id 检测)。 +let _currentRollbackRunId: string | undefined +// 标记 pending 基线已应用到哪个 run;切到下一 run 时自动失效。 +let _pendingRollbackAppliedRunId: string | undefined +const CHECKPOINT_REASON_PRE_RESTORE_GUARD = 'pre_restore_guard' /** 重置模块级游标 —— 在截断聊天历史 / 切换会话等场景调用,避免后续事件挂到已被移除的消息上 */ export function resetEventBridgeCursors() { + const keepCheckpointBaseline = useUIStore.getState().isRestoringCheckpoint _latestVerificationMsgId = undefined _latestDoneToolCallId = undefined - _latestCheckpointId = undefined + _latestCheckpointId = keepCheckpointBaseline ? _latestCheckpointId : undefined + _firstTouchRollbackCheckpointByPath = new Map() + _currentRollbackRunId = undefined + _pendingRollbackAppliedRunId = keepCheckpointBaseline ? _pendingRollbackAppliedRunId : undefined + _pendingNextRunRollbackCheckpointId = keepCheckpointBaseline ? _pendingNextRunRollbackCheckpointId : undefined _latestRunDiffRequestId += 1 + if (!keepCheckpointBaseline) { + _latestRestoreSyncRequestId += 1 + useUIStore.getState().setRestoringCheckpoint(false) + } +} + +// trackRollbackRunBoundary 按 run_id 切分文件回退基线缓存,避免跨 run 复用旧 first-touch 映射。 +function trackRollbackRunBoundary(runId: string) { + const normalizedRunId = runId.trim() + if (!normalizedRunId) return + if (_currentRollbackRunId === normalizedRunId) return + + _currentRollbackRunId = normalizedRunId + _firstTouchRollbackCheckpointByPath = new Map() + + // pending 基线只作用于“下一轮”;一旦已在某个 run 消费,切到后续 run 即失效。 + if (_pendingRollbackAppliedRunId && _pendingRollbackAppliedRunId !== normalizedRunId) { + _pendingNextRunRollbackCheckpointId = undefined + _pendingRollbackAppliedRunId = undefined + } } /** @@ -71,13 +111,42 @@ function normalizeFilePath(input: string): string { return p } +// resolveRollbackCheckpointID 计算文件项的回退 checkpoint,优先首次触碰基线,避免被后续 checkpoint 覆盖。 +function resolveRollbackCheckpointID( + path: string, + fallback?: string, + allowLatestFallback: boolean = true, +): string | undefined { + const normalizedPath = normalizeFilePath(path) + if (!normalizedPath) return fallback + + const firstTouch = _firstTouchRollbackCheckpointByPath.get(normalizedPath) + if (firstTouch) return firstTouch + + const pending = _pendingNextRunRollbackCheckpointId + if (pending) { + _firstTouchRollbackCheckpointByPath.set(normalizedPath, pending) + if (_currentRollbackRunId && !_pendingRollbackAppliedRunId) { + _pendingRollbackAppliedRunId = _currentRollbackRunId + } + return pending + } + + const candidate = fallback || (allowLatestFallback ? _latestCheckpointId : undefined) + if (candidate) { + _firstTouchRollbackCheckpointByPath.set(normalizedPath, candidate) + } + return candidate +} + function _upsertFileChange( rawPath: string, - status: 'added' | 'modified' | 'deleted', + status: 'pending' | 'added' | 'modified' | 'deleted', parsed?: ParsedFileDiff, ) { const path = normalizeFilePath(rawPath) if (!path) return + const checkpointID = resolveRollbackCheckpointID(path) const existing = useUIStore.getState().fileChanges.find((c) => c.path === path) if (existing) { useUIStore.setState((s) => ({ @@ -90,7 +159,7 @@ function _upsertFileChange( deletions: parsed?.deletions ?? c.deletions, diff: parsed?.lines ?? c.diff, hunks: parsed?.hunks ?? c.hunks, - checkpoint_id: _latestCheckpointId ?? c.checkpoint_id, + checkpoint_id: c.checkpoint_id ?? checkpointID, } : c, ), @@ -104,11 +173,16 @@ function _upsertFileChange( deletions: parsed?.deletions ?? 0, diff: parsed?.lines, hunks: parsed?.hunks, - checkpoint_id: _latestCheckpointId, + checkpoint_id: checkpointID, }) } } +function normalizeChangeStatus(kind: unknown): 'added' | 'modified' | 'deleted' | undefined { + if (kind === 'added' || kind === 'modified' || kind === 'deleted') return kind + return undefined +} + /** 写文件工具名集合 */ const FILE_WRITE_TOOLS = new Set([ 'filesystem_write_file', @@ -129,15 +203,15 @@ function _trackFileChangeFromTool(toolName: string, argsRaw: string) { return } - // 统一用 modified 占位,真实状态由 tool_diff 事件覆盖 + // 统一用 pending 占位,真实状态由 tool_diff/run diff 事件覆盖 if (toolName === 'filesystem_move_file' || toolName === 'filesystem_copy_file') { const src = typeof args.source_path === 'string' ? args.source_path : '' const dst = typeof args.destination_path === 'string' ? args.destination_path : '' - if (src) _upsertFileChange(src, 'modified') - if (dst) _upsertFileChange(dst, 'modified') + if (src) _upsertFileChange(src, 'pending') + if (dst) _upsertFileChange(dst, 'pending') } else { const path = typeof args.path === 'string' ? args.path : '' - if (path) _upsertFileChange(path, 'modified') + if (path) _upsertFileChange(path, 'pending') } if (!useUIStore.getState().changesPanelOpen) { @@ -149,8 +223,15 @@ function _trackFileChangeFromTool(toolName: string, argsRaw: string) { function _applyToolDiff(payload: ToolDiffPayload) { // 多文件工具(move/copy) if (payload.diffs && payload.diffs.length > 0) { + const kindByPath = new Map() + for (const file of payload.files ?? []) { + const normalized = normalizeFilePath(file.path) + const status = normalizeChangeStatus(file.kind) + if (normalized && status) kindByPath.set(normalized, status) + } for (const entry of payload.diffs) { - const status: 'added' | 'modified' | 'deleted' = entry.was_new ? 'added' : 'modified' + const normalized = normalizeFilePath(entry.path) + const status = normalizeChangeStatus(entry.kind) ?? kindByPath.get(normalized) ?? (entry.was_new ? 'added' : 'modified') const parsed = entry.diff ? parseSingleFileDiff(entry.diff) : undefined _upsertFileChange(entry.path, status, parsed) } @@ -170,7 +251,24 @@ function _applyToolDiff(payload: ToolDiffPayload) { } } -function _fileChangesFromCheckpointDiff(diff: CheckpointDiffResultPayload) { +function _applyBashSideEffect(payload: BashSideEffectPayload) { + let changed = false + for (const change of payload.changes ?? []) { + const status = normalizeChangeStatus(change.kind) + if (!status) continue + _upsertFileChange(change.path, status) + changed = true + } + if (changed && !useUIStore.getState().changesPanelOpen) { + useUIStore.getState().toggleChangesPanel() + } +} + +function _fileChangesFromCheckpointDiff( + diff: CheckpointDiffResultPayload, + existingCheckpointByPath: Map, +) { + const authoritativeBaseline = diff.prev_checkpoint_id?.trim() || undefined const parsed = diff.patch ? parseUnifiedPatch(diff.patch) : {} const parsedByPath = new Map() for (const [path, parsedDiff] of Object.entries(parsed)) { @@ -191,6 +289,7 @@ function _fileChangesFromCheckpointDiff(diff: CheckpointDiffResultPayload) { .sort(([a], [b]) => a.localeCompare(b)) .map(([path, status]) => { const parsedDiff = parsedByPath.get(path) + const existingCheckpointID = existingCheckpointByPath.get(path) return { id: `fc_${path}`, path, @@ -199,7 +298,7 @@ function _fileChangesFromCheckpointDiff(diff: CheckpointDiffResultPayload) { deletions: parsedDiff?.deletions ?? 0, diff: parsedDiff?.lines, hunks: parsedDiff?.hunks, - checkpoint_id: diff.checkpoint_id, + checkpoint_id: authoritativeBaseline ?? existingCheckpointID, } }) } @@ -221,12 +320,57 @@ function _refreshRunFileChanges( if (runId !== useGatewayStore.getState().currentRunId) return if (sessionId !== useSessionStore.getState().currentSessionId) return if (!result?.payload) return - useUIStore.getState().replaceFileChanges(_fileChangesFromCheckpointDiff(result.payload)) + if (result.payload.warning) { + useUIStore.getState().showToast(`Checkpoint warning: ${result.payload.warning}`, 'info') + } + const existingCheckpointByPath = new Map( + useUIStore.getState().fileChanges.map((change) => [change.path, change.checkpoint_id]), + ) + useUIStore.getState().replaceFileChanges(_fileChangesFromCheckpointDiff(result.payload, existingCheckpointByPath)) }).catch((error) => { console.warn('[eventBridge] checkpoint.diff run scope failed:', error) }) } +// refreshSessionAfterCheckpointRestoreEvent 仅在当前会话收到 restore/undo 事件时刷新会话与文件变更视图。 +function refreshSessionAfterCheckpointRestoreEvent( + gatewayAPI: GatewayAPI, + payloadSessionId: string, + nextCheckpointId: string | undefined, +) { + const sessionId = payloadSessionId.trim() + const currentSessionId = useSessionStore.getState().currentSessionId.trim() + if (!sessionId || !currentSessionId || sessionId !== currentSessionId) { + return + } + + const normalizedNextCheckpointId = nextCheckpointId?.trim() + if (normalizedNextCheckpointId) { + _latestCheckpointId = normalizedNextCheckpointId + _pendingNextRunRollbackCheckpointId = normalizedNextCheckpointId + _pendingRollbackAppliedRunId = undefined + } + + const requestId = ++_latestRestoreSyncRequestId + const reloadSeq = beginCheckpointRestoreReloadSeq() + _firstTouchRollbackCheckpointByPath = new Map() + useUIStore.getState().setRestoringCheckpoint(true) + useUIStore.getState().clearFileChanges() + void reloadSessionAfterCheckpointRestore(gatewayAPI, sessionId, reloadSeq).then(() => { + if (requestId !== _latestRestoreSyncRequestId) return + if (normalizedNextCheckpointId) { + _latestCheckpointId = normalizedNextCheckpointId + _pendingNextRunRollbackCheckpointId = normalizedNextCheckpointId + } + useUIStore.getState().setRestoringCheckpoint(false) + }).catch((error) => { + if (requestId !== _latestRestoreSyncRequestId) return + useUIStore.getState().setRestoringCheckpoint(false) + console.warn('[eventBridge] failed to reload session after checkpoint restore:', error) + useUIStore.getState().showToast('Failed to refresh session after restore', 'error') + }) +} + function normalizePermissionPayload(raw: unknown): PermissionRequestPayload | null { const r = raw as Record | undefined if (!r || typeof r !== 'object') return null @@ -405,6 +549,7 @@ export function handleGatewayEvent(frame: MessageFrame, gatewayAPI: GatewayAPI) } case EventType.ToolStart: { + trackRollbackRunBoundary(frameRunId || useGatewayStore.getState().currentRunId) const toolName = strField(eventPayload, 'name') const toolArgs = strField(eventPayload, 'arguments') const msg = { @@ -432,11 +577,19 @@ export function handleGatewayEvent(frame: MessageFrame, gatewayAPI: GatewayAPI) } case EventType.ToolDiff: { + trackRollbackRunBoundary(frameRunId || useGatewayStore.getState().currentRunId) const diffPayload = eventPayload as ToolDiffPayload | undefined if (diffPayload) _applyToolDiff(diffPayload) break } + case EventType.BashSideEffect: { + trackRollbackRunBoundary(frameRunId || useGatewayStore.getState().currentRunId) + const payload = eventPayload as BashSideEffectPayload | undefined + if (payload) _applyBashSideEffect(payload) + break + } + case EventType.ToolChunk: { const toolCallId = strField(eventPayload, 'tool_call_id') if (toolCallId) useChatStore.getState().appendToolOutput(toolCallId, strField(eventPayload, 'content')) @@ -450,6 +603,8 @@ export function handleGatewayEvent(frame: MessageFrame, gatewayAPI: GatewayAPI) const sessionId = strField(eventPayload, 'session_id') || frameSessionId || '' const runId = strField(eventPayload, 'run_id') || frameRunId || '' if (runId) gwStore.setCurrentRunId(runId) + trackRollbackRunBoundary(runId) + _firstTouchRollbackCheckpointByPath = new Map() useUIStore.getState().clearFileChanges() if (sessionId && sessionId !== useSessionStore.getState().currentSessionId) { useSessionStore.getState().setCurrentSessionId(sessionId) @@ -717,7 +872,9 @@ export function handleGatewayEvent(frame: MessageFrame, gatewayAPI: GatewayAPI) if (_latestDoneToolCallId) { chatStore.attachCheckpointToToolCall(_latestDoneToolCallId, payload.checkpoint_id) } - _latestCheckpointId = payload.checkpoint_id + if (payload.reason !== CHECKPOINT_REASON_PRE_RESTORE_GUARD) { + _latestCheckpointId = payload.checkpoint_id + } if (payload.reason === 'end_of_turn' && gatewayAPI && frameSessionId && frameRunId) { _refreshRunFileChanges(gatewayAPI, frameSessionId, frameRunId, payload.checkpoint_id) } @@ -734,17 +891,27 @@ export function handleGatewayEvent(frame: MessageFrame, gatewayAPI: GatewayAPI) case EventType.CheckpointRestored: { const payload = eventPayload as CheckpointRestoredPayload | undefined - if (payload) insightStore.addCheckpointEvent(payload) - chatStore.markAllCheckpointsRestored() - uiStore.showToast('Checkpoint restored', 'success') + if (payload) { + insightStore.addCheckpointEvent(payload) + if (payload.session_id === useSessionStore.getState().currentSessionId) { + chatStore.markAllCheckpointsRestored() + refreshSessionAfterCheckpointRestoreEvent(gatewayAPI, payload.session_id, payload.checkpoint_id) + uiStore.showToast('Checkpoint restored', 'success') + } + } break } case EventType.CheckpointUndoRestore: { const payload = eventPayload as CheckpointUndoRestorePayload | undefined - if (payload) insightStore.addCheckpointEvent(payload) - chatStore.markAllCheckpointsAvailable() - uiStore.showToast('Checkpoint restore undone', 'success') + if (payload) { + insightStore.addCheckpointEvent(payload) + if (payload.session_id === useSessionStore.getState().currentSessionId) { + chatStore.markAllCheckpointsAvailable() + refreshSessionAfterCheckpointRestoreEvent(gatewayAPI, payload.session_id, payload.guard_checkpoint_id) + uiStore.showToast('Checkpoint restore undone', 'success') + } + } break } From 346eb106615409f9730474d68c728015f8df2a25 Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Sat, 9 May 2026 22:34:39 +0800 Subject: [PATCH 2/4] =?UTF-8?q?test:=E8=A1=A5=E5=85=85=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/App.test.tsx | 57 ++++++ web/src/api/gateway.test.ts | 62 +++++++ web/src/api/wsClient.test.ts | 164 ++++++++++++++++++ web/src/components/ErrorBoundary.test.tsx | 33 ++++ .../components/UpdateNotification.test.tsx | 55 ++++++ .../chat/AcceptanceMessage.test.tsx | 31 ++++ .../chat/CheckpointInlineMark.test.tsx | 39 +++++ web/src/components/chat/CodeBlock.test.tsx | 27 +++ web/src/components/chat/MessageItem.test.tsx | 49 ++++++ web/src/components/chat/MessageList.test.tsx | 36 ++++ web/src/components/chat/SkillPicker.test.tsx | 89 ++++++++++ .../components/chat/SlashCommandMenu.test.tsx | 66 +++++++ web/src/components/chat/TodoStrip.test.tsx | 49 ++++++ web/src/components/chat/ToolCallCard.test.tsx | 58 +++++++ .../chat/VerificationMessage.test.tsx | 32 ++++ web/src/components/layout/AppLayout.test.tsx | 60 +++++++ .../permission/PermissionDialog.test.tsx | 42 +++++ .../status/BudgetIndicator.test.tsx | 43 +++++ web/src/components/status/StatusBar.test.tsx | 32 ++++ web/src/components/ui/ToastContainer.test.tsx | 31 ++++ .../RuntimeProvider.lifecycle.test.tsx | 153 ++++++++++++++++ web/src/main.test.tsx | 23 +++ web/src/pages/ChatPage.test.tsx | 15 ++ web/src/pages/ConnectPage.test.tsx | 61 +++++++ web/src/stores/useGatewayStore.test.ts | 19 ++ web/src/stores/useWorkspaceStore.test.ts | 119 +++++++++++++ web/src/utils/slashCommands.test.ts | 68 ++++++++ web/vitest.config.ts | 9 + 28 files changed, 1522 insertions(+) create mode 100644 web/src/App.test.tsx create mode 100644 web/src/api/gateway.test.ts create mode 100644 web/src/api/wsClient.test.ts create mode 100644 web/src/components/ErrorBoundary.test.tsx create mode 100644 web/src/components/UpdateNotification.test.tsx create mode 100644 web/src/components/chat/AcceptanceMessage.test.tsx create mode 100644 web/src/components/chat/CheckpointInlineMark.test.tsx create mode 100644 web/src/components/chat/CodeBlock.test.tsx create mode 100644 web/src/components/chat/MessageItem.test.tsx create mode 100644 web/src/components/chat/MessageList.test.tsx create mode 100644 web/src/components/chat/SkillPicker.test.tsx create mode 100644 web/src/components/chat/SlashCommandMenu.test.tsx create mode 100644 web/src/components/chat/TodoStrip.test.tsx create mode 100644 web/src/components/chat/ToolCallCard.test.tsx create mode 100644 web/src/components/chat/VerificationMessage.test.tsx create mode 100644 web/src/components/layout/AppLayout.test.tsx create mode 100644 web/src/components/permission/PermissionDialog.test.tsx create mode 100644 web/src/components/status/BudgetIndicator.test.tsx create mode 100644 web/src/components/status/StatusBar.test.tsx create mode 100644 web/src/components/ui/ToastContainer.test.tsx create mode 100644 web/src/context/RuntimeProvider.lifecycle.test.tsx create mode 100644 web/src/main.test.tsx create mode 100644 web/src/pages/ChatPage.test.tsx create mode 100644 web/src/pages/ConnectPage.test.tsx create mode 100644 web/src/stores/useGatewayStore.test.ts create mode 100644 web/src/stores/useWorkspaceStore.test.ts create mode 100644 web/src/utils/slashCommands.test.ts diff --git a/web/src/App.test.tsx b/web/src/App.test.tsx new file mode 100644 index 00000000..cb65a5eb --- /dev/null +++ b/web/src/App.test.tsx @@ -0,0 +1,57 @@ +import { describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import App from './App' + +const runtimeState = { + status: 'connected', + mode: 'browser', + error: '', + loadingMessage: '', + retry: vi.fn(), +} + +vi.mock('./context/RuntimeProvider', () => ({ + useRuntime: () => runtimeState, +})) + +vi.mock('./pages/ChatPage', () => ({ + default: () =>
chat-page
, +})) + +vi.mock('./pages/ConnectPage', () => ({ + default: () =>
connect-page
, +})) + +describe('App routes by runtime status', () => { + it('renders loading screen', () => { + runtimeState.status = 'loading' + runtimeState.loadingMessage = 'booting' + render() + expect(screen.getByText('booting')).toBeInTheDocument() + }) + + it('renders connect page for browser needs_config', () => { + runtimeState.status = 'needs_config' + runtimeState.mode = 'browser' + render() + expect(screen.getByText('connect-page')).toBeInTheDocument() + }) + + it('renders electron error screen with retry', () => { + runtimeState.status = 'error' + runtimeState.mode = 'electron' + runtimeState.error = 'boom' + render() + expect(screen.getByText('Gateway connection failed')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'Retry connection' })) + expect(runtimeState.retry).toHaveBeenCalled() + }) + + it('renders chat routes when connected', () => { + runtimeState.status = 'connected' + runtimeState.mode = 'browser' + render() + expect(screen.getByText('chat-page')).toBeInTheDocument() + }) +}) + diff --git a/web/src/api/gateway.test.ts b/web/src/api/gateway.test.ts new file mode 100644 index 00000000..336b3f40 --- /dev/null +++ b/web/src/api/gateway.test.ts @@ -0,0 +1,62 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { GatewayAPI } from './gateway' +import { Method } from './protocol' + +describe('GatewayAPI', () => { + const call = vi.fn() + const ws = { call } as any + let api: GatewayAPI + + beforeEach(() => { + call.mockReset() + call.mockResolvedValue({ type: 'ack', payload: {} }) + api = new GatewayAPI(ws) + }) + + it('maps authenticate and run methods', async () => { + await api.authenticate('tok') + await api.run({ input_text: 'hello' }) + + expect(call).toHaveBeenNthCalledWith(1, Method.Authenticate, { token: 'tok' }) + expect(call).toHaveBeenNthCalledWith(2, Method.Run, { input_text: 'hello' }) + }) + + it('maps optional session_id in listModels', async () => { + await api.listModels() + await api.listModels('s1') + + expect(call).toHaveBeenNthCalledWith(1, Method.ListModels, undefined) + expect(call).toHaveBeenNthCalledWith(2, Method.ListModels, { session_id: 's1' }) + }) + + it('maps optional provider_id in setSessionModel', async () => { + await api.setSessionModel('s1', 'm1') + await api.setSessionModel('s1', 'm1', 'p1') + + expect(call).toHaveBeenNthCalledWith(1, Method.SetSessionModel, { session_id: 's1', model_id: 'm1' }) + expect(call).toHaveBeenNthCalledWith(2, Method.SetSessionModel, { session_id: 's1', model_id: 'm1', provider_id: 'p1' }) + }) + + it('maps workspace methods and optional remove_data', async () => { + await api.listWorkspaces() + await api.createWorkspace('/tmp/a', 'A') + await api.switchWorkspace('h1') + await api.renameWorkspace('h1', 'B') + await api.deleteWorkspace('h1', true) + + expect(call).toHaveBeenNthCalledWith(1, Method.ListWorkspaces) + expect(call).toHaveBeenNthCalledWith(2, Method.CreateWorkspace, { path: '/tmp/a', name: 'A' }) + expect(call).toHaveBeenNthCalledWith(3, Method.SwitchWorkspace, { workspace_hash: 'h1' }) + expect(call).toHaveBeenNthCalledWith(4, Method.RenameWorkspace, { workspace_hash: 'h1', name: 'B' }) + expect(call).toHaveBeenNthCalledWith(5, Method.DeleteWorkspace, { workspace_hash: 'h1', remove_data: true }) + }) + + it('maps permission and user question resolution', async () => { + await api.resolvePermission({ request_id: 'r1', decision: 'allow_once' }) + await api.resolveUserQuestion({ request_id: 'q1', status: 'answered', message: 'ok' }) + + expect(call).toHaveBeenNthCalledWith(1, Method.ResolvePermission, { request_id: 'r1', decision: 'allow_once' }) + expect(call).toHaveBeenNthCalledWith(2, Method.UserQuestionAnswer, { request_id: 'q1', status: 'answered', message: 'ok' }) + }) +}) + diff --git a/web/src/api/wsClient.test.ts b/web/src/api/wsClient.test.ts new file mode 100644 index 00000000..f107786c --- /dev/null +++ b/web/src/api/wsClient.test.ts @@ -0,0 +1,164 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { createWSClient } from './wsClient' +import { JSONRPC_VERSION, Method } from './protocol' + +class MockWebSocket { + static CONNECTING = 0 + static OPEN = 1 + static CLOSING = 2 + static CLOSED = 3 + + static instances: MockWebSocket[] = [] + + url: string + readyState = MockWebSocket.CONNECTING + onopen: (() => void) | null = null + onclose: (() => void) | null = null + onerror: (() => void) | null = null + onmessage: ((event: { data: string }) => void) | null = null + sent: string[] = [] + + constructor(url: string) { + this.url = url + MockWebSocket.instances.push(this) + } + + send(data: string) { + this.sent.push(data) + } + + close() { + this.readyState = MockWebSocket.CLOSED + this.onclose?.() + } + + open() { + this.readyState = MockWebSocket.OPEN + this.onopen?.() + } + + emit(data: unknown) { + this.onmessage?.({ data: typeof data === 'string' ? data : JSON.stringify(data) }) + } +} + +function latestWS(): MockWebSocket { + const ws = MockWebSocket.instances.at(-1) + if (!ws) throw new Error('websocket not created') + return ws +} + +describe('createWSClient', () => { + beforeEach(() => { + vi.useFakeTimers() + MockWebSocket.instances = [] + vi.stubGlobal('WebSocket', MockWebSocket as any) + }) + + afterEach(() => { + vi.clearAllTimers() + vi.useRealTimers() + vi.unstubAllGlobals() + }) + + it('builds ws url with endpoint and token', () => { + const client = createWSClient({ + baseURL: 'http://127.0.0.1:8080/', + endpoint: 'gateway/ws', + token: 'a b', + }) + client.connect() + expect(latestWS().url).toBe('ws://127.0.0.1:8080/gateway/ws?token=a%20b') + client.disconnect() + }) + + it('sends rpc request and resolves rpc response', async () => { + const client = createWSClient({ baseURL: 'http://127.0.0.1:8080', rpcTimeout: 50 }) + client.connect() + const ws = latestWS() + ws.open() + + const promise = client.call('gateway.ping', { x: 1 }) + const req = JSON.parse(ws.sent[0]) + expect(req).toMatchObject({ + jsonrpc: JSONRPC_VERSION, + method: 'gateway.ping', + params: { x: 1 }, + }) + + ws.emit({ jsonrpc: JSONRPC_VERSION, id: req.id, result: { ok: true } }) + await expect(promise).resolves.toEqual({ ok: true }) + client.disconnect() + }) + + it('rejects rpc call on timeout', async () => { + const client = createWSClient({ baseURL: 'http://127.0.0.1:8080', rpcTimeout: 20 }) + client.connect() + latestWS().open() + const promise = client.call('x.y') + promise.catch(() => {}) + await vi.advanceTimersByTimeAsync(21) + await expect(promise).rejects.toThrow('RPC timeout: x.y') + client.disconnect() + }) + + it('dispatches gateway.event notifications to event handlers', () => { + const client = createWSClient({ baseURL: 'http://127.0.0.1:8080' }) + const handler = vi.fn() + client.onEvent(handler) + + client.connect() + const ws = latestWS() + ws.open() + ws.emit({ + jsonrpc: JSONRPC_VERSION, + method: Method.Event, + params: { type: 'event', payload: { a: 1 } }, + }) + expect(handler).toHaveBeenCalledWith({ type: 'event', payload: { a: 1 } }) + client.disconnect() + }) + + it('transitions to error then reconnects and fires onReconnect', async () => { + const client = createWSClient({ + baseURL: 'http://127.0.0.1:8080', + reconnectBaseInterval: 10, + reconnectMaxInterval: 10, + }) + const states: string[] = [] + const onReconnect = vi.fn() + client.onStateChange((s) => states.push(s)) + client.onReconnect(onReconnect) + + client.connect() + const ws1 = latestWS() + ws1.open() + ws1.close() + + expect(states).toContain('error') + await vi.advanceTimersByTimeAsync(20) + const ws2 = latestWS() + expect(ws2).not.toBe(ws1) + ws2.open() + expect(onReconnect).toHaveBeenCalledTimes(1) + expect(client.getState()).toBe('connected') + client.disconnect() + }) + + it('closes stale connection by heartbeat timeout', async () => { + const client = createWSClient({ + baseURL: 'http://127.0.0.1:8080', + heartbeatTimeout: 10, + heartbeatCheckInterval: 5, + reconnectBaseInterval: 50, + reconnectMaxInterval: 50, + }) + client.connect() + const ws = latestWS() + const closeSpy = vi.spyOn(ws, 'close') + ws.open() + await vi.advanceTimersByTimeAsync(20) + expect(closeSpy).toHaveBeenCalled() + client.disconnect() + }) +}) diff --git a/web/src/components/ErrorBoundary.test.tsx b/web/src/components/ErrorBoundary.test.tsx new file mode 100644 index 00000000..dc43123b --- /dev/null +++ b/web/src/components/ErrorBoundary.test.tsx @@ -0,0 +1,33 @@ +import { describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import { ErrorBoundary } from './ErrorBoundary' + +function Crash() { + throw new Error('boom') +} + +describe('ErrorBoundary', () => { + it('renders fallback UI and supports retry', () => { + const errSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + render( + + + , + ) + expect(screen.getByText('The application encountered an error')).toBeInTheDocument() + expect(screen.getByText('boom')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'Retry' })) + errSpy.mockRestore() + }) + + it('uses custom fallback render prop', () => { + const errSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + render( + }> + + , + ) + expect(screen.getByRole('button', { name: 'custom:boom' })).toBeInTheDocument() + errSpy.mockRestore() + }) +}) diff --git a/web/src/components/UpdateNotification.test.tsx b/web/src/components/UpdateNotification.test.tsx new file mode 100644 index 00000000..37a5cf2a --- /dev/null +++ b/web/src/components/UpdateNotification.test.tsx @@ -0,0 +1,55 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { UpdateNotification } from './UpdateNotification' + +describe('UpdateNotification', () => { + beforeEach(() => { + Object.defineProperty(window, 'electronAPI', { + value: undefined, + configurable: true, + writable: true, + }) + }) + + it('renders nothing when electron api is missing', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + + it('shows available/downloaded updates and handles actions', async () => { + let onAvailable: ((info: any) => void) | null = null + let onDownloaded: ((info: any) => void) | null = null + const quitAndInstall = vi.fn().mockResolvedValue(undefined) + Object.defineProperty(window, 'electronAPI', { + value: { + onUpdateAvailable: (cb: any) => { + onAvailable = cb + return vi.fn() + }, + onUpdateDownloaded: (cb: any) => { + onDownloaded = cb + return vi.fn() + }, + quitAndInstall, + }, + configurable: true, + }) + render() + + act(() => { + onAvailable?.({ version: '1.2.3' }) + }) + await waitFor(() => { + expect(screen.getByText('A new version v1.2.3 is available, downloading...')).toBeInTheDocument() + }) + + act(() => { + onDownloaded?.({ version: '1.2.3' }) + }) + fireEvent.click(screen.getByRole('button', { name: 'Restart Now' })) + expect(quitAndInstall).toHaveBeenCalled() + + fireEvent.click(screen.getByTitle('Dismiss')) + expect(screen.queryByText('NeoCode v1.2.3 is ready to install')).not.toBeInTheDocument() + }) +}) diff --git a/web/src/components/chat/AcceptanceMessage.test.tsx b/web/src/components/chat/AcceptanceMessage.test.tsx new file mode 100644 index 00000000..b33e6e65 --- /dev/null +++ b/web/src/components/chat/AcceptanceMessage.test.tsx @@ -0,0 +1,31 @@ +import { describe, expect, it } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import AcceptanceMessage from './AcceptanceMessage' + +describe('AcceptanceMessage', () => { + it('renders accepted summary and expandable details', () => { + render( + , + ) + expect(screen.getByText('已接受')).toBeInTheDocument() + expect(screen.getByText('all good')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: '展开详情' })) + expect(screen.getByText('detail')).toBeInTheDocument() + }) +}) + diff --git a/web/src/components/chat/CheckpointInlineMark.test.tsx b/web/src/components/chat/CheckpointInlineMark.test.tsx new file mode 100644 index 00000000..de4a5d26 --- /dev/null +++ b/web/src/components/chat/CheckpointInlineMark.test.tsx @@ -0,0 +1,39 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import CheckpointInlineMark from './CheckpointInlineMark' +import { useSessionStore } from '@/stores/useSessionStore' +import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' + +let gatewayAPI: any = null +vi.mock('@/context/RuntimeProvider', () => ({ + useGatewayAPI: () => gatewayAPI, +})) + +describe('CheckpointInlineMark', () => { + beforeEach(() => { + gatewayAPI = { + restoreCheckpoint: vi.fn().mockResolvedValue({ payload: {} }), + undoRestore: vi.fn().mockResolvedValue({ payload: {} }), + checkpointDiff: vi.fn().mockResolvedValue({ payload: { files: { added: [], modified: [], deleted: [] }, patch: '' } }), + } + useSessionStore.setState({ currentSessionId: 's1' } as any) + useRuntimeInsightStore.getState().reset() + vi.spyOn(window, 'confirm').mockReturnValue(true) + }) + + it('restores checkpoint from available state', async () => { + render() + fireEvent.click(screen.getByRole('button', { name: /cp_abcdef/i })) + await waitFor(() => expect(gatewayAPI.restoreCheckpoint).toHaveBeenCalledWith({ + session_id: 's1', + checkpoint_id: 'abcdef123456', + })) + }) + + it('renders restored state and can undo restore', async () => { + render() + fireEvent.click(screen.getByRole('button', { name: /已撤回/ })) + await waitFor(() => expect(gatewayAPI.undoRestore).toHaveBeenCalledWith('s1')) + }) +}) + diff --git a/web/src/components/chat/CodeBlock.test.tsx b/web/src/components/chat/CodeBlock.test.tsx new file mode 100644 index 00000000..d8d10a69 --- /dev/null +++ b/web/src/components/chat/CodeBlock.test.tsx @@ -0,0 +1,27 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import CodeBlock from './CodeBlock' + +describe('CodeBlock', () => { + beforeEach(() => { + Object.assign(navigator, { + clipboard: { writeText: vi.fn() }, + }) + }) + + it('renders inline code and copies content', () => { + render() + const container = screen.getByText('const a = 1').closest('div') as HTMLElement + fireEvent.mouseEnter(container) + fireEvent.click(screen.getByTitle('复制')) + expect(navigator.clipboard.writeText).toHaveBeenCalledWith('const a = 1') + }) + + it('renders file code block with line numbers', () => { + render() + expect(screen.getByText('a.ts')).toBeInTheDocument() + expect(screen.getByText('line1')).toBeInTheDocument() + expect(screen.getByText('line2')).toBeInTheDocument() + }) +}) + diff --git a/web/src/components/chat/MessageItem.test.tsx b/web/src/components/chat/MessageItem.test.tsx new file mode 100644 index 00000000..cfe3db05 --- /dev/null +++ b/web/src/components/chat/MessageItem.test.tsx @@ -0,0 +1,49 @@ +import { describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import MessageItem from './MessageItem' + +vi.mock('./ToolCallCard', () => ({ default: () =>
tool-card
})) +vi.mock('./VerificationMessage', () => ({ default: () =>
verification-card
})) +vi.mock('./AcceptanceMessage', () => ({ default: () =>
acceptance-card
})) +vi.mock('./CodeBlock', () => ({ default: ({ code }: { code: string }) =>
{code}
})) +vi.mock('./MarkdownContent', () => ({ default: ({ content }: { content: string }) => {content} })) +vi.mock('@/context/RuntimeProvider', () => ({ useGatewayAPI: () => null })) + +describe('MessageItem', () => { + it('renders system message', () => { + render() + expect(screen.getByText('sys')).toBeInTheDocument() + }) + + it('renders welcome message', () => { + render() + expect(screen.getByText('hello')).toBeInTheDocument() + }) + + it('renders thinking message and toggles details', () => { + render( + , + ) + fireEvent.click(screen.getByText('AI 思考过程')) + expect(screen.getByText('reasoning')).toBeInTheDocument() + }) + + it('renders tool/verification/acceptance delegates', () => { + const { rerender } = render() + expect(screen.getByText('tool-card')).toBeInTheDocument() + rerender() + expect(screen.getByText('verification-card')).toBeInTheDocument() + rerender() + expect(screen.getByText('acceptance-card')).toBeInTheDocument() + }) + + it('renders code and plain assistant messages', () => { + const { rerender } = render() + expect(screen.getByText('const a=1')).toBeInTheDocument() + rerender() + expect(screen.getByText('answer')).toBeInTheDocument() + }) +}) + diff --git a/web/src/components/chat/MessageList.test.tsx b/web/src/components/chat/MessageList.test.tsx new file mode 100644 index 00000000..a8af748b --- /dev/null +++ b/web/src/components/chat/MessageList.test.tsx @@ -0,0 +1,36 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { render, screen } from '@testing-library/react' +import MessageList from './MessageList' +import { useChatStore } from '@/stores/useChatStore' + +vi.mock('./MessageItem', () => ({ + default: ({ message, groupedWithPrev }: any) => ( +
{message.id}:{groupedWithPrev ? 'group' : 'solo'}
+ ), +})) + +describe('MessageList', () => { + beforeEach(() => { + useChatStore.setState({ messages: [], isGenerating: false } as any) + }) + + it('renders empty state when no messages', () => { + render() + expect(screen.getByText('开始你的 AI 编程之旅')).toBeInTheDocument() + }) + + it('reorders process messages before assistant text within AI turn', () => { + useChatStore.setState({ + messages: [ + { id: 'u1', role: 'user', type: 'text', content: 'q', timestamp: 1 }, + { id: 'a1', role: 'assistant', type: 'text', content: 'answer', timestamp: 2 }, + { id: 't1', role: 'tool', type: 'tool_call', content: '', timestamp: 3 }, + { id: 'a2', role: 'assistant', type: 'thinking', content: 'thinking', timestamp: 4 }, + ], + } as any) + + render() + const ids = screen.getAllByTestId(/msg-/).map((x) => x.textContent) + expect(ids).toEqual(['u1:solo', 't1:solo', 'a2:group', 'a1:group']) + }) +}) diff --git a/web/src/components/chat/SkillPicker.test.tsx b/web/src/components/chat/SkillPicker.test.tsx new file mode 100644 index 00000000..2db59a17 --- /dev/null +++ b/web/src/components/chat/SkillPicker.test.tsx @@ -0,0 +1,89 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import SkillPicker from './SkillPicker' +import { useChatStore } from '@/stores/useChatStore' +import { useUIStore } from '@/stores/useUIStore' + +describe('SkillPicker', () => { + beforeEach(() => { + useChatStore.setState({ isGenerating: false } as any) + useUIStore.setState({ showToast: vi.fn() } as any) + }) + + it('renders empty state when no skills', async () => { + const api = { + listAvailableSkills: vi.fn().mockResolvedValue({ payload: { skills: [] } }), + } as any + render() + await screen.findByText('暂无可用技能') + }) + + it('toggles skill activation and reloads list', async () => { + const api = { + listAvailableSkills: vi + .fn() + .mockResolvedValueOnce({ + payload: { + skills: [{ + active: false, + descriptor: { id: 'sk1', name: 'Skill 1', description: 'desc', scope: 'explicit' }, + }], + }, + }) + .mockResolvedValueOnce({ + payload: { + skills: [{ + active: true, + descriptor: { id: 'sk1', name: 'Skill 1', description: 'desc', scope: 'explicit' }, + }], + }, + }), + activateSessionSkill: vi.fn().mockResolvedValue({ payload: {} }), + deactivateSessionSkill: vi.fn().mockResolvedValue({ payload: {} }), + } as any + + render() + const activateBtn = await screen.findByRole('button', { name: '激活' }) + fireEvent.click(activateBtn) + + await waitFor(() => expect(api.activateSessionSkill).toHaveBeenCalledWith('s1', 'sk1')) + expect(api.listAvailableSkills).toHaveBeenCalledTimes(2) + }) + + it('blocks toggle when session is missing', async () => { + const showToast = vi.fn() + useUIStore.setState({ showToast } as any) + const api = { + listAvailableSkills: vi.fn().mockResolvedValue({ + payload: { + skills: [{ active: false, descriptor: { id: 'sk1', name: 'Skill 1' } }], + }, + }), + activateSessionSkill: vi.fn(), + } as any + + render() + const activateBtn = await screen.findByRole('button', { name: '激活' }) + fireEvent.click(activateBtn) + + expect(api.activateSessionSkill).not.toHaveBeenCalled() + expect(showToast).toHaveBeenCalledWith('Send a message first to start a session', 'error') + }) + + it('disables operation while generating', async () => { + useChatStore.setState({ isGenerating: true } as any) + const api = { + listAvailableSkills: vi.fn().mockResolvedValue({ + payload: { + skills: [{ active: false, descriptor: { id: 'sk1', name: 'Skill 1' } }], + }, + }), + activateSessionSkill: vi.fn(), + } as any + + render() + const activateBtn = await screen.findByRole('button', { name: '激活' }) + expect(activateBtn).toBeDisabled() + }) +}) + diff --git a/web/src/components/chat/SlashCommandMenu.test.tsx b/web/src/components/chat/SlashCommandMenu.test.tsx new file mode 100644 index 00000000..b4f0c3ec --- /dev/null +++ b/web/src/components/chat/SlashCommandMenu.test.tsx @@ -0,0 +1,66 @@ +import { describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import SlashCommandMenu from './SlashCommandMenu' +import type { AnySlashCommand } from '@/utils/slashCommands' + +const builtin: AnySlashCommand = { + id: 'compact', + usage: '/compact', + description: 'compress context', + hasArgument: false, +} + +const skill: AnySlashCommand = { + id: 'skill.demo', + usage: '/skill.demo', + description: 'demo skill', + hasArgument: false, + isSkill: true, + skillId: 'skill.demo', + active: true, +} + +describe('SlashCommandMenu', () => { + ;(HTMLElement.prototype as any).scrollIntoView = vi.fn() + + it('returns null when commands is empty', () => { + const { container } = render( + , + ) + expect(container.firstChild).toBeNull() + }) + + it('renders builtin and skill sections and highlights query', () => { + render( + , + ) + expect(screen.getByText('命令')).toBeInTheDocument() + expect(screen.getByText('技能')).toBeInTheDocument() + expect(screen.getByText('已激活')).toBeInTheDocument() + expect(screen.getAllByText((_, el) => Boolean(el?.textContent?.includes('/compact'))).length).toBeGreaterThan(0) + }) + + it('triggers hover/select callbacks', () => { + const onSelect = vi.fn() + const onHover = vi.fn() + render( + , + ) + fireEvent.mouseEnter(screen.getByText('/compact')) + fireEvent.click(screen.getByText('/skill.demo')) + expect(onHover).toHaveBeenCalledWith(0) + expect(onSelect).toHaveBeenCalledWith(skill) + }) +}) diff --git a/web/src/components/chat/TodoStrip.test.tsx b/web/src/components/chat/TodoStrip.test.tsx new file mode 100644 index 00000000..f5911467 --- /dev/null +++ b/web/src/components/chat/TodoStrip.test.tsx @@ -0,0 +1,49 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import TodoStrip from './TodoStrip' +import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' +import { useUIStore } from '@/stores/useUIStore' + +describe('TodoStrip', () => { + beforeEach(() => { + useRuntimeInsightStore.getState().reset() + useUIStore.setState({ + todoStripExpanded: false, + setTodoStripExpanded: vi.fn(), + } as any) + }) + + it('renders nothing when no snapshot and no conflict', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + + it('renders summary and items from snapshot', () => { + useRuntimeInsightStore.setState({ + todoSnapshot: { + items: [ + { id: '1', content: 'Task 1', status: 'in_progress', required: true, revision: 1 }, + { id: '2', content: 'Task 2', status: 'completed', required: true, revision: 1 }, + ], + summary: { total: 2, required_total: 2, required_completed: 1, required_failed: 0, required_open: 1 }, + }, + todoHistory: { + '1': { id: '1', content: 'Task 1', status: 'in_progress', required: true, revision: 1, firstSeenAt: 1, lastSeenAt: 1 }, + '2': { id: '2', content: 'Task 2', status: 'completed', required: true, revision: 1, firstSeenAt: 1, lastSeenAt: 1 }, + }, + } as any) + render() + expect(screen.getByText('Task 1')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { expanded: false })) + expect(useUIStore.getState().setTodoStripExpanded).toHaveBeenCalled() + }) + + it('forces expanded conflict state', () => { + useRuntimeInsightStore.setState({ + todoConflict: { action: 'conflict', reason: 'manual check needed' }, + } as any) + render() + expect(screen.getByText(/Todo 冲突/)).toBeInTheDocument() + }) +}) + diff --git a/web/src/components/chat/ToolCallCard.test.tsx b/web/src/components/chat/ToolCallCard.test.tsx new file mode 100644 index 00000000..cb79f8df --- /dev/null +++ b/web/src/components/chat/ToolCallCard.test.tsx @@ -0,0 +1,58 @@ +import { describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import ToolCallCard from './ToolCallCard' + +vi.mock('./CheckpointInlineMark', () => ({ + default: ({ checkpointId }: { checkpointId: string }) => cp:{checkpointId}, +})) + +describe('ToolCallCard', () => { + it('shows running state and expands/collapses', () => { + render( + , + ) + expect(screen.getByText('bash')).toBeInTheDocument() + expect(screen.getByText('$')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { expanded: true })) + }) + + it('renders file edit diff detail', () => { + render( + , + ) + fireEvent.click(screen.getByRole('button', { expanded: false })) + expect(screen.getAllByText('a.ts').length).toBeGreaterThan(0) + expect(screen.getByText('old')).toBeInTheDocument() + expect(screen.getByText('new')).toBeInTheDocument() + expect(screen.getByText('cp:cp1')).toBeInTheDocument() + }) +}) diff --git a/web/src/components/chat/VerificationMessage.test.tsx b/web/src/components/chat/VerificationMessage.test.tsx new file mode 100644 index 00000000..dadfb094 --- /dev/null +++ b/web/src/components/chat/VerificationMessage.test.tsx @@ -0,0 +1,32 @@ +import { describe, expect, it } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import VerificationMessage from './VerificationMessage' + +describe('VerificationMessage', () => { + it('renders running summary and stage details', () => { + render( + , + ) + expect(screen.getByText(/Verify running/)).toBeInTheDocument() + fireEvent.click(screen.getByRole('button')) + expect(screen.getByText('test')).toBeInTheDocument() + }) +}) + diff --git a/web/src/components/layout/AppLayout.test.tsx b/web/src/components/layout/AppLayout.test.tsx new file mode 100644 index 00000000..e98098b4 --- /dev/null +++ b/web/src/components/layout/AppLayout.test.tsx @@ -0,0 +1,60 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import AppLayout from './AppLayout' +import { useUIStore } from '@/stores/useUIStore' +import { useSessionStore } from '@/stores/useSessionStore' +import { useWorkspaceStore } from '@/stores/useWorkspaceStore' + +vi.mock('./Sidebar', () => ({ default: ({ collapsed }: { collapsed?: boolean }) =>
{collapsed ? 'sidebar-collapsed' : 'sidebar-open'}
})) +vi.mock('@/components/chat/ChatPanel', () => ({ default: () =>
chat-panel
})) +vi.mock('@/components/panels/FileChangePanel', () => ({ default: () =>
changes-panel
})) +vi.mock('@/components/panels/FileTreePanel', () => ({ default: () =>
tree-panel
})) +vi.mock('@/components/status/StatusBar', () => ({ default: () =>
status-bar
})) +vi.mock('@/components/ui/ToastContainer', () => ({ default: () =>
toast-container
})) + +describe('AppLayout', () => { + beforeEach(() => { + useUIStore.setState({ + sidebarOpen: true, + sidebarWidth: 280, + setSidebarWidth: vi.fn(), + changesPanelOpen: false, + changesPanelWidth: 360, + setChangesPanelWidth: vi.fn(), + fileTreePanelOpen: false, + fileTreePanelWidth: 320, + setFileTreePanelWidth: vi.fn(), + } as any) + useSessionStore.setState({ + prepareNewChat: vi.fn(), + } as any) + useWorkspaceStore.setState({ currentWorkspaceHash: '' } as any) + }) + + it('renders main layout with sidebar open', () => { + render() + expect(screen.getByText('sidebar-open')).toBeInTheDocument() + expect(screen.getByText('chat-panel')).toBeInTheDocument() + }) + + it('renders collapsed sidebar and right panels when toggled', () => { + useUIStore.setState({ + sidebarOpen: false, + changesPanelOpen: true, + fileTreePanelOpen: true, + } as any) + render() + expect(screen.getByText('sidebar-collapsed')).toBeInTheDocument() + expect(screen.getByText('changes-panel')).toBeInTheDocument() + expect(screen.getByText('tree-panel')).toBeInTheDocument() + }) + + it('handles ctrl/cmd+n shortcut', () => { + const prepareNewChat = vi.fn() + useSessionStore.setState({ prepareNewChat } as any) + render() + fireEvent.keyDown(window, { key: 'n', ctrlKey: true }) + expect(prepareNewChat).toHaveBeenCalled() + }) +}) + diff --git a/web/src/components/permission/PermissionDialog.test.tsx b/web/src/components/permission/PermissionDialog.test.tsx new file mode 100644 index 00000000..7db1340e --- /dev/null +++ b/web/src/components/permission/PermissionDialog.test.tsx @@ -0,0 +1,42 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import PermissionDialog from './PermissionDialog' +import { useChatStore } from '@/stores/useChatStore' + +let gatewayAPI: any = null +vi.mock('@/context/RuntimeProvider', () => ({ + useGatewayAPI: () => gatewayAPI, +})) + +describe('PermissionDialog', () => { + beforeEach(() => { + useChatStore.setState({ permissionRequests: [] } as any) + gatewayAPI = null + }) + + it('returns null without request or gateway api', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + + it('renders request details and resolves decisions', async () => { + gatewayAPI = { resolvePermission: vi.fn().mockResolvedValue(undefined) } + useChatStore.setState({ + permissionRequests: [{ + request_id: 'r1', + tool_name: 'bash', + operation: 'run', + target: '/tmp', + }], + } as any) + + render() + expect(screen.getByText('权限请求')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: /允许一次/i })) + expect(gatewayAPI.resolvePermission).toHaveBeenCalledWith({ + request_id: 'r1', + decision: 'allow_once', + }) + }) +}) + diff --git a/web/src/components/status/BudgetIndicator.test.tsx b/web/src/components/status/BudgetIndicator.test.tsx new file mode 100644 index 00000000..5387d12f --- /dev/null +++ b/web/src/components/status/BudgetIndicator.test.tsx @@ -0,0 +1,43 @@ +import { beforeEach, describe, expect, it } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import BudgetIndicator from './BudgetIndicator' +import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' + +describe('BudgetIndicator', () => { + beforeEach(() => { + useRuntimeInsightStore.getState().reset() + }) + + it('renders null when no budget data', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + + it('shows budget and popover details', () => { + useRuntimeInsightStore.setState({ + budgetChecked: { + attempt_seq: 1, + request_hash: 'h', + action: 'allow', + estimated_input_tokens: 100, + prompt_budget: 200, + context_window: 1000, + }, + budgetUsageRatio: 0.5, + } as any) + render() + fireEvent.click(screen.getByTitle('点击查看预算明细')) + expect(screen.getByText('预算明细')).toBeInTheDocument() + expect(screen.getByText('allow')).toBeInTheDocument() + }) + + it('shows estimate failed message', () => { + useRuntimeInsightStore.setState({ + budgetEstimateFailed: { attempt_seq: 1, request_hash: 'h', message: 'estimate failed' }, + } as any) + render() + fireEvent.click(screen.getByTitle('点击查看预算明细')) + expect(screen.getByText('estimate failed')).toBeInTheDocument() + }) +}) + diff --git a/web/src/components/status/StatusBar.test.tsx b/web/src/components/status/StatusBar.test.tsx new file mode 100644 index 00000000..5b18dc3c --- /dev/null +++ b/web/src/components/status/StatusBar.test.tsx @@ -0,0 +1,32 @@ +import { describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import StatusBar from './StatusBar' + +const runtime = { + mode: 'browser', + workdir: '', + selectWorkdir: vi.fn().mockResolvedValue(''), +} + +vi.mock('@/context/RuntimeProvider', () => ({ + useRuntime: () => runtime, +})) + +describe('StatusBar', () => { + it('does not render workdir in browser mode', () => { + runtime.mode = 'browser' + runtime.workdir = '/a' + render() + expect(screen.queryByTitle('点击切换工作区')).not.toBeInTheDocument() + }) + + it('renders and triggers workdir picker in electron mode', async () => { + runtime.mode = 'electron' + runtime.workdir = '/repo' + render() + const btn = screen.getByTitle('点击切换工作区') + fireEvent.click(btn) + expect(runtime.selectWorkdir).toHaveBeenCalled() + }) +}) + diff --git a/web/src/components/ui/ToastContainer.test.tsx b/web/src/components/ui/ToastContainer.test.tsx new file mode 100644 index 00000000..99e61dcb --- /dev/null +++ b/web/src/components/ui/ToastContainer.test.tsx @@ -0,0 +1,31 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import ToastContainer from './ToastContainer' +import { useUIStore } from '@/stores/useUIStore' + +describe('ToastContainer', () => { + beforeEach(() => { + useUIStore.setState({ + toasts: [], + dismissToast: vi.fn(), + } as any) + }) + + it('renders null when no toasts', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + + it('renders toasts and dismisses on close click', () => { + const dismissToast = vi.fn() + useUIStore.setState({ + toasts: [{ id: 't1', message: 'ok', type: 'success' }], + dismissToast, + } as any) + render() + expect(screen.getByText('ok')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button')) + expect(dismissToast).toHaveBeenCalledWith('t1') + }) +}) + diff --git a/web/src/context/RuntimeProvider.lifecycle.test.tsx b/web/src/context/RuntimeProvider.lifecycle.test.tsx new file mode 100644 index 00000000..a21b9d55 --- /dev/null +++ b/web/src/context/RuntimeProvider.lifecycle.test.tsx @@ -0,0 +1,153 @@ +import { act, render, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { RuntimeProvider, useRuntime } from './RuntimeProvider' +import { useChatStore } from '@/stores/useChatStore' +import { useSessionStore } from '@/stores/useSessionStore' +import { useWorkspaceStore } from '@/stores/useWorkspaceStore' + +const clients: any[] = [] + +vi.mock('@/api/wsClient', () => ({ + createWSClient: vi.fn(() => { + let onState: ((s: any) => void) | null = null + let onEvent: ((f: any) => void) | null = null + let onReconnect: (() => void) | null = null + const client = { + connect: vi.fn(() => onState?.('connected')), + disconnect: vi.fn(() => onState?.('disconnected')), + reconnect: vi.fn(), + call: vi.fn(async (method: string) => { + if (method === 'gateway.authenticate') return { payload: {} } + if (method === 'gateway.listWorkspaces') return { payload: { workspaces: [] } } + if (method === 'gateway.ping') return { payload: {} } + return { payload: {} } + }), + onEvent: vi.fn((h: any) => { + onEvent = h + return () => { + if (onEvent === h) onEvent = null + } + }), + onStateChange: vi.fn((h: any) => { + onState = h + return () => { + if (onState === h) onState = null + } + }), + onReconnect: vi.fn((h: any) => { + onReconnect = h + return () => { + if (onReconnect === h) onReconnect = null + } + }), + _emitState: (s: any) => onState?.(s), + } + clients.push(client) + return client + }), +})) + +function RuntimeProbe({ onReady }: { onReady: (value: ReturnType) => void }) { + const runtime = useRuntime() + onReady(runtime) + return null +} + +describe('RuntimeProvider lifecycle', () => { + beforeEach(() => { + clients.length = 0 + sessionStorage.clear() + Object.defineProperty(window.navigator, 'userAgent', { + value: 'Mozilla/5.0 Chrome/120 Safari/537.36', + configurable: true, + }) + Object.defineProperty(window, 'electronAPI', { + value: undefined, + configurable: true, + writable: true, + }) + + useSessionStore.setState({ + fetchSessions: vi.fn().mockResolvedValue(undefined), + initializeActiveSession: vi.fn().mockResolvedValue(undefined), + setProjects: vi.fn(), + setCurrentSessionId: vi.fn(), + setCurrentProjectId: vi.fn(), + currentSessionId: '', + } as any) + useWorkspaceStore.setState({ + fetchWorkspaces: vi.fn().mockResolvedValue(undefined), + setWorkspaces: vi.fn(), + setCurrentWorkspaceHash: vi.fn(), + workspaces: [], + } as any) + useChatStore.setState({ + clearMessages: vi.fn(), + clearPendingUserQuestion: vi.fn(), + resetGeneratingState: vi.fn(), + } as any) + }) + + it('connects from stored browser config and exposes connected runtime', async () => { + sessionStorage.setItem( + 'neocode.browserRuntimeConfig', + JSON.stringify({ mode: 'browser', gatewayBaseURL: 'http://127.0.0.1:8080', token: 'tok' }), + ) + let runtimeSnapshot: any = null + render( + + { runtimeSnapshot = rt }} /> + , + ) + + await waitFor(() => { + expect(runtimeSnapshot?.status).toBe('connected') + expect(runtimeSnapshot?.gatewayAPI).toBeTruthy() + }) + expect(clients).toHaveLength(1) + expect(clients[0].connect).toHaveBeenCalled() + }) + + it('retry reconnects with existing config', async () => { + sessionStorage.setItem( + 'neocode.browserRuntimeConfig', + JSON.stringify({ mode: 'browser', gatewayBaseURL: 'http://127.0.0.1:8080', token: 'tok' }), + ) + let runtimeSnapshot: any = null + render( + + { runtimeSnapshot = rt }} /> + , + ) + await waitFor(() => expect(runtimeSnapshot?.status).toBe('connected')) + + await act(async () => { + await runtimeSnapshot.retry() + }) + expect(clients.length).toBeGreaterThanOrEqual(2) + }) + + it('resetBrowserConfig clears store-facing runtime state', async () => { + sessionStorage.setItem( + 'neocode.browserRuntimeConfig', + JSON.stringify({ mode: 'browser', gatewayBaseURL: 'http://127.0.0.1:8080', token: 'tok' }), + ) + let runtimeSnapshot: any = null + const chatClear = useChatStore.getState().clearMessages as any + render( + + { runtimeSnapshot = rt }} /> + , + ) + await waitFor(() => expect(runtimeSnapshot?.status).toBe('connected')) + + act(() => { + runtimeSnapshot.resetBrowserConfig() + }) + + expect(sessionStorage.getItem('neocode.browserRuntimeConfig')).toBeNull() + expect(chatClear).toHaveBeenCalled() + expect(runtimeSnapshot.status).toBe('needs_config') + }) +}) + diff --git a/web/src/main.test.tsx b/web/src/main.test.tsx new file mode 100644 index 00000000..562be0e3 --- /dev/null +++ b/web/src/main.test.tsx @@ -0,0 +1,23 @@ +import { describe, expect, it, vi } from 'vitest' + +describe('main entry', () => { + it('mounts app with createRoot', async () => { + document.body.innerHTML = '
' + const render = vi.fn() + const createRoot = vi.fn(() => ({ render })) + + vi.doMock('react-dom/client', () => ({ + default: { createRoot }, + createRoot, + })) + vi.doMock('./App', () => ({ default: () => null })) + vi.doMock('./context/RuntimeProvider', () => ({ + RuntimeProvider: ({ children }: { children: React.ReactNode }) => <>{children}, + })) + + await import('./main') + expect(createRoot).toHaveBeenCalledWith(document.getElementById('root')) + expect(render).toHaveBeenCalled() + }) +}) + diff --git a/web/src/pages/ChatPage.test.tsx b/web/src/pages/ChatPage.test.tsx new file mode 100644 index 00000000..00d25aef --- /dev/null +++ b/web/src/pages/ChatPage.test.tsx @@ -0,0 +1,15 @@ +import { describe, expect, it, vi } from 'vitest' +import { render, screen } from '@testing-library/react' +import ChatPage from './ChatPage' + +vi.mock('@/components/layout/AppLayout', () => ({ + default: () =>
layout-mounted
, +})) + +describe('ChatPage', () => { + it('renders AppLayout', () => { + render() + expect(screen.getByText('layout-mounted')).toBeInTheDocument() + }) +}) + diff --git a/web/src/pages/ConnectPage.test.tsx b/web/src/pages/ConnectPage.test.tsx new file mode 100644 index 00000000..cc348fe8 --- /dev/null +++ b/web/src/pages/ConnectPage.test.tsx @@ -0,0 +1,61 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import ConnectPage from './ConnectPage' + +const runtime = { + connectBrowser: vi.fn().mockResolvedValue(undefined), + startLocalGateway: vi.fn().mockResolvedValue(undefined), + retry: vi.fn().mockResolvedValue(undefined), + error: '', + status: 'needs_config', + vitePluginAvailable: false, + defaultBrowserGatewayBaseURL: 'http://127.0.0.1:8080', +} + +vi.mock('@/context/RuntimeProvider', () => ({ + useRuntime: () => runtime, +})) + +describe('ConnectPage', () => { + beforeEach(() => { + runtime.connectBrowser.mockClear() + runtime.startLocalGateway.mockClear() + runtime.retry.mockClear() + runtime.error = '' + runtime.status = 'needs_config' + runtime.vitePluginAvailable = false + }) + + it('submits manual connect form', async () => { + render() + fireEvent.change(screen.getByLabelText('Gateway 地址'), { target: { value: 'http://localhost:9000' } }) + fireEvent.change(screen.getByLabelText('Token(本地模式可留空)'), { target: { value: 'tok' } }) + fireEvent.click(screen.getByRole('button', { name: '连接' })) + + await waitFor(() => { + expect(runtime.connectBrowser).toHaveBeenCalledWith({ + gatewayBaseURL: 'http://localhost:9000', + token: 'tok', + }) + }) + }) + + it('validates local gateway port before start', async () => { + runtime.vitePluginAvailable = true + render() + const portInput = screen.getByLabelText('端口') + fireEvent.change(portInput, { target: { value: '99999' } }) + fireEvent.submit(portInput.closest('form') as HTMLFormElement) + expect(await screen.findByText('Please enter a valid port (1-65535)')).toBeInTheDocument() + expect(runtime.startLocalGateway).not.toHaveBeenCalled() + }) + + it('shows retry button when error and triggers retry', () => { + runtime.status = 'error' + runtime.error = 'connect failed' + render() + fireEvent.click(screen.getByRole('button', { name: '重试' })) + expect(runtime.retry).toHaveBeenCalled() + expect(screen.getByText('connect failed')).toBeInTheDocument() + }) +}) diff --git a/web/src/stores/useGatewayStore.test.ts b/web/src/stores/useGatewayStore.test.ts new file mode 100644 index 00000000..d31a8e40 --- /dev/null +++ b/web/src/stores/useGatewayStore.test.ts @@ -0,0 +1,19 @@ +import { describe, expect, it } from 'vitest' +import { useGatewayStore } from './useGatewayStore' + +describe('useGatewayStore', () => { + it('updates and resets gateway state', () => { + const store = useGatewayStore.getState() + store.setConnectionState('connected') + store.setToken('tok') + store.setCurrentRunId('run1') + store.setAuthenticated(true) + store.notifyProviderChanged() + expect(useGatewayStore.getState().providerChangeTick).toBe(1) + expect(useGatewayStore.getState().authenticated).toBe(true) + store.reset() + expect(useGatewayStore.getState().connectionState).toBe('disconnected') + expect(useGatewayStore.getState().token).toBe('') + }) +}) + diff --git a/web/src/stores/useWorkspaceStore.test.ts b/web/src/stores/useWorkspaceStore.test.ts new file mode 100644 index 00000000..2ffe4a86 --- /dev/null +++ b/web/src/stores/useWorkspaceStore.test.ts @@ -0,0 +1,119 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { useWorkspaceStore } from './useWorkspaceStore' +import { useChatStore } from './useChatStore' +import { useSessionStore } from './useSessionStore' +import { useUIStore } from './useUIStore' +import { useGatewayStore } from './useGatewayStore' + +function flushPromises() { + return new Promise((resolve) => setTimeout(resolve, 0)) +} + +describe('useWorkspaceStore', () => { + beforeEach(() => { + useWorkspaceStore.setState({ + workspaces: [], + currentWorkspaceHash: '', + loading: false, + } as any) + + useChatStore.setState({ + isGenerating: false, + clearMessages: vi.fn(), + setTransitioning: vi.fn(), + } as any) + useSessionStore.setState({ + fetchSessions: vi.fn().mockResolvedValue(undefined), + resetForWorkspaceSwitch: vi.fn(), + currentSessionId: '', + currentProjectId: '', + projects: [], + } as any) + useUIStore.setState({ + showToast: vi.fn(), + clearFileChanges: vi.fn(), + resetPreviewTabs: vi.fn(), + setSearchQuery: vi.fn(), + } as any) + useGatewayStore.setState({ + setCurrentRunId: vi.fn(), + notifyProviderChanged: vi.fn(), + } as any) + }) + + it('deduplicates concurrent fetchWorkspaces calls', async () => { + let resolveList: ((value: any) => void) | null = null + const gatewayAPI = { + listWorkspaces: vi.fn(() => new Promise((resolve) => { resolveList = resolve })), + } as any + + const p1 = useWorkspaceStore.getState().fetchWorkspaces(gatewayAPI) + const p2 = useWorkspaceStore.getState().fetchWorkspaces(gatewayAPI) + expect(gatewayAPI.listWorkspaces).toHaveBeenCalledTimes(1) + + resolveList?.({ + payload: { + workspaces: [{ hash: 'w1', path: '/a', name: 'A', created_at: '1', updated_at: '2' }], + }, + }) + await Promise.all([p1, p2]) + expect(useWorkspaceStore.getState().currentWorkspaceHash).toBe('w1') + }) + + it('blocks switchWorkspace while generating', async () => { + const showToast = vi.fn() + useChatStore.setState({ isGenerating: true } as any) + useUIStore.setState({ showToast } as any) + const gatewayAPI = { switchWorkspace: vi.fn() } as any + + await useWorkspaceStore.getState().switchWorkspace('w2', gatewayAPI) + + expect(gatewayAPI.switchWorkspace).not.toHaveBeenCalled() + expect(showToast).toHaveBeenCalledWith('Cannot switch workspace while generating; stop the current run first.', 'info') + }) + + it('switchWorkspace clears session/UI state then fetches sessions', async () => { + const gatewayAPI = { switchWorkspace: vi.fn().mockResolvedValue(undefined) } as any + const fetchSessions = useSessionStore.getState().fetchSessions as any + + await useWorkspaceStore.getState().switchWorkspace('w2', gatewayAPI) + + expect(useChatStore.getState().clearMessages).toHaveBeenCalled() + expect(useSessionStore.getState().resetForWorkspaceSwitch).toHaveBeenCalled() + expect(useUIStore.getState().clearFileChanges).toHaveBeenCalled() + expect(useUIStore.getState().resetPreviewTabs).toHaveBeenCalled() + expect(gatewayAPI.switchWorkspace).toHaveBeenCalledWith('w2') + expect(useGatewayStore.getState().notifyProviderChanged).toHaveBeenCalled() + expect(fetchSessions).toHaveBeenCalledWith(gatewayAPI, true) + expect(useWorkspaceStore.getState().currentWorkspaceHash).toBe('w2') + }) + + it('createWorkspace failure reports toast', async () => { + const showToast = vi.fn() + useUIStore.setState({ showToast } as any) + const gatewayAPI = { + createWorkspace: vi.fn().mockRejectedValue(new Error('boom')), + } as any + + await useWorkspaceStore.getState().createWorkspace('/x', gatewayAPI) + expect(showToast).toHaveBeenCalledWith('Failed to create workspace', 'error') + }) + + it('deleteWorkspace switches to remaining first workspace when current is removed', async () => { + const switchWorkspace = vi.spyOn(useWorkspaceStore.getState(), 'switchWorkspace') + useWorkspaceStore.setState({ + workspaces: [ + { hash: 'w1', path: '/1', name: '1', createdAt: '1', updatedAt: '1' }, + { hash: 'w2', path: '/2', name: '2', createdAt: '1', updatedAt: '1' }, + ], + currentWorkspaceHash: 'w1', + } as any) + const gatewayAPI = { deleteWorkspace: vi.fn().mockResolvedValue(undefined) } as any + + await useWorkspaceStore.getState().deleteWorkspace('w1', gatewayAPI) + await flushPromises() + expect(gatewayAPI.deleteWorkspace).toHaveBeenCalledWith('w1') + expect(useWorkspaceStore.getState().workspaces.map((w) => w.hash)).toEqual(['w2']) + expect(switchWorkspace).toHaveBeenCalledWith('w2', gatewayAPI) + }) +}) diff --git a/web/src/utils/slashCommands.test.ts b/web/src/utils/slashCommands.test.ts new file mode 100644 index 00000000..5533139a --- /dev/null +++ b/web/src/utils/slashCommands.test.ts @@ -0,0 +1,68 @@ +import { describe, expect, it } from 'vitest' +import { + builtinSlashCommands, + isBuiltinCommand, + isKnownSlashCommand, + isSkillCommand, + isSlashCommand, + matchSlashCommands, + parseSlashCommand, + type AnySlashCommand, +} from './slashCommands' + +describe('slashCommands utils', () => { + it('parses command with and without argument', () => { + expect(parseSlashCommand('/help')).toEqual({ command: '/help', argument: '' }) + expect(parseSlashCommand('/remember user is Alice')).toEqual({ + command: '/remember', + argument: 'user is Alice', + }) + expect(parseSlashCommand('hello')).toBeNull() + }) + + it('detects slash command shape', () => { + expect(isSlashCommand('/')).toBe(false) + expect(isSlashCommand('/a')).toBe(true) + expect(isSlashCommand(' /skills')).toBe(true) + expect(isSlashCommand('abc')).toBe(false) + }) + + it('matches commands by usage/description/id', () => { + const skill: AnySlashCommand = { + id: 'my-skill', + usage: '/skill.my', + description: 'my custom skill', + hasArgument: false, + isSkill: true, + skillId: 'my-skill', + active: false, + } + const commands = [...builtinSlashCommands, skill] + expect(matchSlashCommands('/help', commands).map((c) => c.id)).toContain('help') + expect(matchSlashCommands('/com', commands).map((c) => c.id)).toContain('compact') + expect(matchSlashCommands('/my', commands).map((c) => c.id)).toContain('my-skill') + }) + + it('checks known builtin slash command', () => { + expect(isKnownSlashCommand('/help')).toBe(true) + expect(isKnownSlashCommand('/help foo')).toBe(true) + expect(isKnownSlashCommand('/skill.my')).toBe(false) + }) + + it('guards builtin and skill commands', () => { + const builtin = builtinSlashCommands[0] + const skill: AnySlashCommand = { + id: 'my-skill', + usage: '/skill.my', + description: 'desc', + hasArgument: false, + isSkill: true, + skillId: 'my-skill', + active: true, + } + expect(isBuiltinCommand(builtin)).toBe(true) + expect(isSkillCommand(builtin)).toBe(false) + expect(isBuiltinCommand(skill)).toBe(false) + expect(isSkillCommand(skill)).toBe(true) + }) +}) diff --git a/web/vitest.config.ts b/web/vitest.config.ts index 131d4d35..d632db9a 100644 --- a/web/vitest.config.ts +++ b/web/vitest.config.ts @@ -8,6 +8,15 @@ export default defineConfig({ environment: 'jsdom', globals: true, setupFiles: './src/test/setup.ts', + coverage: { + include: ['src/**/*.{ts,tsx}'], + exclude: [ + 'src/**/*.d.ts', + 'src/**/*.test.{ts,tsx}', + 'src/test/**', + 'src/**/*.css', + ], + }, }, resolve: { alias: { From 3cd87e56415a5a63ecee5e686a0d181b5eed7457 Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Sat, 9 May 2026 23:32:00 +0800 Subject: [PATCH 3/4] =?UTF-8?q?fix:=E6=94=B6=E6=95=9B=E9=A3=8E=E9=99=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/runtime/checkpoint_flow_test.go | 72 +++++++++++++++++++++++ internal/runtime/checkpoint_gate.go | 19 ++++-- internal/runtime/run.go | 5 +- internal/runtime/state.go | 1 + web/src/components/ErrorBoundary.test.tsx | 3 +- web/src/stores/useWorkspaceStore.test.ts | 4 +- web/src/utils/eventBridge.test.ts | 46 +++++++++++++++ web/src/utils/eventBridge.ts | 1 + 8 files changed, 142 insertions(+), 9 deletions(-) diff --git a/internal/runtime/checkpoint_flow_test.go b/internal/runtime/checkpoint_flow_test.go index 13361b3a..44cd0a6d 100644 --- a/internal/runtime/checkpoint_flow_test.go +++ b/internal/runtime/checkpoint_flow_test.go @@ -249,6 +249,78 @@ func TestCreatePreRunDriftRebaseCheckpoint_UsesDriftReasonAndPerEditRef(t *testi } } +func TestCreatePreRunDriftRebaseCheckpoint_UsesEffectiveWorkdirWhenSessionWorkdirEmpty(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + + absPath := filepath.Join(fixture.workdir, "effective.txt") + if err := os.WriteFile(absPath, []byte("effective"), 0o644); err != nil { + t.Fatalf("WriteFile(effective) error = %v", err) + } + + session := fixture.session + session.Workdir = "" + state := newRunState("run-effective-workdir", session) + state.effectiveWorkdir = fixture.workdir + checkpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint(context.Background(), &state, repository.FingerprintDiff{ + Added: []string{"effective.txt"}, + }) + if err != nil { + t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err) + } + if checkpointID == "" { + t.Fatal("expected non-empty drift baseline checkpoint id") + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("expected 1 checkpoint, got %d", len(records)) + } + if records[0].Workdir != fixture.workdir { + t.Fatalf("record workdir = %q, want %q", records[0].Workdir, fixture.workdir) + } + if records[0].WorkspaceKey != agentsession.WorkspacePathKey(fixture.workdir) { + t.Fatalf("workspace key = %q, want %q", records[0].WorkspaceKey, agentsession.WorkspacePathKey(fixture.workdir)) + } +} + +func TestRecordRunEndWorkspaceStateUsesEffectiveWorkdirWhenSessionWorkdirEmpty(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + + absPath := filepath.Join(fixture.workdir, "run-end.txt") + if err := os.WriteFile(absPath, []byte("run end"), 0o644); err != nil { + t.Fatalf("WriteFile(run-end) error = %v", err) + } + + session := fixture.session + session.Workdir = "" + state := newRunState("run-end-effective-workdir", session) + state.effectiveWorkdir = fixture.workdir + fixture.service.recordRunEndWorkspaceState( + context.Background(), + session.ID, + effectiveWorkdirForCheckpointState(&state, session), + "cp-current-effective", + ) + + workspaceKey := agentsession.WorkspacePathKey(fixture.workdir) + loaded, ok, err := fixture.checkpointStore.LoadWorkspaceCheckpointState(context.Background(), workspaceKey) + if err != nil { + t.Fatalf("LoadWorkspaceCheckpointState() error = %v", err) + } + if !ok { + t.Fatal("expected workspace checkpoint state to be persisted") + } + if loaded.CurrentCheckpointID != "cp-current-effective" { + t.Fatalf("current checkpoint id = %q, want cp-current-effective", loaded.CurrentCheckpointID) + } + if loaded.WorkspaceKey != workspaceKey { + t.Fatalf("workspace key = %q, want %q", loaded.WorkspaceKey, workspaceKey) + } +} + func TestCheckpointDiff_ScopeRun_RecreatedFileAfterDriftBaselineIsAdded(t *testing.T) { fixture := newRuntimeCheckpointFixture(t) state := newRunState("run-drift-add", fixture.session) diff --git a/internal/runtime/checkpoint_gate.go b/internal/runtime/checkpoint_gate.go index 2c5ec5f7..db4dfd5c 100644 --- a/internal/runtime/checkpoint_gate.go +++ b/internal/runtime/checkpoint_gate.go @@ -54,9 +54,9 @@ func (s *Service) createPreRunDriftRebaseCheckpoint( state.mu.Lock() session := state.session runID := state.runID + workdir := effectiveWorkdirForCheckpointState(state, session) state.mu.Unlock() - workdir := strings.TrimSpace(session.Workdir) if workdir == "" { return "", fmt.Errorf("checkpoint: workdir is empty when creating drift rebase checkpoint") } @@ -165,7 +165,7 @@ func (s *Service) createCheckpointRecord( return fmt.Errorf("checkpoint: marshal messages: %w", err) } - effectiveWorkdir := strings.TrimSpace(session.Workdir) + effectiveWorkdir := effectiveWorkdirForCheckpointState(state, session) now := time.Now() ref := checkpoint.RefForPerEditCheckpoint(checkpointID) @@ -229,12 +229,13 @@ func (s *Service) createSessionOnlyCheckpoint( return fmt.Errorf("checkpoint: marshal session-only messages: %w", err) } + effectiveWorkdir := effectiveWorkdirForCheckpointState(state, session) record := agentsession.CheckpointRecord{ CheckpointID: checkpointID, - WorkspaceKey: agentsession.WorkspacePathKey(session.Workdir), + WorkspaceKey: agentsession.WorkspacePathKey(effectiveWorkdir), SessionID: session.ID, RunID: runID, - Workdir: session.Workdir, + Workdir: effectiveWorkdir, CreatedAt: now, Reason: reason, Restorable: true, @@ -294,3 +295,13 @@ func (s *Service) findPreviousEndOfTurnCheckpoint(ctx context.Context, sessionID } return "" } + +// effectiveWorkdirForCheckpointState 返回 checkpoint 流程应使用的工作目录,优先使用本次 run 已归一化的目录。 +func effectiveWorkdirForCheckpointState(state *runState, session agentsession.Session) string { + if state != nil { + if workdir := strings.TrimSpace(state.effectiveWorkdir); workdir != "" { + return workdir + } + } + return strings.TrimSpace(session.Workdir) +} diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 58bc4cfd..6887ee55 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -137,7 +137,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { } if statePtr != nil { runEndCtx := context.Background() - s.recordRunEndWorkspaceState(runEndCtx, statePtr.session.ID, statePtr.session.Workdir, statePtr.lastEndOfTurnCheckpointID) + s.recordRunEndWorkspaceState(runEndCtx, statePtr.session.ID, effectiveWorkdirForCheckpointState(statePtr, statePtr.session), statePtr.lastEndOfTurnCheckpointID) s.clearRunCheckpointCaches(statePtr.session.ID, statePtr.runID) } s.emitRunTermination(runCtx, input, statePtr, err) @@ -175,6 +175,8 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { } state := newRunState(input.RunID, session) + effectiveWorkdir := agentsession.EffectiveWorkdir(state.session.Workdir, initialCfg.Workdir) + state.effectiveWorkdir = effectiveWorkdir state.runToken = runToken state.thinkingOverride = cloneThinkingOverride(input.ThinkingOverride) state.disableTools = input.DisableTools @@ -195,7 +197,6 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { return s.handleRunError(err) } - effectiveWorkdir := agentsession.EffectiveWorkdir(state.session.Workdir, initialCfg.Workdir) runBaselineCheckpointID := "" if drifted, driftDiff := s.recordRunStartFingerprint(ctx, state.session.ID, state.runID, effectiveWorkdir); drifted { if checkpointID, cpErr := s.createPreRunDriftRebaseCheckpoint(ctx, &state, driftDiff); cpErr != nil { diff --git a/internal/runtime/state.go b/internal/runtime/state.go index 6606f552..2440bddf 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -17,6 +17,7 @@ type runState struct { runID string runToken uint64 session agentsession.Session + effectiveWorkdir string compactCount int reactiveCompactAttempts int rememberedThisRun bool diff --git a/web/src/components/ErrorBoundary.test.tsx b/web/src/components/ErrorBoundary.test.tsx index dc43123b..b4d99dbf 100644 --- a/web/src/components/ErrorBoundary.test.tsx +++ b/web/src/components/ErrorBoundary.test.tsx @@ -1,8 +1,9 @@ import { describe, expect, it, vi } from 'vitest' import { fireEvent, render, screen } from '@testing-library/react' +import type { ReactElement } from 'react' import { ErrorBoundary } from './ErrorBoundary' -function Crash() { +function Crash(): ReactElement { throw new Error('boom') } diff --git a/web/src/stores/useWorkspaceStore.test.ts b/web/src/stores/useWorkspaceStore.test.ts index 2ffe4a86..a55103e3 100644 --- a/web/src/stores/useWorkspaceStore.test.ts +++ b/web/src/stores/useWorkspaceStore.test.ts @@ -42,7 +42,7 @@ describe('useWorkspaceStore', () => { }) it('deduplicates concurrent fetchWorkspaces calls', async () => { - let resolveList: ((value: any) => void) | null = null + let resolveList!: (value: any) => void const gatewayAPI = { listWorkspaces: vi.fn(() => new Promise((resolve) => { resolveList = resolve })), } as any @@ -51,7 +51,7 @@ describe('useWorkspaceStore', () => { const p2 = useWorkspaceStore.getState().fetchWorkspaces(gatewayAPI) expect(gatewayAPI.listWorkspaces).toHaveBeenCalledTimes(1) - resolveList?.({ + resolveList({ payload: { workspaces: [{ hash: 'w1', path: '/a', name: 'A', created_at: '1', updated_at: '2' }], }, diff --git a/web/src/utils/eventBridge.test.ts b/web/src/utils/eventBridge.test.ts index b6a7c281..9c41d350 100644 --- a/web/src/utils/eventBridge.test.ts +++ b/web/src/utils/eventBridge.test.ts @@ -516,6 +516,52 @@ describe('eventBridge', () => { expect(useUIStore.getState().fileChanges).toHaveLength(0) }) + it('CheckpointRestored invalidates in-flight run-scoped file change refreshes', async () => { + let resolveDiff: ((value: unknown) => void) | undefined + const checkpointDiff = vi.fn(() => new Promise((resolve) => { + resolveDiff = resolve + })) + const loadSession = vi.fn(async () => ({ + payload: { + id: 'sess-1', + agent_mode: 'build', + messages: [{ role: 'assistant', content: 'after restore' }], + }, + })) + const api = createMockGatewayAPI({ checkpointDiff, loadSession }) + useSessionStore.setState({ currentSessionId: 'sess-1' } as any) + useGatewayStore.setState({ currentRunId: 'run-1' } as any) + + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp-end', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'end_of_turn' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + expect(checkpointDiff).toHaveBeenCalled() + + handleGatewayEvent({ + type: EventType.CheckpointRestored, + payload: { payload: { runtime_event_type: EventType.CheckpointRestored, payload: { checkpoint_id: 'cp-restored', session_id: 'sess-1', guard_checkpoint_id: 'guard-1' } } }, + session_id: 'sess-1', + run_id: 'run-restore', + }, api) + await Promise.resolve() + await Promise.resolve() + + resolveDiff?.({ + payload: { + checkpoint_id: 'cp-end', + files: { modified: ['stale.txt'] }, + patch: '--- a/stale.txt\n+++ b/stale.txt\n@@ -1 +1 @@\n-old\n+new\n', + }, + }) + await Promise.resolve() + await Promise.resolve() + + expect(useUIStore.getState().fileChanges).toHaveLength(0) + }) + it('CheckpointRestored does not reload when event session differs from current session', async () => { const loadSession = vi.fn(async () => ({ payload: { id: 'sess-other', messages: [] } })) const api = createMockGatewayAPI({ loadSession }) diff --git a/web/src/utils/eventBridge.ts b/web/src/utils/eventBridge.ts index 8ebe17b0..5b0a37be 100644 --- a/web/src/utils/eventBridge.ts +++ b/web/src/utils/eventBridge.ts @@ -352,6 +352,7 @@ function refreshSessionAfterCheckpointRestoreEvent( } const requestId = ++_latestRestoreSyncRequestId + _latestRunDiffRequestId += 1 const reloadSeq = beginCheckpointRestoreReloadSeq() _firstTouchRollbackCheckpointByPath = new Map() useUIStore.getState().setRestoringCheckpoint(true) From 379f4bacb733731a3a0dae0f8bd4e570ef6d5b00 Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Sun, 10 May 2026 10:10:02 +0800 Subject: [PATCH 4/4] fix(web): restore slash command suggestions --- web/src/components/chat/ChatInput.test.tsx | 98 +++- web/src/components/chat/ChatInput.tsx | 512 ++++++++++++------ .../components/chat/SlashCommandMenu.test.tsx | 107 ++-- web/src/components/chat/SlashCommandMenu.tsx | 54 +- web/src/utils/slashCommands.test.ts | 129 +++-- web/src/utils/slashCommands.ts | 93 +++- 6 files changed, 667 insertions(+), 326 deletions(-) diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx index f9de008e..8c610e7b 100644 --- a/web/src/components/chat/ChatInput.test.tsx +++ b/web/src/components/chat/ChatInput.test.tsx @@ -1,25 +1,60 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' -import { fireEvent, render, screen } from '@testing-library/react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' import ChatInput from './ChatInput' import { useChatStore } from '@/stores/useChatStore' import { useComposerStore } from '@/stores/useComposerStore' import { useSessionStore } from '@/stores/useSessionStore' +const mockGatewayAPI = { + listAvailableSkills: vi.fn(), + listModels: vi.fn(), + run: vi.fn(), + bindStream: vi.fn(), + cancel: vi.fn(), + compact: vi.fn(), + executeSystemTool: vi.fn(), + activateSessionSkill: vi.fn(), + deactivateSessionSkill: vi.fn(), +} + vi.mock('@/context/RuntimeProvider', () => ({ - useGatewayAPI: () => null, + useGatewayAPI: () => mockGatewayAPI, +})) + +vi.mock('./ModelSelector', () => ({ + default: () =>
, })) describe('ChatInput', () => { beforeEach(() => { + vi.clearAllMocks() + mockGatewayAPI.listAvailableSkills.mockResolvedValue({ + payload: { + skills: [ + { + descriptor: { id: 'skill.demo', description: 'demo skill' }, + active: true, + }, + ], + }, + }) + mockGatewayAPI.listModels.mockResolvedValue({ + payload: { + models: [], + selected_provider_id: '', + selected_model_id: '', + }, + }) + useComposerStore.setState({ composerText: '' }) - useSessionStore.setState({ currentSessionId: '' } as any) + useSessionStore.setState({ currentSessionId: '' } as never) useChatStore.setState({ isGenerating: false, messages: [], permissionRequests: [], agentMode: 'build', permissionMode: 'default', - } as any) + } as never) }) it('shows the default/bypass selector in build mode', () => { @@ -40,10 +75,63 @@ describe('ChatInput', () => { expect(screen.queryByRole('button', { name: 'bypass' })).not.toBeInTheDocument() }) + it('opens slash suggestions for bare slash and loads skills immediately', async () => { + render() + + const textarea = screen.getByRole('textbox') + fireEvent.change(textarea, { target: { value: '/' } }) + + await waitFor(() => { + expect(screen.getByTestId('slash-command-menu')).toBeInTheDocument() + }) + await waitFor(() => { + expect(mockGatewayAPI.listAvailableSkills).toHaveBeenCalledWith(undefined) + }) + await waitFor(() => { + expect(screen.getByText('/skill.demo')).toBeInTheDocument() + }) + }) + + it('keeps slash menu visible for fuzzy inputs like /w', async () => { + render() + + const textarea = screen.getByRole('textbox') + fireEvent.change(textarea, { target: { value: '/w' } }) + + await waitFor(() => { + expect(screen.getByTestId('slash-command-menu')).toBeInTheDocument() + }) + expect(screen.getAllByText((_, el) => Boolean(el?.textContent?.includes('/help'))).length).toBeGreaterThan(0) + }) + + it('supports keyboard navigation, tab completion, and escape for slash menu', async () => { + render() + + const textarea = screen.getByRole('textbox') as HTMLTextAreaElement + fireEvent.change(textarea, { target: { value: '/' } }) + + await waitFor(() => { + expect(screen.getByTestId('slash-command-menu')).toBeInTheDocument() + }) + + fireEvent.keyDown(textarea, { key: 'ArrowDown' }) + fireEvent.keyDown(textarea, { key: 'Tab' }) + expect(textarea.value).toBe('/compact ') + + fireEvent.change(textarea, { target: { value: '/' } }) + await waitFor(() => { + expect(screen.getByTestId('slash-command-menu')).toBeInTheDocument() + }) + fireEvent.keyDown(textarea, { key: 'Escape' }) + await waitFor(() => { + expect(screen.queryByTestId('slash-command-menu')).not.toBeInTheDocument() + }) + }) + it('does not render the unimplemented attachment and mention buttons', () => { render() - expect(screen.queryByTitle('附加文件')).not.toBeInTheDocument() + expect(screen.queryByTitle('附件文件')).not.toBeInTheDocument() expect(screen.queryByTitle('引用上下文')).not.toBeInTheDocument() }) }) diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx index f35c6c5d..ff4258e4 100644 --- a/web/src/components/chat/ChatInput.tsx +++ b/web/src/components/chat/ChatInput.tsx @@ -1,4 +1,4 @@ -import { useState, useRef, useEffect, useCallback } from 'react' +import { useState, useRef, useEffect, useCallback, useMemo } from 'react' import { useChatStore, createUserMessage } from '@/stores/useChatStore' import { useGatewayStore } from '@/stores/useGatewayStore' import { useSessionStore, isValidSessionId } from '@/stores/useSessionStore' @@ -21,52 +21,97 @@ import SkillPicker from './SkillPicker' import ModelSelector from './ModelSelector' import { Send, Square } from 'lucide-react' +const slashMenuAnchorStyle: React.CSSProperties = { + position: 'absolute', + left: 0, + bottom: 'calc(100% + 8px)', + zIndex: 100, +} + +/** 将网关返回的技能列表转换成输入框使用的 slash 命令结构。 */ +function buildSkillSlashCommands( + skills: Array<{ descriptor: { id: string; description?: string }; active?: boolean }>, +): SkillSlashCommand[] { + return skills.map((skill) => ({ + id: `skill-${skill.descriptor.id}`, + usage: `/${skill.descriptor.id}`, + description: skill.descriptor.description || '技能', + hasArgument: false, + isSkill: true, + skillId: skill.descriptor.id, + active: Boolean(skill.active), + })) +} + +/** 用当前命令定义生成帮助文本,避免菜单与帮助内容漂移。 */ +function buildSlashHelpText(commands: AnySlashCommand[]): string { + const helpCommands = commands.length > 0 ? commands : builtinSlashCommands + const maxLen = helpCommands.reduce((max, command) => Math.max(max, command.usage.length), 0) + const lines = helpCommands.map((command) => { + const pad = ' '.repeat(maxLen - command.usage.length) + const description = isSkillCommand(command) && command.active + ? `${command.description} (已激活)` + : command.description + return ` ${command.usage}${pad} - ${description}` + }) + return ['可用命令:', ...lines].join('\n') +} + export default function ChatInput() { const gatewayAPI = useGatewayAPI() - const text = useComposerStore((s) => s.composerText) - const setText = useComposerStore((s) => s.setComposerText) + const text = useComposerStore((state) => state.composerText) + const setText = useComposerStore((state) => state.setComposerText) const [rows, setRows] = useState(1) const textareaRef = useRef(null) const runCancelledRef = useRef(false) const composingRef = useRef(false) - const isGenerating = useChatStore((s) => s.isGenerating) - const addMessage = useChatStore((s) => s.addMessage) - const addSystemMessage = useChatStore((s) => s.addSystemMessage) - const setGenerating = useChatStore((s) => s.setGenerating) - const sessionId = useSessionStore((s) => s.currentSessionId) - const agentMode = useChatStore((s) => s.agentMode) - const setAgentMode = useChatStore((s) => s.setAgentMode) - const permissionMode = useChatStore((s) => s.permissionMode) - const setPermissionMode = useChatStore((s) => s.setPermissionMode) + const isGenerating = useChatStore((state) => state.isGenerating) + const addMessage = useChatStore((state) => state.addMessage) + const addSystemMessage = useChatStore((state) => state.addSystemMessage) + const setGenerating = useChatStore((state) => state.setGenerating) + const sessionId = useSessionStore((state) => state.currentSessionId) + const agentMode = useChatStore((state) => state.agentMode) + const setAgentMode = useChatStore((state) => state.setAgentMode) + const permissionMode = useChatStore((state) => state.permissionMode) + const setPermissionMode = useChatStore((state) => state.setPermissionMode) const [showSlashMenu, setShowSlashMenu] = useState(false) const [selectedIndex, setSelectedIndex] = useState(0) const [matchedCommands, setMatchedCommands] = useState([]) const [availableSkillCommands, setAvailableSkillCommands] = useState([]) const [showSkillPicker, setShowSkillPicker] = useState(false) + const allSlashCommands = useMemo( + () => [...builtinSlashCommands, ...availableSkillCommands], + [availableSkillCommands], + ) useEffect(() => { - if (!showSlashMenu || !gatewayAPI) return + if (!gatewayAPI || !text.trimLeft().startsWith('/')) return + let cancelled = false gatewayAPI.listAvailableSkills(sessionId || undefined).then((result) => { + if (cancelled) return const skills = result.payload?.skills || [] - const skillCommands: SkillSlashCommand[] = skills.map((s) => ({ - id: `skill-${s.descriptor.id}`, - usage: `/${s.descriptor.id}`, - description: s.descriptor.description || '技能', - hasArgument: false, isSkill: true, - skillId: s.descriptor.id, active: s.active, - })) - setAvailableSkillCommands(skillCommands) - }).catch(() => { setAvailableSkillCommands([]) }) - }, [showSlashMenu, gatewayAPI, sessionId]) + setAvailableSkillCommands(buildSkillSlashCommands(skills)) + }).catch(() => { + if (!cancelled) setAvailableSkillCommands([]) + }) + return () => { + cancelled = true + } + }, [text, gatewayAPI, sessionId]) useEffect(() => { - if (!isSlashCommand(text)) { setShowSlashMenu(false); return } - const allCommands: AnySlashCommand[] = [...builtinSlashCommands, ...availableSkillCommands] - const matched = matchSlashCommands(text, allCommands) - if (matched.length > 0) { setMatchedCommands(matched); setShowSlashMenu(true); setSelectedIndex(0) } - else { setShowSlashMenu(false) } - }, [text, availableSkillCommands]) + if (!isSlashCommand(text)) { + setMatchedCommands([]) + setShowSlashMenu(false) + return + } + + const matched = matchSlashCommands(text, allSlashCommands) + setMatchedCommands(matched) + setShowSlashMenu(matched.length > 0) + if (matched.length > 0) setSelectedIndex(0) + }, [text, allSlashCommands]) useEffect(() => { const lines = text.split('\n').length @@ -76,137 +121,249 @@ export default function ChatInput() { const executeSlashCommand = useCallback(async (input: string): Promise => { const parsed = parseSlashCommand(input) if (!parsed) return false + const { command, argument } = parsed const currentSessionId = sessionId const api = gatewayAPI - if (!api) { useUIStore.getState().showToast('Gateway not connected', 'error'); return true } + if (!api) { + useUIStore.getState().showToast('Gateway not connected', 'error') + return true + } switch (command) { case '/help': { - const allUsages = [...builtinSlashCommands.map((c) => c.usage), '/'] - const maxLen = Math.max(...allUsages.map((u) => u.length)) - const helpLines = [ - '可用命令:', - ...builtinSlashCommands.map((cmd) => { - const pad = ' '.repeat(maxLen - cmd.usage.length) - return ` ${cmd.usage}${pad} — ${cmd.description}` - }), - ` /${''.padEnd(maxLen - 1)} — 激活/停用技能`, - ] - addSystemMessage(helpLines.join('\n')) + addSystemMessage(buildSlashHelpText(allSlashCommands)) return true } case '/compact': { - if (!isValidSessionId(currentSessionId)) { useUIStore.getState().showToast('Send a message first to start a session', 'error'); return true } - try { await api.compact(currentSessionId, '') } catch (err) { console.error('Compact failed:', err); useUIStore.getState().showToast('Compaction failed', 'error') } + if (!isValidSessionId(currentSessionId)) { + useUIStore.getState().showToast('Send a message first to start a session', 'error') + return true + } + try { + await api.compact(currentSessionId, '') + } catch (err) { + console.error('Compact failed:', err) + useUIStore.getState().showToast('Compaction failed', 'error') + } return true } case '/memo': { - if (!isValidSessionId(currentSessionId)) { useUIStore.getState().showToast('Send a message first to start a session', 'error'); return true } + if (!isValidSessionId(currentSessionId)) { + useUIStore.getState().showToast('Send a message first to start a session', 'error') + return true + } try { const result = await api.executeSystemTool(currentSessionId, '', 'memo_list', {}) - addSystemMessage((result as any)?.payload?.content || 'Memo query complete') - } catch (err) { console.error('Memo list failed:', err); useUIStore.getState().showToast('Failed to query memo', 'error') } + addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo query complete') + } catch (err) { + console.error('Memo list failed:', err) + useUIStore.getState().showToast('Failed to query memo', 'error') + } return true } case '/remember': { - if (!argument) { useUIStore.getState().showToast('Usage: /remember ', 'error'); return true } - if (!isValidSessionId(currentSessionId)) { useUIStore.getState().showToast('Send a message first to start a session', 'error'); return true } + if (!argument) { + useUIStore.getState().showToast('Usage: /remember ', 'error') + return true + } + if (!isValidSessionId(currentSessionId)) { + useUIStore.getState().showToast('Send a message first to start a session', 'error') + return true + } try { - const result = await api.executeSystemTool(currentSessionId, '', 'memo_remember', { type: 'user', title: argument, content: argument }) - addSystemMessage((result as any)?.payload?.content || 'Memo saved') - } catch (err) { console.error('Remember failed:', err); useUIStore.getState().showToast('Failed to save memo', 'error') } + const result = await api.executeSystemTool(currentSessionId, '', 'memo_remember', { + type: 'user', + title: argument, + content: argument, + }) + addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo saved') + } catch (err) { + console.error('Remember failed:', err) + useUIStore.getState().showToast('Failed to save memo', 'error') + } return true } case '/forget': { - if (!argument) { useUIStore.getState().showToast('Usage: /forget ', 'error'); return true } - if (!isValidSessionId(currentSessionId)) { useUIStore.getState().showToast('Send a message first to start a session', 'error'); return true } + if (!argument) { + useUIStore.getState().showToast('Usage: /forget ', 'error') + return true + } + if (!isValidSessionId(currentSessionId)) { + useUIStore.getState().showToast('Send a message first to start a session', 'error') + return true + } try { - const result = await api.executeSystemTool(currentSessionId, '', 'memo_remove', { keyword: argument, scope: 'all' }) - addSystemMessage((result as any)?.payload?.content || 'Memo deleted') - } catch (err) { console.error('Forget failed:', err); useUIStore.getState().showToast('Failed to delete memo', 'error') } + const result = await api.executeSystemTool(currentSessionId, '', 'memo_remove', { + keyword: argument, + scope: 'all', + }) + addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo deleted') + } catch (err) { + console.error('Forget failed:', err) + useUIStore.getState().showToast('Failed to delete memo', 'error') + } + return true + } + case '/skills': { + setShowSkillPicker(true) return true } - case '/skills': { setShowSkillPicker(true); return true } default: { - if (isGenerating) { useUIStore.getState().showToast('Cannot toggle skill while generating', 'info'); return true } - const skillCmd = availableSkillCommands.find((s) => s.usage === command) - if (skillCmd && isValidSessionId(currentSessionId)) { + if (isGenerating) { + useUIStore.getState().showToast('Cannot toggle skill while generating', 'info') + return true + } + const skillCommand = availableSkillCommands.find((skill) => skill.usage === command) + if (skillCommand && isValidSessionId(currentSessionId)) { try { - if (skillCmd.active) await api.deactivateSessionSkill(currentSessionId, skillCmd.skillId) - else await api.activateSessionSkill(currentSessionId, skillCmd.skillId) - setAvailableSkillCommands((prev) => prev.map((item) => item.skillId === skillCmd.skillId ? { ...item, active: !item.active } : item)) - } catch (err) { console.error('Skill toggle failed:', err); useUIStore.getState().showToast('Skill operation failed', 'error') } + if (skillCommand.active) { + await api.deactivateSessionSkill(currentSessionId, skillCommand.skillId) + } else { + await api.activateSessionSkill(currentSessionId, skillCommand.skillId) + } + setAvailableSkillCommands((prev) => prev.map((item) => ( + item.skillId === skillCommand.skillId + ? { ...item, active: !item.active } + : item + ))) + } catch (err) { + console.error('Skill toggle failed:', err) + useUIStore.getState().showToast('Skill operation failed', 'error') + } return true } return false } } - }, [gatewayAPI, sessionId, addSystemMessage, availableSkillCommands, isGenerating]) + }, [gatewayAPI, sessionId, addSystemMessage, availableSkillCommands, isGenerating, allSlashCommands]) async function handleSubmit() { const input = text.trim() if (!input) return + if (isGenerating) { if (isSlashCommand(input)) useUIStore.getState().showToast('Cannot run commands while generating', 'info') return } + if (isSlashCommand(input)) { - setText(''); setShowSlashMenu(false) + setText('') + setShowSlashMenu(false) const handled = await executeSlashCommand(input) if (handled) return } + setText('') const userMsg = createUserMessage(input) addMessage(userMsg) useRuntimeInsightStore.getState().setTodoSnapshot({ - items: [], summary: { total: 0, required_total: 0, required_completed: 0, required_failed: 0, required_open: 0 }, + items: [], + summary: { total: 0, required_total: 0, required_completed: 0, required_failed: 0, required_open: 0 }, }) setGenerating(true) runCancelledRef.current = false + try { if (!gatewayAPI) return const isNewSession = !isValidSessionId(sessionId) - const ack = await gatewayAPI.run({ session_id: isNewSession ? undefined : sessionId, new_session: isNewSession ? true : undefined, input_text: input, mode: agentMode }) + const ack = await gatewayAPI.run({ + session_id: isNewSession ? undefined : sessionId, + new_session: isNewSession ? true : undefined, + input_text: input, + mode: agentMode, + }) if (!runCancelledRef.current) { - const gwStore = useGatewayStore.getState() - const sessStore = useSessionStore.getState() - if (ack.run_id) gwStore.setCurrentRunId(ack.run_id) - if (ack.session_id) { sessStore.setCurrentSessionId(ack.session_id); gatewayAPI?.bindStream({ session_id: ack.session_id, channel: 'all' }).catch(() => {}) } + const gatewayStore = useGatewayStore.getState() + const sessionStore = useSessionStore.getState() + if (ack.run_id) gatewayStore.setCurrentRunId(ack.run_id) + if (ack.session_id) { + sessionStore.setCurrentSessionId(ack.session_id) + gatewayAPI.bindStream({ session_id: ack.session_id, channel: 'all' }).catch(() => {}) + } } } catch (err) { if (!runCancelledRef.current) { - setGenerating(false); useChatStore.getState().removeMessage(userMsg.id) - console.error('Run failed:', err); useUIStore.getState().showToast('Failed to send message', 'error') + setGenerating(false) + useChatStore.getState().removeMessage(userMsg.id) + console.error('Run failed:', err) + useUIStore.getState().showToast('Failed to send message', 'error') } } } function handleKeyDown(e: React.KeyboardEvent) { if (composingRef.current) return + if (!showSlashMenu) { - if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); handleSubmit() } + if (e.key === 'Enter' && !e.shiftKey) { + e.preventDefault() + handleSubmit() + } return } + switch (e.key) { - case 'ArrowDown': e.preventDefault(); setSelectedIndex((prev) => (prev + 1) % matchedCommands.length); return - case 'ArrowUp': e.preventDefault(); setSelectedIndex((prev) => (prev - 1 + matchedCommands.length) % matchedCommands.length); return - case 'Enter': e.preventDefault(); const cmd = matchedCommands[selectedIndex]; if (cmd) handleSelectCommand(cmd); return - case 'Escape': e.preventDefault(); setShowSlashMenu(false); return - case 'Tab': e.preventDefault(); const c = matchedCommands[selectedIndex]; if (c) { setText(c.usage + ' '); textareaRef.current?.focus() }; return + case 'ArrowDown': + e.preventDefault() + setSelectedIndex((prev) => (prev + 1) % matchedCommands.length) + return + case 'ArrowUp': + e.preventDefault() + setSelectedIndex((prev) => (prev - 1 + matchedCommands.length) % matchedCommands.length) + return + case 'Enter': { + e.preventDefault() + const command = matchedCommands[selectedIndex] + if (command) handleSelectCommand(command) + return + } + case 'Escape': + e.preventDefault() + setShowSlashMenu(false) + return + case 'Tab': { + e.preventDefault() + const command = matchedCommands[selectedIndex] + if (command) { + setText(`${command.usage} `) + textareaRef.current?.focus() + } + return + } } } function handleSelectCommand(cmd: AnySlashCommand) { - if (isSkillCommand(cmd)) { setText(cmd.usage); setShowSlashMenu(false); executeSlashCommand(cmd.usage); return } - if (cmd.hasArgument) { setText(cmd.usage + ' '); setShowSlashMenu(false); textareaRef.current?.focus() } - else { setText(''); setShowSlashMenu(false); executeSlashCommand(cmd.usage) } + if (isSkillCommand(cmd)) { + setText(cmd.usage) + setShowSlashMenu(false) + void executeSlashCommand(cmd.usage) + return + } + + if (cmd.hasArgument) { + setText(`${cmd.usage} `) + setShowSlashMenu(false) + textareaRef.current?.focus() + return + } + + setText('') + setShowSlashMenu(false) + void executeSlashCommand(cmd.usage) } async function handleCancel() { runCancelledRef.current = true const runId = useGatewayStore.getState().currentRunId - if (runId && gatewayAPI) { try { await gatewayAPI.cancel({ run_id: runId }) } catch (err) { console.error('Cancel failed:', err) } } + if (runId && gatewayAPI) { + try { + await gatewayAPI.cancel({ run_id: runId }) + } catch (err) { + console.error('Cancel failed:', err) + } + } useChatStore.getState().resetGeneratingState() } @@ -218,9 +375,9 @@ export default function ChatInput() { setShowSkillPicker(false)} /> )}
- {showSlashMenu && matchedCommands.length > 0 && ( -
-
+
+ {showSlashMenu && matchedCommands.length > 0 && ( +
-
- )} -
-