Skip to content
147 changes: 147 additions & 0 deletions internal/runtime/checkpoint_flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,25 @@ func (f runtimeCheckpointFixture) captureFile(t *testing.T, relPath string, cont
return abs
}

func readCheckpointRestoredPayload(t *testing.T, events <-chan RuntimeEvent) CheckpointRestoredPayload {
t.Helper()
for {
select {
case evt := <-events:
if evt.Type != EventCheckpointRestored {
continue
}
payload, ok := evt.Payload.(CheckpointRestoredPayload)
if !ok {
t.Fatalf("checkpoint restored payload type = %T", evt.Payload)
}
return payload
default:
t.Fatal("expected checkpoint restored event")
}
}
}

func TestCreateStartOfTurnCheckpoint_PendingWrite(t *testing.T) {
fixture := newRuntimeCheckpointFixture(t)
fixture.captureFile(t, "main.go", []byte("package main\nconst v = 1\n"))
Expand Down Expand Up @@ -392,6 +411,134 @@ func TestRestoreCheckpoint_RecoversCapturedFile(t *testing.T) {
if string(got) != "version one" {
t.Fatalf("restored content = %q, want %q", string(got), "version one")
}

payload := readCheckpointRestoredPayload(t, fixture.service.events)
if payload.Mode != "exact" {
t.Fatalf("restore payload mode = %q, want exact", payload.Mode)
}
if len(payload.Paths) != 0 {
t.Fatalf("restore payload paths = %#v, want empty", payload.Paths)
}
if payload.GuardCheckpointID == "" {
t.Fatal("restore payload guard checkpoint id is empty")
}
}

func TestRestoreCheckpointBaselineEmitsModeAndPaths(t *testing.T) {
fixture := newRuntimeCheckpointFixture(t)
target := filepath.Join(fixture.workdir, "baseline.txt")
if err := os.WriteFile(target, []byte("before baseline"), 0o644); err != nil {
t.Fatalf("WriteFile(before baseline) error = %v", err)
}
if _, err := fixture.perEditStore.CapturePreWrite(target); err != nil {
t.Fatalf("CapturePreWrite() error = %v", err)
}

state := newRunState("run-baseline", fixture.session)
if err := fixture.service.createStartOfTurnCheckpoint(context.Background(), &state); err != nil {
t.Fatalf("createStartOfTurnCheckpoint() error = %v", err)
}
records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{})
if err != nil {
t.Fatalf("ListCheckpoints() error = %v", err)
}
if len(records) != 1 {
t.Fatalf("records = %#v, want 1", records)
}
cpRecord := records[0]
if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), cpRecord.CheckpointID, agentsession.CheckpointStatusAvailable); err != nil {
t.Fatalf("UpdateCheckpointStatus() error = %v", err)
}
if err := os.WriteFile(target, []byte("after baseline"), 0o644); err != nil {
t.Fatalf("WriteFile(after baseline) error = %v", err)
}

if _, err := fixture.service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{
SessionID: fixture.session.ID,
CheckpointID: cpRecord.CheckpointID,
Mode: "baseline",
Paths: []string{"./baseline.txt", "baseline.txt"},
}); err != nil {
t.Fatalf("RestoreCheckpoint(baseline) error = %v", err)
}
if got := string(mustReadRuntimeFile(t, target)); got != "before baseline" {
t.Fatalf("baseline restored content = %q, want before baseline", got)
}

payload := readCheckpointRestoredPayload(t, fixture.service.events)
if payload.Mode != "baseline" {
t.Fatalf("restore payload mode = %q, want baseline", payload.Mode)
}
if len(payload.Paths) != 1 || payload.Paths[0] != "baseline.txt" {
t.Fatalf("restore payload paths = %#v, want [baseline.txt]", payload.Paths)
}
if payload.GuardCheckpointID != "" {
t.Fatalf("baseline restore guard checkpoint id = %q, want empty", payload.GuardCheckpointID)
}
}

