diff --git a/internal/checkpoint/checkpoint_manager.go b/internal/checkpoint/checkpoint_manager.go index 347c5582..01f8a185 100644 --- a/internal/checkpoint/checkpoint_manager.go +++ b/internal/checkpoint/checkpoint_manager.go @@ -58,6 +58,29 @@ type SQLiteCheckpointStore struct { ownsDB bool // true 表示本实例打开的连接,Close 时需释放 } +type workspaceFingerprintRow struct { + WorkspaceKey string + FingerprintPayload string + UpdatedAtMS int64 +} + +// WorkspaceCheckpointState 保存工作区当前权威代码基线与文件指纹。 +type WorkspaceCheckpointState struct { + WorkspaceKey string + CurrentCheckpointID string + FingerprintPayload string + UpdatedAt time.Time +} + +// RunCheckpointBaseline 保存单次 run 的权威回退基线,供异步 diff 查询跨 run 结束后读取。 +type RunCheckpointBaseline struct { + SessionID string + RunID string + CheckpointID string + Drifted bool + UpdatedAt time.Time +} + // NewSQLiteCheckpointStore 创建 checkpoint 存储实例。 // dbPath 为 session.db 文件路径,可通过 session.DatabasePath 获取。 func NewSQLiteCheckpointStore(dbPath string) *SQLiteCheckpointStore { @@ -114,6 +137,184 @@ func (s *SQLiteCheckpointStore) ensureDB(ctx context.Context) (*sql.DB, error) { return db, nil } +// SaveWorkspaceFingerprint 把 workspace 最新指纹保存到 SQLite,用于跨会话/重启后的 drift 检测。 +func (s *SQLiteCheckpointStore) SaveWorkspaceFingerprint( + ctx context.Context, + workspaceKey string, + fingerprintPayload string, + updatedAt time.Time, +) error { + if err := ctx.Err(); err != nil { + return err + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + if err := ensureWorkspaceFingerprintTable(ctx, db); err != nil { + return err + } + _, err = db.ExecContext(ctx, ` +INSERT INTO workspace_fingerprints(workspace_key, fingerprint_payload, updated_at_ms) +VALUES(?, ?, ?) +ON CONFLICT(workspace_key) DO UPDATE SET + fingerprint_payload=excluded.fingerprint_payload, + updated_at_ms=excluded.updated_at_ms +`, workspaceKey, fingerprintPayload, toUnixMillis(updatedAt)) + if err != nil { + return fmt.Errorf("checkpoint: save workspace fingerprint %s: %w", workspaceKey, err) + } + return nil +} + +// LoadWorkspaceFingerprint 从 SQLite 读取 workspace 最近指纹,不存在时 ok=false。 +func (s *SQLiteCheckpointStore) LoadWorkspaceFingerprint( + ctx context.Context, + workspaceKey string, +) (string, bool, error) { + if err := ctx.Err(); err != nil { + return "", false, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return "", false, err + } + if err := ensureWorkspaceFingerprintTable(ctx, db); err != nil { + return "", false, err + } + var row workspaceFingerprintRow + err = db.QueryRowContext(ctx, ` +SELECT workspace_key, fingerprint_payload, updated_at_ms +FROM workspace_fingerprints +WHERE workspace_key = ? +`, workspaceKey).Scan(&row.WorkspaceKey, &row.FingerprintPayload, &row.UpdatedAtMS) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", false, nil + } + return "", false, fmt.Errorf("checkpoint: load workspace fingerprint %s: %w", workspaceKey, err) + } + return row.FingerprintPayload, true, nil +} + +// SaveWorkspaceCheckpointState 持久化工作区当前代码基线与文件指纹。 +func (s *SQLiteCheckpointStore) SaveWorkspaceCheckpointState(ctx context.Context, state WorkspaceCheckpointState) error { + if err := ctx.Err(); err != nil { + return err + } + if state.WorkspaceKey == "" { + return fmt.Errorf("checkpoint: workspace key required") + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + if err := ensureWorkspaceCheckpointStateTable(ctx, db); err != nil { + return err + } + _, err = db.ExecContext(ctx, ` +INSERT INTO workspace_checkpoint_states(workspace_key, current_checkpoint_id, fingerprint_payload, updated_at_ms) +VALUES(?, ?, ?, ?) +ON CONFLICT(workspace_key) DO UPDATE SET + current_checkpoint_id=excluded.current_checkpoint_id, + fingerprint_payload=excluded.fingerprint_payload, + updated_at_ms=excluded.updated_at_ms +`, state.WorkspaceKey, state.CurrentCheckpointID, state.FingerprintPayload, toUnixMillis(state.UpdatedAt)) + if err != nil { + return fmt.Errorf("checkpoint: save workspace checkpoint state %s: %w", state.WorkspaceKey, err) + } + return nil +} + +// LoadWorkspaceCheckpointState 读取工作区当前代码基线与文件指纹,不存在时 ok=false。 +func (s *SQLiteCheckpointStore) LoadWorkspaceCheckpointState(ctx context.Context, workspaceKey string) (WorkspaceCheckpointState, bool, error) { + if err := ctx.Err(); err != nil { + return WorkspaceCheckpointState{}, false, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return WorkspaceCheckpointState{}, false, err + } + if err := ensureWorkspaceCheckpointStateTable(ctx, db); err != nil { + return WorkspaceCheckpointState{}, false, err + } + var state WorkspaceCheckpointState + var updatedAtMS int64 + err = db.QueryRowContext(ctx, ` +SELECT workspace_key, current_checkpoint_id, fingerprint_payload, updated_at_ms +FROM workspace_checkpoint_states +WHERE workspace_key = ? +`, workspaceKey).Scan(&state.WorkspaceKey, &state.CurrentCheckpointID, &state.FingerprintPayload, &updatedAtMS) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return WorkspaceCheckpointState{}, false, nil + } + return WorkspaceCheckpointState{}, false, fmt.Errorf("checkpoint: load workspace checkpoint state %s: %w", workspaceKey, err) + } + state.UpdatedAt = fromUnixMillis(updatedAtMS) + return state, true, nil +} + +// SaveRunCheckpointBaseline 持久化单次 run 的权威回退基线。 +func (s *SQLiteCheckpointStore) SaveRunCheckpointBaseline(ctx context.Context, baseline RunCheckpointBaseline) error { + if err := ctx.Err(); err != nil { + return err + } + if baseline.SessionID == "" || baseline.RunID == "" { + return fmt.Errorf("checkpoint: session_id and run_id required for run baseline") + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + if err := ensureRunCheckpointBaselineTable(ctx, db); err != nil { + return err + } + _, err = db.ExecContext(ctx, ` +INSERT INTO run_checkpoint_baselines(session_id, run_id, checkpoint_id, drifted, updated_at_ms) +VALUES(?, ?, ?, ?, ?) +ON CONFLICT(session_id, run_id) DO UPDATE SET + checkpoint_id=excluded.checkpoint_id, + drifted=excluded.drifted, + updated_at_ms=excluded.updated_at_ms +`, baseline.SessionID, baseline.RunID, baseline.CheckpointID, boolToInt(baseline.Drifted), toUnixMillis(baseline.UpdatedAt)) + if err != nil { + return fmt.Errorf("checkpoint: save run baseline %s/%s: %w", baseline.SessionID, baseline.RunID, err) + } + return nil +} + +// LoadRunCheckpointBaseline 读取单次 run 的权威回退基线,不存在时 ok=false。 +func (s *SQLiteCheckpointStore) LoadRunCheckpointBaseline(ctx context.Context, sessionID, runID string) (RunCheckpointBaseline, bool, error) { + if err := ctx.Err(); err != nil { + return RunCheckpointBaseline{}, false, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return RunCheckpointBaseline{}, false, err + } + if err := ensureRunCheckpointBaselineTable(ctx, db); err != nil { + return RunCheckpointBaseline{}, false, err + } + var baseline RunCheckpointBaseline + var drifted int + var updatedAtMS int64 + err = db.QueryRowContext(ctx, ` +SELECT session_id, run_id, checkpoint_id, drifted, updated_at_ms +FROM run_checkpoint_baselines +WHERE session_id = ? AND run_id = ? +`, sessionID, runID).Scan(&baseline.SessionID, &baseline.RunID, &baseline.CheckpointID, &drifted, &updatedAtMS) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return RunCheckpointBaseline{}, false, nil + } + return RunCheckpointBaseline{}, false, fmt.Errorf("checkpoint: load run baseline %s/%s: %w", sessionID, runID, err) + } + baseline.Drifted = drifted != 0 + baseline.UpdatedAt = fromUnixMillis(updatedAtMS) + return baseline, true, nil +} + // CreateCheckpoint 在单一事务内写入 checkpoint record + session checkpoint。 // 事务内完成 record INSERT → session_cp INSERT → record UPDATE(设置 session_checkpoint_ref + status=available)。 func (s *SQLiteCheckpointStore) CreateCheckpoint(ctx context.Context, input CreateCheckpointInput) (session.CheckpointRecord, error) { @@ -663,3 +864,52 @@ func rollbackTx(tx *sql.Tx) { _ = tx.Rollback() } } + +func ensureWorkspaceFingerprintTable(ctx context.Context, db *sql.DB) error { + if db == nil { + return fmt.Errorf("checkpoint: workspace fingerprint db is nil") + } + if _, err := db.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS workspace_fingerprints ( + workspace_key TEXT PRIMARY KEY, + fingerprint_payload TEXT NOT NULL, + updated_at_ms INTEGER NOT NULL +)`); err != nil { + return fmt.Errorf("checkpoint: ensure workspace_fingerprints table: %w", err) + } + return nil +} + +func ensureWorkspaceCheckpointStateTable(ctx context.Context, db *sql.DB) error { + if db == nil { + return fmt.Errorf("checkpoint: workspace checkpoint state db is nil") + } + if _, err := db.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS workspace_checkpoint_states ( + workspace_key TEXT PRIMARY KEY, + current_checkpoint_id TEXT NOT NULL, + fingerprint_payload TEXT NOT NULL, + updated_at_ms INTEGER NOT NULL +)`); err != nil { + return fmt.Errorf("checkpoint: ensure workspace_checkpoint_states table: %w", err) + } + return nil +} + +func ensureRunCheckpointBaselineTable(ctx context.Context, db *sql.DB) error { + if db == nil { + return fmt.Errorf("checkpoint: run checkpoint baseline db is nil") + } + if _, err := db.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS run_checkpoint_baselines ( + session_id TEXT NOT NULL, + run_id TEXT NOT NULL, + checkpoint_id TEXT NOT NULL, + drifted INTEGER NOT NULL, + updated_at_ms INTEGER NOT NULL, + PRIMARY KEY(session_id, run_id) +)`); err != nil { + return fmt.Errorf("checkpoint: ensure run_checkpoint_baselines table: %w", err) + } + return nil +} diff --git a/internal/checkpoint/checkpoint_manager_test.go b/internal/checkpoint/checkpoint_manager_test.go index 6afd96a1..67d3f58a 100644 --- a/internal/checkpoint/checkpoint_manager_test.go +++ b/internal/checkpoint/checkpoint_manager_test.go @@ -588,3 +588,81 @@ func TestNewSQLiteCheckpointStoreWithNilDBClose(t *testing.T) { t.Fatalf("Close(nil db) error = %v", err) } } + +func TestSQLiteCheckpointStoreWorkspaceFingerprintPersistence(t *testing.T) { + fixture := newCheckpointStoreFixture(t) + db, err := fixture.sessionStore.InitDB(context.Background()) + if err != nil { + t.Fatalf("InitDB() error = %v", err) + } + shared := NewSQLiteCheckpointStoreWithDB(db) + + workspaceKey := session.WorkspacePathKey(fixture.workspaceRoot) + payload := `{"a.txt":{"size":1,"mod_time":"2026-01-01T00:00:00Z","head_hash":"abc"}}` + if err := shared.SaveWorkspaceFingerprint(context.Background(), workspaceKey, payload, time.Now()); err != nil { + t.Fatalf("SaveWorkspaceFingerprint() error = %v", err) + } + + got, ok, err := shared.LoadWorkspaceFingerprint(context.Background(), workspaceKey) + if err != nil { + t.Fatalf("LoadWorkspaceFingerprint() error = %v", err) + } + if !ok { + t.Fatal("expected persisted fingerprint to exist") + } + if got != payload { + t.Fatalf("fingerprint payload = %q, want %q", got, payload) + } + + missing, ok, err := shared.LoadWorkspaceFingerprint(context.Background(), "workspace/missing") + if err != nil { + t.Fatalf("LoadWorkspaceFingerprint(missing) error = %v", err) + } + if ok || missing != "" { + t.Fatalf("expected missing fingerprint, got ok=%v payload=%q", ok, missing) + } +} + +func TestSQLiteCheckpointStoreWorkspaceStateAndRunBaselinePersistence(t *testing.T) { + fixture := newCheckpointStoreFixture(t) + db, err := fixture.sessionStore.InitDB(context.Background()) + if err != nil { + t.Fatalf("InitDB() error = %v", err) + } + shared := NewSQLiteCheckpointStoreWithDB(db) + + workspaceKey := session.WorkspacePathKey(fixture.workspaceRoot) + payload := `{"tracked.txt":{"size":7}}` + if err := shared.SaveWorkspaceCheckpointState(context.Background(), WorkspaceCheckpointState{ + WorkspaceKey: workspaceKey, + CurrentCheckpointID: "cp-current", + FingerprintPayload: payload, + UpdatedAt: time.Now(), + }); err != nil { + t.Fatalf("SaveWorkspaceCheckpointState() error = %v", err) + } + state, ok, err := shared.LoadWorkspaceCheckpointState(context.Background(), workspaceKey) + if err != nil { + t.Fatalf("LoadWorkspaceCheckpointState() error = %v", err) + } + if !ok || state.CurrentCheckpointID != "cp-current" || state.FingerprintPayload != payload { + t.Fatalf("workspace state = %#v ok=%v", state, ok) + } + + if err := shared.SaveRunCheckpointBaseline(context.Background(), RunCheckpointBaseline{ + SessionID: "session-1", + RunID: "run-1", + CheckpointID: "cp-current", + Drifted: true, + UpdatedAt: time.Now(), + }); err != nil { + t.Fatalf("SaveRunCheckpointBaseline() error = %v", err) + } + baseline, ok, err := shared.LoadRunCheckpointBaseline(context.Background(), "session-1", "run-1") + if err != nil { + t.Fatalf("LoadRunCheckpointBaseline() error = %v", err) + } + if !ok || baseline.CheckpointID != "cp-current" || !baseline.Drifted { + t.Fatalf("run baseline = %#v ok=%v", baseline, ok) + } +} diff --git a/internal/checkpoint/per_edit_snapshot.go b/internal/checkpoint/per_edit_snapshot.go index 9efa78d2..0580bc5c 100644 --- a/internal/checkpoint/per_edit_snapshot.go +++ b/internal/checkpoint/per_edit_snapshot.go @@ -367,9 +367,9 @@ func (s *PerEditSnapshotStore) Reset() { // guardID 为 pre-restore 固化的快照(restoreCheckpointCore 中的 guard checkpoint), // 用于对比确定每个文件的目标操作;guardID 为空时仅处理 target checkpoint 内的文件。 // -// 对比逻辑:对 target 与 guard 中出现的每个文件,分别计算"目标状态"与"当前状态", -// 据此执行写回 / 删除 / 跳过,覆盖文件创建、修改、删除三种变更方向。 -func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID string) error { +// 对比逻辑:存在 relatedCheckpointIDs 时先收敛到 target 与后续 checkpoint 的净变化路径, +// 再分别计算"目标状态"与"当前状态"并执行写回 / 删除 / 跳过。 +func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID string, relatedCheckpointIDs ...string) error { targetCP, err := s.readCheckpointMeta(targetID) if err != nil { return err @@ -378,11 +378,6 @@ func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID st s.indexMu.Lock() defer s.indexMu.Unlock() - hashSet := make(map[string]struct{}, len(targetCP.FileVersions)) - for h := range targetCP.FileVersions { - hashSet[h] = struct{}{} - } - var guardCP CheckpointMeta hasGuard := guardID != "" if hasGuard { @@ -390,15 +385,29 @@ func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID st if err != nil { return err } - for h := range guardCP.FileVersions { - hashSet[h] = struct{}{} + } + relatedCPs := make([]CheckpointMeta, 0, len(relatedCheckpointIDs)) + for _, checkpointID := range relatedCheckpointIDs { + if strings.TrimSpace(checkpointID) == "" || checkpointID == targetID || checkpointID == guardID { + continue + } + relatedCP, err := s.readCheckpointMeta(checkpointID) + if err != nil { + return err } + relatedCPs = append(relatedCPs, relatedCP) } - // 无论有无 guard,都必须合并全量 pathToVersions。 - // guard 是 pending-only 的,不包含此前创建的、本 turn 未触碰的新文件; - // 不合并则这些文件在 restore 后仍会残留。 - for h := range s.pathToVersions { - hashSet[h] = struct{}{} + + hashSet := make(map[string]struct{}) + if len(relatedCPs) > 0 { + if err := s.collectRestoreChangedHashesLocked(hashSet, targetCP, relatedCPs); err != nil { + return err + } + } else { + addCheckpointHashes(hashSet, targetCP) + } + if hasGuard { + addCheckpointHashes(hashSet, guardCP) } for hash := range hashSet { @@ -407,7 +416,7 @@ func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID st } // 目标状态:target checkpoint 时刻文件应如何。 - toContent, toIsDir, toExists, toMode, toDisplay, err := s.contentAtCheckpointLocked(hash, targetCP.FileVersions, false) + toContent, toIsDir, toExists, toMode, toDisplay, err := s.contentAtCheckpointStateLocked(hash, targetCP, false) if err != nil { return err } @@ -418,7 +427,7 @@ func (s *PerEditSnapshotStore) Restore(ctx context.Context, targetID, guardID st var fromMode os.FileMode var fromDisplay string if hasGuard { - fromContent, fromIsDir, fromExists, fromMode, fromDisplay, err = s.contentAtCheckpointLocked(hash, guardCP.FileVersions, true) + fromContent, fromIsDir, fromExists, fromMode, fromDisplay, err = s.contentAtCheckpointStateLocked(hash, guardCP, true) if err != nil { return err } @@ -492,10 +501,26 @@ func (s *PerEditSnapshotStore) RestoreExact(ctx context.Context, checkpointID st s.indexMu.Lock() defer s.indexMu.Unlock() - for hash, vAt := range cp.FileVersions { + hashSet := make(map[string]struct{}, len(cp.FileVersions)+len(cp.ExactFileVersions)) + for hash := range cp.FileVersions { + hashSet[hash] = struct{}{} + } + for hash := range cp.ExactFileVersions { + hashSet[hash] = struct{}{} + } + hashes := make([]string, 0, len(hashSet)) + for hash := range hashSet { + hashes = append(hashes, hash) + } + sort.Strings(hashes) + for _, hash := range hashes { if err := ctx.Err(); err != nil { return err } + vAt, ok := cp.ExactFileVersions[hash] + if !ok { + vAt = cp.FileVersions[hash] + } meta, err := s.readVersionMeta(hash, vAt) if err != nil { return fmt.Errorf("per-edit: read meta v%d: %w", vAt, err) @@ -1311,6 +1336,82 @@ func readWorkdirMode(absPath string) os.FileMode { return info.Mode() } +// addCheckpointHashes 把 checkpoint 中显式记录的文件 hash 加入 restore 候选集合。 +func addCheckpointHashes(hashSet map[string]struct{}, cp CheckpointMeta) { + for hash := range cp.FileVersions { + hashSet[hash] = struct{}{} + } + for hash := range cp.ExactFileVersions { + hashSet[hash] = struct{}{} + } +} + +// collectRestoreChangedHashesLocked 只收集 target 到后续 checkpoint 之间净变化的路径。 +// target 可作为全工作区重锚点;真正 restore 时不应因为 target 是全量快照而误碰无关文件。 +func (s *PerEditSnapshotStore) collectRestoreChangedHashesLocked( + hashSet map[string]struct{}, + targetCP CheckpointMeta, + relatedCPs []CheckpointMeta, +) error { + if len(relatedCPs) == 0 { + return nil + } + latestCP := relatedCPs[0] + candidateHashes := make(map[string]struct{}) + addCheckpointHashes(candidateHashes, targetCP) + for _, cp := range relatedCPs { + addCheckpointHashes(candidateHashes, cp) + if cp.CreatedAt.After(latestCP.CreatedAt) { + latestCP = cp + } + } + for hash := range candidateHashes { + targetContent, targetIsDir, targetExists, targetMode, _, err := s.contentAtCheckpointStateLocked(hash, targetCP, false) + if err != nil { + return err + } + latestContent, latestIsDir, latestExists, latestMode, _, err := s.contentAtCheckpointStateLocked(hash, latestCP, false) + if err != nil { + return err + } + if targetExists != latestExists || + targetIsDir != latestIsDir || + targetMode != latestMode || + !bytes.Equal(targetContent, latestContent) { + hashSet[hash] = struct{}{} + } + } + return nil +} + +// contentAtCheckpointStateLocked 读取 checkpoint 结束态;新 checkpoint 优先使用 ExactFileVersions。 +// 旧 checkpoint 没有 exact 版本时,继续使用 v_next 兼容语义。 +func (s *PerEditSnapshotStore) contentAtCheckpointStateLocked( + hash string, + cp CheckpointMeta, + fallbackIfMissing bool, +) ([]byte, bool, bool, os.FileMode, string, error) { + if version, ok := cp.ExactFileVersions[hash]; ok { + meta, err := s.readVersionMeta(hash, version) + if err != nil { + return nil, false, false, 0, "", fmt.Errorf("per-edit: read exact meta v%d for %s: %w", version, hash, err) + } + display := s.resolveDisplayPathLocked(hash, meta.DisplayPath) + if !meta.Existed { + return nil, false, false, meta.Mode, display, nil + } + if meta.IsDir { + return nil, true, true, meta.Mode, display, nil + } + content, err := s.readVersionBin(hash, version) + if err != nil { + return nil, false, false, 0, display, fmt.Errorf("per-edit: read exact bin v%d for %s: %w", version, hash, err) + } + return content, false, true, meta.Mode, display, nil + } + return s.contentAtCheckpointLocked(hash, cp.FileVersions, fallbackIfMissing) +} + // contentAtCheckpointLocked 计算 hash 在某个 checkpoint 时刻的 workdir 内容。 // 在 cp.FileVersions 中:找下一版本读 .bin(或 Existed=false 时返回 nil); // 没有下一版本时:以当前 workdir 实际内容为准。 diff --git a/internal/checkpoint/per_edit_snapshot_test.go b/internal/checkpoint/per_edit_snapshot_test.go index 27b7d94b..1ff35005 100644 --- a/internal/checkpoint/per_edit_snapshot_test.go +++ b/internal/checkpoint/per_edit_snapshot_test.go @@ -1733,3 +1733,93 @@ func TestRunAggregateDiff_HistoricalFileFilteredByVersion(t *testing.T) { t.Fatalf("patch should NOT contain a.txt (filtered by version):\n%s", patch) } } + +func TestRestore_DoesNotUseGlobalHistoryWhenNoRelatedScope(t *testing.T) { + store, workdir := newTestStore(t) + + target := writeWorkdirFile(t, workdir, "target.txt", "old\n") + if _, err := store.CapturePreWrite(target); err != nil { + t.Fatalf("capture target: %v", err) + } + if err := os.WriteFile(target, []byte("new\n"), 0o644); err != nil { + t.Fatalf("write target: %v", err) + } + if _, err := store.FinalizeWithExactState("cp-target"); err != nil { + t.Fatalf("finalize target: %v", err) + } + store.Reset() + + unrelated := writeWorkdirFile(t, workdir, "unrelated.txt", "must stay\n") + if _, err := store.CapturePreWrite(unrelated); err != nil { + t.Fatalf("capture unrelated: %v", err) + } + if err := store.Restore(context.Background(), "cp-target", ""); err != nil { + t.Fatalf("restore cp-target: %v", err) + } + if got := mustReadFile(t, unrelated); got != "must stay\n" { + t.Fatalf("unrelated file should not be touched, got %q", got) + } +} + +func TestRestore_WithRelatedScopeSkipsUnchangedFilesFromFullRebase(t *testing.T) { + store, workdir := newTestStore(t) + + agentPath := writeWorkdirFile(t, workdir, "agent.txt", "baseline\n") + unrelatedPath := writeWorkdirFile(t, workdir, "unrelated.txt", "keep\n") + if _, err := store.CaptureBatch([]string{agentPath, unrelatedPath}); err != nil { + t.Fatalf("CaptureBatch baseline: %v", err) + } + if _, err := store.FinalizeWithExactState("cp-rebase"); err != nil { + t.Fatalf("FinalizeWithExactState(cp-rebase): %v", err) + } + store.Reset() + + if _, err := store.CapturePreWrite(agentPath); err != nil { + t.Fatalf("CapturePreWrite agent: %v", err) + } + if err := os.WriteFile(agentPath, []byte("agent edit\n"), 0o644); err != nil { + t.Fatalf("write agent edit: %v", err) + } + if _, err := store.FinalizeWithExactState("cp-agent"); err != nil { + t.Fatalf("FinalizeWithExactState(cp-agent): %v", err) + } + store.Reset() + + if err := os.WriteFile(unrelatedPath, []byte("user post-run edit\n"), 0o644); err != nil { + t.Fatalf("write unrelated post-run edit: %v", err) + } + if err := store.Restore(context.Background(), "cp-rebase", "", "cp-agent"); err != nil { + t.Fatalf("Restore(cp-rebase): %v", err) + } + if got := mustReadFile(t, agentPath); got != "baseline\n" { + t.Fatalf("agent file = %q, want baseline", got) + } + if got := mustReadFile(t, unrelatedPath); got != "user post-run edit\n" { + t.Fatalf("unrelated file should not be restored, got %q", got) + } +} + +func TestRestoreExact_PrefersExactCheckpointState(t *testing.T) { + store, workdir := newTestStore(t) + target := writeWorkdirFile(t, workdir, "exact.txt", "before\n") + if _, err := store.CapturePreWrite(target); err != nil { + t.Fatalf("capture: %v", err) + } + if err := os.WriteFile(target, []byte("checkpoint state\n"), 0o644); err != nil { + t.Fatalf("write checkpoint state: %v", err) + } + if _, err := store.FinalizeWithExactState("cp-exact"); err != nil { + t.Fatalf("FinalizeWithExactState: %v", err) + } + store.Reset() + + if err := os.WriteFile(target, []byte("drift\n"), 0o644); err != nil { + t.Fatalf("write drift: %v", err) + } + if err := store.RestoreExact(context.Background(), "cp-exact"); err != nil { + t.Fatalf("RestoreExact: %v", err) + } + if got := mustReadFile(t, target); got != "checkpoint state\n" { + t.Fatalf("RestoreExact should use exact state, got %q", got) + } +} diff --git a/internal/cli/gateway_runtime_bridge.go b/internal/cli/gateway_runtime_bridge.go index a4ec44f5..d8db9ca5 100644 --- a/internal/cli/gateway_runtime_bridge.go +++ b/internal/cli/gateway_runtime_bridge.go @@ -2456,6 +2456,8 @@ func (b *gatewayRuntimePortBridge) CheckpointDiff(ctx context.Context, input gat Deleted: result.Files.Deleted, Modified: result.Files.Modified, }, - Patch: result.Patch, + Patch: result.Patch, + WorkspaceDrifted: result.WorkspaceDrifted, + Warning: result.Warning, }, nil } diff --git a/internal/gateway/contracts.go b/internal/gateway/contracts.go index 69f6692e..533cac1e 100644 --- a/internal/gateway/contracts.go +++ b/internal/gateway/contracts.go @@ -476,6 +476,8 @@ type CheckpointDiffResult struct { PrevCommitHash string `json:"prev_commit_hash,omitempty"` Files FileDiffs `json:"files"` Patch string `json:"patch,omitempty"` + WorkspaceDrifted bool `json:"workspace_drifted,omitempty"` + Warning string `json:"warning,omitempty"` } // FileDiffs 描述 diff 中的文件变更列表。 diff --git a/internal/runtime/checkpoint_flow_test.go b/internal/runtime/checkpoint_flow_test.go index 3075ac75..44cd0a6d 100644 --- a/internal/runtime/checkpoint_flow_test.go +++ b/internal/runtime/checkpoint_flow_test.go @@ -11,6 +11,7 @@ import ( "neo-code/internal/checkpoint" providertypes "neo-code/internal/provider/types" + "neo-code/internal/repository" agentsession "neo-code/internal/session" ) @@ -207,6 +208,370 @@ func TestCreateStartOfTurnCheckpoint_NoPending_SessionOnly(t *testing.T) { } } +func TestCreatePreRunDriftRebaseCheckpoint_UsesDriftReasonAndPerEditRef(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + + absKeep := filepath.Join(fixture.workdir, "keep.txt") + if err := os.WriteFile(absKeep, []byte("keep"), 0o644); err != nil { + t.Fatalf("WriteFile(keep) error = %v", err) + } + absDeleted := filepath.Join(fixture.workdir, "deleted.txt") + if err := os.WriteFile(absDeleted, []byte("deleted"), 0o644); err != nil { + t.Fatalf("WriteFile(deleted) error = %v", err) + } + if err := os.Remove(absDeleted); err != nil { + t.Fatalf("Remove(deleted) error = %v", err) + } + + state := newRunState("run-drift", fixture.session) + checkpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint(context.Background(), &state, repository.FingerprintDiff{ + Deleted: []string{"deleted.txt"}, + }) + if err != nil { + t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err) + } + if checkpointID == "" { + t.Fatal("expected non-empty drift baseline checkpoint id") + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("expected 1 checkpoint, got %d", len(records)) + } + if records[0].Reason != agentsession.CheckpointReasonPreRunDriftRebase { + t.Fatalf("reason = %s, want %s", records[0].Reason, agentsession.CheckpointReasonPreRunDriftRebase) + } + if !checkpoint.IsPerEditRef(records[0].CodeCheckpointRef) { + t.Fatalf("code ref = %q, want peredit ref", records[0].CodeCheckpointRef) + } +} + +func TestCreatePreRunDriftRebaseCheckpoint_UsesEffectiveWorkdirWhenSessionWorkdirEmpty(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + + absPath := filepath.Join(fixture.workdir, "effective.txt") + if err := os.WriteFile(absPath, []byte("effective"), 0o644); err != nil { + t.Fatalf("WriteFile(effective) error = %v", err) + } + + session := fixture.session + session.Workdir = "" + state := newRunState("run-effective-workdir", session) + state.effectiveWorkdir = fixture.workdir + checkpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint(context.Background(), &state, repository.FingerprintDiff{ + Added: []string{"effective.txt"}, + }) + if err != nil { + t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err) + } + if checkpointID == "" { + t.Fatal("expected non-empty drift baseline checkpoint id") + } + + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("expected 1 checkpoint, got %d", len(records)) + } + if records[0].Workdir != fixture.workdir { + t.Fatalf("record workdir = %q, want %q", records[0].Workdir, fixture.workdir) + } + if records[0].WorkspaceKey != agentsession.WorkspacePathKey(fixture.workdir) { + t.Fatalf("workspace key = %q, want %q", records[0].WorkspaceKey, agentsession.WorkspacePathKey(fixture.workdir)) + } +} + +func TestRecordRunEndWorkspaceStateUsesEffectiveWorkdirWhenSessionWorkdirEmpty(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + + absPath := filepath.Join(fixture.workdir, "run-end.txt") + if err := os.WriteFile(absPath, []byte("run end"), 0o644); err != nil { + t.Fatalf("WriteFile(run-end) error = %v", err) + } + + session := fixture.session + session.Workdir = "" + state := newRunState("run-end-effective-workdir", session) + state.effectiveWorkdir = fixture.workdir + fixture.service.recordRunEndWorkspaceState( + context.Background(), + session.ID, + effectiveWorkdirForCheckpointState(&state, session), + "cp-current-effective", + ) + + workspaceKey := agentsession.WorkspacePathKey(fixture.workdir) + loaded, ok, err := fixture.checkpointStore.LoadWorkspaceCheckpointState(context.Background(), workspaceKey) + if err != nil { + t.Fatalf("LoadWorkspaceCheckpointState() error = %v", err) + } + if !ok { + t.Fatal("expected workspace checkpoint state to be persisted") + } + if loaded.CurrentCheckpointID != "cp-current-effective" { + t.Fatalf("current checkpoint id = %q, want cp-current-effective", loaded.CurrentCheckpointID) + } + if loaded.WorkspaceKey != workspaceKey { + t.Fatalf("workspace key = %q, want %q", loaded.WorkspaceKey, workspaceKey) + } +} + +func TestCheckpointDiff_ScopeRun_RecreatedFileAfterDriftBaselineIsAdded(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + state := newRunState("run-drift-add", fixture.session) + + absPath := filepath.Join(fixture.workdir, "recreate.txt") + if err := os.WriteFile(absPath, []byte("legacy"), 0o644); err != nil { + t.Fatalf("WriteFile(legacy) error = %v", err) + } + if err := os.Remove(absPath); err != nil { + t.Fatalf("Remove(recreate.txt) error = %v", err) + } + + rebasedCheckpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint( + context.Background(), + &state, + repository.FingerprintDiff{Deleted: []string{"recreate.txt"}}, + ) + if err != nil { + t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err) + } + if rebasedCheckpointID == "" { + t.Fatal("expected drift rebase checkpoint id") + } + + if _, err := fixture.perEditStore.CapturePreWrite(absPath); err != nil { + t.Fatalf("CapturePreWrite(recreate) error = %v", err) + } + if err := os.WriteFile(absPath, []byte("new"), 0o644); err != nil { + t.Fatalf("WriteFile(new) error = %v", err) + } + if _, err := fixture.perEditStore.FinalizeWithExactState("cp-end"); err != nil { + t.Fatalf("FinalizeWithExactState(cp-end) error = %v", err) + } + fixture.perEditStore.Reset() + if err := fixture.service.createCheckpointRecord( + context.Background(), + fixture.session, + state.runID, + &state, + "cp-end", + agentsession.CheckpointReasonEndOfTurn, + ); err != nil { + t.Fatalf("createCheckpointRecord(cp-end) error = %v", err) + } + + fixture.service.setRunWorkspaceDrift(fixture.session.ID, state.runID, true) + fixture.service.setRunRollbackBaseline(fixture.session.ID, state.runID, rebasedCheckpointID) + result, err := fixture.service.CheckpointDiff(context.Background(), CheckpointDiffInput{ + SessionID: fixture.session.ID, + Scope: "run", + RunID: state.runID, + CheckpointID: "cp-end", + }) + if err != nil { + t.Fatalf("CheckpointDiff(scope=run) error = %v", err) + } + if result.PrevCheckpointID != rebasedCheckpointID { + t.Fatalf("PrevCheckpointID = %q, want %q", result.PrevCheckpointID, rebasedCheckpointID) + } + if len(result.Files.Added) != 1 || result.Files.Added[0] != "recreate.txt" { + t.Fatalf("added files = %+v, want recreate.txt", result.Files.Added) + } + if len(result.Files.Modified) != 0 { + t.Fatalf("modified files = %+v, want empty", result.Files.Modified) + } +} + +func TestCheckpointDiff_ScopeRun_LoadsPersistedBaselineAfterCacheClear(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + state := newRunState("run-persisted-baseline", fixture.session) + + absPath := filepath.Join(fixture.workdir, "persisted-baseline.txt") + if err := os.WriteFile(absPath, []byte("manual\n"), 0o644); err != nil { + t.Fatalf("WriteFile(manual) error = %v", err) + } + rebasedCheckpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint( + context.Background(), + &state, + repository.FingerprintDiff{Added: []string{"persisted-baseline.txt"}}, + ) + if err != nil { + t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err) + } + fixture.service.persistRunRollbackBaseline(context.Background(), fixture.session.ID, state.runID, rebasedCheckpointID, true) + fixture.service.clearRunCheckpointCaches(fixture.session.ID, state.runID) + + if _, err := fixture.perEditStore.CapturePreWrite(absPath); err != nil { + t.Fatalf("CapturePreWrite(agent) error = %v", err) + } + if err := os.WriteFile(absPath, []byte("agent\n"), 0o644); err != nil { + t.Fatalf("WriteFile(agent) error = %v", err) + } + if _, err := fixture.perEditStore.FinalizeWithExactState("cp-end-persisted"); err != nil { + t.Fatalf("FinalizeWithExactState(cp-end-persisted) error = %v", err) + } + fixture.perEditStore.Reset() + if err := fixture.service.createCheckpointRecord( + context.Background(), + fixture.session, + state.runID, + &state, + "cp-end-persisted", + agentsession.CheckpointReasonEndOfTurn, + ); err != nil { + t.Fatalf("createCheckpointRecord(cp-end-persisted) error = %v", err) + } + + result, err := fixture.service.CheckpointDiff(context.Background(), CheckpointDiffInput{ + SessionID: fixture.session.ID, + Scope: "run", + RunID: state.runID, + CheckpointID: "cp-end-persisted", + }) + if err != nil { + t.Fatalf("CheckpointDiff(scope=run) error = %v", err) + } + if result.PrevCheckpointID != rebasedCheckpointID { + t.Fatalf("PrevCheckpointID = %q, want %q", result.PrevCheckpointID, rebasedCheckpointID) + } + if !result.WorkspaceDrifted { + t.Fatal("expected persisted drift flag") + } +} + +func TestRestoreCheckpoint_AfterExternalModifyRestoresToRebasedState(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + state := newRunState("run-restore-rebased", fixture.session) + + absPath := filepath.Join(fixture.workdir, "external.txt") + if err := os.WriteFile(absPath, []byte("manual\n"), 0o644); err != nil { + t.Fatalf("WriteFile(manual) error = %v", err) + } + unrelatedPath := filepath.Join(fixture.workdir, "unrelated.txt") + if err := os.WriteFile(unrelatedPath, []byte("keep\n"), 0o644); err != nil { + t.Fatalf("WriteFile(unrelated) error = %v", err) + } + rebasedCheckpointID, err := fixture.service.createPreRunDriftRebaseCheckpoint( + context.Background(), + &state, + repository.FingerprintDiff{Modified: []string{"external.txt"}}, + ) + if err != nil { + t.Fatalf("createPreRunDriftRebaseCheckpoint() error = %v", err) + } + + if _, err := fixture.perEditStore.CapturePreWrite(absPath); err != nil { + t.Fatalf("CapturePreWrite(agent) error = %v", err) + } + if err := os.WriteFile(absPath, []byte("agent\n"), 0o644); err != nil { + t.Fatalf("WriteFile(agent) error = %v", err) + } + if _, err := fixture.perEditStore.FinalizeWithExactState("cp-agent-end"); err != nil { + t.Fatalf("FinalizeWithExactState(cp-agent-end) error = %v", err) + } + fixture.perEditStore.Reset() + if err := fixture.service.createCheckpointRecord( + context.Background(), + fixture.session, + state.runID, + &state, + "cp-agent-end", + agentsession.CheckpointReasonEndOfTurn, + ); err != nil { + t.Fatalf("createCheckpointRecord(cp-agent-end) error = %v", err) + } + if err := os.WriteFile(unrelatedPath, []byte("post-run user edit\n"), 0o644); err != nil { + t.Fatalf("WriteFile(unrelated post-run) error = %v", err) + } + + if _, err := fixture.service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{ + SessionID: fixture.session.ID, + CheckpointID: rebasedCheckpointID, + }); err != nil { + t.Fatalf("RestoreCheckpoint() error = %v", err) + } + data, err := os.ReadFile(absPath) + if err != nil { + t.Fatalf("ReadFile(restored) error = %v", err) + } + if got := string(data); got != "manual\n" { + t.Fatalf("restored content = %q, want manual", got) + } + unrelatedData, err := os.ReadFile(unrelatedPath) + if err != nil { + t.Fatalf("ReadFile(unrelated) error = %v", err) + } + if got := string(unrelatedData); got != "post-run user edit\n" { + t.Fatalf("unrelated content = %q, want post-run user edit", got) + } +} + +func TestRecordRunStartFingerprint_DetectsDriftAcrossSessionsByWorkspace(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + service := fixture.service + ctx := context.Background() + + absPath := filepath.Join(fixture.workdir, "cross-session.txt") + if err := os.WriteFile(absPath, []byte("before"), 0o644); err != nil { + t.Fatalf("WriteFile(before) error = %v", err) + } + service.recordRunEndFingerprint(ctx, "sess-a", fixture.workdir) + + if err := os.Remove(absPath); err != nil { + t.Fatalf("Remove(cross-session.txt) error = %v", err) + } + + drifted, driftDiff := service.recordRunStartFingerprint(ctx, "sess-b", "run-b", fixture.workdir) + if !drifted { + t.Fatal("expected drift across sessions on the same workspace") + } + if !containsString(driftDiff.Deleted, "cross-session.txt") { + t.Fatalf("deleted diff = %#v, want cross-session.txt", driftDiff.Deleted) + } +} + +func TestRecordRunStartFingerprint_LoadsWorkspaceFingerprintFromPersistence(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + ctx := context.Background() + + absPath := filepath.Join(fixture.workdir, "persisted.txt") + if err := os.WriteFile(absPath, []byte("before"), 0o644); err != nil { + t.Fatalf("WriteFile(before) error = %v", err) + } + fixture.service.recordRunEndFingerprint(ctx, "sess-a", fixture.workdir) + + if err := os.Remove(absPath); err != nil { + t.Fatalf("Remove(persisted.txt) error = %v", err) + } + + // 模拟进程重启:新 Service 仅复用 checkpointStore,不复用内存指纹缓存。 + restarted := &Service{ + checkpointStore: fixture.checkpointStore, + } + drifted, driftDiff := restarted.recordRunStartFingerprint(ctx, "sess-b", "run-b", fixture.workdir) + if !drifted { + t.Fatal("expected drift when loading workspace fingerprint from persistence") + } + if !containsString(driftDiff.Deleted, "persisted.txt") { + t.Fatalf("deleted diff = %#v, want persisted.txt", driftDiff.Deleted) + } +} + +func containsString(items []string, target string) bool { + for _, item := range items { + if item == target { + return true + } + } + return false +} + func TestCreateEndOfTurnCheckpoint_NoWriteSkipped(t *testing.T) { fixture := newRuntimeCheckpointFixture(t) fixture.captureFile(t, "main.go", []byte("package main\n")) @@ -924,6 +1289,9 @@ func TestCheckpointDiffRunScopeAggregatesCurrentRun(t *testing.T) { if result.CheckpointID != "cp-2" { t.Fatalf("CheckpointID = %q, want cp-2", result.CheckpointID) } + if result.PrevCheckpointID != "" { + t.Fatalf("PrevCheckpointID = %q, want empty", result.PrevCheckpointID) + } if len(result.Files.Modified) != 1 || result.Files.Modified[0] != "tracked.txt" { t.Fatalf("modified files = %+v, want tracked.txt", result.Files.Modified) } @@ -941,7 +1309,7 @@ func mustReadRuntimeFile(t *testing.T, path string) []byte { return data } -// ──────── scope=run diff tests ──────── +// scope=run diff tests func TestCheckpointDiff_ScopeRun_ReturnsAggregateDiff(t *testing.T) { workdir := t.TempDir() @@ -988,6 +1356,15 @@ func TestCheckpointDiff_ScopeRun_ReturnsAggregateDiff(t *testing.T) { CreatedAt: now, CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-1"), }, + { + CheckpointID: "cp-0", + SessionID: "session-1", + RunID: "run-prev", + CreatedAt: now.Add(-time.Second), + Reason: agentsession.CheckpointReasonEndOfTurn, + Status: agentsession.CheckpointStatusAvailable, + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-0"), + }, }, } service := &Service{ @@ -1026,10 +1403,12 @@ func TestCheckpointDiff_ScopeRun_ReturnsAggregateDiff(t *testing.T) { if len(modifiedPaths) != 1 || modifiedPaths[0] != "a.txt" { t.Fatalf("expected a.txt modified, got added=%v modified=%v", addedPaths, modifiedPaths) } - // 当前 run-scope diff 默认返回目标 checkpoint(未显式指定时为最新 checkpoint)。 if result.CheckpointID != "cp-2" { t.Fatalf("CheckpointID = %q, want cp-2", result.CheckpointID) } + if result.PrevCheckpointID != "cp-0" { + t.Fatalf("PrevCheckpointID = %q, want cp-0", result.PrevCheckpointID) + } } func TestCheckpointDiff_ScopeRun_RejectsMissingRunID(t *testing.T) { @@ -1077,6 +1456,161 @@ func TestCheckpointDiff_ScopeRun_NoCheckpointsForRunID(t *testing.T) { } } +func TestCheckpointDiff_ScopeRun_RejectsTargetCheckpointFromAnotherRun(t *testing.T) { + spy := &checkpointStoreSpy{ + listRecords: []agentsession.CheckpointRecord{ + { + CheckpointID: "cp-other-run", + SessionID: "session-1", + RunID: "run-other", + CreatedAt: time.Now().UTC(), + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-other-run"), + }, + }, + } + service := &Service{ + checkpointStore: spy, + perEditStore: checkpoint.NewPerEditSnapshotStore(t.TempDir(), t.TempDir()), + } + _, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{ + SessionID: "session-1", + Scope: "run", + RunID: "run-target", + CheckpointID: "cp-other-run", + }) + if err == nil { + t.Fatal("expected error for target checkpoint from another run") + } + if !strings.Contains(err.Error(), "does not belong to run") { + t.Fatalf("error = %v, want run mismatch", err) + } +} + +func TestCheckpointDiff_ScopeRun_WarnsWhenBaselineMissing(t *testing.T) { + workdir := t.TempDir() + projectDir := t.TempDir() + store := checkpoint.NewPerEditSnapshotStore(projectDir, workdir) + now := time.Now().UTC() + + absA := filepath.Join(workdir, "a.txt") + _ = os.WriteFile(absA, []byte("old a\n"), 0o644) + if _, err := store.CapturePreWrite(absA); err != nil { + t.Fatalf("CapturePreWrite a: %v", err) + } + _ = os.WriteFile(absA, []byte("new a\n"), 0o644) + if _, err := store.Finalize("cp-1"); err != nil { + t.Fatalf("Finalize cp-1: %v", err) + } + store.Reset() + + spy := &checkpointStoreSpy{ + listRecords: []agentsession.CheckpointRecord{ + { + CheckpointID: "cp-1", + SessionID: "session-1", + RunID: "run-target", + CreatedAt: now, + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-1"), + }, + }, + } + service := &Service{ + checkpointStore: spy, + perEditStore: store, + } + + result, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{ + SessionID: "session-1", + Scope: "run", + RunID: "run-target", + }) + if err != nil { + t.Fatalf("CheckpointDiff(scope=run) error = %v", err) + } + if result.PrevCheckpointID != "" { + t.Fatalf("PrevCheckpointID = %q, want empty", result.PrevCheckpointID) + } + if !strings.Contains(result.Warning, "run baseline checkpoint is missing") { + t.Fatalf("Warning = %q, want missing-baseline warning", result.Warning) + } +} + +func TestCheckpointDiff_ScopeRun_PrefersRebasedBaselineWhenDrifted(t *testing.T) { + workdir := t.TempDir() + projectDir := t.TempDir() + store := checkpoint.NewPerEditSnapshotStore(projectDir, workdir) + now := time.Now().UTC() + + absA := filepath.Join(workdir, "a.txt") + _ = os.WriteFile(absA, []byte("one\n"), 0o644) + if _, err := store.CapturePreWrite(absA); err != nil { + t.Fatalf("CapturePreWrite cp-drift: %v", err) + } + _ = os.WriteFile(absA, []byte("two\n"), 0o644) + if _, err := store.Finalize("cp-drift"); err != nil { + t.Fatalf("Finalize cp-drift: %v", err) + } + store.Reset() + + if _, err := store.CapturePreWrite(absA); err != nil { + t.Fatalf("CapturePreWrite cp-target: %v", err) + } + _ = os.WriteFile(absA, []byte("three\n"), 0o644) + if _, err := store.Finalize("cp-target"); err != nil { + t.Fatalf("Finalize cp-target: %v", err) + } + store.Reset() + + spy := &checkpointStoreSpy{ + listRecords: []agentsession.CheckpointRecord{ + { + CheckpointID: "cp-target", + SessionID: "session-1", + RunID: "run-target", + CreatedAt: now.Add(time.Second), + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-target"), + }, + { + CheckpointID: "cp-drift", + SessionID: "session-1", + RunID: "run-target", + CreatedAt: now, + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-drift"), + }, + { + CheckpointID: "cp-old", + SessionID: "session-1", + RunID: "run-prev", + CreatedAt: now.Add(-time.Second), + Reason: agentsession.CheckpointReasonEndOfTurn, + Status: agentsession.CheckpointStatusAvailable, + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-old"), + }, + }, + } + service := &Service{ + checkpointStore: spy, + perEditStore: store, + } + service.setRunWorkspaceDrift("session-1", "run-target", true) + service.setRunRollbackBaseline("session-1", "run-target", "cp-drift") + + result, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{ + SessionID: "session-1", + Scope: "run", + RunID: "run-target", + }) + if err != nil { + t.Fatalf("CheckpointDiff(scope=run) error = %v", err) + } + if result.PrevCheckpointID != "cp-drift" { + t.Fatalf("PrevCheckpointID = %q, want cp-drift", result.PrevCheckpointID) + } + if !strings.Contains(result.Warning, "baseline re-anchored") { + t.Fatalf("Warning = %q, want re-anchored hint", result.Warning) + } +} + func TestCheckpointDiff_DefaultScopePreservesExistingBehavior(t *testing.T) { // Verify empty scope still uses checkpoint-to-checkpoint comparison. workdir := t.TempDir() diff --git a/internal/runtime/checkpoint_gate.go b/internal/runtime/checkpoint_gate.go index bae5f42f..db4dfd5c 100644 --- a/internal/runtime/checkpoint_gate.go +++ b/internal/runtime/checkpoint_gate.go @@ -5,10 +5,13 @@ import ( "encoding/json" "fmt" "log" + "path/filepath" + "sort" "strings" "time" "neo-code/internal/checkpoint" + "neo-code/internal/repository" agentsession "neo-code/internal/session" ) @@ -38,6 +41,73 @@ func (s *Service) createStartOfTurnCheckpoint(ctx context.Context, state *runSta return s.createCheckpointRecord(ctx, session, runID, state, checkpointID, agentsession.CheckpointReasonPreWrite) } +// createPreRunDriftRebaseCheckpoint 在 run 开始前检测到外部漂移时,固化当前工作区为权威基线。 +func (s *Service) createPreRunDriftRebaseCheckpoint( + ctx context.Context, + state *runState, + diff repository.FingerprintDiff, +) (string, error) { + if s.checkpointStore == nil || s.perEditStore == nil { + return "", nil + } + + state.mu.Lock() + session := state.session + runID := state.runID + workdir := effectiveWorkdirForCheckpointState(state, session) + state.mu.Unlock() + + if workdir == "" { + return "", fmt.Errorf("checkpoint: workdir is empty when creating drift rebase checkpoint") + } + + currentFingerprint, _, err := repository.ScanWorkdir(ctx, workdir, repository.DefaultFingerprintOptions()) + if err != nil { + return "", fmt.Errorf("checkpoint: scan workdir for drift rebase: %w", err) + } + absPaths := make([]string, 0, len(currentFingerprint)) + for relPath := range currentFingerprint { + absPaths = append(absPaths, filepath.Join(workdir, filepath.FromSlash(relPath))) + } + sort.Strings(absPaths) + if len(absPaths) > 0 { + if _, err := s.perEditStore.CaptureBatch(absPaths); err != nil { + return "", fmt.Errorf("checkpoint: capture current files for drift rebase: %w", err) + } + } + + deleted := append([]string(nil), diff.Deleted...) + sort.Strings(deleted) + for _, relPath := range deleted { + absPath := filepath.Join(workdir, filepath.FromSlash(relPath)) + if _, err := s.perEditStore.CapturePreWrite(absPath); err != nil { + return "", fmt.Errorf("checkpoint: capture deleted file for drift rebase (%s): %w", relPath, err) + } + } + + checkpointID := agentsession.NewID("checkpoint") + written, err := s.perEditStore.FinalizeWithExactState(checkpointID) + if err != nil { + return "", fmt.Errorf("checkpoint: finalize drift rebase checkpoint: %w", err) + } + if !written { + return "", fmt.Errorf("checkpoint: drift rebase checkpoint has no captured files") + } + defer s.perEditStore.Reset() + + if err := s.createCheckpointRecord( + ctx, + session, + runID, + state, + checkpointID, + agentsession.CheckpointReasonPreRunDriftRebase, + ); err != nil { + return "", err + } + return checkpointID, nil +} + // createEndOfTurnCheckpoint 在工具执行完成后创建代码检查点。 // hasWorkspaceWrite=false 时不创建(避免空 checkpoint);为 true 时 Finalize 当前 pending。 // 失败仅 log,不阻塞主流程。 @@ -95,7 +165,7 @@ func (s *Service) createCheckpointRecord( return fmt.Errorf("checkpoint: marshal messages: %w", err) } - effectiveWorkdir := strings.TrimSpace(session.Workdir) + effectiveWorkdir := effectiveWorkdirForCheckpointState(state, session) now := time.Now() ref := checkpoint.RefForPerEditCheckpoint(checkpointID) @@ -159,12 +229,13 @@ func (s *Service) createSessionOnlyCheckpoint( return fmt.Errorf("checkpoint: marshal session-only messages: %w", err) } + effectiveWorkdir := effectiveWorkdirForCheckpointState(state, session) record := agentsession.CheckpointRecord{ CheckpointID: checkpointID, - WorkspaceKey: agentsession.WorkspacePathKey(session.Workdir), + WorkspaceKey: agentsession.WorkspacePathKey(effectiveWorkdir), SessionID: session.ID, RunID: runID, - Workdir: session.Workdir, + Workdir: effectiveWorkdir, CreatedAt: now, Reason: reason, Restorable: true, @@ -224,3 +295,13 @@ func (s *Service) findPreviousEndOfTurnCheckpoint(ctx context.Context, sessionID } return "" } + +// effectiveWorkdirForCheckpointState 返回 checkpoint 流程应使用的工作目录,优先使用本次 run 已归一化的目录。 +func effectiveWorkdirForCheckpointState(state *runState, session agentsession.Session) string { + if state != nil { + if workdir := strings.TrimSpace(state.effectiveWorkdir); workdir != "" { + return workdir + } + } + return strings.TrimSpace(session.Workdir) +} diff --git a/internal/runtime/checkpoint_restore.go b/internal/runtime/checkpoint_restore.go index 4ba43124..ccc8366a 100644 --- a/internal/runtime/checkpoint_restore.go +++ b/internal/runtime/checkpoint_restore.go @@ -86,7 +86,9 @@ func (s *Service) restoreCheckpointCore(ctx context.Context, sessionID, checkpoi return RestoreResult{}, agentsession.CheckpointRecord{}, fmt.Errorf("checkpoint: create guard: %w", guardErr) } - // 3. Restore code via per-edit store(不在 cp.FileVersions 中的文件保持不变)。 + relatedPerEditIDs := s.restoreRelatedPerEditIDs(ctx, sessionID, record.CreatedAt) + + // 3. Restore code via per-edit store(仅处理目标、guard 与后续相关 checkpoint 涉及的文件)。 // Guard checkpoint 恢复时使用 RestoreExact:guard 中存储的 version 就是 restore 前的 pre-write 状态, // 而 Restore 的 v_next 语义在 guard 上通常是 no-op(guard 之后没有新的 capture)。 isGuardRestore := record.Reason == agentsession.CheckpointReasonGuard @@ -102,7 +104,7 @@ func (s *Service) restoreCheckpointCore(ctx context.Context, sessionID, checkpoi if guardWritten { guardCheckpointID = guardID } - if err := s.perEditStore.Restore(ctx, perEditID, guardCheckpointID); err != nil { + if err := s.perEditStore.Restore(ctx, perEditID, guardCheckpointID, relatedPerEditIDs...); err != nil { return RestoreResult{}, agentsession.CheckpointRecord{}, fmt.Errorf("checkpoint: restore code: %w", err) } } @@ -149,6 +151,7 @@ func (s *Service) restoreCheckpointCore(ctx context.Context, sessionID, checkpoi // 7. Update runtime session if it's the current session s.updateRuntimeSessionAfterRestore(sessionID, head, messages) + s.recordRunEndWorkspaceState(ctx, sessionID, head.Workdir, checkpointID) return RestoreResult{ CheckpointID: checkpointID, @@ -156,6 +159,34 @@ func (s *Service) restoreCheckpointCore(ctx context.Context, sessionID, checkpoi }, guardRecord, nil } +// restoreRelatedPerEditIDs 收集目标 checkpoint 之后仍 available 的代码 checkpoint,用于限定 restore 清理范围。 +func (s *Service) restoreRelatedPerEditIDs(ctx context.Context, sessionID string, targetCreatedAt time.Time) []string { + if s.checkpointStore == nil { + return nil + } + records, err := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{}) + if err != nil { + return nil + } + out := make([]string, 0) + for _, record := range records { + if !record.CreatedAt.After(targetCreatedAt) { + continue + } + if record.Status != "" && record.Status != agentsession.CheckpointStatusAvailable { + continue + } + if record.Reason == agentsession.CheckpointReasonGuard { + continue + } + perEditID := checkpoint.PerEditCheckpointIDFromRef(record.CodeCheckpointRef) + if perEditID != "" { + out = append(out, perEditID) + } + } + return out +} + // RestoreCheckpoint 恢复指定 checkpoint 的会话和工作区状态,并发出 checkpoint_restored 事件。 func (s *Service) RestoreCheckpoint(ctx context.Context, input GatewayRestoreInput) (RestoreResult, error) { result, guardRecord, err := s.restoreCheckpointCore(ctx, input.SessionID, input.CheckpointID) @@ -315,6 +346,8 @@ type CheckpointDiffResult struct { PrevCommitHash string `json:"prev_commit_hash,omitempty"` Files FileDiffs `json:"files"` Patch string `json:"patch,omitempty"` + WorkspaceDrifted bool `json:"workspace_drifted,omitempty"` + Warning string `json:"warning,omitempty"` } // FileDiffs 描述 diff 中的文件变更列表。 @@ -452,6 +485,9 @@ func (s *Service) checkpointDiffForRun(ctx context.Context, input CheckpointDiff if runID == "" { return CheckpointDiffResult{}, fmt.Errorf("checkpoint: run_id required for run scope diff") } + if targetRecord != nil && strings.TrimSpace(targetRecord.RunID) != runID { + return CheckpointDiffResult{}, fmt.Errorf("checkpoint: target checkpoint %s does not belong to run %s", targetRecord.CheckpointID, runID) + } codeRecords := make([]agentsession.CheckpointRecord, 0) for _, record := range records { @@ -479,6 +515,11 @@ func (s *Service) checkpointDiffForRun(ctx context.Context, input CheckpointDiff targetRecord = &codeRecords[len(codeRecords)-1] } + prevCheckpointID, workspaceDrifted := s.getPersistentRunRollbackBaseline(ctx, sessionID, runID) + if prevCheckpointID == "" { + prevCheckpointID = resolveRunPrevCheckpointID(records, codeRecords[0].CreatedAt, runID) + } + perEditIDs := make([]string, 0, len(codeRecords)) for _, record := range codeRecords { perEditID := checkpoint.PerEditCheckpointIDFromRef(record.CodeCheckpointRef) @@ -492,7 +533,30 @@ func (s *Service) checkpointDiffForRun(ctx context.Context, input CheckpointDiff return CheckpointDiffResult{}, fmt.Errorf("checkpoint: per-edit run diff: %w", err) } - result := CheckpointDiffResult{CheckpointID: targetRecord.CheckpointID, Patch: patch} + result := CheckpointDiffResult{ + CheckpointID: targetRecord.CheckpointID, + PrevCheckpointID: prevCheckpointID, + Patch: patch, + WorkspaceDrifted: workspaceDrifted, + } + if result.WorkspaceDrifted { + if prevCheckpointID != "" { + result.Warning = "workspace drift detected before run start; baseline re-anchored" + } else { + result.Warning = "workspace drift detected before run start" + } + } + if result.PrevCheckpointID == "" { + if result.Warning != "" { + result.Warning += "; run baseline checkpoint is missing" + } else { + result.Warning = "run baseline checkpoint is missing" + } + } + if result.PrevCheckpointID != "" { + s.persistRunRollbackBaseline(ctx, sessionID, runID, result.PrevCheckpointID, result.WorkspaceDrifted) + } + for _, c := range changes { switch c.Kind { case checkpoint.FileChangeAdded: @@ -505,3 +569,45 @@ func (s *Service) checkpointDiffForRun(ctx context.Context, input CheckpointDiff } return result, nil } + +// resolveRunPrevCheckpointID 解析 run 级 diff 的权威 baseline checkpoint_id。 +// 优先选择 run 开始前最近的 end_of_turn 代码 checkpoint;若不存在,退化为最近的可用代码 checkpoint。 +func resolveRunPrevCheckpointID(records []agentsession.CheckpointRecord, runStartAt time.Time, runID string) string { + var preferred *agentsession.CheckpointRecord + var fallback *agentsession.CheckpointRecord + for i := range records { + record := records[i] + if strings.TrimSpace(record.RunID) == strings.TrimSpace(runID) { + continue + } + if record.Status != "" && record.Status != agentsession.CheckpointStatusAvailable { + continue + } + if !checkpoint.IsPerEditRef(record.CodeCheckpointRef) { + continue + } + if record.Reason == agentsession.CheckpointReasonGuard { + continue + } + if !record.CreatedAt.Before(runStartAt) { + continue + } + if fallback == nil || record.CreatedAt.After(fallback.CreatedAt) { + candidate := record + fallback = &candidate + } + if record.Reason == agentsession.CheckpointReasonEndOfTurn { + if preferred == nil || record.CreatedAt.After(preferred.CreatedAt) { + candidate := record + preferred = &candidate + } + } + } + if preferred != nil { + return preferred.CheckpointID + } + if fallback != nil { + return fallback.CheckpointID + } + return "" +} diff --git a/internal/runtime/checkpoint_run_baseline.go b/internal/runtime/checkpoint_run_baseline.go new file mode 100644 index 00000000..a299d440 --- /dev/null +++ b/internal/runtime/checkpoint_run_baseline.go @@ -0,0 +1,277 @@ +package runtime + +import ( + "context" + "encoding/json" + "strings" + "time" + + "neo-code/internal/checkpoint" + "neo-code/internal/repository" + agentsession "neo-code/internal/session" +) + +// buildRunBaselineKey 生成 run 级缓存 key,保证同一会话不同 run 的基线隔离。 +func buildRunBaselineKey(sessionID, runID string) string { + return strings.TrimSpace(sessionID) + "\n" + strings.TrimSpace(runID) +} + +// ensureRunBaselineMapsLocked 懒初始化 run 级基线与漂移缓存,允许测试中用字面量 Service 直接调用。 +func (s *Service) ensureRunBaselineMapsLocked() { + if s.runRollbackBaselineByRunKey == nil { + s.runRollbackBaselineByRunKey = make(map[string]string) + } + if s.runWorkspaceDriftByRunKey == nil { + s.runWorkspaceDriftByRunKey = make(map[string]bool) + } + if s.lastRunFingerprintByWorkspaceKey == nil { + s.lastRunFingerprintByWorkspaceKey = make(map[string]repository.WorkdirFingerprint) + } +} + +type workspaceFingerprintPersistenceStore interface { + SaveWorkspaceFingerprint(ctx context.Context, workspaceKey string, fingerprintPayload string, updatedAt time.Time) error + LoadWorkspaceFingerprint(ctx context.Context, workspaceKey string) (string, bool, error) +} + +type workspaceCheckpointStatePersistenceStore interface { + SaveWorkspaceCheckpointState(ctx context.Context, state checkpoint.WorkspaceCheckpointState) error + LoadWorkspaceCheckpointState(ctx context.Context, workspaceKey string) (checkpoint.WorkspaceCheckpointState, bool, error) + SaveRunCheckpointBaseline(ctx context.Context, baseline checkpoint.RunCheckpointBaseline) error + LoadRunCheckpointBaseline(ctx context.Context, sessionID, runID string) (checkpoint.RunCheckpointBaseline, bool, error) +} + +// workspaceKeyFromWorkdir 统一把工作目录映射为跨会话稳定 key。 +func workspaceKeyFromWorkdir(workdir string) string { + return strings.TrimSpace(agentsession.WorkspacePathKey(strings.TrimSpace(workdir))) +} + +// setRunRollbackBaseline 记录当前 run 的权威回退基线 checkpoint_id(由后端计算,前端仅消费)。 +func (s *Service) setRunRollbackBaseline(sessionID, runID, checkpointID string) { + key := buildRunBaselineKey(sessionID, runID) + if key == "\n" { + return + } + s.rollbackBaselineMu.Lock() + defer s.rollbackBaselineMu.Unlock() + s.ensureRunBaselineMapsLocked() + if strings.TrimSpace(checkpointID) == "" { + delete(s.runRollbackBaselineByRunKey, key) + return + } + s.runRollbackBaselineByRunKey[key] = strings.TrimSpace(checkpointID) +} + +// getRunRollbackBaseline 返回 run 的权威回退基线 checkpoint_id。 +func (s *Service) getRunRollbackBaseline(sessionID, runID string) string { + key := buildRunBaselineKey(sessionID, runID) + if key == "\n" { + return "" + } + s.rollbackBaselineMu.Lock() + defer s.rollbackBaselineMu.Unlock() + s.ensureRunBaselineMapsLocked() + return s.runRollbackBaselineByRunKey[key] +} + +// getPersistentRunRollbackBaseline 从内存或 SQLite 读取 run 的权威回退基线。 +func (s *Service) getPersistentRunRollbackBaseline(ctx context.Context, sessionID, runID string) (string, bool) { + if baseline := s.getRunRollbackBaseline(sessionID, runID); baseline != "" { + return baseline, s.getRunWorkspaceDrift(sessionID, runID) + } + store, ok := s.checkpointStore.(workspaceCheckpointStatePersistenceStore) + if !ok { + return "", false + } + loaded, found, err := store.LoadRunCheckpointBaseline(ctx, strings.TrimSpace(sessionID), strings.TrimSpace(runID)) + if err != nil || !found { + return "", false + } + s.setRunRollbackBaseline(sessionID, runID, loaded.CheckpointID) + s.setRunWorkspaceDrift(sessionID, runID, loaded.Drifted) + return strings.TrimSpace(loaded.CheckpointID), loaded.Drifted +} + +// persistRunRollbackBaseline 双写 run 级回退基线,保证异步 diff 在 run 结束后仍可读取。 +func (s *Service) persistRunRollbackBaseline(ctx context.Context, sessionID, runID, checkpointID string, drifted bool) { + sessionID = strings.TrimSpace(sessionID) + runID = strings.TrimSpace(runID) + checkpointID = strings.TrimSpace(checkpointID) + if sessionID == "" || runID == "" || checkpointID == "" { + return + } + s.setRunRollbackBaseline(sessionID, runID, checkpointID) + s.setRunWorkspaceDrift(sessionID, runID, drifted) + store, ok := s.checkpointStore.(workspaceCheckpointStatePersistenceStore) + if !ok { + return + } + _ = store.SaveRunCheckpointBaseline(ctx, checkpoint.RunCheckpointBaseline{ + SessionID: sessionID, + RunID: runID, + CheckpointID: checkpointID, + Drifted: drifted, + UpdatedAt: time.Now(), + }) +} + +// setRunWorkspaceDrift 标记当前 run 是否检测到空闲期工作区漂移。 +func (s *Service) setRunWorkspaceDrift(sessionID, runID string, drifted bool) { + key := buildRunBaselineKey(sessionID, runID) + if key == "\n" { + return + } + s.rollbackBaselineMu.Lock() + defer s.rollbackBaselineMu.Unlock() + s.ensureRunBaselineMapsLocked() + if drifted { + s.runWorkspaceDriftByRunKey[key] = true + return + } + delete(s.runWorkspaceDriftByRunKey, key) +} + +// getRunWorkspaceDrift 返回 run 是否检测到空闲期工作区漂移。 +func (s *Service) getRunWorkspaceDrift(sessionID, runID string) bool { + key := buildRunBaselineKey(sessionID, runID) + if key == "\n" { + return false + } + s.rollbackBaselineMu.Lock() + defer s.rollbackBaselineMu.Unlock() + s.ensureRunBaselineMapsLocked() + return s.runWorkspaceDriftByRunKey[key] +} + +// clearRunCheckpointCaches 清理 run 级 checkpoint 缓存,避免内存与跨 run 污染。 +func (s *Service) clearRunCheckpointCaches(sessionID, runID string) { + key := buildRunBaselineKey(sessionID, runID) + if key == "\n" { + return + } + s.rollbackBaselineMu.Lock() + defer s.rollbackBaselineMu.Unlock() + s.ensureRunBaselineMapsLocked() + delete(s.runRollbackBaselineByRunKey, key) + delete(s.runWorkspaceDriftByRunKey, key) +} + +// recordRunStartFingerprint 在 run 开始时比对上次 run 结束指纹,返回是否发生空闲期漂移及详细差异。 +func (s *Service) recordRunStartFingerprint( + ctx context.Context, + sessionID, runID, workdir string, +) (bool, repository.FingerprintDiff) { + emptyDiff := repository.FingerprintDiff{} + normalizedRunID := strings.TrimSpace(runID) + normalizedWorkdir := strings.TrimSpace(workdir) + workspaceKey := workspaceKeyFromWorkdir(normalizedWorkdir) + if strings.TrimSpace(sessionID) == "" || normalizedRunID == "" || normalizedWorkdir == "" || workspaceKey == "" { + return false, emptyDiff + } + current, _, err := repository.ScanWorkdir(ctx, normalizedWorkdir, repository.DefaultFingerprintOptions()) + if err != nil { + return false, emptyDiff + } + + s.rollbackBaselineMu.Lock() + s.ensureRunBaselineMapsLocked() + previous, ok := s.lastRunFingerprintByWorkspaceKey[workspaceKey] + s.rollbackBaselineMu.Unlock() + var currentCheckpointID string + if !ok { + if store, okStore := s.checkpointStore.(workspaceCheckpointStatePersistenceStore); okStore { + state, found, loadErr := store.LoadWorkspaceCheckpointState(ctx, workspaceKey) + if loadErr == nil && found { + currentCheckpointID = strings.TrimSpace(state.CurrentCheckpointID) + var restored repository.WorkdirFingerprint + if err := json.Unmarshal([]byte(state.FingerprintPayload), &restored); err == nil { + s.rollbackBaselineMu.Lock() + s.ensureRunBaselineMapsLocked() + s.lastRunFingerprintByWorkspaceKey[workspaceKey] = restored + s.rollbackBaselineMu.Unlock() + previous = restored + ok = true + } + } + } + if !ok { + if store, okStore := s.checkpointStore.(workspaceFingerprintPersistenceStore); okStore { + payload, found, loadErr := store.LoadWorkspaceFingerprint(ctx, workspaceKey) + if loadErr == nil && found { + var restored repository.WorkdirFingerprint + if err := json.Unmarshal([]byte(payload), &restored); err == nil { + s.rollbackBaselineMu.Lock() + s.ensureRunBaselineMapsLocked() + s.lastRunFingerprintByWorkspaceKey[workspaceKey] = restored + s.rollbackBaselineMu.Unlock() + previous = restored + ok = true + } + } + } + } + if !ok { + return false, emptyDiff + } + } + diff := repository.DiffFingerprints(previous, current) + drifted := len(diff.Added) > 0 || len(diff.Modified) > 0 || len(diff.Deleted) > 0 + if drifted { + s.setRunWorkspaceDrift(sessionID, normalizedRunID, true) + return true, diff + } + if currentCheckpointID != "" { + s.persistRunRollbackBaseline(ctx, sessionID, normalizedRunID, currentCheckpointID, false) + } + return false, diff +} + +// recordRunEndFingerprint 在 run 结束后保存最新指纹,供下次 run 开始时进行漂移检测。 +func (s *Service) recordRunEndFingerprint(ctx context.Context, sessionID, workdir string) { + s.recordRunEndWorkspaceState(ctx, sessionID, workdir, "") +} + +// recordRunEndWorkspaceState 保存最新指纹和当前 checkpoint;checkpointID 为空时保留已有基线。 +func (s *Service) recordRunEndWorkspaceState(ctx context.Context, sessionID, workdir, checkpointID string) { + normalizedWorkdir := strings.TrimSpace(workdir) + workspaceKey := workspaceKeyFromWorkdir(normalizedWorkdir) + if strings.TrimSpace(sessionID) == "" || normalizedWorkdir == "" || workspaceKey == "" { + return + } + current, _, err := repository.ScanWorkdir(ctx, normalizedWorkdir, repository.DefaultFingerprintOptions()) + if err != nil { + return + } + s.rollbackBaselineMu.Lock() + s.ensureRunBaselineMapsLocked() + s.lastRunFingerprintByWorkspaceKey[workspaceKey] = current + s.rollbackBaselineMu.Unlock() + + store, ok := s.checkpointStore.(workspaceFingerprintPersistenceStore) + if !ok { + return + } + payload, err := json.Marshal(current) + if err != nil { + return + } + now := time.Now() + _ = store.SaveWorkspaceFingerprint(ctx, workspaceKey, string(payload), now) + + stateStore, ok := s.checkpointStore.(workspaceCheckpointStatePersistenceStore) + if !ok { + return + } + currentCheckpointID := strings.TrimSpace(checkpointID) + if currentCheckpointID == "" { + if existing, found, err := stateStore.LoadWorkspaceCheckpointState(ctx, workspaceKey); err == nil && found { + currentCheckpointID = strings.TrimSpace(existing.CurrentCheckpointID) + } + } + _ = stateStore.SaveWorkspaceCheckpointState(ctx, checkpoint.WorkspaceCheckpointState{ + WorkspaceKey: workspaceKey, + CurrentCheckpointID: currentCheckpointID, + FingerprintPayload: string(payload), + UpdatedAt: now, + }) +} diff --git a/internal/runtime/run.go b/internal/runtime/run.go index c4bdb542..6887ee55 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -135,6 +135,11 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { ChangedFiles: changedFiles, }) } + if statePtr != nil { + runEndCtx := context.Background() + s.recordRunEndWorkspaceState(runEndCtx, statePtr.session.ID, effectiveWorkdirForCheckpointState(statePtr, statePtr.session), statePtr.lastEndOfTurnCheckpointID) + s.clearRunCheckpointCaches(statePtr.session.ID, statePtr.runID) + } s.emitRunTermination(runCtx, input, statePtr, err) }() ctx = runCtx @@ -170,6 +175,8 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { } state := newRunState(input.RunID, session) + effectiveWorkdir := agentsession.EffectiveWorkdir(state.session.Workdir, initialCfg.Workdir) + state.effectiveWorkdir = effectiveWorkdir state.runToken = runToken state.thinkingOverride = cloneThinkingOverride(input.ThinkingOverride) state.disableTools = input.DisableTools @@ -190,7 +197,28 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { return s.handleRunError(err) } - effectiveWorkdir := agentsession.EffectiveWorkdir(state.session.Workdir, initialCfg.Workdir) + runBaselineCheckpointID := "" + if drifted, driftDiff := s.recordRunStartFingerprint(ctx, state.session.ID, state.runID, effectiveWorkdir); drifted { + if checkpointID, cpErr := s.createPreRunDriftRebaseCheckpoint(ctx, &state, driftDiff); cpErr != nil { + s.emitRunScopedOptional(EventCheckpointWarning, &state, CheckpointWarningPayload{ + Error: fmt.Sprintf("workspace drift detected but baseline rebase failed: %v", cpErr), + Phase: "run_start_drift_rebase", + }) + } else if checkpointID != "" { + runBaselineCheckpointID = checkpointID + s.persistRunRollbackBaseline(ctx, state.session.ID, state.runID, checkpointID, true) + s.recordRunEndWorkspaceState(ctx, state.session.ID, effectiveWorkdir, checkpointID) + } + s.emitRunScopedOptional(EventCheckpointWarning, &state, CheckpointWarningPayload{ + Error: "workspace drift detected before run start", + Phase: "run_start_drift", + }) + } + if runBaselineCheckpointID == "" { + if baseline, _ := s.getPersistentRunRollbackBaseline(ctx, state.session.ID, state.runID); baseline != "" { + runBaselineCheckpointID = baseline + } + } _ = s.runHookPoint(ctx, &state, runtimehooks.HookPointSessionStart, runtimehooks.HookContext{ Metadata: map[string]any{ "session_id": state.session.ID, @@ -223,7 +251,11 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { s.updateResumeCheckpoint(ctx, &state, "plan", "") maxTurns := resolveRuntimeMaxTurns(initialCfg.Runtime) - state.baselineCheckpointID = s.findPreviousEndOfTurnCheckpoint(ctx, sessionID, input.RunID) + if runBaselineCheckpointID != "" { + state.baselineCheckpointID = runBaselineCheckpointID + } else { + state.baselineCheckpointID = s.findPreviousEndOfTurnCheckpoint(ctx, sessionID, input.RunID) + } for turn := 0; ; turn++ { if turn >= maxTurns { state.maxTurnsReached = true diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 0da88617..f1db6601 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -193,19 +193,26 @@ type Service struct { events chan RuntimeEvent runtimeSnapshotMu sync.Mutex runtimeSnapshots map[string]RuntimeSnapshot - sessionMu sync.Mutex - sessionLocks map[string]*sessionLockEntry - runMu sync.Mutex - activeRunToken uint64 - nextRunToken uint64 - activeRunCancels map[uint64]context.CancelFunc - activeRunByID map[string]uint64 - activeRunTokenIDs map[uint64]string - activeRunStates map[uint64]*runState - permissionAskMapMu sync.Mutex - permissionAskLocks map[string]*permissionAskLockEntry - askStore AskSessionStore - askSequence uint64 + rollbackBaselineMu sync.Mutex + // runRollbackBaselineByRunKey 记录一次 run 的权威回退基线 checkpoint_id,key=session_id+"\n"+run_id。 + runRollbackBaselineByRunKey map[string]string + // runWorkspaceDriftByRunKey 标记 run 开始前工作区是否发生空闲期漂移,key=session_id+"\n"+run_id。 + runWorkspaceDriftByRunKey map[string]bool + // lastRunFingerprintByWorkspaceKey 保存每个工作区上一次 run 结束时的指纹,用于跨会话漂移检测。 + lastRunFingerprintByWorkspaceKey map[string]repository.WorkdirFingerprint + sessionMu sync.Mutex + sessionLocks map[string]*sessionLockEntry + runMu sync.Mutex + activeRunToken uint64 + nextRunToken uint64 + activeRunCancels map[uint64]context.CancelFunc + activeRunByID map[string]uint64 + activeRunTokenIDs map[uint64]string + activeRunStates map[uint64]*runState + permissionAskMapMu sync.Mutex + permissionAskLocks map[string]*permissionAskLockEntry + askStore AskSessionStore + askSequence uint64 thinkingEnabled bool @@ -261,24 +268,27 @@ func NewWithFactory( } service := &Service{ - configManager: configManager, - sessionStore: sessionStore, - toolManager: toolManager, - providerFactory: providerFactory, - contextBuilder: contextBuilder, - repositoryService: repository.NewService(), - approvalBroker: approval.NewBroker(), - askUserBroker: askuser.NewBroker(), - events: make(chan RuntimeEvent, 128), - runtimeSnapshots: make(map[string]RuntimeSnapshot), - sessionLocks: make(map[string]*sessionLockEntry), - permissionAskLocks: make(map[string]*permissionAskLockEntry), - activeRunCancels: make(map[uint64]context.CancelFunc), - activeRunByID: make(map[string]uint64), - activeRunTokenIDs: make(map[uint64]string), - activeRunStates: make(map[uint64]*runState), - askStore: newInMemoryAskSessionStore(askSessionTTL), - thinkingEnabled: true, + configManager: configManager, + sessionStore: sessionStore, + toolManager: toolManager, + providerFactory: providerFactory, + contextBuilder: contextBuilder, + repositoryService: repository.NewService(), + approvalBroker: approval.NewBroker(), + askUserBroker: askuser.NewBroker(), + events: make(chan RuntimeEvent, 128), + runtimeSnapshots: make(map[string]RuntimeSnapshot), + runRollbackBaselineByRunKey: make(map[string]string), + runWorkspaceDriftByRunKey: make(map[string]bool), + lastRunFingerprintByWorkspaceKey: make(map[string]repository.WorkdirFingerprint), + sessionLocks: make(map[string]*sessionLockEntry), + permissionAskLocks: make(map[string]*permissionAskLockEntry), + activeRunCancels: make(map[uint64]context.CancelFunc), + activeRunByID: make(map[string]uint64), + activeRunTokenIDs: make(map[uint64]string), + activeRunStates: make(map[uint64]*runState), + askStore: newInMemoryAskSessionStore(askSessionTTL), + thinkingEnabled: true, } baseHookExecutor := runtimehooks.NewExecutor(runtimehooks.NewRegistry(), newHookRuntimeEventEmitter(service), runtimehooks.DefaultHookTimeout) baseHookExecutor.SetAsyncResultSink(newHookAsyncResultSink(service)) diff --git a/internal/runtime/state.go b/internal/runtime/state.go index 6606f552..2440bddf 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -17,6 +17,7 @@ type runState struct { runID string runToken uint64 session agentsession.Session + effectiveWorkdir string compactCount int reactiveCompactAttempts int rememberedThisRun bool diff --git a/internal/session/checkpoint_types.go b/internal/session/checkpoint_types.go index d21ac79b..c0a4e62c 100644 --- a/internal/session/checkpoint_types.go +++ b/internal/session/checkpoint_types.go @@ -6,13 +6,14 @@ import "time" type CheckpointReason string const ( - CheckpointReasonPreWrite CheckpointReason = "pre_write" - CheckpointReasonCompact CheckpointReason = "compact" - CheckpointReasonPlanMode CheckpointReason = "plan_mode" - CheckpointReasonManual CheckpointReason = "manual" - CheckpointReasonGuard CheckpointReason = "pre_restore_guard" - CheckpointReasonPreWriteDegraded CheckpointReason = "pre_write_degraded" - CheckpointReasonEndOfTurn CheckpointReason = "end_of_turn" + CheckpointReasonPreWrite CheckpointReason = "pre_write" + CheckpointReasonCompact CheckpointReason = "compact" + CheckpointReasonPlanMode CheckpointReason = "plan_mode" + CheckpointReasonManual CheckpointReason = "manual" + CheckpointReasonGuard CheckpointReason = "pre_restore_guard" + CheckpointReasonPreRunDriftRebase CheckpointReason = "pre_run_drift_rebase" + CheckpointReasonPreWriteDegraded CheckpointReason = "pre_write_degraded" + CheckpointReasonEndOfTurn CheckpointReason = "end_of_turn" ) // CheckpointStatus 描述 checkpoint 的生命周期状态。 diff --git a/internal/tui/services/runtime_contract.go b/internal/tui/services/runtime_contract.go index 7c4d1cf8..8ce28ade 100644 --- a/internal/tui/services/runtime_contract.go +++ b/internal/tui/services/runtime_contract.go @@ -245,6 +245,8 @@ type CheckpointDiffResult struct { PrevCommitHash string `json:"prev_commit_hash,omitempty"` Files CheckpointDiffFiles `json:"files"` Patch string `json:"patch,omitempty"` + WorkspaceDrifted bool `json:"workspace_drifted,omitempty"` + Warning string `json:"warning,omitempty"` } // WorkspaceRecord 描述工作区登记信息。 diff --git a/web/src/App.test.tsx b/web/src/App.test.tsx new file mode 100644 index 00000000..cb65a5eb --- /dev/null +++ b/web/src/App.test.tsx @@ -0,0 +1,57 @@ +import { describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import App from './App' + +const runtimeState = { + status: 'connected', + mode: 'browser', + error: '', + loadingMessage: '', + retry: vi.fn(), +} + +vi.mock('./context/RuntimeProvider', () => ({ + useRuntime: () => runtimeState, +})) + +vi.mock('./pages/ChatPage', () => ({ + default: () =>
chat-page
, +})) + +vi.mock('./pages/ConnectPage', () => ({ + default: () =>
connect-page
, +})) + +describe('App routes by runtime status', () => { + it('renders loading screen', () => { + runtimeState.status = 'loading' + runtimeState.loadingMessage = 'booting' + render() + expect(screen.getByText('booting')).toBeInTheDocument() + }) + + it('renders connect page for browser needs_config', () => { + runtimeState.status = 'needs_config' + runtimeState.mode = 'browser' + render() + expect(screen.getByText('connect-page')).toBeInTheDocument() + }) + + it('renders electron error screen with retry', () => { + runtimeState.status = 'error' + runtimeState.mode = 'electron' + runtimeState.error = 'boom' + render() + expect(screen.getByText('Gateway connection failed')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'Retry connection' })) + expect(runtimeState.retry).toHaveBeenCalled() + }) + + it('renders chat routes when connected', () => { + runtimeState.status = 'connected' + runtimeState.mode = 'browser' + render() + expect(screen.getByText('chat-page')).toBeInTheDocument() + }) +}) + diff --git a/web/src/api/gateway.test.ts b/web/src/api/gateway.test.ts new file mode 100644 index 00000000..336b3f40 --- /dev/null +++ b/web/src/api/gateway.test.ts @@ -0,0 +1,62 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { GatewayAPI } from './gateway' +import { Method } from './protocol' + +describe('GatewayAPI', () => { + const call = vi.fn() + const ws = { call } as any + let api: GatewayAPI + + beforeEach(() => { + call.mockReset() + call.mockResolvedValue({ type: 'ack', payload: {} }) + api = new GatewayAPI(ws) + }) + + it('maps authenticate and run methods', async () => { + await api.authenticate('tok') + await api.run({ input_text: 'hello' }) + + expect(call).toHaveBeenNthCalledWith(1, Method.Authenticate, { token: 'tok' }) + expect(call).toHaveBeenNthCalledWith(2, Method.Run, { input_text: 'hello' }) + }) + + it('maps optional session_id in listModels', async () => { + await api.listModels() + await api.listModels('s1') + + expect(call).toHaveBeenNthCalledWith(1, Method.ListModels, undefined) + expect(call).toHaveBeenNthCalledWith(2, Method.ListModels, { session_id: 's1' }) + }) + + it('maps optional provider_id in setSessionModel', async () => { + await api.setSessionModel('s1', 'm1') + await api.setSessionModel('s1', 'm1', 'p1') + + expect(call).toHaveBeenNthCalledWith(1, Method.SetSessionModel, { session_id: 's1', model_id: 'm1' }) + expect(call).toHaveBeenNthCalledWith(2, Method.SetSessionModel, { session_id: 's1', model_id: 'm1', provider_id: 'p1' }) + }) + + it('maps workspace methods and optional remove_data', async () => { + await api.listWorkspaces() + await api.createWorkspace('/tmp/a', 'A') + await api.switchWorkspace('h1') + await api.renameWorkspace('h1', 'B') + await api.deleteWorkspace('h1', true) + + expect(call).toHaveBeenNthCalledWith(1, Method.ListWorkspaces) + expect(call).toHaveBeenNthCalledWith(2, Method.CreateWorkspace, { path: '/tmp/a', name: 'A' }) + expect(call).toHaveBeenNthCalledWith(3, Method.SwitchWorkspace, { workspace_hash: 'h1' }) + expect(call).toHaveBeenNthCalledWith(4, Method.RenameWorkspace, { workspace_hash: 'h1', name: 'B' }) + expect(call).toHaveBeenNthCalledWith(5, Method.DeleteWorkspace, { workspace_hash: 'h1', remove_data: true }) + }) + + it('maps permission and user question resolution', async () => { + await api.resolvePermission({ request_id: 'r1', decision: 'allow_once' }) + await api.resolveUserQuestion({ request_id: 'q1', status: 'answered', message: 'ok' }) + + expect(call).toHaveBeenNthCalledWith(1, Method.ResolvePermission, { request_id: 'r1', decision: 'allow_once' }) + expect(call).toHaveBeenNthCalledWith(2, Method.UserQuestionAnswer, { request_id: 'q1', status: 'answered', message: 'ok' }) + }) +}) + diff --git a/web/src/api/protocol.ts b/web/src/api/protocol.ts index fbe1011c..249b0dc1 100644 --- a/web/src/api/protocol.ts +++ b/web/src/api/protocol.ts @@ -81,6 +81,7 @@ export const EventType = { ToolStart: 'tool_start', ToolResult: 'tool_result', ToolDiff: 'tool_diff', + BashSideEffect: 'bash_side_effect', ToolChunk: 'tool_chunk', ToolCallThinking: 'tool_call_thinking', ThinkingDelta: 'thinking_delta', @@ -484,6 +485,8 @@ export interface CheckpointDiffResultPayload { prev_commit_hash?: string files: FileDiffs patch?: string + workspace_drifted?: boolean + warning?: string } export interface CheckpointRestoreResultPayload { @@ -916,6 +919,7 @@ export interface ToolDiffFileEntry { path: string diff?: string was_new?: boolean + kind?: string } /** tool_diff 事件载荷:写工具修改了哪些文件 */ @@ -928,3 +932,12 @@ export interface ToolDiffPayload { files?: ToolDiffFileChange[] diffs?: ToolDiffFileEntry[] } + +/** bash_side_effect 文件变更条目 */ +export interface BashSideEffectPayload { + tool_call_id: string + command?: string + changes?: ToolDiffFileChange[] + preemptively_captured_paths?: string[] + uncovered_paths?: string[] +} diff --git a/web/src/api/wsClient.test.ts b/web/src/api/wsClient.test.ts new file mode 100644 index 00000000..f107786c --- /dev/null +++ b/web/src/api/wsClient.test.ts @@ -0,0 +1,164 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { createWSClient } from './wsClient' +import { JSONRPC_VERSION, Method } from './protocol' + +class MockWebSocket { + static CONNECTING = 0 + static OPEN = 1 + static CLOSING = 2 + static CLOSED = 3 + + static instances: MockWebSocket[] = [] + + url: string + readyState = MockWebSocket.CONNECTING + onopen: (() => void) | null = null + onclose: (() => void) | null = null + onerror: (() => void) | null = null + onmessage: ((event: { data: string }) => void) | null = null + sent: string[] = [] + + constructor(url: string) { + this.url = url + MockWebSocket.instances.push(this) + } + + send(data: string) { + this.sent.push(data) + } + + close() { + this.readyState = MockWebSocket.CLOSED + this.onclose?.() + } + + open() { + this.readyState = MockWebSocket.OPEN + this.onopen?.() + } + + emit(data: unknown) { + this.onmessage?.({ data: typeof data === 'string' ? data : JSON.stringify(data) }) + } +} + +function latestWS(): MockWebSocket { + const ws = MockWebSocket.instances.at(-1) + if (!ws) throw new Error('websocket not created') + return ws +} + +describe('createWSClient', () => { + beforeEach(() => { + vi.useFakeTimers() + MockWebSocket.instances = [] + vi.stubGlobal('WebSocket', MockWebSocket as any) + }) + + afterEach(() => { + vi.clearAllTimers() + vi.useRealTimers() + vi.unstubAllGlobals() + }) + + it('builds ws url with endpoint and token', () => { + const client = createWSClient({ + baseURL: 'http://127.0.0.1:8080/', + endpoint: 'gateway/ws', + token: 'a b', + }) + client.connect() + expect(latestWS().url).toBe('ws://127.0.0.1:8080/gateway/ws?token=a%20b') + client.disconnect() + }) + + it('sends rpc request and resolves rpc response', async () => { + const client = createWSClient({ baseURL: 'http://127.0.0.1:8080', rpcTimeout: 50 }) + client.connect() + const ws = latestWS() + ws.open() + + const promise = client.call('gateway.ping', { x: 1 }) + const req = JSON.parse(ws.sent[0]) + expect(req).toMatchObject({ + jsonrpc: JSONRPC_VERSION, + method: 'gateway.ping', + params: { x: 1 }, + }) + + ws.emit({ jsonrpc: JSONRPC_VERSION, id: req.id, result: { ok: true } }) + await expect(promise).resolves.toEqual({ ok: true }) + client.disconnect() + }) + + it('rejects rpc call on timeout', async () => { + const client = createWSClient({ baseURL: 'http://127.0.0.1:8080', rpcTimeout: 20 }) + client.connect() + latestWS().open() + const promise = client.call('x.y') + promise.catch(() => {}) + await vi.advanceTimersByTimeAsync(21) + await expect(promise).rejects.toThrow('RPC timeout: x.y') + client.disconnect() + }) + + it('dispatches gateway.event notifications to event handlers', () => { + const client = createWSClient({ baseURL: 'http://127.0.0.1:8080' }) + const handler = vi.fn() + client.onEvent(handler) + + client.connect() + const ws = latestWS() + ws.open() + ws.emit({ + jsonrpc: JSONRPC_VERSION, + method: Method.Event, + params: { type: 'event', payload: { a: 1 } }, + }) + expect(handler).toHaveBeenCalledWith({ type: 'event', payload: { a: 1 } }) + client.disconnect() + }) + + it('transitions to error then reconnects and fires onReconnect', async () => { + const client = createWSClient({ + baseURL: 'http://127.0.0.1:8080', + reconnectBaseInterval: 10, + reconnectMaxInterval: 10, + }) + const states: string[] = [] + const onReconnect = vi.fn() + client.onStateChange((s) => states.push(s)) + client.onReconnect(onReconnect) + + client.connect() + const ws1 = latestWS() + ws1.open() + ws1.close() + + expect(states).toContain('error') + await vi.advanceTimersByTimeAsync(20) + const ws2 = latestWS() + expect(ws2).not.toBe(ws1) + ws2.open() + expect(onReconnect).toHaveBeenCalledTimes(1) + expect(client.getState()).toBe('connected') + client.disconnect() + }) + + it('closes stale connection by heartbeat timeout', async () => { + const client = createWSClient({ + baseURL: 'http://127.0.0.1:8080', + heartbeatTimeout: 10, + heartbeatCheckInterval: 5, + reconnectBaseInterval: 50, + reconnectMaxInterval: 50, + }) + client.connect() + const ws = latestWS() + const closeSpy = vi.spyOn(ws, 'close') + ws.open() + await vi.advanceTimersByTimeAsync(20) + expect(closeSpy).toHaveBeenCalled() + client.disconnect() + }) +}) diff --git a/web/src/components/ErrorBoundary.test.tsx b/web/src/components/ErrorBoundary.test.tsx new file mode 100644 index 00000000..b4d99dbf --- /dev/null +++ b/web/src/components/ErrorBoundary.test.tsx @@ -0,0 +1,34 @@ +import { describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import type { ReactElement } from 'react' +import { ErrorBoundary } from './ErrorBoundary' + +function Crash(): ReactElement { + throw new Error('boom') +} + +describe('ErrorBoundary', () => { + it('renders fallback UI and supports retry', () => { + const errSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + render( + + + , + ) + expect(screen.getByText('The application encountered an error')).toBeInTheDocument() + expect(screen.getByText('boom')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'Retry' })) + errSpy.mockRestore() + }) + + it('uses custom fallback render prop', () => { + const errSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + render( + }> + + , + ) + expect(screen.getByRole('button', { name: 'custom:boom' })).toBeInTheDocument() + errSpy.mockRestore() + }) +}) diff --git a/web/src/components/UpdateNotification.test.tsx b/web/src/components/UpdateNotification.test.tsx new file mode 100644 index 00000000..37a5cf2a --- /dev/null +++ b/web/src/components/UpdateNotification.test.tsx @@ -0,0 +1,55 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { UpdateNotification } from './UpdateNotification' + +describe('UpdateNotification', () => { + beforeEach(() => { + Object.defineProperty(window, 'electronAPI', { + value: undefined, + configurable: true, + writable: true, + }) + }) + + it('renders nothing when electron api is missing', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + + it('shows available/downloaded updates and handles actions', async () => { + let onAvailable: ((info: any) => void) | null = null + let onDownloaded: ((info: any) => void) | null = null + const quitAndInstall = vi.fn().mockResolvedValue(undefined) + Object.defineProperty(window, 'electronAPI', { + value: { + onUpdateAvailable: (cb: any) => { + onAvailable = cb + return vi.fn() + }, + onUpdateDownloaded: (cb: any) => { + onDownloaded = cb + return vi.fn() + }, + quitAndInstall, + }, + configurable: true, + }) + render() + + act(() => { + onAvailable?.({ version: '1.2.3' }) + }) + await waitFor(() => { + expect(screen.getByText('A new version v1.2.3 is available, downloading...')).toBeInTheDocument() + }) + + act(() => { + onDownloaded?.({ version: '1.2.3' }) + }) + fireEvent.click(screen.getByRole('button', { name: 'Restart Now' })) + expect(quitAndInstall).toHaveBeenCalled() + + fireEvent.click(screen.getByTitle('Dismiss')) + expect(screen.queryByText('NeoCode v1.2.3 is ready to install')).not.toBeInTheDocument() + }) +}) diff --git a/web/src/components/chat/AcceptanceMessage.test.tsx b/web/src/components/chat/AcceptanceMessage.test.tsx new file mode 100644 index 00000000..b33e6e65 --- /dev/null +++ b/web/src/components/chat/AcceptanceMessage.test.tsx @@ -0,0 +1,31 @@ +import { describe, expect, it } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import AcceptanceMessage from './AcceptanceMessage' + +describe('AcceptanceMessage', () => { + it('renders accepted summary and expandable details', () => { + render( + , + ) + expect(screen.getByText('已接受')).toBeInTheDocument() + expect(screen.getByText('all good')).toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: '展开详情' })) + expect(screen.getByText('detail')).toBeInTheDocument() + }) +}) + diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx index f9de008e..8c610e7b 100644 --- a/web/src/components/chat/ChatInput.test.tsx +++ b/web/src/components/chat/ChatInput.test.tsx @@ -1,25 +1,60 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' -import { fireEvent, render, screen } from '@testing-library/react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' import ChatInput from './ChatInput' import { useChatStore } from '@/stores/useChatStore' import { useComposerStore } from '@/stores/useComposerStore' import { useSessionStore } from '@/stores/useSessionStore' +const mockGatewayAPI = { + listAvailableSkills: vi.fn(), + listModels: vi.fn(), + run: vi.fn(), + bindStream: vi.fn(), + cancel: vi.fn(), + compact: vi.fn(), + executeSystemTool: vi.fn(), + activateSessionSkill: vi.fn(), + deactivateSessionSkill: vi.fn(), +} + vi.mock('@/context/RuntimeProvider', () => ({ - useGatewayAPI: () => null, + useGatewayAPI: () => mockGatewayAPI, +})) + +vi.mock('./ModelSelector', () => ({ + default: () =>
, })) describe('ChatInput', () => { beforeEach(() => { + vi.clearAllMocks() + mockGatewayAPI.listAvailableSkills.mockResolvedValue({ + payload: { + skills: [ + { + descriptor: { id: 'skill.demo', description: 'demo skill' }, + active: true, + }, + ], + }, + }) + mockGatewayAPI.listModels.mockResolvedValue({ + payload: { + models: [], + selected_provider_id: '', + selected_model_id: '', + }, + }) + useComposerStore.setState({ composerText: '' }) - useSessionStore.setState({ currentSessionId: '' } as any) + useSessionStore.setState({ currentSessionId: '' } as never) useChatStore.setState({ isGenerating: false, messages: [], permissionRequests: [], agentMode: 'build', permissionMode: 'default', - } as any) + } as never) }) it('shows the default/bypass selector in build mode', () => { @@ -40,10 +75,63 @@ describe('ChatInput', () => { expect(screen.queryByRole('button', { name: 'bypass' })).not.toBeInTheDocument() }) + it('opens slash suggestions for bare slash and loads skills immediately', async () => { + render() + + const textarea = screen.getByRole('textbox') + fireEvent.change(textarea, { target: { value: '/' } }) + + await waitFor(() => { + expect(screen.getByTestId('slash-command-menu')).toBeInTheDocument() + }) + await waitFor(() => { + expect(mockGatewayAPI.listAvailableSkills).toHaveBeenCalledWith(undefined) + }) + await waitFor(() => { + expect(screen.getByText('/skill.demo')).toBeInTheDocument() + }) + }) + + it('keeps slash menu visible for fuzzy inputs like /w', async () => { + render() + + const textarea = screen.getByRole('textbox') + fireEvent.change(textarea, { target: { value: '/w' } }) + + await waitFor(() => { + expect(screen.getByTestId('slash-command-menu')).toBeInTheDocument() + }) + expect(screen.getAllByText((_, el) => Boolean(el?.textContent?.includes('/help'))).length).toBeGreaterThan(0) + }) + + it('supports keyboard navigation, tab completion, and escape for slash menu', async () => { + render() + + const textarea = screen.getByRole('textbox') as HTMLTextAreaElement + fireEvent.change(textarea, { target: { value: '/' } }) + + await waitFor(() => { + expect(screen.getByTestId('slash-command-menu')).toBeInTheDocument() + }) + + fireEvent.keyDown(textarea, { key: 'ArrowDown' }) + fireEvent.keyDown(textarea, { key: 'Tab' }) + expect(textarea.value).toBe('/compact ') + + fireEvent.change(textarea, { target: { value: '/' } }) + await waitFor(() => { + expect(screen.getByTestId('slash-command-menu')).toBeInTheDocument() + }) + fireEvent.keyDown(textarea, { key: 'Escape' }) + await waitFor(() => { + expect(screen.queryByTestId('slash-command-menu')).not.toBeInTheDocument() + }) + }) + it('does not render the unimplemented attachment and mention buttons', () => { render() - expect(screen.queryByTitle('附加文件')).not.toBeInTheDocument() + expect(screen.queryByTitle('附件文件')).not.toBeInTheDocument() expect(screen.queryByTitle('引用上下文')).not.toBeInTheDocument() }) }) diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx index f35c6c5d..ff4258e4 100644 --- a/web/src/components/chat/ChatInput.tsx +++ b/web/src/components/chat/ChatInput.tsx @@ -1,4 +1,4 @@ -import { useState, useRef, useEffect, useCallback } from 'react' +import { useState, useRef, useEffect, useCallback, useMemo } from 'react' import { useChatStore, createUserMessage } from '@/stores/useChatStore' import { useGatewayStore } from '@/stores/useGatewayStore' import { useSessionStore, isValidSessionId } from '@/stores/useSessionStore' @@ -21,52 +21,97 @@ import SkillPicker from './SkillPicker' import ModelSelector from './ModelSelector' import { Send, Square } from 'lucide-react' +const slashMenuAnchorStyle: React.CSSProperties = { + position: 'absolute', + left: 0, + bottom: 'calc(100% + 8px)', + zIndex: 100, +} + +/** 将网关返回的技能列表转换成输入框使用的 slash 命令结构。 */ +function buildSkillSlashCommands( + skills: Array<{ descriptor: { id: string; description?: string }; active?: boolean }>, +): SkillSlashCommand[] { + return skills.map((skill) => ({ + id: `skill-${skill.descriptor.id}`, + usage: `/${skill.descriptor.id}`, + description: skill.descriptor.description || '技能', + hasArgument: false, + isSkill: true, + skillId: skill.descriptor.id, + active: Boolean(skill.active), + })) +} + +/** 用当前命令定义生成帮助文本,避免菜单与帮助内容漂移。 */ +function buildSlashHelpText(commands: AnySlashCommand[]): string { + const helpCommands = commands.length > 0 ? commands : builtinSlashCommands + const maxLen = helpCommands.reduce((max, command) => Math.max(max, command.usage.length), 0) + const lines = helpCommands.map((command) => { + const pad = ' '.repeat(maxLen - command.usage.length) + const description = isSkillCommand(command) && command.active + ? `${command.description} (已激活)` + : command.description + return ` ${command.usage}${pad} - ${description}` + }) + return ['可用命令:', ...lines].join('\n') +} + export default function ChatInput() { const gatewayAPI = useGatewayAPI() - const text = useComposerStore((s) => s.composerText) - const setText = useComposerStore((s) => s.setComposerText) + const text = useComposerStore((state) => state.composerText) + const setText = useComposerStore((state) => state.setComposerText) const [rows, setRows] = useState(1) const textareaRef = useRef(null) const runCancelledRef = useRef(false) const composingRef = useRef(false) - const isGenerating = useChatStore((s) => s.isGenerating) - const addMessage = useChatStore((s) => s.addMessage) - const addSystemMessage = useChatStore((s) => s.addSystemMessage) - const setGenerating = useChatStore((s) => s.setGenerating) - const sessionId = useSessionStore((s) => s.currentSessionId) - const agentMode = useChatStore((s) => s.agentMode) - const setAgentMode = useChatStore((s) => s.setAgentMode) - const permissionMode = useChatStore((s) => s.permissionMode) - const setPermissionMode = useChatStore((s) => s.setPermissionMode) + const isGenerating = useChatStore((state) => state.isGenerating) + const addMessage = useChatStore((state) => state.addMessage) + const addSystemMessage = useChatStore((state) => state.addSystemMessage) + const setGenerating = useChatStore((state) => state.setGenerating) + const sessionId = useSessionStore((state) => state.currentSessionId) + const agentMode = useChatStore((state) => state.agentMode) + const setAgentMode = useChatStore((state) => state.setAgentMode) + const permissionMode = useChatStore((state) => state.permissionMode) + const setPermissionMode = useChatStore((state) => state.setPermissionMode) const [showSlashMenu, setShowSlashMenu] = useState(false) const [selectedIndex, setSelectedIndex] = useState(0) const [matchedCommands, setMatchedCommands] = useState([]) const [availableSkillCommands, setAvailableSkillCommands] = useState([]) const [showSkillPicker, setShowSkillPicker] = useState(false) + const allSlashCommands = useMemo( + () => [...builtinSlashCommands, ...availableSkillCommands], + [availableSkillCommands], + ) useEffect(() => { - if (!showSlashMenu || !gatewayAPI) return + if (!gatewayAPI || !text.trimLeft().startsWith('/')) return + let cancelled = false gatewayAPI.listAvailableSkills(sessionId || undefined).then((result) => { + if (cancelled) return const skills = result.payload?.skills || [] - const skillCommands: SkillSlashCommand[] = skills.map((s) => ({ - id: `skill-${s.descriptor.id}`, - usage: `/${s.descriptor.id}`, - description: s.descriptor.description || '技能', - hasArgument: false, isSkill: true, - skillId: s.descriptor.id, active: s.active, - })) - setAvailableSkillCommands(skillCommands) - }).catch(() => { setAvailableSkillCommands([]) }) - }, [showSlashMenu, gatewayAPI, sessionId]) + setAvailableSkillCommands(buildSkillSlashCommands(skills)) + }).catch(() => { + if (!cancelled) setAvailableSkillCommands([]) + }) + return () => { + cancelled = true + } + }, [text, gatewayAPI, sessionId]) useEffect(() => { - if (!isSlashCommand(text)) { setShowSlashMenu(false); return } - const allCommands: AnySlashCommand[] = [...builtinSlashCommands, ...availableSkillCommands] - const matched = matchSlashCommands(text, allCommands) - if (matched.length > 0) { setMatchedCommands(matched); setShowSlashMenu(true); setSelectedIndex(0) } - else { setShowSlashMenu(false) } - }, [text, availableSkillCommands]) + if (!isSlashCommand(text)) { + setMatchedCommands([]) + setShowSlashMenu(false) + return + } + + const matched = matchSlashCommands(text, allSlashCommands) + setMatchedCommands(matched) + setShowSlashMenu(matched.length > 0) + if (matched.length > 0) setSelectedIndex(0) + }, [text, allSlashCommands]) useEffect(() => { const lines = text.split('\n').length @@ -76,137 +121,249 @@ export default function ChatInput() { const executeSlashCommand = useCallback(async (input: string): Promise => { const parsed = parseSlashCommand(input) if (!parsed) return false + const { command, argument } = parsed const currentSessionId = sessionId const api = gatewayAPI - if (!api) { useUIStore.getState().showToast('Gateway not connected', 'error'); return true } + if (!api) { + useUIStore.getState().showToast('Gateway not connected', 'error') + return true + } switch (command) { case '/help': { - const allUsages = [...builtinSlashCommands.map((c) => c.usage), '/'] - const maxLen = Math.max(...allUsages.map((u) => u.length)) - const helpLines = [ - '可用命令:', - ...builtinSlashCommands.map((cmd) => { - const pad = ' '.repeat(maxLen - cmd.usage.length) - return ` ${cmd.usage}${pad} — ${cmd.description}` - }), - ` /${''.padEnd(maxLen - 1)} — 激活/停用技能`, - ] - addSystemMessage(helpLines.join('\n')) + addSystemMessage(buildSlashHelpText(allSlashCommands)) return true } case '/compact': { - if (!isValidSessionId(currentSessionId)) { useUIStore.getState().showToast('Send a message first to start a session', 'error'); return true } - try { await api.compact(currentSessionId, '') } catch (err) { console.error('Compact failed:', err); useUIStore.getState().showToast('Compaction failed', 'error') } + if (!isValidSessionId(currentSessionId)) { + useUIStore.getState().showToast('Send a message first to start a session', 'error') + return true + } + try { + await api.compact(currentSessionId, '') + } catch (err) { + console.error('Compact failed:', err) + useUIStore.getState().showToast('Compaction failed', 'error') + } return true } case '/memo': { - if (!isValidSessionId(currentSessionId)) { useUIStore.getState().showToast('Send a message first to start a session', 'error'); return true } + if (!isValidSessionId(currentSessionId)) { + useUIStore.getState().showToast('Send a message first to start a session', 'error') + return true + } try { const result = await api.executeSystemTool(currentSessionId, '', 'memo_list', {}) - addSystemMessage((result as any)?.payload?.content || 'Memo query complete') - } catch (err) { console.error('Memo list failed:', err); useUIStore.getState().showToast('Failed to query memo', 'error') } + addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo query complete') + } catch (err) { + console.error('Memo list failed:', err) + useUIStore.getState().showToast('Failed to query memo', 'error') + } return true } case '/remember': { - if (!argument) { useUIStore.getState().showToast('Usage: /remember ', 'error'); return true } - if (!isValidSessionId(currentSessionId)) { useUIStore.getState().showToast('Send a message first to start a session', 'error'); return true } + if (!argument) { + useUIStore.getState().showToast('Usage: /remember ', 'error') + return true + } + if (!isValidSessionId(currentSessionId)) { + useUIStore.getState().showToast('Send a message first to start a session', 'error') + return true + } try { - const result = await api.executeSystemTool(currentSessionId, '', 'memo_remember', { type: 'user', title: argument, content: argument }) - addSystemMessage((result as any)?.payload?.content || 'Memo saved') - } catch (err) { console.error('Remember failed:', err); useUIStore.getState().showToast('Failed to save memo', 'error') } + const result = await api.executeSystemTool(currentSessionId, '', 'memo_remember', { + type: 'user', + title: argument, + content: argument, + }) + addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo saved') + } catch (err) { + console.error('Remember failed:', err) + useUIStore.getState().showToast('Failed to save memo', 'error') + } return true } case '/forget': { - if (!argument) { useUIStore.getState().showToast('Usage: /forget ', 'error'); return true } - if (!isValidSessionId(currentSessionId)) { useUIStore.getState().showToast('Send a message first to start a session', 'error'); return true } + if (!argument) { + useUIStore.getState().showToast('Usage: /forget ', 'error') + return true + } + if (!isValidSessionId(currentSessionId)) { + useUIStore.getState().showToast('Send a message first to start a session', 'error') + return true + } try { - const result = await api.executeSystemTool(currentSessionId, '', 'memo_remove', { keyword: argument, scope: 'all' }) - addSystemMessage((result as any)?.payload?.content || 'Memo deleted') - } catch (err) { console.error('Forget failed:', err); useUIStore.getState().showToast('Failed to delete memo', 'error') } + const result = await api.executeSystemTool(currentSessionId, '', 'memo_remove', { + keyword: argument, + scope: 'all', + }) + addSystemMessage((result as { payload?: { content?: string } })?.payload?.content || 'Memo deleted') + } catch (err) { + console.error('Forget failed:', err) + useUIStore.getState().showToast('Failed to delete memo', 'error') + } + return true + } + case '/skills': { + setShowSkillPicker(true) return true } - case '/skills': { setShowSkillPicker(true); return true } default: { - if (isGenerating) { useUIStore.getState().showToast('Cannot toggle skill while generating', 'info'); return true } - const skillCmd = availableSkillCommands.find((s) => s.usage === command) - if (skillCmd && isValidSessionId(currentSessionId)) { + if (isGenerating) { + useUIStore.getState().showToast('Cannot toggle skill while generating', 'info') + return true + } + const skillCommand = availableSkillCommands.find((skill) => skill.usage === command) + if (skillCommand && isValidSessionId(currentSessionId)) { try { - if (skillCmd.active) await api.deactivateSessionSkill(currentSessionId, skillCmd.skillId) - else await api.activateSessionSkill(currentSessionId, skillCmd.skillId) - setAvailableSkillCommands((prev) => prev.map((item) => item.skillId === skillCmd.skillId ? { ...item, active: !item.active } : item)) - } catch (err) { console.error('Skill toggle failed:', err); useUIStore.getState().showToast('Skill operation failed', 'error') } + if (skillCommand.active) { + await api.deactivateSessionSkill(currentSessionId, skillCommand.skillId) + } else { + await api.activateSessionSkill(currentSessionId, skillCommand.skillId) + } + setAvailableSkillCommands((prev) => prev.map((item) => ( + item.skillId === skillCommand.skillId + ? { ...item, active: !item.active } + : item + ))) + } catch (err) { + console.error('Skill toggle failed:', err) + useUIStore.getState().showToast('Skill operation failed', 'error') + } return true } return false } } - }, [gatewayAPI, sessionId, addSystemMessage, availableSkillCommands, isGenerating]) + }, [gatewayAPI, sessionId, addSystemMessage, availableSkillCommands, isGenerating, allSlashCommands]) async function handleSubmit() { const input = text.trim() if (!input) return + if (isGenerating) { if (isSlashCommand(input)) useUIStore.getState().showToast('Cannot run commands while generating', 'info') return } + if (isSlashCommand(input)) { - setText(''); setShowSlashMenu(false) + setText('') + setShowSlashMenu(false) const handled = await executeSlashCommand(input) if (handled) return } + setText('') const userMsg = createUserMessage(input) addMessage(userMsg) useRuntimeInsightStore.getState().setTodoSnapshot({ - items: [], summary: { total: 0, required_total: 0, required_completed: 0, required_failed: 0, required_open: 0 }, + items: [], + summary: { total: 0, required_total: 0, required_completed: 0, required_failed: 0, required_open: 0 }, }) setGenerating(true) runCancelledRef.current = false + try { if (!gatewayAPI) return const isNewSession = !isValidSessionId(sessionId) - const ack = await gatewayAPI.run({ session_id: isNewSession ? undefined : sessionId, new_session: isNewSession ? true : undefined, input_text: input, mode: agentMode }) + const ack = await gatewayAPI.run({ + session_id: isNewSession ? undefined : sessionId, + new_session: isNewSession ? true : undefined, + input_text: input, + mode: agentMode, + }) if (!runCancelledRef.current) { - const gwStore = useGatewayStore.getState() - const sessStore = useSessionStore.getState() - if (ack.run_id) gwStore.setCurrentRunId(ack.run_id) - if (ack.session_id) { sessStore.setCurrentSessionId(ack.session_id); gatewayAPI?.bindStream({ session_id: ack.session_id, channel: 'all' }).catch(() => {}) } + const gatewayStore = useGatewayStore.getState() + const sessionStore = useSessionStore.getState() + if (ack.run_id) gatewayStore.setCurrentRunId(ack.run_id) + if (ack.session_id) { + sessionStore.setCurrentSessionId(ack.session_id) + gatewayAPI.bindStream({ session_id: ack.session_id, channel: 'all' }).catch(() => {}) + } } } catch (err) { if (!runCancelledRef.current) { - setGenerating(false); useChatStore.getState().removeMessage(userMsg.id) - console.error('Run failed:', err); useUIStore.getState().showToast('Failed to send message', 'error') + setGenerating(false) + useChatStore.getState().removeMessage(userMsg.id) + console.error('Run failed:', err) + useUIStore.getState().showToast('Failed to send message', 'error') } } } function handleKeyDown(e: React.KeyboardEvent) { if (composingRef.current) return + if (!showSlashMenu) { - if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); handleSubmit() } + if (e.key === 'Enter' && !e.shiftKey) { + e.preventDefault() + handleSubmit() + } return } + switch (e.key) { - case 'ArrowDown': e.preventDefault(); setSelectedIndex((prev) => (prev + 1) % matchedCommands.length); return - case 'ArrowUp': e.preventDefault(); setSelectedIndex((prev) => (prev - 1 + matchedCommands.length) % matchedCommands.length); return - case 'Enter': e.preventDefault(); const cmd = matchedCommands[selectedIndex]; if (cmd) handleSelectCommand(cmd); return - case 'Escape': e.preventDefault(); setShowSlashMenu(false); return - case 'Tab': e.preventDefault(); const c = matchedCommands[selectedIndex]; if (c) { setText(c.usage + ' '); textareaRef.current?.focus() }; return + case 'ArrowDown': + e.preventDefault() + setSelectedIndex((prev) => (prev + 1) % matchedCommands.length) + return + case 'ArrowUp': + e.preventDefault() + setSelectedIndex((prev) => (prev - 1 + matchedCommands.length) % matchedCommands.length) + return + case 'Enter': { + e.preventDefault() + const command = matchedCommands[selectedIndex] + if (command) handleSelectCommand(command) + return + } + case 'Escape': + e.preventDefault() + setShowSlashMenu(false) + return + case 'Tab': { + e.preventDefault() + const command = matchedCommands[selectedIndex] + if (command) { + setText(`${command.usage} `) + textareaRef.current?.focus() + } + return + } } } function handleSelectCommand(cmd: AnySlashCommand) { - if (isSkillCommand(cmd)) { setText(cmd.usage); setShowSlashMenu(false); executeSlashCommand(cmd.usage); return } - if (cmd.hasArgument) { setText(cmd.usage + ' '); setShowSlashMenu(false); textareaRef.current?.focus() } - else { setText(''); setShowSlashMenu(false); executeSlashCommand(cmd.usage) } + if (isSkillCommand(cmd)) { + setText(cmd.usage) + setShowSlashMenu(false) + void executeSlashCommand(cmd.usage) + return + } + + if (cmd.hasArgument) { + setText(`${cmd.usage} `) + setShowSlashMenu(false) + textareaRef.current?.focus() + return + } + + setText('') + setShowSlashMenu(false) + void executeSlashCommand(cmd.usage) } async function handleCancel() { runCancelledRef.current = true const runId = useGatewayStore.getState().currentRunId - if (runId && gatewayAPI) { try { await gatewayAPI.cancel({ run_id: runId }) } catch (err) { console.error('Cancel failed:', err) } } + if (runId && gatewayAPI) { + try { + await gatewayAPI.cancel({ run_id: runId }) + } catch (err) { + console.error('Cancel failed:', err) + } + } useChatStore.getState().resetGeneratingState() } @@ -218,9 +375,9 @@ export default function ChatInput() { setShowSkillPicker(false)} /> )}
- {showSlashMenu && matchedCommands.length > 0 && ( -
-
+
+ {showSlashMenu && matchedCommands.length > 0 && ( +
-
- )} -
-