Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 250 additions & 0 deletions internal/checkpoint/checkpoint_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,29 @@ type SQLiteCheckpointStore struct {
ownsDB bool // true 表示本实例打开的连接,Close 时需释放
}

type workspaceFingerprintRow struct {
WorkspaceKey string
FingerprintPayload string
UpdatedAtMS int64
}

// WorkspaceCheckpointState 保存工作区当前权威代码基线与文件指纹。
type WorkspaceCheckpointState struct {
WorkspaceKey string
CurrentCheckpointID string
FingerprintPayload string
UpdatedAt time.Time
}

// RunCheckpointBaseline 保存单次 run 的权威回退基线,供异步 diff 查询跨 run 结束后读取。
type RunCheckpointBaseline struct {
SessionID string
RunID string
CheckpointID string
Drifted bool
UpdatedAt time.Time
}

// NewSQLiteCheckpointStore 创建 checkpoint 存储实例。
// dbPath 为 session.db 文件路径,可通过 session.DatabasePath 获取。
func NewSQLiteCheckpointStore(dbPath string) *SQLiteCheckpointStore {
Expand Down Expand Up @@ -114,6 +137,184 @@ func (s *SQLiteCheckpointStore) ensureDB(ctx context.Context) (*sql.DB, error) {
return db, nil
}

// SaveWorkspaceFingerprint 把 workspace 最新指纹保存到 SQLite,用于跨会话/重启后的 drift 检测。
func (s *SQLiteCheckpointStore) SaveWorkspaceFingerprint(
ctx context.Context,
workspaceKey string,
fingerprintPayload string,
updatedAt time.Time,
) error {
if err := ctx.Err(); err != nil {
return err
}
db, err := s.ensureDB(ctx)
if err != nil {
return err
}
if err := ensureWorkspaceFingerprintTable(ctx, db); err != nil {
return err
}
_, err = db.ExecContext(ctx, `
INSERT INTO workspace_fingerprints(workspace_key, fingerprint_payload, updated_at_ms)
VALUES(?, ?, ?)
ON CONFLICT(workspace_key) DO UPDATE SET
fingerprint_payload=excluded.fingerprint_payload,
updated_at_ms=excluded.updated_at_ms
`, workspaceKey, fingerprintPayload, toUnixMillis(updatedAt))
if err != nil {
return fmt.Errorf("checkpoint: save workspace fingerprint %s: %w", workspaceKey, err)
}
return nil
}

// LoadWorkspaceFingerprint 从 SQLite 读取 workspace 最近指纹,不存在时 ok=false。
func (s *SQLiteCheckpointStore) LoadWorkspaceFingerprint(
ctx context.Context,
workspaceKey string,
) (string, bool, error) {
if err := ctx.Err(); err != nil {
return "", false, err
}
db, err := s.ensureDB(ctx)
if err != nil {
return "", false, err
}
if err := ensureWorkspaceFingerprintTable(ctx, db); err != nil {
return "", false, err
}
var row workspaceFingerprintRow
err = db.QueryRowContext(ctx, `
SELECT workspace_key, fingerprint_payload, updated_at_ms
FROM workspace_fingerprints
WHERE workspace_key = ?
`, workspaceKey).Scan(&row.WorkspaceKey, &row.FingerprintPayload, &row.UpdatedAtMS)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", false, nil
}
return "", false, fmt.Errorf("checkpoint: load workspace fingerprint %s: %w", workspaceKey, err)
}
return row.FingerprintPayload, true, nil
}

// SaveWorkspaceCheckpointState 持久化工作区当前代码基线与文件指纹。
func (s *SQLiteCheckpointStore) SaveWorkspaceCheckpointState(ctx context.Context, state WorkspaceCheckpointState) error {
if err := ctx.Err(); err != nil {
return err
}
if state.WorkspaceKey == "" {
return fmt.Errorf("checkpoint: workspace key required")
}
db, err := s.ensureDB(ctx)
if err != nil {
return err
}
if err := ensureWorkspaceCheckpointStateTable(ctx, db); err != nil {
return err
}
_, err = db.ExecContext(ctx, `
INSERT INTO workspace_checkpoint_states(workspace_key, current_checkpoint_id, fingerprint_payload, updated_at_ms)
VALUES(?, ?, ?, ?)
ON CONFLICT(workspace_key) DO UPDATE SET
current_checkpoint_id=excluded.current_checkpoint_id,
fingerprint_payload=excluded.fingerprint_payload,
updated_at_ms=excluded.updated_at_ms
`, state.WorkspaceKey, state.CurrentCheckpointID, state.FingerprintPayload, toUnixMillis(state.UpdatedAt))
if err != nil {
return fmt.Errorf("checkpoint: save workspace checkpoint state %s: %w", state.WorkspaceKey, err)
}
return nil
}

