diff --git a/apps/web/src/app/api/exa/mcp/route.test.ts b/apps/web/src/app/api/exa/mcp/route.test.ts new file mode 100644 index 0000000000..1783ca7b5b --- /dev/null +++ b/apps/web/src/app/api/exa/mcp/route.test.ts @@ -0,0 +1,665 @@ +import { describe, it, expect, beforeEach } from '@jest/globals'; +import { NextResponse } from 'next/server'; +import { getUserFromAuth } from '@/lib/user.server'; +import { failureResult } from '@/lib/maybe-result'; +import type { User } from '@kilocode/db/schema'; +import { + getExaMonthlyUsage, + getExaFreeAllowanceMicrodollars, + recordExaUsage, +} from '@/lib/exa-usage'; +import { EXA_MONTHLY_ALLOWANCE_MICRODOLLARS } from '@/lib/constants'; +import { getBalanceAndOrgSettings } from '@/lib/organizations/organization-usage'; + +let afterCallbacks: (() => Promise)[] = []; + +jest.mock('next/server', () => { + return { + ...(jest.requireActual('next/server') as Record), + after: (fn: () => Promise) => { + afterCallbacks.push(fn); + }, + }; +}); + +async function flushAfterCallbacks() { + for (const fn of afterCallbacks) { + await fn(); + } + afterCallbacks = []; +} + +jest.mock('@/lib/config.server', () => ({ + EXA_API_KEY: 'test-exa-key', +})); + +jest.mock('@/lib/user.server'); +jest.mock('@/lib/exa-usage'); +jest.mock('@/lib/organizations/organization-usage'); + +const mockedGetUserFromAuth = jest.mocked(getUserFromAuth); +const mockedGetExaMonthlyUsage = jest.mocked(getExaMonthlyUsage); +const mockedGetExaFreeAllowanceMicrodollars = jest.mocked(getExaFreeAllowanceMicrodollars); +const mockedRecordExaUsage = jest.mocked(recordExaUsage); +const mockedGetBalanceAndOrgSettings = jest.mocked(getBalanceAndOrgSettings); +const mockedFetch = jest.fn() as jest.MockedFunction; +const originalFetch = globalThis.fetch; + +type RpcRequestBody = { + jsonrpc: '2.0'; + id?: number | string | null; + method: string; + params?: unknown; +}; + +function makeRequest(body: RpcRequestBody | unknown, headers: Record = {}) { + return new Request('http://localhost:3000/api/exa/mcp', { + method: 'POST', + headers: { 'Content-Type': 'application/json', ...headers }, + body: JSON.stringify(body), + }); +} + +function setUserAuth(id = 'user-123', organizationId?: string) { + mockedGetUserFromAuth.mockResolvedValue({ + user: { id } as User, + authFailedResponse: null, + organizationId, + }); +} + +function makeUpstreamResponse(body: unknown, status = 200) { + return new Response(JSON.stringify(body), { + status, + headers: { 'content-type': 'application/json' }, + }); +} + +async function readSseData(response: Response): Promise { + const text = await response.text(); + const match = text.match(/^data: (.+)$/m); + if (!match) throw new Error(`No data line in SSE body: ${text}`); + return JSON.parse(match[1]); +} + +describe('POST /api/exa/mcp', () => { + beforeEach(() => { + jest.resetAllMocks(); + afterCallbacks = []; + globalThis.fetch = mockedFetch; + mockedGetExaMonthlyUsage.mockResolvedValue({ usage: 0, freeAllowance: null }); + mockedGetExaFreeAllowanceMicrodollars.mockReturnValue(EXA_MONTHLY_ALLOWANCE_MICRODOLLARS); + mockedRecordExaUsage.mockResolvedValue(); + }); + + afterAll(() => { + globalThis.fetch = originalFetch; + }); + + describe('authentication', () => { + it('returns auth failure response when not authenticated', async () => { + const authFailedResponse = NextResponse.json(failureResult('Unauthorized'), { status: 401 }); + mockedGetUserFromAuth.mockResolvedValue({ + user: null, + authFailedResponse, + }); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ jsonrpc: '2.0', id: 1, method: 'initialize' }) as never + ); + + expect(response).toBe(authFailedResponse); + expect(mockedFetch).not.toHaveBeenCalled(); + }); + }); + + describe('initialize', () => { + it('returns protocolVersion, capabilities and serverInfo', async () => { + setUserAuth(); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ jsonrpc: '2.0', id: 1, method: 'initialize' }) as never + ); + + expect(response.status).toBe(200); + expect(response.headers.get('content-type')).toContain('text/event-stream'); + const payload = (await readSseData(response)) as { + jsonrpc: string; + id: number; + result: { protocolVersion: string; capabilities: unknown; serverInfo: unknown }; + }; + expect(payload.jsonrpc).toBe('2.0'); + expect(payload.id).toBe(1); + expect(payload.result.protocolVersion).toMatch(/^\d{4}-\d{2}-\d{2}$/); + expect(payload.result.capabilities).toEqual({ tools: {} }); + expect(payload.result.serverInfo).toEqual({ name: 'kilo-exa-mcp', version: '1.0.0' }); + expect(mockedFetch).not.toHaveBeenCalled(); + }); + }); + + describe('tools/list', () => { + it('returns the two Exa tools', async () => { + setUserAuth(); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ jsonrpc: '2.0', id: 2, method: 'tools/list' }) as never + ); + + expect(response.status).toBe(200); + const payload = (await readSseData(response)) as { + result: { tools: Array<{ name: string }> }; + }; + const names = payload.result.tools.map(t => t.name).sort(); + expect(names).toEqual(['get_code_context_exa', 'web_search_exa']); + }); + }); + + describe('notifications', () => { + it('returns 202 with no body for notifications/initialized', async () => { + setUserAuth(); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ jsonrpc: '2.0', method: 'notifications/initialized' }) as never + ); + + expect(response.status).toBe(202); + expect(await response.text()).toBe(''); + }); + + it('returns 202 when id is null (treated as notification)', async () => { + setUserAuth(); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ jsonrpc: '2.0', id: null, method: 'notifications/cancelled' }) as never + ); + + expect(response.status).toBe(202); + }); + }); + + describe('unknown method', () => { + it('returns JSON-RPC method-not-found error', async () => { + setUserAuth(); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ jsonrpc: '2.0', id: 9, method: 'resources/list' }) as never + ); + + expect(response.status).toBe(200); + const payload = (await readSseData(response)) as { + id: number; + error: { code: number; message: string }; + }; + expect(payload.id).toBe(9); + expect(payload.error.code).toBe(-32601); + expect(payload.error.message).toContain('resources/list'); + }); + }); + + describe('tools/call web_search_exa', () => { + it('translates arguments into an Exa /search call', async () => { + setUserAuth(); + mockedFetch.mockResolvedValue( + makeUpstreamResponse({ results: [{ title: 'hit' }], costDollars: { total: 0.007 } }) + ); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ + jsonrpc: '2.0', + id: 3, + method: 'tools/call', + params: { + name: 'web_search_exa', + arguments: { + query: 'typescript hooks', + type: 'auto', + numResults: 8, + livecrawl: 'fallback', + contextMaxCharacters: 10000, + }, + }, + }) as never + ); + + expect(response.status).toBe(200); + expect(mockedFetch).toHaveBeenCalledTimes(1); + expect(mockedFetch).toHaveBeenCalledWith( + 'https://api.exa.ai/search', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'x-api-key': 'test-exa-key', + }, + }) + ); + const sentBody = JSON.parse(mockedFetch.mock.calls[0][1]?.body as string); + expect(sentBody).toEqual({ + query: 'typescript hooks', + numResults: 8, + type: 'auto', + livecrawl: 'fallback', + contents: { + text: { maxCharacters: 10000 }, + highlights: true, + }, + }); + + const payload = (await readSseData(response)) as { + id: number; + result: { content: Array<{ type: string; text: string }> }; + }; + expect(payload.id).toBe(3); + expect(payload.result.content[0].type).toBe('text'); + const parsed = JSON.parse(payload.result.content[0].text) as { + results: Array<{ title: string }>; + }; + expect(parsed.results[0].title).toBe('hit'); + }); + + it('uses contents.text: true when contextMaxCharacters is omitted, and default numResults/type/livecrawl', async () => { + setUserAuth(); + mockedFetch.mockResolvedValue(makeUpstreamResponse({ results: [] })); + + const { POST } = await import('./route'); + await POST( + makeRequest({ + jsonrpc: '2.0', + id: 4, + method: 'tools/call', + params: { + name: 'web_search_exa', + arguments: { query: 'hello' }, + }, + }) as never + ); + + const sentBody = JSON.parse(mockedFetch.mock.calls[0][1]?.body as string); + expect(sentBody).toEqual({ + query: 'hello', + numResults: 8, + type: 'auto', + livecrawl: 'fallback', + contents: { text: true, highlights: true }, + }); + }); + + it('returns invalid-params error when query is missing', async () => { + setUserAuth(); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ + jsonrpc: '2.0', + id: 5, + method: 'tools/call', + params: { name: 'web_search_exa', arguments: {} }, + }) as never + ); + + const payload = (await readSseData(response)) as { + error: { code: number; message: string }; + }; + expect(payload.error.code).toBe(-32602); + expect(mockedFetch).not.toHaveBeenCalled(); + }); + }); + + describe('tools/call get_code_context_exa', () => { + it('translates tokensNum into maxCharacters (tokensNum * 4)', async () => { + setUserAuth(); + mockedFetch.mockResolvedValue(makeUpstreamResponse({ results: [] })); + + const { POST } = await import('./route'); + await POST( + makeRequest({ + jsonrpc: '2.0', + id: 6, + method: 'tools/call', + params: { + name: 'get_code_context_exa', + arguments: { query: 'react hooks', tokensNum: 5000 }, + }, + }) as never + ); + + const sentBody = JSON.parse(mockedFetch.mock.calls[0][1]?.body as string); + expect(sentBody).toEqual({ + query: 'react hooks', + numResults: 5, + type: 'auto', + contents: { text: { maxCharacters: 20000 } }, + }); + }); + + it('defaults tokensNum to 5000 when omitted', async () => { + setUserAuth(); + mockedFetch.mockResolvedValue(makeUpstreamResponse({ results: [] })); + + const { POST } = await import('./route'); + await POST( + makeRequest({ + jsonrpc: '2.0', + id: 7, + method: 'tools/call', + params: { name: 'get_code_context_exa', arguments: { query: 'react hooks' } }, + }) as never + ); + + const sentBody = JSON.parse(mockedFetch.mock.calls[0][1]?.body as string); + expect(sentBody.contents.text.maxCharacters).toBe(20000); + }); + }); + + describe('tools/call unknown tool', () => { + it('returns invalid-params error', async () => { + setUserAuth(); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ + jsonrpc: '2.0', + id: 8, + method: 'tools/call', + params: { name: 'mystery_tool', arguments: { query: 'x' } }, + }) as never + ); + + const payload = (await readSseData(response)) as { + error: { code: number; message: string }; + }; + expect(payload.error.code).toBe(-32602); + expect(payload.error.message).toContain('Unknown tool'); + expect(mockedFetch).not.toHaveBeenCalled(); + }); + }); + + describe('request signal propagation', () => { + it('passes request.signal to upstream fetch', async () => { + setUserAuth(); + mockedFetch.mockResolvedValue(makeUpstreamResponse({ results: [] })); + + const { POST } = await import('./route'); + const request = makeRequest({ + jsonrpc: '2.0', + id: 10, + method: 'tools/call', + params: { name: 'web_search_exa', arguments: { query: 'x' } }, + }); + await POST(request as never); + + expect(mockedFetch).toHaveBeenCalledWith( + 'https://api.exa.ai/search', + expect.objectContaining({ signal: request.signal }) + ); + }); + }); + + describe('monthly allowance', () => { + it('allows tool call when under the free tier', async () => { + setUserAuth(); + mockedGetExaMonthlyUsage.mockResolvedValue({ usage: 5_000_000, freeAllowance: 10_000_000 }); + mockedFetch.mockResolvedValue(makeUpstreamResponse({ results: [] })); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ + jsonrpc: '2.0', + id: 11, + method: 'tools/call', + params: { name: 'web_search_exa', arguments: { query: 'x' } }, + }) as never + ); + + expect(response.status).toBe(200); + expect(mockedGetBalanceAndOrgSettings).not.toHaveBeenCalled(); + }); + + it('returns JSON-RPC error when free tier is exhausted and no balance', async () => { + setUserAuth(); + mockedGetExaMonthlyUsage.mockResolvedValue({ + usage: 10_000_000, + freeAllowance: 10_000_000, + }); + mockedGetBalanceAndOrgSettings.mockResolvedValue({ balance: 0 }); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ + jsonrpc: '2.0', + id: 12, + method: 'tools/call', + params: { name: 'web_search_exa', arguments: { query: 'x' } }, + }) as never + ); + + const payload = (await readSseData(response)) as { + error: { code: number; message: string }; + }; + expect(payload.error.message).toContain('free allowance exhausted'); + expect(mockedFetch).not.toHaveBeenCalled(); + }); + + it('passes organizationId to balance check', async () => { + const orgId = 'org-456'; + setUserAuth('user-123', orgId); + mockedGetExaMonthlyUsage.mockResolvedValue({ + usage: 10_000_000, + freeAllowance: 10_000_000, + }); + mockedGetBalanceAndOrgSettings.mockResolvedValue({ balance: 10.0 }); + mockedFetch.mockResolvedValue(makeUpstreamResponse({ results: [] })); + + const { POST } = await import('./route'); + await POST( + makeRequest( + { + jsonrpc: '2.0', + id: 13, + method: 'tools/call', + params: { name: 'web_search_exa', arguments: { query: 'x' } }, + }, + { 'X-KiloCode-OrganizationId': orgId } + ) as never + ); + + expect(mockedGetBalanceAndOrgSettings).toHaveBeenCalledWith( + orgId, + expect.objectContaining({ id: 'user-123' }), + expect.anything() + ); + }); + }); + + describe('cost recording', () => { + it('records cost from upstream response via after callback (free tier)', async () => { + setUserAuth(); + mockedFetch.mockResolvedValue( + makeUpstreamResponse({ results: [], costDollars: { total: 0.007 } }) + ); + + const { POST } = await import('./route'); + await POST( + makeRequest({ + jsonrpc: '2.0', + id: 14, + method: 'tools/call', + params: { name: 'web_search_exa', arguments: { query: 'x' } }, + }) as never + ); + await flushAfterCallbacks(); + + expect(mockedRecordExaUsage).toHaveBeenCalledWith({ + userId: 'user-123', + organizationId: undefined, + path: '/search', + costMicrodollars: 7000, + chargedToBalance: false, + freeAllowanceMicrodollars: EXA_MONTHLY_ALLOWANCE_MICRODOLLARS, + }); + }); + + it('records cost with chargedToBalance when over free tier', async () => { + setUserAuth(); + mockedGetExaMonthlyUsage.mockResolvedValue({ + usage: 10_000_000, + freeAllowance: 10_000_000, + }); + mockedGetBalanceAndOrgSettings.mockResolvedValue({ balance: 5.0 }); + mockedFetch.mockResolvedValue( + makeUpstreamResponse({ results: [], costDollars: { total: 0.005 } }) + ); + + const { POST } = await import('./route'); + await POST( + makeRequest({ + jsonrpc: '2.0', + id: 15, + method: 'tools/call', + params: { name: 'web_search_exa', arguments: { query: 'x' } }, + }) as never + ); + await flushAfterCallbacks(); + + expect(mockedRecordExaUsage).toHaveBeenCalledWith({ + userId: 'user-123', + organizationId: undefined, + path: '/search', + costMicrodollars: 5000, + chargedToBalance: true, + freeAllowanceMicrodollars: 10_000_000, + }); + }); + + it('does not record cost when upstream returns no costDollars', async () => { + setUserAuth(); + mockedFetch.mockResolvedValue(makeUpstreamResponse({ results: [] })); + + const { POST } = await import('./route'); + await POST( + makeRequest({ + jsonrpc: '2.0', + id: 16, + method: 'tools/call', + params: { name: 'web_search_exa', arguments: { query: 'x' } }, + }) as never + ); + await flushAfterCallbacks(); + + expect(mockedRecordExaUsage).not.toHaveBeenCalled(); + }); + }); + + describe('upstream error', () => { + it('wraps Exa upstream errors as JSON-RPC errors (HTTP 200 SSE)', async () => { + setUserAuth(); + mockedFetch.mockResolvedValue( + new Response(JSON.stringify({ error: 'rate limited' }), { + status: 429, + headers: { 'content-type': 'application/json' }, + }) + ); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ + jsonrpc: '2.0', + id: 17, + method: 'tools/call', + params: { name: 'web_search_exa', arguments: { query: 'x' } }, + }) as never + ); + + expect(response.status).toBe(200); + const payload = (await readSseData(response)) as { + error: { code: number; message: string }; + }; + expect(payload.error.message).toContain('rate limited'); + }); + + it('does not record cost when upstream returns an error', async () => { + setUserAuth(); + mockedFetch.mockResolvedValue( + new Response(JSON.stringify({ error: 'bad request', costDollars: { total: 0.001 } }), { + status: 400, + headers: { 'content-type': 'application/json' }, + }) + ); + + const { POST } = await import('./route'); + await POST( + makeRequest({ + jsonrpc: '2.0', + id: 18, + method: 'tools/call', + params: { name: 'web_search_exa', arguments: { query: 'x' } }, + }) as never + ); + await flushAfterCallbacks(); + + expect(mockedRecordExaUsage).not.toHaveBeenCalled(); + }); + }); + + describe('response headers', () => { + it('sets Content-Encoding: identity on SSE responses', async () => { + setUserAuth(); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ jsonrpc: '2.0', id: 19, method: 'initialize' }) as never + ); + + expect(response.headers.get('content-encoding')).toBe('identity'); + expect(response.headers.get('content-type')).toContain('text/event-stream'); + }); + }); + + describe('malformed body', () => { + it('returns 400 for invalid JSON', async () => { + setUserAuth(); + + const request = new Request('http://localhost:3000/api/exa/mcp', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: 'not-json', + }); + + const { POST } = await import('./route'); + const response = await POST(request as never); + + expect(response.status).toBe(400); + const body = (await response.json()) as { error: string }; + expect(body.error).toContain('Invalid JSON'); + }); + + it.each([ + ['null', 'null'], + ['a primitive', '42'], + ['an array', '[1,2,3]'], + ])('returns 400 when body is %s (not a JSON-RPC object)', async (_desc, rawBody) => { + setUserAuth(); + + const request = new Request('http://localhost:3000/api/exa/mcp', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: rawBody, + }); + + const { POST } = await import('./route'); + const response = await POST(request as never); + + expect(response.status).toBe(400); + const body = (await response.json()) as { error: string }; + expect(body.error).toContain('JSON-RPC object'); + expect(mockedFetch).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/apps/web/src/app/api/exa/mcp/route.ts b/apps/web/src/app/api/exa/mcp/route.ts new file mode 100644 index 0000000000..322265231c --- /dev/null +++ b/apps/web/src/app/api/exa/mcp/route.ts @@ -0,0 +1,309 @@ +import { NextResponse } from 'next/server'; +import { type NextRequest } from 'next/server'; +import { after } from 'next/server'; +import { getUserFromAuth } from '@/lib/user.server'; +import { EXA_API_KEY } from '@/lib/config.server'; +import { + getExaMonthlyUsage, + getExaFreeAllowanceMicrodollars, + recordExaUsage, +} from '@/lib/exa-usage'; +import { getBalanceAndOrgSettings } from '@/lib/organizations/organization-usage'; +import { readDb } from '@/lib/drizzle'; +import { captureException } from '@sentry/nextjs'; + +const EXA_BASE_URL = 'https://api.exa.ai'; +const EXA_SEARCH_PATH = '/search'; +const MCP_PROTOCOL_VERSION = '2024-11-05'; + +const WEB_SEARCH_TOOL = { + name: 'web_search_exa', + description: + 'Search the web using Exa AI - performs real-time web searches and can scrape content from specific URLs. Provides up-to-date information for current events and recent data. Returns the content from the most relevant websites.', + inputSchema: { + type: 'object', + properties: { + query: { type: 'string', description: 'Search query' }, + type: { + type: 'string', + enum: ['auto', 'fast', 'deep'], + description: 'Search type - auto (balanced), fast (quick), deep (comprehensive)', + }, + numResults: { type: 'number', description: 'Number of results (default 8)' }, + livecrawl: { + type: 'string', + enum: ['fallback', 'preferred'], + description: + 'Live crawl mode - fallback uses live crawling as backup, preferred prioritizes live crawling', + }, + contextMaxCharacters: { + type: 'number', + description: 'Maximum characters of content returned per result', + }, + }, + required: ['query'], + }, +}; + +const CODE_CONTEXT_TOOL = { + name: 'get_code_context_exa', + description: + 'Search and get relevant context for any programming task using Exa Code API. Provides code examples, documentation and API references optimized for finding specific programming patterns and solutions.', + inputSchema: { + type: 'object', + properties: { + query: { type: 'string', description: 'Programming query' }, + tokensNum: { + type: 'number', + description: 'Number of tokens to return (default 5000)', + }, + }, + required: ['query'], + }, +}; + +const TOOLS = [WEB_SEARCH_TOOL, CODE_CONTEXT_TOOL]; + +type JsonRpcId = number | string | null; + +type JsonRpcRequest = { + jsonrpc?: string; + id?: JsonRpcId; + method?: string; + params?: unknown; +}; + +// JSON-RPC 2.0 error codes +// -32601: Method not found +// -32602: Invalid params +// -32603: Internal error +// -32000 to -32099: Server-defined errors +const ERROR_METHOD_NOT_FOUND = -32601; +const ERROR_INVALID_PARAMS = -32602; +const ERROR_INTERNAL = -32603; +const ERROR_SERVER = -32000; + +function sseResponse(payload: unknown, status = 200): Response { + const body = `event: message\ndata: ${JSON.stringify(payload)}\n\n`; + return new Response(body, { + status, + headers: { + 'content-type': 'text/event-stream; charset=utf-8', + 'cache-control': 'no-cache, no-transform', + 'content-encoding': 'identity', + }, + }); +} + +function jsonRpcResult(id: JsonRpcId | undefined, result: unknown) { + return { jsonrpc: '2.0', id: id ?? null, result }; +} + +function jsonRpcError(id: JsonRpcId | undefined, code: number, message: string) { + return { jsonrpc: '2.0', id: id ?? null, error: { code, message } }; +} + +type BuildSearchBodyResult = + | { ok: true; body: Record } + | { ok: false; message: string }; + +function buildSearchBody(toolName: string, args: Record): BuildSearchBodyResult { + if (toolName === 'web_search_exa') { + const query = args.query; + if (typeof query !== 'string' || query.length === 0) { + return { ok: false, message: "missing or invalid 'query' argument" }; + } + const contextMaxCharacters = + typeof args.contextMaxCharacters === 'number' ? args.contextMaxCharacters : undefined; + const body: Record = { + query, + numResults: typeof args.numResults === 'number' ? args.numResults : 8, + type: typeof args.type === 'string' ? args.type : 'auto', + livecrawl: typeof args.livecrawl === 'string' ? args.livecrawl : 'fallback', + contents: { + text: contextMaxCharacters !== undefined ? { maxCharacters: contextMaxCharacters } : true, + highlights: true, + }, + }; + return { ok: true, body }; + } + if (toolName === 'get_code_context_exa') { + const query = args.query; + if (typeof query !== 'string' || query.length === 0) { + return { ok: false, message: "missing or invalid 'query' argument" }; + } + const tokensNum = typeof args.tokensNum === 'number' ? args.tokensNum : 5000; + const body: Record = { + query, + numResults: 5, + type: 'auto', + contents: { text: { maxCharacters: tokensNum * 4 } }, + }; + return { ok: true, body }; + } + return { ok: false, message: `Unknown tool: ${toolName}` }; +} + +function extractCostDollars(responseBody: unknown): number | undefined { + const body = responseBody as { costDollars?: { total?: number } } | null; + return body?.costDollars?.total; +} + +export async function POST(request: NextRequest) { + const { user, authFailedResponse, organizationId } = await getUserFromAuth({ + adminOnly: false, + }); + if (authFailedResponse) return authFailedResponse; + + let parsed: unknown; + try { + parsed = await request.json(); + } catch { + return NextResponse.json({ error: 'Invalid JSON body' }, { status: 400 }); + } + if (typeof parsed !== 'object' || parsed === null || Array.isArray(parsed)) { + return NextResponse.json({ error: 'Request body must be a JSON-RPC object' }, { status: 400 }); + } + const rpcRequest = parsed as JsonRpcRequest; + + const { id, method, params } = rpcRequest; + const isNotification = id === undefined || id === null; + + // Notifications (no id) get 202 Accepted with no body per MCP streamable HTTP spec. + if (isNotification) { + return new Response(null, { status: 202 }); + } + + if (method === 'initialize') { + return sseResponse( + jsonRpcResult(id, { + protocolVersion: MCP_PROTOCOL_VERSION, + capabilities: { tools: {} }, + serverInfo: { name: 'kilo-exa-mcp', version: '1.0.0' }, + }) + ); + } + + if (method === 'tools/list') { + return sseResponse(jsonRpcResult(id, { tools: TOOLS })); + } + + if (method !== 'tools/call') { + return sseResponse( + jsonRpcError(id, ERROR_METHOD_NOT_FOUND, `Method not found: ${method ?? 'unknown'}`) + ); + } + + const callParams = (params ?? {}) as { name?: string; arguments?: Record }; + const toolName = callParams.name; + const args = callParams.arguments ?? {}; + if (typeof toolName !== 'string') { + return sseResponse(jsonRpcError(id, ERROR_INVALID_PARAMS, "Missing 'name' parameter")); + } + + const built = buildSearchBody(toolName, args); + if (!built.ok) { + return sseResponse(jsonRpcError(id, ERROR_INVALID_PARAMS, built.message)); + } + + if (!EXA_API_KEY) { + captureException(new Error('EXA_API_KEY is not configured')); + return sseResponse(jsonRpcError(id, ERROR_INTERNAL, 'Internal Server Error')); + } + + const { usage: monthlyUsage, freeAllowance: storedAllowance } = await getExaMonthlyUsage( + user.id, + readDb + ); + const allowance = storedAllowance ?? getExaFreeAllowanceMicrodollars(new Date(), user); + const isPaidRequest = monthlyUsage >= allowance; + + if (isPaidRequest) { + const { balance } = await getBalanceAndOrgSettings(organizationId, user, readDb); + if (balance <= 0) { + return sseResponse( + jsonRpcError( + id, + ERROR_SERVER, + 'Exa free allowance exhausted and no credit balance available' + ) + ); + } + } + + const upstream = await fetch(`${EXA_BASE_URL}${EXA_SEARCH_PATH}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'x-api-key': EXA_API_KEY, + }, + body: JSON.stringify(built.body), + signal: request.signal, + }); + + if (upstream.status >= 400) { + // Log only fields that are known to be safe (numeric status, validated IDs). + // toolName is user-controlled, so report it via Sentry tags instead of console. + console.error(`[exa-mcp] upstream error: status=${upstream.status}`); + captureException(new Error(`Exa upstream error (status ${upstream.status})`), { + tags: { route: '/api/exa/mcp', tool: toolName }, + extra: { userId: user.id }, + }); + let message = `Exa upstream error (status ${upstream.status})`; + try { + const errBody = (await upstream.clone().json()) as { + error?: string; + message?: string; + } | null; + if (errBody?.error) message = errBody.error; + else if (errBody?.message) message = errBody.message; + } catch { + // Ignore non-JSON error bodies; keep the default message. + } + return sseResponse(jsonRpcError(id, ERROR_SERVER, message)); + } + + let upstreamBody: unknown; + try { + upstreamBody = await upstream.json(); + } catch (error) { + captureException(error, { + tags: { route: '/api/exa/mcp', tool: toolName }, + extra: { userId: user.id }, + }); + return sseResponse(jsonRpcError(id, ERROR_INTERNAL, 'Invalid upstream response')); + } + + const costDollars = extractCostDollars(upstreamBody); + + after(async () => { + if (costDollars === undefined || costDollars <= 0) return; + try { + const costMicrodollars = Math.round(costDollars * 1_000_000); + await recordExaUsage({ + userId: user.id, + organizationId, + path: EXA_SEARCH_PATH, + costMicrodollars, + chargedToBalance: isPaidRequest, + freeAllowanceMicrodollars: allowance, + }); + } catch (error) { + captureException(error, { + tags: { route: '/api/exa/mcp', tool: toolName }, + extra: { userId: user.id }, + }); + } + }); + + return sseResponse( + jsonRpcResult(id, { + content: [ + { + type: 'text', + text: JSON.stringify(upstreamBody), + }, + ], + }) + ); +}