-
Notifications
You must be signed in to change notification settings - Fork 555
Open
Labels
enhancementNew feature or requestNew feature or request
Description
@chat.post("/agent/{agent_id}")
async def chat_agent
外部可以调用这个接口
# ==================================================================
# > === 外部调用接口 (无需认证) ===
# ==================================================================
@chat.post("/external/thread", response_model=ThreadResponse)
async def external_create_thread(
thread: ThreadCreate,
db: AsyncSession = Depends(get_db),
):
"""外部程序调用接口 - 创建新对话线程(无需登录,使用固定user_id=1)"""
# 获取user_id=1的用户
result = await db.execute(select(User).filter(User.id == 1))
current_user = result.scalar_one_or_none()
if current_user is None:
raise HTTPException(
status_code=500,
detail="系统用户(user_id=1)不存在,请联系管理员"
)
thread_id = str(uuid.uuid4())
logger.debug(f"[EXTERNAL] Creating thread with agent_id: {thread.agent_id}")
# Create conversation using new storage system
conv_manager = ConversationManager(db)
conversation = await conv_manager.create_conversation(
user_id=str(current_user.id),
agent_id=thread.agent_id,
title=thread.title or "新的对话",
thread_id=thread_id,
metadata={**(thread.metadata or {}), "external_api": True},
)
logger.info(f"[EXTERNAL] Created conversation with thread_id: {thread_id}")
return {
"id": conversation.thread_id,
"user_id": conversation.user_id,
"agent_id": conversation.agent_id,
"title": conversation.title,
"created_at": conversation.created_at.isoformat(),
"updated_at": conversation.updated_at.isoformat(),
}
@chat.post("/external/agent/{agent_id}")
async def external_chat_agent(
agent_id: str,
query: str = Body(...),
config: dict = Body({}),
meta: dict = Body({}),
image_content: str | None = Body(None),
db: AsyncSession = Depends(get_db),
):
"""外部程序调用接口 - 使用特定智能体进行对话(无需登录,使用固定user_id=1)"""
start_time = asyncio.get_event_loop().time()
logger.info(f"[EXTERNAL] agent_id: {agent_id}, query: {query}, config: {config}, meta: {meta}")
logger.info(f"[EXTERNAL] image_content present: {image_content is not None}")
if image_content:
logger.info(f"[EXTERNAL] image_content length: {len(image_content)}")
logger.info(f"[EXTERNAL] image_content preview: {image_content[:50]}...")
# 获取user_id=1的用户
result = await db.execute(select(User).filter(User.id == 1))
current_user = result.scalar_one_or_none()
if current_user is None:
raise HTTPException(
status_code=500,
detail="系统用户(user_id=1)不存在,请联系管理员"
)
# 确保 request_id 存在
if "request_id" not in meta or not meta.get("request_id"):
meta["request_id"] = str(uuid.uuid4())
meta.update(
{
"query": query,
"agent_id": agent_id,
"server_model_name": config.get("model", agent_id),
"thread_id": config.get("thread_id"),
"user_id": current_user.id,
"has_image": bool(image_content),
"external_api": True, # 标记为外部API调用
}
)
# 将meta和thread_id整合到config中
def make_chunk(content=None, **kwargs):
return (
json.dumps(
{"request_id": meta.get("request_id"), "response": content, **kwargs}, ensure_ascii=False
).encode("utf-8")
+ b"\n"
)
async def stream_messages():
# 构建多模态消息
if image_content:
# 多模态消息格式
human_message = HumanMessage(
content=[
{"type": "text", "text": query},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_content}"}},
]
)
message_type = "multimodal_image"
else:
# 普通文本消息
human_message = HumanMessage(content=query)
message_type = "text"
# 代表服务端已经收到了请求,发送前端友好的消息格式
init_msg = {"role": "user", "content": query, "type": "human"}
# 如果有图片,添加图片相关信息
if image_content:
init_msg["message_type"] = "multimodal_image"
init_msg["image_content"] = image_content
else:
init_msg["message_type"] = "text"
yield make_chunk(status="init", meta=meta, msg=init_msg)
# Input guard
if conf.enable_content_guard and await content_guard.check(query):
yield make_chunk(
status="error", error_type="content_guard_blocked", error_message="输入内容包含敏感词", meta=meta
)
return
try:
agent = agent_manager.get_agent(agent_id)
except Exception as e:
logger.error(f"Error getting agent {agent_id}: {e}, {traceback.format_exc()}")
yield make_chunk(
status="error",
error_type="agent_error",
error_message=f"智能体 {agent_id} 获取失败: {str(e)}",
meta=meta,
)
return
messages = [human_message]
# 构造运行时配置,如果没有thread_id则生成一个
user_id = str(current_user.id)
thread_id = config.get("thread_id")
input_context = {"user_id": user_id, "thread_id": thread_id}
if not thread_id:
thread_id = str(uuid.uuid4())
logger.warning(f"[EXTERNAL] No thread_id provided, generated new thread_id: {thread_id}")
try:
async with db_manager.get_async_session_context() as db:
# Initialize conversation manager
conv_manager = ConversationManager(db)
# Save user message
try:
await conv_manager.add_message_by_thread_id(
thread_id=thread_id,
role="user",
content=query,
message_type=message_type,
image_content=image_content,
extra_metadata={"raw_message": human_message.model_dump()},
)
except Exception as e:
logger.error(f"[EXTERNAL] Error saving user message: {e}")
try:
assert thread_id, "thread_id is required"
attachments = await conv_manager.get_attachments_by_thread_id(thread_id)
input_context["attachments"] = attachments
logger.debug(f"Loaded {len(attachments)} attachments for thread_id={thread_id}")
except Exception as e:
logger.error(f"Error loading attachments for thread_id={thread_id}: {e}")
input_context["attachments"] = []
full_msg = None
accumulated_content = []
langgraph_config = {"configurable": input_context}
async for msg, metadata in agent.stream_messages(messages, input_context=input_context):
if isinstance(msg, AIMessageChunk):
accumulated_content.append(msg.content)
content_for_check = "".join(accumulated_content[-10:])
if conf.enable_content_guard and await content_guard.check_with_keywords(content_for_check):
logger.warning("Sensitive content detected in stream")
full_msg = AIMessage(content="".join(accumulated_content))
await save_partial_message(conv_manager, thread_id, full_msg, "content_guard_blocked")
meta["time_cost"] = asyncio.get_event_loop().time() - start_time
yield make_chunk(status="interrupted", message="检测到敏感内容,已中断输出", meta=meta)
return
yield make_chunk(content=msg.content, msg=msg.model_dump(), metadata=metadata, status="loading")
else:
msg_dict = msg.model_dump()
yield make_chunk(msg=msg_dict, metadata=metadata, status="loading")
# try:
# if msg_dict.get("type") == "tool":
# graph = await agent.get_graph()
# state = await graph.aget_state(langgraph_config)
# agent_state = _extract_agent_state(getattr(state, "values", {})) if state else {}
# if agent_state:
# yield make_chunk(status="agent_state", agent_state=agent_state, meta=meta)
# except Exception as e:
# logger.error(f"Error processing tool message: {e}")
# pass
if not full_msg and accumulated_content:
full_msg = AIMessage(content="".join(accumulated_content))
if (
conf.enable_content_guard
and hasattr(full_msg, "content")
and await content_guard.check(full_msg.content)
):
logger.warning("Sensitive content detected in final message")
await save_partial_message(conv_manager, thread_id, full_msg, "content_guard_blocked")
meta["time_cost"] = asyncio.get_event_loop().time() - start_time
yield make_chunk(status="interrupted", message="检测到敏感内容,已中断输出", meta=meta)
return
# After streaming finished, check for interrupts and save messages
# Check for human approval interrupts
async for chunk in check_and_handle_interrupts(agent, langgraph_config, make_chunk, meta, thread_id):
yield chunk
meta["time_cost"] = asyncio.get_event_loop().time() - start_time
try:
graph = await agent.get_graph()
state = await graph.aget_state(langgraph_config)
agent_state = _extract_agent_state(getattr(state, "values", {})) if state else {}
except Exception:
agent_state = {}
if agent_state:
yield make_chunk(status="agent_state", agent_state=agent_state, meta=meta)
yield make_chunk(status="finished", meta=meta)
# Save all messages from LangGraph state
await save_messages_from_langgraph_state(
agent_instance=agent,
thread_id=thread_id,
conv_mgr=conv_manager,
config_dict=langgraph_config,
)
except (asyncio.CancelledError, ConnectionError) as e:
# 客户端主动中断连接,检查中断并保存已生成的部分内容
logger.warning(f"Client disconnected, cancelling stream: {e}")
# Run save in a separate task to avoid cancellation
async def save_cleanup():
nonlocal full_msg
if not full_msg and accumulated_content:
full_msg = AIMessage(content="".join(accumulated_content))
async with db_manager.get_async_session_context() as new_db:
new_conv_manager = ConversationManager(new_db)
await save_partial_message(
new_conv_manager,
thread_id,
full_msg=full_msg,
error_message="对话已中断" if not full_msg else None,
error_type="interrupted",
)
# Create a task and await it, shielding it from cancellation
# ensuring the DB operation completes even if the stream is cancelled
cleanup_task = asyncio.create_task(save_cleanup())
try:
await asyncio.shield(cleanup_task)
except asyncio.CancelledError:
pass
except Exception as exc:
logger.error(f"Error during cleanup save: {exc}")
# 通知前端中断(可能发送不到,但用于一致性)
yield make_chunk(status="interrupted", message="对话已中断", meta=meta)
except Exception as e:
logger.error(f"Error streaming messages: {e}, {traceback.format_exc()}")
error_msg = f"Error streaming messages: {e}"
error_type = "unexpected_error"
if not full_msg and accumulated_content:
full_msg = AIMessage(content="".join(accumulated_content))
# 保存错误消息到数据库
async with db_manager.get_async_session_context() as new_db:
new_conv_manager = ConversationManager(new_db)
await save_partial_message(
new_conv_manager,
thread_id,
full_msg=full_msg,
error_message=error_msg,
error_type=error_type,
)
yield make_chunk(status="error", error_type=error_type, error_message=error_msg, meta=meta)
return StreamingResponse(stream_messages(), media_type="application/json")
@chat.post("/external/agent/{agent_id}/sync")
async def external_chat_agent_sync(
agent_id: str,
query: str = Body(...),
config: dict = Body({}),
meta: dict = Body({}),
image_content: str | None = Body(None),
db: AsyncSession = Depends(get_db),
):
"""外部程序调用接口(非流式) - 使用特定智能体进行对话(无需登录,使用固定user_id=1)"""
start_time = asyncio.get_event_loop().time()
logger.info(f"[EXTERNAL-SYNC] agent_id: {agent_id}, query: {query}, config: {config}, meta: {meta}")
logger.info(f"[EXTERNAL-SYNC] image_content present: {image_content is not None}")
if image_content:
logger.info(f"[EXTERNAL-SYNC] image_content length: {len(image_content)}")
logger.info(f"[EXTERNAL-SYNC] image_content preview: {image_content[:50]}...")
# 获取user_id=1的用户
result = await db.execute(select(User).filter(User.id == 1))
current_user = result.scalar_one_or_none()
if current_user is None:
raise HTTPException(
status_code=500,
detail="系统用户(user_id=1)不存在,请联系管理员"
)
# 确保 request_id 存在
if "request_id" not in meta or not meta.get("request_id"):
meta["request_id"] = str(uuid.uuid4())
meta.update(
{
"query": query,
"agent_id": agent_id,
"server_model_name": config.get("model", agent_id),
"thread_id": config.get("thread_id"),
"user_id": current_user.id,
"has_image": bool(image_content),
"external_api": True,
"sync_mode": True, # 标记为同步模式
}
)
# 构建多模态消息
if image_content:
# 多模态消息格式
human_message = HumanMessage(
content=[
{"type": "text", "text": query},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_content}"}},
]
)
message_type = "multimodal_image"
else:
# 普通文本消息
human_message = HumanMessage(content=query)
message_type = "text"
# Input guard
if conf.enable_content_guard and await content_guard.check(query):
return {
"request_id": meta.get("request_id"),
"status": "error",
"error_type": "content_guard_blocked",
"error_message": "输入内容包含敏感词",
"meta": meta,
}
try:
agent = agent_manager.get_agent(agent_id)
except Exception as e:
logger.error(f"Error getting agent {agent_id}: {e}, {traceback.format_exc()}")
return {
"request_id": meta.get("request_id"),
"status": "error",
"error_type": "agent_error",
"error_message": f"智能体 {agent_id} 获取失败: {str(e)}",
"meta": meta,
}
messages = [human_message]
# 构造运行时配置,如果没有thread_id则生成一个
user_id = str(current_user.id)
thread_id = config.get("thread_id")
input_context = {"user_id": user_id, "thread_id": thread_id}
if not thread_id:
thread_id = str(uuid.uuid4())
logger.warning(f"[EXTERNAL-SYNC] No thread_id provided, generated new thread_id: {thread_id}")
try:
async with db_manager.get_async_session_context() as db:
# Initialize conversation manager
conv_manager = ConversationManager(db)
# Save user message
try:
await conv_manager.add_message_by_thread_id(
thread_id=thread_id,
role="user",
content=query,
message_type=message_type,
image_content=image_content,
extra_metadata={"raw_message": human_message.model_dump()},
)
except Exception as e:
logger.error(f"[EXTERNAL-SYNC] Error saving user message: {e}")
try:
assert thread_id, "thread_id is required"
attachments = await conv_manager.get_attachments_by_thread_id(thread_id)
input_context["attachments"] = attachments
logger.debug(f"Loaded {len(attachments)} attachments for thread_id={thread_id}")
except Exception as e:
logger.error(f"Error loading attachments for thread_id={thread_id}: {e}")
input_context["attachments"] = []
full_msg = None
accumulated_content = []
tool_messages = []
langgraph_config = {"configurable": input_context}
# 收集所有消息
async for msg, metadata in agent.stream_messages(messages, input_context=input_context):
if isinstance(msg, AIMessageChunk):
accumulated_content.append(msg.content)
content_for_check = "".join(accumulated_content[-10:])
if conf.enable_content_guard and await content_guard.check_with_keywords(content_for_check):
logger.warning("Sensitive content detected in stream")
full_msg = AIMessage(content="".join(accumulated_content))
await save_partial_message(conv_manager, thread_id, full_msg, "content_guard_blocked")
meta["time_cost"] = asyncio.get_event_loop().time() - start_time
return {
"request_id": meta.get("request_id"),
"status": "interrupted",
"message": "检测到敏感内容,已中断输出",
"meta": meta,
}
else:
# 保存工具消息以便返回
msg_dict = msg.model_dump()
tool_messages.append(msg_dict)
if not full_msg and accumulated_content:
full_msg = AIMessage(content="".join(accumulated_content))
if (
conf.enable_content_guard
and hasattr(full_msg, "content")
and await content_guard.check(full_msg.content)
):
logger.warning("Sensitive content detected in final message")
await save_partial_message(conv_manager, thread_id, full_msg, "content_guard_blocked")
meta["time_cost"] = asyncio.get_event_loop().time() - start_time
return {
"request_id": meta.get("request_id"),
"status": "interrupted",
"message": "检测到敏感内容,已中断输出",
"meta": meta,
}
# 获取最终的agent状态
try:
graph = await agent.get_graph()
state = await graph.aget_state(langgraph_config)
agent_state = _extract_agent_state(getattr(state, "values", {})) if state else {}
except Exception:
agent_state = {}
meta["time_cost"] = asyncio.get_event_loop().time() - start_time
# Save all messages from LangGraph state
await save_messages_from_langgraph_state(
agent_instance=agent,
thread_id=thread_id,
conv_mgr=conv_manager,
config_dict=langgraph_config,
)
# 返回完整响应
return {
"request_id": meta.get("request_id"),
"status": "finished",
"response": full_msg.content if full_msg else "",
"tool_messages": tool_messages,
"agent_state": agent_state,
"meta": meta,
}
except Exception as e:
logger.error(f"Error processing messages: {e}, {traceback.format_exc()}")
error_msg = f"Error processing messages: {e}"
error_type = "unexpected_error"
if not full_msg and accumulated_content:
full_msg = AIMessage(content="".join(accumulated_content))
# 保存错误消息到数据库
async with db_manager.get_async_session_context() as new_db:
new_conv_manager = ConversationManager(new_db)
await save_partial_message(
new_conv_manager,
thread_id,
full_msg=full_msg,
error_message=error_msg,
error_type=error_type,
)
return {
"request_id": meta.get("request_id"),
"status": "error",
"error_type": error_type,
"error_message": error_msg,
"meta": meta,
}
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request