Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions internal/checkpoint/checkpoint_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,99 @@ func TestSQLiteCheckpointStoreCreateRestoreAndResume(t *testing.T) {
}
}

func TestSQLiteCheckpointStoreRestoreCheckpointUpdatesStatuses(t *testing.T) {
t.Parallel()

fixture := newCheckpointStoreFixture(t)
loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_restore_status", fixture.workspaceRoot)

recordA, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), checkpointInputFromSession(
t, loaded, "cp-available", session.CheckpointReasonPreWrite, time.Now().Add(-2*time.Minute),
))
if err != nil {
t.Fatalf("CreateCheckpoint(cp-available) error = %v", err)
}
recordB, err := fixture.checkpointStore.CreateCheckpoint(context.Background(), checkpointInputFromSession(
t, loaded, "cp-restored", session.CheckpointReasonEndOfTurn, time.Now().Add(-time.Minute),
))
if err != nil {
t.Fatalf("CreateCheckpoint(cp-restored) error = %v", err)
}

if err := fixture.checkpointStore.RestoreCheckpoint(context.Background(), RestoreCheckpointInput{
SessionID: loaded.ID,
Head: loaded.HeadSnapshot(),
Messages: loaded.Messages,
UpdatedAt: time.Now(),
MarkAvailableIDs: []string{recordA.CheckpointID},
MarkRestoredIDs: []string{recordB.CheckpointID},
}); err != nil {
t.Fatalf("RestoreCheckpoint() error = %v", err)
}

availableRecord, _, err := fixture.checkpointStore.GetCheckpoint(context.Background(), recordA.CheckpointID)
if err != nil {
t.Fatalf("GetCheckpoint(cp-available) error = %v", err)
}
if availableRecord.Status != session.CheckpointStatusAvailable {
t.Fatalf("available record status = %q, want available", availableRecord.Status)
}

restoredRecord, _, err := fixture.checkpointStore.GetCheckpoint(context.Background(), recordB.CheckpointID)
if err != nil {
t.Fatalf("GetCheckpoint(cp-restored) error = %v", err)
}
if restoredRecord.Status != session.CheckpointStatusRestored {
t.Fatalf("restored record status = %q, want restored", restoredRecord.Status)
}
}

