From dcc108574cd4ee553b1e14268a9df14d3733933c Mon Sep 17 00:00:00 2001 From: Eularioal <1807616715@qq.com> Date: Sun, 5 Apr 2026 19:54:14 +0800 Subject: [PATCH] processing workflow --- fastapi_app/main.py | 3 +- fastapi_app/routers/__init__.py | 4 +- fastapi_app/routers/table_processing.py | 55 ++ .../services/table_processing_service.py | 109 ++++ frontend_zh/package-lock.json | 224 +++---- frontend_zh/src/components/SettingsModal.tsx | 53 +- frontend_zh/src/pages/Dashboard.tsx | 51 +- frontend_zh/src/pages/NotebookView.tsx | 433 +++++++++++++- frontend_zh/src/types/index.ts | 2 +- package-lock.json | 7 + .../agentroles/cores/base_agent.py | 7 +- workflow_engine/agentroles/table_agents.py | 147 +++++ workflow_engine/constants.py | 58 ++ workflow_engine/llm_callers/text.py | 94 ++- .../resources/pt_table_agent_repo.py | 363 ++++++++++++ workflow_engine/state.py | 138 ++++- workflow_engine/table_agent_utils.py | 366 ++++++++++++ .../workflow/wf_table_processing_api.py | 226 +++++++ .../workflow/wf_table_processing_workflow.py | 557 ++++++++++++++++++ workflow_engine/workflow/wf_table_strategy.py | 295 ++++++++++ 20 files changed, 3009 insertions(+), 183 deletions(-) create mode 100644 fastapi_app/routers/table_processing.py create mode 100644 fastapi_app/services/table_processing_service.py create mode 100644 package-lock.json create mode 100644 workflow_engine/agentroles/table_agents.py create mode 100644 workflow_engine/constants.py create mode 100644 workflow_engine/promptstemplates/resources/pt_table_agent_repo.py create mode 100644 workflow_engine/table_agent_utils.py create mode 100644 workflow_engine/workflow/wf_table_processing_api.py create mode 100644 workflow_engine/workflow/wf_table_processing_workflow.py create mode 100644 workflow_engine/workflow/wf_table_strategy.py diff --git a/fastapi_app/main.py b/fastapi_app/main.py index fdbd9b1..6842cb5 100644 --- a/fastapi_app/main.py +++ b/fastapi_app/main.py @@ -40,7 +40,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse -from fastapi_app.routers import auth, data_extract, files, kb, kb_embedding, paper2drawio, paper2ppt +from fastapi_app.routers import auth, data_extract, files, kb, kb_embedding, paper2drawio, paper2ppt, table_processing from fastapi_app.middleware.api_key import APIKeyMiddleware from fastapi_app.middleware.logging import LoggingMiddleware from workflow_engine.utils import get_project_root @@ -473,6 +473,7 @@ def create_app() -> FastAPI: app.include_router(kb_embedding.router, prefix="/api/v1", tags=["Knowledge Base Embedding"]) app.include_router(files.router, prefix="/api/v1", tags=["Files"]) app.include_router(data_extract.router, prefix="/api/v1", tags=["Data Extract"]) + app.include_router(table_processing.router, prefix="/api/v1", tags=["Table Processing"]) app.include_router(paper2drawio.router, prefix="/api/v1", tags=["Paper2Drawio"]) app.include_router(paper2ppt.router, prefix="/api/v1", tags=["Paper2PPT"]) app.include_router(auth.router, prefix="/api/v1", tags=["Auth"]) diff --git a/fastapi_app/routers/__init__.py b/fastapi_app/routers/__init__.py index b536a40..ba5f3d4 100644 --- a/fastapi_app/routers/__init__.py +++ b/fastapi_app/routers/__init__.py @@ -4,6 +4,6 @@ Router package for FastAPI backend (Notebook / frontend-v2). """ -from . import auth, data_extract, files, kb, kb_embedding, paper2drawio, paper2ppt +from . import auth, data_extract, files, kb, kb_embedding, paper2drawio, paper2ppt, table_processing -__all__ = ["auth", "data_extract", "kb", "kb_embedding", "files", "paper2drawio", "paper2ppt"] +__all__ = ["auth", "data_extract", "kb", "kb_embedding", "files", "paper2drawio", "paper2ppt", "table_processing"] diff --git a/fastapi_app/routers/table_processing.py b/fastapi_app/routers/table_processing.py new file mode 100644 index 0000000..226b210 --- /dev/null +++ b/fastapi_app/routers/table_processing.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter +from fitz import log +from pydantic import BaseModel, Field + +from fastapi_app.services.table_processing_service import TableProcessingService +from workflow_engine.logger import get_logger + +logger = get_logger(__name__) +router = APIRouter(prefix="/table-processing", tags=["Table Processing"]) + +class DataSourceRef(BaseModel): + name: str + url: str + + +class TableProcessingRequest(BaseModel): + notebook_id: str + notebook_title: str = "" + user_id: str = "local" + email: Optional[str] = None + datasources: List[DataSourceRef] + instruction: str + output_format: str = Field(default="csv", pattern="^(json|csv|markdown|dict)$") + title: str = "" + # 用户指定的 API 配置 + api_key: Optional[str] = None + api_url: Optional[str] = None + model: Optional[str] = "gpt-4o" + + +def _effective_user_id(user_id: str, email: Optional[str]) -> str: + return (email or user_id or "local").strip() or "local" + + +@router.post("/process") +async def process_table(request: TableProcessingRequest) -> Dict[str, Any]: + svc = TableProcessingService() + result = await svc.process_table( + notebook_id=request.notebook_id, + notebook_title=request.notebook_title, + user_id=_effective_user_id(request.user_id, request.email), + datasources=[ds.model_dump() for ds in request.datasources], + instruction=request.instruction, + output_format=request.output_format, + title=request.title, + api_key=request.api_key, + api_url=request.api_url, + model=request.model, + ) + logger.info("datasources: %s", request.datasources) + return {"success": True, **result} diff --git a/fastapi_app/services/table_processing_service.py b/fastapi_app/services/table_processing_service.py new file mode 100644 index 0000000..b7c610a --- /dev/null +++ b/fastapi_app/services/table_processing_service.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, List, Optional + +from fastapi import HTTPException + +from fastapi_app.utils import _from_outputs_url, _to_outputs_url +from workflow_engine.state import TableProcessingRequest, TableProcessingState +from workflow_engine.workflow import run_workflow + + +class TableProcessingService: + """独立的 Table Processing 服务:不再复用 DataExtractService 的 session/message。""" + + async def process_table( + self, + notebook_id: str, + notebook_title: str, + user_id: str, + datasources: List[Dict[str, Any]], + instruction: str, + output_format: str = "csv", + title: str = "", + api_key: Optional[str] = None, + api_url: Optional[str] = None, + model: Optional[str] = "gpt-4o", + ) -> Dict[str, Any]: + if not instruction or not instruction.strip(): + raise HTTPException(status_code=400, detail="instruction is required") + + # datasources 来自前端:[{name,url},...] + csv_paths: List[str] = [] + for ds in datasources or []: + url = (ds.get("url") or "").strip() + if not url: + continue + resolved = _from_outputs_url(url) + if resolved and isinstance(resolved, str): + csv_paths.append(resolved) + + if not csv_paths: + raise HTTPException(status_code=400, detail="datasources 不能为空/无可用 url") + + # workflow_engine 会在内部把结果整理成 content/sql/columns/rows/row_count + req = TableProcessingRequest( + datasources=csv_paths, + instruction=instruction, + output_format=output_format, + title=title or "智能表格处理", + api_key=api_key, + chat_api_url=api_url, + model=model or "gpt-4o", + notebook_id=notebook_id, + ) + state = TableProcessingState(request=req) + + result_state = await run_workflow("table_processing_api", state) + + if isinstance(result_state, dict): + content = str(result_state.get("content") or "") + sql = str(result_state.get("sql") or "") + columns = result_state.get("columns") or [] + rows = result_state.get("rows") or [] + row_count = int(result_state.get("row_count") or 0) + error = str(result_state.get("error") or "") + result_path = str(result_state.get("result_path") or "") + else: + content = str(getattr(result_state, "content", "") or "") + sql = str(getattr(result_state, "sql", "") or "") + columns = getattr(result_state, "columns", []) or [] + rows = getattr(result_state, "rows", []) or [] + row_count = int(getattr(result_state, "row_count", 0) or 0) + error = str(getattr(result_state, "error", "") or "") + result_path = str(getattr(result_state, "result_path", "") or "") + + # 转换 result_path 为可下载的 URL + processed_file_url = "" + if result_path: + # 查找 result_path 目录下的 CSV 文件 + result_dir = Path(result_path) + if result_dir.exists(): + for f in result_dir.rglob("*.csv"): + if f.is_file(): + processed_file_url = _to_outputs_url(str(f)) + break + + if error: + # error 由 workflow 层填充时,前端展示会走 content(通常为失败提示) + return { + "success": False, + "content": content or "处理失败,请稍后重试。", + "sql": sql or "", + "columns": columns, + "rows": rows, + "row_count": row_count, + "error": error, + "processed_file_url": processed_file_url, + } + + return { + "success": True, + "content": content, + "sql": sql or "", + "columns": columns, + "rows": rows, + "row_count": row_count, + "processed_file_url": processed_file_url, + } diff --git a/frontend_zh/package-lock.json b/frontend_zh/package-lock.json index 873bc08..fc43eb7 100644 --- a/frontend_zh/package-lock.json +++ b/frontend_zh/package-lock.json @@ -1088,9 +1088,9 @@ "license": "MIT" }, "node_modules/@rollup/rollup-android-arm-eabi": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.57.1.tgz", - "integrity": "sha512-A6ehUVSiSaaliTxai040ZpZ2zTevHYbvu/lDoeAteHI8QnaosIzm4qwtezfRg1jOYaUmnzLX1AOD6Z+UJjtifg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.60.0.tgz", + "integrity": "sha512-WOhNW9K8bR3kf4zLxbfg6Pxu2ybOUbB2AjMDHSQx86LIF4rH4Ft7vmMwNt0loO0eonglSNy4cpD3MKXXKQu0/A==", "cpu": [ "arm" ], @@ -1102,9 +1102,9 @@ ] }, "node_modules/@rollup/rollup-android-arm64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.57.1.tgz", - "integrity": "sha512-dQaAddCY9YgkFHZcFNS/606Exo8vcLHwArFZ7vxXq4rigo2bb494/xKMMwRRQW6ug7Js6yXmBZhSBRuBvCCQ3w==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.60.0.tgz", + "integrity": "sha512-u6JHLll5QKRvjciE78bQXDmqRqNs5M/3GVqZeMwvmjaNODJih/WIrJlFVEihvV0MiYFmd+ZyPr9wxOVbPAG2Iw==", "cpu": [ "arm64" ], @@ -1116,9 +1116,9 @@ ] }, "node_modules/@rollup/rollup-darwin-arm64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.57.1.tgz", - "integrity": "sha512-crNPrwJOrRxagUYeMn/DZwqN88SDmwaJ8Cvi/TN1HnWBU7GwknckyosC2gd0IqYRsHDEnXf328o9/HC6OkPgOg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.60.0.tgz", + "integrity": "sha512-qEF7CsKKzSRc20Ciu2Zw1wRrBz4g56F7r/vRwY430UPp/nt1x21Q/fpJ9N5l47WWvJlkNCPJz3QRVw008fi7yA==", "cpu": [ "arm64" ], @@ -1130,9 +1130,9 @@ ] }, "node_modules/@rollup/rollup-darwin-x64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.57.1.tgz", - "integrity": "sha512-Ji8g8ChVbKrhFtig5QBV7iMaJrGtpHelkB3lsaKzadFBe58gmjfGXAOfI5FV0lYMH8wiqsxKQ1C9B0YTRXVy4w==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.60.0.tgz", + "integrity": "sha512-WADYozJ4QCnXCH4wPB+3FuGmDPoFseVCUrANmA5LWwGmC6FL14BWC7pcq+FstOZv3baGX65tZ378uT6WG8ynTw==", "cpu": [ "x64" ], @@ -1144,9 +1144,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-arm64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.57.1.tgz", - "integrity": "sha512-R+/WwhsjmwodAcz65guCGFRkMb4gKWTcIeLy60JJQbXrJ97BOXHxnkPFrP+YwFlaS0m+uWJTstrUA9o+UchFug==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.60.0.tgz", + "integrity": "sha512-6b8wGHJlDrGeSE3aH5mGNHBjA0TTkxdoNHik5EkvPHCt351XnigA4pS7Wsj/Eo9Y8RBU6f35cjN9SYmCFBtzxw==", "cpu": [ "arm64" ], @@ -1158,9 +1158,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-x64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.57.1.tgz", - "integrity": "sha512-IEQTCHeiTOnAUC3IDQdzRAGj3jOAYNr9kBguI7MQAAZK3caezRrg0GxAb6Hchg4lxdZEI5Oq3iov/w/hnFWY9Q==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.60.0.tgz", + "integrity": "sha512-h25Ga0t4jaylMB8M/JKAyrvvfxGRjnPQIR8lnCayyzEjEOx2EJIlIiMbhpWxDRKGKF8jbNH01NnN663dH638mA==", "cpu": [ "x64" ], @@ -1172,9 +1172,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm-gnueabihf": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.57.1.tgz", - "integrity": "sha512-F8sWbhZ7tyuEfsmOxwc2giKDQzN3+kuBLPwwZGyVkLlKGdV1nvnNwYD0fKQ8+XS6hp9nY7B+ZeK01EBUE7aHaw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.60.0.tgz", + "integrity": "sha512-RzeBwv0B3qtVBWtcuABtSuCzToo2IEAIQrcyB/b2zMvBWVbjo8bZDjACUpnaafaxhTw2W+imQbP2BD1usasK4g==", "cpu": [ "arm" ], @@ -1186,9 +1186,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm-musleabihf": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.57.1.tgz", - "integrity": "sha512-rGfNUfn0GIeXtBP1wL5MnzSj98+PZe/AXaGBCRmT0ts80lU5CATYGxXukeTX39XBKsxzFpEeK+Mrp9faXOlmrw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.60.0.tgz", + "integrity": "sha512-Sf7zusNI2CIU1HLzuu9Tc5YGAHEZs5Lu7N1ssJG4Tkw6e0MEsN7NdjUDDfGNHy2IU+ENyWT+L2obgWiguWibWQ==", "cpu": [ "arm" ], @@ -1200,9 +1200,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.57.1.tgz", - "integrity": "sha512-MMtej3YHWeg/0klK2Qodf3yrNzz6CGjo2UntLvk2RSPlhzgLvYEB3frRvbEF2wRKh1Z2fDIg9KRPe1fawv7C+g==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.60.0.tgz", + "integrity": "sha512-DX2x7CMcrJzsE91q7/O02IJQ5/aLkVtYFryqCjduJhUfGKG6yJV8hxaw8pZa93lLEpPTP/ohdN4wFz7yp/ry9A==", "cpu": [ "arm64" ], @@ -1214,9 +1214,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.57.1.tgz", - "integrity": "sha512-1a/qhaaOXhqXGpMFMET9VqwZakkljWHLmZOX48R0I/YLbhdxr1m4gtG1Hq7++VhVUmf+L3sTAf9op4JlhQ5u1Q==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.60.0.tgz", + "integrity": "sha512-09EL+yFVbJZlhcQfShpswwRZ0Rg+z/CsSELFCnPt3iK+iqwGsI4zht3secj5vLEs957QvFFXnzAT0FFPIxSrkQ==", "cpu": [ "arm64" ], @@ -1228,9 +1228,9 @@ ] }, "node_modules/@rollup/rollup-linux-loong64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.57.1.tgz", - "integrity": "sha512-QWO6RQTZ/cqYtJMtxhkRkidoNGXc7ERPbZN7dVW5SdURuLeVU7lwKMpo18XdcmpWYd0qsP1bwKPf7DNSUinhvA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.60.0.tgz", + "integrity": "sha512-i9IcCMPr3EXm8EQg5jnja0Zyc1iFxJjZWlb4wr7U2Wx/GrddOuEafxRdMPRYVaXjgbhvqalp6np07hN1w9kAKw==", "cpu": [ "loong64" ], @@ -1242,9 +1242,9 @@ ] }, "node_modules/@rollup/rollup-linux-loong64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.57.1.tgz", - "integrity": "sha512-xpObYIf+8gprgWaPP32xiN5RVTi/s5FCR+XMXSKmhfoJjrpRAjCuuqQXyxUa/eJTdAE6eJ+KDKaoEqjZQxh3Gw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.60.0.tgz", + "integrity": "sha512-DGzdJK9kyJ+B78MCkWeGnpXJ91tK/iKA6HwHxF4TAlPIY7GXEvMe8hBFRgdrR9Ly4qebR/7gfUs9y2IoaVEyog==", "cpu": [ "loong64" ], @@ -1256,9 +1256,9 @@ ] }, "node_modules/@rollup/rollup-linux-ppc64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.57.1.tgz", - "integrity": "sha512-4BrCgrpZo4hvzMDKRqEaW1zeecScDCR+2nZ86ATLhAoJ5FQ+lbHVD3ttKe74/c7tNT9c6F2viwB3ufwp01Oh2w==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.60.0.tgz", + "integrity": "sha512-RwpnLsqC8qbS8z1H1AxBA1H6qknR4YpPR9w2XX0vo2Sz10miu57PkNcnHVaZkbqyw/kUWfKMI73jhmfi9BRMUQ==", "cpu": [ "ppc64" ], @@ -1270,9 +1270,9 @@ ] }, "node_modules/@rollup/rollup-linux-ppc64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.57.1.tgz", - "integrity": "sha512-NOlUuzesGauESAyEYFSe3QTUguL+lvrN1HtwEEsU2rOwdUDeTMJdO5dUYl/2hKf9jWydJrO9OL/XSSf65R5+Xw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.60.0.tgz", + "integrity": "sha512-Z8pPf54Ly3aqtdWC3G4rFigZgNvd+qJlOE52fmko3KST9SoGfAdSRCwyoyG05q1HrrAblLbk1/PSIV+80/pxLg==", "cpu": [ "ppc64" ], @@ -1284,9 +1284,9 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.57.1.tgz", - "integrity": "sha512-ptA88htVp0AwUUqhVghwDIKlvJMD/fmL/wrQj99PRHFRAG6Z5nbWoWG4o81Nt9FT+IuqUQi+L31ZKAFeJ5Is+A==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.60.0.tgz", + "integrity": "sha512-3a3qQustp3COCGvnP4SvrMHnPQ9d1vzCakQVRTliaz8cIp/wULGjiGpbcqrkv0WrHTEp8bQD/B3HBjzujVWLOA==", "cpu": [ "riscv64" ], @@ -1298,9 +1298,9 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.57.1.tgz", - "integrity": "sha512-S51t7aMMTNdmAMPpBg7OOsTdn4tySRQvklmL3RpDRyknk87+Sp3xaumlatU+ppQ+5raY7sSTcC2beGgvhENfuw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.60.0.tgz", + "integrity": "sha512-pjZDsVH/1VsghMJ2/kAaxt6dL0psT6ZexQVrijczOf+PeP2BUqTHYejk3l6TlPRydggINOeNRhvpLa0AYpCWSQ==", "cpu": [ "riscv64" ], @@ -1312,9 +1312,9 @@ ] }, "node_modules/@rollup/rollup-linux-s390x-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.57.1.tgz", - "integrity": "sha512-Bl00OFnVFkL82FHbEqy3k5CUCKH6OEJL54KCyx2oqsmZnFTR8IoNqBF+mjQVcRCT5sB6yOvK8A37LNm/kPJiZg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.60.0.tgz", + "integrity": "sha512-3ObQs0BhvPgiUVZrN7gqCSvmFuMWvWvsjG5ayJ3Lraqv+2KhOsp+pUbigqbeWqueGIsnn+09HBw27rJ+gYK4VQ==", "cpu": [ "s390x" ], @@ -1326,9 +1326,9 @@ ] }, "node_modules/@rollup/rollup-linux-x64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.57.1.tgz", - "integrity": "sha512-ABca4ceT4N+Tv/GtotnWAeXZUZuM/9AQyCyKYyKnpk4yoA7QIAuBt6Hkgpw8kActYlew2mvckXkvx0FfoInnLg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.60.0.tgz", + "integrity": "sha512-EtylprDtQPdS5rXvAayrNDYoJhIz1/vzN2fEubo3yLE7tfAw+948dO0g4M0vkTVFhKojnF+n6C8bDNe+gDRdTg==", "cpu": [ "x64" ], @@ -1340,9 +1340,9 @@ ] }, "node_modules/@rollup/rollup-linux-x64-musl": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.57.1.tgz", - "integrity": "sha512-HFps0JeGtuOR2convgRRkHCekD7j+gdAuXM+/i6kGzQtFhlCtQkpwtNzkNj6QhCDp7DRJ7+qC/1Vg2jt5iSOFw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.60.0.tgz", + "integrity": "sha512-k09oiRCi/bHU9UVFqD17r3eJR9bn03TyKraCrlz5ULFJGdJGi7VOmm9jl44vOJvRJ6P7WuBi/s2A97LxxHGIdw==", "cpu": [ "x64" ], @@ -1354,9 +1354,9 @@ ] }, "node_modules/@rollup/rollup-openbsd-x64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.57.1.tgz", - "integrity": "sha512-H+hXEv9gdVQuDTgnqD+SQffoWoc0Of59AStSzTEj/feWTBAnSfSD3+Dql1ZruJQxmykT/JVY0dE8Ka7z0DH1hw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.60.0.tgz", + "integrity": "sha512-1o/0/pIhozoSaDJoDcec+IVLbnRtQmHwPV730+AOD29lHEEo4F5BEUB24H0OBdhbBBDwIOSuf7vgg0Ywxdfiiw==", "cpu": [ "x64" ], @@ -1368,9 +1368,9 @@ ] }, "node_modules/@rollup/rollup-openharmony-arm64": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.57.1.tgz", - "integrity": "sha512-4wYoDpNg6o/oPximyc/NG+mYUejZrCU2q+2w6YZqrAs2UcNUChIZXjtafAiiZSUc7On8v5NyNj34Kzj/Ltk6dQ==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.60.0.tgz", + "integrity": "sha512-pESDkos/PDzYwtyzB5p/UoNU/8fJo68vcXM9ZW2V0kjYayj1KaaUfi1NmTUTUpMn4UhU4gTuK8gIaFO4UGuMbA==", "cpu": [ "arm64" ], @@ -1382,9 +1382,9 @@ ] }, "node_modules/@rollup/rollup-win32-arm64-msvc": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.57.1.tgz", - "integrity": "sha512-O54mtsV/6LW3P8qdTcamQmuC990HDfR71lo44oZMZlXU4tzLrbvTii87Ni9opq60ds0YzuAlEr/GNwuNluZyMQ==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.60.0.tgz", + "integrity": "sha512-hj1wFStD7B1YBeYmvY+lWXZ7ey73YGPcViMShYikqKT1GtstIKQAtfUI6yrzPjAy/O7pO0VLXGmUVWXQMaYgTQ==", "cpu": [ "arm64" ], @@ -1396,9 +1396,9 @@ ] }, "node_modules/@rollup/rollup-win32-ia32-msvc": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.57.1.tgz", - "integrity": "sha512-P3dLS+IerxCT/7D2q2FYcRdWRl22dNbrbBEtxdWhXrfIMPP9lQhb5h4Du04mdl5Woq05jVCDPCMF7Ub0NAjIew==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.60.0.tgz", + "integrity": "sha512-SyaIPFoxmUPlNDq5EHkTbiKzmSEmq/gOYFI/3HHJ8iS/v1mbugVa7dXUzcJGQfoytp9DJFLhHH4U3/eTy2Bq4w==", "cpu": [ "ia32" ], @@ -1410,9 +1410,9 @@ ] }, "node_modules/@rollup/rollup-win32-x64-gnu": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.57.1.tgz", - "integrity": "sha512-VMBH2eOOaKGtIJYleXsi2B8CPVADrh+TyNxJ4mWPnKfLB/DBUmzW+5m1xUrcwWoMfSLagIRpjUFeW5CO5hyciQ==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.60.0.tgz", + "integrity": "sha512-RdcryEfzZr+lAr5kRm2ucN9aVlCCa2QNq4hXelZxb8GG0NJSazq44Z3PCCc8wISRuCVnGs0lQJVX5Vp6fKA+IA==", "cpu": [ "x64" ], @@ -1424,9 +1424,9 @@ ] }, "node_modules/@rollup/rollup-win32-x64-msvc": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.57.1.tgz", - "integrity": "sha512-mxRFDdHIWRxg3UfIIAwCm6NzvxG0jDX/wBN6KsQFTvKFqqg9vTrWUE68qEjHt19A5wwx5X5aUi2zuZT7YR0jrA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.60.0.tgz", + "integrity": "sha512-PrsWNQ8BuE00O3Xsx3ALh2Df8fAj9+cvvX9AIA6o4KpATR98c9mud4XtDWVvsEuyia5U4tVSTKygawyJkjm60w==", "cpu": [ "x64" ], @@ -2611,9 +2611,9 @@ "license": "MIT" }, "node_modules/dompurify": { - "version": "3.3.1", - "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.3.1.tgz", - "integrity": "sha512-qkdCKzLNtrgPFP1Vo+98FRzJnBRGe4ffyCea9IwHB1fyxPOeNTHpLKYGd4Uk9xvNoH0ZoOjwZxNptyMwqrId1Q==", + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.3.3.tgz", + "integrity": "sha512-Oj6pzI2+RqBfFG+qOaOLbFXLQ90ARpcGG6UePL82bJLtdsa6CYJD7nmiU8MW9nQNOtCHV3lZ/Bzq1X0QYbBZCA==", "license": "(MPL-2.0 OR Apache-2.0)", "optionalDependencies": { "@types/trusted-types": "^2.0.7" @@ -5800,9 +5800,9 @@ "license": "ISC" }, "node_modules/picomatch": { - "version": "2.3.1", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", - "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "version": "2.3.2", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz", + "integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==", "dev": true, "license": "MIT", "engines": { @@ -6737,9 +6737,9 @@ "license": "Unlicense" }, "node_modules/rollup": { - "version": "4.57.1", - "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.57.1.tgz", - "integrity": "sha512-oQL6lgK3e2QZeQ7gcgIkS2YZPg5slw37hYufJ3edKlfQSGGm8ICoxswK15ntSzF/a8+h7ekRy7k7oWc3BQ7y8A==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.60.0.tgz", + "integrity": "sha512-yqjxruMGBQJ2gG4HtjZtAfXArHomazDHoFwFFmZZl0r7Pdo7qCIXKqKHZc8yeoMgzJJ+pO6pEEHa+V7uzWlrAQ==", "dev": true, "license": "MIT", "dependencies": { @@ -6753,31 +6753,31 @@ "npm": ">=8.0.0" }, "optionalDependencies": { - "@rollup/rollup-android-arm-eabi": "4.57.1", - "@rollup/rollup-android-arm64": "4.57.1", - "@rollup/rollup-darwin-arm64": "4.57.1", - "@rollup/rollup-darwin-x64": "4.57.1", - "@rollup/rollup-freebsd-arm64": "4.57.1", - "@rollup/rollup-freebsd-x64": "4.57.1", - "@rollup/rollup-linux-arm-gnueabihf": "4.57.1", - "@rollup/rollup-linux-arm-musleabihf": "4.57.1", - "@rollup/rollup-linux-arm64-gnu": "4.57.1", - "@rollup/rollup-linux-arm64-musl": "4.57.1", - "@rollup/rollup-linux-loong64-gnu": "4.57.1", - "@rollup/rollup-linux-loong64-musl": "4.57.1", - "@rollup/rollup-linux-ppc64-gnu": "4.57.1", - "@rollup/rollup-linux-ppc64-musl": "4.57.1", - "@rollup/rollup-linux-riscv64-gnu": "4.57.1", - "@rollup/rollup-linux-riscv64-musl": "4.57.1", - "@rollup/rollup-linux-s390x-gnu": "4.57.1", - "@rollup/rollup-linux-x64-gnu": "4.57.1", - "@rollup/rollup-linux-x64-musl": "4.57.1", - "@rollup/rollup-openbsd-x64": "4.57.1", - "@rollup/rollup-openharmony-arm64": "4.57.1", - "@rollup/rollup-win32-arm64-msvc": "4.57.1", - "@rollup/rollup-win32-ia32-msvc": "4.57.1", - "@rollup/rollup-win32-x64-gnu": "4.57.1", - "@rollup/rollup-win32-x64-msvc": "4.57.1", + "@rollup/rollup-android-arm-eabi": "4.60.0", + "@rollup/rollup-android-arm64": "4.60.0", + "@rollup/rollup-darwin-arm64": "4.60.0", + "@rollup/rollup-darwin-x64": "4.60.0", + "@rollup/rollup-freebsd-arm64": "4.60.0", + "@rollup/rollup-freebsd-x64": "4.60.0", + "@rollup/rollup-linux-arm-gnueabihf": "4.60.0", + "@rollup/rollup-linux-arm-musleabihf": "4.60.0", + "@rollup/rollup-linux-arm64-gnu": "4.60.0", + "@rollup/rollup-linux-arm64-musl": "4.60.0", + "@rollup/rollup-linux-loong64-gnu": "4.60.0", + "@rollup/rollup-linux-loong64-musl": "4.60.0", + "@rollup/rollup-linux-ppc64-gnu": "4.60.0", + "@rollup/rollup-linux-ppc64-musl": "4.60.0", + "@rollup/rollup-linux-riscv64-gnu": "4.60.0", + "@rollup/rollup-linux-riscv64-musl": "4.60.0", + "@rollup/rollup-linux-s390x-gnu": "4.60.0", + "@rollup/rollup-linux-x64-gnu": "4.60.0", + "@rollup/rollup-linux-x64-musl": "4.60.0", + "@rollup/rollup-openbsd-x64": "4.60.0", + "@rollup/rollup-openharmony-arm64": "4.60.0", + "@rollup/rollup-win32-arm64-msvc": "4.60.0", + "@rollup/rollup-win32-ia32-msvc": "4.60.0", + "@rollup/rollup-win32-x64-gnu": "4.60.0", + "@rollup/rollup-win32-x64-msvc": "4.60.0", "fsevents": "~2.3.2" } }, @@ -7065,9 +7065,9 @@ } }, "node_modules/tinyglobby/node_modules/picomatch": { - "version": "4.0.3", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", - "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz", + "integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==", "dev": true, "license": "MIT", "engines": { diff --git a/frontend_zh/src/components/SettingsModal.tsx b/frontend_zh/src/components/SettingsModal.tsx index adc42cf..c5e5574 100644 --- a/frontend_zh/src/components/SettingsModal.tsx +++ b/frontend_zh/src/components/SettingsModal.tsx @@ -13,6 +13,8 @@ export const SettingsModal: React.FC = ({ open, onClose }) = const { user } = useAuthStore(); const userIdForSettings = user?.id ?? 'default'; const [apiUrl, setApiUrl] = useState(DEFAULT_LLM_API_URL); + const [customApiUrl, setCustomApiUrl] = useState(''); + const [isCustomUrl, setIsCustomUrl] = useState(false); const [apiKey, setApiKey] = useState(''); const [searchProvider, setSearchProvider] = useState('serper'); const [searchApiKey, setSearchApiKey] = useState(''); @@ -24,13 +26,19 @@ export const SettingsModal: React.FC = ({ open, onClose }) = if (open) { const settings = getApiSettings(userIdForSettings); if (settings) { - setApiUrl(settings.apiUrl || DEFAULT_LLM_API_URL); + const savedUrl = settings.apiUrl || DEFAULT_LLM_API_URL; + const isSavedUrlCustom = !API_URL_OPTIONS.includes(savedUrl); + setApiUrl(isSavedUrlCustom ? API_URL_OPTIONS[0] : savedUrl); + setCustomApiUrl(isSavedUrlCustom ? savedUrl : ''); + setIsCustomUrl(isSavedUrlCustom); setApiKey(settings.apiKey || ''); setSearchProvider((settings.searchProvider as SearchProvider) || 'serper'); setSearchApiKey(settings.searchApiKey || ''); setSearchEngine((settings.searchEngine as SearchEngine) || 'google'); } else { setApiUrl(DEFAULT_LLM_API_URL); + setCustomApiUrl(''); + setIsCustomUrl(false); setApiKey(''); setSearchProvider('serper'); setSearchApiKey(''); @@ -42,8 +50,10 @@ export const SettingsModal: React.FC = ({ open, onClose }) = const handleSave = () => { setSaving(true); setSaved(false); + // 如果是自定义 URL,使用自定义输入的值 + const finalApiUrl = isCustomUrl ? customApiUrl.trim() : apiUrl; saveApiSettings(userIdForSettings, { - apiUrl: apiUrl.trim(), + apiUrl: finalApiUrl, apiKey: apiKey.trim(), searchProvider, searchApiKey: searchApiKey.trim(), @@ -56,6 +66,15 @@ export const SettingsModal: React.FC = ({ open, onClose }) = }, 1500); }; + const handleApiUrlChange = (value: string) => { + if (value === '__CUSTOM__') { + setIsCustomUrl(true); + } else { + setIsCustomUrl(false); + setApiUrl(value); + } + }; + if (!open) return null; return ( @@ -86,15 +105,27 @@ export const SettingsModal: React.FC = ({ open, onClose }) =
- +
+ + {isCustomUrl && ( + setCustomApiUrl(e.target.value)} + placeholder="https://your-api-endpoint.com/v1" + className="w-full px-4 py-2.5 bg-white border border-blue-300 rounded-xl text-gray-800 placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent transition-all" + /> + )} +

OpenAI 兼容接口地址,如 api.openai.com/v1 或自建服务

diff --git a/frontend_zh/src/pages/Dashboard.tsx b/frontend_zh/src/pages/Dashboard.tsx index d4fa3ca..171dc13 100644 --- a/frontend_zh/src/pages/Dashboard.tsx +++ b/frontend_zh/src/pages/Dashboard.tsx @@ -33,6 +33,8 @@ const Dashboard = ({ onOpenNotebook, refreshTrigger = 0, supabaseConfigured }: { const [createError, setCreateError] = useState(''); const [configOpen, setConfigOpen] = useState(false); const [apiUrl, setApiUrl] = useState(DEFAULT_LLM_API_URL); + const [customApiUrl, setCustomApiUrl] = useState(''); + const [isCustomUrl, setIsCustomUrl] = useState(false); const [apiKey, setApiKey] = useState(''); const [searchProvider, setSearchProvider] = useState('serper'); const [searchApiKey, setSearchApiKey] = useState(''); @@ -48,7 +50,11 @@ const Dashboard = ({ onOpenNotebook, refreshTrigger = 0, supabaseConfigured }: { useEffect(() => { const s = getApiSettings(effectiveUserId); if (s) { - setApiUrl(s.apiUrl || DEFAULT_LLM_API_URL); + const savedUrl = s.apiUrl || DEFAULT_LLM_API_URL; + const isSavedUrlCustom = !API_URL_OPTIONS.includes(savedUrl); + setApiUrl(isSavedUrlCustom ? API_URL_OPTIONS[0] : savedUrl); + setCustomApiUrl(isSavedUrlCustom ? savedUrl : ''); + setIsCustomUrl(isSavedUrlCustom); setApiKey(s.apiKey || ''); setSearchProvider((s.searchProvider as SearchProvider) || 'serper'); setSearchApiKey(s.searchApiKey || ''); @@ -59,8 +65,10 @@ const Dashboard = ({ onOpenNotebook, refreshTrigger = 0, supabaseConfigured }: { const handleSaveConfig = () => { setConfigSaving(true); setConfigSaved(false); + // 如果是自定义 URL,使用自定义输入的值 + const finalApiUrl = isCustomUrl ? customApiUrl.trim() : apiUrl; const settings: ApiSettings = { - apiUrl: apiUrl.trim(), + apiUrl: finalApiUrl, apiKey: apiKey.trim(), searchProvider, searchApiKey: searchApiKey.trim(), @@ -74,6 +82,15 @@ const Dashboard = ({ onOpenNotebook, refreshTrigger = 0, supabaseConfigured }: { }, 1500); }; + const handleApiUrlChange = (value: string) => { + if (value === '__CUSTOM__') { + setIsCustomUrl(true); + } else { + setIsCustomUrl(false); + setApiUrl(value); + } + }; + const fetchNotebooks = async (options?: { force?: boolean }) => { const cached = getCachedValue(notebookListCacheKey); if (cached) { @@ -235,15 +252,27 @@ const Dashboard = ({ onOpenNotebook, refreshTrigger = 0, supabaseConfigured }: {

LLM 调用

- +
+ + {isCustomUrl && ( + setCustomApiUrl(e.target.value)} + placeholder="https://your-api-endpoint.com/v1" + className="w-full px-3 py-2.5 border border-blue-300 rounded-ios text-sm focus:ring-2 focus:ring-primary/30 focus:border-primary transition-colors" + /> + )} +
diff --git a/frontend_zh/src/pages/NotebookView.tsx b/frontend_zh/src/pages/NotebookView.tsx index 64c76bd..12b8466 100644 --- a/frontend_zh/src/pages/NotebookView.tsx +++ b/frontend_zh/src/pages/NotebookView.tsx @@ -57,6 +57,19 @@ type DataExtractMessage = { error?: string | null; }; +type TableProcessingMessage = { + id: string; + role: 'user' | 'assistant'; + content: string; + time: string; + sql?: string; + columns?: string[]; + rows?: Record[]; + rowCount?: number; + exportUrl?: string; + error?: string | null; +}; + type DataExtractArtifact = { id: string; session_id: string; @@ -191,6 +204,18 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void ]); const [dataExtractLoading, setDataExtractLoading] = useState(false); const [dataExtractSyncing, setDataExtractSyncing] = useState(false); + + // Table processing states + const [tableProcessingInput, setTableProcessingInput] = useState(''); + const [tableProcessingMessages, setTableProcessingMessages] = useState([ + { id: 'table-processing-welcome', role: 'assistant', content: '选择 CSV 数据源后,输入自然语言指令进行智能处理。', time: new Date().toLocaleTimeString() } + ]); + const [tableProcessingResult, setTableProcessingResult] = useState(null); + const [tableProcessingLoading, setTableProcessingLoading] = useState(false); + const [tableProcessingFormat, setTableProcessingFormat] = useState<'json' | 'csv' | 'markdown' | 'dict'>('csv'); + const [tableProcessingSubView, setTableProcessingSubView] = useState<'current' | 'history'>('current'); + const [tableProcessingSessions, setTableProcessingSessions] = useState>([]); + const [chatLoadingStage, setChatLoadingStage] = useState('思考中...'); // 对话历史:本地持久化 @@ -249,7 +274,7 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void // Output preview const [previewOutput, setPreviewOutput] = useState<{ id: string; - type: 'ppt' | 'mindmap' | 'podcast' | 'drawio' | 'flashcard' | 'quiz'; + type: 'ppt' | 'mindmap' | 'podcast' | 'drawio' | 'flashcard' | 'quiz' | 'note'; title: string; sources: string; url?: string; @@ -364,8 +389,10 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void // Studio tools const dataExtractTool = { icon: , label: '智能取数', id: 'data_extract' as ToolType }; + const tableProcessingTool = { icon: , label: '智能处理', id: 'table_processing' as ToolType }; const studioTools: Array<{icon: React.ReactNode, label: string, id: ToolType}> = [ dataExtractTool, + tableProcessingTool, { icon: , label: 'PPT生成', id: 'ppt' }, { icon: , label: '思维导图', id: 'mindmap' }, // DrawIO 图表功能暂时隐藏,后续修复 @@ -379,12 +406,13 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void ]; // Studio:每个功能卡片各自配置,点卡片上的「…」翻转进该卡片的设置 - type StudioToolId = 'data_extract' | 'ppt' | 'mindmap' | 'drawio' | 'flashcard' | 'quiz' | 'podcast' | 'video' | 'note'; + type StudioToolId = 'data_extract' | 'table_processing' | 'ppt' | 'mindmap' | 'drawio' | 'flashcard' | 'quiz' | 'podcast' | 'video' | 'note'; const [studioPanelView, setStudioPanelView] = useState<'tools' | 'settings'>('tools'); const [studioSettingsTool, setStudioSettingsTool] = useState(null); const STORAGE_STUDIO_CONFIG = `kb_studio_config_${effectiveUser?.id || 'default'}`; const defaultByTool: Record> = { data_extract: { resultFormat: 'json', executionStrategy: 'auto' }, + table_processing: { resultFormat: 'csv', llmModel: 'gpt-4o' }, ppt: { llmModel: 'deepseek-v3.2', genFigModel: 'gemini-2.5-flash-image', stylePreset: 'modern', stylePrompt: '', language: 'zh', page_count: '10' }, mindmap: { llmModel: 'deepseek-v3.2', mindmapStyle: 'default' }, drawio: { llmModel: 'deepseek-v3.2', diagramType: 'auto', diagramStyle: 'default', language: 'zh' }, @@ -1084,7 +1112,9 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void return candidate.includes('.csv') || candidate.includes('text/csv'); }; - const selectedCsvFiles = files.filter(f => selectedIds.has(f.id) && f.type === 'dataset' && isCsvDataExtractFile(f)); + const selectedCsvFiles = files.filter( + f => selectedIds.has(f.id) && f.type === 'dataset' && isCsvDataExtractFile(f) && Boolean(f.url) + ); const selectedCsvFileUrls = new Set( selectedCsvFiles .map(file => file.url) @@ -1169,7 +1199,7 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void }; const fetchDataExtractDatasources = async () => { - if (!notebook?.id) return; + if (!notebook?.id) return []; try { const params = new URLSearchParams({ notebook_id: notebook.id, @@ -1182,8 +1212,10 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void const data = await res.json(); const list = Array.isArray(data?.datasources) ? data.datasources : []; setDataExtractDatasources(list); + return list; } catch (err) { console.error('Failed to fetch data extract datasources:', err); + return []; } }; @@ -1384,8 +1416,24 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void return; } if (!dataExtractSessionId && activeDataExtractDatasourceIds.length === 0) { - alert(selectedCsvFiles.length > 0 ? '请先同步选中的 CSV 数据源' : '请先同步并选择一个数据源'); - return; + if (selectedCsvFiles.length > 0) { + showToast('正在同步选中的 CSV 数据源,请稍候...', 'success'); + await handleSyncDataExtractSources(); + + const refreshedDatasources = await fetchDataExtractDatasources(); + const selectedCsvFileUrlsSet = new Set(selectedCsvFiles.map(f => f.url).filter((url): url is string => Boolean(url))); + const refreshedActive = selectedCsvFileUrlsSet.size > 0 + ? refreshedDatasources.filter((ds: DataExtractDatasource) => selectedCsvFileUrlsSet.has(ds.file_path)) + : refreshedDatasources; + + if (refreshedActive.length === 0) { + alert('同步成功后未检测到数据源,请稍后刷新或检查后再试。'); + return; + } + } else { + alert('请先同步并选择一个数据源'); + return; + } } const userMsg: DataExtractMessage = { @@ -1449,6 +1497,124 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void } }; + const handleTableProcessing = async () => { + if (!notebook?.id) { + alert('请先创建或选择一个笔记本'); + return; + } + if (!tableProcessingInput.trim()) { + alert('请先输入处理指令'); + return; + } + if (selectedCsvFiles.length === 0) { + alert('请先选择至少一个 CSV 数据源'); + return; + } + + const userMessage: TableProcessingMessage = { + id: `table-processing-user-${Date.now()}`, + role: 'user', + content: tableProcessingInput, + time: new Date().toLocaleTimeString(), + }; + setTableProcessingMessages(prev => [...prev, userMessage]); + + setTableProcessingLoading(true); + setTableProcessingResult(null); + + // 获取 API 配置 + const settings = getApiSettings(effectiveUser?.id || null); + const apiUrl = settings?.apiUrl?.trim() || ''; + const apiKey = settings?.apiKey?.trim() || ''; + + try { + const tableConfig = getStudioConfig('table_processing'); + const res = await apiFetch('/api/v1/table-processing/process', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + notebook_id: notebook.id, + notebook_title: notebook?.title || notebook?.name || '', + user_id: effectiveUser.id || 'default', + email: effectiveUser.email || effectiveUser.id || 'default', + datasources: selectedCsvFiles.map(f => ({ name: f.name, url: f.url! })), + instruction: tableProcessingInput, + output_format: tableProcessingFormat, + title: '智能表格处理', + api_key: apiKey || undefined, + api_url: apiUrl || undefined, + model: tableConfig.llmModel || 'gpt-4o', + }), + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(text || 'Table processing failed'); + } + const data = await res.json(); + setTableProcessingResult(data); + + // 成功时固定返回"处理成功!",失败时显示错误信息 + const assistantContent = data?.success === false + ? (data?.content || data?.error || '处理失败,请稍后重试') + : '处理成功!'; + + const assistantMessage: TableProcessingMessage = { + id: `table-processing-assistant-${Date.now()}`, + role: 'assistant', + content: assistantContent, + time: new Date().toLocaleTimeString(), + sql: typeof data?.sql === 'string' ? data.sql : undefined, + columns: Array.isArray(data?.columns) ? data.columns : undefined, + rows: Array.isArray(data?.rows) ? data.rows : undefined, + rowCount: typeof data?.row_count === 'number' + ? data.row_count + : typeof data?.rowCount === 'number' + ? data.rowCount + : undefined, + exportUrl: typeof data?.processed_file_url === 'string' ? data.processed_file_url : undefined, + }; + setTableProcessingMessages(prev => [...prev, assistantMessage]); + + setTableProcessingSessions(prev => [ + { + id: `${Date.now()}`, + title: `智能处理 ${new Date().toLocaleString()}`, + updated_at: new Date().toISOString(), + instruction: tableProcessingInput, + }, + ...prev, + ]); + setTableProcessingSubView('current'); + showToast('表格处理完成', 'success'); + } catch (error) { + console.error('Table processing error:', error); + const errMsg = (error as any)?.message || String(error); + setTableProcessingMessages(prev => [...prev, { + id: `table-processing-assistant-error-${Date.now()}`, + role: 'assistant', + content: `处理失败:${errMsg}`, + time: new Date().toLocaleTimeString(), + }]); + showToast('表格处理失败,请检查日志', 'error'); + } finally { + setTableProcessingLoading(false); + setTableProcessingInput(''); + } + }; + + const handleNewTableProcessingSession = () => { + setTableProcessingSubView('current'); + setTableProcessingInput(''); + setTableProcessingMessages([ + { id: 'table-processing-welcome', role: 'assistant', content: '选择 CSV 数据源后,输入自然语言指令进行智能处理。', time: new Date().toLocaleTimeString() } + ]); + setTableProcessingResult(null); + }; + + const handleShowTableProcessingHistory = () => { + setTableProcessingSubView('history'); + }; + const handleToggleSelect = (id: string) => { const newSet = new Set(selectedIds); if (newSet.has(id)) { @@ -1920,7 +2086,7 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void uploadQueue.length > 1 ? `已添加 ${uploadQueue.length} 个文件,正在处理` : `已添加 ${uploadQueue[0].name},正在处理`, - 'info' + 'success' ); let successCount = 0; @@ -2217,6 +2383,10 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void }; switch (tool) { + case 'table_processing': + setActiveTool('table_processing'); + setToolLoading(false); + return; case 'mindmap': endpoint = '/api/v1/kb/generate-mindmap'; break; @@ -3123,6 +3293,172 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void initialBlocks={editingNote?.blocks} />
+ ) : activeTool === 'table_processing' ? ( +
+
+
+ 智能处理 +

根据自然语言处理选中表格并返回结果

+
+
+ + +
+
+ +
+ {tableProcessingSubView === 'history' ? ( +
+
+

智能处理历史

+ +
+ {tableProcessingSessions.length === 0 ? ( +
暂无历史记录
+ ) : ( +
    + {tableProcessingSessions.map(session => ( +
  • + +
  • + ))} +
+ )} +
+ ) : ( +
+ {selectedCsvFiles.length > 0 && ( +
+

已选 CSV 文件:

+
    + {selectedCsvFiles.map((file, index) => ( +
  • {file.name}
  • + ))} +
+
+ )} +
+ {tableProcessingMessages.map(msg => ( +
+
+ {msg.role === 'assistant' ? : } +
+
+
{msg.content}
+ {msg.sql && ( +
{msg.sql}
+ )} + {msg.role === 'assistant' && msg.rows && msg.rows.length > 0 && ( +
+
+ + + + {msg.columns?.map((col, i) => ( + + ))} + + + + {msg.rows.slice(0, 10).map((row, rowIdx) => ( + + {msg.columns?.map((col, colIdx) => ( + + ))} + + ))} + +
{col}
{String(row?.[col] ?? '')}
+
+ {msg.rows.length > 10 && ( +

显示前 10 行

+ )} +
+ )} + {msg.role === 'assistant' && msg.exportUrl && ( +
+ +
+ )} +
+
+ ))} +
+
+ )} +
+ {tableProcessingSubView === 'current' && ( +
+
+
+ setTableProcessingInput(e.target.value)} + onKeyDown={e => e.key === 'Enter' && handleTableProcessing()} + placeholder={selectedCsvFiles.length > 0 ? '输入表格处理指令,例如:按城市分组并求销售总额前 20' : '请先选择 CSV 数据源'} + disabled={selectedCsvFiles.length === 0} + className="w-full bg-transparent rounded-ios-xl py-4 pl-6 pr-24 text-lg focus:outline-none disabled:opacity-50" + /> +
+ + {tableProcessingLoading ? '执行中...' : '发送'} + +
+
+
+
+ )} +
) : activeTool === 'data_extract' ? (
@@ -3438,29 +3774,37 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void {dataExtractSubView === 'current' && (
-
-
- setDataExtractInput(e.target.value)} - onKeyDown={e => e.key === 'Enter' && handleSendDataExtractMessage()} - placeholder={activeDataExtractDatasourceIds.length > 0 ? '输入一个数据问题,例如:统计各城市销售额前 10 名' : '请先同步并选择一个数据源'} - disabled={activeDataExtractDatasourceIds.length === 0 && !dataExtractSessionId} - className="w-full bg-transparent rounded-ios-xl py-4 pl-6 pr-24 focus:outline-none text-lg disabled:opacity-50" - /> -
- - {activeDataExtractDatasourceIds.length > 1 ? `联合 ${activeDataExtractDatasourceIds.length} 个数据源` : `${activeDataExtractDatasources.length} 个数据源`} - - - - +
+ {selectedCsvFiles.length > 0 && ( +
+ 已选 CSV 文件 {selectedCsvFiles.length} 个:{selectedCsvFiles.map(f => f.name).slice(0, 3).join(',')}{selectedCsvFiles.length > 3 ? ` 等 ${selectedCsvFiles.length} 个` : ''}。 + {activeDataExtractDatasources.length === 0 && ' 请先点击“同步选中 CSV”完成注册后即可发送问题。'} +
+ )} +
+
+ setDataExtractInput(e.target.value)} + onKeyDown={e => e.key === 'Enter' && handleSendDataExtractMessage()} + placeholder={selectedCsvFiles.length > 0 ? '选中 CSV 后可输入问题:例如统计各城市销售额前 10 名' : '请先同步并选择一个数据源'} + disabled={selectedCsvFiles.length === 0 && !dataExtractSessionId} + className="w-full bg-transparent rounded-ios-xl py-4 pl-6 pr-24 focus:outline-none text-lg disabled:opacity-50" + /> +
+ + {activeDataExtractDatasourceIds.length > 1 ? `联合 ${activeDataExtractDatasourceIds.length} 个数据源` : `${dataExtractDatasources.length} 个数据源`} + + + + +
@@ -3644,6 +3988,7 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void

{studioSettingsTool === 'data_extract' && '智能取数'} + {studioSettingsTool === 'table_processing' && '智能处理'} {studioSettingsTool === 'ppt' && 'PPT 生成'} {studioSettingsTool === 'mindmap' && '思维导图'} {studioSettingsTool === 'drawio' && 'DrawIO 图表'} @@ -3676,6 +4021,17 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void ); })()} + {studioSettingsTool === 'table_processing' && (() => { + const c = getStudioConfig('table_processing'); + return ( + <> +
+ + setStudioConfigForTool('table_processing', { llmModel: e.target.value })} placeholder="gpt-4o" className="w-full px-3 py-2 border border-gray-200 rounded-lg text-sm focus:ring-2 focus:ring-blue-500" /> +
+ + ); + })()} {studioSettingsTool === 'ppt' && (() => { const c = getStudioConfig('ppt'); return ( @@ -4046,7 +4402,7 @@ const NotebookView = ({ notebook, onBack }: { notebook: any, onBack: () => void ))}

- {activeTool !== 'chat' && activeTool !== 'search' && activeTool !== 'data_extract' && ( + {activeTool !== 'chat' && activeTool !== 'search' && activeTool !== 'data_extract' && activeTool !== 'table_processing' && ( void )}
)} + {previewOutput.type === 'note' && previewOutput.url && ( +
+
+

笔记预览将在编辑器打开

+ + 打开笔记文件 + +
+
+ )} {previewOutput.type === 'mindmap' && !previewOutput.mermaidCode && (
{previewLoading ? '正在加载思维导图内容...' : '暂无预览内容'} diff --git a/frontend_zh/src/types/index.ts b/frontend_zh/src/types/index.ts index 0b17279..15c32f7 100644 --- a/frontend_zh/src/types/index.ts +++ b/frontend_zh/src/types/index.ts @@ -34,4 +34,4 @@ export interface ChatMessage { } export type SectionType = 'library' | 'upload' | 'output' | 'settings'; -export type ToolType = 'chat' | 'ppt' | 'mindmap' | 'podcast' | 'video' | 'search' | 'drawio' | 'flashcard' | 'quiz' | 'note' | 'data_extract'; +export type ToolType = 'chat' | 'ppt' | 'mindmap' | 'podcast' | 'video' | 'search' | 'drawio' | 'flashcard' | 'quiz' | 'note' | 'data_extract' | 'table_processing'; diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 0000000..3c37f59 --- /dev/null +++ b/package-lock.json @@ -0,0 +1,7 @@ +{ + "name": "Open-NotebookLM-test", + "lockfileVersion": 3, + "requires": true, + "packages": {} + } + \ No newline at end of file diff --git a/workflow_engine/agentroles/cores/base_agent.py b/workflow_engine/agentroles/cores/base_agent.py index 60f41f8..e11a056 100644 --- a/workflow_engine/agentroles/cores/base_agent.py +++ b/workflow_engine/agentroles/cores/base_agent.py @@ -249,8 +249,10 @@ def __init__(self, # 创建执行策略 from workflow_engine.agentroles.cores.strategies import StrategyFactory + # 获取 mode:可能是枚举.value 或字符串 + mode_value = execution_config.mode.value if hasattr(execution_config.mode, 'value') else execution_config.mode self._execution_strategy = StrategyFactory.create( - execution_config.mode.value, + mode_value, self, execution_config ) @@ -433,7 +435,8 @@ def build_messages(self, ptg = PromptsTemplateGenerator(state.request.language) # 渲染系统提示词 - sys_prompt = ptg.render(self.system_prompt_template_name) + sys_params = self.get_task_prompt_params(pre_tool_results) + sys_prompt = ptg.render(self.system_prompt_template_name, **sys_params) # 添加解析器格式说明(VLM 模式可能不需要) format_instruction = self.parser.get_format_instruction() diff --git a/workflow_engine/agentroles/table_agents.py b/workflow_engine/agentroles/table_agents.py new file mode 100644 index 0000000..1587b42 --- /dev/null +++ b/workflow_engine/agentroles/table_agents.py @@ -0,0 +1,147 @@ +"""Table Agent agents - agents for table processing workflow. + +This module defines agents used in the table processing workflow: +- intent_understanding: Parse user intent into task type and operation +- data_profiling: Profile input tables using ReAct mode +- decompositer: Decompose complex tasks +- generator: Generate Python code for table operations +- debugger: Debug failed code +- summarizer: Summarize results using ReAct mode +""" +from typing import Optional + +from workflow_engine.toolkits.tool_manager import ToolManager, get_tool_manager +from workflow_engine.agentroles.cores.base_agent import BaseAgent +from workflow_engine.agentroles.cores.registry import register + + +@register("intent_understanding") +class IntentUnderstandingAgent(BaseAgent): + """Agent for parsing user intent into structured task specifications.""" + + def __init__(self, tool_manager: Optional[ToolManager] = None, **kwargs): + super().__init__(tool_manager=tool_manager, **kwargs) + if self.tool_manager is None: + self.tool_manager = get_tool_manager() + + @property + def role_name(self) -> str: + return "intent_understanding" + + @property + def system_prompt_template_name(self) -> str: + return "system_prompt_for_intent_understanding" + + @property + def task_prompt_template_name(self) -> str: + return "task_prompt_for_intent_understanding" + + +@register("data_profiling") +class DataProfilingAgent(BaseAgent): + """Agent for profiling input tables using ReAct mode.""" + + def __init__(self, tool_manager: Optional[ToolManager] = None, **kwargs): + super().__init__(tool_manager=tool_manager, **kwargs) + if self.tool_manager is None: + self.tool_manager = get_tool_manager() + + @property + def role_name(self) -> str: + return "data_profiling" + + @property + def system_prompt_template_name(self) -> str: + return "system_prompt_for_data_profiling" + + @property + def task_prompt_template_name(self) -> str: + return "task_prompt_for_data_profiling" + + +@register("decompositer") +class DecompositerAgent(BaseAgent): + """Agent for decomposing complex tasks into sub-tasks.""" + + def __init__(self, tool_manager: Optional[ToolManager] = None, **kwargs): + super().__init__(tool_manager=tool_manager, **kwargs) + if self.tool_manager is None: + self.tool_manager = get_tool_manager() + + @property + def role_name(self) -> str: + return "decompositer" + + @property + def system_prompt_template_name(self) -> str: + return "system_prompt_for_decompositer" + + @property + def task_prompt_template_name(self) -> str: + return "task_prompt_for_decompositer" + + +@register("generator") +class GeneratorAgent(BaseAgent): + """Agent for generating Python code for table operations.""" + + def __init__(self, tool_manager: Optional[ToolManager] = None, **kwargs): + super().__init__(tool_manager=tool_manager, **kwargs) + if self.tool_manager is None: + self.tool_manager = get_tool_manager() + + @property + def role_name(self) -> str: + return "generator" + + @property + def system_prompt_template_name(self) -> str: + return "system_prompt_for_generator" + + @property + def task_prompt_template_name(self) -> str: + return "task_prompt_for_generator" + + +@register("debugger") +class DebuggerAgent(BaseAgent): + """Agent for debugging failed code.""" + + def __init__(self, tool_manager: Optional[ToolManager] = None, **kwargs): + super().__init__(tool_manager=tool_manager, **kwargs) + if self.tool_manager is None: + self.tool_manager = get_tool_manager() + + @property + def role_name(self) -> str: + return "debugger" + + @property + def system_prompt_template_name(self) -> str: + return "system_prompt_for_debugger" + + @property + def task_prompt_template_name(self) -> str: + return "task_prompt_for_debugger" + + +@register("summarizer") +class SummarizerAgent(BaseAgent): + """Agent for summarizing results using ReAct mode.""" + + def __init__(self, tool_manager: Optional[ToolManager] = None, **kwargs): + super().__init__(tool_manager=tool_manager, **kwargs) + if self.tool_manager is None: + self.tool_manager = get_tool_manager() + + @property + def role_name(self) -> str: + return "summarizer" + + @property + def system_prompt_template_name(self) -> str: + return "system_prompt_for_summarizer" + + @property + def task_prompt_template_name(self) -> str: + return "task_prompt_for_summarizer" diff --git a/workflow_engine/constants.py b/workflow_engine/constants.py new file mode 100644 index 0000000..1394565 --- /dev/null +++ b/workflow_engine/constants.py @@ -0,0 +1,58 @@ +# Centralized constants for TableAgent + +# ReAct / generation limits +MAX_REACT_STEPS = 7 +MAX_GENERATE_ATTEMPTS = 3 +MAX_DEBUG_ATTEMPTS = 5 + +# Tag markers +THINK_TAG_PATTERN = r"(.*?)" +ACTION_TAG_PATTERN = r"(.*?)" +ANSWER_TAG_PATTERN = r"(.*?)" +OBS_TAG_WRAPPER = "{obs}" + +# File naming +GENERATED_SCRIPT_NAME = "generated_step.py" +EVAL_RESULT_FILENAME = "eval_result.txt" + +# Log truncation +STDOUT_LOG_TRUNCATE = 1000 # characters +CODE_LOG_TRUNCATE = 1200 # characters + +# Safety / fallbacks +EMPTY_CODE_FALLBACK = "# Empty code produced by model" + +# Benchmark task types +BENCHMARK_TASK_TYPES = [ + # Table Cleaning + "TableCleaning-ErrorDetectionANDCorrection", + "TableCleaning-ColumnTypeAnnotation", + "TableCleaning-DataImputation", + "TableCleaning-Deduplication", + + # Table Transformation + "TableTransformation-RowToRowTransform", + "TableTransformation-SplittingANDConcatenation", + "TableTransformation-RowColumnSwapping", + "TableTransformation-Filtering", + "TableTransformation-Grouping", + "TableTransformation-Sorting", + "TableTransformation-ListExtraction", + + # Table Augmentation + "TableAugmentation-RowPopulation", + "TableAugmentation-SchemaAugmentation", + "TableAugmentation-ColumnAugmentation", + + # Table Matching + "TableMatching-SchemaMatching", + "TableMatching-EntityMatching" +] + +__all__ = [ + 'MAX_REACT_STEPS', 'MAX_GENERATE_ATTEMPTS', 'MAX_DEBUG_ATTEMPTS', + 'THINK_TAG_PATTERN', 'ACTION_TAG_PATTERN', 'ANSWER_TAG_PATTERN', + 'OBS_TAG_WRAPPER', 'GENERATED_SCRIPT_NAME', 'EVAL_RESULT_FILENAME', + 'STDOUT_LOG_TRUNCATE', 'CODE_LOG_TRUNCATE', 'EMPTY_CODE_FALLBACK', + 'BENCHMARK_TASK_TYPES' +] \ No newline at end of file diff --git a/workflow_engine/llm_callers/text.py b/workflow_engine/llm_callers/text.py index e3b19a3..110c147 100644 --- a/workflow_engine/llm_callers/text.py +++ b/workflow_engine/llm_callers/text.py @@ -1,5 +1,5 @@ -from typing import List -from langchain_core.messages import BaseMessage,AIMessage +from typing import List, Dict, Any +from langchain_core.messages import BaseMessage, AIMessage, HumanMessage from langchain_openai import ChatOpenAI from .base import BaseLLMCaller @@ -7,27 +7,101 @@ log = get_logger(__name__) + class TextLLMCaller(BaseLLMCaller): - """文本LLM调用器 - 原有实现""" - + """文本LLM调用器 - 原有实现 + + 支持通过 summary() 方法获取调用统计信息, + 也支持像函数一样调用:result = await caller(messages)。 + """ + + def __init__(self, state, model_name: str = None, temperature: float = 0.0, max_tokens: int = 10000): + super().__init__( + state=state, + model_name=model_name or getattr(state, 'model', None) or state.request.model, + temperature=temperature, + max_tokens=max_tokens, + ) + self._input_tokens = 0 + self._output_tokens = 0 + self._completion_time_sec = 0.0 + self._total_cost_usd = 0.0 + async def call(self, messages: List[BaseMessage], bind_post_tools: bool = False) -> AIMessage: + """调用 LLM""" + import time + start_time = time.time() + log.info(f"TextLLM调用,模型: {self.model_name}") - + llm = ChatOpenAI( openai_api_base=self.state.request.chat_api_url, openai_api_key=self.state.request.api_key, model_name=self.model_name, temperature=self.temperature, - # max_tokens=self.max_tokens, ) - + # 绑定工具(如果需要) if bind_post_tools and self.tool_manager: from langchain_core.tools import Tool - tools = self.tool_manager.get_post_tools("current_role") # 需要传入角色名 + tools = self.tool_manager.get_post_tools("current_role") if tools: llm = llm.bind_tools(tools, tool_choice=self.tool_mode) log.info(f"为LLM绑定了 {len(tools)} 个工具") - + response = await llm.ainvoke(messages) - return response \ No newline at end of file + + # 统计 + elapsed = time.time() - start_time + self._completion_time_sec += elapsed + # 估算 token(简化,实际应该从 response 获取 usage 信息) + self._input_tokens += self._estimate_tokens(messages) + self._output_tokens += self._estimate_tokens([response.content]) if hasattr(response, 'content') else 0 + + return response + + async def __call__(self, messages) -> AIMessage: + """支持像函数一样调用""" + if isinstance(messages, list): + return await self.call(messages) + elif isinstance(messages, dict): + # 支持传入消息字典列表 + converted = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "system": + converted.append({"type": "system", "content": content}) + elif role == "user": + converted.append({"type": "user", "content": content}) + elif role == "assistant": + converted.append({"type": "ai", "content": content}) + # 使用 LangChain 消息 + from langchain_core.messages import convert_to_messages + langchain_msgs = convert_to_messages(converted) + return await self.call(langchain_msgs) + else: + raise ValueError(f"Invalid messages type: {type(messages)}") + + def summary(self) -> Dict[str, Any]: + """返回调用统计摘要""" + return { + "input_tokens": self._input_tokens, + "output_tokens": self._output_tokens, + "completion_time_sec": self._completion_time_sec, + "total_cost_usd": self._total_cost_usd, + } + + def _estimate_tokens(self, messages) -> int: + """简单估算 token 数量""" + if isinstance(messages, str): + return len(messages) // 4 + total = 0 + for msg in messages: + if isinstance(msg, str): + total += len(msg) // 4 + elif hasattr(msg, 'content'): + total += len(str(msg.content)) // 4 + elif isinstance(msg, dict): + total += len(str(msg.get('content', ''))) // 4 + return total \ No newline at end of file diff --git a/workflow_engine/promptstemplates/resources/pt_table_agent_repo.py b/workflow_engine/promptstemplates/resources/pt_table_agent_repo.py new file mode 100644 index 0000000..bd719a6 --- /dev/null +++ b/workflow_engine/promptstemplates/resources/pt_table_agent_repo.py @@ -0,0 +1,363 @@ +"""Table Agent prompts template - prompts for table processing agents.""" + + +class intent_understanding: + task_prompt_for_intent_understanding = ''' +User request: {user_input}. Note: output only a JSON object, without ```json wrapping. +''' + + system_prompt_for_intent_understanding = ''' + [ROLE] + You are a Table preprocessing intent parsing API. Based on the metadata of the data table {task_meta}, parse the user's natural language instruction into a standardized JSON format to identify the required data processing operator. You must respond with only a JSON object—do not wrap it in ```json. + + [OUTPUT RULES]: + 1. **operation**: Clearly describe the required operator operation in natural language, use orders like 1. ..., 2. ... to list multiple operations if needed. You should describe as detailed as possible. + 2. **reason**: Briefly explain the rationale for selecting this operator (1–2 sentences), possibly referencing missing rates, data distribution, or task objectives in the metadata. + 3. **task_type**: Select exactly one matching task type from: + + "TableCleaning-ErrorDetectionANDCorrection", + "TableCleaning-ColumnTypeAnnotation", + "TableCleaning-DataImputation", + "TableCleaning-Deduplication", + + # Table Transformation + "TableTransformation-RowToRowTransform", + "TableTransformation-SplittingANDConcatenation", + "TableTransformation-RowColumnSwapping", + "TableTransformation-Filtering", + "TableTransformation-Grouping", + "TableTransformation-Sorting", + "TableTransformation-ListExtraction", + + # Table Augmentation + "TableAugmentation-RowPopulation", + "TableAugmentation-SchemaAugmentation", + "TableAugmentation-ColumnAugmentation", + + # Table Matching + "TableMatching-SchemaMatching", + "TableMatching-EntityMatching" + + The chosen type must strictly align with both the table metadata and the user's intent. + 4. **suffix**: Specify the output file format, e.g., "csv", "jsonl", etc. + + Your output must be a valid JSON object only, with no additional text, explanations, Markdown formatting (such as ```json), or line breaks. Strictly follow the format in the example below. + + [Example]: + User request: Fill missing values in the age column using an LSTM model trained on other columns + Your output: {{"operation": "1:fill missing values in the column age using LSTM model trained on other columns", "reason": "'age' has 30% missing values, which may impact analysis. Using LSTM can better capture relationships with other columns to impute missing values.", + "task_type": "TableCleaning-DataImputation", "suffix": "csv"}} + ''' + + +class data_profiling: + task_prompt_for_data_profiling = """ + +Begin data profiling for the target. {user_refine_input} + +""" + + system_prompt_for_data_profiling = """ + [ROLE] + You are a careful data profiler to prepare data report for target. Your role is to write code that analyzes tabular data files (CSV/Parquet) and produces a comprehensive data profiling report in JSON format. + You are given up to {MAX_REACT_STEPS} attempts to reach a conclusion. + [Inputs] + Files_paths: {raw_table_paths} + target: {operation} + + [Goal] + Prepare for what the target requires by analyzing the data files and producing a detailed profiling report. You are not to fulfill the target yet — only analyze and report for further processing. + Produce the final data profiling report as JSON inside {{"table_1":{{...}}, "table_2":{{...}}, ...}}, + where `table_x` is replaced by the **filename without extension** (e.g., `sales.csv` → `"sales"`). + At least you need to include the number of rows, number of columns, column names, column types(detect abnormal types like mixed types or unexpected nulls) and so on. + If the goal is to transform some column or correct some column..., you need to also analyze that column in detail, like unique values, missing rate, distribution for numeric columns so that the next step can be better performed. + If the number of unique values or some other statistics is small(like less than 20), you should list them all, otherwise you just need to sample 5~10 values. + Final report should be concise(don't surpass 200 characters unless necessary) but comprehensive, focusing on key statistics and insights. Useless information should be avoided. + + [Rules] + - In each turn: + - Use ... to describe your reasoning. + - Use ```python\n...\n``` to provide **standalone, executable Python code**. + → The code **must be wrapped in triple backticks with language specifier `python`**, like: + + ```python + import pandas as pd + df = pd.read_csv("file.csv") + print({{"columns": list(df.columns), "shape": list(df.shape)}}) + ``` + + - Use ... to provide the final JSON profiling report without any ``` formatting. + - After the action, wait for the observation (the printed output). + - After receiving observation, continue reasoning with , then issue next if needed. + - Your code must be **fully self-contained**: include all imports, data loading, and logic. Do *not* rely on prior context or variables. + - Always load data from the provided file paths. + - You should use print() to output results, which will be captured as observations. Avoid printing raw data or huge outputs. + - For multiple files: profile each one and include a separate entry in the final JSON report. + - Keep code precise and concise (≤50 lines per action unless absolutely necessary). + - Do **not** write any files to disk — only output via `print()`. + - Once profiling is complete, output the full report in ... as valid JSON without ```. + - + + + [EXAMPLE] + + I need to read the first CSV file and get basic column info. + + ```python + import pandas as pd + df = pd.read_csv("data/sales.csv") + print({{"columns": list(df.columns), "shape": list(df.shape)}}) + ``` + + + [Observation] + {{"columns": ["id", "amount", "date"], "shape": [1000, 3]}} + + Now I'll compute statistics for numeric columns... + + ```python + import pandas as pd + df = pd.read_csv("data/sales.csv") + numeric_cols = df.select_dtypes(include='number').columns + stats = df[numeric_cols].describe().to_dict() + print(stats) + ``` + + ... + + All tables are profiled. Compiling final JSON report. + {{...}} + """ + + +class decompositer: + task_prompt_for_decompositer = ''' + Begin decomposite the task. + task: {user_query} +''' + system_prompt_for_decompositer = ''' + [ROLE] + You are an expert in decomposing complex tasks into independent, executable sub-tasks. + + [OUTPUT RULES]: + 1. Decompose the task into independent sub-tasks. + 2. Each sub-task should be executable and can be found in {benchmark_task_types}. + 3. Output the result strictly in JSON format, mapping each sub-task type to its specific operation description, no ````json wrapping. + 4. Each key in the JSON should be different sub-task type, and the corresponding value should be the specific operation description. + + [Example]: + User request: Merge multiple CSV files and deduplicate entries based on a primary key. + Your output: {{"TableTransformation-SplittingANDConcatenation":"Merge multiple CSV files", "TableCleaning-Deduplication":"Deduplicate entries based on a primary key"}} + Note: ensure output's keys are different sub-task types. Concatenate multiple operations under the same sub-task type into one description if needed. + + ''' + + +class decomposition_codes: + task_prompt_for_decomposition_codes = ''' + Begin writting code snippets for each sub-task in {decomposition_result}. + ''' + + system_prompt_for_decomposition_codes = ''' + [ROLE] + You are an expert in writing code snippets for table processing sub-tasks. + Your role is to write bug-free Python code snippets for each sub-task identified in decomposition_result. + Each code snippet should be self-contained and executable, focusing solely on the specific sub-task. + If there exists retrived code snippets for similar operations in {retrieved_operators}, you can refer to them when writing the code snippets. + + [OUTPUT RULES]: + 1. For each sub-task, write a complete Python code snippet that accomplishes the task. + + [Example]: + Sub-tasks: {{"TableTransformation-SplittingANDConcatenation":"Merge multiple + CSV files", "TableCleaning-Deduplication":"Deduplicate entries based on a primary key"}} + + Retrieved similar operator code snippets: + [ ... ] + + Your output: + def merge_csv_files(file_paths): + import pandas as pd + def merge_csv_files(file_paths): + dataframes = [pd.read_csv(file_path) for file_path in file_paths] + merged_df = pd.concat(dataframes, ignore_index=True) + return merged_df + def deduplicate_entries(df, primary_key): + deduplicated_df = df.drop_duplicates(subset=primary_key, keep='first') + return deduplicated_df + ''' + + +class generator: + task_prompt_for_generator = ''' + "human", "User request: {user_input}. Note: pay attention to the order of file paths and think step by step." + ''' + system_prompt_for_generator = ''' + [ROLE] + You specialize in table proprocessing. Please generate a bug-free Python script based on the following information and the user's request: + + [INPUT] + 1. Metadata of the data table: {task_meta} + 2. Retrieved similar operator code snippets: {retrieved_operators}, If there exits, you can refer to them when writing the code. But do not copy them directly, you need to adapt them to fit the current task. + 3. Operator specification: {user_query} + + [OUTPUT RULES]: + 1. The code must be executable, safe, step by step and output as a complete code block in the format ```python ... ```. + 2. The code must include a main() function that accepts command-line arguments for input and output file paths. Use a fixed argparse format with two required arguments: --input (input file path or list of paths) and --output_path_dir (output file path directory). + 3. The function must fulfill all user requirements. Ensure the output file format matches the user's request and contains no extra columns beyond those in the input. + 5. If the task involves multiple tables, the --input argument should be treated as a list of file paths, this list can have 1,2... paths. And the --output_path_dir argument will be the directory where the results will be saved. + 6. Please avoid modifying the original input files; read from them and write results to new files in the specified(--output_path_dir) output directory. + 7. no BOM in output csv file means that don't use encoding='utf-8-sig' when saving csv file. + 8. Let the code step by step and don't use complex logic in one step. Use as many steps as needed to ensure clarity and correctness. + + [Example] + [INPUT] + User request: + Fill missing values in the age column using an LSTM model trained on other columns + Operator specification: + {{"operators": "fill missing values in the column age using LSTM model trained on other columns", "reason": "'age' has 30% missing values, which may impact analysis. Using LSTM can better capture relationships with other columns to impute missing values.", + "task_type": "TableCleaning-DataImputation", "suffix": "csv"}} + Metadata: + {{...}} + Retrieved similar operator code snippets: + [ ... ] + Debug_history: + [ ... ] + + [OUTPUT] (illustrative only): + ```python + import json + ...(import statements) + + def fill_missing_age_with_lstm(df): + # implement logic here + + def main(): + parser = argparse.ArgumentParser(description="Fill missing 'age' values using LSTM.") + parser.add_argument("--input", required=True, nargs='+', help="Path(s) to input CSV/Parquet file(s)") + parser.add_argument("--output_path_dir", required=True, help="Path to output file's directory") + args = parser.parse_args() + ... + df_filled.to_csv(output_path, index=False) + + if __name__ == "__main__": + main() + ``` + ''' + + +class debugger: + task_prompt_for_debugger = ''' + - The original code: {code} + - The error messages: {error} + - The target: {target} + - Input file paths sequence: {input_file_paths} + - Debug_history: {debug_history} + Note: you should avoid the previous mistakes based on the debug history above. + ''' + system_prompt_for_debugger = ''' + [ROLE] + You are an expert in code debugging and correction. + + [TASK] + Given the original code, error message, requirement. + and reference code, minimally modify the original code to fix the + error. Ensure your corrections are precise and focus on issues such as key alignment or import errors. + Output the corrected code and your reason for modification strictly in JSONformat, and follow all + specified requirements. + + [INPUT] + You will receive the following informations in human request: + - The original code: + - The error messages: + - The target: + - Raw data and expected data formats: + + [OUTPUT RULES] + 1. The response must be strictly in JSON format, containing only the keys "code" (with the complete corrected code) and "reason" (explaining the modification); + no extra keys, explanations, comments, or markdown syntax are allowed. + + 2. The code's --input and --output_path_dir arguments should be kept unchanged. + + 3. The code must include an `if __name__ == '__main__':` block to ensure the script can be run independently. + + 4. all paser arguments should have default values except --input and --output_path_dir. + + 5. Your output must be a valid JSON object only, with no additional text, explanations, Markdown formatting (such as ```json), or line breaks. + + 6. No additional files or external references should be included unless explicitly required to resolve the error, and if needed, they must be listed within the JSON under the appropriate key (though the current instruction prohibits extra keys, so such cases must be handled within the code itself). + +''' + + +class summarizer: + task_prompt_for_summarizer = ''' + Your previous score was {score}. The score_rule is: {score_rule}. + Based on the previous profiling trace summary: {summarizing_trace_summary}, find potential problems in the processed file(s) and give reasonable suggestions for improvement if not meet all requirements. + ''' + system_prompt_for_summarizer = """ + [ROLE] + You are a careful evaluator tasked with analyzing the processed results of a task to determine if they meet the target requirements. Your role is to write code that evaluates the processed file(s) and produces a summary of whether the target requirements are satisfied. and give reasonable suggestions for improvement if not. + You are given up to {MAX_REACT_STEPS} attempts to reach a conclusion. + + [Inputs] + metadata: {task_meta} this is the metadata after processing + processed_file_paths: {processed_file_paths} + raw_file_paths: {raw_file_paths} + task_objective: {task_objective} + + [Goal] + The generated code should directly assess the *content* of the processed file(s) for basic reasonableness — e.g., presence of required fields, structural/schema consistency, and absence of obvious anomalies (e.g., empty arrays, malformed JSON/CSV, unexpected nulls in critical columns). + + ❗ Do NOT generate ground truth (gt) or simulate expected outputs. + ❗ Do NOT compute, assign, or justify any numerical scores (e.g., no "0.8/1.0" reasoning). + ❗ Do NOT attempt to modify the files, nor check whether a hypothetical fix worked. + + ✅ identify concrete, observable issues and — if present — give short, actionable suggestions. + ✅ You should sample each columns' data from processed files and compare it with raw files to identify discrepancies based on the target requirements. + ✅ If inspection reveals no clear issues, or further analysis yields conclusions nearly identical to the previous round, promptly summarize concisely and output . + + [Rules] + - In each turn: + - Use ... to describe your reasoning. + - Use ```python\n...\n``` to provide **standalone, executable Python code**. + → The code **must be wrapped in triple backticks with language specifier `python`**, like: + + ```python + # Example evaluation logic + with open("processed_file.csv") as f: + content = f.read() + print("Evaluation result: Pass") + ``` + + - Use ... to provide the final evaluation summary as a string. + - After the action, wait for the observation (the printed output). + - After receiving observation, continue reasoning with , then issue next if needed. + - Your code must be **fully self-contained**: include all imports, data loading, and logic. Do *not* rely on prior context or variables. + - Always load data from the provided file paths. + - You should use print() to output results, which will be captured as observations. Avoid printing raw data or huge outputs. + - Keep code precise and concise (≤50 lines per action unless absolutely necessary). + - Do **not** write any files to disk — only output via `print()`. + - Once you find some problems or all requirements are met, output the final summary in ... as a string immediately. + + [EXAMPLE] + + I need to load the processed file and check if it meets the target requirements. + + ```python + import pandas as pd + df = pd.read_csv("...") + if "target_column" in df.columns: + print("Evaluation result: Pass") + else: + print("Evaluation result: Fail , missing 'target_column'") + ``` + + + [Observation] + Evaluation result: Fail , missing 'target_column' + + The processed file doesn't meet the target requirements. Let's check other requirements. + ... + The processed file miss 'target_column' and ... + """ diff --git a/workflow_engine/state.py b/workflow_engine/state.py index 676698e..455bf6c 100644 --- a/workflow_engine/state.py +++ b/workflow_engine/state.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field import os from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple current_file = Path(__file__).resolve() PROJDIR = current_file.parent.parent from typing_extensions import TypedDict, Annotated @@ -29,10 +29,13 @@ class MainRequest: def get(self, key, default=None): return getattr(self, key, default) - + def __setitem__(self, key, value): setattr(self, key, value) + def __getitem__(self, key): + return getattr(self, key) + # ==================== 最基础的 State(所有State的祖先)==================== @dataclass @@ -49,6 +52,9 @@ def get(self, key, default=None): def __setitem__(self, key, value): setattr(self, key, value) + + def __getitem__(self, key): + return getattr(self, key) # ==================== 主流程 Request ==================== @@ -550,3 +556,131 @@ class Paper2DrawioState(MainState): output_xml_path: str = "" # XML 文件路径 output_png_path: str = "" # PNG 导出路径 output_svg_path: str = "" # SVG 导出路径 + + +# ==================== Table Processing 相关 State(统一版本)==================== +@dataclass +class TableProcessingRequest(MainRequest): + """ + Table Processing 请求参数(统一版本,同时服务于 API 层和 Table Agent 工作流)。 + + API 层使用:datasources, instruction, output_format, title, task_type, operator_json_path, use_rag + Agent 层使用:target(即 instruction),task_objective + """ + + # ===== API 层参数 ===== + datasources: List[str] = field(default_factory=list) + instruction: str = "" # 用户指令 + output_format: str = "csv" + title: str = "" + notebook_id: str = "" # 用于生成唯一 result_path + + # 允许显式指定 tableAgent 的行为(参考 Processing_workflow/main/ProfiliTable.py) + task_type: str = "DataImputation" + operator_json_path: str = "" + use_rag: bool = True + + # ===== 与 TableAgentRequest 合并的字段 ===== + # 测试样例文件(仅 CLI 批量跑用) + json_file: str = "" + + # Python 代码文件位置 + python_file_path: str = "" + + # Debug 相关 + need_debug: bool = False + max_debug_rounds: int = 3 + + # 本地模型相关 + use_local_model: bool = False + local_model_path: str = "" + + # 缓存和会话 + session_id: str = "default_session" + + # embeddings url + chat_api_url_for_embeddings: str = "" + embedding_model_name: str = "text-embedding-3-small" + update_rag_content: bool = True + + def __post_init__(self): + # 确保 target 字段与 instruction 保持一致 + self.target = self.instruction or self.target + + +@dataclass +class TableProcessingState(MainState): + """ + Table Processing 状态类(统一版本,同时服务于 API 层和 Table Agent 工作流)。 + + 包含表格处理工作流所需的完整状态字段,同时保持与前端 API 的兼容性。 + """ + + request: TableProcessingRequest = field(default_factory=TableProcessingRequest) + + # ===== API 层输出字段 ===== + content: str = "" + sql: str = "" + columns: List[str] = field(default_factory=list) + rows: List[Dict[str, Any]] = field(default_factory=list) + row_count: int = 0 + + # 运行日志/错误 + error: str = "" + + # 内部运行目录(可用于落地 processed 文件等) + result_path: str = "" + + # ===== 与 TableAgentState 合并的字段 ===== + + # 任务配置 + task_objective: str = "" # 与 request.instruction 对齐 + score_threshold: float = 0.0 + data_profiling: Dict[str, Any] = field(default_factory=dict) + gt_table_path: str = "" + raw_table_paths: List[str] = field(default_factory=list) + score_func_path: str = "" + user_query: Dict[str, Any] = field(default_factory=dict) + is_dag: bool = False + score_rule: str = "" + profiling_trace_summary: str = "" + summarizing_trace_summary: str = "" + score: float = 0.0 + valid: bool = False + attempts: int = 0 + execution_time: float = 0.0 + debug_attempts: int = 0 + task_name: str = "" + summary: str = "" + processed_file_paths: List[str] = field(default_factory=list) + + # 节点记忆 + generated_codes: List[str] = field(default_factory=list) + error_logs: List[str] = field(default_factory=list) + evaluation_feedbacks: List[Dict[str, Any]] = field(default_factory=list) + + script_generated_total: int = 0 + script_runnable_total: int = 0 + debug_total_attempts: int = 0 + debug_reasons: List[str] = field(default_factory=list) + current_best_score_and_code: Tuple[float, str] = (0.0, "") + + # 任务分解相关 + decomposition_result: str = "" + decomposition_codes: str = "" + retrieved_operators: List[Any] = field(default_factory=list) + + # LLM tracker(用于计费和日志) + llm_tracker: Any = None + + def __post_init__(self): + # 确保 task_objective 与 request.instruction 保持一致 + if not self.task_objective and self.request.instruction: + self.task_objective = self.request.instruction + + + +# ==================== 向后兼容别名 ==================== +# 为了保持与旧代码的兼容性,保留 TableAgentRequest 和 TableAgentState 作为别名 +TableAgentRequest = TableProcessingRequest +TableAgentState = TableProcessingState diff --git a/workflow_engine/table_agent_utils.py b/workflow_engine/table_agent_utils.py new file mode 100644 index 0000000..7c594cf --- /dev/null +++ b/workflow_engine/table_agent_utils.py @@ -0,0 +1,366 @@ +"""TableAgent utilities - helper functions for table processing workflow.""" +from __future__ import annotations + +import json +import os +import re +import shutil +import subprocess +import sys +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union +import pandas as pd + +from workflow_engine.logger import get_logger + +log = get_logger(__name__) + + +def get_project_root() -> Path: + """获取项目根目录""" + return Path(__file__).resolve().parent.parent + + +def robust_parse_json( + text: str, + *, + merge_dicts: bool = False, + strip_double_braces: bool = False +) -> Union[Dict[str, Any], List[Any]]: + """ + 尽量从 LLM / 日志 / jsonl / Markdown 片段中提取合法 JSON。 + """ + s = text.strip() + s = _remove_markdown_fence(s) + s = _remove_outer_triple_quotes(s) + s = _remove_leading_json_word(s) + + if strip_double_braces: + s = s.replace("{{", "{").replace("}}", "}") + + s = _strip_json_comments(s) + s = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', s) + + try: + result = json.loads(s) + return result + except json.JSONDecodeError: + pass + + objs = _extract_json_objects(s) + if not objs: + raise ValueError("Unable to locate any valid JSON fragment.") + + return _maybe_merge(objs, merge_dicts) + + +def _remove_markdown_fence(src: str) -> str: + blocks = re.findall(r'```[\w-]*\s*([\s\S]*?)```', src, re.I) + return "\n".join(blocks).strip() if blocks else src + + +def _remove_outer_triple_quotes(src: str) -> str: + if (src.startswith("'''") and src.endswith("'''")) or ( + src.startswith('"""') and src.endswith('"""') + ): + return src[3:-3].strip() + return src + + +def _remove_leading_json_word(src: str) -> str: + return src[4:].lstrip() if src.lower().startswith("json") else src + + +def _strip_json_comments(src: str) -> str: + src = re.sub(r'/\*[\s\S]*?\*/', '', src) + src = re.sub(r'(?![:"\'])//.*', '', src) + src = re.sub(r',\s*([}\]])', r'\1', src) + return src.strip() + + +def _extract_json_objects(src: str) -> List[Any]: + from json import JSONDecoder + dec = JSONDecoder() + idx, n = 0, len(src) + objs: List[Any] = [] + + while idx < n: + m = re.search(r'[{\[]', src[idx:]) + if not m: + break + idx += m.start() + try: + obj, end = dec.raw_decode(src, idx) + tail = src[end:].lstrip() + if tail and tail[0] not in ',]}>\n\r': + idx += 1 + continue + objs.append(obj) + idx = end + except json.JSONDecodeError: + idx += 1 + return objs + + +def _maybe_merge(objs: List[Any], merge_dicts: bool) -> Union[Any, List[Any]]: + if len(objs) == 1: + return objs[0] + if merge_dicts and all(isinstance(o, dict) for o in objs): + merged: Dict[str, Any] = {} + for o in objs: + merged.update(o) + return merged + return objs + + +def get_paths(base_dir: str) -> Tuple[str, str]: + """获取代码路径和输出目录路径""" + if not os.path.exists(base_dir): + os.makedirs(base_dir) + code_path = os.path.join(base_dir, 'generated_code.py') + processed_table_path = os.path.join(base_dir, "results") + if os.path.exists(processed_table_path): + shutil.rmtree(processed_table_path) + os.makedirs(processed_table_path) + return code_path, processed_table_path + + +def safe_exec_code(py_path: Union[str, Path], output_path: Union[str, List[str]], input_path: List[Any] = None) -> Tuple[str, float]: + """安全执行 Python 代码""" + time_before = time.time() + py_path = Path(py_path) + if not os.path.isfile(py_path): + raise FileNotFoundError(f"Script not found: {py_path}") + if py_path.suffix.lower() != '.py': + raise ValueError("Only .py files are allowed") + + if input_path: + input_args = [str(p) for p in input_path] + result = subprocess.run( + [sys.executable, str(py_path), "--input", *input_args, "--output", str(output_path)], + capture_output=True, + text=True, + timeout=600 + ) + else: + output_args = [arg for out in output_path for arg in ("--output", out)] + result = subprocess.run( + [sys.executable, str(py_path), *output_args], + capture_output=True, + text=True, + timeout=600 + ) + + if result.returncode != 0: + raise RuntimeError(f"Script failed:\n{result.stderr}") + + return result.stdout.strip(), time.time() - time_before + + +def extract_python_code_block(content: str) -> str: + """从内容中提取 Python 代码块""" + match = re.search(r"```python\s*(.*?)\s*```", content, re.DOTALL) + if match: + return match.group(1).strip() + return content.strip() or "# Empty code" + + +def write_code_file(res_path: str, code: str) -> Path: + """写入代码文件""" + code_path, _ = get_paths(res_path) + Path(code_path).write_text(code, encoding="utf-8") + return Path(code_path) + + +def parse_react_output(raw: str) -> Dict[str, Any]: + """解析 ReAct 风格的输出""" + from workflow_engine.constants import ( + THINK_TAG_PATTERN, ACTION_TAG_PATTERN, ANSWER_TAG_PATTERN + ) + + result = {"thinks": [], "action_code": None, "answer_obj": None, "errors": []} + think_blocks = re.findall(THINK_TAG_PATTERN, raw, re.DOTALL | re.IGNORECASE) + result["thinks"] = [t.strip() for t in think_blocks if t.strip()] + + answer_match = re.search(ANSWER_TAG_PATTERN, raw, re.DOTALL) + if answer_match: + ans_raw = answer_match.group(1) + try: + result["answer_obj"] = json.loads(ans_raw) + except Exception: + result["answer_obj"] = ans_raw.strip() + + action_match = re.search(ACTION_TAG_PATTERN, raw, re.DOTALL | re.IGNORECASE) + if action_match and result["answer_obj"] is None: + act_raw = action_match.group(1) + try: + result["action_code"] = extract_python_code_block(act_raw) + except Exception as e: + result["errors"].append(f"action_json_parse_failed: {e}") + + return result + + +def observation_to_message(obs: Dict[str, Any]) -> str: + """将观察结果转换为消息""" + from workflow_engine.constants import OBS_TAG_WRAPPER + return OBS_TAG_WRAPPER.format(obs=json.dumps(obs, ensure_ascii=False, separators=(",", ":"))) + + +def truncate_for_log(text: str, limit: int = 1000) -> str: + """截断日志文本""" + if len(text) <= limit: + return text + return text[:limit] + "..." + + +def load_config(config_path: str) -> Dict[str, Any]: + """加载 YAML 配置文件""" + import yaml + with open(config_path, 'r', encoding='utf-8') as f: + return yaml.safe_load(f) + + +def profile_multiple_csvs(csv_paths: List[str], output_dir: str) -> Dict[str, Any]: + """对多个 CSV 文件进行数据画像""" + import pandas as pd + + all_profiles = {} + for file_path in csv_paths: + path = Path(file_path) + try: + df = _read_file_as_dataframe(path) + profile = _simple_data_profile(df) + profile["filename"] = path.name + all_profiles[path.name] = profile + except Exception as e: + all_profiles[path.name] = {"error": str(e), "file_path": str(path)} + + output_path = Path(output_dir) / "default_data_profiling.json" + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(all_profiles, f, indent=2, ensure_ascii=False) + + return all_profiles + + +def _simple_data_profile(df: pd.DataFrame) -> dict: + """为单个 DataFrame 生成轻量级数据概要""" + profile = { + "num_rows": len(df), + "num_columns": df.shape[1], + "columns": {} + } + + for col in df.columns: + series = df[col] + col_info = { + "dtype": str(series.dtype), + "non_null_count": int(series.count()), + "null_count": int(series.isnull().sum()), + "unique_count": int(series.nunique(dropna=True)), + } + + if pd.api.types.is_numeric_dtype(series): + col_info["stats"] = { + "mean": float(series.mean()) if not series.isna().all() else None, + "std": float(series.std()) if not series.isna().all() else None, + "min": float(series.min()) if not series.isna().all() else None, + "max": float(series.max()) if not series.isna().all() else None, + } + elif pd.api.types.is_string_dtype(series) or series.dtype == 'object': + top_values = series.value_counts().head(10).to_dict() + col_info["top_values"] = {str(k): int(v) for k, v in top_values.items()} + + profile["columns"][col] = col_info + + return profile + + +def _read_file_as_dataframe(path: Path) -> pd.DataFrame: + """智能读取 CSV / JSON / JSONL""" + import pandas as pd + + suffix = path.suffix.lower() + + try: + if suffix == '.json': + df = pd.read_json(path) + if isinstance(df, pd.Series) or (len(df) == 1 and isinstance(df.iloc[0], (dict, list))): + with open(path, 'r', encoding='utf-8') as f: + raw = json.load(f) + if isinstance(raw, dict) and 'data' in raw: + df = pd.json_normalize(raw['data']) + elif isinstance(raw, list): + df = pd.json_normalize(raw) + else: + df = pd.json_normalize([raw]) if isinstance(raw, dict) else pd.DataFrame() + elif suffix == '.jsonl': + records = [] + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + records.append(json.loads(line)) + df = pd.json_normalize(records) + elif suffix == '.csv': + df = pd.read_csv(path) + else: + raise ValueError(f"Unsupported file extension: {suffix}") + + return df + + except Exception as e: + raise RuntimeError(f"Failed to read {path.name}: {e}") + + +def write_eval_result(state: Dict[str, Any]) -> None: + """写入评估结果""" + from workflow_engine.constants import EVAL_RESULT_FILENAME + + llm_tracker = state.get("llm_tracker") + if llm_tracker: + summary = llm_tracker.summary() + money_cost = summary.get("total_cost_usd", 0.0) + input_tokens = summary.get("input_tokens", 0) + output_tokens = summary.get("output_tokens", 0) + completion_time = summary.get("completion_time_sec", 0.0) + else: + money_cost = input_tokens = output_tokens = completion_time = 0.0 + + task_name = state.get("task_name", "unknown") + profiling = state.get("data_profiling", {}) + + eval_result_path = os.path.join(state.get("result_path", state.get("res_path", "")), EVAL_RESULT_FILENAME) + summary_lines = [ + f"task_name: {task_name}", + f"input_tokens: {input_tokens}", + f"output_tokens: {output_tokens}", + f"completion_time: {completion_time:.3f}", + f"execution_time: {state.get('execution_time', 0):.3f}", + f"Money Cost: {money_cost:.3f}", + "", + f"generated_attempts: {state.get('attempts', 0)}", + f"debug_total_attempts: {state.get('debug_total_attempts', 0)}", + f"script_generated_total: {state.get('script_generated_total', 0)}", + f"script_runnable_total: {state.get('script_runnable_total', 0)}", + "", + ] + + data_profiling_path = os.path.join(state.get("result_path", state.get("res_path", "")), "data_profiling.json") + try: + with open(data_profiling_path, "w", encoding="utf-8") as f: + json.dump(profiling, f, ensure_ascii=False, indent=2) + except Exception: + pass + + error_logs = state.get("error_logs", []) + if error_logs: + summary_lines.append("Error Logs:") + summary_lines.extend(error_logs) + + content = "\n".join(summary_lines) + "\n" + with open(eval_result_path, "w", encoding="utf-8") as f: + f.write(content) diff --git a/workflow_engine/workflow/wf_table_processing_api.py b/workflow_engine/workflow/wf_table_processing_api.py new file mode 100644 index 0000000..772080e --- /dev/null +++ b/workflow_engine/workflow/wf_table_processing_api.py @@ -0,0 +1,226 @@ +""" +Table Processing API - REST API endpoint for table processing functionality. + +This module provides an API for processing tabular data using the table processing workflow. +It can be called from external services to process CSV files with natural language instructions. +""" +from __future__ import annotations + +import csv +import json +import os +import sys +import time +import traceback +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from workflow_engine.graphbuilder.graph_builder import GenericGraphBuilder +from workflow_engine.logger import get_logger +from workflow_engine.state import TableProcessingRequest, TableProcessingState +from workflow_engine.workflow.registry import register +from workflow_engine.utils import get_project_root +from workflow_engine.llm_callers.text import TextLLMCaller +from workflow_engine.table_agent_utils import profile_multiple_csvs, load_config + +log = get_logger(__name__) + + +def _workspace_root() -> Path: + """ + Open-NotebookLM-test 位于 /data/lw/notebook/Open-NotebookLM-test + workspace root 为 /data/lw + """ + project_root = get_project_root() + return project_root.parents[2] + + +def _processing_workflow_root() -> Path: + return _workspace_root() / "Processing_workflow" + + +def _default_operator_json_path() -> str: + """ + 默认 operator json:优先 Processing_workflow,其次 tableAgent。 + """ + candidates = [ + _processing_workflow_root() / "table_agent" / "utils" / "operators" / "Operators.json", + _workspace_root() / "tableAgent" / "table_agent" / "utils" / "operators" / "Operators.json", + ] + for p in candidates: + if p.exists(): + return str(p) + return str(candidates[0]) + + +def _read_csv_sample(csv_path: str, max_rows: int = 20) -> Tuple[List[str], List[Dict[str, Any]], int]: + """ + 读取生成后的第一个 csv,并返回(列名、前 N 行记录、行数)。 + 为了避免大文件一次性读全,行数使用简单计数(可能较慢但比全量 DataFrame 更稳)。 + """ + columns: List[str] = [] + rows: List[Dict[str, Any]] = [] + row_count = 0 + + with open(csv_path, "r", encoding="utf-8", errors="ignore", newline="") as f: + reader = csv.DictReader(f) + columns = [str(c) for c in (reader.fieldnames or [])] + for idx, row in enumerate(reader): + if idx < max_rows: + rows.append(row) + row_count += 1 + + return columns, rows, int(row_count) + + +@register("table_processing_api") +def create_table_processing_api_graph() -> GenericGraphBuilder: + builder = GenericGraphBuilder(state_model=TableProcessingState, entry_point="_start_") + + def _start_(state: TableProcessingState) -> TableProcessingState: + state.content = "" + state.sql = "" + state.columns = [] + state.rows = [] + state.row_count = 0 + state.error = "" + return state + + def _route(_: TableProcessingState) -> str: + return "run_table_agent" + + async def run_table_agent(state: TableProcessingState) -> TableProcessingState: + try: + if not state.request.datasources: + state.error = "datasources is empty" + return state + + raw_paths: List[str] = [str(p).strip() for p in state.request.datasources if str(p).strip()] + if not raw_paths: + state.error = "No valid datasource paths" + return state + + project_root = get_project_root() + base_out_dir = project_root / "outputs" / "table_processing" + ts = int(time.time()) + # 使用 notebook_id 和 timestamp 生成唯一目录 + notebook_id = getattr(state.request, 'notebook_id', 'default') or 'default' + out_dir = base_out_dir / f"{notebook_id}_{ts}" + out_dir.mkdir(parents=True, exist_ok=True) + state.result_path = str(out_dir) + + # 加载配置 + cfg_path = _processing_workflow_root() / "main" / "config.yaml" + cfg = load_config(str(cfg_path)) if cfg_path.exists() else {} + operator_json_path = ( + state.request.operator_json_path + or (cfg.get("paths", {}) or {}).get("operator_json_path") + or _default_operator_json_path() + ) + try: + if operator_json_path: + operator_json_path = str((cfg_path.parent / operator_json_path).resolve()) + except Exception: + operator_json_path = str(operator_json_path or "") + + # 1) profiling + log.info(f"[table_processing_api] profiling multiple csv: {raw_paths}") + data_profiling = profile_multiple_csvs(raw_paths, str(out_dir)) + + # 2) 设置 TableProcessingState 字段(统一后的 State) + # 注意:task_objective 已经在 __post_init__ 中与 request.instruction 对齐 + state.task_objective = state.request.instruction # 确保同步 + state.score_threshold = 0.0 + state.task_type = state.request.task_type or "TableCleaning-DataImputation" + state.data_profiling = data_profiling + state.raw_table_paths = raw_paths + state.score_func_path = "" + state.gt_table_path = "" + state.profiling_trace_summary = "" + state.summarizing_trace_summary = "" + state.score = -1.0 + state.valid = True + state.attempts = 0 + state.is_dag = False + state.error_logs = [] + state.execution_time = 0.0 + state.score_rule = "" + state.debug_attempts = 0 + state.task_name = state.request.title or "table_processing" + state.operator_json_path = operator_json_path + state.current_best_score_and_code = (0.0, "") + + # 3) 初始化 LLM tracker + llm_tracker = TextLLMCaller( + state, + model_name=state.request.model or "deepseek-v3.2", + temperature=0.3, + max_tokens=10000, + ) + state.llm_tracker = llm_tracker + + # 4) 执行 workflow(使用统一的 TableProcessingState) + log.info("[table_processing_api] running table_processing_workflow...") + # 延迟导入,避免循环导入问题 + from workflow_engine.workflow import run_workflow + final_state = await run_workflow("table_processing_workflow", state) + + # 5) 提取结果 + summary = "" + best_code = "" + processed_files = [] + + if isinstance(final_state, dict): + summary = str(final_state.get("summary") or "") + processed_files = final_state.get("processed_file_paths") or [] + cur_best = final_state.get("current_best_score_and_code") or (0.0, "") + if isinstance(cur_best, (list, tuple)) and len(cur_best) >= 2: + best_code = str(cur_best[1] or "") + else: + summary = str(getattr(final_state, "summary", "") or "") + processed_files = getattr(final_state, "processed_file_paths", []) or [] + cur_best = getattr(final_state, "current_best_score_and_code", (0.0, "")) + if isinstance(cur_best, (list, tuple)) and len(cur_best) >= 2: + best_code = str(cur_best[1] or "") + + state.content = summary or (best_code[:2000] if best_code else "表格处理已完成,但未返回可展示的摘要。") + state.sql = best_code or "" + + # 6) 读取第一个 csv 输出作为预览,并保存 result_path 供下载 + csv_path = None + for f in processed_files: + p = str(f) + if p.lower().endswith(".csv") and Path(p).exists(): + csv_path = p + break + + if csv_path: + columns, rows, row_count = _read_csv_sample(csv_path, max_rows=20) + state.columns = columns + state.rows = rows + state.row_count = row_count + else: + # 如果没有 CSV 文件,仍然保存 result_path 用于调试 + state.columns = [] + state.rows = [] + state.row_count = 0 + + return state + + except Exception as e: + log.error(f"[table_processing_api] failed: {e}") + state.error = str(e) + state.content = "处理失败,请稍后重试。" + state.sql = traceback.format_exc() + return state + + nodes = { + "_start_": _start_, + "run_table_agent": run_table_agent, + "_end_": lambda s: s, + } + edges = [ + ("run_table_agent", "_end_"), + ] + builder.add_nodes(nodes).add_edges(edges).add_conditional_edge("_start_", _route) + return builder \ No newline at end of file diff --git a/workflow_engine/workflow/wf_table_processing_workflow.py b/workflow_engine/workflow/wf_table_processing_workflow.py new file mode 100644 index 0000000..5f5a26e --- /dev/null +++ b/workflow_engine/workflow/wf_table_processing_workflow.py @@ -0,0 +1,557 @@ +"""Table Processing Workflow - Self-contained workflow using workflow_engine components. + +This workflow implements the full table processing pipeline: +1. intent_understanding: Parse user intent into task type and operation +2. data_profiling: Profile the input tables +3. decompositer: Decompose complex tasks (optional) +4. generator: Generate Python code for the task +5. evaluator: Execute and validate the generated code +6. debugger: Debug failed code (if needed) +7. summarizer: Summarize results + +The workflow uses workflow_engine's agent system and tools instead of table_agent. +""" +from __future__ import annotations + +import json +import os +import shutil +import traceback +from pathlib import Path +from typing import Literal + +from langchain_core.messages import HumanMessage, AIMessage + +from workflow_engine.graphbuilder.graph_builder import GenericGraphBuilder +from workflow_engine.state import TableProcessingState +from workflow_engine.workflow.registry import register +from workflow_engine.agentroles import create_simple_agent +from workflow_engine.logger import get_logger +from workflow_engine.constants import ( + MAX_REACT_STEPS, + MAX_GENERATE_ATTEMPTS, + MAX_DEBUG_ATTEMPTS, + BENCHMARK_TASK_TYPES, +) +from workflow_engine.table_agent_utils import ( + get_paths, + safe_exec_code, + extract_python_code_block, + write_code_file, + write_eval_result, + profile_multiple_csvs, + parse_react_output, + observation_to_message, + truncate_for_log, +) + +# 导入自定义策略以触发自动注册,并获取 create_table_react_agent 函数 +from workflow_engine.workflow.wf_table_strategy import create_table_react_agent + +log = get_logger(__name__) + +CURRENT_DIR = Path(__file__).parent.parent.resolve() + + +@register("table_processing_workflow") +def create_table_processing_workflow() -> GenericGraphBuilder: + """Create the table processing workflow graph.""" + builder = GenericGraphBuilder( + state_model=TableProcessingState, + entry_point="intent_understanding" + ) + + # ======================================================================= + # Pre-tools for each node + # ======================================================================= + + @builder.pre_tool("user_input", "intent_understanding") + def _user_input(state: TableProcessingState): + return state.get("task_objective", "") + + @builder.pre_tool("task_meta", "intent_understanding") + def _task_meta(state: TableProcessingState): + return state.get("data_profiling", {}) + + # ----------------------------------------------------------------------- + # Data profiling pre-tools + @builder.pre_tool("raw_table_paths", "data_profiling") + def _raw_table_paths(state: TableProcessingState): + return state.get("raw_table_paths", []) + + @builder.pre_tool("operation", "data_profiling") + def _operation(state: TableProcessingState): + user_query = state.get("user_query", {}) + if isinstance(user_query, dict): + return user_query.get("operation", "") + return "" + + @builder.pre_tool("MAX_REACT_STEPS", "data_profiling") + def _max_react_steps(state: TableProcessingState): + return MAX_REACT_STEPS + + @builder.pre_tool("user_refine_input", "data_profiling") + def _user_refine_input(state: TableProcessingState): + score = state.get("score", 0.0) + score_rule = state.get("score_rule", "") + profiling_trace_summary = state.get("profiling_trace_summary", "") + insight = state.get("summary", "") + return f""" + Based on the previous profiling trace summary: {profiling_trace_summary}, + Based on the previous insight(about why agent don't do well): {insight}, try your best to improve quality of the profiling. + """ + + # ----------------------------------------------------------------------- + # Decompositer pre-tools + @builder.pre_tool("user_query", "decompositer") + def _decompositer_user_query(state: TableProcessingState): + user_query = state.get("user_query", {}) + if isinstance(user_query, dict): + return json.dumps(user_query, ensure_ascii=False) + return str(user_query) + + @builder.pre_tool("benchmark_task_types", "decompositer") + def _benchmark_task_types(state: TableProcessingState): + return json.dumps(BENCHMARK_TASK_TYPES, ensure_ascii=False) + + # ----------------------------------------------------------------------- + # Generator pre-tools + @builder.pre_tool("task_meta", "generator") + def _generator_task_meta(state: TableProcessingState): + return state.get("data_profiling", {}) + + @builder.pre_tool("user_query", "generator") + def _generator_user_query(state: TableProcessingState): + user_query = state.get("user_query", {}) + if isinstance(user_query, dict): + return json.dumps(user_query, ensure_ascii=False) + return str(user_query) + + @builder.pre_tool("user_input", "generator") + def _generator_user_input(state: TableProcessingState): + user_inputs = [state.get("task_objective", "")] + input_paths = " ".join(Path(p).name for p in state.get("raw_table_paths", [])) + user_inputs.append(f"输入文件的顺序与类型:{input_paths}") + if state.get("debug_reasons"): + user_inputs.append( + f"Debug Reasons:{state['debug_reasons']} You need to avoid the previous mistakes." + ) + if state.get("summary"): + user_inputs.append(f"Additional Context: {state.get('summary')}") + return user_inputs + + @builder.pre_tool("retrieved_operators", "generator") + def _retrieved_operators(state: TableProcessingState): + return state.get("retrieved_operators", []) + + # ----------------------------------------------------------------------- + # Debugger pre-tools + @builder.pre_tool("code", "debugger") + def _code(state: TableProcessingState): + previews_generated_codes = state.get("generated_codes", []) + return previews_generated_codes[-1] if previews_generated_codes else "" + + @builder.pre_tool("error", "debugger") + def _error(state: TableProcessingState): + error_logs = state.get("error_logs", []) + return error_logs[-1] if error_logs else "" + + @builder.pre_tool("target", "debugger") + def _target(state: TableProcessingState): + user_query = state.get("user_query", {}) + if isinstance(user_query, dict): + return user_query.get("operation", "") + return "" + + @builder.pre_tool("input_file_paths", "debugger") + def _input_file_paths(state: TableProcessingState): + return " ".join(Path(p).name for p in state.get("raw_table_paths", [])) + + @builder.pre_tool("debug_history", "debugger") + def _debug_history(state: TableProcessingState): + error_logs = state.get("error_logs", []) + return error_logs[:-1] if len(error_logs) > 1 else "No previous debug history." + + # ----------------------------------------------------------------------- + # Summarizer pre-tools + @builder.pre_tool("processed_file_paths", "summarizer") + def _processed_file_paths(state: TableProcessingState): + return state.get("processed_file_paths", []) + + @builder.pre_tool("task_objective", "summarizer") + def _task_objective(state: TableProcessingState): + return state.get("task_objective", "") + + @builder.pre_tool("raw_file_paths", "summarizer") + def _raw_file_paths(state: TableProcessingState): + return state.get("raw_table_paths", []) + + @builder.pre_tool("score", "summarizer") + def _score(state: TableProcessingState): + return state.get("score", 0.0) + + @builder.pre_tool("score_rule", "summarizer") + def _score_rule(state: TableProcessingState): + return state.get("score_rule", "") + + @builder.pre_tool("summarizing_trace_summary", "summarizer") + def _summarizing_trace_summary(state: TableProcessingState): + return state.get("summarizing_trace_summary", "") + + @builder.pre_tool("task_meta", "summarizer") + def _summarizer_task_meta(state: TableProcessingState): + return state.get("data_profiling", {}) + + @builder.pre_tool("MAX_REACT_STEPS", "summarizer") + def _summarizer_max_react_steps(state: TableProcessingState): + return MAX_REACT_STEPS + + # ======================================================================= + # Node implementations + # ======================================================================= + + async def intent_understanding(state: TableProcessingState) -> TableProcessingState: + """Parse user intent into task type and operation.""" + log.info("🔍 开始意图识别...") + model = state.request.model or state.get("model", "deepseek-v3.2") + + agent = create_simple_agent( + name="intent_understanding", + model_name=model, + temperature=0.3, + max_tokens=20480, + parser_type="json", + ) + + state = await agent.execute(state=state) + agent_result = state.agent_results.get("intent_understanding", {}) + user_query = agent_result.get("results", {}) + + if isinstance(user_query, dict): + required = {"operation", "reason", "task_type", "suffix"} + if not required.issubset(user_query.keys()): + raise ValueError(f"Missing required fields in intent_understanding result: {required - set(user_query.keys())}") + + state["user_query"] = user_query + state["task_type"] = user_query.get("task_type", "TableCleaning-DataImputation") + log.info(f"✅ 意图解析成功: {user_query}") + return state + + async def data_profiling(state: TableProcessingState) -> TableProcessingState: + """Profile the input tables using TableReAct Strategy.""" + log.info("📊 开始数据画像...") + model = state.request.model or state.get("model", "deepseek-v3.2") + + agent = create_table_react_agent( + name="data_profiling", + model_name=model, + max_retries=MAX_REACT_STEPS, + parser_type="json", + ) + + state = await agent.execute(state=state) + agent_result = state.agent_results.get("data_profiling", {}) + results = agent_result.get("results", {}) + data_profiling = results.get("answer", {}) + react_trace = results.get("react_trace", []) + + log.info(f"📊 Profiling 原始输出:{data_profiling}") + + if state.get("llm_tracker"): + summary_messages = [ + {"role": "system", "content": "Summarize the following ReAct trace briefly."}, + {"role": "user", "content": f"ReAct Trace: {json.dumps(react_trace, ensure_ascii=False)}"}, + ] + local_summarizer_response = await state["llm_tracker"](summary_messages) + profiling_trace_summary = local_summarizer_response.content.strip() + log.info(f"📊 Local Summarizer ReAct Trace Summary:\n{profiling_trace_summary}") + else: + profiling_trace_summary = "" + + if "error" not in data_profiling: + state["data_profiling"] = data_profiling + else: + log.warning("📊 Profiling 返回 error,保留原有的 state.data_profiling") + + state["profiling_trace_summary"] = profiling_trace_summary + state["execution_time"] = results.get("execution_time", 0.0) + return state + + async def decompositer(state: TableProcessingState) -> TableProcessingState: + """Decompose complex tasks into sub-tasks.""" + log.info("🔄 开始任务分解...") + model = state.request.model or state.get("model", "deepseek-v3.2") + + agent = create_simple_agent( + name="decompositer", + model_name=model, + temperature=0.1, + max_tokens=20480, + parser_type="text", + ) + + state = await agent.execute(state=state) + agent_result = state.agent_results.get("decompositer", {}) + decomposition_result = agent_result.get("results", {}) + + if isinstance(decomposition_result, dict): + decomposition_result = json.dumps(decomposition_result, ensure_ascii=False) + elif not isinstance(decomposition_result, str): + decomposition_result = str(decomposition_result) + + log.info(f"Decompositer 原始输出:{decomposition_result}") + + try: + parsed_result = json.loads(decomposition_result) + for sub_task, task_desc in parsed_result.items(): + log.info(f"子任务: {sub_task} 描述: {task_desc}") + except json.JSONDecodeError: + log.warning(f"Decompositer output is not valid JSON: {decomposition_result}") + parsed_result = {} + + state["decomposition_result"] = decomposition_result + state["retrieved_operators"] = [] + return state + + async def generator(state: TableProcessingState) -> TableProcessingState: + """Generate Python code for the task using TableReAct Strategy.""" + log.info("🛠️ 开始代码生成...") + model = state.request.model or state.get("model", "deepseek-v3.2") + + agent = create_simple_agent( + name="generator", + model_name=model, + parser_type="text", + temperature=0.1, + max_tokens=20480, + ) + + state = await agent.execute(state=state) + + code = extract_python_code_block(state["agent_results"]["generator"]["results"]['text']) + code_path = write_code_file(state.get("result_path", ""), code) + + attempts = state.get("attempts", 0) + 1 + generated_codes = state.get("generated_codes", []) + [code] + + log.info(f"GenerationStrategy 生成代码 (尝试 {attempts})") + log.info(f"代码已写入: {code_path}") + + state["generated_codes"] = generated_codes + state["attempts"] = attempts + return state + + async def evaluator(state: TableProcessingState) -> TableProcessingState: + """Execute and validate the generated code.""" + log.info("🧪 开始执行与评估...") + execution_time = state.get("execution_time", 0.0) + script_generated_total = state.get("script_generated_total", 0) + 1 + script_runnable_total = state.get("script_runnable_total", 0) + + try: + code_path, process_table_path = get_paths(state.get("result_path", "")) + log.info(f"Raw table paths: {state['raw_table_paths']}") + + stdout, exec_time = safe_exec_code( + code_path, + process_table_path, + state.get("raw_table_paths", []) + ) + execution_time += exec_time + + processed_files = [str(f) for f in Path(process_table_path).iterdir() if f.is_file()] + log.info(f"✅ 评估成功 | Processed files: {processed_files} | Execution time: {execution_time:.2f}s") + + feedback = {"status": "success", "reason": "Execution succeeded."} + script_runnable_total += 1 + + return { + "messages": [AIMessage(content=f"[Evaluator] feedback={feedback}")], + "valid": True, + "evaluation_feedbacks": state.get("evaluation_feedbacks", []) + [feedback], + "execution_time": execution_time, + "script_generated_total": script_generated_total, + "script_runnable_total": script_runnable_total, + "processed_file_paths": processed_files, + } + except Exception as e: + tb = traceback.format_exc() + error_msg = f"[Evaluator] 执行失败: {e}\n{tb}" + log.error(f"💥 执行失败 | Error: {e}") + feedback = {"score": 0.0, "status": "error", "reason": str(e), "traceback": tb} + + return { + "messages": [AIMessage(content="Error:" + error_msg)], + "valid": False, + "error_logs": state.get("error_logs", []) + [error_msg], + "evaluation_feedbacks": state.get("evaluation_feedbacks", []) + [feedback], + "execution_time": execution_time, + "script_generated_total": script_generated_total, + } + + async def debugger(state: TableProcessingState) -> TableProcessingState: + """Debug failed code.""" + log.info("🐞 开始调试代码...") + code = "" + reason = "" + model = state.request.model or state.get("model", "deepseek-v3.2") + + agent = create_simple_agent( + name="debugger", + model_name=model, + temperature=0.1, + max_tokens=20480, + parser_type="json", + ) + + for attempt in range(MAX_DEBUG_ATTEMPTS): + state = await agent.execute(state=state) + agent_result = state.agent_results.get("debugger", {}) + last_raw = str(agent_result.get("results", {})) + + try: + if isinstance(agent_result.get("results"), dict): + parsed = agent_result["results"] + else: + parsed = json.loads(last_raw) + code = parsed.get("code", "") + reason = parsed.get("reason", "") + except json.JSONDecodeError: + log.warning(f"调试 JSON 解析失败: {attempt + 1}") + continue + + if code: + break + + if not code: + eval_result_path = os.path.join(state.get("result_path", ""), "eval_result.txt") + with open(eval_result_path, "w", encoding="utf-8") as f: + f.write("Score: 0.0\nValid: False\nError: Debugger 未能生成有效代码(多次失败)。\n") + raise ValueError("Debugger 未能生成有效代码(多次失败)。") + + code_path = write_code_file(state.get("result_path", ""), code) + log.info(f"Debugger 生成调试后代码: {code_path}") + + debug_attempts = state.get("debug_attempts", 0) + 1 + debug_total_attempts = state.get("debug_total_attempts", 0) + 1 + debug_reasons = state.get("debug_reasons", []) + [reason] + generated_codes = state.get("generated_codes", []) + [code] + + return { + "generated_codes": generated_codes, + "debug_attempts": debug_attempts, + "debug_total_attempts": debug_total_attempts, + "debug_reasons": debug_reasons, + } + + async def summarizer(state: TableProcessingState) -> TableProcessingState: + """Summarize results using TableReAct Strategy.""" + log.info("🔍🔍 Summarizer Node (TableReAct THINK → ACTION → OBSERVE)") + model = state.request.model or state.get("model", "deepseek-v3.2") + + agent = create_table_react_agent( + name="summarizer", + model_name=model, + max_retries=MAX_REACT_STEPS, + parser_type="json", + ) + + state = await agent.execute(state=state) + agent_result = state.agent_results.get("summarizer", {}) + results = agent_result.get("results", {}) + summary = results.get("answer", "") + react_trace = results.get("react_trace", []) + + if state.get("llm_tracker"): + summary_messages = [ + {"role": "system", "content": "Summarize the following ReAct trace briefly."}, + {"role": "user", "content": f"ReAct Trace: {json.dumps(react_trace, ensure_ascii=False)}"}, + ] + local_summarizer_response = await state["llm_tracker"](summary_messages) + summarizing_trace_summary = local_summarizer_response.content.strip() + log.info(f"🔍🔍 Local Summarizer ReAct Trace Summary:\n{summarizing_trace_summary}") + else: + summarizing_trace_summary = "" + + log.info(f"Summarizer 原始输出:{summary}") + + return { + "summary": summary, + "summarizing_trace_summary": summarizing_trace_summary, + "execution_time": state.get("execution_time", 0.0) + results.get("execution_time", 0.0), + } + + async def finalizer(state: TableProcessingState) -> TableProcessingState: + """Finalize and write evaluation results.""" + log.info("✅️ 进入终止节点,写入评估结果并结束流程") + try: + write_eval_result(state) + except Exception as e: + log.error(f"写入评估结果失败: {e}") + return state + + # ======================================================================= + # Conditional edges + # ======================================================================= + + def should_debug(state: TableProcessingState) -> Literal["debugger", "summarizer", "finalizer"]: + """根据评分与有效性决定是否进入调试节点""" + debug_attempts = state.get("debug_attempts", 0) + valid = state.get("valid", False) + + if not valid and debug_attempts < MAX_DEBUG_ATTEMPTS: + log.info(f"🔄 任务未通过验证,进入调试节点(debug_attempts={debug_attempts})") + return "debugger" + + if not valid and debug_attempts >= MAX_DEBUG_ATTEMPTS: + log.warning(f"⚠️ 达到最大调试次数 ({debug_attempts}),强制终止") + return "finalizer" + + log.info("✅ 任务通过验证,进入总结节点") + if state.get("attempts", 0) >= MAX_GENERATE_ATTEMPTS: + log.warning(f"⚠️ 达到最大重试次数 ({state.get('attempts', 0)}),强制终止") + return "finalizer" + return "summarizer" + + def should_decomposite(state: TableProcessingState) -> Literal["decompositer", "generator"]: + """根据任务复杂度决定是否进入分解节点""" + is_dag = state.get("is_dag", False) + if is_dag: + log.info("🔄 任务复杂,进入分解节点") + return "decompositer" + else: + log.info("✅ 任务简单,跳过分解节点") + return "generator" + + # ======================================================================= + # Register nodes and edges + # ======================================================================= + + nodes = { + "intent_understanding": intent_understanding, + "data_profiling": data_profiling, + "decompositer": decompositer, + "generator": generator, + "evaluator": evaluator, + "summarizer": summarizer, + "debugger": debugger, + "finalizer": finalizer, + "_end_": lambda state: state, + } + + edges = [ + ("intent_understanding", "data_profiling"), + ("decompositer", "generator"), + ("generator", "evaluator"), + ("debugger", "evaluator"), + ("summarizer", "finalizer"), + ("finalizer", "_end_"), + ] + + conditional_edges = { + "evaluator": should_debug, + "data_profiling": should_decomposite, + } + + builder.add_nodes(nodes).add_edges(edges).add_conditional_edges(conditional_edges) + return builder diff --git a/workflow_engine/workflow/wf_table_strategy.py b/workflow_engine/workflow/wf_table_strategy.py new file mode 100644 index 0000000..c94283f --- /dev/null +++ b/workflow_engine/workflow/wf_table_strategy.py @@ -0,0 +1,295 @@ +"""Custom Strategies for Table Processing Workflow. + +This module provides custom execution strategies that extend the base +ExecutionStrategy class. These strategies are designed specifically for +table processing tasks. + +Import this module to auto-register the custom strategies: + from workflow_engine.workflow.wf_table_strategy import TableReActStrategy, TableReactConfig + +""" +from __future__ import annotations + +import json +import tempfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Optional, List, TYPE_CHECKING + +from workflow_engine.logger import get_logger +from workflow_engine.agentroles.cores.strategies import ExecutionStrategy, StrategyFactory +from workflow_engine.table_agent_utils import ( + safe_exec_code, + parse_react_output, + observation_to_message, + truncate_for_log, +) + +if TYPE_CHECKING: + from workflow_engine.agentroles.cores.base_agent import BaseAgent + from workflow_engine.state import MainState + +log = get_logger(__name__) + + +# ==================== 自定义配置类 ==================== + +@dataclass +class TableReactConfig: + """ + Table ReAct 模式配置 - 专门用于表格处理任务 + + 这种模式结合了 ReAct 推理框架和表格处理能力: + - 通过 Python 代码执行 ACTION + - 代码直接操作输入表格文件 + - 支持轨迹追踪和执行时间统计 + + Example: + >>> config = TableReactConfig( + ... model_name="gpt-4", + ... max_retries=5, + ... temperature=0.1 + ... ) + """ + # 策略模式名称 + mode: str = "table_react" + + # 核心参数 + model_name: Optional[str] = None + chat_api_url: Optional[str] = None + temperature: float = 0.0 + max_tokens: int = 65536 + + # 工具相关 + tool_mode: str = "auto" + tool_manager: Optional[Any] = None + + # 解析器相关 + parser_type: str = "json" + parser_config: Optional[Dict[str, Any]] = None + + # 消息历史 + ignore_history: bool = True + message_history: Optional[Any] = None + + # TableReAct 特有配置 + max_retries: int = 3 + validators: Optional[List[Any]] = None + + +# ==================== 自定义策略 ==================== + +class TableReActStrategy(ExecutionStrategy): + """ + Table ReAct Strategy - 基于 ReAct 框架的表格处理策略 + + 与普通 ReactStrategy 不同,本策略专注于表格处理任务: + - 通过 Python 代码执行 ACTION + - 代码直接操作输入表格文件 + - 支持轨迹追踪和执行时间统计 + - 返回标准化的结果格式 {answer, react_trace, execution_time} + + 流程: + THINK → ACTION(代码) → OBSERVE → ... → ANSWER + """ + + def __init__(self, agent: "BaseAgent", config: TableReactConfig): + self.agent = agent + self.config = config + # 从 config 中同步属性到 agent + if config.model_name: + self.agent.model_name = config.model_name + if config.temperature is not None: + self.agent.temperature = config.temperature + if config.max_tokens: + self.agent.max_tokens = config.max_tokens + if config.parser_type: + self.agent.parser_type = config.parser_type + if config.chat_api_url: + self.agent.chat_api_url = config.chat_api_url + + async def execute(self, state: "MainState", **kwargs) -> Dict[str, Any]: + log.info(f"[TableReActStrategy] 执行 {self.agent.role_name},最大重试: {self.config.max_retries}") + + pre_tool_results = await self.agent.execute_pre_tools(state) + return await self._process_table_react_mode(state, pre_tool_results) + + async def _process_table_react_mode( + self, + state: "MainState", + pre_tool_results: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Table ReAct 模式处理 + + 循环调用 LLM,执行 THINK → ACTION(代码) → OBSERVE 的模式, + 直到获得最终结果或达到最大步骤限制。 + + Args: + state: 当前状态对象 + pre_tool_results: 前置工具执行结果 + + Returns: + Dict[str, Any]: 包含数据分析结果或错误信息的字典 + """ + log.info("🔍🔍 开始 TableReAct 模式处理") + + messages = self.agent.build_messages(state, pre_tool_results) + react_trace = [] + execution_time = 0.0 + + for step in range(1, self.config.max_retries + 1): + log.info(f"➡️ ReAct Step {step}/{self.config.max_retries}") + + llm = self.agent.create_llm(state, bind_post_tools=False) + response = await llm.ainvoke(messages) + + # 记录 token 使用 + self._track_token_usage(state, messages, response) + + raw_output = response.content.strip() + log.debug(f"Raw LLM Output:\n{raw_output}") + + # 添加到消息历史 + messages.append({"role": "assistant", "content": raw_output}) + + # 解析 LLM 输出 + parsed = parse_react_output(raw_output) + + # 记录 THINK 步骤 + for tb in parsed["thinks"]: + react_trace.append({"step": len(react_trace) + 1, "type": "think", "content": tb}) + log.info(f"💭 THINK: {tb}") + + # 检查是否有最终答案 + if parsed["answer_obj"] is not None: + answer = parsed["answer_obj"] + log.info("🎯 最终答案已生成!") + break + + # 执行 ACTION 步骤(Python 代码) + code = parsed["action_code"] + if not code: + obs = {"status": "error", "stderr": "未解析到 ACTION 或 ANSWER."} + else: + log.debug(f"🔧 执行代码 ({len(code)} chars)") + log.debug(f"Code Content:\n{code}") + obs = self._execute_table_code(state, code) + execution_time += obs.get("exec_time_sec", 0.0) + + # 记录 OBSERVATION 步骤 + react_trace.append({ + "step": len(react_trace) + 1, + "type": "observation", + "content": json.dumps(obs, ensure_ascii=False) + }) + status_icon = "✅" if obs.get("status") == "success" else "❌" + log.info(f"{status_icon} OBS: {truncate_for_log(obs.get('stdout', obs.get('stderr', '')))}") + + # 添加观察结果到消息 + messages.append({ + "role": "user", + "content": observation_to_message(obs) + "你的代码必须是完全自包含的,可以独立执行。不要依赖之前的上下文或代码片段。" + }) + + else: + answer = {"error": "max_steps_reached"} + + return { + "answer": answer, + "react_trace": react_trace, + "execution_time": execution_time, + } + + def _execute_table_code(self, state: "MainState", code: str) -> Dict[str, Any]: + """ + 执行表格处理代码 + + Args: + state: 当前状态对象 + code: 要执行的 Python 代码 + + Returns: + Dict[str, Any]: 执行结果 {status, stdout/stderr, exec_time_sec} + """ + with tempfile.TemporaryDirectory() as tmp_dir: + script_path = Path(tmp_dir) / "react_step.py" + script_path.write_text(code, encoding="utf-8") + + try: + stdout, exec_time = safe_exec_code( + py_path=script_path, + output_path=str(Path(tmp_dir) / "output"), + input_path=state.get("raw_table_paths", []), + ) + return { + "status": "success", + "stdout": stdout, + "exec_time_sec": round(exec_time, 2) + } + except Exception as e: + log.error(f"💥 代码执行失败: {e}") + return {"status": "error", "stderr": str(e)} + + def _track_token_usage( + self, + state: "MainState", + messages: list, + response: Any + ) -> None: + """记录 token 使用情况""" + try: + from workflow_engine.llm.text import extract_token_usage_from_response + token_usage = extract_token_usage_from_response(response) + if token_usage: + log.info(f"Token 使用: {token_usage}") + + if hasattr(state, "llm_tracker") and state.llm_tracker: + state.llm_tracker( + model=self.agent.model_name, + messages=[{"role": msg.get("role", ""), "content": msg.get("content", "")} for msg in messages], + response=response, + token_usage=token_usage, + temperature=self.agent.temperature + ) + except Exception as e: + log.debug(f"Token 追踪失败: {e}") + + +# ===== 注册自定义策略到工厂 ===== +StrategyFactory.register("table_react", TableReActStrategy) +log.info("✓ TableReActStrategy 已自动注册") + + +def create_table_react_agent(name: str, model_name: str, max_retries: int = 3, **kwargs): + """ + 创建使用 TableReActStrategy 的 Agent + + 这种策略专门用于表格处理任务,通过 Python 代码执行 ACTION。 + + Args: + name: Agent 名称 + model_name: 模型名称 + max_retries: 最大重试次数 + **kwargs: 其他配置参数 + + Returns: + Agent 实例 + """ + from workflow_engine.agentroles import create_agent + + config = TableReactConfig( + model_name=model_name, + max_retries=max_retries, + temperature=kwargs.get("temperature", 0.1), + max_tokens=kwargs.get("max_tokens", 20480), + parser_type=kwargs.get("parser_type", "json"), + chat_api_url=kwargs.get("chat_api_url"), + tool_mode=kwargs.get("tool_mode", "auto"), + ignore_history=kwargs.get("ignore_history", True), + ) + return create_agent(name, config=config) + + +if __name__ == "__main__": + print("已注册的策略:", list(StrategyFactory._strategies.keys()))