func TestRestoreCheckpointBaselineRejectsPathsThatNormalizeEmpty(t *testing.T) {
fixture := newRuntimeCheckpointFixture(t)
target := filepath.Join(fixture.workdir, "baseline.txt")
if err := os.WriteFile(target, []byte("before baseline"), 0o644); err != nil {
t.Fatalf("WriteFile(before baseline) error = %v", err)
}
if _, err := fixture.perEditStore.CapturePreWrite(target); err != nil {
t.Fatalf("CapturePreWrite() error = %v", err)
}

state := newRunState("run-baseline-empty-paths", fixture.session)
if err := fixture.service.createStartOfTurnCheckpoint(context.Background(), &state); err != nil {
t.Fatalf("createStartOfTurnCheckpoint() error = %v", err)
}
records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{})
if err != nil {
t.Fatalf("ListCheckpoints() error = %v", err)
}
if len(records) != 1 {
t.Fatalf("records = %#v, want 1", records)
}
cpRecord := records[0]
if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), cpRecord.CheckpointID, agentsession.CheckpointStatusAvailable); err != nil {
t.Fatalf("UpdateCheckpointStatus() error = %v", err)
}

_, err = fixture.service.restoreCheckpointBaseline(context.Background(), fixture.session.ID, cpRecord.CheckpointID, []string{"./", " . "})
if err == nil || !strings.Contains(err.Error(), "baseline restore paths required") {
t.Fatalf("restoreCheckpointBaseline() error = %v, want baseline restore paths required", err)
}
}

func TestRestoreCheckpointBaselineWrapsRestoreBaselineError(t *testing.T) {
fixture := newRuntimeCheckpointFixture(t)
target := filepath.Join(fixture.workdir, "baseline.txt")
if err := os.WriteFile(target, []byte("before baseline"), 0o644); err != nil {
t.Fatalf("WriteFile(before baseline) error = %v", err)
}
if _, err := fixture.perEditStore.CapturePreWrite(target); err != nil {
t.Fatalf("CapturePreWrite() error = %v", err)
}

state := newRunState("run-baseline-missing-path", fixture.session)
if err := fixture.service.createStartOfTurnCheckpoint(context.Background(), &state); err != nil {
t.Fatalf("createStartOfTurnCheckpoint() error = %v", err)
}
records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{})
if err != nil {
t.Fatalf("ListCheckpoints() error = %v", err)
}
if len(records) != 1 {
t.Fatalf("records = %#v, want 1", records)
}
cpRecord := records[0]
if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), cpRecord.CheckpointID, agentsession.CheckpointStatusAvailable); err != nil {
t.Fatalf("UpdateCheckpointStatus() error = %v", err)
}

_, err = fixture.service.restoreCheckpointBaseline(context.Background(), fixture.session.ID, cpRecord.CheckpointID, []string{"missing.txt"})
if err == nil || !strings.Contains(err.Error(), "baseline restore code") || !strings.Contains(err.Error(), "missing.txt") {
t.Fatalf("restoreCheckpointBaseline() error = %v, want wrapped missing baseline path error", err)
}
}

