diff --git a/src/app/v1/_lib/proxy/response-handler.ts b/src/app/v1/_lib/proxy/response-handler.ts index cad152770..b9e2edee7 100644 --- a/src/app/v1/_lib/proxy/response-handler.ts +++ b/src/app/v1/_lib/proxy/response-handler.ts @@ -57,6 +57,240 @@ import { } from "./stream-finalization"; const CLIENT_ABORT_DRAIN_MAX_MS = 60_000; +const STREAM_STATS_MAX_BUFFER_BYTES = 10 * 1024 * 1024; +const STREAM_STATS_HEAD_BYTES = 1024 * 1024; +const STREAM_STATS_TAIL_BYTES = STREAM_STATS_MAX_BUFFER_BYTES - STREAM_STATS_HEAD_BYTES; +const STREAM_STATS_TAIL_CHUNKS = 8192; +const STREAM_STATS_TRUNCATED_MARKER = "\n\n: [cch_truncated]\n\n"; + +type BoundedStreamTextSnapshot = { + text: string; + truncated: boolean; + totalBytes: number; + bufferedBytes: number; + chunkCount: number; +}; + +function copyUint8Range(value: Uint8Array, start = 0, end = value.byteLength): Uint8Array { + return new Uint8Array(value.subarray(start, end)); +} + +function resolveNonStreamTaskStaleTimeoutMs(provider: Provider): number { + return provider.requestTimeoutNonStreamingMs > 0 + ? provider.requestTimeoutNonStreamingMs + : Number.POSITIVE_INFINITY; +} + +function resolveStreamTaskStaleTimeoutMs(provider: Provider): number { + if (provider.streamingIdleTimeoutMs <= 0) { + return Number.POSITIVE_INFINITY; + } + + if (provider.firstByteTimeoutStreamingMs > 0) { + return Math.max(provider.firstByteTimeoutStreamingMs, provider.streamingIdleTimeoutMs); + } + + return Number.POSITIVE_INFINITY; +} + +// 流式统计只需要头部元信息和尾部 usage/final event。按字节保存窗口,避免 +// string[] 无界增长,也避免 subarray 持有超大原始 ArrayBuffer。 +export class BoundedStreamTextAccumulator { + private readonly headChunks: Uint8Array[] = []; + private readonly tailChunks: Uint8Array[] = []; + private readonly tailChunkBytes: number[] = []; + private headBufferedBytes = 0; + private tailBufferedBytes = 0; + private tailHead = 0; + private tailMode = false; + private truncated = false; + private totalBytes = 0; + private chunksSeen = 0; + private finishedSnapshot: BoundedStreamTextSnapshot | null = null; + + get chunkCount(): number { + return this.chunksSeen; + } + + get totalByteCount(): number { + return this.totalBytes; + } + + get bufferedByteCount(): number { + return this.headBufferedBytes + this.tailBufferedBytes; + } + + get isTruncated(): boolean { + return this.truncated; + } + + pushBytes(value: Uint8Array): void { + if (!value || value.byteLength === 0) { + return; + } + + this.finishedSnapshot = null; + this.chunksSeen += 1; + this.totalBytes += value.byteLength; + + if (!this.tailMode && this.headBufferedBytes < STREAM_STATS_HEAD_BYTES) { + const remainingHeadBytes = STREAM_STATS_HEAD_BYTES - this.headBufferedBytes; + if (value.byteLength <= remainingHeadBytes) { + this.headChunks.push(copyUint8Range(value)); + this.headBufferedBytes += value.byteLength; + return; + } + + this.headChunks.push(copyUint8Range(value, 0, remainingHeadBytes)); + this.headBufferedBytes += remainingHeadBytes; + this.tailMode = true; + this.pushTailBytes(value.subarray(remainingHeadBytes)); + return; + } + + this.tailMode = true; + this.pushTailBytes(value); + } + + finish(): BoundedStreamTextSnapshot { + if (this.finishedSnapshot) { + return this.finishedSnapshot; + } + + const text = this.createSnapshotText(); + + this.finishedSnapshot = { + text, + truncated: this.truncated, + totalBytes: this.totalBytes, + bufferedBytes: this.headBufferedBytes + this.tailBufferedBytes, + chunkCount: this.chunksSeen, + }; + + return this.finishedSnapshot; + } + + private createSnapshotText(): string { + if (!this.tailMode) { + return this.decodeChunks(this.headChunks, 0, this.headBufferedBytes); + } + + if (!this.truncated) { + return this.decodeContiguousBufferedBytes(); + } + + const headText = this.decodeChunks(this.headChunks, 0, this.headBufferedBytes); + const tailText = this.decodeChunks(this.tailChunks, this.tailHead, this.tailBufferedBytes); + return `${headText}${STREAM_STATS_TRUNCATED_MARKER}${tailText}`; + } + + private pushTailBytes(value: Uint8Array): void { + if (!value || value.byteLength === 0) { + return; + } + + if (value.byteLength > STREAM_STATS_TAIL_BYTES) { + this.tailChunks.length = 0; + this.tailChunkBytes.length = 0; + this.tailHead = 0; + const tail = copyUint8Range(value, value.byteLength - STREAM_STATS_TAIL_BYTES); + this.tailChunks.push(tail); + this.tailChunkBytes.push(tail.byteLength); + this.tailBufferedBytes = tail.byteLength; + this.truncated = true; + return; + } + + const copy = copyUint8Range(value); + this.tailChunks.push(copy); + this.tailChunkBytes.push(copy.byteLength); + this.tailBufferedBytes += copy.byteLength; + + while ( + this.tailBufferedBytes > STREAM_STATS_TAIL_BYTES && + this.tailHead < this.tailChunkBytes.length + ) { + const overflowBytes = this.tailBufferedBytes - STREAM_STATS_TAIL_BYTES; + const oldestChunkBytes = this.tailChunkBytes[this.tailHead] ?? 0; + + if (oldestChunkBytes <= 0) { + this.tailHead += 1; + continue; + } + + if (overflowBytes >= oldestChunkBytes) { + this.tailBufferedBytes -= oldestChunkBytes; + this.tailChunks[this.tailHead] = new Uint8Array(); + this.tailChunkBytes[this.tailHead] = 0; + this.tailHead += 1; + this.truncated = true; + continue; + } + + const oldestChunk = this.tailChunks[this.tailHead]!; + this.tailChunks[this.tailHead] = copyUint8Range(oldestChunk, overflowBytes); + this.tailChunkBytes[this.tailHead] = oldestChunkBytes - overflowBytes; + this.tailBufferedBytes -= overflowBytes; + this.truncated = true; + } + + if (this.tailHead > 4096) { + this.tailChunks.splice(0, this.tailHead); + this.tailChunkBytes.splice(0, this.tailHead); + this.tailHead = 0; + } + + const keptCount = this.tailChunks.length - this.tailHead; + if (keptCount > STREAM_STATS_TAIL_CHUNKS) { + const joined = this.concatChunks(this.tailChunks, this.tailHead, this.tailBufferedBytes); + this.tailChunks.length = 0; + this.tailChunkBytes.length = 0; + this.tailHead = 0; + this.tailChunks.push(joined); + this.tailChunkBytes.push(joined.byteLength); + this.tailBufferedBytes = joined.byteLength; + } + } + + private decodeChunks(chunks: Uint8Array[], startIndex: number, totalBytes: number): string { + if (totalBytes <= 0) { + return ""; + } + return new TextDecoder().decode(this.concatChunks(chunks, startIndex, totalBytes)); + } + + private decodeContiguousBufferedBytes(): string { + const totalBytes = this.headBufferedBytes + this.tailBufferedBytes; + if (totalBytes <= 0) { + return ""; + } + + const headBytes = this.concatChunks(this.headChunks, 0, this.headBufferedBytes); + const tailBytes = this.concatChunks(this.tailChunks, this.tailHead, this.tailBufferedBytes); + const out = new Uint8Array(headBytes.byteLength + tailBytes.byteLength); + out.set(headBytes, 0); + out.set(tailBytes, headBytes.byteLength); + return new TextDecoder().decode(out); + } + + private concatChunks(chunks: Uint8Array[], startIndex: number, totalBytes: number): Uint8Array { + if (totalBytes <= 0) { + return new Uint8Array(); + } + + const out = new Uint8Array(totalBytes); + let offset = 0; + for (let i = startIndex; i < chunks.length; i++) { + const chunk = chunks[i]; + if (!chunk || chunk.byteLength === 0) { + continue; + } + out.set(chunk, offset); + offset += chunk.byteLength; + } + return offset === totalBytes ? out : out.slice(0, offset); + } +} /** * Idempotent helper to release the agent pool reference count attached to a session. @@ -74,6 +308,82 @@ function releaseSessionAgent(session: ProxySession): void { } } +function bindTaskAbortToUpstreamResponse( + session: ProxySession, + abortController: AbortController, + taskId: string +): () => void { + const abortUpstream = () => { + const sessionWithController = session as typeof session & { + responseController?: AbortController; + }; + const upstreamController = sessionWithController.responseController; + if (!upstreamController || upstreamController.signal.aborted) { + return; + } + + const reason = + abortController.signal.reason instanceof Error + ? abortController.signal.reason + : new Error("async_task_aborted"); + try { + upstreamController.abort(reason); + } catch (error) { + logger.warn("[ResponseHandler] Failed to abort upstream response for async task", { + taskId, + error, + }); + } + }; + + abortController.signal.addEventListener("abort", abortUpstream, { once: true }); + if (abortController.signal.aborted) { + abortUpstream(); + } + + return () => { + abortController.signal.removeEventListener("abort", abortUpstream); + }; +} + +async function readResponseTextWithTaskActivity( + response: Response, + taskId: string +): Promise { + if (!response.body) { + AsyncTaskManager.touch(taskId); + return response.text(); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + const chunks: string[] = []; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + if (!value || value.byteLength === 0) { + continue; + } + + AsyncTaskManager.touch(taskId); + chunks.push(decoder.decode(value, { stream: true })); + } + + const finalText = decoder.decode(); + if (finalText) { + chunks.push(finalText); + } + AsyncTaskManager.touch(taskId); + return chunks.join(""); + } finally { + reader.releaseLock(); + } +} + function takeBeforeResponseBodySnapshotSource(session: ProxySession): Response | null { const snapshotSession = session as ProxySession & { detailSnapshotResponseBeforeSource?: Response | null; @@ -1049,17 +1359,17 @@ export class ProxyResponseHandler { const statusCode = response.status; let finalResponse = response; - const persistNonStreamAfterSnapshot = async (targetResponse: Response) => { + let finalResponseBodyForSnapshot: string | null = null; + const persistNonStreamAfterSnapshot = (targetResponse: Response, body: string) => { if (!session.sessionId || !session.shouldPersistSessionDebugArtifacts()) { return; } - const finalBody = await targetResponse.clone().text(); const responseAfterSnapshotTask = SessionManager.storeSessionResponsePhaseSnapshot?.( session.sessionId, "after", { - body: finalBody, + body, headers: targetResponse.headers, meta: { upstreamUrl: null, @@ -1096,9 +1406,15 @@ export class ProxyResponseHandler { const statusCode = response.status; const taskId = `non-stream-passthrough-${messageContext?.id || `unknown-${Date.now()}`}`; + const statsAbortController = new AbortController(); + const cleanupTaskAbortBinding = bindTaskAbortToUpstreamResponse( + session, + statsAbortController, + taskId + ); const statsPromise = (async () => { try { - const responseText = await responseForStats.text(); + const responseText = await readResponseTextWithTaskActivity(responseForStats, taskId); const sessionWithCleanup = session as typeof session & { clearResponseTimeout?: () => void; @@ -1203,12 +1519,16 @@ export class ProxyResponseHandler { ); } } finally { + cleanupTaskAbortBinding(); releaseSessionAgent(session); - AsyncTaskManager.cleanup(taskId); } })(); - AsyncTaskManager.register(taskId, statsPromise, "non-stream-passthrough-stats"); + AsyncTaskManager.register(taskId, statsPromise, { + taskType: "non-stream-passthrough-stats", + abortController: statsAbortController, + staleTimeoutMs: resolveNonStreamTaskStaleTimeoutMs(provider), + }); statsPromise.catch((error) => { if (session.sessionId && session.shouldPersistSessionDebugArtifacts()) { void discardBeforeResponseBodySnapshot(session); @@ -1246,6 +1566,7 @@ export class ProxyResponseHandler { const responseData = JSON.parse(responseText) as GeminiResponse; const transformed = GeminiAdapter.transformResponse(responseData, false); + const transformedBody = JSON.stringify(transformed); logger.debug( "[ResponseHandler] Transformed Gemini non-stream response to client format", @@ -1257,7 +1578,8 @@ export class ProxyResponseHandler { ); // ⭐ 清理传输 headers(body 已从流转为 JSON 字符串) - finalResponse = new Response(JSON.stringify(transformed), { + finalResponseBodyForSnapshot = transformedBody; + finalResponse = new Response(transformedBody, { status: response.status, statusText: response.statusText, headers: cleanResponseHeaders(response.headers), @@ -1265,6 +1587,7 @@ export class ProxyResponseHandler { } catch (error) { logger.error("[ResponseHandler] Failed to transform Gemini non-stream response:", error); finalResponse = response; + finalResponseBodyForSnapshot = null; } } } @@ -1272,6 +1595,11 @@ export class ProxyResponseHandler { // 使用 AsyncTaskManager 管理后台处理任务 const taskId = `non-stream-${messageContext?.id || `unknown-${Date.now()}`}`; const abortController = new AbortController(); + const cleanupTaskAbortBinding = bindTaskAbortToUpstreamResponse( + session, + abortController, + taskId + ); const cleanupClientAbortListener = bindClientAbortListener(session.clientAbortSignal, () => { AsyncTaskManager.cancel(taskId); abortController.abort(); @@ -1334,7 +1662,7 @@ export class ProxyResponseHandler { } // ⭐ 非流式:读取完整响应体(会等待所有数据下载完成) - const responseText = await responseForLog.text(); + const responseText = await readResponseTextWithTaskActivity(responseForLog, taskId); // ⭐ 响应体读取完成:清除响应超时定时器 const sessionWithCleanup = session as typeof session & { @@ -1420,6 +1748,13 @@ export class ProxyResponseHandler { responseBeforeSnapshotTask?.catch((err) => { logger.error("[ResponseHandler] Failed to store response before snapshot:", err); }); + + // after 快照复用本任务已经读取到的响应文本,避免再启动一个未受 + // AsyncTaskManager 管理的 clone().text() 读取分支。 + persistNonStreamAfterSnapshot( + finalResponse, + finalResponseBodyForSnapshot ?? responseText + ); } if (billableUsageMetrics && messageContext) { @@ -1706,14 +2041,18 @@ export class ProxyResponseHandler { }); } } finally { + cleanupTaskAbortBinding(); cleanupClientAbortListener(); releaseSessionAgent(session); - AsyncTaskManager.cleanup(taskId); } })(); // 注册任务并添加全局错误捕获 - AsyncTaskManager.register(taskId, processingPromise, "non-stream-processing"); + AsyncTaskManager.register(taskId, processingPromise, { + taskType: "non-stream-processing", + abortController, + staleTimeoutMs: resolveNonStreamTaskStaleTimeoutMs(provider), + }); processingPromise.catch(async (error) => { logger.error("ResponseHandler: Uncaught error in non-stream processing", { taskId, @@ -1731,9 +2070,6 @@ export class ProxyResponseHandler { }); }); - void persistNonStreamAfterSnapshot(finalResponse).catch((error) => { - logger.error("[ResponseHandler] Failed to persist non-stream after snapshot", { error }); - }); return finalResponse; } @@ -1778,6 +2114,13 @@ export class ProxyResponseHandler { const statusCode = response.status; const taskId = `stream-passthrough-${messageContext.id}`; + const streamTaskStaleTimeoutMs = resolveStreamTaskStaleTimeoutMs(provider); + const statsAbortController = new AbortController(); + const cleanupTaskAbortBinding = bindTaskAbortToUpstreamResponse( + session, + statsAbortController, + taskId + ); const statsPromise = (async () => { const sessionWithCleanup = session as typeof session & { clearResponseTimeout?: () => void; @@ -1787,106 +2130,10 @@ export class ProxyResponseHandler { }; let reader: ReadableStreamDefaultReader | null = null; - // 保护:避免透传 stats 任务把超大响应体无界缓存在内存中(DoS/OOM 风险) - // 说明:用于统计/结算的内容采用“头部 + 尾部窗口”: - // - 头部保留前 MAX_STATS_HEAD_BYTES(便于解析可能前置的 metadata) - // - 尾部保留最近 MAX_STATS_TAIL_BYTES(便于解析结尾 usage/假 200 等) - // - 中间部分会被丢弃(wasTruncated=true),统计将退化为 best-effort - const MAX_STATS_BUFFER_BYTES = 10 * 1024 * 1024; // 10MB - const MAX_STATS_HEAD_BYTES = 1024 * 1024; // 1MB - const MAX_STATS_TAIL_BYTES = MAX_STATS_BUFFER_BYTES - MAX_STATS_HEAD_BYTES; - const MAX_STATS_TAIL_CHUNKS = 8192; - - const headChunks: string[] = []; - let headBufferedBytes = 0; - - const tailChunks: string[] = []; - const tailChunkBytes: number[] = []; - let tailHead = 0; - let tailBufferedBytes = 0; - let wasTruncated = false; - let inTailMode = false; - - const joinTailChunks = (): string => { - if (tailHead <= 0) return tailChunks.join(""); - return tailChunks.slice(tailHead).join(""); - }; - - const joinChunks = (): string => { - const headText = headChunks.join(""); - if (!inTailMode) { - return headText; - } - - const tailText = joinTailChunks(); - - // 用 SSE comment 标记被截断的中间段;parseSSEData 会忽略 ":" 开头的行 - if (wasTruncated) { - // 插入空行强制 flush event,避免“头+尾”拼接后跨 event 误拼接数据行 - return `${headText}\n\n: [cch_truncated]\n\n${tailText}`; - } - - return `${headText}${tailText}`; - }; - - const pushChunk = (text: string, bytes: number) => { - if (!text) return; - - const pushToTail = (tailText: string, tailBytes: number) => { - if (!tailText) return; - - tailChunks.push(tailText); - tailChunkBytes.push(tailBytes); - tailBufferedBytes += tailBytes; - - // 仅保留尾部窗口,避免内存无界增长 - while (tailBufferedBytes > MAX_STATS_TAIL_BYTES && tailHead < tailChunkBytes.length) { - tailBufferedBytes -= tailChunkBytes[tailHead] ?? 0; - tailChunks[tailHead] = ""; - tailChunkBytes[tailHead] = 0; - tailHead += 1; - wasTruncated = true; - } - - // 定期压缩数组,避免 head 指针过大导致 slice/join 性能退化 - if (tailHead > 4096) { - tailChunks.splice(0, tailHead); - tailChunkBytes.splice(0, tailHead); - tailHead = 0; - } - - // 防御:限制 chunk 数量,避免大量超小 chunk 导致对象/数组膨胀(即使总字节数已受限) - const keptCount = tailChunks.length - tailHead; - if (keptCount > MAX_STATS_TAIL_CHUNKS) { - const joined = joinTailChunks(); - tailChunks.length = 0; - tailChunkBytes.length = 0; - tailHead = 0; - tailChunks.push(joined); - tailChunkBytes.push(tailBufferedBytes); - } - }; - - // 优先填充 head;超过 head 上限后切到 tail(但不代表一定发生截断,只有 tail 溢出才算截断) - if (!inTailMode && headBufferedBytes < MAX_STATS_HEAD_BYTES) { - const remainingHeadBytes = MAX_STATS_HEAD_BYTES - headBufferedBytes; - if (remainingHeadBytes > 0 && bytes > remainingHeadBytes) { - const headPart = text.substring(0, remainingHeadBytes); - const tailPart = text.substring(remainingHeadBytes); - - pushChunk(headPart, remainingHeadBytes); - - inTailMode = true; - pushToTail(tailPart, bytes - remainingHeadBytes); - } else { - headChunks.push(text); - headBufferedBytes += bytes; - } - } else { - pushToTail(text, bytes); - } - }; - const decoder = new TextDecoder(); + const streamTextAccumulator = new BoundedStreamTextAccumulator(); + let lastStreamTextSnapshot: BoundedStreamTextSnapshot | null = null; + const getCollectedChunkCount = () => + lastStreamTextSnapshot?.chunkCount ?? streamTextAccumulator.chunkCount; let isFirstChunk = true; let streamEndedNormally = false; let responseTimeoutCleared = false; @@ -1894,7 +2141,9 @@ export class ProxyResponseHandler { // 静默期 Watchdog:透传也需要支持中途卡住(无新数据推送) const idleTimeoutMs = - provider.streamingIdleTimeoutMs > 0 ? provider.streamingIdleTimeoutMs : Infinity; + provider.streamingIdleTimeoutMs > 0 + ? provider.streamingIdleTimeoutMs + : Number.POSITIVE_INFINITY; let idleTimeoutId: NodeJS.Timeout | null = null; const clearIdleTimer = () => { if (idleTimeoutId) { @@ -1912,11 +2161,10 @@ export class ProxyResponseHandler { providerId: provider.id, providerName: provider.name, idleTimeoutMs, - chunksCollected: headChunks.length + Math.max(0, tailChunks.length - tailHead), - headBufferedBytes, - tailBufferedBytes, - bufferedBytes: headBufferedBytes + tailBufferedBytes, - wasTruncated, + chunksCollected: getCollectedChunkCount(), + totalBytes: streamTextAccumulator.totalByteCount, + bufferedBytes: streamTextAccumulator.bufferedByteCount, + wasTruncated: streamTextAccumulator.isTruncated, }); // 终止上游连接:让透传到客户端的连接也尽快结束,避免永久悬挂占用资源 try { @@ -1945,10 +2193,14 @@ export class ProxyResponseHandler { } }; + const flushAndSnapshot = (): BoundedStreamTextSnapshot => { + const snapshot = streamTextAccumulator.finish(); + lastStreamTextSnapshot = snapshot; + return snapshot; + }; + const flushAndJoin = (): string => { - const flushed = decoder.decode(); - if (flushed) pushChunk(flushed, 0); - return joinChunks(); + return flushAndSnapshot().text; }; try { @@ -1989,25 +2241,8 @@ export class ProxyResponseHandler { clearResponseTimeoutOnce(chunkSize); } - // 尽量填满 head:边界 chunk 可能跨过 head 上限,按 byte 切分以避免 head 少于 1MB - if (!inTailMode && headBufferedBytes < MAX_STATS_HEAD_BYTES) { - const remainingHeadBytes = MAX_STATS_HEAD_BYTES - headBufferedBytes; - if (remainingHeadBytes > 0 && chunkSize > remainingHeadBytes) { - const headPart = value.subarray(0, remainingHeadBytes); - const tailPart = value.subarray(remainingHeadBytes); - - const headText = decoder.decode(headPart, { stream: true }); - pushChunk(headText, remainingHeadBytes); - - inTailMode = true; - const tailText = decoder.decode(tailPart, { stream: true }); - pushChunk(tailText, chunkSize - remainingHeadBytes); - } else { - pushChunk(decoder.decode(value, { stream: true }), chunkSize); - } - } else { - pushChunk(decoder.decode(value, { stream: true }), chunkSize); - } + streamTextAccumulator.pushBytes(value); + AsyncTaskManager.touch(taskId); } // 首块数据到达后才启动 idle timer(避免与首字节超时职责重叠) @@ -2017,13 +2252,14 @@ export class ProxyResponseHandler { } clearIdleTimer(); - const allContent = flushAndJoin(); + const streamSnapshot = flushAndSnapshot(); + const allContent = streamSnapshot.text; const clientAborted = session.clientAbortSignal?.aborted ?? false; // 存储响应体到 Redis(5分钟过期) if ( session.sessionId && - !wasTruncated && + !streamSnapshot?.truncated && session.shouldPersistSessionDebugArtifacts() ) { void SessionManager.storeSessionResponse( @@ -2053,12 +2289,14 @@ export class ProxyResponseHandler { responseAfterSnapshotTask?.catch((err) => { logger.error("[ResponseHandler] Failed to store response after snapshot:", err); }); - } else if (session.sessionId && wasTruncated) { + } else if (session.sessionId && streamSnapshot?.truncated) { logger.warn("[ResponseHandler] Skip storing passthrough response: body too large", { taskId, providerId: provider.id, providerName: provider.name, - maxBytes: MAX_STATS_BUFFER_BYTES, + maxBytes: STREAM_STATS_MAX_BUFFER_BYTES, + totalBytes: streamSnapshot.totalBytes, + bufferedBytes: streamSnapshot.bufferedBytes, }); } @@ -2157,6 +2395,7 @@ export class ProxyResponseHandler { }); } } finally { + cleanupTaskAbortBinding(); clearIdleTimer(); // 兜底:在流结束/中断后清理首字节超时,避免定时器泄漏 // 注意:不应在流仍可能继续时清理(否则会让首字节超时失效) @@ -2219,11 +2458,14 @@ export class ProxyResponseHandler { }); } releaseSessionAgent(session); - AsyncTaskManager.cleanup(taskId); } })(); - AsyncTaskManager.register(taskId, statsPromise, "stream-passthrough-stats"); + AsyncTaskManager.register(taskId, statsPromise, { + taskType: "stream-passthrough-stats", + abortController: statsAbortController, + staleTimeoutMs: streamTaskStaleTimeoutMs, + }); statsPromise.catch((error) => { if (session.sessionId && session.shouldPersistSessionDebugArtifacts()) { void discardBeforeResponseBodySnapshot(session); @@ -2322,14 +2564,25 @@ export class ProxyResponseHandler { // 使用 AsyncTaskManager 管理后台处理任务 const taskId = `stream-${messageContext?.id || `unknown-${Date.now()}`}`; const abortController = new AbortController(); + const cleanupTaskAbortBinding = bindTaskAbortToUpstreamResponse( + session, + abortController, + taskId + ); const idleTimeoutMs = - provider.streamingIdleTimeoutMs > 0 ? provider.streamingIdleTimeoutMs : Infinity; + provider.streamingIdleTimeoutMs > 0 + ? provider.streamingIdleTimeoutMs + : Number.POSITIVE_INFINITY; + const streamTaskStaleTimeoutMs = resolveStreamTaskStaleTimeoutMs(provider); const clientAbortDrainTimeoutMs = CLIENT_ABORT_DRAIN_MAX_MS; // 提升 idleTimeoutId 到外部作用域,以便客户端断开时能清除 let idleTimeoutId: NodeJS.Timeout | null = null; let clientAbortDrainTimeoutId: NodeJS.Timeout | null = null; - const chunks: string[] = []; + const streamTextAccumulator = new BoundedStreamTextAccumulator(); + let lastStreamTextSnapshot: BoundedStreamTextSnapshot | null = null; + const getCollectedChunkCount = () => + lastStreamTextSnapshot?.chunkCount ?? streamTextAccumulator.chunkCount; const clearClientAbortDrainTimer = () => { if (clientAbortDrainTimeoutId) { clearTimeout(clientAbortDrainTimeoutId); @@ -2350,7 +2603,7 @@ export class ProxyResponseHandler { taskId, providerId: provider.id, idleTimeoutMs, - chunksCollected: chunks.length, + chunksCollected: getCollectedChunkCount(), }); // 1. 关闭客户端流(让客户端收到连接关闭通知,避免悬挂) @@ -2435,10 +2688,7 @@ export class ProxyResponseHandler { const processingPromise = (async () => { const reader = internalStream.getReader(); - const decoder = new TextDecoder(); - // 注意:即使 STORE_SESSION_RESPONSE_BODY=false(不写入 Redis),这里也会在内存中累积完整流内容: - // - 用于解析 usage/cost 与内部结算(例如“假 200”检测) - // 因此该开关仅影响“是否持久化”,不用于控制流式内存占用。 + // 统计/结算只保留有界的“头 + 尾”文本快照,避免长流式响应把进程堆撑满。 let usageForCost: UsageMetrics | null = null; let isFirstChunk = true; // 标记是否为第一块数据 @@ -2448,11 +2698,9 @@ export class ProxyResponseHandler { // 静默一直等到 60s drain 总上限。 const flushAndJoin = (): string => { - const flushed = decoder.decode(); - if (flushed) { - chunks.push(flushed); - } - return chunks.join(""); + const snapshot = streamTextAccumulator.finish(); + lastStreamTextSnapshot = snapshot; + return snapshot.text; }; const finalizeStream = async ( @@ -2473,8 +2721,14 @@ export class ProxyResponseHandler { const streamErrorMessage = finalized.errorMessage; const providerIdForPersistence = finalized.providerIdForPersistence; - // 存储响应体到 Redis(5分钟过期) - if (session.sessionId && session.shouldPersistSessionDebugArtifacts()) { + const streamSnapshot = lastStreamTextSnapshot; + + // 存储响应体到 Redis(5分钟过期)。截断后的统计快照不是完整正文,不能伪装成完整调试正文落盘。 + if ( + session.sessionId && + session.shouldPersistSessionDebugArtifacts() && + !streamSnapshot?.truncated + ) { const beforeBody = (await consumeBeforeResponseBodySnapshot(session)) ?? allContent; void SessionManager.storeSessionResponse( session.sessionId, @@ -2503,6 +2757,16 @@ export class ProxyResponseHandler { responseBeforeSnapshotTask?.catch((err) => { logger.error("[ResponseHandler] Failed to store response before snapshot:", err); }); + } else if (session.sessionId && streamSnapshot?.truncated) { + discardBeforeResponseBodySnapshot(session); + logger.warn("[ResponseHandler] Skip storing stream response: body too large", { + taskId, + providerId: provider.id, + providerName: provider.name, + maxBytes: STREAM_STATS_MAX_BUFFER_BYTES, + totalBytes: streamSnapshot.totalBytes, + bufferedBytes: streamSnapshot.bufferedBytes, + }); } const duration = Date.now() - session.startTime; @@ -2746,7 +3010,7 @@ export class ProxyResponseHandler { statusCode: effectiveStatusCode, durationMs: duration, isStreaming: true, - sseEventCount: chunks.length, + sseEventCount: getCollectedChunkCount(), errorMessage: streamErrorMessage ?? undefined, }); }; @@ -2760,7 +3024,7 @@ export class ProxyResponseHandler { taskId, providerId: provider.id, providerName: provider.name, - chunksCollected: chunks.length, + chunksCollected: getCollectedChunkCount(), }); break; // 提前终止 } @@ -2772,14 +3036,15 @@ export class ProxyResponseHandler { } if (value) { const chunkSize = value.length; - chunks.push(decoder.decode(value, { stream: true })); + streamTextAccumulator.pushBytes(value); + AsyncTaskManager.touch(taskId); // 每次收到数据后重置静默期计时器(首次收到数据时启动) startIdleTimer(); logger.trace("ResponseHandler: Idle timer reset (data received)", { taskId, providerId: provider.id, - chunksCollected: chunks.length, + chunksCollected: getCollectedChunkCount(), lastChunkSize: chunkSize, idleTimeoutMs: idleTimeoutMs === Infinity ? "disabled" : idleTimeoutMs, }); @@ -2852,7 +3117,7 @@ export class ProxyResponseHandler { providerId: provider.id, providerName: provider.name, messageId: messageContext.id, - chunksCollected: chunks.length, + chunksCollected: getCollectedChunkCount(), errorName: err.name, }); @@ -2887,7 +3152,7 @@ export class ProxyResponseHandler { providerId: provider.id, providerName: provider.name, messageId: messageContext.id, - chunksCollected: chunks.length, + chunksCollected: getCollectedChunkCount(), }); // 注意:无法重试,因为客户端已收到 HTTP 200 @@ -2926,7 +3191,7 @@ export class ProxyResponseHandler { providerId: provider.id, providerName: provider.name, messageId: messageContext.id, - chunksCollected: chunks.length, + chunksCollected: getCollectedChunkCount(), errorName: err.name, errorMessage: err.message || "(empty message)", }); @@ -2959,7 +3224,7 @@ export class ProxyResponseHandler { providerId: provider.id, providerName: provider.name, messageId: messageContext.id, - chunksCollected: chunks.length, + chunksCollected: getCollectedChunkCount(), errorName: err.name, reason: err.name === "ResponseAborted" @@ -2985,7 +3250,7 @@ export class ProxyResponseHandler { providerId: provider.id, providerName: provider.name, messageId: messageContext.id, - chunksCollected: chunks.length, + chunksCollected: getCollectedChunkCount(), errorName: err.name, errorMessage: err.message || "(empty message)", errorCode: (err as NodeJS.ErrnoException).code, @@ -3037,6 +3302,7 @@ export class ProxyResponseHandler { } } finally { // 确保资源释放 + cleanupTaskAbortBinding(); cleanupClientAbortListener(); clearClientAbortDrainTimer(); clearIdleTimer(); // 清除静默期计时器(防止泄漏) @@ -3049,12 +3315,15 @@ export class ProxyResponseHandler { }); } releaseSessionAgent(session); - AsyncTaskManager.cleanup(taskId); } })(); // 注册任务并添加全局错误捕获 - AsyncTaskManager.register(taskId, processingPromise, "stream-processing"); + AsyncTaskManager.register(taskId, processingPromise, { + taskType: "stream-processing", + abortController, + staleTimeoutMs: streamTaskStaleTimeoutMs, + }); processingPromise.catch(async (error) => { logger.error("ResponseHandler: Uncaught error in stream processing", { taskId, diff --git a/src/lib/async-task-manager.ts b/src/lib/async-task-manager.ts index 7fa7184d2..912cf4140 100644 --- a/src/lib/async-task-manager.ts +++ b/src/lib/async-task-manager.ts @@ -20,9 +20,19 @@ interface TaskInfo { promise: Promise; abortController: AbortController; createdAt: number; + lastActivityAt: number; taskType: string; + staleTimeoutMs: number; } +interface RegisterTaskOptions { + taskType?: string; + abortController?: AbortController; + staleTimeoutMs?: number; +} + +const DEFAULT_STALE_TASK_TIMEOUT_MS = 10 * 60 * 1000; + class AsyncTaskManagerClass { private tasks: Map = new Map(); private cleanupInterval: NodeJS.Timeout | null = null; @@ -60,7 +70,7 @@ class AsyncTaskManagerClass { this.cleanupAll(); }); - // 每分钟检查并清理超时任务(>10 分钟未完成,防止内存泄漏) + // 每分钟检查并清理空闲超时任务,防止挂死后台任务长期强引用上下文。 this.cleanupInterval = setInterval(() => { this.cleanupCompletedTasks(); }, 60000); @@ -74,25 +84,42 @@ class AsyncTaskManagerClass { * @param taskType 任务类型(用于日志) * @returns AbortController(可用于取消任务) */ - register(taskId: string, promise: Promise, taskType = "unknown"): AbortController { + register( + taskId: string, + promise: Promise, + taskTypeOrOptions: string | RegisterTaskOptions = "unknown" + ): AbortController { this.initializeIfNeeded(); + const options = + typeof taskTypeOrOptions === "string" ? { taskType: taskTypeOrOptions } : taskTypeOrOptions; + const taskType = options.taskType ?? "unknown"; + // 如果任务已存在,先取消旧任务 - if (this.tasks.has(taskId)) { + const oldTaskInfo = this.tasks.get(taskId); + if (oldTaskInfo) { logger.warn("[AsyncTaskManager] Task already exists, cancelling old task", { taskId, taskType, }); this.cancel(taskId); + this.cleanup(taskId, oldTaskInfo); } - const abortController = new AbortController(); + const abortController = options.abortController ?? new AbortController(); + const staleTimeoutMs = + options.staleTimeoutMs === undefined || options.staleTimeoutMs <= 0 + ? DEFAULT_STALE_TASK_TIMEOUT_MS + : options.staleTimeoutMs; + const now = Date.now(); const taskInfo: TaskInfo = { promise, abortController, - createdAt: Date.now(), + createdAt: now, + lastActivityAt: now, taskType, + staleTimeoutMs, }; this.tasks.set(taskId, taskInfo); @@ -126,7 +153,7 @@ class AsyncTaskManagerClass { } }) .finally(() => { - this.cleanup(taskId); + this.cleanup(taskId, taskInfo); }); logger.debug("[AsyncTaskManager] Task registered", { @@ -138,6 +165,20 @@ class AsyncTaskManagerClass { return abortController; } + /** + * 标记任务仍在推进。流式任务每次读到 chunk 都应 touch,避免长时间活跃流被 + * wall-clock stale cleanup 误判为挂死任务。 + */ + touch(taskId: string): boolean { + const taskInfo = this.tasks.get(taskId); + if (!taskInfo) { + return false; + } + + taskInfo.lastActivityAt = Date.now(); + return true; + } + /** * 取消一个任务 * @@ -150,7 +191,9 @@ class AsyncTaskManagerClass { return; } - taskInfo.abortController.abort(); + if (!taskInfo.abortController.signal.aborted) { + taskInfo.abortController.abort(); + } logger.info("[AsyncTaskManager] Task cancelled", { taskId, @@ -160,11 +203,15 @@ class AsyncTaskManagerClass { } /** - * 清理单个任务 + * 清理单个任务。必须带上注册时的任务实例,避免旧任务 finally 误删同 taskId 的新任务。 * * @param taskId 任务唯一标识 */ - cleanup(taskId: string): void { + private cleanup(taskId: string, expectedTask: TaskInfo): boolean { + if (this.tasks.get(taskId) !== expectedTask) { + return false; + } + const deleted = this.tasks.delete(taskId); if (deleted) { logger.debug("[AsyncTaskManager] Task cleaned up", { @@ -172,33 +219,40 @@ class AsyncTaskManagerClass { remainingTasks: this.tasks.size, }); } + return deleted; } /** * 检查并清理超时任务 * - * 遍历所有活跃任务,对于超过 10 分钟还未完成的任务: + * 遍历所有活跃任务,对于空闲时间超过任务级 staleTimeoutMs 的任务: * 1. 记录警告日志 * 2. 触发 AbortController 取消任务 * 3. 从任务 Map 中移除 * - * ⚠️ 注意:这不是清理"已完成"的任务,而是清理"超时未完成"的任务 + * 注意:这是清理"空闲超时"的任务。活跃流应在收到上游 chunk 时 + * 调用 touch() 更新 lastActivityAt,避免被误判为挂死任务。 */ private cleanupCompletedTasks(): void { const now = Date.now(); - const staleThreshold = 10 * 60 * 1000; // 10 分钟 for (const [taskId, taskInfo] of this.tasks.entries()) { const age = now - taskInfo.createdAt; + const idleAge = now - taskInfo.lastActivityAt; + + const staleTimeoutMs = taskInfo.staleTimeoutMs || DEFAULT_STALE_TASK_TIMEOUT_MS; - // 如果任务超过 10 分钟还没完成,记录警告并取消 - if (age > staleThreshold) { - logger.warn("[AsyncTaskManager] Task timeout, cancelling", { + // 如果任务超过阈值没有任何进展,记录警告、取消并从 Map 断开强引用。 + if (idleAge > staleTimeoutMs) { + logger.warn("[AsyncTaskManager] Task timeout, cancelling and detaching", { taskId, taskType: taskInfo.taskType, age, + idleAge, + staleTimeoutMs, }); this.cancel(taskId); + this.cleanup(taskId, taskInfo); } } } @@ -211,8 +265,9 @@ class AsyncTaskManagerClass { count: this.tasks.size, }); - for (const taskId of this.tasks.keys()) { + for (const [taskId, taskInfo] of Array.from(this.tasks.entries())) { this.cancel(taskId); + this.cleanup(taskId, taskInfo); } if (this.cleanupInterval) { diff --git a/src/lib/langfuse/emit-proxy-trace.ts b/src/lib/langfuse/emit-proxy-trace.ts index 64e4ffe78..1d143c117 100644 --- a/src/lib/langfuse/emit-proxy-trace.ts +++ b/src/lib/langfuse/emit-proxy-trace.ts @@ -3,6 +3,10 @@ import type { ProxySession } from "@/app/v1/_lib/proxy/session"; import { logger } from "@/lib/logger"; import type { CostBreakdown } from "@/lib/utils/cost-calculation"; +const LANGFUSE_RESPONSE_TEXT_MAX_CHARS = 1024 * 1024; +const LANGFUSE_RESPONSE_TEXT_EDGE_CHARS = 128 * 1024; +const LANGFUSE_TRUNCATED_MARKER = "\n\n[langfuse_response_truncated]\n\n"; + export interface EmitProxyLangfuseTraceData { responseHeaders: Headers; responseText: string; @@ -16,6 +20,81 @@ export interface EmitProxyLangfuseTraceData { errorMessage?: string; } +function truncateResponseTextForLangfuse(text: string): string { + if (text.length <= LANGFUSE_RESPONSE_TEXT_MAX_CHARS) { + return text; + } + + return `${text.slice(0, LANGFUSE_RESPONSE_TEXT_EDGE_CHARS)}${LANGFUSE_TRUNCATED_MARKER}${text.slice( + -LANGFUSE_RESPONSE_TEXT_EDGE_CHARS + )}`; +} + +function buildRequestMessagePreview(message: Record): Record { + return { + truncatedForLangfuse: true, + model: typeof message.model === "string" ? message.model : undefined, + stream: typeof message.stream === "boolean" ? message.stream : undefined, + max_tokens: typeof message.max_tokens === "number" ? message.max_tokens : undefined, + temperature: typeof message.temperature === "number" ? message.temperature : undefined, + messageCount: Array.isArray(message.messages) ? message.messages.length : undefined, + contentsCount: Array.isArray(message.contents) ? message.contents.length : undefined, + toolsCount: Array.isArray(message.tools) ? message.tools.length : undefined, + hasSystemPrompt: + (Array.isArray(message.system) && message.system.length > 0) || + (typeof message.system === "string" && message.system.length > 0), + }; +} + +function buildLangfuseSessionSnapshot(session: ProxySession): ProxySession { + const providerChain = session.getProviderChain().map((item) => ({ ...item })); + const specialSettings = session.getSpecialSettings(); + const cacheTtlResolved = session.getCacheTtlResolved(); + const context1mApplied = session.getContext1mApplied(); + const currentModel = session.getCurrentModel(); + const originalModel = session.getOriginalModel(); + const modelRedirected = session.isModelRedirected(); + const endpoint = session.getEndpoint(); + const requestSequence = session.getRequestSequence(); + const messagesLength = session.getMessagesLength(); + const forwardedRequestBody = + typeof session.forwardedRequestBody === "string" + ? truncateResponseTextForLangfuse(session.forwardedRequestBody) + : null; + const requestMessage = buildRequestMessagePreview(session.request.message); + + return { + startTime: session.startTime, + method: session.method, + headers: new Headers(session.headers), + request: { + message: requestMessage, + log: truncateResponseTextForLangfuse(session.request.log ?? ""), + note: session.request.note, + model: session.request.model, + imageRequestMetadata: null, + }, + userAgent: session.userAgent, + provider: session.provider, + messageContext: session.messageContext, + ttfbMs: session.ttfbMs, + forwardStartTime: session.forwardStartTime, + forwardedRequestBody, + sessionId: session.sessionId, + originalFormat: session.originalFormat, + getMessagesLength: () => messagesLength, + getEndpoint: () => endpoint, + getCurrentModel: () => currentModel, + getProviderChain: () => providerChain, + getRequestSequence: () => requestSequence, + getOriginalModel: () => originalModel, + isModelRedirected: () => modelRedirected, + getSpecialSettings: () => specialSettings, + getCacheTtlResolved: () => cacheTtlResolved, + getContext1mApplied: () => context1mApplied, + } as unknown as ProxySession; +} + /** * 异步发送代理请求的 Langfuse trace。 * @@ -27,20 +106,35 @@ export function emitProxyLangfuseTrace( ): void { if (!process.env.LANGFUSE_PUBLIC_KEY || !process.env.LANGFUSE_SECRET_KEY) return; + // 必须在异步 import 之前截断,避免动态加载/SDK 发送期间闭包继续强引用完整大响应。 + const responseText = truncateResponseTextForLangfuse(data.responseText); + const sessionSnapshot = buildLangfuseSessionSnapshot(session); + const { + responseHeaders, + durationMs, + statusCode, + isStreaming, + usageMetrics, + costUsd, + costBreakdown, + sseEventCount, + errorMessage, + } = data; + void import("@/lib/langfuse/trace-proxy-request") .then(({ traceProxyRequest }) => { void traceProxyRequest({ - session, - responseHeaders: data.responseHeaders, - durationMs: data.durationMs, - statusCode: data.statusCode, - isStreaming: data.isStreaming, - responseText: data.responseText, - usageMetrics: data.usageMetrics, - costUsd: data.costUsd, - costBreakdown: data.costBreakdown, - sseEventCount: data.sseEventCount, - errorMessage: data.errorMessage, + session: sessionSnapshot, + responseHeaders, + durationMs, + statusCode, + isStreaming, + responseText, + usageMetrics, + costUsd, + costBreakdown, + sseEventCount, + errorMessage, }); }) .catch((err) => { diff --git a/src/lib/langfuse/trace-proxy-request.ts b/src/lib/langfuse/trace-proxy-request.ts index 09a23f7fd..abbe359ea 100644 --- a/src/lib/langfuse/trace-proxy-request.ts +++ b/src/lib/langfuse/trace-proxy-request.ts @@ -4,13 +4,27 @@ import { isLangfuseEnabled } from "@/lib/langfuse/index"; import { logger } from "@/lib/logger"; import type { CostBreakdown } from "@/lib/utils/cost-calculation"; +const LANGFUSE_JSON_PARSE_MAX_CHARS = 1024 * 1024; +const LANGFUSE_TEXT_PREVIEW_EDGE_CHARS = 128 * 1024; + function buildRequestBodySummary(session: ProxySession): Record { const msg = session.request.message as Record; + const hasSystemPrompt = + typeof msg.hasSystemPrompt === "boolean" + ? msg.hasSystemPrompt + : Array.isArray(msg.system) && msg.system.length > 0; + const toolsCount = + typeof msg.toolsCount === "number" + ? msg.toolsCount + : Array.isArray(msg.tools) + ? msg.tools.length + : 0; + return { model: session.request.model, messageCount: session.getMessagesLength(), - hasSystemPrompt: Array.isArray(msg.system) && msg.system.length > 0, - toolsCount: Array.isArray(msg.tools) ? msg.tools.length : 0, + hasSystemPrompt, + toolsCount, stream: msg.stream === true, maxTokens: typeof msg.max_tokens === "number" ? msg.max_tokens : undefined, temperature: typeof msg.temperature === "number" ? msg.temperature : undefined, @@ -126,6 +140,15 @@ function buildResponseOutput(ctx: TraceContext): unknown { return output; } +function buildLargeTextPreview(text: string): Record { + return { + truncated: true, + totalChars: text.length, + head: text.slice(0, LANGFUSE_TEXT_PREVIEW_EDGE_CHARS), + tail: text.slice(-LANGFUSE_TEXT_PREVIEW_EDGE_CHARS), + }; +} + /** * Send a trace to Langfuse for a completed proxy request. * Fully async and non-blocking. Errors are caught and logged. @@ -422,6 +445,10 @@ export async function traceProxyRequest(ctx: TraceContext): Promise { } function tryParseJsonSafe(text: string): unknown { + if (text.length > LANGFUSE_JSON_PARSE_MAX_CHARS) { + return buildLargeTextPreview(text); + } + try { return JSON.parse(text); } catch { diff --git a/tests/unit/langfuse/langfuse-trace.test.ts b/tests/unit/langfuse/langfuse-trace.test.ts index 3c3c105f8..91bd63899 100644 --- a/tests/unit/langfuse/langfuse-trace.test.ts +++ b/tests/unit/langfuse/langfuse-trace.test.ts @@ -350,6 +350,48 @@ describe("traceProxyRequest", () => { ); }); + test("should preserve request summary fields from lightweight Langfuse previews", async () => { + const { traceProxyRequest } = await import("@/lib/langfuse/trace-proxy-request"); + + await traceProxyRequest({ + session: createMockSession({ + request: { + message: { + truncatedForLangfuse: true, + model: "claude-sonnet-4-20250514", + stream: true, + max_tokens: 1024, + temperature: 0.7, + messageCount: 3, + toolsCount: 2, + hasSystemPrompt: true, + }, + model: "claude-sonnet-4-20250514", + }, + getMessagesLength: () => 3, + }), + responseHeaders: new Headers(), + durationMs: 500, + statusCode: 200, + isStreaming: true, + }); + + const llmCall = mockRootSpan.startObservation.mock.calls.find( + (c: unknown[]) => c[0] === "llm-call" + ); + expect(llmCall[1].metadata.requestSummary).toEqual( + expect.objectContaining({ + model: "claude-sonnet-4-20250514", + messageCount: 3, + hasSystemPrompt: true, + toolsCount: 2, + stream: true, + maxTokens: 1024, + temperature: 0.7, + }) + ); + }); + test("should handle model redirect metadata", async () => { const { traceProxyRequest } = await import("@/lib/langfuse/trace-proxy-request"); diff --git a/tests/unit/lib/async-task-manager-edge-runtime.test.ts b/tests/unit/lib/async-task-manager-edge-runtime.test.ts index 4ee32cf1e..dbc21f36f 100644 --- a/tests/unit/lib/async-task-manager-edge-runtime.test.ts +++ b/tests/unit/lib/async-task-manager-edge-runtime.test.ts @@ -215,6 +215,37 @@ describe.sequential("AsyncTaskManager edge runtime", () => { await Promise.all([firstPromise, secondPromise]); }); + it("does not let an old task finalizer remove a newer task with the same taskId", async () => { + process.env.CI = "true"; + process.env.NEXT_RUNTIME = "nodejs"; + + const { AsyncTaskManager } = await import("@/lib/async-task-manager"); + + let resolveFirst: () => void; + const firstPromise = new Promise((resolve) => { + resolveFirst = resolve; + }); + AsyncTaskManager.register("t1", firstPromise); + + let resolveSecond: () => void; + const secondPromise = new Promise((resolve) => { + resolveSecond = resolve; + }); + AsyncTaskManager.register("t1", secondPromise); + + resolveFirst!(); + await firstPromise; + await new Promise((resolve) => queueMicrotask(() => resolve())); + + expect(AsyncTaskManager.getActiveTaskCount()).toBe(1); + + resolveSecond!(); + await secondPromise; + await new Promise((resolve) => queueMicrotask(() => resolve())); + + expect(AsyncTaskManager.getActiveTaskCount()).toBe(0); + }); + it("logs task cancelled when isClientAbortError returns true", async () => { process.env.CI = "true"; process.env.NEXT_RUNTIME = "nodejs"; @@ -266,12 +297,14 @@ describe.sequential("AsyncTaskManager edge runtime", () => { const controller = AsyncTaskManager.register("stale-task", taskPromise, "custom_type"); const managerAny = AsyncTaskManager as unknown as { - tasks: Map; + tasks: Map; cleanupCompletedTasks: () => void; }; const info = managerAny.tasks.get("stale-task"); expect(info).toBeDefined(); - info!.createdAt = Date.now() - 11 * 60 * 1000; + const oldTimestamp = Date.now() - 11 * 60 * 1000; + info!.createdAt = oldTimestamp; + info!.lastActivityAt = oldTimestamp; let resolveFresh: () => void; const freshPromise = new Promise((resolve) => { @@ -283,6 +316,7 @@ describe.sequential("AsyncTaskManager edge runtime", () => { expect(controller.signal.aborted).toBe(true); expect(freshController.signal.aborted).toBe(false); + expect(AsyncTaskManager.getActiveTaskCount()).toBe(1); expect(vi.mocked(logger.warn)).toHaveBeenCalled(); resolveTask!(); @@ -290,6 +324,77 @@ describe.sequential("AsyncTaskManager edge runtime", () => { await Promise.all([taskPromise, freshPromise]); }); + it("does not cancel a long-running task that was recently touched", async () => { + process.env.CI = "true"; + process.env.NEXT_RUNTIME = "nodejs"; + + const { AsyncTaskManager } = await import("@/lib/async-task-manager"); + + let resolveTask: () => void; + const taskPromise = new Promise((resolve) => { + resolveTask = resolve; + }); + + const controller = AsyncTaskManager.register("active-stream", taskPromise, "stream-processing"); + + const managerAny = AsyncTaskManager as unknown as { + tasks: Map; + cleanupCompletedTasks: () => void; + }; + const info = managerAny.tasks.get("active-stream"); + expect(info).toBeDefined(); + const oldTimestamp = Date.now() - 11 * 60 * 1000; + info!.createdAt = oldTimestamp; + info!.lastActivityAt = oldTimestamp; + + expect(AsyncTaskManager.touch("active-stream")).toBe(true); + managerAny.cleanupCompletedTasks(); + + expect(controller.signal.aborted).toBe(false); + expect(AsyncTaskManager.getActiveTaskCount()).toBe(1); + expect(AsyncTaskManager.touch("missing-task")).toBe(false); + + resolveTask!(); + await taskPromise; + }); + + it("cleanupCompletedTasks aborts a provided controller and detaches stale tasks", async () => { + process.env.CI = "true"; + process.env.NEXT_RUNTIME = "nodejs"; + + const { AsyncTaskManager } = await import("@/lib/async-task-manager"); + + let resolveTask: () => void; + const taskPromise = new Promise((resolve) => { + resolveTask = resolve; + }); + const controller = new AbortController(); + + const returnedController = AsyncTaskManager.register("stale-task", taskPromise, { + taskType: "stream-processing", + abortController: controller, + }); + expect(returnedController).toBe(controller); + + const managerAny = AsyncTaskManager as unknown as { + tasks: Map; + cleanupCompletedTasks: () => void; + }; + const info = managerAny.tasks.get("stale-task"); + expect(info).toBeDefined(); + const oldTimestamp = Date.now() - 11 * 60 * 1000; + info!.createdAt = oldTimestamp; + info!.lastActivityAt = oldTimestamp; + + managerAny.cleanupCompletedTasks(); + + expect(controller.signal.aborted).toBe(true); + expect(AsyncTaskManager.getActiveTaskCount()).toBe(0); + + resolveTask!(); + await taskPromise; + }); + it("cleanupAll cancels tasks and clears interval", async () => { process.env.CI = "true"; process.env.NEXT_RUNTIME = "nodejs"; @@ -313,6 +418,7 @@ describe.sequential("AsyncTaskManager edge runtime", () => { managerAny.cleanupAll(); expect(controller.signal.aborted).toBe(true); + expect(AsyncTaskManager.getActiveTaskCount()).toBe(0); expect(clearIntervalSpy).toHaveBeenCalledWith(intervalId); expect(managerAny.cleanupInterval).toBeNull(); diff --git a/tests/unit/proxy/response-handler-abort-listener-cleanup.test.ts b/tests/unit/proxy/response-handler-abort-listener-cleanup.test.ts index 350c33d06..73f6933b9 100644 --- a/tests/unit/proxy/response-handler-abort-listener-cleanup.test.ts +++ b/tests/unit/proxy/response-handler-abort-listener-cleanup.test.ts @@ -22,6 +22,7 @@ vi.mock("@/lib/async-task-manager", () => ({ testState.asyncTasks.push(promise); return new AbortController(); }, + touch: () => true, cleanup: testState.cleanupTask, cancel: testState.cancelTask, }, diff --git a/tests/unit/proxy/response-handler-bill-non-success.test.ts b/tests/unit/proxy/response-handler-bill-non-success.test.ts index 86e0e441a..f4d04466e 100644 --- a/tests/unit/proxy/response-handler-bill-non-success.test.ts +++ b/tests/unit/proxy/response-handler-bill-non-success.test.ts @@ -17,6 +17,7 @@ vi.mock("@/lib/logger", () => ({ vi.mock("@/lib/async-task-manager", () => ({ AsyncTaskManager: { register: () => new AbortController(), + touch: () => true, cleanup: () => {}, cancel: () => {}, }, diff --git a/tests/unit/proxy/response-handler-client-abort-drain.test.ts b/tests/unit/proxy/response-handler-client-abort-drain.test.ts index 7ce1d7f77..094a8614c 100644 --- a/tests/unit/proxy/response-handler-client-abort-drain.test.ts +++ b/tests/unit/proxy/response-handler-client-abort-drain.test.ts @@ -1,13 +1,19 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; -import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler"; +import { + BoundedStreamTextAccumulator, + ProxyResponseHandler, +} from "@/app/v1/_lib/proxy/response-handler"; import { ProxySession } from "@/app/v1/_lib/proxy/session"; import { setDeferredStreamingFinalization } from "@/app/v1/_lib/proxy/stream-finalization"; import { AsyncTaskManager } from "@/lib/async-task-manager"; +import { emitProxyLangfuseTrace } from "@/lib/langfuse/emit-proxy-trace"; +import { SessionManager } from "@/lib/session-manager"; import { updateMessageRequestDetails, updateMessageRequestDuration } from "@/repository/message"; import type { Provider } from "@/types/provider"; const asyncTasks: Promise[] = []; +const STREAM_STATS_HEAD_BYTES_FOR_TEST = 1024 * 1024; vi.mock("@/app/v1/_lib/proxy/response-fixer", () => ({ ResponseFixer: { @@ -21,6 +27,7 @@ vi.mock("@/lib/async-task-manager", () => ({ asyncTasks.push(promise); return new AbortController(); }), + touch: vi.fn(() => true), cleanup: vi.fn(), cancel: vi.fn(), }, @@ -72,7 +79,7 @@ vi.mock("@/lib/session-manager", () => ({ SessionManager: { clearSessionProvider: vi.fn(), extractCodexPromptCacheKey: vi.fn(), - storeSessionResponse: vi.fn(), + storeSessionResponse: vi.fn(async () => undefined), storeSessionRequestPhaseSnapshot: vi.fn(), storeSessionResponsePhaseSnapshot: vi.fn(), storeSessionRequestHeaders: vi.fn(), @@ -82,6 +89,7 @@ vi.mock("@/lib/session-manager", () => ({ storeSessionUpstreamResponseMeta: vi.fn(), updateSessionProvider: vi.fn(), updateSessionUsage: vi.fn(), + updateSessionBindingSmart: vi.fn(async () => ({ updated: false, reason: "test" })), updateSessionWithCodexCacheKey: vi.fn(), }, })); @@ -273,6 +281,131 @@ function createResponsesSse(): Response { }); } +function createResponsesJson(): Response { + return new Response( + JSON.stringify({ + id: "resp_non_stream", + model: "gpt-5.4-mini-2026-03-17", + usage: { + input_tokens: 463, + output_tokens: 11, + }, + }), + { + status: 200, + headers: { "content-type": "application/json" }, + } + ); +} + +function createOversizedResponsesSse(): Response { + const oversizedDelta = "x".repeat(11 * 1024 * 1024); + const body = [ + `event: response.output_text.delta\ndata: ${JSON.stringify({ + type: "response.output_text.delta", + delta: oversizedDelta, + })}`, + `event: response.completed\ndata: ${JSON.stringify({ + type: "response.completed", + response: { + id: "resp_large", + model: "gpt-5.4-mini-2026-03-17", + usage: { + input_tokens: 463, + output_tokens: 11, + }, + }, + })}`, + "", + ].join("\n\n"); + + return new Response(body, { + status: 200, + headers: { "content-type": "text/event-stream" }, + }); +} + +function createUtf8SplitHeadTailResponsesSse(): Response { + const encoder = new TextEncoder(); + const eventPrefix = `event: response.output_text.delta\ndata: {"type":"response.output_text.delta","delta":"`; + const splitChar = "界"; + const prefixBytes = encoder.encode(eventPrefix).byteLength; + const fillBytes = STREAM_STATS_HEAD_BYTES_FOR_TEST - prefixBytes - 1; + if (fillBytes < 0) { + throw new Error("test event prefix is too large for the head window"); + } + + const completedEvent = `event: response.completed\ndata: ${JSON.stringify({ + type: "response.completed", + response: { + id: "resp_utf8_boundary", + model: "gpt-5.4-mini-2026-03-17", + usage: { + input_tokens: 463, + output_tokens: 11, + }, + }, + })}\n\n`; + const body = `${eventPrefix}${"a".repeat(fillBytes)}${splitChar}"}\n\n${completedEvent}`; + const chunk = encoder.encode(body); + const stream = new ReadableStream({ + start(controller) { + controller.enqueue(chunk); + controller.close(); + }, + }); + + return new Response(stream, { + status: 200, + headers: { "content-type": "text/event-stream" }, + }); +} + +function createSplitTailBoundaryResponsesSse(): Response { + const encoder = new TextEncoder(); + const completedEvent = `event: response.completed\ndata: ${JSON.stringify({ + type: "response.completed", + response: { + id: "resp_split_tail", + model: "gpt-5.4-mini-2026-03-17", + usage: { + input_tokens: 463, + output_tokens: 11, + }, + }, + })}\n\n`; + const splitAt = Math.floor(completedEvent.length / 2); + const firstChunk = encoder.encode( + `event: response.output_text.delta\ndata: ${JSON.stringify({ + type: "response.output_text.delta", + delta: "x".repeat(9 * 1024 * 1024), + })}\n\n${completedEvent.slice(0, splitAt)}` + ); + const secondChunk = encoder.encode( + `${completedEvent.slice(splitAt)}event: response.output_text.delta\ndata: ${JSON.stringify({ + type: "response.output_text.delta", + delta: "y".repeat(2 * 1024 * 1024), + })}\n\n` + ); + const chunks = [firstChunk, secondChunk]; + let index = 0; + + const stream = new ReadableStream({ + pull(controller) { + if (index < chunks.length) { + controller.enqueue(chunks[index++]); + return; + } + controller.close(); + }, + }); + + return new Response(stream, { + status: 200, + headers: { "content-type": "text/event-stream" }, + }); +} + function createErroredResponsesSse(): Response { const encoder = new TextEncoder(); const stream = new ReadableStream({ @@ -527,6 +660,80 @@ describe("ProxyResponseHandler stream client abort finalization", () => { vi.clearAllMocks(); }); + it("copies Buffer-backed stream windows before retaining stats snapshots", () => { + const accumulator = new BoundedStreamTextAccumulator(); + const headMarker = "head-copy-marker"; + const tailMarker = "tail-copy-marker"; + const originalChunk = Buffer.from(`${headMarker}${"x".repeat(11 * 1024 * 1024)}${tailMarker}`); + const originalLength = originalChunk.byteLength; + + accumulator.pushBytes(originalChunk); + originalChunk.fill("z"); + + const snapshot = accumulator.finish(); + + expect(snapshot.truncated).toBe(true); + expect(snapshot.totalBytes).toBe(originalLength); + expect(snapshot.bufferedBytes).toBe(10 * 1024 * 1024); + expect(snapshot.text).toContain(headMarker); + expect(snapshot.text).toContain(tailMarker); + expect(snapshot.text).not.toContain("zzzzzzzzzzzzzzzz"); + }); + + it("does not apply the default stale cleanup when stream idle timeout is disabled", async () => { + const controller = new AbortController(); + const session = createSession(controller.signal); + session.provider.streamingIdleTimeoutMs = 0; + setDeferredStreamingFinalization(session, { + providerId: 1, + providerName: "avemujica-responses", + providerPriority: 1, + attemptNumber: 1, + totalProvidersAttempted: 1, + isFirstAttempt: true, + isFailoverSuccess: false, + endpointId: 42, + endpointUrl: "https://api.test.invalid/v1", + upstreamStatusCode: 200, + }); + + await ProxyResponseHandler.dispatch(session, createResponsesSse()); + await drainAsyncTasks(); + + const streamRegisterCall = vi.mocked(AsyncTaskManager.register).mock.calls.find((call) => { + const options = call[2] as { taskType?: string } | undefined; + return options?.taskType === "stream-processing"; + }); + + expect(streamRegisterCall).toBeDefined(); + expect(streamRegisterCall?.[2]).toEqual( + expect.objectContaining({ + staleTimeoutMs: Number.POSITIVE_INFINITY, + }) + ); + }); + + it("does not apply the default stale cleanup when non-stream request timeout is disabled", async () => { + const controller = new AbortController(); + const session = createSession(controller.signal); + session.provider.requestTimeoutNonStreamingMs = 0; + + await ProxyResponseHandler.dispatch(session, createResponsesJson()); + await drainAsyncTasks(); + + const nonStreamRegisterCall = vi.mocked(AsyncTaskManager.register).mock.calls.find((call) => { + const options = call[2] as { taskType?: string } | undefined; + return options?.taskType === "non-stream-processing"; + }); + + expect(nonStreamRegisterCall).toBeDefined(); + expect(nonStreamRegisterCall?.[2]).toEqual( + expect.objectContaining({ + staleTimeoutMs: Number.POSITIVE_INFINITY, + }) + ); + }); + it("finalizes a complete upstream responses stream as success when the downstream client already closed", async () => { const controller = new AbortController(); controller.abort(); @@ -559,6 +766,121 @@ describe("ProxyResponseHandler stream client abort finalization", () => { ); }); + it("keeps stream accounting bounded for oversized successful streams", async () => { + const controller = new AbortController(); + const session = createSession(controller.signal); + session.sessionId = "session_large"; + Object.assign(session, { + shouldPersistSessionDebugArtifacts: () => true, + }); + setDeferredStreamingFinalization(session, { + providerId: 1, + providerName: "avemujica-responses", + providerPriority: 1, + attemptNumber: 1, + totalProvidersAttempted: 1, + isFirstAttempt: true, + isFailoverSuccess: false, + endpointId: 42, + endpointUrl: "https://api.test.invalid/v1", + upstreamStatusCode: 200, + }); + + await ProxyResponseHandler.dispatch(session, createOversizedResponsesSse()); + await drainAsyncTasks(); + + expect(updateMessageRequestDetails).toHaveBeenCalledWith( + 123, + expect.objectContaining({ + statusCode: 200, + inputTokens: 463, + outputTokens: 11, + }) + ); + expect(SessionManager.storeSessionResponse).not.toHaveBeenCalled(); + + const traceCall = vi.mocked(emitProxyLangfuseTrace).mock.calls.at(-1); + expect(traceCall).toBeDefined(); + const traceData = traceCall?.[1]; + const responseText = traceData?.responseText ?? ""; + expect(responseText).toContain("[cch_truncated]"); + expect(responseText.length).toBeLessThan(10 * 1024 * 1024 + 1024); + }); + + it("decodes an untruncated stream as contiguous UTF-8 across the head/tail split", async () => { + const controller = new AbortController(); + const session = createSession(controller.signal); + session.sessionId = "session_utf8_boundary"; + Object.assign(session, { + shouldPersistSessionDebugArtifacts: () => true, + }); + setDeferredStreamingFinalization(session, { + providerId: 1, + providerName: "avemujica-responses", + providerPriority: 1, + attemptNumber: 1, + totalProvidersAttempted: 1, + isFirstAttempt: true, + isFailoverSuccess: false, + endpointId: 42, + endpointUrl: "https://api.test.invalid/v1", + upstreamStatusCode: 200, + }); + + await ProxyResponseHandler.dispatch(session, createUtf8SplitHeadTailResponsesSse()); + await drainAsyncTasks(); + + expect(updateMessageRequestDetails).toHaveBeenCalledWith( + 123, + expect.objectContaining({ + statusCode: 200, + inputTokens: 463, + outputTokens: 11, + }) + ); + + const traceCall = vi.mocked(emitProxyLangfuseTrace).mock.calls.at(-1); + expect(traceCall).toBeDefined(); + const responseText = traceCall?.[1].responseText ?? ""; + expect(responseText).toContain("界"); + expect(responseText).not.toContain("\uFFFD"); + expect(responseText).not.toContain("[cch_truncated]"); + }); + + it("keeps usage when a terminal responses event is split across tail chunk eviction", async () => { + const controller = new AbortController(); + const session = createSession(controller.signal); + session.sessionId = "session_split_tail"; + Object.assign(session, { + shouldPersistSessionDebugArtifacts: () => true, + }); + setDeferredStreamingFinalization(session, { + providerId: 1, + providerName: "avemujica-responses", + providerPriority: 1, + attemptNumber: 1, + totalProvidersAttempted: 1, + isFirstAttempt: true, + isFailoverSuccess: false, + endpointId: 42, + endpointUrl: "https://api.test.invalid/v1", + upstreamStatusCode: 200, + }); + + await ProxyResponseHandler.dispatch(session, createSplitTailBoundaryResponsesSse()); + await drainAsyncTasks(); + + expect(updateMessageRequestDetails).toHaveBeenCalledWith( + 123, + expect.objectContaining({ + statusCode: 200, + inputTokens: 463, + outputTokens: 11, + }) + ); + expect(SessionManager.storeSessionResponse).not.toHaveBeenCalled(); + }); + it("reclassifies a client-aborted stream as success when final usage was already received", async () => { const controller = new AbortController(); controller.abort(); diff --git a/tests/unit/proxy/response-handler-endpoint-circuit-isolation.test.ts b/tests/unit/proxy/response-handler-endpoint-circuit-isolation.test.ts index 05b6bfa7a..5b6f88a28 100644 --- a/tests/unit/proxy/response-handler-endpoint-circuit-isolation.test.ts +++ b/tests/unit/proxy/response-handler-endpoint-circuit-isolation.test.ts @@ -21,6 +21,7 @@ vi.mock("@/lib/async-task-manager", () => ({ asyncTasks.push(promise); return new AbortController(); }, + touch: () => true, cleanup: () => {}, cancel: () => {}, }, diff --git a/tests/unit/proxy/response-handler-gemini-stream-passthrough-timeouts.test.ts b/tests/unit/proxy/response-handler-gemini-stream-passthrough-timeouts.test.ts index 8142f190f..da67fcf81 100644 --- a/tests/unit/proxy/response-handler-gemini-stream-passthrough-timeouts.test.ts +++ b/tests/unit/proxy/response-handler-gemini-stream-passthrough-timeouts.test.ts @@ -5,7 +5,9 @@ import { ProxyForwarder } from "@/app/v1/_lib/proxy/forwarder"; import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler"; import { ProxySession } from "@/app/v1/_lib/proxy/session"; +import { AsyncTaskManager } from "@/lib/async-task-manager"; import { SessionManager } from "@/lib/session-manager"; +import { updateMessageRequestDetails } from "@/repository/message"; import type { Provider } from "@/types/provider"; const asyncTasks: Promise[] = []; @@ -37,10 +39,11 @@ vi.mock("@/app/v1/_lib/proxy/response-fixer", () => ({ vi.mock("@/lib/async-task-manager", () => ({ AsyncTaskManager: { - register: (_taskId: string, promise: Promise) => { + register: vi.fn((_taskId: string, promise: Promise) => { asyncTasks.push(promise); return new AbortController(); - }, + }), + touch: () => true, cleanup: () => {}, cancel: () => {}, }, @@ -75,7 +78,7 @@ vi.mock("@/repository/model-price", () => ({ vi.mock("@/lib/session-manager", () => ({ SessionManager: { storeSessionResponse: vi.fn(), - updateSessionUsage: vi.fn(), + updateSessionUsage: vi.fn(async () => undefined), clearSessionProvider: vi.fn(), storeSessionRequestPhaseSnapshot: vi.fn(async () => undefined), storeSessionResponsePhaseSnapshot: vi.fn(async () => undefined), @@ -394,6 +397,48 @@ describe("ProxyResponseHandler - Gemini stream passthrough timeouts", () => { await Promise.allSettled(asyncTasks); }); + test("Gemini 流式透传禁用 idle timeout 时不应回落到默认 stale cleanup", async () => { + asyncTasks.length = 0; + vi.mocked(AsyncTaskManager.register).mockClear(); + + const provider = createProvider({ + firstByteTimeoutStreamingMs: 1000, + streamingIdleTimeoutMs: 0, + }); + const session = createSession({ + clientAbortSignal: new AbortController().signal, + messageId: 12, + userId: 22, + }); + session.setProvider(provider); + + const upstreamResponse = new Response('data: {"usageMetadata":{"promptTokenCount":1}}\n\n', { + status: 200, + headers: { "content-type": "text/event-stream" }, + }); + + const returned = await ( + ProxyResponseHandler as unknown as { + handleStream: (session: ProxySession, response: Response) => Promise; + } + ).handleStream(session, upstreamResponse); + + await returned.text(); + await Promise.allSettled(asyncTasks); + + const statsRegisterCall = vi.mocked(AsyncTaskManager.register).mock.calls.find((call) => { + const options = call[2] as { taskType?: string } | undefined; + return options?.taskType === "stream-passthrough-stats"; + }); + + expect(statsRegisterCall).toBeDefined(); + expect(statsRegisterCall?.[2]).toEqual( + expect.objectContaining({ + staleTimeoutMs: Number.POSITIVE_INFINITY, + }) + ); + }); + test("不应在仅收到 headers 时清除首字节超时:无首块数据时应在窗口内中断避免悬挂", async () => { asyncTasks.length = 0; const { baseUrl, close } = await startSseServer((_req, res) => { @@ -662,4 +707,64 @@ describe("ProxyResponseHandler - Gemini stream passthrough timeouts", () => { await Promise.allSettled(asyncTasks); } }); + + test("Gemini 流式透传超大单 chunk 应保留尾部 usage 且不把截断快照作为完整正文存储", async () => { + asyncTasks.length = 0; + vi.mocked(SessionManager.storeSessionResponse).mockClear(); + vi.mocked(updateMessageRequestDetails).mockClear(); + + const clientAbortController = new AbortController(); + const provider = createProvider({ + firstByteTimeoutStreamingMs: 1000, + streamingIdleTimeoutMs: 0, + }); + const session = createSession({ + clientAbortSignal: clientAbortController.signal, + messageId: 77, + userId: 1, + }); + session.setProvider(provider); + session.setSessionId("gemini-large-single-chunk"); + ( + session as ProxySession & { + shouldPersistSessionDebugArtifacts?: () => boolean; + } + ).shouldPersistSessionDebugArtifacts = () => true; + + const hugeText = "x".repeat(11 * 1024 * 1024); + const bodyText = `data: {"text":"${hugeText}"}\n\ndata: {"usageMetadata":{"promptTokenCount":463,"candidatesTokenCount":11}}\n\n`; + const bodyBytes = new TextEncoder().encode(bodyText); + + const upstreamResponse = new Response( + new ReadableStream({ + start(controller) { + controller.enqueue(bodyBytes); + controller.close(); + }, + }), + { + status: 200, + headers: { "content-type": "text/event-stream" }, + } + ); + + const returned = await ( + ProxyResponseHandler as unknown as { + handleStream: (session: ProxySession, response: Response) => Promise; + } + ).handleStream(session, upstreamResponse); + + await returned.text(); + await Promise.allSettled(asyncTasks); + + expect(updateMessageRequestDetails).toHaveBeenCalledWith( + 77, + expect.objectContaining({ + statusCode: 200, + inputTokens: 463, + outputTokens: 11, + }) + ); + expect(SessionManager.storeSessionResponse).not.toHaveBeenCalled(); + }); }); diff --git a/tests/unit/proxy/response-handler-hedge-loser-priority.test.ts b/tests/unit/proxy/response-handler-hedge-loser-priority.test.ts index d22d19f36..9b0c1f7b7 100644 --- a/tests/unit/proxy/response-handler-hedge-loser-priority.test.ts +++ b/tests/unit/proxy/response-handler-hedge-loser-priority.test.ts @@ -31,6 +31,7 @@ vi.mock("@/lib/logger", () => ({ vi.mock("@/lib/async-task-manager", () => ({ AsyncTaskManager: { register: () => new AbortController(), + touch: vi.fn(() => true), cleanup: vi.fn(), cancel: vi.fn(), }, diff --git a/tests/unit/proxy/response-handler-lease-decrement.test.ts b/tests/unit/proxy/response-handler-lease-decrement.test.ts index 12a6fe282..28bc86d65 100644 --- a/tests/unit/proxy/response-handler-lease-decrement.test.ts +++ b/tests/unit/proxy/response-handler-lease-decrement.test.ts @@ -21,6 +21,7 @@ vi.mock("@/lib/async-task-manager", () => ({ asyncTasks.push(promise); return new AbortController(); }, + touch: vi.fn(() => true), cleanup: () => {}, cancel: () => {}, }, @@ -59,6 +60,7 @@ vi.mock("@/lib/session-manager", () => ({ SessionManager: { updateSessionUsage: vi.fn(async () => undefined), storeSessionResponse: vi.fn(), + storeSessionResponsePhaseSnapshot: vi.fn(async () => undefined), extractCodexPromptCacheKey: vi.fn(), updateSessionWithCodexCacheKey: vi.fn(), }, @@ -88,6 +90,7 @@ vi.mock("@/lib/proxy-status-tracker", () => ({ import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler"; import { ProxySession } from "@/app/v1/_lib/proxy/session"; +import { AsyncTaskManager } from "@/lib/async-task-manager"; import { SessionManager } from "@/lib/session-manager"; import { RateLimitService } from "@/lib/rate-limit"; import { SessionTracker } from "@/lib/session-tracker"; @@ -265,6 +268,38 @@ function createNonStreamResponse(usage: { input_tokens: number; output_tokens: n ); } +function createChunkedNonStreamResponse(usage: { + input_tokens: number; + output_tokens: number; +}): Response { + const body = JSON.stringify({ + type: "message", + usage, + }); + const encoder = new TextEncoder(); + const chunks = [ + encoder.encode(body.slice(0, 8)), + encoder.encode(body.slice(8, 24)), + encoder.encode(body.slice(24)), + ]; + let index = 0; + + const stream = new ReadableStream({ + pull(controller) { + if (index < chunks.length) { + controller.enqueue(chunks[index++]); + return; + } + controller.close(); + }, + }); + + return new Response(stream, { + status: 200, + headers: { "content-type": "application/json" }, + }); +} + function createStreamResponse(usage: { input_tokens: number; output_tokens: number }): Response { const sseText = `event: message_delta\ndata: ${JSON.stringify({ usage })}\n\n`; const encoder = new TextEncoder(); @@ -303,6 +338,7 @@ describe("Lease Budget Decrement after trackCostToRedis", () => { vi.mocked(updateMessageRequestDetails).mockResolvedValue(undefined); vi.mocked(updateMessageRequestDuration).mockResolvedValue(undefined); vi.mocked(SessionManager.storeSessionResponse).mockResolvedValue(undefined); + vi.mocked(SessionManager.storeSessionResponsePhaseSnapshot).mockResolvedValue(undefined); vi.mocked(RateLimitService.trackCost).mockResolvedValue(undefined); vi.mocked(RateLimitService.trackUserDailyCost).mockResolvedValue(undefined); vi.mocked(RateLimitService.decrementLeaseBudget).mockResolvedValue({ @@ -356,6 +392,46 @@ describe("Lease Budget Decrement after trackCostToRedis", () => { } }); + it("should refresh task activity while reading chunked non-stream response bodies", async () => { + const messageId = 5010; + const session = createSession({ + originalModel, + redirectedModel: originalModel, + sessionId: "sess-non-stream-chunked-touch", + messageId, + }); + + const response = createChunkedNonStreamResponse(usage); + const cloneSpy = vi.spyOn(response, "clone"); + + await ProxyResponseHandler.dispatch(session, response); + await drainAsyncTasks(); + + const taskId = `non-stream-${messageId}`; + const touchCalls = vi + .mocked(AsyncTaskManager.touch) + .mock.calls.filter(([calledTaskId]) => calledTaskId === taskId); + expect(touchCalls.length).toBeGreaterThanOrEqual(2); + expect(cloneSpy).toHaveBeenCalledTimes(1); + expect(SessionManager.storeSessionResponsePhaseSnapshot).toHaveBeenCalledWith( + session.sessionId, + "after", + expect.objectContaining({ + body: expect.stringContaining('"type":"message"'), + meta: expect.objectContaining({ statusCode: 200 }), + }), + session.requestSequence + ); + expect(updateMessageRequestDetails).toHaveBeenCalledWith( + messageId, + expect.objectContaining({ + statusCode: 200, + inputTokens: usage.input_tokens, + outputTokens: usage.output_tokens, + }) + ); + }); + it("should call decrementLeaseBudget for all windows and entity types (stream)", async () => { const session = createSession({ originalModel, diff --git a/tests/unit/proxy/response-handler-non200.test.ts b/tests/unit/proxy/response-handler-non200.test.ts index 74e0bc909..ef62b25a7 100644 --- a/tests/unit/proxy/response-handler-non200.test.ts +++ b/tests/unit/proxy/response-handler-non200.test.ts @@ -23,6 +23,7 @@ vi.mock("@/lib/async-task-manager", () => ({ asyncTasks.push(promise); return new AbortController(); }, + touch: () => true, cleanup: () => {}, cancel: () => {}, },