// LoadWorkspaceCheckpointState 读取工作区当前代码基线与文件指纹,不存在时 ok=false。
func (s *SQLiteCheckpointStore) LoadWorkspaceCheckpointState(ctx context.Context, workspaceKey string) (WorkspaceCheckpointState, bool, error) {
if err := ctx.Err(); err != nil {
return WorkspaceCheckpointState{}, false, err
}
db, err := s.ensureDB(ctx)
if err != nil {
return WorkspaceCheckpointState{}, false, err
}
if err := ensureWorkspaceCheckpointStateTable(ctx, db); err != nil {
return WorkspaceCheckpointState{}, false, err
}
var state WorkspaceCheckpointState
var updatedAtMS int64
err = db.QueryRowContext(ctx, `
SELECT workspace_key, current_checkpoint_id, fingerprint_payload, updated_at_ms
FROM workspace_checkpoint_states
WHERE workspace_key = ?
`, workspaceKey).Scan(&state.WorkspaceKey, &state.CurrentCheckpointID, &state.FingerprintPayload, &updatedAtMS)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return WorkspaceCheckpointState{}, false, nil
}
return WorkspaceCheckpointState{}, false, fmt.Errorf("checkpoint: load workspace checkpoint state %s: %w", workspaceKey, err)
}
state.UpdatedAt = fromUnixMillis(updatedAtMS)
return state, true, nil
}

// SaveRunCheckpointBaseline 持久化单次 run 的权威回退基线。
func (s *SQLiteCheckpointStore) SaveRunCheckpointBaseline(ctx context.Context, baseline RunCheckpointBaseline) error {
if err := ctx.Err(); err != nil {
return err
}
if baseline.SessionID == "" || baseline.RunID == "" {
return fmt.Errorf("checkpoint: session_id and run_id required for run baseline")
}
db, err := s.ensureDB(ctx)
if err != nil {
return err
}
if err := ensureRunCheckpointBaselineTable(ctx, db); err != nil {
return err
}
_, err = db.ExecContext(ctx, `
INSERT INTO run_checkpoint_baselines(session_id, run_id, checkpoint_id, drifted, updated_at_ms)
VALUES(?, ?, ?, ?, ?)
ON CONFLICT(session_id, run_id) DO UPDATE SET
checkpoint_id=excluded.checkpoint_id,
drifted=excluded.drifted,
updated_at_ms=excluded.updated_at_ms
`, baseline.SessionID, baseline.RunID, baseline.CheckpointID, boolToInt(baseline.Drifted), toUnixMillis(baseline.UpdatedAt))
if err != nil {
return fmt.Errorf("checkpoint: save run baseline %s/%s: %w", baseline.SessionID, baseline.RunID, err)
}
return nil
}

// LoadRunCheckpointBaseline 读取单次 run 的权威回退基线,不存在时 ok=false。
func (s *SQLiteCheckpointStore) LoadRunCheckpointBaseline(ctx context.Context, sessionID, runID string) (RunCheckpointBaseline, bool, error) {
if err := ctx.Err(); err != nil {
return RunCheckpointBaseline{}, false, err
}
db, err := s.ensureDB(ctx)
if err != nil {
return RunCheckpointBaseline{}, false, err
}
if err := ensureRunCheckpointBaselineTable(ctx, db); err != nil {
return RunCheckpointBaseline{}, false, err
}
var baseline RunCheckpointBaseline
var drifted int
var updatedAtMS int64
err = db.QueryRowContext(ctx, `
SELECT session_id, run_id, checkpoint_id, drifted, updated_at_ms
FROM run_checkpoint_baselines
WHERE session_id = ? AND run_id = ?
`, sessionID, runID).Scan(&baseline.SessionID, &baseline.RunID, &baseline.CheckpointID, &drifted, &updatedAtMS)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return RunCheckpointBaseline{}, false, nil
}
return RunCheckpointBaseline{}, false, fmt.Errorf("checkpoint: load run baseline %s/%s: %w", sessionID, runID, err)
}
baseline.Drifted = drifted != 0
baseline.UpdatedAt = fromUnixMillis(updatedAtMS)
return baseline, true, nil
}

// CreateCheckpoint 在单一事务内写入 checkpoint record + session checkpoint。
// 事务内完成 record INSERT → session_cp INSERT → record UPDATE(设置 session_checkpoint_ref + status=available)。
func (s *SQLiteCheckpointStore) CreateCheckpoint(ctx context.Context, input CreateCheckpointInput) (session.CheckpointRecord, error) {
Expand Down Expand Up @@ -663,3 +864,52 @@ func rollbackTx(tx *sql.Tx) {
_ = tx.Rollback()
}
}

