Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions packages/tools/src/openai/middleware.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/**
* Unit tests for createOpenAIMiddleware — no network calls required.
*
* These tests verify the double-wrapping protection added to guard against
* cross-tenant memory leakage when the same OpenAI client is passed to
* `createOpenAIMiddleware` (or `withSupermemory`) multiple times, which is
* the common pattern in multi-user server applications.
*/

import { describe, it, expect, vi, beforeEach } from "vitest"
import type OpenAI from "openai"

// The Supermemory SDK client constructor throws when SUPERMEMORY_API_KEY is
// absent. Provide a fake key here — the actual value is irrelevant because
// all network calls are intercepted by a `vi.stubGlobal("fetch", ...)` mock.
process.env.SUPERMEMORY_API_KEY = "sm_test_fake_key_for_unit_tests"

// We import the internal implementation directly to avoid the SUPERMEMORY_API_KEY
// check in the public `withSupermemory` wrapper.
import { createOpenAIMiddleware } from "./middleware"

// ---------------------------------------------------------------------------
// Minimal OpenAI client stub
// ---------------------------------------------------------------------------

/** Number of times the real (innermost) `create` was called. */
let realCreateCallCount = 0

/**
* Build a minimal mock OpenAI client whose `chat.completions.create` is a
* spy tracking how many times the underlying SDK method is invoked.
*/
function makeOpenAIStub(): OpenAI {
realCreateCallCount = 0

const realCreate = vi.fn(async () => {
realCreateCallCount++
return { choices: [{ message: { role: "assistant", content: "ok" } }] }
})

const stub = {
chat: {
completions: {
create: realCreate,
},
},
responses: undefined,
} as unknown as OpenAI

return stub
}

// ---------------------------------------------------------------------------
// Helper: patch fetch so the profile endpoint returns an empty result
// (prevents actual network calls during the test).
// ---------------------------------------------------------------------------

