From cb83813dabcc99a0fdb4b0e1bde9d490e58866c9 Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Wed, 4 Mar 2026 14:25:14 +0800 Subject: [PATCH 01/75] =?UTF-8?q?=E2=9C=A8=20Enhance=20error=20handling=20?= =?UTF-8?q?and=20messaging=20for=20model=20operations:=20Added=20new=20err?= =?UTF-8?q?or=20codes=20and=20messages=20for=20prompt=20generation=20failu?= =?UTF-8?q?res,=20API=20key=20issues,=20rate=20limits,=20and=20service=20u?= =?UTF-8?q?navailability.=20Updated=20exception=20handling=20in=20the=20pr?= =?UTF-8?q?ompt=20service=20to=20yield=20appropriate=20error=20responses.?= =?UTF-8?q?=20Improved=20frontend=20error=20handling=20to=20display=20loca?= =?UTF-8?q?lized=20messages=20based=20on=20error=20codes.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/consts/error_code.py | 7 + backend/consts/error_message.py | 7 + backend/consts/exceptions.py | 2 +- backend/middleware/exception_handler.py | 6 +- backend/services/prompt_service.py | 73 +++++++--- backend/utils/llm_utils.py | 22 ++- .../agentInfo/AgentGenerateDetail.tsx | 16 ++- frontend/const/errorCode.ts | 7 + frontend/const/errorMessage.ts | 13 ++ frontend/public/locales/en/common.json | 128 +++++++++--------- frontend/public/locales/zh/common.json | 128 +++++++++--------- frontend/services/promptService.ts | 8 +- 12 files changed, 268 insertions(+), 149 deletions(-) diff --git a/backend/consts/error_code.py b/backend/consts/error_code.py index 73decbbf3..7affd2b2f 100644 --- a/backend/consts/error_code.py +++ b/backend/consts/error_code.py @@ -121,6 +121,13 @@ class ErrorCode(Enum): MODEL_CONFIG_INVALID = "090102" # Invalid model configuration MODEL_HEALTH_CHECK_FAILED = "090103" # Health check failed MODEL_PROVIDER_ERROR = "090104" # Model provider error + MODEL_PROMPT_GENERATION_FAILED = "090105" # Model prompt generation failed + # 02 - Model API errors + MODEL_API_KEY_INVALID = "090201" # API key is invalid or expired + MODEL_API_KEY_NO_PERMISSION = "090202" # API key does not have permission + MODEL_RATE_LIMIT_EXCEEDED = "090203" # Rate limit exceeded + MODEL_SERVICE_UNAVAILABLE = "090204" # Model service is temporarily unavailable + MODEL_CONNECTION_ERROR = "090205" # Failed to connect to model service # ==================== 10 Memory / 记忆管理 ==================== # 01 - Memory diff --git a/backend/consts/error_message.py b/backend/consts/error_message.py index aa7bf45e3..4ff1141c7 100644 --- a/backend/consts/error_message.py +++ b/backend/consts/error_message.py @@ -84,6 +84,13 @@ class ErrorMessage: ErrorCode.MODEL_CONFIG_INVALID: "Model configuration is invalid.", ErrorCode.MODEL_HEALTH_CHECK_FAILED: "Model health check failed.", ErrorCode.MODEL_PROVIDER_ERROR: "Model provider error.", + ErrorCode.MODEL_PROMPT_GENERATION_FAILED: "Model is unavailable. Please check the model status and try again.", + # 02 - Model API errors + ErrorCode.MODEL_API_KEY_INVALID: "Model API key is invalid or expired. Please check your API key configuration.", + ErrorCode.MODEL_API_KEY_NO_PERMISSION: "Model API key does not have permission. Please check your API key permissions.", + ErrorCode.MODEL_RATE_LIMIT_EXCEEDED: "Rate limit exceeded. Please try again later.", + ErrorCode.MODEL_SERVICE_UNAVAILABLE: "Model service is temporarily unavailable. Please try again later.", + ErrorCode.MODEL_CONNECTION_ERROR: "Failed to connect to model service. Please check your network and model configuration.", # ==================== 10 Memory / 记忆管理 ==================== ErrorCode.MEMORY_NOT_FOUND: "Memory not found.", diff --git a/backend/consts/exceptions.py b/backend/consts/exceptions.py index e9d270673..7f058c25c 100644 --- a/backend/consts/exceptions.py +++ b/backend/consts/exceptions.py @@ -43,7 +43,7 @@ def __init__(self, error_code: ErrorCode, message: str = None, details: dict = N def to_dict(self) -> dict: return { - "code": int(self.error_code.value), + "code": str(self.error_code.value), # Keep as string to preserve leading zeros "message": self.message, "details": self.details if self.details else None } diff --git a/backend/middleware/exception_handler.py b/backend/middleware/exception_handler.py index 14d9ebb38..6ec521f12 100644 --- a/backend/middleware/exception_handler.py +++ b/backend/middleware/exception_handler.py @@ -74,7 +74,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: return JSONResponse( status_code=http_status, content={ - "code": int(exc.error_code.value), + "code": exc.error_code.value, # Keep as string to preserve leading zeros "message": exc.message, "trace_id": trace_id, "details": exc.details if exc.details else None @@ -88,7 +88,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: return JSONResponse( status_code=exc.status_code, content={ - "code": int(error_code.value), + "code": error_code.value, # Keep as string to preserve leading zeros "message": exc.detail, "trace_id": trace_id } @@ -141,7 +141,7 @@ def create_error_response( return JSONResponse( status_code=status, content={ - "code": int(error_code.value), + "code": error_code.value, # Keep as string to preserve leading zeros "message": message or ErrorMessage.get_message(error_code), "trace_id": trace_id, "details": details diff --git a/backend/services/prompt_service.py b/backend/services/prompt_service.py index a505f28f4..3706c3cc5 100644 --- a/backend/services/prompt_service.py +++ b/backend/services/prompt_service.py @@ -7,8 +7,10 @@ from jinja2 import StrictUndefined, Template from consts.const import LANGUAGE -from consts.model import AgentInfoRequest -from database.agent_db import update_agent, search_agent_info_by_agent_id, query_all_agent_info_by_tenant_id, \ +from consts.error_code import ErrorCode +from consts.error_message import ErrorMessage +from consts.exceptions import AppException +from database.agent_db import search_agent_info_by_agent_id, query_all_agent_info_by_tenant_id, \ query_sub_agents_id_list from database.tool_db import query_tools_by_ids from services.agent_service import ( @@ -28,18 +30,30 @@ def gen_system_prompt_streamable(agent_id: int, model_id: int, task_description: str, user_id: str, tenant_id: str, language: str, tool_ids: Optional[List[int]] = None, sub_agent_ids: Optional[List[int]] = None): - for system_prompt in generate_and_save_system_prompt_impl( - agent_id=agent_id, - model_id=model_id, - task_description=task_description, - user_id=user_id, - tenant_id=tenant_id, - language=language, - tool_ids=tool_ids, - sub_agent_ids=sub_agent_ids - ): - # SSE format, each message ends with \n\n - yield f"data: {json.dumps({'success': True, 'data': system_prompt}, ensure_ascii=False)}\n\n" + try: + for system_prompt in generate_and_save_system_prompt_impl( + agent_id=agent_id, + model_id=model_id, + task_description=task_description, + user_id=user_id, + tenant_id=tenant_id, + language=language, + tool_ids=tool_ids, + sub_agent_ids=sub_agent_ids + ): + # SSE format, each message ends with \n\n + yield f"data: {json.dumps({'success': True, 'data': system_prompt}, ensure_ascii=False)}\n\n" + except Exception as e: + # Catch model unavailable or other errors and return error through SSE + logger.error(f"Error generating prompt: {e}") + # Use original error code if it's an AppException, otherwise use default + if isinstance(e, AppException): + error_code = e.error_code + error_message = e.message + else: + error_code = ErrorCode.MODEL_PROMPT_GENERATION_FAILED + error_message = ErrorMessage.get_message(error_code) + yield f"data: {json.dumps({'success': False, 'error': {'code': error_code.value, 'message': error_message}}, ensure_ascii=False)}\n\n" def generate_and_save_system_prompt_impl(agent_id: int, @@ -200,6 +214,14 @@ def generate_and_save_system_prompt_impl(agent_id: int, "Updating agent with business_description and prompt segments") logger.info("Prompt generation and agent update completed successfully") + # Check if any content was generated - if all fields are empty, model likely failed + all_fields = ["duty", "constraint", "few_shots", + "agent_var_name", "agent_display_name", "agent_description"] + has_content = any(final_results.get(field, "").strip() + for field in all_fields) + if not has_content: + raise Exception("Failed to generate prompt content.") + def generate_system_prompt(sub_agent_info_list, task_description, tool_info_list, tenant_id: str, model_id: int, language: str = LANGUAGE["ZH"]): """Main function for generating system prompts""" @@ -222,15 +244,18 @@ def generate_system_prompt(sub_agent_info_list, task_description, tool_info_list "agent_var_name": False, "agent_display_name": False, "agent_description": False} # Start all generation threads - threads = _start_generation_threads( + threads, error_holder = _start_generation_threads( content, prompt_for_generate, produce_queue, latest, stop_flags, tenant_id, model_id) # Stream results - yield from _stream_results(produce_queue, latest, stop_flags, threads) + yield from _stream_results(produce_queue, latest, stop_flags, threads, error_holder) def _start_generation_threads(content, prompt_for_generate, produce_queue, latest, stop_flags, tenant_id, model_id): """Start all prompt generation threads""" + # Shared error tracking across threads + error_holder = {"error": None} + def make_callback(tag): def callback_fn(current_text): latest[tag] = current_text @@ -243,6 +268,7 @@ def run_and_flag(tag, sys_prompt): model_id, content, sys_prompt, make_callback(tag), tenant_id) except Exception as e: logger.error(f"Error in {tag} generation: {e}") + error_holder["error"] = e finally: stop_flags[tag] = True @@ -266,10 +292,10 @@ def run_and_flag(tag, sys_prompt): thread.start() threads.append(thread) - return threads + return threads, error_holder -def _stream_results(produce_queue, latest, stop_flags, threads): +def _stream_results(produce_queue, latest, stop_flags, threads, error_holder): """Stream prompt generation results""" # Real-time streaming output for the first three sections @@ -277,6 +303,13 @@ def _stream_results(produce_queue, latest, stop_flags, threads): "agent_var_name": "", "agent_display_name": "", "agent_description": ""} while not all(stop_flags.values()): + # Check if error occurred in any thread - raise immediately + if error_holder.get("error"): + # Wait for threads to finish + for thread in threads: + thread.join(timeout=5) + raise error_holder["error"] + try: produce_queue.get(timeout=0.5) except queue.Empty: @@ -293,6 +326,10 @@ def _stream_results(produce_queue, latest, stop_flags, threads): yield result_data last_results[tag] = latest[tag] + # Check if error occurred before final output + if error_holder.get("error"): + raise error_holder["error"] + # Wait for all threads to complete for thread in threads: thread.join(timeout=5) diff --git a/backend/utils/llm_utils.py b/backend/utils/llm_utils.py index 0ede9a263..d1aa6fcf3 100644 --- a/backend/utils/llm_utils.py +++ b/backend/utils/llm_utils.py @@ -2,8 +2,9 @@ from typing import Callable, List, Optional from consts.const import MESSAGE_ROLE, THINK_END_PATTERN, THINK_START_PATTERN +from consts.error_code import ErrorCode +from consts.exceptions import AppException from database.model_management_db import get_model_by_model_id -from nexent.core.utils.observer import MessageObserver from nexent.core.models import OpenAIModel from utils.config_utils import get_model_name_from_config @@ -122,8 +123,23 @@ def call_llm_for_system_prompt( return result except Exception as exc: logger.error("Failed to generate prompt from LLM: %s", str(exc)) - raise + # Parse error code from exception message and raise appropriate AppException + # Use specific error codes for different scenarios + error_msg = str(exc) + if "401" in error_msg or "api key" in error_msg.lower() or "unauthorized" in error_msg.lower(): + raise AppException(ErrorCode.MODEL_API_KEY_INVALID) + elif "403" in error_msg or "forbidden" in error_msg.lower(): + raise AppException(ErrorCode.MODEL_API_KEY_NO_PERMISSION) + elif "404" in error_msg or "not found" in error_msg.lower(): + raise AppException(ErrorCode.MODEL_NOT_FOUND) + elif "429" in error_msg or "rate limit" in error_msg.lower(): + raise AppException(ErrorCode.MODEL_RATE_LIMIT_EXCEEDED) + elif "500" in error_msg or "502" in error_msg or "503" in error_msg or "504" in error_msg: + raise AppException(ErrorCode.MODEL_SERVICE_UNAVAILABLE) + elif "connection" in error_msg.lower() or "timeout" in error_msg.lower() or "refused" in error_msg.lower(): + raise AppException(ErrorCode.MODEL_CONNECTION_ERROR) + else: + raise AppException(ErrorCode.MODEL_PROMPT_GENERATION_FAILED) __all__ = ["call_llm_for_system_prompt", "_process_thinking_tokens"] - diff --git a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx index 27129594d..d161503b8 100644 --- a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx +++ b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx @@ -546,7 +546,21 @@ export default function AgentGenerateDetail({ }, (error) => { log.error("Generate prompt stream error:", error); - message.error(t("businessLogic.config.message.generateError")); + // Try to get i18n translated message using error code, fallback to backend message or default + let errorMessage = t("businessLogic.config.message.generateError"); + if (error?.code) { + const i18nKey = `errorCode.${error.code}`; + const translated = t(i18nKey); + // Check if translation exists (i18next returns the key if not found) + if (translated !== i18nKey) { + errorMessage = translated; + } else if (error?.message) { + errorMessage = error.message; + } + } else if (error?.message) { + errorMessage = error.message; + } + message.error(errorMessage); setIsGenerating(false); }, () => { diff --git a/frontend/const/errorCode.ts b/frontend/const/errorCode.ts index 88b8ba0cb..9d1154dc6 100644 --- a/frontend/const/errorCode.ts +++ b/frontend/const/errorCode.ts @@ -118,6 +118,13 @@ export const ErrorCode = { MODEL_CONFIG_INVALID: "090102", MODEL_HEALTH_CHECK_FAILED: "090103", MODEL_PROVIDER_ERROR: "090104", + MODEL_PROMPT_GENERATION_FAILED: "090105", + // 02 - Model API errors + MODEL_API_KEY_INVALID: "090201", + MODEL_API_KEY_NO_PERMISSION: "090202", + MODEL_RATE_LIMIT_EXCEEDED: "090203", + MODEL_SERVICE_UNAVAILABLE: "090204", + MODEL_CONNECTION_ERROR: "090205", // ==================== 10 Memory / 记忆管理 ==================== // 01 - Memory diff --git a/frontend/const/errorMessage.ts b/frontend/const/errorMessage.ts index 90ae1c286..02026f9d4 100644 --- a/frontend/const/errorMessage.ts +++ b/frontend/const/errorMessage.ts @@ -105,6 +105,19 @@ export const DEFAULT_ERROR_MESSAGES: Record = { [ErrorCode.MODEL_CONFIG_INVALID]: "Model configuration is invalid.", [ErrorCode.MODEL_HEALTH_CHECK_FAILED]: "Model health check failed.", [ErrorCode.MODEL_PROVIDER_ERROR]: "Model provider error.", + [ErrorCode.MODEL_PROMPT_GENERATION_FAILED]: + "Model is unavailable. Please check the model status and try again.", + // 02 - Model API errors + [ErrorCode.MODEL_API_KEY_INVALID]: + "Model API key is invalid or expired. Please check your API key configuration.", + [ErrorCode.MODEL_API_KEY_NO_PERMISSION]: + "Model API key does not have permission. Please check your API key permissions.", + [ErrorCode.MODEL_RATE_LIMIT_EXCEEDED]: + "Rate limit exceeded. Please try again later.", + [ErrorCode.MODEL_SERVICE_UNAVAILABLE]: + "Model service is temporarily unavailable. Please try again later.", + [ErrorCode.MODEL_CONNECTION_ERROR]: + "Failed to connect to model service. Please check your network and model configuration.", // ==================== 10 Memory / 记忆管理 ==================== // 01 - Memory diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index e69033126..f4d81f30a 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -1851,67 +1851,73 @@ "errorCode.130206": "Failed to parse Dify response. Please check API URL.", "errorCode.130301": "Failed to connect to ME service.", - "errorCode.101": "Validation failed.", - "errorCode.102": "Invalid parameter.", - "errorCode.103": "Required field is missing.", - - "errorCode.201": "You are not authorized to perform this action.", - "errorCode.202": "Access forbidden.", - "errorCode.203": "Your session has expired. Please login again.", - "errorCode.204": "Invalid token. Please login again.", - - "errorCode.301": "External service error.", - "errorCode.302": "Too many requests. Please try again later.", - - "errorCode.401": "File not found.", - "errorCode.402": "Failed to upload file.", - "errorCode.403": "File size exceeds limit.", - "errorCode.404": "File type not allowed.", - "errorCode.405": "File preprocessing failed.", - - "errorCode.501": "Resource not found.", - "errorCode.502": "Resource already exists.", - "errorCode.503": "Resource is disabled.", - - "errorCode.10101": "Conversation not found.", - "errorCode.10102": "Message not found.", - "errorCode.10103": "Failed to save conversation.", - "errorCode.10104": "Failed to generate conversation title.", - - "errorCode.20101": "Invalid configuration.", - "errorCode.20102": "Sync configuration failed.", - - "errorCode.30101": "Agent not found.", - "errorCode.30102": "Agent is disabled.", - "errorCode.30103": "Failed to run agent. Please try again later.", - "errorCode.30104": "Agent name already exists.", - "errorCode.30105": "Agent version not found.", - - "errorCode.40101": "Agent not found in market.", - - "errorCode.50101": "Invalid agent configuration.", - "errorCode.50102": "Invalid prompt.", - - "errorCode.60101": "Knowledge base not found.", - "errorCode.60102": "Failed to upload knowledge.", - "errorCode.60103": "Failed to sync knowledge base.", - "errorCode.60104": "Search index not found.", - "errorCode.60105": "Knowledge search failed.", - - "errorCode.70101": "Tool not found.", - "errorCode.70102": "Tool execution failed.", - "errorCode.70103": "Tool configuration is invalid.", - "errorCode.70201": "Failed to connect to MCP service.", - "errorCode.70202": "MCP container operation failed.", - "errorCode.70301": "MCP name contains invalid characters.", - - "errorCode.80101": "Metric query failed.", - "errorCode.80201": "Invalid alert configuration.", - - "errorCode.90101": "Model not found.", - "errorCode.90102": "Model configuration is invalid.", - "errorCode.90103": "Model health check failed.", - "errorCode.90104": "Model provider error.", + "errorCode.000101": "Validation failed.", + "errorCode.000102": "Invalid parameter.", + "errorCode.000103": "Required field is missing.", + + "errorCode.000201": "You are not authorized to perform this action.", + "errorCode.000202": "Access forbidden.", + "errorCode.000203": "Your session has expired. Please login again.", + "errorCode.000204": "Invalid token. Please login again.", + + "errorCode.000301": "External service error.", + "errorCode.000302": "Too many requests. Please try again later.", + + "errorCode.000401": "File not found.", + "errorCode.000402": "Failed to upload file.", + "errorCode.000403": "File size exceeds limit.", + "errorCode.000404": "File type not allowed.", + "errorCode.000405": "File preprocessing failed.", + + "errorCode.000501": "Resource not found.", + "errorCode.000502": "Resource already exists.", + "errorCode.000503": "Resource is disabled.", + + "errorCode.010101": "Conversation not found.", + "errorCode.010102": "Message not found.", + "errorCode.010103": "Failed to save conversation.", + "errorCode.010104": "Failed to generate conversation title.", + + "errorCode.020101": "Invalid configuration.", + "errorCode.020102": "Sync configuration failed.", + + "errorCode.030101": "Agent not found.", + "errorCode.030102": "Agent is disabled.", + "errorCode.030103": "Failed to run agent. Please try again later.", + "errorCode.030104": "Agent name already exists.", + "errorCode.030105": "Agent version not found.", + + "errorCode.040101": "Agent not found in market.", + + "errorCode.050101": "Invalid agent configuration.", + "errorCode.050102": "Invalid prompt.", + + "errorCode.060101": "Knowledge base not found.", + "errorCode.060102": "Failed to upload knowledge.", + "errorCode.060103": "Failed to sync knowledge base.", + "errorCode.060104": "Search index not found.", + "errorCode.060105": "Knowledge search failed.", + + "errorCode.070101": "Tool not found.", + "errorCode.070102": "Tool execution failed.", + "errorCode.070103": "Tool configuration is invalid.", + "errorCode.070201": "Failed to connect to MCP service.", + "errorCode.070202": "MCP container operation failed.", + "errorCode.070301": "MCP name contains invalid characters.", + + "errorCode.080101": "Metric query failed.", + "errorCode.080201": "Invalid alert configuration.", + + "errorCode.090101": "Model not found.", + "errorCode.090102": "Model configuration is invalid.", + "errorCode.090103": "Model health check failed.", + "errorCode.090104": "Model provider error.", + "errorCode.090105": "Model is unavailable. Please check the model status and try again.", + "errorCode.090201": "Model API key is invalid or expired. Please check your API key configuration.", + "errorCode.090202": "Model API key does not have permission. Please check your API key permissions.", + "errorCode.090203": "Rate limit exceeded. Please try again later.", + "errorCode.090204": "Model service is temporarily unavailable. Please try again later.", + "errorCode.090205": "Failed to connect to model service. Please check your network and model configuration.", "errorCode.100101": "Memory not found.", "errorCode.100102": "Failed to prepare memory.", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 7b16eea84..65a7553b4 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -1868,67 +1868,73 @@ "errorCode.130206": "Dify响应解析失败,请检查API URL", "errorCode.130301": "连接ME服务失败", - "errorCode.101": "验证失败", - "errorCode.102": "参数无效", - "errorCode.103": "缺少必填字段", - - "errorCode.201": "您没有执行此操作的权限", - "errorCode.202": "禁止访问", - "errorCode.203": "您的登录已过期,请重新登录", - "errorCode.204": "登录令牌无效,请重新登录", - - "errorCode.301": "外部服务错误", - "errorCode.302": "请求过于频繁,请稍后重试", - - "errorCode.401": "文件不存在", - "errorCode.402": "文件上传失败", - "errorCode.403": "文件大小超出限制", - "errorCode.404": "不支持的文件类型", - "errorCode.405": "文件预处理失败", - - "errorCode.501": "资源不存在", - "errorCode.502": "资源已存在", - "errorCode.503": "资源已被禁用", - - "errorCode.10101": "对话不存在", - "errorCode.10102": "消息不存在", - "errorCode.10103": "保存对话失败", - "errorCode.10104": "生成对话标题失败", - - "errorCode.20101": "配置无效", - "errorCode.20102": "同步配置失败", - - "errorCode.30101": "智能体不存在", - "errorCode.30102": "智能体已被禁用", - "errorCode.30103": "运行智能体失败,请稍后重试", - "errorCode.30104": "智能体名称已存在", - "errorCode.30105": "智能体版本不存在", - - "errorCode.40101": "市场中智能体不存在", - - "errorCode.50101": "智能体配置无效", - "errorCode.50102": "提示词无效", - - "errorCode.60101": "知识库不存在", - "errorCode.60102": "上传知识失败", - "errorCode.60103": "同步知识库失败", - "errorCode.60104": "搜索索引不存在", - "errorCode.60105": "知识搜索失败", - - "errorCode.70101": "工具不存在", - "errorCode.70102": "工具执行失败", - "errorCode.70103": "工具配置无效", - "errorCode.70201": "连接MCP服务失败", - "errorCode.70202": "MCP容器操作失败", - "errorCode.70301": "MCP名称包含非法字符", - - "errorCode.80101": "指标查询失败", - "errorCode.80201": "告警配置无效", - - "errorCode.90101": "模型不存在", - "errorCode.90102": "模型配置无效", - "errorCode.90103": "模型健康检查失败", - "errorCode.90104": "模型提供商错误", + "errorCode.000101": "验证失败", + "errorCode.000102": "参数无效", + "errorCode.000103": "缺少必填字段", + + "errorCode.000201": "您没有执行此操作的权限", + "errorCode.000202": "禁止访问", + "errorCode.000203": "您的登录已过期,请重新登录", + "errorCode.000204": "登录令牌无效,请重新登录", + + "errorCode.000301": "外部服务错误", + "errorCode.000302": "请求过于频繁,请稍后重试", + + "errorCode.000401": "文件不存在", + "errorCode.000402": "文件上传失败", + "errorCode.000403": "文件大小超出限制", + "errorCode.000404": "不支持的文件类型", + "errorCode.000405": "文件预处理失败", + + "errorCode.000501": "资源不存在", + "errorCode.000502": "资源已存在", + "errorCode.000503": "资源已被禁用", + + "errorCode.010101": "对话不存在", + "errorCode.010102": "消息不存在", + "errorCode.010103": "保存对话失败", + "errorCode.010104": "生成对话标题失败", + + "errorCode.020101": "配置无效", + "errorCode.020102": "同步配置失败", + + "errorCode.030101": "智能体不存在", + "errorCode.030102": "智能体已被禁用", + "errorCode.030103": "运行智能体失败,请稍后重试", + "errorCode.030104": "智能体名称已存在", + "errorCode.030105": "智能体版本不存在", + + "errorCode.040101": "市场中智能体不存在", + + "errorCode.050101": "智能体配置无效", + "errorCode.050102": "提示词无效", + + "errorCode.060101": "知识库不存在", + "errorCode.060102": "上传知识失败", + "errorCode.060103": "同步知识库失败", + "errorCode.060104": "搜索索引不存在", + "errorCode.060105": "知识搜索失败", + + "errorCode.070101": "工具不存在", + "errorCode.070102": "工具执行失败", + "errorCode.070103": "工具配置无效", + "errorCode.070201": "连接MCP服务失败", + "errorCode.070202": "MCP容器操作失败", + "errorCode.070301": "MCP名称包含非法字符", + + "errorCode.080101": "指标查询失败", + "errorCode.080201": "告警配置无效", + + "errorCode.090101": "模型不存在", + "errorCode.090102": "模型配置无效", + "errorCode.090103": "模型健康检查失败", + "errorCode.090104": "模型提供商错误", + "errorCode.090105": "模型不可用,请检查模型状态后重试", + "errorCode.090201": "模型 API 密钥无效或已过期,请检查 API 密钥配置", + "errorCode.090202": "模型 API 密钥没有权限,请检查 API 密钥权限", + "errorCode.090203": "请求频率超限,请稍后重试", + "errorCode.090204": "模型服务暂时不可用,请稍后重试", + "errorCode.090205": "连接模型服务失败,请检查网络和模型配置", "errorCode.100101": "记忆不存在", "errorCode.100102": "准备记忆失败", diff --git a/frontend/services/promptService.ts b/frontend/services/promptService.ts index 8d066556f..3b6c49395 100644 --- a/frontend/services/promptService.ts +++ b/frontend/services/promptService.ts @@ -30,6 +30,7 @@ export const generatePromptStream = async ( const reader = response.body.getReader(); const decoder = new TextDecoder('utf-8'); let buffer = ''; + let hasError = false; while (true) { const { value, done } = await reader.read(); @@ -44,6 +45,10 @@ export const generatePromptStream = async ( const json = JSON.parse(line.replace('data: ', '')); if (json.success) { onData(json.data); + } else if (json.success === false && json.error) { + // Handle error response from backend + hasError = true; + if (onError) onError(json.error); } } catch (e) { if (onError) onError(e); @@ -51,7 +56,8 @@ export const generatePromptStream = async ( } } } - if (onComplete) onComplete(); + // Only call onComplete if no error occurred + if (!hasError && onComplete) onComplete(); } catch (err) { if (onError) onError(err); if (onComplete) onComplete(); From 516cc950ef669b3b83e8e77426d44163a3c17603 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Tue, 3 Mar 2026 17:05:01 +0800 Subject: [PATCH 02/75] implement DashScope and TokenPony model providers --- backend/consts/provider.py | 10 ++ backend/services/model_management_service.py | 6 +- backend/services/model_provider_service.py | 8 ++ .../services/providers/dashscope_provider.py | 131 +++++++++++++++++ .../services/providers/tokenpony_provider.py | 120 ++++++++++++++++ .../components/model/ModelAddDialog.tsx | 34 ++++- .../models/components/model/ModelListCard.tsx | 116 ++++++++++++++- frontend/const/modelConfig.ts | 6 + frontend/hooks/model/useDashscopeModelList.ts | 133 ++++++++++++++++++ frontend/hooks/model/useTokenponyModelList.ts | 133 ++++++++++++++++++ frontend/package.json | 1 + frontend/public/locales/en/common.json | 6 + frontend/public/locales/zh/common.json | 6 + frontend/public/tokenpony.png | Bin 0 -> 1296 bytes frontend/types/modelConfig.ts | 2 + 15 files changed, 702 insertions(+), 10 deletions(-) create mode 100644 backend/services/providers/dashscope_provider.py create mode 100644 backend/services/providers/tokenpony_provider.py create mode 100644 frontend/hooks/model/useDashscopeModelList.ts create mode 100644 frontend/hooks/model/useTokenponyModelList.ts create mode 100644 frontend/public/tokenpony.png diff --git a/backend/consts/provider.py b/backend/consts/provider.py index 7fd783015..e2a0f0235 100644 --- a/backend/consts/provider.py +++ b/backend/consts/provider.py @@ -6,11 +6,21 @@ class ProviderEnum(str, Enum): SILICON = "silicon" OPENAI = "openai" MODELENGINE = "modelengine" + DASHSCOPE = "dashscope" + TOKENPONY = "tokenpony" # Silicon Flow SILICON_BASE_URL = "https://api.siliconflow.cn/v1/" SILICON_GET_URL = "https://api.siliconflow.cn/v1/models" +# Dashcope +DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" +DASHSCOPE_GET_URL = "https://dashscope.aliyuncs.com/api/v1/models" + +# TokenPony +TOKENPONY_BASE_URL = "https://api.tokenpony.cn/v1" +TOKENPONY_GET_URL = "https://api.tokenpony.cn/v1/models" + # ModelEngine # Base URL and API key are loaded from environment variables at runtime diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index 4b8265028..a18c16c36 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -3,7 +3,7 @@ from consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST from consts.model import ModelConnectStatusEnum -from consts.provider import ProviderEnum, SILICON_BASE_URL +from consts.provider import ProviderEnum, SILICON_BASE_URL, DASHSCOPE_BASE_URL, TOKENPONY_BASE_URL from database.model_management_db import ( create_model_record, @@ -142,6 +142,10 @@ async def batch_create_models_for_tenant(user_id: str, tenant_id: str, batch_pay elif provider == ProviderEnum.MODELENGINE.value: # ModelEngine models carry their own base_url in each model dict model_url = "" + elif provider == ProviderEnum.DASHSCOPE.value: + model_url = DASHSCOPE_BASE_URL + elif provider == ProviderEnum.TOKENPONY.value: + model_url = TOKENPONY_BASE_URL else: model_url = "" diff --git a/backend/services/model_provider_service.py b/backend/services/model_provider_service.py index a302eb999..3c916eb8c 100644 --- a/backend/services/model_provider_service.py +++ b/backend/services/model_provider_service.py @@ -11,6 +11,8 @@ from services.model_health_service import embedding_dimension_check from services.providers.base import AbstractModelProvider from services.providers.silicon_provider import SiliconModelProvider +from services.providers.tokenpony_provider import TokenPonyModelProvider +from services.providers.dashscope_provider import DashScopeModelProvider from services.providers.modelengine_provider import ModelEngineProvider, get_model_engine_raw_url, MODEL_ENGINE_NORTH_PREFIX from utils.model_name_utils import split_repo_name, add_repo_to_name @@ -40,6 +42,12 @@ async def get_provider_models(model_data: dict) -> List[dict]: elif model_data["provider"] == ProviderEnum.MODELENGINE.value: provider = ModelEngineProvider() model_list = await provider.get_models(model_data) + elif model_data["provider"] == ProviderEnum.DASHSCOPE.value: + provider = DashScopeModelProvider() + model_list = await provider.get_models(model_data) + elif model_data["provider"] == ProviderEnum.TOKENPONY.value: + provider = TokenPonyModelProvider() + model_list = await provider.get_models(model_data) return model_list diff --git a/backend/services/providers/dashscope_provider.py b/backend/services/providers/dashscope_provider.py new file mode 100644 index 000000000..2a34823ed --- /dev/null +++ b/backend/services/providers/dashscope_provider.py @@ -0,0 +1,131 @@ +import httpx +from typing import Dict, List +import asyncio +from consts.const import DEFAULT_LLM_MAX_TOKENS +from consts.provider import DASHSCOPE_GET_URL +from services.providers.base import AbstractModelProvider, _classify_provider_error + + +class DashScopeModelProvider(AbstractModelProvider): + """Concrete implementation for DashScope (Aliyun) provider.""" + + async def get_models(self, provider_config: Dict) -> List[Dict]: + """ + Fetch models from DashScope API, categorize them, and return + the requested model type. + + Args: + provider_config: Configuration dict containing model_type and api_key + + Returns: + List of models with canonical fields. Returns error dict if API call fails. + """ + try: + target_model_type: str = provider_config["model_type"] + model_api_key: str = provider_config["api_key"] + + headers = {"Authorization": f"Bearer {model_api_key}"} + base_url = DASHSCOPE_GET_URL + + all_models: List[Dict] = [] + current_page = 1 + + # Fetch all models with pagination asynchronously + async with httpx.AsyncClient(verify=False) as client: + while True: + params = {"page_size": 100, "page_no": current_page} + response = await client.get(base_url, headers=headers, params=params) + response.raise_for_status() + + data = response.json() + models = data.get("output", {}).get("models", []) + + if response.status_code == 429: + await asyncio.sleep(2) + continue + if not models : # Break loop if no more models on the current page + break + + all_models.extend(models) + if(len(models)<100): + break + current_page += 1 + await asyncio.sleep(0.5) + + # Initialize containers for the 6 main categories + categorized_models = { + "chat": [], # Maps to "llm" + "vlm": [], # Maps to "vlm" + "embedding": [], # Maps to "embedding" / "multi_embedding" + "reranker": [], # Maps to "reranker" + "tts": [], # Maps to "tts" + "stt": [] # Maps to "stt" + } + + # Classify models and inject canonical fields expected downstream + for model_obj in all_models: + # Extract key fields for logical determination (lowercased for robustness) + m_id = model_obj.get('model', '').lower() + desc = model_obj.get('description', '') + metadata = model_obj.get('inference_metadata', {}) + req_mod = metadata.get('request_modality', []) + res_mod = metadata.get('response_modality', []) + model_obj.setdefault("object", model_obj.get("object", "model")) + model_obj.setdefault("owned_by", model_obj.get("owned_by", "dashscope")) + cleaned_model = { + "id": m_id, + "object": model_obj.get("object"), + "created": 0, + "owned_by": model_obj.get("owned_by"), + "model_tag": "", + "model_type": "", + "max_tokens": DEFAULT_LLM_MAX_TOKENS + } + # 1. Embedding + if 'embedding' in m_id.lower() or '向量' in desc: + cleaned_model.update({"model_tag": "embedding", "model_type": "embedding"}) + categorized_models['embedding'].append(cleaned_model) + continue + + # 2. Reranker + if 'rerank' in m_id.lower() or '重排序' in desc: + cleaned_model.update({"model_tag": "reranker", "model_type": "reranker"}) + categorized_models['reranker'].append(cleaned_model) + continue + + # 3. STT + if 'Audio' in req_mod and 'Text' in res_mod: + cleaned_model.update({"model_tag": "stt", "model_type": "stt"}) + categorized_models['stt'].append(cleaned_model) + continue + + # 4. TTS + if 'Audio' in res_mod and 'Video' not in res_mod: + cleaned_model.update({"model_tag": "tts", "model_type": "tts"}) + categorized_models['tts'].append(cleaned_model) + continue + + # 5. VLM + vision_mods = {'Image', 'Video'} + if (set(req_mod) & vision_mods) or (set(res_mod) & vision_mods) or '视觉' in desc: + cleaned_model.update({"model_tag": "chat", "model_type": "vlm"}) + categorized_models['vlm'].append(cleaned_model) + continue + + # 6. Chat / LLM + if 'Text' in req_mod or 'Text' in res_mod: + cleaned_model.update({"model_tag": "chat", "model_type": "llm"}) + categorized_models['chat'].append(cleaned_model) + + # Return the specific list based on the requested target_model_type + if target_model_type == "llm": + return categorized_models["chat"] + elif target_model_type in ("embedding", "multi_embedding"): + return categorized_models["embedding"] + elif target_model_type in categorized_models: + return categorized_models[target_model_type] + else: + return [] + except (httpx.HTTPStatusError, httpx.ConnectTimeout, httpx.ConnectError, Exception) as e: + return _classify_provider_error("DashScope", exception=e) + diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py new file mode 100644 index 000000000..62972b698 --- /dev/null +++ b/backend/services/providers/tokenpony_provider.py @@ -0,0 +1,120 @@ +import httpx +import ssl + +from typing import Dict, List + + +from consts.const import DEFAULT_LLM_MAX_TOKENS +from consts.provider import TOKENPONY_GET_URL +from services.providers.base import AbstractModelProvider, _classify_provider_error + + +class TokenPonyModelProvider(AbstractModelProvider): + """Concrete implementation for TokenPony provider.""" + + async def get_models(self, provider_config: Dict) -> List[Dict]: + """ + Fetch models from TokenPony API, categorize them based on modality/ID, + and return the requested model type. + + Args: + provider_config: Configuration dict containing model_type and api_key + + Returns: + List of models with canonical fields. Returns error dict if API call fails. + """ + try: + target_model_type: str = provider_config["model_type"] + model_api_key: str = provider_config["api_key"] + + headers = {"Authorization": f"Bearer {model_api_key}"} + url = TOKENPONY_GET_URL + + + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + ssl_context.set_ciphers("DEFAULT@SECLEVEL=1") + # response = requests.get(url, headers=headers) + # all_models=[] + # if response.status_code == 200: + # data = response.json() + # # 注意:OpenAI 标准返回是在 "data" 字段下 + # all_models=data.get("data", []) + # Fetch all models asynchronously + async with httpx.AsyncClient(http2=True) as client: + response = await client.get(url, headers=headers) + response.raise_for_status() + # OpenAI standard response puts the model list inside the "data" array + all_models: List[Dict] = response.json().get("data", []) + + # Initialize containers for the 6 main categories + categorized_models = { + "chat": [], # Maps to "llm" + "vlm": [], # Maps to "vlm" + "embedding": [], # Maps to "embedding" / "multi_embedding" + "reranker": [], # Maps to "reranker" + "tts": [], # Maps to "tts" + "stt": [] # Maps to "stt" + } + + # Classify models and inject canonical fields expected downstream + for model_obj in all_models: + m_id = model_obj['id'].lower() + model_obj.setdefault("object", model_obj.get("object", "model")) + model_obj.setdefault("owned_by", model_obj.get("owned_by", "tokenpony")) + cleaned_model = { + "id": m_id, + "object": model_obj.get("object"), + "created": 0, + "owned_by": model_obj.get("owned_by"), + "model_tag": "", + "model_type": "", + "max_tokens": DEFAULT_LLM_MAX_TOKENS + } + # 1. Embedding + if 'embedding' in m_id or m_id.startswith('bge-'): + cleaned_model.update({"model_tag": "embedding", "model_type": "embedding", "max_tokens": 0}) + categorized_models['embedding'].append(cleaned_model) + + # 2. Reranker + elif 'rerank' in m_id: + cleaned_model.update({"model_tag": "reranker", "model_type": "reranker"}) + categorized_models['reranker'].append(cleaned_model) + + + # 3. STT (Speech-to-Text / Audio understanding) + elif 'stt' in m_id: + cleaned_model.update({"model_tag": "stt", "model_type": "stt"}) + categorized_models['stt'].append(cleaned_model) + + + # 4. TTS (Text-to-Speech) + elif 'tts' in m_id: + cleaned_model.update({"model_tag": "tts", "model_type": "tts"}) + categorized_models['tts'].append(cleaned_model) + + # 5. VLM (Vision Language Model / Image & Video Generation) + + elif any(keyword in m_id for keyword in ['-vl', 'vl-', 'ocr', 'vision']): + cleaned_model.update({"model_tag": "chat", "model_type": "vlm"}) + categorized_models['vlm'].append(cleaned_model) + + # 6. Chat (Pure Text Conversation / Reasoning) + # Fallback check added: 'not metadata' catches standard OpenAI models that lack modality data + else : + cleaned_model.update({"model_tag": "chat", "model_type": "llm"}) + categorized_models['chat'].append(cleaned_model) + + # Return the specific list based on the requested target_model_type + if target_model_type == "llm": + return categorized_models["chat"] + elif target_model_type in ("embedding", "multi_embedding"): + return categorized_models["embedding"] + elif target_model_type in categorized_models: + return categorized_models[target_model_type] + else: + return [] + + except (httpx.HTTPStatusError, httpx.ConnectTimeout, httpx.ConnectError, Exception) as e: + return _classify_provider_error("TokenPony", exception=e) diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index 2df9643a9..cd258abc8 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -16,6 +16,8 @@ import { modelService } from "@/services/modelService"; import { ModelType, SingleModelConfig } from "@/types/modelConfig"; import { MODEL_TYPES, PROVIDER_LINKS } from "@/const/modelConfig"; import { useSiliconModelList } from "@/hooks/model/useSiliconModelList"; +import { useDashscopeModelList } from "@/hooks/model/useDashscopeModelList"; +import { useTokenPonyModelList } from "@/hooks/model/useTokenponyModelList"; import log from "@/lib/logger"; import { ModelChunkSizeSlider, @@ -248,7 +250,7 @@ export const ModelAddDialog = ({ const [modelMaxTokens, setModelMaxTokens] = useState("4096"); // Use the silicon model list hook - const { getModelList, getProviderSelectedModalList } = useSiliconModelList({ + const siliconHook = useSiliconModelList({ form, setModelList, setSelectedModelIds, @@ -256,7 +258,33 @@ export const ModelAddDialog = ({ setLoadingModelList, tenantId, }); - + const dashscopeHook = useDashscopeModelList({ + form, + setModelList, + setSelectedModelIds, + setShowModelList, + setLoadingModelList, + tenantId, + }); + const tokenponyHook = useTokenPonyModelList({ + form, + setModelList, + setSelectedModelIds, + setShowModelList, + setLoadingModelList, + tenantId, + }); + let getModelList; + let getProviderSelectedModalList; + +// 2. 根据条件赋值 + if (form.provider === "silicon") { + ({ getModelList, getProviderSelectedModalList } = siliconHook); + } else if (form.provider === "dashscope") { + ({ getModelList, getProviderSelectedModalList } = dashscopeHook); + } else if (form.provider === "tokenpony") { + ({ getModelList, getProviderSelectedModalList } = tokenponyHook); + } // Reset form to default state const resetForm = useCallback(() => { setForm(DEFAULT_FORM_STATE); @@ -794,6 +822,8 @@ export const ModelAddDialog = ({ {t("model.provider.modelengine")} + + {/* ModelEngine URL input (only when provider is ModelEngine) */} {form.provider === "modelengine" && ( diff --git a/frontend/app/[locale]/models/components/model/ModelListCard.tsx b/frontend/app/[locale]/models/components/model/ModelListCard.tsx index ae966ae35..8bf6e00a6 100644 --- a/frontend/app/[locale]/models/components/model/ModelListCard.tsx +++ b/frontend/app/[locale]/models/components/model/ModelListCard.tsx @@ -33,12 +33,12 @@ const PULSE_ANIMATION = ` transform: scale(0.95); box-shadow: 0 0 0 0 rgba(41, 128, 185, 0.7); } - + 70% { transform: scale(1); box-shadow: 0 0 0 5px rgba(41, 128, 185, 0); } - + 100% { transform: scale(0.95); box-shadow: 0 0 0 0 rgba(41, 128, 185, 0); @@ -162,27 +162,33 @@ export const ModelListCard = ({ const model = modelsData.find( (m) => m.type === type && m.displayName === displayName ); - + if (!model) return t("model.source.unknown"); - + // Return source label based on model.source if (model.source === "modelengine") { return t("model.source.modelEngine"); } else if (model.source === "silicon") { return t("model.source.silicon"); + } else if (model.source==="dashscope"){ + return t("model.source.dashscope"); + }else if (model.source==="tokenpony"){ + return t("model.source.tokenpony"); } else if (model.source === "OpenAI-API-Compatible") { return t("model.source.custom"); } - + return t("model.source.unknown"); }; const filteredModels = getFilteredModels(); - + // Group models by source for display const groupedModels = { modelengine: filteredModels.filter((m) => m.source === "modelengine"), silicon: filteredModels.filter((m) => m.source === "silicon"), + dashscope: filteredModels.filter((m) => m.source === "dashscope"), + tokenpony: filteredModels.filter((m) => m.source === "tokenpony"), custom: filteredModels.filter((m) => m.source === "OpenAI-API-Compatible"), }; @@ -343,6 +349,102 @@ export const ModelListCard = ({ ))} )} + {groupedModels.dashscope.length > 0 && ( + + {groupedModels.dashscope.map((model) => ( + + ))} + + )} + {groupedModels.tokenpony.length > 0 && ( + + {groupedModels.tokenpony.map((model) => ( + + ))} + + )} {groupedModels.custom.length > 0 && ( {groupedModels.custom.map((model) => ( @@ -394,4 +496,4 @@ export const ModelListCard = ({ ); -}; \ No newline at end of file +}; diff --git a/frontend/const/modelConfig.ts b/frontend/const/modelConfig.ts index ce7f1841d..9b0128529 100644 --- a/frontend/const/modelConfig.ts +++ b/frontend/const/modelConfig.ts @@ -40,6 +40,8 @@ export const MODEL_PROVIDER_KEYS = [ "jina", "deepseek", "aliyuncs", + "tokenpony", + "dashscope", ] as const; export type ModelProviderKey = (typeof MODEL_PROVIDER_KEYS)[number]; @@ -52,6 +54,8 @@ export const PROVIDER_HINTS: Record = { jina: "jina", deepseek: "deepseek", aliyuncs: "aliyuncs", + tokenpony: "tokenpony", + dashscope: "dashscope", }; // Icon filenames for providers @@ -62,6 +66,8 @@ export const PROVIDER_ICON_MAP: Record = { jina: "/jina.png", deepseek: "/deepseek.png", aliyuncs: "/aliyuncs.png", + dashscope:"/aliyuncs.png", + tokenpony: "/tokenpony.png", }; export const OFFICIAL_PROVIDER_ICON = "/modelengine-logo.png"; diff --git a/frontend/hooks/model/useDashscopeModelList.ts b/frontend/hooks/model/useDashscopeModelList.ts new file mode 100644 index 000000000..b44348fe5 --- /dev/null +++ b/frontend/hooks/model/useDashscopeModelList.ts @@ -0,0 +1,133 @@ +import { useEffect } from "react"; +import { message } from "antd"; +import { useTranslation } from "react-i18next"; +import { modelService } from "@/services/modelService"; +import { ModelType } from "@/types/modelConfig"; +import { processProviderResponse } from "@/lib/providerError"; +import log from "@/lib/logger"; + +interface UseDashscopeModelListProps { + form: { + type: ModelType; + isBatchImport: boolean; + apiKey: string; + provider: string; // Expected to be "dashscope" + maxTokens: string; + isMultimodal: boolean; + }; + setModelList: (models: any[]) => void; + setSelectedModelIds: (ids: Set) => void; + setShowModelList: (show: boolean) => void; + setLoadingModelList: (loading: boolean) => void; + tenantId?: string; // Optional tenant ID for manage operations +} + +export const useDashscopeModelList = ({ + form, + setModelList, + setSelectedModelIds, + setShowModelList, + setLoadingModelList, + tenantId, +}: UseDashscopeModelListProps) => { + const { t } = useTranslation(); + + const getModelList = async () => { + setShowModelList(true); + setLoadingModelList(true); + + const modelType = + form.type === "embedding" && form.isMultimodal + ? ("multi_embedding" as ModelType) + : form.type; + + try { + // Use manage interface if tenantId is provided (for super admin) + const result = tenantId + ? await modelService.addManageProviderModel({ + tenantId, + provider: form.provider, + type: modelType, + apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, + }) + : await modelService.addProviderModel({ + provider: form.provider, + type: modelType, + apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, + }); + + // Use centralized error processing + const { models, error } = processProviderResponse( + result, + form.provider, + t + ); + + if (error) { + message.error(error); + setModelList([]); + setSelectedModelIds(new Set()); + setLoadingModelList(false); + return; + } + + // Ensure each model has a default max_tokens value + const modelsWithDefaults = models.map((model: any) => ({ + ...model, + max_tokens: model.max_tokens || parseInt(form.maxTokens) || 4096, + })); + setModelList(modelsWithDefaults); + + const selectedModels = (await getProviderSelectedModalList()) || []; + + // Key logic: Sync previously selected models + if (!selectedModels.length) { + // Select none + setSelectedModelIds(new Set()); + } else { + // Only select selectedModels + setSelectedModelIds(new Set(selectedModels.map((m: any) => m.id))); + } + } catch (error) { + message.error(t("model.dialog.error.addFailed", { error })); + log.error(t("model.dialog.error.addFailedLog"), error); + } finally { + setLoadingModelList(false); + } + }; + + const getProviderSelectedModalList = async () => { + const modelType = + form.type === "embedding" && form.isMultimodal + ? ("multi_embedding" as ModelType) + : form.type; + + // Use manage interface if tenantId is provided (for super admin) + const result = tenantId + ? await modelService.getManageProviderSelectedModalList({ + tenantId, + provider: form.provider, + type: modelType, + }) + : await modelService.getProviderSelectedModalList({ + provider: form.provider, + type: modelType, + api_key: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, + }); + + return result; + }; + + // Auto-fetch model list when batch import is enabled and API key is provided + useEffect(() => { + if (form.isBatchImport && form.apiKey.trim() !== "") { + getModelList(); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [form.type, form.isBatchImport]); + + return { + getModelList, + getProviderSelectedModalList, + }; +}; diff --git a/frontend/hooks/model/useTokenponyModelList.ts b/frontend/hooks/model/useTokenponyModelList.ts new file mode 100644 index 000000000..0a7e23581 --- /dev/null +++ b/frontend/hooks/model/useTokenponyModelList.ts @@ -0,0 +1,133 @@ +import { useEffect } from "react"; +import { message } from "antd"; +import { useTranslation } from "react-i18next"; +import { modelService } from "@/services/modelService"; +import { ModelType } from "@/types/modelConfig"; +import { processProviderResponse } from "@/lib/providerError"; +import log from "@/lib/logger"; + +interface UseTokenPonyModelListProps { + form: { + type: ModelType; + isBatchImport: boolean; + apiKey: string; + provider: string; // Expected to be "tokenpony" + maxTokens: string; + isMultimodal: boolean; + }; + setModelList: (models: any[]) => void; + setSelectedModelIds: (ids: Set) => void; + setShowModelList: (show: boolean) => void; + setLoadingModelList: (loading: boolean) => void; + tenantId?: string; // Optional tenant ID for manage operations +} + +export const useTokenPonyModelList = ({ + form, + setModelList, + setSelectedModelIds, + setShowModelList, + setLoadingModelList, + tenantId, +}: UseTokenPonyModelListProps) => { + const { t } = useTranslation(); + + const getModelList = async () => { + setShowModelList(true); + setLoadingModelList(true); + + const modelType = + form.type === "embedding" && form.isMultimodal + ? ("multi_embedding" as ModelType) + : form.type; + + try { + // Use manage interface if tenantId is provided (for super admin) + const result = tenantId + ? await modelService.addManageProviderModel({ + tenantId, + provider: form.provider, + type: modelType, + apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, + }) + : await modelService.addProviderModel({ + provider: form.provider, + type: modelType, + apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, + }); + + // Use centralized error processing + const { models, error } = processProviderResponse( + result, + form.provider, + t + ); + + if (error) { + message.error(error); + setModelList([]); + setSelectedModelIds(new Set()); + setLoadingModelList(false); + return; + } + + // Ensure each model has a default max_tokens value + const modelsWithDefaults = models.map((model: any) => ({ + ...model, + max_tokens: model.max_tokens || parseInt(form.maxTokens) || 4096, + })); + setModelList(modelsWithDefaults); + + const selectedModels = (await getProviderSelectedModalList()) || []; + + // Key logic: Sync previously selected models + if (!selectedModels.length) { + // Select none + setSelectedModelIds(new Set()); + } else { + // Only select selectedModels + setSelectedModelIds(new Set(selectedModels.map((m: any) => m.id))); + } + } catch (error) { + message.error(t("model.dialog.error.addFailed", { error })); + log.error(t("model.dialog.error.addFailedLog"), error); + } finally { + setLoadingModelList(false); + } + }; + + const getProviderSelectedModalList = async () => { + const modelType = + form.type === "embedding" && form.isMultimodal + ? ("multi_embedding" as ModelType) + : form.type; + + // Use manage interface if tenantId is provided (for super admin) + const result = tenantId + ? await modelService.getManageProviderSelectedModalList({ + tenantId, + provider: form.provider, + type: modelType, + }) + : await modelService.getProviderSelectedModalList({ + provider: form.provider, + type: modelType, + api_key: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, + }); + + return result; + }; + + // Auto-fetch model list when batch import is enabled and API key is provided + useEffect(() => { + if (form.isBatchImport && form.apiKey.trim() !== "") { + getModelList(); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [form.type, form.isBatchImport]); + + return { + getModelList, + getProviderSelectedModalList, + }; +}; diff --git a/frontend/package.json b/frontend/package.json index ba8ce8a67..db2e48756 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -25,6 +25,7 @@ "bootstrap-icons": "^1.11.3", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", + "cross-env": "^10.1.0", "dayjs": "^1.11.19", "dicebear": "^9.2.2", "dotenv": "^16.4.7", diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 775eae675..986140c83 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -674,6 +674,8 @@ "model.dialog.hint.batchImportEnabled": "Batch add enabled. Multiple models will be added at once.", "model.dialog.hint.batchImportDisabled": "Batch add disabled. Only a single model will be added.", "model.provider.silicon": "SiliconFlow", + "model.provider.dashscope": "DashScope", + "model.provider.tokenpony": "TokenPony", "model.provider.modelengine": "ModelEngine", "model.dialog.modelList.title": "Show Models", "model.dialog.modelList.searchPlaceholder": "Search models by name", @@ -746,12 +748,16 @@ "model.source.modelEngine": "ModelEngine", "model.source.openai": "OpenAI", "model.source.silicon": "Silicon Flow", + "model.source.dashscope": "DashScope", + "model.source.tokenpony": "TokenPony", "model.source.unknown": "Unknown Source", "model.warning.updateNotFound": "Model not found for update: {{displayName}}, type: {{type}}", "model.type.main": "LLM Model", "model.select.placeholder": "Select Model", "model.group.modelEngine": "ModelEngine Models", "model.group.silicon": "Silicon Flow Models", + "model.group.dashscope": "DashScope Models", + "model.group.tokenpony": "TokenPony Models", "model.group.custom": "Custom Models", "model.status.tooltip": "Click to verify connectivity", "model.dialog.embeddingConfig.title": "Edit Embedding Model: {{modelName}}", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 88ef18fdc..b830b1792 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -676,6 +676,8 @@ "model.dialog.hint.batchImportEnabled": "批量添加模式已启用,可通过API Key一次性导入多个模型", "model.dialog.hint.batchImportDisabled": "批量添加模式已关闭,仅添加单个模型", "model.provider.silicon": "硅基流动", + "model.provider.dashscope": "阿里灵积", + "model.provider.tokenpony": "小马算力", "model.provider.modelengine": "ModelEngine", "model.dialog.modelList.title": "显示模型", "model.dialog.modelList.searchPlaceholder": "按名称搜索模型", @@ -748,11 +750,15 @@ "model.source.unknown": "未知来源", "model.source.openai": "OpenAI", "model.source.silicon": "硅基流动", + "model.source.dashscope": "阿里灵积", + "model.source.tokenpony": "小马算力", "model.warning.updateNotFound": "未找到要更新的模型: {{displayName}}, 类型: {{type}}", "model.type.main": "大语言模型", "model.select.placeholder": "选择模型", "model.group.modelEngine": "ModelEngine模型", "model.group.silicon": "硅基流动模型", + "model.group.dashscope": "阿里灵积模型", + "model.group.tokenpony": "小马算力模型", "model.group.custom": "自定义模型", "model.status.tooltip": "点击可验证连通性", "model.dialog.success.updateSuccess": "更新成功", diff --git a/frontend/public/tokenpony.png b/frontend/public/tokenpony.png new file mode 100644 index 0000000000000000000000000000000000000000..d582ae86b2b3a14192759a9d89d39d25bcc1508f GIT binary patch literal 1296 zcmV+r1@HQaP)Px#1ZP1_K>z@;j|==^1poj532;bRa{vGi!TsHp)@UQ zK(JwufJzK26bR5FD4XDTOdC7KA+hVkwH^C#&wc;djtA7VsJ_x|Uf;XtocrH!-?IZ0 zQ0pg8&STHMB>HZSVO3Wc9)Ael6(Rk*9Jd&9kc*t;uGgz1P$=Z_+)xt!%Oi07NqAOF z3V|fh5xnpO1gjjvMNV>4gIY@xC>HbBHjt2@v;ZRTbqawc6C$TWT-wWW!4Rv-h?A4t z)Yvu&w7dE!Eaz;&wRxp!qAHuC-xoe>{QV7g88Lkiug-6X@;@ z8^YSyTC9vQhOxvDOAd0?<8;PFTN02)Wm?1~aH|}Xc%mq~6eyR=>a^)(5f45buae0* zS6AGGjxmN-iz62~EpCaX$wV4lD|ZT(VMBvwgW z9Q!z?Cb7g2OAZ=2$<0`5Rdj8`621FgS^)&^efsapSMJ)A)IQqCv5X-i0bSbcGgj7E zizA1{k(;sh>z5?pc!Tg=7Q(4NS)qV$zFm-DnKBBg2k_-rtmVM?cwWjX!6G8gHv4Fd zsmGCvMs8|Qt1$s5h?R12*fo6p%~|Oihu0I3QYu11^y+J~IyN`Ah&5MoW3fM|ZT8U^ z!&skSh~-Hj7meK1Af{;ot9!zjoi#*5!zo=p40g4_j5?47SXd_ zj*y!`&-ySD(?yK_m_z>^W}Od@q;ceE8vd0b>_0f8oaCm)c_lD7ltTZVr?pKR{V9h> z9!+44#BA6o7naY=ihixMd#{S1d!0FgznnOyn0&shQ+`1S9DIpi9mJtm{?RWfmCQyW z=?|kheEi8-y!CckH`b@0W$~+smKIIS{Co*VMlyJIP&Tp5+{?m2TUG)D5$yAsjkIx7 z1W)do#)lteFcq(!bz@~jqXis&KZ6}lB+z>`mxnWBu+;=g|DAwtLRm}N8@dUFq>~R4 zGu=1zMRC`?XRvmIneI*3M{&)yQ7NfmtP|1u?w`U<^7>({M=)N&vP~`;xm#~S_Za?h z${eTfzHdB+$B&z1$j#=Jcsu5mJuvojPRG8Mn0g$!Xyk5r3~eEL=Ww+UJ@t%Rh(uub z-fBTP_CaGI;`GzFEW7s!PB6rh!{W%z*ye@EOP9uoEEo4 zv(Q literal 0 HcmV?d00001 diff --git a/frontend/types/modelConfig.ts b/frontend/types/modelConfig.ts index 04d6a5ff3..2897c762d 100644 --- a/frontend/types/modelConfig.ts +++ b/frontend/types/modelConfig.ts @@ -17,6 +17,8 @@ export type ModelSource = | "openai" | "custom" | "silicon" + | "dashscope" + | "tokenpony" | "OpenAI-API-Compatible" | "modelengine"; From 463ebd525d404bef9dd9187a940626de14083d93 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Wed, 4 Mar 2026 17:09:35 +0800 Subject: [PATCH 03/75] New Requirement: Support for provider Zhipu AI Models (LLM and Embedding) --- backend/consts/provider.py | 4 +- backend/services/model_provider_service.py | 3 +- .../services/providers/dashscope_provider.py | 10 +- .../components/model/ModelDeleteDialog.tsx | 159 +++++++++++++++++- frontend/const/modelConfig.ts | 2 + 5 files changed, 166 insertions(+), 12 deletions(-) diff --git a/backend/consts/provider.py b/backend/consts/provider.py index e2a0f0235..38bbc4027 100644 --- a/backend/consts/provider.py +++ b/backend/consts/provider.py @@ -15,11 +15,11 @@ class ProviderEnum(str, Enum): SILICON_GET_URL = "https://api.siliconflow.cn/v1/models" # Dashcope -DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" +DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1/" DASHSCOPE_GET_URL = "https://dashscope.aliyuncs.com/api/v1/models" # TokenPony -TOKENPONY_BASE_URL = "https://api.tokenpony.cn/v1" +TOKENPONY_BASE_URL = "https://api.tokenpony.cn/v1/" TOKENPONY_GET_URL = "https://api.tokenpony.cn/v1/models" # ModelEngine diff --git a/backend/services/model_provider_service.py b/backend/services/model_provider_service.py index 3c916eb8c..8c397dc70 100644 --- a/backend/services/model_provider_service.py +++ b/backend/services/model_provider_service.py @@ -125,7 +125,8 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a # dimension by performing a real connectivity check. if model["model_type"] in ["embedding", "multi_embedding"]: if provider != ProviderEnum.MODELENGINE.value: - model_dict["base_url"] = f"{model_url}embeddings" + # Ensure proper slash between base URL and endpoint + model_dict["base_url"] = f"{model_url.rstrip('/')}/embeddings" else: # For ModelEngine embedding models, append the embeddings path model_dict["base_url"] = f"{model_url.rstrip('/')}/{MODEL_ENGINE_NORTH_PREFIX}/embeddings" diff --git a/backend/services/providers/dashscope_provider.py b/backend/services/providers/dashscope_provider.py index 2a34823ed..cde54b60a 100644 --- a/backend/services/providers/dashscope_provider.py +++ b/backend/services/providers/dashscope_provider.py @@ -35,16 +35,16 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: while True: params = {"page_size": 100, "page_no": current_page} response = await client.get(base_url, headers=headers, params=params) - response.raise_for_status() - - data = response.json() - models = data.get("output", {}).get("models", []) - if response.status_code == 429: await asyncio.sleep(2) continue if not models : # Break loop if no more models on the current page break + response.raise_for_status() + + data = response.json() + models = data.get("output", {}).get("models", []) + all_models.extend(models) if(len(models)<100): diff --git a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx index 541ed6266..579908d95 100644 --- a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx @@ -183,6 +183,10 @@ export const ModelDeleteDialog = ({ return t("model.source.modelEngine"); case MODEL_SOURCES.OPENAI_API_COMPATIBLE: return t("model.source.custom"); + case MODEL_SOURCES.DASHSCOPE: + return t("model.source.dashscope"); + case MODEL_SOURCES.TOKENPONY: + return t("model.source.tokenpony"); default: return t("model.source.unknown"); } @@ -217,6 +221,18 @@ export const ModelDeleteDialog = ({ text: "text-rose-600", border: "border-rose-100", }; + case MODEL_SOURCES.DASHSCOPE: + return { + bg: "bg-orange-50", + text: "text-orange-600", + border: "border-orange-100", + }; + case MODEL_SOURCES.TOKENPONY: + return { + bg: "bg-cyan-50", + text: "text-cyan-600", + border: "border-cyan-100", + }; default: return { bg: "bg-gray-50", @@ -253,6 +269,14 @@ export const ModelDeleteDialog = ({ 🛠️ ); + case MODEL_SOURCES.DASHSCOPE: + return ( + DashScope + ); + case MODEL_SOURCES.TOKENPONY: + return ( + TokenPony + ); default: return ( @@ -288,6 +312,16 @@ export const ModelDeleteDialog = ({ ); if (byModelEngine?.apiKey) return byModelEngine.apiKey; + const byDashScope = models.find( + (m) => m.source === MODEL_SOURCES.DASHSCOPE && m.type === type && m.apiKey + ); + if (byDashScope?.apiKey) return byDashScope.apiKey; + + const byTokenPony = models.find( + (m) => m.source === MODEL_SOURCES.TOKENPONY && m.type === type && m.apiKey + ); + if (byTokenPony?.apiKey) return byTokenPony.apiKey; + // Fallback: any model that has apiKey const anyWithKey = models.find((m) => m.apiKey); return anyWithKey?.apiKey || ""; @@ -327,7 +361,7 @@ export const ModelDeleteDialog = ({ return anyModelWithUrl?.apiUrl || undefined; }; - // Prefetch provider model list (supports Silicon and ModelEngine) + // Prefetch provider model list (supports Silicon, ModelEngine, DashScope, TokenPony) const prefetchProviderModels = async ( provider: ModelSource, modelType: ModelType | null @@ -351,6 +385,20 @@ export const ModelDeleteDialog = ({ apiKey: apiKey && apiKey.trim() !== "" ? apiKey : "sk-no-api-key", baseUrl: baseUrl || undefined, }); + } else if (provider === MODEL_SOURCES.DASHSCOPE) { + const apiKey = getApiKeyByType(modelType, MODEL_SOURCES.DASHSCOPE); + result = await modelService.addProviderModel({ + provider: MODEL_SOURCES.DASHSCOPE, + type: modelType, + apiKey: apiKey && apiKey.trim() !== "" ? apiKey : "sk-no-api-key", + }); + } else if (provider === MODEL_SOURCES.TOKENPONY) { + const apiKey = getApiKeyByType(modelType, MODEL_SOURCES.TOKENPONY); + result = await modelService.addProviderModel({ + provider: MODEL_SOURCES.TOKENPONY, + type: modelType, + apiKey: apiKey && apiKey.trim() !== "" ? apiKey : "sk-no-api-key", + }); } else { // Unsupported provider for prefetching return; @@ -383,7 +431,12 @@ export const ModelDeleteDialog = ({ const handleSourceSelect = async (source: ModelSource) => { setLoadingSource(source); try { - if (source === MODEL_SOURCES.SILICON || source === MODEL_SOURCES.MODELENGINE) { + if ( + source === MODEL_SOURCES.SILICON || + source === MODEL_SOURCES.MODELENGINE || + source === MODEL_SOURCES.DASHSCOPE || + source === MODEL_SOURCES.TOKENPONY + ) { await prefetchProviderModels(source, deletingModelType); } else if (source === MODEL_SOURCES.OPENAI) { // For OpenAI source, just set the selected source without prefetching @@ -543,7 +596,9 @@ export const ModelDeleteDialog = ({ setMaxTokens(maxTokens); if ( (selectedSource === MODEL_SOURCES.SILICON || - selectedSource === MODEL_SOURCES.MODELENGINE) && + selectedSource === MODEL_SOURCES.MODELENGINE || + selectedSource === MODEL_SOURCES.DASHSCOPE || + selectedSource === MODEL_SOURCES.TOKENPONY) && deletingModelType ) { try { @@ -839,6 +894,98 @@ export const ModelDeleteDialog = ({ t("model.dialog.error.addFailed", { error: e as any }) ); } + } else if ( + selectedSource === MODEL_SOURCES.DASHSCOPE && + deletingModelType + ) { + try { + const allEnabledModels = providerModels.filter( + (pm: any) => pendingSelectedProviderIds.has(pm.id) + ); + + if (allEnabledModels) { + const apiKey = getApiKeyByType(deletingModelType, MODEL_SOURCES.DASHSCOPE); + const isEmbeddingType = + deletingModelType === MODEL_TYPES.EMBEDDING || + deletingModelType === MODEL_TYPES.MULTI_EMBEDDING; + await modelService.addBatchCustomModel({ + api_key: + apiKey && apiKey.trim() !== "" + ? apiKey + : "sk-no-api-key", + provider: MODEL_SOURCES.DASHSCOPE, + type: deletingModelType, + models: allEnabledModels.map((model) => { + if (isEmbeddingType) { + const { max_tokens, ...modelWithoutMaxTokens } = + model; + return modelWithoutMaxTokens; + } else { + return { + ...model, + max_tokens: model.max_tokens || 4096, + }; + } + }), + }); + } + + await onSuccess(); + await prefetchProviderModels(selectedSource, deletingModelType); + message.success(t("model.dialog.success.updateSuccess")); + handleClose(); + } catch (e) { + log.error("Failed to apply DashScope model updates", e); + message.error( + t("model.dialog.error.addFailed", { error: e as any }) + ); + } + } else if ( + selectedSource === MODEL_SOURCES.TOKENPONY && + deletingModelType + ) { + try { + const allEnabledModels = providerModels.filter( + (pm: any) => pendingSelectedProviderIds.has(pm.id) + ); + + if (allEnabledModels) { + const apiKey = getApiKeyByType(deletingModelType, MODEL_SOURCES.TOKENPONY); + const isEmbeddingType = + deletingModelType === MODEL_TYPES.EMBEDDING || + deletingModelType === MODEL_TYPES.MULTI_EMBEDDING; + await modelService.addBatchCustomModel({ + api_key: + apiKey && apiKey.trim() !== "" + ? apiKey + : "sk-no-api-key", + provider: MODEL_SOURCES.TOKENPONY, + type: deletingModelType, + models: allEnabledModels.map((model) => { + if (isEmbeddingType) { + const { max_tokens, ...modelWithoutMaxTokens } = + model; + return modelWithoutMaxTokens; + } else { + return { + ...model, + max_tokens: model.max_tokens || 4096, + }; + } + }), + }); + } + + await onSuccess(); + await prefetchProviderModels(selectedSource, deletingModelType); + message.success(t("model.dialog.success.updateSuccess")); + handleClose(); + } catch (e) { + log.error("Failed to apply TokenPony model updates", e); + message.error( + t("model.dialog.error.addFailed", { error: e as any }) + ); + } } else if ( selectedSource === MODEL_SOURCES.OPENAI && deletingModelType @@ -976,6 +1123,8 @@ export const ModelDeleteDialog = ({ MODEL_SOURCES.OPENAI, MODEL_SOURCES.SILICON, MODEL_SOURCES.OPENAI_API_COMPATIBLE, + MODEL_SOURCES.DASHSCOPE, + MODEL_SOURCES.TOKENPONY, ] as ModelSource[] ).map((source) => { const modelsOfSource = models.filter( @@ -1074,7 +1223,9 @@ export const ModelDeleteDialog = ({ onClick={async () => { if ( (selectedSource === MODEL_SOURCES.SILICON || - selectedSource === MODEL_SOURCES.MODELENGINE) && + selectedSource === MODEL_SOURCES.MODELENGINE || + selectedSource === MODEL_SOURCES.DASHSCOPE || + selectedSource === MODEL_SOURCES.TOKENPONY) && deletingModelType ) { try { diff --git a/frontend/const/modelConfig.ts b/frontend/const/modelConfig.ts index 9b0128529..4c412824a 100644 --- a/frontend/const/modelConfig.ts +++ b/frontend/const/modelConfig.ts @@ -16,6 +16,8 @@ export const MODEL_SOURCES = { MODELENGINE: "modelengine", OPENAI_API_COMPATIBLE: "OpenAI-API-Compatible", CUSTOM: "custom", + DASHSCOPE: "dashscope", + TOKENPONY: "tokenpony", } as const; // Model status constants From eb12b0abd47689c326470de62419536af7fe3cc2 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Wed, 4 Mar 2026 17:23:38 +0800 Subject: [PATCH 04/75] New Requirement: Support for provider dashscope and tokenpony Models (LLM and Embedding) --- backend/services/providers/tokenpony_provider.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py index 62972b698..6fe67502e 100644 --- a/backend/services/providers/tokenpony_provider.py +++ b/backend/services/providers/tokenpony_provider.py @@ -72,16 +72,14 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: "model_type": "", "max_tokens": DEFAULT_LLM_MAX_TOKENS } - # 1. Embedding - if 'embedding' in m_id or m_id.startswith('bge-'): - cleaned_model.update({"model_tag": "embedding", "model_type": "embedding", "max_tokens": 0}) - categorized_models['embedding'].append(cleaned_model) - - # 2. Reranker - elif 'rerank' in m_id: + # 1. reranker + if 'rerank' in m_id: cleaned_model.update({"model_tag": "reranker", "model_type": "reranker"}) categorized_models['reranker'].append(cleaned_model) - + #2. embedding + elif 'embedding' in m_id or m_id.startswith('bge-'): + cleaned_model.update({"model_tag": "embedding", "model_type": "embedding", "max_tokens": 0}) + categorized_models['embedding'].append(cleaned_model) # 3. STT (Speech-to-Text / Audio understanding) elif 'stt' in m_id: From 36b8be90cf49945e8d7b58f77572221b204d3cc8 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Thu, 5 Mar 2026 13:37:24 +0800 Subject: [PATCH 05/75] bug fix : embedding model max_tokens changes --- backend/services/providers/tokenpony_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py index 6fe67502e..844dd1859 100644 --- a/backend/services/providers/tokenpony_provider.py +++ b/backend/services/providers/tokenpony_provider.py @@ -78,7 +78,7 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: categorized_models['reranker'].append(cleaned_model) #2. embedding elif 'embedding' in m_id or m_id.startswith('bge-'): - cleaned_model.update({"model_tag": "embedding", "model_type": "embedding", "max_tokens": 0}) + cleaned_model.update({"model_tag": "embedding", "model_type": "embedding"}) categorized_models['embedding'].append(cleaned_model) # 3. STT (Speech-to-Text / Audio understanding) From 347066293e7aef4d3dcaa72bea3f4de8d59c2090 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Thu, 5 Mar 2026 13:39:22 +0800 Subject: [PATCH 06/75] bug fix : embedding model max_tokens changes --- backend/services/providers/tokenpony_provider.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py index 844dd1859..42e5d178c 100644 --- a/backend/services/providers/tokenpony_provider.py +++ b/backend/services/providers/tokenpony_provider.py @@ -35,13 +35,7 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE ssl_context.set_ciphers("DEFAULT@SECLEVEL=1") - # response = requests.get(url, headers=headers) - # all_models=[] - # if response.status_code == 200: - # data = response.json() - # # 注意:OpenAI 标准返回是在 "data" 字段下 - # all_models=data.get("data", []) - # Fetch all models asynchronously + async with httpx.AsyncClient(http2=True) as client: response = await client.get(url, headers=headers) response.raise_for_status() From fb16d9365d4eaa8fa911641db058aaa3f82b1ad7 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Thu, 5 Mar 2026 14:50:23 +0800 Subject: [PATCH 07/75] create test files for the backend providers --- .../providers/test_dashscope_provider.py | 718 ++++++++++++++++++ .../providers/test_tokenpony_provider.py | 711 +++++++++++++++++ .../services/test_model_management_service.py | 4 + .../services/test_model_provider_service.py | 124 +++ 4 files changed, 1557 insertions(+) create mode 100644 test/backend/services/providers/test_dashscope_provider.py create mode 100644 test/backend/services/providers/test_tokenpony_provider.py diff --git a/test/backend/services/providers/test_dashscope_provider.py b/test/backend/services/providers/test_dashscope_provider.py new file mode 100644 index 000000000..2dc3a8f27 --- /dev/null +++ b/test/backend/services/providers/test_dashscope_provider.py @@ -0,0 +1,718 @@ +"""Unit tests for DashScopeModelProvider module. + +Tests cover model fetching, type classification, and error handling. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from pytest_mock import MockFixture + +import httpx + +from backend.services.providers.dashscope_provider import DashScopeModelProvider + + +class TestDashScopeModelProvider: + """Tests for DashScopeModelProvider class.""" + + @pytest.mark.asyncio + async def test_get_models_llm_success(self, mocker: MockFixture): + """Test successful model retrieval for LLM models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-turbo", + "description": "Text generation model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + }, + { + "model": "qwen-plus", + "description": "Advanced text generation", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DEFAULT_LLM_MAX_TOKENS", + 4096 + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 2 + assert result[0]["id"] == "qwen-turbo" + assert result[0]["model_type"] == "llm" + assert result[0]["model_tag"] == "chat" + assert result[0]["max_tokens"] == 4096 + + @pytest.mark.asyncio + async def test_get_models_embedding_success(self, mocker: MockFixture): + """Test successful model retrieval for embedding models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "text-embedding-v3", + "description": "Embedding model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "embedding", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "text-embedding-v3" + assert result[0]["model_type"] == "embedding" + assert result[0]["model_tag"] == "embedding" + + @pytest.mark.asyncio + async def test_get_models_vlm_success(self, mocker: MockFixture): + """Test successful model retrieval for VLM models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-vl-plus", + "description": "Vision language model", + "inference_metadata": { + "request_modality": ["Image", "Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "vlm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "qwen-vl-plus" + assert result[0]["model_type"] == "vlm" + assert result[0]["model_tag"] == "chat" + + @pytest.mark.asyncio + async def test_get_models_reranker_success(self, mocker: MockFixture): + """Test successful model retrieval for reranker models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "gte-reranker", + "description": "Reranking model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "reranker", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "gte-reranker" + assert result[0]["model_type"] == "reranker" + assert result[0]["model_tag"] == "reranker" + + @pytest.mark.asyncio + async def test_get_models_tts_success(self, mocker: MockFixture): + """Test successful model retrieval for TTS models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "sambert-tts", + "description": "Text to speech", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Audio"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "tts", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "sambert-tts" + assert result[0]["model_type"] == "tts" + assert result[0]["model_tag"] == "tts" + + @pytest.mark.asyncio + async def test_get_models_stt_success(self, mocker: MockFixture): + """Test successful model retrieval for STT models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "paraformer-realtime-v2", + "description": "Speech recognition", + "inference_metadata": { + "request_modality": ["Audio"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "stt", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "paraformer-realtime-v2" + assert result[0]["model_type"] == "stt" + assert result[0]["model_tag"] == "stt" + + @pytest.mark.asyncio + async def test_get_models_multi_embedding_success(self, mocker: MockFixture): + """Test successful model retrieval for multi-embedding models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "text-embedding-multimodal-v3", + "description": "Multimodal embedding", + "inference_metadata": { + "request_modality": ["Text", "Image"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "multi_embedding", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "text-embedding-multimodal-v3" + assert result[0]["model_type"] == "embedding" + + @pytest.mark.asyncio + async def test_get_models_empty_response(self, mocker: MockFixture): + """Test handling of empty model list from API.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"output": {"models": []}} + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert result == [] + + @pytest.mark.asyncio + async def test_get_models_http_error(self, mocker: MockFixture): + """Test handling of HTTP error.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.HTTPStatusError( + "Error", + request=MagicMock(), + response=MagicMock(status_code=500) + ) + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["_error"] == "connection_failed" + + @pytest.mark.asyncio + async def test_get_models_connect_error(self, mocker: MockFixture): + """Test handling of connection error.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.ConnectError("Connection failed") + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["_error"] == "connection_failed" + + @pytest.mark.asyncio + async def test_get_models_timeout(self, mocker: MockFixture): + """Test handling of connection timeout.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.ConnectTimeout("Timeout") + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["_error"] == "connection_failed" + + @pytest.mark.asyncio + async def test_get_models_authorization_header(self, mocker: MockFixture): + """Test that Authorization header is correctly set.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-turbo", + "description": "Test", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "my-secret-key" + } + + await provider.get_models(provider_config) + + # Verify Authorization header + call_args = mock_client.get.call_args + headers = call_args[1]["headers"] + assert headers["Authorization"] == "Bearer my-secret-key" + + @pytest.mark.asyncio + async def test_get_models_pagination(self, mocker: MockFixture): + """Test that pagination works correctly.""" + # First page returns 100 models + mock_response_page1 = MagicMock() + mock_response_page1.status_code = 200 + mock_response_page1.json.return_value = { + "output": { + "models": [{"model": f"model-{i}", "description": "test", + "inference_metadata": {"request_modality": ["Text"], "response_modality": ["Text"]}} + for i in range(100)] + } + } + mock_response_page1.raise_for_status = MagicMock() + + # Second page returns 50 models (less than page_size) + mock_response_page2 = MagicMock() + mock_response_page2.status_code = 200 + mock_response_page2.json.return_value = { + "output": { + "models": [{"model": f"model-{i}", "description": "test", + "inference_metadata": {"request_modality": ["Text"], "response_modality": ["Text"]}} + for i in range(100, 150)] + } + } + mock_response_page2.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.side_effect = [mock_response_page1, mock_response_page2] + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + # Should get models from both pages + assert len(result) == 150 + + @pytest.mark.asyncio + async def test_get_models_unknown_type_returns_empty(self, mocker: MockFixture): + """Test that unknown model type returns empty list.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-turbo", + "description": "Text generation", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "unknown_type", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert result == [] + + @pytest.mark.asyncio + async def test_get_models_with_chinese_description(self, mocker: MockFixture): + """Test model classification by Chinese description.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "embedding-v1", + "description": "向量embedding模型", # Chinese description + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + }, + { + "model": "rerank-v1", + "description": "重排序模型", # Chinese description + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + + # Test embedding classification by Chinese description + result = await provider.get_models({"model_type": "embedding", "api_key": "test-key"}) + assert len(result) == 1 + assert result[0]["id"] == "embedding-v1" + + # Test reranker classification by Chinese description + result = await provider.get_models({"model_type": "reranker", "api_key": "test-key"}) + assert len(result) == 1 + assert result[0]["id"] == "rerank-v1" + diff --git a/test/backend/services/providers/test_tokenpony_provider.py b/test/backend/services/providers/test_tokenpony_provider.py new file mode 100644 index 000000000..4f4a564e1 --- /dev/null +++ b/test/backend/services/providers/test_tokenpony_provider.py @@ -0,0 +1,711 @@ +"""Unit tests for TokenPonyModelProvider module. + +Tests cover model fetching, type classification, and error handling. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from pytest_mock import MockFixture + +import httpx + +from backend.services.providers.tokenpony_provider import TokenPonyModelProvider + + +class TestTokenPonyModelProvider: + """Tests for TokenPonyModelProvider class.""" + + @pytest.mark.asyncio + async def test_get_models_llm_success(self, mocker: MockFixture): + """Test successful model retrieval for LLM models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "gpt-4", + "object": "model", + "owned_by": "openai" + }, + { + "id": "claude-3-opus", + "object": "model", + "owned_by": "anthropic" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.DEFAULT_LLM_MAX_TOKENS", + 4096 + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 2 + assert result[0]["id"] == "gpt-4" + assert result[0]["model_type"] == "llm" + assert result[0]["model_tag"] == "chat" + assert result[0]["max_tokens"] == 4096 + + @pytest.mark.asyncio + async def test_get_models_embedding_success(self, mocker: MockFixture): + """Test successful model retrieval for embedding models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "text-embedding-ada-002", + "object": "model", + "owned_by": "openai" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "embedding", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "text-embedding-ada-002" + assert result[0]["model_type"] == "embedding" + assert result[0]["model_tag"] == "embedding" + + @pytest.mark.asyncio + async def test_get_models_vlm_success(self, mocker: MockFixture): + """Test successful model retrieval for VLM models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "qwen-vl-plus", + "object": "model", + "owned_by": "qwen" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "vlm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "qwen-vl-plus" + assert result[0]["model_type"] == "vlm" + assert result[0]["model_tag"] == "chat" + + @pytest.mark.asyncio + async def test_get_models_reranker_success(self, mocker: MockFixture): + """Test successful model retrieval for reranker models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "gte-reranker-base", + "object": "model", + "owned_by": "gte" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "reranker", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "gte-reranker-base" + assert result[0]["model_type"] == "reranker" + assert result[0]["model_tag"] == "reranker" + + @pytest.mark.asyncio + async def test_get_models_tts_success(self, mocker: MockFixture): + """Test successful model retrieval for TTS models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "tts-1-hd", + "object": "model", + "owned_by": "openai" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "tts", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "tts-1-hd" + assert result[0]["model_type"] == "tts" + assert result[0]["model_tag"] == "tts" + + @pytest.mark.asyncio + async def test_get_models_stt_success(self, mocker: MockFixture): + """Test successful model retrieval for STT models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "whisper-1", + "object": "model", + "owned_by": "openai" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "stt", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "whisper-1" + assert result[0]["model_type"] == "stt" + assert result[0]["model_tag"] == "stt" + + @pytest.mark.asyncio + async def test_get_models_multi_embedding_success(self, mocker: MockFixture): + """Test successful model retrieval for multi-embedding models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "bge-large", + "object": "model", + "owned_by": "bge" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "multi_embedding", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "bge-large" + assert result[0]["model_type"] == "embedding" + + @pytest.mark.asyncio + async def test_get_models_empty_response(self, mocker: MockFixture): + """Test handling of empty model list from API.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"data": []} + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert result == [] + + @pytest.mark.asyncio + async def test_get_models_http_error(self, mocker: MockFixture): + """Test handling of HTTP error.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.HTTPStatusError( + "Error", + request=MagicMock(), + response=MagicMock(status_code=500) + ) + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["_error"] == "connection_failed" + + @pytest.mark.asyncio + async def test_get_models_connect_error(self, mocker: MockFixture): + """Test handling of connection error.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.ConnectError("Connection failed") + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["_error"] == "connection_failed" + + @pytest.mark.asyncio + async def test_get_models_timeout(self, mocker: MockFixture): + """Test handling of connection timeout.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.ConnectTimeout("Timeout") + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["_error"] == "connection_failed" + + @pytest.mark.asyncio + async def test_get_models_authorization_header(self, mocker: MockFixture): + """Test that Authorization header is correctly set.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "gpt-4", + "object": "model", + "owned_by": "openai" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "my-secret-key" + } + + await provider.get_models(provider_config) + + # Verify Authorization header + call_args = mock_client.get.call_args + headers = call_args[1]["headers"] + assert headers["Authorization"] == "Bearer my-secret-key" + + @pytest.mark.asyncio + async def test_get_models_unknown_type_returns_empty(self, mocker: MockFixture): + """Test that unknown model type returns empty list.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "gpt-4", + "object": "model", + "owned_by": "openai" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "unknown_type", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert result == [] + + @pytest.mark.asyncio + async def test_get_models_vlm_by_keyword(self, mocker: MockFixture): + """Test VLM classification by keywords like -vl, vl-, ocr, vision.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "qwen-vl-plus", + "object": "model", + "owned_by": "qwen" + }, + { + "id": "vl-ocr-v1", + "object": "model", + "owned_by": "ocr" + }, + { + "id": "vision-model-v2", + "object": "model", + "owned_by": "vision" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "vlm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 3 + for model in result: + assert model["model_type"] == "vlm" + assert model["model_tag"] == "chat" + + @pytest.mark.asyncio + async def test_get_models_bge_prefix_embedding(self, mocker: MockFixture): + """Test that models with bge- prefix are classified as embedding.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "bge-large-zh-v1.5", + "object": "model", + "owned_by": "bge" + }, + { + "id": "bge-base-en-v1.5", + "object": "model", + "owned_by": "bge" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "embedding", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 2 + for model in result: + assert model["model_type"] == "embedding" + assert model["model_tag"] == "embedding" + + @pytest.mark.asyncio + async def test_get_models_llm_has_max_tokens(self, mocker: MockFixture): + """Test that LLM models have max_tokens set.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "gpt-4", + "object": "model", + "owned_by": "openai" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.DEFAULT_LLM_MAX_TOKENS", + 4096 + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["max_tokens"] == 4096 + diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 6d0806299..e5d52d31a 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -120,10 +120,14 @@ class _Func: class _ProviderEnum: SILICON = _EnumItem("silicon") MODELENGINE = _EnumItem("modelengine") + DASHSCOPE = _EnumItem("dashscope") + TOKENPONY = _EnumItem("tokenpony") consts_provider_mod.ProviderEnum = _ProviderEnum consts_provider_mod.SILICON_BASE_URL = "http://silicon.test" +consts_provider_mod.DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1/" +consts_provider_mod.TOKENPONY_BASE_URL = "https://api.tokenpony.cn/v1/" sys.modules["consts.provider"] = consts_provider_mod # Stub services.model_provider_service used by service diff --git a/test/backend/services/test_model_provider_service.py b/test/backend/services/test_model_provider_service.py index f81222056..992025754 100644 --- a/test/backend/services/test_model_provider_service.py +++ b/test/backend/services/test_model_provider_service.py @@ -157,6 +157,8 @@ def __init__(self): class _ProviderEnumStub: SILICON = mock.Mock(value="silicon") MODELENGINE = mock.Mock(value="modelengine") + DASHSCOPE = mock.Mock(value="dashscope") + TOKENPONY = mock.Mock(value="tokenpony") sys.modules["consts.provider"].ProviderEnum = _ProviderEnumStub @@ -1903,3 +1905,125 @@ def test_get_model_engine_raw_url_trailing_slash(): for input_url, expected in test_cases: result = get_model_engine_raw_url(input_url) assert result == expected, f"Failed for input: {input_url}" + + +# ============================================================================ +# Test-cases for get_provider_models with DashScope provider +# ============================================================================ + + +@pytest.mark.asyncio +async def test_get_provider_models_dashscope_success(): + """Should successfully get models from DashScope provider.""" + from backend.services.model_provider_service import DashScopeModelProvider + + model_data = { + "provider": "dashscope", + "model_type": "llm", + "api_key": "test-key", + } + + expected_models = [ + { + "id": "qwen-turbo", + "model_tag": "chat", + "model_type": "llm", + "max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS, + } + ] + + with mock.patch( + "backend.services.model_provider_service.DashScopeModelProvider" + ) as mock_provider_class: + mock_provider_instance = mock.AsyncMock() + mock_provider_instance.get_models.return_value = expected_models + mock_provider_class.return_value = mock_provider_instance + + result = await get_provider_models(model_data) + + assert result == expected_models + mock_provider_class.assert_called_once() + mock_provider_instance.get_models.assert_called_once_with(model_data) + + +@pytest.mark.asyncio +async def test_get_provider_models_dashscope_empty_result(): + """Should handle empty result from DashScope provider.""" + model_data = { + "provider": "dashscope", + "model_type": "embedding", + "api_key": "test-key", + } + + with mock.patch( + "backend.services.model_provider_service.DashScopeModelProvider" + ) as mock_provider_class: + mock_provider_instance = mock.AsyncMock() + mock_provider_instance.get_models.return_value = [] + mock_provider_class.return_value = mock_provider_instance + + result = await get_provider_models(model_data) + + assert result == [] + mock_provider_instance.get_models.assert_called_once_with(model_data) + + +# ============================================================================ +# Test-cases for get_provider_models with TokenPony provider +# ============================================================================ + + +@pytest.mark.asyncio +async def test_get_provider_models_tokenpony_success(): + """Should successfully get models from TokenPony provider.""" + from backend.services.model_provider_service import TokenPonyModelProvider + + model_data = { + "provider": "tokenpony", + "model_type": "llm", + "api_key": "test-key", + } + + expected_models = [ + { + "id": "gpt-4", + "model_tag": "chat", + "model_type": "llm", + "max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS, + } + ] + + with mock.patch( + "backend.services.model_provider_service.TokenPonyModelProvider" + ) as mock_provider_class: + mock_provider_instance = mock.AsyncMock() + mock_provider_instance.get_models.return_value = expected_models + mock_provider_class.return_value = mock_provider_instance + + result = await get_provider_models(model_data) + + assert result == expected_models + mock_provider_class.assert_called_once() + mock_provider_instance.get_models.assert_called_once_with(model_data) + + +@pytest.mark.asyncio +async def test_get_provider_models_tokenpony_empty_result(): + """Should handle empty result from TokenPony provider.""" + model_data = { + "provider": "tokenpony", + "model_type": "embedding", + "api_key": "test-key", + } + + with mock.patch( + "backend.services.model_provider_service.TokenPonyModelProvider" + ) as mock_provider_class: + mock_provider_instance = mock.AsyncMock() + mock_provider_instance.get_models.return_value = [] + mock_provider_class.return_value = mock_provider_instance + + result = await get_provider_models(model_data) + + assert result == [] + mock_provider_instance.get_models.assert_called_once_with(model_data) \ No newline at end of file From 941cac22d6498c841267d1f72e8b2f6d96f6061c Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Thu, 5 Mar 2026 15:50:48 +0800 Subject: [PATCH 08/75] bugfix for test files of the backend providers --- .../services/providers/dashscope_provider.py | 7 +- .../providers/test_dashscope_provider.py | 164 ++++-------------- .../providers/test_tokenpony_provider.py | 4 +- 3 files changed, 38 insertions(+), 137 deletions(-) diff --git a/backend/services/providers/dashscope_provider.py b/backend/services/providers/dashscope_provider.py index cde54b60a..4ecbcbb1d 100644 --- a/backend/services/providers/dashscope_provider.py +++ b/backend/services/providers/dashscope_provider.py @@ -38,16 +38,17 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: if response.status_code == 429: await asyncio.sleep(2) continue - if not models : # Break loop if no more models on the current page - break response.raise_for_status() data = response.json() models = data.get("output", {}).get("models", []) + # Break loop if no more models on the current page + if not models: + break all_models.extend(models) - if(len(models)<100): + if len(models) < 100: break current_page += 1 await asyncio.sleep(0.5) diff --git a/test/backend/services/providers/test_dashscope_provider.py b/test/backend/services/providers/test_dashscope_provider.py index 2dc3a8f27..44bbdbda5 100644 --- a/test/backend/services/providers/test_dashscope_provider.py +++ b/test/backend/services/providers/test_dashscope_provider.py @@ -4,7 +4,7 @@ """ import pytest -from unittest.mock import MagicMock, AsyncMock, patch +from unittest.mock import MagicMock, AsyncMock, patch, Mock from pytest_mock import MockFixture import httpx @@ -15,6 +15,27 @@ class TestDashScopeModelProvider: """Tests for DashScopeModelProvider class.""" + def _setup_mock_client(self, mocker, mock_response): + """Set up mock for httpx.AsyncClient with proper context manager.""" + # Create mock client that handles the get request + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + # Create context manager mock + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + # Create a mock class that can be called with verify=False + mock_client_class = Mock(return_value=mock_cm) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + mock_client_class + ) + + return mock_client_class + @pytest.mark.asyncio async def test_get_models_llm_success(self, mocker: MockFixture): """Test successful model retrieval for LLM models.""" @@ -44,17 +65,8 @@ async def test_get_models_llm_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response + self._setup_mock_client(mocker, mock_response) - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -99,17 +111,8 @@ async def test_get_models_embedding_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) + self._setup_mock_client(mocker, mock_response) - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -149,17 +152,8 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response + self._setup_mock_client(mocker, mock_response) - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -199,17 +193,8 @@ async def test_get_models_reranker_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) + self._setup_mock_client(mocker, mock_response) - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -249,17 +234,8 @@ async def test_get_models_tts_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) + self._setup_mock_client(mocker, mock_response) - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -299,17 +275,8 @@ async def test_get_models_stt_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response + self._setup_mock_client(mocker, mock_response) - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -349,17 +316,8 @@ async def test_get_models_multi_embedding_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response + self._setup_mock_client(mocker, mock_response) - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -385,17 +343,8 @@ async def test_get_models_empty_response(self, mocker: MockFixture): mock_response.json.return_value = {"output": {"models": []}} mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response + self._setup_mock_client(mocker, mock_response) - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -429,10 +378,6 @@ async def test_get_models_http_error(self, mocker: MockFixture): "backend.services.providers.dashscope_provider.httpx.AsyncClient", return_value=mock_cm ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) provider = DashScopeModelProvider() provider_config = { @@ -460,10 +405,6 @@ async def test_get_models_connect_error(self, mocker: MockFixture): "backend.services.providers.dashscope_provider.httpx.AsyncClient", return_value=mock_cm ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) provider = DashScopeModelProvider() provider_config = { @@ -491,10 +432,6 @@ async def test_get_models_timeout(self, mocker: MockFixture): "backend.services.providers.dashscope_provider.httpx.AsyncClient", return_value=mock_cm ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) provider = DashScopeModelProvider() provider_config = { @@ -540,10 +477,6 @@ async def test_get_models_authorization_header(self, mocker: MockFixture): "backend.services.providers.dashscope_provider.httpx.AsyncClient", return_value=mock_cm ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) provider = DashScopeModelProvider() provider_config = { @@ -596,10 +529,6 @@ async def test_get_models_pagination(self, mocker: MockFixture): "backend.services.providers.dashscope_provider.httpx.AsyncClient", return_value=mock_cm ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) provider = DashScopeModelProvider() provider_config = { @@ -633,21 +562,7 @@ async def test_get_models_unknown_type_returns_empty(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) + self._setup_mock_client(mocker, mock_response) provider = DashScopeModelProvider() provider_config = { @@ -688,21 +603,7 @@ async def test_get_models_with_chinese_description(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) + self._setup_mock_client(mocker, mock_response) provider = DashScopeModelProvider() @@ -715,4 +616,3 @@ async def test_get_models_with_chinese_description(self, mocker: MockFixture): result = await provider.get_models({"model_type": "reranker", "api_key": "test-key"}) assert len(result) == 1 assert result[0]["id"] == "rerank-v1" - diff --git a/test/backend/services/providers/test_tokenpony_provider.py b/test/backend/services/providers/test_tokenpony_provider.py index 4f4a564e1..7fd9df9eb 100644 --- a/test/backend/services/providers/test_tokenpony_provider.py +++ b/test/backend/services/providers/test_tokenpony_provider.py @@ -258,7 +258,7 @@ async def test_get_models_stt_success(self, mocker: MockFixture): mock_response.json.return_value = { "data": [ { - "id": "whisper-1", + "id": "stt-whisper-1", "object": "model", "owned_by": "openai" } @@ -291,7 +291,7 @@ async def test_get_models_stt_success(self, mocker: MockFixture): result = await provider.get_models(provider_config) assert len(result) == 1 - assert result[0]["id"] == "whisper-1" + assert result[0]["id"] == "stt-whisper-1" assert result[0]["model_type"] == "stt" assert result[0]["model_tag"] == "stt" From 7b12f126ab661810e6b549dbe433b84c710a0100 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Fri, 6 Mar 2026 10:52:31 +0800 Subject: [PATCH 09/75] Refactor: Source the default available model from configuration, not from the model list --- .../components/agentConfig/ToolManagement.tsx | 94 ++----------------- .../agentInfo/AgentGenerateDetail.tsx | 21 ++++- frontend/hooks/model/useModelList.ts | 40 +------- frontend/hooks/useConfig.ts | 12 +++ 4 files changed, 42 insertions(+), 125 deletions(-) diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index f5815a094..18d467317 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -1,13 +1,12 @@ "use client"; -import { useState, useEffect, useCallback, useMemo } from "react"; +import { useState, useEffect, useCallback } from "react"; import { useTranslation } from "react-i18next"; import ToolConfigModal from "./tool/ToolConfigModal"; import { ToolGroup, Tool, ToolParam } from "@/types/agentConfig"; import { Tabs, Collapse, message, Tooltip } from "antd"; import { useAgentConfigStore } from "@/stores/agentConfigStore"; import { useToolList } from "@/hooks/agent/useToolList"; -import { useModelList } from "@/hooks/model/useModelList"; import { usePrefetchKnowledgeBases } from "@/hooks/useKnowledgeBaseSelector"; import { useConfig } from "@/hooks/useConfig"; import { updateToolConfig } from "@/services/agentConfigService"; @@ -98,73 +97,7 @@ export default function ToolManagement({ // Use tool list hook for data management const { availableTools } = useToolList(); - // Get config for model checks - const { modelConfig: tenantModelConfig } = useConfig(); - - // Get VLM models to check availability - const { availableVlmModels, models } = useModelList(); - - // Check if VLM is properly configured: - // 1. Must have at least one VLM model that passed health check (available) - // 2. Must have a VLM model selected in tenant configuration - const isVlmConfigured = useMemo(() => { - // Check if there's any available VLM model - if (!availableVlmModels || availableVlmModels.length === 0) { - return false; - } - - // Check if tenant configuration has selected a VLM model - try { - const selectedVlmModelName = tenantModelConfig?.vlm?.modelName || tenantModelConfig?.vlm?.displayName; - - if (!selectedVlmModelName) { - return false; - } - - // Check if the selected VLM model exists in available models - const isSelectedModelAvailable = availableVlmModels.some( - (model) => model.name === selectedVlmModelName || model.displayName === selectedVlmModelName - ); - - return isSelectedModelAvailable; - } catch (error) { - return false; - } - }, [availableVlmModels, models, tenantModelConfig]); - - // Get Embedding models to check availability - const { availableEmbeddingModels } = useModelList(); - - // Check if Embedding is properly configured: - // 1. Must have at least one Embedding model that passed health check (available) - // 2. Must have an Embedding model selected in tenant configuration - const isEmbeddingConfigured = useMemo(() => { - // Check if there's any available Embedding model - if (!availableEmbeddingModels || availableEmbeddingModels.length === 0) { - return false; - } - - // Check if tenant configuration has selected an Embedding model - try { - const selectedEmbeddingModelName = - tenantModelConfig?.embedding?.modelName || tenantModelConfig?.embedding?.displayName; - - if (!selectedEmbeddingModelName) { - return false; - } - - // Check if the selected Embedding model exists in available models - const isSelectedModelAvailable = availableEmbeddingModels.some( - (model) => - model.name === selectedEmbeddingModelName || - model.displayName === selectedEmbeddingModelName - ); - - return isSelectedModelAvailable; - } catch (error) { - return false; - } - }, [availableEmbeddingModels, models, tenantModelConfig]); + const { isVlmAvailable, isEmbeddingAvailable } = useConfig(); // Prefetch knowledge bases for KB tools const { prefetchKnowledgeBases } = usePrefetchKnowledgeBases(); @@ -235,9 +168,7 @@ export default function ToolManagement({ (t) => parseInt(t.id) === parseInt(tool.id) ); // Merge configured tool with original tool to ensure all fields are present - const toolToUse = configuredTool - ? { ...tool, ...configuredTool, initParams: configuredTool.initParams } - : tool; + const toolToUse = configuredTool ? { ...tool, ...configuredTool, initParams: configuredTool.initParams } : tool; // Get merged parameters (for editing mode, merge with instance params) const mergedParams = await mergeToolParamsWithInstance( @@ -264,23 +195,18 @@ export default function ToolManagement({ } // Get latest tools directly from store to avoid stale closure issues - const currentSelectdTools = - useAgentConfigStore.getState().editedAgent.tools; + const currentSelectdTools = useAgentConfigStore.getState().editedAgent.tools; const isCurrentlySelected = currentSelectdTools.some( (t) => parseInt(t.id) === numericId ); if (isCurrentlySelected) { // If already selected, deselect it - const newSelectedTools = currentSelectdTools.filter( - (t) => parseInt(t.id) !== numericId - ); + const newSelectedTools = currentSelectdTools.filter((t) => parseInt(t.id) !== numericId); updateTools(newSelectedTools); } else { // If not selected, determine tool params and check if modal is needed - const configuredTool = currentSelectdTools.find( - (t) => parseInt(t.id) === numericId - ); + const configuredTool = currentSelectdTools.find((t) => parseInt(t.id) === numericId); // Merge configured tool with original tool to ensure all fields are present const toolToUse = configuredTool ? { ...tool, ...configuredTool, initParams: configuredTool.initParams } @@ -428,8 +354,8 @@ export default function ToolManagement({ const isSelected = originalSelectedToolIdsSet.has( tool.id ); - const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmConfigured); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingConfigured); + const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly @@ -533,8 +459,8 @@ export default function ToolManagement({ > {group.tools.map((tool) => { const isSelected = originalSelectedToolIdsSet.has(tool.id); - const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmConfigured); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingConfigured); + const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly diff --git a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx index 80b48fe34..dcd2ed0fb 100644 --- a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx +++ b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx @@ -29,6 +29,7 @@ import { generatePromptStream } from "@/services/promptService"; import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; import { useDeployment } from "@/components/providers/deploymentProvider"; import { useModelList } from "@/hooks/model/useModelList"; +import { useConfig } from "@/hooks/useConfig"; import { useTenantList } from "@/hooks/tenant/useTenantList"; import { useGroupList } from "@/hooks/group/useGroupList"; import { USER_ROLES } from "@/const/auth"; @@ -62,8 +63,24 @@ export default function AgentGenerateDetail({ const updateBusinessInfo = useAgentConfigStore((state) => state.updateBusinessInfo); const updateProfileInfo = useAgentConfigStore((state) => state.updateProfileInfo); - // Model data from React Query - const { availableLlmModels, defaultLlmModel, isLoading: loadingModels } = useModelList(); + // Model data: default LLM name from config, resolve to full model from model list + const { defaultLlmModelName } = useConfig(); + const { availableLlmModels, models, isLoading: loadingModels } = useModelList(); + const defaultLlmModel = useMemo(() => { + if (defaultLlmModelName) { + const found = availableLlmModels.find( + (m) => m.name === defaultLlmModelName || m.displayName === defaultLlmModelName + ); + if (found) return found; + return models.find( + (m) => + m.type === "llm" && + (m.name === defaultLlmModelName || m.displayName === defaultLlmModelName) + ); + } + // No default configured: use the first available LLM, or undefined if none + return availableLlmModels[0]; + }, [defaultLlmModelName, availableLlmModels, models]); // Tenant & group data for group selection const { data: tenantData } = useTenantList(); diff --git a/frontend/hooks/model/useModelList.ts b/frontend/hooks/model/useModelList.ts index 7a30255be..f6ff1dce1 100644 --- a/frontend/hooks/model/useModelList.ts +++ b/frontend/hooks/model/useModelList.ts @@ -2,8 +2,6 @@ import { useQuery, useQueryClient } from "@tanstack/react-query"; import { modelService } from "@/services/modelService"; import { ModelOption } from "@/types/modelConfig"; import { useMemo } from "react"; -import { useConfig } from "@/hooks/useConfig"; - export function useModelList(options?: { enabled?: boolean; staleTime?: number }) { const queryClient = useQueryClient(); @@ -48,41 +46,6 @@ export function useModelList(options?: { enabled?: boolean; staleTime?: number } return models.filter((model) => model.type === "vlm" && model.connect_status === "available"); }, [models]); - const { modelConfig: tenantModelConfig } = useConfig(); - - // Get default LLM model from tenant configuration - const defaultLlmModel = useMemo(() => { - try { - const defaultModelName = tenantModelConfig?.llm?.modelName || tenantModelConfig?.llm?.displayName; - - if (defaultModelName) { - // First try to find by name in available LLM models (should be available) - let defaultModel = availableLlmModels.find(model => - model.name === defaultModelName || - model.displayName === defaultModelName - ); - - // If not found in available models, try all models but only if they're LLM type - if (!defaultModel) { - defaultModel = models.find(model => - model.type === "llm" && ( - model.name === defaultModelName || - model.displayName === defaultModelName - ) - ); - } - - return defaultModel; // Return the found model or undefined if not found - } - - // If no default configured, return undefined - return undefined; - } catch (error) { - return undefined; - } - }, [models, availableLlmModels, tenantModelConfig]); - - return { ...query, models, @@ -92,8 +55,7 @@ export function useModelList(options?: { enabled?: boolean; staleTime?: number } embeddingModels, availableEmbeddingModels, vlmModels, - availableVlmModels, - defaultLlmModel, + availableVlmModels, invalidate: () => queryClient.invalidateQueries({ queryKey: ["models"] }), }; } diff --git a/frontend/hooks/useConfig.ts b/frontend/hooks/useConfig.ts index 0032c80c1..75539295d 100644 --- a/frontend/hooks/useConfig.ts +++ b/frontend/hooks/useConfig.ts @@ -258,6 +258,15 @@ export function useConfig() { const config: GlobalConfig = (query.data as GlobalConfig | undefined) ?? defaultConfig; + // Whether config has selected a VLM model + const isVlmAvailable = !!(config?.models?.vlm?.modelName || config?.models?.vlm?.displayName); + + // Whether config has selected an Embedding model + const isEmbeddingAvailable = !!(config?.models?.embedding?.modelName || config?.models?.embedding?.displayName); + + // Default LLM model name from config (modelName or displayName) + const defaultLlmModelName = config?.models?.llm?.modelName || config?.models?.llm?.displayName || ""; + const updateAppConfig = useCallback( (partial: Partial) => { if (!config) return; @@ -332,6 +341,9 @@ export function useConfig() { config, appConfig: config?.app, modelConfig: config?.models, + isVlmAvailable, + isEmbeddingAvailable, + defaultLlmModelName, updateAppConfig, updateModelConfig, updateConfig, From 147d61fe8f34289467f199df2ce376f45cee0510 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Fri, 6 Mar 2026 11:29:20 +0800 Subject: [PATCH 10/75] Bugfix: Stop directly modifying agent-tool when toggling tool selections --- .../components/agentConfig/ToolManagement.tsx | 36 ------- .../agentConfig/tool/ToolConfigModal.tsx | 49 +--------- frontend/hooks/agent/useSaveGuard.ts | 95 ++++++++++++++----- 3 files changed, 76 insertions(+), 104 deletions(-) diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index 18d467317..4ac7b798c 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -243,42 +243,6 @@ export default function ToolManagement({ }, ]; updateTools(newSelectedTools); - - // In non-creating mode, immediately save tool config to backend - if (!isCreatingMode && currentAgentId) { - try { - // Convert params to backend format - const paramsObj = mergedParams.reduce( - (acc, param) => { - acc[param.name] = param.value; - return acc; - }, - {} as Record - ); - - const isEnabled = true; // New tool is enabled by default - const result = await updateToolConfig( - numericId, - currentAgentId, - paramsObj, - isEnabled - ); - - if (result.success) { - // Invalidate queries to refresh tool info - queryClient.invalidateQueries({ - queryKey: ["toolInfo", numericId, currentAgentId], - }); - } else { - message.error( - result.message || t("toolConfig.message.saveError") - ); - } - } catch (error) { - console.error("Failed to save tool config:", error); - message.error(t("toolConfig.message.saveError")); - } - } } } }; diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index 2a616326b..c5884f32b 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -746,51 +746,10 @@ export default function ToolConfigModal({ newSelectedTools = [...currentTools, updatedTool]; } - // For editing mode (when currentAgentId exists), always call API - // For creating mode (isCreatingMode=true), update local state only - if (isCreatingMode) { - // In creating mode, just update local state - updateTools(newSelectedTools); - message.success(t("toolConfig.message.saveSuccess")); - handleClose(); // Close modal - return; - } - - if (!currentAgentId) { - // Should not happen in normal editing mode, but handle gracefully - updateTools(newSelectedTools); - message.success(t("toolConfig.message.saveSuccess")); - handleClose(); // Close modal - return; - } - - // Edit mode: call API to persist changes - try { - setIsLoading(true); - const isEnabled = true; // New tool is enabled by default - const result = await updateToolConfig( - parseInt(toolToSave.id), - currentAgentId, - paramsObj, - isEnabled - ); - setIsLoading(false); - - if (result.success) { - // Update local state and invalidate queries - updateTools(newSelectedTools); - queryClient.invalidateQueries({ - queryKey: ["toolInfo", parseInt(toolToSave.id), currentAgentId], - }); - message.success(t("toolConfig.message.saveSuccess")); - handleClose(); // Close modal - } else { - message.error(result.message || t("toolConfig.message.saveError")); - } - } catch (error) { - setIsLoading(false); - message.error(t("toolConfig.message.saveError")); - } + // Update local state only - actual save will happen when user clicks "Save Agent" + updateTools(newSelectedTools); + message.success(t("toolConfig.message.saveSuccess")); + handleClose(); // Close modal // Call original onSave if provided if (onSave) { diff --git a/frontend/hooks/agent/useSaveGuard.ts b/frontend/hooks/agent/useSaveGuard.ts index a1f4cea35..76a231e8b 100644 --- a/frontend/hooks/agent/useSaveGuard.ts +++ b/frontend/hooks/agent/useSaveGuard.ts @@ -4,10 +4,78 @@ import { App } from "antd"; import { useQueryClient } from "@tanstack/react-query"; import { useConfirmModal } from "../useConfirmModal"; import { useAgentConfigStore } from "@/stores/agentConfigStore"; -import { updateAgentInfo, updateToolConfig } from "@/services/agentConfigService"; +import { updateAgentInfo, updateToolConfig, searchToolConfig } from "@/services/agentConfigService"; import { Agent } from "@/types/agentConfig"; import log from "@/lib/logger"; +/** + * Batch update tool configurations for an agent + * Handles create, update, and enable/disable operations + * + * Logic: + * 1. For newly selected tools (not in baseline): Create tool instance with enable=true + * 2. For previously selected tools (in baseline): Update tool params with enable=true + * 3. For deselected tools (in baseline but not in current): Set enable=false + * + * @param agentId - The agent ID + * @param currentTools - Current tool list from edited agent + * @param baselineTools - Baseline tool list (original state before editing) + */ +async function batchUpdateToolConfigs( + agentId: number, + currentTools: any[], + baselineTools: any[] +) { + // Get the set of currently selected tool IDs + const currentToolIds = new Set( + currentTools.map((tool) => parseInt(tool.id)) + ); + + // Get the set of baseline (original) tool IDs + const baselineToolIds = new Set( + baselineTools.map((tool) => parseInt(tool.id)) + ); + + // Process each tool in the current selection + for (const tool of currentTools) { + const toolId = parseInt(tool.id); + const isEnabled = true; // Selected tools are always enabled + const params = tool.initParams?.reduce((acc: Record, param: any) => { + acc[param.name] = param.value; + return acc; + }, {} as Record) || {}; + + try { + // Update or create tool instance with current params and enabled status + await updateToolConfig(toolId, agentId, params, isEnabled); + } catch (error) { + log.error(`Failed to save tool config for tool ${toolId}:`, error); + // Continue with other tools even if one fails + } + } + + // Disable tools that were previously selected but are now deselected + const toolsToDisable = Array.from(baselineToolIds).filter( + (toolId) => !currentToolIds.has(toolId) + ); + + for (const toolId of toolsToDisable) { + try { + // Fetch existing params to preserve them when disabling + const toolInstance = await searchToolConfig(toolId, agentId); + const existingParams = toolInstance.success && toolInstance.data?.params + ? toolInstance.data.params + : {}; + + // Disable the tool while preserving its params + await updateToolConfig(toolId, agentId, existingParams, false); + } catch (error) { + log.error(`Failed to disable tool ${toolId}:`, error); + // Continue with other tools even if one fails + } + } +} + /** * Hook for handling agent save guard logic * Provides two functions: one with confirmation dialog, one for direct save @@ -83,28 +151,9 @@ export const useSaveGuard = () => { throw new Error("Failed to get agent ID after save operation"); } - // Handle new agent creation - save tool configurations - if (!currentAgentId && result.data?.agent_id) { - // Save tool configurations for the newly created agent - const agentIdNumber = result.data.agent_id; - if (currentEditedAgent.tools && currentEditedAgent.tools.length > 0) { - for (const tool of currentEditedAgent.tools) { - const toolId = parseInt(tool.id); - const isEnabled = tool.is_available !== false; // Default to true if not explicitly set to false - const params = tool.initParams?.reduce((acc, param) => { - acc[param.name] = param.value; - return acc; - }, {} as Record) || {}; - - try { - await updateToolConfig(toolId, agentIdNumber, params, isEnabled); - } catch (error) { - log.error(`Failed to save tool config for tool ${toolId}:`, error); - // Continue with other tools even if one fails - } - } - } - } + // Batch process tool configurations for both create and update modes + const baselineTools = useAgentConfigStore.getState().baselineAgent?.tools || []; + await batchUpdateToolConfigs(finalAgentId, currentEditedAgent.tools || [], baselineTools); // Common logic for both creation and update: refresh cache and update store await queryClient.invalidateQueries({ From f3ed9a391853d060ffc25080daf87e49f15e6434 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Fri, 6 Mar 2026 11:38:33 +0800 Subject: [PATCH 11/75] Bugfix: Always recalculate hasUnsavedChanges to fix false positive dirty state in agent config --- frontend/stores/agentConfigStore.ts | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/frontend/stores/agentConfigStore.ts b/frontend/stores/agentConfigStore.ts index 1cd323b76..5829e9153 100644 --- a/frontend/stores/agentConfigStore.ts +++ b/frontend/stores/agentConfigStore.ts @@ -321,9 +321,7 @@ export const useAgentConfigStore = create((set, get) => ( const editedAgent = { ...state.editedAgent, sub_agent_id_list: nextIds }; // If there are already unsaved changes, keep it true and skip recalculation. // Only when state is clean do we need to check whether sub-agent IDs changed. - const hasUnsavedChanges = state.hasUnsavedChanges - ? true - : isSubAgentIdsDirty(state.baselineAgent, editedAgent); + const hasUnsavedChanges = isSubAgentIdsDirty(state.baselineAgent, editedAgent); return { editedAgent, hasUnsavedChanges, @@ -336,9 +334,7 @@ export const useAgentConfigStore = create((set, get) => ( const editedAgent = { ...state.editedAgent, ...payload }; // If there are already unsaved changes, keep it true and skip recalculation. // Only when state is clean do we need to check whether business info changed. - const hasUnsavedChanges = state.hasUnsavedChanges - ? true - : isBusinessInfoDirty(state.baselineAgent, editedAgent); + const hasUnsavedChanges = isBusinessInfoDirty(state.baselineAgent, editedAgent); return { editedAgent, hasUnsavedChanges, @@ -351,9 +347,7 @@ export const useAgentConfigStore = create((set, get) => ( const editedAgent = { ...state.editedAgent, ...payload }; // If there are already unsaved changes, keep it true and skip recalculation. // Only when state is clean do we need to check whether profile info changed. - const hasUnsavedChanges = state.hasUnsavedChanges - ? true - : isProfileInfoDirty(state.baselineAgent, editedAgent); + const hasUnsavedChanges = isProfileInfoDirty(state.baselineAgent, editedAgent); return { editedAgent, hasUnsavedChanges, From 60036e792ed8484aac4e457560d63148be3762cb Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Fri, 6 Mar 2026 11:49:55 +0800 Subject: [PATCH 12/75] Bugfix: Refactor tool config save logic and fix array parameter comparison --- frontend/stores/agentConfigStore.ts | 34 ++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/frontend/stores/agentConfigStore.ts b/frontend/stores/agentConfigStore.ts index 5829e9153..2ea19d309 100644 --- a/frontend/stores/agentConfigStore.ts +++ b/frontend/stores/agentConfigStore.ts @@ -252,7 +252,39 @@ const isToolsDirty = (baselineAgent: EditableAgent | null, editedAgent: Editable // Compare each param's name and value for (const baseParam of baseParams) { const editParam = editParams.find(p => p.name === baseParam.name); - if (!editParam || baseParam.value !== editParam.value) { + if (!editParam) { + return true; + } + + // Deep comparison for array and object values + const baseValue = baseParam.value; + const editValue = editParam.value; + + // If both are arrays, compare their contents + if (Array.isArray(baseValue) && Array.isArray(editValue)) { + if (baseValue.length !== editValue.length) { + return true; + } + // Sort and compare array elements + const sortedBase = [...baseValue].sort(); + const sortedEdit = [...editValue].sort(); + if (JSON.stringify(sortedBase) !== JSON.stringify(sortedEdit)) { + return true; + } + } + // If both are objects (but not arrays), compare their JSON representation + else if ( + baseValue !== null && + editValue !== null && + typeof baseValue === 'object' && + typeof editValue === 'object' + ) { + if (JSON.stringify(baseValue) !== JSON.stringify(editValue)) { + return true; + } + } + // For primitive values, use strict equality + else if (baseValue !== editValue) { return true; } } From c1fc826d3e75278ffbbc8ad3f7469274732522ea Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Fri, 6 Mar 2026 12:33:44 +0800 Subject: [PATCH 13/75] Bugfix: reset when user enter /agent page --- frontend/app/[locale]/agents/page.tsx | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/frontend/app/[locale]/agents/page.tsx b/frontend/app/[locale]/agents/page.tsx index 9cade7ff4..86a52750b 100644 --- a/frontend/app/[locale]/agents/page.tsx +++ b/frontend/app/[locale]/agents/page.tsx @@ -17,6 +17,7 @@ export default function AgentSetupOrchestrator() { const { pageVariants, pageTransition } = useSetupFlow(); const searchParams = useSearchParams(); const enterCreateMode = useAgentConfigStore((state) => state.enterCreateMode); + const reset = useAgentConfigStore((state) => state.reset); // Local UI state for version panel const [isShowVersionManagePanel, setIsShowVersionManagePanel] = useState(false); @@ -32,6 +33,13 @@ export default function AgentSetupOrchestrator() { } }, [searchParams, enterCreateMode]); + // Reset agent selection state when leaving the page + useEffect(() => { + return () => { + reset(); + }; + }, [reset]); + return (
Date: Fri, 6 Mar 2026 15:01:49 +0800 Subject: [PATCH 14/75] Update opensource-memorial-wall.md add message --- doc/docs/zh/opensource-memorial-wall.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/docs/zh/opensource-memorial-wall.md b/doc/docs/zh/opensource-memorial-wall.md index 54bac7c28..c31428ba5 100644 --- a/doc/docs/zh/opensource-memorial-wall.md +++ b/doc/docs/zh/opensource-memorial-wall.md @@ -711,3 +711,7 @@ Nexent 加油!希望能达成所愿! ::: info sisyphus0x - 2026-03-04 对多智能体编排和协同工作很感兴趣,学习一下 ::: + +::: info hmh_mike - 2026-03-05 +感觉很有意思,试用一下看看对工作有没有帮助 +::: From ec137538c3132ff3fd40fd107aa5ba8ede038b38 Mon Sep 17 00:00:00 2001 From: zwb <1194371519@qq.com> Date: Fri, 6 Mar 2026 16:23:02 +0800 Subject: [PATCH 15/75] =?UTF-8?q?=E2=9C=A8File=20preview:=20Add=20file=20p?= =?UTF-8?q?review=20backend=20service?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/auto-unit-test.yml | 3 + backend/apps/data_process_app.py | 33 ++ backend/apps/file_management_app.py | 85 +++- backend/consts/const.py | 15 + backend/consts/exceptions.py | 15 + backend/database/attachment_db.py | 37 ++ backend/database/client.py | 27 ++ backend/services/data_process_service.py | 84 +++- backend/services/file_management_service.py | 160 ++++++- backend/utils/file_management_utils.py | 65 +++ docker/docker-compose.prod.yml | 1 + docker/docker-compose.yml | 1 + make/data_process/Dockerfile | 3 + sdk/nexent/storage/minio.py | 32 ++ sdk/nexent/storage/storage_client_base.py | 22 + test/backend/app/test_data_process_app.py | 64 +++ test/backend/app/test_file_management_app.py | 304 +++++++++++- test/backend/database/test_attachment_db.py | 89 ++++ test/backend/database/test_client.py | 101 ++++ .../services/test_data_process_service.py | 189 +++++++- .../services/test_file_management_service.py | 439 ++++++++++++++++++ .../utils/test_file_management_utils.py | 95 ++++ test/sdk/storage/test_minio.py | 89 ++++ 23 files changed, 1937 insertions(+), 16 deletions(-) diff --git a/.github/workflows/auto-unit-test.yml b/.github/workflows/auto-unit-test.yml index 6addafa22..29cf3a42d 100644 --- a/.github/workflows/auto-unit-test.yml +++ b/.github/workflows/auto-unit-test.yml @@ -48,6 +48,9 @@ jobs: uv pip install -e "../sdk[dev]" cd .. + - name: Install LibreOffice + run: sudo apt-get update && sudo apt-get install -y libreoffice + - name: Run all tests and collect coverage run: | source backend/.venv/bin/activate && python test/run_all_test.py diff --git a/backend/apps/data_process_app.py b/backend/apps/data_process_app.py index 3ac8b45cf..9138d5ef1 100644 --- a/backend/apps/data_process_app.py +++ b/backend/apps/data_process_app.py @@ -11,6 +11,7 @@ ConvertStateRequest, TaskRequest, ) +from consts.exceptions import OfficeConversionException from data_process.tasks import process_and_forward, process_sync from services.data_process_service import get_data_process_service @@ -311,3 +312,35 @@ async def convert_state(request: ConvertStateRequest): status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"Error converting state: {str(e)}" ) + + +@router.post("/convert_to_pdf") +async def convert_office_to_pdf( + object_name: str = Form(...), + pdf_object_name: str = Form(...) +): + """ + Convert an Office document stored in MinIO to PDF. + + Parameters: + object_name: Source Office file path in MinIO + pdf_object_name: Destination PDF path in MinIO + """ + try: + await service.convert_office_to_pdf_impl( + object_name=object_name, + pdf_object_name=pdf_object_name, + ) + return JSONResponse(status_code=HTTPStatus.OK, content={"success": True}) + except OfficeConversionException as exc: + logger.error(f"Office conversion failed for '{object_name}': {exc}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=str(exc) + ) + except Exception as exc: + logger.error(f"Unexpected error during conversion for '{object_name}': {exc}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"Office conversion failed: {exc}" + ) diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py index 9ed87cfae..5b7c7bc3c 100644 --- a/backend/apps/file_management_app.py +++ b/backend/apps/file_management_app.py @@ -9,22 +9,29 @@ from fastapi import APIRouter, Body, File, Form, Header, HTTPException, Path as PathParam, Query, UploadFile from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse +from consts.exceptions import FileTooLargeException, NotFoundException, OfficeConversionException, UnsupportedFileTypeException from consts.model import ProcessParams from services.file_management_service import upload_to_minio, upload_files_impl, \ - get_file_url_impl, get_file_stream_impl, delete_file_impl, list_files_impl + get_file_url_impl, get_file_stream_impl, delete_file_impl, list_files_impl, \ + preview_file_impl from utils.file_management_utils import trigger_data_process logger = logging.getLogger("file_management_app") -def build_content_disposition_header(filename: Optional[str]) -> str: +def build_content_disposition_header(filename: Optional[str], inline: bool = False) -> str: """ Build a Content-Disposition header that keeps the original filename. + Args: + filename: Original filename to include in header + inline: If True, use 'inline' disposition (for preview); otherwise 'attachment' (for download) + - ASCII filenames are returned directly. - Non-ASCII filenames include both an ASCII fallback and RFC 5987 encoded value so modern browsers keep the original name. """ + disposition = "inline" if inline else "attachment" safe_name = (filename or "download").strip() or "download" def _sanitize_ascii(value: str) -> str: @@ -40,26 +47,26 @@ def _sanitize_ascii(value: str) -> str: try: safe_name.encode("ascii") - return f'attachment; filename="{_sanitize_ascii(safe_name)}"' + return f'{disposition}; filename="{_sanitize_ascii(safe_name)}"' except UnicodeEncodeError: try: encoded = quote(safe_name, safe="") except Exception: # quote failure, fallback to sanitized ASCII only logger.warning("Failed to encode filename '%s', using fallback", safe_name) - return f'attachment; filename="{_sanitize_ascii(safe_name)}"' + return f'{disposition}; filename="{_sanitize_ascii(safe_name)}"' fallback = _sanitize_ascii( safe_name.encode("ascii", "ignore").decode("ascii") or "download" ) - return f'attachment; filename="{fallback}"; filename*=UTF-8\'\'{encoded}' + return f'{disposition}; filename="{fallback}"; filename*=UTF-8\'\'{encoded}' except Exception as exc: # pragma: no cover logger.warning( "Failed to encode filename '%s': %s. Using fallback.", safe_name, exc, ) - return 'attachment; filename="download"' + return f'{disposition}; filename="download"' # Create API router file_management_runtime_router = APIRouter(prefix="/file") @@ -567,3 +574,69 @@ async def get_storage_file_batch_urls( "failed_count": sum(1 for r in results if not r.get("success", False)), "results": results } + +@file_management_config_router.get("/preview/{object_name:path}") +async def preview_file( + object_name: str = PathParam(..., description="File object name to preview"), + filename: Optional[str] = Query(None, description="Original filename for display (optional)") +): + """ + Preview file inline in browser + + - **object_name**: File object name in storage + - **filename**: Original filename for Content-Disposition header (optional) + + Returns file stream with Content-Disposition: inline for browser preview + """ + try: + # Get file stream from preview service + file_stream, content_type = await preview_file_impl(object_name=object_name) + + # Use provided filename or extract from object_name + display_filename = filename + if not display_filename: + display_filename = object_name.split("/")[-1] if "/" in object_name else object_name + + # Build Content-Disposition header for inline display + content_disposition = build_content_disposition_header(display_filename, inline=True) + + return StreamingResponse( + file_stream, + media_type=content_type, + headers={ + "Content-Disposition": content_disposition, + "Cache-Control": "public, max-age=3600", + "ETag": f'"{object_name}"', + } + ) + + except FileTooLargeException as e: + logger.warning(f"[preview_file] File too large: object_name={object_name}, error={str(e)}") + raise HTTPException( + status_code=HTTPStatus.REQUEST_ENTITY_TOO_LARGE, + detail=str(e) + ) + except NotFoundException as e: + logger.error(f"[preview_file] File not found: object_name={object_name}, error={str(e)}") + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail=f"File not found: {object_name}" + ) + except UnsupportedFileTypeException as e: + logger.error(f"[preview_file] Unsupported file type: object_name={object_name}, error={str(e)}") + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"File format not supported for preview: {str(e)}" + ) + except OfficeConversionException as e: + logger.error(f"[preview_file] Conversion failed: object_name={object_name}, error={str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"Failed to preview file: {str(e)}" + ) + except Exception as e: + logger.error(f"[preview_file] Unexpected error: object_name={object_name}, error={str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"Failed to preview file: {str(e)}" + ) \ No newline at end of file diff --git a/backend/consts/const.py b/backend/consts/const.py index 32404bab4..6249af049 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -36,6 +36,21 @@ class VectorDatabaseType(str, Enum): ROOT_DIR = os.getenv("ROOT_DIR") +# Preview Configuration +FILE_PREVIEW_SIZE_LIMIT = 100 * 1024 * 1024 # 100MB +# Limit concurrent Office-to-PDF conversions +MAX_CONCURRENT_CONVERSIONS = 5 +# Supported Office file MIME types +OFFICE_MIME_TYPES = [ + 'application/msword', # .doc + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', # .docx + 'application/vnd.ms-excel', # .xls + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', # .xlsx + 'application/vnd.ms-powerpoint', # .ppt + 'application/vnd.openxmlformats-officedocument.presentationml.presentation' # .pptx +] + + # Supabase Configuration SUPABASE_URL = os.getenv('SUPABASE_URL') SUPABASE_KEY = os.getenv('SUPABASE_KEY') diff --git a/backend/consts/exceptions.py b/backend/consts/exceptions.py index e9d270673..4e6e78734 100644 --- a/backend/consts/exceptions.py +++ b/backend/consts/exceptions.py @@ -115,6 +115,21 @@ class IncorrectInviteCodeException(Exception): pass +class OfficeConversionException(Exception): + """Raised when Office-to-PDF conversion via data-process service fails.""" + pass + + +class UnsupportedFileTypeException(Exception): + """Raised when a file type is not supported for the requested operation.""" + pass + + +class FileTooLargeException(Exception): + """Raised when a file exceeds the maximum allowed size for the requested operation.""" + pass + + class UserRegistrationException(Exception): """Raised when user registration fails.""" pass diff --git a/backend/database/attachment_db.py b/backend/database/attachment_db.py index d7764b3a2..2e6249468 100644 --- a/backend/database/attachment_db.py +++ b/backend/database/attachment_db.py @@ -169,6 +169,42 @@ def get_file_size_from_minio(object_name: str, bucket: Optional[str] = None) -> return minio_client.get_file_size(object_name, bucket) +def file_exists(object_name: str, bucket: Optional[str] = None) -> bool: + """ + Check if a file exists in the bucket. + + Args: + object_name: Object name in storage + bucket: Bucket name, if not specified will use default bucket + + Returns: + bool: True if file exists, False otherwise + """ + try: + return minio_client.file_exists(object_name, bucket) + except Exception: + return False + + +def copy_file(source_object: str, dest_object: str, bucket: Optional[str] = None) -> Dict[str, Any]: + """ + Copy a file within the same bucket (atomic operation in MinIO). + + Args: + source_object: Source object name + dest_object: Destination object name + bucket: Bucket name, if not specified will use default bucket + + Returns: + Dict[str, Any]: Result containing success flag and error message (if any) + """ + success, result = minio_client.copy_file(source_object, dest_object, bucket) + if success: + return {"success": True, "object_name": result} + else: + return {"success": False, "error": result} + + def list_files(prefix: str = "", bucket: Optional[str] = None) -> List[Dict[str, Any]]: """ List files in bucket @@ -269,6 +305,7 @@ def get_content_type(file_path: str) -> str: '.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', '.txt': 'text/plain', '.csv': 'text/csv', + '.md': 'text/markdown', '.html': 'text/html', '.htm': 'text/html', '.json': 'application/json', diff --git a/backend/database/client.py b/backend/database/client.py index c82f78df3..37e5dba03 100644 --- a/backend/database/client.py +++ b/backend/database/client.py @@ -213,6 +213,33 @@ def get_file_stream(self, object_name: str, bucket: Optional[str] = None) -> Tup """ return self._storage_client.get_file_stream(object_name, bucket) + def file_exists(self, object_name: str, bucket: Optional[str] = None) -> bool: + """ + Check if file exists in MinIO + + Args: + object_name: Object name + bucket: Bucket name, if not specified use default bucket + + Returns: + bool: True if file exists, False otherwise + """ + return self._storage_client.exists(object_name, bucket) + + def copy_file(self, source_object: str, dest_object: str, bucket: Optional[str] = None) -> Tuple[bool, str]: + """ + Copy a file within the same bucket (atomic operation) + + Args: + source_object: Source object name + dest_object: Destination object name + bucket: Bucket name, if not specified use default bucket + + Returns: + Tuple[bool, str]: (Success status, Destination object name or error message) + """ + return self._storage_client.copy_file(source_object, dest_object, bucket) + # Create global database and MinIO client instances db_client = PostgresClient() diff --git a/backend/services/data_process_service.py b/backend/services/data_process_service.py index bce279a4c..8c44c15e6 100644 --- a/backend/services/data_process_service.py +++ b/backend/services/data_process_service.py @@ -4,6 +4,7 @@ import io import logging import os +import shutil import tempfile import threading import time @@ -18,12 +19,18 @@ from transformers import CLIPProcessor, CLIPModel from nexent.data_process.core import DataProcessCore -from consts.const import CLIP_MODEL_PATH, IMAGE_FILTER, REDIS_BACKEND_URL, REDIS_URL +from consts.const import CLIP_MODEL_PATH, IMAGE_FILTER, MAX_CONCURRENT_CONVERSIONS, REDIS_BACKEND_URL, REDIS_URL +from consts.exceptions import OfficeConversionException from consts.model import BatchTaskRequest +from database.attachment_db import delete_file, file_exists, get_file_size_from_minio, get_file_stream, upload_file +from utils.file_management_utils import convert_office_to_pdf from data_process.app import app as celery_app from data_process.tasks import process, forward from data_process.utils import get_task_info, get_all_task_ids_from_redis +# Limit concurrent LibreOffice processes to avoid resource exhaustion +_conversion_semaphore = asyncio.Semaphore(MAX_CONCURRENT_CONVERSIONS) + # Configure logging logger = logging.getLogger("data_process.service") @@ -551,6 +558,81 @@ async def process_uploaded_text_file(self, file_content: bytes, filename: str, c "chunking_strategy": chunking_strategy } + async def convert_office_to_pdf_impl(self, object_name: str, pdf_object_name: str) -> None: + """Full conversion pipeline: download → convert → upload → validate → cleanup. + + All five steps run inside data-process so that LibreOffice only needs to be + installed in this container. + + Args: + object_name: Source Office file path in MinIO. + pdf_object_name: Destination PDF path in MinIO (final, not temp). + """ + async with _conversion_semaphore: + temp_dir = None + try: + temp_dir = tempfile.mkdtemp(prefix='office_convert_') + + # Step 1: Download original Office file from MinIO + original_stream = get_file_stream(object_name) + if original_stream is None: + raise OfficeConversionException(f"Source file not found in storage: {object_name}") + + original_filename = os.path.basename(object_name) + input_path = os.path.join(temp_dir, original_filename) + with open(input_path, 'wb') as f: + while chunk := original_stream.read(8192): + f.write(chunk) + + # Step 2: Local conversion using LibreOffice + try: + pdf_path = await convert_office_to_pdf(input_path, temp_dir, timeout=30) + except Exception as exc: + raise OfficeConversionException(f"LibreOffice conversion failed: {exc}") from exc + + # Step 3: Upload converted PDF to MinIO + result = upload_file(file_path=pdf_path, object_name=pdf_object_name) + if not result.get('success'): + raise OfficeConversionException( + f"Failed to upload PDF to MinIO: {result.get('error', 'Unknown error')}" + ) + + # Step 4: Validate the uploaded PDF (header check + minimum size) + remote_size = get_file_size_from_minio(pdf_object_name) + if remote_size <= 0: + raise OfficeConversionException("PDF validation failed: cannot read remote file size") + if remote_size < 100: + raise OfficeConversionException( + f"PDF validation failed: file too small ({remote_size} bytes)" + ) + remote_stream = get_file_stream(pdf_object_name) + if remote_stream is None: + raise OfficeConversionException("PDF validation failed: cannot read uploaded file") + try: + header = remote_stream.read(5) + finally: + try: + remote_stream.close() + except Exception: + pass + if not header.startswith(b'%PDF-'): + raise OfficeConversionException("PDF validation failed: invalid PDF header") + + except OfficeConversionException: + # Clean up any partially-uploaded remote PDF so a future retry starts clean + if file_exists(pdf_object_name): + delete_file(pdf_object_name) + raise + except Exception as exc: + raise OfficeConversionException(f"Unexpected error during conversion: {exc}") from exc + finally: + # Step 5: Clean up local temporary directory + if temp_dir and os.path.exists(temp_dir): + try: + shutil.rmtree(temp_dir) + except Exception as cleanup_err: + logger.warning(f"Failed to cleanup temp dir '{temp_dir}': {cleanup_err}") + def convert_celery_states_to_custom(self, process_celery_state: Optional[str], forward_celery_state: Optional[str]) -> str: """Map Celery task states to a custom frontend state string. diff --git a/backend/services/file_management_service.py b/backend/services/file_management_service.py index 8215be810..7c7886bdc 100644 --- a/backend/services/file_management_service.py +++ b/backend/services/file_management_service.py @@ -1,20 +1,33 @@ import asyncio +import hashlib import logging import os from io import BytesIO from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Tuple +import httpx from fastapi import UploadFile -from consts.const import UPLOAD_FOLDER, MAX_CONCURRENT_UPLOADS, MODEL_CONFIG_MAPPING +from consts.const import ( + DATA_PROCESS_SERVICE, + FILE_PREVIEW_SIZE_LIMIT, + MAX_CONCURRENT_UPLOADS, + MODEL_CONFIG_MAPPING, + OFFICE_MIME_TYPES, + UPLOAD_FOLDER, +) +from consts.exceptions import FileTooLargeException, NotFoundException, OfficeConversionException, UnsupportedFileTypeException from database.attachment_db import ( - upload_fileobj, - get_file_url, + copy_file, + delete_file, + file_exists, get_content_type, + get_file_size_from_minio, get_file_stream, - delete_file, - list_files + get_file_url, + list_files, + upload_fileobj, ) from services.vectordatabase_service import ElasticSearchService, get_vector_db_core from utils.config_utils import tenant_config_manager, get_model_name_from_config @@ -28,6 +41,10 @@ upload_dir.mkdir(exist_ok=True) upload_semaphore = asyncio.Semaphore(MAX_CONCURRENT_UPLOADS) +# Per-file locks prevent duplicate conversions of the same file +_conversion_locks: dict[str, asyncio.Lock] = {} +_conversion_locks_guard = asyncio.Lock() + logger = logging.getLogger("file_management_service") @@ -195,4 +212,133 @@ def get_llm_model(tenant_id: str): max_context_tokens=main_model_config.get("max_tokens"), ssl_verify=main_model_config.get("ssl_verify", True), ) - return long_text_to_text_model \ No newline at end of file + return long_text_to_text_model + + +async def preview_file_impl(object_name: str) -> Tuple[BytesIO, str]: + """ + Preview a file by returning its contents as a stream. + + Args: + object_name: File object name in storage + + Returns: + Tuple[BytesIO, str]: (file_stream, content_type) + """ + file_size = get_file_size_from_minio(object_name) + if file_size > FILE_PREVIEW_SIZE_LIMIT: + raise FileTooLargeException( + f"File size {file_size} bytes exceeds the {FILE_PREVIEW_SIZE_LIMIT // (1024 * 1024)} MB preview limit" + ) + + content_type = get_content_type(object_name) + + # PDF, images, and text files - return directly + if content_type == 'application/pdf' or content_type.startswith('image/') or content_type in ['text/plain', 'text/csv', 'text/markdown']: + file_stream = get_file_stream(object_name) + if file_stream is None: + raise NotFoundException("File not found or failed to read from storage") + return file_stream, content_type + + # Office documents - convert to PDF with caching + elif content_type in OFFICE_MIME_TYPES: + name_without_ext = object_name.rsplit('.', 1)[0] if '.' in object_name else object_name + hash_suffix = hashlib.md5(object_name.encode()).hexdigest()[:8] + pdf_object_name = f"preview/converted/{name_without_ext}_{hash_suffix}.pdf" + temp_pdf_object_name = f"preview/converting/{name_without_ext}_{hash_suffix}.pdf.tmp" + + # Fast path: return from cache without acquiring any lock + cached_stream = _get_cached_pdf_stream(pdf_object_name) + if cached_stream is not None: + return cached_stream, 'application/pdf' + + # Slow path: convert with locking + file_stream = await _convert_office_to_cached_pdf(object_name, pdf_object_name, temp_pdf_object_name) + return file_stream, 'application/pdf' + + # Unsupported file type + else: + raise UnsupportedFileTypeException(f"Unsupported file type for preview: {content_type}") + + +def _get_cached_pdf_stream(pdf_object_name: str) -> Optional[BytesIO]: + """ + Return the cached PDF stream if available, or None if missing or corrupted. + + If the file exists but cannot be read, the corrupted entry is deleted so + a subsequent call will trigger a fresh conversion. + """ + if file_exists(pdf_object_name): + file_stream = get_file_stream(pdf_object_name) + if file_stream is None: + logger.warning(f"Corrupted cache detected (cannot read), deleting: {pdf_object_name}") + delete_file(pdf_object_name) + return None + return file_stream + return None + + +async def _convert_office_to_cached_pdf( + object_name: str, + pdf_object_name: str, + temp_pdf_object_name: str, +) -> BytesIO: + """ + Convert an Office document to PDF and store the result in MinIO. + + Args: + object_name: Source Office file path in MinIO + pdf_object_name: Final cached PDF path in MinIO + temp_pdf_object_name: Temporary PDF path used during conversion + + Returns: + BytesIO stream of the converted PDF + """ + # Get or create a lock for this specific file to prevent duplicate conversions + async with _conversion_locks_guard: + if object_name not in _conversion_locks: + _conversion_locks[object_name] = asyncio.Lock() + file_lock = _conversion_locks[object_name] + + async with file_lock: + # Double-check: another request may have completed the conversion while we waited + cached_stream = _get_cached_pdf_stream(pdf_object_name) + if cached_stream is not None: + return cached_stream + + # Conversion semaphore is enforced inside the data-process service + try: + # Request conversion: data-process downloads, converts, uploads to temp path, validates + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + f"{DATA_PROCESS_SERVICE}/tasks/convert_to_pdf", + data={ + "object_name": object_name, + "pdf_object_name": temp_pdf_object_name, + }, + ) + if response.status_code != 200: + raise Exception( + f"data-process conversion returned {response.status_code}: {response.text}" + ) + + # Atomic move from temp to final location, then clean up temp + copy_result = copy_file(source_object=temp_pdf_object_name, dest_object=pdf_object_name) + if not copy_result.get('success'): + raise Exception(f"Failed to finalize PDF cache: {copy_result.get('error', 'Unknown error')}") + delete_file(temp_pdf_object_name) + + except Exception as e: + if file_exists(temp_pdf_object_name): + delete_file(temp_pdf_object_name) + logger.error(f"Office conversion failed: {str(e)}") + raise OfficeConversionException(f"Failed to convert Office document to PDF: {str(e)}") from e + finally: + # Clean up the file lock (prevents memory leak for many unique files) + async with _conversion_locks_guard: + _conversion_locks.pop(object_name, None) + + file_stream = get_file_stream(pdf_object_name) + if file_stream is None: + raise NotFoundException("Converted PDF not found or failed to read from storage") + return file_stream diff --git a/backend/utils/file_management_utils.py b/backend/utils/file_management_utils.py index 2a1aa3801..57025e350 100644 --- a/backend/utils/file_management_utils.py +++ b/backend/utils/file_management_utils.py @@ -1,5 +1,7 @@ +import asyncio import logging import os +import subprocess import traceback from pathlib import Path from typing import List @@ -337,3 +339,66 @@ def get_file_size(source_type: str, path_or_url: str) -> int: logging.error(f"Error getting file size for {path_or_url}: {str(e)}") return 0 + +async def convert_office_to_pdf(input_path: str, output_dir: str, timeout: int = 30) -> str: + """ + Convert Office document to PDF using LibreOffice. + + Args: + input_path: Path to input Office file + output_dir: Directory for output PDF file + timeout: Conversion timeout in seconds (default: 30s) + + Returns: + str: Path to generated PDF file + """ + if not os.path.exists(input_path): + raise FileNotFoundError(f"Input file not found: {input_path}") + + def _run_libreoffice_conversion(): + """Synchronous LibreOffice conversion to run in thread executor.""" + cmd = [ + 'libreoffice', + '--headless', + '--convert-to', 'pdf', + '--outdir', output_dir, + input_path + ] + return subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout + ) + + try: + # Run blocking subprocess in thread executor to avoid blocking event loop + result = await asyncio.to_thread(_run_libreoffice_conversion) + + if result.returncode != 0: + error_msg = result.stderr or result.stdout or "Unknown conversion error" + logger.error(f"LibreOffice conversion failed: {error_msg}") + raise RuntimeError(f"Office to PDF conversion failed: {error_msg}") + + # Find generated PDF file + input_filename = os.path.basename(input_path) + pdf_filename = os.path.splitext(input_filename)[0] + '.pdf' + pdf_path = os.path.join(output_dir, pdf_filename) + + if not os.path.exists(pdf_path): + raise RuntimeError(f"Converted PDF not found: {pdf_path}") + + return pdf_path + + except subprocess.TimeoutExpired: + logger.error(f"Office to PDF conversion timeout after {timeout}s: {input_path}") + raise TimeoutError(f"Office to PDF conversion timeout (>{timeout}s)") + + except FileNotFoundError as e: + # LibreOffice executable not found in PATH + logger.error(f"LibreOffice not available: {str(e)}") + raise FileNotFoundError( + "LibreOffice is not installed or not available in PATH. " + ) from e + + diff --git a/docker/docker-compose.prod.yml b/docker/docker-compose.prod.yml index e9d344461..8eef651ae 100644 --- a/docker/docker-compose.prod.yml +++ b/docker/docker-compose.prod.yml @@ -272,6 +272,7 @@ services: mc admin policy attach myadmin readwrite --user=$MINIO_ACCESS_KEY mc mb myadmin/$MINIO_DEFAULT_BUCKET mc anonymous set download myadmin/$MINIO_DEFAULT_BUCKET + mc ilm rule add myadmin/$MINIO_DEFAULT_BUCKET --prefix 'preview/' --expiry-days 7 --id expire-converted-pdfs wait $$MINIO_PID " diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 221ff0c89..321f29665 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -298,6 +298,7 @@ services: mc admin policy attach myadmin readwrite --user=$MINIO_ACCESS_KEY mc mb myadmin/$MINIO_DEFAULT_BUCKET mc anonymous set download myadmin/$MINIO_DEFAULT_BUCKET + mc ilm rule add myadmin/$MINIO_DEFAULT_BUCKET --prefix 'preview/' --expiry-days 7 --id expire-converted-pdfs wait $$MINIO_PID " diff --git a/make/data_process/Dockerfile b/make/data_process/Dockerfile index 35d7a6c48..7903cfd92 100644 --- a/make/data_process/Dockerfile +++ b/make/data_process/Dockerfile @@ -24,6 +24,9 @@ RUN apt-get update && \ libreoffice \ libgl1 \ coreutils \ + fontconfig \ + fonts-noto-cjk \ + && fc-cache -fv \ && apt-get autoremove -y \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* diff --git a/sdk/nexent/storage/minio.py b/sdk/nexent/storage/minio.py index 8815d8751..3a80b6607 100644 --- a/sdk/nexent/storage/minio.py +++ b/sdk/nexent/storage/minio.py @@ -396,3 +396,35 @@ def exists( except ClientError: return False + def copy_file( + self, + source_object: str, + dest_object: str, + bucket: Optional[str] = None + ) -> Tuple[bool, str]: + """ + Copy a file within the same bucket. + + Args: + source_object: Source object name + dest_object: Destination object name + bucket: Bucket name, if not specified use default bucket + + Returns: + Tuple[bool, str]: (Success status, Destination object name or error message) + """ + bucket = bucket or self.default_bucket + if bucket is None: + return False, "Bucket name is required" + + try: + copy_source = {"Bucket": bucket, "Key": source_object} + self.client.copy_object( + Bucket=bucket, + Key=dest_object, + CopySource=copy_source + ) + return True, dest_object + except Exception as e: + logger.error(f"Failed to copy object {source_object} to {dest_object}: {e}") + return False, str(e) diff --git a/sdk/nexent/storage/storage_client_base.py b/sdk/nexent/storage/storage_client_base.py index 095dc43fc..05623a0c0 100644 --- a/sdk/nexent/storage/storage_client_base.py +++ b/sdk/nexent/storage/storage_client_base.py @@ -217,3 +217,25 @@ def exists( """ pass + @abstractmethod + def copy_file( + self, + source_object: str, + dest_object: str, + bucket: Optional[str] = None + ) -> Tuple[bool, str]: + """ + Copy a file within the same bucket. + + Args: + source_object: Source object name + dest_object: Destination object name + bucket: Bucket name, if not specified use default bucket + + Returns: + Tuple[bool, str]: (Success status, Destination object name or error message) + """ + pass + + + diff --git a/test/backend/app/test_data_process_app.py b/test/backend/app/test_data_process_app.py index acbfe889e..b59b0f817 100644 --- a/test/backend/app/test_data_process_app.py +++ b/test/backend/app/test_data_process_app.py @@ -8,6 +8,18 @@ from fastapi.testclient import TestClient from pydantic import BaseModel +# Install consts.exceptions at module level so OfficeConversionException is bound +# in the app module's namespace on first import. +_exc_mod = types.ModuleType("consts.exceptions") + + +class _OfficeConversionException(Exception): + """Stub exception for Office document conversion failures.""" + + +_exc_mod.OfficeConversionException = _OfficeConversionException # type: ignore[attr-defined] +sys.modules["consts.exceptions"] = _exc_mod + class _TaskRequest(BaseModel): source: str @@ -136,6 +148,14 @@ def convert_celery_states_to_custom(self, process_celery_state: str, forward_cel return "COMPLETED" return "WAIT_FOR_PROCESSING" + async def convert_office_to_pdf_impl(self, object_name: str, pdf_object_name: str) -> None: + """Stub: raise OfficeConversionException for sentinel inputs, otherwise succeed.""" + from consts.exceptions import OfficeConversionException + if object_name == "fail.docx": + raise OfficeConversionException("conversion failed") + if object_name == "err.docx": + raise RuntimeError("unexpected error") + @pytest.fixture(autouse=True) def stub_modules(monkeypatch): @@ -451,3 +471,47 @@ def raise_convert_http(*args, **kwargs): resp = client.post("/tasks/convert_state", json={"process_state": "PENDING", "forward_state": ""}) assert resp.status_code == HTTPStatus.NOT_ACCEPTABLE + + +def test_convert_to_pdf_success(): + """Valid request returns 200 {success: True}.""" + app = _build_app() + client = TestClient(app) + resp = client.post( + "/tasks/convert_to_pdf", + data={"object_name": "uploads/doc.docx", "pdf_object_name": "converted/doc.pdf"}, + ) + assert resp.status_code == HTTPStatus.OK + assert resp.json()["success"] is True + + +def test_convert_to_pdf_office_conversion_exception(monkeypatch): + """OfficeConversionException from service maps to HTTP 500.""" + app = _build_app() + client = TestClient(app) + # Trigger the sentinel path in _ServiceStub + resp = client.post( + "/tasks/convert_to_pdf", + data={"object_name": "fail.docx", "pdf_object_name": "converted/fail.pdf"}, + ) + assert resp.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + assert "conversion failed" in resp.json()["detail"] + + +def test_convert_to_pdf_unexpected_exception(): + """Unexpected RuntimeError from service also maps to HTTP 500.""" + app = _build_app() + client = TestClient(app) + resp = client.post( + "/tasks/convert_to_pdf", + data={"object_name": "err.docx", "pdf_object_name": "converted/err.pdf"}, + ) + assert resp.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + +def test_convert_to_pdf_missing_params(): + """Missing required form fields returns HTTP 422 Unprocessable Entity.""" + app = _build_app() + client = TestClient(app) + resp = client.post("/tasks/convert_to_pdf", data={}) + assert resp.status_code == HTTPStatus.UNPROCESSABLE_ENTITY diff --git a/test/backend/app/test_file_management_app.py b/test/backend/app/test_file_management_app.py index a337a1434..1165f3d9d 100644 --- a/test/backend/app/test_file_management_app.py +++ b/test/backend/app/test_file_management_app.py @@ -55,6 +55,12 @@ async def _stub_preprocess_files_generator(*_: Any, **__: Any) -> AsyncGenerator yield "data: {\"type\": \"progress\", \"progress\": 0}\n\n" yield "data: {\"type\": \"complete\", \"progress\": 100}\n\n" +async def _stub_preview_file_impl(object_name: str): + """Default stub for preview_file_impl""" + from io import BytesIO + return BytesIO(b"PDF content"), "application/pdf" + +sfms_stub.preview_file_impl = _stub_preview_file_impl sfms_stub.upload_to_minio = _stub_upload_to_minio sfms_stub.upload_files_impl = _stub_upload_files_impl sfms_stub.get_file_url_impl = _stub_get_file_url_impl @@ -101,9 +107,22 @@ def __init__(self, chunking_strategy: str, source_type: str, index_name: str, au self.index_name = index_name self.authorization = authorization model_stub.ProcessParams = ProcessParams -sys.modules["consts.model"] = model_stub +sys.modules.setdefault("consts.model", model_stub) setattr(consts_pkg, "model", model_stub) +# Stub consts.exceptions with real exception classes so isinstance checks work +exceptions_stub = types.ModuleType("consts.exceptions") +class NotFoundException(Exception): pass +class OfficeConversionException(Exception): pass +class UnsupportedFileTypeException(Exception): pass +class FileTooLargeException(Exception): pass +exceptions_stub.NotFoundException = NotFoundException +exceptions_stub.OfficeConversionException = OfficeConversionException +exceptions_stub.UnsupportedFileTypeException = UnsupportedFileTypeException +exceptions_stub.FileTooLargeException = FileTooLargeException +sys.modules["consts.exceptions"] = exceptions_stub +setattr(consts_pkg, "exceptions", exceptions_stub) + # Import the module under test after stubbing deps file_management_app = __import__( @@ -444,6 +463,40 @@ def boom(_value: str, safe: str = "") -> str: assert 'filename*=UTF-8' not in result +def test_build_content_disposition_header_inline_ascii(): + """Test build_content_disposition_header with inline=True for ASCII filename""" + result = file_management_app.build_content_disposition_header("test.pdf", inline=True) + assert result == 'inline; filename="test.pdf"' + assert 'attachment' not in result + + +def test_build_content_disposition_header_inline_non_ascii(): + """Test build_content_disposition_header with inline=True for non-ASCII filename""" + result = file_management_app.build_content_disposition_header("测试文档.pdf", inline=True) + assert 'inline; filename=' in result + assert 'attachment' not in result + assert 'filename*=UTF-8' in result + + +def test_build_content_disposition_header_inline_false_explicit(): + """Test build_content_disposition_header with inline=False explicitly""" + result = file_management_app.build_content_disposition_header("test.pdf", inline=False) + assert result == 'attachment; filename="test.pdf"' + assert 'inline' not in result + + +def test_build_content_disposition_header_inline_exception_handling(monkeypatch): + """Test build_content_disposition_header inline mode exception handling""" + def boom(_value: str, safe: str = "") -> str: + raise RuntimeError("quote failure") + + monkeypatch.setattr("backend.apps.file_management_app.quote", boom) + + result = file_management_app.build_content_disposition_header("中文.pdf", inline=True) + assert 'inline; filename=' in result + assert 'attachment' not in result + + # --- Tests for get_storage_file with filename parameter --- @pytest.mark.asyncio @@ -872,3 +925,252 @@ def test_build_datamate_url_from_parts_empty_base_url(): assert "base_url is required" in str(ei.value) +# --- Tests for preview_file endpoint --- + +@pytest.mark.asyncio +async def test_preview_file_pdf_success(monkeypatch): + """Test previewing a PDF file returns StreamingResponse with inline disposition""" + from io import BytesIO + + async def fake_preview(object_name): + return BytesIO(b"PDF content"), "application/pdf" + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + resp = await file_management_app.preview_file( + object_name="documents/test.pdf", + filename="test.pdf" + ) + + assert resp.media_type == "application/pdf" + content_disposition = resp.headers.get("content-disposition", "") + assert "inline" in content_disposition + assert "test.pdf" in content_disposition + assert resp.headers.get("cache-control") == "public, max-age=3600" + + +@pytest.mark.asyncio +async def test_preview_file_image_success(monkeypatch): + """Test previewing an image file returns correct content type""" + from io import BytesIO + + async def fake_preview(object_name): + return BytesIO(b"PNG image data"), "image/png" + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + resp = await file_management_app.preview_file( + object_name="images/photo.png", + filename="photo.png" + ) + + assert resp.media_type == "image/png" + content_disposition = resp.headers.get("content-disposition", "") + assert "inline" in content_disposition + + +@pytest.mark.asyncio +async def test_preview_file_text_success(monkeypatch): + """Test previewing a text file returns correct content type""" + from io import BytesIO + + async def fake_preview(object_name): + return BytesIO(b"Hello World"), "text/plain" + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + resp = await file_management_app.preview_file( + object_name="files/readme.txt", + filename="readme.txt" + ) + + assert resp.media_type == "text/plain" + content_disposition = resp.headers.get("content-disposition", "") + assert "inline" in content_disposition + + +@pytest.mark.asyncio +async def test_preview_file_without_filename_extracts_from_path(monkeypatch): + """Test previewing without filename parameter extracts name from object_name""" + from io import BytesIO + + async def fake_preview(object_name): + return BytesIO(b"PDF content"), "application/pdf" + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + resp = await file_management_app.preview_file( + object_name="folder/subfolder/document.pdf", + filename=None + ) + + content_disposition = resp.headers.get("content-disposition", "") + assert "document.pdf" in content_disposition + + +@pytest.mark.asyncio +async def test_preview_file_chinese_filename(monkeypatch): + """Test previewing with Chinese filename uses UTF-8 encoding""" + from io import BytesIO + + async def fake_preview(object_name): + return BytesIO(b"PDF content"), "application/pdf" + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + resp = await file_management_app.preview_file( + object_name="documents/test.pdf", + filename="测试文档.pdf" + ) + + content_disposition = resp.headers.get("content-disposition", "") + assert "inline" in content_disposition + assert "filename*=UTF-8" in content_disposition or "测试文档" in content_disposition + + +@pytest.mark.asyncio +async def test_preview_file_not_found_error(monkeypatch): + """Test previewing a non-existent file returns 404""" + async def fake_preview(object_name): + raise Exception("File not found") + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + with pytest.raises(Exception) as ei: + await file_management_app.preview_file( + object_name="nonexistent/file.pdf", + filename=None + ) + assert "File not found" in str(ei.value) + + +@pytest.mark.asyncio +async def test_preview_file_too_large_error(monkeypatch): + """Test previewing a file exceeding size limit returns 413""" + _FileTooLargeException = sys.modules["consts.exceptions"].FileTooLargeException + + async def fake_preview(object_name): + raise _FileTooLargeException("File size 110 MB exceeds the 100 MB preview limit") + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + with pytest.raises(Exception) as ei: + await file_management_app.preview_file( + object_name="files/huge.pdf", + filename=None + ) + assert "100 MB" in str(ei.value) + + +@pytest.mark.asyncio +async def test_preview_file_unsupported_format_error(monkeypatch): + """Test previewing an unsupported file format returns 400""" + _UnsupportedFileTypeException = sys.modules["consts.exceptions"].UnsupportedFileTypeException + + async def fake_preview(object_name): + raise _UnsupportedFileTypeException("Unsupported file format for preview") + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + with pytest.raises(Exception) as ei: + await file_management_app.preview_file( + object_name="files/archive.zip", + filename=None + ) + assert "not supported for preview" in str(ei.value) + + +@pytest.mark.asyncio +async def test_preview_file_internal_error(monkeypatch): + """Test previewing with internal error returns 500""" + async def fake_preview(object_name): + raise Exception("Internal server error") + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + with pytest.raises(Exception) as ei: + await file_management_app.preview_file( + object_name="files/test.pdf", + filename=None + ) + assert "Failed to preview file" in str(ei.value) + + +@pytest.mark.asyncio +async def test_preview_file_office_converted_to_pdf(monkeypatch): + """Test previewing an Office document returns converted PDF""" + from io import BytesIO + + async def fake_preview(object_name): + # Office documents are converted to PDF by preview_file_impl + return BytesIO(b"Converted PDF content"), "application/pdf" + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + resp = await file_management_app.preview_file( + object_name="documents/report.docx", + filename="report.docx" + ) + + # Content type should be PDF after conversion + assert resp.media_type == "application/pdf" + content_disposition = resp.headers.get("content-disposition", "") + assert "inline" in content_disposition + + +@pytest.mark.asyncio +async def test_preview_file_has_etag_header(monkeypatch): + """Test preview response includes ETag header for caching""" + from io import BytesIO + + async def fake_preview(object_name): + return BytesIO(b"PDF content"), "application/pdf" + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + resp = await file_management_app.preview_file( + object_name="documents/test.pdf", + filename="test.pdf" + ) + + etag = resp.headers.get("etag", "") + assert "documents/test.pdf" in etag + + +@pytest.mark.asyncio +async def test_preview_file_simple_object_name_without_slash(monkeypatch): + """Test previewing with simple object name without slash""" + from io import BytesIO + + async def fake_preview(object_name): + return BytesIO(b"PDF content"), "application/pdf" + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + resp = await file_management_app.preview_file( + object_name="simple.pdf", + filename=None + ) + + content_disposition = resp.headers.get("content-disposition", "") + assert "simple.pdf" in content_disposition + + +@pytest.mark.asyncio +async def test_preview_file_does_not_exist_error(monkeypatch): + """Test previewing with 'does not exist' error message returns 404""" + _NotFoundException = sys.modules["consts.exceptions"].NotFoundException + + async def fake_preview(object_name): + raise _NotFoundException("The specified key does not exist") + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + with pytest.raises(Exception) as ei: + await file_management_app.preview_file( + object_name="missing/file.pdf", + filename=None + ) + assert "File not found" in str(ei.value) + + diff --git a/test/backend/database/test_attachment_db.py b/test/backend/database/test_attachment_db.py index 4053877fe..7abdd3a07 100644 --- a/test/backend/database/test_attachment_db.py +++ b/test/backend/database/test_attachment_db.py @@ -26,6 +26,14 @@ boto3_mock = MagicMock() sys.modules['boto3'] = boto3_mock +# Mock minio module +minio_mock = MagicMock() +minio_commonconfig_mock = MagicMock() +minio_commonconfig_mock.CopySource = MagicMock() +minio_mock.commonconfig = minio_commonconfig_mock +sys.modules['minio'] = minio_mock +sys.modules['minio.commonconfig'] = minio_commonconfig_mock + # Mock nexent.storage modules nexent_mock = MagicMock() nexent_storage_mock = MagicMock() @@ -58,6 +66,8 @@ download_file, get_file_url, get_file_size_from_minio, + file_exists, + copy_file, list_files, delete_file, get_file_stream, @@ -485,3 +495,82 @@ def test_get_content_type_case_insensitive(self): assert get_content_type('test.PNG') == 'image/png' assert get_content_type('test.PDF') == 'application/pdf' + +class TestFileExists: + """Test cases for file_exists function""" + + def test_file_exists_returns_true_when_file_exists(self): + """Test file_exists returns True when file exists in bucket""" + with patch('backend.database.attachment_db.minio_client') as mock_client: + mock_client.file_exists.return_value = True + + result = file_exists('test/file.txt') + + assert result is True + mock_client.file_exists.assert_called_once_with('test/file.txt', None) + + def test_file_exists_returns_false_when_file_not_exists(self): + """Test file_exists returns False when file does not exist""" + with patch('backend.database.attachment_db.minio_client') as mock_client: + mock_client.file_exists.return_value = False + + result = file_exists('nonexistent/file.txt') + + assert result is False + mock_client.file_exists.assert_called_once_with('nonexistent/file.txt', None) + + def test_file_exists_with_custom_bucket(self): + """Test file_exists with custom bucket parameter""" + with patch('backend.database.attachment_db.minio_client') as mock_client: + mock_client.file_exists.return_value = True + + result = file_exists('test/file.txt', bucket='custom-bucket') + + assert result is True + mock_client.file_exists.assert_called_once_with('test/file.txt', 'custom-bucket') + + def test_file_exists_handles_any_exception(self): + """Test file_exists handles any exception and returns False""" + with patch('backend.database.attachment_db.minio_client') as mock_client: + mock_client.file_exists.side_effect = RuntimeError('Connection failed') + + result = file_exists('test/file.txt') + + assert result is False + mock_client.file_exists.assert_called_once_with('test/file.txt', None) + + +class TestCopyFile: + """Test cases for copy_file function""" + + def test_copy_file_success(self): + """Test successful file copy""" + with patch('backend.database.attachment_db.minio_client') as mock_client: + mock_client.copy_file.return_value = (True, 'dest/file.pdf') + + result = copy_file('source/file.pdf', 'dest/file.pdf') + + assert result['success'] is True + assert result['object_name'] == 'dest/file.pdf' + mock_client.copy_file.assert_called_once_with('source/file.pdf', 'dest/file.pdf', None) + + def test_copy_file_with_custom_bucket(self): + """Test copy_file with custom bucket""" + with patch('backend.database.attachment_db.minio_client') as mock_client: + mock_client.copy_file.return_value = (True, 'dest/file.pdf') + + result = copy_file('source/file.pdf', 'dest/file.pdf', bucket='custom-bucket') + + assert result['success'] is True + mock_client.copy_file.assert_called_once_with('source/file.pdf', 'dest/file.pdf', 'custom-bucket') + + def test_copy_file_failure(self): + """Test copy_file handles errors""" + with patch('backend.database.attachment_db.minio_client') as mock_client: + mock_client.copy_file.return_value = (False, 'Copy failed') + + result = copy_file('source/file.pdf', 'dest/file.pdf') + + assert result['success'] is False + assert 'Copy failed' in result['error'] + diff --git a/test/backend/database/test_client.py b/test/backend/database/test_client.py index b11c7f998..9514fb143 100644 --- a/test/backend/database/test_client.py +++ b/test/backend/database/test_client.py @@ -346,6 +346,107 @@ def test_minio_client_get_file_stream(self, mock_config_class, mock_create_clien mock_storage_client.get_file_stream.assert_called_once_with( 'file.txt', 'bucket') + @patch('backend.database.client.create_storage_client_from_config') + @patch('backend.database.client.MinIOStorageConfig') + def test_minio_client_file_exists_true(self, mock_config_class, mock_create_client): + """Test MinioClient.file_exists returns True when file exists""" + MinioClient._instance = None + + mock_storage_client = MagicMock() + mock_storage_client.exists.return_value = True + mock_create_client.return_value = mock_storage_client + mock_config_class.return_value = MagicMock() + + client = MinioClient() + result = client.file_exists('file.txt', 'bucket') + + assert result is True + mock_storage_client.exists.assert_called_once_with('file.txt', 'bucket') + + @patch('backend.database.client.create_storage_client_from_config') + @patch('backend.database.client.MinIOStorageConfig') + def test_minio_client_file_exists_false(self, mock_config_class, mock_create_client): + """Test MinioClient.file_exists returns False when file does not exist""" + MinioClient._instance = None + + mock_storage_client = MagicMock() + mock_storage_client.exists.return_value = False + mock_create_client.return_value = mock_storage_client + mock_config_class.return_value = MagicMock() + + client = MinioClient() + result = client.file_exists('file.txt', 'bucket') + + assert result is False + mock_storage_client.exists.assert_called_once_with('file.txt', 'bucket') + + @patch('backend.database.client.create_storage_client_from_config') + @patch('backend.database.client.MinIOStorageConfig') + def test_minio_client_copy_file_success(self, mock_config_class, mock_create_client): + """Test MinioClient.copy_file successfully copies file""" + MinioClient._instance = None + + mock_storage_client = MagicMock() + mock_storage_client.copy_file.return_value = (True, 'dest/file.pdf') + mock_create_client.return_value = mock_storage_client + mock_config = MagicMock() + mock_config.default_bucket = 'test-bucket' + mock_config_class.return_value = mock_config + + client = MinioClient() + success, result = client.copy_file('source/file.pdf', 'dest/file.pdf', 'bucket') + + assert success is True + assert result == 'dest/file.pdf' + mock_storage_client.copy_file.assert_called_once_with( + 'source/file.pdf', + 'dest/file.pdf', + 'bucket' + ) + + @patch('backend.database.client.create_storage_client_from_config') + @patch('backend.database.client.MinIOStorageConfig') + def test_minio_client_copy_file_with_default_bucket(self, mock_config_class, mock_create_client): + """Test MinioClient.copy_file uses default bucket when not specified""" + MinioClient._instance = None + + mock_storage_client = MagicMock() + mock_storage_client.copy_file.return_value = (True, 'dest/file.pdf') + mock_create_client.return_value = mock_storage_client + mock_config = MagicMock() + mock_config.default_bucket = 'default-bucket' + mock_config_class.return_value = mock_config + + client = MinioClient() + success, result = client.copy_file('source/file.pdf', 'dest/file.pdf') + + assert success is True + assert result == 'dest/file.pdf' + mock_storage_client.copy_file.assert_called_once_with( + 'source/file.pdf', + 'dest/file.pdf', + None + ) + + @patch('backend.database.client.create_storage_client_from_config') + @patch('backend.database.client.MinIOStorageConfig') + def test_minio_client_copy_file_failure(self, mock_config_class, mock_create_client): + """Test MinioClient.copy_file handles errors properly""" + MinioClient._instance = None + + mock_storage_client = MagicMock() + mock_storage_client.copy_file.return_value = (False, 'Copy failed') + mock_create_client.return_value = mock_storage_client + mock_config = MagicMock() + mock_config.default_bucket = 'test-bucket' + mock_config_class.return_value = mock_config + + client = MinioClient() + success, result = client.copy_file('source/file.pdf', 'dest/file.pdf') + + assert success is False + assert 'Copy failed' in result + class TestGetDbSession: """Test cases for get_db_session context manager""" diff --git a/test/backend/services/test_data_process_service.py b/test/backend/services/test_data_process_service.py index 02c3b11a3..6d33e097a 100644 --- a/test/backend/services/test_data_process_service.py +++ b/test/backend/services/test_data_process_service.py @@ -4,6 +4,7 @@ import io import base64 import asyncio +import types from unittest.mock import patch, MagicMock, AsyncMock import warnings from PIL import Image @@ -42,8 +43,27 @@ mock_const.IMAGE_FILTER = True mock_const.REDIS_BACKEND_URL = "redis://mock:6379/0" mock_const.REDIS_URL = "redis://mock:6379/0" +mock_const.MAX_CONCURRENT_CONVERSIONS = 3 sys.modules['consts.const'] = mock_const +# Stub consts.exceptions with a *real* exception class so assertRaises works correctly +_exceptions_mod = types.ModuleType('consts.exceptions') + + +class OfficeConversionException(Exception): + """Stub OfficeConversionException used in tests.""" + + +_exceptions_mod.OfficeConversionException = OfficeConversionException +sys.modules['consts.exceptions'] = _exceptions_mod + +# Stub utils.file_management_utils (new import in data_process_service) +if 'utils.file_management_utils' not in sys.modules: + import types as _types + _utils_mod = _types.ModuleType('utils.file_management_utils') + _utils_mod.convert_office_to_pdf = AsyncMock() + sys.modules['utils.file_management_utils'] = _utils_mod + # from backend.services.data_process_service import DataProcessService, get_data_process_service with patch('data_process.utils.get_task_info') as mock_get_task_info, \ patch('data_process.utils.get_all_task_ids_from_redis') as mock_get_redis_task_ids: @@ -51,6 +71,21 @@ class TestDataProcessService(unittest.TestCase): + + class _NopSemaphore: + """Drop-in asyncio.Semaphore that never blocks. + + asyncio.Semaphore is bound to the event loop at creation time; using + asyncio.run() in tests creates a new loop each time, so the module-level + semaphore would deadlock. This stub avoids that issue completely. + """ + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return False + def setUp(self): """Set up test environment before each test""" # Create a clean instance for each test @@ -60,18 +95,32 @@ def setUp(self): # Suppress warnings during tests warnings.filterwarnings('ignore', category=UserWarning) + # Replace module-level semaphore with a no-op to avoid asyncio loop issues + import backend.services.data_process_service as _dm + self._dm = _dm + self._orig_sem = _dm._conversion_semaphore + self._nop_sem = TestDataProcessService._NopSemaphore() + _dm._conversion_semaphore = self._nop_sem + # Reset mocks for each test to prevent interference - # Do not import data_process.app here - use the already mocked module mock_celery_app = sys.modules['data_process.app'].app mock_celery_app.reset_mock() self.mock_celery_app = mock_celery_app def tearDown(self): """Clean up after each test""" + # Restore the original semaphore + self._dm._conversion_semaphore = self._orig_sem # Restore environment variables os.environ.clear() os.environ.update(self.original_env) + @staticmethod + def _make_stream(data: bytes): + """Return a BytesIO stream containing *data*.""" + from io import BytesIO + return BytesIO(data) + @patch('backend.services.data_process_service.redis.ConnectionPool.from_url') @patch('backend.services.data_process_service.redis.Redis') def test_init_redis_client_with_url(self, mock_redis, mock_pool): @@ -2162,5 +2211,143 @@ def test_convert_to_base64(self): asyncio.run(self.async_test_convert_to_base64()) + @patch('backend.services.data_process_service.convert_office_to_pdf', + new_callable=AsyncMock) + @patch('backend.services.data_process_service.upload_file') + @patch('backend.services.data_process_service.get_file_size_from_minio') + @patch('backend.services.data_process_service.get_file_stream') + @patch('shutil.rmtree') + @patch('tempfile.mkdtemp', return_value='/tmp/test_cv') + @patch('os.path.exists', return_value=True) + def test_convert_office_to_pdf_impl_success( + self, _exists, _mkdtemp, mock_rmtree, + mock_get_stream, mock_get_size, mock_upload, mock_convert + ): + """Happy path: full pipeline completes and temp dir is cleaned up.""" + mock_get_stream.side_effect = [ + self._make_stream(b'DOC data'), # Step 1: original file + self._make_stream(b'%PDF-1.4 ok'), # Step 4: header check + ] + mock_get_size.return_value = 208 + mock_upload.return_value = {'success': True} + mock_convert.return_value = '/tmp/test_cv/doc.pdf' + + with patch('builtins.open', MagicMock()): + asyncio.run( + self.service.convert_office_to_pdf_impl( + 'uploads/doc.docx', 'converted/doc.pdf' + ) + ) + + mock_convert.assert_called_once() + mock_rmtree.assert_called_once_with('/tmp/test_cv') + + @patch('backend.services.data_process_service.get_file_stream', + return_value=None) + @patch('shutil.rmtree') + @patch('tempfile.mkdtemp', return_value='/tmp/test_cv') + @patch('os.path.exists', return_value=True) + def test_convert_office_to_pdf_impl_source_not_found( + self, _exists, _mkdtemp, mock_rmtree, _get_stream + ): + """Source file missing → OfficeConversionException.""" + # Prevent cleanup path from calling real delete_file + sys.modules['database.attachment_db'].file_exists = MagicMock( + return_value=False + ) + with self.assertRaises(OfficeConversionException) as ctx: + asyncio.run( + self.service.convert_office_to_pdf_impl( + 'uploads/missing.docx', 'converted/missing.pdf' + ) + ) + self.assertIn('Source file not found', str(ctx.exception)) + + @patch('backend.services.data_process_service.convert_office_to_pdf', + new_callable=AsyncMock) + @patch('backend.services.data_process_service.get_file_stream') + @patch('shutil.rmtree') + @patch('tempfile.mkdtemp', return_value='/tmp/test_cv') + @patch('os.path.exists', return_value=True) + def test_convert_office_to_pdf_impl_libreoffice_failure( + self, _exists, _mkdtemp, mock_rmtree, mock_get_stream, mock_convert + ): + """LibreOffice error → OfficeConversionException.""" + mock_get_stream.return_value = self._make_stream(b'DOC data') + mock_convert.side_effect = RuntimeError('soffice not found') + sys.modules['database.attachment_db'].file_exists = MagicMock( + return_value=False + ) + with patch('builtins.open', MagicMock()): + with self.assertRaises(OfficeConversionException) as ctx: + asyncio.run( + self.service.convert_office_to_pdf_impl( + 'uploads/doc.docx', 'converted/doc.pdf' + ) + ) + self.assertIn('LibreOffice conversion failed', str(ctx.exception)) + + @patch('backend.services.data_process_service.convert_office_to_pdf', + new_callable=AsyncMock) + @patch('backend.services.data_process_service.upload_file') + @patch('backend.services.data_process_service.get_file_stream') + @patch('shutil.rmtree') + @patch('tempfile.mkdtemp', return_value='/tmp/test_cv') + @patch('os.path.exists', return_value=True) + def test_convert_office_to_pdf_impl_upload_failure( + self, _exists, _mkdtemp, mock_rmtree, + mock_get_stream, mock_upload, mock_convert + ): + """Upload failure → OfficeConversionException with error detail.""" + mock_get_stream.return_value = self._make_stream(b'DOC data') + mock_convert.return_value = '/tmp/test_cv/doc.pdf' + mock_upload.return_value = {'success': False, 'error': 'quota exceeded'} + sys.modules['database.attachment_db'].file_exists = MagicMock( + return_value=False + ) + with patch('builtins.open', MagicMock()): + with self.assertRaises(OfficeConversionException) as ctx: + asyncio.run( + self.service.convert_office_to_pdf_impl( + 'uploads/doc.docx', 'converted/doc.pdf' + ) + ) + self.assertIn('Failed to upload PDF', str(ctx.exception)) + + @patch('backend.services.data_process_service.delete_file') + @patch('backend.services.data_process_service.file_exists', return_value=True) + @patch('backend.services.data_process_service.convert_office_to_pdf', + new_callable=AsyncMock) + @patch('backend.services.data_process_service.upload_file') + @patch('backend.services.data_process_service.get_file_size_from_minio') + @patch('backend.services.data_process_service.get_file_stream') + @patch('shutil.rmtree') + @patch('tempfile.mkdtemp', return_value='/tmp/test_cv') + @patch('os.path.exists', return_value=True) + def test_convert_office_to_pdf_impl_invalid_pdf_header( + self, _exists, _mkdtemp, mock_rmtree, + mock_get_stream, mock_get_size, mock_upload, mock_convert, + mock_file_exists, mock_delete_file + ): + """Invalid PDF header → OfficeConversionException; remote file deleted.""" + mock_get_stream.side_effect = [ + self._make_stream(b'DOC data'), # Step 1: original file + self._make_stream(b'NOT-PDF'), # Step 4: header check + ] + mock_get_size.return_value = 208 + mock_upload.return_value = {'success': True} + mock_convert.return_value = '/tmp/test_cv/doc.pdf' + + with patch('builtins.open', MagicMock()): + with self.assertRaises(OfficeConversionException) as ctx: + asyncio.run( + self.service.convert_office_to_pdf_impl( + 'uploads/doc.docx', 'converted/doc.pdf' + ) + ) + self.assertIn('invalid PDF header', str(ctx.exception)) + mock_delete_file.assert_called_once_with('converted/doc.pdf') + + if __name__ == '__main__': unittest.main() diff --git a/test/backend/services/test_file_management_service.py b/test/backend/services/test_file_management_service.py index f46f87f13..cc02add6d 100644 --- a/test/backend/services/test_file_management_service.py +++ b/test/backend/services/test_file_management_service.py @@ -82,6 +82,7 @@ def setup_patches(): patch('backend.database.attachment_db.get_file_stream', MagicMock()), patch('backend.database.attachment_db.delete_file', MagicMock()), patch('backend.database.attachment_db.list_files', MagicMock()), + patch('backend.services.file_management_service.get_file_size_from_minio', MagicMock(return_value=0)), patch('backend.services.file_management_service.save_upload_file', AsyncMock()), patch('backend.services.file_management_service.upload_semaphore', MagicMock()), patch('backend.services.file_management_service.upload_dir', @@ -1011,3 +1012,441 @@ def test_get_llm_model_with_different_tenant_ids(self, mock_tenant_config, mock_ assert mock_tenant_config.get_model_config.call_count == 2 assert mock_tenant_config.get_model_config.call_args_list[0][1]["tenant_id"] == "tenant1" assert mock_tenant_config.get_model_config.call_args_list[1][1]["tenant_id"] == "tenant2" + + +class TestPreviewFileImpl: + """Test cases for preview_file_impl function""" + + @pytest.mark.asyncio + async def test_preview_pdf_file_success(self): + """Test previewing a PDF file returns stream directly""" + from backend.services.file_management_service import preview_file_impl + + mock_stream = BytesIO(b"PDF content") + + with patch('backend.services.file_management_service.get_content_type', return_value='application/pdf'), \ + patch('backend.services.file_management_service.get_file_stream', return_value=mock_stream): + + result_stream, result_type = await preview_file_impl("test/document.pdf") + + assert result_type == 'application/pdf' + assert result_stream == mock_stream + + @pytest.mark.asyncio + async def test_preview_image_file_success(self): + """Test previewing an image file returns stream directly""" + from backend.services.file_management_service import preview_file_impl + + mock_stream = BytesIO(b"PNG content") + + with patch('backend.services.file_management_service.get_content_type', return_value='image/png'), \ + patch('backend.services.file_management_service.get_file_stream', return_value=mock_stream): + + result_stream, result_type = await preview_file_impl("test/image.png") + + assert result_type == 'image/png' + assert result_stream == mock_stream + + @pytest.mark.asyncio + async def test_preview_text_file_success(self): + """Test previewing a text file returns stream directly""" + from backend.services.file_management_service import preview_file_impl + + mock_stream = BytesIO(b"Text content") + + with patch('backend.services.file_management_service.get_content_type', return_value='text/plain'), \ + patch('backend.services.file_management_service.get_file_stream', return_value=mock_stream): + + result_stream, result_type = await preview_file_impl("test/readme.txt") + + assert result_type == 'text/plain' + assert result_stream == mock_stream + + @pytest.mark.asyncio + async def test_preview_csv_file_success(self): + """Test previewing a CSV file returns stream directly""" + from backend.services.file_management_service import preview_file_impl + + mock_stream = BytesIO(b"col1,col2\nval1,val2") + + with patch('backend.services.file_management_service.get_content_type', return_value='text/csv'), \ + patch('backend.services.file_management_service.get_file_stream', return_value=mock_stream): + + result_stream, result_type = await preview_file_impl("test/data.csv") + + assert result_type == 'text/csv' + assert result_stream == mock_stream + + @pytest.mark.asyncio + async def test_preview_markdown_file_success(self): + """Test previewing a Markdown file returns stream directly""" + from backend.services.file_management_service import preview_file_impl + + mock_stream = BytesIO(b"# Heading\nContent") + + with patch('backend.services.file_management_service.get_content_type', return_value='text/markdown'), \ + patch('backend.services.file_management_service.get_file_stream', return_value=mock_stream): + + result_stream, result_type = await preview_file_impl("test/readme.md") + + assert result_type == 'text/markdown' + assert result_stream == mock_stream + + @pytest.mark.asyncio + async def test_preview_office_docx_with_cache_hit(self): + """Test previewing a Word document with cached PDF available""" + from backend.services.file_management_service import preview_file_impl + + mock_pdf_stream = BytesIO(b"Cached PDF content") + + with patch('backend.services.file_management_service.get_content_type', + return_value='application/vnd.openxmlformats-officedocument.wordprocessingml.document'), \ + patch('backend.services.file_management_service.file_exists', return_value=True), \ + patch('backend.services.file_management_service.get_file_stream', return_value=mock_pdf_stream): + + result_stream, result_type = await preview_file_impl("test/document.docx") + + assert result_type == 'application/pdf' + assert result_stream == mock_pdf_stream + + @pytest.mark.asyncio + async def test_preview_office_docx_cache_miss_convert_success(self): + """Cache miss: delegates conversion to data-process via HTTP, then serves resulting PDF.""" + from backend.services.file_management_service import preview_file_impl + + mock_pdf_stream = BytesIO(b"%PDF-1.4 converted content") + + # Simulate data-process returning HTTP 200 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "" + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + + mock_http_ctx = MagicMock() + mock_http_ctx.__aenter__ = AsyncMock(return_value=mock_client) + mock_http_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch('backend.services.file_management_service.get_content_type', + return_value='application/vnd.openxmlformats-officedocument.wordprocessingml.document'), \ + patch('backend.services.file_management_service.file_exists', return_value=False), \ + patch('backend.services.file_management_service.get_file_stream', + return_value=mock_pdf_stream), \ + patch('httpx.AsyncClient', return_value=mock_http_ctx), \ + patch('backend.services.file_management_service.copy_file', + return_value={'success': True}), \ + patch('backend.services.file_management_service.delete_file'): + + result_stream, result_type = await preview_file_impl("test/document.docx") + + assert result_type == 'application/pdf' + assert result_stream == mock_pdf_stream + mock_client.post.assert_called_once() + url_called = mock_client.post.call_args[0][0] + assert "convert_to_pdf" in url_called + + @pytest.mark.asyncio + async def test_preview_office_conversion_failure(self): + """HTTP error from data-process service propagates as conversion failure.""" + from backend.services.file_management_service import preview_file_impl + + # Simulate data-process returning HTTP 500 + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + + mock_http_ctx = MagicMock() + mock_http_ctx.__aenter__ = AsyncMock(return_value=mock_client) + mock_http_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch('backend.services.file_management_service.get_content_type', + return_value='application/vnd.openxmlformats-officedocument.wordprocessingml.document'), \ + patch('backend.services.file_management_service.file_exists', return_value=False), \ + patch('httpx.AsyncClient', return_value=mock_http_ctx), \ + patch('backend.services.file_management_service.delete_file'): + + with pytest.raises(Exception) as exc_info: + await preview_file_impl("test/document.docx") + + assert "Failed to convert Office document to PDF" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_preview_unsupported_file_type(self): + """Test previewing an unsupported file type raises exception""" + from backend.services.file_management_service import preview_file_impl + + with patch('backend.services.file_management_service.get_content_type', + return_value='application/octet-stream'): + + with pytest.raises(Exception) as exc_info: + await preview_file_impl("test/unknown.bin") + + assert "Unsupported file type for preview" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_preview_file_not_found(self): + """Test previewing a non-existent file raises exception""" + from backend.services.file_management_service import preview_file_impl + + with patch('backend.services.file_management_service.get_content_type', return_value='application/pdf'), \ + patch('backend.services.file_management_service.get_file_stream', return_value=None): + + with pytest.raises(Exception) as exc_info: + await preview_file_impl("test/nonexistent.pdf") + + assert "File not found" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_preview_file_too_large(self): + """Test that files exceeding FILE_PREVIEW_SIZE_LIMIT raise FileTooLargeException""" + from backend.services.file_management_service import preview_file_impl, FILE_PREVIEW_SIZE_LIMIT + + oversized = FILE_PREVIEW_SIZE_LIMIT + 1 + with patch('backend.services.file_management_service.get_file_size_from_minio', return_value=oversized): + with pytest.raises(Exception) as exc_info: + await preview_file_impl("test/large_file.pdf") + + assert str(FILE_PREVIEW_SIZE_LIMIT // (1024 * 1024)) in str(exc_info.value) + + @pytest.mark.asyncio + @pytest.mark.parametrize("content_type,expected_direct", [ + ('application/pdf', True), + ('image/jpeg', True), + ('image/png', True), + ('image/gif', True), + ('image/webp', True), + ('text/plain', True), + ('text/csv', True), + ('text/markdown', True), + ('application/vnd.openxmlformats-officedocument.wordprocessingml.document', False), + ('application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', False), + ('application/vnd.openxmlformats-officedocument.presentationml.presentation', False), + ('application/msword', False), + ('application/vnd.ms-excel', False), + ('application/vnd.ms-powerpoint', False), + ]) + async def test_preview_file_type_routing(self, content_type, expected_direct): + """Test that different file types are routed correctly""" + from backend.services.file_management_service import preview_file_impl + + mock_stream = BytesIO(b"test content") + get_stream_call_count = 0 + + def mock_get_file_stream(object_name): + nonlocal get_stream_call_count + get_stream_call_count += 1 + return mock_stream + + with patch('backend.services.file_management_service.get_content_type', return_value=content_type), \ + patch('backend.services.file_management_service.file_exists', return_value=True), \ + patch('backend.services.file_management_service.get_file_stream', side_effect=mock_get_file_stream): + + result_stream, result_type = await preview_file_impl("test/file") + + assert result_stream == mock_stream + if expected_direct: + # Direct file types should call get_file_stream once + assert get_stream_call_count == 1 + assert result_type == content_type + else: + # Office files return PDF type + assert result_type == 'application/pdf' + + +class TestGetCachedPdfStream: + """Unit tests for _get_cached_pdf_stream helper.""" + + def test_returns_stream_when_cache_valid(self): + """Returns the stream when file exists and is readable.""" + from backend.services.file_management_service import _get_cached_pdf_stream + + mock_stream = BytesIO(b"%PDF-1.4") + with patch('backend.services.file_management_service.file_exists', return_value=True), \ + patch('backend.services.file_management_service.get_file_stream', return_value=mock_stream): + result = _get_cached_pdf_stream("preview/converted/doc_abc12345.pdf") + assert result is mock_stream + + def test_returns_none_when_file_not_exist(self): + """Returns None immediately when the cached file does not exist.""" + from backend.services.file_management_service import _get_cached_pdf_stream + + with patch('backend.services.file_management_service.file_exists', return_value=False): + result = _get_cached_pdf_stream("preview/converted/doc_abc12345.pdf") + assert result is None + + def test_deletes_and_returns_none_when_cache_corrupted(self): + """Deletes the corrupted cache entry and returns None when stream cannot be read.""" + from backend.services.file_management_service import _get_cached_pdf_stream + + with patch('backend.services.file_management_service.file_exists', return_value=True), \ + patch('backend.services.file_management_service.get_file_stream', return_value=None), \ + patch('backend.services.file_management_service.delete_file') as mock_delete: + result = _get_cached_pdf_stream("preview/converted/doc_abc12345.pdf") + assert result is None + mock_delete.assert_called_once_with("preview/converted/doc_abc12345.pdf") + + +class TestConvertOfficeToCachedPdf: + """Unit tests for _convert_office_to_cached_pdf helper.""" + + @pytest.mark.asyncio + async def test_returns_stream_on_double_check_cache_hit(self): + """If another coroutine completes conversion while we waited for the lock, serves from cache.""" + from backend.services.file_management_service import _convert_office_to_cached_pdf + + mock_stream = BytesIO(b"%PDF-1.4 already done") + # file_exists returns False on the outer check but the helper is called after lock acquisition + with patch('backend.services.file_management_service._get_cached_pdf_stream', + return_value=mock_stream): + result = await _convert_office_to_cached_pdf( + "docs/report.docx", + "preview/converted/docs/report_deadbeef.pdf", + "preview/converting/docs/report_deadbeef.pdf.tmp", + ) + assert result is mock_stream + + @pytest.mark.asyncio + async def test_full_conversion_success(self): + """Happy path: calls data-process, copies result, deletes temp, returns stream.""" + from backend.services.file_management_service import _convert_office_to_cached_pdf + + final_stream = BytesIO(b"%PDF-1.4 fresh") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "" + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + + mock_http_ctx = MagicMock() + mock_http_ctx.__aenter__ = AsyncMock(return_value=mock_client) + mock_http_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch('backend.services.file_management_service._get_cached_pdf_stream', + return_value=None), \ + patch('httpx.AsyncClient', return_value=mock_http_ctx), \ + patch('backend.services.file_management_service.copy_file', + return_value={'success': True}), \ + patch('backend.services.file_management_service.delete_file') as mock_delete, \ + patch('backend.services.file_management_service.file_exists', return_value=False), \ + patch('backend.services.file_management_service.get_file_stream', + return_value=final_stream): + + result = await _convert_office_to_cached_pdf( + "docs/report.docx", + "preview/converted/docs/report_deadbeef.pdf", + "preview/converting/docs/report_deadbeef.pdf.tmp", + ) + + assert result is final_stream + mock_client.post.assert_called_once() + called_url = mock_client.post.call_args[0][0] + assert "convert_to_pdf" in called_url + # Temp file should be deleted after successful copy + mock_delete.assert_called_with("preview/converting/docs/report_deadbeef.pdf.tmp") + + @pytest.mark.asyncio + async def test_http_error_raises_office_conversion_exception(self): + """Non-200 HTTP response from data-process raises OfficeConversionException.""" + from backend.services.file_management_service import _convert_office_to_cached_pdf + from consts.exceptions import OfficeConversionException + + mock_response = MagicMock() + mock_response.status_code = 503 + mock_response.text = "Service Unavailable" + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + + mock_http_ctx = MagicMock() + mock_http_ctx.__aenter__ = AsyncMock(return_value=mock_client) + mock_http_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch('backend.services.file_management_service._get_cached_pdf_stream', + return_value=None), \ + patch('httpx.AsyncClient', return_value=mock_http_ctx), \ + patch('backend.services.file_management_service.file_exists', return_value=False), \ + patch('backend.services.file_management_service.delete_file'): + + with pytest.raises(OfficeConversionException) as exc_info: + await _convert_office_to_cached_pdf( + "docs/report.docx", + "preview/converted/docs/report_deadbeef.pdf", + "preview/converting/docs/report_deadbeef.pdf.tmp", + ) + + assert "Failed to convert Office document to PDF" in str(exc_info.value) + assert "503" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_copy_failure_raises_office_conversion_exception(self): + """copy_file failure raises OfficeConversionException and cleans up temp file.""" + from backend.services.file_management_service import _convert_office_to_cached_pdf + from consts.exceptions import OfficeConversionException + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "" + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + + mock_http_ctx = MagicMock() + mock_http_ctx.__aenter__ = AsyncMock(return_value=mock_client) + mock_http_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch('backend.services.file_management_service._get_cached_pdf_stream', + return_value=None), \ + patch('httpx.AsyncClient', return_value=mock_http_ctx), \ + patch('backend.services.file_management_service.copy_file', + return_value={'success': False, 'error': 'bucket full'}), \ + patch('backend.services.file_management_service.file_exists', return_value=True), \ + patch('backend.services.file_management_service.delete_file') as mock_delete: + + with pytest.raises(OfficeConversionException): + await _convert_office_to_cached_pdf( + "docs/report.docx", + "preview/converted/docs/report_deadbeef.pdf", + "preview/converting/docs/report_deadbeef.pdf.tmp", + ) + + # Cleanup: temp file must be deleted on failure + mock_delete.assert_called_with("preview/converting/docs/report_deadbeef.pdf.tmp") + + @pytest.mark.asyncio + async def test_converted_pdf_not_readable_raises_not_found(self): + """Raises NotFoundException when the final PDF cannot be read after successful conversion.""" + from backend.services.file_management_service import _convert_office_to_cached_pdf + from consts.exceptions import NotFoundException + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "" + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + + mock_http_ctx = MagicMock() + mock_http_ctx.__aenter__ = AsyncMock(return_value=mock_client) + mock_http_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch('backend.services.file_management_service._get_cached_pdf_stream', + return_value=None), \ + patch('httpx.AsyncClient', return_value=mock_http_ctx), \ + patch('backend.services.file_management_service.copy_file', + return_value={'success': True}), \ + patch('backend.services.file_management_service.delete_file'), \ + patch('backend.services.file_management_service.file_exists', return_value=False), \ + patch('backend.services.file_management_service.get_file_stream', return_value=None): + + with pytest.raises(NotFoundException): + await _convert_office_to_cached_pdf( + "docs/report.docx", + "preview/converted/docs/report_deadbeef.pdf", + "preview/converting/docs/report_deadbeef.pdf.tmp", + ) diff --git a/test/backend/utils/test_file_management_utils.py b/test/backend/utils/test_file_management_utils.py index 02553db8f..a7696a682 100644 --- a/test/backend/utils/test_file_management_utils.py +++ b/test/backend/utils/test_file_management_utils.py @@ -704,3 +704,98 @@ async def _fake_convert(*a, **k): # total_chunks should remain from task state (12) since redis_total is None assert out["/p9"]["total_chunks"] == 12 + +class TestConvertOfficeToPdf: + """Test cases for convert_office_to_pdf function""" + + @pytest.mark.asyncio + async def test_convert_office_to_pdf_success(self, fmu, monkeypatch): + """Test successful Office to PDF conversion""" + import subprocess + + mock_result = types.SimpleNamespace(returncode=0, stderr="", stdout="") + + monkeypatch.setattr(fmu.os.path, "exists", lambda p: True) + monkeypatch.setattr(fmu.os.path, "basename", lambda p: "document.docx") + monkeypatch.setattr(fmu.subprocess, "run", lambda *a, **k: mock_result) + + result = await fmu.convert_office_to_pdf('/tmp/document.docx', '/tmp/output') + + assert result == '/tmp/output/document.pdf' + + @pytest.mark.asyncio + async def test_convert_office_to_pdf_input_not_found(self, fmu, monkeypatch): + """Test conversion failure when input file does not exist""" + monkeypatch.setattr(fmu.os.path, "exists", lambda p: False) + + with pytest.raises(FileNotFoundError) as exc_info: + await fmu.convert_office_to_pdf('/tmp/nonexistent.docx', '/tmp/output') + + assert "Input file not found" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_convert_office_to_pdf_libreoffice_error(self, fmu, monkeypatch): + """Test conversion failure when LibreOffice returns error""" + mock_result = types.SimpleNamespace(returncode=1, stderr="Error: LibreOffice crashed", stdout="") + + monkeypatch.setattr(fmu.os.path, "exists", lambda p: True) + monkeypatch.setattr(fmu.subprocess, "run", lambda *a, **k: mock_result) + + with pytest.raises(RuntimeError) as exc_info: + await fmu.convert_office_to_pdf('/tmp/document.docx', '/tmp/output') + + assert "Office to PDF conversion failed" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_convert_office_to_pdf_timeout(self, fmu, monkeypatch): + """Test conversion failure due to timeout""" + import subprocess + + monkeypatch.setattr(fmu.os.path, "exists", lambda p: True) + + def raise_timeout(*a, **k): + raise subprocess.TimeoutExpired(cmd='libreoffice', timeout=30) + + monkeypatch.setattr(fmu.subprocess, "run", raise_timeout) + + with pytest.raises(TimeoutError) as exc_info: + await fmu.convert_office_to_pdf('/tmp/document.docx', '/tmp/output', timeout=30) + + assert "timeout" in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_convert_office_to_pdf_libreoffice_not_installed(self, fmu, monkeypatch): + """Test conversion failure when LibreOffice is not installed""" + monkeypatch.setattr(fmu.os.path, "exists", lambda p: True) + + def raise_file_not_found(*a, **k): + raise FileNotFoundError("[Errno 2] No such file or directory: 'libreoffice'") + + monkeypatch.setattr(fmu.subprocess, "run", raise_file_not_found) + + with pytest.raises(FileNotFoundError) as exc_info: + await fmu.convert_office_to_pdf('/tmp/document.docx', '/tmp/output') + + assert "LibreOffice is not installed" in str(exc_info.value) + assert "not available in PATH" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_convert_office_to_pdf_output_not_found(self, fmu, monkeypatch): + """Test conversion failure when output PDF is not generated""" + mock_result = types.SimpleNamespace(returncode=0, stderr="", stdout="") + + def exists_side_effect(path): + # Input file exists, output PDF does not + if 'document.docx' in path: + return True + return False + + monkeypatch.setattr(fmu.os.path, "exists", exists_side_effect) + monkeypatch.setattr(fmu.os.path, "basename", lambda p: "document.docx") + monkeypatch.setattr(fmu.subprocess, "run", lambda *a, **k: mock_result) + + with pytest.raises(RuntimeError) as exc_info: + await fmu.convert_office_to_pdf('/tmp/document.docx', '/tmp/output') + + assert "Converted PDF not found" in str(exc_info.value) + diff --git a/test/sdk/storage/test_minio.py b/test/sdk/storage/test_minio.py index e9ad2972d..75ea1a3dd 100644 --- a/test/sdk/storage/test_minio.py +++ b/test/sdk/storage/test_minio.py @@ -883,3 +883,92 @@ def test_exists_without_bucket(self, mock_boto3): assert exists is False + +class TestMinIOStorageClientCopyFile: + """Test cases for copy_file method""" + + @patch('nexent.storage.minio.boto3') + def test_copy_file_success(self, mock_boto3): + """Test successful file copy within the same bucket""" + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + mock_client.head_bucket.return_value = None + + client = MinIOStorageClient( + endpoint="http://localhost:9000", + access_key="minioadmin", + secret_key="minioadmin", + default_bucket="test-bucket" + ) + + success, result = client.copy_file('src.txt', 'dst.txt', 'test-bucket') + + assert success is True + assert result == 'dst.txt' + mock_client.copy_object.assert_called_once_with( + Bucket='test-bucket', + Key='dst.txt', + CopySource={'Bucket': 'test-bucket', 'Key': 'src.txt'} + ) + + @patch('nexent.storage.minio.boto3') + def test_copy_file_uses_default_bucket(self, mock_boto3): + """Test copy_file falls back to default bucket when bucket is not specified""" + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + mock_client.head_bucket.return_value = None + + client = MinIOStorageClient( + endpoint="http://localhost:9000", + access_key="minioadmin", + secret_key="minioadmin", + default_bucket="test-bucket" + ) + + success, result = client.copy_file('src.txt', 'dst.txt') + + assert success is True + assert result == 'dst.txt' + mock_client.copy_object.assert_called_once_with( + Bucket='test-bucket', + Key='dst.txt', + CopySource={'Bucket': 'test-bucket', 'Key': 'src.txt'} + ) + + @patch('nexent.storage.minio.boto3') + def test_copy_file_without_bucket(self, mock_boto3): + """Test copy_file fails when no bucket is configured""" + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + + client = MinIOStorageClient( + endpoint="http://localhost:9000", + access_key="minioadmin", + secret_key="minioadmin" + ) + + success, result = client.copy_file('src.txt', 'dst.txt') + + assert success is False + assert result == "Bucket name is required" + mock_client.copy_object.assert_not_called() + + @patch('nexent.storage.minio.boto3') + def test_copy_file_exception(self, mock_boto3): + """Test copy_file returns failure on unexpected exception""" + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + mock_client.head_bucket.return_value = None + mock_client.copy_object.side_effect = Exception("copy failed") + + client = MinIOStorageClient( + endpoint="http://localhost:9000", + access_key="minioadmin", + secret_key="minioadmin", + default_bucket="test-bucket" + ) + + success, result = client.copy_file('src.txt', 'dst.txt') + + assert success is False + assert "copy failed" in result From 57a24a4a2905fe4c8933b48ddff14d3f22164d02 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Fri, 6 Mar 2026 17:08:57 +0800 Subject: [PATCH 16/75] improve codecov for testfiles --- .../providers/test_dashscope_provider.py | 51 ++++++++++++++++++ .../services/test_model_management_service.py | 52 +++++++++++++++++++ 2 files changed, 103 insertions(+) diff --git a/test/backend/services/providers/test_dashscope_provider.py b/test/backend/services/providers/test_dashscope_provider.py index 44bbdbda5..0bc2d3ad8 100644 --- a/test/backend/services/providers/test_dashscope_provider.py +++ b/test/backend/services/providers/test_dashscope_provider.py @@ -574,6 +574,57 @@ async def test_get_models_unknown_type_returns_empty(self, mocker: MockFixture): assert result == [] + @pytest.mark.asyncio + async def test_get_models_rate_limit_retry(self, mocker: MockFixture): + """Test that a 429 response triggers a retry after sleeping.""" + rate_limit_response = MagicMock() + rate_limit_response.status_code = 429 + + ok_response = MagicMock() + ok_response.status_code = 200 + ok_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-turbo", + "description": "Text generation", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"], + }, + } + ] + } + } + ok_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.side_effect = [rate_limit_response, ok_response] + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm, + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models", + ) + mocker.patch( + "backend.services.providers.dashscope_provider.asyncio.sleep", + new=AsyncMock(), + ) + + provider = DashScopeModelProvider() + result = await provider.get_models({"model_type": "llm", "api_key": "test-key"}) + + assert mock_client.get.call_count == 2 + assert len(result) == 1 + assert result[0]["id"] == "qwen-turbo" + @pytest.mark.asyncio async def test_get_models_with_chinese_description(self, mocker: MockFixture): """Test model classification by Chinese description.""" diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index e5d52d31a..6e504e90a 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -557,6 +557,58 @@ async def test_create_provider_models_for_tenant_exception(): assert "Failed to create provider models" in str(exc.value) +@pytest.mark.asyncio +async def test_batch_create_models_for_tenant_dashscope_provider(): + """Test batch_create_models_for_tenant with DASHSCOPE provider uses DASHSCOPE_BASE_URL.""" + svc = import_svc() + + batch_payload = { + "provider": "dashscope", + "type": "llm", + "models": [{"id": "qwen/qwen-turbo", "max_tokens": 8192}], + "api_key": "dash-key", + } + + with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \ + mock.patch.object(svc, "delete_model_record"), \ + mock.patch.object(svc, "split_repo_name", return_value=("qwen", "qwen-turbo")), \ + mock.patch.object(svc, "add_repo_to_name", return_value="qwen/qwen-turbo"), \ + mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \ + mock.patch.object(svc, "create_model_record", return_value=True): + + await svc.batch_create_models_for_tenant("u1", "t1", batch_payload) + + call_args = svc.prepare_model_dict.call_args + assert call_args[1]["model_url"] == "https://dashscope.aliyuncs.com/compatible-mode/v1/" + + +@pytest.mark.asyncio +async def test_batch_create_models_for_tenant_tokenpony_provider(): + """Test batch_create_models_for_tenant with TOKENPONY provider uses TOKENPONY_BASE_URL.""" + svc = import_svc() + + batch_payload = { + "provider": "tokenpony", + "type": "llm", + "models": [{"id": "gpt/gpt-4o", "max_tokens": 128000}], + "api_key": "tp-key", + } + + with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \ + mock.patch.object(svc, "delete_model_record"), \ + mock.patch.object(svc, "split_repo_name", return_value=("gpt", "gpt-4o")), \ + mock.patch.object(svc, "add_repo_to_name", return_value="gpt/gpt-4o"), \ + mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 2})), \ + mock.patch.object(svc, "create_model_record", return_value=True): + + await svc.batch_create_models_for_tenant("u1", "t1", batch_payload) + + call_args = svc.prepare_model_dict.call_args + assert call_args[1]["model_url"] == "https://api.tokenpony.cn/v1/" + + @pytest.mark.asyncio async def test_batch_create_models_for_tenant_other_provider(): """Test batch_create_models_for_tenant with non-Silicon/ModelEngine provider (covers lines 138-140)""" From 00854caee381a1dbb29daab35c4e04070441de34 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Tue, 3 Mar 2026 17:05:01 +0800 Subject: [PATCH 17/75] implement DashScope and TokenPony model providers --- backend/consts/provider.py | 10 ++ backend/services/model_management_service.py | 6 +- backend/services/model_provider_service.py | 8 ++ .../services/providers/dashscope_provider.py | 131 +++++++++++++++++ .../services/providers/tokenpony_provider.py | 120 ++++++++++++++++ .../components/model/ModelAddDialog.tsx | 34 ++++- .../models/components/model/ModelListCard.tsx | 116 ++++++++++++++- frontend/const/modelConfig.ts | 6 + frontend/hooks/model/useDashscopeModelList.ts | 133 ++++++++++++++++++ frontend/hooks/model/useTokenponyModelList.ts | 133 ++++++++++++++++++ frontend/public/locales/en/common.json | 6 + frontend/public/locales/zh/common.json | 6 + frontend/public/tokenpony.png | Bin 0 -> 1296 bytes frontend/types/modelConfig.ts | 2 + 14 files changed, 701 insertions(+), 10 deletions(-) create mode 100644 backend/services/providers/dashscope_provider.py create mode 100644 backend/services/providers/tokenpony_provider.py create mode 100644 frontend/hooks/model/useDashscopeModelList.ts create mode 100644 frontend/hooks/model/useTokenponyModelList.ts create mode 100644 frontend/public/tokenpony.png diff --git a/backend/consts/provider.py b/backend/consts/provider.py index 7fd783015..e2a0f0235 100644 --- a/backend/consts/provider.py +++ b/backend/consts/provider.py @@ -6,11 +6,21 @@ class ProviderEnum(str, Enum): SILICON = "silicon" OPENAI = "openai" MODELENGINE = "modelengine" + DASHSCOPE = "dashscope" + TOKENPONY = "tokenpony" # Silicon Flow SILICON_BASE_URL = "https://api.siliconflow.cn/v1/" SILICON_GET_URL = "https://api.siliconflow.cn/v1/models" +# Dashcope +DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" +DASHSCOPE_GET_URL = "https://dashscope.aliyuncs.com/api/v1/models" + +# TokenPony +TOKENPONY_BASE_URL = "https://api.tokenpony.cn/v1" +TOKENPONY_GET_URL = "https://api.tokenpony.cn/v1/models" + # ModelEngine # Base URL and API key are loaded from environment variables at runtime diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index 4b8265028..a18c16c36 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -3,7 +3,7 @@ from consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST from consts.model import ModelConnectStatusEnum -from consts.provider import ProviderEnum, SILICON_BASE_URL +from consts.provider import ProviderEnum, SILICON_BASE_URL, DASHSCOPE_BASE_URL, TOKENPONY_BASE_URL from database.model_management_db import ( create_model_record, @@ -142,6 +142,10 @@ async def batch_create_models_for_tenant(user_id: str, tenant_id: str, batch_pay elif provider == ProviderEnum.MODELENGINE.value: # ModelEngine models carry their own base_url in each model dict model_url = "" + elif provider == ProviderEnum.DASHSCOPE.value: + model_url = DASHSCOPE_BASE_URL + elif provider == ProviderEnum.TOKENPONY.value: + model_url = TOKENPONY_BASE_URL else: model_url = "" diff --git a/backend/services/model_provider_service.py b/backend/services/model_provider_service.py index a302eb999..3c916eb8c 100644 --- a/backend/services/model_provider_service.py +++ b/backend/services/model_provider_service.py @@ -11,6 +11,8 @@ from services.model_health_service import embedding_dimension_check from services.providers.base import AbstractModelProvider from services.providers.silicon_provider import SiliconModelProvider +from services.providers.tokenpony_provider import TokenPonyModelProvider +from services.providers.dashscope_provider import DashScopeModelProvider from services.providers.modelengine_provider import ModelEngineProvider, get_model_engine_raw_url, MODEL_ENGINE_NORTH_PREFIX from utils.model_name_utils import split_repo_name, add_repo_to_name @@ -40,6 +42,12 @@ async def get_provider_models(model_data: dict) -> List[dict]: elif model_data["provider"] == ProviderEnum.MODELENGINE.value: provider = ModelEngineProvider() model_list = await provider.get_models(model_data) + elif model_data["provider"] == ProviderEnum.DASHSCOPE.value: + provider = DashScopeModelProvider() + model_list = await provider.get_models(model_data) + elif model_data["provider"] == ProviderEnum.TOKENPONY.value: + provider = TokenPonyModelProvider() + model_list = await provider.get_models(model_data) return model_list diff --git a/backend/services/providers/dashscope_provider.py b/backend/services/providers/dashscope_provider.py new file mode 100644 index 000000000..2a34823ed --- /dev/null +++ b/backend/services/providers/dashscope_provider.py @@ -0,0 +1,131 @@ +import httpx +from typing import Dict, List +import asyncio +from consts.const import DEFAULT_LLM_MAX_TOKENS +from consts.provider import DASHSCOPE_GET_URL +from services.providers.base import AbstractModelProvider, _classify_provider_error + + +class DashScopeModelProvider(AbstractModelProvider): + """Concrete implementation for DashScope (Aliyun) provider.""" + + async def get_models(self, provider_config: Dict) -> List[Dict]: + """ + Fetch models from DashScope API, categorize them, and return + the requested model type. + + Args: + provider_config: Configuration dict containing model_type and api_key + + Returns: + List of models with canonical fields. Returns error dict if API call fails. + """ + try: + target_model_type: str = provider_config["model_type"] + model_api_key: str = provider_config["api_key"] + + headers = {"Authorization": f"Bearer {model_api_key}"} + base_url = DASHSCOPE_GET_URL + + all_models: List[Dict] = [] + current_page = 1 + + # Fetch all models with pagination asynchronously + async with httpx.AsyncClient(verify=False) as client: + while True: + params = {"page_size": 100, "page_no": current_page} + response = await client.get(base_url, headers=headers, params=params) + response.raise_for_status() + + data = response.json() + models = data.get("output", {}).get("models", []) + + if response.status_code == 429: + await asyncio.sleep(2) + continue + if not models : # Break loop if no more models on the current page + break + + all_models.extend(models) + if(len(models)<100): + break + current_page += 1 + await asyncio.sleep(0.5) + + # Initialize containers for the 6 main categories + categorized_models = { + "chat": [], # Maps to "llm" + "vlm": [], # Maps to "vlm" + "embedding": [], # Maps to "embedding" / "multi_embedding" + "reranker": [], # Maps to "reranker" + "tts": [], # Maps to "tts" + "stt": [] # Maps to "stt" + } + + # Classify models and inject canonical fields expected downstream + for model_obj in all_models: + # Extract key fields for logical determination (lowercased for robustness) + m_id = model_obj.get('model', '').lower() + desc = model_obj.get('description', '') + metadata = model_obj.get('inference_metadata', {}) + req_mod = metadata.get('request_modality', []) + res_mod = metadata.get('response_modality', []) + model_obj.setdefault("object", model_obj.get("object", "model")) + model_obj.setdefault("owned_by", model_obj.get("owned_by", "dashscope")) + cleaned_model = { + "id": m_id, + "object": model_obj.get("object"), + "created": 0, + "owned_by": model_obj.get("owned_by"), + "model_tag": "", + "model_type": "", + "max_tokens": DEFAULT_LLM_MAX_TOKENS + } + # 1. Embedding + if 'embedding' in m_id.lower() or '向量' in desc: + cleaned_model.update({"model_tag": "embedding", "model_type": "embedding"}) + categorized_models['embedding'].append(cleaned_model) + continue + + # 2. Reranker + if 'rerank' in m_id.lower() or '重排序' in desc: + cleaned_model.update({"model_tag": "reranker", "model_type": "reranker"}) + categorized_models['reranker'].append(cleaned_model) + continue + + # 3. STT + if 'Audio' in req_mod and 'Text' in res_mod: + cleaned_model.update({"model_tag": "stt", "model_type": "stt"}) + categorized_models['stt'].append(cleaned_model) + continue + + # 4. TTS + if 'Audio' in res_mod and 'Video' not in res_mod: + cleaned_model.update({"model_tag": "tts", "model_type": "tts"}) + categorized_models['tts'].append(cleaned_model) + continue + + # 5. VLM + vision_mods = {'Image', 'Video'} + if (set(req_mod) & vision_mods) or (set(res_mod) & vision_mods) or '视觉' in desc: + cleaned_model.update({"model_tag": "chat", "model_type": "vlm"}) + categorized_models['vlm'].append(cleaned_model) + continue + + # 6. Chat / LLM + if 'Text' in req_mod or 'Text' in res_mod: + cleaned_model.update({"model_tag": "chat", "model_type": "llm"}) + categorized_models['chat'].append(cleaned_model) + + # Return the specific list based on the requested target_model_type + if target_model_type == "llm": + return categorized_models["chat"] + elif target_model_type in ("embedding", "multi_embedding"): + return categorized_models["embedding"] + elif target_model_type in categorized_models: + return categorized_models[target_model_type] + else: + return [] + except (httpx.HTTPStatusError, httpx.ConnectTimeout, httpx.ConnectError, Exception) as e: + return _classify_provider_error("DashScope", exception=e) + diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py new file mode 100644 index 000000000..62972b698 --- /dev/null +++ b/backend/services/providers/tokenpony_provider.py @@ -0,0 +1,120 @@ +import httpx +import ssl + +from typing import Dict, List + + +from consts.const import DEFAULT_LLM_MAX_TOKENS +from consts.provider import TOKENPONY_GET_URL +from services.providers.base import AbstractModelProvider, _classify_provider_error + + +class TokenPonyModelProvider(AbstractModelProvider): + """Concrete implementation for TokenPony provider.""" + + async def get_models(self, provider_config: Dict) -> List[Dict]: + """ + Fetch models from TokenPony API, categorize them based on modality/ID, + and return the requested model type. + + Args: + provider_config: Configuration dict containing model_type and api_key + + Returns: + List of models with canonical fields. Returns error dict if API call fails. + """ + try: + target_model_type: str = provider_config["model_type"] + model_api_key: str = provider_config["api_key"] + + headers = {"Authorization": f"Bearer {model_api_key}"} + url = TOKENPONY_GET_URL + + + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + ssl_context.set_ciphers("DEFAULT@SECLEVEL=1") + # response = requests.get(url, headers=headers) + # all_models=[] + # if response.status_code == 200: + # data = response.json() + # # 注意:OpenAI 标准返回是在 "data" 字段下 + # all_models=data.get("data", []) + # Fetch all models asynchronously + async with httpx.AsyncClient(http2=True) as client: + response = await client.get(url, headers=headers) + response.raise_for_status() + # OpenAI standard response puts the model list inside the "data" array + all_models: List[Dict] = response.json().get("data", []) + + # Initialize containers for the 6 main categories + categorized_models = { + "chat": [], # Maps to "llm" + "vlm": [], # Maps to "vlm" + "embedding": [], # Maps to "embedding" / "multi_embedding" + "reranker": [], # Maps to "reranker" + "tts": [], # Maps to "tts" + "stt": [] # Maps to "stt" + } + + # Classify models and inject canonical fields expected downstream + for model_obj in all_models: + m_id = model_obj['id'].lower() + model_obj.setdefault("object", model_obj.get("object", "model")) + model_obj.setdefault("owned_by", model_obj.get("owned_by", "tokenpony")) + cleaned_model = { + "id": m_id, + "object": model_obj.get("object"), + "created": 0, + "owned_by": model_obj.get("owned_by"), + "model_tag": "", + "model_type": "", + "max_tokens": DEFAULT_LLM_MAX_TOKENS + } + # 1. Embedding + if 'embedding' in m_id or m_id.startswith('bge-'): + cleaned_model.update({"model_tag": "embedding", "model_type": "embedding", "max_tokens": 0}) + categorized_models['embedding'].append(cleaned_model) + + # 2. Reranker + elif 'rerank' in m_id: + cleaned_model.update({"model_tag": "reranker", "model_type": "reranker"}) + categorized_models['reranker'].append(cleaned_model) + + + # 3. STT (Speech-to-Text / Audio understanding) + elif 'stt' in m_id: + cleaned_model.update({"model_tag": "stt", "model_type": "stt"}) + categorized_models['stt'].append(cleaned_model) + + + # 4. TTS (Text-to-Speech) + elif 'tts' in m_id: + cleaned_model.update({"model_tag": "tts", "model_type": "tts"}) + categorized_models['tts'].append(cleaned_model) + + # 5. VLM (Vision Language Model / Image & Video Generation) + + elif any(keyword in m_id for keyword in ['-vl', 'vl-', 'ocr', 'vision']): + cleaned_model.update({"model_tag": "chat", "model_type": "vlm"}) + categorized_models['vlm'].append(cleaned_model) + + # 6. Chat (Pure Text Conversation / Reasoning) + # Fallback check added: 'not metadata' catches standard OpenAI models that lack modality data + else : + cleaned_model.update({"model_tag": "chat", "model_type": "llm"}) + categorized_models['chat'].append(cleaned_model) + + # Return the specific list based on the requested target_model_type + if target_model_type == "llm": + return categorized_models["chat"] + elif target_model_type in ("embedding", "multi_embedding"): + return categorized_models["embedding"] + elif target_model_type in categorized_models: + return categorized_models[target_model_type] + else: + return [] + + except (httpx.HTTPStatusError, httpx.ConnectTimeout, httpx.ConnectError, Exception) as e: + return _classify_provider_error("TokenPony", exception=e) diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index 2df9643a9..cd258abc8 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -16,6 +16,8 @@ import { modelService } from "@/services/modelService"; import { ModelType, SingleModelConfig } from "@/types/modelConfig"; import { MODEL_TYPES, PROVIDER_LINKS } from "@/const/modelConfig"; import { useSiliconModelList } from "@/hooks/model/useSiliconModelList"; +import { useDashscopeModelList } from "@/hooks/model/useDashscopeModelList"; +import { useTokenPonyModelList } from "@/hooks/model/useTokenponyModelList"; import log from "@/lib/logger"; import { ModelChunkSizeSlider, @@ -248,7 +250,7 @@ export const ModelAddDialog = ({ const [modelMaxTokens, setModelMaxTokens] = useState("4096"); // Use the silicon model list hook - const { getModelList, getProviderSelectedModalList } = useSiliconModelList({ + const siliconHook = useSiliconModelList({ form, setModelList, setSelectedModelIds, @@ -256,7 +258,33 @@ export const ModelAddDialog = ({ setLoadingModelList, tenantId, }); - + const dashscopeHook = useDashscopeModelList({ + form, + setModelList, + setSelectedModelIds, + setShowModelList, + setLoadingModelList, + tenantId, + }); + const tokenponyHook = useTokenPonyModelList({ + form, + setModelList, + setSelectedModelIds, + setShowModelList, + setLoadingModelList, + tenantId, + }); + let getModelList; + let getProviderSelectedModalList; + +// 2. 根据条件赋值 + if (form.provider === "silicon") { + ({ getModelList, getProviderSelectedModalList } = siliconHook); + } else if (form.provider === "dashscope") { + ({ getModelList, getProviderSelectedModalList } = dashscopeHook); + } else if (form.provider === "tokenpony") { + ({ getModelList, getProviderSelectedModalList } = tokenponyHook); + } // Reset form to default state const resetForm = useCallback(() => { setForm(DEFAULT_FORM_STATE); @@ -794,6 +822,8 @@ export const ModelAddDialog = ({ {t("model.provider.modelengine")} + + {/* ModelEngine URL input (only when provider is ModelEngine) */} {form.provider === "modelengine" && ( diff --git a/frontend/app/[locale]/models/components/model/ModelListCard.tsx b/frontend/app/[locale]/models/components/model/ModelListCard.tsx index ae966ae35..8bf6e00a6 100644 --- a/frontend/app/[locale]/models/components/model/ModelListCard.tsx +++ b/frontend/app/[locale]/models/components/model/ModelListCard.tsx @@ -33,12 +33,12 @@ const PULSE_ANIMATION = ` transform: scale(0.95); box-shadow: 0 0 0 0 rgba(41, 128, 185, 0.7); } - + 70% { transform: scale(1); box-shadow: 0 0 0 5px rgba(41, 128, 185, 0); } - + 100% { transform: scale(0.95); box-shadow: 0 0 0 0 rgba(41, 128, 185, 0); @@ -162,27 +162,33 @@ export const ModelListCard = ({ const model = modelsData.find( (m) => m.type === type && m.displayName === displayName ); - + if (!model) return t("model.source.unknown"); - + // Return source label based on model.source if (model.source === "modelengine") { return t("model.source.modelEngine"); } else if (model.source === "silicon") { return t("model.source.silicon"); + } else if (model.source==="dashscope"){ + return t("model.source.dashscope"); + }else if (model.source==="tokenpony"){ + return t("model.source.tokenpony"); } else if (model.source === "OpenAI-API-Compatible") { return t("model.source.custom"); } - + return t("model.source.unknown"); }; const filteredModels = getFilteredModels(); - + // Group models by source for display const groupedModels = { modelengine: filteredModels.filter((m) => m.source === "modelengine"), silicon: filteredModels.filter((m) => m.source === "silicon"), + dashscope: filteredModels.filter((m) => m.source === "dashscope"), + tokenpony: filteredModels.filter((m) => m.source === "tokenpony"), custom: filteredModels.filter((m) => m.source === "OpenAI-API-Compatible"), }; @@ -343,6 +349,102 @@ export const ModelListCard = ({ ))} )} + {groupedModels.dashscope.length > 0 && ( + + {groupedModels.dashscope.map((model) => ( + + ))} + + )} + {groupedModels.tokenpony.length > 0 && ( + + {groupedModels.tokenpony.map((model) => ( + + ))} + + )} {groupedModels.custom.length > 0 && ( {groupedModels.custom.map((model) => ( @@ -394,4 +496,4 @@ export const ModelListCard = ({
); -}; \ No newline at end of file +}; diff --git a/frontend/const/modelConfig.ts b/frontend/const/modelConfig.ts index ce7f1841d..9b0128529 100644 --- a/frontend/const/modelConfig.ts +++ b/frontend/const/modelConfig.ts @@ -40,6 +40,8 @@ export const MODEL_PROVIDER_KEYS = [ "jina", "deepseek", "aliyuncs", + "tokenpony", + "dashscope", ] as const; export type ModelProviderKey = (typeof MODEL_PROVIDER_KEYS)[number]; @@ -52,6 +54,8 @@ export const PROVIDER_HINTS: Record = { jina: "jina", deepseek: "deepseek", aliyuncs: "aliyuncs", + tokenpony: "tokenpony", + dashscope: "dashscope", }; // Icon filenames for providers @@ -62,6 +66,8 @@ export const PROVIDER_ICON_MAP: Record = { jina: "/jina.png", deepseek: "/deepseek.png", aliyuncs: "/aliyuncs.png", + dashscope:"/aliyuncs.png", + tokenpony: "/tokenpony.png", }; export const OFFICIAL_PROVIDER_ICON = "/modelengine-logo.png"; diff --git a/frontend/hooks/model/useDashscopeModelList.ts b/frontend/hooks/model/useDashscopeModelList.ts new file mode 100644 index 000000000..b44348fe5 --- /dev/null +++ b/frontend/hooks/model/useDashscopeModelList.ts @@ -0,0 +1,133 @@ +import { useEffect } from "react"; +import { message } from "antd"; +import { useTranslation } from "react-i18next"; +import { modelService } from "@/services/modelService"; +import { ModelType } from "@/types/modelConfig"; +import { processProviderResponse } from "@/lib/providerError"; +import log from "@/lib/logger"; + +interface UseDashscopeModelListProps { + form: { + type: ModelType; + isBatchImport: boolean; + apiKey: string; + provider: string; // Expected to be "dashscope" + maxTokens: string; + isMultimodal: boolean; + }; + setModelList: (models: any[]) => void; + setSelectedModelIds: (ids: Set) => void; + setShowModelList: (show: boolean) => void; + setLoadingModelList: (loading: boolean) => void; + tenantId?: string; // Optional tenant ID for manage operations +} + +export const useDashscopeModelList = ({ + form, + setModelList, + setSelectedModelIds, + setShowModelList, + setLoadingModelList, + tenantId, +}: UseDashscopeModelListProps) => { + const { t } = useTranslation(); + + const getModelList = async () => { + setShowModelList(true); + setLoadingModelList(true); + + const modelType = + form.type === "embedding" && form.isMultimodal + ? ("multi_embedding" as ModelType) + : form.type; + + try { + // Use manage interface if tenantId is provided (for super admin) + const result = tenantId + ? await modelService.addManageProviderModel({ + tenantId, + provider: form.provider, + type: modelType, + apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, + }) + : await modelService.addProviderModel({ + provider: form.provider, + type: modelType, + apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, + }); + + // Use centralized error processing + const { models, error } = processProviderResponse( + result, + form.provider, + t + ); + + if (error) { + message.error(error); + setModelList([]); + setSelectedModelIds(new Set()); + setLoadingModelList(false); + return; + } + + // Ensure each model has a default max_tokens value + const modelsWithDefaults = models.map((model: any) => ({ + ...model, + max_tokens: model.max_tokens || parseInt(form.maxTokens) || 4096, + })); + setModelList(modelsWithDefaults); + + const selectedModels = (await getProviderSelectedModalList()) || []; + + // Key logic: Sync previously selected models + if (!selectedModels.length) { + // Select none + setSelectedModelIds(new Set()); + } else { + // Only select selectedModels + setSelectedModelIds(new Set(selectedModels.map((m: any) => m.id))); + } + } catch (error) { + message.error(t("model.dialog.error.addFailed", { error })); + log.error(t("model.dialog.error.addFailedLog"), error); + } finally { + setLoadingModelList(false); + } + }; + + const getProviderSelectedModalList = async () => { + const modelType = + form.type === "embedding" && form.isMultimodal + ? ("multi_embedding" as ModelType) + : form.type; + + // Use manage interface if tenantId is provided (for super admin) + const result = tenantId + ? await modelService.getManageProviderSelectedModalList({ + tenantId, + provider: form.provider, + type: modelType, + }) + : await modelService.getProviderSelectedModalList({ + provider: form.provider, + type: modelType, + api_key: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, + }); + + return result; + }; + + // Auto-fetch model list when batch import is enabled and API key is provided + useEffect(() => { + if (form.isBatchImport && form.apiKey.trim() !== "") { + getModelList(); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [form.type, form.isBatchImport]); + + return { + getModelList, + getProviderSelectedModalList, + }; +}; diff --git a/frontend/hooks/model/useTokenponyModelList.ts b/frontend/hooks/model/useTokenponyModelList.ts new file mode 100644 index 000000000..0a7e23581 --- /dev/null +++ b/frontend/hooks/model/useTokenponyModelList.ts @@ -0,0 +1,133 @@ +import { useEffect } from "react"; +import { message } from "antd"; +import { useTranslation } from "react-i18next"; +import { modelService } from "@/services/modelService"; +import { ModelType } from "@/types/modelConfig"; +import { processProviderResponse } from "@/lib/providerError"; +import log from "@/lib/logger"; + +interface UseTokenPonyModelListProps { + form: { + type: ModelType; + isBatchImport: boolean; + apiKey: string; + provider: string; // Expected to be "tokenpony" + maxTokens: string; + isMultimodal: boolean; + }; + setModelList: (models: any[]) => void; + setSelectedModelIds: (ids: Set) => void; + setShowModelList: (show: boolean) => void; + setLoadingModelList: (loading: boolean) => void; + tenantId?: string; // Optional tenant ID for manage operations +} + +export const useTokenPonyModelList = ({ + form, + setModelList, + setSelectedModelIds, + setShowModelList, + setLoadingModelList, + tenantId, +}: UseTokenPonyModelListProps) => { + const { t } = useTranslation(); + + const getModelList = async () => { + setShowModelList(true); + setLoadingModelList(true); + + const modelType = + form.type === "embedding" && form.isMultimodal + ? ("multi_embedding" as ModelType) + : form.type; + + try { + // Use manage interface if tenantId is provided (for super admin) + const result = tenantId + ? await modelService.addManageProviderModel({ + tenantId, + provider: form.provider, + type: modelType, + apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, + }) + : await modelService.addProviderModel({ + provider: form.provider, + type: modelType, + apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, + }); + + // Use centralized error processing + const { models, error } = processProviderResponse( + result, + form.provider, + t + ); + + if (error) { + message.error(error); + setModelList([]); + setSelectedModelIds(new Set()); + setLoadingModelList(false); + return; + } + + // Ensure each model has a default max_tokens value + const modelsWithDefaults = models.map((model: any) => ({ + ...model, + max_tokens: model.max_tokens || parseInt(form.maxTokens) || 4096, + })); + setModelList(modelsWithDefaults); + + const selectedModels = (await getProviderSelectedModalList()) || []; + + // Key logic: Sync previously selected models + if (!selectedModels.length) { + // Select none + setSelectedModelIds(new Set()); + } else { + // Only select selectedModels + setSelectedModelIds(new Set(selectedModels.map((m: any) => m.id))); + } + } catch (error) { + message.error(t("model.dialog.error.addFailed", { error })); + log.error(t("model.dialog.error.addFailedLog"), error); + } finally { + setLoadingModelList(false); + } + }; + + const getProviderSelectedModalList = async () => { + const modelType = + form.type === "embedding" && form.isMultimodal + ? ("multi_embedding" as ModelType) + : form.type; + + // Use manage interface if tenantId is provided (for super admin) + const result = tenantId + ? await modelService.getManageProviderSelectedModalList({ + tenantId, + provider: form.provider, + type: modelType, + }) + : await modelService.getProviderSelectedModalList({ + provider: form.provider, + type: modelType, + api_key: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, + }); + + return result; + }; + + // Auto-fetch model list when batch import is enabled and API key is provided + useEffect(() => { + if (form.isBatchImport && form.apiKey.trim() !== "") { + getModelList(); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [form.type, form.isBatchImport]); + + return { + getModelList, + getProviderSelectedModalList, + }; +}; diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 775eae675..986140c83 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -674,6 +674,8 @@ "model.dialog.hint.batchImportEnabled": "Batch add enabled. Multiple models will be added at once.", "model.dialog.hint.batchImportDisabled": "Batch add disabled. Only a single model will be added.", "model.provider.silicon": "SiliconFlow", + "model.provider.dashscope": "DashScope", + "model.provider.tokenpony": "TokenPony", "model.provider.modelengine": "ModelEngine", "model.dialog.modelList.title": "Show Models", "model.dialog.modelList.searchPlaceholder": "Search models by name", @@ -746,12 +748,16 @@ "model.source.modelEngine": "ModelEngine", "model.source.openai": "OpenAI", "model.source.silicon": "Silicon Flow", + "model.source.dashscope": "DashScope", + "model.source.tokenpony": "TokenPony", "model.source.unknown": "Unknown Source", "model.warning.updateNotFound": "Model not found for update: {{displayName}}, type: {{type}}", "model.type.main": "LLM Model", "model.select.placeholder": "Select Model", "model.group.modelEngine": "ModelEngine Models", "model.group.silicon": "Silicon Flow Models", + "model.group.dashscope": "DashScope Models", + "model.group.tokenpony": "TokenPony Models", "model.group.custom": "Custom Models", "model.status.tooltip": "Click to verify connectivity", "model.dialog.embeddingConfig.title": "Edit Embedding Model: {{modelName}}", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 88ef18fdc..b830b1792 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -676,6 +676,8 @@ "model.dialog.hint.batchImportEnabled": "批量添加模式已启用,可通过API Key一次性导入多个模型", "model.dialog.hint.batchImportDisabled": "批量添加模式已关闭,仅添加单个模型", "model.provider.silicon": "硅基流动", + "model.provider.dashscope": "阿里灵积", + "model.provider.tokenpony": "小马算力", "model.provider.modelengine": "ModelEngine", "model.dialog.modelList.title": "显示模型", "model.dialog.modelList.searchPlaceholder": "按名称搜索模型", @@ -748,11 +750,15 @@ "model.source.unknown": "未知来源", "model.source.openai": "OpenAI", "model.source.silicon": "硅基流动", + "model.source.dashscope": "阿里灵积", + "model.source.tokenpony": "小马算力", "model.warning.updateNotFound": "未找到要更新的模型: {{displayName}}, 类型: {{type}}", "model.type.main": "大语言模型", "model.select.placeholder": "选择模型", "model.group.modelEngine": "ModelEngine模型", "model.group.silicon": "硅基流动模型", + "model.group.dashscope": "阿里灵积模型", + "model.group.tokenpony": "小马算力模型", "model.group.custom": "自定义模型", "model.status.tooltip": "点击可验证连通性", "model.dialog.success.updateSuccess": "更新成功", diff --git a/frontend/public/tokenpony.png b/frontend/public/tokenpony.png new file mode 100644 index 0000000000000000000000000000000000000000..d582ae86b2b3a14192759a9d89d39d25bcc1508f GIT binary patch literal 1296 zcmV+r1@HQaP)Px#1ZP1_K>z@;j|==^1poj532;bRa{vGi!TsHp)@UQ zK(JwufJzK26bR5FD4XDTOdC7KA+hVkwH^C#&wc;djtA7VsJ_x|Uf;XtocrH!-?IZ0 zQ0pg8&STHMB>HZSVO3Wc9)Ael6(Rk*9Jd&9kc*t;uGgz1P$=Z_+)xt!%Oi07NqAOF z3V|fh5xnpO1gjjvMNV>4gIY@xC>HbBHjt2@v;ZRTbqawc6C$TWT-wWW!4Rv-h?A4t z)Yvu&w7dE!Eaz;&wRxp!qAHuC-xoe>{QV7g88Lkiug-6X@;@ z8^YSyTC9vQhOxvDOAd0?<8;PFTN02)Wm?1~aH|}Xc%mq~6eyR=>a^)(5f45buae0* zS6AGGjxmN-iz62~EpCaX$wV4lD|ZT(VMBvwgW z9Q!z?Cb7g2OAZ=2$<0`5Rdj8`621FgS^)&^efsapSMJ)A)IQqCv5X-i0bSbcGgj7E zizA1{k(;sh>z5?pc!Tg=7Q(4NS)qV$zFm-DnKBBg2k_-rtmVM?cwWjX!6G8gHv4Fd zsmGCvMs8|Qt1$s5h?R12*fo6p%~|Oihu0I3QYu11^y+J~IyN`Ah&5MoW3fM|ZT8U^ z!&skSh~-Hj7meK1Af{;ot9!zjoi#*5!zo=p40g4_j5?47SXd_ zj*y!`&-ySD(?yK_m_z>^W}Od@q;ceE8vd0b>_0f8oaCm)c_lD7ltTZVr?pKR{V9h> z9!+44#BA6o7naY=ihixMd#{S1d!0FgznnOyn0&shQ+`1S9DIpi9mJtm{?RWfmCQyW z=?|kheEi8-y!CckH`b@0W$~+smKIIS{Co*VMlyJIP&Tp5+{?m2TUG)D5$yAsjkIx7 z1W)do#)lteFcq(!bz@~jqXis&KZ6}lB+z>`mxnWBu+;=g|DAwtLRm}N8@dUFq>~R4 zGu=1zMRC`?XRvmIneI*3M{&)yQ7NfmtP|1u?w`U<^7>({M=)N&vP~`;xm#~S_Za?h z${eTfzHdB+$B&z1$j#=Jcsu5mJuvojPRG8Mn0g$!Xyk5r3~eEL=Ww+UJ@t%Rh(uub z-fBTP_CaGI;`GzFEW7s!PB6rh!{W%z*ye@EOP9uoEEo4 zv(Q literal 0 HcmV?d00001 diff --git a/frontend/types/modelConfig.ts b/frontend/types/modelConfig.ts index 04d6a5ff3..2897c762d 100644 --- a/frontend/types/modelConfig.ts +++ b/frontend/types/modelConfig.ts @@ -17,6 +17,8 @@ export type ModelSource = | "openai" | "custom" | "silicon" + | "dashscope" + | "tokenpony" | "OpenAI-API-Compatible" | "modelengine"; From 97bcb3b255b768357697d9b5f0017bfa991a02d7 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Wed, 4 Mar 2026 17:09:35 +0800 Subject: [PATCH 18/75] New Requirement: Support for provider Zhipu AI Models (LLM and Embedding) --- backend/consts/provider.py | 4 +- backend/services/model_provider_service.py | 3 +- .../services/providers/dashscope_provider.py | 10 +- .../components/model/ModelDeleteDialog.tsx | 159 +++++++++++++++++- frontend/const/modelConfig.ts | 2 + 5 files changed, 166 insertions(+), 12 deletions(-) diff --git a/backend/consts/provider.py b/backend/consts/provider.py index e2a0f0235..38bbc4027 100644 --- a/backend/consts/provider.py +++ b/backend/consts/provider.py @@ -15,11 +15,11 @@ class ProviderEnum(str, Enum): SILICON_GET_URL = "https://api.siliconflow.cn/v1/models" # Dashcope -DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" +DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1/" DASHSCOPE_GET_URL = "https://dashscope.aliyuncs.com/api/v1/models" # TokenPony -TOKENPONY_BASE_URL = "https://api.tokenpony.cn/v1" +TOKENPONY_BASE_URL = "https://api.tokenpony.cn/v1/" TOKENPONY_GET_URL = "https://api.tokenpony.cn/v1/models" # ModelEngine diff --git a/backend/services/model_provider_service.py b/backend/services/model_provider_service.py index 3c916eb8c..8c397dc70 100644 --- a/backend/services/model_provider_service.py +++ b/backend/services/model_provider_service.py @@ -125,7 +125,8 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a # dimension by performing a real connectivity check. if model["model_type"] in ["embedding", "multi_embedding"]: if provider != ProviderEnum.MODELENGINE.value: - model_dict["base_url"] = f"{model_url}embeddings" + # Ensure proper slash between base URL and endpoint + model_dict["base_url"] = f"{model_url.rstrip('/')}/embeddings" else: # For ModelEngine embedding models, append the embeddings path model_dict["base_url"] = f"{model_url.rstrip('/')}/{MODEL_ENGINE_NORTH_PREFIX}/embeddings" diff --git a/backend/services/providers/dashscope_provider.py b/backend/services/providers/dashscope_provider.py index 2a34823ed..cde54b60a 100644 --- a/backend/services/providers/dashscope_provider.py +++ b/backend/services/providers/dashscope_provider.py @@ -35,16 +35,16 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: while True: params = {"page_size": 100, "page_no": current_page} response = await client.get(base_url, headers=headers, params=params) - response.raise_for_status() - - data = response.json() - models = data.get("output", {}).get("models", []) - if response.status_code == 429: await asyncio.sleep(2) continue if not models : # Break loop if no more models on the current page break + response.raise_for_status() + + data = response.json() + models = data.get("output", {}).get("models", []) + all_models.extend(models) if(len(models)<100): diff --git a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx index 541ed6266..579908d95 100644 --- a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx @@ -183,6 +183,10 @@ export const ModelDeleteDialog = ({ return t("model.source.modelEngine"); case MODEL_SOURCES.OPENAI_API_COMPATIBLE: return t("model.source.custom"); + case MODEL_SOURCES.DASHSCOPE: + return t("model.source.dashscope"); + case MODEL_SOURCES.TOKENPONY: + return t("model.source.tokenpony"); default: return t("model.source.unknown"); } @@ -217,6 +221,18 @@ export const ModelDeleteDialog = ({ text: "text-rose-600", border: "border-rose-100", }; + case MODEL_SOURCES.DASHSCOPE: + return { + bg: "bg-orange-50", + text: "text-orange-600", + border: "border-orange-100", + }; + case MODEL_SOURCES.TOKENPONY: + return { + bg: "bg-cyan-50", + text: "text-cyan-600", + border: "border-cyan-100", + }; default: return { bg: "bg-gray-50", @@ -253,6 +269,14 @@ export const ModelDeleteDialog = ({ 🛠️
); + case MODEL_SOURCES.DASHSCOPE: + return ( + DashScope + ); + case MODEL_SOURCES.TOKENPONY: + return ( + TokenPony + ); default: return ( @@ -288,6 +312,16 @@ export const ModelDeleteDialog = ({ ); if (byModelEngine?.apiKey) return byModelEngine.apiKey; + const byDashScope = models.find( + (m) => m.source === MODEL_SOURCES.DASHSCOPE && m.type === type && m.apiKey + ); + if (byDashScope?.apiKey) return byDashScope.apiKey; + + const byTokenPony = models.find( + (m) => m.source === MODEL_SOURCES.TOKENPONY && m.type === type && m.apiKey + ); + if (byTokenPony?.apiKey) return byTokenPony.apiKey; + // Fallback: any model that has apiKey const anyWithKey = models.find((m) => m.apiKey); return anyWithKey?.apiKey || ""; @@ -327,7 +361,7 @@ export const ModelDeleteDialog = ({ return anyModelWithUrl?.apiUrl || undefined; }; - // Prefetch provider model list (supports Silicon and ModelEngine) + // Prefetch provider model list (supports Silicon, ModelEngine, DashScope, TokenPony) const prefetchProviderModels = async ( provider: ModelSource, modelType: ModelType | null @@ -351,6 +385,20 @@ export const ModelDeleteDialog = ({ apiKey: apiKey && apiKey.trim() !== "" ? apiKey : "sk-no-api-key", baseUrl: baseUrl || undefined, }); + } else if (provider === MODEL_SOURCES.DASHSCOPE) { + const apiKey = getApiKeyByType(modelType, MODEL_SOURCES.DASHSCOPE); + result = await modelService.addProviderModel({ + provider: MODEL_SOURCES.DASHSCOPE, + type: modelType, + apiKey: apiKey && apiKey.trim() !== "" ? apiKey : "sk-no-api-key", + }); + } else if (provider === MODEL_SOURCES.TOKENPONY) { + const apiKey = getApiKeyByType(modelType, MODEL_SOURCES.TOKENPONY); + result = await modelService.addProviderModel({ + provider: MODEL_SOURCES.TOKENPONY, + type: modelType, + apiKey: apiKey && apiKey.trim() !== "" ? apiKey : "sk-no-api-key", + }); } else { // Unsupported provider for prefetching return; @@ -383,7 +431,12 @@ export const ModelDeleteDialog = ({ const handleSourceSelect = async (source: ModelSource) => { setLoadingSource(source); try { - if (source === MODEL_SOURCES.SILICON || source === MODEL_SOURCES.MODELENGINE) { + if ( + source === MODEL_SOURCES.SILICON || + source === MODEL_SOURCES.MODELENGINE || + source === MODEL_SOURCES.DASHSCOPE || + source === MODEL_SOURCES.TOKENPONY + ) { await prefetchProviderModels(source, deletingModelType); } else if (source === MODEL_SOURCES.OPENAI) { // For OpenAI source, just set the selected source without prefetching @@ -543,7 +596,9 @@ export const ModelDeleteDialog = ({ setMaxTokens(maxTokens); if ( (selectedSource === MODEL_SOURCES.SILICON || - selectedSource === MODEL_SOURCES.MODELENGINE) && + selectedSource === MODEL_SOURCES.MODELENGINE || + selectedSource === MODEL_SOURCES.DASHSCOPE || + selectedSource === MODEL_SOURCES.TOKENPONY) && deletingModelType ) { try { @@ -839,6 +894,98 @@ export const ModelDeleteDialog = ({ t("model.dialog.error.addFailed", { error: e as any }) ); } + } else if ( + selectedSource === MODEL_SOURCES.DASHSCOPE && + deletingModelType + ) { + try { + const allEnabledModels = providerModels.filter( + (pm: any) => pendingSelectedProviderIds.has(pm.id) + ); + + if (allEnabledModels) { + const apiKey = getApiKeyByType(deletingModelType, MODEL_SOURCES.DASHSCOPE); + const isEmbeddingType = + deletingModelType === MODEL_TYPES.EMBEDDING || + deletingModelType === MODEL_TYPES.MULTI_EMBEDDING; + await modelService.addBatchCustomModel({ + api_key: + apiKey && apiKey.trim() !== "" + ? apiKey + : "sk-no-api-key", + provider: MODEL_SOURCES.DASHSCOPE, + type: deletingModelType, + models: allEnabledModels.map((model) => { + if (isEmbeddingType) { + const { max_tokens, ...modelWithoutMaxTokens } = + model; + return modelWithoutMaxTokens; + } else { + return { + ...model, + max_tokens: model.max_tokens || 4096, + }; + } + }), + }); + } + + await onSuccess(); + await prefetchProviderModels(selectedSource, deletingModelType); + message.success(t("model.dialog.success.updateSuccess")); + handleClose(); + } catch (e) { + log.error("Failed to apply DashScope model updates", e); + message.error( + t("model.dialog.error.addFailed", { error: e as any }) + ); + } + } else if ( + selectedSource === MODEL_SOURCES.TOKENPONY && + deletingModelType + ) { + try { + const allEnabledModels = providerModels.filter( + (pm: any) => pendingSelectedProviderIds.has(pm.id) + ); + + if (allEnabledModels) { + const apiKey = getApiKeyByType(deletingModelType, MODEL_SOURCES.TOKENPONY); + const isEmbeddingType = + deletingModelType === MODEL_TYPES.EMBEDDING || + deletingModelType === MODEL_TYPES.MULTI_EMBEDDING; + await modelService.addBatchCustomModel({ + api_key: + apiKey && apiKey.trim() !== "" + ? apiKey + : "sk-no-api-key", + provider: MODEL_SOURCES.TOKENPONY, + type: deletingModelType, + models: allEnabledModels.map((model) => { + if (isEmbeddingType) { + const { max_tokens, ...modelWithoutMaxTokens } = + model; + return modelWithoutMaxTokens; + } else { + return { + ...model, + max_tokens: model.max_tokens || 4096, + }; + } + }), + }); + } + + await onSuccess(); + await prefetchProviderModels(selectedSource, deletingModelType); + message.success(t("model.dialog.success.updateSuccess")); + handleClose(); + } catch (e) { + log.error("Failed to apply TokenPony model updates", e); + message.error( + t("model.dialog.error.addFailed", { error: e as any }) + ); + } } else if ( selectedSource === MODEL_SOURCES.OPENAI && deletingModelType @@ -976,6 +1123,8 @@ export const ModelDeleteDialog = ({ MODEL_SOURCES.OPENAI, MODEL_SOURCES.SILICON, MODEL_SOURCES.OPENAI_API_COMPATIBLE, + MODEL_SOURCES.DASHSCOPE, + MODEL_SOURCES.TOKENPONY, ] as ModelSource[] ).map((source) => { const modelsOfSource = models.filter( @@ -1074,7 +1223,9 @@ export const ModelDeleteDialog = ({ onClick={async () => { if ( (selectedSource === MODEL_SOURCES.SILICON || - selectedSource === MODEL_SOURCES.MODELENGINE) && + selectedSource === MODEL_SOURCES.MODELENGINE || + selectedSource === MODEL_SOURCES.DASHSCOPE || + selectedSource === MODEL_SOURCES.TOKENPONY) && deletingModelType ) { try { diff --git a/frontend/const/modelConfig.ts b/frontend/const/modelConfig.ts index 9b0128529..4c412824a 100644 --- a/frontend/const/modelConfig.ts +++ b/frontend/const/modelConfig.ts @@ -16,6 +16,8 @@ export const MODEL_SOURCES = { MODELENGINE: "modelengine", OPENAI_API_COMPATIBLE: "OpenAI-API-Compatible", CUSTOM: "custom", + DASHSCOPE: "dashscope", + TOKENPONY: "tokenpony", } as const; // Model status constants From a552e1272b5a8d2d0bf126d499efcc429bd0c01a Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Wed, 4 Mar 2026 17:23:38 +0800 Subject: [PATCH 19/75] New Requirement: Support for provider dashscope and tokenpony Models (LLM and Embedding) --- backend/services/providers/tokenpony_provider.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py index 62972b698..6fe67502e 100644 --- a/backend/services/providers/tokenpony_provider.py +++ b/backend/services/providers/tokenpony_provider.py @@ -72,16 +72,14 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: "model_type": "", "max_tokens": DEFAULT_LLM_MAX_TOKENS } - # 1. Embedding - if 'embedding' in m_id or m_id.startswith('bge-'): - cleaned_model.update({"model_tag": "embedding", "model_type": "embedding", "max_tokens": 0}) - categorized_models['embedding'].append(cleaned_model) - - # 2. Reranker - elif 'rerank' in m_id: + # 1. reranker + if 'rerank' in m_id: cleaned_model.update({"model_tag": "reranker", "model_type": "reranker"}) categorized_models['reranker'].append(cleaned_model) - + #2. embedding + elif 'embedding' in m_id or m_id.startswith('bge-'): + cleaned_model.update({"model_tag": "embedding", "model_type": "embedding", "max_tokens": 0}) + categorized_models['embedding'].append(cleaned_model) # 3. STT (Speech-to-Text / Audio understanding) elif 'stt' in m_id: From e21157d68d320e17c14682f3c43a63c0ba6768cf Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Thu, 5 Mar 2026 13:37:24 +0800 Subject: [PATCH 20/75] bug fix : embedding model max_tokens changes --- backend/services/providers/tokenpony_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py index 6fe67502e..844dd1859 100644 --- a/backend/services/providers/tokenpony_provider.py +++ b/backend/services/providers/tokenpony_provider.py @@ -78,7 +78,7 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: categorized_models['reranker'].append(cleaned_model) #2. embedding elif 'embedding' in m_id or m_id.startswith('bge-'): - cleaned_model.update({"model_tag": "embedding", "model_type": "embedding", "max_tokens": 0}) + cleaned_model.update({"model_tag": "embedding", "model_type": "embedding"}) categorized_models['embedding'].append(cleaned_model) # 3. STT (Speech-to-Text / Audio understanding) From b70e45c311c0bfd4892fd2b5bff1f27c03b4fe2d Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Thu, 5 Mar 2026 13:39:22 +0800 Subject: [PATCH 21/75] bug fix : embedding model max_tokens changes --- backend/services/providers/tokenpony_provider.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py index 844dd1859..42e5d178c 100644 --- a/backend/services/providers/tokenpony_provider.py +++ b/backend/services/providers/tokenpony_provider.py @@ -35,13 +35,7 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE ssl_context.set_ciphers("DEFAULT@SECLEVEL=1") - # response = requests.get(url, headers=headers) - # all_models=[] - # if response.status_code == 200: - # data = response.json() - # # 注意:OpenAI 标准返回是在 "data" 字段下 - # all_models=data.get("data", []) - # Fetch all models asynchronously + async with httpx.AsyncClient(http2=True) as client: response = await client.get(url, headers=headers) response.raise_for_status() From 2f3af41a53ef66bfc0736f01d34293e14ac36c58 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Thu, 5 Mar 2026 14:50:23 +0800 Subject: [PATCH 22/75] create test files for the backend providers --- .../providers/test_dashscope_provider.py | 718 ++++++++++++++++++ .../providers/test_tokenpony_provider.py | 711 +++++++++++++++++ .../services/test_model_management_service.py | 4 + .../services/test_model_provider_service.py | 124 +++ 4 files changed, 1557 insertions(+) create mode 100644 test/backend/services/providers/test_dashscope_provider.py create mode 100644 test/backend/services/providers/test_tokenpony_provider.py diff --git a/test/backend/services/providers/test_dashscope_provider.py b/test/backend/services/providers/test_dashscope_provider.py new file mode 100644 index 000000000..2dc3a8f27 --- /dev/null +++ b/test/backend/services/providers/test_dashscope_provider.py @@ -0,0 +1,718 @@ +"""Unit tests for DashScopeModelProvider module. + +Tests cover model fetching, type classification, and error handling. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from pytest_mock import MockFixture + +import httpx + +from backend.services.providers.dashscope_provider import DashScopeModelProvider + + +class TestDashScopeModelProvider: + """Tests for DashScopeModelProvider class.""" + + @pytest.mark.asyncio + async def test_get_models_llm_success(self, mocker: MockFixture): + """Test successful model retrieval for LLM models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-turbo", + "description": "Text generation model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + }, + { + "model": "qwen-plus", + "description": "Advanced text generation", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DEFAULT_LLM_MAX_TOKENS", + 4096 + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 2 + assert result[0]["id"] == "qwen-turbo" + assert result[0]["model_type"] == "llm" + assert result[0]["model_tag"] == "chat" + assert result[0]["max_tokens"] == 4096 + + @pytest.mark.asyncio + async def test_get_models_embedding_success(self, mocker: MockFixture): + """Test successful model retrieval for embedding models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "text-embedding-v3", + "description": "Embedding model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "embedding", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "text-embedding-v3" + assert result[0]["model_type"] == "embedding" + assert result[0]["model_tag"] == "embedding" + + @pytest.mark.asyncio + async def test_get_models_vlm_success(self, mocker: MockFixture): + """Test successful model retrieval for VLM models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-vl-plus", + "description": "Vision language model", + "inference_metadata": { + "request_modality": ["Image", "Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "vlm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "qwen-vl-plus" + assert result[0]["model_type"] == "vlm" + assert result[0]["model_tag"] == "chat" + + @pytest.mark.asyncio + async def test_get_models_reranker_success(self, mocker: MockFixture): + """Test successful model retrieval for reranker models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "gte-reranker", + "description": "Reranking model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "reranker", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "gte-reranker" + assert result[0]["model_type"] == "reranker" + assert result[0]["model_tag"] == "reranker" + + @pytest.mark.asyncio + async def test_get_models_tts_success(self, mocker: MockFixture): + """Test successful model retrieval for TTS models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "sambert-tts", + "description": "Text to speech", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Audio"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "tts", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "sambert-tts" + assert result[0]["model_type"] == "tts" + assert result[0]["model_tag"] == "tts" + + @pytest.mark.asyncio + async def test_get_models_stt_success(self, mocker: MockFixture): + """Test successful model retrieval for STT models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "paraformer-realtime-v2", + "description": "Speech recognition", + "inference_metadata": { + "request_modality": ["Audio"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "stt", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "paraformer-realtime-v2" + assert result[0]["model_type"] == "stt" + assert result[0]["model_tag"] == "stt" + + @pytest.mark.asyncio + async def test_get_models_multi_embedding_success(self, mocker: MockFixture): + """Test successful model retrieval for multi-embedding models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "text-embedding-multimodal-v3", + "description": "Multimodal embedding", + "inference_metadata": { + "request_modality": ["Text", "Image"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "multi_embedding", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "text-embedding-multimodal-v3" + assert result[0]["model_type"] == "embedding" + + @pytest.mark.asyncio + async def test_get_models_empty_response(self, mocker: MockFixture): + """Test handling of empty model list from API.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"output": {"models": []}} + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert result == [] + + @pytest.mark.asyncio + async def test_get_models_http_error(self, mocker: MockFixture): + """Test handling of HTTP error.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.HTTPStatusError( + "Error", + request=MagicMock(), + response=MagicMock(status_code=500) + ) + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["_error"] == "connection_failed" + + @pytest.mark.asyncio + async def test_get_models_connect_error(self, mocker: MockFixture): + """Test handling of connection error.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.ConnectError("Connection failed") + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["_error"] == "connection_failed" + + @pytest.mark.asyncio + async def test_get_models_timeout(self, mocker: MockFixture): + """Test handling of connection timeout.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.ConnectTimeout("Timeout") + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["_error"] == "connection_failed" + + @pytest.mark.asyncio + async def test_get_models_authorization_header(self, mocker: MockFixture): + """Test that Authorization header is correctly set.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-turbo", + "description": "Test", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "my-secret-key" + } + + await provider.get_models(provider_config) + + # Verify Authorization header + call_args = mock_client.get.call_args + headers = call_args[1]["headers"] + assert headers["Authorization"] == "Bearer my-secret-key" + + @pytest.mark.asyncio + async def test_get_models_pagination(self, mocker: MockFixture): + """Test that pagination works correctly.""" + # First page returns 100 models + mock_response_page1 = MagicMock() + mock_response_page1.status_code = 200 + mock_response_page1.json.return_value = { + "output": { + "models": [{"model": f"model-{i}", "description": "test", + "inference_metadata": {"request_modality": ["Text"], "response_modality": ["Text"]}} + for i in range(100)] + } + } + mock_response_page1.raise_for_status = MagicMock() + + # Second page returns 50 models (less than page_size) + mock_response_page2 = MagicMock() + mock_response_page2.status_code = 200 + mock_response_page2.json.return_value = { + "output": { + "models": [{"model": f"model-{i}", "description": "test", + "inference_metadata": {"request_modality": ["Text"], "response_modality": ["Text"]}} + for i in range(100, 150)] + } + } + mock_response_page2.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.side_effect = [mock_response_page1, mock_response_page2] + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + # Should get models from both pages + assert len(result) == 150 + + @pytest.mark.asyncio + async def test_get_models_unknown_type_returns_empty(self, mocker: MockFixture): + """Test that unknown model type returns empty list.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-turbo", + "description": "Text generation", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "unknown_type", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert result == [] + + @pytest.mark.asyncio + async def test_get_models_with_chinese_description(self, mocker: MockFixture): + """Test model classification by Chinese description.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "embedding-v1", + "description": "向量embedding模型", # Chinese description + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + }, + { + "model": "rerank-v1", + "description": "重排序模型", # Chinese description + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models" + ) + + provider = DashScopeModelProvider() + + # Test embedding classification by Chinese description + result = await provider.get_models({"model_type": "embedding", "api_key": "test-key"}) + assert len(result) == 1 + assert result[0]["id"] == "embedding-v1" + + # Test reranker classification by Chinese description + result = await provider.get_models({"model_type": "reranker", "api_key": "test-key"}) + assert len(result) == 1 + assert result[0]["id"] == "rerank-v1" + diff --git a/test/backend/services/providers/test_tokenpony_provider.py b/test/backend/services/providers/test_tokenpony_provider.py new file mode 100644 index 000000000..4f4a564e1 --- /dev/null +++ b/test/backend/services/providers/test_tokenpony_provider.py @@ -0,0 +1,711 @@ +"""Unit tests for TokenPonyModelProvider module. + +Tests cover model fetching, type classification, and error handling. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from pytest_mock import MockFixture + +import httpx + +from backend.services.providers.tokenpony_provider import TokenPonyModelProvider + + +class TestTokenPonyModelProvider: + """Tests for TokenPonyModelProvider class.""" + + @pytest.mark.asyncio + async def test_get_models_llm_success(self, mocker: MockFixture): + """Test successful model retrieval for LLM models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "gpt-4", + "object": "model", + "owned_by": "openai" + }, + { + "id": "claude-3-opus", + "object": "model", + "owned_by": "anthropic" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.DEFAULT_LLM_MAX_TOKENS", + 4096 + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 2 + assert result[0]["id"] == "gpt-4" + assert result[0]["model_type"] == "llm" + assert result[0]["model_tag"] == "chat" + assert result[0]["max_tokens"] == 4096 + + @pytest.mark.asyncio + async def test_get_models_embedding_success(self, mocker: MockFixture): + """Test successful model retrieval for embedding models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "text-embedding-ada-002", + "object": "model", + "owned_by": "openai" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "embedding", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "text-embedding-ada-002" + assert result[0]["model_type"] == "embedding" + assert result[0]["model_tag"] == "embedding" + + @pytest.mark.asyncio + async def test_get_models_vlm_success(self, mocker: MockFixture): + """Test successful model retrieval for VLM models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "qwen-vl-plus", + "object": "model", + "owned_by": "qwen" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "vlm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "qwen-vl-plus" + assert result[0]["model_type"] == "vlm" + assert result[0]["model_tag"] == "chat" + + @pytest.mark.asyncio + async def test_get_models_reranker_success(self, mocker: MockFixture): + """Test successful model retrieval for reranker models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "gte-reranker-base", + "object": "model", + "owned_by": "gte" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "reranker", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "gte-reranker-base" + assert result[0]["model_type"] == "reranker" + assert result[0]["model_tag"] == "reranker" + + @pytest.mark.asyncio + async def test_get_models_tts_success(self, mocker: MockFixture): + """Test successful model retrieval for TTS models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "tts-1-hd", + "object": "model", + "owned_by": "openai" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "tts", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "tts-1-hd" + assert result[0]["model_type"] == "tts" + assert result[0]["model_tag"] == "tts" + + @pytest.mark.asyncio + async def test_get_models_stt_success(self, mocker: MockFixture): + """Test successful model retrieval for STT models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "whisper-1", + "object": "model", + "owned_by": "openai" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "stt", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "whisper-1" + assert result[0]["model_type"] == "stt" + assert result[0]["model_tag"] == "stt" + + @pytest.mark.asyncio + async def test_get_models_multi_embedding_success(self, mocker: MockFixture): + """Test successful model retrieval for multi-embedding models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "bge-large", + "object": "model", + "owned_by": "bge" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "multi_embedding", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["id"] == "bge-large" + assert result[0]["model_type"] == "embedding" + + @pytest.mark.asyncio + async def test_get_models_empty_response(self, mocker: MockFixture): + """Test handling of empty model list from API.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"data": []} + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert result == [] + + @pytest.mark.asyncio + async def test_get_models_http_error(self, mocker: MockFixture): + """Test handling of HTTP error.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.HTTPStatusError( + "Error", + request=MagicMock(), + response=MagicMock(status_code=500) + ) + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["_error"] == "connection_failed" + + @pytest.mark.asyncio + async def test_get_models_connect_error(self, mocker: MockFixture): + """Test handling of connection error.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.ConnectError("Connection failed") + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["_error"] == "connection_failed" + + @pytest.mark.asyncio + async def test_get_models_timeout(self, mocker: MockFixture): + """Test handling of connection timeout.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.ConnectTimeout("Timeout") + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["_error"] == "connection_failed" + + @pytest.mark.asyncio + async def test_get_models_authorization_header(self, mocker: MockFixture): + """Test that Authorization header is correctly set.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "gpt-4", + "object": "model", + "owned_by": "openai" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "my-secret-key" + } + + await provider.get_models(provider_config) + + # Verify Authorization header + call_args = mock_client.get.call_args + headers = call_args[1]["headers"] + assert headers["Authorization"] == "Bearer my-secret-key" + + @pytest.mark.asyncio + async def test_get_models_unknown_type_returns_empty(self, mocker: MockFixture): + """Test that unknown model type returns empty list.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "gpt-4", + "object": "model", + "owned_by": "openai" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "unknown_type", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert result == [] + + @pytest.mark.asyncio + async def test_get_models_vlm_by_keyword(self, mocker: MockFixture): + """Test VLM classification by keywords like -vl, vl-, ocr, vision.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "qwen-vl-plus", + "object": "model", + "owned_by": "qwen" + }, + { + "id": "vl-ocr-v1", + "object": "model", + "owned_by": "ocr" + }, + { + "id": "vision-model-v2", + "object": "model", + "owned_by": "vision" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "vlm", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 3 + for model in result: + assert model["model_type"] == "vlm" + assert model["model_tag"] == "chat" + + @pytest.mark.asyncio + async def test_get_models_bge_prefix_embedding(self, mocker: MockFixture): + """Test that models with bge- prefix are classified as embedding.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "bge-large-zh-v1.5", + "object": "model", + "owned_by": "bge" + }, + { + "id": "bge-base-en-v1.5", + "object": "model", + "owned_by": "bge" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "embedding", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 2 + for model in result: + assert model["model_type"] == "embedding" + assert model["model_tag"] == "embedding" + + @pytest.mark.asyncio + async def test_get_models_llm_has_max_tokens(self, mocker: MockFixture): + """Test that LLM models have max_tokens set.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "gpt-4", + "object": "model", + "owned_by": "openai" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.TOKENPONY_GET_URL", + "https://api.tokenpony.cn/v1/models" + ) + mocker.patch( + "backend.services.providers.tokenpony_provider.DEFAULT_LLM_MAX_TOKENS", + 4096 + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "llm", + "api_key": "test-key" + } + + result = await provider.get_models(provider_config) + + assert len(result) == 1 + assert result[0]["max_tokens"] == 4096 + diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 6d0806299..e5d52d31a 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -120,10 +120,14 @@ class _Func: class _ProviderEnum: SILICON = _EnumItem("silicon") MODELENGINE = _EnumItem("modelengine") + DASHSCOPE = _EnumItem("dashscope") + TOKENPONY = _EnumItem("tokenpony") consts_provider_mod.ProviderEnum = _ProviderEnum consts_provider_mod.SILICON_BASE_URL = "http://silicon.test" +consts_provider_mod.DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1/" +consts_provider_mod.TOKENPONY_BASE_URL = "https://api.tokenpony.cn/v1/" sys.modules["consts.provider"] = consts_provider_mod # Stub services.model_provider_service used by service diff --git a/test/backend/services/test_model_provider_service.py b/test/backend/services/test_model_provider_service.py index f81222056..992025754 100644 --- a/test/backend/services/test_model_provider_service.py +++ b/test/backend/services/test_model_provider_service.py @@ -157,6 +157,8 @@ def __init__(self): class _ProviderEnumStub: SILICON = mock.Mock(value="silicon") MODELENGINE = mock.Mock(value="modelengine") + DASHSCOPE = mock.Mock(value="dashscope") + TOKENPONY = mock.Mock(value="tokenpony") sys.modules["consts.provider"].ProviderEnum = _ProviderEnumStub @@ -1903,3 +1905,125 @@ def test_get_model_engine_raw_url_trailing_slash(): for input_url, expected in test_cases: result = get_model_engine_raw_url(input_url) assert result == expected, f"Failed for input: {input_url}" + + +# ============================================================================ +# Test-cases for get_provider_models with DashScope provider +# ============================================================================ + + +@pytest.mark.asyncio +async def test_get_provider_models_dashscope_success(): + """Should successfully get models from DashScope provider.""" + from backend.services.model_provider_service import DashScopeModelProvider + + model_data = { + "provider": "dashscope", + "model_type": "llm", + "api_key": "test-key", + } + + expected_models = [ + { + "id": "qwen-turbo", + "model_tag": "chat", + "model_type": "llm", + "max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS, + } + ] + + with mock.patch( + "backend.services.model_provider_service.DashScopeModelProvider" + ) as mock_provider_class: + mock_provider_instance = mock.AsyncMock() + mock_provider_instance.get_models.return_value = expected_models + mock_provider_class.return_value = mock_provider_instance + + result = await get_provider_models(model_data) + + assert result == expected_models + mock_provider_class.assert_called_once() + mock_provider_instance.get_models.assert_called_once_with(model_data) + + +@pytest.mark.asyncio +async def test_get_provider_models_dashscope_empty_result(): + """Should handle empty result from DashScope provider.""" + model_data = { + "provider": "dashscope", + "model_type": "embedding", + "api_key": "test-key", + } + + with mock.patch( + "backend.services.model_provider_service.DashScopeModelProvider" + ) as mock_provider_class: + mock_provider_instance = mock.AsyncMock() + mock_provider_instance.get_models.return_value = [] + mock_provider_class.return_value = mock_provider_instance + + result = await get_provider_models(model_data) + + assert result == [] + mock_provider_instance.get_models.assert_called_once_with(model_data) + + +# ============================================================================ +# Test-cases for get_provider_models with TokenPony provider +# ============================================================================ + + +@pytest.mark.asyncio +async def test_get_provider_models_tokenpony_success(): + """Should successfully get models from TokenPony provider.""" + from backend.services.model_provider_service import TokenPonyModelProvider + + model_data = { + "provider": "tokenpony", + "model_type": "llm", + "api_key": "test-key", + } + + expected_models = [ + { + "id": "gpt-4", + "model_tag": "chat", + "model_type": "llm", + "max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS, + } + ] + + with mock.patch( + "backend.services.model_provider_service.TokenPonyModelProvider" + ) as mock_provider_class: + mock_provider_instance = mock.AsyncMock() + mock_provider_instance.get_models.return_value = expected_models + mock_provider_class.return_value = mock_provider_instance + + result = await get_provider_models(model_data) + + assert result == expected_models + mock_provider_class.assert_called_once() + mock_provider_instance.get_models.assert_called_once_with(model_data) + + +@pytest.mark.asyncio +async def test_get_provider_models_tokenpony_empty_result(): + """Should handle empty result from TokenPony provider.""" + model_data = { + "provider": "tokenpony", + "model_type": "embedding", + "api_key": "test-key", + } + + with mock.patch( + "backend.services.model_provider_service.TokenPonyModelProvider" + ) as mock_provider_class: + mock_provider_instance = mock.AsyncMock() + mock_provider_instance.get_models.return_value = [] + mock_provider_class.return_value = mock_provider_instance + + result = await get_provider_models(model_data) + + assert result == [] + mock_provider_instance.get_models.assert_called_once_with(model_data) \ No newline at end of file From 0515bd35fde6047f01ac4ebdb285c21431814252 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Thu, 5 Mar 2026 15:50:48 +0800 Subject: [PATCH 23/75] bugfix for test files of the backend providers --- .../services/providers/dashscope_provider.py | 7 +- .../providers/test_dashscope_provider.py | 164 ++++-------------- .../providers/test_tokenpony_provider.py | 4 +- 3 files changed, 38 insertions(+), 137 deletions(-) diff --git a/backend/services/providers/dashscope_provider.py b/backend/services/providers/dashscope_provider.py index cde54b60a..4ecbcbb1d 100644 --- a/backend/services/providers/dashscope_provider.py +++ b/backend/services/providers/dashscope_provider.py @@ -38,16 +38,17 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: if response.status_code == 429: await asyncio.sleep(2) continue - if not models : # Break loop if no more models on the current page - break response.raise_for_status() data = response.json() models = data.get("output", {}).get("models", []) + # Break loop if no more models on the current page + if not models: + break all_models.extend(models) - if(len(models)<100): + if len(models) < 100: break current_page += 1 await asyncio.sleep(0.5) diff --git a/test/backend/services/providers/test_dashscope_provider.py b/test/backend/services/providers/test_dashscope_provider.py index 2dc3a8f27..44bbdbda5 100644 --- a/test/backend/services/providers/test_dashscope_provider.py +++ b/test/backend/services/providers/test_dashscope_provider.py @@ -4,7 +4,7 @@ """ import pytest -from unittest.mock import MagicMock, AsyncMock, patch +from unittest.mock import MagicMock, AsyncMock, patch, Mock from pytest_mock import MockFixture import httpx @@ -15,6 +15,27 @@ class TestDashScopeModelProvider: """Tests for DashScopeModelProvider class.""" + def _setup_mock_client(self, mocker, mock_response): + """Set up mock for httpx.AsyncClient with proper context manager.""" + # Create mock client that handles the get request + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + # Create context manager mock + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + # Create a mock class that can be called with verify=False + mock_client_class = Mock(return_value=mock_cm) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + mock_client_class + ) + + return mock_client_class + @pytest.mark.asyncio async def test_get_models_llm_success(self, mocker: MockFixture): """Test successful model retrieval for LLM models.""" @@ -44,17 +65,8 @@ async def test_get_models_llm_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response + self._setup_mock_client(mocker, mock_response) - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -99,17 +111,8 @@ async def test_get_models_embedding_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) + self._setup_mock_client(mocker, mock_response) - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -149,17 +152,8 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response + self._setup_mock_client(mocker, mock_response) - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -199,17 +193,8 @@ async def test_get_models_reranker_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) + self._setup_mock_client(mocker, mock_response) - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -249,17 +234,8 @@ async def test_get_models_tts_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) + self._setup_mock_client(mocker, mock_response) - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -299,17 +275,8 @@ async def test_get_models_stt_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response + self._setup_mock_client(mocker, mock_response) - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -349,17 +316,8 @@ async def test_get_models_multi_embedding_success(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response + self._setup_mock_client(mocker, mock_response) - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -385,17 +343,8 @@ async def test_get_models_empty_response(self, mocker: MockFixture): mock_response.json.return_value = {"output": {"models": []}} mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response + self._setup_mock_client(mocker, mock_response) - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) mocker.patch( "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", "https://dashscope.aliyuncs.com/api/v1/models" @@ -429,10 +378,6 @@ async def test_get_models_http_error(self, mocker: MockFixture): "backend.services.providers.dashscope_provider.httpx.AsyncClient", return_value=mock_cm ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) provider = DashScopeModelProvider() provider_config = { @@ -460,10 +405,6 @@ async def test_get_models_connect_error(self, mocker: MockFixture): "backend.services.providers.dashscope_provider.httpx.AsyncClient", return_value=mock_cm ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) provider = DashScopeModelProvider() provider_config = { @@ -491,10 +432,6 @@ async def test_get_models_timeout(self, mocker: MockFixture): "backend.services.providers.dashscope_provider.httpx.AsyncClient", return_value=mock_cm ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) provider = DashScopeModelProvider() provider_config = { @@ -540,10 +477,6 @@ async def test_get_models_authorization_header(self, mocker: MockFixture): "backend.services.providers.dashscope_provider.httpx.AsyncClient", return_value=mock_cm ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) provider = DashScopeModelProvider() provider_config = { @@ -596,10 +529,6 @@ async def test_get_models_pagination(self, mocker: MockFixture): "backend.services.providers.dashscope_provider.httpx.AsyncClient", return_value=mock_cm ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) provider = DashScopeModelProvider() provider_config = { @@ -633,21 +562,7 @@ async def test_get_models_unknown_type_returns_empty(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) + self._setup_mock_client(mocker, mock_response) provider = DashScopeModelProvider() provider_config = { @@ -688,21 +603,7 @@ async def test_get_models_with_chinese_description(self, mocker: MockFixture): } mock_response.raise_for_status = MagicMock() - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - - mock_cm = MagicMock() - mock_cm.__aenter__ = AsyncMock(return_value=mock_client) - mock_cm.__aexit__ = AsyncMock(return_value=None) - - mocker.patch( - "backend.services.providers.dashscope_provider.httpx.AsyncClient", - return_value=mock_cm - ) - mocker.patch( - "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", - "https://dashscope.aliyuncs.com/api/v1/models" - ) + self._setup_mock_client(mocker, mock_response) provider = DashScopeModelProvider() @@ -715,4 +616,3 @@ async def test_get_models_with_chinese_description(self, mocker: MockFixture): result = await provider.get_models({"model_type": "reranker", "api_key": "test-key"}) assert len(result) == 1 assert result[0]["id"] == "rerank-v1" - diff --git a/test/backend/services/providers/test_tokenpony_provider.py b/test/backend/services/providers/test_tokenpony_provider.py index 4f4a564e1..7fd9df9eb 100644 --- a/test/backend/services/providers/test_tokenpony_provider.py +++ b/test/backend/services/providers/test_tokenpony_provider.py @@ -258,7 +258,7 @@ async def test_get_models_stt_success(self, mocker: MockFixture): mock_response.json.return_value = { "data": [ { - "id": "whisper-1", + "id": "stt-whisper-1", "object": "model", "owned_by": "openai" } @@ -291,7 +291,7 @@ async def test_get_models_stt_success(self, mocker: MockFixture): result = await provider.get_models(provider_config) assert len(result) == 1 - assert result[0]["id"] == "whisper-1" + assert result[0]["id"] == "stt-whisper-1" assert result[0]["model_type"] == "stt" assert result[0]["model_tag"] == "stt" From 74e3c1a20d66ce238b07f4387e1786f8bb204694 Mon Sep 17 00:00:00 2001 From: wadecrack <2138269670@qq.com> Date: Fri, 6 Mar 2026 17:08:57 +0800 Subject: [PATCH 24/75] improve codecov for testfiles --- .../providers/test_dashscope_provider.py | 51 ++++++++++++++++++ .../services/test_model_management_service.py | 52 +++++++++++++++++++ 2 files changed, 103 insertions(+) diff --git a/test/backend/services/providers/test_dashscope_provider.py b/test/backend/services/providers/test_dashscope_provider.py index 44bbdbda5..0bc2d3ad8 100644 --- a/test/backend/services/providers/test_dashscope_provider.py +++ b/test/backend/services/providers/test_dashscope_provider.py @@ -574,6 +574,57 @@ async def test_get_models_unknown_type_returns_empty(self, mocker: MockFixture): assert result == [] + @pytest.mark.asyncio + async def test_get_models_rate_limit_retry(self, mocker: MockFixture): + """Test that a 429 response triggers a retry after sleeping.""" + rate_limit_response = MagicMock() + rate_limit_response.status_code = 429 + + ok_response = MagicMock() + ok_response.status_code = 200 + ok_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-turbo", + "description": "Text generation", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"], + }, + } + ] + } + } + ok_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.side_effect = [rate_limit_response, ok_response] + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.dashscope_provider.httpx.AsyncClient", + return_value=mock_cm, + ) + mocker.patch( + "backend.services.providers.dashscope_provider.DASHSCOPE_GET_URL", + "https://dashscope.aliyuncs.com/api/v1/models", + ) + mocker.patch( + "backend.services.providers.dashscope_provider.asyncio.sleep", + new=AsyncMock(), + ) + + provider = DashScopeModelProvider() + result = await provider.get_models({"model_type": "llm", "api_key": "test-key"}) + + assert mock_client.get.call_count == 2 + assert len(result) == 1 + assert result[0]["id"] == "qwen-turbo" + @pytest.mark.asyncio async def test_get_models_with_chinese_description(self, mocker: MockFixture): """Test model classification by Chinese description.""" diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index e5d52d31a..6e504e90a 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -557,6 +557,58 @@ async def test_create_provider_models_for_tenant_exception(): assert "Failed to create provider models" in str(exc.value) +@pytest.mark.asyncio +async def test_batch_create_models_for_tenant_dashscope_provider(): + """Test batch_create_models_for_tenant with DASHSCOPE provider uses DASHSCOPE_BASE_URL.""" + svc = import_svc() + + batch_payload = { + "provider": "dashscope", + "type": "llm", + "models": [{"id": "qwen/qwen-turbo", "max_tokens": 8192}], + "api_key": "dash-key", + } + + with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \ + mock.patch.object(svc, "delete_model_record"), \ + mock.patch.object(svc, "split_repo_name", return_value=("qwen", "qwen-turbo")), \ + mock.patch.object(svc, "add_repo_to_name", return_value="qwen/qwen-turbo"), \ + mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \ + mock.patch.object(svc, "create_model_record", return_value=True): + + await svc.batch_create_models_for_tenant("u1", "t1", batch_payload) + + call_args = svc.prepare_model_dict.call_args + assert call_args[1]["model_url"] == "https://dashscope.aliyuncs.com/compatible-mode/v1/" + + +@pytest.mark.asyncio +async def test_batch_create_models_for_tenant_tokenpony_provider(): + """Test batch_create_models_for_tenant with TOKENPONY provider uses TOKENPONY_BASE_URL.""" + svc = import_svc() + + batch_payload = { + "provider": "tokenpony", + "type": "llm", + "models": [{"id": "gpt/gpt-4o", "max_tokens": 128000}], + "api_key": "tp-key", + } + + with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \ + mock.patch.object(svc, "delete_model_record"), \ + mock.patch.object(svc, "split_repo_name", return_value=("gpt", "gpt-4o")), \ + mock.patch.object(svc, "add_repo_to_name", return_value="gpt/gpt-4o"), \ + mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 2})), \ + mock.patch.object(svc, "create_model_record", return_value=True): + + await svc.batch_create_models_for_tenant("u1", "t1", batch_payload) + + call_args = svc.prepare_model_dict.call_args + assert call_args[1]["model_url"] == "https://api.tokenpony.cn/v1/" + + @pytest.mark.asyncio async def test_batch_create_models_for_tenant_other_provider(): """Test batch_create_models_for_tenant with non-Silicon/ModelEngine provider (covers lines 138-140)""" From 064552f6cebc3e36d397fccd34f2b2c21b07acb8 Mon Sep 17 00:00:00 2001 From: zwb <1194371519@qq.com> Date: Fri, 6 Mar 2026 17:56:56 +0800 Subject: [PATCH 25/75] Increase patch coverage --- test/backend/app/test_file_management_app.py | 18 ++ .../services/test_data_process_service.py | 168 ++++++++++++++++++ 2 files changed, 186 insertions(+) diff --git a/test/backend/app/test_file_management_app.py b/test/backend/app/test_file_management_app.py index 1165f3d9d..cd85e8935 100644 --- a/test/backend/app/test_file_management_app.py +++ b/test/backend/app/test_file_management_app.py @@ -1174,3 +1174,21 @@ async def fake_preview(object_name): assert "File not found" in str(ei.value) +@pytest.mark.asyncio +async def test_preview_file_office_conversion_error(monkeypatch): + """OfficeConversionException from preview_file_impl → HTTP 500 with conversion detail.""" + _OfficeConversionException = sys.modules["consts.exceptions"].OfficeConversionException + + async def fake_preview(object_name): + raise _OfficeConversionException("LibreOffice conversion failed") + + monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) + + with pytest.raises(Exception) as ei: + await file_management_app.preview_file( + object_name="files/report.docx", + filename=None + ) + assert "Failed to preview file" in str(ei.value) + + diff --git a/test/backend/services/test_data_process_service.py b/test/backend/services/test_data_process_service.py index 6d33e097a..ef9d1e926 100644 --- a/test/backend/services/test_data_process_service.py +++ b/test/backend/services/test_data_process_service.py @@ -2348,6 +2348,174 @@ def test_convert_office_to_pdf_impl_invalid_pdf_header( self.assertIn('invalid PDF header', str(ctx.exception)) mock_delete_file.assert_called_once_with('converted/doc.pdf') + @patch('backend.services.data_process_service.convert_office_to_pdf', + new_callable=AsyncMock) + @patch('backend.services.data_process_service.upload_file') + @patch('backend.services.data_process_service.get_file_size_from_minio') + @patch('backend.services.data_process_service.get_file_stream') + @patch('shutil.rmtree') + @patch('tempfile.mkdtemp', return_value='/tmp/test_cv') + @patch('os.path.exists', return_value=True) + def test_convert_office_to_pdf_impl_size_zero( + self, _exists, _mkdtemp, mock_rmtree, + mock_get_stream, mock_get_size, mock_upload, mock_convert + ): + """remote_size == 0 → OfficeConversionException: cannot read remote file size.""" + mock_get_stream.return_value = self._make_stream(b'DOC data') + mock_get_size.return_value = 0 + mock_upload.return_value = {'success': True} + mock_convert.return_value = '/tmp/test_cv/doc.pdf' + sys.modules['database.attachment_db'].file_exists = MagicMock(return_value=False) + with patch('builtins.open', MagicMock()): + with self.assertRaises(OfficeConversionException) as ctx: + asyncio.run( + self.service.convert_office_to_pdf_impl( + 'uploads/doc.docx', 'converted/doc.pdf' + ) + ) + self.assertIn('cannot read remote file size', str(ctx.exception)) + + @patch('backend.services.data_process_service.convert_office_to_pdf', + new_callable=AsyncMock) + @patch('backend.services.data_process_service.upload_file') + @patch('backend.services.data_process_service.get_file_size_from_minio') + @patch('backend.services.data_process_service.get_file_stream') + @patch('shutil.rmtree') + @patch('tempfile.mkdtemp', return_value='/tmp/test_cv') + @patch('os.path.exists', return_value=True) + def test_convert_office_to_pdf_impl_size_too_small( + self, _exists, _mkdtemp, mock_rmtree, + mock_get_stream, mock_get_size, mock_upload, mock_convert + ): + """remote_size < 100 (but > 0) → OfficeConversionException: file too small.""" + mock_get_stream.return_value = self._make_stream(b'DOC data') + mock_get_size.return_value = 50 + mock_upload.return_value = {'success': True} + mock_convert.return_value = '/tmp/test_cv/doc.pdf' + sys.modules['database.attachment_db'].file_exists = MagicMock(return_value=False) + with patch('builtins.open', MagicMock()): + with self.assertRaises(OfficeConversionException) as ctx: + asyncio.run( + self.service.convert_office_to_pdf_impl( + 'uploads/doc.docx', 'converted/doc.pdf' + ) + ) + self.assertIn('file too small', str(ctx.exception)) + + @patch('backend.services.data_process_service.convert_office_to_pdf', + new_callable=AsyncMock) + @patch('backend.services.data_process_service.upload_file') + @patch('backend.services.data_process_service.get_file_size_from_minio') + @patch('backend.services.data_process_service.get_file_stream') + @patch('shutil.rmtree') + @patch('tempfile.mkdtemp', return_value='/tmp/test_cv') + @patch('os.path.exists', return_value=True) + def test_convert_office_to_pdf_impl_stream_none( + self, _exists, _mkdtemp, mock_rmtree, + mock_get_stream, mock_get_size, mock_upload, mock_convert + ): + """get_file_stream returns None for header check → OfficeConversionException.""" + mock_get_stream.side_effect = [ + self._make_stream(b'DOC data'), # Step 1: original file + None, # Step 4: header check stream + ] + mock_get_size.return_value = 208 + mock_upload.return_value = {'success': True} + mock_convert.return_value = '/tmp/test_cv/doc.pdf' + sys.modules['database.attachment_db'].file_exists = MagicMock(return_value=False) + with patch('builtins.open', MagicMock()): + with self.assertRaises(OfficeConversionException) as ctx: + asyncio.run( + self.service.convert_office_to_pdf_impl( + 'uploads/doc.docx', 'converted/doc.pdf' + ) + ) + self.assertIn('cannot read uploaded file', str(ctx.exception)) + + @patch('backend.services.data_process_service.convert_office_to_pdf', + new_callable=AsyncMock) + @patch('backend.services.data_process_service.upload_file') + @patch('backend.services.data_process_service.get_file_size_from_minio') + @patch('backend.services.data_process_service.get_file_stream') + @patch('shutil.rmtree') + @patch('tempfile.mkdtemp', return_value='/tmp/test_cv') + @patch('os.path.exists', return_value=True) + def test_convert_office_to_pdf_impl_close_raises( + self, _exists, _mkdtemp, mock_rmtree, + mock_get_stream, mock_get_size, mock_upload, mock_convert + ): + """stream.close() raises during header check → exception swallowed, pipeline succeeds.""" + header_stream = MagicMock() + header_stream.read.return_value = b'%PDF-1.4' + header_stream.close.side_effect = OSError('close failed') + mock_get_stream.side_effect = [ + self._make_stream(b'DOC data'), # Step 1: original file + header_stream, # Step 4: header check + ] + mock_get_size.return_value = 208 + mock_upload.return_value = {'success': True} + mock_convert.return_value = '/tmp/test_cv/doc.pdf' + with patch('builtins.open', MagicMock()): + asyncio.run( + self.service.convert_office_to_pdf_impl( + 'uploads/doc.docx', 'converted/doc.pdf' + ) + ) + mock_convert.assert_called_once() + + @patch('backend.services.data_process_service.convert_office_to_pdf', + new_callable=AsyncMock) + @patch('backend.services.data_process_service.upload_file') + @patch('backend.services.data_process_service.get_file_stream') + @patch('shutil.rmtree') + @patch('tempfile.mkdtemp', return_value='/tmp/test_cv') + @patch('os.path.exists', return_value=True) + def test_convert_office_to_pdf_impl_unexpected_exception( + self, _exists, _mkdtemp, mock_rmtree, + mock_get_stream, mock_upload, mock_convert + ): + """Non-OfficeConversionException from upload_file → wrapped as OfficeConversionException.""" + mock_get_stream.return_value = self._make_stream(b'DOC data') + mock_convert.return_value = '/tmp/test_cv/doc.pdf' + mock_upload.side_effect = ConnectionError('storage unreachable') + with patch('builtins.open', MagicMock()): + with self.assertRaises(OfficeConversionException) as ctx: + asyncio.run( + self.service.convert_office_to_pdf_impl( + 'uploads/doc.docx', 'converted/doc.pdf' + ) + ) + self.assertIn('Unexpected error', str(ctx.exception)) + + @patch('backend.services.data_process_service.convert_office_to_pdf', + new_callable=AsyncMock) + @patch('backend.services.data_process_service.upload_file') + @patch('backend.services.data_process_service.get_file_size_from_minio') + @patch('backend.services.data_process_service.get_file_stream') + @patch('shutil.rmtree') + @patch('tempfile.mkdtemp', return_value='/tmp/test_cv') + @patch('os.path.exists', return_value=True) + def test_convert_office_to_pdf_impl_cleanup_failure( + self, _exists, _mkdtemp, mock_rmtree, + mock_get_stream, mock_get_size, mock_upload, mock_convert + ): + """shutil.rmtree raises during cleanup → error is logged, not re-raised.""" + mock_get_stream.side_effect = [ + self._make_stream(b'DOC data'), # Step 1: original file + self._make_stream(b'%PDF-1.4 ok'), # Step 4: header check + ] + mock_get_size.return_value = 208 + mock_upload.return_value = {'success': True} + mock_convert.return_value = '/tmp/test_cv/doc.pdf' + mock_rmtree.side_effect = OSError('permission denied') + with patch('builtins.open', MagicMock()): + # Cleanup error must not propagate + asyncio.run( + self.service.convert_office_to_pdf_impl( + 'uploads/doc.docx', 'converted/doc.pdf' + ) + ) + if __name__ == '__main__': unittest.main() From e4d4d97ae4903046ddf5f1f70baf046cdc694caa Mon Sep 17 00:00:00 2001 From: zwb <1194371519@qq.com> Date: Fri, 6 Mar 2026 19:42:56 +0800 Subject: [PATCH 26/75] fix issues and update tests --- backend/consts/const.py | 2 +- backend/services/file_management_service.py | 73 ++++++++++--------- sdk/nexent/storage/storage_client_base.py | 5 +- test/backend/app/test_file_management_app.py | 16 ---- .../services/test_data_process_service.py | 35 +++++++++ 5 files changed, 74 insertions(+), 57 deletions(-) diff --git a/backend/consts/const.py b/backend/consts/const.py index 6249af049..668540250 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -39,7 +39,7 @@ class VectorDatabaseType(str, Enum): # Preview Configuration FILE_PREVIEW_SIZE_LIMIT = 100 * 1024 * 1024 # 100MB # Limit concurrent Office-to-PDF conversions -MAX_CONCURRENT_CONVERSIONS = 5 +MAX_CONCURRENT_CONVERSIONS = 5 # Supported Office file MIME types OFFICE_MIME_TYPES = [ 'application/msword', # .doc diff --git a/backend/services/file_management_service.py b/backend/services/file_management_service.py index 7c7886bdc..39b3af858 100644 --- a/backend/services/file_management_service.py +++ b/backend/services/file_management_service.py @@ -300,43 +300,44 @@ async def _convert_office_to_cached_pdf( _conversion_locks[object_name] = asyncio.Lock() file_lock = _conversion_locks[object_name] - async with file_lock: - # Double-check: another request may have completed the conversion while we waited - cached_stream = _get_cached_pdf_stream(pdf_object_name) - if cached_stream is not None: - return cached_stream - - # Conversion semaphore is enforced inside the data-process service - try: - # Request conversion: data-process downloads, converts, uploads to temp path, validates - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post( - f"{DATA_PROCESS_SERVICE}/tasks/convert_to_pdf", - data={ - "object_name": object_name, - "pdf_object_name": temp_pdf_object_name, - }, - ) - if response.status_code != 200: - raise Exception( - f"data-process conversion returned {response.status_code}: {response.text}" - ) - - # Atomic move from temp to final location, then clean up temp - copy_result = copy_file(source_object=temp_pdf_object_name, dest_object=pdf_object_name) - if not copy_result.get('success'): - raise Exception(f"Failed to finalize PDF cache: {copy_result.get('error', 'Unknown error')}") - delete_file(temp_pdf_object_name) - - except Exception as e: - if file_exists(temp_pdf_object_name): + try: + async with file_lock: + # Double-check: another request may have completed the conversion while we waited + cached_stream = _get_cached_pdf_stream(pdf_object_name) + if cached_stream is not None: + return cached_stream + + # Conversion semaphore is enforced inside the data-process service + try: + # Request conversion: data-process downloads, converts, uploads to temp path, validates + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post( + f"{DATA_PROCESS_SERVICE}/tasks/convert_to_pdf", + data={ + "object_name": object_name, + "pdf_object_name": temp_pdf_object_name, + }, + ) + if response.status_code != 200: + raise Exception( + f"data-process conversion returned {response.status_code}: {response.text}" + ) + + # Atomic move from temp to final location, then clean up temp + copy_result = copy_file(source_object=temp_pdf_object_name, dest_object=pdf_object_name) + if not copy_result.get('success'): + raise Exception(f"Failed to finalize PDF cache: {copy_result.get('error', 'Unknown error')}") delete_file(temp_pdf_object_name) - logger.error(f"Office conversion failed: {str(e)}") - raise OfficeConversionException(f"Failed to convert Office document to PDF: {str(e)}") from e - finally: - # Clean up the file lock (prevents memory leak for many unique files) - async with _conversion_locks_guard: - _conversion_locks.pop(object_name, None) + + except Exception as e: + if file_exists(temp_pdf_object_name): + delete_file(temp_pdf_object_name) + logger.error(f"Office conversion failed: {str(e)}") + raise OfficeConversionException(f"Failed to convert Office document to PDF: {str(e)}") from e + finally: + # Clean up the file lock (prevents memory leak for many unique files) + async with _conversion_locks_guard: + _conversion_locks.pop(object_name, None) file_stream = get_file_stream(pdf_object_name) if file_stream is None: diff --git a/sdk/nexent/storage/storage_client_base.py b/sdk/nexent/storage/storage_client_base.py index 05623a0c0..90a37f395 100644 --- a/sdk/nexent/storage/storage_client_base.py +++ b/sdk/nexent/storage/storage_client_base.py @@ -235,7 +235,4 @@ def copy_file( Returns: Tuple[bool, str]: (Success status, Destination object name or error message) """ - pass - - - + pass \ No newline at end of file diff --git a/test/backend/app/test_file_management_app.py b/test/backend/app/test_file_management_app.py index cd85e8935..1721b5f98 100644 --- a/test/backend/app/test_file_management_app.py +++ b/test/backend/app/test_file_management_app.py @@ -1028,22 +1028,6 @@ async def fake_preview(object_name): assert "filename*=UTF-8" in content_disposition or "测试文档" in content_disposition -@pytest.mark.asyncio -async def test_preview_file_not_found_error(monkeypatch): - """Test previewing a non-existent file returns 404""" - async def fake_preview(object_name): - raise Exception("File not found") - - monkeypatch.setattr(file_management_app, "preview_file_impl", fake_preview) - - with pytest.raises(Exception) as ei: - await file_management_app.preview_file( - object_name="nonexistent/file.pdf", - filename=None - ) - assert "File not found" in str(ei.value) - - @pytest.mark.asyncio async def test_preview_file_too_large_error(monkeypatch): """Test previewing a file exceeding size limit returns 413""" diff --git a/test/backend/services/test_data_process_service.py b/test/backend/services/test_data_process_service.py index ef9d1e926..03afeefbe 100644 --- a/test/backend/services/test_data_process_service.py +++ b/test/backend/services/test_data_process_service.py @@ -2348,6 +2348,41 @@ def test_convert_office_to_pdf_impl_invalid_pdf_header( self.assertIn('invalid PDF header', str(ctx.exception)) mock_delete_file.assert_called_once_with('converted/doc.pdf') + @patch('backend.services.data_process_service.file_exists', return_value=False) + @patch('backend.services.data_process_service.get_file_stream', return_value=None) + @patch('shutil.rmtree') + @patch('tempfile.mkdtemp', return_value='/tmp/test_cv') + @patch('os.path.exists', return_value=True) + def test_convert_office_to_pdf_impl_no_remote_cleanup_when_not_exists( + self, _exists, _mkdtemp, mock_rmtree, _get_stream, mock_file_exists + ): + """OfficeConversionException raised and file_exists=False → delete_file never called (623->625 branch).""" + with patch('backend.services.data_process_service.delete_file') as mock_del: + with self.assertRaises(OfficeConversionException): + asyncio.run( + self.service.convert_office_to_pdf_impl( + 'uploads/doc.docx', 'converted/doc.pdf' + ) + ) + mock_del.assert_not_called() + + @patch('backend.services.data_process_service.get_file_stream', return_value=None) + @patch('shutil.rmtree') + @patch('tempfile.mkdtemp', side_effect=OSError('no space left on device')) + @patch('os.path.exists', return_value=True) + def test_convert_office_to_pdf_impl_mkdtemp_failure( + self, _exists, mock_mkdtemp, mock_rmtree, _get_stream + ): + """tempfile.mkdtemp raises → temp_dir stays None → finally skips cleanup (630->exit branch).""" + with self.assertRaises(OfficeConversionException) as ctx: + asyncio.run( + self.service.convert_office_to_pdf_impl( + 'uploads/doc.docx', 'converted/doc.pdf' + ) + ) + self.assertIn('Unexpected error', str(ctx.exception)) + mock_rmtree.assert_not_called() + @patch('backend.services.data_process_service.convert_office_to_pdf', new_callable=AsyncMock) @patch('backend.services.data_process_service.upload_file') From 18ddb91ab57c62653207983ba49398ee9e3ad365 Mon Sep 17 00:00:00 2001 From: zwb <1194371519@qq.com> Date: Fri, 6 Mar 2026 20:10:58 +0800 Subject: [PATCH 27/75] add test for existing lock --- .../services/test_data_process_service.py | 4 ++-- .../services/test_file_management_service.py | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/test/backend/services/test_data_process_service.py b/test/backend/services/test_data_process_service.py index 03afeefbe..c52e496bb 100644 --- a/test/backend/services/test_data_process_service.py +++ b/test/backend/services/test_data_process_service.py @@ -2356,7 +2356,7 @@ def test_convert_office_to_pdf_impl_invalid_pdf_header( def test_convert_office_to_pdf_impl_no_remote_cleanup_when_not_exists( self, _exists, _mkdtemp, mock_rmtree, _get_stream, mock_file_exists ): - """OfficeConversionException raised and file_exists=False → delete_file never called (623->625 branch).""" + """OfficeConversionException raised and file_exists=False → delete_file never called.""" with patch('backend.services.data_process_service.delete_file') as mock_del: with self.assertRaises(OfficeConversionException): asyncio.run( @@ -2373,7 +2373,7 @@ def test_convert_office_to_pdf_impl_no_remote_cleanup_when_not_exists( def test_convert_office_to_pdf_impl_mkdtemp_failure( self, _exists, mock_mkdtemp, mock_rmtree, _get_stream ): - """tempfile.mkdtemp raises → temp_dir stays None → finally skips cleanup (630->exit branch).""" + """tempfile.mkdtemp raises → temp_dir stays None → finally skips cleanup.""" with self.assertRaises(OfficeConversionException) as ctx: asyncio.run( self.service.convert_office_to_pdf_impl( diff --git a/test/backend/services/test_file_management_service.py b/test/backend/services/test_file_management_service.py index cc02add6d..2e7e4f43a 100644 --- a/test/backend/services/test_file_management_service.py +++ b/test/backend/services/test_file_management_service.py @@ -1450,3 +1450,27 @@ async def test_converted_pdf_not_readable_raises_not_found(self): "preview/converted/docs/report_deadbeef.pdf", "preview/converting/docs/report_deadbeef.pdf.tmp", ) + + @pytest.mark.asyncio + async def test_reuses_existing_lock_for_same_object(self): + """If a lock for object_name already exists, it is reused.""" + import asyncio as _asyncio + import backend.services.file_management_service as _svc + from backend.services.file_management_service import _convert_office_to_cached_pdf + + existing_lock = _asyncio.Lock() + _svc._conversion_locks["docs/existing.docx"] = existing_lock + + mock_stream = BytesIO(b"%PDF-1.4 cached") + try: + with patch('backend.services.file_management_service._get_cached_pdf_stream', + return_value=mock_stream): + result = await _convert_office_to_cached_pdf( + "docs/existing.docx", + "preview/converted/docs/existing_aabbccdd.pdf", + "preview/converting/docs/existing_aabbccdd.pdf.tmp", + ) + finally: + _svc._conversion_locks.pop("docs/existing.docx", None) + + assert result is mock_stream From d0651a8b2f367c620aa5f499c5401c2df76c1084 Mon Sep 17 00:00:00 2001 From: CHGZX <88022755+CHGZX@users.noreply.github.com> Date: Sun, 8 Mar 2026 18:02:44 +0800 Subject: [PATCH 28/75] Add tip section for GZX with gratitude message --- doc/docs/zh/opensource-memorial-wall.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/docs/zh/opensource-memorial-wall.md b/doc/docs/zh/opensource-memorial-wall.md index 54bac7c28..a61e42c64 100644 --- a/doc/docs/zh/opensource-memorial-wall.md +++ b/doc/docs/zh/opensource-memorial-wall.md @@ -711,3 +711,7 @@ Nexent 加油!希望能达成所愿! ::: info sisyphus0x - 2026-03-04 对多智能体编排和协同工作很感兴趣,学习一下 ::: + +::: tip GZX- 2026-03-08 +感谢 Nexent 期待与Nexent一起进步。 +::: From 2dbc9bb42217a4073747a7f632511d2e187f7a87 Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Mon, 9 Mar 2026 10:59:58 +0800 Subject: [PATCH 29/75] =?UTF-8?q?=E2=9C=A8=20Update=20exception=20handling?= =?UTF-8?q?=20in=20tests:=20Change=20expected=20error=20codes=20to=20strin?= =?UTF-8?q?gs=20for=20consistency=20and=20enhance=20error=20propagation=20?= =?UTF-8?q?tests=20in=20prompt=20service.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/backend/consts/test_exceptions.py | 8 +- .../middleware/test_exception_handler.py | 5 +- test/backend/services/test_prompt_service.py | 392 ++++++++++++++++-- test/backend/utils/test_llm_utils.py | 294 ++++++++++++- 4 files changed, 665 insertions(+), 34 deletions(-) diff --git a/test/backend/consts/test_exceptions.py b/test/backend/consts/test_exceptions.py index 4ec5d0234..8c954aff9 100644 --- a/test/backend/consts/test_exceptions.py +++ b/test/backend/consts/test_exceptions.py @@ -44,8 +44,8 @@ def test_app_exception_to_dict(self): """Test AppException.to_dict() method.""" exc = AppException(ErrorCode.DIFY_AUTH_ERROR, "Auth failed", {"key": "value"}) result = exc.to_dict() - - assert result["code"] == 130204 + + assert result["code"] == "130204" assert result["message"] == "Auth failed" assert result["details"] == {"key": "value"} @@ -53,7 +53,7 @@ def test_app_exception_to_dict_null_details(self): """Test that to_dict() returns null for empty details.""" exc = AppException(ErrorCode.DIFY_AUTH_ERROR, "Auth failed") result = exc.to_dict() - + assert result["details"] is None def test_app_exception_http_status_property(self): @@ -71,7 +71,7 @@ def test_app_exception_http_status_for_different_codes(self): (ErrorCode.COMMON_TOKEN_EXPIRED, 401), (ErrorCode.COMMON_FORBIDDEN, 403), ] - + for error_code, expected_status in test_cases: exc = AppException(error_code) assert exc.http_status == expected_status, \ diff --git a/test/backend/middleware/test_exception_handler.py b/test/backend/middleware/test_exception_handler.py index 0b2bfc865..0b234ef4b 100644 --- a/test/backend/middleware/test_exception_handler.py +++ b/test/backend/middleware/test_exception_handler.py @@ -412,8 +412,7 @@ def test_error_response_contains_code_as_int(self): import json body = json.loads(response.body) assert "code" in body - # Code should be integer when converted (string "130204" -> int 130204) - assert body["code"] == 130204 + assert body["code"] == "130204" def test_error_response_contains_message(self): """Test that error response contains message.""" @@ -516,7 +515,7 @@ def test_to_dict_contains_code(self): """Test that to_dict contains code as integer.""" exc = AppException(ErrorCode.DIFY_AUTH_ERROR, "Auth failed") result = exc.to_dict() - assert result["code"] == 130204 + assert result["code"] == "130204" def test_to_dict_contains_message(self): """Test that to_dict contains message.""" diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index 01474d205..3b33f1a5e 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -728,31 +728,20 @@ def mock_llm_call_with_exception(model_id, content, sys_prompt, callback, tenant mock_tenant_id = "test_tenant" mock_language = "en" - # Execute - should handle exceptions gracefully - result_list = [] - for result in generate_system_prompt( - mock_sub_agents, - mock_task_description, - mock_tools, - mock_tenant_id, - self.test_model_id, - mock_language - ): - result_list.append(result) - - # Assert - should still return results for other prompt types - self.assertGreater(len(result_list), 0) - - # Constraint should work fine - constraint_results = [ - r for r in result_list if r["type"] == "constraint"] - self.assertGreater(len(constraint_results), 0) - - # Verify that duty result exists but might be empty due to exception handling - duty_results = [r for r in result_list if r["type"] == "duty"] - - # Should still have duty result entry with empty content - self.assertGreater(len(duty_results), 0) + # Execute - exception should be raised (this tests the error propagation behavior) + with self.assertRaises(Exception) as context: + for result in generate_system_prompt( + mock_sub_agents, + mock_task_description, + mock_tools, + mock_tenant_id, + self.test_model_id, + mock_language + ): + pass # Consume the generator to trigger the exception + + # Assert - exception message should be present + self.assertIn("LLM error", str(context.exception)) @patch('backend.services.prompt_service.Template') def test_join_info_for_generate_system_prompt(self, mock_template): @@ -844,4 +833,357 @@ def test_get_enabled_sub_agent_description_for_generate_prompt( self.assertEqual(result[0]["agent_id"], 10) self.assertEqual(result[1]["agent_id"], 20) + # ==================== Additional tests for higher coverage ==================== + + @patch('backend.services.prompt_service.generate_and_save_system_prompt_impl') + def test_gen_system_prompt_streamable_with_app_exception(self, mock_generate_impl): + """Test gen_system_prompt_streamable handles AppException and returns error through SSE""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + # Setup - mock generate_and_save_system_prompt_impl to raise AppException + mock_generate_impl.side_effect = AppException( + ErrorCode.MODEL_NOT_FOUND, + "Model not found error" + ) + + # Execute - collect results from the generator + result_list = [] + for result in gen_system_prompt_streamable( + agent_id=123, + model_id=self.test_model_id, + task_description="Test task", + user_id="user123", + tenant_id="tenant456", + language="zh" + ): + result_list.append(result) + + # Assert - should yield error in SSE format + self.assertEqual(len(result_list), 1) + import json + parsed = json.loads(result_list[0].replace("data: ", "").replace("\n\n", "")) + self.assertFalse(parsed['success']) + self.assertEqual(parsed['error']['code'], str(ErrorCode.MODEL_NOT_FOUND.value)) + self.assertEqual(parsed['error']['message'], "Model not found error") + + @patch('backend.services.prompt_service.generate_and_save_system_prompt_impl') + def test_gen_system_prompt_streamable_with_generic_exception(self, mock_generate_impl): + """Test gen_system_prompt_streamable handles generic Exception and returns error through SSE""" + # Setup - mock generate_and_save_system_prompt_impl to raise generic Exception + mock_generate_impl.side_effect = Exception("Some random error") + + # Execute - collect results from the generator + result_list = [] + for result in gen_system_prompt_streamable( + agent_id=123, + model_id=self.test_model_id, + task_description="Test task", + user_id="user123", + tenant_id="tenant456", + language="zh" + ): + result_list.append(result) + + # Assert - should yield error in SSE format with default error code + self.assertEqual(len(result_list), 1) + import json + parsed = json.loads(result_list[0].replace("data: ", "").replace("\n\n", "")) + self.assertFalse(parsed['success']) + # Should use default error code for non-AppException + self.assertIn('error', parsed) + + @patch('backend.services.prompt_service.search_agent_info_by_agent_id') + @patch('backend.services.prompt_service.query_tools_by_ids') + @patch('backend.services.prompt_service.generate_system_prompt') + @patch('backend.services.prompt_service.query_all_agent_info_by_tenant_id') + def test_generate_and_save_system_prompt_impl_sub_agent_exception( + self, + mock_query_all_agents, + mock_generate_system_prompt, + mock_query_tools, + mock_search_agent_info, + ): + """Test generate_and_save_system_prompt_impl handles sub-agent info retrieval exception (lines 88-89)""" + # Setup + mock_query_tools.return_value = [] + mock_query_all_agents.return_value = [] + + # Mock generate_system_prompt to yield data + def mock_gen(*args, **kwargs): + yield {"type": "duty", "content": "duty content", "is_complete": True} + + mock_generate_system_prompt.side_effect = mock_gen + + # Make search_agent_info_by_agent_id raise exception for one sub-agent + mock_search_agent_info.side_effect = [ + {"agent_id": 10, "name": "agent1"}, # First sub-agent succeeds + Exception("Database error"), # Second sub-agent fails + ] + + # Execute - should handle exception gracefully and continue + result_gen = generate_and_save_system_prompt_impl( + agent_id=123, + model_id=self.test_model_id, + task_description="Test task", + user_id="user123", + tenant_id="tenant456", + language="zh", + tool_ids=[1], + sub_agent_ids=[10, 20] # Two sub-agents + ) + result = list(result_gen) + + # Assert - should still return results (exception was logged but not raised) + self.assertGreater(len(result), 0) + + @patch('backend.services.prompt_service._check_agent_display_name_duplicate') + @patch('backend.services.prompt_service._check_agent_name_duplicate') + @patch('backend.services.prompt_service.query_all_agent_info_by_tenant_id') + @patch('backend.services.prompt_service.generate_system_prompt') + @patch('backend.services.prompt_service.query_tools_by_ids') + @patch('backend.services.prompt_service.search_agent_info_by_agent_id') + def test_generate_and_save_system_prompt_impl_empty_content_raises_exception( + self, + mock_search_agent_info, + mock_query_tools, + mock_generate_system_prompt, + mock_query_all_agents, + mock_check_name_dup, + mock_check_display_dup, + ): + """Test generate_and_save_system_prompt_impl raises exception when no content is generated (line 223)""" + # Setup + mock_query_tools.return_value = [] + mock_search_agent_info.return_value = {} + mock_query_all_agents.return_value = [] + mock_check_name_dup.return_value = False + mock_check_display_dup.return_value = False + + # Mock generate_system_prompt to yield empty content + def mock_gen(*args, **kwargs): + yield {"type": "duty", "content": "", "is_complete": True} + yield {"type": "constraint", "content": "", "is_complete": True} + yield {"type": "few_shots", "content": "", "is_complete": True} + yield {"type": "agent_var_name", "content": "", "is_complete": True} + yield {"type": "agent_display_name", "content": "", "is_complete": True} + yield {"type": "agent_description", "content": "", "is_complete": True} + + mock_generate_system_prompt.side_effect = mock_gen + + # Execute and Assert - should raise Exception when all content is empty + with self.assertRaises(Exception) as context: + list(generate_and_save_system_prompt_impl( + agent_id=123, + model_id=self.test_model_id, + task_description="Test task", + user_id="user123", + tenant_id="tenant456", + language="zh", + tool_ids=[1], + sub_agent_ids=[10], + )) + + self.assertIn("Failed to generate prompt content", str(context.exception)) + + @patch('backend.services.prompt_service.call_llm_for_system_prompt') + @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') + @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') + def test_generate_system_prompt_error_before_streaming( + self, + mock_get_prompt_template, + mock_join_info, + mock_call_llm, + ): + """Test generate_system_prompt handles error that occurs before streaming (line 307-311)""" + # Setup + mock_prompt_config = { + "USER_PROMPT": "Test user prompt template", + "DUTY_SYSTEM_PROMPT": "Generate duty prompt", + "CONSTRAINT_SYSTEM_PROMPT": "Generate constraint prompt", + "FEW_SHOTS_SYSTEM_PROMPT": "Generate few shots prompt", + "AGENT_VARIABLE_NAME_SYSTEM_PROMPT": "Generate agent var name", + "AGENT_DISPLAY_NAME_SYSTEM_PROMPT": "Generate agent display name", + "AGENT_DESCRIPTION_SYSTEM_PROMPT": "Generate agent description" + } + mock_get_prompt_template.return_value = mock_prompt_config + mock_join_info.return_value = "Joined template content" + + # Mock call_llm_for_system_prompt to raise exception immediately + def mock_llm_call_error(model_id, content, sys_prompt, callback, tenant_id): + if "duty" in sys_prompt.lower(): + raise Exception("LLM connection error") + # Other prompts work normally + if callback: + callback(f"Content for {sys_prompt}") + return f"Content for {sys_prompt}" + + mock_call_llm.side_effect = mock_llm_call_error + + # Execute - should raise the exception during iteration + result_list = [] + with self.assertRaises(Exception) as context: + for result in generate_system_prompt( + [{"name": "agent1"}], + "Test task", + [{"name": "tool1"}], + "tenant123", + self.test_model_id, + "zh" + ): + result_list.append(result) + + self.assertIn("LLM connection error", str(context.exception)) + + @patch('backend.services.prompt_service.call_llm_for_system_prompt') + @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') + @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') + def test_generate_system_prompt_error_during_streaming( + self, + mock_get_prompt_template, + mock_join_info, + mock_call_llm, + ): + """Test generate_system_prompt handles error that occurs during streaming (line 330-331)""" + # Setup + mock_prompt_config = { + "USER_PROMPT": "Test user prompt template", + "DUTY_SYSTEM_PROMPT": "Generate duty prompt", + "CONSTRAINT_SYSTEM_PROMPT": "Generate constraint prompt", + "FEW_SHOTS_SYSTEM_PROMPT": "Generate few shots prompt", + "AGENT_VARIABLE_NAME_SYSTEM_PROMPT": "Generate agent var name", + "AGENT_DISPLAY_NAME_SYSTEM_PROMPT": "Generate agent display name", + "AGENT_DESCRIPTION_SYSTEM_PROMPT": "Generate agent description" + } + mock_get_prompt_template.return_value = mock_prompt_config + mock_join_info.return_value = "Joined template content" + + # Track which call we're on + call_count = {"count": 0} + + # Mock call_llm to succeed initially then fail after some streaming + def mock_llm_call_error_after_first( + model_id, content, sys_prompt, callback, tenant_id + ): + call_count["count"] += 1 + + # First few calls succeed + if call_count["count"] <= 3: + if callback: + callback(f"Content for {sys_prompt}") + return f"Content for {sys_prompt}" + else: + # Later calls fail + raise Exception("LLM error during generation") + + mock_call_llm.side_effect = mock_llm_call_error_after_first + + # Execute - error should be raised during streaming + result_list = [] + with self.assertRaises(Exception) as context: + for result in generate_system_prompt( + [{"name": "agent1"}], + "Test task", + [{"name": "tool1"}], + "tenant123", + self.test_model_id, + "zh" + ): + result_list.append(result) + + # Should eventually raise an exception + self.assertIn("LLM error during generation", str(context.exception)) + + @patch('backend.services.prompt_service.query_tools_by_ids') + @patch('backend.services.prompt_service.get_enable_tool_id_by_agent_id') + def test_get_enabled_tool_description_for_generate_prompt_empty_tool_ids( + self, + mock_get_enable_tool_ids, + mock_query_tools, + ): + """Test get_enabled_tool_description_for_generate_prompt with empty tool IDs""" + from backend.services.prompt_service import get_enabled_tool_description_for_generate_prompt + + # Setup - return empty list + mock_get_enable_tool_ids.return_value = [] + mock_query_tools.return_value = [] + + result = get_enabled_tool_description_for_generate_prompt( + agent_id=123, tenant_id="tenant-x" + ) + + # Should return empty list + self.assertEqual(result, []) + + @patch('backend.services.prompt_service.search_agent_info_by_agent_id') + @patch('backend.services.prompt_service.query_sub_agents_id_list') + def test_get_enabled_sub_agent_description_for_generate_prompt_empty( + self, + mock_query_sub_ids, + mock_search_agent, + ): + """Test get_enabled_sub_agent_description_for_generate_prompt with empty sub-agent IDs""" + from backend.services.prompt_service import get_enabled_sub_agent_description_for_generate_prompt + + # Setup - return empty list + mock_query_sub_ids.return_value = [] + + result = get_enabled_sub_agent_description_for_generate_prompt( + agent_id=99, tenant_id="tenant-y" + ) + + # Should return empty list + self.assertEqual(result, []) + mock_search_agent.assert_not_called() + + @patch('backend.services.prompt_service.Template') + def test_join_info_for_generate_system_prompt_english(self, mock_template): + """Test join_info_for_generate_system_prompt with English language""" + # Setup + mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_sub_agents = [ + {"name": "agent1", "description": "Agent 1 desc"} + ] + mock_task_description = "Test task" + mock_tools = [ + {"name": "tool1", "description": "Tool 1 desc", + "inputs": "input1", "output_type": "output1"} + ] + + mock_template_instance = MagicMock() + mock_template.return_value = mock_template_instance + mock_template_instance.render.return_value = "Rendered content" + + # Execute with English language + result = join_info_for_generate_system_prompt( + mock_prompt_for_generate, mock_sub_agents, mock_task_description, mock_tools, + language="en" + ) + + # Assert + self.assertEqual(result, "Rendered content") + # Check that English labels are used + call_args = mock_template_instance.render.call_args[0][0] + self.assertEqual(call_args["task_description"], mock_task_description) + + @patch('backend.services.prompt_service.Template') + def test_join_info_for_generate_system_prompt_empty_tools_and_agents(self, mock_template): + """Test join_info_for_generate_system_prompt with empty tools and sub-agents""" + # Setup + mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_sub_agents = [] + mock_task_description = "Test task" + mock_tools = [] + + mock_template_instance = MagicMock() + mock_template.return_value = mock_template_instance + mock_template_instance.render.return_value = "Rendered content" + + # Execute + result = join_info_for_generate_system_prompt( + mock_prompt_for_generate, mock_sub_agents, mock_task_description, mock_tools + ) + + # Assert + self.assertEqual(result, "Rendered content") diff --git a/test/backend/utils/test_llm_utils.py b/test/backend/utils/test_llm_utils.py index b34a58b71..2c43ea01c 100644 --- a/test/backend/utils/test_llm_utils.py +++ b/test/backend/utils/test_llm_utils.py @@ -139,6 +139,9 @@ def test_call_llm_for_system_prompt_success(self, mocker: MockFixture): ) def test_call_llm_for_system_prompt_exception(self, mocker: MockFixture): + from consts.error_code import ErrorCode + from consts.exceptions import AppException + mock_get_model_by_id = mocker.patch('backend.utils.llm_utils.get_model_by_model_id') mock_get_model_name = mocker.patch('backend.utils.llm_utils.get_model_name_from_config') mock_openai = mocker.patch('backend.utils.llm_utils.OpenAIModel') @@ -155,14 +158,15 @@ def test_call_llm_for_system_prompt_exception(self, mocker: MockFixture): mock_llm_instance.client.chat.completions.create.side_effect = Exception("LLM error") mock_llm_instance._prepare_completion_kwargs.return_value = {} - with pytest.raises(Exception) as exc_info: + with pytest.raises(AppException) as exc_info: call_llm_for_system_prompt( 1, "user prompt", "system prompt", ) - assert "LLM error" in str(exc_info.value) + # Verify AppException is raised with correct error code for unmapped errors + assert exc_info.value.error_code == ErrorCode.MODEL_PROMPT_GENERATION_FAILED class TestProcessThinkingTokens: @@ -813,3 +817,289 @@ def test_call_llm_for_system_prompt_exception_logging(self, mocker: MockFixture) mock_logger.error.assert_called_once() call_args = mock_logger.error.call_args[0][0] assert "Failed to generate prompt" in call_args + + +class TestCallLLMForSystemPromptErrorHandling: + """Tests for error handling in call_llm_for_system_prompt function.""" + + def _create_mock_llm_setup(self, mocker: MockFixture): + """Helper to setup common mocks for LLM error tests.""" + mock_get_model_by_id = mocker.patch('backend.utils.llm_utils.get_model_by_model_id') + mock_get_model_name = mocker.patch('backend.utils.llm_utils.get_model_name_from_config') + mock_openai = mocker.patch('backend.utils.llm_utils.OpenAIModel') + + mock_get_model_by_id.return_value = {"base_url": "http://example.com", "api_key": "fake-key"} + mock_get_model_name.return_value = "gpt-4" + + mock_llm_instance = mock_openai.return_value + mock_llm_instance._prepare_completion_kwargs.return_value = {} + + return mock_llm_instance + + def test_error_401_api_key_invalid(self, mocker: MockFixture): + """Test error handling for 401 status code - API key invalid.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Error 401: Invalid API key" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_API_KEY_INVALID + + def test_error_unauthorized_lowercase(self, mocker: MockFixture): + """Test error handling for 'unauthorized' in error message.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Unauthorized access to the resource" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_API_KEY_INVALID + + def test_error_api_key_in_message(self, mocker: MockFixture): + """Test error handling for 'api key' in error message.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Invalid API key provided" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_API_KEY_INVALID + + def test_error_403_forbidden(self, mocker: MockFixture): + """Test error handling for 403 status code - no permission.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Error 403: Access forbidden" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_API_KEY_NO_PERMISSION + + def test_error_forbidden_lowercase(self, mocker: MockFixture): + """Test error handling for 'forbidden' in error message.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Request forbidden by the server" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_API_KEY_NO_PERMISSION + + def test_error_404_not_found(self, mocker: MockFixture): + """Test error handling for 404 status code - model not found.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Error 404: Model not found" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_NOT_FOUND + + def test_error_not_found_lowercase(self, mocker: MockFixture): + """Test error handling for 'not found' in error message.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "The requested model was not found" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_NOT_FOUND + + def test_error_429_rate_limit(self, mocker: MockFixture): + """Test error handling for 429 status code - rate limit exceeded.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Error 429: Rate limit exceeded" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_RATE_LIMIT_EXCEEDED + + def test_error_rate_limit_lowercase(self, mocker: MockFixture): + """Test error handling for 'rate limit' in error message.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Too many requests, rate limit reached" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_RATE_LIMIT_EXCEEDED + + def test_error_500_service_unavailable(self, mocker: MockFixture): + """Test error handling for 500 status code - service unavailable.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Error 500: Internal server error" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_SERVICE_UNAVAILABLE + + def test_error_502_service_unavailable(self, mocker: MockFixture): + """Test error handling for 502 status code - bad gateway.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Error 502: Bad gateway" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_SERVICE_UNAVAILABLE + + def test_error_503_service_unavailable(self, mocker: MockFixture): + """Test error handling for 503 status code - service unavailable.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Error 503: Service temporarily unavailable" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_SERVICE_UNAVAILABLE + + def test_error_504_service_unavailable(self, mocker: MockFixture): + """Test error handling for 504 status code - gateway timeout.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Error 504: Gateway timeout" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_SERVICE_UNAVAILABLE + + def test_error_connection_error(self, mocker: MockFixture): + """Test error handling for connection error.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Connection error: Unable to reach the server" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_CONNECTION_ERROR + + def test_error_timeout(self, mocker: MockFixture): + """Test error handling for timeout error.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Request timeout occurred" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_CONNECTION_ERROR + + def test_error_connection_refused(self, mocker: MockFixture): + """Test error handling for connection refused error.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Connection refused by the server" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_CONNECTION_ERROR + + def test_error_generic_unmapped_error(self, mocker: MockFixture): + """Test error handling for generic unmapped errors.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception( + "Some unexpected error occurred" + ) + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_PROMPT_GENERATION_FAILED + + def test_error_empty_message(self, mocker: MockFixture): + """Test error handling for exception with empty message.""" + from consts.error_code import ErrorCode + from consts.exceptions import AppException + + mock_llm_instance = self._create_mock_llm_setup(mocker) + mock_llm_instance.client.chat.completions.create.side_effect = Exception() + + with pytest.raises(AppException) as exc_info: + call_llm_for_system_prompt(1, "user prompt", "system prompt") + + assert exc_info.value.error_code == ErrorCode.MODEL_PROMPT_GENERATION_FAILED \ No newline at end of file From 5909e4f1fc3c58943befd39bf662e12a1328b12b Mon Sep 17 00:00:00 2001 From: zwb <1194371519@qq.com> Date: Mon, 9 Mar 2026 11:21:01 +0800 Subject: [PATCH 30/75] Delete install of LibreOffice --- .github/workflows/auto-unit-test.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/auto-unit-test.yml b/.github/workflows/auto-unit-test.yml index 29cf3a42d..6addafa22 100644 --- a/.github/workflows/auto-unit-test.yml +++ b/.github/workflows/auto-unit-test.yml @@ -48,9 +48,6 @@ jobs: uv pip install -e "../sdk[dev]" cd .. - - name: Install LibreOffice - run: sudo apt-get update && sudo apt-get install -y libreoffice - - name: Run all tests and collect coverage run: | source backend/.venv/bin/activate && python test/run_all_test.py From 622f31806ae59e0eff98f95226b4b099de52555a Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Mon, 9 Mar 2026 11:52:14 +0800 Subject: [PATCH 31/75] =?UTF-8?q?=E2=9C=A8=20Enhance=20test=20configuratio?= =?UTF-8?q?ns:=20Add=20language=20support,=20message=20roles,=20and=20thin?= =?UTF-8?q?k=20patterns;=20mock=20error=20codes=20and=20exceptions=20for?= =?UTF-8?q?=20improved=20test=20reliability.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/backend/test_cluster_summarization.py | 10 ++++++++++ test/backend/test_document_vector_integration.py | 11 +++++++++++ test/backend/test_document_vector_utils.py | 10 ++++++++++ test/backend/test_document_vector_utils_coverage.py | 12 +++++++++++- test/backend/test_summary_formatting.py | 11 +++++++++++ 5 files changed, 53 insertions(+), 1 deletion(-) diff --git a/test/backend/test_cluster_summarization.py b/test/backend/test_cluster_summarization.py index e6edd46b3..82af6d5ba 100644 --- a/test/backend/test_cluster_summarization.py +++ b/test/backend/test_cluster_summarization.py @@ -26,9 +26,19 @@ consts_const_mock.POSTGRES_DB = "test_db" consts_const_mock.POSTGRES_PORT = 5432 consts_const_mock.LANGUAGE = {"ZH": "zh", "EN": "en"} +consts_const_mock.MESSAGE_ROLE = {"USER": "user", "ASSISTANT": "assistant", "SYSTEM": "system"} +consts_const_mock.THINK_START_PATTERN = "" +consts_const_mock.THINK_END_PATTERN = "" consts_mock.const = consts_const_mock +# Mock consts.error_code and consts.exceptions +consts_error_code_mock = MagicMock() +consts_error_code_mock.ErrorCode = MagicMock() +consts_exceptions_mock = MagicMock() +consts_exceptions_mock.AppException = Exception sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock +sys.modules['consts.error_code'] = consts_error_code_mock +sys.modules['consts.exceptions'] = consts_exceptions_mock # Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/backend/test_document_vector_integration.py b/test/backend/test_document_vector_integration.py index 8e05abe86..4fb094618 100644 --- a/test/backend/test_document_vector_integration.py +++ b/test/backend/test_document_vector_integration.py @@ -26,9 +26,20 @@ consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" consts_const_mock.POSTGRES_DB = "test_db" consts_const_mock.POSTGRES_PORT = 5432 +consts_const_mock.LANGUAGE = {"ZH": "zh", "EN": "en"} +consts_const_mock.MESSAGE_ROLE = {"USER": "user", "ASSISTANT": "assistant", "SYSTEM": "system"} +consts_const_mock.THINK_START_PATTERN = "" +consts_const_mock.THINK_END_PATTERN = "" consts_mock.const = consts_const_mock +# Mock consts.error_code and consts.exceptions +consts_error_code_mock = MagicMock() +consts_error_code_mock.ErrorCode = MagicMock() +consts_exceptions_mock = MagicMock() +consts_exceptions_mock.AppException = Exception sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock +sys.modules['consts.error_code'] = consts_error_code_mock +sys.modules['consts.exceptions'] = consts_exceptions_mock # Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/backend/test_document_vector_utils.py b/test/backend/test_document_vector_utils.py index 9df14475d..9bce2af29 100644 --- a/test/backend/test_document_vector_utils.py +++ b/test/backend/test_document_vector_utils.py @@ -26,9 +26,19 @@ consts_const_mock.POSTGRES_DB = "test_db" consts_const_mock.POSTGRES_PORT = 5432 consts_const_mock.LANGUAGE = {"ZH": "zh", "EN": "en"} +consts_const_mock.MESSAGE_ROLE = {"USER": "user", "ASSISTANT": "assistant", "SYSTEM": "system"} +consts_const_mock.THINK_START_PATTERN = "" +consts_const_mock.THINK_END_PATTERN = "" consts_mock.const = consts_const_mock +# Mock consts.error_code and consts.exceptions +consts_error_code_mock = MagicMock() +consts_error_code_mock.ErrorCode = MagicMock() +consts_exceptions_mock = MagicMock() +consts_exceptions_mock.AppException = Exception sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock +sys.modules['consts.error_code'] = consts_error_code_mock +sys.modules['consts.exceptions'] = consts_exceptions_mock # Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/backend/test_document_vector_utils_coverage.py b/test/backend/test_document_vector_utils_coverage.py index fc0c69311..23a6923c8 100644 --- a/test/backend/test_document_vector_utils_coverage.py +++ b/test/backend/test_document_vector_utils_coverage.py @@ -24,10 +24,20 @@ consts_const_mock.POSTGRES_USER = "test_user" consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" consts_const_mock.POSTGRES_DB = "test_db" -consts_const_mock.POSTGRES_PORT = 5432 +consts_const_mock.LANGUAGE = {"ZH": "zh", "EN": "en"} +consts_const_mock.MESSAGE_ROLE = {"USER": "user", "ASSISTANT": "assistant", "SYSTEM": "system"} +consts_const_mock.THINK_START_PATTERN = "" +consts_const_mock.THINK_END_PATTERN = "" consts_mock.const = consts_const_mock +# Mock consts.error_code and consts.exceptions +consts_error_code_mock = MagicMock() +consts_error_code_mock.ErrorCode = MagicMock() +consts_exceptions_mock = MagicMock() +consts_exceptions_mock.AppException = Exception sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock +sys.modules['consts.error_code'] = consts_error_code_mock +sys.modules['consts.exceptions'] = consts_exceptions_mock # Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/backend/test_summary_formatting.py b/test/backend/test_summary_formatting.py index 22f8dec36..be9d6a20d 100644 --- a/test/backend/test_summary_formatting.py +++ b/test/backend/test_summary_formatting.py @@ -22,9 +22,20 @@ consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" consts_const_mock.POSTGRES_DB = "test_db" consts_const_mock.POSTGRES_PORT = 5432 +consts_const_mock.LANGUAGE = {"ZH": "zh", "EN": "en"} +consts_const_mock.MESSAGE_ROLE = {"USER": "user", "ASSISTANT": "assistant", "SYSTEM": "system"} +consts_const_mock.THINK_START_PATTERN = "" +consts_const_mock.THINK_END_PATTERN = "" consts_mock.const = consts_const_mock +# Mock consts.error_code and consts.exceptions +consts_error_code_mock = MagicMock() +consts_error_code_mock.ErrorCode = MagicMock() +consts_exceptions_mock = MagicMock() +consts_exceptions_mock.AppException = Exception sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock +sys.modules['consts.error_code'] = consts_error_code_mock +sys.modules['consts.exceptions'] = consts_exceptions_mock # Add backend to path before patching backend modules sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend')) From da4b4530e07f4ae84a10134c7f9bda14807d9926 Mon Sep 17 00:00:00 2001 From: fenghuaof2011 Date: Mon, 9 Mar 2026 13:43:39 +0800 Subject: [PATCH 32/75] Update opensource-memorial-wall.md --- doc/docs/zh/opensource-memorial-wall.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/docs/zh/opensource-memorial-wall.md b/doc/docs/zh/opensource-memorial-wall.md index 54bac7c28..cb8553726 100644 --- a/doc/docs/zh/opensource-memorial-wall.md +++ b/doc/docs/zh/opensource-memorial-wall.md @@ -711,3 +711,7 @@ Nexent 加油!希望能达成所愿! ::: info sisyphus0x - 2026-03-04 对多智能体编排和协同工作很感兴趣,学习一下 ::: + +::: info xingzhewujiang - 2026-03-09 +偶然发现Nexent是一个开源的零代码智能体自动生成平台,非常值的研究与尝试,祝福Nexent让零代码走向AI全球。 +::: From 92444cd68df98880612f8912a198ac9258595c80 Mon Sep 17 00:00:00 2001 From: biansimeng Date: Mon, 9 Mar 2026 15:12:48 +0800 Subject: [PATCH 33/75] Unify tavily search's record starting index as 1 --- sdk/nexent/core/tools/tavily_search_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/nexent/core/tools/tavily_search_tool.py b/sdk/nexent/core/tools/tavily_search_tool.py index df64474b8..1c6fe1418 100644 --- a/sdk/nexent/core/tools/tavily_search_tool.py +++ b/sdk/nexent/core/tools/tavily_search_tool.py @@ -37,7 +37,7 @@ def __init__(self, tavily_api_key:str=Field(description="Tavily API key"), self.tavily = TavilyClient(api_key=tavily_api_key) self.max_results = max_results self.image_filter = image_filter - self.record_ops = 0 # Used to record sequence number + self.record_ops = 1 # Used to record sequence number self.running_prompt_en = "Searching the web..." self.running_prompt_zh = "网络搜索中..." From 866ffceeef7417ce66f8b46610bef9e330ae55f7 Mon Sep 17 00:00:00 2001 From: panyehong <2655992392@qq.com> Date: Mon, 9 Mar 2026 15:41:48 +0800 Subject: [PATCH 34/75] =?UTF-8?q?=F0=9F=90=9B=20Bugfix:=20Fixed=20an=20iss?= =?UTF-8?q?ue=20where=20starting=20a=20container=20resulted=20in=20an=20un?= =?UTF-8?q?clear=20error=20message=20when=20no=20MCP=20image=20was=20avail?= =?UTF-8?q?able.=20#2293=20[Specification=20Detail]=201.=20When=20no=20mir?= =?UTF-8?q?ror=20is=20available,=20the=20backend=20returns=20a=20specific?= =?UTF-8?q?=20error=20message,=20and=20the=20frontend=20adds=20internation?= =?UTF-8?q?alized=20error=20messages.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/apps/remote_mcp_app.py | 8 +- frontend/hooks/useMcpConfig.ts | 6 +- frontend/public/locales/en/common.json | 1 + frontend/public/locales/zh/common.json | 1 + frontend/services/mcpService.ts | 20 ++- test/backend/app/test_remote_mcp_app.py | 157 ++++++++++++++++++++++++ 6 files changed, 188 insertions(+), 5 deletions(-) diff --git a/backend/apps/remote_mcp_app.py b/backend/apps/remote_mcp_app.py index cfc82146b..009e5cffa 100644 --- a/backend/apps/remote_mcp_app.py +++ b/backend/apps/remote_mcp_app.py @@ -387,7 +387,13 @@ async def add_mcp_from_config( except MCPContainerError as e: logger.error( f"Failed to start MCP container {service_name}: {e}") - errors.append(f"{service_name}: {str(e)}") + error_str = str(e) + # Check if error is related to image not found + if "not found" in error_str.lower() or "404" in error_str: + errors.append( + f"{service_name}: Image not found - MCP service startup image is missing") + else: + errors.append(f"{service_name}: {error_str}") except Exception as e: logger.error( f"Unexpected error adding MCP {service_name}: {e}") diff --git a/frontend/hooks/useMcpConfig.ts b/frontend/hooks/useMcpConfig.ts index 8478e931a..386a777bf 100644 --- a/frontend/hooks/useMcpConfig.ts +++ b/frontend/hooks/useMcpConfig.ts @@ -255,7 +255,11 @@ export function useMcpConfig(options: UseMcpConfigOptions = {}) { options.onContainerAdded?.(); return { success: true, messageKey: "mcpService.message.addContainerSuccess" }; } else { - return { success: false, message: result.message, messageKey: "mcpConfig.message.addContainerFailed" }; + return { + success: false, + message: result.message, + messageKey: (result as any).messageKey || "mcpConfig.message.addContainerFailed" + }; } } catch (error) { log.error("Failed to add container:", error); diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 775eae675..d4282b83c 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -1082,6 +1082,7 @@ "mcpService.message.invalidUploadParameters": "Invalid upload parameters", "mcpService.message.serviceNameAlreadyExists": "MCP service name already exists", "mcpService.message.fileTooLarge": "File size exceeds limit", + "mcpService.message.missingMcpImage": "Failed to add container: MCP service startup image is missing", "agentConfig.tools.refreshSuccess": "Tool list refreshed successfully", "agentConfig.tools.refreshFailed": "Failed to refresh tool list", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 88ef18fdc..6ca160b47 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -1084,6 +1084,7 @@ "mcpService.message.invalidUploadParameters": "上传参数无效", "mcpService.message.serviceNameAlreadyExists": "MCP服务名称已存在", "mcpService.message.fileTooLarge": "文件大小超过限制", + "mcpService.message.missingMcpImage": "添加容器失败:缺少mcp服务启动镜像", "agentConfig.tools.refreshSuccess": "工具列表已刷新", "agentConfig.tools.refreshFailed": "刷新工具列表失败", diff --git a/frontend/services/mcpService.ts b/frontend/services/mcpService.ts index 1b656cd8d..20383809f 100644 --- a/frontend/services/mcpService.ts +++ b/frontend/services/mcpService.ts @@ -433,17 +433,30 @@ export const addMcpFromConfig = async (mcpConfig: { mcpServers: Record Date: Mon, 9 Mar 2026 20:11:29 +0800 Subject: [PATCH 35/75] =?UTF-8?q?=F0=9F=90=9B=20Bugfix:=20Fixed=20the=20is?= =?UTF-8?q?sue=20of=20suadmin=20account=20not=20being=20generated=20in=20i?= =?UTF-8?q?nfrastructure=20mode.=20#2556=20=E2=99=BB=EF=B8=8F=20Improvemen?= =?UTF-8?q?t:=20During=20deployment,=20the=20user=20is=20prompted=20to=20e?= =?UTF-8?q?nter=20the=20password=20for=20the=20suadmin=20user.=20#2531=20[?= =?UTF-8?q?Specification=20Detail]=201.=20In=20infrastructure=20mode,=20th?= =?UTF-8?q?e=20supabase-db-mini=20container=20is=20used=20to=20perform=20o?= =?UTF-8?q?perations=20such=20as=20creating=20the=20su=20user.=202.=20If?= =?UTF-8?q?=20the=20suadmin=20user=20is=20not=20detected=20during=20deploy?= =?UTF-8?q?ment,=20the=20user=20will=20be=20prompted=20to=20enter=20their?= =?UTF-8?q?=20password.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docker/create-su.sh | 48 ++++++++++++---- docker/deploy.sh | 136 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 159 insertions(+), 25 deletions(-) diff --git a/docker/create-su.sh b/docker/create-su.sh index 8d290a726..639e64553 100644 --- a/docker/create-su.sh +++ b/docker/create-su.sh @@ -54,10 +54,32 @@ wait_for_postgresql_ready() { create_default_super_admin_user() { local email="suadmin@nexent.com" local password - password="$(generate_random_password)" + + # Get password from command line argument, or generate random one if not provided + if [ -n "$1" ]; then + password="$1" + else + # Fallback to random password if no argument provided (for backward compatibility) + password="$(generate_random_password)" + echo " ⚠️ Warning: No password provided, using random password" + fi echo "🔧 Creating super admin user..." - RESPONSE=$(docker exec nexent-config bash -c "curl -s -X POST http://kong:8000/auth/v1/signup -H \"apikey: ${SUPABASE_KEY}\" -H \"Authorization: Bearer ${SUPABASE_KEY}\" -H \"Content-Type: application/json\" -d '{\"email\":\"${email}\",\"password\":\"${password}\",\"email_confirm\":true}'" 2>/dev/null) + + # Determine which container to use for curl command + local curl_container="nexent-config" + if [ "$DEPLOYMENT_MODE" = "infrastructure" ] || ! docker ps | grep -q "nexent-config"; then + # In infrastructure mode or if nexent-config is not running, use supabase-db-mini + if docker ps | grep -q "supabase-db-mini"; then + curl_container="supabase-db-mini" + echo " ℹ️ Using supabase-db-mini container (infrastructure mode)" + else + echo " ❌ Neither nexent-config nor supabase-db-mini container is available." + return 1 + fi + fi + + RESPONSE=$(docker exec "$curl_container" bash -c "curl -s -X POST http://kong:8000/auth/v1/signup -H \"apikey: ${SUPABASE_KEY}\" -H \"Authorization: Bearer ${SUPABASE_KEY}\" -H \"Content-Type: application/json\" -d '{\"email\":\"${email}\",\"password\":\"${password}\",\"email_confirm\":true}'" 2>/dev/null) if [ -z "$RESPONSE" ]; then echo " ❌ No response received from Supabase." @@ -65,21 +87,24 @@ create_default_super_admin_user() { elif echo "$RESPONSE" | grep -q '"access_token"' && echo "$RESPONSE" | grep -q '"user"'; then echo " ✅ Default super admin user has been successfully created." echo "" - echo " Please save the following credentials carefully, which would ONLY be shown once." + echo " Please save the following credentials carefully." echo " 📧 Email: ${email}" - echo " 🔏 Password: ${password}" + if [ -n "$1" ]; then + echo " 🔏 Password: [User provided password]" + else + echo " 🔏 Password: ${password}" + fi # Extract user.id from RESPONSE JSON local user_id - # Try using Python to parse JSON (most reliable) - user_id=$(echo "$RESPONSE" | docker exec -i nexent-config python3 -c "import sys, json; data = json.load(sys.stdin); print(data.get('user', {}).get('id', ''))" 2>/dev/null) - - # Fallback to jq if Python fails - if [ -z "$user_id" ] && command -v jq >/dev/null 2>&1; then + # Try using jq first (if available in the container or on host) + if docker exec "$curl_container" command -v jq >/dev/null 2>&1; then + user_id=$(echo "$RESPONSE" | docker exec -i "$curl_container" jq -r '.user.id // empty' 2>/dev/null) + elif command -v jq >/dev/null 2>&1; then user_id=$(echo "$RESPONSE" | jq -r '.user.id // empty' 2>/dev/null) fi - # Final fallback: use grep and sed + # Fallback: use grep and sed (works without any special tools) if [ -z "$user_id" ]; then user_id=$(echo "$RESPONSE" | grep -o '"user"[^}]*"id":"[^"]*"' | sed -n 's/.*"id":"\([^"]*\)".*/\1/p' 2>/dev/null) fi @@ -150,7 +175,8 @@ create_default_super_admin_user() { } # Main execution -if create_default_super_admin_user; then +# Pass password as first argument if provided +if create_default_super_admin_user "$1"; then exit 0 else exit 1 diff --git a/docker/deploy.sh b/docker/deploy.sh index 83d3f7947..7676ecf60 100755 --- a/docker/deploy.sh +++ b/docker/deploy.sh @@ -865,24 +865,98 @@ select_terminal_tool() { echo "" } -generate_random_password() { - # Generate a URL/JSON safe random password (alphanumeric only) - local pwd="" - if command -v openssl >/dev/null 2>&1; then - pwd=$(openssl rand -base64 32 | tr -dc 'A-Za-z0-9' | head -c 20) - else - pwd=$(tr -dc 'A-Za-z0-9' /dev/null | tr -d '[:space:]') + if [ "$user_exists" = "1" ]; then + return 0 # User exists + elif [ "$user_exists" = "0" ]; then + return 1 # User does not exist + fi fi - if [ -z "$pwd" ]; then - # Fallback (should be extremely rare) - pwd=$(date +%s%N | tr -dc '0-9' | head -c 20) + + # Fallback: Try to sign in with a dummy password to check if user exists + # This is less reliable but works when database access is not available + local test_response + test_response=$(docker exec "$curl_container" bash -c "curl -s -X POST http://kong:8000/auth/v1/token?grant_type=password -H \"apikey: ${SUPABASE_KEY}\" -H \"Content-Type: application/json\" -d '{\"email\":\"${email}\",\"password\":\"dummy_password_check\"}'" 2>/dev/null) + + if echo "$test_response" | grep -q '"error_code":"invalid_credentials"'; then + return 0 # User exists (wrong password means user exists) + elif echo "$test_response" | grep -q '"error_code":"email_not_confirmed"'; then + return 0 # User exists + else + return 1 # User likely does not exist fi - echo "$pwd" +} + +prompt_super_admin_password() { + # Prompt user to enter password for super admin user with confirmation + # Note: All prompts go to stderr, only password is returned via stdout + local password="" + local password_confirm="" + local max_attempts=3 + local attempts=0 + + echo "" >&2 + echo "🔐 Super Admin User Password Setup" >&2 + echo " Email: suadmin@nexent.com" >&2 + echo "" >&2 + + while [ $attempts -lt $max_attempts ]; do + # First password input + echo " 🔐 Please enter password for super admin user:" >&2 + read -s password + echo "" >&2 + + # Check if password is empty + if [ -z "$password" ]; then + echo " ❌ Password cannot be empty. Please try again." >&2 + attempts=$((attempts + 1)) + continue + fi + + # Confirm password input + echo " 🔐 Please confirm the password:" >&2 + read -s password_confirm + echo "" >&2 + + # Check if passwords match + if [ "$password" != "$password_confirm" ]; then + echo " ❌ Passwords do not match. Please try again." >&2 + attempts=$((attempts + 1)) + continue + fi + + # Passwords match, return the password via stdout + echo "$password" + return 0 + done + + # Max attempts reached + echo " ❌ Maximum attempts reached. Failed to set password." >&2 + return 1 } create_default_super_admin_user() { # Call the dedicated script for creating super admin user local script_path="$SCRIPT_DIR/create-su.sh" + local email="suadmin@nexent.com" if [ ! -f "$script_path" ]; then echo " ❌ ERROR create-su.sh not found at $script_path" @@ -892,15 +966,43 @@ create_default_super_admin_user() { # Make sure the script is executable chmod +x "$script_path" + # Check if super admin user already exists + echo "" + echo "🔍 Checking if super admin user exists..." + local check_result + check_super_admin_user_exists + check_result=$? + + if [ $check_result -eq 0 ]; then + echo " ✅ Super admin user (${email}) already exists." + echo " 💡 Skipping user creation. If you need to reset the password, please do so manually." + return 0 + elif [ $check_result -eq 1 ]; then + echo " ℹ️ Super admin user (${email}) does not exist. Proceeding with creation..." + else + echo " ⚠️ Warning: Could not determine if user exists. Proceeding with creation..." + fi + + # Prompt for password + local password + password="$(prompt_super_admin_password)" + local prompt_result=$? + + if [ $prompt_result -ne 0 ] || [ -z "$password" ]; then + echo " ❌ Failed to get password from user." + return 1 + fi + # Export necessary environment variables for the script export SUPABASE_KEY export POSTGRES_USER export POSTGRES_DB export DEPLOYMENT_VERSION export SUPABASE_POSTGRES_DB + export DEPLOYMENT_MODE - # Execute the script with current environment variables - if bash "$script_path"; then + # Execute the script with password as argument + if bash "$script_path" "$password"; then return 0 else return 1 @@ -939,7 +1041,7 @@ main_deploy() { echo "--------------------------------" echo "" - APP_VERSION="$(get_app_version)" + APP_VERSION="latest" if [ -z "$APP_VERSION" ]; then echo "❌ Failed to get app version, please check the backend/consts/const.py file" exit 1 @@ -984,6 +1086,12 @@ main_deploy() { # Special handling for infrastructure mode if [ "$DEPLOYMENT_MODE" = "infrastructure" ]; then generate_env_for_infrastructure || { echo "❌ Environment generation failed"; exit 1; } + + # Create default super admin user (only for full version) + if [ "$DEPLOYMENT_VERSION" = "full" ]; then + create_default_super_admin_user || { echo "❌ Default super admin user creation failed"; exit 1; } + fi + echo "🎉 Infrastructure deployment completed successfully!" echo " You can now start the core services manually using dev containers" echo " Environment file available at: $(cd .. && pwd)/.env" From 0ef721e3ea27a7949c3efd0c34910e235cfc3fab Mon Sep 17 00:00:00 2001 From: biansimeng Date: Mon, 9 Mar 2026 20:16:02 +0800 Subject: [PATCH 36/75] Unify model list logic to show only availiable LLMs --- .../[locale]/knowledges/components/document/DocumentList.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx index 01c074045..f4cc9c341 100644 --- a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx +++ b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx @@ -283,7 +283,7 @@ const DocumentListContainer = forwardRef( setIsLoadingModels(true); try { const models = await modelService.getLLMModels(); - setAvailableModels(models); + setAvailableModels(models.filter(m => m.connect_status === "available")); // Determine initial selection order: // 1) Knowledge base's own configured model (server-side config) From edef55f9f8095f1fee6e0875303816b056641532 Mon Sep 17 00:00:00 2001 From: panyehong <2655992392@qq.com> Date: Tue, 10 Mar 2026 11:11:52 +0800 Subject: [PATCH 37/75] =?UTF-8?q?=F0=9F=90=9B=20Bugfix:=20Fix=20the=20issu?= =?UTF-8?q?e=20where=20MCP=20services=20with=20the=20same=20tool=20cannot?= =?UTF-8?q?=20be=20displayed.=20#2294=20[Specification=20Details]=201.=20M?= =?UTF-8?q?odify=20the=20backend=20logic=20of=20the=20scan=20and=20update?= =?UTF-8?q?=20tools=20to=20use=20tool=20name,=20source,=20and=20usage=20as?= =?UTF-8?q?=20unique=20identifiers.=202.=20The=20front-end=20should=20prov?= =?UTF-8?q?ide=20a=20prompt=20when=20selecting=20a=20tool=20with=20the=20s?= =?UTF-8?q?ame=20name.=203.=20Add=20test=20cases.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/database/tool_db.py | 30 +- .../components/agentConfig/ToolManagement.tsx | 43 +- frontend/public/locales/en/common.json | 4 + frontend/public/locales/zh/common.json | 4 + test/backend/database/test_tool_db.py | 690 ++++++++++++++++-- 5 files changed, 697 insertions(+), 74 deletions(-) diff --git a/backend/database/tool_db.py b/backend/database/tool_db.py index 0001315a7..0514bc945 100644 --- a/backend/database/tool_db.py +++ b/backend/database/tool_db.py @@ -4,6 +4,7 @@ from database.agent_db import logger from database.client import get_db_session, filter_property, as_dict from database.db_models import ToolInstance, ToolInfo +from consts.model import ToolSourceEnum def create_tool(tool_info, version_no: int = 0): @@ -190,13 +191,23 @@ def check_tool_list_initialized(tenant_id: str) -> bool: def update_tool_table_from_scan_tool_list(tenant_id: str, user_id: str, tool_list: List[ToolInfo]): """ scan all tools and update the tool table in PG database, remove the duplicate tools + For MCP tools, use name&source&usage as unique key to allow same tool name from different MCP servers """ with get_db_session() as session: # get all existing tools (including complete information) existing_tools = session.query(ToolInfo).filter(ToolInfo.delete_flag != 'Y', ToolInfo.author == tenant_id).all() - existing_tool_dict = { - f"{tool.name}&{tool.source}": tool for tool in existing_tools} + # Build existing_tool_dict with different keys for MCP vs non-MCP tools + existing_tool_dict = {} + for tool in existing_tools: + if tool.source == ToolSourceEnum.MCP.value: + # For MCP tools, use name + source + usage (MCP server name) as unique key + key = f"{tool.name}&{tool.source}&{tool.usage or ''}" + else: + # For other tools, use name + source as unique key + key = f"{tool.name}&{tool.source}" + existing_tool_dict[key] = tool + # set all tools to unavailable for tool in existing_tools: tool.is_available = False @@ -208,9 +219,15 @@ def update_tool_table_from_scan_tool_list(tenant_id: str, user_id: str, tool_lis is_available = True if re.match( r'^[a-zA-Z_][a-zA-Z0-9_]*$', tool.name) is not None else False - if f"{tool.name}&{tool.source}" in existing_tool_dict: - # by tool name and source to update the existing tool - existing_tool = existing_tool_dict[f"{tool.name}&{tool.source}"] + # Use same key generation logic as above + if tool.source == ToolSourceEnum.MCP.value: + tool_key = f"{tool.name}&{tool.source}&{tool.usage or ''}" + else: + tool_key = f"{tool.name}&{tool.source}" + + if tool_key in existing_tool_dict: + # by tool name, source, and usage (for MCP) to update the existing tool + existing_tool = existing_tool_dict[tool_key] for key, value in filtered_tool_data.items(): setattr(existing_tool, key, value) existing_tool.updated_by = user_id @@ -308,6 +325,7 @@ def delete_tools_by_agent_id(agent_id, tenant_id, user_id, version_no: int = 0): ToolInstance.delete_flag: 'Y', 'updated_by': user_id }) + def search_last_tool_instance_by_tool_id(tool_id: int, tenant_id: str, user_id: str, version_no: int = 0): """ Query the latest ToolInstance by tool_id. @@ -331,4 +349,4 @@ def search_last_tool_instance_by_tool_id(tool_id: int, tenant_id: str, user_id: ToolInstance.delete_flag != 'Y' ).order_by(ToolInstance.update_time.desc()) tool_instance = query.first() - return as_dict(tool_instance) if tool_instance else None \ No newline at end of file + return as_dict(tool_instance) if tool_instance else None diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index f5815a094..d4eb0e2ac 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -12,6 +12,7 @@ import { usePrefetchKnowledgeBases } from "@/hooks/useKnowledgeBaseSelector"; import { useConfig } from "@/hooks/useConfig"; import { updateToolConfig } from "@/services/agentConfigService"; import { useQueryClient } from "@tanstack/react-query"; +import { useConfirmModal } from "@/hooks/useConfirmModal"; import { Settings, AlertTriangle } from "lucide-react"; @@ -74,6 +75,7 @@ export default function ToolManagement({ }: ToolManagementProps) { const { t } = useTranslation("common"); const queryClient = useQueryClient(); + const { confirm } = useConfirmModal(); // Get current agent permission from store const currentAgentPermission = useAgentConfigStore( @@ -277,7 +279,46 @@ export default function ToolManagement({ ); updateTools(newSelectedTools); } else { - // If not selected, determine tool params and check if modal is needed + // If not selected, check for duplicate tool names first + const duplicateTool = currentSelectdTools.find( + (selectedTool) => selectedTool.name === tool.name + ); + + if (duplicateTool) { + // Show confirmation modal for duplicate tool name + return new Promise((resolve) => { + confirm({ + title: t("toolPool.duplicateToolName.title"), + content: t("toolPool.duplicateToolName.content", { + toolName: tool.name, + }), + okText: t("toolPool.duplicateToolName.confirm"), + cancelText: t("toolPool.duplicateToolName.cancel"), + danger: true, + onOk: async () => { + // User confirmed, proceed with tool selection + await proceedWithToolSelection(); + resolve(); + }, + onCancel: () => { + // User cancelled, do nothing + resolve(); + }, + }); + }); + } + + // No duplicate, proceed with normal tool selection + await proceedWithToolSelection(); + } + + // Helper function to proceed with tool selection after duplicate check + async function proceedWithToolSelection() { + // Get latest tools again to ensure we have the most up-to-date list + const currentSelectdTools = + useAgentConfigStore.getState().editedAgent.tools; + + // Determine tool params and check if modal is needed const configuredTool = currentSelectdTools.find( (t) => parseInt(t.id) === numericId ); diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 986140c83..ef3ac7915 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -439,6 +439,10 @@ "toolPool.vlmDisabledTooltip": "Please contact your administrator to configure an available Vision Language Model", "toolPool.embeddingDisabledTooltip": "Please contact your administrator to configure an available Embedding model", "toolPool.tooltip.functionGuide": "1. For local knowledge base search functionality, please enable the knowledge_base_search tool;\n2. For text file parsing functionality, please enable the analyze_text_file tool;\n3. For image parsing functionality, please enable the analyze_image tool.", + "toolPool.duplicateToolName.title": "Duplicate Tool Name Detected", + "toolPool.duplicateToolName.content": "You have selected tools with the same name ({{toolName}}). Duplicate tool names will cause the agent to fail during runtime. Do you want to continue selecting this tool?", + "toolPool.duplicateToolName.confirm": "Continue", + "toolPool.duplicateToolName.cancel": "Cancel", "tool.message.unavailable": "This tool is currently unavailable and cannot be selected", "tool.error.noMainAgentId": "Main Agent ID is not set, cannot update tool status", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index b830b1792..aaa9a4b54 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -442,6 +442,10 @@ "toolPool.vlmDisabledTooltip": "请联系管理员配置可用的视觉语言模型", "toolPool.embeddingDisabledTooltip": "请联系管理员配置可用的向量模型", "toolPool.tooltip.functionGuide": "1. 本地知识库检索功能,请启用knowledge_base_search工具;\n2. 文本文件解析功能,请启用analyze_text_file工具;\n3. 图片解析功能,请启用analyze_image工具。", + "toolPool.duplicateToolName.title": "检测到重复工具名", + "toolPool.duplicateToolName.content": "您已勾选相同工具名的工具({{toolName}}),重复选择会导致智能体无法正常运行。是否继续勾选?", + "toolPool.duplicateToolName.confirm": "继续", + "toolPool.duplicateToolName.cancel": "取消", "tool.message.unavailable": "该工具当前不可用,无法选择", "tool.error.noMainAgentId": "主代理ID未设置,无法更新工具状态", diff --git a/test/backend/database/test_tool_db.py b/test/backend/database/test_tool_db.py index 604997187..e37f13ffa 100644 --- a/test/backend/database/test_tool_db.py +++ b/test/backend/database/test_tool_db.py @@ -1,3 +1,19 @@ +from backend.database.tool_db import ( + create_tool, + create_or_update_tool_by_tool_info, + query_all_tools, + query_tool_instances_by_id, + query_tool_instances_by_agent_id, + query_tools_by_ids, + query_all_enabled_tool_instances, + update_tool_table_from_scan_tool_list, + add_tool_field, + search_tools_for_sub_agent, + check_tool_is_available, + delete_tools_by_agent_id, + search_last_tool_instance_by_tool_id, + check_tool_list_initialized +) import sys import pytest from unittest.mock import patch, MagicMock @@ -18,14 +34,39 @@ consts_mock.const.POSTGRES_PORT = 5432 consts_mock.const.DEFAULT_TENANT_ID = "default_tenant" +# Mock consts.model module and ToolSourceEnum +# Create a mock ToolSourceEnum that supports .value attribute access + + +class MockEnumMember: + def __init__(self, value): + self.value = value + + +class MockToolSourceEnum: + LOCAL = MockEnumMember("local") + MCP = MockEnumMember("mcp") + LANGCHAIN = MockEnumMember("langchain") + +# Create consts.model as a proper module-like object + + +class MockModelModule: + ToolSourceEnum = MockToolSourceEnum + + +consts_mock.model = MockModelModule() + # Add the mocked consts module to sys.modules sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_mock.const +sys.modules['consts.model'] = consts_mock.model # Mock utils module utils_mock = MagicMock() utils_mock.auth_utils = MagicMock() -utils_mock.auth_utils.get_current_user_id_from_token = MagicMock(return_value="test_user_id") +utils_mock.auth_utils.get_current_user_id_from_token = MagicMock( + return_value="test_user_id") # Add the mocked utils module to sys.modules sys.modules['utils'] = utils_mock @@ -67,22 +108,7 @@ sys.modules['backend.database.agent_db'] = agent_db_mock # Now we can safely import the module being tested -from backend.database.tool_db import ( - create_tool, - create_or_update_tool_by_tool_info, - query_all_tools, - query_tool_instances_by_id, - query_tool_instances_by_agent_id, - query_tools_by_ids, - query_all_enabled_tool_instances, - update_tool_table_from_scan_tool_list, - add_tool_field, - search_tools_for_sub_agent, - check_tool_is_available, - delete_tools_by_agent_id, - search_last_tool_instance_by_tool_id, - check_tool_list_initialized -) + class MockToolInstance: def __init__(self): @@ -103,6 +129,7 @@ def __init__(self): "delete_flag": "N" } + class MockToolInfo: def __init__(self): self.tool_id = 1 @@ -132,6 +159,7 @@ def __init__(self): "class_name": "TestTool" } + @pytest.fixture def mock_session(): """Create a mock database session""" @@ -140,6 +168,7 @@ def mock_session(): mock_session.query.return_value = mock_query return mock_session, mock_query + def test_create_tool_success(monkeypatch, mock_session): """Test successful tool creation""" session, query = mock_session @@ -148,15 +177,19 @@ def test_create_tool_success(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.filter_property", lambda data, model: data) - monkeypatch.setattr("backend.database.tool_db.ToolInstance", lambda **kwargs: MagicMock()) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) + monkeypatch.setattr("backend.database.tool_db.ToolInstance", + lambda **kwargs: MagicMock()) tool_info = {"tool_id": 1, "agent_id": 1, "tenant_id": "tenant1"} create_tool(tool_info) session.add.assert_called_once() + def test_create_or_update_tool_by_tool_info_update_existing(monkeypatch, mock_session): """Test updating an existing tool instance""" session, query = mock_session @@ -171,7 +204,8 @@ def test_create_or_update_tool_by_tool_info_update_existing(monkeypatch, mock_se mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) tool_info = MagicMock() tool_info.__dict__ = {"agent_id": 1, "tool_id": 1} @@ -180,6 +214,7 @@ def test_create_or_update_tool_by_tool_info_update_existing(monkeypatch, mock_se assert result == mock_tool_instance + def test_create_or_update_tool_by_tool_info_create_new(monkeypatch, mock_session): """Test creating a new tool instance""" session, query = mock_session @@ -192,7 +227,8 @@ def test_create_or_update_tool_by_tool_info_create_new(monkeypatch, mock_session mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) monkeypatch.setattr("backend.database.tool_db.create_tool", MagicMock()) tool_info = MagicMock() @@ -202,6 +238,7 @@ def test_create_or_update_tool_by_tool_info_create_new(monkeypatch, mock_session assert result is None + def test_query_all_tools(monkeypatch, mock_session): """Test querying all tools""" session, query = mock_session @@ -216,8 +253,10 @@ def test_query_all_tools(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.as_dict", lambda obj: obj.__dict__) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.tool_db.as_dict", + lambda obj: obj.__dict__) result = query_all_tools("tenant1") @@ -225,6 +264,7 @@ def test_query_all_tools(monkeypatch, mock_session): assert result[0]["tool_id"] == 1 assert result[0]["name"] == "test_tool" + def test_query_tool_instances_by_id_found(monkeypatch, mock_session): """Test successfully querying tool instances""" session, query = mock_session @@ -239,14 +279,17 @@ def test_query_tool_instances_by_id_found(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.as_dict", lambda obj: obj.__dict__) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.tool_db.as_dict", + lambda obj: obj.__dict__) result = query_tool_instances_by_id(1, 1, "tenant1") assert result["tool_instance_id"] == 1 assert result["tool_id"] == 1 + def test_query_tool_instances_by_id_not_found(monkeypatch, mock_session): """Test querying non-existent tool instances""" session, query = mock_session @@ -259,12 +302,14 @@ def test_query_tool_instances_by_id_not_found(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) result = query_tool_instances_by_id(1, 1, "tenant1") assert result is None + def test_query_tools_by_ids(monkeypatch, mock_session): """Test querying tools by ID list""" session, query = mock_session @@ -281,14 +326,17 @@ def test_query_tools_by_ids(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.as_dict", lambda obj: obj.__dict__) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.tool_db.as_dict", + lambda obj: obj.__dict__) result = query_tools_by_ids([1, 2]) assert len(result) == 1 assert result[0]["tool_id"] == 1 + def test_query_all_enabled_tool_instances(monkeypatch, mock_session): """Test querying all enabled tool instances""" session, query = mock_session @@ -303,14 +351,17 @@ def test_query_all_enabled_tool_instances(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.as_dict", lambda obj: obj.__dict__) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.tool_db.as_dict", + lambda obj: obj.__dict__) result = query_all_enabled_tool_instances(1, "tenant1") assert len(result) == 1 assert result[0]["tool_instance_id"] == 1 + def test_update_tool_table_from_scan_tool_list_success(monkeypatch, mock_session): """Test successfully updating tool table""" session, query = mock_session @@ -327,8 +378,10 @@ def test_update_tool_table_from_scan_tool_list_success(monkeypatch, mock_session mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.filter_property", lambda data, model: data) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) # Create a mock for ToolInfo class with properly accessible attributes mock_tool_info_class = MagicMock() @@ -336,13 +389,15 @@ def test_update_tool_table_from_scan_tool_list_success(monkeypatch, mock_session mock_tool_info_class.author = "tenant1" mock_tool_info_class.name = "test_tool" mock_tool_info_class.source = "test_source" - monkeypatch.setattr("backend.database.tool_db.ToolInfo", mock_tool_info_class) + monkeypatch.setattr("backend.database.tool_db.ToolInfo", + mock_tool_info_class) tool_list = [MockToolInfo()] update_tool_table_from_scan_tool_list("tenant1", "user1", tool_list) # Function executes successfully without throwing exceptions + def test_update_tool_table_from_scan_tool_list_create_new_tool(monkeypatch, mock_session): """Test creating new tool when tool doesn't exist in database""" session, query = mock_session @@ -363,13 +418,16 @@ def test_update_tool_table_from_scan_tool_list_create_new_tool(monkeypatch, mock mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.filter_property", lambda data, model: data) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) # Create a mock for ToolInfo class constructor mock_tool_info_instance = MagicMock() mock_tool_info_class = MagicMock(return_value=mock_tool_info_instance) - monkeypatch.setattr("backend.database.tool_db.ToolInfo", mock_tool_info_class) + monkeypatch.setattr("backend.database.tool_db.ToolInfo", + mock_tool_info_class) # Create a new tool with different name&source that doesn't exist in database new_tool = MockToolInfo() @@ -391,6 +449,7 @@ def test_update_tool_table_from_scan_tool_list_create_new_tool(monkeypatch, mock }) mock_tool_info_class.assert_called_once_with(**expected_call_args) + def test_update_tool_table_from_scan_tool_list_create_new_tool_invalid_name(monkeypatch, mock_session): """Test creating new tool with invalid name (is_available=False)""" session, query = mock_session @@ -411,13 +470,16 @@ def test_update_tool_table_from_scan_tool_list_create_new_tool_invalid_name(monk mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.filter_property", lambda data, model: data) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) # Create a mock for ToolInfo class constructor mock_tool_info_instance = MagicMock() mock_tool_info_class = MagicMock(return_value=mock_tool_info_instance) - monkeypatch.setattr("backend.database.tool_db.ToolInfo", mock_tool_info_class) + monkeypatch.setattr("backend.database.tool_db.ToolInfo", + mock_tool_info_class) # Create a new tool with invalid name (contains special characters) new_tool = MockToolInfo() @@ -439,6 +501,466 @@ def test_update_tool_table_from_scan_tool_list_create_new_tool_invalid_name(monk }) mock_tool_info_class.assert_called_once_with(**expected_call_args) + +def test_update_tool_table_mcp_tools_same_name_different_usage(monkeypatch, mock_session): + """Test MCP tools with same name but different usage (MCP server) should be treated as different tools""" + session, query = mock_session + + # Mock existing tools - one MCP tool from server1 + existing_tool = MockToolInfo() + existing_tool.name = "get_tickets" + existing_tool.source = "mcp" + existing_tool.usage = "mcp_server_1" + + mock_all = MagicMock() + mock_all.return_value = [existing_tool] + mock_filter = MagicMock() + mock_filter.all = mock_all + query.filter.return_value = mock_filter + + session.add = MagicMock() + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) + + # Create a mock for ToolInfo class constructor + mock_tool_info_instance = MagicMock() + mock_tool_info_class = MagicMock(return_value=mock_tool_info_instance) + monkeypatch.setattr("backend.database.tool_db.ToolInfo", + mock_tool_info_class) + + # Create a new MCP tool with same name but different usage (different MCP server) + new_tool = MockToolInfo() + new_tool.name = "get_tickets" + new_tool.source = "mcp" + new_tool.usage = "mcp_server_2" # Different MCP server + tool_list = [new_tool] + + update_tool_table_from_scan_tool_list("tenant1", "user1", tool_list) + + # Verify that session.add was called to add the new tool (different usage = different tool) + session.add.assert_called_once_with(mock_tool_info_instance) + # Verify that ToolInfo constructor was called with correct parameters + expected_call_args = new_tool.__dict__.copy() + expected_call_args.update({ + "created_by": "user1", + "updated_by": "user1", + "author": "tenant1", + "is_available": True + }) + mock_tool_info_class.assert_called_once_with(**expected_call_args) + + +def test_update_tool_table_mcp_tools_same_name_same_usage(monkeypatch, mock_session): + """Test MCP tools with same name and same usage should update existing tool""" + session, query = mock_session + + # Mock existing MCP tool + existing_tool = MockToolInfo() + existing_tool.name = "get_tickets" + existing_tool.source = "mcp" + existing_tool.usage = "mcp_server_1" + existing_tool.description = "old description" + existing_tool.is_available = True + + mock_all = MagicMock() + mock_all.return_value = [existing_tool] + mock_filter = MagicMock() + mock_filter.all = mock_all + query.filter.return_value = mock_filter + + session.add = MagicMock() + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) + + # Create a new MCP tool with same name and same usage (should update existing) + new_tool = MockToolInfo() + new_tool.name = "get_tickets" + new_tool.source = "mcp" + new_tool.usage = "mcp_server_1" # Same MCP server + new_tool.description = "new description" + tool_list = [new_tool] + + update_tool_table_from_scan_tool_list("tenant1", "user1", tool_list) + + # Verify that session.add was NOT called (tool should be updated, not created) + session.add.assert_not_called() + # Verify that existing tool was updated + assert existing_tool.description == "new description" + assert existing_tool.updated_by == "user1" + assert existing_tool.is_available is True + + +def test_update_tool_table_mcp_tools_empty_usage(monkeypatch, mock_session): + """Test MCP tools with empty/null usage should be handled correctly""" + session, query = mock_session + + # Mock existing MCP tool with empty usage + existing_tool = MockToolInfo() + existing_tool.name = "get_tickets" + existing_tool.source = "mcp" + existing_tool.usage = None # Empty usage + + mock_all = MagicMock() + mock_all.return_value = [existing_tool] + mock_filter = MagicMock() + mock_filter.all = mock_all + query.filter.return_value = mock_filter + + session.add = MagicMock() + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) + + # Create a mock for ToolInfo class constructor + mock_tool_info_instance = MagicMock() + mock_tool_info_class = MagicMock(return_value=mock_tool_info_instance) + monkeypatch.setattr("backend.database.tool_db.ToolInfo", + mock_tool_info_class) + + # Create a new MCP tool with same name and empty usage (should update existing) + new_tool = MockToolInfo() + new_tool.name = "get_tickets" + new_tool.source = "mcp" + new_tool.usage = "" # Empty usage (same as None) + tool_list = [new_tool] + + update_tool_table_from_scan_tool_list("tenant1", "user1", tool_list) + + # Verify that session.add was NOT called (tool should be updated, not created) + session.add.assert_not_called() + # Verify that existing tool was updated + assert existing_tool.updated_by == "user1" + + +def test_update_tool_table_non_mcp_tools_use_name_source(monkeypatch, mock_session): + """Test non-MCP tools should still use name&source as unique key""" + session, query = mock_session + + # Mock existing non-MCP tool + existing_tool = MockToolInfo() + existing_tool.name = "test_tool" + existing_tool.source = "local" + existing_tool.usage = "some_usage" # Usage should be ignored for non-MCP tools + + mock_all = MagicMock() + mock_all.return_value = [existing_tool] + mock_filter = MagicMock() + mock_filter.all = mock_all + query.filter.return_value = mock_filter + + session.add = MagicMock() + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) + + # Create a new non-MCP tool with same name and source but different usage + new_tool = MockToolInfo() + new_tool.name = "test_tool" + new_tool.source = "local" + # Different usage, but should still match existing tool + new_tool.usage = "different_usage" + tool_list = [new_tool] + + update_tool_table_from_scan_tool_list("tenant1", "user1", tool_list) + + # Verify that session.add was NOT called (tool should be updated, not created) + # because non-MCP tools use name&source as unique key, ignoring usage + session.add.assert_not_called() + # Verify that existing tool was updated + assert existing_tool.updated_by == "user1" + + +def test_update_tool_table_mcp_tools_multiple_different_servers(monkeypatch, mock_session): + """Test multiple MCP tools from different servers with same name should all be created""" + session, query = mock_session + + # Mock existing MCP tool from server1 + existing_tool = MockToolInfo() + existing_tool.name = "get_tickets" + existing_tool.source = "mcp" + existing_tool.usage = "mcp_server_1" + + mock_all = MagicMock() + mock_all.return_value = [existing_tool] + mock_filter = MagicMock() + mock_filter.all = mock_all + query.filter.return_value = mock_filter + + session.add = MagicMock() + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) + + # Create a mock for ToolInfo class constructor + mock_tool_info_instance = MagicMock() + mock_tool_info_class = MagicMock(return_value=mock_tool_info_instance) + monkeypatch.setattr("backend.database.tool_db.ToolInfo", + mock_tool_info_class) + + # Create two new MCP tools with same name but different usage (different servers) + new_tool1 = MockToolInfo() + new_tool1.name = "get_tickets" + new_tool1.source = "mcp" + new_tool1.usage = "mcp_server_2" # Different server + + new_tool2 = MockToolInfo() + new_tool2.name = "get_tickets" + new_tool2.source = "mcp" + new_tool2.usage = "mcp_server_3" # Another different server + + tool_list = [new_tool1, new_tool2] + + update_tool_table_from_scan_tool_list("tenant1", "user1", tool_list) + + # Verify that session.add was called twice (one for each new tool) + assert session.add.call_count == 2 + + +def test_update_tool_table_mixed_mcp_and_non_mcp_tools(monkeypatch, mock_session): + """Test mixed scenario with both MCP and non-MCP tools""" + session, query = mock_session + + # Mock existing tools: one MCP tool and one non-MCP tool + existing_mcp_tool = MockToolInfo() + existing_mcp_tool.name = "get_tickets" + existing_mcp_tool.source = "mcp" + existing_mcp_tool.usage = "mcp_server_1" + + existing_local_tool = MockToolInfo() + existing_local_tool.name = "local_tool" + existing_local_tool.source = "local" + existing_local_tool.usage = "some_usage" + + mock_all = MagicMock() + mock_all.return_value = [existing_mcp_tool, existing_local_tool] + mock_filter = MagicMock() + mock_filter.all = mock_all + query.filter.return_value = mock_filter + + session.add = MagicMock() + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) + + # Create a mock for ToolInfo class constructor + mock_tool_info_instance = MagicMock() + mock_tool_info_class = MagicMock(return_value=mock_tool_info_instance) + monkeypatch.setattr("backend.database.tool_db.ToolInfo", + mock_tool_info_class) + + # Create tools: update existing MCP tool, update existing local tool, create new MCP tool + update_mcp_tool = MockToolInfo() + update_mcp_tool.name = "get_tickets" + update_mcp_tool.source = "mcp" + update_mcp_tool.usage = "mcp_server_1" # Same as existing, should update + + update_local_tool = MockToolInfo() + update_local_tool.name = "local_tool" + update_local_tool.source = "local" # Same as existing, should update + + new_mcp_tool = MockToolInfo() + new_mcp_tool.name = "get_tickets" + new_mcp_tool.source = "mcp" + new_mcp_tool.usage = "mcp_server_2" # Different server, should create + + tool_list = [update_mcp_tool, update_local_tool, new_mcp_tool] + + update_tool_table_from_scan_tool_list("tenant1", "user1", tool_list) + + # Verify that session.add was called once (only for the new MCP tool) + assert session.add.call_count == 1 + # Verify that existing tools were updated + assert existing_mcp_tool.updated_by == "user1" + assert existing_local_tool.updated_by == "user1" + + +def test_update_tool_table_mcp_tool_update_existing_attributes(monkeypatch, mock_session): + """Test that updating existing MCP tool properly updates all attributes""" + session, query = mock_session + + # Mock existing MCP tool + existing_tool = MockToolInfo() + existing_tool.name = "get_tickets" + existing_tool.source = "mcp" + existing_tool.usage = "mcp_server_1" + existing_tool.description = "old description" + existing_tool.params = [{"name": "old_param"}] + existing_tool.is_available = True + + mock_all = MagicMock() + mock_all.return_value = [existing_tool] + mock_filter = MagicMock() + mock_filter.all = mock_all + query.filter.return_value = mock_filter + + session.add = MagicMock() + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) + + # Create updated MCP tool with same name and usage + updated_tool = MockToolInfo() + updated_tool.name = "get_tickets" + updated_tool.source = "mcp" + updated_tool.usage = "mcp_server_1" + updated_tool.description = "new description" + updated_tool.params = [{"name": "new_param"}] + tool_list = [updated_tool] + + update_tool_table_from_scan_tool_list("tenant1", "user1", tool_list) + + # Verify that session.add was NOT called (tool should be updated, not created) + session.add.assert_not_called() + # Verify that existing tool attributes were updated + assert existing_tool.description == "new description" + assert existing_tool.params == [{"name": "new_param"}] + assert existing_tool.updated_by == "user1" + assert existing_tool.is_available is True + + +def test_update_tool_table_existing_tools_set_unavailable(monkeypatch, mock_session): + """Test that all existing tools are set to unavailable before processing tool list""" + session, query = mock_session + + # Mock multiple existing tools + existing_tool1 = MockToolInfo() + existing_tool1.name = "tool1" + existing_tool1.source = "local" + existing_tool1.is_available = True + + existing_tool2 = MockToolInfo() + existing_tool2.name = "get_tickets" + existing_tool2.source = "mcp" + existing_tool2.usage = "mcp_server_1" + existing_tool2.is_available = True + + mock_all = MagicMock() + mock_all.return_value = [existing_tool1, existing_tool2] + mock_filter = MagicMock() + mock_filter.all = mock_all + query.filter.return_value = mock_filter + + session.add = MagicMock() + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) + + # Create a mock for ToolInfo class constructor + mock_tool_info_instance = MagicMock() + mock_tool_info_class = MagicMock(return_value=mock_tool_info_instance) + monkeypatch.setattr("backend.database.tool_db.ToolInfo", + mock_tool_info_class) + + # Create tool list with only one tool (tool2 will be updated, tool1 will remain unavailable) + updated_tool = MockToolInfo() + updated_tool.name = "get_tickets" + updated_tool.source = "mcp" + updated_tool.usage = "mcp_server_1" + tool_list = [updated_tool] + + update_tool_table_from_scan_tool_list("tenant1", "user1", tool_list) + + # Verify that existing_tool1 is set to unavailable (not in tool_list) + assert existing_tool1.is_available is False + # Verify that existing_tool2 is set to available (updated from tool_list) + assert existing_tool2.is_available is True + + +def test_update_tool_table_mcp_tool_invalid_name(monkeypatch, mock_session): + """Test MCP tool with invalid name should set is_available=False""" + session, query = mock_session + + # Mock existing tools + existing_tool = MockToolInfo() + existing_tool.name = "existing_tool" + existing_tool.source = "local" + + mock_all = MagicMock() + mock_all.return_value = [existing_tool] + mock_filter = MagicMock() + mock_filter.all = mock_all + query.filter.return_value = mock_filter + + session.add = MagicMock() + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) + + # Create a mock for ToolInfo class constructor + mock_tool_info_instance = MagicMock() + mock_tool_info_class = MagicMock(return_value=mock_tool_info_instance) + monkeypatch.setattr("backend.database.tool_db.ToolInfo", + mock_tool_info_class) + + # Create a new MCP tool with invalid name (contains special characters) + new_tool = MockToolInfo() + new_tool.name = "invalid-tool-name!" # Contains dash and exclamation mark + new_tool.source = "mcp" + new_tool.usage = "mcp_server_1" + tool_list = [new_tool] + + update_tool_table_from_scan_tool_list("tenant1", "user1", tool_list) + + # Verify that session.add was called to add the new tool + session.add.assert_called_once_with(mock_tool_info_instance) + # Verify that ToolInfo constructor was called with is_available=False for invalid name + expected_call_args = new_tool.__dict__.copy() + expected_call_args.update({ + "created_by": "user1", + "updated_by": "user1", + "author": "tenant1", + "is_available": False # Should be False for invalid tool name + }) + mock_tool_info_class.assert_called_once_with(**expected_call_args) + + def test_add_tool_field(monkeypatch, mock_session): """Test adding tool field""" session, query = mock_session @@ -453,8 +975,10 @@ def test_add_tool_field(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.as_dict", lambda obj: obj.__dict__) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.tool_db.as_dict", + lambda obj: obj.__dict__) tool_info = {"tool_id": 1, "params": {"param1": "value1"}} result = add_tool_field(tool_info) @@ -463,6 +987,7 @@ def test_add_tool_field(monkeypatch, mock_session): assert result["description"] == "test description" assert result["source"] == "test_source" + def test_search_tools_for_sub_agent(monkeypatch, mock_session): """Test searching tools for sub-agent""" session, query = mock_session @@ -477,15 +1002,19 @@ def test_search_tools_for_sub_agent(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.as_dict", lambda obj: obj.__dict__) - monkeypatch.setattr("backend.database.tool_db.add_tool_field", lambda data: data) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.tool_db.as_dict", + lambda obj: obj.__dict__) + monkeypatch.setattr( + "backend.database.tool_db.add_tool_field", lambda data: data) result = search_tools_for_sub_agent(1, "tenant1") assert len(result) == 1 assert result[0]["tool_instance_id"] == 1 + def test_check_tool_is_available(monkeypatch, mock_session): """Test checking if tool is available""" session, query = mock_session @@ -499,12 +1028,14 @@ def test_check_tool_is_available(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) result = check_tool_is_available([1, 2]) assert result == [True] + def test_delete_tools_by_agent_id_success(monkeypatch, mock_session): """Test successfully deleting agent's tools""" session, query = mock_session @@ -516,7 +1047,8 @@ def test_delete_tools_by_agent_id_success(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) # Function returns no value, only verify successful execution delete_tools_by_agent_id(1, "tenant1", "user1") @@ -542,8 +1074,10 @@ def test_search_last_tool_instance_by_tool_id_found(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.as_dict", lambda obj: obj.__dict__) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.tool_db.as_dict", + lambda obj: obj.__dict__) result = search_last_tool_instance_by_tool_id(1, "tenant1", "user1") @@ -551,6 +1085,7 @@ def test_search_last_tool_instance_by_tool_id_found(monkeypatch, mock_session): assert result["tool_id"] == 1 assert result["params"] == {"param1": "value1", "param2": "value2"} + def test_search_last_tool_instance_by_tool_id_not_found(monkeypatch, mock_session): """Test searching for non-existent last tool instance""" session, query = mock_session @@ -565,12 +1100,14 @@ def test_search_last_tool_instance_by_tool_id_not_found(monkeypatch, mock_sessio mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) result = search_last_tool_instance_by_tool_id(999, "tenant1", "user1") assert result is None + def test_search_last_tool_instance_by_tool_id_with_deleted_flag(monkeypatch, mock_session): """Test searching for tool instance with deleted flag filter""" session, query = mock_session @@ -588,8 +1125,10 @@ def test_search_last_tool_instance_by_tool_id_with_deleted_flag(monkeypatch, moc mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.as_dict", lambda obj: obj.__dict__) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.tool_db.as_dict", + lambda obj: obj.__dict__) result = search_last_tool_instance_by_tool_id(1, "tenant1", "user1") @@ -597,6 +1136,7 @@ def test_search_last_tool_instance_by_tool_id_with_deleted_flag(monkeypatch, moc # Verify that the filter was called with correct parameters assert query.filter.call_count == 1 + def test_search_last_tool_instance_by_tool_id_ordering(monkeypatch, mock_session): """Test that results are ordered by update_time desc""" session, query = mock_session @@ -613,8 +1153,10 @@ def test_search_last_tool_instance_by_tool_id_ordering(monkeypatch, mock_session mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.as_dict", lambda obj: obj.__dict__) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.tool_db.as_dict", + lambda obj: obj.__dict__) result = search_last_tool_instance_by_tool_id(1, "tenant1", "user1") @@ -622,6 +1164,7 @@ def test_search_last_tool_instance_by_tool_id_ordering(monkeypatch, mock_session mock_filter.order_by.assert_called_once() assert result is not None + def test_search_last_tool_instance_by_tool_id_different_tenants(monkeypatch, mock_session): """Test searching with different tenant and user IDs""" session, query = mock_session @@ -640,8 +1183,10 @@ def test_search_last_tool_instance_by_tool_id_different_tenants(monkeypatch, moc mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.as_dict", lambda obj: obj.__dict__) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.tool_db.as_dict", + lambda obj: obj.__dict__) result = search_last_tool_instance_by_tool_id(1, "tenant2", "user2") @@ -665,8 +1210,10 @@ def test_query_tool_instances_by_agent_id(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.as_dict", lambda obj: obj.__dict__) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.tool_db.as_dict", + lambda obj: obj.__dict__) result = query_tool_instances_by_agent_id(agent_id=1, tenant_id="tenant1") @@ -687,8 +1234,10 @@ def test_query_tool_instances_by_agent_id_empty(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.as_dict", lambda obj: obj.__dict__) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.tool_db.as_dict", + lambda obj: obj.__dict__) result = query_tool_instances_by_agent_id(agent_id=1, tenant_id="tenant1") @@ -709,10 +1258,13 @@ def test_query_tool_instances_by_agent_id_with_version(monkeypatch, mock_session mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.as_dict", lambda obj: obj.__dict__) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.tool_db.as_dict", + lambda obj: obj.__dict__) - result = query_tool_instances_by_agent_id(agent_id=1, tenant_id="tenant1", version_no=2) + result = query_tool_instances_by_agent_id( + agent_id=1, tenant_id="tenant1", version_no=2) assert len(result) == 1 assert result[0]["tool_id"] == 1 @@ -730,7 +1282,8 @@ def test_check_tool_list_initialized_has_tools(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) result = check_tool_list_initialized("tenant1") @@ -750,7 +1303,8 @@ def test_check_tool_list_initialized_no_tools(monkeypatch, mock_session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) result = check_tool_list_initialized("new_tenant") @@ -770,7 +1324,8 @@ def test_check_tool_list_initialized_with_deleted_tools_only(monkeypatch, mock_s mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) result = check_tool_list_initialized("tenant_with_only_deleted_tools") @@ -789,7 +1344,8 @@ def test_check_tool_list_initialized_correct_tenant_filter(monkeypatch, mock_ses mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.tool_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.tool_db.get_db_session", lambda: mock_ctx) target_tenant = "specific_tenant_id" check_tool_list_initialized(target_tenant) From 9481891eb7ce95aca60f0a99af22e648fb407ce5 Mon Sep 17 00:00:00 2001 From: panyehong <2655992392@qq.com> Date: Tue, 10 Mar 2026 12:38:36 +0800 Subject: [PATCH 38/75] =?UTF-8?q?=F0=9F=90=9B=20Bugfix:=20Fix=20the=20issu?= =?UTF-8?q?e=20where=20MCP=20services=20with=20the=20same=20tool=20cannot?= =?UTF-8?q?=20be=20displayed.=20#2294=20[Specification=20Details]=201.=20M?= =?UTF-8?q?odify=20test=20cases.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/backend/database/test_tool_db.py | 33 +++++++++++++-------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/test/backend/database/test_tool_db.py b/test/backend/database/test_tool_db.py index e37f13ffa..a5e13dfcb 100644 --- a/test/backend/database/test_tool_db.py +++ b/test/backend/database/test_tool_db.py @@ -1,19 +1,3 @@ -from backend.database.tool_db import ( - create_tool, - create_or_update_tool_by_tool_info, - query_all_tools, - query_tool_instances_by_id, - query_tool_instances_by_agent_id, - query_tools_by_ids, - query_all_enabled_tool_instances, - update_tool_table_from_scan_tool_list, - add_tool_field, - search_tools_for_sub_agent, - check_tool_is_available, - delete_tools_by_agent_id, - search_last_tool_instance_by_tool_id, - check_tool_list_initialized -) import sys import pytest from unittest.mock import patch, MagicMock @@ -108,7 +92,22 @@ class MockModelModule: sys.modules['backend.database.agent_db'] = agent_db_mock # Now we can safely import the module being tested - +from backend.database.tool_db import ( + create_tool, + create_or_update_tool_by_tool_info, + query_all_tools, + query_tool_instances_by_id, + query_tool_instances_by_agent_id, + query_tools_by_ids, + query_all_enabled_tool_instances, + update_tool_table_from_scan_tool_list, + add_tool_field, + search_tools_for_sub_agent, + check_tool_is_available, + delete_tools_by_agent_id, + search_last_tool_instance_by_tool_id, + check_tool_list_initialized +) class MockToolInstance: def __init__(self): From b98d6faed1dd2a7b88d217e74a69c27a61c38749 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 10 Mar 2026 12:51:38 +0800 Subject: [PATCH 39/75] Bugfix: cookie need to be send evne through http --- frontend/server.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/server.js b/frontend/server.js index 798b019fd..05f098402 100644 --- a/frontend/server.js +++ b/frontend/server.js @@ -47,7 +47,7 @@ const isProduction = process.env.NODE_ENV === "production"; function buildCookieOptions(httpOnly) { return { httpOnly, - secure: isProduction, + secure: false, // cookie can be send through http sameSite: "lax", path: "/", }; From 38339d53ecac41f0c5f8e6b30d078c5a91841310 Mon Sep 17 00:00:00 2001 From: WMC001 <46217886+WMC001@users.noreply.github.com> Date: Tue, 10 Mar 2026 14:13:48 +0800 Subject: [PATCH 40/75] =?UTF-8?q?=F0=9F=90=9B=20Bugfix:=20Failed=20to=20mo?= =?UTF-8?q?dify=20Tenant=20Name?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../components/UserManageComp.tsx | 58 +++++++++++++++---- .../components/resources/GroupList.tsx | 16 +++++ .../components/resources/ModelList.tsx | 27 +++++---- .../components/resources/UserList.tsx | 15 +++++ 4 files changed, 95 insertions(+), 21 deletions(-) diff --git a/frontend/app/[locale]/tenant-resources/components/UserManageComp.tsx b/frontend/app/[locale]/tenant-resources/components/UserManageComp.tsx index ffa68f62b..7c14318d1 100644 --- a/frontend/app/[locale]/tenant-resources/components/UserManageComp.tsx +++ b/frontend/app/[locale]/tenant-resources/components/UserManageComp.tsx @@ -1,6 +1,7 @@ "use client"; import React, { useState, useEffect, useRef } from "react"; +import { useQuery } from "@tanstack/react-query"; import { Row, Col, @@ -27,6 +28,7 @@ import { updateTenant, deleteTenant, getTenantUsers, + getTenant, } from "@/services/tenantService"; import { createInvitation, deleteInvitation } from "@/services/invitationService"; import { authService } from "@/services/authService"; @@ -562,6 +564,31 @@ export default function UserManageComp() { refetch: refetchTenants, } = useTenantList({ page: currentPage, page_size: DEFAULT_PAGE_SIZE }); + // For non-super admins, automatically select their own tenant based on user.tenantId + // This must be declared before useQuery that uses tenantId + const [tenantId, setTenantId] = useState(null); + useEffect(() => { + if (!isSuperAdmin && user?.tenantId && !tenantId) { + setTenantId(user.tenantId); + } + }, [isSuperAdmin, tenantId, user?.tenantId]); + + // For non-super-admin users, directly fetch their tenant details + // This ensures they always get the correct tenant info regardless of pagination + const { + data: directTenantData, + isLoading: directTenantLoading, + refetch: refetchDirectTenant, + } = useQuery({ + queryKey: ["tenant", tenantId], + queryFn: async () => { + if (!tenantId || isSuperAdmin) return null; + return await getTenant(tenantId); + }, + enabled: !!tenantId && !isSuperAdmin, + staleTime: 1000 * 60, // Cache for 1 minute + }); + // Handle page change const handlePageChange = (page: number) => { setCurrentPage(page); @@ -583,17 +610,21 @@ export default function UserManageComp() { // Invitation list refresh key - increment to trigger invitation list refetch const [invitationListRefreshKey, setInvitationListRefreshKey] = useState(0); - // For non-super admins, automatically select their own tenant based on user.tenantId - const [tenantId, setTenantId] = useState(null); - useEffect(() => { - if (!isSuperAdmin && user?.tenantId && !tenantId) { - setTenantId(user.tenantId); - } - }, [isSuperAdmin, tenantId, user?.tenantId]); - // Get current tenant name - const currentTenant = tenantData?.data?.find((t: Tenant) => t.tenant_id === tenantId); - const currentTenantName = currentTenant?.tenant_name || t("tenantResources.tenants.unnamed"); + // For non-super-admin: use directly fetched tenant data (directTenantData) + // For super-admin: use paginated tenant list (tenantData) + let currentTenant: Tenant | undefined; + let currentTenantName: string; + + if (!isSuperAdmin && directTenantData) { + // Non-super-admin: use directly fetched tenant info + currentTenant = directTenantData; + currentTenantName = directTenantData.tenant_name || t("tenantResources.tenants.unnamed"); + } else { + // Super-admin: search in paginated list + currentTenant = tenantData?.data?.find((t: Tenant) => t.tenant_id === tenantId); + currentTenantName = currentTenant?.tenant_name || t("tenantResources.tenants.unnamed"); + } // Tenant name editing states const [isEditingTenantName, setIsEditingTenantName] = useState(false); @@ -625,7 +656,12 @@ export default function UserManageComp() { } try { await updateTenant(tenantId, { tenant_name: trimmedName }); - await refetchTenants(); + // For non-super-admin, refetch the direct tenant data; for super-admin, refetch the list + if (!isSuperAdmin) { + await refetchDirectTenant(); + } else { + await refetchTenants(); + } message.success(t("tenantResources.tenants.updated")); setIsEditingTenantName(false); } catch (error) { diff --git a/frontend/app/[locale]/tenant-resources/components/resources/GroupList.tsx b/frontend/app/[locale]/tenant-resources/components/resources/GroupList.tsx index cf9843889..47aca4334 100644 --- a/frontend/app/[locale]/tenant-resources/components/resources/GroupList.tsx +++ b/frontend/app/[locale]/tenant-resources/components/resources/GroupList.tsx @@ -13,6 +13,8 @@ import { message, Select, } from "antd"; +import type { TablePaginationConfig } from "antd"; +import { FilterValue, SorterResult } from "antd/es/table/interface"; import { Edit, Trash2 } from "lucide-react"; import { Tooltip } from "@/components/ui/tooltip"; import { ColumnsType } from "antd/es/table"; @@ -211,6 +213,20 @@ export default function GroupList({ tenantId }: { tenantId: string | null }) { } }; + // Handle pagination change + const handlePageChange = ( + pagination: TablePaginationConfig, + _filters: Record, + _sorter: SorterResult | SorterResult[] + ) => { + const newPage = pagination.current || 1; + const newPageSize = pagination.pageSize || 10; + setPage(newPage); + if (newPageSize !== pageSize) { + setPageSize(newPageSize); + } + }; + const columns: ColumnsType = useMemo( () => [ { title: t("tenantResources.groups.name"), dataIndex: "group_name", key: "group_name" }, diff --git a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx index f3abbe011..6de719127 100644 --- a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx +++ b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx @@ -6,6 +6,8 @@ import { Table, Button, Popconfirm, message, Tag, Pagination } from "antd"; import { Edit, Trash2, RefreshCw } from "lucide-react"; import { Tooltip } from "@/components/ui/tooltip"; import { ColumnsType } from "antd/es/table"; +import type { TablePaginationConfig } from "antd"; +import { FilterValue, SorterResult } from "antd/es/table/interface"; import { useManageTenantModels } from "@/hooks/model/useManageTenantModels"; import { modelService } from "@/services/modelService"; import { type ModelOption, type ModelType } from "@/types/modelConfig"; @@ -121,11 +123,16 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { }; // Handle pagination change - const handlePageChange = (newPage: number, newPageSize: number) => { + const handlePageChange = ( + pagination: TablePaginationConfig, + _filters: Record, + _sorter: SorterResult | SorterResult[] + ) => { + const newPage = pagination.current || 1; + const newPageSize = pagination.pageSize || 10; setPage(newPage); if (newPageSize !== pageSize) { setPageSize(newPageSize); - setPage(1); } }; @@ -135,13 +142,8 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { title: t("common.name"), dataIndex: "displayName", key: "displayName", - width: 170, - render: (text: string, record: ModelOption) => ( -
-
{text || record.name}
-
{record.name}
-
- ), + width: 200, + ellipsis: true, }, { title: t("common.type"), @@ -245,7 +247,12 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { dataSource={models} loading={isLoading} rowKey="id" - pagination={{ pageSize: 10 }} + pagination={{ + current: page, + pageSize: pageSize, + total: total + }} + onChange={handlePageChange} scroll={{ x: true }} className="flex-1" /> diff --git a/frontend/app/[locale]/tenant-resources/components/resources/UserList.tsx b/frontend/app/[locale]/tenant-resources/components/resources/UserList.tsx index 8e7438a5c..59a77c47f 100644 --- a/frontend/app/[locale]/tenant-resources/components/resources/UserList.tsx +++ b/frontend/app/[locale]/tenant-resources/components/resources/UserList.tsx @@ -13,6 +13,8 @@ import { message, Tag, } from "antd"; +import type { TablePaginationConfig } from "antd"; +import { FilterValue, SorterResult } from "antd/es/table/interface"; import { Edit, Trash2 } from "lucide-react"; import { Tooltip } from "@/components/ui/tooltip"; import { ColumnsType } from "antd/es/table"; @@ -135,6 +137,19 @@ export default function UserList({ tenantId, refreshKey }: { tenantId: string | } }; + const handlePageChange = ( + pagination: TablePaginationConfig, + _filters: Record, + _sorter: SorterResult | SorterResult[] + ) => { + const newPage = pagination.current || 1; + const newPageSize = pagination.pageSize || 10; + setPage(newPage); + if (newPageSize !== pageSize) { + setPageSize(newPageSize); + } + }; + const columns: ColumnsType = useMemo( () => [ { From f858d331cee11207d2fd0cebb6a9c41fefecd1af Mon Sep 17 00:00:00 2001 From: WMC001 <46217886+WMC001@users.noreply.github.com> Date: Tue, 10 Mar 2026 14:27:10 +0800 Subject: [PATCH 41/75] =?UTF-8?q?=F0=9F=90=9B=20Bugfix:=20Failed=20to=20mo?= =?UTF-8?q?dify=20Tenant=20Name=202?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../components/resources/GroupList.tsx | 16 ---------------- .../components/resources/UserList.tsx | 15 --------------- 2 files changed, 31 deletions(-) diff --git a/frontend/app/[locale]/tenant-resources/components/resources/GroupList.tsx b/frontend/app/[locale]/tenant-resources/components/resources/GroupList.tsx index 47aca4334..cf9843889 100644 --- a/frontend/app/[locale]/tenant-resources/components/resources/GroupList.tsx +++ b/frontend/app/[locale]/tenant-resources/components/resources/GroupList.tsx @@ -13,8 +13,6 @@ import { message, Select, } from "antd"; -import type { TablePaginationConfig } from "antd"; -import { FilterValue, SorterResult } from "antd/es/table/interface"; import { Edit, Trash2 } from "lucide-react"; import { Tooltip } from "@/components/ui/tooltip"; import { ColumnsType } from "antd/es/table"; @@ -213,20 +211,6 @@ export default function GroupList({ tenantId }: { tenantId: string | null }) { } }; - // Handle pagination change - const handlePageChange = ( - pagination: TablePaginationConfig, - _filters: Record, - _sorter: SorterResult | SorterResult[] - ) => { - const newPage = pagination.current || 1; - const newPageSize = pagination.pageSize || 10; - setPage(newPage); - if (newPageSize !== pageSize) { - setPageSize(newPageSize); - } - }; - const columns: ColumnsType = useMemo( () => [ { title: t("tenantResources.groups.name"), dataIndex: "group_name", key: "group_name" }, diff --git a/frontend/app/[locale]/tenant-resources/components/resources/UserList.tsx b/frontend/app/[locale]/tenant-resources/components/resources/UserList.tsx index 59a77c47f..8e7438a5c 100644 --- a/frontend/app/[locale]/tenant-resources/components/resources/UserList.tsx +++ b/frontend/app/[locale]/tenant-resources/components/resources/UserList.tsx @@ -13,8 +13,6 @@ import { message, Tag, } from "antd"; -import type { TablePaginationConfig } from "antd"; -import { FilterValue, SorterResult } from "antd/es/table/interface"; import { Edit, Trash2 } from "lucide-react"; import { Tooltip } from "@/components/ui/tooltip"; import { ColumnsType } from "antd/es/table"; @@ -137,19 +135,6 @@ export default function UserList({ tenantId, refreshKey }: { tenantId: string | } }; - const handlePageChange = ( - pagination: TablePaginationConfig, - _filters: Record, - _sorter: SorterResult | SorterResult[] - ) => { - const newPage = pagination.current || 1; - const newPageSize = pagination.pageSize || 10; - setPage(newPage); - if (newPageSize !== pageSize) { - setPageSize(newPageSize); - } - }; - const columns: ColumnsType = useMemo( () => [ { From 24e32d80415387274262ac71ff4ca9fa3715a6b6 Mon Sep 17 00:00:00 2001 From: panyehong <2655992392@qq.com> Date: Tue, 10 Mar 2026 15:15:30 +0800 Subject: [PATCH 42/75] =?UTF-8?q?=F0=9F=90=9B=20Bugfix:=20Solving=20the=20?= =?UTF-8?q?problem=20of=20duplicate=20tool=20instances=20in=20intelligent?= =?UTF-8?q?=20agents=20#2647=20[Specification=20Details]=201.=20When=20upd?= =?UTF-8?q?ating=20the=20tool=5Finstance=20table,=20do=20not=20use=20user?= =?UTF-8?q?=5Fid=20as=20a=20query=20condition=20to=20ensure=20that=20there?= =?UTF-8?q?=20is=20only=20one=20tool=5Finstance=20record=20for=20the=20sam?= =?UTF-8?q?e=20tool.=202.=20Modify=20test=20cases.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/database/tool_db.py | 12 +++++++--- test/backend/database/test_tool_db.py | 32 +++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/backend/database/tool_db.py b/backend/database/tool_db.py index 0514bc945..2a64c47d6 100644 --- a/backend/database/tool_db.py +++ b/backend/database/tool_db.py @@ -37,7 +37,7 @@ def create_or_update_tool_by_tool_info(tool_info, tenant_id: str, user_id: str, Args: tool_info: Dictionary containing tool information tenant_id: Tenant ID for filtering, mandatory - user_id: Optional user ID for filtering + user_id: User ID for updating (will be set as the last updater) version_no: Version number to filter. Default 0 = draft/editing state Returns: @@ -48,9 +48,10 @@ def create_or_update_tool_by_tool_info(tool_info, tenant_id: str, user_id: str, with get_db_session() as session: # Query if there is an existing ToolInstance + # Note: Do not filter by user_id to avoid creating duplicate instances + # for the same agent_id and tool_id when different users save query = session.query(ToolInstance).filter( ToolInstance.tenant_id == tenant_id, - ToolInstance.user_id == user_id, ToolInstance.agent_id == tool_info_dict['agent_id'], ToolInstance.delete_flag != 'Y', ToolInstance.tool_id == tool_info_dict['tool_id'], @@ -63,7 +64,12 @@ def create_or_update_tool_by_tool_info(tool_info, tenant_id: str, user_id: str, if hasattr(tool_instance, key): setattr(tool_instance, key, value) else: - create_tool(tool_info_dict, version_no) + # Create a new ToolInstance + new_tool_instance = ToolInstance( + **filter_property(tool_info_dict, ToolInstance)) + session.add(new_tool_instance) + session.flush() # Flush to get the ID + tool_instance = new_tool_instance return tool_instance diff --git a/test/backend/database/test_tool_db.py b/test/backend/database/test_tool_db.py index a5e13dfcb..936f66dc1 100644 --- a/test/backend/database/test_tool_db.py +++ b/test/backend/database/test_tool_db.py @@ -228,14 +228,42 @@ def test_create_or_update_tool_by_tool_info_create_new(monkeypatch, mock_session mock_ctx.__exit__.return_value = None monkeypatch.setattr( "backend.database.tool_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.tool_db.create_tool", MagicMock()) + monkeypatch.setattr( + "backend.database.tool_db.filter_property", lambda data, model: data) + + # Mock ToolInstance class - needs to have column attributes for query building + mock_tool_instance = MockToolInstance() + + # Create a Mock class that can be used both as a class (for query) and instantiated + class MockToolInstanceClass: + tenant_id = MagicMock() + agent_id = MagicMock() + tool_id = MagicMock() + delete_flag = MagicMock() + version_no = MagicMock() + + def __init__(self, **kwargs): + # Copy attributes from mock_tool_instance + for key, value in mock_tool_instance.__dict__.items(): + setattr(self, key, value) + # Update with any kwargs passed + for key, value in kwargs.items(): + setattr(self, key, value) + + monkeypatch.setattr( + "backend.database.tool_db.ToolInstance", MockToolInstanceClass) + + session.add = MagicMock() + session.flush = MagicMock() tool_info = MagicMock() tool_info.__dict__ = {"agent_id": 1, "tool_id": 1} result = create_or_update_tool_by_tool_info(tool_info, "tenant1", "user1") - assert result is None + assert isinstance(result, MockToolInstanceClass) + session.add.assert_called_once() + session.flush.assert_called_once() def test_query_all_tools(monkeypatch, mock_session): From b7ef52cd0185c9e45f9166cbb19d254dcd43e0fd Mon Sep 17 00:00:00 2001 From: Jasonxia007 Date: Tue, 10 Mar 2026 15:23:07 +0800 Subject: [PATCH 43/75] =?UTF-8?q?=E2=9C=A8=20Add=20access=20key=20in=20use?= =?UTF-8?q?r=20profile=20page=20=E2=99=BB=EF=B8=8F=20Remove=20deprecated?= =?UTF-8?q?=20HMAC=20logics=20in=20northbound=20interfaces?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/apps/northbound_app.py | 229 ++++++++---------- backend/apps/user_management_app.py | 109 ++++++++- backend/database/db_models.py | 29 +++ backend/database/token_db.py | 189 +++++++++++++++ backend/services/northbound_service.py | 184 +++++++++----- backend/services/user_management_service.py | 49 ++++ backend/utils/auth_utils.py | 211 +++++----------- .../sql/v1.8.0.3_0306_add_user_token_info.sql | 112 +++++++++ .../users/components/UserProfileComp.tsx | 151 +++++++++++- frontend/public/locales/en/common.json | 13 + frontend/public/locales/zh/common.json | 13 + frontend/services/api.ts | 2 + frontend/services/tokenService.ts | 69 ++++++ 13 files changed, 1014 insertions(+), 346 deletions(-) create mode 100644 backend/database/token_db.py create mode 100644 docker/sql/v1.8.0.3_0306_add_user_token_info.sql create mode 100644 frontend/services/tokenService.ts diff --git a/backend/apps/northbound_app.py b/backend/apps/northbound_app.py index a39877ded..cc392219f 100644 --- a/backend/apps/northbound_app.py +++ b/backend/apps/northbound_app.py @@ -1,12 +1,12 @@ import logging from http import HTTPStatus -from typing import Optional, Dict +from typing import Optional, Dict, Any import uuid -from fastapi import APIRouter, Body, Header, Request, HTTPException +from fastapi import APIRouter, Body, Header, Request, HTTPException, Query from fastapi.responses import JSONResponse -from consts.exceptions import UnauthorizedError, LimitExceededError, SignatureValidationError +from consts.exceptions import LimitExceededError, UnauthorizedError from services.northbound_service import ( NorthboundContext, get_conversation_history, @@ -14,86 +14,85 @@ start_streaming_chat, stop_chat, get_agent_info_list, - update_conversation_title + update_conversation_title, ) -from utils.auth_utils import get_current_user_id, validate_aksk_authentication +from utils.auth_utils import validate_bearer_token, get_user_and_tenant_by_access_key router = APIRouter(prefix="/nb/v1", tags=["northbound"]) -def _get_header(headers: Dict[str, str], name: str) -> Optional[str]: - for k, v in headers.items(): - if k.lower() == name.lower(): - return v - return None +async def _get_northbound_context(request: Request) -> NorthboundContext: + """ + Build northbound context from request. + Authentication: Bearer Token (API Key) in Authorization header + - Authorization: Bearer -async def _parse_northbound_context(request: Request) -> NorthboundContext: - """ - Build northbound context from headers. + The user_id and tenant_id are derived from the access_key by querying + user_token_info_t and user_tenant_t tables. - - X-Access-Key: Access key for AK/SK authentication - - X-Timestamp: Timestamp for signature validation - - X-Signature: HMAC-SHA256 signature signed with secret key - - Authorization: Bearer , jwt contains sub (user_id) - - X-Request-Id: optional, generated if not provided + Optional headers: + - X-Request-Id: Request ID, generated if not provided """ - # 1. Verify AK/SK signature + # 1. Validate Bearer Token and extract access_key try: - # Get request body for signature verification - request_body = "" - if request.method in ["POST", "PUT", "PATCH"]: - try: - body_bytes = await request.body() - request_body = body_bytes.decode('utf-8') if body_bytes else "" - except Exception as e: - logging.warning( - f"Cannot read request body for signature verification: {e}") - request_body = "" - - validate_aksk_authentication(request.headers, request_body) - except (UnauthorizedError, LimitExceededError, SignatureValidationError) as e: - raise e + auth_header = request.headers.get("Authorization") + is_valid, token_info = validate_bearer_token(auth_header) + + if not is_valid or not token_info: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail="Invalid or missing API key" + ) + + # Extract access_key from the token + access_key = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else auth_header + + # Get user_id and tenant_id from access_key + user_tenant_info = get_user_and_tenant_by_access_key(access_key) + resolved_user_id = user_tenant_info.get("user_id") + resolved_tenant_id = user_tenant_info.get("tenant_id") + token_id = user_tenant_info.get("token_id") + + except HTTPException: + raise + except UnauthorizedError as e: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail=str(e) + ) except Exception as e: - logging.error(f"Failed to parse northbound context: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail="Internal Server Error: cannot parse northbound context") + logging.error(f"Failed to validate bearer token: {str(e)}", exc_info=e) + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: invalid API key" + ) - # 2. Parse JWT token - auth_header = _get_header(request.headers, "Authorization") - if not auth_header: - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: No authorization header found") + if not resolved_user_id: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Missing user information for this access key" + ) - # Use auth_utils to parse JWT token - try: - user_id, tenant_id = get_current_user_id(auth_header) + if not resolved_tenant_id: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Missing tenant information for this access key" + ) - if not user_id: - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: missing user_id in JWT token") - if not tenant_id: - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: unregistered user_id in JWT token") + request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4()) - except HTTPException as e: - # Preserve explicit HTTP errors raised during JWT parsing - raise e - except Exception as e: - logging.error(f"Failed to parse JWT token: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail="Internal Server Error: cannot parse JWT token") - - request_id = _get_header( - request.headers, "X-Request-Id") or str(uuid.uuid4()) + # Get authorization header if present, otherwise use a placeholder + auth_header_value = request.headers.get("Authorization", "Bearer placeholder") return NorthboundContext( request_id=request_id, - tenant_id=tenant_id, - user_id=str(user_id), - authorization=auth_header, + tenant_id=resolved_tenant_id, + user_id=resolved_user_id, + authorization=auth_header_value, + token_id=token_id, ) @@ -105,34 +104,27 @@ async def health_check(): @router.post("/chat/run") async def run_chat( request: Request, - conversation_id: str = Body(..., embed=True), + conversation_id: Optional[int] = Body(None, embed=True), agent_name: str = Body(..., embed=True), query: str = Body(..., embed=True), + meta_data: Optional[Dict[str, Any]] = Body(None, embed=True), idempotency_key: Optional[str] = Header(None, alias="Idempotency-Key"), ): try: - ctx: NorthboundContext = await _parse_northbound_context(request) + ctx: NorthboundContext = await _get_northbound_context(request) return await start_streaming_chat( ctx=ctx, - external_conversation_id=conversation_id, + conversation_id=conversation_id, agent_name=agent_name, query=query, + meta_data=meta_data, idempotency_key=idempotency_key, ) - except UnauthorizedError as e: - logging.error(f"Unauthorized: AK/SK authentication failed: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: AK/SK authentication failed") except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") - except SignatureValidationError as e: - logging.error(f"Unauthorized: invalid signature: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: invalid signature") except HTTPException as e: - # Propagate HTTP errors from context parsing without altering status/detail raise e except Exception as e: logging.error(f"Failed to run chat: {str(e)}", exc_info=e) @@ -141,22 +133,25 @@ async def run_chat( @router.get("/chat/stop/{conversation_id}") -async def stop_chat_stream(request: Request, conversation_id: str): +async def stop_chat_stream( + request: Request, + conversation_id: int, + meta_data: Optional[str] = Query(None, description="Optional metadata as JSON string"), +): + import json + parsed_meta_data = None + if meta_data: + try: + parsed_meta_data = json.loads(meta_data) + except json.JSONDecodeError: + pass try: - ctx: NorthboundContext = await _parse_northbound_context(request) - return await stop_chat(ctx=ctx, external_conversation_id=conversation_id) - except UnauthorizedError as e: - logging.error(f"Unauthorized: AK/SK authentication failed: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: AK/SK authentication failed") + ctx: NorthboundContext = await _get_northbound_context(request) + return await stop_chat(ctx=ctx, conversation_id=conversation_id, meta_data=parsed_meta_data) except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") - except SignatureValidationError as e: - logging.error(f"Unauthorized: invalid signature: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: invalid signature") except HTTPException as e: raise e except Exception as e: @@ -166,22 +161,17 @@ async def stop_chat_stream(request: Request, conversation_id: str): @router.get("/conversations/{conversation_id}") -async def get_history(request: Request, conversation_id: str): +async def get_history( + request: Request, + conversation_id: int, +): try: - ctx: NorthboundContext = await _parse_northbound_context(request) - return await get_conversation_history(ctx=ctx, external_conversation_id=conversation_id) - except UnauthorizedError as e: - logging.error(f"Unauthorized: AK/SK authentication failed: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: AK/SK authentication failed") + ctx: NorthboundContext = await _get_northbound_context(request) + return await get_conversation_history(ctx=ctx, conversation_id=conversation_id) except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") - except SignatureValidationError as e: - logging.error(f"Unauthorized: invalid signature: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: invalid signature") except HTTPException as e: raise e except Exception as e: @@ -193,20 +183,12 @@ async def get_history(request: Request, conversation_id: str): @router.get("/agents") async def list_agents(request: Request): try: - ctx: NorthboundContext = await _parse_northbound_context(request) + ctx: NorthboundContext = await _get_northbound_context(request) return await get_agent_info_list(ctx=ctx) - except UnauthorizedError as e: - logging.error(f"Unauthorized: AK/SK authentication failed: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: AK/SK authentication failed") except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") - except SignatureValidationError as e: - logging.error(f"Unauthorized: invalid signature: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: invalid signature") except HTTPException as e: raise e except Exception as e: @@ -218,20 +200,12 @@ async def list_agents(request: Request): @router.get("/conversations") async def list_convs(request: Request): try: - ctx: NorthboundContext = await _parse_northbound_context(request) + ctx: NorthboundContext = await _get_northbound_context(request) return await list_conversations(ctx=ctx) - except UnauthorizedError as e: - logging.error(f"Unauthorized: AK/SK authentication failed: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: AK/SK authentication failed") except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") - except SignatureValidationError as e: - logging.error(f"Unauthorized: invalid signature: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: invalid signature") except HTTPException as e: raise e except Exception as e: @@ -243,34 +217,35 @@ async def list_convs(request: Request): @router.put("/conversations/{conversation_id}/title") async def update_convs_title( request: Request, - conversation_id: str, - title: str, + conversation_id: int, + title: str = Query(..., description="New title"), + meta_data: Optional[str] = Query(None, description="Optional metadata as JSON string"), idempotency_key: Optional[str] = Header(None, alias="Idempotency-Key"), ): + import json + parsed_meta_data = None + if meta_data: + try: + parsed_meta_data = json.loads(meta_data) + except json.JSONDecodeError: + pass try: - ctx: NorthboundContext = await _parse_northbound_context(request) + ctx: NorthboundContext = await _get_northbound_context(request) result = await update_conversation_title( ctx=ctx, - external_conversation_id=conversation_id, + conversation_id=conversation_id, title=title, + meta_data=parsed_meta_data, idempotency_key=idempotency_key, ) headers_out = { "Idempotency-Key": result.get("idempotency_key", ""), "X-Request-Id": ctx.request_id} return JSONResponse(content=result, headers=headers_out) - except UnauthorizedError as e: - logging.error(f"Unauthorized: AK/SK authentication failed: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: AK/SK authentication failed") except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") - except SignatureValidationError as e: - logging.error(f"Unauthorized: invalid signature: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: invalid signature") except HTTPException as e: raise e except Exception as e: diff --git a/backend/apps/user_management_app.py b/backend/apps/user_management_app.py index c38b4e73c..956832f52 100644 --- a/backend/apps/user_management_app.py +++ b/backend/apps/user_management_app.py @@ -1,9 +1,10 @@ import logging from dotenv import load_dotenv -from fastapi import APIRouter, Request, HTTPException +from fastapi import APIRouter, Header, Query, Request, HTTPException from fastapi.responses import JSONResponse from http import HTTPStatus +from typing import Optional from supabase_auth.errors import AuthApiError, AuthWeakPasswordError @@ -11,7 +12,7 @@ from consts.exceptions import NoInviteCodeException, IncorrectInviteCodeException, UserRegistrationException from services.user_management_service import get_authorized_client, validate_token, \ check_auth_service_health, signup_user_with_invitation, signin_user, refresh_user_token, \ - get_session_by_authorization, get_user_info + get_session_by_authorization, get_user_info, create_token, list_tokens_by_user, delete_token from services.user_service import delete_user_and_cleanup from consts.exceptions import UnauthorizedError from utils.auth_utils import get_current_user_id @@ -273,3 +274,107 @@ async def revoke_user_account(request: Request): logging.error(f"User revoke failed: {str(e)}") raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="User revoke failed") + +@router.post("/tokens") +async def create_token_endpoint( + authorization: Optional[str] = Header(None) +): + """Create a new token for the authenticated user. + + The user_id is extracted from the Authorization header (JWT token). + Returns the complete token including the secret key. + """ + try: + if not authorization: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: No authorization header found") + + user_id, _ = get_current_user_id(authorization) + if not user_id: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: missing user_id in JWT token") + + result = create_token(str(user_id)) + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "success", "data": result} + ) + except HTTPException as e: + raise e + except Exception as e: + logging.error(f"Failed to create token: {str(e)}", exc_info=e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Internal Server Error") + + +@router.get("/tokens") +async def list_tokens_endpoint( + user_id: str = Query(..., description="User ID to query tokens for"), + authorization: Optional[str] = Header(None) +): + """List all tokens for the specified user. + + Returns token information with masked access keys (middle part replaced with *). + """ + try: + if not authorization: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: No authorization header found") + + request_user_id, _ = get_current_user_id(authorization) + if not request_user_id: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: missing user_id in JWT token") + + # Only allow users to list their own tokens + if str(request_user_id) != user_id: + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, + detail="Forbidden: cannot list tokens for other users") + + tokens = list_tokens_by_user(user_id) + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "success", "data": tokens} + ) + except HTTPException as e: + raise e + except Exception as e: + logging.error(f"Failed to list tokens: {str(e)}", exc_info=e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Internal Server Error") + + +@router.delete("/tokens/{token_id}") +async def delete_token_endpoint( + token_id: int, + authorization: Optional[str] = Header(None) +): + """Soft delete a token. + + Only the owner of the token can delete it. + """ + try: + if not authorization: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: No authorization header found") + + user_id, _ = get_current_user_id(authorization) + if not user_id: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: missing user_id in JWT token") + + success = delete_token(token_id, str(user_id)) + if not success: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, + detail="Token not found or not owned by user") + + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "success", "data": {"token_id": token_id}} + ) + except HTTPException as e: + raise e + except Exception as e: + logging.error(f"Failed to delete token: {str(e)}", exc_info=e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Internal Server Error") diff --git a/backend/database/db_models.py b/backend/database/db_models.py index 36f475f53..80dcc87eb 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -1,4 +1,5 @@ from sqlalchemy import BigInteger, Boolean, Column, Integer, JSON, Numeric, PrimaryKeyConstraint, Sequence, String, Text, TIMESTAMP +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import DeclarativeBase from sqlalchemy.sql import func @@ -483,3 +484,31 @@ class AgentVersion(TableBase): source_version_no = Column(Integer, doc="Source version number. If this version is a rollback, record the source version") source_type = Column(String(30), doc="Source type: NORMAL (normal publish) / ROLLBACK (rollback and republish)") status = Column(String(30), default="RELEASED", doc="Version status: RELEASED / DISABLED / ARCHIVED") + + +class UserTokenInfo(TableBase): + """ + User token (AK/SK) information table + """ + __tablename__ = "user_token_info_t" + __table_args__ = {"schema": SCHEMA} + + token_id = Column(Integer, Sequence("user_token_info_t_token_id_seq", schema=SCHEMA), + primary_key=True, nullable=False, doc="Token ID, unique primary key") + access_key = Column(String(100), nullable=False, doc="Access Key (AK)") + user_id = Column(String(100), nullable=False, doc="User ID who owns this token") + + +class UserTokenUsageLog(TableBase): + """ + User token usage log table + """ + __tablename__ = "user_token_usage_log_t" + __table_args__ = {"schema": SCHEMA} + + token_usage_id = Column(Integer, Sequence("user_token_usage_log_t_token_usage_id_seq", schema=SCHEMA), + primary_key=True, nullable=False, doc="Token usage log ID, unique primary key") + token_id = Column(Integer, nullable=False, doc="Foreign key to user_token_info_t.token_id") + call_function_name = Column(String(100), doc="API function name being called") + related_id = Column(Integer, doc="Related resource ID (e.g., conversation_id)") + meta_data = Column(JSONB, doc="Additional metadata for this usage log entry, stored as JSON") diff --git a/backend/database/token_db.py b/backend/database/token_db.py new file mode 100644 index 000000000..3be0e93f3 --- /dev/null +++ b/backend/database/token_db.py @@ -0,0 +1,189 @@ +""" +Database operations for user API token (API Key) management. +""" +import secrets +from typing import Any, Dict, List, Optional + +from database.client import get_db_session +from database.db_models import UserTokenInfo, UserTokenUsageLog + + +def generate_access_key() -> str: + """Generate a random access key with format nexent-xxxxx...""" + random_part = secrets.token_hex(12) # 24 hex characters for more entropy + return f"nexent-{random_part}" + + +def create_token(access_key: str, user_id: str) -> Dict[str, Any]: + """Create a new token record in the database. + + Args: + access_key: The access key (API Key). + user_id: The user ID who owns this token. + + Returns: + Dictionary containing the created token information. + """ + with get_db_session() as session: + token = UserTokenInfo( + access_key=access_key, + user_id=user_id, + created_by=user_id, + updated_by=user_id, + delete_flag='N' + ) + session.add(token) + session.flush() + + return { + "token_id": token.token_id, + "access_key": token.access_key, + "user_id": token.user_id + } + + +def list_tokens_by_user(user_id: str) -> List[Dict[str, Any]]: + """List all active tokens for the specified user. + + Args: + user_id: The user ID to query tokens for. + + Returns: + List of token information with masked access keys. + """ + with get_db_session() as session: + tokens = session.query(UserTokenInfo).filter( + UserTokenInfo.user_id == user_id, + UserTokenInfo.delete_flag == 'N' + ).order_by(UserTokenInfo.create_time.desc()).all() + + return [ + { + "token_id": token.token_id, + "access_key": token.access_key, + "user_id": token.user_id, + "create_time": token.create_time.isoformat() if token.create_time else None + } + for token in tokens + ] + + +def get_token_by_id(token_id: int) -> UserTokenInfo: + """Get a token by its ID. + + Args: + token_id: The token ID to query. + + Returns: + UserTokenInfo object if found and active, None otherwise. + """ + with get_db_session() as session: + return session.query(UserTokenInfo).filter( + UserTokenInfo.token_id == token_id, + UserTokenInfo.delete_flag == 'N' + ).first() + + +def get_token_by_access_key(access_key: str) -> Optional[Dict[str, Any]]: + """Get a token by its access key. + + Args: + access_key: The access key to query. + + Returns: + Token information dict if found and active, None otherwise. + """ + with get_db_session() as session: + token = session.query(UserTokenInfo).filter( + UserTokenInfo.access_key == access_key, + UserTokenInfo.delete_flag == 'N' + ).first() + + if token: + return { + "token_id": token.token_id, + "access_key": token.access_key, + "user_id": token.user_id, + "delete_flag": token.delete_flag + } + return None + + +def delete_token(token_id: int, user_id: str) -> bool: + """Soft delete a token by setting delete_flag to 'Y'. + + Args: + token_id: The token ID to delete. + user_id: The user ID who owns this token (for authorization). + + Returns: + True if the token was deleted, False if not found or not owned by user. + """ + with get_db_session() as session: + token = session.query(UserTokenInfo).filter( + UserTokenInfo.token_id == token_id, + UserTokenInfo.user_id == user_id, + UserTokenInfo.delete_flag == 'N' + ).first() + + if not token: + return False + + token.delete_flag = 'Y' + token.updated_by = user_id + return True + + +def log_token_usage( + token_id: int, + call_function_name: str, + related_id: Optional[int], + created_by: str, + metadata: Optional[Dict[str, Any]] = None +) -> int: + """Log token usage to the database. + + Args: + token_id: The token ID used. + call_function_name: The API function name being called. + related_id: Related resource ID (e.g., conversation_id). + created_by: User ID who initiated the call. + metadata: Optional additional metadata for this usage log entry. + + Returns: + The created token_usage_id. + """ + with get_db_session() as session: + usage_log = UserTokenUsageLog( + token_id=token_id, + call_function_name=call_function_name, + related_id=related_id, + created_by=created_by, + metadata=metadata + ) + session.add(usage_log) + session.flush() + return usage_log.token_usage_id + + +def get_latest_usage_metadata(token_id: int, related_id: int, call_function_name: str) -> Optional[Dict[str, Any]]: + """Get the latest metadata for a given token, related_id and function name. + + Args: + token_id: The token ID used. + related_id: Related resource ID (e.g., conversation_id). + call_function_name: The API function name. + + Returns: + The metadata dict if found, None otherwise. + """ + with get_db_session() as session: + usage_log = session.query(UserTokenUsageLog).filter( + UserTokenUsageLog.token_id == token_id, + UserTokenUsageLog.related_id == related_id, + UserTokenUsageLog.call_function_name == call_function_name + ).order_by(UserTokenUsageLog.create_time.desc()).first() + + if usage_log and usage_log.metadata: + return usage_log.metadata + return None diff --git a/backend/services/northbound_service.py b/backend/services/northbound_service.py index 6f9164269..140e69a68 100644 --- a/backend/services/northbound_service.py +++ b/backend/services/northbound_service.py @@ -13,11 +13,7 @@ ) from consts.model import AgentRequest from database.conversation_db import get_conversation_messages -from database.partner_db import ( - add_mapping_id, - get_external_id_by_internal, - get_internal_id_by_external -) +from database.token_db import log_token_usage, get_latest_usage_metadata from services.agent_service import ( run_agent_stream, stop_agent_tasks, @@ -40,6 +36,7 @@ class NorthboundContext: tenant_id: str user_id: str authorization: str + token_id: int = 0 # ----------------------------- @@ -114,26 +111,6 @@ def _build_idempotency_key(*parts: Any) -> str: return ":".join(processed) -# ----------------------------- -# ID mapping helpers -# ----------------------------- -async def to_external_conversation_id(internal_id: int) -> str: - if not internal_id: - raise Exception("invalid internal conversation id") - external_id = get_external_id_by_internal(internal_id=internal_id, mapping_type="CONVERSATION") - if not external_id: - logger.error(f"cannot find external id for conversation_id: {internal_id}") - raise Exception("cannot find external id") - return external_id - - -async def to_internal_conversation_id(external_id: str) -> int: - if not external_id: - raise Exception("invalid external conversation id") - internal_id = get_internal_id_by_external(external_id=external_id, mapping_type="CONVERSATION") - return internal_id - - # ----------------------------- # Agent resolver # ----------------------------- @@ -146,30 +123,30 @@ async def get_agent_info_by_name(agent_name: str, tenant_id: str) -> int: async def start_streaming_chat( ctx: NorthboundContext, - external_conversation_id: str, + conversation_id: Optional[int], agent_name: str, query: str, + meta_data: Optional[Dict[str, Any]] = None, idempotency_key: Optional[str] = None ) -> StreamingResponse: try: # Simple rate limit await check_and_consume_rate_limit(ctx.tenant_id) - internal_conversation_id = await to_internal_conversation_id(external_conversation_id) - # Add mapping to postgres database - if internal_conversation_id is None: - logging.info(f"Conversation {external_conversation_id} not found, creating a new conversation") - # Create a new conversation and get its internal ID + # If conversation_id is not provided, create a new conversation + if conversation_id is None: + logging.info("No conversation_id provided, creating a new conversation") new_conversation = create_new_conversation(title="New Conversation", user_id=ctx.user_id) - internal_conversation_id = new_conversation["conversation_id"] - # Add the new mapping to the database - add_mapping_id(internal_id=internal_conversation_id, external_id=external_conversation_id, tenant_id=ctx.tenant_id, user_id=ctx.user_id) + conversation_id = new_conversation["conversation_id"] + logging.info(f"Created new conversation with id: {conversation_id}") + + internal_conversation_id = conversation_id # Get history according to internal_conversation_id - history_resp = await get_conversation_history(ctx, external_conversation_id) + history_resp = await get_conversation_history_internal(ctx, internal_conversation_id) agent_id = await get_agent_id_by_name(agent_name=agent_name, tenant_id=ctx.tenant_id) # Idempotency: only prevent concurrent duplicate starts - composed_key = idempotency_key or _build_idempotency_key(ctx.tenant_id, external_conversation_id, agent_id, query) + composed_key = idempotency_key or _build_idempotency_key(ctx.tenant_id, str(conversation_id), agent_id, query) await idempotency_start(composed_key) agent_request = AgentRequest( conversation_id=internal_conversation_id, @@ -192,7 +169,7 @@ async def start_streaming_chat( except UnauthorizedError as _: raise UnauthorizedError("Cannot authenticate.") except Exception as e: - raise Exception(f"Failed to start streaming chat for external conversation id {external_conversation_id}: {str(e)}") + raise Exception(f"Failed to start streaming chat for conversation_id {conversation_id}: {str(e)}") try: response = await run_agent_stream( @@ -207,34 +184,74 @@ async def start_streaming_chat( if composed_key: asyncio.create_task(_release_idempotency_after_delay(composed_key)) - # Attach request id header + # Log token usage + if ctx.token_id > 0: + try: + log_token_usage( + token_id=ctx.token_id, + call_function_name="run_chat", + related_id=conversation_id, + created_by=ctx.user_id, + metadata=meta_data + ) + except Exception as e: + logger.warning(f"Failed to log token usage: {str(e)}") + + # Attach request id header and conversation_id (internal id) response.headers["X-Request-Id"] = ctx.request_id - response.headers["conversation_id"] = external_conversation_id + response.headers["conversation_id"] = str(conversation_id) return response -async def stop_chat(ctx: NorthboundContext, external_conversation_id: str) -> Dict[str, Any]: +async def stop_chat(ctx: NorthboundContext, conversation_id: int, meta_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: try: - internal_id = await to_internal_conversation_id(external_conversation_id) - - stop_result = stop_agent_tasks(internal_id, ctx.user_id) - return {"message": stop_result.get("message", "success"), "data": external_conversation_id, "requestId": ctx.request_id} + stop_result = stop_agent_tasks(conversation_id, ctx.user_id) + + # Log token usage + if ctx.token_id > 0: + try: + log_token_usage( + token_id=ctx.token_id, + call_function_name="stop_chat_stream", + related_id=conversation_id, + created_by=ctx.user_id, + metadata=meta_data + ) + except Exception as e: + logger.warning(f"Failed to log token usage: {str(e)}") + + return {"message": stop_result.get("message", "success"), "data": conversation_id, "requestId": ctx.request_id} except Exception as e: - raise Exception(f"Failed to stop chat for external conversation id {external_conversation_id}: {str(e)}") + raise Exception(f"Failed to stop chat for conversation_id {conversation_id}: {str(e)}") async def list_conversations(ctx: NorthboundContext) -> Dict[str, Any]: conversations = get_conversation_list_service(ctx.user_id) # get_conversation_list_service is sync - for item in conversations: - item["conversation_id"] = await to_external_conversation_id(int(item["conversation_id"])) - return {"message": "success", "data": conversations, "requestId": ctx.request_id} + # Add meta_data from token usage log if available + if ctx.token_id > 0: + for item in conversations: + conversation_id = item.get("conversation_id") + if conversation_id: + try: + meta_data = get_latest_usage_metadata( + token_id=ctx.token_id, + related_id=int(conversation_id), + call_function_name="run_chat" + ) + if meta_data: + item["meta_data"] = meta_data + except Exception as e: + logger.warning(f"Failed to get meta_data for conversation {conversation_id}: {str(e)}") + + # Now return internal conversation_id directly + return {"message": "success", "data": conversations, "requestId": ctx.request_id} -async def get_conversation_history(ctx: NorthboundContext, external_conversation_id: str) -> Dict[str, Any]: - internal_id = await to_internal_conversation_id(external_conversation_id) - history = get_conversation_messages(internal_id) +async def get_conversation_history_internal(ctx: NorthboundContext, conversation_id: int) -> Dict[str, Any]: + """Internal helper to get conversation history without logging.""" + history = get_conversation_messages(conversation_id) # Remove unnecessary fields result = [] for message in history: @@ -244,44 +261,89 @@ async def get_conversation_history(ctx: NorthboundContext, external_conversation }) response = { - "conversation_id": external_conversation_id, + "conversation_id": conversation_id, "history": result } - # Ensure external id in response return {"message": "success", "data": response, "requestId": ctx.request_id} +async def get_conversation_history(ctx: NorthboundContext, conversation_id: int) -> Dict[str, Any]: + try: + # Log token usage + if ctx.token_id > 0: + try: + log_token_usage( + token_id=ctx.token_id, + call_function_name="get_conversation_history", + related_id=conversation_id, + created_by=ctx.user_id, + metadata=None + ) + except Exception as e: + logger.warning(f"Failed to log token usage: {str(e)}") + + return await get_conversation_history_internal(ctx, conversation_id) + except Exception as e: + raise Exception(f"Failed to get conversation history for conversation_id {conversation_id}: {str(e)}") + + async def get_agent_info_list(ctx: NorthboundContext) -> Dict[str, Any]: try: - agent_info_list = await list_all_agent_info_impl(tenant_id=ctx.tenant_id) + agent_info_list = await list_all_agent_info_impl(tenant_id=ctx.tenant_id, user_id=ctx.user_id) # Remove internal information that partner don't need for agent_info in agent_info_list: agent_info.pop("agent_id", None) + + # Log token usage + if ctx.token_id > 0: + try: + log_token_usage( + token_id=ctx.token_id, + call_function_name="get_agent_info_list", + related_id=None, + created_by=ctx.user_id, + metadata=None + ) + except Exception as e: + logger.warning(f"Failed to log token usage: {str(e)}") + return {"message": "success", "data": agent_info_list, "requestId": ctx.request_id} except Exception as e: raise Exception(f"Failed to get agent info list for tenant {ctx.tenant_id}: {str(e)}") -async def update_conversation_title(ctx: NorthboundContext, external_conversation_id: str, title: str, idempotency_key: Optional[str] = None) -> Dict[str, Any]: +async def update_conversation_title(ctx: NorthboundContext, conversation_id: int, title: str, meta_data: Optional[Dict[str, Any]] = None, idempotency_key: Optional[str] = None) -> Dict[str, Any]: composed_key: Optional[str] = None try: - internal_id = await to_internal_conversation_id(external_conversation_id) - # Idempotency: avoid concurrent duplicate title update for same conversation - composed_key = idempotency_key or _build_idempotency_key(ctx.tenant_id, external_conversation_id, title) + composed_key = idempotency_key or _build_idempotency_key(ctx.tenant_id, str(conversation_id), title) await idempotency_start(composed_key) - update_conversation_title_service(internal_id, title, ctx.user_id) + update_conversation_title_service(conversation_id, title, ctx.user_id) + + # Log token usage + if ctx.token_id > 0: + try: + log_token_usage( + token_id=ctx.token_id, + call_function_name="update_conversation_title", + related_id=conversation_id, + created_by=ctx.user_id, + metadata=meta_data + ) + except Exception as e: + logger.warning(f"Failed to log token usage: {str(e)}") + return { "message": "success", - "data": external_conversation_id, + "data": conversation_id, "requestId": ctx.request_id, "idempotency_key": composed_key, } except LimitExceededError as _: raise LimitExceededError("Duplicate request is still running, please wait.") except Exception as e: - raise Exception(f"Failed to update conversation title for external conversation id {external_conversation_id}: {str(e)}") + raise Exception(f"Failed to update conversation title for conversation_id {conversation_id}: {str(e)}") finally: if composed_key: asyncio.create_task(_release_idempotency_after_delay(composed_key)) diff --git a/backend/services/user_management_service.py b/backend/services/user_management_service.py index 792887ec5..3499d3170 100644 --- a/backend/services/user_management_service.py +++ b/backend/services/user_management_service.py @@ -1,6 +1,13 @@ import logging from typing import Optional, Any, Tuple, Dict, List +from database.token_db import ( + create_token as create_token_record, + generate_access_key, + list_tokens_by_user as list_tokens_by_user_record, + delete_token as delete_token_record, +) + import aiohttp from fastapi import Header from supabase import Client @@ -472,3 +479,45 @@ def format_role_permissions(permissions: List[Dict[str, Any]]) -> Dict[str, List "permissions": formatted_permissions, "accessibleRoutes": accessible_routes } + + +# ----------------------------- +# Token Management +# ----------------------------- + +def create_token(user_id: str) -> Dict[str, Any]: + """Create a new API token for the specified user. + + Args: + user_id: The user ID who owns this token. + + Returns: + Dictionary containing the API token information including token_id. + """ + access_key = generate_access_key() + return create_token_record(access_key, user_id) + + +def list_tokens_by_user(user_id: str) -> List[Dict[str, Any]]: + """List all tokens for the specified user. + + Args: + user_id: The user ID to query token pairs for. + + Returns: + List of token information with masked access keys. + """ + return list_tokens_by_user_record(user_id) + + +def delete_token(token_id: int, user_id: str) -> bool: + """Soft delete a token. + + Args: + token_id: The token ID to delete. + user_id: The user ID who owns this token (for authorization). + + Returns: + True if the token was deleted, False if not found or not owned by user. + """ + return delete_token_record(token_id, user_id) diff --git a/backend/utils/auth_utils.py b/backend/utils/auth_utils.py index a27a48b38..c614f093d 100644 --- a/backend/utils/auth_utils.py +++ b/backend/utils/auth_utils.py @@ -1,6 +1,4 @@ import logging -import hashlib -import hmac import time from datetime import datetime, timedelta from typing import Optional, Tuple @@ -20,189 +18,94 @@ DEBUG_JWT_EXPIRE_SECONDS, LANGUAGE, ) -from consts.exceptions import LimitExceededError, SignatureValidationError, UnauthorizedError +from consts.exceptions import LimitExceededError, UnauthorizedError from database.user_tenant_db import get_user_tenant_by_user_id +from database.token_db import get_token_by_access_key +from typing import Dict # Module logger logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- -# AK/SK authentication helpers (merged from aksk_auth_utils.py) +# Bearer Token (API Key) authentication # --------------------------------------------------------------------------- -# Mock AK/SK configuration (replace with DB/config lookup in production) -MOCK_ACCESS_KEY = "mock_access_key_12345" -MOCK_SECRET_KEY = "mock_secret_key_67890abcdef" -MOCK_JWT_SECRET_KEY = "mock_jwt_secret_key_67890abcdef" -# Timestamp validity window in seconds (prevent replay attacks) -TIMESTAMP_VALIDITY_WINDOW = 300 - - -def get_aksk_config(tenant_id: str) -> Tuple[str, str]: +def validate_bearer_token(authorization: Optional[str]) -> Tuple[bool, Optional[dict]]: """ - Get AK/SK configuration according to tenant_id - - Returns: - Tuple[str, str]: (access_key, secret_key) - """ - - # TODO: get ak/sk according to tenant_id from DB - return MOCK_ACCESS_KEY, MOCK_SECRET_KEY - - -def validate_timestamp(timestamp: str) -> bool: - """ - Validate timestamp is within validity window + Validate Bearer token (API Key) from Authorization header. Args: - timestamp: timestamp string + authorization: Authorization header value (e.g., "Bearer nexent-xxxxx") Returns: - bool: whether timestamp is valid + Tuple of (is_valid, token_info_dict) + - is_valid: True if token exists and is active + - token_info: Token information dict if valid, None otherwise """ - try: - timestamp_int = int(timestamp) - current_time = int(time.time()) + if not authorization: + logger.warning("No authorization header provided") + return False, None - if abs(current_time - timestamp_int) > TIMESTAMP_VALIDITY_WINDOW: - logger.warning( - f"Timestamp validation failed: current={current_time}, provided={timestamp_int}" - ) - return False + # Extract token from "Bearer " format + token = authorization.replace("Bearer ", "") if authorization.startswith("Bearer ") else authorization - return True - except (ValueError, TypeError) as e: - logger.error(f"Invalid timestamp format: {timestamp}, error: {e}") - return False - - -def calculate_hmac_signature(secret_key: str, access_key: str, timestamp: str, request_body: str = "") -> str: - """ - Calculate HMAC-SHA256 signature + if not token: + logger.warning("Empty bearer token") + return False, None - Args: - secret_key: secret key - access_key: access key - timestamp: timestamp - request_body: request body (optional) - - Returns: - str: HMAC-SHA256 signature (hex string) - """ - string_to_sign = f"{access_key}{timestamp}{request_body}" - signature = hmac.new( - secret_key.encode("utf-8"), - string_to_sign.encode("utf-8"), - hashlib.sha256, - ).hexdigest() - return signature - - -def verify_aksk_signature( - access_key: str, timestamp: str, signature: str, request_body: str = "" -) -> bool: - """ - Validate AK/SK signature - - Args: - access_key: access key - timestamp: timestamp - signature: provided signature - request_body: request body (optional) - - Returns: - bool: whether signature is valid - """ + # Look up token in database try: - if not validate_timestamp(timestamp): - raise SignatureValidationError("Timestamp is invalid or expired") - - # TODO: get ak/sk according to tenant_id from DB - mock_access_key, mock_secret_key = get_aksk_config( - tenant_id="tenant_id") - - if access_key != mock_access_key: - logger.warning(f"Invalid access key: {access_key}") - return False - - expected_signature = calculate_hmac_signature( - mock_secret_key, access_key, timestamp, request_body - ) - - if not hmac.compare_digest(signature, expected_signature): - logger.warning( - f"Signature mismatch: expected={expected_signature}, provided={signature}" - ) - return False - - return True + token_info = get_token_by_access_key(token) + if token_info and token_info.get("delete_flag") != "Y": + logger.debug(f"Token validated successfully for user {token_info.get('user_id')}") + return True, token_info + else: + logger.warning(f"Invalid or inactive token: {token[:20]}...") + return False, None except Exception as e: - logger.error(f"Error during signature verification: {e}") - return False + logger.error(f"Error validating bearer token: {str(e)}") + return False, None -def extract_aksk_headers(headers: dict) -> Tuple[str, str, str]: +def get_user_and_tenant_by_access_key(access_key: str) -> Dict[str, str]: """ - Extract AK/SK related information from request headers + Get user_id and tenant_id from access_key by querying user_token_info_t and user_tenant_t. Args: - headers: request headers dictionary + access_key: The access key (API Key) from the Authorization header. Returns: - Tuple[str, str, str]: (access_key, timestamp, signature) + Dict containing user_id and tenant_id. Raises: - UnauthorizedError: when required headers are missing + UnauthorizedError: If the access key is not found or invalid. """ - - def get_header(headers: dict, name: str) -> Optional[str]: - for k, v in headers.items(): - if k.lower() == name.lower(): - return v - return None - - access_key = get_header(headers, "X-Access-Key") - timestamp = get_header(headers, "X-Timestamp") - signature = get_header(headers, "X-Signature") - if not access_key: - raise UnauthorizedError("Missing X-Access-Key header") - if not timestamp: - raise UnauthorizedError("Missing X-Timestamp header") - if not signature: - raise UnauthorizedError("Missing X-Signature header") - - return access_key, timestamp, signature - - -def validate_aksk_authentication(headers: dict, request_body: str = "") -> bool: - """ - Validate AK/SK authentication - - Args: - headers: request headers dictionary - request_body: request body (optional) - - Returns: - bool: whether authentication is successful - - Raises: - UnauthorizedError: when authentication fails - SignatureValidationError: when signature verification fails - """ - try: - access_key, timestamp, signature = extract_aksk_headers(headers) - - if not verify_aksk_signature(access_key, timestamp, signature, request_body): - raise SignatureValidationError("Invalid signature") - - return True - except (UnauthorizedError, SignatureValidationError, LimitExceededError) as e: - raise e - except Exception as e: - logger.error(f"Unexpected error during AK/SK authentication: {e}") - raise UnauthorizedError("Authentication failed") + raise UnauthorizedError("Invalid access key") + + # Query token from user_token_info_t + token_info = get_token_by_access_key(access_key) + if not token_info or token_info.get("delete_flag") == "Y": + raise UnauthorizedError("Invalid or inactive access key") + + user_id = token_info.get("user_id") + if not user_id: + raise UnauthorizedError("No user associated with this access key") + + # Query tenant from user_tenant_t + user_tenant_record = get_user_tenant_by_user_id(user_id) + if user_tenant_record and user_tenant_record.get("tenant_id"): + tenant_id = user_tenant_record["tenant_id"] + else: + tenant_id = DEFAULT_TENANT_ID + logger.warning(f"No tenant relationship found for user {user_id}, using default tenant") + + return { + "user_id": user_id, + "tenant_id": tenant_id, + "token_id": token_info.get("token_id") + } def get_supabase_client(): diff --git a/docker/sql/v1.8.0.3_0306_add_user_token_info.sql b/docker/sql/v1.8.0.3_0306_add_user_token_info.sql new file mode 100644 index 000000000..b8f731fbf --- /dev/null +++ b/docker/sql/v1.8.0.3_0306_add_user_token_info.sql @@ -0,0 +1,112 @@ +-- Migration: Add user_token_info_t and user_token_usage_log_t tables +-- Date: 2026-03-06 +-- Description: Create user token (AK/SK) management tables with audit fields + +-- Set search path to nexent schema +SET search_path TO nexent; + +-- Create the user_token_info_t table in the nexent schema +CREATE TABLE IF NOT EXISTS nexent.user_token_info_t ( + token_id SERIAL4 PRIMARY KEY NOT NULL, + access_key VARCHAR(100) NOT NULL, + user_id VARCHAR(100) NOT NULL, + create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N' +); + +ALTER TABLE "user_token_info_t" OWNER TO "root"; + +-- Add comment to the table +COMMENT ON TABLE nexent.user_token_info_t IS 'User token (AK/SK) information table'; + +-- Add comments to the columns +COMMENT ON COLUMN nexent.user_token_info_t.token_id IS 'Token ID, unique primary key'; +COMMENT ON COLUMN nexent.user_token_info_t.access_key IS 'Access Key (AK)'; +COMMENT ON COLUMN nexent.user_token_info_t.user_id IS 'User ID who owns this token'; +COMMENT ON COLUMN nexent.user_token_info_t.create_time IS 'Creation time, audit field'; +COMMENT ON COLUMN nexent.user_token_info_t.update_time IS 'Update time, audit field'; +COMMENT ON COLUMN nexent.user_token_info_t.created_by IS 'Creator ID, audit field'; +COMMENT ON COLUMN nexent.user_token_info_t.updated_by IS 'Last updater ID, audit field'; +COMMENT ON COLUMN nexent.user_token_info_t.delete_flag IS 'Soft delete flag, Y means deleted'; + +-- Create unique index on access_key to ensure uniqueness +CREATE UNIQUE INDEX IF NOT EXISTS idx_user_token_info_access_key ON nexent.user_token_info_t(access_key) WHERE delete_flag = 'N'; + +-- Create index on user_id for query performance +CREATE INDEX IF NOT EXISTS idx_user_token_info_user_id ON nexent.user_token_info_t(user_id) WHERE delete_flag = 'N'; + +-- Create a function to update the update_time column +CREATE OR REPLACE FUNCTION update_user_token_info_update_time() +RETURNS TRIGGER AS $$ +BEGIN + NEW.update_time = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Add comment to the function +COMMENT ON FUNCTION update_user_token_info_update_time() IS 'Function to update the update_time column when a record in user_token_info_t is updated'; + +-- Create a trigger to call the function before each update +DROP TRIGGER IF EXISTS update_user_token_info_update_time_trigger ON nexent.user_token_info_t; +CREATE TRIGGER update_user_token_info_update_time_trigger +BEFORE UPDATE ON nexent.user_token_info_t +FOR EACH ROW +EXECUTE FUNCTION update_user_token_info_update_time(); + +-- Add comment to the trigger +COMMENT ON TRIGGER update_user_token_info_update_time_trigger ON nexent.user_token_info_t IS 'Trigger to call update_user_token_info_update_time function before each update on user_token_info_t table'; + + +-- Create the user_token_usage_log_t table in the nexent schema +CREATE TABLE IF NOT EXISTS nexent.user_token_usage_log_t ( + token_usage_id SERIAL4 PRIMARY KEY NOT NULL, + token_id INT4 NOT NULL, + call_function_name VARCHAR(100), + related_id INT4, + meta_data JSONB, + create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + created_by VARCHAR(100) +); + +ALTER TABLE "user_token_usage_log_t" OWNER TO "root"; + +-- Add comment to the table +COMMENT ON TABLE nexent.user_token_usage_log_t IS 'User token usage log table'; + +-- Add comments to the columns +COMMENT ON COLUMN nexent.user_token_usage_log_t.token_usage_id IS 'Token usage log ID, unique primary key'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.token_id IS 'Foreign key to user_token_info_t.token_id'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.call_function_name IS 'API function name being called'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.related_id IS 'Related resource ID (e.g., conversation_id)'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.meta_data IS 'Additional metadata for this usage log entry, stored as JSON'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.create_time IS 'Creation time, audit field'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.created_by IS 'Creator ID, audit field'; + +-- Create index on token_id for query performance +CREATE INDEX IF NOT EXISTS idx_user_token_usage_log_token_id ON nexent.user_token_usage_log_t(token_id); + +-- Create index on call_function_name for query performance +CREATE INDEX IF NOT EXISTS idx_user_token_usage_log_function_name ON nexent.user_token_usage_log_t(call_function_name); + +-- Add foreign key constraint +ALTER TABLE nexent.user_token_usage_log_t +ADD CONSTRAINT fk_user_token_usage_log_token_id +FOREIGN KEY (token_id) +REFERENCES nexent.user_token_info_t(token_id) +ON DELETE CASCADE; + + +-- Migration: Remove partner_mapping_id_t table for northbound conversation ID mapping +-- Date: 2026-03-10 +-- Description: Remove the external-internal conversation ID mapping table as northbound APIs now use internal conversation IDs directly +-- Note: This table is no longer needed after refactoring northbound authentication logic + +-- Drop the partner_mapping_id_t table if it exists +DROP TABLE IF EXISTS nexent.partner_mapping_id_t CASCADE; + +-- Drop the associated sequence if it exists +DROP SEQUENCE IF EXISTS nexent.partner_mapping_id_t_id_seq; diff --git a/frontend/app/[locale]/users/components/UserProfileComp.tsx b/frontend/app/[locale]/users/components/UserProfileComp.tsx index 6d45b4db0..2a66bd89e 100644 --- a/frontend/app/[locale]/users/components/UserProfileComp.tsx +++ b/frontend/app/[locale]/users/components/UserProfileComp.tsx @@ -1,6 +1,6 @@ "use client"; -import React, { useState } from "react"; +import React, { useState, useEffect } from "react"; import { Button, Typography, @@ -25,6 +25,9 @@ import { Edit, Key, ChevronRight, + KeySquare, + KeyRound, + Copy, } from "lucide-react"; import { USER_ROLES } from "@/const/modelConfig"; import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; @@ -32,6 +35,12 @@ import { useAuthenticationContext } from "@/components/providers/AuthenticationP import { useGroupList } from "@/hooks/group/useGroupList"; import { useMemo } from "react"; import { DeleteAccountModal } from "@/components/auth/DeleteAccountModal"; +import log from "@/lib/logger"; +import { + getUserTokens, + deleteUserToken, + createUserToken, +} from "@/services/tokenService"; /** * UserProfileComp - User profile and account settings component @@ -77,6 +86,12 @@ export default function UserProfileComp() { const [isPasswordModalOpen, setIsPasswordModalOpen] = useState(false); const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false); + // AK/SK state + const [akInfo, setAkInfo] = useState(null); + const [existingTokenIds, setExistingTokenIds] = useState([]); + const [isLoadingAkSk, setIsLoadingAkSk] = useState(false); + const [isGeneratingAkSk, setIsGeneratingAkSk] = useState(false); + // Form instances const [editForm] = Form.useForm(); const [passwordForm] = Form.useForm(); @@ -121,6 +136,58 @@ export default function UserProfileComp() { } }; + // Fetch AK/SK info on mount + useEffect(() => { + const fetchAkSkInfo = async () => { + if (!user?.id) return; + setIsLoadingAkSk(true); + try { + const tokens = await getUserTokens(user.id); + if (tokens.length > 0) { + setAkInfo(tokens[0].access_key); + setExistingTokenIds(tokens.map((t) => t.token_id)); + } + } catch (error) { + log.error("Failed to fetch AK/SK info:", error); + } finally { + setIsLoadingAkSk(false); + } + }; + + fetchAkSkInfo(); + }, [user?.id]); + + // Handle generate AK/SK: delete existing tokens first, then create a new one + const handleGenerateAkSk = async () => { + setIsGeneratingAkSk(true); + try { + for (const tokenId of existingTokenIds) { + await deleteUserToken(tokenId); + } + + const newToken = await createUserToken(); + setAkInfo(newToken.access_key); + setExistingTokenIds([newToken.token_id]); + antdMessage.success(t("profile.generateAkSkSuccess") || "Access key generated successfully"); + } catch (error) { + antdMessage.error(t("profile.generateAkSkFailed") || "Failed to generate access key"); + } finally { + setIsGeneratingAkSk(false); + } + }; + + // Handle copy AK to clipboard + const handleCopyAk = async () => { + if (akInfo) { + try { + await navigator.clipboard.writeText(akInfo); + antdMessage.success(t("profile.copyAkSuccess") || "Access key copied to clipboard"); + } catch (error) { + antdMessage.error(t("profile.copyAkFailed") || "Failed to copy access key"); + } + } + }; + // Open edit modal // const openEditModal = () => { // editForm.setFieldsValue({ @@ -272,7 +339,7 @@ export default function UserProfileComp() { >
- +
@@ -286,6 +353,86 @@ export default function UserProfileComp() {
+ {/* Generate Access Token Option */} +
{ + if (akInfo) { + Modal.confirm({ + title: t("profile.generateAkSkConfirmTitle") || "Generate New Access Key", + content: t("profile.generateAkSkConfirmContent") || "You already have an access key. Generating a new one will overwrite the existing key. Continue?", + okText: t("common.confirm") || "Confirm", + cancelText: t("common.cancel") || "Cancel", + onOk: handleGenerateAkSk, + okButtonProps: { loading: isGeneratingAkSk }, + }); + } else { + handleGenerateAkSk(); + } + }} + > +
+
+ +
+
+
+ {t("profile.generateAkSk") || "Generate Access Token"} +
+ {akInfo ? ( +
+ + {akInfo} + +
+ ) : ( +
+ {t("profile.generateAkSkDesc") || "Create or regenerate your API access key"} +
+ )} +
+
+ +
+ - - - - + + + + + + + + + + + {/* Data protection notice - only shown in full version */} @@ -207,7 +211,7 @@ function FeatureCard({ icon, title, description }: FeatureCardProps) { {icon}
-

+

{title}

From dcb6503badf550d64b5ac6a5aa42fd051994178d Mon Sep 17 00:00:00 2001 From: "XUYAQIDE\\xuyaq" Date: Tue, 10 Mar 2026 19:14:57 +0800 Subject: [PATCH 51/75] Bugfix: make auth prompt modal open in full version --- frontend/hooks/auth/useAuthenticationUI.ts | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/frontend/hooks/auth/useAuthenticationUI.ts b/frontend/hooks/auth/useAuthenticationUI.ts index c158b2d8d..8891790e6 100644 --- a/frontend/hooks/auth/useAuthenticationUI.ts +++ b/frontend/hooks/auth/useAuthenticationUI.ts @@ -80,16 +80,6 @@ export function useAuthenticationUI({ handleUnauthenticatedModalClose(); }, [handleUnauthenticatedModalClose]); - /** - * Check if current path is home page - * Home page paths: "/", "/zh", "/en" - */ - const isLocaleHomePath = (path?: string | null) => { - if (!path) return false; - const segments = path.split("/").filter(Boolean); - return segments.length <= 1; - }; - useEffect(() => { if (isSpeedMode) return; @@ -131,14 +121,7 @@ export function useAuthenticationUI({ if (isSessionExpiredModalOpen) return; if (isLoginModalOpen) return; if (isRegisterModalOpen) return; - // Skip if already on home page - if (isLocaleHomePath(pathname)) return; - - // For unauthenticated users accessing protected routes, show auth prompt - const effectivePath = getEffectiveRoutePath(pathname); - if (effectivePath !== "/") { - openAuthPromptModal(); - } + openAuthPromptModal(); }, [pathname, isAuthenticated, isSpeedMode, isAuthChecking, isSessionExpiredModalOpen, openAuthPromptModal]); From 174b0b069f3e12e58387987c349156f8053969d5 Mon Sep 17 00:00:00 2001 From: biansimeng Date: Tue, 10 Mar 2026 19:44:04 +0800 Subject: [PATCH 52/75] Add test case --- test/sdk/core/tools/test_exa_search_tool.py | 12 - .../sdk/core/tools/test_tavily_search_tool.py | 250 ++++++++++++++++++ 2 files changed, 250 insertions(+), 12 deletions(-) create mode 100644 test/sdk/core/tools/test_tavily_search_tool.py diff --git a/test/sdk/core/tools/test_exa_search_tool.py b/test/sdk/core/tools/test_exa_search_tool.py index ad5b15339..846fcb84b 100644 --- a/test/sdk/core/tools/test_exa_search_tool.py +++ b/test/sdk/core/tools/test_exa_search_tool.py @@ -5,28 +5,16 @@ from datetime import datetime # Create all necessary mocks -mock_tavily_client = MagicMock() -mock_tavily = MagicMock() -mock_tavily.TavilyClient = mock_tavily_client - mock_exa = MagicMock() mock_exa_client = MagicMock() mock_exa.Exa = mock_exa_client -mock_linkup = MagicMock() -mock_linkup_client = MagicMock() -mock_linkup.LinkupClient = mock_linkup_client -mock_linkup.LinkupSearchImageResult = MagicMock() -mock_linkup.LinkupSearchTextResult = MagicMock() - mock_aiohttp = MagicMock() mock_aiohttp.ClientSession = MagicMock() # Use module-level mocks module_mocks = { - 'tavily': mock_tavily, 'exa_py': mock_exa, - 'linkup': mock_linkup, 'aiohttp': mock_aiohttp } diff --git a/test/sdk/core/tools/test_tavily_search_tool.py b/test/sdk/core/tools/test_tavily_search_tool.py new file mode 100644 index 000000000..6d157b10e --- /dev/null +++ b/test/sdk/core/tools/test_tavily_search_tool.py @@ -0,0 +1,250 @@ +import pytest +from unittest.mock import MagicMock, patch +import json +import os +from datetime import datetime + +# Create all necessary mocks +mock_tavily_client = MagicMock() +mock_tavily = MagicMock() +mock_tavily.TavilyClient = mock_tavily_client + +mock_aiohttp = MagicMock() +mock_aiohttp.ClientSession = MagicMock() + +# Use module-level mocks +module_mocks = { + 'tavily': mock_tavily, + 'aiohttp': mock_aiohttp +} + +# Apply mocks +with patch.dict('sys.modules', module_mocks): + # Import all required modules + from sdk.nexent.core.utils.observer import MessageObserver, ProcessType + # Import target module + from sdk.nexent.core.tools.tavily_search_tool import TavilySearchTool + + +@pytest.fixture +def mock_observer(): + observer = MagicMock(spec=MessageObserver) + observer.lang = "en" + return observer + + +@pytest.fixture +def tavily_search_tool(mock_observer): + # Reset all mock objects + mock_tavily_client.reset_mock() + + tavily_api_key = "test_api_key" + with patch('tavily.TavilyClient', return_value=mock_tavily_client): + tool = TavilySearchTool( + tavily_api_key=tavily_api_key, + observer=mock_observer, + max_results=3, + image_filter=True + ) + + # Directly set a mock object for tool.tavily + tool.tavily = mock_tavily_client + + # Set environment variables + os.environ["DATA_PROCESS_SERVICE"] = "http://test-service" + tool.data_process_service = "http://test-service" + + return tool + + +def create_mock_tavily_search_result(count=3): + """Helper method to create mock Tavily search results""" + results = [] + for i in range(count): + result = { + "title": f"Test Title {i}", + "url": f"https://example.com/{i}", + "content": f"This is test content {i}", + "published_date": datetime.now().isoformat(), + "score": 0.9 - i * 0.1 + } + results.append(result) + + mock_response = { + "results": results, + "images": [f"https://example.com/image{i}.jpg" for i in range(count)] + } + return mock_response + + +def test_forward_with_results(tavily_search_tool, mock_observer): + """Test forward method with search results""" + # Configure mock + mock_results = create_mock_tavily_search_result(3) + mock_tavily_client.search.return_value = mock_results + + # Mock _filter_images method to prevent creating unawaited coroutines + with patch.object(tavily_search_tool, '_filter_images'): + # Call method + result = tavily_search_tool.forward("test query") + + # Print actual JSON structure to help with understanding + search_results = json.loads(result) + print(f"\nActual search result structure: {json.dumps(search_results[0], indent=2)}") + + # Assertions + mock_tavily_client.search.assert_called_once_with( + query="test query", + max_results=3, + include_images=True + ) + + # Check observer messages + mock_observer.add_message.assert_any_call("", ProcessType.TOOL, "Searching the web...") + mock_observer.add_message.assert_any_call("", ProcessType.CARD, + json.dumps([{"icon": "search", "text": "test query"}], + ensure_ascii=False)) + + # Verify search results were processed + assert len(search_results) == 3 + + # Check that the returned JSON structure contains expected fields + first_result = search_results[0] + assert "title" in first_result + assert first_result["title"] == "Test Title 0" + + # Check all keys to understand the actual structure + keys = first_result.keys() + print(f"\nAvailable keys in result: {keys}") + + # Check if text field exists + assert "text" in first_result + assert first_result["text"].startswith("This is test content") + + # If there's a cite_index field, verify it as well + if "cite_index" in first_result: + assert isinstance(first_result["cite_index"], int) + + +def test_forward_no_results(tavily_search_tool): + """Test forward method with no search results""" + # Configure empty results mock + mock_response = { + "results": [], + "images": [] + } + mock_tavily_client.search.return_value = mock_response + + # Call method and check for exception + with pytest.raises(Exception) as excinfo: + tavily_search_tool.forward("test query") + + assert 'No results found' in str(excinfo.value) + + +def test_forward_without_observer(tavily_search_tool): + """Test forward method without an observer""" + # Mock _filter_images method to prevent creating unawaited coroutines + with patch.object(tavily_search_tool, '_filter_images'), \ + patch.object(TavilySearchTool, 'forward', wraps=tavily_search_tool.forward) as wrapped_forward: + # Directly set observer to None + # Note: This is not recommended in production code, only for testing + wrapped_forward.__defaults__ = (None,) + + # Configure mock and call method + mock_results = create_mock_tavily_search_result(2) + mock_tavily_client.search.return_value = mock_results + + # Call method with parameters directly + result = wrapped_forward("test query") + + # Verify results were processed + search_results = json.loads(result) + assert len(search_results) == 2 + + # Verify Tavily search was called + mock_tavily_client.search.assert_called_with( + query="test query", + max_results=3, + include_images=True + ) + + +def test_chinese_language_observer(tavily_search_tool, mock_observer): + """Test Chinese language observer""" + # Set observer language to Chinese + mock_observer.lang = "zh" + + # Mock _filter_images method to prevent creating unawaited coroutines + with patch.object(tavily_search_tool, '_filter_images'): + # Configure mock + mock_results = create_mock_tavily_search_result(1) + mock_tavily_client.search.return_value = mock_results + + # Call method + tavily_search_tool.forward("测试查询") + + # Check Chinese running prompt + mock_observer.add_message.assert_any_call("", ProcessType.TOOL, "网络搜索中...") + + +def test_filter_images_success(tavily_search_tool, mock_observer): + """Test successful image filtering""" + # Set up test data + images_list = ["https://example.com/image1.jpg", "https://example.com/image2.jpg"] + + # Mock _filter_images method + with patch.object(tavily_search_tool, '_filter_images') as mock_filter: + # Configure mock + mock_results = create_mock_tavily_search_result(1) + mock_tavily_client.search.return_value = mock_results + + # Call forward method, which indirectly calls _filter_images + tavily_search_tool.forward("test query") + + # Verify _filter_images was called with correct parameters + mock_filter.assert_called_once() + # Extract the first argument of the call + called_images = mock_filter.call_args[0][0] + assert isinstance(called_images, list) + + +def test_filter_images_api_error(tavily_search_tool, mock_observer): + """Test image filtering API error handling""" + # Set up test data + images_list = ["https://example.com/image1.jpg"] + + # Send message directly to observer, simulating _filter_images behavior + tavily_search_tool._filter_images = lambda img_list, query: mock_observer.add_message( + "", ProcessType.PICTURE_WEB, json.dumps({"images_url": img_list}, ensure_ascii=False) + ) + + # Configure mock + mock_results = create_mock_tavily_search_result(1) + mock_tavily_client.search.return_value = mock_results + + # Call method + tavily_search_tool.forward("test query") + + # Verify observer was called with unfiltered images + mock_observer.add_message.assert_any_call("", ProcessType.PICTURE_WEB, + json.dumps({"images_url": ["https://example.com/image0.jpg"]}, + ensure_ascii=False)) + + +def test_image_filter_disabled(tavily_search_tool, mock_observer): + """Test behavior when image filtering is disabled""" + # Disable image filtering + tavily_search_tool.image_filter = False + + # Configure mock + mock_results = create_mock_tavily_search_result(1) + mock_tavily_client.search.return_value = mock_results + + # Call method + tavily_search_tool.forward("test query") + + # Verify images were sent to observer without filtering + expected_images = ["https://example.com/image0.jpg"] + mock_observer.add_message.assert_any_call("", ProcessType.PICTURE_WEB, + json.dumps({"images_url": expected_images}, ensure_ascii=False)) From 58ca3feeca65d1b4c1f757078caecff6cd0bb9eb Mon Sep 17 00:00:00 2001 From: panyehong <2655992392@qq.com> Date: Wed, 11 Mar 2026 11:52:49 +0800 Subject: [PATCH 53/75] =?UTF-8?q?=E2=9C=A8=20Feature:=20idata=20search=20t?= =?UTF-8?q?ool=20development=20#2666=20[Specification=20Details]=201.=20Ad?= =?UTF-8?q?d=20the=20idata=5Fsearch=20API=20and=20tool=20to=20the=20backen?= =?UTF-8?q?d.=20Call=20the=20API=20to=20retrieve=20the=20knowledge=20space?= =?UTF-8?q?=20and=20knowledge=20base.=20After=20selecting=20the=20knowledg?= =?UTF-8?q?e=20base,=20save=20the=20tool=20configuration.=202.=20Add=20tes?= =?UTF-8?q?t=20cases.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/apps/config_app.py | 2 + backend/apps/idata_app.py | 109 ++ backend/consts/error_code.py | 14 + backend/services/idata_service.py | 359 ++++++ .../components/agentConfig/ToolManagement.tsx | 4 +- .../agentConfig/tool/ToolConfigModal.tsx | 242 +++- .../KnowledgeBaseSelectorModal.tsx | 4 +- frontend/components/tool-config/index.ts | 10 +- .../useKnowledgeBaseConfigChangeHandler.ts | 58 +- frontend/hooks/useKnowledgeBaseSelector.ts | 71 ++ frontend/services/api.ts | 4 + frontend/services/knowledgeBaseService.ts | 138 ++ sdk/nexent/core/tools/__init__.py | 2 + sdk/nexent/core/tools/idata_search_tool.py | 355 ++++++ sdk/nexent/core/utils/tools_common_message.py | 10 +- test/backend/app/test_config_app.py | 13 +- test/backend/app/test_idata_app.py | 545 ++++++++ test/backend/consts/test_error_code.py | 82 +- test/backend/database/test_tool_db.py | 34 +- test/backend/services/test_idata_service.py | 976 +++++++++++++++ test/sdk/core/tools/test_idata_search_tool.py | 1107 +++++++++++++++++ 21 files changed, 4103 insertions(+), 36 deletions(-) create mode 100644 backend/apps/idata_app.py create mode 100644 backend/services/idata_service.py create mode 100644 sdk/nexent/core/tools/idata_search_tool.py create mode 100644 test/backend/app/test_idata_app.py create mode 100644 test/backend/services/test_idata_service.py create mode 100644 test/sdk/core/tools/test_idata_search_tool.py diff --git a/backend/apps/config_app.py b/backend/apps/config_app.py index fb6a0a4f0..58e2b008b 100644 --- a/backend/apps/config_app.py +++ b/backend/apps/config_app.py @@ -6,6 +6,7 @@ from apps.datamate_app import router as datamate_router from apps.vectordatabase_app import router as vectordatabase_router from apps.dify_app import router as dify_router +from apps.idata_app import router as idata_router from apps.file_management_app import file_management_config_router as file_manager_router from apps.image_app import router as proxy_router from apps.knowledge_summary_app import router as summary_router @@ -39,6 +40,7 @@ app.include_router(proxy_router) app.include_router(tool_config_router) app.include_router(dify_router) +app.include_router(idata_router) # Choose user management router based on IS_SPEED_MODE if IS_SPEED_MODE: diff --git a/backend/apps/idata_app.py b/backend/apps/idata_app.py new file mode 100644 index 000000000..278c1b60f --- /dev/null +++ b/backend/apps/idata_app.py @@ -0,0 +1,109 @@ +""" +iData App Layer +FastAPI endpoints for iData knowledge space operations. + +This module provides API endpoints to interact with iData's API, +including fetching knowledge spaces and transforming responses to a format +compatible with the frontend. +""" +import logging +from http import HTTPStatus + +from fastapi import APIRouter, Query +from fastapi.responses import JSONResponse + +from consts.error_code import ErrorCode +from consts.exceptions import AppException +from services.idata_service import ( + fetch_idata_knowledge_spaces_impl, + fetch_idata_datasets_impl, +) + +router = APIRouter(prefix="/idata") +logger = logging.getLogger("idata_app") + + +@router.get("/knowledge-space") +async def fetch_idata_knowledge_spaces_api( + idata_api_base: str = Query(..., description="iData API base URL"), + api_key: str = Query(..., description="iData API key"), + user_id: str = Query(..., description="iData user ID"), +): + """ + Fetch knowledge spaces from iData API. + + Returns knowledge spaces in a format with id and name for frontend compatibility. + """ + try: + # Normalize URL by removing trailing slash + idata_api_base = idata_api_base.rstrip('/') + except Exception as e: + logger.error(f"Invalid iData configuration: {e}") + raise AppException( + ErrorCode.IDATA_CONFIG_INVALID, + f"Invalid URL format: {str(e)}" + ) + + try: + result = fetch_idata_knowledge_spaces_impl( + idata_api_base=idata_api_base, + api_key=api_key, + user_id=user_id, + ) + return JSONResponse( + status_code=HTTPStatus.OK, + content=result + ) + except AppException: + # Re-raise AppException to be handled by global middleware + raise + except Exception as e: + logger.error(f"Failed to fetch iData knowledge spaces: {e}") + raise AppException( + ErrorCode.IDATA_SERVICE_ERROR, + f"Failed to fetch iData knowledge spaces: {str(e)}" + ) + + +@router.get("/datasets") +async def fetch_idata_datasets_api( + idata_api_base: str = Query(..., description="iData API base URL"), + api_key: str = Query(..., description="iData API key"), + user_id: str = Query(..., description="iData user ID"), + knowledge_space_id: str = Query(..., description="Knowledge space ID"), +): + """ + Fetch datasets (knowledge bases) from iData API. + + Returns knowledge bases in a format consistent with DataMate for frontend compatibility. + """ + try: + # Normalize URL by removing trailing slash + idata_api_base = idata_api_base.rstrip('/') + except Exception as e: + logger.error(f"Invalid iData configuration: {e}") + raise AppException( + ErrorCode.IDATA_CONFIG_INVALID, + f"Invalid URL format: {str(e)}" + ) + + try: + result = fetch_idata_datasets_impl( + idata_api_base=idata_api_base, + api_key=api_key, + user_id=user_id, + knowledge_space_id=knowledge_space_id, + ) + return JSONResponse( + status_code=HTTPStatus.OK, + content=result + ) + except AppException: + # Re-raise AppException to be handled by global middleware + raise + except Exception as e: + logger.error(f"Failed to fetch iData datasets: {e}") + raise AppException( + ErrorCode.IDATA_SERVICE_ERROR, + f"Failed to fetch iData datasets: {str(e)}" + ) diff --git a/backend/consts/error_code.py b/backend/consts/error_code.py index 7affd2b2f..072243de4 100644 --- a/backend/consts/error_code.py +++ b/backend/consts/error_code.py @@ -164,6 +164,14 @@ class ErrorCode(Enum): # 03 - ME Service ME_CONNECTION_FAILED = "130301" # ME service connection failed + # 04 - iData Service + IDATA_SERVICE_ERROR = "130401" # iData service error + IDATA_CONFIG_INVALID = "130402" # Invalid iData configuration + IDATA_CONNECTION_ERROR = "130403" # iData connection error + IDATA_AUTH_ERROR = "130404" # iData auth error + IDATA_RATE_LIMIT = "130405" # iData rate limit + IDATA_RESPONSE_ERROR = "130406" # iData response error + # ==================== 14 Northbound / 北向接口 ==================== # 01 - Request NORTHBOUND_REQUEST_FAILED = "140101" # Northbound request failed @@ -223,4 +231,10 @@ class ErrorCode(Enum): ErrorCode.DIFY_CONNECTION_ERROR: 502, ErrorCode.DIFY_RESPONSE_ERROR: 502, ErrorCode.DIFY_RATE_LIMIT: 429, + # iData (module 13) + ErrorCode.IDATA_CONFIG_INVALID: 400, + ErrorCode.IDATA_AUTH_ERROR: 401, + ErrorCode.IDATA_CONNECTION_ERROR: 502, + ErrorCode.IDATA_RESPONSE_ERROR: 502, + ErrorCode.IDATA_RATE_LIMIT: 429, } diff --git a/backend/services/idata_service.py b/backend/services/idata_service.py new file mode 100644 index 000000000..691130dc0 --- /dev/null +++ b/backend/services/idata_service.py @@ -0,0 +1,359 @@ +""" +iData Service Layer +Handles API calls to iData for knowledge space operations. + +This service layer provides functionality to interact with iData's API, +including fetching knowledge spaces and transforming responses +to a format compatible with the frontend. +""" +import json +import logging +from typing import Any, Dict, List + +import httpx + +from consts.error_code import ErrorCode +from consts.exceptions import AppException +from nexent.utils.http_client_manager import http_client_manager + +logger = logging.getLogger("idata_service") + + +def _validate_idata_base_params( + idata_api_base: str, + api_key: str, + user_id: str, +) -> None: + """ + Validate common iData API parameters. + + Args: + idata_api_base: iData API base URL + api_key: iData API key + user_id: iData user ID + + Raises: + AppException: If any parameter is invalid + """ + if not idata_api_base or not isinstance(idata_api_base, str): + raise AppException( + ErrorCode.IDATA_CONFIG_INVALID, + "iData API URL is required and must be a non-empty string" + ) + + if not (idata_api_base.startswith("http://") or idata_api_base.startswith("https://")): + raise AppException( + ErrorCode.IDATA_CONFIG_INVALID, + "iData API URL must start with http:// or https://" + ) + + if not api_key or not isinstance(api_key, str): + raise AppException( + ErrorCode.IDATA_CONFIG_INVALID, + "iData API key is required and must be a non-empty string" + ) + + if not user_id or not isinstance(user_id, str): + raise AppException( + ErrorCode.IDATA_CONFIG_INVALID, + "iData user ID is required and must be a non-empty string" + ) + + +def _normalize_api_base(idata_api_base: str) -> str: + """ + Normalize API base URL by removing trailing slash. + + Args: + idata_api_base: iData API base URL + + Returns: + Normalized API base URL + """ + return idata_api_base.rstrip("/") + + +def _make_idata_request( + api_base: str, + url: str, + headers: Dict[str, str], + request_body: Dict[str, Any], +) -> Dict[str, Any]: + """ + Make HTTP POST request to iData API and handle common errors. + + Args: + api_base: Normalized API base URL + url: Full request URL + headers: Request headers + request_body: Request body as dictionary + + Returns: + Parsed JSON response + + Raises: + AppException: If request fails or response is invalid + """ + logger.info(f"Making iData API request to: {url}") + + try: + # Use shared HttpClientManager for connection pooling + # Note: ssl_verify is set to False as per requirement (self-signed certificate) + client = http_client_manager.get_sync_client( + base_url=api_base, + timeout=10.0, + verify_ssl=False + ) + response = client.post(url, headers=headers, json=request_body) + response.raise_for_status() + + return response.json() + + except httpx.RequestError as e: + logger.error(f"iData API request failed: {str(e)}") + raise AppException( + ErrorCode.IDATA_CONNECTION_ERROR, + f"iData API request failed: {str(e)}" + ) + except httpx.HTTPStatusError as e: + logger.error( + f"iData API HTTP error: {str(e)}, status_code: {e.response.status_code}") + # Map HTTP status to specific error code + if e.response.status_code == 401: + logger.error("Raising IDATA_AUTH_ERROR for 401 error") + raise AppException( + ErrorCode.IDATA_AUTH_ERROR, + f"iData authentication failed: {str(e)}" + ) + elif e.response.status_code == 403: + logger.error("Raising IDATA_AUTH_ERROR for 403 error") + raise AppException( + ErrorCode.IDATA_AUTH_ERROR, + f"iData access forbidden: {str(e)}" + ) + elif e.response.status_code == 429: + logger.error("Raising IDATA_RATE_LIMIT for 429 error") + raise AppException( + ErrorCode.IDATA_RATE_LIMIT, + f"iData API rate limit exceeded: {str(e)}" + ) + else: + logger.error( + f"Raising IDATA_SERVICE_ERROR for status {e.response.status_code}") + raise AppException( + ErrorCode.IDATA_SERVICE_ERROR, + f"iData API HTTP error {e.response.status_code}: {str(e)}" + ) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse iData API response: {str(e)}") + raise AppException( + ErrorCode.IDATA_RESPONSE_ERROR, + f"Failed to parse iData API response: {str(e)}" + ) + + +def _parse_idata_response(result: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Parse iData API response and validate format. + + Args: + result: Parsed JSON response from iData API + + Returns: + List of data items from response + + Raises: + AppException: If response format is invalid + """ + # Expected format: {"code": "1", "msg": "...", "data": [...], "msgParams": null} + code = result.get("code", "") + if code != "1": + msg = result.get("msg", "Unknown error") + logger.error( + f"iData API returned error code: {code}, message: {msg}") + raise AppException( + ErrorCode.IDATA_SERVICE_ERROR, + f"iData API error: {msg}" + ) + + data = result.get("data", []) + if not isinstance(data, list): + logger.error( + f"Unexpected iData API response format: data is not a list") + raise AppException( + ErrorCode.IDATA_RESPONSE_ERROR, + "Unexpected iData API response format: data is not a list" + ) + + return data + + +def fetch_idata_knowledge_spaces_impl( + idata_api_base: str, + api_key: str, + user_id: str, +) -> List[Dict[str, str]]: + """ + Fetch knowledge spaces from iData API. + + Args: + idata_api_base: iData API base URL + api_key: iData API key with Bearer token + user_id: iData user ID + + Returns: + List of dictionaries containing knowledge spaces with id and name: + [ + { + "id": "6cbf949946bf4b769c073259406b04f8", + "name": "test1" + }, + ... + ] + + Raises: + AppException: If API request fails or response is invalid + """ + # Validate inputs + _validate_idata_base_params(idata_api_base, api_key, user_id) + + # Normalize API base URL + api_base = _normalize_api_base(idata_api_base) + + # Build request URL + url = f"{api_base}/apiaccess/modelmate/north/machine/v1/knowledgeSpaces/query" + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + # Request body + request_body = { + "userId": user_id + } + + # Make request and parse response + result = _make_idata_request(api_base, url, headers, request_body) + data = _parse_idata_response(result) + + # Extract id and name from each knowledge space + knowledge_spaces = [] + for item in data: + if not isinstance(item, dict): + continue + + space_id = item.get("id") + space_name = item.get("name") + + if space_id and space_name: + knowledge_spaces.append({ + "id": str(space_id), + "name": str(space_name) + }) + + return knowledge_spaces + + +def fetch_idata_datasets_impl( + idata_api_base: str, + api_key: str, + user_id: str, + knowledge_space_id: str, +) -> Dict[str, Any]: + """ + Fetch datasets (knowledge bases) from iData API and transform to DataMate-compatible format. + + Args: + idata_api_base: iData API base URL + api_key: iData API key with Bearer token + user_id: iData user ID + knowledge_space_id: Knowledge space ID + + Returns: + Dictionary containing knowledge bases in DataMate-compatible format: + { + "indices": ["dataset_id_1", "dataset_id_2", ...], + "count": 2, + "indices_info": [ + { + "name": "dataset_id_1", + "display_name": "知识库名称", + "stats": { + "base_info": { + "doc_count": 10, + "process_source": "iData" + } + } + }, + ... + ] + } + + Raises: + AppException: If API request fails or response is invalid + """ + # Validate inputs + _validate_idata_base_params(idata_api_base, api_key, user_id) + + if not knowledge_space_id or not isinstance(knowledge_space_id, str): + raise AppException( + ErrorCode.IDATA_CONFIG_INVALID, + "Knowledge space ID is required and must be a non-empty string" + ) + + # Normalize API base URL + api_base = _normalize_api_base(idata_api_base) + + # Build request URL + url = f"{api_base}/apiaccess/modelmate/north/machine/v1/knowledgeBases/query" + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + # Request body + request_body = { + "userId": user_id, + "knowledgeSpaceId": knowledge_space_id + } + + # Make request and parse response + result = _make_idata_request(api_base, url, headers, request_body) + data = _parse_idata_response(result) + + # Transform to DataMate-compatible format + indices = [] + indices_info = [] + + for knowledge_base in data: + if not isinstance(knowledge_base, dict): + continue + + kb_id = knowledge_base.get("id", "") + kb_name = knowledge_base.get("name", "") + file_count = knowledge_base.get("fileCount", 0) + + if not kb_id: + continue + + indices.append(kb_id) + + # Create indices_info entry (compatible with DataMate format) + indices_info.append({ + "name": kb_id, + "display_name": kb_name, + "stats": { + "base_info": { + "doc_count": file_count, + "process_source": "iData" + } + } + }) + + return { + "indices": indices, + "count": len(indices), + "indices_info": indices_info + } diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index f407243e7..850e7095a 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -26,6 +26,7 @@ const TOOLS_REQUIRING_KB_SELECTION = [ "knowledge_base_search", "dify_search", "datamate_search", + "idata_search", ]; // Tool types that require Embedding model @@ -40,10 +41,11 @@ const TOOLS_REQUIRING_VLM = [ function getToolKbType( toolName: string -): "knowledge_base_search" | "dify_search" | "datamate_search" | null { +): "knowledge_base_search" | "dify_search" | "datamate_search" | "idata_search" | null { if (!TOOLS_REQUIRING_KB_SELECTION.includes(toolName)) return null; if (toolName === "dify_search") return "dify_search"; if (toolName === "datamate_search") return "datamate_search"; + if (toolName === "idata_search") return "idata_search"; return "knowledge_base_search"; } diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index c5884f32b..fc927d51d 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -27,6 +27,7 @@ import { useConfig } from "@/hooks/useConfig"; import { useKnowledgeBasesForToolConfig } from "@/hooks/useKnowledgeBaseSelector"; import { useKnowledgeBaseConfigChangeHandler } from "@/hooks/useKnowledgeBaseConfigChangeHandler"; import { API_ENDPOINTS } from "@/services/api"; +import knowledgeBaseService from "@/services/knowledgeBaseService"; import log from "@/lib/logger"; export interface ToolConfigModalProps { @@ -45,6 +46,7 @@ const TOOLS_REQUIRING_KB_SELECTION = [ "knowledge_base_search", "dify_search", "datamate_search", + "idata_search", ]; export default function ToolConfigModal({ @@ -91,6 +93,26 @@ export default function ToolConfigModal({ apiKey: "", }); + // iData configuration state + const [idataConfig, setIdataConfig] = useState<{ + serverUrl: string; + apiKey: string; + userId: string; + knowledgeSpaceId: string; + }>({ + serverUrl: "", + apiKey: "", + userId: "", + knowledgeSpaceId: "", + }); + + // iData knowledge spaces state + const [idataKnowledgeSpaces, setIdataKnowledgeSpaces] = useState< + Array<{ id: string; name: string }> + >([]); + const [idataKnowledgeSpacesLoading, setIdataKnowledgeSpacesLoading] = + useState(false); + // DataMate URL from knowledge base configuration const [knowledgeBaseDataMateUrl, setKnowledgeBaseDataMateUrl] = useState(""); @@ -117,11 +139,13 @@ export default function ToolConfigModal({ | "knowledge_base_search" | "dify_search" | "datamate_search" + | "idata_search" | null => { if (!toolRequiresKbSelection) return null; const name = tool?.name; if (name === "dify_search") return "dify_search"; if (name === "datamate_search") return "datamate_search"; + if (name === "idata_search") return "idata_search"; return "knowledge_base_search"; }, [tool?.name, toolRequiresKbSelection]); @@ -147,6 +171,46 @@ export default function ToolConfigModal({ } }, [toolKbType, difyServerUrlParam, difyApiKeyParam]); + // Get iData configuration from initial params + const idataServerUrlParam = useMemo(() => { + return currentParams.find((param) => param.name === "server_url"); + }, [currentParams]); + + const idataApiKeyParam = useMemo(() => { + return currentParams.find((param) => param.name === "api_key"); + }, [currentParams]); + + const idataUserIdParam = useMemo(() => { + return currentParams.find((param) => param.name === "user_id"); + }, [currentParams]); + + const idataKnowledgeSpaceIdParam = useMemo(() => { + return currentParams.find((param) => param.name === "knowledge_space_id"); + }, [currentParams]); + + // Initialize iData config from params + useEffect(() => { + if (toolKbType === "idata_search") { + const serverUrl = idataServerUrlParam?.value || ""; + const apiKey = idataApiKeyParam?.value || ""; + const userId = idataUserIdParam?.value || ""; + const knowledgeSpaceId = idataKnowledgeSpaceIdParam?.value || ""; + + setIdataConfig({ + serverUrl, + apiKey, + userId, + knowledgeSpaceId, + }); + } + }, [ + toolKbType, + idataServerUrlParam, + idataApiKeyParam, + idataUserIdParam, + idataKnowledgeSpaceIdParam, + ]); + // Fetch knowledge bases for tool config based on tool type (now uses React Query caching) // For datamate_search, use the server_url from the form as config const datamateServerUrl = useMemo(() => { @@ -157,6 +221,40 @@ export default function ToolConfigModal({ return ""; }, [toolKbType, currentParams]); + // Fetch iData knowledge spaces when config is available + useEffect(() => { + if ( + toolKbType === "idata_search" && + idataConfig.serverUrl && + idataConfig.apiKey && + idataConfig.userId + ) { + setIdataKnowledgeSpacesLoading(true); + knowledgeBaseService + .getIdataKnowledgeSpaces( + idataConfig.serverUrl, + idataConfig.apiKey, + idataConfig.userId + ) + .then((spaces) => { + setIdataKnowledgeSpaces(spaces); + setIdataKnowledgeSpacesLoading(false); + }) + .catch((error) => { + log.error("Failed to fetch iData knowledge spaces:", error); + setIdataKnowledgeSpaces([]); + setIdataKnowledgeSpacesLoading(false); + }); + } else if (toolKbType === "idata_search") { + setIdataKnowledgeSpaces([]); + } + }, [ + toolKbType, + idataConfig.serverUrl, + idataConfig.apiKey, + idataConfig.userId, + ]); + const { data: knowledgeBases = [], isLoading: kbLoading, @@ -168,7 +266,19 @@ export default function ToolConfigModal({ ? difyConfig : toolKbType === "datamate_search" ? { serverUrl: datamateServerUrl } - : undefined + : toolKbType === "idata_search" + ? idataConfig.serverUrl && + idataConfig.apiKey && + idataConfig.userId && + idataConfig.knowledgeSpaceId + ? { + serverUrl: idataConfig.serverUrl, + apiKey: idataConfig.apiKey, + userId: idataConfig.userId, + knowledgeSpaceId: idataConfig.knowledgeSpaceId, + } + : undefined + : undefined ); // Handle config change: clear knowledge base selection and refetch @@ -210,10 +320,92 @@ export default function ToolConfigModal({ ? difyConfig : toolKbType === "datamate_search" ? { serverUrl: datamateServerUrl } - : undefined, + : toolKbType === "idata_search" + ? { + serverUrl: idataConfig.serverUrl, + apiKey: idataConfig.apiKey, + userId: idataConfig.userId, + } + : undefined, onConfigChange: handleKbConfigChange, }); + // Handle iData knowledge space ID change: clear knowledge base selection and refetch + const prevKnowledgeSpaceIdRef = useRef(""); + useEffect(() => { + if ( + toolKbType === "idata_search" && + idataConfig.knowledgeSpaceId && + idataConfig.serverUrl && + idataConfig.apiKey && + idataConfig.userId + ) { + // Only trigger if knowledge space ID actually changed + // Skip if this is the initial load (prevKnowledgeSpaceIdRef is empty and we have a value from initialParams) + if (prevKnowledgeSpaceIdRef.current === idataConfig.knowledgeSpaceId) { + return; + } + + // If prevKnowledgeSpaceIdRef is empty, this is likely the initial load + // Don't clear dataset_ids on initial load, only when space ID actually changes + if (prevKnowledgeSpaceIdRef.current === "") { + // This is initial load, just update the ref without clearing + prevKnowledgeSpaceIdRef.current = idataConfig.knowledgeSpaceId; + return; + } + + // Update ref + prevKnowledgeSpaceIdRef.current = idataConfig.knowledgeSpaceId; + + // Clear previous knowledge base selection when space ID changes + setSelectedKbIds([]); + setSelectedKbDisplayNames([]); + + // Clear form value for dataset_ids field + const kbFieldIndex = currentParams.findIndex( + (p) => p.name === "dataset_ids" + ); + if (kbFieldIndex >= 0) { + form.setFieldValue(`param_${kbFieldIndex}`, []); + const updatedParams = [...currentParams]; + updatedParams[kbFieldIndex] = { + ...updatedParams[kbFieldIndex], + value: [], + }; + setCurrentParams(updatedParams); + } + + // Refetch knowledge bases with new space ID + refetchKnowledgeBases(); + } else if (toolKbType === "idata_search") { + // Reset ref when config is cleared + prevKnowledgeSpaceIdRef.current = ""; + } + }, [ + toolKbType, + idataConfig.knowledgeSpaceId, + idataConfig.serverUrl, + idataConfig.apiKey, + idataConfig.userId, + refetchKnowledgeBases, + currentParams, + form, + ]); + + // Reset prevKnowledgeSpaceIdRef when modal opens/closes + useEffect(() => { + if (!isOpen) { + // Reset ref when modal closes + prevKnowledgeSpaceIdRef.current = ""; + } else if (isOpen && toolKbType === "idata_search") { + // Initialize ref with current knowledgeSpaceId when modal opens + // This prevents clearing dataset_ids on initial load + if (idataConfig.knowledgeSpaceId) { + prevKnowledgeSpaceIdRef.current = idataConfig.knowledgeSpaceId; + } + } + }, [isOpen, toolKbType, idataConfig.knowledgeSpaceId]); + // Get current embedding model from config for model matching const currentEmbeddingModel = useMemo(() => { try { @@ -847,7 +1039,8 @@ export default function ToolConfigModal({ const getToolType = (): | "knowledge_base_search" | "dify_search" - | "datamate_search" => { + | "datamate_search" + | "idata_search" => { return toolKbType || "knowledge_base_search"; }; @@ -984,13 +1177,47 @@ export default function ToolConfigModal({ ); const renderParamInput = (param: ToolParam, index: number) => { + // Get field name for form + const fieldName = `param_${index}`; + // Get options from frontend configuration based on tool name and parameter name const options = getToolParamOptions(tool.name, param.name); // Determine if this parameter should be rendered as a select dropdown const isSelectType = options && options.length > 0; + // Special handling for iData knowledge_space_id parameter + const isIdataKnowledgeSpaceId = + toolKbType === "idata_search" && param.name === "knowledge_space_id"; + const inputComponent = (() => { + // Handle iData knowledge space ID selector + if (isIdataKnowledgeSpaceId) { + const currentValue = form.getFieldValue(fieldName); + return ( + {/* Group permission dropdown - second position */} diff --git a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseEditModal.tsx b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseEditModal.tsx index 9baf3a95d..360eb9efd 100644 --- a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseEditModal.tsx +++ b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseEditModal.tsx @@ -1,6 +1,6 @@ "use client"; -import React, { useState, useRef } from "react"; +import React, { useState, useRef, useEffect } from "react"; import { useTranslation } from "react-i18next"; import { Modal, Form, Input, Select, message } from "antd"; import { useGroupList } from "@/hooks/group/useGroupList"; @@ -35,6 +35,9 @@ export function KnowledgeBaseEditModal({ // Store original name for comparison const originalNameRef = useRef(""); + // Track current permission value for conditional logic + const [currentPermission, setCurrentPermission] = useState("READ_ONLY"); + // Fetch groups for group selection const { data: groupData } = useGroupList(tenantId); const groups = groupData?.groups || []; @@ -42,15 +45,18 @@ export function KnowledgeBaseEditModal({ // Reset form and states when knowledge base changes React.useEffect(() => { if (knowledgeBase && open) { + const permission = knowledgeBase.ingroup_permission || "READ_ONLY"; form.setFieldsValue({ knowledge_name: knowledgeBase.name, - ingroup_permission: knowledgeBase.ingroup_permission || "READ_ONLY", - group_ids: knowledgeBase.group_ids || [], + ingroup_permission: permission, + group_ids: permission === "PRIVATE" ? [] : (knowledgeBase.group_ids || []), }); // Store original name for comparison originalNameRef.current = knowledgeBase.name; // Reset error state setNameError(null); + // Set current permission + setCurrentPermission(permission); } }, [knowledgeBase, open, form]); @@ -90,10 +96,13 @@ export function KnowledgeBaseEditModal({ return; // Error message is displayed via Form.Item help } + // Ensure group_ids is empty when permission is PRIVATE + const groupIds = values.ingroup_permission === "PRIVATE" ? [] : values.group_ids; + await knowledgeBaseService.updateKnowledgeBase(knowledgeBase.id, { knowledge_name: values.knowledge_name, ingroup_permission: values.ingroup_permission, - group_ids: values.group_ids, + group_ids: groupIds, }); message.success(t("tenantResources.knowledgeBase.updated")); @@ -103,7 +112,7 @@ export function KnowledgeBaseEditModal({ ...knowledgeBase, name: values.knowledge_name, ingroup_permission: values.ingroup_permission, - group_ids: values.group_ids, + group_ids: groupIds, }; // Trigger knowledge base list refresh to seamlessly update UI @@ -119,6 +128,17 @@ export function KnowledgeBaseEditModal({ } }; + // Handle permission change - clear group_ids when PRIVATE is selected + const handlePermissionChange = (value: string) => { + setCurrentPermission(value); + if (value === "PRIVATE") { + form.setFieldsValue({ group_ids: [] }); + } + }; + + // Check if group select should be disabled + const isGroupSelectDisabled = currentPermission === "PRIVATE"; + return ( ({ label: group.group_name, value: group.group_id, }))} + disabled={isGroupSelectDisabled} /> diff --git a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx index 4a03650fe..d5ec5cdb7 100644 --- a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx +++ b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx @@ -589,16 +589,17 @@ const KnowledgeBaseList: React.FC = ({ )} - {/* User group tags */} + {/* User group tags - only show when not PRIVATE */} - {getGroupNames(kb.group_ids).map((groupName, idx) => ( - - {groupName} - - ))} + {kb.ingroup_permission !== "PRIVATE" && + getGroupNames(kb.group_ids).map((groupName, idx) => ( + + {groupName} + + ))} )} diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index a58bba8a5..17bbe6a69 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -545,7 +545,7 @@ "knowledgeBase.summary.notGenerated": "Knowledge base summary was not generated, please change model configuration and retry", "knowledgeBase.name.new": "new_base", "knowledgeBase.message.getDocumentsFailed": "Failed to get documents", - "knowledgeBase.create.permission.groupPlaceholder": "User groups of this knowledge base", + "knowledgeBase.create.permission.groupPlaceholder": "No user group", "knowledgeBase.ingroup.permission.EDIT": "In Group Read/Write", "knowledgeBase.ingroup.permission.READ_ONLY": "In Group Read Only", "knowledgeBase.ingroup.permission.PRIVATE": "Personal Private", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 2f50f53f0..9388c9a49 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -547,7 +547,7 @@ "knowledgeBase.summary.notGenerated": "未生成知识库总结,请更换模型配置重试", "knowledgeBase.name.new": "新知识库", "knowledgeBase.message.getDocumentsFailed": "获取文档列表失败", - "knowledgeBase.create.permission.groupPlaceholder": "该知识库所属用户组", + "knowledgeBase.create.permission.groupPlaceholder": "无所属用户组", "knowledgeBase.ingroup.permission.EDIT": "同组可编辑", "knowledgeBase.ingroup.permission.READ_ONLY": "同组只读", "knowledgeBase.ingroup.permission.PRIVATE": "私有", From 8fb4a76326b1298c761b1d3dc9739104378a0d9c Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Thu, 5 Mar 2026 19:48:08 +0800 Subject: [PATCH 61/75] Delete unused code --- frontend/app/[locale]/chat/internal/chatInterface.tsx | 11 ----------- frontend/components/auth/index.ts | 8 -------- 2 files changed, 19 deletions(-) delete mode 100644 frontend/components/auth/index.ts diff --git a/frontend/app/[locale]/chat/internal/chatInterface.tsx b/frontend/app/[locale]/chat/internal/chatInterface.tsx index 2bbdc7ff3..31b78649d 100644 --- a/frontend/app/[locale]/chat/internal/chatInterface.tsx +++ b/frontend/app/[locale]/chat/internal/chatInterface.tsx @@ -30,8 +30,6 @@ import { createMessageAttachments, cleanupAttachmentUrls, } from "@/app/chat/internal/chatPreprocess"; -import { Tooltip, TooltipProvider } from "@/components/ui/tooltip"; - import { ConversationListItem, ApiConversationDetail } from "@/types/chat"; import { ChatMessageType } from "@/types/chat"; import { handleStreamResponse } from "@/app/chat/streaming/chatStreamHandler"; @@ -1435,15 +1433,6 @@ export function ChatInterface() {

- - -
- - ); } diff --git a/frontend/components/auth/index.ts b/frontend/components/auth/index.ts deleted file mode 100644 index b54ca123a..000000000 --- a/frontend/components/auth/index.ts +++ /dev/null @@ -1,8 +0,0 @@ -/** - * Export all authentication related components - */ - -export * from "./avatarDropdown"; -export * from "./loginModal"; -export * from "./registerModal"; -export * from "./DeleteAccountModal"; From d0030b6a8c3542b656d204bd8bde46043fb3b85c Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Thu, 12 Mar 2026 10:55:48 +0800 Subject: [PATCH 62/75] solve conflict --- .../KnowledgeBaseSelectorModal.tsx | 36 ++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx index ab695d869..d7ca0e72f 100644 --- a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx +++ b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx @@ -18,12 +18,40 @@ import { } from "@ant-design/icons"; import { KnowledgeBase } from "@/types/knowledgeBase"; -import { - KnowledgeBaseSelectorProps, - getKnowledgeBaseSourcesForTool, -} from "./index"; import { KB_LAYOUT, KB_TAG_VARIANTS } from "@/const/knowledgeBaseLayout"; +interface KnowledgeBaseSelectorProps { + isOpen: boolean; + onClose: () => void; + onConfirm: (selectedKnowledgeBases: KnowledgeBase[]) => void; + selectedIds: string[]; + toolType: "knowledge_base_search" | "dify_search" | "datamate_search"; + title?: string; + maxSelect?: number; + showCreateButton?: boolean; + showDeleteButton?: boolean; + showCheckbox?: boolean; + difyConfig?: { + serverUrl?: string; + apiKey?: string; + }; +} + +function getKnowledgeBaseSourcesForTool( + toolType: "knowledge_base_search" | "dify_search" | "datamate_search" +): string[] { + switch (toolType) { + case "knowledge_base_search": + return ["nexent"]; + case "dify_search": + return ["dify"]; + case "datamate_search": + return ["datamate"]; + default: + return ["nexent"]; + } +} + interface KnowledgeBaseSelectorModalProps extends KnowledgeBaseSelectorProps { knowledgeBases: KnowledgeBase[]; isLoading?: boolean; From b02a82d02b8234417dca7b5b9bb76e779ebcd11d Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 10 Mar 2026 16:02:02 +0800 Subject: [PATCH 63/75] Refactor: Redesign left sidebar with responsive collapse --- .../[locale]/chat/components/chatHeader.tsx | 59 +- .../chat/components/chatLeftSidebar.tsx | 544 +++++++++--------- .../chat/components/chatRightPanel.tsx | 2 +- .../[locale]/chat/internal/chatInterface.tsx | 413 +++++-------- .../chat/streaming/chatStreamMain.tsx | 2 +- frontend/app/[locale]/layout.client.tsx | 1 + .../hooks/chat/useConversationManagement.ts | 120 ++-- frontend/types/chat.ts | 21 - 8 files changed, 494 insertions(+), 668 deletions(-) diff --git a/frontend/app/[locale]/chat/components/chatHeader.tsx b/frontend/app/[locale]/chat/components/chatHeader.tsx index 730522420..1621c881c 100644 --- a/frontend/app/[locale]/chat/components/chatHeader.tsx +++ b/frontend/app/[locale]/chat/components/chatHeader.tsx @@ -2,17 +2,14 @@ import { useState, useRef, useEffect } from "react"; import { useTranslation } from "react-i18next"; -import { Button } from "antd"; import { Input } from "@/components/ui/input"; import { loadMemoryConfig, setMemorySwitch } from "@/services/memoryService"; import { useConfig } from "@/hooks/useConfig"; import log from "@/lib/logger"; -import { useRouter } from "next/navigation"; import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; import { useDeployment } from "@/components/providers/deploymentProvider"; import { USER_ROLES } from "@/const/auth"; -import { saveView } from "@/lib/viewPersistence"; import { useConfirmModal } from "@/hooks/useConfirmModal"; interface ChatHeaderProps { @@ -21,11 +18,9 @@ interface ChatHeaderProps { } export function ChatHeader({ title, onRename }: ChatHeaderProps) { - const router = useRouter(); const [isEditing, setIsEditing] = useState(false); const [editTitle, setEditTitle] = useState(title); - const inputRef = useRef(null); const { t, i18n } = useTranslation("common"); const { user } = useAuthorizationContext(); @@ -124,41 +119,27 @@ export function ChatHeader({ title, onRename }: ChatHeaderProps) { return ( <>
-
-
-
- {/* Left button area */} -
- -
-
- {isEditing ? ( - setEditTitle(e.target.value)} - onKeyDown={handleKeyDown} - onBlur={handleSubmit} - className="text-xl font-bold text-center h-9 max-w-xs" - autoFocus - /> - ) : ( -

- {title} -

- )} -
-
- -
- {/* Right side controls - now handled by navigation bar */} -
+
+ {isEditing ? ( + setEditTitle(e.target.value)} + onKeyDown={handleKeyDown} + onBlur={handleSubmit} + className="text-xl font-bold text-center h-9 max-w-xs" + autoFocus + /> + ) : ( +

+ {title} +

+ )}
-
diff --git a/frontend/app/[locale]/chat/components/chatLeftSidebar.tsx b/frontend/app/[locale]/chat/components/chatLeftSidebar.tsx index baf569c36..92332b88c 100644 --- a/frontend/app/[locale]/chat/components/chatLeftSidebar.tsx +++ b/frontend/app/[locale]/chat/components/chatLeftSidebar.tsx @@ -1,4 +1,4 @@ -import { useState, useRef, useEffect } from "react"; +import { useState } from "react"; import { Clock, Plus, @@ -8,15 +8,16 @@ import { ChevronLeft, ChevronRight, } from "lucide-react"; -import { useRouter } from "next/navigation"; -import { Button, Dropdown } from "antd"; -import { Input } from "@/components/ui/input"; -import { Tooltip, TooltipProvider } from "@/components/ui/tooltip"; -import { StaticScrollArea } from "@/components/ui/scrollArea"; +import { Button, Dropdown, Layout, Typography, Tooltip } from "antd"; import { useTranslation } from "react-i18next"; import { useConfirmModal } from "@/hooks/useConfirmModal"; -import { ConversationListItem, ChatSidebarProps } from "@/types/chat"; +import { conversationService } from "@/services/conversationService"; +import { + type ConversationManagement, +} from "@/hooks/chat/useConversationManagement"; +import { ConversationListItem, SettingsMenuItem } from "@/types/chat"; +import log from "@/lib/logger"; // conversation status indicator component const ConversationStatusIndicator = ({ @@ -50,7 +51,7 @@ const ConversationStatusIndicator = ({ }; // Helper function - dialog classification -const categorizeDialogs = (dialogs: ConversationListItem[]) => { +const categorizeConversations = (conversations: ConversationListItem[]) => { const now = new Date(); const today = new Date( now.getFullYear(), @@ -59,246 +60,197 @@ const categorizeDialogs = (dialogs: ConversationListItem[]) => { ).getTime(); const weekAgo = today - 7 * 24 * 60 * 60 * 1000; - const todayDialogs: ConversationListItem[] = []; - const weekDialogs: ConversationListItem[] = []; - const olderDialogs: ConversationListItem[] = []; + const todayConversations: ConversationListItem[] = []; + const weekConversations: ConversationListItem[] = []; + const olderConversations: ConversationListItem[] = []; - dialogs.forEach((dialog) => { - const dialogTime = dialog.create_time; + conversations.forEach((conversations) => { + const conversationTime = conversations.create_time; - if (dialogTime >= today) { - todayDialogs.push(dialog); - } else if (dialogTime >= weekAgo) { - weekDialogs.push(dialog); + if (conversationTime >= today) { + todayConversations.push(conversations); + } else if (conversationTime >= weekAgo) { + weekConversations.push(conversations); } else { - olderDialogs.push(dialog); + olderConversations.push(conversations); } }); return { - today: todayDialogs, - week: weekDialogs, - older: olderDialogs, + today: todayConversations, + week: weekConversations, + older: olderConversations, }; }; +// Chat sidebar props type +export interface ChatSidebarProps { + streamingConversations: Set; + completedConversations: Set; + conversationManagement: ConversationManagement; + /** Called when user clicks a conversation - loads messages and updates selection */ + onConversationSelect: (conversation: ConversationListItem) => void | Promise; +} + export function ChatSidebar({ - conversationList, - selectedConversationId, - openDropdownId, streamingConversations, completedConversations, - onNewConversation, - onDialogClick, - onRename, - onDelete, - onSettingsClick, - onDropdownOpenChange, - onToggleSidebar, - expanded, - userEmail, - userAvatarUrl + conversationManagement, + onConversationSelect, }: ChatSidebarProps) { const { t } = useTranslation(); const { confirm } = useConfirmModal(); - const router = useRouter(); - const { today, week, older } = categorizeDialogs(conversationList); + const { today, week, older } = categorizeConversations(conversationManagement.conversationList); const [editingId, setEditingId] = useState(null); - const [editingTitle, setEditingTitle] = useState(""); - const inputRef = useRef(null); - - const [animationComplete, setAnimationComplete] = useState(false); + const [collapsed, setCollapsed] = useState(false); - useEffect(() => { - // Reset animation state when expanded changes - setAnimationComplete(false); + const onToggleSidebar = () => setCollapsed((prev) => !prev); - // Set animation complete after the transition duration (200ms) - const timer = setTimeout(() => { - setAnimationComplete(true); - }, 200); - - return () => clearTimeout(timer); - }, [expanded]); - - // Handle edit start - const handleStartEdit = (dialogId: number, title: string) => { - setEditingId(dialogId); - setEditingTitle(title); - // Close any open dropdown menus - onDropdownOpenChange(false, null); - - // Use setTimeout to ensure that the input box is focused after the DOM is updated - setTimeout(() => { - if (inputRef.current) { - inputRef.current.focus(); - inputRef.current.select(); - } - }, 10); + const handleRenameClick = (conversationId: number) => { + setEditingId(conversationId); }; - // Handle edit submission - const handleSubmitEdit = () => { - if (editingId !== null && editingTitle.trim()) { - onRename(editingId, editingTitle.trim()); + const handleRename = async (conversationId: number, newTitle: string) => { + if (!newTitle.trim()) return; + try { + await conversationService.rename(conversationId, newTitle.trim()); + await conversationManagement.fetchConversationList(); + if (conversationManagement.selectedConversationId === conversationId) { + conversationManagement.setConversationTitle(newTitle.trim()); + } setEditingId(null); + } catch (error) { + log.error(t("chatInterface.renameFailed"), error); } }; - // Handle edit cancellation - const handleCancelEdit = () => { - setEditingId(null); - }; + // Handle delete + const handleDelete = (conversationId: number) => { - // Handle key events - const handleKeyDown = (e: React.KeyboardEvent) => { - if (e.key === "Enter") { - handleSubmitEdit(); - } else if (e.key === "Escape") { - handleCancelEdit(); - } - }; - - // Handle delete click - const handleDeleteClick = (dialogId: number) => { - // Close dropdown menus - onDropdownOpenChange(false, null); - - // Show confirmation modal confirm({ title: t("chatLeftSidebar.confirmDeletionTitle"), content: t("chatLeftSidebar.confirmDeletionDescription"), - onOk: () => { - onDelete(dialogId); + onOk: async () => { + try { + await conversationService.delete(conversationId); + await conversationManagement.fetchConversationList(); + if (conversationManagement.selectedConversationId === conversationId) { + conversationManagement.setSelectedConversationId(null); + conversationManagement.setConversationTitle( + t("chatInterface.newConversation") + ); + conversationManagement.handleNewConversation(); + } + } catch (error) { + log.error(t("chatInterface.deleteFailed"), error); + } }, }); }; // Render dialog list items - const renderDialogList = (dialogs: ConversationListItem[], title: string) => { - if (dialogs.length === 0) return null; + const renderConversationList = (conversation: ConversationListItem[], title: string) => { + if (conversation.length === 0) return null; return ( -
+

{title}

- {dialogs.map((dialog) => ( + {conversation.map((conversation) => (
- {editingId === dialog.conversation_id ? ( - // Edit mode -
- setEditingTitle(e.target.value)} - onKeyDown={handleKeyDown} - onBlur={handleSubmitEdit} - className="h-8 text-base" - autoFocus - /> -
- ) : ( - // Display mode - <> - - {dialog.conversation_title}

- } - placement="right" - styles={{ root: { maxWidth: "300px" } }} +
+ + {conversation.conversation_title} + + } + placement="bottom" + > +
onConversationSelect(conversation)} + > + +
+ handleRename(conversation.conversation_id, value), + // onCancel: () => setEditingId(null), + }} + className="block text-base font-normal text-gray-800 tracking-wide font-sans ml-0.5 flex-1 min-w-0" > - - - + {conversation.conversation_title} + +
+
+
+
- - onDropdownOpenChange( - open, - dialog.conversation_id.toString() - ) +
+ + + {t("chatLeftSidebar.rename")} + + ), + }, + { + key: "delete", + label: ( + + + {t("chatLeftSidebar.delete")} + + ), + }, + ], + onClick: ({ key }) => { + if (key === "rename") { + handleRenameClick(conversation.conversation_id); + } else if (key === "delete") { + handleDelete(conversation.conversation_id); } - menu={{ - items: [ - { - key: "rename", - label: ( - - - {t("chatLeftSidebar.rename")} - - ), - }, - { - key: "delete", - label: ( - - - {t("chatLeftSidebar.delete")} - - ), - }, - ], - onClick: ({ key }) => { - if (key === "rename") { - handleStartEdit( - dialog.conversation_id, - dialog.conversation_title - ); - } else if (key === "delete") { - handleDeleteClick(dialog.conversation_id); - } - }, - }} - placement="bottomRight" - trigger={["click"]} - > - - - - )} + }, + }} + placement="bottomRight" + trigger={["click"]} + > + + +
))}
@@ -311,40 +263,30 @@ export function ChatSidebar({ <> {/* Expand/Collapse button */}
- - + - - + + +
{/* New conversation button */} -
- - + + - - + + +
{/* Spacer */} @@ -354,20 +296,26 @@ export function ChatSidebar({ }; return ( - <> -
- {expanded || !animationComplete ? ( -
+ + {!collapsed ? ( +
- - - + + + +
+
+ +
+
+
+ {conversationManagement.conversationList.length > 0 ? + ( + <> + {renderConversationList(today, t("chatLeftSidebar.today"))} + {renderConversationList(week, t("chatLeftSidebar.last7Days"))} + {renderConversationList(older, t("chatLeftSidebar.older"))} + + ) : ( +
+

+ {t("chatLeftSidebar.recentConversations")} +

- - - +
+ )} +
- - -
- {conversationList.length > 0 ? ( - <> - {renderDialogList(today, t("chatLeftSidebar.today"))} - {renderDialogList(week, t("chatLeftSidebar.last7Days"))} - {renderDialogList(older, t("chatLeftSidebar.older"))} - - ) : ( -
-

- {t("chatLeftSidebar.recentConversations")} -

- -
- )} -
-
) : ( renderCollapsedSidebar() )} -
- + + ); } diff --git a/frontend/app/[locale]/chat/components/chatRightPanel.tsx b/frontend/app/[locale]/chat/components/chatRightPanel.tsx index 83b25c4b5..c11be9679 100644 --- a/frontend/app/[locale]/chat/components/chatRightPanel.tsx +++ b/frontend/app/[locale]/chat/components/chatRightPanel.tsx @@ -479,7 +479,7 @@ export function ChatRightPanel({ {/* Image viewer modal */} {viewingImage && (
setViewingImage(null)} >
diff --git a/frontend/app/[locale]/chat/internal/chatInterface.tsx b/frontend/app/[locale]/chat/internal/chatInterface.tsx index 31b78649d..5762d55a8 100644 --- a/frontend/app/[locale]/chat/internal/chatInterface.tsx +++ b/frontend/app/[locale]/chat/internal/chatInterface.tsx @@ -38,7 +38,7 @@ import { extractAssistantMsgFromResponse, } from "./extractMsgFromHistoryResponse"; -import { X } from "lucide-react"; +import { Layout } from "antd"; import log from "@/lib/logger"; const stepIdCounter = { current: 0 }; @@ -67,8 +67,6 @@ export function ChatInterface() { // Use conversation management hook const conversationManagement = useConversationManagement(); - const [openDropdownId, setOpenDropdownId] = useState(null); - const { appConfig } = useConfig(); // For each conversation, maintain independent SSE connections and states const [streamingConversations, setStreamingConversations] = useState< @@ -90,8 +88,8 @@ export function ChatInterface() { // Monitor changes in currentMessages // Calculate if the current conversation is streaming const isCurrentConversationStreaming = - conversationManagement.conversationId && conversationManagement.conversationId !== -1 - ? streamingConversations.has(conversationManagement.conversationId) + conversationManagement.selectedConversationId != null + ? streamingConversations.has(conversationManagement.selectedConversationId) : false; const [viewingImage, setViewingImage] = useState(null); @@ -104,8 +102,6 @@ export function ChatInterface() { const abortControllerRef = useRef(null); // Add AbortController reference const timeoutRef = useRef(null); // Add timeout reference - // Add sidebar state control - const [sidebarOpen, setSidebarOpen] = useState(true); // Add a state to track if we're loading a historical conversation const [isLoadingHistoricalConversation, setIsLoadingHistoricalConversation] = @@ -116,9 +112,6 @@ export function ChatInterface() { Set >(new Set()); - // Add a ref to track the currently selected conversation ID for real-time access - const currentSelectedConversationRef = useRef(null); - // Ensure right sidebar is closed by default const [showRightPanel, setShowRightPanel] = useState(false); @@ -176,64 +169,43 @@ export function ChatInterface() { setAttachments(newAttachments); }; - // Define sidebar toggle function - const toggleSidebar = () => { - setSidebarOpen(!sidebarOpen); - }; // Handle right panel toggle - keep it simple and clear const toggleRightPanel = () => { setShowRightPanel(!showRightPanel); }; - useEffect(() => { - if (!conversationManagement.initialized.current) { - conversationManagement.initialized.current = true; - - // Get conversation history list, but don't auto-select the latest conversation - conversationManagement.fetchConversationList() - .then((dialogData) => { - // Create new conversation by default regardless of history - handleNewConversation(); - }) - .catch((err) => { - log.error(t("chatInterface.errorFetchingConversationList"), err); - // Create new conversation even if getting conversation list fails - handleNewConversation(); - }); - } - }, [appConfig]); // Add appConfig as dependency - // Add useEffect to listen for conversationId changes, ensure right sidebar is always closed when conversation switches useEffect(() => { // Ensure right sidebar is reset to closed state whenever conversation ID changes setSelectedMessageId(undefined); setShowRightPanel(false); - }, [conversationManagement.conversationId]); + }, [conversationManagement.selectedConversationId]); // Helper function to clear completed conversation indicator const clearCompletedIndicator = useCallback(() => { if ( - conversationManagement.conversationId && - conversationManagement.conversationId !== -1 + conversationManagement.selectedConversationId != null ) { setCompletedConversations((prev) => { // Use functional update to avoid dependency on completedConversations - if (prev.has(conversationManagement.conversationId)) { + if (conversationManagement.selectedConversationId != null && prev.has(conversationManagement.selectedConversationId)) { const newSet = new Set(prev); - newSet.delete(conversationManagement.conversationId); + newSet.delete(conversationManagement.selectedConversationId); return newSet; } return prev; }); } - }, [conversationManagement.conversationId]); + }, [conversationManagement.selectedConversationId]); + + // Add useEffect to clear completed conversation indicator when user is viewing the current conversation useEffect(() => { // If current conversation is in completedConversations, clear it when user is viewing it clearCompletedIndicator(); - }, [conversationManagement.conversationId, clearCompletedIndicator]); + }, [conversationManagement.selectedConversationId, clearCompletedIndicator]); // Add click event listener to clear completed conversation indicator when user clicks anywhere on the page useEffect(() => { @@ -289,13 +261,9 @@ export function ChatInterface() { const userMessageId = uuidv4(); const userMessageContent = input.trim(); - // Get current conversation ID - let currentConversationId = conversationManagement.conversationId; - - // Ensure ref reflects the current conversation state - if (currentConversationId && currentConversationId !== -1) { - conversationManagement.currentSelectedConversationRef.current = currentConversationId; - } + // Get current conversation ID (null when new conversation) + let currentConversationId = conversationManagement.selectedConversationId; + let cid: number | null = null; // set after guard, used in try/catch/finally // Prepare attachment information // Handle file upload @@ -355,8 +323,8 @@ export function ChatInterface() { try { // Check if need to create new conversation - if (!currentConversationId || currentConversationId === -1) { - // If no session ID or ID is -1, create new conversation first + if (currentConversationId == null) { + // No conversation selected: create new conversation first try { const createData = await conversationService.create( t("chatInterface.newConversation") @@ -364,18 +332,14 @@ export function ChatInterface() { currentConversationId = createData.conversation_id; // Update current session state - conversationManagement.setConversationId(currentConversationId); conversationManagement.setSelectedConversationId(currentConversationId); - // Update ref to track current selected conversation - conversationManagement.currentSelectedConversationRef.current = currentConversationId; conversationManagement.setConversationTitle( createData.conversation_title || t("chatInterface.newConversation") ); // After creating new conversation, add it to streaming list setStreamingConversations((prev) => { - const newSet = new Set(prev).add(currentConversationId); - + const newSet = new Set(prev).add(createData.conversation_id); return newSet; }); @@ -406,25 +370,25 @@ export function ChatInterface() { } } - // Ensure valid conversation ID before registering controller and streaming state - if (currentConversationId && currentConversationId !== -1) { - conversationControllersRef.current.set( - currentConversationId, - currentController - ); - setStreamingConversations((prev) => { - const newSet = new Set(prev); - newSet.add(currentConversationId); - return newSet; - }); - } + // Type guard: we have a number here (either from selection or from create above) + if (currentConversationId == null) return; + const id = currentConversationId; + cid = id; + + // Register controller and streaming state for this conversation + conversationControllersRef.current.set(id, currentController); + setStreamingConversations((prev) => { + const newSet = new Set(prev); + newSet.add(id); + return newSet; + }); // Now add messages after conversation is created/confirmed // 1. When sending user message, complete ChatMessageType fields setSessionMessages((prev) => ({ ...prev, - [currentConversationId]: [ - ...(prev[currentConversationId] || []), + [id]: [ + ...(prev[id] || []), { ...userMessage, id: userMessage.id || uuidv4(), @@ -440,8 +404,8 @@ export function ChatInterface() { // 2. When adding AI reply message, complete ChatMessageType fields setSessionMessages((prev) => ({ ...prev, - [currentConversationId]: [ - ...(prev[currentConversationId] || []), + [id]: [ + ...(prev[id] || []), { ...initialAssistantMessage, id: initialAssistantMessage.id || uuidv4(), @@ -478,7 +442,7 @@ export function ChatInterface() { // Send request to backend API, add signal parameter const runAgentParams: any = { query: finalQuery, // Use preprocessed query or original query - conversation_id: currentConversationId, + conversation_id: id, is_set: isSwitchedConversation || currentMessages.length <= 1, history: currentMessages .filter((msg) => msg.id !== userMessage.id) @@ -549,16 +513,12 @@ export function ChatInterface() { // Create resetTimeout function for current conversation const resetTimeout = () => { - const timeout = conversationTimeoutsRef.current.get( - currentConversationId - ); + const timeout = conversationTimeoutsRef.current.get(id); if (timeout) { clearTimeout(timeout); } const newTimeout = setTimeout(async () => { - const controller = conversationControllersRef.current.get( - currentConversationId - ); + const controller = conversationControllersRef.current.get(id); if (controller && !controller.signal.aborted) { try { controller.abort(t("chatInterface.requestTimeout")); @@ -566,9 +526,7 @@ export function ChatInterface() { setSessionMessages((prev) => { const newMessages = { ...prev }; const lastMsg = - newMessages[currentConversationId]?.[ - newMessages[currentConversationId].length - 1 - ]; + newMessages[id]?.[newMessages[id].length - 1]; if (lastMsg && lastMsg.role === ROLE_ASSISTANT) { lastMsg.error = t("chatInterface.requestTimeoutRetry"); lastMsg.isComplete = true; @@ -577,23 +535,21 @@ export function ChatInterface() { return newMessages; }); - if (currentConversationId && currentConversationId !== -1) { - try { - await conversationService.stop(currentConversationId); - } catch (error) { - log.error( - t("chatInterface.stopTimeoutRequestFailed"), - error - ); - } + try { + await conversationService.stop(id); + } catch (error) { + log.error( + t("chatInterface.stopTimeoutRequestFailed"), + error + ); } } catch (error) { log.error(t("chatInterface.errorCancelingRequest"), error); } } - conversationTimeoutsRef.current.delete(currentConversationId); + conversationTimeoutsRef.current.delete(id); }, 120000); - conversationTimeoutsRef.current.set(currentConversationId, newTimeout); + conversationTimeoutsRef.current.set(id, newTimeout); }; // Before processing streaming response, set an initial timeout first @@ -603,14 +559,14 @@ export function ChatInterface() { // Compatible with both function and direct assignment await handleStreamResponse( reader, - setCurrentSessionMessagesFactory(currentConversationId), + setCurrentSessionMessagesFactory(id), resetTimeout, stepIdCounter, setIsSwitchedConversation, conversationManagement.isNewConversation, conversationManagement.setConversationTitle, conversationManagement.fetchConversationList, - currentConversationId, + id, conversationService, false, // isDebug: false for normal chat mode t @@ -621,101 +577,88 @@ export function ChatInterface() { setIsStreaming(false); // Clean up controller and timeout for current conversation - conversationControllersRef.current.delete(currentConversationId); - const timeout = conversationTimeoutsRef.current.get( - currentConversationId - ); + conversationControllersRef.current.delete(id); + const timeout = conversationTimeoutsRef.current.get(id); if (timeout) { clearTimeout(timeout); - conversationTimeoutsRef.current.delete(currentConversationId); + conversationTimeoutsRef.current.delete(id); } - // Remove from streaming list (only when conversationId is not -1) - if (currentConversationId !== -1) { - setStreamingConversations((prev) => { + // Remove from streaming list when we have a valid conversation id + setStreamingConversations((prev) => { + const newSet = new Set(prev); + newSet.delete(id); + return newSet; + }); + + // When conversation is completed, only add to completed conversation list when user is not in current conversation interface + const currentUserConversation = conversationManagement.selectedConversationId; + if (currentUserConversation !== id) { + setCompletedConversations((prev) => { const newSet = new Set(prev); - newSet.delete(currentConversationId); + newSet.add(id); return newSet; }); - - // When conversation is completed, only add to completed conversation list when user is not in current conversation interface - // Use ref to get the actual conversation the user is in - const currentUserConversation = currentSelectedConversationRef.current; - if (currentUserConversation !== currentConversationId) { - setCompletedConversations((prev) => { - const newSet = new Set(prev); - newSet.add(currentConversationId); - return newSet; - }); - } } // Note: Save operation is already implemented in agent run API, no need to save again in frontend } catch (error) { // If user actively canceled, don't show error message const err = error as Error; - if (err.name === "AbortError") { - setSessionMessages((prev) => { - const newMessages = { ...prev }; - const lastMsg = - newMessages[currentConversationId]?.[ - newMessages[currentConversationId].length - 1 - ]; - if (lastMsg && lastMsg.role === ROLE_ASSISTANT) { - lastMsg.content = t("chatInterface.conversationStopped"); - lastMsg.isComplete = true; - lastMsg.thinking = undefined; // Explicitly clear thinking state - } - return newMessages; - }); - } else { - log.error(t("chatInterface.errorLabel"), error); - // Show user-friendly error message instead of technical error details - const errorMessage = t("chatInterface.errorProcessingRequest"); - setSessionMessages((prev) => { - const newMessages = { ...prev }; - const lastMsg = - newMessages[currentConversationId]?.[ - newMessages[currentConversationId].length - 1 - ]; - if (lastMsg && lastMsg.role === ROLE_ASSISTANT) { - lastMsg.content = errorMessage; - lastMsg.isComplete = true; - lastMsg.error = errorMessage; - lastMsg.thinking = undefined; // Explicitly clear thinking state - } - return newMessages; - }); + if (cid != null) { + const idForCatch = cid; + if (err.name === "AbortError") { + setSessionMessages((prev) => { + const newMessages = { ...prev }; + const lastMsg = + newMessages[idForCatch]?.[newMessages[idForCatch].length - 1]; + if (lastMsg && lastMsg.role === ROLE_ASSISTANT) { + lastMsg.content = t("chatInterface.conversationStopped"); + lastMsg.isComplete = true; + lastMsg.thinking = undefined; // Explicitly clear thinking state + } + return newMessages; + }); + } else { + log.error(t("chatInterface.errorLabel"), error); + const errorMessage = t("chatInterface.errorProcessingRequest"); + setSessionMessages((prev) => { + const newMessages = { ...prev }; + const lastMsg = + newMessages[idForCatch]?.[newMessages[idForCatch].length - 1]; + if (lastMsg && lastMsg.role === ROLE_ASSISTANT) { + lastMsg.content = errorMessage; + lastMsg.isComplete = true; + lastMsg.error = errorMessage; + lastMsg.thinking = undefined; // Explicitly clear thinking state + } + return newMessages; + }); + } } setIsLoading(false); setIsStreaming(false); - // Clean up controller and timeout for current conversation - conversationControllersRef.current.delete(currentConversationId); - const timeout = conversationTimeoutsRef.current.get( - currentConversationId - ); - if (timeout) { - clearTimeout(timeout); - conversationTimeoutsRef.current.delete(currentConversationId); - } - - // Remove from streaming list (only when conversationId is not -1) - if (currentConversationId !== -1) { + // Clean up when we had a conversation id (cid is set after the guard in try) + if (cid != null) { + const idForCatch = cid; + conversationControllersRef.current.delete(idForCatch); + const timeout = conversationTimeoutsRef.current.get(idForCatch); + if (timeout) { + clearTimeout(timeout); + conversationTimeoutsRef.current.delete(idForCatch); + } setStreamingConversations((prev) => { const newSet = new Set(prev); - newSet.delete(currentConversationId); + newSet.delete(idForCatch); return newSet; }); - - // When conversation is completed, only add to completed conversation list when user is not in current conversation interface - // Use ref to get the actual conversation the user is in - const currentUserConversation = currentSelectedConversationRef.current; - if (currentUserConversation !== currentConversationId) { + const currentUserConversation = conversationManagement.selectedConversationId; + if (currentUserConversation !== idForCatch) { setCompletedConversations((prev) => { const newSet = new Set(prev); - newSet.add(currentConversationId); + newSet.add(idForCatch); return newSet; }); } @@ -800,7 +743,7 @@ export function ChatInterface() { // Check if there are cached messages const hasCachedMessages = sessionMessages[dialog.conversation_id] !== undefined; - const isCurrentActive = dialog.conversation_id === conversationManagement.conversationId; + const isCurrentActive = dialog.conversation_id === conversationManagement.selectedConversationId; // Log: click conversation // If there are cached messages, ensure not to show loading state @@ -1057,7 +1000,7 @@ export function ChatInterface() { // Create a copy to avoid directly modifying parameters const updatedMessages = [...messages]; let hasUpdates = false; - const conversationIdToUse = targetConversationId || conversationManagement.conversationId; + const conversationIdToUse = targetConversationId ?? conversationManagement.selectedConversationId; // Process attachments for each message for (const message of updatedMessages) { @@ -1086,8 +1029,8 @@ export function ChatInterface() { } } - // If there are updates, set new message array - if (hasUpdates) { + // If there are updates and we have a conversation id, set new message array + if (hasUpdates && conversationIdToUse != null) { setSessionMessages((prev) => ({ ...prev, [conversationIdToUse]: updatedMessages, @@ -1095,76 +1038,6 @@ export function ChatInterface() { } }; - // Left sidebar conversation title update - const handleConversationRename = async (dialogId: number, title: string) => { - try { - await conversationService.rename(dialogId, title); - await conversationManagement.fetchConversationList(); - - if (conversationManagement.selectedConversationId === dialogId) { - conversationManagement.setConversationTitle(title); - } - } catch (error) { - log.error(t("chatInterface.renameFailed"), error); - } - }; - - // Left sidebar conversation deletion - const handleConversationDeleteClick = async (dialogId: number) => { - try { - // If deleting the currently active conversation, stop conversation first - if ( - conversationManagement.selectedConversationId === dialogId && - isStreaming && - conversationManagement.conversationId === dialogId - ) { - // Cancel current ongoing request first - if (abortControllerRef.current) { - try { - abortControllerRef.current.abort( - t("chatInterface.deleteConversation") - ); - } catch (error) { - log.error(t("chatInterface.errorCancelingRequest"), error); - } - abortControllerRef.current = null; - } - - // Clear timeout timer - if (timeoutRef.current) { - clearTimeout(timeoutRef.current); - timeoutRef.current = null; - } - - setIsStreaming(false); - setIsLoading(false); - - try { - await conversationService.stop(dialogId); - } catch (error) { - log.error( - t("chatInterface.stopConversationToDeleteFailed"), - error - ); - // Continue deleting even if stopping fails - } - } - - await conversationService.delete(dialogId); - await conversationManagement.fetchConversationList(); - - if (conversationManagement.selectedConversationId === dialogId) { - conversationManagement.setSelectedConversationId(null); - // Update ref to track current selected conversation - conversationManagement.currentSelectedConversationRef.current = null; - conversationManagement.setConversationTitle(t("chatInterface.newConversation")); - handleNewConversation(); - } - } catch (error) { - log.error(t("chatInterface.deleteFailed"), error); - } - }; - // Add image error handling function const handleImageError = (imageUrl: string) => { log.error(t("chatInterface.imageLoadFailed"), imageUrl); @@ -1173,7 +1046,7 @@ export function ChatInterface() { setSessionMessages((prev) => { const newMessages = { ...prev }; const lastMsg = - newMessages[conversationManagement.conversationId]?.[newMessages[conversationManagement.conversationId].length - 1]; + newMessages[conversationManagement.selectedConversationId!]?.[newMessages[conversationManagement.selectedConversationId!].length - 1]; if (lastMsg && lastMsg.role === ROLE_ASSISTANT && lastMsg.images) { // Filter out failed images @@ -1193,21 +1066,21 @@ export function ChatInterface() { const handleStop = async () => { // Stop agent_run of current conversation const currentController = - conversationControllersRef.current.get(conversationManagement.conversationId); + conversationControllersRef.current.get(conversationManagement.selectedConversationId!); if (currentController) { try { currentController.abort(t("chatInterface.userManuallyStopped")); } catch (error) { log.error(t("chatInterface.errorCancelingRequest"), error); } - conversationControllersRef.current.delete(conversationManagement.conversationId); + conversationControllersRef.current.delete(conversationManagement.selectedConversationId!); } // Clear timeout timer for current conversation - const currentTimeout = conversationTimeoutsRef.current.get(conversationManagement.conversationId); + const currentTimeout = conversationTimeoutsRef.current.get(conversationManagement.selectedConversationId!); if (currentTimeout) { clearTimeout(currentTimeout); - conversationTimeoutsRef.current.delete(conversationManagement.conversationId); + conversationTimeoutsRef.current.delete(conversationManagement.selectedConversationId!); } // Immediately update frontend state @@ -1215,19 +1088,19 @@ export function ChatInterface() { setIsLoading(false); // If no valid conversation ID, just reset frontend state - if (!conversationManagement.conversationId || conversationManagement.conversationId === -1) { + if (conversationManagement.selectedConversationId == null) { return; } try { // Call backend stop API - this will stop both agent run and preprocess tasks - await conversationService.stop(conversationManagement.conversationId); + await conversationService.stop(conversationManagement.selectedConversationId!); // Manually update messages, clear thinking state setSessionMessages((prev) => { const newMessages = { ...prev }; const lastMsg = - newMessages[conversationManagement.conversationId]?.[newMessages[conversationManagement.conversationId].length - 1]; + newMessages[conversationManagement.selectedConversationId!]?.[newMessages[conversationManagement.selectedConversationId!].length - 1]; if (lastMsg && lastMsg.role === ROLE_ASSISTANT) { lastMsg.isComplete = true; lastMsg.thinking = undefined; // Explicitly clear thinking state @@ -1238,16 +1111,16 @@ export function ChatInterface() { // remove from streaming list setStreamingConversations((prev) => { const newSet = new Set(prev); - newSet.delete(conversationManagement.conversationId); + newSet.delete(conversationManagement.selectedConversationId!); return newSet; }); // when conversation is stopped, only add to completed conversations list when user is not in current conversation interface - const currentUserConversation = currentSelectedConversationRef.current; - if (currentUserConversation !== conversationManagement.conversationId) { + const currentUserConversation = conversationManagement.selectedConversationId; + if (currentUserConversation != null && currentUserConversation !== conversationManagement.selectedConversationId) { setCompletedConversations((prev) => { const newSet = new Set(prev); - newSet.add(conversationManagement.conversationId); + newSet.add(conversationManagement.selectedConversationId!); return newSet; }); } @@ -1258,7 +1131,7 @@ export function ChatInterface() { setSessionMessages((prev) => { const newMessages = { ...prev }; const lastMsg = - newMessages[conversationManagement.conversationId]?.[newMessages[conversationManagement.conversationId].length - 1]; + newMessages[conversationManagement.selectedConversationId!]?.[newMessages[conversationManagement.selectedConversationId!].length - 1]; if (lastMsg && lastMsg.role === ROLE_ASSISTANT) { lastMsg.isComplete = true; lastMsg.thinking = undefined; // Explicitly clear thinking state @@ -1360,30 +1233,15 @@ export function ChatInterface() { return ( - <> -
- - setOpenDropdownId(open ? id : null) - } - onToggleSidebar={toggleSidebar} - expanded={sidebarOpen} - userEmail={user?.email} - userAvatarUrl={user?.avatarUrl} - userRole={user?.role} - /> - -
+ + + +
-
-
- + + ); } diff --git a/frontend/app/[locale]/chat/streaming/chatStreamMain.tsx b/frontend/app/[locale]/chat/streaming/chatStreamMain.tsx index ec3d0a7fa..0380a81c8 100644 --- a/frontend/app/[locale]/chat/streaming/chatStreamMain.tsx +++ b/frontend/app/[locale]/chat/streaming/chatStreamMain.tsx @@ -313,7 +313,7 @@ export function ChatStreamMain({ return (
{/* Main message area */} - +
{processedMessages.finalMessages.length === 0 ? ( isLoadingHistoricalConversation ? ( diff --git a/frontend/app/[locale]/layout.client.tsx b/frontend/app/[locale]/layout.client.tsx index 75b49d111..5f8c7d5fa 100644 --- a/frontend/app/[locale]/layout.client.tsx +++ b/frontend/app/[locale]/layout.client.tsx @@ -112,6 +112,7 @@ export function ClientLayout({ children }: { children: ReactNode }) { style={siderStyle} width={SIDER_CONFIG.EXPANDED_WIDTH} collapsed={collapsed} + onCollapse={setCollapsed} trigger={null} breakpoint="lg" collapsedWidth={SIDER_CONFIG.COLLAPSED_WIDTH} diff --git a/frontend/hooks/chat/useConversationManagement.ts b/frontend/hooks/chat/useConversationManagement.ts index 5e53680f3..c07726df7 100644 --- a/frontend/hooks/chat/useConversationManagement.ts +++ b/frontend/hooks/chat/useConversationManagement.ts @@ -1,79 +1,93 @@ -import { useState, useRef, useEffect } from "react"; +import type React from "react"; +import { useState } from "react"; import { useTranslation } from "react-i18next"; +import type { UseQueryResult } from "@tanstack/react-query"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; import { conversationService } from "@/services/conversationService"; import { ConversationListItem } from "@/types/chat"; import log from "@/lib/logger"; -export const useConversationManagement = () => { - const { t } = useTranslation("common"); - - // Conversation state - const [conversationId, setConversationId] = useState(0); - const [conversationTitle, setConversationTitle] = useState( - t("chatInterface.newConversation") - ); - const [conversationList, setConversationList] = useState< - ConversationListItem[] - >([]); - const [selectedConversationId, setSelectedConversationId] = useState< - number | null - >(null); - const [isNewConversation, setIsNewConversation] = useState(true); - const [conversationLoadError, setConversationLoadError] = useState<{ - [conversationId: number]: string; - }>({}); +const CONVERSATION_LIST_QUERY_KEY = ["conversations"] as const; - // Refs - const currentSelectedConversationRef = useRef(null); - const initialized = useRef(false); +/** + * Return type of useConversationManagement hook. + * Use this type when passing conversation management state/handlers between parent and child components. + */ +export interface ConversationManagement { + conversationTitle: string; + conversationList: ConversationListItem[]; + selectedConversationId: number | null; + isNewConversation: boolean; + conversationLoadError: Record; + conversationListQuery: UseQueryResult; + fetchConversationList: () => Promise; + invalidateConversationList: () => void; + handleNewConversation: () => void; + handleConversationSelect: (conversation: ConversationListItem) => Promise; + updateConversationTitle: (conversationId: number, title: string) => Promise; + clearConversationLoadError: (conversationId: number) => void; + setConversationLoadErrorForId: (conversationId: number, error: string) => void; + setSelectedConversationId: React.Dispatch>; + setConversationTitle: React.Dispatch>; + setIsNewConversation: React.Dispatch>; +} - // Ensure currentSelectedConversationRef is synchronized with selectedConversationId - useEffect(() => { - currentSelectedConversationRef.current = selectedConversationId; - }, [selectedConversationId]); +export const useConversationManagement = (): ConversationManagement => { + const { t } = useTranslation("common"); + const queryClient = useQueryClient(); - // Fetch conversation list - const fetchConversationList = async (): Promise => { - try { + const conversationListQuery = useQuery({ + queryKey: CONVERSATION_LIST_QUERY_KEY, + queryFn: async (): Promise => { const dialogHistory = await conversationService.getList(); - // Sort by creation time, newest first dialogHistory.sort((a, b) => b.create_time - a.create_time); - setConversationList(dialogHistory); return dialogHistory; - } catch (error) { - log.error(t("chatInterface.errorFetchingConversationList"), error); - throw error; + }, + staleTime: 30_000, + }); + + const conversationList = conversationListQuery.data ?? []; + + const fetchConversationList = async (): Promise => { + const result = await conversationListQuery.refetch(); + if (result.error) { + log.error(t("chatInterface.errorFetchingConversationList"), result.error); + throw result.error; } + return result.data ?? []; }; + const invalidateConversationList = () => queryClient.invalidateQueries({ queryKey: CONVERSATION_LIST_QUERY_KEY }); + + // Conversation state: null = no selection / new conversation, number = current conversation id + const [conversationTitle, setConversationTitle] = useState(t("chatInterface.newConversation")); + const [selectedConversationId, setSelectedConversationId] = useState(null); + const [isNewConversation, setIsNewConversation] = useState(true); + const [conversationLoadError, setConversationLoadError] = useState<{[conversationId: number]: string;}>({}); + + // Refs + // Handle new conversation const handleNewConversation = () => { - setConversationId(-1); setSelectedConversationId(null); setConversationTitle(t("chatInterface.newConversation")); setIsNewConversation(true); - currentSelectedConversationRef.current = null; }; // Handle conversation selection - const handleConversationSelect = async (dialog: ConversationListItem) => { - // Immediately set conversation state, avoid flashing new conversation interface - setSelectedConversationId(dialog.conversation_id); - setConversationId(dialog.conversation_id); - setConversationTitle(dialog.conversation_title); - - // Update ref to track current selected conversation - currentSelectedConversationRef.current = dialog.conversation_id; + const handleConversationSelect = async (conversation: ConversationListItem) => { + setSelectedConversationId(conversation.conversation_id); + setConversationTitle(conversation.conversation_title); setIsNewConversation(false); }; // Update conversation title - const updateConversationTitle = async (dialogId: number, title: string) => { + const updateConversationTitle = async (conversationId: number, title: string) => { try { - await conversationService.rename(dialogId, title); + await conversationService.rename(conversationId, title); await fetchConversationList(); - if (selectedConversationId === dialogId) { + if (selectedConversationId === conversationId) { setConversationTitle(title); } } catch (error) { @@ -101,27 +115,23 @@ export const useConversationManagement = () => { return { // State (read-only) - conversationId, conversationTitle, conversationList, selectedConversationId, isNewConversation, conversationLoadError, - - // Refs - currentSelectedConversationRef, - initialized, - + conversationListQuery, + // Methods fetchConversationList, + invalidateConversationList, handleNewConversation, handleConversationSelect, updateConversationTitle, clearConversationLoadError, setConversationLoadErrorForId, - + // Setters (for internal use by components) - setConversationId, setSelectedConversationId, setConversationTitle, setIsNewConversation, diff --git a/frontend/types/chat.ts b/frontend/types/chat.ts index 3e9835d0a..af5751295 100644 --- a/frontend/types/chat.ts +++ b/frontend/types/chat.ts @@ -262,27 +262,6 @@ export interface SettingsMenuItem { onClick: () => void; } -// Chat sidebar props type -export interface ChatSidebarProps { - conversationList: ConversationListItem[]; - selectedConversationId: number | null; - openDropdownId: string | null; - streamingConversations: Set; - completedConversations: Set; - onNewConversation: () => void; - onDialogClick: (dialog: ConversationListItem) => void; - onRename: (dialogId: number, title: string) => void; - onDelete: (dialogId: number) => void; - onSettingsClick: () => void; - settingsMenuItems?: SettingsMenuItem[]; - onDropdownOpenChange: (open: boolean, id: string | null) => void; - onToggleSidebar: () => void; - expanded: boolean; - userEmail: string | undefined; - userAvatarUrl: string | undefined; - userRole: string | undefined; -} - // Image item type for chat right panel export interface ImageItem { base64Data: string; From e528716b3d06b767fa1b866f1424b4247b8ce453 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Wed, 11 Mar 2026 10:04:11 +0800 Subject: [PATCH 64/75] Bugfix: Fix background color of root container to white --- frontend/app/[locale]/chat/components/chatInput.tsx | 2 +- frontend/app/[locale]/chat/components/chatRightPanel.tsx | 2 +- frontend/app/[locale]/chat/streaming/chatStreamMain.tsx | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/app/[locale]/chat/components/chatInput.tsx b/frontend/app/[locale]/chat/components/chatInput.tsx index 7665b934c..9b175c8cd 100644 --- a/frontend/app/[locale]/chat/components/chatInput.tsx +++ b/frontend/app/[locale]/chat/components/chatInput.tsx @@ -1257,7 +1257,7 @@ export function ChatInput({
) : ( -
+