func ensureWorkspaceFingerprintTable(ctx context.Context, db *sql.DB) error {
if db == nil {
return fmt.Errorf("checkpoint: workspace fingerprint db is nil")
}
if _, err := db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS workspace_fingerprints (
workspace_key TEXT PRIMARY KEY,
fingerprint_payload TEXT NOT NULL,
updated_at_ms INTEGER NOT NULL
)`); err != nil {
return fmt.Errorf("checkpoint: ensure workspace_fingerprints table: %w", err)
}
return nil
}

func ensureWorkspaceCheckpointStateTable(ctx context.Context, db *sql.DB) error {
if db == nil {
return fmt.Errorf("checkpoint: workspace checkpoint state db is nil")
}
if _, err := db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS workspace_checkpoint_states (
workspace_key TEXT PRIMARY KEY,
current_checkpoint_id TEXT NOT NULL,
fingerprint_payload TEXT NOT NULL,
updated_at_ms INTEGER NOT NULL
)`); err != nil {
return fmt.Errorf("checkpoint: ensure workspace_checkpoint_states table: %w", err)
}
return nil
}

func ensureRunCheckpointBaselineTable(ctx context.Context, db *sql.DB) error {
if db == nil {
return fmt.Errorf("checkpoint: run checkpoint baseline db is nil")
}
if _, err := db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS run_checkpoint_baselines (
session_id TEXT NOT NULL,
run_id TEXT NOT NULL,
checkpoint_id TEXT NOT NULL,
drifted INTEGER NOT NULL,
updated_at_ms INTEGER NOT NULL,
PRIMARY KEY(session_id, run_id)
)`); err != nil {
return fmt.Errorf("checkpoint: ensure run_checkpoint_baselines table: %w", err)
}
return nil
}
78 changes: 78 additions & 0 deletions internal/checkpoint/checkpoint_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,3 +588,81 @@ func TestNewSQLiteCheckpointStoreWithNilDBClose(t *testing.T) {
t.Fatalf("Close(nil db) error = %v", err)
}
}

func TestSQLiteCheckpointStoreWorkspaceFingerprintPersistence(t *testing.T) {
fixture := newCheckpointStoreFixture(t)
db, err := fixture.sessionStore.InitDB(context.Background())
if err != nil {
t.Fatalf("InitDB() error = %v", err)
}
shared := NewSQLiteCheckpointStoreWithDB(db)

workspaceKey := session.WorkspacePathKey(fixture.workspaceRoot)
payload := `{"a.txt":{"size":1,"mod_time":"2026-01-01T00:00:00Z","head_hash":"abc"}}`
if err := shared.SaveWorkspaceFingerprint(context.Background(), workspaceKey, payload, time.Now()); err != nil {
t.Fatalf("SaveWorkspaceFingerprint() error = %v", err)
}

got, ok, err := shared.LoadWorkspaceFingerprint(context.Background(), workspaceKey)
if err != nil {
t.Fatalf("LoadWorkspaceFingerprint() error = %v", err)
}
if !ok {
t.Fatal("expected persisted fingerprint to exist")
}
if got != payload {
t.Fatalf("fingerprint payload = %q, want %q", got, payload)
}

missing, ok, err := shared.LoadWorkspaceFingerprint(context.Background(), "workspace/missing")
if err != nil {
t.Fatalf("LoadWorkspaceFingerprint(missing) error = %v", err)
}
if ok || missing != "" {
t.Fatalf("expected missing fingerprint, got ok=%v payload=%q", ok, missing)
}
}

func TestSQLiteCheckpointStoreWorkspaceStateAndRunBaselinePersistence(t *testing.T) {
fixture := newCheckpointStoreFixture(t)
db, err := fixture.sessionStore.InitDB(context.Background())
if err != nil {
t.Fatalf("InitDB() error = %v", err)
}
shared := NewSQLiteCheckpointStoreWithDB(db)

workspaceKey := session.WorkspacePathKey(fixture.workspaceRoot)
payload := `{"tracked.txt":{"size":7}}`
if err := shared.SaveWorkspaceCheckpointState(context.Background(), WorkspaceCheckpointState{
WorkspaceKey: workspaceKey,
CurrentCheckpointID: "cp-current",
FingerprintPayload: payload,
UpdatedAt: time.Now(),
}); err != nil {
t.Fatalf("SaveWorkspaceCheckpointState() error = %v", err)
}
state, ok, err := shared.LoadWorkspaceCheckpointState(context.Background(), workspaceKey)
if err != nil {
t.Fatalf("LoadWorkspaceCheckpointState() error = %v", err)
}
if !ok || state.CurrentCheckpointID != "cp-current" || state.FingerprintPayload != payload {
t.Fatalf("workspace state = %#v ok=%v", state, ok)
}

if err := shared.SaveRunCheckpointBaseline(context.Background(), RunCheckpointBaseline{
SessionID: "session-1",
RunID: "run-1",
CheckpointID: "cp-current",
Drifted: true,
UpdatedAt: time.Now(),
}); err != nil {
t.Fatalf("SaveRunCheckpointBaseline() error = %v", err)
}
baseline, ok, err := shared.LoadRunCheckpointBaseline(context.Background(), "session-1", "run-1")
if err != nil {
t.Fatalf("LoadRunCheckpointBaseline() error = %v", err)
}
if !ok || baseline.CheckpointID != "cp-current" || !baseline.Drifted {
t.Fatalf("run baseline = %#v ok=%v", baseline, ok)
}
}
Loading
Loading