func TestSQLiteCheckpointStoreGetCheckpointWithoutSessionSnapshot(t *testing.T) {
t.Parallel()

fixture := newCheckpointStoreFixture(t)
loaded := createCheckpointTestSession(t, fixture.sessionStore, "session_record_only", fixture.workspaceRoot)
db, err := fixture.checkpointStore.ensureDB(context.Background())
if err != nil {
t.Fatalf("ensureDB() error = %v", err)
}

if _, err := db.ExecContext(context.Background(), `
INSERT INTO checkpoint_records (
id, workspace_key, session_id, run_id, workdir, created_at_ms,
reason, code_checkpoint_ref, session_checkpoint_ref, resume_checkpoint_ref,
transcript_revision, restorable, status
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
"cp-record-only",
session.WorkspacePathKey(loaded.Workdir),
loaded.ID,
"run-record-only",
loaded.Workdir,
time.Now().UnixMilli(),
string(session.CheckpointReasonPreWrite),
"",
"",
"",
0,
1,
string(session.CheckpointStatusAvailable),
); err != nil {
t.Fatalf("insert checkpoint record error = %v", err)
}

record, sessionCP, err := fixture.checkpointStore.GetCheckpoint(context.Background(), "cp-record-only")
if err != nil {
t.Fatalf("GetCheckpoint() error = %v", err)
}
if record.CheckpointID != "cp-record-only" {
t.Fatalf("record id = %q, want cp-record-only", record.CheckpointID)
}
if sessionCP != nil {
t.Fatalf("session checkpoint = %#v, want nil", sessionCP)
}
}

func TestSQLiteCheckpointStorePruneAndRepair(t *testing.T) {
t.Parallel()

Expand Down
122 changes: 122 additions & 0 deletions internal/runtime/checkpoint_flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,91 @@ func TestRestoreCheckpointRejectsInvalidRequestAndMismatchedSession(t *testing.T
}); err == nil || !strings.Contains(err.Error(), "session mismatch") {
t.Fatalf("RestoreCheckpoint() error = %v, want session mismatch", err)
}

for _, tc := range []struct {
name string
record agentsession.CheckpointRecord
sessionCP *agentsession.SessionCheckpoint
wantSubstr string
}{
{
name: "status must be available",
record: agentsession.CheckpointRecord{
CheckpointID: "cp-status",
SessionID: "session-1",
Status: agentsession.CheckpointStatusRestored,
Restorable: true,
},
sessionCP: &agentsession.SessionCheckpoint{HeadJSON: `{}`, MessagesJSON: `[]`},
wantSubstr: "status is restored",
},
{
name: "checkpoint must be restorable",
record: agentsession.CheckpointRecord{
CheckpointID: "cp-restorable",
SessionID: "session-1",
Status: agentsession.CheckpointStatusAvailable,
Restorable: false,
},
sessionCP: &agentsession.SessionCheckpoint{HeadJSON: `{}`, MessagesJSON: `[]`},
wantSubstr: "not restorable",
},
{
name: "session checkpoint data is required",
record: agentsession.CheckpointRecord{
CheckpointID: "cp-session-data",
SessionID: "session-1",
Status: agentsession.CheckpointStatusAvailable,
Restorable: true,
},
sessionCP: nil,
wantSubstr: "no session checkpoint data",
},
{
name: "head json must be valid",
record: agentsession.CheckpointRecord{
CheckpointID: "cp-head-json",
SessionID: "session-1",
Status: agentsession.CheckpointStatusAvailable,
Restorable: true,
},
sessionCP: &agentsession.SessionCheckpoint{HeadJSON: `{invalid`, MessagesJSON: `[]`},
wantSubstr: "unmarshal head",
},
{
name: "messages json must be valid",
record: agentsession.CheckpointRecord{
CheckpointID: "cp-messages-json",
SessionID: "session-1",
Status: agentsession.CheckpointStatusAvailable,
Restorable: true,
},
sessionCP: &agentsession.SessionCheckpoint{HeadJSON: `{}`, MessagesJSON: `{invalid`},
wantSubstr: "unmarshal messages",
},
} {
t.Run(tc.name, func(t *testing.T) {
fixture := newRuntimeCheckpointFixture(t)
tc.record.SessionID = fixture.session.ID
spy := &checkpointStoreSpy{
getRecord: tc.record,
getSessionCP: tc.sessionCP,
}
service := &Service{
sessionStore: fixture.sessionStore,
checkpointStore: spy,
perEditStore: checkpoint.NewPerEditSnapshotStore(t.TempDir(), fixture.workdir),
events: make(chan RuntimeEvent, 8),
}
_, err := service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{
SessionID: fixture.session.ID,
CheckpointID: tc.record.CheckpointID,
})
if err == nil || !strings.Contains(err.Error(), tc.wantSubstr) {
t.Fatalf("RestoreCheckpoint() error = %v, want substring %q", err, tc.wantSubstr)
}
})
}
}

func TestCheckpointDiffSelectsLatestCodeCheckpointAndRejectsSessionOnlyTarget(t *testing.T) {
Expand Down Expand Up @@ -654,6 +739,43 @@ func TestCheckpointDiffSelectsLatestCodeCheckpointAndRejectsSessionOnlyTarget(t
}
}

func TestCheckpointDiffRejectsMissingStateAndReturnsEmptyWhenNoPreviousSnapshot(t *testing.T) {
service := &Service{}
if _, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{}); err == nil {
t.Fatal("expected store availability error")
}

service = &Service{
checkpointStore: &checkpointStoreSpy{},
perEditStore: checkpoint.NewPerEditSnapshotStore(t.TempDir(), t.TempDir()),
}
if _, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{}); err == nil || !strings.Contains(err.Error(), "session_id required") {
t.Fatalf("CheckpointDiff() error = %v, want session_id validation", err)
}

spy := &checkpointStoreSpy{
listRecords: []agentsession.CheckpointRecord{
{
CheckpointID: "cp-only",
SessionID: "session-1",
CreatedAt: time.Now().UTC(),
CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-only"),
},
},
}
service = &Service{
checkpointStore: spy,
perEditStore: checkpoint.NewPerEditSnapshotStore(t.TempDir(), t.TempDir()),
}
result, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{SessionID: "session-1"})
if err != nil {
t.Fatalf("CheckpointDiff() error = %v", err)
}
if result.CheckpointID != "cp-only" || result.PrevCheckpointID != "" || result.Patch != "" {
t.Fatalf("CheckpointDiff() = %#v, want latest checkpoint without previous diff", result)
}
}

func mustReadRuntimeFile(t *testing.T, path string) []byte {
t.Helper()
data, err := os.ReadFile(path)
Expand Down
40 changes: 40 additions & 0 deletions internal/runtime/tool_diff_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@ func TestBuildToolDiffPayload(t *testing.T) {
}

func TestToolExecutionHelperFunctions(t *testing.T) {
t.Run("toolCallTouchedPaths covers write and move payloads", func(t *testing.T) {
writePaths := toolCallTouchedPaths(providertypes.ToolCall{
Name: tools.ToolNameFilesystemWriteFile,
Arguments: `{"path":" docs/readme.md "}`,
}, "/repo")
if len(writePaths) != 1 || writePaths[0] != "/repo/docs/readme.md" {
t.Fatalf("write toolCallTouchedPaths() = %#v", writePaths)
}

movePaths := toolCallTouchedPaths(providertypes.ToolCall{
Name: tools.ToolNameFilesystemMoveFile,
Arguments: `{"source_path":"src/a.txt","destination_path":" /tmp/b.txt "}`,
}, "/repo")
if len(movePaths) != 2 || movePaths[0] != "/repo/src/a.txt" || movePaths[1] != "/tmp/b.txt" {
t.Fatalf("move toolCallTouchedPaths() = %#v", movePaths)
}

if got := toolCallTouchedPaths(providertypes.ToolCall{
Name: tools.ToolNameFilesystemCopyFile,
Arguments: `{invalid`,
}, "/repo"); got != nil {
t.Fatalf("malformed toolCallTouchedPaths() = %#v, want nil", got)
}
})

t.Run("toolResultMultiDiffs parses valid entries", func(t *testing.T) {
entries, ok := toolResultMultiDiffs(map[string]any{
"tool_diffs": []map[string]any{
Expand Down Expand Up @@ -158,6 +183,21 @@ func TestEmitHelpersPublishExpectedEvents(t *testing.T) {
t.Fatalf("unexpected bash payload: %#v", payload)
}

service.emitBashSideEffectEvent(
context.Background(),
state,
providertypes.ToolCall{ID: "tool-2"},
"touch noop",
checkpoint.FingerprintDiff{},
nil,
nil,
)
select {
case extra := <-service.events:
t.Fatalf("unexpected empty bash side effect event: %#v", extra)
default:
}

service.emitToolDiffs(context.Background(), state, toolExecutionSummary{
Results: []tools.ToolResult{
{
Expand Down
Loading