Skip to content

有没有计划将智能体分享给第三方使用? #502

@zbage

Description

@zbage

@chat.post("/agent/{agent_id}")
async def chat_agent
外部可以调用这个接口

# ==================================================================
# > === 外部调用接口 (无需认证) ===
# ==================================================================

@chat.post("/external/thread", response_model=ThreadResponse)
async def external_create_thread(
    thread: ThreadCreate,
    db: AsyncSession = Depends(get_db),
):
    """外部程序调用接口 - 创建新对话线程(无需登录,使用固定user_id=1)"""
    # 获取user_id=1的用户
    result = await db.execute(select(User).filter(User.id == 1))
    current_user = result.scalar_one_or_none()

    if current_user is None:
        raise HTTPException(
            status_code=500,
            detail="系统用户(user_id=1)不存在,请联系管理员"
        )

    thread_id = str(uuid.uuid4())
    logger.debug(f"[EXTERNAL] Creating thread with agent_id: {thread.agent_id}")

    # Create conversation using new storage system
    conv_manager = ConversationManager(db)
    conversation = await conv_manager.create_conversation(
        user_id=str(current_user.id),
        agent_id=thread.agent_id,
        title=thread.title or "新的对话",
        thread_id=thread_id,
        metadata={**(thread.metadata or {}), "external_api": True},
    )

    logger.info(f"[EXTERNAL] Created conversation with thread_id: {thread_id}")

    return {
        "id": conversation.thread_id,
        "user_id": conversation.user_id,
        "agent_id": conversation.agent_id,
        "title": conversation.title,
        "created_at": conversation.created_at.isoformat(),
        "updated_at": conversation.updated_at.isoformat(),
    }