func TestUndoRestoreCheckpoint_RestoresGuardState(t *testing.T) {
Expand Down
31 changes: 23 additions & 8 deletions internal/runtime/checkpoint_restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,18 +191,22 @@ func (s *Service) RestoreCheckpoint(ctx context.Context, input GatewayRestoreInp
CheckpointID: result.CheckpointID,
SessionID: result.SessionID,
GuardCheckpointID: guardRecord.CheckpointID,
Mode: "exact",
})
return result, nil
}
if mode == "baseline" {
result, err := s.restoreCheckpointBaseline(ctx, input.SessionID, input.CheckpointID, input.Paths)
paths := normalizeBaselineRestorePaths(input.Paths)
result, err := s.restoreCheckpointBaseline(ctx, input.SessionID, input.CheckpointID, paths)
if err != nil {
return RestoreResult{}, err
}
_ = s.emit(ctx, EventCheckpointRestored, "", result.SessionID, CheckpointRestoredPayload{
CheckpointID: result.CheckpointID,
SessionID: result.SessionID,
GuardCheckpointID: "",
Mode: "baseline",
Paths: paths,
})
return result, nil
}
Expand Down Expand Up @@ -351,7 +355,20 @@ func (s *Service) restoreCheckpointBaseline(
if perEditID == "" {
return RestoreResult{}, fmt.Errorf("checkpoint: %s has no code snapshot", checkpointID)
}
relPaths := normalizeBaselineRestorePaths(paths)
if len(relPaths) == 0 {
return RestoreResult{}, fmt.Errorf("checkpoint: baseline restore paths required")
}
if err := s.perEditStore.RestoreBaseline(ctx, perEditID, relPaths); err != nil {
return RestoreResult{}, fmt.Errorf("checkpoint: baseline restore code: %w", err)
}
return RestoreResult{CheckpointID: checkpointID, SessionID: sessionID}, nil
}

