diff --git a/.github/workflows/e2e-api.yml b/.github/workflows/e2e-api.yml index 439dc29..559b68c 100644 --- a/.github/workflows/e2e-api.yml +++ b/.github/workflows/e2e-api.yml @@ -96,9 +96,9 @@ jobs: if: steps.doc_check.outputs.run_tests == 'true' run: bun install --frozen-lockfile - - name: Install Chromium + - name: Install Playwright browsers if: steps.doc_check.outputs.run_tests == 'true' - run: bunx playwright install --with-deps chromium + run: ./packages/backend/node_modules/.bin/playwright install --with-deps - name: Install Venom if: steps.doc_check.outputs.run_tests == 'true' diff --git a/apps/web/src/app/api/auth/refresh/route.test.ts b/apps/web/src/app/api/auth/refresh/route.test.ts new file mode 100644 index 0000000..106f189 --- /dev/null +++ b/apps/web/src/app/api/auth/refresh/route.test.ts @@ -0,0 +1,119 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const refreshWorkOsAccessToken = vi.fn(); + +vi.mock("@/lib/workos-device-auth", () => ({ + refreshWorkOsAccessToken, +})); + +describe("POST /api/auth/refresh", () => { + beforeEach(() => { + refreshWorkOsAccessToken.mockReset(); + }); + + it("returns 400 for invalid JSON payloads", async () => { + const { POST } = await import("./route"); + const traceId = "11111111-1111-4111-8111-111111111111"; + + const response = await POST( + new Request("http://localhost/api/auth/refresh", { + method: "POST", + headers: { + "content-type": "application/json", + "x-trace-id": traceId, + }, + body: "{", + }) + ); + + expect(response.status).toBe(400); + expect(response.headers.get("x-trace-id")).toBe(traceId); + await expect(response.json()).resolves.toEqual({ + error: "Invalid JSON payload", + }); + }); + + it("returns 400 when refreshToken is missing", async () => { + const { POST } = await import("./route"); + const traceId = "22222222-2222-4222-8222-222222222222"; + + const response = await POST( + new Request("http://localhost/api/auth/refresh", { + method: "POST", + headers: { + "content-type": "application/json", + "x-trace-id": traceId, + }, + body: JSON.stringify({}), + }) + ); + + expect(response.status).toBe(400); + expect(response.headers.get("x-trace-id")).toBe(traceId); + await expect(response.json()).resolves.toEqual({ + error: "refreshToken is required", + }); + expect(refreshWorkOsAccessToken).not.toHaveBeenCalled(); + }); + + it("returns refreshed credentials on success", async () => { + refreshWorkOsAccessToken.mockResolvedValue({ + status: "success", + accessToken: "fresh-access", + refreshToken: "fresh-refresh", + accessTokenExpiresAt: 123, + }); + + const { POST } = await import("./route"); + const traceId = "33333333-3333-4333-8333-333333333333"; + + const response = await POST( + new Request("http://localhost/api/auth/refresh", { + method: "POST", + headers: { + "content-type": "application/json", + "x-trace-id": traceId, + }, + body: JSON.stringify({ refreshToken: "stored-refresh-token" }), + }) + ); + + expect(response.status).toBe(200); + expect(response.headers.get("x-trace-id")).toBe(traceId); + expect(refreshWorkOsAccessToken).toHaveBeenCalledWith({ + refreshToken: "stored-refresh-token", + }); + await expect(response.json()).resolves.toEqual({ + status: "success", + accessToken: "fresh-access", + refreshToken: "fresh-refresh", + accessTokenExpiresAt: 123, + }); + }); + + it("passes through invalid_grant responses", async () => { + refreshWorkOsAccessToken.mockResolvedValue({ + status: "invalid_grant", + }); + + const { POST } = await import("./route"); + const traceId = "44444444-4444-4444-8444-444444444444"; + + const response = await POST( + new Request("http://localhost/api/auth/refresh", { + method: "POST", + headers: { + "content-type": "application/json", + "x-trace-id": traceId, + }, + body: JSON.stringify({ refreshToken: "expired-refresh-token" }), + }) + ); + + expect(response.status).toBe(200); + expect(response.headers.get("x-trace-id")).toBe(traceId); + await expect(response.json()).resolves.toEqual({ + status: "invalid_grant", + }); + }); +}); diff --git a/apps/web/src/app/api/auth/refresh/route.ts b/apps/web/src/app/api/auth/refresh/route.ts new file mode 100644 index 0000000..f67c9ae --- /dev/null +++ b/apps/web/src/app/api/auth/refresh/route.ts @@ -0,0 +1,69 @@ +import { createTraceId, normalizeTraceId } from "@sketchi/shared"; + +import { refreshWorkOsAccessToken } from "@/lib/workos-device-auth"; + +export async function POST(request: Request) { + const traceId = + normalizeTraceId(request.headers.get("x-trace-id")) ?? createTraceId(); + + let payload: unknown; + try { + payload = await request.json(); + } catch { + return Response.json( + { error: "Invalid JSON payload" }, + { + status: 400, + headers: { + "cache-control": "no-store", + "x-trace-id": traceId, + }, + } + ); + } + + const refreshToken = + payload && typeof payload === "object" && "refreshToken" in payload + ? (payload.refreshToken as string) + : undefined; + + if (!(typeof refreshToken === "string" && refreshToken.trim().length > 0)) { + return Response.json( + { error: "refreshToken is required" }, + { + status: 400, + headers: { + "cache-control": "no-store", + "x-trace-id": traceId, + }, + } + ); + } + + try { + const result = await refreshWorkOsAccessToken({ + refreshToken, + }); + + return Response.json(result, { + headers: { + "cache-control": "no-store", + "x-trace-id": traceId, + }, + }); + } catch (error) { + const message = + error instanceof Error ? error.message : "Failed to refresh access token"; + + return Response.json( + { error: message }, + { + status: 500, + headers: { + "cache-control": "no-store", + "x-trace-id": traceId, + }, + } + ); + } +} diff --git a/apps/web/src/lib/orpc/openapi.ts b/apps/web/src/lib/orpc/openapi.ts index b762dc1..9e4367d 100644 --- a/apps/web/src/lib/orpc/openapi.ts +++ b/apps/web/src/lib/orpc/openapi.ts @@ -91,10 +91,11 @@ function withDeviceAuthPaths(spec: unknown): unknown { }, DeviceAuthTokenSuccessResponse: { type: "object", - required: ["status", "accessToken"], + required: ["status", "accessToken", "refreshToken"], properties: { status: { type: "string", enum: ["success"] }, accessToken: { type: "string" }, + refreshToken: { type: "string" }, accessTokenExpiresAt: { type: "integer", description: @@ -103,6 +104,22 @@ function withDeviceAuthPaths(spec: unknown): unknown { }, additionalProperties: false, }, + DeviceAuthRefreshRequest: { + type: "object", + required: ["refreshToken"], + properties: { + refreshToken: { type: "string" }, + }, + additionalProperties: false, + }, + DeviceAuthRefreshInvalidGrantResponse: { + type: "object", + required: ["status"], + properties: { + status: { type: "string", enum: ["invalid_grant"] }, + }, + additionalProperties: false, + }, DeviceAuthTokenInvalidGrantResponse: { type: "object", required: ["status"], @@ -278,12 +295,87 @@ function withDeviceAuthPaths(spec: unknown): unknown { }, }, }, + "/auth/refresh": { + post: { + tags: ["Device Auth"], + summary: "Refresh access token", + description: + "Exchanges a stored WorkOS refresh token for a new access token and rotated refresh token. This endpoint does not require user bearer auth.", + requestBody: { + required: true, + content: { + "application/json": { + schema: { + $ref: "#/components/schemas/DeviceAuthRefreshRequest", + }, + example: { + refreshToken: "workos_refresh_token", + }, + }, + }, + }, + responses: { + 200: { + description: "Refresh result.", + headers: { + "x-trace-id": { + schema: { type: "string" }, + }, + }, + content: { + "application/json": { + schema: { + oneOf: [ + { + $ref: "#/components/schemas/DeviceAuthTokenSuccessResponse", + }, + { + $ref: "#/components/schemas/DeviceAuthRefreshInvalidGrantResponse", + }, + ], + }, + }, + }, + }, + 400: { + description: "Invalid request payload.", + headers: { + "x-trace-id": { + schema: { type: "string" }, + }, + }, + content: { + "application/json": { + schema: { + $ref: "#/components/schemas/DeviceAuthErrorResponse", + }, + }, + }, + }, + 500: { + description: "Failed to refresh access token.", + headers: { + "x-trace-id": { + schema: { type: "string" }, + }, + }, + content: { + "application/json": { + schema: { + $ref: "#/components/schemas/DeviceAuthErrorResponse", + }, + }, + }, + }, + }, + }, + }, } satisfies Record; const existingDescription = document.info?.description ?? ""; const notes = [ existingDescription, - "Device auth endpoints are included for CLI integrations (`/auth/device/start`, `/auth/device/token`).", + "Device auth endpoints are included for CLI integrations (`/auth/device/start`, `/auth/device/token`, `/auth/refresh`).", ] .filter(Boolean) .join(" "); diff --git a/apps/web/src/lib/orpc/router.ts b/apps/web/src/lib/orpc/router.ts index 1ea4b5a..2197923 100644 --- a/apps/web/src/lib/orpc/router.ts +++ b/apps/web/src/lib/orpc/router.ts @@ -66,6 +66,7 @@ export function createOrpcContext( const orpc = os.$context(); type PublicErrorReason = + | "NOT_FOUND" | "UNAUTHORIZED" | "AI_NO_OUTPUT" | "AI_PROVIDER_ERROR" @@ -84,6 +85,10 @@ function classifyError(error: unknown): { const name = error.name; const lower = message.toLowerCase(); + if (lower.includes("session not found")) { + return { reason: "NOT_FOUND", message, name }; + } + if (lower.includes("no output generated")) { return { reason: "AI_NO_OUTPUT", message, name }; } @@ -154,6 +159,19 @@ function throwInternalError(params: { }); } + if (reason === "NOT_FOUND") { + throw new ORPCError("NOT_FOUND", { + message: `${params.action} could not find the requested diagram session. traceId=${params.traceId}`, + data: { + traceId: params.traceId, + stage: params.stage, + action: params.action, + errorName: name, + errorMessage: message.slice(0, 600), + }, + }); + } + withScope((scope) => { scope.setTag("traceId", params.traceId); scope.setTag("orpc.route", params.action); diff --git a/apps/web/src/lib/workos-device-auth.test.ts b/apps/web/src/lib/workos-device-auth.test.ts index 0e597a4..b0c929e 100644 --- a/apps/web/src/lib/workos-device-auth.test.ts +++ b/apps/web/src/lib/workos-device-auth.test.ts @@ -2,6 +2,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { pollWorkOsDeviceFlow, + refreshWorkOsAccessToken, startWorkOsDeviceFlow, } from "./workos-device-auth"; @@ -119,6 +120,7 @@ describe("workos-device-auth", () => { new Response( JSON.stringify({ access_token: "access-token-123", + refresh_token: "refresh-token-123", expires_in: 3600, }), { status: 200 } @@ -136,10 +138,57 @@ describe("workos-device-auth", () => { } expect(result.accessToken).toBe("access-token-123"); + expect(result.refreshToken).toBe("refresh-token-123"); expect(result.accessTokenExpiresAt).toBeDefined(); expect(result.accessTokenExpiresAt).toBeGreaterThanOrEqual( before + 3_599_000 ); expect(result.accessTokenExpiresAt).toBeLessThanOrEqual(after + 3_601_000); }); + + it("refreshWorkOsAccessToken maps successful refresh payload", async () => { + const fetchMock = vi.fn().mockResolvedValue( + new Response( + JSON.stringify({ + access_token: "next-access-token", + refresh_token: "next-refresh-token", + expires_in: 1800, + }), + { status: 200 } + ) + ); + globalThis.fetch = fetchMock as unknown as typeof fetch; + + const before = Date.now(); + const result = await refreshWorkOsAccessToken({ + refreshToken: "stored-refresh-token", + }); + const after = Date.now(); + + expect(result.status).toBe("success"); + if (result.status !== "success") { + return; + } + + expect(result.accessToken).toBe("next-access-token"); + expect(result.refreshToken).toBe("next-refresh-token"); + expect(result.accessTokenExpiresAt).toBeDefined(); + expect(result.accessTokenExpiresAt).toBeGreaterThanOrEqual( + before + 1_799_000 + ); + expect(result.accessTokenExpiresAt).toBeLessThanOrEqual(after + 1_801_000); + }); + + it("refreshWorkOsAccessToken maps invalid_grant payloads", async () => { + const fetchMock = vi.fn().mockResolvedValue( + new Response(JSON.stringify({ error: "invalid_grant" }), { + status: 400, + }) + ); + globalThis.fetch = fetchMock as unknown as typeof fetch; + + await expect( + refreshWorkOsAccessToken({ refreshToken: "stale-refresh-token" }) + ).resolves.toEqual({ status: "invalid_grant" }); + }); }); diff --git a/apps/web/src/lib/workos-device-auth.ts b/apps/web/src/lib/workos-device-auth.ts index 323fe3a..2c67fc7 100644 --- a/apps/web/src/lib/workos-device-auth.ts +++ b/apps/web/src/lib/workos-device-auth.ts @@ -15,6 +15,7 @@ interface WorkOsDeviceStartSuccess { interface WorkOsTokenSuccess { access_token: string; expires_in?: number; + refresh_token: string; } interface WorkOsTokenError { @@ -142,6 +143,7 @@ export async function pollWorkOsDeviceFlow(input: { | { status: "success"; accessToken: string; + refreshToken: string; accessTokenExpiresAt?: number; } | { @@ -167,7 +169,7 @@ export async function pollWorkOsDeviceFlow(input: { const payload = await parseJsonResponse(response); if (response.ok) { - if (!("access_token" in payload)) { + if (!("access_token" in payload && "refresh_token" in payload)) { throw new Error("Unexpected WorkOS token response"); } @@ -179,6 +181,7 @@ export async function pollWorkOsDeviceFlow(input: { return { status: "success", accessToken: payload.access_token, + refreshToken: payload.refresh_token, accessTokenExpiresAt, }; } @@ -221,3 +224,75 @@ export async function pollWorkOsDeviceFlow(input: { }) ); } + +export async function refreshWorkOsAccessToken(input: { + refreshToken: string; + organizationId?: string; + ipAddress?: string; + userAgent?: string; +}): Promise< + | { + status: "success"; + accessToken: string; + refreshToken: string; + accessTokenExpiresAt?: number; + } + | { + status: "invalid_grant"; + } +> { + const clientId = getWorkOsClientId(); + const response = await fetch( + `${resolveWorkOsBaseUrl()}/user_management/authenticate`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + grant_type: "refresh_token", + client_id: clientId, + refresh_token: input.refreshToken, + organization_id: input.organizationId, + ip_address: input.ipAddress, + user_agent: input.userAgent, + }), + cache: "no-store", + } + ); + + const payload = await parseJsonResponse(response); + if (response.ok) { + if (!("access_token" in payload && "refresh_token" in payload)) { + throw new Error("Unexpected WorkOS refresh response"); + } + + const accessTokenExpiresAt = + typeof payload.expires_in === "number" + ? Date.now() + Math.max(1, payload.expires_in) * 1000 + : undefined; + + return { + status: "success", + accessToken: payload.access_token, + refreshToken: payload.refresh_token, + accessTokenExpiresAt, + }; + } + + if ( + "error" in payload && + (payload.error === "invalid_grant" || + payload.error === "access_denied" || + payload.error === "expired_token") + ) { + return { status: "invalid_grant" }; + } + + throw new Error( + getErrorMessage({ + payload, + fallback: "Failed to refresh WorkOS access token", + }) + ); +} diff --git a/packages/opencode-excalidraw/src/index.test.ts b/packages/opencode-excalidraw/src/index.test.ts index 7863cd0..9804510 100644 --- a/packages/opencode-excalidraw/src/index.test.ts +++ b/packages/opencode-excalidraw/src/index.test.ts @@ -221,6 +221,89 @@ describe("SketchiPlugin", () => { } }); + test("device-flow callback returns the rotated refresh token", async () => { + const originalFetch = globalThis.fetch; + const originalSetTimeout = globalThis.setTimeout; + const requestedUrls: string[] = []; + + globalThis.fetch = ((input) => { + let url: string; + if (typeof input === "string") { + url = input; + } else if (input instanceof URL) { + url = input.toString(); + } else { + url = input.url; + } + + requestedUrls.push(url); + + if (url.endsWith("/api/auth/device/start")) { + return Promise.resolve( + new Response( + JSON.stringify({ + deviceCode: "device-code", + userCode: "ABCD-EFGH", + interval: 1, + expiresIn: 600, + verificationUrl: "https://www.sketchi.app/device", + }), + { + status: 200, + headers: { "content-type": "application/json" }, + } + ) + ); + } + + return Promise.resolve( + new Response( + JSON.stringify({ + status: "success", + accessToken: "fresh-access-token", + refreshToken: "fresh-refresh-token", + accessTokenExpiresAt: Date.now() + 60_000, + }), + { + status: 200, + headers: { "content-type": "application/json" }, + } + ) + ); + }) as typeof fetch; + + globalThis.setTimeout = ((handler: TimerHandler) => { + if (typeof handler === "function") { + handler(); + } + return 0 as never; + }) as typeof setTimeout; + + try { + const plugin = await SketchiPlugin(createPluginInput()); + const method = plugin.auth?.methods?.[0]; + expect(method?.type).toBe("oauth"); + + const authStart = await method?.authorize(); + const authResult = await authStart?.callback(); + + expect(authResult).toEqual({ + type: "success", + provider: "sketchi", + access: "fresh-access-token", + refresh: "fresh-refresh-token", + expires: authResult?.expires, + }); + expect(requestedUrls).toEqual([ + "https://www.sketchi.app/api/auth/device/start", + "https://www.sketchi.app/api/auth/device/token", + ]); + } finally { + globalThis.fetch = originalFetch; + globalThis.setTimeout = originalSetTimeout; + } + }); + test("diagram_grade blocks concurrent calls for the same message", async () => { const deferred = createDeferred<{ data: { parts: Array<{ type: string; text: string }> }; @@ -309,4 +392,60 @@ describe("SketchiPlugin", () => { expect(promptCalls.length).toBe(1); }); + + test("diagram tools fail fast when stored Sketchi oauth is expired and refresh fails", async () => { + const originalFetch = globalThis.fetch; + const authSetCalls: unknown[] = []; + + globalThis.fetch = (async () => + new Response( + JSON.stringify({ + status: "invalid_grant", + }), + { + status: 200, + headers: { "content-type": "application/json" }, + } + )) as typeof fetch; + + try { + const plugin = await SketchiPlugin( + createPluginInput({ + auth: { + set: (input: unknown) => { + authSetCalls.push(input); + return Promise.resolve(); + }, + }, + }) + ); + + await plugin.auth?.loader?.( + async () => ({ + type: "oauth", + access: "expired-access", + refresh: "legacy-refresh", + expires: Date.now() - 1, + }), + {} as never + ); + + const fromPrompt = plugin.tool?.diagram_from_prompt; + expect(fromPrompt).toBeDefined(); + if (!fromPrompt) { + throw new Error("diagram_from_prompt tool missing"); + } + + await expect( + fromPrompt.execute( + { prompt: "Create a simple flowchart." }, + createToolContext("message-expired-sketchi-auth") as never + ) + ).rejects.toThrow("opencode auth login --provider sketchi"); + + expect(authSetCalls).toHaveLength(1); + } finally { + globalThis.fetch = originalFetch; + } + }); }); diff --git a/packages/opencode-excalidraw/src/index.ts b/packages/opencode-excalidraw/src/index.ts index 2cec2bf..9848399 100644 --- a/packages/opencode-excalidraw/src/index.ts +++ b/packages/opencode-excalidraw/src/index.ts @@ -11,6 +11,15 @@ import { gradeDiagram } from "./lib/grade"; import { buildDefaultPngPath, resolveOutputPath, writePng } from "./lib/output"; import { closeBrowser, renderElementsToPng } from "./lib/render"; import { resolveExcalidrawFromShareUrl } from "./lib/resolve-share-url"; +import { + resolveSessionCandidate, + withRecoveredCachedSession, +} from "./lib/session-continuity"; +import { + accessTokenExpired, + isOAuthAuth, + refreshSketchiAccessToken, +} from "./lib/sketchi-oauth"; import { createToolTraceId } from "./lib/trace"; const DEFAULT_API_BASE = "https://www.sketchi.app"; @@ -218,14 +227,8 @@ function cacheSketchiSession( sketchiSessionByOpenCodeSession.set(opencodeSessionID, sketchiSessionID); } -function resolveSessionCandidate(input: { - explicitSessionId?: string; - opencodeSessionID: string; -}): string | undefined { - if (input.explicitSessionId) { - return input.explicitSessionId; - } - return sketchiSessionByOpenCodeSession.get(input.opencodeSessionID); +function clearCachedSketchiSession(opencodeSessionID: string): void { + sketchiSessionByOpenCodeSession.delete(opencodeSessionID); } function createStudioUrl(apiBase: string, sessionId: string): string { @@ -363,6 +366,7 @@ type DeviceTokenResponse = | { status: "success"; accessToken: string; + refreshToken: string; accessTokenExpiresAt?: number; } | { @@ -431,18 +435,27 @@ export const SketchiPlugin: Plugin = (input) => { const sessionCandidate = resolveSessionCandidate({ explicitSessionId: args.sessionId, - opencodeSessionID: context.sessionID, + getCachedSessionId: () => + sketchiSessionByOpenCodeSession.get(context.sessionID), }); - const runResult = await runThreadPrompt({ - apiBase, - authorizationHeader, - traceId, - sessionId: sessionCandidate, - prompt: args.prompt, - promptMessageId, - abort: context.abort, - timeoutMs: DEFAULT_THREAD_RUN_TIMEOUT_MS, + const runResult = await withRecoveredCachedSession({ + sessionCandidate, + clearCachedSession: () => + clearCachedSketchiSession(context.sessionID), + explicitSessionErrorMessage: + "Provided Sketchi sessionId was not found. Omit sessionId to create a new diagram.", + call: async (sessionId) => + await runThreadPrompt({ + apiBase, + authorizationHeader, + traceId, + sessionId, + prompt: args.prompt, + promptMessageId, + abort: context.abort, + timeoutMs: DEFAULT_THREAD_RUN_TIMEOUT_MS, + }), }); if (runResult.status !== "persisted") { @@ -607,10 +620,12 @@ export const SketchiPlugin: Plugin = (input) => { prompt: args.request, }); - let workingSessionId = resolveSessionCandidate({ + let sessionCandidate = resolveSessionCandidate({ explicitSessionId: args.sessionId, - opencodeSessionID: context.sessionID, + getCachedSessionId: () => + sketchiSessionByOpenCodeSession.get(context.sessionID), }); + let workingSessionId = sessionCandidate.sessionId; const sceneSeed = await resolveSceneSeed({ apiBase, @@ -624,15 +639,27 @@ export const SketchiPlugin: Plugin = (input) => { }); if (sceneSeed) { - const seeded = await seedSessionFromScene({ - apiBase, - authorizationHeader, - traceId, - sessionId: workingSessionId, - scene: sceneSeed, - abort: context.abort, + const seeded = await withRecoveredCachedSession({ + sessionCandidate, + clearCachedSession: () => + clearCachedSketchiSession(context.sessionID), + explicitSessionErrorMessage: + "Provided Sketchi sessionId was not found. Omit sessionId or provide scene input to start a new diagram.", + call: async (sessionId) => + await seedSessionFromScene({ + apiBase, + authorizationHeader, + traceId, + sessionId, + scene: sceneSeed, + abort: context.abort, + }), }); workingSessionId = seeded.sessionId; + sessionCandidate = { + sessionId: seeded.sessionId, + source: "explicit", + }; } if (!workingSessionId) { @@ -641,15 +668,32 @@ export const SketchiPlugin: Plugin = (input) => { ); } - const runResult = await runThreadPrompt({ - apiBase, - authorizationHeader, - traceId, - sessionId: workingSessionId, - prompt: `Tactical tweak request:\\n${args.request}`, - promptMessageId, - abort: context.abort, - timeoutMs: Math.max(serverTimeoutMs, DEFAULT_THREAD_RUN_TIMEOUT_MS), + const runResult = await withRecoveredCachedSession({ + sessionCandidate: + workingSessionId === sessionCandidate.sessionId + ? sessionCandidate + : { + sessionId: workingSessionId, + source: "explicit", + }, + clearCachedSession: () => + clearCachedSketchiSession(context.sessionID), + explicitSessionErrorMessage: + "Provided Sketchi sessionId was not found. Supply an existing sessionId, shareUrl, or excalidraw input.", + call: async (sessionId) => + await runThreadPrompt({ + apiBase, + authorizationHeader, + traceId, + sessionId, + prompt: `Tactical tweak request:\\n${args.request}`, + promptMessageId, + abort: context.abort, + timeoutMs: Math.max( + serverTimeoutMs, + DEFAULT_THREAD_RUN_TIMEOUT_MS + ), + }), }); if (runResult.status !== "persisted") { @@ -816,10 +860,12 @@ export const SketchiPlugin: Plugin = (input) => { prompt: args.prompt, }); - let workingSessionId = resolveSessionCandidate({ + let sessionCandidate = resolveSessionCandidate({ explicitSessionId: args.sessionId, - opencodeSessionID: context.sessionID, + getCachedSessionId: () => + sketchiSessionByOpenCodeSession.get(context.sessionID), }); + let workingSessionId = sessionCandidate.sessionId; const sceneSeed = await resolveSceneSeed({ apiBase, @@ -833,15 +879,27 @@ export const SketchiPlugin: Plugin = (input) => { }); if (sceneSeed) { - const seeded = await seedSessionFromScene({ - apiBase, - authorizationHeader, - traceId, - sessionId: workingSessionId, - scene: sceneSeed, - abort: context.abort, + const seeded = await withRecoveredCachedSession({ + sessionCandidate, + clearCachedSession: () => + clearCachedSketchiSession(context.sessionID), + explicitSessionErrorMessage: + "Provided Sketchi sessionId was not found. Omit sessionId or provide scene input to start a new diagram.", + call: async (sessionId) => + await seedSessionFromScene({ + apiBase, + authorizationHeader, + traceId, + sessionId, + scene: sceneSeed, + abort: context.abort, + }), }); workingSessionId = seeded.sessionId; + sessionCandidate = { + sessionId: seeded.sessionId, + source: "explicit", + }; } if (!workingSessionId) { @@ -850,15 +908,32 @@ export const SketchiPlugin: Plugin = (input) => { ); } - const runResult = await runThreadPrompt({ - apiBase, - authorizationHeader, - traceId, - sessionId: workingSessionId, - prompt: `Structural restructure request:\\n${args.prompt}`, - promptMessageId, - abort: context.abort, - timeoutMs: Math.max(serverTimeoutMs, DEFAULT_THREAD_RUN_TIMEOUT_MS), + const runResult = await withRecoveredCachedSession({ + sessionCandidate: + workingSessionId === sessionCandidate.sessionId + ? sessionCandidate + : { + sessionId: workingSessionId, + source: "explicit", + }, + clearCachedSession: () => + clearCachedSketchiSession(context.sessionID), + explicitSessionErrorMessage: + "Provided Sketchi sessionId was not found. Supply an existing sessionId, shareUrl, or excalidraw input.", + call: async (sessionId) => + await runThreadPrompt({ + apiBase, + authorizationHeader, + traceId, + sessionId, + prompt: `Structural restructure request:\\n${args.prompt}`, + promptMessageId, + abort: context.abort, + timeoutMs: Math.max( + serverTimeoutMs, + DEFAULT_THREAD_RUN_TIMEOUT_MS + ), + }), }); if (runResult.status !== "persisted") { @@ -1193,8 +1268,36 @@ export const SketchiPlugin: Plugin = (input) => { loader(getAuth) { getAuthorizationHeader = async () => { const auth = await getAuth(); + if (!isOAuthAuth(auth)) { + return ( + resolveAuthorizationHeaderFromAuth(auth) ?? envAuthorizationHeader + ); + } + + let effectiveAuth = auth; + if (accessTokenExpired(effectiveAuth)) { + const refreshed = await refreshSketchiAccessToken({ + apiBase, + auth: effectiveAuth, + client: input.client, + traceId: createToolTraceId(), + }); + + if (refreshed) { + effectiveAuth = refreshed; + } else { + if (envAuthorizationHeader) { + return envAuthorizationHeader; + } + throw new Error( + "Sketchi OAuth is expired or invalid. Run `opencode auth login --provider sketchi`." + ); + } + } + return ( - resolveAuthorizationHeaderFromAuth(auth) ?? envAuthorizationHeader + resolveAuthorizationHeaderFromAuth(effectiveAuth) ?? + envAuthorizationHeader ); }; return Promise.resolve({}); @@ -1264,7 +1367,7 @@ export const SketchiPlugin: Plugin = (input) => { return { type: "success" as const, - refresh: pollResult.accessToken, + refresh: pollResult.refreshToken, access: pollResult.accessToken, expires: pollResult.accessTokenExpiresAt ?? diff --git a/packages/opencode-excalidraw/src/lib/session-continuity.test.ts b/packages/opencode-excalidraw/src/lib/session-continuity.test.ts new file mode 100644 index 0000000..6328830 --- /dev/null +++ b/packages/opencode-excalidraw/src/lib/session-continuity.test.ts @@ -0,0 +1,93 @@ +import { describe, expect, test } from "bun:test"; + +import { + isRecoverableSessionContinuationError, + resolveSessionCandidate, + withRecoveredCachedSession, +} from "./session-continuity"; + +describe("session continuity helpers", () => { + test("prefers explicit session IDs over cached continuity", () => { + const candidate = resolveSessionCandidate({ + explicitSessionId: "explicit-session", + getCachedSessionId: () => "cached-session", + }); + + expect(candidate).toEqual({ + sessionId: "explicit-session", + source: "explicit", + }); + }); + + test("falls back to cached continuity when explicit session is absent", () => { + const candidate = resolveSessionCandidate({ + getCachedSessionId: () => "cached-session", + }); + + expect(candidate).toEqual({ + sessionId: "cached-session", + source: "cached", + }); + }); + + test("detects session-not-found transport errors", () => { + const error = new Error( + 'Request failed (500): {"message":"diagrams.threadRun failed","data":{"errorMessage":"Session not found"}}' + ); + + expect(isRecoverableSessionContinuationError(error)).toBe(true); + }); + + test("retries once without sessionId when cached continuity is stale", async () => { + const calls: Array = []; + let cleared = false; + + const result = await withRecoveredCachedSession({ + sessionCandidate: { + sessionId: "stale-session", + source: "cached", + }, + clearCachedSession: () => { + cleared = true; + }, + explicitSessionErrorMessage: "should not be used", + call: (sessionId) => { + calls.push(sessionId); + if (sessionId) { + return Promise.reject( + new Error( + 'Request failed (500): {"message":"diagrams.threadRun failed","data":{"errorMessage":"Session not found"}}' + ) + ); + } + return Promise.resolve("recovered"); + }, + }); + + expect(result).toBe("recovered"); + expect(cleared).toBe(true); + expect(calls).toEqual(["stale-session", undefined]); + }); + + test("surfaces a clearer error for invalid explicit session IDs", async () => { + await expect( + withRecoveredCachedSession({ + sessionCandidate: { + sessionId: "invented-session", + source: "explicit", + }, + clearCachedSession: () => undefined, + explicitSessionErrorMessage: + "Provided Sketchi sessionId was not found. Omit sessionId to create a new diagram.", + call: () => + Promise.reject( + new Error( + 'Request failed (500): {"message":"diagrams.threadRun failed","data":{"errorMessage":"Session not found"}}' + ) + ), + }) + ).rejects.toThrow( + "Provided Sketchi sessionId was not found. Omit sessionId to create a new diagram." + ); + }); +}); diff --git a/packages/opencode-excalidraw/src/lib/session-continuity.ts b/packages/opencode-excalidraw/src/lib/session-continuity.ts new file mode 100644 index 0000000..781de61 --- /dev/null +++ b/packages/opencode-excalidraw/src/lib/session-continuity.ts @@ -0,0 +1,64 @@ +export type SessionCandidateSource = "none" | "cached" | "explicit"; + +export interface ResolvedSessionCandidate { + sessionId?: string; + source: SessionCandidateSource; +} + +export function resolveSessionCandidate(input: { + explicitSessionId?: string; + getCachedSessionId: () => string | undefined; +}): ResolvedSessionCandidate { + if (input.explicitSessionId) { + return { + sessionId: input.explicitSessionId, + source: "explicit", + }; + } + + const cachedSessionId = input.getCachedSessionId(); + if (cachedSessionId) { + return { + sessionId: cachedSessionId, + source: "cached", + }; + } + + return { + source: "none", + }; +} + +export function isRecoverableSessionContinuationError(error: unknown): boolean { + if (!(error instanceof Error)) { + return false; + } + + return error.message.toLowerCase().includes("session not found"); +} + +export async function withRecoveredCachedSession(input: { + sessionCandidate: ResolvedSessionCandidate; + call: (sessionId?: string) => Promise; + clearCachedSession: () => void; + explicitSessionErrorMessage: string; +}): Promise { + try { + return await input.call(input.sessionCandidate.sessionId); + } catch (error) { + if (!isRecoverableSessionContinuationError(error)) { + throw error; + } + + if (input.sessionCandidate.source === "cached") { + input.clearCachedSession(); + return await input.call(undefined); + } + + if (input.sessionCandidate.source === "explicit") { + throw new Error(input.explicitSessionErrorMessage, { cause: error }); + } + + throw error; + } +} diff --git a/packages/opencode-excalidraw/src/lib/sketchi-oauth.test.ts b/packages/opencode-excalidraw/src/lib/sketchi-oauth.test.ts new file mode 100644 index 0000000..e840dd2 --- /dev/null +++ b/packages/opencode-excalidraw/src/lib/sketchi-oauth.test.ts @@ -0,0 +1,151 @@ +import { describe, expect, test } from "bun:test"; + +import { + accessTokenExpired, + clearPersistedOAuthAuth, + isOAuthAuth, + refreshSketchiAccessToken, +} from "./sketchi-oauth"; + +describe("sketchi oauth helpers", () => { + test("recognizes oauth auth payloads", () => { + expect( + isOAuthAuth({ + type: "oauth", + access: "access-token", + refresh: "refresh-token", + expires: Date.now() + 60_000, + }) + ).toBe(true); + expect(isOAuthAuth({ type: "api", key: "secret" })).toBe(false); + }); + + test("treats missing or stale access tokens as expired", () => { + expect( + accessTokenExpired({ + type: "oauth", + refresh: "refresh-token", + }) + ).toBe(true); + + expect( + accessTokenExpired({ + type: "oauth", + access: "access-token", + refresh: "refresh-token", + expires: Date.now() - 1, + }) + ).toBe(true); + }); + + test("refreshSketchiAccessToken persists rotated auth", async () => { + const originalFetch = globalThis.fetch; + const setCalls: unknown[] = []; + + globalThis.fetch = (async () => + new Response( + JSON.stringify({ + status: "success", + accessToken: "new-access-token", + refreshToken: "new-refresh-token", + accessTokenExpiresAt: Date.now() + 3_600_000, + }), + { status: 200 } + )) as typeof fetch; + + try { + const result = await refreshSketchiAccessToken({ + apiBase: "https://www.sketchi.app", + auth: { + type: "oauth", + access: "old-access-token", + refresh: "old-refresh-token", + expires: Date.now() - 1, + }, + client: { + auth: { + set: (input) => { + setCalls.push(input); + return Promise.resolve(); + }, + }, + }, + traceId: "trace-refresh", + }); + + expect(result).toEqual({ + type: "oauth", + access: "new-access-token", + refresh: "new-refresh-token", + expires: result?.expires, + }); + expect(setCalls).toHaveLength(1); + } finally { + globalThis.fetch = originalFetch; + } + }); + + test("refreshSketchiAccessToken clears invalid_grant auth", async () => { + const originalFetch = globalThis.fetch; + const setCalls: unknown[] = []; + + globalThis.fetch = (async () => + new Response( + JSON.stringify({ + status: "invalid_grant", + }), + { status: 200 } + )) as typeof fetch; + + try { + const result = await refreshSketchiAccessToken({ + apiBase: "https://www.sketchi.app", + auth: { + type: "oauth", + access: "old-access-token", + refresh: "old-refresh-token", + expires: Date.now() - 1, + }, + client: { + auth: { + set: (input) => { + setCalls.push(input); + return Promise.resolve(); + }, + }, + }, + traceId: "trace-invalid", + }); + + expect(result).toBeUndefined(); + expect(setCalls).toHaveLength(1); + } finally { + globalThis.fetch = originalFetch; + } + }); + + test("clearPersistedOAuthAuth writes an empty oauth record", async () => { + const setCalls: unknown[] = []; + + await clearPersistedOAuthAuth({ + auth: { + set: (input) => { + setCalls.push(input); + return Promise.resolve(); + }, + }, + }); + + expect(setCalls).toEqual([ + { + path: { id: "sketchi" }, + body: { + type: "oauth", + access: "", + refresh: "", + expires: 0, + }, + }, + ]); + }); +}); diff --git a/packages/opencode-excalidraw/src/lib/sketchi-oauth.ts b/packages/opencode-excalidraw/src/lib/sketchi-oauth.ts new file mode 100644 index 0000000..c3eda62 --- /dev/null +++ b/packages/opencode-excalidraw/src/lib/sketchi-oauth.ts @@ -0,0 +1,129 @@ +import { fetchJson } from "./api"; + +const SKETCHI_PROVIDER_ID = "sketchi"; +const ACCESS_TOKEN_EXPIRY_BUFFER_MS = 60 * 1000; +const DEFAULT_TOKEN_TTL_MS = 60 * 60 * 1000; + +export interface OAuthAuthDetails { + access: string; + expires: number; + refresh: string; + type: "oauth"; +} + +export interface OAuthAuthLike { + access?: string; + expires?: number; + refresh?: string; + type: "oauth"; +} + +export interface NonOAuthAuthDetails { + type?: string; + [key: string]: unknown; +} + +export type AuthDetails = OAuthAuthLike | NonOAuthAuthDetails; + +export interface SketchiPluginClient { + auth: { + set(input: { + path: { id: string }; + body: OAuthAuthDetails; + }): Promise; + }; +} + +interface RefreshTokenResponse { + accessToken?: string; + accessTokenExpiresAt?: number; + refreshToken?: string; + status: "success" | "invalid_grant"; +} + +export function isOAuthAuth( + auth: AuthDetails | null | undefined +): auth is OAuthAuthLike { + return auth?.type === "oauth"; +} + +export function accessTokenExpired(auth: OAuthAuthLike): boolean { + if (!auth.access || typeof auth.expires !== "number") { + return true; + } + + return auth.expires <= Date.now() + ACCESS_TOKEN_EXPIRY_BUFFER_MS; +} + +async function persistOAuthAuth( + client: SketchiPluginClient, + auth: OAuthAuthDetails +): Promise { + await client.auth.set({ + path: { id: SKETCHI_PROVIDER_ID }, + body: auth, + }); +} + +export async function clearPersistedOAuthAuth( + client: SketchiPluginClient +): Promise { + await persistOAuthAuth(client, { + type: "oauth", + access: "", + refresh: "", + expires: 0, + }); +} + +export async function refreshSketchiAccessToken(input: { + abort?: AbortSignal; + apiBase: string; + auth: OAuthAuthLike; + client: SketchiPluginClient; + traceId: string; +}): Promise { + const refreshToken = input.auth.refresh?.trim(); + if (!refreshToken) { + return undefined; + } + + try { + const result = await fetchJson( + `${input.apiBase}/api/auth/refresh`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + "x-trace-id": input.traceId, + }, + body: JSON.stringify({ + refreshToken, + traceId: input.traceId, + }), + }, + input.abort + ); + + if ( + result.status !== "success" || + !result.accessToken || + !result.refreshToken + ) { + await clearPersistedOAuthAuth(input.client).catch(() => undefined); + return undefined; + } + + const updatedAuth: OAuthAuthDetails = { + type: "oauth", + access: result.accessToken, + refresh: result.refreshToken, + expires: result.accessTokenExpiresAt ?? Date.now() + DEFAULT_TOKEN_TTL_MS, + }; + + await persistOAuthAuth(input.client, updatedAuth).catch(() => undefined); + return updatedAuth; + } catch { + return undefined; + } +}