@chat.post("/external/agent/{agent_id}")
async def external_chat_agent(
    agent_id: str,
    query: str = Body(...),
    config: dict = Body({}),
    meta: dict = Body({}),
    image_content: str | None = Body(None),
    db: AsyncSession = Depends(get_db),
):
    """外部程序调用接口 - 使用特定智能体进行对话(无需登录,使用固定user_id=1)"""
    start_time = asyncio.get_event_loop().time()

    logger.info(f"[EXTERNAL] agent_id: {agent_id}, query: {query}, config: {config}, meta: {meta}")
    logger.info(f"[EXTERNAL] image_content present: {image_content is not None}")
    if image_content:
        logger.info(f"[EXTERNAL] image_content length: {len(image_content)}")
        logger.info(f"[EXTERNAL] image_content preview: {image_content[:50]}...")

    # 获取user_id=1的用户
    result = await db.execute(select(User).filter(User.id == 1))
    current_user = result.scalar_one_or_none()

    if current_user is None:
        raise HTTPException(
            status_code=500,
            detail="系统用户(user_id=1)不存在,请联系管理员"
        )

    # 确保 request_id 存在
    if "request_id" not in meta or not meta.get("request_id"):
        meta["request_id"] = str(uuid.uuid4())

    meta.update(
        {
            "query": query,
            "agent_id": agent_id,
            "server_model_name": config.get("model", agent_id),
            "thread_id": config.get("thread_id"),
            "user_id": current_user.id,
            "has_image": bool(image_content),
            "external_api": True,  # 标记为外部API调用
        }
    )

    # 将meta和thread_id整合到config中
    def make_chunk(content=None, **kwargs):
        return (
            json.dumps(
                {"request_id": meta.get("request_id"), "response": content, **kwargs}, ensure_ascii=False
            ).encode("utf-8")
            + b"\n"
        )

    async def stream_messages():
        # 构建多模态消息
        if image_content:
            # 多模态消息格式
            human_message = HumanMessage(
                content=[
                    {"type": "text", "text": query},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_content}"}},
                ]
            )
            message_type = "multimodal_image"
        else:
            # 普通文本消息
            human_message = HumanMessage(content=query)
            message_type = "text"

        # 代表服务端已经收到了请求,发送前端友好的消息格式
        init_msg = {"role": "user", "content": query, "type": "human"}

        # 如果有图片,添加图片相关信息
        if image_content:
            init_msg["message_type"] = "multimodal_image"
            init_msg["image_content"] = image_content
        else:
            init_msg["message_type"] = "text"

        yield make_chunk(status="init", meta=meta, msg=init_msg)

        # Input guard
        if conf.enable_content_guard and await content_guard.check(query):
            yield make_chunk(
                status="error", error_type="content_guard_blocked", error_message="输入内容包含敏感词", meta=meta
            )
            return

        try:
            agent = agent_manager.get_agent(agent_id)
        except Exception as e:
            logger.error(f"Error getting agent {agent_id}: {e}, {traceback.format_exc()}")
            yield make_chunk(
                status="error",
                error_type="agent_error",
                error_message=f"智能体 {agent_id} 获取失败: {str(e)}",
                meta=meta,
            )
            return

        messages = [human_message]

        # 构造运行时配置,如果没有thread_id则生成一个
        user_id = str(current_user.id)
        thread_id = config.get("thread_id")
        input_context = {"user_id": user_id, "thread_id": thread_id}

        if not thread_id:
            thread_id = str(uuid.uuid4())
            logger.warning(f"[EXTERNAL] No thread_id provided, generated new thread_id: {thread_id}")

        try:
            async with db_manager.get_async_session_context() as db:
                # Initialize conversation manager
                conv_manager = ConversationManager(db)

                # Save user message
                try:
                    await conv_manager.add_message_by_thread_id(
                        thread_id=thread_id,
                        role="user",
                        content=query,
                        message_type=message_type,
                        image_content=image_content,
                        extra_metadata={"raw_message": human_message.model_dump()},
                    )
                except Exception as e:
                    logger.error(f"[EXTERNAL] Error saving user message: {e}")

                try:
                    assert thread_id, "thread_id is required"
                    attachments = await conv_manager.get_attachments_by_thread_id(thread_id)
                    input_context["attachments"] = attachments
                    logger.debug(f"Loaded {len(attachments)} attachments for thread_id={thread_id}")
                except Exception as e:
                    logger.error(f"Error loading attachments for thread_id={thread_id}: {e}")
                    input_context["attachments"] = []

                full_msg = None
                accumulated_content = []
                langgraph_config = {"configurable": input_context}
                async for msg, metadata in agent.stream_messages(messages, input_context=input_context):
                    if isinstance(msg, AIMessageChunk):
                        accumulated_content.append(msg.content)

                        content_for_check = "".join(accumulated_content[-10:])
                        if conf.enable_content_guard and await content_guard.check_with_keywords(content_for_check):
                            logger.warning("Sensitive content detected in stream")
                            full_msg = AIMessage(content="".join(accumulated_content))
                            await save_partial_message(conv_manager, thread_id, full_msg, "content_guard_blocked")
                            meta["time_cost"] = asyncio.get_event_loop().time() - start_time
                            yield make_chunk(status="interrupted", message="检测到敏感内容,已中断输出", meta=meta)
                            return

                        yield make_chunk(content=msg.content, msg=msg.model_dump(), metadata=metadata, status="loading")

                    else:
                        msg_dict = msg.model_dump()
                        yield make_chunk(msg=msg_dict, metadata=metadata, status="loading")

                        # try:
                        #     if msg_dict.get("type") == "tool":
                        #         graph = await agent.get_graph()
                        #         state = await graph.aget_state(langgraph_config)
                        #         agent_state = _extract_agent_state(getattr(state, "values", {})) if state else {}
                        #         if agent_state:
                        #             yield make_chunk(status="agent_state", agent_state=agent_state, meta=meta)
                        # except Exception as e:
                        #     logger.error(f"Error processing tool message: {e}")
                        #     pass

                if not full_msg and accumulated_content:
                    full_msg = AIMessage(content="".join(accumulated_content))

                if (
                    conf.enable_content_guard
                    and hasattr(full_msg, "content")
                    and await content_guard.check(full_msg.content)
                ):
                    logger.warning("Sensitive content detected in final message")
                    await save_partial_message(conv_manager, thread_id, full_msg, "content_guard_blocked")
                    meta["time_cost"] = asyncio.get_event_loop().time() - start_time
                    yield make_chunk(status="interrupted", message="检测到敏感内容,已中断输出", meta=meta)
                    return

                # After streaming finished, check for interrupts and save messages

                # Check for human approval interrupts
                async for chunk in check_and_handle_interrupts(agent, langgraph_config, make_chunk, meta, thread_id):
                    yield chunk

                meta["time_cost"] = asyncio.get_event_loop().time() - start_time
                try:
                    graph = await agent.get_graph()
                    state = await graph.aget_state(langgraph_config)
                    agent_state = _extract_agent_state(getattr(state, "values", {})) if state else {}
                except Exception:
                    agent_state = {}

                if agent_state:
                    yield make_chunk(status="agent_state", agent_state=agent_state, meta=meta)

                yield make_chunk(status="finished", meta=meta)

                # Save all messages from LangGraph state
                await save_messages_from_langgraph_state(
                    agent_instance=agent,
                    thread_id=thread_id,
                    conv_mgr=conv_manager,
                    config_dict=langgraph_config,
                )

        except (asyncio.CancelledError, ConnectionError) as e:
            # 客户端主动中断连接,检查中断并保存已生成的部分内容
            logger.warning(f"Client disconnected, cancelling stream: {e}")

            # Run save in a separate task to avoid cancellation
            async def save_cleanup():
                nonlocal full_msg
                if not full_msg and accumulated_content:
                    full_msg = AIMessage(content="".join(accumulated_content))

                async with db_manager.get_async_session_context() as new_db:
                    new_conv_manager = ConversationManager(new_db)
                    await save_partial_message(
                        new_conv_manager,
                        thread_id,
                        full_msg=full_msg,
                        error_message="对话已中断" if not full_msg else None,
                        error_type="interrupted",
                    )

            # Create a task and await it, shielding it from cancellation
            # ensuring the DB operation completes even if the stream is cancelled
            cleanup_task = asyncio.create_task(save_cleanup())
            try:
                await asyncio.shield(cleanup_task)
            except asyncio.CancelledError:
                pass
            except Exception as exc:
                logger.error(f"Error during cleanup save: {exc}")

            # 通知前端中断(可能发送不到,但用于一致性)
            yield make_chunk(status="interrupted", message="对话已中断", meta=meta)

        except Exception as e:
            logger.error(f"Error streaming messages: {e}, {traceback.format_exc()}")

            error_msg = f"Error streaming messages: {e}"
            error_type = "unexpected_error"

            if not full_msg and accumulated_content:
                full_msg = AIMessage(content="".join(accumulated_content))

            # 保存错误消息到数据库
            async with db_manager.get_async_session_context() as new_db:
                new_conv_manager = ConversationManager(new_db)
                await save_partial_message(
                    new_conv_manager,
                    thread_id,
                    full_msg=full_msg,
                    error_message=error_msg,
                    error_type=error_type,
                )

            yield make_chunk(status="error", error_type=error_type, error_message=error_msg, meta=meta)

    return StreamingResponse(stream_messages(), media_type="application/json")