// normalizeBaselineRestorePaths 统一 baseline 文件回退路径,保证恢复执行与事件通知使用同一组相对路径。
func normalizeBaselineRestorePaths(paths []string) []string {
relPaths := make([]string, 0, len(paths))
seen := make(map[string]struct{}, len(paths))
for _, path := range paths {
clean := filepath.ToSlash(strings.TrimSpace(path))
for strings.HasPrefix(clean, "./") {
Expand All @@ -360,15 +377,13 @@ func (s *Service) restoreCheckpointBaseline(
if clean == "" || clean == "." {
continue
}
if _, ok := seen[clean]; ok {
continue
}
seen[clean] = struct{}{}
relPaths = append(relPaths, clean)
}
if len(relPaths) == 0 {
return RestoreResult{}, fmt.Errorf("checkpoint: baseline restore paths required")
}
if err := s.perEditStore.RestoreBaseline(ctx, perEditID, relPaths); err != nil {
return RestoreResult{}, fmt.Errorf("checkpoint: baseline restore code: %w", err)
}
return RestoreResult{CheckpointID: checkpointID, SessionID: sessionID}, nil
return relPaths
}

// ListCheckpoints 查询指定会话的 checkpoint 列表。
Expand Down
8 changes: 5 additions & 3 deletions internal/runtime/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,11 @@ type CheckpointWarningPayload struct {

// CheckpointRestoredPayload 描述 checkpoint 恢复成功事件。
type CheckpointRestoredPayload struct {
CheckpointID string `json:"checkpoint_id"`
SessionID string `json:"session_id"`
GuardCheckpointID string `json:"guard_checkpoint_id"`
CheckpointID string `json:"checkpoint_id"`
SessionID string `json:"session_id"`
GuardCheckpointID string `json:"guard_checkpoint_id"`
Mode string `json:"mode,omitempty"`
Paths []string `json:"paths,omitempty"`
}

// CheckpointUndoRestorePayload 描述 restore 撤销事件。
Expand Down
76 changes: 76 additions & 0 deletions internal/tui/services/gateway_stream_client_additional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,82 @@ func TestRestoreRuntimePayloadAdditionalBranches(t *testing.T) {
}
}

func TestRestoreRuntimePayloadCheckpointRestoredKeepsBaselineFields(t *testing.T) {
t.Parallel()

payload, err := restoreRuntimePayload(EventCheckpointRestored, map[string]any{
"checkpoint_id": "cp-baseline",
"session_id": "session-1",
"guard_checkpoint_id": "",
"mode": "baseline",
"paths": []any{"a.txt", "nested/b.txt"},
})
if err != nil {
t.Fatalf("restoreRuntimePayload(CheckpointRestored) error = %v", err)
}
restored, ok := payload.(CheckpointRestoredPayload)
if !ok {
t.Fatalf("payload type = %T, want CheckpointRestoredPayload", payload)
}
if restored.CheckpointID != "cp-baseline" || restored.SessionID != "session-1" {
t.Fatalf("unexpected restored identity payload: %+v", restored)
}
if restored.Mode != "baseline" {
t.Fatalf("restored.Mode = %q, want baseline", restored.Mode)
}
if !reflect.DeepEqual(restored.Paths, []string{"a.txt", "nested/b.txt"}) {
t.Fatalf("restored.Paths = %#v, want [a.txt nested/b.txt]", restored.Paths)
}
if restored.GuardCheckpointID != "" {
t.Fatalf("restored.GuardCheckpointID = %q, want empty for baseline restore", restored.GuardCheckpointID)
}
}

func TestDecodeRuntimeEventCheckpointRestoredKeepsModeAndPaths(t *testing.T) {
t.Parallel()

notification := buildGatewayEventNotification(t, gateway.MessageFrame{
Type: gateway.FrameTypeEvent,
Action: gateway.FrameActionRun,
RunID: "run-1",
SessionID: "session-1",
Payload: map[string]any{
"event_type": "run_progress",
"payload": map[string]any{
"runtime_event_type": string(EventCheckpointRestored),
"payload_version": runtimeEventPayloadVersion,
"turn": 2,
"phase": "restore",
"payload": map[string]any{
"checkpoint_id": "cp-baseline",
"session_id": "session-1",
"guard_checkpoint_id": "",
"mode": "baseline",
"paths": []string{"a.txt", "nested/b.txt"},
},
},
},
})

event, err := decodeRuntimeEventFromGatewayNotification(notification)
if err != nil {
t.Fatalf("decodeRuntimeEventFromGatewayNotification() error = %v", err)
}
if event.Type != EventCheckpointRestored || event.RunID != "run-1" || event.SessionID != "session-1" || event.Turn != 2 || event.Phase != "restore" {
t.Fatalf("unexpected event metadata: %+v", event)
}
restored, ok := event.Payload.(CheckpointRestoredPayload)
if !ok {
t.Fatalf("event.Payload type = %T, want CheckpointRestoredPayload", event.Payload)
}
if restored.CheckpointID != "cp-baseline" || restored.Mode != "baseline" {
t.Fatalf("unexpected checkpoint restored payload: %+v", restored)
}
if !reflect.DeepEqual(restored.Paths, []string{"a.txt", "nested/b.txt"}) {
t.Fatalf("restored.Paths = %#v, want [a.txt nested/b.txt]", restored.Paths)
}
}

func TestStreamHelperBranches(t *testing.T) {
t.Parallel()

Expand Down
8 changes: 5 additions & 3 deletions internal/tui/services/runtime_contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,9 +618,11 @@ type CheckpointWarningPayload struct {

// CheckpointRestoredPayload 描述 checkpoint 恢复成功事件。
type CheckpointRestoredPayload struct {
CheckpointID string `json:"checkpoint_id"`
SessionID string `json:"session_id"`
GuardCheckpointID string `json:"guard_checkpoint_id"`
CheckpointID string `json:"checkpoint_id"`
SessionID string `json:"session_id"`
GuardCheckpointID string `json:"guard_checkpoint_id"`
Mode string `json:"mode,omitempty"`
Paths []string `json:"paths,omitempty"`
}

// CheckpointUndoRestorePayload 描述 restore 撤销事件。
Expand Down
2 changes: 2 additions & 0 deletions web/src/api/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ export interface CheckpointRestoredPayload {
checkpoint_id: string;
session_id: string;
guard_checkpoint_id: string;
mode?: "exact" | "baseline" | string;
paths?: string[];
}

export interface CheckpointUndoRestorePayload {
Expand Down
Loading
Loading