From a3eaea8b1b92e038fb81939748e6765fbe324b4f Mon Sep 17 00:00:00 2001 From: Night Yu <57441406+NightYuYyy@users.noreply.github.com> Date: Sat, 13 Jun 2026 16:26:49 +0800 Subject: [PATCH 1/3] =?UTF-8?q?feat(settings):=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=89=B9=E9=87=8F=E6=B5=8B=E8=AF=95=E4=BE=9B=E5=BA=94=E5=95=86?= =?UTF-8?q?=E8=83=BD=E5=8A=9B=20(#1276)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(settings): 增加批量测试供应商能力 -【前端】新增批量测试弹窗、并发执行与结果筛选 -【前端】在批量操作中加入“测试”入口与多语言文案 -【后端】新增按供应商 ID 执行测试的接口与 Action -【类型】补充接口与 schema,支持批量测试请求 -【测试】增加动作与 hook 的单元测试覆盖 * fix(settings): 修复批量测试评审与流水线问题 -【流水线】重新生成 OpenAPI 类型,移除本仓库不存在的接口定义 -【前端】关闭弹窗时停止派发新测试,避免后台消耗供应商额度 -【前端】重新测试时重置结果筛选,避免列表显示为空 -【后端】Gemini JSON 凭证先换取 access token 并改用 Bearer 认证 -【测试】provider-manager 测试 mock 批量测试弹窗;补充 Gemini 凭证与请求头用例 * fix(api): 以最新依赖重新生成 OpenAPI 类型对齐 CI 校验 --------- Co-authored-by: NightYu --- .gitignore | 1 + messages/en/settings/index.ts | 2 + messages/en/settings/providers/batchEdit.json | 1 + messages/en/settings/providers/batchTest.json | 57 +++ messages/ja/settings/index.ts | 2 + messages/ja/settings/providers/batchEdit.json | 1 + messages/ja/settings/providers/batchTest.json | 57 +++ messages/ru/settings/index.ts | 2 + messages/ru/settings/providers/batchEdit.json | 1 + messages/ru/settings/providers/batchTest.json | 57 +++ messages/zh-CN/settings/index.ts | 2 + .../zh-CN/settings/providers/batchEdit.json | 1 + .../zh-CN/settings/providers/batchTest.json | 57 +++ messages/zh-TW/settings/index.ts | 2 + .../zh-TW/settings/providers/batchEdit.json | 1 + .../zh-TW/settings/providers/batchTest.json | 57 +++ src/actions/providers.ts | 155 +++++- .../batch-edit/provider-batch-actions.tsx | 9 +- .../batch-test/batch-test-dialog.tsx | 445 ++++++++++++++++++ .../providers/_components/batch-test/index.ts | 8 + .../batch-test/use-batch-provider-test.ts | 145 ++++++ .../_components/provider-manager.tsx | 18 + .../api/v1/resources/providers/handlers.ts | 23 + src/app/api/v1/resources/providers/router.ts | 31 ++ src/lib/api-client/v1/actions/providers.ts | 6 + src/lib/api-client/v1/openapi-types.gen.ts | 203 ++++++++ src/lib/api/v1/schemas/providers.ts | 6 + src/lib/provider-testing/test-service.ts | 9 +- src/lib/provider-testing/types.ts | 2 + .../provider-testing/utils/test-prompts.ts | 14 +- .../unit/actions/providers-test-by-id.test.ts | 276 +++++++++++ .../test-prompts-headers.test.ts | 16 + .../providers/provider-manager.test.tsx | 5 + .../use-batch-provider-test.test.tsx | 177 +++++++ 34 files changed, 1816 insertions(+), 33 deletions(-) create mode 100644 messages/en/settings/providers/batchTest.json create mode 100644 messages/ja/settings/providers/batchTest.json create mode 100644 messages/ru/settings/providers/batchTest.json create mode 100644 messages/zh-CN/settings/providers/batchTest.json create mode 100644 messages/zh-TW/settings/providers/batchTest.json create mode 100644 src/app/[locale]/settings/providers/_components/batch-test/batch-test-dialog.tsx create mode 100644 src/app/[locale]/settings/providers/_components/batch-test/index.ts create mode 100644 src/app/[locale]/settings/providers/_components/batch-test/use-batch-provider-test.ts create mode 100644 tests/unit/actions/providers-test-by-id.test.ts create mode 100644 tests/unit/settings/providers/use-batch-provider-test.test.tsx diff --git a/.gitignore b/.gitignore index 1ec9ea7d4..b3f0bfc70 100644 --- a/.gitignore +++ b/.gitignore @@ -84,6 +84,7 @@ docs-site/node_modules/ # local scratch tmp/ +.playwright-mcp/ .trae/ .sisyphus .ace-tool/ diff --git a/messages/en/settings/index.ts b/messages/en/settings/index.ts index cfa249145..3cd1471cd 100644 --- a/messages/en/settings/index.ts +++ b/messages/en/settings/index.ts @@ -15,6 +15,7 @@ import strings from "./strings.json"; import providersAutoSort from "./providers/autoSort.json"; import providersBatchEdit from "./providers/batchEdit.json"; +import providersBatchTest from "./providers/batchTest.json"; import providersDispatchSimulator from "./providers/dispatchSimulator.json"; import providersFilter from "./providers/filter.json"; import providersGuide from "./providers/guide.json"; @@ -84,6 +85,7 @@ const providers = { ...providersStrings, autoSort: providersAutoSort, batchEdit: providersBatchEdit, + batchTest: providersBatchTest, dispatchSimulator: providersDispatchSimulator, filter: providersFilter, form: providersForm, diff --git a/messages/en/settings/providers/batchEdit.json b/messages/en/settings/providers/batchEdit.json index 0946acf05..b6ea94acf 100644 --- a/messages/en/settings/providers/batchEdit.json +++ b/messages/en/settings/providers/batchEdit.json @@ -2,6 +2,7 @@ "selectedCount": "{count} selected", "actions": { "edit": "Edit", + "test": "Test", "resetCircuit": "Reset Circuit", "delete": "Delete" }, diff --git a/messages/en/settings/providers/batchTest.json b/messages/en/settings/providers/batchTest.json new file mode 100644 index 000000000..2725bf696 --- /dev/null +++ b/messages/en/settings/providers/batchTest.json @@ -0,0 +1,57 @@ +{ + "title": "Batch test providers", + "description": "{count} providers will be tested.", + "overLimit": "Selection exceeds the limit; only the first {max} will be tested.", + "model": { + "label": "Test model", + "placeholder": "Leave empty to use type defaults", + "hint": "The model is applied to every selected provider. Leave empty to use the default test model for each provider type." + }, + "start": "Start test", + "retest": "Test again", + "cancelRemaining": "Cancel remaining", + "close": "Close", + "summary": { + "progress": "{done}/{total}", + "green": "Available {count}", + "yellow": "Degraded {count}", + "failed": "Failed {count}" + }, + "filter": { + "all": "All", + "green": "Available", + "yellow": "Degraded", + "failed": "Failed" + }, + "table": { + "provider": "Provider", + "group": "Group", + "status": "Status", + "latency": "Latency", + "message": "Message", + "enabled": "Enabled", + "empty": "No matching results" + }, + "status": { + "pending": "Pending", + "testing": "Testing", + "green": "Available", + "yellow": "Degraded", + "red": "Unavailable", + "error": "Request failed", + "canceled": "Canceled" + }, + "bulk": { + "disableFailed": "Disable all failed ({count})", + "enableGreen": "Enable all available ({count})" + }, + "toast": { + "toggleFailed": "Failed to toggle status: {error}", + "bulkFailed": "Bulk operation failed: {error}", + "bulkApplied": "Updated {count} providers", + "undo": "Undo", + "undoSuccess": "Restored {count} providers", + "undoFailed": "Undo failed: {error}", + "unknownError": "Unknown error" + } +} diff --git a/messages/ja/settings/index.ts b/messages/ja/settings/index.ts index cfa249145..3cd1471cd 100644 --- a/messages/ja/settings/index.ts +++ b/messages/ja/settings/index.ts @@ -15,6 +15,7 @@ import strings from "./strings.json"; import providersAutoSort from "./providers/autoSort.json"; import providersBatchEdit from "./providers/batchEdit.json"; +import providersBatchTest from "./providers/batchTest.json"; import providersDispatchSimulator from "./providers/dispatchSimulator.json"; import providersFilter from "./providers/filter.json"; import providersGuide from "./providers/guide.json"; @@ -84,6 +85,7 @@ const providers = { ...providersStrings, autoSort: providersAutoSort, batchEdit: providersBatchEdit, + batchTest: providersBatchTest, dispatchSimulator: providersDispatchSimulator, filter: providersFilter, form: providersForm, diff --git a/messages/ja/settings/providers/batchEdit.json b/messages/ja/settings/providers/batchEdit.json index be7f18739..1382ed8d3 100644 --- a/messages/ja/settings/providers/batchEdit.json +++ b/messages/ja/settings/providers/batchEdit.json @@ -2,6 +2,7 @@ "selectedCount": "{count} 件選択中", "actions": { "edit": "編集", + "test": "テスト", "resetCircuit": "サーキットブレーカーをリセット", "delete": "削除" }, diff --git a/messages/ja/settings/providers/batchTest.json b/messages/ja/settings/providers/batchTest.json new file mode 100644 index 000000000..7d4ec1a43 --- /dev/null +++ b/messages/ja/settings/providers/batchTest.json @@ -0,0 +1,57 @@ +{ + "title": "プロバイダーの一括テスト", + "description": "{count} 件のプロバイダーをテストします。", + "overLimit": "上限を超えています。今回は先頭の {max} 件のみテストします。", + "model": { + "label": "テストモデル", + "placeholder": "空欄の場合はタイプ別のデフォルトを使用", + "hint": "指定したモデルは選択中のすべてのプロバイダーに適用されます。空欄の場合はプロバイダータイプごとのデフォルトテストモデルを使用します。" + }, + "start": "テスト開始", + "retest": "再テスト", + "cancelRemaining": "残りをキャンセル", + "close": "閉じる", + "summary": { + "progress": "{done}/{total}", + "green": "利用可能 {count}", + "yellow": "不安定 {count}", + "failed": "失敗 {count}" + }, + "filter": { + "all": "すべて", + "green": "利用可能", + "yellow": "不安定", + "failed": "失敗" + }, + "table": { + "provider": "プロバイダー", + "group": "グループ", + "status": "ステータス", + "latency": "レイテンシ", + "message": "メッセージ", + "enabled": "有効", + "empty": "該当する結果がありません" + }, + "status": { + "pending": "待機中", + "testing": "テスト中", + "green": "利用可能", + "yellow": "不安定", + "red": "利用不可", + "error": "リクエスト失敗", + "canceled": "キャンセル済み" + }, + "bulk": { + "disableFailed": "失敗をすべて無効化 ({count})", + "enableGreen": "利用可能をすべて有効化 ({count})" + }, + "toast": { + "toggleFailed": "ステータスの切り替えに失敗しました: {error}", + "bulkFailed": "一括操作に失敗しました: {error}", + "bulkApplied": "{count} 件のプロバイダーを更新しました", + "undo": "元に戻す", + "undoSuccess": "{count} 件のプロバイダーを復元しました", + "undoFailed": "元に戻せませんでした: {error}", + "unknownError": "不明なエラー" + } +} diff --git a/messages/ru/settings/index.ts b/messages/ru/settings/index.ts index cfa249145..3cd1471cd 100644 --- a/messages/ru/settings/index.ts +++ b/messages/ru/settings/index.ts @@ -15,6 +15,7 @@ import strings from "./strings.json"; import providersAutoSort from "./providers/autoSort.json"; import providersBatchEdit from "./providers/batchEdit.json"; +import providersBatchTest from "./providers/batchTest.json"; import providersDispatchSimulator from "./providers/dispatchSimulator.json"; import providersFilter from "./providers/filter.json"; import providersGuide from "./providers/guide.json"; @@ -84,6 +85,7 @@ const providers = { ...providersStrings, autoSort: providersAutoSort, batchEdit: providersBatchEdit, + batchTest: providersBatchTest, dispatchSimulator: providersDispatchSimulator, filter: providersFilter, form: providersForm, diff --git a/messages/ru/settings/providers/batchEdit.json b/messages/ru/settings/providers/batchEdit.json index 986226327..5f6fc9154 100644 --- a/messages/ru/settings/providers/batchEdit.json +++ b/messages/ru/settings/providers/batchEdit.json @@ -2,6 +2,7 @@ "selectedCount": "Выбрано: {count}", "actions": { "edit": "Редактировать", + "test": "Тест", "resetCircuit": "Сбросить предохранитель", "delete": "Удалить" }, diff --git a/messages/ru/settings/providers/batchTest.json b/messages/ru/settings/providers/batchTest.json new file mode 100644 index 000000000..2aa69b7a2 --- /dev/null +++ b/messages/ru/settings/providers/batchTest.json @@ -0,0 +1,57 @@ +{ + "title": "Массовое тестирование провайдеров", + "description": "Будет протестировано провайдеров: {count}.", + "overLimit": "Превышен лимит за один запуск; будут протестированы только первые {max}.", + "model": { + "label": "Модель для теста", + "placeholder": "Оставьте пустым для модели по умолчанию", + "hint": "Указанная модель применяется ко всем выбранным провайдерам. Если оставить пустым, используется модель по умолчанию для каждого типа провайдера." + }, + "start": "Начать тест", + "retest": "Повторить тест", + "cancelRemaining": "Отменить оставшиеся", + "close": "Закрыть", + "summary": { + "progress": "{done}/{total}", + "green": "Доступно {count}", + "yellow": "Нестабильно {count}", + "failed": "Сбой {count}" + }, + "filter": { + "all": "Все", + "green": "Доступные", + "yellow": "Нестабильные", + "failed": "Сбойные" + }, + "table": { + "provider": "Провайдер", + "group": "Группа", + "status": "Статус", + "latency": "Задержка", + "message": "Сообщение", + "enabled": "Включен", + "empty": "Нет подходящих результатов" + }, + "status": { + "pending": "В очереди", + "testing": "Тестируется", + "green": "Доступен", + "yellow": "Нестабилен", + "red": "Недоступен", + "error": "Ошибка запроса", + "canceled": "Отменено" + }, + "bulk": { + "disableFailed": "Отключить все сбойные ({count})", + "enableGreen": "Включить все доступные ({count})" + }, + "toast": { + "toggleFailed": "Не удалось переключить статус: {error}", + "bulkFailed": "Массовая операция не удалась: {error}", + "bulkApplied": "Обновлено провайдеров: {count}", + "undo": "Отменить", + "undoSuccess": "Восстановлено провайдеров: {count}", + "undoFailed": "Не удалось отменить: {error}", + "unknownError": "Неизвестная ошибка" + } +} diff --git a/messages/zh-CN/settings/index.ts b/messages/zh-CN/settings/index.ts index cfa249145..3cd1471cd 100644 --- a/messages/zh-CN/settings/index.ts +++ b/messages/zh-CN/settings/index.ts @@ -15,6 +15,7 @@ import strings from "./strings.json"; import providersAutoSort from "./providers/autoSort.json"; import providersBatchEdit from "./providers/batchEdit.json"; +import providersBatchTest from "./providers/batchTest.json"; import providersDispatchSimulator from "./providers/dispatchSimulator.json"; import providersFilter from "./providers/filter.json"; import providersGuide from "./providers/guide.json"; @@ -84,6 +85,7 @@ const providers = { ...providersStrings, autoSort: providersAutoSort, batchEdit: providersBatchEdit, + batchTest: providersBatchTest, dispatchSimulator: providersDispatchSimulator, filter: providersFilter, form: providersForm, diff --git a/messages/zh-CN/settings/providers/batchEdit.json b/messages/zh-CN/settings/providers/batchEdit.json index f3978ef7d..a51fb54a6 100644 --- a/messages/zh-CN/settings/providers/batchEdit.json +++ b/messages/zh-CN/settings/providers/batchEdit.json @@ -2,6 +2,7 @@ "selectedCount": "已选择 {count} 个", "actions": { "edit": "编辑", + "test": "测试", "resetCircuit": "重置熔断器", "delete": "删除" }, diff --git a/messages/zh-CN/settings/providers/batchTest.json b/messages/zh-CN/settings/providers/batchTest.json new file mode 100644 index 000000000..6c06f886b --- /dev/null +++ b/messages/zh-CN/settings/providers/batchTest.json @@ -0,0 +1,57 @@ +{ + "title": "批量测试供应商", + "description": "将测试 {count} 个供应商。", + "overLimit": "超过单次上限,本次仅测试前 {max} 个。", + "model": { + "label": "测试模型", + "placeholder": "留空使用各类型默认模型", + "hint": "指定的模型将应用于所有选中的供应商;留空则按供应商类型使用默认测试模型。" + }, + "start": "开始测试", + "retest": "重新测试", + "cancelRemaining": "取消剩余", + "close": "关闭", + "summary": { + "progress": "{done}/{total}", + "green": "可用 {count}", + "yellow": "波动 {count}", + "failed": "失败 {count}" + }, + "filter": { + "all": "全部", + "green": "可用", + "yellow": "波动", + "failed": "失败" + }, + "table": { + "provider": "供应商", + "group": "分组", + "status": "状态", + "latency": "延迟", + "message": "信息", + "enabled": "启用", + "empty": "暂无匹配的结果" + }, + "status": { + "pending": "等待中", + "testing": "测试中", + "green": "可用", + "yellow": "波动", + "red": "不可用", + "error": "请求失败", + "canceled": "已取消" + }, + "bulk": { + "disableFailed": "禁用所有失败 ({count})", + "enableGreen": "启用所有可用 ({count})" + }, + "toast": { + "toggleFailed": "状态切换失败: {error}", + "bulkFailed": "批量操作失败: {error}", + "bulkApplied": "已更新 {count} 个供应商", + "undo": "撤销", + "undoSuccess": "已恢复 {count} 个供应商", + "undoFailed": "撤销失败: {error}", + "unknownError": "未知错误" + } +} diff --git a/messages/zh-TW/settings/index.ts b/messages/zh-TW/settings/index.ts index cfa249145..3cd1471cd 100644 --- a/messages/zh-TW/settings/index.ts +++ b/messages/zh-TW/settings/index.ts @@ -15,6 +15,7 @@ import strings from "./strings.json"; import providersAutoSort from "./providers/autoSort.json"; import providersBatchEdit from "./providers/batchEdit.json"; +import providersBatchTest from "./providers/batchTest.json"; import providersDispatchSimulator from "./providers/dispatchSimulator.json"; import providersFilter from "./providers/filter.json"; import providersGuide from "./providers/guide.json"; @@ -84,6 +85,7 @@ const providers = { ...providersStrings, autoSort: providersAutoSort, batchEdit: providersBatchEdit, + batchTest: providersBatchTest, dispatchSimulator: providersDispatchSimulator, filter: providersFilter, form: providersForm, diff --git a/messages/zh-TW/settings/providers/batchEdit.json b/messages/zh-TW/settings/providers/batchEdit.json index eede51fab..2a434e97e 100644 --- a/messages/zh-TW/settings/providers/batchEdit.json +++ b/messages/zh-TW/settings/providers/batchEdit.json @@ -2,6 +2,7 @@ "selectedCount": "已選擇 {count} 個", "actions": { "edit": "編輯", + "test": "測試", "resetCircuit": "重設熔斷器", "delete": "刪除" }, diff --git a/messages/zh-TW/settings/providers/batchTest.json b/messages/zh-TW/settings/providers/batchTest.json new file mode 100644 index 000000000..2b4db89ba --- /dev/null +++ b/messages/zh-TW/settings/providers/batchTest.json @@ -0,0 +1,57 @@ +{ + "title": "批次測試供應商", + "description": "將測試 {count} 個供應商。", + "overLimit": "超過單次上限,本次僅測試前 {max} 個。", + "model": { + "label": "測試模型", + "placeholder": "留空使用各類型預設模型", + "hint": "指定的模型將套用於所有選中的供應商;留空則依供應商類型使用預設測試模型。" + }, + "start": "開始測試", + "retest": "重新測試", + "cancelRemaining": "取消剩餘", + "close": "關閉", + "summary": { + "progress": "{done}/{total}", + "green": "可用 {count}", + "yellow": "波動 {count}", + "failed": "失敗 {count}" + }, + "filter": { + "all": "全部", + "green": "可用", + "yellow": "波動", + "failed": "失敗" + }, + "table": { + "provider": "供應商", + "group": "分組", + "status": "狀態", + "latency": "延遲", + "message": "訊息", + "enabled": "啟用", + "empty": "暫無符合的結果" + }, + "status": { + "pending": "等待中", + "testing": "測試中", + "green": "可用", + "yellow": "波動", + "red": "不可用", + "error": "請求失敗", + "canceled": "已取消" + }, + "bulk": { + "disableFailed": "停用所有失敗 ({count})", + "enableGreen": "啟用所有可用 ({count})" + }, + "toast": { + "toggleFailed": "狀態切換失敗: {error}", + "bulkFailed": "批次操作失敗: {error}", + "bulkApplied": "已更新 {count} 個供應商", + "undo": "復原", + "undoSuccess": "已還原 {count} 個供應商", + "undoFailed": "復原失敗: {error}", + "unknownError": "未知錯誤" + } +} diff --git a/src/actions/providers.ts b/src/actions/providers.ts index e5992862e..4c09f2bf6 100644 --- a/src/actions/providers.ts +++ b/src/actions/providers.ts @@ -37,6 +37,7 @@ import { import { executeProviderTest, type ProviderTestConfig, + type ProviderTestResult, type TestStatus, type TestSubStatus, } from "@/lib/provider-testing"; @@ -4714,37 +4715,141 @@ export async function testProviderUnified(data: UnifiedTestArgs): Promise["data"]; + +/** + * Map an internal ProviderTestResult to the unified test response payload + */ +function buildUnifiedTestSuccessData(result: ProviderTestResult): UnifiedTestSuccessData { + const statusText = + result.status === "green" ? "可用" : result.status === "yellow" ? "波动" : "不可用"; + const message = `供应商 ${statusText}: ${SUB_STATUS_MESSAGES[result.subStatus]}`; + + return { + success: result.success, + status: result.status, + subStatus: result.subStatus, + message, + latencyMs: result.latencyMs, + firstByteMs: result.firstByteMs, + httpStatusCode: result.httpStatusCode, + httpStatusText: result.httpStatusText, + model: result.model, + content: result.content, + requestUrl: result.requestUrl, + rawResponse: result.rawResponse, + usage: result.usage, + streamInfo: result.streamInfo, + errorMessage: result.errorMessage, + errorType: result.errorType, + testedAt: result.testedAt.toISOString(), + validationDetails: result.validationDetails, + }; +} + +// ============================================================================ +// Test Provider By Id +// ============================================================================ + +/** + * Arguments for testing an existing provider by id + */ +export type TestProviderByIdArgs = { + /** Optional model override; falls back to the provider type preset default */ + model?: string; +}; + +/** Timeout for by-id tests; Gemini needs longer because of thinking output */ +const BY_ID_TEST_TIMEOUT_MS = 15000; +const BY_ID_TEST_GEMINI_TIMEOUT_MS = 60000; + +/** + * Run the unified provider test against a stored provider. + * + * Server-side variant of testProviderUnified: loads URL, key, proxy and custom + * headers from the database so the plaintext key never reaches the client. + * Used by the batch provider testing UI. The test result does not touch the + * circuit breaker or usage statistics. + */ +export async function testProviderById( + providerId: number, + args?: TestProviderByIdArgs +): Promise { + const session = await getSession(); + if (!session || session.user.role !== "admin") { + return { + ok: false, + error: "未授权", + }; + } - const message = `供应商 ${statusText}: ${SUB_STATUS_MESSAGES[result.subStatus]}`; + const provider = await findProviderById(providerId); + if (!provider) { + return { + ok: false, + error: "供应商不存在", + errorCode: "provider.not_found", + }; + } + + const urlValidation = await isUrlSafeForApiTest(provider.url); + if (!urlValidation.safe) { + return { + ok: false, + error: urlValidation.reason ?? "无效的 URL", + }; + } + + const isGeminiType = provider.providerType === "gemini" || provider.providerType === "gemini-cli"; + + // JSON credentials must be exchanged for an access token and sent as a + // Bearer header, mirroring the dedicated Gemini test/model-fetch flows + let apiKey = provider.key; + let geminiBearerAuth = false; + if (isGeminiType) { + try { + apiKey = await GeminiAuth.getAccessToken(provider.key); + geminiBearerAuth = GeminiAuth.isJson(provider.key); + } catch (error) { + logger.warn("testProviderById: gemini auth preprocess failed", { error, providerId }); + } + } + + try { + const config: ProviderTestConfig = { + providerId: String(provider.id), + providerUrl: provider.url, + apiKey, + providerType: provider.providerType, + model: args?.model?.trim() || undefined, + proxyUrl: provider.proxyUrl ?? undefined, + proxyFallbackToDirect: provider.proxyFallbackToDirect, + customHeaders: provider.customHeaders ?? undefined, + timeoutMs: isGeminiType ? BY_ID_TEST_GEMINI_TIMEOUT_MS : BY_ID_TEST_TIMEOUT_MS, + geminiBearerAuth: geminiBearerAuth || undefined, + }; + + const result = await executeProviderTest(config); return { ok: true, - data: { - success: result.success, - status: result.status, - subStatus: result.subStatus, - message, - latencyMs: result.latencyMs, - firstByteMs: result.firstByteMs, - httpStatusCode: result.httpStatusCode, - httpStatusText: result.httpStatusText, - model: result.model, - content: result.content, - requestUrl: result.requestUrl, - rawResponse: result.rawResponse, - usage: result.usage, - streamInfo: result.streamInfo, - errorMessage: result.errorMessage, - errorType: result.errorType, - testedAt: result.testedAt.toISOString(), - validationDetails: result.validationDetails, - }, + data: buildUnifiedTestSuccessData(result), }; } catch (error) { - logger.error("testProviderUnified error", { error }); + logger.error("testProviderById error", { error, providerId }); return { ok: false, error: error instanceof Error ? error.message : "测试执行失败", diff --git a/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-actions.tsx b/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-actions.tsx index fbf7ece72..39dbf1533 100644 --- a/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-actions.tsx +++ b/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-actions.tsx @@ -1,12 +1,12 @@ "use client"; -import { Pencil, RotateCcw, Trash2 } from "lucide-react"; +import { FlaskConical, Pencil, RotateCcw, Trash2 } from "lucide-react"; import { useTranslations } from "next-intl"; import { Button } from "@/components/ui/button"; import { Separator } from "@/components/ui/separator"; import { cn } from "@/lib/utils"; -export type BatchActionMode = "edit" | "delete" | "resetCircuit" | null; +export type BatchActionMode = "edit" | "delete" | "resetCircuit" | "test" | null; export interface ProviderBatchActionsProps { selectedCount: number; @@ -47,6 +47,11 @@ export function ProviderBatchActions({ {t("actions.edit")} + + + ) : ( + + )} + + + {/* Progress summary */} + {hasResults && ( +
+
+ + {t("summary.progress", { done: summary.done, total: targets.length })} + + + {t("summary.green", { count: summary.green })} + + + {t("summary.yellow", { count: summary.yellow })} + + + {t("summary.failed", { count: summary.failed })} + + {isRunning && } +
+ +
+ {(["all", "green", "yellow", "failed"] as const).map((filterKey) => ( + + ))} +
+
+ )} + + {/* Result table */} +
+ + + + {t("table.provider")} + {t("table.group")} + {t("table.status")} + {t("table.latency")} + {t("table.message")} + {t("table.enabled")} + + + + {visibleRows.map((provider) => { + const row: BatchTestRowResult = results[provider.id] ?? { status: "pending" }; + return ( + + +
{provider.name}
+
{provider.providerType}
+
+ + + {provider.groupTag || "default"} + + + + + {row.status === "testing" && ( + + )} + {t(`status.${row.status}`)} + + + + {row.latencyMs != null ? `${row.latencyMs}ms` : "-"} + + + + {row.message ?? "-"} + + + + handleToggleEnabled(provider, checked)} + aria-label={t("table.enabled")} + /> + +
+ ); + })} + {visibleRows.length === 0 && ( + + + {t("table.empty")} + + + )} +
+
+
+ + + {hasFinishedRun && ( +
+ + +
+ )} + +
+ + + ); +} diff --git a/src/app/[locale]/settings/providers/_components/batch-test/index.ts b/src/app/[locale]/settings/providers/_components/batch-test/index.ts new file mode 100644 index 000000000..847b26b4d --- /dev/null +++ b/src/app/[locale]/settings/providers/_components/batch-test/index.ts @@ -0,0 +1,8 @@ +export { BatchTestDialog, type BatchTestDialogProps } from "./batch-test-dialog"; +export { + BATCH_TEST_CONCURRENCY, + BATCH_TEST_MAX_PROVIDERS, + type BatchTestRowResult, + type BatchTestRowStatus, + useBatchProviderTest, +} from "./use-batch-provider-test"; diff --git a/src/app/[locale]/settings/providers/_components/batch-test/use-batch-provider-test.ts b/src/app/[locale]/settings/providers/_components/batch-test/use-batch-provider-test.ts new file mode 100644 index 000000000..cbb121244 --- /dev/null +++ b/src/app/[locale]/settings/providers/_components/batch-test/use-batch-provider-test.ts @@ -0,0 +1,145 @@ +"use client"; + +import { useCallback, useRef, useState } from "react"; +import { testProviderById } from "@/lib/api-client/v1/actions/providers"; + +/** Max providers tested in one batch run */ +export const BATCH_TEST_MAX_PROVIDERS = 100; +/** Number of providers tested concurrently */ +export const BATCH_TEST_CONCURRENCY = 5; + +export type BatchTestRowStatus = + | "pending" + | "testing" + | "green" + | "yellow" + | "red" + | "error" + | "canceled"; + +export interface BatchTestRowResult { + status: BatchTestRowStatus; + latencyMs?: number; + message?: string; + responseModel?: string; + httpStatusCode?: number; +} + +interface UnifiedTestData { + success: boolean; + status: "green" | "yellow" | "red"; + subStatus: string; + message: string; + latencyMs: number; + httpStatusCode?: number; + model?: string; + errorMessage?: string; +} + +export interface UseBatchProviderTestResult { + results: Record; + isRunning: boolean; + run: (providerIds: number[], model?: string) => Promise; + cancel: () => void; + reset: () => void; +} + +/** + * Client-side concurrency pool that tests providers one by one through the + * by-id endpoint. Cancelling stops launching new tests; in-flight requests + * finish naturally and keep their results. + */ +export function useBatchProviderTest(): UseBatchProviderTestResult { + const [results, setResults] = useState>({}); + const [isRunning, setIsRunning] = useState(false); + const cancelRef = useRef(false); + const runIdRef = useRef(0); + + const setRow = useCallback((providerId: number, row: BatchTestRowResult) => { + setResults((prev) => ({ ...prev, [providerId]: row })); + }, []); + + const run = useCallback( + async (providerIds: number[], model?: string) => { + const targets = providerIds.slice(0, BATCH_TEST_MAX_PROVIDERS); + if (targets.length === 0) return; + + const runId = ++runIdRef.current; + cancelRef.current = false; + setIsRunning(true); + setResults(Object.fromEntries(targets.map((id) => [id, { status: "pending" as const }]))); + + const trimmedModel = model?.trim() || undefined; + let cursor = 0; + + const worker = async (): Promise => { + while (true) { + if (cancelRef.current || runIdRef.current !== runId) return; + const index = cursor; + cursor += 1; + if (index >= targets.length) return; + const providerId = targets[index]; + + setRow(providerId, { status: "testing" }); + try { + const result = await testProviderById( + providerId, + trimmedModel ? { model: trimmedModel } : undefined + ); + if (runIdRef.current !== runId) return; + if (result.ok) { + const data = result.data as UnifiedTestData; + setRow(providerId, { + status: data.status, + latencyMs: data.latencyMs, + message: data.errorMessage ?? data.message, + responseModel: data.model, + httpStatusCode: data.httpStatusCode, + }); + } else { + setRow(providerId, { status: "error", message: result.error }); + } + } catch (error) { + if (runIdRef.current !== runId) return; + setRow(providerId, { + status: "error", + message: error instanceof Error ? error.message : String(error), + }); + } + } + }; + + const workerCount = Math.min(BATCH_TEST_CONCURRENCY, targets.length); + await Promise.all(Array.from({ length: workerCount }, () => worker())); + + if (runIdRef.current !== runId) return; + + if (cancelRef.current) { + setResults((prev) => { + const next = { ...prev }; + for (const id of targets) { + if (next[id]?.status === "pending") { + next[id] = { status: "canceled" }; + } + } + return next; + }); + } + setIsRunning(false); + }, + [setRow] + ); + + const cancel = useCallback(() => { + cancelRef.current = true; + }, []); + + const reset = useCallback(() => { + runIdRef.current += 1; + cancelRef.current = false; + setResults({}); + setIsRunning(false); + }, []); + + return { results, isRunning, run, cancel, reset }; +} diff --git a/src/app/[locale]/settings/providers/_components/provider-manager.tsx b/src/app/[locale]/settings/providers/_components/provider-manager.tsx index e099d686d..aab19b4d4 100644 --- a/src/app/[locale]/settings/providers/_components/provider-manager.tsx +++ b/src/app/[locale]/settings/providers/_components/provider-manager.tsx @@ -41,6 +41,7 @@ import { ProviderBatchDialog, ProviderBatchToolbar, } from "./batch-edit"; +import { BatchTestDialog } from "./batch-test"; import { ProviderForm } from "./forms/provider-form"; import { ProviderFormDialogContent } from "./provider-form-dialog-content"; import { ProviderGroupTab } from "./provider-group-tab"; @@ -109,6 +110,7 @@ export function ProviderManager({ const [selectedProviderIds, setSelectedProviderIds] = useState>(new Set()); const [batchDialogOpen, setBatchDialogOpen] = useState(false); const [batchActionMode, setBatchActionMode] = useState(null); + const [batchTestOpen, setBatchTestOpen] = useState(false); const [editingProviderId, setEditingProviderId] = useState(null); // Helper: check if a provider has any circuit open (key-level or endpoint-level) @@ -310,10 +312,20 @@ export function ProviderManager({ }, []); const handleBatchAction = useCallback((mode: BatchActionMode) => { + if (mode === "test") { + setBatchTestOpen(true); + return; + } setBatchActionMode(mode); setBatchDialogOpen(true); }, []); + // 批量测试基于全量列表取已选项:筛选条件变化不会丢失已勾选的供应商 + const selectedProviders = useMemo( + () => providers.filter((p) => selectedProviderIds.has(p.id)), + [providers, selectedProviderIds] + ); + const handleSelectByType = useCallback( (type: ProviderType) => { setSelectedProviderIds((prev) => { @@ -709,6 +721,12 @@ export function ProviderManager({ onSuccess={handleBatchSuccess} /> + + !open && setEditingProviderId(null)} diff --git a/src/app/api/v1/resources/providers/handlers.ts b/src/app/api/v1/resources/providers/handlers.ts index e9575468c..d4c108cb3 100644 --- a/src/app/api/v1/resources/providers/handlers.ts +++ b/src/app/api/v1/resources/providers/handlers.ts @@ -37,6 +37,7 @@ import { ProviderModelSuggestionsQuerySchema, ProviderProxyTestSchema, type ProviderSummaryResponse, + ProviderTestByIdSchema, ProviderTypeQuerySchema, ProviderUndoBodySchema, ProviderUnifiedTestSchema, @@ -428,6 +429,28 @@ export async function testProviderUnified(c: Context): Promise { ); } +export async function testProviderById(c: Context): Promise { + const id = Number(c.req.param("id")); + if (!Number.isInteger(id) || id <= 0) { + return createProblemResponse({ + status: 400, + instance: new URL(c.req.url).pathname, + errorCode: "request.validation_failed", + detail: "Provider id is invalid.", + }); + } + const body = await parseJson(c, ProviderTestByIdSchema); + if (body instanceof Response) return body; + const existing = await findVisibleProvider(c, id); + if (existing instanceof Response) return existing; + if (!existing) return providerNotFound(c); + const providerActions = await import("@/actions/providers"); + return actionJson( + c, + await callAction(c, providerActions.testProviderById, [id, body] as never[], c.get("auth")) + ); +} + export async function testProviderAnthropic(c: Context): Promise { return callProviderTest(c, ProviderApiTestSchema, "testProviderAnthropicMessages"); } diff --git a/src/app/api/v1/resources/providers/router.ts b/src/app/api/v1/resources/providers/router.ts index 26124f273..b3db03f7e 100644 --- a/src/app/api/v1/resources/providers/router.ts +++ b/src/app/api/v1/resources/providers/router.ts @@ -22,6 +22,7 @@ import { ProviderModelSuggestionsQuerySchema, ProviderProxyTestSchema, ProviderSummarySchema, + ProviderTestByIdSchema, ProviderTypeQuerySchema, ProviderUndoBodySchema, ProviderUnifiedTestSchema, @@ -50,6 +51,7 @@ import { resetProviderUsage, revealProviderKey, testProviderAnthropic, + testProviderById, testProviderGemini, testProviderOpenAIChat, testProviderOpenAIResponses, @@ -657,6 +659,35 @@ providersRouter.openapi( testProviderUnified as never ); +providersRouter.openapi( + createRoute({ + method: "post", + path: "/providers/{id}/test", + middleware: requireAuth("admin"), + tags: ["Providers"], + summary: "Run provider test by id", + description: + "Runs the unified relay-style provider API test against a stored provider using its saved credentials and proxy configuration.", + "x-required-access": "admin", + security, + request: { + params: ProviderIdParamSchema, + body: { + required: true, + content: { "application/json": { schema: ProviderTestByIdSchema } }, + }, + }, + responses: { + 200: { + description: "Unified test result.", + content: { "application/json": { schema: ProviderGenericResponseSchema } }, + }, + ...problemResponses, + }, + }), + testProviderById as never +); + providersRouter.openapi( createRoute({ method: "post", diff --git a/src/lib/api-client/v1/actions/providers.ts b/src/lib/api-client/v1/actions/providers.ts index 7ceee8c4c..53bd2db19 100644 --- a/src/lib/api-client/v1/actions/providers.ts +++ b/src/lib/api-client/v1/actions/providers.ts @@ -204,6 +204,12 @@ export function testProviderUnified(data: unknown) { return toActionResult(apiPost("/api/v1/providers/test:unified", data, dashboardCompatOptions)); } +export function testProviderById(providerId: number, data?: { model?: string }) { + return toActionResult( + apiPost(`/api/v1/providers/${providerId}/test`, data ?? {}, dashboardCompatOptions) + ); +} + export function getProviderTestPresets(providerType: string) { return toActionResult( apiGet( diff --git a/src/lib/api-client/v1/openapi-types.gen.ts b/src/lib/api-client/v1/openapi-types.gen.ts index ef097e7bf..5306ae919 100644 --- a/src/lib/api-client/v1/openapi-types.gen.ts +++ b/src/lib/api-client/v1/openapi-types.gen.ts @@ -504,6 +504,26 @@ export interface paths { patch?: never; trace?: never; }; + "/api/v1/providers/{id}/test": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + /** + * Run provider test by id + * @description Runs the unified relay-style provider API test against a stored provider using its saved credentials and proxy configuration. + */ + post: operations["postProvidersByIdTest"]; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/providers/test:anthropic-messages": { parameters: { query?: never; @@ -9136,6 +9156,189 @@ export interface operations { }; }; }; + postProvidersByIdTest: { + parameters: { + query?: never; + header?: { + /** @description Required only when authenticating with the auth-token cookie on mutation requests. */ + "X-CCH-CSRF"?: string; + }; + path: { + /** @description Provider id. */ + id: number; + }; + cookie?: never; + }; + requestBody: { + content: { + "application/json": { + /** @description Optional model override. */ + model?: string; + }; + }; + }; + responses: { + /** @description Unified test result. */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": { + [key: string]: unknown; + }; + }; + }; + /** @description Invalid request. */ + 400: { + headers: { + [name: string]: unknown; + }; + content: { + "application/problem+json": { + /** @description Stable problem type URI or URN. */ + type: string; + /** @description Short problem title. */ + title: string; + /** @description HTTP status code. */ + status: number; + /** @description Human-readable error detail. */ + detail: string; + /** @description Request path that produced the problem. */ + instance: string; + /** @description Application error code for frontend i18n. */ + errorCode: string; + /** @description Optional i18n parameters. */ + errorParams?: { + [key: string]: unknown; + }; + /** @description Optional request trace identifier. */ + traceId?: string; + /** @description Validation failure details. */ + invalidParams?: { + /** @description Path to the invalid input field. */ + path: (string | number)[]; + /** @description Machine-readable validation error code. */ + code: string; + /** @description Validation error message. */ + message: string; + }[]; + }; + }; + }; + /** @description Authentication required. */ + 401: { + headers: { + [name: string]: unknown; + }; + content: { + "application/problem+json": { + /** @description Stable problem type URI or URN. */ + type: string; + /** @description Short problem title. */ + title: string; + /** @description HTTP status code. */ + status: number; + /** @description Human-readable error detail. */ + detail: string; + /** @description Request path that produced the problem. */ + instance: string; + /** @description Application error code for frontend i18n. */ + errorCode: string; + /** @description Optional i18n parameters. */ + errorParams?: { + [key: string]: unknown; + }; + /** @description Optional request trace identifier. */ + traceId?: string; + /** @description Validation failure details. */ + invalidParams?: { + /** @description Path to the invalid input field. */ + path: (string | number)[]; + /** @description Machine-readable validation error code. */ + code: string; + /** @description Validation error message. */ + message: string; + }[]; + }; + }; + }; + /** @description Admin access required. */ + 403: { + headers: { + [name: string]: unknown; + }; + content: { + "application/problem+json": { + /** @description Stable problem type URI or URN. */ + type: string; + /** @description Short problem title. */ + title: string; + /** @description HTTP status code. */ + status: number; + /** @description Human-readable error detail. */ + detail: string; + /** @description Request path that produced the problem. */ + instance: string; + /** @description Application error code for frontend i18n. */ + errorCode: string; + /** @description Optional i18n parameters. */ + errorParams?: { + [key: string]: unknown; + }; + /** @description Optional request trace identifier. */ + traceId?: string; + /** @description Validation failure details. */ + invalidParams?: { + /** @description Path to the invalid input field. */ + path: (string | number)[]; + /** @description Machine-readable validation error code. */ + code: string; + /** @description Validation error message. */ + message: string; + }[]; + }; + }; + }; + /** @description Provider not found. */ + 404: { + headers: { + [name: string]: unknown; + }; + content: { + "application/problem+json": { + /** @description Stable problem type URI or URN. */ + type: string; + /** @description Short problem title. */ + title: string; + /** @description HTTP status code. */ + status: number; + /** @description Human-readable error detail. */ + detail: string; + /** @description Request path that produced the problem. */ + instance: string; + /** @description Application error code for frontend i18n. */ + errorCode: string; + /** @description Optional i18n parameters. */ + errorParams?: { + [key: string]: unknown; + }; + /** @description Optional request trace identifier. */ + traceId?: string; + /** @description Validation failure details. */ + invalidParams?: { + /** @description Path to the invalid input field. */ + path: (string | number)[]; + /** @description Machine-readable validation error code. */ + code: string; + /** @description Validation error message. */ + message: string; + }[]; + }; + }; + }; + }; + }; postProvidersTestAnthropicMessages: { parameters: { query?: never; diff --git a/src/lib/api/v1/schemas/providers.ts b/src/lib/api/v1/schemas/providers.ts index 83aad458b..f4cdd1d14 100644 --- a/src/lib/api/v1/schemas/providers.ts +++ b/src/lib/api/v1/schemas/providers.ts @@ -310,6 +310,12 @@ export const ProviderUnifiedTestSchema = ProviderApiTestSchema.extend({ customHeaders: z.record(z.string(), z.string()).optional().describe("Optional custom headers."), }).strict(); +export const ProviderTestByIdSchema = z + .object({ + model: z.string().trim().min(1).optional().describe("Optional model override."), + }) + .strict(); + export const ProviderTypeQuerySchema = z.object({ providerType: ProviderTypeSchema.describe("Provider type."), }); diff --git a/src/lib/provider-testing/test-service.ts b/src/lib/provider-testing/test-service.ts index f686a31ff..4ec3b4f82 100644 --- a/src/lib/provider-testing/test-service.ts +++ b/src/lib/provider-testing/test-service.ts @@ -57,7 +57,9 @@ function buildAttemptPlans(config: ProviderTestConfig): AttemptPlan[] { { body: parsed, headers: { - ...getTestHeaders(config.providerType, config.apiKey, config.providerUrl), + ...getTestHeaders(config.providerType, config.apiKey, config.providerUrl, { + geminiBearerAuth: config.geminiBearerAuth, + }), ...(config.customHeaders || {}), }, model: config.model, @@ -90,7 +92,9 @@ function buildAttemptPlans(config: ProviderTestConfig): AttemptPlan[] { { body: getTestBody(config.providerType, config.model), headers: { - ...getTestHeaders(config.providerType, config.apiKey, config.providerUrl), + ...getTestHeaders(config.providerType, config.apiKey, config.providerUrl, { + geminiBearerAuth: config.geminiBearerAuth, + }), ...(config.customHeaders || {}), }, model: config.model, @@ -109,6 +113,7 @@ function buildAttemptPlans(config: ProviderTestConfig): AttemptPlan[] { ...getTestHeaders(config.providerType, config.apiKey, config.providerUrl, { userAgent: preset.userAgent, extraHeaders: preset.extraHeaders, + geminiBearerAuth: config.geminiBearerAuth, }), ...(config.customHeaders || {}), }, diff --git a/src/lib/provider-testing/types.ts b/src/lib/provider-testing/types.ts index 7d216a7e4..5953d1ba8 100644 --- a/src/lib/provider-testing/types.ts +++ b/src/lib/provider-testing/types.ts @@ -93,6 +93,8 @@ export interface ProviderTestConfig { successContains?: string; /** Request timeout in ms (default: 10000) */ timeoutMs?: number; + /** Send the Gemini key as Authorization Bearer instead of x-goog-api-key (JSON credentials) */ + geminiBearerAuth?: boolean; // =========== Custom Configuration Fields =========== diff --git a/src/lib/provider-testing/utils/test-prompts.ts b/src/lib/provider-testing/utils/test-prompts.ts index ba6ce56c2..a5ccfd2d9 100644 --- a/src/lib/provider-testing/utils/test-prompts.ts +++ b/src/lib/provider-testing/utils/test-prompts.ts @@ -167,6 +167,7 @@ export function getTestHeaders( overrides?: { userAgent?: string; extraHeaders?: Record; + geminiBearerAuth?: boolean; } ): Record { const headers: Record = { @@ -198,10 +199,15 @@ export function getTestHeaders( break; case "gemini": case "gemini-cli": - Object.assign(headers, { - ...GEMINI_TEST_HEADERS, - "x-goog-api-key": apiKey, - }); + Object.assign( + headers, + GEMINI_TEST_HEADERS, + // JSON credentials are exchanged for an OAuth access token upstream, + // which Gemini only accepts as a Bearer token + overrides?.geminiBearerAuth + ? { Authorization: `Bearer ${apiKey}` } + : { "x-goog-api-key": apiKey } + ); break; default: throw new Error(`Unsupported provider type: ${providerType}`); diff --git a/tests/unit/actions/providers-test-by-id.test.ts b/tests/unit/actions/providers-test-by-id.test.ts new file mode 100644 index 000000000..dd6802a69 --- /dev/null +++ b/tests/unit/actions/providers-test-by-id.test.ts @@ -0,0 +1,276 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; +import type { Provider } from "@/types/provider"; + +const getSessionMock = vi.fn(); +const executeProviderTestMock = vi.fn(); +const findProviderByIdMock = vi.fn(); +const getPresetsForProviderMock = vi.fn(); +const validateProviderUrlForConnectivityMock = vi.fn(); +const createProxyAgentForProviderMock = vi.fn(); + +vi.mock("@/lib/auth", () => ({ + getSession: getSessionMock, +})); + +vi.mock("@/repository/provider", () => ({ + createProvider: vi.fn(), + deleteProvider: vi.fn(), + findAllProviders: vi.fn(async () => []), + findAllProvidersFresh: vi.fn(async () => []), + findProviderById: findProviderByIdMock, + getProviderStatistics: vi.fn(), + resetProviderTotalCostResetAt: vi.fn(async () => {}), + updateProvider: vi.fn(), + updateProviderPrioritiesBatch: vi.fn(), +})); + +vi.mock("@/lib/cache/provider-cache", () => ({ + publishProviderCacheInvalidation: vi.fn(), +})); + +vi.mock("@/lib/redis/circuit-breaker-config", () => ({ + deleteProviderCircuitConfig: vi.fn(), + saveProviderCircuitConfig: vi.fn(), +})); + +vi.mock("@/lib/circuit-breaker", () => ({ + clearConfigCache: vi.fn(), + clearProviderState: vi.fn(), + getAllHealthStatusAsync: vi.fn(async () => ({})), + publishCircuitBreakerConfigInvalidation: vi.fn(), + forceCloseCircuitState: vi.fn(), + resetCircuit: vi.fn(), +})); + +vi.mock("@/lib/session-manager", () => ({ + SessionManager: { + terminateProviderSessionsBatch: vi.fn(), + terminateStickySessionsForProviders: vi.fn(), + }, +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + trace: vi.fn(), + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }, +})); + +vi.mock("next/cache", () => ({ + revalidatePath: vi.fn(), +})); + +vi.mock("@/lib/provider-testing", () => ({ + executeProviderTest: executeProviderTestMock, +})); + +vi.mock("@/lib/provider-testing/presets", () => ({ + getPresetsForProvider: getPresetsForProviderMock, +})); + +vi.mock("@/lib/validation/provider-url", () => ({ + validateProviderUrlForConnectivity: validateProviderUrlForConnectivityMock, +})); + +vi.mock("@/lib/proxy-agent", () => ({ + createProxyAgentForProvider: createProxyAgentForProviderMock, + isValidProxyUrl: vi.fn(() => true), +})); + +const geminiGetAccessTokenMock = vi.fn(async (apiKey: string) => apiKey); +const geminiIsJsonMock = vi.fn(() => false); + +vi.mock("@/app/v1/_lib/gemini/auth", () => ({ + GeminiAuth: { + getAccessToken: geminiGetAccessTokenMock, + isJson: geminiIsJsonMock, + }, +})); + +function buildProvider(overrides: Partial = {}): Provider { + return { + id: 7, + name: "p-claude", + url: "https://api.example.com", + key: "sk-stored-secret", + providerType: "claude", + proxyUrl: null, + proxyFallbackToDirect: false, + customHeaders: null, + ...overrides, + } as Provider; +} + +const GREEN_RESULT = { + success: true, + status: "green", + subStatus: "success", + latencyMs: 88, + firstByteMs: 30, + httpStatusCode: 200, + httpStatusText: "OK", + model: "claude-sonnet-4-5", + content: "pong", + rawResponse: '{"content":"pong"}', + requestUrl: "https://api.example.com/v1/messages", + testedAt: new Date("2026-06-12T00:00:00.000Z"), + validationDetails: { + httpPassed: true, + httpStatusCode: 200, + latencyPassed: true, + latencyMs: 88, + contentPassed: true, + contentTarget: "pong", + }, +}; + +describe("testProviderById", () => { + beforeEach(() => { + vi.clearAllMocks(); + getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } }); + validateProviderUrlForConnectivityMock.mockImplementation((providerUrl: string) => ({ + valid: true, + normalizedUrl: providerUrl, + })); + createProxyAgentForProviderMock.mockReturnValue(null); + getPresetsForProviderMock.mockReturnValue([]); + findProviderByIdMock.mockResolvedValue(buildProvider()); + executeProviderTestMock.mockResolvedValue(GREEN_RESULT); + geminiGetAccessTokenMock.mockImplementation(async (apiKey: string) => apiKey); + geminiIsJsonMock.mockReturnValue(false); + }); + + test("非 admin 会话应返回未授权且不执行测试", async () => { + getSessionMock.mockResolvedValue({ user: { id: 2, role: "user" } }); + + const { testProviderById } = await import("@/actions/providers"); + const result = await testProviderById(7); + + expect(result.ok).toBe(false); + expect(executeProviderTestMock).not.toHaveBeenCalled(); + expect(findProviderByIdMock).not.toHaveBeenCalled(); + }); + + test("供应商不存在时返回 provider.not_found", async () => { + findProviderByIdMock.mockResolvedValue(null); + + const { testProviderById } = await import("@/actions/providers"); + const result = await testProviderById(404); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.errorCode).toBe("provider.not_found"); + } + expect(executeProviderTestMock).not.toHaveBeenCalled(); + }); + + test("URL 校验失败时不执行测试", async () => { + validateProviderUrlForConnectivityMock.mockReturnValue({ + valid: false, + error: { message: "blocked url" }, + }); + + const { testProviderById } = await import("@/actions/providers"); + const result = await testProviderById(7); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toBe("blocked url"); + } + expect(executeProviderTestMock).not.toHaveBeenCalled(); + }); + + test("使用库内配置执行测试,密钥来自数据库", async () => { + findProviderByIdMock.mockResolvedValue( + buildProvider({ + proxyUrl: "http://proxy.local:8080", + proxyFallbackToDirect: true, + customHeaders: { "x-extra": "1" }, + }) + ); + + const { testProviderById } = await import("@/actions/providers"); + const result = await testProviderById(7, { model: " claude-sonnet-4-5 " }); + + expect(result.ok).toBe(true); + expect(executeProviderTestMock).toHaveBeenCalledTimes(1); + const config = executeProviderTestMock.mock.calls[0]?.[0]; + expect(config).toMatchObject({ + providerId: "7", + providerUrl: "https://api.example.com", + apiKey: "sk-stored-secret", + providerType: "claude", + model: "claude-sonnet-4-5", + proxyUrl: "http://proxy.local:8080", + proxyFallbackToDirect: true, + customHeaders: { "x-extra": "1" }, + timeoutMs: 15000, + }); + if (result.ok) { + expect(result.data?.status).toBe("green"); + expect(result.data?.testedAt).toBe("2026-06-12T00:00:00.000Z"); + } + }); + + test("空白 model 覆盖会被忽略并回退到类型默认", async () => { + const { testProviderById } = await import("@/actions/providers"); + const result = await testProviderById(7, { model: " " }); + + expect(result.ok).toBe(true); + const config = executeProviderTestMock.mock.calls[0]?.[0]; + expect(config?.model).toBeUndefined(); + }); + + test("gemini 类型使用 60 秒超时", async () => { + findProviderByIdMock.mockResolvedValue(buildProvider({ providerType: "gemini" })); + + const { testProviderById } = await import("@/actions/providers"); + await testProviderById(7); + + const config = executeProviderTestMock.mock.calls[0]?.[0]; + expect(config?.timeoutMs).toBe(60000); + }); + + test("gemini JSON 凭证转换为 access token 并使用 Bearer 认证", async () => { + const jsonKey = JSON.stringify({ type: "authorized_user", access_token: "ya29.token" }); + findProviderByIdMock.mockResolvedValue( + buildProvider({ providerType: "gemini-cli", key: jsonKey }) + ); + geminiGetAccessTokenMock.mockResolvedValue("ya29.token"); + geminiIsJsonMock.mockReturnValue(true); + + const { testProviderById } = await import("@/actions/providers"); + const result = await testProviderById(7); + + expect(result.ok).toBe(true); + expect(geminiGetAccessTokenMock).toHaveBeenCalledWith(jsonKey); + const config = executeProviderTestMock.mock.calls[0]?.[0]; + expect(config?.apiKey).toBe("ya29.token"); + expect(config?.geminiBearerAuth).toBe(true); + }); + + test("非 gemini 类型不做凭证预处理", async () => { + const { testProviderById } = await import("@/actions/providers"); + await testProviderById(7); + + expect(geminiGetAccessTokenMock).not.toHaveBeenCalled(); + const config = executeProviderTestMock.mock.calls[0]?.[0]; + expect(config?.apiKey).toBe("sk-stored-secret"); + expect(config?.geminiBearerAuth).toBeUndefined(); + }); + + test("executeProviderTest 抛错时返回失败结果", async () => { + executeProviderTestMock.mockRejectedValue(new Error("upstream exploded")); + + const { testProviderById } = await import("@/actions/providers"); + const result = await testProviderById(7); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toBe("upstream exploded"); + } + }); +}); diff --git a/tests/unit/provider-testing/test-prompts-headers.test.ts b/tests/unit/provider-testing/test-prompts-headers.test.ts index 22988bd78..b62c1580e 100644 --- a/tests/unit/provider-testing/test-prompts-headers.test.ts +++ b/tests/unit/provider-testing/test-prompts-headers.test.ts @@ -40,3 +40,19 @@ describe("provider-testing getTestHeaders — Anthropic auth header selection", expect(headers.Authorization).toBeUndefined(); }); }); + +describe("provider-testing getTestHeaders — Gemini auth header selection", () => { + it("sends x-goog-api-key by default for gemini", () => { + const headers = getTestHeaders("gemini", "AIza-test", "https://gemini.example.com"); + expect(headers["x-goog-api-key"]).toBe("AIza-test"); + expect(headers.Authorization).toBeUndefined(); + }); + + it("sends Bearer-only when geminiBearerAuth is set (JSON credentials)", () => { + const headers = getTestHeaders("gemini-cli", "ya29.token", "https://gemini.example.com", { + geminiBearerAuth: true, + }); + expect(headers.Authorization).toBe("Bearer ya29.token"); + expect(headers["x-goog-api-key"]).toBeUndefined(); + }); +}); diff --git a/tests/unit/settings/providers/provider-manager.test.tsx b/tests/unit/settings/providers/provider-manager.test.tsx index 8f383e147..bd7d4240b 100644 --- a/tests/unit/settings/providers/provider-manager.test.tsx +++ b/tests/unit/settings/providers/provider-manager.test.tsx @@ -24,6 +24,11 @@ vi.mock("@/app/[locale]/settings/providers/_components/batch-edit", () => ({ ProviderBatchToolbar: () => null, })); +// Batch-test dialog (requires QueryClientProvider, irrelevant to this test scope) +vi.mock("@/app/[locale]/settings/providers/_components/batch-test", () => ({ + BatchTestDialog: () => null, +})); + // ProviderList -- render a simple list so we can inspect filtered output vi.mock("@/app/[locale]/settings/providers/_components/provider-list", () => ({ ProviderList: ({ providers }: { providers: ProviderDisplay[] }) => ( diff --git a/tests/unit/settings/providers/use-batch-provider-test.test.tsx b/tests/unit/settings/providers/use-batch-provider-test.test.tsx new file mode 100644 index 000000000..c5a32c7ff --- /dev/null +++ b/tests/unit/settings/providers/use-batch-provider-test.test.tsx @@ -0,0 +1,177 @@ +/** + * @vitest-environment happy-dom + */ + +import { act } from "react"; +import { createRoot, type Root } from "react-dom/client"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { + BATCH_TEST_CONCURRENCY, + type UseBatchProviderTestResult, + useBatchProviderTest, +} from "@/app/[locale]/settings/providers/_components/batch-test/use-batch-provider-test"; + +const { testProviderByIdMock } = vi.hoisted(() => ({ + testProviderByIdMock: vi.fn(), +})); + +vi.mock("@/lib/api-client/v1/actions/providers", () => ({ + testProviderById: testProviderByIdMock, +})); + +function greenData(latencyMs = 100) { + return { + success: true, + status: "green" as const, + subStatus: "success", + message: "ok", + latencyMs, + httpStatusCode: 200, + model: "claude-sonnet-4-5", + }; +} + +function redData() { + return { + success: false, + status: "red" as const, + subStatus: "auth_error", + message: "auth failed", + latencyMs: 50, + httpStatusCode: 401, + errorMessage: "Invalid key", + }; +} + +describe("useBatchProviderTest", () => { + let hook: UseBatchProviderTestResult; + let root: Root; + let container: HTMLDivElement; + + function HookProbe() { + hook = useBatchProviderTest(); + return null; + } + + beforeEach(() => { + vi.clearAllMocks(); + container = document.createElement("div"); + document.body.appendChild(container); + root = createRoot(container); + act(() => { + root.render(); + }); + + return () => { + act(() => { + root.unmount(); + }); + container.remove(); + }; + }); + + test("按结果状态记录每个供应商:green/yellow/red 与失败信息", async () => { + testProviderByIdMock.mockImplementation(async (providerId: number) => { + if (providerId === 1) return { ok: true, data: greenData(80) }; + if (providerId === 2) return { ok: true, data: { ...greenData(6000), status: "yellow" } }; + if (providerId === 3) return { ok: true, data: redData() }; + return { ok: false, error: "network down" }; + }); + + await act(async () => { + await hook.run([1, 2, 3, 4]); + }); + + expect(hook.isRunning).toBe(false); + expect(hook.results[1]).toMatchObject({ status: "green", latencyMs: 80 }); + expect(hook.results[2]).toMatchObject({ status: "yellow" }); + expect(hook.results[3]).toMatchObject({ status: "red", message: "Invalid key" }); + expect(hook.results[4]).toMatchObject({ status: "error", message: "network down" }); + }); + + test("并发执行不超过上限", async () => { + let inFlight = 0; + let maxInFlight = 0; + testProviderByIdMock.mockImplementation(async () => { + inFlight += 1; + maxInFlight = Math.max(maxInFlight, inFlight); + await new Promise((resolve) => setTimeout(resolve, 1)); + inFlight -= 1; + return { ok: true, data: greenData() }; + }); + + const ids = Array.from({ length: 12 }, (_, index) => index + 1); + await act(async () => { + await hook.run(ids); + }); + + expect(testProviderByIdMock).toHaveBeenCalledTimes(12); + expect(maxInFlight).toBeLessThanOrEqual(BATCH_TEST_CONCURRENCY); + }); + + test("model 覆盖会去除空白后传给接口,空白则不传", async () => { + testProviderByIdMock.mockResolvedValue({ ok: true, data: greenData() }); + + await act(async () => { + await hook.run([1], " claude-sonnet-4-5 "); + }); + expect(testProviderByIdMock).toHaveBeenLastCalledWith(1, { model: "claude-sonnet-4-5" }); + + await act(async () => { + await hook.run([1], " "); + }); + expect(testProviderByIdMock).toHaveBeenLastCalledWith(1, undefined); + }); + + test("取消后不再发起新请求,剩余标记为 canceled,已发出的保留结果", async () => { + const resolvers: Array<() => void> = []; + testProviderByIdMock.mockImplementation( + () => + new Promise((resolve) => { + resolvers.push(() => resolve({ ok: true, data: greenData() })); + }) + ); + + const ids = Array.from({ length: 8 }, (_, index) => index + 1); + let runPromise: Promise = Promise.resolve(); + await act(async () => { + runPromise = hook.run(ids); + // Let the first wave of workers start + await Promise.resolve(); + }); + + expect(testProviderByIdMock).toHaveBeenCalledTimes(BATCH_TEST_CONCURRENCY); + + act(() => { + hook.cancel(); + }); + + await act(async () => { + for (const resolve of resolvers) resolve(); + await runPromise; + }); + + // No new requests were launched after cancel + expect(testProviderByIdMock).toHaveBeenCalledTimes(BATCH_TEST_CONCURRENCY); + const statuses = ids.map((id) => hook.results[id]?.status); + expect(statuses.filter((status) => status === "green")).toHaveLength(BATCH_TEST_CONCURRENCY); + expect(statuses.filter((status) => status === "canceled")).toHaveLength( + ids.length - BATCH_TEST_CONCURRENCY + ); + expect(hook.isRunning).toBe(false); + }); + + test("reset 清空结果并结束运行状态", async () => { + testProviderByIdMock.mockResolvedValue({ ok: true, data: greenData() }); + await act(async () => { + await hook.run([1]); + }); + expect(Object.keys(hook.results)).toHaveLength(1); + + act(() => { + hook.reset(); + }); + expect(hook.results).toEqual({}); + expect(hook.isRunning).toBe(false); + }); +}); From 1fe972b4cf465c9a45a32b2abc5409b692ad8de9 Mon Sep 17 00:00:00 2001 From: Brisbanehuang Date: Sat, 13 Jun 2026 16:27:08 +0800 Subject: [PATCH 2/3] fix(proxy): cap client abort drain window (#1277) * fix(proxy): cap client abort drain window * fix(proxy): decouple client abort drain from idle timeout * fix(proxy): enforce idle timeout during abort drain * fix(proxy): preserve abort drain idle deadline * fix(proxy): keep abort-drain idle as client abort --- src/app/v1/_lib/proxy/response-handler.ts | 153 +++++++------ ...esponse-handler-client-abort-drain.test.ts | 206 +++++++++++++++++- 2 files changed, 287 insertions(+), 72 deletions(-) diff --git a/src/app/v1/_lib/proxy/response-handler.ts b/src/app/v1/_lib/proxy/response-handler.ts index e4cb75244..cad152770 100644 --- a/src/app/v1/_lib/proxy/response-handler.ts +++ b/src/app/v1/_lib/proxy/response-handler.ts @@ -56,6 +56,8 @@ import { peekDeferredStreamingFinalization, } from "./stream-finalization"; +const CLIENT_ABORT_DRAIN_MAX_MS = 60_000; + /** * Idempotent helper to release the agent pool reference count attached to a session. * Prevents double-release by clearing the callback after first invocation. @@ -2300,7 +2302,7 @@ export class ProxyResponseHandler { } } - // ⭐ 使用 TransformStream 包装流,以便在 idle timeout 时能关闭客户端流 + // 使用 TransformStream 包装流,以便在 idle timeout 时能关闭客户端流 // 这解决了 tee() 后 internalStream abort 不影响 clientStream 的问题 let streamController: TransformStreamDefaultController | null = null; const controllableStream = processedStream.pipeThrough( @@ -2322,17 +2324,76 @@ export class ProxyResponseHandler { const abortController = new AbortController(); const idleTimeoutMs = provider.streamingIdleTimeoutMs > 0 ? provider.streamingIdleTimeoutMs : Infinity; - const clientAbortDrainTimeoutMs = idleTimeoutMs === Infinity ? 60_000 : idleTimeoutMs; + const clientAbortDrainTimeoutMs = CLIENT_ABORT_DRAIN_MAX_MS; - // ⭐ 提升 idleTimeoutId 到外部作用域,以便客户端断开时能清除 + // 提升 idleTimeoutId 到外部作用域,以便客户端断开时能清除 let idleTimeoutId: NodeJS.Timeout | null = null; let clientAbortDrainTimeoutId: NodeJS.Timeout | null = null; + const chunks: string[] = []; const clearClientAbortDrainTimer = () => { if (clientAbortDrainTimeoutId) { clearTimeout(clientAbortDrainTimeoutId); clientAbortDrainTimeoutId = null; } }; + const clearIdleTimer = () => { + if (idleTimeoutId) { + clearTimeout(idleTimeoutId); + idleTimeoutId = null; + } + }; + const startIdleTimer = () => { + if (idleTimeoutMs === Infinity) return; // 禁用时跳过 + clearIdleTimer(); // 清除旧的 + idleTimeoutId = setTimeout(() => { + logger.warn("ResponseHandler: Streaming idle timeout triggered", { + taskId, + providerId: provider.id, + idleTimeoutMs, + chunksCollected: chunks.length, + }); + + // 1. 关闭客户端流(让客户端收到连接关闭通知,避免悬挂) + try { + if (streamController) { + streamController.error(new Error("Streaming idle timeout")); + logger.debug("ResponseHandler: Client stream closed due to idle timeout", { + taskId, + providerId: provider.id, + }); + } + } catch (e) { + logger.warn("ResponseHandler: Failed to close client stream", { + taskId, + providerId: provider.id, + error: e, + }); + } + + // 2. 终止上游连接(避免资源泄漏) + try { + const sessionWithController = session as typeof session & { + responseController?: AbortController; + }; + if (sessionWithController.responseController) { + sessionWithController.responseController.abort(new Error("streaming_idle")); + logger.debug("ResponseHandler: Upstream connection aborted due to idle timeout", { + taskId, + providerId: provider.id, + }); + } + } catch (e) { + logger.warn("ResponseHandler: Failed to abort upstream connection", { + taskId, + providerId: provider.id, + error: e, + }); + } + + // 3. 终止后台读取任务 + abortController.abort(new Error("streaming_idle")); + }, idleTimeoutMs); + }; const cleanupClientAbortListener = bindClientAbortListener(session.clientAbortSignal, () => { logger.debug("ResponseHandler: Client disconnected, cleaning up", { taskId, @@ -2344,6 +2405,9 @@ export class ProxyResponseHandler { // still drain buffered final usage and record the request as successful. // Idle/response timeout paths still abort via abortController. clearClientAbortDrainTimer(); + if (!idleTimeoutId) { + startIdleTimer(); + } clientAbortDrainTimeoutId = setTimeout(() => { logger.info("ResponseHandler: Client abort drain window exceeded", { taskId, @@ -2375,71 +2439,13 @@ export class ProxyResponseHandler { // 注意:即使 STORE_SESSION_RESPONSE_BODY=false(不写入 Redis),这里也会在内存中累积完整流内容: // - 用于解析 usage/cost 与内部结算(例如“假 200”检测) // 因此该开关仅影响“是否持久化”,不用于控制流式内存占用。 - const chunks: string[] = []; let usageForCost: UsageMetrics | null = null; - let isFirstChunk = true; // ⭐ 标记是否为第一块数据 + let isFirstChunk = true; // 标记是否为第一块数据 - const startIdleTimer = () => { - if (idleTimeoutMs === Infinity) return; // 禁用时跳过 - clearIdleTimer(); // 清除旧的 - idleTimeoutId = setTimeout(() => { - logger.warn("ResponseHandler: Streaming idle timeout triggered", { - taskId, - providerId: provider.id, - idleTimeoutMs, - chunksCollected: chunks.length, - }); - - // ⭐ 1. 关闭客户端流(让客户端收到连接关闭通知,避免悬挂) - try { - if (streamController) { - streamController.error(new Error("Streaming idle timeout")); - logger.debug("ResponseHandler: Client stream closed due to idle timeout", { - taskId, - providerId: provider.id, - }); - } - } catch (e) { - logger.warn("ResponseHandler: Failed to close client stream", { - taskId, - providerId: provider.id, - error: e, - }); - } - - // ⭐ 2. 终止上游连接(避免资源泄漏) - try { - const sessionWithController = session as typeof session & { - responseController?: AbortController; - }; - if (sessionWithController.responseController) { - sessionWithController.responseController.abort(new Error("streaming_idle")); - logger.debug("ResponseHandler: Upstream connection aborted due to idle timeout", { - taskId, - providerId: provider.id, - }); - } - } catch (e) { - logger.warn("ResponseHandler: Failed to abort upstream connection", { - taskId, - providerId: provider.id, - error: e, - }); - } - - // ⭐ 3. 终止后台读取任务 - abortController.abort(new Error("streaming_idle")); - }, idleTimeoutMs); - }; - const clearIdleTimer = () => { - if (idleTimeoutId) { - clearTimeout(idleTimeoutId); - idleTimeoutId = null; - } - }; - - // ⭐ 不在首次读取前启动 idle timer(避免与首字节超时职责重叠) - // idle timer 仅在首块数据到达后启动,用于检测流中途静默 + // 不在首次读取前启动 idle timer(避免与首字节超时职责重叠) + // idle timer 仅在首块数据到达后启动,用于检测流中途静默。 + // 客户端断开后例外:后台 drain 也会启动 idle timer,避免 pre-body + // 静默一直等到 60s drain 总上限。 const flushAndJoin = (): string => { const flushed = decoder.decode(); @@ -2768,7 +2774,7 @@ export class ProxyResponseHandler { const chunkSize = value.length; chunks.push(decoder.decode(value, { stream: true })); - // ⭐ 每次收到数据后重置静默期计时器(首次收到数据时启动) + // 每次收到数据后重置静默期计时器(首次收到数据时启动) startIdleTimer(); logger.trace("ResponseHandler: Idle timer reset (data received)", { taskId, @@ -2778,7 +2784,7 @@ export class ProxyResponseHandler { idleTimeoutMs: idleTimeoutMs === Infinity ? "disabled" : idleTimeoutMs, }); - // ⭐ 流式:读到第一块数据后立即清除响应超时定时器 + // 流式:读到第一块数据后立即清除响应超时定时器 if (isFirstChunk) { session.recordTtfb(); isFirstChunk = false; @@ -2797,7 +2803,7 @@ export class ProxyResponseHandler { } } - // ⭐ 流式读取完成:清除静默期计时器 + // 流式读取完成:清除静默期计时器 clearIdleTimer(); const allContent = flushAndJoin(); const clientAborted = session.clientAbortSignal?.aborted ?? false; @@ -2890,7 +2896,12 @@ export class ProxyResponseHandler { // 结算并消费 deferred meta,确保 provider chain/熔断归因完整 try { const allContent = flushAndJoin(); - await finalizeStream(allContent, false, false, "STREAM_IDLE_TIMEOUT"); + await finalizeStream( + allContent, + false, + clientAborted, + clientAborted ? "CLIENT_ABORTED" : "STREAM_IDLE_TIMEOUT" + ); } catch (finalizeError) { logger.error("ResponseHandler: Failed to finalize idle-timeout stream", { taskId, @@ -3028,7 +3039,7 @@ export class ProxyResponseHandler { // 确保资源释放 cleanupClientAbortListener(); clearClientAbortDrainTimer(); - clearIdleTimer(); // ⭐ 清除静默期计时器(防止泄漏) + clearIdleTimer(); // 清除静默期计时器(防止泄漏) try { reader.releaseLock(); } catch (releaseError) { diff --git a/tests/unit/proxy/response-handler-client-abort-drain.test.ts b/tests/unit/proxy/response-handler-client-abort-drain.test.ts index 8f3efa29b..7ce1d7f77 100644 --- a/tests/unit/proxy/response-handler-client-abort-drain.test.ts +++ b/tests/unit/proxy/response-handler-client-abort-drain.test.ts @@ -327,6 +327,65 @@ function createHangingResponsesSse(upstreamSignal: AbortSignal): Response { }); } +function createPreBodyHangingResponsesSse(upstreamSignal: AbortSignal): Response { + const stream = new ReadableStream({ + start(controller) { + upstreamSignal.addEventListener( + "abort", + () => { + const error = new Error("streaming_idle"); + error.name = "AbortError"; + controller.error(error); + }, + { once: true } + ); + }, + }); + + return new Response(stream, { + status: 200, + headers: { "content-type": "text/event-stream" }, + }); +} + +function createActiveHangingResponsesSse(upstreamSignal: AbortSignal): Response { + const encoder = new TextEncoder(); + let index = 0; + let intervalId: ReturnType | null = null; + + const encodeChunk = (delta: string) => + encoder.encode( + `event: response.output_text.delta\ndata: ${JSON.stringify({ + type: "response.output_text.delta", + delta, + })}\n\n` + ); + + const stream = new ReadableStream({ + start(controller) { + controller.enqueue(encodeChunk("短")); + intervalId = setInterval(() => { + controller.enqueue(encodeChunk(`持续-${++index}`)); + }, 4_000); + upstreamSignal.addEventListener( + "abort", + () => { + if (intervalId) clearInterval(intervalId); + const error = new Error("client_abort_drain_timeout"); + error.name = "AbortError"; + controller.error(error); + }, + { once: true } + ); + }, + }); + + return new Response(stream, { + status: 200, + headers: { "content-type": "text/event-stream" }, + }); +} + function createCompletedThenErroredResponsesSse(): Response { const encoder = new TextEncoder(); const chunks = [ @@ -642,12 +701,157 @@ describe("ProxyResponseHandler stream client abort finalization", () => { ); }); - it("bounds client-abort drain when the upstream stream hangs", async () => { + it("keeps client-abort drain independent from a small idle timeout while chunks are active", async () => { + vi.useFakeTimers(); + try { + const clientController = new AbortController(); + const upstreamController = new AbortController(); + const session = createSession(clientController.signal); + session.provider.streamingIdleTimeoutMs = 5_000; + Object.assign(session, { responseController: upstreamController }); + setDeferredStreamingFinalization(session, { + providerId: 1, + providerName: "avemujica-responses", + providerPriority: 1, + attemptNumber: 1, + totalProvidersAttempted: 1, + isFirstAttempt: true, + isFailoverSuccess: false, + endpointId: 42, + endpointUrl: "https://api.test.invalid/v1", + upstreamStatusCode: 200, + }); + + await ProxyResponseHandler.dispatch( + session, + createActiveHangingResponsesSse(upstreamController.signal) + ); + clientController.abort(); + + await vi.advanceTimersByTimeAsync(59_000); + expect(upstreamController.signal.aborted).toBe(false); + + await vi.advanceTimersByTimeAsync(1_000); + const tasks = asyncTasks.splice(0, asyncTasks.length); + await Promise.allSettled(tasks); + + expect(upstreamController.signal.aborted).toBe(true); + expect(AsyncTaskManager.cancel).not.toHaveBeenCalled(); + expect(updateMessageRequestDetails).toHaveBeenCalledWith( + 123, + expect.objectContaining({ + statusCode: 499, + errorMessage: "CLIENT_ABORTED", + }) + ); + } finally { + vi.useRealTimers(); + } + }); + + it("uses idle timeout for client-aborted streams that hang before the first chunk", async () => { + vi.useFakeTimers(); + try { + const clientController = new AbortController(); + const upstreamController = new AbortController(); + const session = createSession(clientController.signal); + session.provider.streamingIdleTimeoutMs = 5_000; + Object.assign(session, { responseController: upstreamController }); + setDeferredStreamingFinalization(session, { + providerId: 1, + providerName: "avemujica-responses", + providerPriority: 1, + attemptNumber: 1, + totalProvidersAttempted: 1, + isFirstAttempt: true, + isFailoverSuccess: false, + endpointId: 42, + endpointUrl: "https://api.test.invalid/v1", + upstreamStatusCode: 200, + }); + + await ProxyResponseHandler.dispatch( + session, + createPreBodyHangingResponsesSse(upstreamController.signal) + ); + clientController.abort(); + + await vi.advanceTimersByTimeAsync(4_999); + expect(upstreamController.signal.aborted).toBe(false); + + await vi.advanceTimersByTimeAsync(1); + const tasks = asyncTasks.splice(0, asyncTasks.length); + await Promise.allSettled(tasks); + + expect(upstreamController.signal.aborted).toBe(true); + expect(AsyncTaskManager.cancel).not.toHaveBeenCalled(); + expect(updateMessageRequestDetails).toHaveBeenCalledWith( + 123, + expect.objectContaining({ + statusCode: 499, + errorMessage: "CLIENT_ABORTED", + }) + ); + } finally { + vi.useRealTimers(); + } + }); + + it("preserves an existing idle deadline when the client aborts after a chunk", async () => { + vi.useFakeTimers(); + try { + const clientController = new AbortController(); + const upstreamController = new AbortController(); + const session = createSession(clientController.signal); + session.provider.streamingIdleTimeoutMs = 5_000; + Object.assign(session, { responseController: upstreamController }); + setDeferredStreamingFinalization(session, { + providerId: 1, + providerName: "avemujica-responses", + providerPriority: 1, + attemptNumber: 1, + totalProvidersAttempted: 1, + isFirstAttempt: true, + isFailoverSuccess: false, + endpointId: 42, + endpointUrl: "https://api.test.invalid/v1", + upstreamStatusCode: 200, + }); + + await ProxyResponseHandler.dispatch( + session, + createHangingResponsesSse(upstreamController.signal) + ); + await vi.advanceTimersByTimeAsync(0); + await vi.advanceTimersByTimeAsync(4_999); + expect(upstreamController.signal.aborted).toBe(false); + + clientController.abort(); + await vi.advanceTimersByTimeAsync(1); + const tasks = asyncTasks.splice(0, asyncTasks.length); + await Promise.allSettled(tasks); + + expect(upstreamController.signal.aborted).toBe(true); + expect(AsyncTaskManager.cancel).not.toHaveBeenCalled(); + expect(updateMessageRequestDetails).toHaveBeenCalledWith( + 123, + expect.objectContaining({ + statusCode: 499, + errorMessage: "CLIENT_ABORTED", + }) + ); + } finally { + vi.useRealTimers(); + } + }); + + it("caps client-abort drain at 60s when the upstream stream hangs", async () => { vi.useFakeTimers(); try { const clientController = new AbortController(); const upstreamController = new AbortController(); const session = createSession(clientController.signal); + session.provider.streamingIdleTimeoutMs = 120_000; Object.assign(session, { responseController: upstreamController }); setDeferredStreamingFinalization(session, { providerId: 1, From 74b586a59ecde90e27361701b2f331c1ff9a01b8 Mon Sep 17 00:00:00 2001 From: Ding <44717411+ding113@users.noreply.github.com> Date: Sun, 14 Jun 2026 10:50:46 +0800 Subject: [PATCH 3/3] =?UTF-8?q?feat(proxy):=20=E7=AB=9E=E9=80=9F=E8=B5=A2?= =?UTF-8?q?=E5=AE=B6=E6=97=A0=E6=9D=A1=E4=BB=B6=E6=94=B9=E7=BB=91=20Sessio?= =?UTF-8?q?n=20=E5=A4=8D=E7=94=A8=E7=BB=91=E5=AE=9A=20(#1279)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 产生竞速赢家(launchedProviderCount>1)时,无条件把 Session 复用绑定改绑到赢家。updateSessionBindingSmart 新增 forceUpdate 短路智能决策;commitWinner 传 isActualHedgeWin。含 TDD 单测与深度评审加固。 --- src/app/v1/_lib/proxy/forwarder.ts | 4 +- src/lib/session-manager.ts | 32 ++- .../lib/session-manager-binding-smart.test.ts | 194 ++++++++++++++++++ .../proxy-forwarder-hedge-first-byte.test.ts | 17 +- 4 files changed, 234 insertions(+), 13 deletions(-) create mode 100644 tests/unit/lib/session-manager-binding-smart.test.ts diff --git a/src/app/v1/_lib/proxy/forwarder.ts b/src/app/v1/_lib/proxy/forwarder.ts index 28bc7eb3f..8cefd5c56 100644 --- a/src/app/v1/_lib/proxy/forwarder.ts +++ b/src/app/v1/_lib/proxy/forwarder.ts @@ -4419,7 +4419,9 @@ export class ProxyForwarder { attempt.provider.priority || 0, launchedProviderCount === 1 && attempt.provider.id === initialProvider.id, attempt.provider.id !== initialProvider.id, - session.authState?.key?.id ?? null + session.authState?.key?.id ?? null, + // 产生了真实竞速赢家时,无条件把 Session 复用绑定改绑到赢家。 + isActualHedgeWin ); if (bindingResult.updated) { diff --git a/src/lib/session-manager.ts b/src/lib/session-manager.ts index 1acafdd1b..eaec22ae8 100644 --- a/src/lib/session-manager.ts +++ b/src/lib/session-manager.ts @@ -743,7 +743,7 @@ export class SessionManager { /** * 智能更新 Session 绑定 * - * 策略:首次绑定用 SET NX;故障转移后无条件更新;其他情况按优先级和熔断状态决策 + * 策略:首次绑定用 SET NX;故障转移成功或竞速赢家强制改绑时无条件更新;其他情况按优先级和熔断状态决策 */ static async updateSessionBindingSmart( sessionId: string, @@ -751,7 +751,8 @@ export class SessionManager { newProviderPriority: number, isFirstAttempt: boolean = false, isFailoverSuccess: boolean = false, - keyId?: number | null + keyId?: number | null, + forceUpdate: boolean = false ): Promise<{ updated: boolean; reason: string; details?: string }> { const redis = getRedisClient(); if (!redis || redis.status !== "ready") { @@ -801,8 +802,9 @@ export class SessionManager { // ========== 情况 2:重试成功(需要智能决策)========== - // 2.0 故障转移成功:无条件更新绑定(减少缓存切换) - if (isFailoverSuccess) { + // 2.0 故障转移成功 或 竞速赢家强制改绑:无条件更新绑定 + // forceUpdate 在读取当前绑定/优先级/熔断状态之前短路,确保竞速赢家一定成为复用绑定。 + if (isFailoverSuccess || forceUpdate) { const pipeline = redis.pipeline(); pipeline.setex( `session:${sessionId}:provider`, @@ -814,16 +816,24 @@ export class SessionManager { } await pipeline.exec(); - logger.info("SessionManager: Updated binding after failover", { - sessionId, - newProviderId, - newPriority: newProviderPriority, - }); + const reason = isFailoverSuccess ? "failover_success" : "race_winner_forced"; + logger.info( + isFailoverSuccess + ? "SessionManager: Updated binding after failover" + : "SessionManager: Forced binding to race winner", + { + sessionId, + newProviderId, + newPriority: newProviderPriority, + } + ); return { updated: true, - reason: "failover_success", - details: `故障转移成功,绑定到供应商 ${newProviderId}`, + reason, + details: isFailoverSuccess + ? `故障转移成功,绑定到供应商 ${newProviderId}` + : `竞速赢家强制改绑到供应商 ${newProviderId}`, }; } diff --git a/tests/unit/lib/session-manager-binding-smart.test.ts b/tests/unit/lib/session-manager-binding-smart.test.ts new file mode 100644 index 000000000..8186ff9b8 --- /dev/null +++ b/tests/unit/lib/session-manager-binding-smart.test.ts @@ -0,0 +1,194 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +/** + * Tests for SessionManager.updateSessionBindingSmart forceUpdate semantics. + * + * Hedge race winners must unconditionally rebind the session-reuse binding to + * the winner. forceUpdate short-circuits the smart-decision path (priority / + * circuit health) that would otherwise keep a healthy higher-priority binding. + */ + +let redisClientRef: { + status: string; + get: ReturnType; + set: ReturnType; + setex: ReturnType; + pipeline: ReturnType; +} | null; + +let lastPipeline: { + setex: ReturnType; + exec: ReturnType; +}; + +const makePipeline = () => { + const pipeline = { + setex: vi.fn(() => pipeline), + exec: vi.fn(async () => []), + }; + lastPipeline = pipeline; + return pipeline; +}; + +vi.mock("@/lib/logger", () => ({ + logger: { + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + trace: vi.fn(), + }, +})); + +vi.mock("@/lib/redis", () => ({ + getRedisClient: () => redisClientRef, +})); + +// Both are loaded via `await import(...)` inside updateSessionBindingSmart; the +// static vi.mock still intercepts the dynamic import. +vi.mock("@/repository/provider", () => ({ + findProviderById: vi.fn(), +})); + +vi.mock("@/lib/circuit-breaker", () => ({ + isCircuitOpen: vi.fn(), +})); + +import { isCircuitOpen } from "@/lib/circuit-breaker"; +import { SessionManager } from "@/lib/session-manager"; +import { findProviderById } from "@/repository/provider"; + +const SID = "sess-binding"; +const TTL = 300; + +beforeEach(() => { + vi.clearAllMocks(); + redisClientRef = { + status: "ready", + get: vi.fn(async () => null), + set: vi.fn(async () => "OK"), + setex: vi.fn(async () => "OK"), + pipeline: vi.fn(() => makePipeline()), + }; +}); + +describe("SessionManager.updateSessionBindingSmart forceUpdate", () => { + it("forceUpdate=true overrides a healthy higher-priority existing binding", async () => { + // Existing binding -> provider 1 (healthy, higher priority than the winner) + redisClientRef!.get.mockResolvedValue("1"); + vi.mocked(findProviderById).mockResolvedValue({ id: 1, name: "main", priority: 5 } as never); + vi.mocked(isCircuitOpen).mockResolvedValue(false); + + const result = await SessionManager.updateSessionBindingSmart( + SID, + 2, // winner id + 10, // winner priority (lower priority than current's 5) + false, // isFirstAttempt + false, // isFailoverSuccess + null, + true // forceUpdate + ); + + expect(result).toMatchObject({ updated: true, reason: "race_winner_forced" }); + expect(lastPipeline.setex).toHaveBeenCalledWith(`session:${SID}:provider`, TTL, "2"); + // Guard against a regression that queues setex but forgets to flush the pipeline. + expect(lastPipeline.exec).toHaveBeenCalledTimes(1); + }); + + it("forceUpdate=true rebinds even when the winner equals the current binding", async () => { + // Production winner==initialProvider race: the bound provider is already the winner, + // but the race result must still (re)write the binding and refresh its TTL. + redisClientRef!.get.mockResolvedValue("2"); + + const result = await SessionManager.updateSessionBindingSmart( + SID, + 2, // winner id == currently bound id + 10, + false, + false, + null, + true // forceUpdate + ); + + expect(result).toMatchObject({ updated: true, reason: "race_winner_forced" }); + expect(redisClientRef!.get).not.toHaveBeenCalled(); + expect(lastPipeline.setex).toHaveBeenCalledWith(`session:${SID}:provider`, TTL, "2"); + expect(lastPipeline.exec).toHaveBeenCalledTimes(1); + }); + + it("forceUpdate=false keeps the healthy higher-priority binding (documents the gap)", async () => { + redisClientRef!.get.mockResolvedValue("1"); + vi.mocked(findProviderById).mockResolvedValue({ id: 1, name: "main", priority: 5 } as never); + vi.mocked(isCircuitOpen).mockResolvedValue(false); + + const result = await SessionManager.updateSessionBindingSmart( + SID, + 2, + 10, + false, + false, + null, + false // forceUpdate + ); + + expect(result).toMatchObject({ updated: false, reason: "keep_healthy_higher_priority" }); + }); + + it("forceUpdate=true short-circuits before consulting provider/circuit state", async () => { + redisClientRef!.get.mockResolvedValue("1"); + + await SessionManager.updateSessionBindingSmart(SID, 2, 10, false, false, null, true); + + expect(findProviderById).not.toHaveBeenCalled(); + expect(isCircuitOpen).not.toHaveBeenCalled(); + // forceUpdate goes straight to the unconditional pipeline path. + expect(redisClientRef!.get).not.toHaveBeenCalled(); + }); + + it("forceUpdate=true also persists the keyId binding with TTL", async () => { + const result = await SessionManager.updateSessionBindingSmart( + SID, + 2, + 10, + false, + false, + 42, // keyId + true + ); + + expect(result.updated).toBe(true); + expect(lastPipeline.setex).toHaveBeenCalledWith(`session:${SID}:provider`, TTL, "2"); + expect(lastPipeline.setex).toHaveBeenCalledWith(`session:${SID}:key`, TTL, "42"); + expect(lastPipeline.exec).toHaveBeenCalledTimes(1); + }); + + it("isFailoverSuccess=true keeps reason failover_success even when forceUpdate=true", async () => { + const result = await SessionManager.updateSessionBindingSmart( + SID, + 2, + 10, + false, + true, // isFailoverSuccess + null, + true // forceUpdate + ); + + expect(result).toMatchObject({ updated: true, reason: "failover_success" }); + }); + + it("returns redis_not_ready regardless of forceUpdate when redis is unavailable", async () => { + redisClientRef!.status = "connecting"; + + const result = await SessionManager.updateSessionBindingSmart( + SID, + 2, + 10, + false, + false, + null, + true + ); + + expect(result).toMatchObject({ updated: false, reason: "redis_not_ready" }); + }); +}); diff --git a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts index 2f8f1d46f..6dcbccd33 100644 --- a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts +++ b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts @@ -968,13 +968,16 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { expect(mocks.recordFailure).not.toHaveBeenCalled(); expect(mocks.recordSuccess).not.toHaveBeenCalled(); expect(session.provider?.id).toBe(2); + // Actual hedge win (launchedProviderCount > 1) forces the session-reuse + // binding to the winner (forceUpdate=true, the trailing arg). expect(mocks.updateSessionBindingSmart).toHaveBeenCalledWith( "sess-hedge", 2, 0, false, true, - null + null, + true ); expect(mocks.releaseProviderSession).toHaveBeenCalledWith(1, "sess-hedge"); } finally { @@ -1267,6 +1270,18 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { expect(mocks.recordFailure).not.toHaveBeenCalled(); expect(mocks.recordSuccess).not.toHaveBeenCalled(); expect(session.provider?.id).toBe(1); + // Initial provider won the race (launchedProviderCount > 1): the binding + // must still be force-updated to the winner (forceUpdate=true), closing + // the gap where the smart path could keep a stale/higher-priority binding. + expect(mocks.updateSessionBindingSmart).toHaveBeenCalledWith( + "sess-hedge", + 1, + 0, + false, + false, + null, + true + ); expect(mocks.releaseProviderSession).toHaveBeenCalledWith(2, "sess-hedge"); } finally { vi.useRealTimers();