@chat.post("/external/agent/{agent_id}/sync")
async def external_chat_agent_sync(
    agent_id: str,
    query: str = Body(...),
    config: dict = Body({}),
    meta: dict = Body({}),
    image_content: str | None = Body(None),
    db: AsyncSession = Depends(get_db),
):
    """外部程序调用接口(非流式) - 使用特定智能体进行对话(无需登录,使用固定user_id=1)"""
    start_time = asyncio.get_event_loop().time()

    logger.info(f"[EXTERNAL-SYNC] agent_id: {agent_id}, query: {query}, config: {config}, meta: {meta}")
    logger.info(f"[EXTERNAL-SYNC] image_content present: {image_content is not None}")
    if image_content:
        logger.info(f"[EXTERNAL-SYNC] image_content length: {len(image_content)}")
        logger.info(f"[EXTERNAL-SYNC] image_content preview: {image_content[:50]}...")

    # 获取user_id=1的用户
    result = await db.execute(select(User).filter(User.id == 1))
    current_user = result.scalar_one_or_none()

    if current_user is None:
        raise HTTPException(
            status_code=500,
            detail="系统用户(user_id=1)不存在,请联系管理员"
        )

    # 确保 request_id 存在
    if "request_id" not in meta or not meta.get("request_id"):
        meta["request_id"] = str(uuid.uuid4())

    meta.update(
        {
            "query": query,
            "agent_id": agent_id,
            "server_model_name": config.get("model", agent_id),
            "thread_id": config.get("thread_id"),
            "user_id": current_user.id,
            "has_image": bool(image_content),
            "external_api": True,
            "sync_mode": True,  # 标记为同步模式
        }
    )

    # 构建多模态消息
    if image_content:
        # 多模态消息格式
        human_message = HumanMessage(
            content=[
                {"type": "text", "text": query},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_content}"}},
            ]
        )
        message_type = "multimodal_image"
    else:
        # 普通文本消息
        human_message = HumanMessage(content=query)
        message_type = "text"

    # Input guard
    if conf.enable_content_guard and await content_guard.check(query):
        return {
            "request_id": meta.get("request_id"),
            "status": "error",
            "error_type": "content_guard_blocked",
            "error_message": "输入内容包含敏感词",
            "meta": meta,
        }

    try:
        agent = agent_manager.get_agent(agent_id)
    except Exception as e:
        logger.error(f"Error getting agent {agent_id}: {e}, {traceback.format_exc()}")
        return {
            "request_id": meta.get("request_id"),
            "status": "error",
            "error_type": "agent_error",
            "error_message": f"智能体 {agent_id} 获取失败: {str(e)}",
            "meta": meta,
        }

    messages = [human_message]

    # 构造运行时配置,如果没有thread_id则生成一个
    user_id = str(current_user.id)
    thread_id = config.get("thread_id")
    input_context = {"user_id": user_id, "thread_id": thread_id}

    if not thread_id:
        thread_id = str(uuid.uuid4())
        logger.warning(f"[EXTERNAL-SYNC] No thread_id provided, generated new thread_id: {thread_id}")

    try:
        async with db_manager.get_async_session_context() as db:
            # Initialize conversation manager
            conv_manager = ConversationManager(db)

            # Save user message
            try:
                await conv_manager.add_message_by_thread_id(
                    thread_id=thread_id,
                    role="user",
                    content=query,
                    message_type=message_type,
                    image_content=image_content,
                    extra_metadata={"raw_message": human_message.model_dump()},
                )
            except Exception as e:
                logger.error(f"[EXTERNAL-SYNC] Error saving user message: {e}")

            try:
                assert thread_id, "thread_id is required"
                attachments = await conv_manager.get_attachments_by_thread_id(thread_id)
                input_context["attachments"] = attachments
                logger.debug(f"Loaded {len(attachments)} attachments for thread_id={thread_id}")
            except Exception as e:
                logger.error(f"Error loading attachments for thread_id={thread_id}: {e}")
                input_context["attachments"] = []

            full_msg = None
            accumulated_content = []
            tool_messages = []
            langgraph_config = {"configurable": input_context}

            # 收集所有消息
            async for msg, metadata in agent.stream_messages(messages, input_context=input_context):
                if isinstance(msg, AIMessageChunk):
                    accumulated_content.append(msg.content)

                    content_for_check = "".join(accumulated_content[-10:])
                    if conf.enable_content_guard and await content_guard.check_with_keywords(content_for_check):
                        logger.warning("Sensitive content detected in stream")
                        full_msg = AIMessage(content="".join(accumulated_content))
                        await save_partial_message(conv_manager, thread_id, full_msg, "content_guard_blocked")
                        meta["time_cost"] = asyncio.get_event_loop().time() - start_time
                        return {
                            "request_id": meta.get("request_id"),
                            "status": "interrupted",
                            "message": "检测到敏感内容,已中断输出",
                            "meta": meta,
                        }
                else:
                    # 保存工具消息以便返回
                    msg_dict = msg.model_dump()
                    tool_messages.append(msg_dict)

            if not full_msg and accumulated_content:
                full_msg = AIMessage(content="".join(accumulated_content))

            if (
                conf.enable_content_guard
                and hasattr(full_msg, "content")
                and await content_guard.check(full_msg.content)
            ):
                logger.warning("Sensitive content detected in final message")
                await save_partial_message(conv_manager, thread_id, full_msg, "content_guard_blocked")
                meta["time_cost"] = asyncio.get_event_loop().time() - start_time
                return {
                    "request_id": meta.get("request_id"),
                    "status": "interrupted",
                    "message": "检测到敏感内容,已中断输出",
                    "meta": meta,
                }

            # 获取最终的agent状态
            try:
                graph = await agent.get_graph()
                state = await graph.aget_state(langgraph_config)
                agent_state = _extract_agent_state(getattr(state, "values", {})) if state else {}
            except Exception:
                agent_state = {}

            meta["time_cost"] = asyncio.get_event_loop().time() - start_time

            # Save all messages from LangGraph state
            await save_messages_from_langgraph_state(
                agent_instance=agent,
                thread_id=thread_id,
                conv_mgr=conv_manager,
                config_dict=langgraph_config,
            )

            # 返回完整响应
            return {
                "request_id": meta.get("request_id"),
                "status": "finished",
                "response": full_msg.content if full_msg else "",
                "tool_messages": tool_messages,
                "agent_state": agent_state,
                "meta": meta,
            }

    except Exception as e:
        logger.error(f"Error processing messages: {e}, {traceback.format_exc()}")

        error_msg = f"Error processing messages: {e}"
        error_type = "unexpected_error"

        if not full_msg and accumulated_content:
            full_msg = AIMessage(content="".join(accumulated_content))

        # 保存错误消息到数据库
        async with db_manager.get_async_session_context() as new_db:
            new_conv_manager = ConversationManager(new_db)
            await save_partial_message(
                new_conv_manager,
                thread_id,
                full_msg=full_msg,
                error_message=error_msg,
                error_type=error_type,
            )

        return {
            "request_id": meta.get("request_id"),
            "status": "error",
            "error_type": error_type,
            "error_message": error_msg,
            "meta": meta,
        }

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions