diff --git a/packages/tools/src/openai/middleware.test.ts b/packages/tools/src/openai/middleware.test.ts new file mode 100644 index 000000000..d4490a827 --- /dev/null +++ b/packages/tools/src/openai/middleware.test.ts @@ -0,0 +1,180 @@ +/** + * Unit tests for createOpenAIMiddleware — no network calls required. + * + * These tests verify the double-wrapping protection added to guard against + * cross-tenant memory leakage when the same OpenAI client is passed to + * `createOpenAIMiddleware` (or `withSupermemory`) multiple times, which is + * the common pattern in multi-user server applications. + */ + +import { describe, it, expect, vi, beforeEach } from "vitest" +import type OpenAI from "openai" + +// The Supermemory SDK client constructor throws when SUPERMEMORY_API_KEY is +// absent. Provide a fake key here — the actual value is irrelevant because +// all network calls are intercepted by a `vi.stubGlobal("fetch", ...)` mock. +process.env.SUPERMEMORY_API_KEY = "sm_test_fake_key_for_unit_tests" + +// We import the internal implementation directly to avoid the SUPERMEMORY_API_KEY +// check in the public `withSupermemory` wrapper. +import { createOpenAIMiddleware } from "./middleware" + +// --------------------------------------------------------------------------- +// Minimal OpenAI client stub +// --------------------------------------------------------------------------- + +/** Number of times the real (innermost) `create` was called. */ +let realCreateCallCount = 0 + +/** + * Build a minimal mock OpenAI client whose `chat.completions.create` is a + * spy tracking how many times the underlying SDK method is invoked. + */ +function makeOpenAIStub(): OpenAI { + realCreateCallCount = 0 + + const realCreate = vi.fn(async () => { + realCreateCallCount++ + return { choices: [{ message: { role: "assistant", content: "ok" } }] } + }) + + const stub = { + chat: { + completions: { + create: realCreate, + }, + }, + responses: undefined, + } as unknown as OpenAI + + return stub +} + +// --------------------------------------------------------------------------- +// Helper: patch fetch so the profile endpoint returns an empty result +// (prevents actual network calls during the test). +// --------------------------------------------------------------------------- + +function mockFetchEmptyProfile() { + vi.stubGlobal( + "fetch", + vi.fn(async () => ({ + ok: true, + json: async () => ({ + profile: { static: [], dynamic: [] }, + searchResults: { results: [] }, + }), + text: async () => "", + })), + ) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("createOpenAIMiddleware — double-wrapping protection", () => { + beforeEach(() => { + mockFetchEmptyProfile() + vi.unstubAllGlobals() + mockFetchEmptyProfile() + }) + + it("calling createOpenAIMiddleware twice on the same client only invokes the real SDK create once per completion", async () => { + const openai = makeOpenAIStub() + + // First call — simulates request from user-A. + createOpenAIMiddleware(openai, "user-A", { + containerTag: "user-A", + customId: "conv-1", + }) + + // Second call — simulates request from user-B on the same shared client. + createOpenAIMiddleware(openai, "user-B", { + containerTag: "user-B", + customId: "conv-2", + }) + + // Trigger a completion through the (now-wrapped) client. + await openai.chat.completions.create({ + model: "gpt-4o-mini", + messages: [{ role: "user", content: "hello" }], + } as Parameters[0]) + + // Without the fix the real create would be called twice (once per wrapper layer). + expect(realCreateCallCount).toBe(1) + }) + + it("calling createOpenAIMiddleware three times still only invokes the real SDK create once", async () => { + const openai = makeOpenAIStub() + + createOpenAIMiddleware(openai, "user-A", { + containerTag: "user-A", + customId: "conv-1", + }) + createOpenAIMiddleware(openai, "user-B", { + containerTag: "user-B", + customId: "conv-2", + }) + createOpenAIMiddleware(openai, "user-C", { + containerTag: "user-C", + customId: "conv-3", + }) + + await openai.chat.completions.create({ + model: "gpt-4o-mini", + messages: [{ role: "user", content: "hello" }], + } as Parameters[0]) + + expect(realCreateCallCount).toBe(1) + }) + + it("the wrapper installed by the second call replaces the first (latest wins)", async () => { + const openai = makeOpenAIStub() + const capturedContainerTags: string[] = [] + + // Spy on fetch to capture which containerTag was used for memory search. + vi.stubGlobal( + "fetch", + vi.fn(async (url: string, init?: RequestInit) => { + if (typeof init?.body === "string") { + try { + const body = JSON.parse(init.body) as { containerTag?: string } + if (body.containerTag) capturedContainerTags.push(body.containerTag) + } catch { + // ignore + } + } + return { + ok: true, + json: async () => ({ + profile: { static: [], dynamic: [] }, + searchResults: { results: [] }, + }), + text: async () => "", + } + }), + ) + + createOpenAIMiddleware(openai, "user-A", { + containerTag: "user-A", + customId: "conv-1", + addMemory: "never", // avoid write path for this assertion + }) + createOpenAIMiddleware(openai, "user-B", { + containerTag: "user-B", + customId: "conv-2", + addMemory: "never", + }) + + await openai.chat.completions.create({ + model: "gpt-4o-mini", + messages: [{ role: "user", content: "hello" }], + } as Parameters[0]) + + // Only user-B's containerTag should appear — user-A's wrapper was replaced, + // not stacked on top. + expect(capturedContainerTags).not.toContain("user-A") + expect(capturedContainerTags).toContain("user-B") + }) +}) diff --git a/packages/tools/src/openai/middleware.ts b/packages/tools/src/openai/middleware.ts index 7fc4267e7..b89c2e43a 100644 --- a/packages/tools/src/openai/middleware.ts +++ b/packages/tools/src/openai/middleware.ts @@ -11,6 +11,20 @@ const normalizeBaseUrl = (url?: string): string => { return url.endsWith("/") ? url.slice(0, -1) : url } +/** + * Symbol used to stash the real (un-wrapped) SDK method on the wrapper function + * so that subsequent calls to `createOpenAIMiddleware` on the same client always + * capture the original SDK implementation rather than a previously-installed + * wrapper. Without this guard, calling the function twice on the same shared + * `OpenAI` instance would double-wrap the method: every completion would + * inject memories twice and could surface memories from a different user's + * container tag (cross-tenant leakage in multi-request server usage). + */ +const ORIGINAL_CREATE_SYM = Symbol("supermemory.originalCreate") +const ORIGINAL_RESPONSES_CREATE_SYM = Symbol( + "supermemory.originalResponsesCreate", +) + export interface OpenAIMiddlewareOptions { /** Container tag/identifier for memory search (e.g., user ID, project ID). Required. */ containerTag: string @@ -430,8 +444,24 @@ export function createOpenAIMiddleware( const mode = options?.mode ?? "profile" const addMemory = options?.addMemory ?? "always" - const originalCreate = openaiClient.chat.completions.create - const originalResponsesCreate = openaiClient.responses?.create + // Unwrap any previously-installed Supermemory wrapper so we always call the + // real SDK method. This makes the function safe to call multiple times on + // the same shared `openai` instance (e.g. per-request in a server) without + // double-wrapping. + const currentCreate = openaiClient.chat.completions.create as typeof openaiClient.chat.completions.create & { + [ORIGINAL_CREATE_SYM]?: typeof openaiClient.chat.completions.create + } + const originalCreate = + currentCreate[ORIGINAL_CREATE_SYM] ?? currentCreate + + const currentResponsesCreate = openaiClient.responses?.create as + | (typeof openaiClient.responses.create & { + [ORIGINAL_RESPONSES_CREATE_SYM]?: typeof openaiClient.responses.create + }) + | undefined + const originalResponsesCreate = + currentResponsesCreate?.[ORIGINAL_RESPONSES_CREATE_SYM] ?? + currentResponsesCreate /** * Searches for memories and formats them for injection into API calls. @@ -635,11 +665,25 @@ export function createOpenAIMiddleware( }) } + // Stamp the original SDK method on the wrapper so future calls to + // `createOpenAIMiddleware` on the same client can unwrap it (see above). + ;( + createWithMemory as typeof originalCreate & { + [ORIGINAL_CREATE_SYM]?: typeof originalCreate + } + )[ORIGINAL_CREATE_SYM] = originalCreate + openaiClient.chat.completions.create = createWithMemory as typeof originalCreate // Wrap Responses API if available if (originalResponsesCreate) { + ;( + createResponsesWithMemory as typeof originalResponsesCreate & { + [ORIGINAL_RESPONSES_CREATE_SYM]?: typeof originalResponsesCreate + } + )[ORIGINAL_RESPONSES_CREATE_SYM] = originalResponsesCreate + openaiClient.responses.create = createResponsesWithMemory as typeof originalResponsesCreate }