diff --git a/apps/code/src/main/di/container.ts b/apps/code/src/main/di/container.ts index f45163c53..ba30b8fa3 100644 --- a/apps/code/src/main/di/container.ts +++ b/apps/code/src/main/di/container.ts @@ -32,6 +32,7 @@ import { LinearIntegrationService } from "../services/linear-integration/service import { LlmGatewayService } from "../services/llm-gateway/service"; import { McpAppsService } from "../services/mcp-apps/service"; import { McpCallbackService } from "../services/mcp-callback/service"; +import { McpProxyService } from "../services/mcp-proxy/service"; import { NotificationService } from "../services/notification/service"; import { OAuthService } from "../services/oauth/service"; import { PosthogPluginService } from "../services/posthog-plugin/service"; @@ -66,6 +67,7 @@ container.bind(MAIN_TOKENS.AgentAuthAdapter).to(AgentAuthAdapter); container.bind(MAIN_TOKENS.AgentService).to(AgentService); container.bind(MAIN_TOKENS.AuthService).to(AuthService); container.bind(MAIN_TOKENS.AuthProxyService).to(AuthProxyService); +container.bind(MAIN_TOKENS.McpProxyService).to(McpProxyService); container.bind(MAIN_TOKENS.ArchiveService).to(ArchiveService); container.bind(MAIN_TOKENS.SuspensionService).to(SuspensionService); container.bind(MAIN_TOKENS.AppLifecycleService).to(AppLifecycleService); diff --git a/apps/code/src/main/di/tokens.ts b/apps/code/src/main/di/tokens.ts index 27bdbcafc..1b980855b 100644 --- a/apps/code/src/main/di/tokens.ts +++ b/apps/code/src/main/di/tokens.ts @@ -23,6 +23,7 @@ export const MAIN_TOKENS = Object.freeze({ AgentService: Symbol.for("Main.AgentService"), AuthService: Symbol.for("Main.AuthService"), AuthProxyService: Symbol.for("Main.AuthProxyService"), + McpProxyService: Symbol.for("Main.McpProxyService"), ArchiveService: Symbol.for("Main.ArchiveService"), SuspensionService: Symbol.for("Main.SuspensionService"), AppLifecycleService: Symbol.for("Main.AppLifecycleService"), diff --git a/apps/code/src/main/menu.ts b/apps/code/src/main/menu.ts index 48e475207..ba6418676 100644 --- a/apps/code/src/main/menu.ts +++ b/apps/code/src/main/menu.ts @@ -10,6 +10,7 @@ import { } from "electron"; import { container } from "./di/container"; import { MAIN_TOKENS } from "./di/tokens"; +import type { AuthService } from "./services/auth/service"; import type { McpAppsService } from "./services/mcp-apps/service"; import type { UIService } from "./services/ui/service"; import type { UpdatesService } from "./services/updates/service"; @@ -132,6 +133,28 @@ function buildFileMenu(): MenuItemConstructorOptions { .invalidateToken(); }, }, + { + label: "Force refresh of OAuth token", + click: () => { + container + .get(MAIN_TOKENS.AuthService) + .refreshAccessToken() + .then(() => { + dialog.showMessageBox({ + type: "info", + title: "OAuth Token Refreshed", + message: "Access token refreshed successfully.", + }); + }) + .catch((err: Error) => { + dialog.showMessageBox({ + type: "error", + title: "OAuth Token Refresh Failed", + message: err.message, + }); + }); + }, + }, { label: "Refresh MCP Apps discovery", click: () => { diff --git a/apps/code/src/main/services/agent/auth-adapter.test.ts b/apps/code/src/main/services/agent/auth-adapter.test.ts index acb58fd45..102c2ddb8 100644 --- a/apps/code/src/main/services/agent/auth-adapter.test.ts +++ b/apps/code/src/main/services/agent/auth-adapter.test.ts @@ -50,6 +50,14 @@ function createDependencies() { authProxy: { start: vi.fn().mockResolvedValue("http://127.0.0.1:9999"), }, + mcpProxy: { + start: vi.fn().mockResolvedValue(undefined), + register: vi + .fn() + .mockImplementation( + (id: string) => `http://127.0.0.1:9998/${encodeURIComponent(id)}`, + ), + }, }; } @@ -68,6 +76,7 @@ describe("AgentAuthAdapter", () => { adapter = new AgentAuthAdapter( deps.authService as never, deps.authProxy as never, + deps.mcpProxy as never, ); }); @@ -75,20 +84,21 @@ describe("AgentAuthAdapter", () => { vi.restoreAllMocks(); }); - it("builds the default PostHog MCP server", async () => { + it("builds the default PostHog MCP server routed through the local proxy", async () => { const servers = await adapter.buildMcpServers(baseCredentials); + expect(deps.mcpProxy.register).toHaveBeenCalledWith( + "posthog", + "https://mcp.posthog.com/mcp", + ); expect(servers).toEqual( expect.arrayContaining([ expect.objectContaining({ name: "posthog", type: "http", - url: "https://mcp.posthog.com/mcp", - headers: expect.arrayContaining([ - { - name: "Authorization", - value: "Bearer test-access-token", - }, + url: "http://127.0.0.1:9998/posthog", + headers: expect.not.arrayContaining([ + expect.objectContaining({ name: "Authorization" }), ]), }), ]), @@ -152,14 +162,16 @@ describe("AgentAuthAdapter", () => { const servers = await adapter.buildMcpServers(baseCredentials); + expect(deps.mcpProxy.register).toHaveBeenCalledWith( + "installation-inst-2", + "https://proxy.posthog.com/inst-2/", + ); expect(servers).toEqual( expect.arrayContaining([ expect.objectContaining({ name: "secure-server", - url: "https://proxy.posthog.com/inst-2/", - headers: [ - { name: "Authorization", value: "Bearer test-access-token" }, - ], + url: "http://127.0.0.1:9998/installation-inst-2", + headers: [], }), ]), ); diff --git a/apps/code/src/main/services/agent/auth-adapter.ts b/apps/code/src/main/services/agent/auth-adapter.ts index ad635474a..9136cd485 100644 --- a/apps/code/src/main/services/agent/auth-adapter.ts +++ b/apps/code/src/main/services/agent/auth-adapter.ts @@ -5,6 +5,7 @@ import { MAIN_TOKENS } from "../../di/tokens"; import { logger } from "../../utils/logger"; import type { AuthService } from "../auth/service"; import type { AuthProxyService } from "../auth-proxy/service"; +import type { McpProxyService } from "../mcp-proxy/service"; import type { Credentials } from "./schemas"; const log = logger.scope("agent-auth-adapter"); @@ -37,6 +38,8 @@ export class AgentAuthAdapter { private readonly authService: AuthService, @inject(MAIN_TOKENS.AuthProxyService) private readonly authProxy: AuthProxyService, + @inject(MAIN_TOKENS.McpProxyService) + private readonly mcpProxy: McpProxyService, ) {} createPosthogConfig(credentials: Credentials): AgentPosthogConfig { @@ -51,14 +54,19 @@ export class AgentAuthAdapter { async buildMcpServers(credentials: Credentials): Promise { const servers: AcpMcpServer[] = []; const mcpUrl = this.getPostHogMcpUrl(credentials.apiHost); - const token = await this.getValidToken(); + // Warm the token so authenticatedFetch() has something cached, but do not + // bake it into the MCP config — the proxy injects a fresh one on every + // forwarded request. + await this.getValidToken(); + + await this.mcpProxy.start(); + const proxiedPosthogUrl = this.mcpProxy.register("posthog", mcpUrl); servers.push({ name: "posthog", type: "http", - url: mcpUrl, + url: proxiedPosthogUrl, headers: [ - { name: "Authorization", value: `Bearer ${token}` }, { name: "x-posthog-project-id", value: String(credentials.projectId), @@ -72,10 +80,12 @@ export class AgentAuthAdapter { for (const installation of installations) { if (installation.url === mcpUrl) continue; + const name = + installation.name || installation.display_name || installation.url; + if (installation.auth_type === "none") { servers.push({ - name: - installation.name || installation.display_name || installation.url, + name, type: "http", url: installation.url, headers: [], @@ -83,12 +93,15 @@ export class AgentAuthAdapter { continue; } + const proxiedUrl = this.mcpProxy.register( + `installation-${installation.id}`, + installation.proxy_url, + ); servers.push({ - name: - installation.name || installation.display_name || installation.url, + name, type: "http", - url: installation.proxy_url, - headers: [{ name: "Authorization", value: `Bearer ${token}` }], + url: proxiedUrl, + headers: [], }); } diff --git a/apps/code/src/main/services/mcp-proxy/service.test.ts b/apps/code/src/main/services/mcp-proxy/service.test.ts new file mode 100644 index 000000000..290aaf202 --- /dev/null +++ b/apps/code/src/main/services/mcp-proxy/service.test.ts @@ -0,0 +1,269 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import type { AuthService } from "../auth/service"; +import { McpProxyService } from "./service"; + +vi.mock("../../utils/logger.js", () => ({ + logger: { + scope: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), + }, +})); + +type AuthServiceMock = { + authenticatedFetch: ReturnType; + refreshAccessToken: ReturnType; + getValidAccessToken: ReturnType; +}; + +function createAuthServiceMock(): AuthServiceMock { + return { + authenticatedFetch: vi.fn(), + refreshAccessToken: vi.fn().mockResolvedValue({ + accessToken: "refreshed-token", + apiHost: "https://app.posthog.com", + }), + getValidAccessToken: vi.fn().mockResolvedValue({ + accessToken: "access-token", + apiHost: "https://app.posthog.com", + }), + }; +} + +describe("McpProxyService", () => { + let authServiceMock: AuthServiceMock; + let service: McpProxyService; + + beforeEach(() => { + authServiceMock = createAuthServiceMock(); + service = new McpProxyService(authServiceMock as unknown as AuthService); + }); + + afterEach(async () => { + await service.stop(); + vi.restoreAllMocks(); + }); + + describe("lifecycle", () => { + it("starts on a loopback port and returns a URL for register()", async () => { + await service.start(); + const url = service.register("alpha", "https://upstream.example/path"); + expect(url).toMatch(/^http:\/\/127\.0\.0\.1:\d+\/alpha$/); + }); + + it("throws from register() before start()", () => { + expect(() => + service.register("alpha", "https://upstream.example"), + ).toThrowError(/not started/); + }); + + it("handles concurrent start() calls without races", async () => { + await Promise.all([service.start(), service.start(), service.start()]); + const url = service.register("alpha", "https://upstream.example"); + expect(url).toMatch(/^http:\/\/127\.0\.0\.1:\d+\/alpha$/); + }); + + it("stop() closes the server and clears registered targets", async () => { + await service.start(); + service.register("alpha", "https://upstream.example"); + await service.stop(); + expect(() => + service.register("alpha", "https://upstream.example"), + ).toThrowError(/not started/); + }); + }); + + describe("request forwarding", () => { + it("returns 404 for unknown targets", async () => { + await service.start(); + const proxyUrl = service.register("alpha", "https://upstream.example"); + const unknownUrl = proxyUrl.replace("/alpha", "/bravo"); + + const res = await fetch(unknownUrl); + + expect(res.status).toBe(404); + expect(await res.text()).toBe("Unknown target"); + expect(authServiceMock.authenticatedFetch).not.toHaveBeenCalled(); + }); + + it("forwards GET requests and returns the upstream body and status", async () => { + authServiceMock.authenticatedFetch.mockResolvedValue( + new Response('{"ok":true}', { + status: 200, + headers: { "content-type": "application/json" }, + }), + ); + + await service.start(); + const proxyUrl = service.register("alpha", "https://upstream.example"); + + const res = await fetch(proxyUrl); + + expect(res.status).toBe(200); + expect(await res.text()).toBe('{"ok":true}'); + expect(authServiceMock.authenticatedFetch).toHaveBeenCalledTimes(1); + const [, url] = authServiceMock.authenticatedFetch.mock.calls[0]; + expect(url).toBe("https://upstream.example"); + }); + + it("forwards POST body bytes to the upstream URL", async () => { + authServiceMock.authenticatedFetch.mockResolvedValue( + new Response('{"ok":true}', { + status: 200, + headers: { "content-type": "application/json" }, + }), + ); + + await service.start(); + const proxyUrl = service.register("alpha", "https://upstream.example"); + + await fetch(proxyUrl, { + method: "POST", + headers: { "content-type": "application/json" }, + body: '{"hello":"world"}', + }); + + expect(authServiceMock.authenticatedFetch).toHaveBeenCalledTimes(1); + const [, , options] = authServiceMock.authenticatedFetch.mock.calls[0]; + expect(options.method).toBe("POST"); + expect(Buffer.from(options.body).toString("utf8")).toBe( + '{"hello":"world"}', + ); + }); + + it("strips Authorization and Host headers before forwarding", async () => { + authServiceMock.authenticatedFetch.mockResolvedValue( + new Response('{"ok":true}', { + status: 200, + headers: { "content-type": "application/json" }, + }), + ); + + await service.start(); + const proxyUrl = service.register("alpha", "https://upstream.example"); + + await fetch(proxyUrl, { + headers: { + Authorization: "Bearer leaked", + "X-Custom": "keep-me", + }, + }); + + const [, , options] = authServiceMock.authenticatedFetch.mock.calls[0]; + const forwardedHeaderKeys = Object.keys(options.headers).map((k) => + k.toLowerCase(), + ); + expect(forwardedHeaderKeys).not.toContain("authorization"); + expect(forwardedHeaderKeys).not.toContain("host"); + expect(forwardedHeaderKeys).not.toContain("connection"); + expect(options.headers["x-custom"]).toBe("keep-me"); + }); + + it("joins path suffix without producing a double slash for trailing-slash targets", async () => { + authServiceMock.authenticatedFetch.mockResolvedValue( + new Response("{}", { + status: 200, + headers: { "content-type": "application/json" }, + }), + ); + + await service.start(); + service.register("alpha", "https://upstream.example/inst-2/"); + const port = new URL( + service.register("alpha", "https://upstream.example/inst-2/"), + ).port; + + await fetch(`http://127.0.0.1:${port}/alpha/tools/list`); + + const [, url] = + authServiceMock.authenticatedFetch.mock.calls.at(-1) ?? []; + expect(url).toBe("https://upstream.example/inst-2/tools/list"); + }); + + it("preserves the incoming query string on the upstream URL", async () => { + authServiceMock.authenticatedFetch.mockResolvedValue( + new Response("{}", { + status: 200, + headers: { "content-type": "application/json" }, + }), + ); + + await service.start(); + const proxyUrl = service.register("alpha", "https://upstream.example"); + + await fetch(`${proxyUrl}?token=abc&foo=bar`); + + const [, url] = authServiceMock.authenticatedFetch.mock.calls[0]; + expect(url).toBe("https://upstream.example?token=abc&foo=bar"); + }); + }); + + describe("auth error retry", () => { + it("refreshes the token and retries once when the body contains authentication_failed", async () => { + authServiceMock.authenticatedFetch + .mockResolvedValueOnce( + new Response( + JSON.stringify({ error: { code: "authentication_failed" } }), + { status: 200, headers: { "content-type": "application/json" } }, + ), + ) + .mockResolvedValueOnce( + new Response('{"ok":true}', { + status: 200, + headers: { "content-type": "application/json" }, + }), + ); + + await service.start(); + const proxyUrl = service.register("alpha", "https://upstream.example"); + + const res = await fetch(proxyUrl, { method: "POST", body: "payload" }); + + expect(res.status).toBe(200); + expect(await res.text()).toBe('{"ok":true}'); + expect(authServiceMock.refreshAccessToken).toHaveBeenCalledTimes(1); + expect(authServiceMock.authenticatedFetch).toHaveBeenCalledTimes(2); + }); + + it("does not retry when the body looks healthy", async () => { + authServiceMock.authenticatedFetch.mockResolvedValue( + new Response('{"ok":true}', { + status: 200, + headers: { "content-type": "application/json" }, + }), + ); + + await service.start(); + const proxyUrl = service.register("alpha", "https://upstream.example"); + + await fetch(proxyUrl); + + expect(authServiceMock.refreshAccessToken).not.toHaveBeenCalled(); + expect(authServiceMock.authenticatedFetch).toHaveBeenCalledTimes(1); + }); + }); + + describe("SSE streaming", () => { + it("streams event-stream responses through to the client", async () => { + const sseBody = "data: one\n\ndata: two\n\n"; + authServiceMock.authenticatedFetch.mockResolvedValue( + new Response(sseBody, { + status: 200, + headers: { "content-type": "text/event-stream" }, + }), + ); + + await service.start(); + const proxyUrl = service.register("alpha", "https://upstream.example"); + + const res = await fetch(proxyUrl); + + expect(res.headers.get("content-type")).toContain("text/event-stream"); + expect(await res.text()).toBe(sseBody); + expect(authServiceMock.refreshAccessToken).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/apps/code/src/main/services/mcp-proxy/service.ts b/apps/code/src/main/services/mcp-proxy/service.ts new file mode 100644 index 000000000..5f6b5d832 --- /dev/null +++ b/apps/code/src/main/services/mcp-proxy/service.ts @@ -0,0 +1,289 @@ +import http from "node:http"; +import { inject, injectable, preDestroy } from "inversify"; +import { MAIN_TOKENS } from "../../di/tokens"; +import { logger } from "../../utils/logger"; +import type { AuthService } from "../auth/service"; + +const log = logger.scope("mcp-proxy"); + +/** + * Local HTTP proxy for MCP servers. Allows routing MCP requests through a + * stable loopback URL while injecting a fresh access token on every forwarded + * request. MCP transports bake their headers at construction time, so without + * this proxy we would either need to tear the transport down on every token + * rotation (expensive, racy) or leave it serving stale tokens. + * + * The proxy only listens on 127.0.0.1 and strips inbound Authorization headers + * before forwarding, but any local process can still use it to issue requests + * on the user's behalf — acceptable for a single-user desktop app. + */ +@injectable() +export class McpProxyService { + private server: http.Server | null = null; + private port: number | null = null; + private startPromise: Promise | null = null; + private targets = new Map(); + + constructor( + @inject(MAIN_TOKENS.AuthService) + private readonly authService: AuthService, + ) {} + + async start(): Promise { + if (this.server && this.port) return; + if (this.startPromise) return this.startPromise; + this.startPromise = this.doStart().catch((err) => { + this.startPromise = null; + throw err; + }); + return this.startPromise; + } + + private async doStart(): Promise { + const server = http.createServer((req, res) => { + this.handleRequest(req, res); + }); + this.server = server; + + await new Promise((resolve, reject) => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address(); + if (typeof addr === "object" && addr) { + this.port = addr.port; + log.info("MCP proxy started", { port: this.port }); + resolve(); + } else { + reject(new Error("Failed to get proxy address")); + } + }); + + server.on("error", (err) => { + log.error("MCP proxy server error", err); + reject(err); + }); + }); + } + + /** + * Register a target URL under a stable ID. Returns the loopback URL that + * should be passed to the MCP transport. Subsequent registrations with the + * same ID overwrite the target. + */ + register(id: string, targetUrl: string): string { + if (!this.port) { + throw new Error("MCP proxy not started"); + } + this.targets.set(id, targetUrl); + return `http://127.0.0.1:${this.port}/${encodeURIComponent(id)}`; + } + + @preDestroy() + async stop(): Promise { + if (!this.server) return; + const server = this.server; + await new Promise((resolve) => { + server.close(() => { + log.info("MCP proxy stopped"); + resolve(); + }); + }); + this.server = null; + this.port = null; + this.startPromise = null; + this.targets.clear(); + } + + private handleRequest( + req: http.IncomingMessage, + res: http.ServerResponse, + ): void { + const incoming = new URL(req.url ?? "/", "http://placeholder"); + const segments = incoming.pathname.split("/").filter(Boolean); + const [rawId, ...rest] = segments; + const id = rawId ? decodeURIComponent(rawId) : ""; + const target = this.targets.get(id); + + if (!target) { + log.warn("Unknown MCP proxy target", { id, url: req.url }); + res.writeHead(404); + res.end("Unknown target"); + return; + } + + const suffix = rest.join("/"); + const targetBase = target.replace(/\/+$/, ""); + const targetUrl = + (suffix ? `${targetBase}/${suffix}` : targetBase) + incoming.search; + + const strippedAuthHeaders = new Set([ + "authorization", + "proxy-authorization", + ]); + const headers: Record = {}; + for (const [key, value] of Object.entries(req.headers)) { + if ( + key === "host" || + key === "connection" || + strippedAuthHeaders.has(key) + ) { + continue; + } + if (typeof value === "string") { + headers[key] = value; + } + } + + const fetchOptions: RequestInit = { + method: req.method ?? "GET", + headers, + }; + + if (req.method !== "GET" && req.method !== "HEAD") { + const chunks: Buffer[] = []; + req.on("data", (chunk: Buffer) => chunks.push(chunk)); + req.on("end", () => { + fetchOptions.body = Buffer.concat(chunks); + this.forwardRequest(id, targetUrl, fetchOptions, res); + }); + } else { + this.forwardRequest(id, targetUrl, fetchOptions, res); + } + } + + private async forwardRequest( + id: string, + url: string, + options: RequestInit, + res: http.ServerResponse, + ): Promise { + try { + let response = await this.authService.authenticatedFetch( + fetch, + url, + options, + ); + + // MCP servers return HTTP 200 with auth failures encoded in the JSON-RPC + // body, so authenticatedFetch's 401/403 retry never kicks in. Detect the + // known error shape and retry once with a force-refreshed token. + const contentType = response.headers.get("content-type") ?? ""; + const isSse = contentType.includes("text/event-stream"); + + if (!isSse) { + const buf = Buffer.from(await response.arrayBuffer()); + const bodyText = buf.toString("utf8"); + + if (this.isAuthErrorBody(bodyText)) { + log.warn("MCP auth error in body — refreshing token and retrying", { + id, + url, + }); + await this.authService.refreshAccessToken(); + response = await this.authService.authenticatedFetch( + fetch, + url, + options, + ); + const retryContentType = response.headers.get("content-type") ?? ""; + if (!retryContentType.includes("text/event-stream")) { + const retryBuf = Buffer.from(await response.arrayBuffer()); + this.writeBufferedResponse(response, retryBuf, res); + return; + } + this.writeStreamingResponse(response, res); + return; + } + + if (/"isError"\s*:\s*true/.test(bodyText) || response.status >= 400) { + log.warn("MCP proxy non-OK body", { + id, + url, + status: response.status, + body: bodyText.slice(0, 2000), + }); + } else { + log.debug("MCP proxy response", { + id, + url, + status: response.status, + }); + } + + this.writeBufferedResponse(response, buf, res); + return; + } + + log.debug("MCP proxy response", { + id, + url, + status: response.status, + streaming: true, + }); + this.writeStreamingResponse(response, res); + } catch (err) { + log.error("MCP proxy forward error", { id, url, err }); + if (!res.headersSent) { + res.writeHead(502); + } + res.end("Proxy error"); + } + } + + private isAuthErrorBody(bodyText: string): boolean { + return ( + bodyText.includes('"authentication_failed"') || + bodyText.includes('"authentication_error"') + ); + } + + private buildResponseHeaders(response: Response): Record { + const stripHeaders = new Set([ + "transfer-encoding", + "content-encoding", + "content-length", + ]); + const headers: Record = {}; + response.headers.forEach((value: string, key: string) => { + if (stripHeaders.has(key)) return; + headers[key] = value; + }); + return headers; + } + + private writeBufferedResponse( + response: Response, + buf: Buffer, + res: http.ServerResponse, + ): void { + res.writeHead(response.status, this.buildResponseHeaders(response)); + res.end(buf); + } + + private async writeStreamingResponse( + response: Response, + res: http.ServerResponse, + ): Promise { + res.writeHead(response.status, this.buildResponseHeaders(response)); + if (!response.body) { + res.end(); + return; + } + const reader = response.body.getReader(); + res.on("close", () => { + void reader.cancel().catch(() => {}); + }); + const pump = async (): Promise => { + const { done, value } = await reader.read(); + if (done) { + res.end(); + return; + } + const canContinue = res.write(value); + if (canContinue) { + return pump(); + } + res.once("drain", () => pump()); + }; + await pump(); + } +}