diff --git a/apps/code/src/main/db/migrations/0006_fork_relationships.sql b/apps/code/src/main/db/migrations/0006_fork_relationships.sql new file mode 100644 index 000000000..372a41fcb --- /dev/null +++ b/apps/code/src/main/db/migrations/0006_fork_relationships.sql @@ -0,0 +1,12 @@ +CREATE TABLE `fork_relationships` ( + `id` text PRIMARY KEY NOT NULL, + `forked_task_id` text NOT NULL, + `source_task_id` text NOT NULL, + `source_task_run_id` text NOT NULL, + `source_task_title` text NOT NULL, + `fork_at_message_index` integer NOT NULL, + `forked_at` text NOT NULL, + `created_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL +); +--> statement-breakpoint +CREATE UNIQUE INDEX `fork_relationships_forked_task_id_unique` ON `fork_relationships` (`forked_task_id`); diff --git a/apps/code/src/main/db/migrations/meta/_journal.json b/apps/code/src/main/db/migrations/meta/_journal.json index 5ea0be65d..34b127a59 100644 --- a/apps/code/src/main/db/migrations/meta/_journal.json +++ b/apps/code/src/main/db/migrations/meta/_journal.json @@ -43,6 +43,13 @@ "when": 1775755977659, "tag": "0005_youthful_scarlet_spider", "breakpoints": true + }, + { + "idx": 6, + "version": "6", + "when": 1779508277499, + "tag": "0006_fork_relationships", + "breakpoints": true } ] } diff --git a/apps/code/src/main/db/repositories/fork-relationship-repository.ts b/apps/code/src/main/db/repositories/fork-relationship-repository.ts new file mode 100644 index 000000000..6b0983539 --- /dev/null +++ b/apps/code/src/main/db/repositories/fork-relationship-repository.ts @@ -0,0 +1,54 @@ +import { eq } from "drizzle-orm"; +import { inject, injectable } from "inversify"; +import { MAIN_TOKENS } from "../../di/tokens"; +import { forkRelationships } from "../schema"; +import type { DatabaseService } from "../service"; + +export type ForkRelationship = typeof forkRelationships.$inferSelect; + +export interface CreateForkRelationshipData { + forkedTaskId: string; + sourceTaskId: string; + sourceTaskRunId: string; + sourceTaskTitle: string; + forkAtMessageIndex: number; + forkedAt: string; +} + +export interface IForkRelationshipRepository { + create(data: CreateForkRelationshipData): ForkRelationship; + findByForkedTaskId(forkedTaskId: string): ForkRelationship | null; +} + +@injectable() +export class ForkRelationshipRepository implements IForkRelationshipRepository { + constructor( + @inject(MAIN_TOKENS.DatabaseService) + private readonly databaseService: DatabaseService, + ) {} + + private get db() { + return this.databaseService.db; + } + + create(data: CreateForkRelationshipData): ForkRelationship { + return this.db + .insert(forkRelationships) + .values({ + id: crypto.randomUUID(), + ...data, + }) + .returning() + .get(); + } + + findByForkedTaskId(forkedTaskId: string): ForkRelationship | null { + return ( + this.db + .select() + .from(forkRelationships) + .where(eq(forkRelationships.forkedTaskId, forkedTaskId)) + .get() ?? null + ); + } +} diff --git a/apps/code/src/main/db/schema.ts b/apps/code/src/main/db/schema.ts index 8e4f14404..1b5950563 100644 --- a/apps/code/src/main/db/schema.ts +++ b/apps/code/src/main/db/schema.ts @@ -88,6 +88,18 @@ export const authSessions = sqliteTable("auth_sessions", { updatedAt: updatedAt(), }); +export const forkRelationships = sqliteTable("fork_relationships", { + id: id(), + forkedTaskId: text().notNull().unique(), + sourceTaskId: text().notNull(), + sourceTaskRunId: text().notNull(), + /** Title of the source task captured at fork time, shown if parent is later deleted. */ + sourceTaskTitle: text().notNull(), + forkAtMessageIndex: integer().notNull(), + forkedAt: text().notNull(), + createdAt: createdAt(), +}); + export const authPreferences = sqliteTable( "auth_preferences", { diff --git a/apps/code/src/main/di/container.ts b/apps/code/src/main/di/container.ts index 959ea1431..242968610 100644 --- a/apps/code/src/main/di/container.ts +++ b/apps/code/src/main/di/container.ts @@ -4,6 +4,7 @@ import { Container } from "inversify"; import { ArchiveRepository } from "../db/repositories/archive-repository"; import { AuthPreferenceRepository } from "../db/repositories/auth-preference-repository"; import { AuthSessionRepository } from "../db/repositories/auth-session-repository"; +import { ForkRelationshipRepository } from "../db/repositories/fork-relationship-repository"; import { RepositoryRepository } from "../db/repositories/repository-repository"; import { SuspensionRepositoryImpl } from "../db/repositories/suspension-repository"; import { WorkspaceRepository } from "../db/repositories/workspace-repository"; @@ -41,6 +42,7 @@ import { FileWatcherService } from "../services/file-watcher/service"; import { FocusService } from "../services/focus/service"; import { FocusSyncService } from "../services/focus/sync-service"; import { FoldersService } from "../services/folders/service"; +import { ForkService } from "../services/fork/service"; import { FsService } from "../services/fs/service"; import { GitService } from "../services/git/service"; import { GitHubIntegrationService } from "../services/github-integration/service"; @@ -95,6 +97,9 @@ container.bind(MAIN_TOKENS.AuthSessionRepository).to(AuthSessionRepository); container.bind(MAIN_TOKENS.RepositoryRepository).to(RepositoryRepository); container.bind(MAIN_TOKENS.WorkspaceRepository).to(WorkspaceRepository); container.bind(MAIN_TOKENS.WorktreeRepository).to(WorktreeRepository); +container + .bind(MAIN_TOKENS.ForkRelationshipRepository) + .to(ForkRelationshipRepository); container.bind(MAIN_TOKENS.ArchiveRepository).to(ArchiveRepository); container.bind(MAIN_TOKENS.SuspensionRepository).to(SuspensionRepositoryImpl); container.bind(MAIN_TOKENS.AgentAuthAdapter).to(AgentAuthAdapter); @@ -142,5 +147,6 @@ container.bind(MAIN_TOKENS.TaskLinkService).to(TaskLinkService); container.bind(MAIN_TOKENS.InboxLinkService).to(InboxLinkService); container.bind(MAIN_TOKENS.WatcherRegistryService).to(WatcherRegistryService); container.bind(MAIN_TOKENS.WorkspaceService).to(WorkspaceService); +container.bind(MAIN_TOKENS.ForkService).to(ForkService); container.bind(MAIN_TOKENS.SettingsStore).toConstantValue(settingsStore); diff --git a/apps/code/src/main/di/tokens.ts b/apps/code/src/main/di/tokens.ts index c8225b2b1..ffb790748 100644 --- a/apps/code/src/main/di/tokens.ts +++ b/apps/code/src/main/di/tokens.ts @@ -32,6 +32,7 @@ export const MAIN_TOKENS = Object.freeze({ RepositoryRepository: Symbol.for("Main.RepositoryRepository"), WorkspaceRepository: Symbol.for("Main.WorkspaceRepository"), WorktreeRepository: Symbol.for("Main.WorktreeRepository"), + ForkRelationshipRepository: Symbol.for("Main.ForkRelationshipRepository"), ArchiveRepository: Symbol.for("Main.ArchiveRepository"), SuspensionRepository: Symbol.for("Main.SuspensionRepository"), @@ -76,5 +77,6 @@ export const MAIN_TOKENS = Object.freeze({ EnvironmentService: Symbol.for("Main.EnvironmentService"), ProvisioningService: Symbol.for("Main.ProvisioningService"), WorkspaceService: Symbol.for("Main.WorkspaceService"), + ForkService: Symbol.for("Main.ForkService"), EnrichmentService: Symbol.for("Main.EnrichmentService"), }); diff --git a/apps/code/src/main/services/agent/service.ts b/apps/code/src/main/services/agent/service.ts index 4c3eecb07..a0546d278 100644 --- a/apps/code/src/main/services/agent/service.ts +++ b/apps/code/src/main/services/agent/service.ts @@ -37,6 +37,7 @@ import { import { getLlmGatewayUrl } from "@posthog/agent/posthog-api"; import { extractCreatedPrUrl } from "@posthog/agent/pr-url-detector"; import type * as AgentTypes from "@posthog/agent/types"; +import { createGitClient } from "@posthog/git/client"; import { getCurrentBranch } from "@posthog/git/queries"; import type { IAppMeta } from "@posthog/platform/app-meta"; import type { IBundledResources } from "@posthog/platform/bundled-resources"; @@ -890,6 +891,44 @@ When creating pull requests, add the following footer at the end of the PR descr if (!this.hasActiveSessions()) { this.emit(AgentServiceEvent.SessionsIdle, undefined); } + + void this.captureAgentCheckpoint( + sessionId, + session.taskId, + session.repoPath, + session.agent, + ); + } + } + + private async captureAgentCheckpoint( + taskRunId: string, + taskId: string, + repoPath: string, + agent: Agent, + ): Promise { + if (taskId === "__preview__") return; + + try { + const git = createGitClient(repoPath); + const headSha = await git.revparse(["HEAD"]); + + const posthogAPI = agent.getPosthogAPI(); + if (!posthogAPI) return; + + await posthogAPI.appendTaskRunLog(taskId, taskRunId, [ + { + type: "notification", + timestamp: new Date().toISOString(), + notification: { + jsonrpc: "2.0", + method: "_posthog/agent_checkpoint", + params: { headSha }, + }, + }, + ]); + } catch (err) { + log.debug("Failed to capture agent checkpoint", { taskRunId, err }); } } diff --git a/apps/code/src/main/services/fork/service.ts b/apps/code/src/main/services/fork/service.ts new file mode 100644 index 000000000..cb803b12a --- /dev/null +++ b/apps/code/src/main/services/fork/service.ts @@ -0,0 +1,243 @@ +import fs from "node:fs/promises"; +import { homedir } from "node:os"; +import path from "node:path"; +import { + conversationTurnsToJsonlEntries, + filterEntriesUpToMessage, + getSessionJsonlPath, + rebuildConversation, + selectRecentTurns, +} from "@posthog/agent/adapters/claude/session/jsonl-hydration"; +import { PostHogAPIClient } from "@posthog/agent/posthog-api"; +import type * as AgentTypes from "@posthog/agent/types"; +import { createGitClient } from "@posthog/git/client"; +import { inject, injectable } from "inversify"; +import { v7 as uuidv7 } from "uuid"; +import type { ForkRelationshipRepository } from "../../db/repositories/fork-relationship-repository"; +import { MAIN_TOKENS } from "../../di/tokens"; +import { logger } from "../../utils/logger"; +import type { AgentAuthAdapter } from "../agent/auth-adapter"; +import type { WorkspaceService } from "../workspace/service"; + +const log = logger.scope("fork-service"); + +export interface PrepareForkInput { + sourceTaskId: string; + sourceTaskRunId: string; + /** Title of the source task — stored so we can show it even if the task is later deleted. */ + sourceTaskTitle: string; + forkAtMessageIndex: number; + newTaskId: string; + newTaskRunId: string; + sourceWorktreePath: string; + mainRepoPath: string; + apiHost: string; + projectId: number; + model?: string; +} + +export interface PrepareForkResult { + newWorktreePath: string; + newSessionId: string; +} + +@injectable() +export class ForkService { + constructor( + @inject(MAIN_TOKENS.AgentAuthAdapter) + private readonly agentAuthAdapter: AgentAuthAdapter, + @inject(MAIN_TOKENS.WorkspaceService) + private readonly workspaceService: WorkspaceService, + @inject(MAIN_TOKENS.ForkRelationshipRepository) + private readonly forkRepo: ForkRelationshipRepository, + ) {} + + async prepareFork(input: PrepareForkInput): Promise { + const { + sourceTaskId, + sourceTaskRunId, + sourceTaskTitle, + forkAtMessageIndex, + newTaskId, + newTaskRunId, + sourceWorktreePath, + mainRepoPath, + apiHost, + projectId, + model, + } = input; + + log.info("Preparing fork", { + sourceTaskId, + sourceTaskRunId, + forkAtMessageIndex, + newTaskId, + }); + + const posthogConfig = this.agentAuthAdapter.createPosthogConfig({ + apiHost, + projectId, + }); + const posthogAPI = new PostHogAPIClient(posthogConfig); + + // 1. Fetch the full S3 log for the source run + const sourceTaskRun = await posthogAPI.getTaskRun( + sourceTaskId, + sourceTaskRunId, + ); + const allEntries = await posthogAPI.fetchTaskRunLogs(sourceTaskRun); + + // 2. Slice entries up to and including the response to message N + const entries = filterEntriesUpToMessage(allEntries, forkAtMessageIndex); + + // 3. Find git HEAD SHA from the last checkpoint entry in the filtered set + const headSha = + this.extractHeadSha(entries) ?? + (await this.getCurrentHead(sourceWorktreePath)); + + if (!headSha) { + throw new Error( + `Cannot determine git HEAD for fork at message ${forkAtMessageIndex} (sourceTaskRunId: ${sourceTaskRunId})`, + ); + } + + // 4. Create the new worktree + workspace record + const workspace = await this.workspaceService.createWorkspaceFromFork({ + taskId: newTaskId, + mainRepoPath, + headSha, + }); + + // 5. Record the lineage so the UI can show "Forked from: [title]" + // Non-fatal: if the migration hasn't run yet the banner simply won't show. + try { + this.forkRepo.create({ + forkedTaskId: newTaskId, + sourceTaskId, + sourceTaskRunId, + sourceTaskTitle, + forkAtMessageIndex, + forkedAt: new Date().toISOString(), + }); + } catch (err) { + log.warn("Could not record fork lineage (migration pending?)", { err }); + } + + const newWorktreePath = workspace.worktree?.worktreePath; + if (!newWorktreePath) { + throw new Error(`Fork workspace creation did not return a worktree path`); + } + + // 6. Write the truncated conversation as a JSONL file for the new SDK session + const newSessionId = uuidv7(); + await this.writeSessionJsonl({ + entries, + sessionId: newSessionId, + cwd: newWorktreePath, + model, + }); + + // 7. Write truncated log to local cache so the new run can be resumed instantly. + // Include newSessionId so the session service can discover which JSONL file to load. + await this.writeLocalCache(newTaskRunId, entries, newSessionId); + + log.info("Fork prepared", { + newTaskId, + newTaskRunId, + newWorktreePath, + newSessionId, + headSha, + entriesIncluded: entries.length, + }); + + return { newWorktreePath, newSessionId }; + } + + /** Extract the HEAD SHA from the last `_posthog/agent_checkpoint` entry. */ + private extractHeadSha(entries: AgentTypes.StoredEntry[]): string | null { + for (let i = entries.length - 1; i >= 0; i--) { + const entry = entries[i]; + if (entry.notification?.method === "_posthog/agent_checkpoint") { + const headSha = (entry.notification.params as Record) + ?.headSha; + if (typeof headSha === "string" && headSha.trim()) { + return headSha.trim(); + } + } + } + return null; + } + + /** Fallback: get the current HEAD SHA from the source worktree on disk. */ + private async getCurrentHead(worktreePath: string): Promise { + try { + const git = createGitClient(worktreePath); + return await git.revparse(["HEAD"]); + } catch { + return null; + } + } + + private async writeSessionJsonl(params: { + entries: AgentTypes.StoredEntry[]; + sessionId: string; + cwd: string; + model?: string; + }): Promise { + const { entries, sessionId, cwd, model } = params; + + const allTurns = rebuildConversation(entries); + if (allTurns.length === 0) return; + + const turns = selectRecentTurns(allTurns); + const lines = conversationTurnsToJsonlEntries(turns, { + sessionId, + cwd, + model, + }); + + const jsonlPath = getSessionJsonlPath(sessionId, cwd); + await fs.mkdir(path.dirname(jsonlPath), { recursive: true }); + + const tmpPath = `${jsonlPath}.tmp.${Date.now()}`; + await fs.writeFile(tmpPath, `${lines.join("\n")}\n`); + await fs.rename(tmpPath, jsonlPath); + + log.info("Wrote fork JSONL", { sessionId, turns: turns.length }); + } + + private async writeLocalCache( + taskRunId: string, + entries: AgentTypes.StoredEntry[], + newSessionId: string, + ): Promise { + if (entries.length === 0) return; + + const sessionDir = path.join( + homedir(), + ".posthog-code", + "sessions", + taskRunId, + ); + const logPath = path.join(sessionDir, "logs.ndjson"); + + await fs.mkdir(sessionDir, { recursive: true }); + + // Append a synthetic entry so parseLogContent picks up the fork's session ID, + // not the source session's ID that may appear in the copied entries. + const metaEntry: AgentTypes.StoredEntry = { + type: "notification", + timestamp: new Date().toISOString(), + notification: { + jsonrpc: "2.0", + method: "_posthog/fork_session_meta", + params: { sdkSessionId: newSessionId }, + }, + }; + const allEntries = [...entries, metaEntry]; + const lines = `${allEntries.map((e) => JSON.stringify(e)).join("\n")}\n`; + await fs.writeFile(logPath, lines, "utf-8"); + + log.info("Wrote fork local cache", { taskRunId, entries: entries.length }); + } +} diff --git a/apps/code/src/main/services/workspace/service.ts b/apps/code/src/main/services/workspace/service.ts index 10ddfc363..58e941d76 100644 --- a/apps/code/src/main/services/workspace/service.ts +++ b/apps/code/src/main/services/workspace/service.ts @@ -645,6 +645,50 @@ export class WorkspaceService extends TypedEventEmitter }; } + async createWorkspaceFromFork(params: { + taskId: string; + mainRepoPath: string; + headSha: string; + }): Promise { + const { taskId, mainRepoPath, headSha } = params; + + const worktreeBasePath = getWorktreeLocation(); + const worktreeManager = new WorktreeManager({ + mainRepoPath, + worktreeBasePath, + }); + + const worktree = + await worktreeManager.createDetachedWorktreeAtCommit(headSha); + + const repository = this.repositoryRepo.findByPath(mainRepoPath); + const repositoryId = repository?.id ?? null; + + const createdWorkspace = this.workspaceRepo.create({ + taskId, + repositoryId, + mode: "worktree", + }); + + this.worktreeRepo.create({ + workspaceId: createdWorkspace.id, + name: worktree.worktreeName, + path: worktree.worktreePath, + }); + + log.info( + `Created fork workspace for task ${taskId} at ${worktree.worktreePath} (sha: ${headSha})`, + ); + + return { + taskId, + mode: "worktree", + worktree, + branchName: worktree.branchName || null, + linkedBranch: null, + }; + } + async deleteWorkspace(taskId: string, mainRepoPath: string): Promise { log.info(`Deleting workspace for task ${taskId}`); diff --git a/apps/code/src/main/trpc/router.ts b/apps/code/src/main/trpc/router.ts index 75a5c85c2..31d941821 100644 --- a/apps/code/src/main/trpc/router.ts +++ b/apps/code/src/main/trpc/router.ts @@ -13,6 +13,7 @@ import { externalAppsRouter } from "./routers/external-apps"; import { fileWatcherRouter } from "./routers/file-watcher"; import { focusRouter } from "./routers/focus"; import { foldersRouter } from "./routers/folders"; +import { forkRouter } from "./routers/fork"; import { fsRouter } from "./routers/fs"; import { gitRouter } from "./routers/git"; import { githubIntegrationRouter } from "./routers/github-integration"; @@ -52,6 +53,7 @@ export const trpcRouter = router({ externalApps: externalAppsRouter, fileWatcher: fileWatcherRouter, focus: focusRouter, + fork: forkRouter, folders: foldersRouter, fs: fsRouter, git: gitRouter, diff --git a/apps/code/src/main/trpc/routers/fork.ts b/apps/code/src/main/trpc/routers/fork.ts new file mode 100644 index 000000000..6c8b6aaf7 --- /dev/null +++ b/apps/code/src/main/trpc/routers/fork.ts @@ -0,0 +1,66 @@ +import { z } from "zod"; +import type { IForkRelationshipRepository } from "../../db/repositories/fork-relationship-repository"; +import { container } from "../../di/container"; +import { MAIN_TOKENS } from "../../di/tokens"; +import type { ForkService } from "../../services/fork/service"; +import { publicProcedure, router } from "../trpc"; + +const getForkService = () => + container.get(MAIN_TOKENS.ForkService); +const getForkRepo = () => + container.get( + MAIN_TOKENS.ForkRelationshipRepository, + ); + +const prepareForkInput = z.object({ + sourceTaskId: z.string(), + sourceTaskRunId: z.string(), + sourceTaskTitle: z.string(), + forkAtMessageIndex: z.number().int().nonnegative(), + newTaskId: z.string(), + newTaskRunId: z.string(), + sourceWorktreePath: z.string(), + mainRepoPath: z.string(), + apiHost: z.string(), + projectId: z.number(), + model: z.string().optional(), +}); + +const prepareForkOutput = z.object({ + newWorktreePath: z.string(), + newSessionId: z.string(), +}); + +const forkRelationshipOutput = z + .object({ + forkedTaskId: z.string(), + sourceTaskId: z.string(), + sourceTaskRunId: z.string(), + sourceTaskTitle: z.string(), + forkAtMessageIndex: z.number(), + forkedAt: z.string(), + }) + .nullable(); + +export const forkRouter = router({ + prepareFork: publicProcedure + .input(prepareForkInput) + .output(prepareForkOutput) + .mutation(({ input }) => getForkService().prepareFork(input)), + + getForkRelationship: publicProcedure + .input(z.object({ taskId: z.string() })) + .output(forkRelationshipOutput) + .query(({ input }) => { + const rel = getForkRepo().findByForkedTaskId(input.taskId); + if (!rel) return null; + return { + forkedTaskId: rel.forkedTaskId, + sourceTaskId: rel.sourceTaskId, + sourceTaskRunId: rel.sourceTaskRunId, + sourceTaskTitle: rel.sourceTaskTitle, + forkAtMessageIndex: rel.forkAtMessageIndex, + forkedAt: rel.forkedAt, + }; + }), +}); diff --git a/apps/code/src/renderer/features/sessions/components/ConversationView.tsx b/apps/code/src/renderer/features/sessions/components/ConversationView.tsx index 6d0773c56..8676e5827 100644 --- a/apps/code/src/renderer/features/sessions/components/ConversationView.tsx +++ b/apps/code/src/renderer/features/sessions/components/ConversationView.tsx @@ -1,6 +1,7 @@ import { CHAT_CONTENT_MAX_WIDTH } from "@features/sessions/constants"; import { useContextUsage } from "@features/sessions/hooks/useContextUsage"; import { useConversationSearch } from "@features/sessions/hooks/useConversationSearch"; +import { useForkSession } from "@features/sessions/hooks/useForkSession"; import { sessionStoreSetters, useOptimisticItemsForTask, @@ -112,6 +113,7 @@ export function ConversationView({ const optimisticItems = useOptimisticItemsForTask(taskId); const session = useSessionForTask(taskId); const pausedDurationMs = session?.pausedDurationMs ?? 0; + const { fork } = useForkSession({ taskId, task }); const queuedItems = useMemo[]>( () => @@ -136,6 +138,17 @@ export function ConversationView({ [conversationItems, optimisticItems, queuedItems, isCloud], ); + const userMessageIndexMap = useMemo(() => { + const map = new Map(); + let count = 0; + for (const item of items) { + if (item.type === "user_message") { + map.set(item.id, count++); + } + } + return map; + }, [items]); + // Keep MCP App tool call items mounted so their iframes and bridges // survive scrolling out of the virtualized viewport. const mcpAppIndices = useMemo(() => { @@ -184,7 +197,8 @@ export function ConversationView({ const renderItem = useCallback( (item: ConversationItem) => { switch (item.type) { - case "user_message": + case "user_message": { + const messageIndex = userMessageIndexMap.get(item.id) ?? 0; return ( void fork(messageIndex) : undefined} /> ); + } case "git_action": return ; case "skill_button_action": @@ -239,7 +255,15 @@ export function ConversationView({ ); } }, - [repoPath, taskId, slackThreadUrl, firstUserMessageId, initialItemIds], + [ + repoPath, + taskId, + slackThreadUrl, + firstUserMessageId, + initialItemIds, + userMessageIndexMap, + fork, + ], ); const getItemKey = useCallback((item: ConversationItem) => item.id, []); diff --git a/apps/code/src/renderer/features/sessions/components/ForkedFromBanner.tsx b/apps/code/src/renderer/features/sessions/components/ForkedFromBanner.tsx new file mode 100644 index 000000000..1f457c6ab --- /dev/null +++ b/apps/code/src/renderer/features/sessions/components/ForkedFromBanner.tsx @@ -0,0 +1,57 @@ +import { GitBranch } from "@phosphor-icons/react"; +import { Flex, Text } from "@radix-ui/themes"; +import { useTRPC } from "@renderer/trpc/client"; +import type { Task } from "@shared/types"; +import { useNavigationStore } from "@stores/navigationStore"; +import { useQuery } from "@tanstack/react-query"; + +interface ForkedFromBannerProps { + taskId: string; +} + +export function ForkedFromBanner({ taskId }: ForkedFromBannerProps) { + const trpc = useTRPC(); + const navigateToTask = useNavigationStore((s) => s.navigateToTask); + + const { data: relationship } = useQuery( + trpc.fork.getForkRelationship.queryOptions({ taskId }), + ); + + if (!relationship) return null; + + const handleClick = () => { + const placeholder: Task = { + id: relationship.sourceTaskId, + task_number: null, + slug: "", + title: relationship.sourceTaskTitle, + description: "", + origin_product: "user_created", + created_at: "", + updated_at: "", + }; + navigateToTask(placeholder); + }; + + return ( + + + + Forked from: + + + + ); +} diff --git a/apps/code/src/renderer/features/sessions/components/SessionView.tsx b/apps/code/src/renderer/features/sessions/components/SessionView.tsx index b95675e7d..7a34e5793 100644 --- a/apps/code/src/renderer/features/sessions/components/SessionView.tsx +++ b/apps/code/src/renderer/features/sessions/components/SessionView.tsx @@ -39,6 +39,7 @@ import { import { CloudInitializingView } from "./CloudInitializingView"; import { ConversationView } from "./ConversationView"; import { DropZoneOverlay } from "./DropZoneOverlay"; +import { ForkedFromBanner } from "./ForkedFromBanner"; import { ModelSelector } from "./ModelSelector"; import { PlanStatusBar } from "./PlanStatusBar"; import { ReasoningLevelSelector } from "./ReasoningLevelSelector"; @@ -453,6 +454,7 @@ export function SessionView({ > {isSuspended ? ( <> + {taskId && } + {taskId && } void; } function formatTimestamp(ts: number): string { @@ -49,6 +51,7 @@ export function UserMessage({ sourceUrl, attachments = [], animate = true, + onFork, }: UserMessageProps) { const containsFileMentions = hasFileMentions(content); const showAttachmentChips = attachments.length > 0 && !containsFileMentions; @@ -171,6 +174,18 @@ export function UserMessage({ {copied ? : } + {onFork && ( + + + + + + )} diff --git a/apps/code/src/renderer/features/sessions/hooks/useForkSession.ts b/apps/code/src/renderer/features/sessions/hooks/useForkSession.ts new file mode 100644 index 000000000..25a58cbab --- /dev/null +++ b/apps/code/src/renderer/features/sessions/hooks/useForkSession.ts @@ -0,0 +1,130 @@ +import { useOptionalAuthenticatedClient } from "@features/auth/hooks/authClient"; +import { + fetchAuthState, + useAuthStateValue, +} from "@features/auth/hooks/authQueries"; +import { useSessionForTask } from "@features/sessions/hooks/useSession"; +import { useWorkspace } from "@features/workspace/hooks/useWorkspace"; +import { useTRPC } from "@renderer/trpc/client"; +import type { Task } from "@shared/types"; +import { getCloudUrlFromRegion } from "@shared/utils/urls"; +import { useNavigationStore } from "@stores/navigationStore"; +import { useMutation } from "@tanstack/react-query"; +import { toast } from "@utils/toast"; +import { useCallback } from "react"; + +interface UseForkSessionOptions { + taskId: string | undefined; + task: Task | undefined; +} + +export function useForkSession({ taskId, task }: UseForkSessionOptions) { + const workspace = useWorkspace(taskId); + const session = useSessionForTask(taskId); + const posthogClient = useOptionalAuthenticatedClient(); + const cloudRegion = useAuthStateValue((s) => s.cloudRegion); + const projectId = useAuthStateValue((s) => s.projectId); + const navigateToTask = useNavigationStore((s) => s.navigateToTask); + const trpc = useTRPC(); + + const prepareForkMutation = useMutation( + trpc.fork.prepareFork.mutationOptions(), + ); + + const fork = useCallback( + async (messageIndex: number) => { + if ( + !taskId || + !task || + !workspace || + !session || + !cloudRegion || + !projectId || + !posthogClient + ) { + toast.error("Cannot fork: missing task, workspace, or session data"); + return; + } + + const sourceWorktreePath = workspace.worktreePath ?? workspace.folderPath; + const mainRepoPath = workspace.folderPath; + + if (!sourceWorktreePath || !mainRepoPath) { + toast.error("Cannot fork: workspace has no repository path"); + return; + } + + const model = session.configOptions?.find((o) => o.id === "model") + ?.currentValue as string | undefined; + + const toastId = toast.loading("Creating fork…"); + + try { + // 1. Create a new task on PostHog + const newTask = await posthogClient.createTask({ + title: `Fork of: ${task.title}`, + description: task.description ?? "", + repository: task.repository ?? undefined, + origin_product: "user_created", + }); + + // 2. Create a task run for it + const newTaskRun = await posthogClient.createTaskRun(newTask.id); + + // 3. Prepare the fork in the main process (worktree + JSONL) + const authState = await fetchAuthState(); + if (authState.status !== "authenticated" || !authState.cloudRegion) { + throw new Error("Not authenticated"); + } + const apiHost = getCloudUrlFromRegion(authState.cloudRegion); + + await prepareForkMutation.mutateAsync({ + sourceTaskId: taskId, + sourceTaskRunId: session.taskRunId, + sourceTaskTitle: task.title, + forkAtMessageIndex: messageIndex, + newTaskId: newTask.id, + newTaskRunId: newTaskRun.id, + sourceWorktreePath, + mainRepoPath, + apiHost, + projectId, + model, + }); + + toast.success("Fork created", { id: toastId }); + + // 4. Navigate to the new task — include latest_run so the session service + // knows to look for the pre-seeded local cache rather than creating a new run. + const forkTask: Task = { + ...(newTask as unknown as Task), + latest_run: newTaskRun, + }; + navigateToTask(forkTask); + } catch (err) { + toast.error( + err instanceof Error ? err.message : "Failed to create fork", + { + id: toastId, + }, + ); + } + }, + [ + taskId, + task, + workspace, + session, + cloudRegion, + projectId, + posthogClient, + prepareForkMutation, + navigateToTask, + ], + ); + + return { + fork, + isForkPending: prepareForkMutation.isPending, + }; +} diff --git a/apps/code/src/renderer/features/sessions/service/service.ts b/apps/code/src/renderer/features/sessions/service/service.ts index efa33cc25..cef083643 100644 --- a/apps/code/src/renderer/features/sessions/service/service.ts +++ b/apps/code/src/renderer/features/sessions/service/service.ts @@ -442,6 +442,33 @@ export class SessionService { return; } + // If a pre-seeded local cache exists for this run (e.g. a fork), reconnect + // to that run instead of creating a brand-new session. + if (latestRun?.id) { + try { + const localContent = await trpcClient.logs.readLocalLogs.query({ + taskRunId: latestRun.id, + }); + if (localContent?.trim()) { + log.info("Found pre-seeded local cache for run, reconnecting", { + taskId, + taskRunId: latestRun.id, + }); + await this.reconnectToLocalSession( + taskId, + latestRun.id, + taskTitle, + undefined, + repoPath, + auth, + ); + return; + } + } catch { + // No local cache — fall through to createNewLocalSession + } + } + await this.createNewLocalSession( taskId, taskTitle, diff --git a/packages/agent/src/adapters/claude/session/jsonl-hydration.ts b/packages/agent/src/adapters/claude/session/jsonl-hydration.ts index 16d1a54ce..326dc1c61 100644 --- a/packages/agent/src/adapters/claude/session/jsonl-hydration.ts +++ b/packages/agent/src/adapters/claude/session/jsonl-hydration.ts @@ -64,6 +64,64 @@ export function getSessionJsonlPath(sessionId: string, cwd: string): string { return path.join(configDir, "projects", projectKey, `${sessionId}.jsonl`); } +/** + * Return entries up to and including the full assistant turn that responded to + * the Nth user message (0-based), plus any trailing `_posthog/agent_checkpoint` + * entry so fork callers can read the git HEAD at that exact point. + * + * If messageIndex is beyond the last user message, returns all entries. + */ +export function filterEntriesUpToMessage( + entries: StoredEntry[], + messageIndex: number, +): StoredEntry[] { + let userMessageCount = 0; + + for (let i = 0; i < entries.length; i++) { + const entry = entries[i]; + const params = entry.notification?.params as Record; + + if (entry.notification?.method === "session/update" && params?.update) { + const update = params.update as { sessionUpdate?: string }; + const isUserMessage = + update.sessionUpdate === "user_message" || + update.sessionUpdate === "user_message_chunk"; + + if (isUserMessage) { + if (userMessageCount === messageIndex) { + // Scan forward through the assistant response that follows this user message. + // Stop at the next user message. Include any checkpoint entries. + let cutoffIndex = i + 1; + for (let j = i + 1; j < entries.length; j++) { + const jEntry = entries[j]; + const jParams = jEntry.notification?.params as Record< + string, + unknown + >; + if ( + jEntry.notification?.method === "session/update" && + jParams?.update + ) { + const jUpdate = jParams.update as { sessionUpdate?: string }; + if ( + jUpdate.sessionUpdate === "user_message" || + jUpdate.sessionUpdate === "user_message_chunk" + ) { + break; + } + } + cutoffIndex = j + 1; + } + return entries.slice(0, cutoffIndex); + } + userMessageCount++; + } + } + } + + return entries; +} + export function rebuildConversation( entries: StoredEntry[], ): ConversationTurn[] {