function mockFetchEmptyProfile() {
vi.stubGlobal(
"fetch",
vi.fn(async () => ({
ok: true,
json: async () => ({
profile: { static: [], dynamic: [] },
searchResults: { results: [] },
}),
text: async () => "",
})),
)
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

describe("createOpenAIMiddleware — double-wrapping protection", () => {
beforeEach(() => {
mockFetchEmptyProfile()
vi.unstubAllGlobals()
mockFetchEmptyProfile()
})

it("calling createOpenAIMiddleware twice on the same client only invokes the real SDK create once per completion", async () => {
const openai = makeOpenAIStub()

// First call — simulates request from user-A.
createOpenAIMiddleware(openai, "user-A", {
containerTag: "user-A",
customId: "conv-1",
})

// Second call — simulates request from user-B on the same shared client.
createOpenAIMiddleware(openai, "user-B", {
containerTag: "user-B",
customId: "conv-2",
})

// Trigger a completion through the (now-wrapped) client.
await openai.chat.completions.create({
model: "gpt-4o-mini",
messages: [{ role: "user", content: "hello" }],
} as Parameters<typeof openai.chat.completions.create>[0])

// Without the fix the real create would be called twice (once per wrapper layer).
expect(realCreateCallCount).toBe(1)
})

it("calling createOpenAIMiddleware three times still only invokes the real SDK create once", async () => {
const openai = makeOpenAIStub()

createOpenAIMiddleware(openai, "user-A", {
containerTag: "user-A",
customId: "conv-1",
})
createOpenAIMiddleware(openai, "user-B", {
containerTag: "user-B",
customId: "conv-2",
})
createOpenAIMiddleware(openai, "user-C", {
containerTag: "user-C",
customId: "conv-3",
})

await openai.chat.completions.create({
model: "gpt-4o-mini",
messages: [{ role: "user", content: "hello" }],
} as Parameters<typeof openai.chat.completions.create>[0])

expect(realCreateCallCount).toBe(1)
})

it("the wrapper installed by the second call replaces the first (latest wins)", async () => {
const openai = makeOpenAIStub()
const capturedContainerTags: string[] = []

// Spy on fetch to capture which containerTag was used for memory search.
vi.stubGlobal(
"fetch",
vi.fn(async (url: string, init?: RequestInit) => {
if (typeof init?.body === "string") {
try {
const body = JSON.parse(init.body) as { containerTag?: string }
if (body.containerTag) capturedContainerTags.push(body.containerTag)
} catch {
// ignore
}
}
return {
ok: true,
json: async () => ({
profile: { static: [], dynamic: [] },
searchResults: { results: [] },
}),
text: async () => "",
}
}),
)

createOpenAIMiddleware(openai, "user-A", {
containerTag: "user-A",
customId: "conv-1",
addMemory: "never", // avoid write path for this assertion
})
createOpenAIMiddleware(openai, "user-B", {
containerTag: "user-B",
customId: "conv-2",
addMemory: "never",
})

await openai.chat.completions.create({
model: "gpt-4o-mini",
messages: [{ role: "user", content: "hello" }],
} as Parameters<typeof openai.chat.completions.create>[0])

// Only user-B's containerTag should appear — user-A's wrapper was replaced,
// not stacked on top.
expect(capturedContainerTags).not.toContain("user-A")
expect(capturedContainerTags).toContain("user-B")
})
})
48 changes: 46 additions & 2 deletions packages/tools/src/openai/middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@ const normalizeBaseUrl = (url?: string): string => {
return url.endsWith("/") ? url.slice(0, -1) : url
}

/**
* Symbol used to stash the real (un-wrapped) SDK method on the wrapper function
* so that subsequent calls to `createOpenAIMiddleware` on the same client always
* capture the original SDK implementation rather than a previously-installed
* wrapper. Without this guard, calling the function twice on the same shared
* `OpenAI` instance would double-wrap the method: every completion would
* inject memories twice and could surface memories from a different user's
* container tag (cross-tenant leakage in multi-request server usage).
*/
const ORIGINAL_CREATE_SYM = Symbol("supermemory.originalCreate")
const ORIGINAL_RESPONSES_CREATE_SYM = Symbol(
"supermemory.originalResponsesCreate",
)

export interface OpenAIMiddlewareOptions {
/** Container tag/identifier for memory search (e.g., user ID, project ID). Required. */
containerTag: string
Expand Down Expand Up @@ -430,8 +444,24 @@ export function createOpenAIMiddleware(
const mode = options?.mode ?? "profile"
const addMemory = options?.addMemory ?? "always"

const originalCreate = openaiClient.chat.completions.create
const originalResponsesCreate = openaiClient.responses?.create
// Unwrap any previously-installed Supermemory wrapper so we always call the
// real SDK method. This makes the function safe to call multiple times on
// the same shared `openai` instance (e.g. per-request in a server) without
// double-wrapping.
const currentCreate = openaiClient.chat.completions.create as typeof openaiClient.chat.completions.create & {
[ORIGINAL_CREATE_SYM]?: typeof openaiClient.chat.completions.create
}
const originalCreate =
currentCreate[ORIGINAL_CREATE_SYM] ?? currentCreate

const currentResponsesCreate = openaiClient.responses?.create as
| (typeof openaiClient.responses.create & {
[ORIGINAL_RESPONSES_CREATE_SYM]?: typeof openaiClient.responses.create
})
| undefined
const originalResponsesCreate =
currentResponsesCreate?.[ORIGINAL_RESPONSES_CREATE_SYM] ??
currentResponsesCreate

/**
* Searches for memories and formats them for injection into API calls.
Expand Down Expand Up @@ -635,11 +665,25 @@ export function createOpenAIMiddleware(
})
}

// Stamp the original SDK method on the wrapper so future calls to
// `createOpenAIMiddleware` on the same client can unwrap it (see above).
;(
createWithMemory as typeof originalCreate & {
[ORIGINAL_CREATE_SYM]?: typeof originalCreate
}
)[ORIGINAL_CREATE_SYM] = originalCreate

openaiClient.chat.completions.create =
createWithMemory as typeof originalCreate

// Wrap Responses API if available
if (originalResponsesCreate) {
;(
createResponsesWithMemory as typeof originalResponsesCreate & {
[ORIGINAL_RESPONSES_CREATE_SYM]?: typeof originalResponsesCreate
}
)[ORIGINAL_RESPONSES_CREATE_SYM] = originalResponsesCreate

openaiClient.responses.create =
createResponsesWithMemory as typeof originalResponsesCreate
}
Expand Down