diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..227449ad --- /dev/null +++ b/.gitattributes @@ -0,0 +1,38 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bin.* filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zstandard filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +# Audio files - uncompressed +*.pcm filter=lfs diff=lfs merge=lfs -text +*.sam filter=lfs diff=lfs merge=lfs -text +*.raw filter=lfs diff=lfs merge=lfs -text +# Audio files - compressed +*.aac filter=lfs diff=lfs merge=lfs -text +*.flac filter=lfs diff=lfs merge=lfs -text +*.mp3 filter=lfs diff=lfs merge=lfs -text +*.ogg filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index cd4efcd3..ed84d2a1 100644 --- a/.gitignore +++ b/.gitignore @@ -160,6 +160,8 @@ tutorial/example_deep_finance/scripts/* flash_attn-2.8.*.whl tutorial/example_deep_finance/prepare_data/* tutorial/example_deep_finance/judge/analytical_sufficiency/* +tutorial/example_deep_finance/output_report/* +dataset_gsm8k/* .dockerignore benchmark_datasets diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py index ea951d5a..685798f3 100644 --- a/ajet/utils/metric_helper/reward_metric_helper.py +++ b/ajet/utils/metric_helper/reward_metric_helper.py @@ -83,7 +83,10 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str openjudge_graders = [ "presentation_quality", "grounding", - "planning" + "planning", + "audit", + "traceability", + "cgcv" ] for grader_name in openjudge_graders: diff --git a/tutorial/example_deep_finance/deep_finance.md b/tutorial/example_deep_finance/deep_finance.md index 1ac6d0c0..33820bff 100644 --- a/tutorial/example_deep_finance/deep_finance.md +++ b/tutorial/example_deep_finance/deep_finance.md @@ -1 +1,358 @@ -# deep_finance \ No newline at end of file +# DeepFinance: 通过强化学习训练金融深度研究 Agent + +## 概述 + +DeepFinance 是基于 AgentJet 框架构建的金融深度研究 Agent 训练方案。其核心目标是:通过 GRPO 强化学习,训练 LLM 自主调用金融工具、收集多源数据、进行交叉验证,并最终生成结构化、有据可查的投资研究报告。 + +与传统 SFT 微调不同,DeepFinance 不依赖人工标注的「标准回答」来监督训练,而是设计了一套 **多维度奖励体系** 作为 RL 训练信号——让模型在「写报告」的过程中自行探索最优策略,并通过 5 个正交维度的评分反馈来持续改进。 + +**训练闭环**: + +```plain +金融问题 → Agent 调用工具收集数据 → 生成研究报告 → 多维度 Judge 评分 → GRPO 策略更新 → 下一轮生成 +``` + +------ + +## Pipeline + +整个训练流水线由 4 个核心模块组成: + +| 模块 | 文件 | 职责 | +| ------------ | ---------------------------------- | --------------------------------------------------- | +| **Reader** | `deep_finance_reader.py` | 加载 JSON 训练数据,组装 System Prompt + User Query | +| **Workflow** | `deep_finance.py` | 定义 ReAct Agent 的多轮交互逻辑,维护对话历史 | +| **Judge** | `deep_finance_judge.py` + `judge/` | 多维度奖励评分(核心创新) | +| **配置** | `deep_finance.yaml` / `*.sh` | 训练参数、奖励权重、环境配置 | + +```plain +┌─────────────────────────────────────────────────────────────┐ +│ AgentJet 训练框架 │ +│ │ +│ ┌──────────────┐ ┌──────────────────────┐ │ +│ │ DeepFinance │ │ ExampleDeepResearch │ │ +│ │ Reader │───>│ Protocol (Workflow) │ │ +│ │ 数据加载 + │ │ ReAct Agent 多轮交互 │ │ +│ │ Prompt 组装 │ └──────────┬───────────┘ │ +│ └──────────────┘ │ │ +│ v │ +│ ┌────────────────────────┐ │ +│ │ EnvService (FinWorld) │ │ +│ │ 19 个金融工具 + MCP │ │ +│ │ MongoDB 缓存加速 │ │ +│ └────────────┬───────────┘ │ +│ │ │ +│ v │ +│ ┌────────────────────────┐ │ +│ │ DeepFinanceJudge │ │ +│ │ 多 维 Reward 评分 │ │ +│ │ (基于 OpenJudge) │ │ +│ └────────────┬───────────┘ │ +│ │ │ +│ v │ +│ ┌────────────────────────┐ │ +│ │ GRPO Trainer (verl) │ │ +│ │ 多机多卡 Ray 集群 │ │ +│ └────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +------ + +## Workflow设计 + +### 两阶段深度研究流程 + +Agent 的 System Prompt(`prompt/finance_analyst_prompt.md`)要求遵循两阶段研究方法: + +**第一阶段:先大纲后调研** + +1. 理解用户问题类型(个股分析/行业研究/事件解读/宏观分析/股票检索) +2. **先输出研究大纲**(一级/二级标题 + 每节的 Key Questions),此阶段不调用工具 +3. 按大纲逐段调研,每轮调用工具后做小结 + +**第二阶段:深度分析与报告生成** + +1. 当数据充分后,基于真实数据生成 Markdown 格式研究报告 +2. 写作中发现证据不足时允许追加 1-2 轮工具调用补充取证 +3. 报告末尾添加 `[TASK_COMPLETED]` 标记 + +### 引用规范 + +Agent 被要求使用学术论文风格的引用标注: + +- 所有关键事实句句末必须添加引用编号 `[n]` +- 报告末尾必须包含 `## References` 小节 +- 引用必须可追溯到实际工具返回的数据,禁止伪造 + +------ + +## 工具体系 + +DeepFinance 集成了 **19 个金融工具**,通过 MCP(Model Context Protocol)协议与 EnvService 交互,覆盖金融研究的完整数据需求。 + +| 类别 | 工具 | 功能 | +| ------------------ | ----------------------- | ----------------------------------- | +| **实体与计算** | `extract_entities_code` | 从自然语言中提取金融实体并查找代码 | +| | `history_calculate` | A股历史股价分析(支持自然语言提问) | +| **通用能力** | `dashscope_search` | 互联网搜索 | +| | `execute_code` | Python 代码执行 | +| | `execute_shell` | Shell 命令执行 | +| **同花顺专项数据** | `crawl_ths_company` | 上市公司基本资料 | +| | `crawl_ths_holder` | 股东研究信息 | +| | `crawl_ths_operate` | 经营分析信息 | +| | `crawl_ths_finance` | 财务分析信息 | +| | `crawl_ths_worth` | 盈利预测信息 | +| | `crawl_ths_news` | 新闻公告信息 | +| | `crawl_ths_concept` | 概念题材信息 | +| | `crawl_ths_equity` | 股本结构信息 | +| | `crawl_ths_capital` | 资本运作信息 | +| | `crawl_ths_position` | 主力持仓信息 | +| | `crawl_ths_bonus` | 分红融资信息 | +| | `crawl_ths_event` | 公司大事信息 | +| | `crawl_ths_field` | 行业对比信息 | + +工具调用规范: + +- 每次最多调用 **3 个工具**,采用多轮次渐进式调研 +- Agent 必须先搜索确认信息(如股票代码),再进行深度查询 +- 每轮工具调用后先做小结,再决定下一步调研方向 + +------ + +## 奖励设计(Reward Design) + +这是 DeepFinance 的核心创新。我们设计了 **5 个正交维度** 的评分器(Grader),通过可配置的权重加权融合为最终 reward,并额外引入工具调用惩罚机制。 + +### 总体公式 + +```plain +final_reward = Σ(w_i × grader_i_score) + tool_penalty +``` + +其中各 grader 权重归一化(`Σw_i = 1`),`tool_penalty` 为额外惩罚项。 + +### 5 个评分维度总览 + +| 维度 | 名称 | 评估对象 | 核心问题 | +| ---------------- | ------------------- | ------------------ | ------------------------------------------------ | +| **分析充分性** | RM Gallery | 报告整体质量 | 分析是否充分?逻辑是否合理? | +| **呈现质量** | PresentationQuality | 报告排版与结构 | 读者体验好不好?信息是否易获取? | +| **引用规范性** | Grounding | 引用的覆盖与真实性 | 关键事实是否都有引用?引用是否真实? | +| **证据溯源** | EBTU | 原子断言的证据锚定 | 每个数字/事实能否追溯到工具返回的原始数据? | +| **引用逻辑审计** | Audit | 引用的逻辑蕴含关系 | 引用是否真正支撑了对应的陈述?有没有夸大或捏造? | + +默认权重配置(可在 shell 脚本中调整): + +```bash +RM_WEIGHT=0.5 # 分析充分性 +PRESENTATION_QUALITY_WEIGHT=0.2 # 呈现质量 +GROUNDING_WEIGHT=0.1 # 引用规范性 +EBTU_WEIGHT=0.2 # 证据溯源(可选启用) +AUDIT_WEIGHT=0.0 # 引用逻辑审计(可选启用) +``` + +------ + +### 1) 分析充分性(RM Gallery) + +**目标**:评估报告的分析深度、覆盖面和逻辑性——回答「分析得好不好」。 + +**机制**:使用 `finance_composition` 评估器,通过独立的 Judge LLM( `qwen-max`)对生成报告与参考答案进行对比评估。 + +**评估维度(按金融 domain 分域)**: + +- 分析深度:对核心问题的挖掘是否足够深入 +- 覆盖面:是否覆盖了问题涉及的多个分析维度(基本面、财务、估值、行业、新闻等) +- 逻辑性:分析推理链条是否完整、结论是否有据可依 + +**输入输出**: + +- 输入:用户 Query + Agent 生成的报告 + 参考答案 +- 输出:`[0, 1]` 归一化分数 + +------ + +### 2) 呈现质量(Presentation Quality) + +**目标**:评估报告的用户体验与信息架构——回答「写得好不好看、好不好读」。 + +**严格不评估**:事实真伪、引用准确性、内容深度(这些由其他 Grader 负责)。 + +**8 项子指标(1/3/5 分制)**: + +| 分类 | 指标 | 5分标准 | +| -------------------------- | --------------- | ------------------------------------------------ | +| **Scan 可扫描性** | A1 结论先行 | 开头有独立摘要/TL;DR,读者无需滚动即可获取主结论 | +| | A2 结构导航 | 层级分明(H1/H2/H3),长文有清晰小标题路标 | +| | A3 视觉重点 | 精准使用加粗/斜体强调核心洞察,信噪比高 | +| **Structuring 信息结构化** | B1 密集信息解构 | 复杂数据用表格/嵌套列表呈现,一目了然 | +| | B2 对比对齐 | 方案A vs B / 历史 vs 现状使用表格,维度横向可比 | +| | B3 一致性与渲染 | 格式统一,Markdown 渲染完美 | +| **Editorial 编辑清晰度** | C1 论证链可视化 | 逻辑链条可视(主张→证据→结论),引用锚点清晰 | +| | C2 风险与行动 | 独立板块列出风险/局限性及下一步建议 | + +**评分计算**: + +```plain +score = Σ(8项得分) / 40 # 归一化到 [0, 1] +``` + +**反刷分机制**:空表格、无意义重复列表、为格式而格式 → 直接判 1 分。 + +------ + +### 3) 引用规范性(Grounding) + +**目标**:评估报告的引用覆盖率和引用真实性——回答「关键事实都有出处吗?引用是真的吗?」 + +**评估流程**: + +1. 从对话轨迹中提取 User Query、Evidence(工具调用与返回)、最终报告 +2. LLM 审计员识别报告中的所有「关键事实句」(含数字/日期/财务指标/确定性陈述) +3. 检查每个关键事实句句末是否有引用标记 `[n]` +4. 检查引用是否在 References 中有合法条目(有效 URL 或完整的 no-url 记录) +5. 检查引用内容与 Evidence 是否一致(检测虚假引用) + +**输出字段**: + +- `total_key_facts`:关键事实句总数 +- `cited_key_facts`:句末有引用的关键事实句数 +- `fake_count`:引用内容与证据明显矛盾的数量 +- `missing_count`:缺少引用的关键事实句数 +- `invalid_reference_nums`:不合规的引用编号 + +**评分计算**: + +```plain +citation_coverage = cited_key_facts / total_key_facts # 引用覆盖率 +grounding_score = 1 - fake_count / cited_key_facts # 引用真实性 +final_score = 0.5 × coverage + 0.5 × grounding # 综合分数 +``` + +------ + +### 4) 证据溯源(EBTU - Evidence-Backed Trace Units) + +**目标**:对报告中的每个「原子断言」做证据锚定审计——回答「每个数字、每个事实,能否追溯到工具返回的原始数据?」 + +**核心理念:证据优先(Evidence-first)**。审计官必须先给出证据锚点(step + quote),再下裁决,严禁先下结论再找证据。 + +**审计流程**: + +1. 从报告中提取所有原子断言(Trace Units),标记类型(numeric/temporal/event/comparison/causal 等) +2. 标记硬度:`hard`(确定性事实) / `soft`(明确标注为推测/假设) +3. 对每个断言在 Evidence 中寻找锚点(anchors),要求: + +- - 精确到 step 编号和原文引用(quote ≤ 120 字) + - 数字/日期必须能在 Evidence 原文中找到对应 + +1. 给出裁决(verdict): + +| Verdict | 含义 | +| ---------------- | ----------------------------------------- | +| `supported` | 锚点直接支持断言 | +| `contradicted` | 锚点与断言明确冲突 | +| `no_evidence` | Evidence 中找不到支撑,且断言是确定性表述 | +| `speculative_ok` | 断言明确为推测/假设,未伪装成事实 | +| `unclear` | Evidence 相关但不足以支持或反驳 | + +1. 标记问题类型(issue):`entity_mismatch` / `time_mismatch` / `value_mismatch` / `scope_mismatch` / `logic_leap` / `over_precision` / `missing_anchor` + +**评分计算**(确定性打分,由 Python 代码计算,非 LLM 输出): + +```plain +base = (supported - 1.4×contradicted - 0.9×no_evidence - 0.4×unclear) / hard_units +misattrib_factor = max(0, 1 - 0.7 × misattrib_rate) # 错误归因惩罚 +selection_factor = min(1, extracted_units / expected) # 覆盖率因子 +cov_factor = 0.65 + 0.35 × digit_coverage # 数字/日期覆盖 +score = base × misattrib_factor × selection_factor × cov_factor +``` + +关键设计:LLM 只负责结构化输出(断言提取 + 锚点标注 + 裁决),分数完全由代码确定性计算,避免 LLM 自评分的不稳定性。 + +------ + +### 工具调用惩罚 + +在加权融合分数之外,额外施加工具调用惩罚,鼓励 Agent 积极使用工具收集数据: + +| 工具调用次数 | 惩罚 | +| ------------ | ------------- | +| 0 次 | -1.0 | +| 1-2 次 | -0.5 | +| ≥3 次 | 0.0(无惩罚) | + +------ + +## Quick Start + +### 环境准备 + +1. 安装 AgentJet 及依赖 + +```bash +cd /path/to/AgentJet +bash install.sh # TODO:把这部分缩减到一个install:https://yuque.alibaba-inc.com/bayotg/wxz7sb/qdesuu33621x2yhi +``` + +1. 配置 `.env` 文件(API 密钥、模型路径、数据路径等): + +```bash +# .env 示例 +MODEL_PATH=/path/to/Qwen3-8B +TRAIN_DATA_PATH=/path/to/train.json +VAL_DATA_PATH=/path/to/val.json +TRAIN_REF_ANS_PATH=/path/to/train_ref_answers.json +VAL_REF_ANS_PATH=/path/to/val_ref_answers.json +CKPT_SAVE_PATH=/path/to/checkpoints +OPENJUDGE_API_KEY=your_api_key +RM_API_KEY=your_api_key +``` + +1. 启动 EnvService(金融工具服务) + +### 单机调试模式 + +```bash +bash tutorial/example_deep_finance/deep_finance_single.sh +``` + +该脚本以 `--backbone="debug"` 模式运行,适合验证工作流和调试。 + +### 多机训练模式 + +```bash +# 在 PAI-DLC 或多机环境中提交 +bash tutorial/example_deep_finance/deep_finance.sh +``` + +该脚本会: + +1. 从 YAML 模板动态生成配置文件 +2. 在 Master 节点启动 Ray Head + 训练任务 +3. Worker 节点自动加入 Ray 集群 + +### 关键参数说明 + +| 参数 | 默认值 | 说明 | +| ----------------------------- | ------ | ------------------------------------- | +| `NUM_REPEAT` | 4 | Group size,每个 query rollout 的次数 | +| `NUM_STEPS` | 6 | 每个样本的最大交互轮数 | +| `TRAIN_BATCH_SIZE` | 32 | 训练 batch size | +| `RM_WEIGHT` | 0.5 | 分析充分性权重 | +| `PRESENTATION_QUALITY_WEIGHT` | 0.25 | 呈现质量权重 | +| `GROUNDING_WEIGHT` | 0.25 | 引用规范性权重 | +| `EBTU_WEIGHT` | 0.0 | 证据溯源权重(可选启用) | +| `AUDIT_WEIGHT` | 0.0 | 引用逻辑审计权重(可选启用) | + +------ + +## 实验结果 + + +![img](https://intranetproxy.alipay.com/skylark/lark/0/2026/png/107756372/1771843906200-9dd35ac4-f71e-40dc-b130-f03e3e6bae6a.png) + +![img](https://intranetproxy.alipay.com/skylark/lark/0/2026/png/107756372/1771843940824-4e3637d7-a16e-4994-8878-242effc2c0d7.png)![img](https://intranetproxy.alipay.com/skylark/lark/0/2026/png/107756372/1771843950142-09def779-5521-41f0-a457-a7715a819cc7.png) + + diff --git a/tutorial/example_deep_finance/deep_finance.py b/tutorial/example_deep_finance/deep_finance.py index 470e6225..baffb0b3 100644 --- a/tutorial/example_deep_finance/deep_finance.py +++ b/tutorial/example_deep_finance/deep_finance.py @@ -9,7 +9,7 @@ # 创建信号量,允许同时12个线程运行 -sem = threading.Semaphore(30) +sem = threading.Semaphore(60) class ExampleDeepResearchProtocol(Workflow): @@ -125,9 +125,9 @@ async def execute( if info: if 'tool_stats' in info: latest_tool_stats = info['tool_stats'] - if latest_tool_stats.get('total_calls', 0) > 0: - logger.info(f"步骤 {step + 1} 工具统计: 调用={latest_tool_stats.get('total_calls', 0)}, " - f"成功率={latest_tool_stats.get('success_rate', 0):.1f}%") + # if latest_tool_stats.get('total_calls', 0) > 0: + # logger.info(f"步骤 {step + 1} 工具统计: 调用={latest_tool_stats.get('total_calls', 0)}, " + # f"成功率={latest_tool_stats.get('success_rate', 0):.1f}%") if 'reward_stats' in info: latest_reward_stats = info['reward_stats'] # 累加工具调用时间 diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index bee02ac2..f6121655 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -15,6 +15,10 @@ JUDGE_CONCURRENCY=10 RM_WEIGHT=0.5 PRESENTATION_QUALITY_WEIGHT=0.25 GROUNDING_WEIGHT=0.25 +CGCV_WEIGHT=0.0 # 不使用 CGCV,设为 0 +AUDIT_WEIGHT=0.0 # 不使用 Audit,设为 0 +TRACEABILITY_WEIGHT=0.0 # 不使用 Traceability,设为 0 +EBTU_WEIGHT=0.0 # 不使用 EBTU,设为 0 # 训练参数配置 NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 @@ -22,6 +26,9 @@ TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 +# Env Service URL 配置 +ENV_SERVICE_URL="http://127.0.0.1:8080" # 环境服务地址 + # 主目录(需要更改) export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new" @@ -57,6 +64,10 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ -e "s|{{PRESENTATION_QUALITY_WEIGHT}}|${PRESENTATION_QUALITY_WEIGHT}|g" \ -e "s|{{GROUNDING_WEIGHT}}|${GROUNDING_WEIGHT}|g" \ + -e "s|{{CGCV_WEIGHT}}|${CGCV_WEIGHT}|g" \ + -e "s|{{AUDIT_WEIGHT}}|${AUDIT_WEIGHT}|g" \ + -e "s|{{TRACEABILITY_WEIGHT}}|${TRACEABILITY_WEIGHT}|g" \ + -e "s|{{EBTU_WEIGHT}}|${EBTU_WEIGHT}|g" \ -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ -e "s|{{RM_LLM}}|${RM_LLM}|g" \ -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ @@ -68,10 +79,11 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ -e "s|{{CKPT_SAVE_PATH}}|${CKPT_SAVE_PATH}|g" \ + -e "s|{{ENV_SERVICE_URL}}|${ENV_SERVICE_URL}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} echo "配置文件已生成: ${CONFIG_FILE}" -echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" +echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, CGCV=${CGCV_WEIGHT}, Audit=${AUDIT_WEIGHT}, Traceability=${TRACEABILITY_WEIGHT}, EBTU=${EBTU_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" #=============================================================================== # 3. 环境配置 diff --git a/tutorial/example_deep_finance/deep_finance.yaml b/tutorial/example_deep_finance/deep_finance.yaml index 166381da..e5de33da 100644 --- a/tutorial/example_deep_finance/deep_finance.yaml +++ b/tutorial/example_deep_finance/deep_finance.yaml @@ -37,7 +37,7 @@ ajet: max_env_worker: 64 # 增加环境并行数 max_num_seqs: 64 # 增加VLLM并发序列数 max_response_length_in_one_turn: 8000 - max_model_len: 50000 + max_model_len: 40960 agent_madness_reward: 0.0 compute_madness_checklist: None multi_turn: diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py index ce859cbc..733cfb00 100644 --- a/tutorial/example_deep_finance/deep_finance_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -15,7 +15,7 @@ from openjudge.models.openai_chat_model import OpenAIChatModel from openjudge.runner.grading_runner import GraderConfig, GradingRunner -from tutorial.example_deep_finance.judge import PresentationQualityGrader, GroundingGrader +from tutorial.example_deep_finance.judge import PresentationQualityGrader, GroundingGrader, CGCVGrader, AuditGrader, TraceabilityRewardGrader, EBTUTraceabilityGrader @@ -103,7 +103,11 @@ def _setup_weights(self): self.w = { "rm": getattr(cfg, "rm_weight", 1.0) if cfg else 1.0, # RM Gallery 权重 "presentation_quality": getattr(cfg, "presentation_quality_weight", 0.25) if cfg else 0.25, - "grounding": getattr(cfg, "grounding_weight", 0.25) if cfg else 0.25, + "grounding": getattr(cfg, "grounding_weight", 0.0) if cfg else 0.0, # 引用规范性评估 + "cgcv": getattr(cfg, "cgcv_weight", 0.25) if cfg else 0.25, # Citation-Grounded Claim Verification + "audit": getattr(cfg, "audit_weight", 0.0) if cfg else 0.0, # Audit Grader: audit reward 引用逻辑审计 + "traceability": getattr(cfg, "traceability_weight", 0.0) if cfg else 0.0, # 可追溯性/可核验性审计 (TVR) + "ebtu": getattr(cfg, "ebtu_weight", 0.0) if cfg else 0.0, # Audit Grader: audit reward EBTU证据优先可追溯性审计 } # 归一化(注意:action_loop 是惩罚项,不参与归一化;rm 需要参与归一化) @@ -256,6 +260,26 @@ def extract_report_content(data: Dict) -> str: grader=GroundingGrader(model=model), mapper=lambda data: {"traj": data}, ), + # CGCV: Citation-Grounded Claim Verification - 引用锤定的断言验证 + "cgcv": GraderConfig( + grader=CGCVGrader(model=model), + mapper=lambda data: {"traj": data}, + ), + # Audit: 引用逻辑审计 - 验证引用是否严格符合逻辑蕴含原则 + "audit": GraderConfig( + grader=AuditGrader(model=model), + mapper=lambda data: {"traj": data}, + ), + # Traceability: 可追溯性/可核验性审计 - 验证报告断言是否有证据锚点支撑 + "traceability": GraderConfig( + grader=TraceabilityRewardGrader(model=model), + mapper=lambda data: {"traj": data}, + ), + # Audit Grader: audit reward EBTU证据优先可追溯性审计 - Evidence-Backed Trace Units + "ebtu": GraderConfig( + grader=EBTUTraceabilityGrader(model=model), + mapper=lambda data: {"traj": data}, + ), } def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> Tuple[float, bool]: diff --git a/tutorial/example_deep_finance/deep_finance_single.sh b/tutorial/example_deep_finance/deep_finance_single.sh index e794dff0..67de294d 100644 --- a/tutorial/example_deep_finance/deep_finance_single.sh +++ b/tutorial/example_deep_finance/deep_finance_single.sh @@ -15,6 +15,10 @@ JUDGE_CONCURRENCY=10 RM_WEIGHT=0.5 PRESENTATION_QUALITY_WEIGHT=0.25 GROUNDING_WEIGHT=0.25 +CGCV_WEIGHT=0.0 # 不使用 CGCV,设为 0 +AUDIT_WEIGHT=0.0 # 不使用 Audit,设为 0 +TRACEABILITY_WEIGHT=0.0 # 不使用 Traceability,设为 0 +EBTU_WEIGHT=0.0 # 不使用 EBTU,设为 0 # 训练参数配置 NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 @@ -22,10 +26,19 @@ TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 +# Env Service URL 配置 +ENV_SERVICE_URL="http://127.0.0.1:8080" # 环境服务地址 + # 主目录(需要更改) export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new" -NNODES=${WORLD_SIZE} +# 单机调试配置(默认值) +NNODES=${WORLD_SIZE:-1} +GPUS_PER_NODE=8 +CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") +LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" +TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" +mkdir -p ${LOG_DIR} # 涉密的配置(API_KEY以及模型、数据位置)从.env读取 cd ${AJET_ROOT} @@ -42,6 +55,9 @@ else echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" fi +export MODEL_PATH="/mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-8B" + + #=============================================================================== # 2. 动态生成配置文件 (从yaml template生成yaml) #=============================================================================== @@ -57,6 +73,10 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ -e "s|{{PRESENTATION_QUALITY_WEIGHT}}|${PRESENTATION_QUALITY_WEIGHT}|g" \ -e "s|{{GROUNDING_WEIGHT}}|${GROUNDING_WEIGHT}|g" \ + -e "s|{{CGCV_WEIGHT}}|${CGCV_WEIGHT}|g" \ + -e "s|{{AUDIT_WEIGHT}}|${AUDIT_WEIGHT}|g" \ + -e "s|{{TRACEABILITY_WEIGHT}}|${TRACEABILITY_WEIGHT}|g" \ + -e "s|{{EBTU_WEIGHT}}|${EBTU_WEIGHT}|g" \ -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ -e "s|{{RM_LLM}}|${RM_LLM}|g" \ -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ @@ -68,10 +88,11 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ -e "s|{{CKPT_SAVE_PATH}}|${CKPT_SAVE_PATH}|g" \ + -e "s|{{ENV_SERVICE_URL}}|${ENV_SERVICE_URL}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} echo "配置文件已生成: ${CONFIG_FILE}" -echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" +echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, CGCV=${CGCV_WEIGHT}, Audit=${AUDIT_WEIGHT}, Traceability=${TRACEABILITY_WEIGHT}, EBTU=${EBTU_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" #=============================================================================== @@ -115,15 +136,16 @@ export RAY_CLUSTER_MODE="multi_node" #=============================================================================== # 6. 主流程 #=============================================================================== -log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" -mkdir -p ${LOG_DIR} -mkdir -p $(dirname ${CONFIG_FILE}) +log "单机调试模式: NNODES=${NNODES}, GPUS_PER_NODE=${GPUS_PER_NODE}" #=============================================================================== # 6.1 Master 节点启动流程 #=============================================================================== # 启动训练任务(最核心) +# 请注意只有单节点需要--with-ray 多节点应该删除 python ajet/launcher.py \ --conf ${CONFIG_FILE} \ + --with-deepfinance \ + --with-ray \ --backbone="debug" \ 2>&1 | tee ${TRAIN_LOG} diff --git a/tutorial/example_deep_finance/judge/__init__.py b/tutorial/example_deep_finance/judge/__init__.py index 75c8ceff..235247f9 100644 --- a/tutorial/example_deep_finance/judge/__init__.py +++ b/tutorial/example_deep_finance/judge/__init__.py @@ -1,6 +1,10 @@ # 使得可以通过 from judge import PresentationQualityGrader 直接引用 from .grounding.grader import GroundingGrader from .presentation_quality.grader import PresentationQualityGrader +from .cgcv.grader import CGCVGrader +from .audit.grader import AuditGrader +from .traceability.grader import TraceabilityRewardGrader +from .ebtu.grader import EBTUTraceabilityGrader # from .research_depth.grader import ResearchDepthGrader # from .research_breadth.grader import ResearchBreadthGrader @@ -8,4 +12,4 @@ # from .grounding.grader import GroundingGrader # from .research_breadth.grader import ResearchBreadthGrader # __all__ = ["PresentationQualityGrader", "GroundingGrader", "ResearchDepthGrader", "ResearchBreadthGrader"] -__all__ = ["PresentationQualityGrader", "GroundingGrader"] +__all__ = ["PresentationQualityGrader", "GroundingGrader", "CGCVGrader", "AuditGrader", "TraceabilityRewardGrader", "EBTUTraceabilityGrader"] diff --git a/tutorial/example_deep_finance/judge/audit/__init__.py b/tutorial/example_deep_finance/judge/audit/__init__.py new file mode 100644 index 00000000..7e4d05c3 --- /dev/null +++ b/tutorial/example_deep_finance/judge/audit/__init__.py @@ -0,0 +1,4 @@ +"""Grounding Grader - 引用逻辑审计""" +from .grader import AuditGrader + +__all__ = ["AuditGrader"] \ No newline at end of file diff --git a/tutorial/example_deep_finance/judge/audit/grader.py b/tutorial/example_deep_finance/judge/audit/grader.py new file mode 100644 index 00000000..3b8a9806 --- /dev/null +++ b/tutorial/example_deep_finance/judge/audit/grader.py @@ -0,0 +1,215 @@ +"""Audit Grader - 引用逻辑审计 (OpenJudge logic version)""" +from __future__ import annotations + +import os +from typing import Any, Dict, List, Tuple + +from openjudge.graders.base_grader import BaseGrader +from openjudge.graders.schema import GraderScore + +try: + from openjudge.models import OpenAIChatModel +except Exception: + from openjudge.models.openai_chat_model import OpenAIChatModel + +from .prompt import CITATION_INTEGRITY_PROMPT_COT, CITATION_INTEGRITY_USER_TEMPLATE +from .json_utils import strict_load_json, validate_integrity_shape, construct_reward_prompt + + +class AuditGrader(BaseGrader): + """ + 引用逻辑审计 Grader + + - 输入:traj (完整对话轨迹) + - 输出:GraderScore(score, reason) + - score: integrity_score (Supported / Total) + - reason: 审计摘要,包括错误分布和定性总结 + """ + + def __init__( + self, + model: OpenAIChatModel, + name: str = "citation_integrity", + **kwargs: Any, + ): + super().__init__(name=name, **kwargs) + self.model = model + + @staticmethod + def create_default_model( + model_name: str, + api_key: str | None = None, + base_url: str | None = None, + deterministic: bool = True, + enable_thinking: bool = False, + seed: int = 42, + ) -> OpenAIChatModel: + api_key = api_key or os.getenv("OPENAI_API_KEY") + base_url = base_url or os.getenv("OPENAI_BASE_URL") + + extra_body: Dict[str, Any] = {} + if deterministic: + extra_body.update( + { + "temperature": 0.0, + "top_p": 1.0, + "seed": seed, + } + ) + if enable_thinking is False: + extra_body["enable_thinking"] = False + + kwargs: Dict[str, Any] = {"model": model_name} + if api_key: + kwargs["api_key"] = api_key + if base_url: + kwargs["base_url"] = base_url + if extra_body: + kwargs["extra_body"] = extra_body + + return OpenAIChatModel(**kwargs) + + async def _aevaluate( + self, + traj: Any, + **_: Any, + ) -> GraderScore: + """ + 入口:必须喂 traj(完整对话轨迹) + + Args: + traj: 对话轨迹,支持以下格式: + - [{"role": ..., "content": ...}, ...] 直接消息列表 + - {"messages": [...]} 包含 messages 字段的 dict + - {"traj": [[...]]} 包含 traj 字段的 dict(双重嵌套) + + Returns: + GraderScore(name, score, reason) + """ + # 1. 提取 messages(兼容多种格式) + if isinstance(traj, dict): + if "traj" in traj: + # 支持 {"traj": [[...]]} 格式 + traj_list = traj["traj"] + if traj_list and isinstance(traj_list[0], list): + messages_list = traj_list[0] + else: + messages_list = traj_list + else: + messages_list = traj.get("messages", []) + elif isinstance(traj, list): + messages_list = traj + else: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: traj must be list or dict with 'messages'/'traj'", + ) + + if not messages_list: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty trajectory", + ) + + # 2. 构建 Prompt + # 使用新的 System Prompt 和 User Template + user_prompt = construct_reward_prompt(messages_list, CITATION_INTEGRITY_USER_TEMPLATE) + + messages = [ + {"role": "system", "content": CITATION_INTEGRITY_PROMPT_COT}, + {"role": "user", "content": user_prompt} + ] + + # 3. 模型推理 + try: + resp = await self.model.achat(messages) + raw_text = getattr(resp, "content", str(resp)) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ) + + # 4. JSON 解析与验证 + obj, jerr = strict_load_json(raw_text) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"ParseError: {jerr}; raw[:200]={snippet}", + ) + + # 使用新的验证逻辑 validate_integrity_shape + obj, serr = validate_integrity_shape(obj) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"SchemaError: {serr}; raw[:200]={snippet}", + ) + + # 5. 计算分数与生成理由 + score, reason = self._compute_scores(obj) + return GraderScore(name=self.name, score=score, reason=reason) + + def _compute_scores(self, obj: Dict[str, Any]) -> Tuple[float, str]: + """ + 基于 audit_trail 和 integrity_score 计算最终结果 + """ + # 直接获取模型计算的 integrity_score,若缺失则手动计算 + audit_trail = obj.get("audit_trail", []) + total_citations = len(audit_trail) + + # 统计各Verdict数量 + verdict_counts = { + "Supported": 0, + "Overstated": 0, + "Contradicted": 0, + "Hallucinated": 0, + "Irrelevant": 0 + } + + for item in audit_trail: + v = item.get("verdict", "Irrelevant") + if v in verdict_counts: + verdict_counts[v] += 1 + else: + verdict_counts["Irrelevant"] += 1 + + supported_count = verdict_counts["Supported"] + + # 优先使用模型输出的 score,如果有误则回退到手动计算 + # model_score = obj.get("integrity_score") + # if isinstance(model_score, (float, int)) and 0.0 <= model_score <= 1.0: + # final_score = float(model_score) + # else: + final_score = supported_count / total_citations if total_citations > 0 else 0.0 + + # 构建 Reason + # 格式: Score: 0.80 | Total: 10 | Supp: 8, Over: 1, Hallu: 1 | Summary: ... + stats_parts = [] + for k, v in verdict_counts.items(): + if v > 0: + stats_parts.append(f"{k[:4]}:{v}") # 缩写 Verdict + + stats_str = ", ".join(stats_parts) + qualitative = obj.get("qualitative_summary", "No summary provided.") + + # 截取主要错误示例 (如果有) + errors = [x for x in audit_trail if x.get("verdict") != "Supported"] + error_msg = "" + if errors: + first_err = errors[0] + error_msg = f" | Example Error ([{first_err.get('citation_id')}]) {first_err.get('verdict')}: {first_err.get('logic_analysis')}" + + reason = ( + f"Score: {final_score:.2f} | Total: {total_citations} | {stats_str} | " + f"Summary: {qualitative}{error_msg}" + ) + + return round(final_score, 4), reason[:1000] \ No newline at end of file diff --git a/tutorial/example_deep_finance/judge/audit/json_utils.py b/tutorial/example_deep_finance/judge/audit/json_utils.py new file mode 100644 index 00000000..11e157ae --- /dev/null +++ b/tutorial/example_deep_finance/judge/audit/json_utils.py @@ -0,0 +1,262 @@ +"""JSON Utilities for Audit Grader""" +from __future__ import annotations + +import json +import re +from typing import Any, Dict, List, Tuple + +_JSON_RE = re.compile(r"\{.*\}", re.DOTALL) + +def extract_first_json_object(text: str) -> str | None: + if not text: + return None + m = _JSON_RE.search(text.strip()) + if not m: + return None + return m.group(0) + + +def _repair_json(js: str) -> str: + """ + 尝试修复常见的JSON格式错误 + 1. 修复字符串中未转义的换行符 + 2. 修复trailing comma + 3. 修复缺少的逗号 + 4. 修复不完整的JSON(截断) + """ + # 1. 替换字符串值中的未转义换行符 + # 这是最常见的问题:LLM在字符串中直接输出换行而非 \n + def escape_newlines_in_strings(s: str) -> str: + result = [] + in_string = False + escape_next = False + i = 0 + while i < len(s): + c = s[i] + if escape_next: + result.append(c) + escape_next = False + elif c == '\\': + result.append(c) + escape_next = True + elif c == '"': + result.append(c) + in_string = not in_string + elif in_string and c == '\n': + result.append('\\n') + elif in_string and c == '\r': + result.append('\\r') + elif in_string and c == '\t': + result.append('\\t') + else: + result.append(c) + i += 1 + return ''.join(result) + + js = escape_newlines_in_strings(js) + + # 2. 移除trailing comma: ",}" -> "}" 和 ",]" -> "]" + js = re.sub(r',\s*}', '}', js) + js = re.sub(r',\s*]', ']', js) + + # 3. 尝试修复截断的JSON - 补全缺失的括号 + # 统计括号数量 + open_braces = js.count('{') + close_braces = js.count('}') + open_brackets = js.count('[') + close_brackets = js.count(']') + + # 如果括号不匹配,尝试补全 + if open_braces > close_braces: + # 先关闭可能未闭合的字符串 + # 检查最后是否在字符串中 + in_string = False + escape_next = False + for c in js: + if escape_next: + escape_next = False + elif c == '\\': + escape_next = True + elif c == '"': + in_string = not in_string + if in_string: + js += '"' + + # 补全缺失的括号 + js += ']' * (open_brackets - close_brackets) + js += '}' * (open_braces - close_braces) + + return js + + +def strict_load_json(text: str) -> Tuple[Dict[str, Any] | None, str | None]: + js = extract_first_json_object(text) + if js is None: + return None, "No JSON object found" + + # 第一次尝试:直接解析 + try: + obj = json.loads(js) + if not isinstance(obj, dict): + return None, f"Root is not dict: {type(obj)}" + return obj, None + except json.JSONDecodeError: + pass # 继续尝试修复 + + # 第二次尝试:修复后解析 + try: + repaired = _repair_json(js) + obj = json.loads(repaired) + if not isinstance(obj, dict): + return None, f"Root is not dict: {type(obj)}" + return obj, None + except json.JSONDecodeError as e: + return None, f"JSONDecodeError: {str(e)}" + +def validate_integrity_shape(obj: Dict[str, Any]) -> Tuple[Dict[str, Any] | None, str | None]: + """ + 验证 Evidence Logic Analyst 的输出结构 + Schema: + { + "audit_trail": [ + {"citation_id": int, "verdict": str, ...}, ... + ], + "qualitative_summary": str, + "integrity_score": float + } + """ + # 1. Check Top-level fields + required_fields = ["audit_trail", "qualitative_summary", "integrity_score"] + for f in required_fields: + if f not in obj: + return None, f"Missing field: {f}" + + # 2. Validate integrity_score + try: + score = float(obj["integrity_score"]) + if not (0.0 <= score <= 1.0): + # 容错:稍微越界归一化 + score = max(0.0, min(1.0, score)) + obj["integrity_score"] = score + except ValueError: + return None, "integrity_score must be a float" + + # 3. Validate audit_trail + if not isinstance(obj["audit_trail"], list): + return None, "audit_trail must be a list" + + valid_verdicts = {"Supported", "Overstated", "Contradicted", "Hallucinated", "Irrelevant"} + + for idx, item in enumerate(obj["audit_trail"]): + if not isinstance(item, dict): + return None, f"audit_trail[{idx}] is not a dict" + + # Check required item fields + if "citation_id" not in item: + return None, f"audit_trail[{idx}] missing 'citation_id'" + if "verdict" not in item: + return None, f"audit_trail[{idx}] missing 'verdict'" + + # Normalize verdict + v = str(item["verdict"]).strip() + # 简单的大小写兼容 + v_cap = v.capitalize() + if v not in valid_verdicts and v_cap in valid_verdicts: + item["verdict"] = v_cap + elif v not in valid_verdicts: + # 如果模型输出了奇奇怪怪的verdict,降级为Irrelevant或报错,这里选择报错以保证严谨 + return None, f"Invalid verdict '{v}' in item {idx}" + + return obj, None + + +# ============================================================================= +# Trajectory Helpers +# ============================================================================= + +def _extract_text_content(content) -> str: + if content is None: return "" + if isinstance(content, str): return content + if isinstance(content, list): + # Handle OpenAI multi-part content + parts = [] + for p in content: + if isinstance(p, dict) and p.get("type") == "text": + parts.append(p.get("text", "")) + elif isinstance(p, str): + parts.append(p) + return "\n".join(parts) + return str(content) + +def _strip_think(text: str) -> str: + return re.sub(r".*?\s*", "", text, flags=re.S).strip() + +def _strip_markdown_fences(text: str) -> str: + text = text.strip() + text = re.sub(r'^```(?:markdown|md)?\s*\n?', '', text, flags=re.IGNORECASE) + text = re.sub(r'\n?```\s*$', '', text) + return text.strip() + +def _extract_tool_call_json(text: str) -> str: + # 尝试提取 ```json ... ``` + m = re.search(r"```json\s*(\[[\s\S]*?\])\s*```", text) + if m: return m.group(1).strip() + # 简单的 fallback + if text.strip().startswith("[") and text.strip().endswith("]"): + return text.strip() + return "" + +def construct_reward_prompt(trajectory: List[Dict[str, Any]], template: str) -> str: + """ + 提取 User Query, Evidence (Tool Outputs), Final Report + """ + user_query = "" + evidence_parts = [] + final_report = "" + + # Helper to clean text + def clean(c): return _strip_think(_extract_text_content(c)) + + # 1. Identify components + # 倒序查找 Final Report (包含 References 或 TASK_COMPLETED 的 Assistant 消息) + for i in range(len(trajectory) - 1, -1, -1): + msg = trajectory[i] + if msg.get("role") == "assistant": + txt = clean(msg.get("content")) + # 宽松判定:通常最后的长文本是报告 + if "References" in txt or "[TASK_COMPLETED]" in txt or len(txt) > 600: + final_report = _strip_markdown_fences(txt) + break + + # 找不到显式报告时,取最后一条 Assistant + if not final_report and trajectory: + last = trajectory[-1] + if last.get("role") == "assistant": + final_report = _strip_markdown_fences(clean(last.get("content"))) + + for idx, msg in enumerate(trajectory): + role = msg.get("role") + content_raw = clean(msg.get("content")) + + # User Query: First user message + if role == "user" and not user_query: + user_query = content_raw + continue # 不要把 query 当作 evidence + + # Evidence: Tool calls and Tool outputs + if role == "assistant": + # Check for tool calls + tool_json = _extract_tool_call_json(content_raw) + if tool_json: + evidence_parts.append(f"--- Step {idx} Tool Call ---\n{tool_json}") + + elif role == "tool": + evidence_parts.append(f"--- Step {idx} Tool Result ---\n{content_raw}") + + evidence_text = "\n\n".join(evidence_parts) + + return template.format( + user_query=user_query, + evidence_text=evidence_text, + final_report=final_report + ) \ No newline at end of file diff --git a/tutorial/example_deep_finance/judge/audit/prompt.py b/tutorial/example_deep_finance/judge/audit/prompt.py new file mode 100644 index 00000000..f045b6f4 --- /dev/null +++ b/tutorial/example_deep_finance/judge/audit/prompt.py @@ -0,0 +1,67 @@ +"""Audit Grader Prompt - 引用逻辑审计 (Logic Analyst)""" + +# ============================================================================= +# System Prompt (Evidence Logic Analyst) +# ============================================================================= + +CITATION_INTEGRITY_PROMPT_COT = """ +你是一位 **"证据逻辑分析师" (Evidence Logic Analyst)**。你的任务是审计 AI 研究报告中的引用是否严格符合"逻辑蕴含 (Logical Entailment)"原则。 + +## 核心任务 +不要预设结论。你必须像法官判案一样,先罗列证据,再进行逻辑推导,最后下达判决。 +你需要对报告中出现的每一个引用标记 `[n]` 进行独立的"三步验证"。 + +## 验证逻辑 (必须严格遵守的思维顺序) + +1. **提取 (Extract)**: 锁定报告中由 `[n]` 支撑的陈述片段 (Claim)。 +2. **溯源 (Trace)**: 在 Reference 列表中找到 `[n]` 对应的原始文本,并摘录出核心证据句 (Source Quote)。 + - 注意:Reference 列表可能包含 URL 或 工具调用信息,你需要根据这些信息去上文提供的 **Evidence** 中寻找对应的内容。 +3. **比对 (Compare)**: 分析 Claim 是否被 Source Quote 严格支撑。 + * Check: 数字/事实是否一致? + * Check: 语气是否一致(有没有把"可能"改成"确定")? + * Check: 因果关系是否存在? + +## 判决标准 (Verdict Criteria) +* **Supported**: 证据充分,逻辑闭环。允许合理的概括,但禁止添加细节。 +* **Overstated**: 夸大其词。证据只说了 A,报告却写成了 A+ (如:去掉了"据报道"、"约"等限定词,或强加了因果关系)。 +* **Contradicted**: 事实冲突。报告内容与证据相反。 +* **Hallucinated**: 无中生有。报告中的关键细节(人名、数据、事件)在证据中找不到,或者引用编号在 References 中不存在。 +* **Irrelevant**: 引用无效。证据内容真实,但与报告所述主题无关。 + +## 输出格式 (JSON Only) +只输出 JSON,严禁输出 Markdown 或其他文字。字段顺序代表你的思考顺序,**不可乱序**: + +{ + "audit_trail": [ + { + "citation_id": 1, + "claim_excerpt": "报告中声称的片段...", + "evidence_quote": "从Evidence中摘录的原话...", + "logic_analysis": "分析:证据说的是X,报告写的是Y。二者是否一致?有没有夸大?(简短分析)", + "verdict": "Supported" | "Overstated" | "Contradicted" | "Hallucinated" | "Irrelevant", + "correction": "如果非Supported,基于证据的正确表述应该是..." + }, + ... + ], + "qualitative_summary": "基于上述审计,用一句话总结该报告的引用可信度(如:引用大多准确,但在具体数据上存在夸大嫌疑)。", + "integrity_score": <0.0 到 1.0 的浮点数,计算公式:Supported数量 / 总引用数> +} +""" + +# ============================================================================= +# User Prompt Template +# ============================================================================= + +CITATION_INTEGRITY_USER_TEMPLATE = """请作为逻辑分析师,对以下 AI 研究报告进行引用审计。 + +### User Query +{user_query} + +### Evidence (工具调用与返回结果) +{evidence_text} + +### AI Report (待审计报告) +{final_report} + +请严格遵守 JSON 输出格式,对报告中的所有 [n] 引用进行逐一核查。 +""" \ No newline at end of file diff --git a/tutorial/example_deep_finance/judge/cgcv/__init__.py b/tutorial/example_deep_finance/judge/cgcv/__init__.py new file mode 100644 index 00000000..b67a705f --- /dev/null +++ b/tutorial/example_deep_finance/judge/cgcv/__init__.py @@ -0,0 +1,7 @@ +""" +CGCV (Citation-Grounded Claim Verification) Grader +引用锚定的断言验证框架 +""" +from .grader import CGCVGrader + +__all__ = ["CGCVGrader"] diff --git a/tutorial/example_deep_finance/judge/cgcv/grader.py b/tutorial/example_deep_finance/judge/cgcv/grader.py new file mode 100644 index 00000000..cae97eb4 --- /dev/null +++ b/tutorial/example_deep_finance/judge/cgcv/grader.py @@ -0,0 +1,362 @@ +""" +CGCV Grader - Citation-Grounded Claim Verification +引用锚定的断言验证评分器 +""" +from __future__ import annotations + +import os +from typing import Any, Dict, List, Optional, Tuple + +from openjudge.graders.base_grader import BaseGrader +from openjudge.graders.schema import GraderScore + +# import path 兼容两种写法 +try: + from openjudge.models import OpenAIChatModel +except Exception: # pragma: no cover + from openjudge.models.openai_chat_model import OpenAIChatModel + +from .prompt import ( + CGCV_SYSTEM_PROMPT_ZH, + CGCV_SYSTEM_PROMPT_EN, + CGCV_USER_PROMPT_TEMPLATE_ZH, + CGCV_USER_PROMPT_TEMPLATE_EN, + get_cgcv_prompts +) +from .json_utils import ( + strict_load_json, + validate_cgcv_schema, + parse_cgcv_result, + construct_cgcv_prompt, + compute_cgcv_score, + CGCVResult, + ClaimStatus +) + + +class CGCVGrader(BaseGrader): + """ + Citation-Grounded Claim Verification (CGCV) Grader + 引用锚定的断言验证评分器 + + 核心理念:引用是断言与证据之间的"锚点" + + 验证流程: + 1. 断言提取 (Claim Extraction) + 2. 引用检查 (Citation Checking) + 3. 来源追溯 (Source Tracing) + 4. 内容对齐验证 (Content Alignment) + + 验证状态: + - verified: 验证通过 + - citation_missing: 引用缺失 + - citation_broken: 引用断裂 + - subject_misalign: 对象错位 + - predicate_misalign: 属性错位 + - object_misalign: 值错位 + - qualifier_misalign: 限定错位 + + 评分机制: + - score = verified_claims / total_claims + - 范围: [0, 1] + + 输入:traj(完整对话轨迹) + 输出:GraderScore(name, score, reason) + """ + + def __init__( + self, + model: OpenAIChatModel, + name: str = "cgcv", + language: str = "zh", + **kwargs: Any, + ): + """ + 初始化 CGCV Grader + + Args: + model: OpenAI 兼容的聊天模型 + name: Grader 名称 + language: 语言选择,"zh" 或 "en" + **kwargs: 其他参数传递给 BaseGrader + """ + super().__init__(name=name, **kwargs) + self.model = model + self.language = language.lower() + + # 根据语言选择 prompt + self.system_prompt, self.user_prompt_template = get_cgcv_prompts(self.language) + + @staticmethod + def create_default_model( + model_name: str, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + deterministic: bool = True, + enable_thinking: bool = False, + seed: int = 0, + ) -> OpenAIChatModel: + """ + 创建默认模型 + + Args: + model_name: 模型名称 + api_key: API Key,默认从环境变量读取 + base_url: API Base URL,默认从环境变量读取 + deterministic: 是否使用确定性配置 + enable_thinking: 是否启用思考模式 + seed: 随机种子 + + Returns: + OpenAIChatModel 实例 + """ + api_key = api_key or os.getenv("OPENAI_API_KEY") + base_url = base_url or os.getenv("OPENAI_BASE_URL") + + extra_body: Dict[str, Any] = {} + if deterministic: + extra_body.update({ + "temperature": 0, + "top_p": 1, + "seed": seed, + "presence_penalty": 0, + "frequency_penalty": 0, + }) + if enable_thinking is False: + extra_body["enable_thinking"] = False + + kwargs: Dict[str, Any] = {"model": model_name} + if api_key: + kwargs["api_key"] = api_key + if base_url: + kwargs["base_url"] = base_url + if extra_body: + kwargs["extra_body"] = extra_body + + return OpenAIChatModel(**kwargs) + + async def _aevaluate( + self, + traj: Any, + **_: Any, + ) -> GraderScore: + """ + 异步评估入口 + + Args: + traj: 对话轨迹,格式为 [{"role": ..., "content": ...}, ...] + 或者 {"messages": [...]} 格式 + + Returns: + GraderScore(name, score, reason) + """ + # 1. 提取 messages(兼容两种格式) + if isinstance(traj, dict): + messages_list = traj.get("messages", []) + elif isinstance(traj, list): + messages_list = traj + else: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: traj must be list or dict with 'messages'", + ) + + if not messages_list: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty trajectory", + ) + + # 2. 构建 prompt + user_prompt = construct_cgcv_prompt(messages_list, self.user_prompt_template) + + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": user_prompt} + ] + + # 3. 调用模型 + try: + resp = await self.model.achat(messages) + raw_text = getattr(resp, "content", None) + if raw_text is None: + raw_text = str(resp) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ) + + # 4. 解析 JSON + obj, jerr = strict_load_json(str(raw_text)) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"ParseError: {jerr}; raw[:200]={snippet}", + ) + + # 5. 验证 schema + obj, serr = validate_cgcv_schema(obj) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"SchemaError: {serr}; raw[:200]={snippet}", + ) + + # 6. 解析结果并计算分数 + result = parse_cgcv_result(obj) + score, reason = compute_cgcv_score(result) + + return GraderScore(name=self.name, score=score, reason=reason) + + def evaluate( + self, + traj: Any, + **kwargs: Any, + ) -> GraderScore: + """ + 同步评估入口(通过 asyncio 包装异步方法) + + Args: + traj: 对话轨迹 + **kwargs: 其他参数 + + Returns: + GraderScore + """ + import asyncio + + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete(self._aevaluate(traj, **kwargs)) + + def get_detailed_result( + self, + traj: Any, + ) -> Tuple[GraderScore, Optional[CGCVResult]]: + """ + 获取详细评估结果(包含每个断言的验证详情) + + Args: + traj: 对话轨迹 + + Returns: + (GraderScore, CGCVResult) 元组 + """ + import asyncio + + async def _detailed_evaluate(): + # 复用主流程逻辑 + if isinstance(traj, dict): + messages_list = traj.get("messages", []) + elif isinstance(traj, list): + messages_list = traj + else: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: traj must be list or dict with 'messages'", + ), None + + if not messages_list: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty trajectory", + ), None + + user_prompt = construct_cgcv_prompt(messages_list, self.user_prompt_template) + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": user_prompt} + ] + + try: + resp = await self.model.achat(messages) + raw_text = getattr(resp, "content", None) + if raw_text is None: + raw_text = str(resp) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ), None + + obj, jerr = strict_load_json(str(raw_text)) + if obj is None: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ParseError: {jerr}", + ), None + + obj, serr = validate_cgcv_schema(obj) + if obj is None: + return GraderScore( + name=self.name, + score=0.0, + reason=f"SchemaError: {serr}", + ), None + + result = parse_cgcv_result(obj) + score, reason = compute_cgcv_score(result) + + return GraderScore(name=self.name, score=score, reason=reason), result + + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete(_detailed_evaluate()) + + +# ============================================================================= +# Convenience Functions +# ============================================================================= + +def create_cgcv_grader( + model_name: str, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + language: str = "zh", + **kwargs +) -> CGCVGrader: + """ + 便捷函数:创建 CGCV Grader + + Args: + model_name: 模型名称 + api_key: API Key + base_url: API Base URL + language: 语言 ("zh" 或 "en") + **kwargs: 其他模型参数 + + Returns: + CGCVGrader 实例 + + Example: + >>> grader = create_cgcv_grader("gpt-4o", language="zh") + >>> result = await grader.aevaluate(trajectory) + >>> print(f"Score: {result.score}, Reason: {result.reason}") + """ + model = CGCVGrader.create_default_model( + model_name=model_name, + api_key=api_key, + base_url=base_url, + **kwargs + ) + return CGCVGrader(model=model, language=language) diff --git a/tutorial/example_deep_finance/judge/cgcv/json_utils.py b/tutorial/example_deep_finance/judge/cgcv/json_utils.py new file mode 100644 index 00000000..48cb59aa --- /dev/null +++ b/tutorial/example_deep_finance/judge/cgcv/json_utils.py @@ -0,0 +1,661 @@ +""" +CGCV JSON Utilities +JSON 解析和验证工具 +""" +from __future__ import annotations + +import json +import re +from typing import Any, Dict, List, Tuple, Optional +from dataclasses import dataclass +from enum import Enum + + +# ============================================================================= +# Constants +# ============================================================================= + +class ClaimStatus(str, Enum): + """断言验证状态枚举""" + VERIFIED = "verified" + CITATION_MISSING = "citation_missing" + CITATION_BROKEN = "citation_broken" + SUBJECT_MISALIGN = "subject_misalign" + PREDICATE_MISALIGN = "predicate_misalign" + OBJECT_MISALIGN = "object_misalign" + QUALIFIER_MISALIGN = "qualifier_misalign" + + +# 所有有效的 status 值 +VALID_STATUSES = {s.value for s in ClaimStatus} + +# JSON 提取正则 +_JSON_RE = re.compile(r"\{.*\}", re.DOTALL) + + +# ============================================================================= +# JSON Repair Helper +# ============================================================================= + +def _repair_json(js: str) -> str: + """ + 尝试修复常见的JSON格式错误 + 1. 修复字符串中未转义的换行符 + 2. 修复trailing comma + 3. 修复不完整的JSON(截断) + """ + # 1. 替换字符串值中的未转义换行符 + def escape_newlines_in_strings(s: str) -> str: + result = [] + in_string = False + escape_next = False + i = 0 + while i < len(s): + c = s[i] + if escape_next: + result.append(c) + escape_next = False + elif c == '\\': + result.append(c) + escape_next = True + elif c == '"': + result.append(c) + in_string = not in_string + elif in_string and c == '\n': + result.append('\\n') + elif in_string and c == '\r': + result.append('\\r') + elif in_string and c == '\t': + result.append('\\t') + else: + result.append(c) + i += 1 + return ''.join(result) + + js = escape_newlines_in_strings(js) + + # 2. 移除trailing comma: ",}" -> "}" 和 ",]" -> "]" + js = re.sub(r',\s*}', '}', js) + js = re.sub(r',\s*]', ']', js) + + # 3. 尝试修复截断的JSON - 补全缺失的括号 + open_braces = js.count('{') + close_braces = js.count('}') + open_brackets = js.count('[') + close_brackets = js.count(']') + + if open_braces > close_braces: + # 先关闭可能未闭合的字符串 + in_string = False + escape_next = False + for c in js: + if escape_next: + escape_next = False + elif c == '\\': + escape_next = True + elif c == '"': + in_string = not in_string + if in_string: + js += '"' + + # 补全缺失的括号 + js += ']' * (open_brackets - close_brackets) + js += '}' * (open_braces - close_braces) + + return js + + +# ============================================================================= +# Data Classes +# ============================================================================= + +@dataclass +class ClaimVerification: + """单个断言的验证结果""" + subject: str + predicate: str + object: str + qualifier: str + citation: Optional[str] + status: str + source_id: Optional[str] + note: str + + def is_verified(self) -> bool: + return self.status == ClaimStatus.VERIFIED.value + + def is_citation_issue(self) -> bool: + return self.status in { + ClaimStatus.CITATION_MISSING.value, + ClaimStatus.CITATION_BROKEN.value + } + + def is_alignment_issue(self) -> bool: + return self.status in { + ClaimStatus.SUBJECT_MISALIGN.value, + ClaimStatus.PREDICATE_MISALIGN.value, + ClaimStatus.OBJECT_MISALIGN.value, + ClaimStatus.QUALIFIER_MISALIGN.value + } + + +@dataclass +class CGCVResult: + """CGCV 验证结果汇总""" + claims: List[ClaimVerification] + total: int + verified: int + citation_missing: int + citation_broken: int + alignment_issues: int + + @property + def score(self) -> float: + """计算验证通过率""" + if self.total == 0: + return 0.0 + return self.verified / self.total + + def get_summary(self) -> Dict[str, int]: + """获取统计摘要""" + return { + "total": self.total, + "verified": self.verified, + "citation_missing": self.citation_missing, + "citation_broken": self.citation_broken, + "alignment_issues": self.alignment_issues + } + + +# ============================================================================= +# JSON Parsing Functions +# ============================================================================= + +def extract_first_json_object(text: str) -> Optional[str]: + """ + 从文本中提取第一个 JSON 对象 + + Args: + text: 原始文本 + + Returns: + JSON 字符串,如果未找到返回 None + """ + if not text: + return None + + # 先尝试找 ```json ... ``` 代码块 + json_block_match = re.search(r"```json\s*(\{[\s\S]*?\})\s*```", text) + if json_block_match: + return json_block_match.group(1).strip() + + # 再尝试找第一个 {...} + m = _JSON_RE.search(text.strip()) + if not m: + return None + return m.group(0) + + +def strict_load_json(text: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + """ + 严格解析 JSON(带容错修复) + + Args: + text: 原始文本 + + Returns: + (解析结果, 错误信息) 元组 + """ + js = extract_first_json_object(text) + if js is None: + return None, "No JSON object found in model output" + + # 第一次尝试:直接解析 + try: + obj = json.loads(js) + if not isinstance(obj, dict): + return None, f"Top-level JSON is not an object: {type(obj).__name__}" + return obj, None + except json.JSONDecodeError: + pass # 继续尝试修复 + + # 第二次尝试:修复后解析 + try: + repaired = _repair_json(js) + obj = json.loads(repaired) + if not isinstance(obj, dict): + return None, f"Top-level JSON is not an object: {type(obj).__name__}" + return obj, None + except json.JSONDecodeError as e: + return None, f"JSONDecodeError: {e}" + except Exception as e: + return None, f"{type(e).__name__}: {e}" + + +def validate_cgcv_schema(obj: Dict[str, Any]) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + """ + 验证 CGCV JSON 结构 + + 期望格式: + { + "claims": [ + { + "subject": str, + "predicate": str, + "object": str, + "qualifier": str, + "citation": str | null, + "status": str (one of VALID_STATUSES), + "source_id": str | null, + "note": str + } + ] + } + + Args: + obj: JSON 对象 + + Returns: + (规范化后的对象, 错误信息) 元组 + """ + # claims 必须存在且为 list + if "claims" not in obj: + return None, "Missing field: claims" + + claims = obj["claims"] + if not isinstance(claims, list): + return None, f"Field 'claims' must be list, got {type(claims).__name__}" + + # 验证并规范化每个 claim + normalized_claims = [] + for idx, claim in enumerate(claims): + if not isinstance(claim, dict): + continue # 跳过非字典项 + + # 提取并规范化字段 + normalized = { + "subject": str(claim.get("subject", "未明确"))[:200], + "predicate": str(claim.get("predicate", "未明确"))[:200], + "object": str(claim.get("object", "未明确"))[:500], + "qualifier": str(claim.get("qualifier", "未明确"))[:200], + "citation": claim.get("citation"), + "status": str(claim.get("status", "")).lower(), + "source_id": claim.get("source_id"), + "note": str(claim.get("note", ""))[:500] + } + + # 规范化 citation + if normalized["citation"] is not None: + normalized["citation"] = str(normalized["citation"]) + if normalized["citation"].lower() in ("null", "none", ""): + normalized["citation"] = None + + # 规范化 source_id + if normalized["source_id"] is not None: + normalized["source_id"] = str(normalized["source_id"]) + if normalized["source_id"].lower() in ("null", "none", ""): + normalized["source_id"] = None + + # 验证 status + if normalized["status"] not in VALID_STATUSES: + # 尝试模糊匹配 + status_lower = normalized["status"] + matched = False + for valid_status in VALID_STATUSES: + if valid_status in status_lower or status_lower in valid_status: + normalized["status"] = valid_status + matched = True + break + if not matched: + # 默认标记为 citation_missing + normalized["status"] = ClaimStatus.CITATION_MISSING.value + + normalized_claims.append(normalized) + + obj["claims"] = normalized_claims + return obj, None + + +def parse_cgcv_result(obj: Dict[str, Any]) -> CGCVResult: + """ + 解析 CGCV 结果为结构化对象 + + Args: + obj: 经过 validate_cgcv_schema 验证的 JSON 对象 + + Returns: + CGCVResult 对象 + """ + claims = [] + verified_count = 0 + citation_missing_count = 0 + citation_broken_count = 0 + alignment_issues_count = 0 + + for claim_dict in obj.get("claims", []): + claim = ClaimVerification( + subject=claim_dict.get("subject", ""), + predicate=claim_dict.get("predicate", ""), + object=claim_dict.get("object", ""), + qualifier=claim_dict.get("qualifier", ""), + citation=claim_dict.get("citation"), + status=claim_dict.get("status", ""), + source_id=claim_dict.get("source_id"), + note=claim_dict.get("note", "") + ) + claims.append(claim) + + # 统计 + if claim.is_verified(): + verified_count += 1 + elif claim.status == ClaimStatus.CITATION_MISSING.value: + citation_missing_count += 1 + elif claim.status == ClaimStatus.CITATION_BROKEN.value: + citation_broken_count += 1 + elif claim.is_alignment_issue(): + alignment_issues_count += 1 + + return CGCVResult( + claims=claims, + total=len(claims), + verified=verified_count, + citation_missing=citation_missing_count, + citation_broken=citation_broken_count, + alignment_issues=alignment_issues_count + ) + + +# ============================================================================= +# Trajectory 处理辅助函数 +# ============================================================================= + +def _extract_text_content(content) -> str: + """统一提取纯文本内容""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + out = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + out.append(item.get("text", "")) + elif isinstance(item, str): + out.append(item) + return "\n".join(out) + return str(content) + + +def _strip_think(text: str) -> str: + """去除 ... 标签""" + return re.sub(r".*?\s*", "", text, flags=re.S).strip() + + +def _strip_markdown_fences(text: str) -> str: + """ + 清理 markdown 代码块标记 + - 移除开头的 ```markdown / ```md / ``` 等 + - 移除结尾的 ``` + """ + text = text.strip() + # 移除开头的 ```xxx + text = re.sub(r'^```(?:markdown|md)?\s*\n?', '', text, flags=re.IGNORECASE) + # 移除结尾的 ``` + text = re.sub(r'\n?```\s*$', '', text) + return text.strip() + + +def _normalize_traj(trajectory): + """兼容 [[...]] 格式""" + if isinstance(trajectory, list) and trajectory and isinstance(trajectory[0], list): + return trajectory[0] + return trajectory + + +def _extract_tool_call_json(text: str) -> str: + """提取工具调用 JSON""" + m = re.search(r"```json\s*(\[[\s\S]*?\])\s*```", text) + if m: + return m.group(1).strip() + l, r = text.find("["), text.rfind("]") + if l != -1 and r != -1 and r > l: + cand = text[l:r+1].strip() + if ("tool_name" in cand) and ("tool_args" in cand): + return cand + return "" + + +def _looks_like_tool_result(text: str) -> bool: + """判断是否为工具返回结果""" + t = text.strip() + # 匹配常见的工具返回格式 + if t.startswith("Tool:") or t.startswith("Result:"): + return True + # 匹配 [Tool: xxx] 格式 + if t.startswith("[Tool:"): + return True + # 匹配 格式 + if "" in t or "" in t: + return True + # 匹配 dashscope_search 等工具的返回结果 + if t.startswith("{") and ("query" in t) and ("search_results" in t or "response_content" in t): + return True + # 匹配爬取工具返回的结构化数据 + if ("股票代码 |" in t) or ("单位:" in t) or t.startswith("### "): + return True + # 匹配同花顺工具返回的来源标记 + if "> 以下内容来自:" in t: + return True + return False + + +def _is_probably_final_report(text: str) -> bool: + """判断是否为最终报告""" + t = text.strip() + return ("## References" in t) or ("[TASK_COMPLETED]" in t) or t.lstrip().startswith("# ") + + +def _split_tool_responses(text: str) -> List[str]: + """ + 分割多个工具响应 + + 处理格式如: + [Tool: xxx] +... + + +[Tool: yyy] +... + """ + # 先尝试按 \n 分割 + if "" in text and "" in text: + parts = re.split(r'\s*', text) + # 清理每个部分的标签 + cleaned = [] + for p in parts: + p = re.sub(r'^\s*\s*', '', p) + p = re.sub(r'\s*\s*$', '', p) + p = p.strip() + if p: + cleaned.append(p) + if cleaned: + return cleaned + + # 尝试按 [Tool: xxx] 分割 + tool_pattern = r'(?=\[Tool:\s*[^\]]+\])' + parts = re.split(tool_pattern, text) + parts = [p.strip() for p in parts if p.strip()] + if len(parts) > 1: + return parts + + # 无法分割,返回原文本 + return [text.strip()] if text.strip() else [] + + +def construct_cgcv_prompt( + trajectory: List[Dict[str, Any]], + user_prompt_template: str +) -> str: + """ + 从 trajectory 构建 CGCV 评估 prompt + + Args: + trajectory: 对话轨迹 [{"role": ..., "content": ...}, ...] + user_prompt_template: 用户 prompt 模板 + + Returns: + 构建好的 user prompt 字符串 + """ + traj = _normalize_traj(trajectory) + if not traj: + traj = [] + + user_query = "" + tool_calls: List[str] = [] + evidence: List[str] = [] + final_report = "" + + # 找到 final report(从后往前找第一个符合条件的 assistant 消息) + for i in range(len(traj) - 1, -1, -1): + step = traj[i] + if step.get("role") == "assistant": + txt = _strip_think(_extract_text_content(step.get("content"))) + if _is_probably_final_report(txt): + final_report = txt + break + + if not final_report: + for i in range(len(traj) - 1, -1, -1): + if traj[i].get("role") == "assistant": + final_report = _strip_think(_extract_text_content(traj[i].get("content"))) + break + + # 清理 markdown 代码块标记 + final_report = _strip_markdown_fences(final_report) + + # 遍历提取 user_query, tool_calls, evidence + evidence_idx = 0 + for idx, step in enumerate(traj): + role = step.get("role") + raw = _extract_text_content(step.get("content")) + txt = _strip_think(raw) + if not raw: + continue + + # 跳过 system 消息 + if role == "system": + continue + + if role == "user" and not user_query and (not _looks_like_tool_result(raw)): + user_query = txt + continue + + if role == "assistant": + call_json = _extract_tool_call_json(raw) + if call_json: + tool_calls.append(f"【工具调用 {len(tool_calls) + 1}】\n{call_json}") + + if role == "tool": + # 处理多工具响应的情况 + tool_parts = _split_tool_responses(raw) + for part in tool_parts: + if part: + evidence_idx += 1 + evidence.append(f"【Evidence {evidence_idx}】\n{part}") + elif role == "user" and user_query and _looks_like_tool_result(raw): + # 某些情况下工具结果可能在 user 消息中 + evidence_idx += 1 + evidence.append(f"【Evidence {evidence_idx}】\n{raw}") + + # 构建 evidence_text,使用更清晰的分隔 + evidence_parts = [] + if evidence: + evidence_parts.append("\n\n".join(evidence)) + + evidence_text = "\n\n".join(evidence_parts) if evidence_parts else "(无可用证据)" + + return user_prompt_template.format( + user_query=user_query, + evidence_text=evidence_text, + report=final_report + ).strip() + + +# ============================================================================= +# Score Computation +# ============================================================================= + +def compute_cgcv_score( + result: CGCVResult, + citation_weight: float = 0.3, + alignment_weight: float = 0.7 +) -> Tuple[float, str]: + """ + 计算 CGCV 评分 + + 评分策略: + 1. 基础分:verified / total + 2. 可选:分层评分 + - citation_score: 有引用且可追溯的比例 + - alignment_score: 内容对齐的比例(在有有效引用的前提下) + + Args: + result: CGCVResult 对象 + citation_weight: 引用分数权重(默认 0.3) + alignment_weight: 对齐分数权重(默认 0.7) + + Returns: + (score, reason) 元组 + """ + total = result.total + + if total == 0: + return 0.0, "no_claims_detected" + + # 简单评分:verified / total + base_score = result.verified / total + + # 分层统计 + citation_issues = result.citation_missing + result.citation_broken + claims_with_valid_citation = total - citation_issues + + # 引用有效率 + citation_valid_rate = claims_with_valid_citation / total if total > 0 else 0.0 + + # 对齐正确率(在有效引用中) + if claims_with_valid_citation > 0: + alignment_correct_rate = result.verified / claims_with_valid_citation + else: + alignment_correct_rate = 0.0 + + # 加权分数 + weighted_score = ( + citation_weight * citation_valid_rate + + alignment_weight * alignment_correct_rate + ) + + # 最终使用基础分数(更直观) + final_score = base_score + + # 构建 reason + reason_parts = [ + f"total={total}", + f"verified={result.verified}", + f"citation_missing={result.citation_missing}", + f"citation_broken={result.citation_broken}", + f"alignment_issues={result.alignment_issues}", + f"score={final_score:.4f}", + ] + + # 添加错误摘要 + if result.alignment_issues > 0: + # 统计各类对齐错误 + error_counts = {} + for claim in result.claims: + if claim.is_alignment_issue(): + error_counts[claim.status] = error_counts.get(claim.status, 0) + 1 + error_summary = ", ".join(f"{k}:{v}" for k, v in error_counts.items()) + reason_parts.append(f"errors=[{error_summary}]") + + reason = " | ".join(reason_parts) + return round(final_score, 6), reason[:800] diff --git a/tutorial/example_deep_finance/judge/cgcv/prompt.py b/tutorial/example_deep_finance/judge/cgcv/prompt.py new file mode 100644 index 00000000..a98a98b8 --- /dev/null +++ b/tutorial/example_deep_finance/judge/cgcv/prompt.py @@ -0,0 +1,378 @@ +""" +Citation-Grounded Claim Verification (CGCV) Prompt +引用锚定的断言验证框架 + +核心理念:引用是断言与证据之间的"锚点",验证引用的有效性和内容的一致性。 +""" + +# ============================================================================= +# System Prompt - 中文版 +# ============================================================================= + +CGCV_SYSTEM_PROMPT_ZH = """你是一位"引用核查专家",负责审计研究报告中的断言是否有正确的引用支撑,并验证断言内容与来源是否一致。 + +重要说明:这是一个事后评估任务,用于评估已完成的报告质量。报告中通过工具调用获取的信息是正确的研究方式,你的任务是验证这些信息在最终报告中是否被正确引用和准确呈现。 + +## 输入说明 + +你会收到三部分内容: +1. **用户问题**:用户的原始查询 +2. **Evidence**:工具调用返回的原始数据(如搜索结果、爬取的网页内容等) +3. **研究报告**:待核查的报告,包含: + - 正文:包含带引用标记 `[n]` 的断言 + - References 区块:报告末尾的 `## References` 部分,格式通常为: + `[n] 标题描述, 工具: tool_name, 参数:xxx, 数据日期/报告期: xxx, 来源 - URL 或 (no-url)` + +## 验证流程 + +### Stage 1: 断言提取 +从报告**正文**(不含 References 区块)中识别所有包含具体信息的可验证断言,提取四个要素: +- **Subject**:断言涉及的对象(公司、产品、指数、人物等) +- **Predicate**:描述的属性或关系(收入、增长率、排名、状态等) +- **Object**:具体的值、数量或结论 +- **Qualifier**:限定条件(时间、范围、前提条件等) + +**可验证断言的识别标准**: +- 包含具体数值(金额、比例、增速、排名等) +- 包含具体日期或时间段 +- 包含可被证据支持或反驳的明确事实陈述 +- 一句话包含多个数值时,按一条断言计数 + +### Stage 2: 引用检查 +检查每个断言是否有引用标记 `[n]`: +- 有引用 → 继续下一阶段 +- 无引用 → 标记为 `citation_missing` + +### Stage 3: 来源追溯 +追溯引用 `[n]` 的验证路径:**报告正文 [n] → References 中的 [n] 条目 → Evidence 中的对应数据** +- 若 References 中存在 `[n]` 条目,且能在 Evidence 中找到对应数据 → 继续下一阶段 +- 若 References 中无 `[n]` 条目,或条目无效(如 URL 为 javascript:void(0)) → 标记为 `citation_broken` + +### Stage 4: 内容对齐验证 +将报告中的断言与 Evidence 中的原始数据进行比对,验证四个要素是否一致: +- Subject 不一致 → `subject_misalign` +- Predicate 不一致 → `predicate_misalign` +- Object 不一致 → `object_misalign` +- Qualifier 不一致 → `qualifier_misalign` +- 全部一致 → `verified` + +## 验证状态说明 + +| 状态 | 含义 | +|-----|------| +| `verified` | 验证通过:有引用、可追溯、内容与 Evidence 一致 | +| `citation_missing` | 引用缺失:可验证断言无引用标记 | +| `citation_broken` | 引用断裂:引用在 References 中不存在或无效 | +| `subject_misalign` | 对象错位:断言对象与 Evidence 不一致 | +| `predicate_misalign` | 属性错位:属性或关系与 Evidence 不匹配 | +| `object_misalign` | 值错位:数值或结论与 Evidence 不一致 | +| `qualifier_misalign` | 限定错位:时间或条件与 Evidence 不一致 | + +## 内容对齐规则 + +### Subject 对齐规则 +- ✓ 完全一致或已知别名等价(如:腾讯 = 腾讯控股 = Tencent) +- ✓ 股票代码与公司名对应(如:600745 = 闻泰科技) +- ✗ 不同实体混淆(A公司数据误标为B公司) +- ✗ 范围混淆(子公司/渠道数据误标为集团整体,如:i茅台营收 ≠ 贵州茅台总营收) + +### Predicate 对齐规则 +- ✓ 完全一致或语义等价(如:ROE = 净资产收益率、营收 = 营业收入 = 总收入) +- ✗ 概念混淆(净利润 ≠ 营业收入、毛利率 ≠ 净利率) +- ✗ 口径混淆(日收益率 ≠ 周收益率、同比 ≠ 环比) + +### Object 对齐规则 +- ✓ 精确一致(454.03亿 = 454.03亿) +- ✓ 等价形式(18.60% = 18.6%,末尾零可省) +- ✓ 单位换算等价(45403百万 = 454.03亿) +- ✓ 表述等价(下降8% = 增长-8% = 同比-8%) +- ✓ 合理近似:使用"约/大约/左右"修饰时,允许5%以内误差 +- ✗ 精度丢失:未使用"约"等修饰词时,不允许省略有效数字(454.03亿 → 454亿) +- ✗ 超出容差:即使有"约"修饰,误差超过5% +- ✗ 数值无据:Evidence 中找不到该数值 + +### Qualifier 对齐规则 +- ✓ 完全一致或语义等价(2025年Q2 = 2025年第二季度 = 2025年4-6月) +- ✓ 报告期等价(2025年三季报 = 截至2025年9月30日 = 2025年前三季度) +- ✗ 年份错位(2024年 ≠ 2025年) +- ✗ 周期错位(Q2 ≠ Q3、上半年 ≠ 前三季度) +- ✗ 时点混淆(发布日期 ≠ 数据截止日期) + +## 输出格式 + +请直接输出 JSON,格式如下: +```json +{ + "claims": [ + { + "subject": "断言对象", + "predicate": "属性/关系", + "object": "值/结论", + "qualifier": "限定条件(无则填'未明确')", + "citation": "引用标记如[1],无则填null", + "status": "verified/citation_missing/citation_broken/subject_misalign/predicate_misalign/object_misalign/qualifier_misalign", + "source_id": "来源编号(如有)", + "note": "说明(verified时为空字符串)" + } + ] +} +``` + +只输出 JSON,不要输出其他解释文字。 + +## 示例 + +### 示例1:验证通过 (verified) + +**Report正文片段**:闻泰科技2025年三季报净利润为15.13亿元,同比增长265.09%[5] +**Report References**:[5] 闻泰科技2025年三季报财务分析, 工具: crawl_ths_finance, 参数:code=600745, 数据日期/报告期: 2025-09-30, 来源 - https://basic.10jqka.com.cn/600745/finance.html +**Evidence**:...闻泰科技...净利润15.13亿元...同比增长265.09%... + +分析: +- Subject: 闻泰科技 ✓ +- Predicate: 净利润、同比增长 ✓ +- Object: 15.13亿元、265.09% ✓ +- Qualifier: 2025年三季报 ↔ 2025-09-30 ✓(语义等价) +- 引用[5]存在于References,可追溯到Evidence ✓ + +输出: +{"subject": "闻泰科技", "predicate": "净利润同比增长", "object": "15.13亿元,265.09%", "qualifier": "2025年三季报", "citation": "[5]", "status": "verified", "source_id": "5", "note": ""} + +--- + +### 示例2:引用缺失 (citation_missing) + +**Report正文片段**:该公司毛利率达到16.98%,同比提升6.97个百分点 +**Evidence**:...毛利率16.98%...同比提升6.97个百分点... + +分析: +- 断言包含具体数值(16.98%、6.97个百分点),属于可验证断言 +- 但断言末尾无引用标记 [n] + +输出: +{"subject": "该公司", "predicate": "毛利率", "object": "16.98%,同比提升6.97个百分点", "qualifier": "未明确", "citation": null, "status": "citation_missing", "source_id": null, "note": "可验证断言缺少引用标记"} + +--- + +### 示例3:引用断裂 (citation_broken) + +**Report正文片段**:市场份额达到23%[9] +**Report References**:(无[9]条目,或[9]条目的URL为 javascript:void(0)) + +分析: +- 有引用标记[9] +- 但References中无有效的[9]条目 + +输出: +{"subject": "未明确", "predicate": "市场份额", "object": "23%", "qualifier": "未明确", "citation": "[9]", "status": "citation_broken", "source_id": null, "note": "引用[9]在References中不存在或无效"} + +--- + +### 示例4:对象错位 (subject_misalign) + +**Report正文片段**:赛腾股份2025年三季报净利润为15.13亿元[5] +**Report References**:[5] 闻泰科技2025年三季报财务分析, 工具: crawl_ths_finance, 参数:code=600745... +**Evidence**:...闻泰科技...净利润15.13亿元... + +分析: +- Subject: 赛腾股份 ↔ 闻泰科技 ✗ +- 15.13亿元是闻泰科技的数据,被错误归属给赛腾股份 + +输出: +{"subject": "赛腾股份", "predicate": "净利润", "object": "15.13亿元", "qualifier": "2025年三季报", "citation": "[5]", "status": "subject_misalign", "source_id": "5", "note": "来源[5]中15.13亿元属于闻泰科技,非赛腾股份"} + +--- + +### 示例5:值错位-精度丢失 (object_misalign) + +**Report正文片段**:净利润15亿元[5] +**Evidence**:...净利润15.13亿元... + +分析: +- Object: 15亿 ↔ 15.13亿 ✗ +- 报告未使用"约"修饰,但省略了小数部分(0.13亿 = 1300万,精度损失明显) + +输出: +{"subject": "未明确", "predicate": "净利润", "object": "15亿元", "qualifier": "未明确", "citation": "[5]", "status": "object_misalign", "source_id": "5", "note": "Evidence为15.13亿元,报告省略为15亿元,存在精度丢失"} + +--- + +### 示例6:限定错位 (qualifier_misalign) + +**Report正文片段**:2025年Q2净利润为15.13亿元[5] +**Report References**:[5] ...数据日期/报告期: 2025-09-30... +**Evidence**:...2025年三季报...净利润15.13亿元... + +分析: +- Qualifier: Q2(截至6月30日) ↔ 2025-09-30(三季报,截至9月30日) ✗ +- 报告期不一致 + +输出: +{"subject": "未明确", "predicate": "净利润", "object": "15.13亿元", "qualifier": "2025年Q2", "citation": "[5]", "status": "qualifier_misalign", "source_id": "5", "note": "来源[5]为2025年三季报数据(截至9月30日),非Q2数据"}""" + +# ============================================================================= +# System Prompt - English Version +# ============================================================================= + +CGCV_SYSTEM_PROMPT_EN = """You are a "Citation Verification Expert" responsible for auditing whether claims in research reports have proper citation support and whether the claim content is consistent with the evidence sources. + +Important Note: This is a post-hoc evaluation task for assessing completed report quality. Information obtained through tool calls in the report is a correct research approach. Your task is to verify whether this information is correctly cited and accurately presented in the final report. + +## Input Description + +You will receive three parts: +1. **User Query**: The original user question +2. **Evidence**: Raw data returned from tool calls (search results, crawled web content, etc.) +3. **Research Report**: The report to be verified, containing: + - Body: Contains claims with citation markers `[n]` + - References section: The `## References` part at the end, typically in format: + `[n] Title description, Tool: tool_name, Params:xxx, Data date/Report period: xxx, Source - URL or (no-url)` + +## Verification Process + +### Stage 1: Claim Extraction +Identify all verifiable claims containing specific information from the report **body** (excluding References section), extracting four elements: +- **Subject**: The entity the claim is about (company, product, index, person, etc.) +- **Predicate**: The attribute or relationship described (revenue, growth rate, ranking, status, etc.) +- **Object**: The specific value, quantity, or conclusion +- **Qualifier**: Limiting conditions (time, scope, prerequisites, etc.) + +**Criteria for verifiable claims**: +- Contains specific numbers (amounts, ratios, growth rates, rankings, etc.) +- Contains specific dates or time periods +- Contains definitive factual statements that can be supported or refuted by evidence +- Multiple values in one sentence count as one claim + +### Stage 2: Citation Checking +Check whether each claim has a citation marker `[n]`: +- Has citation → proceed to next stage +- No citation → mark as `citation_missing` + +### Stage 3: Source Tracing +Trace citation `[n]` verification path: **Report body [n] → [n] entry in References → Corresponding data in Evidence** +- If `[n]` entry exists in References and corresponding data can be found in Evidence → proceed to next stage +- If `[n]` entry doesn't exist in References, or entry is invalid (e.g., URL is javascript:void(0)) → mark as `citation_broken` + +### Stage 4: Content Alignment Verification +Compare claims in report with original data in Evidence, verify if four elements are consistent: +- Subject inconsistent → `subject_misalign` +- Predicate inconsistent → `predicate_misalign` +- Object inconsistent → `object_misalign` +- Qualifier inconsistent → `qualifier_misalign` +- All consistent → `verified` + +## Verification Status Description + +| Status | Meaning | +|--------|--------| +| `verified` | Verified: has citation, traceable, content matches Evidence | +| `citation_missing` | Missing citation: verifiable claim has no citation marker | +| `citation_broken` | Broken citation: citation doesn't exist or is invalid in References | +| `subject_misalign` | Subject misaligned: claim subject inconsistent with Evidence | +| `predicate_misalign` | Predicate misaligned: attribute or relationship doesn't match Evidence | +| `object_misalign` | Object misaligned: value or conclusion inconsistent with Evidence | +| `qualifier_misalign` | Qualifier misaligned: time or condition inconsistent with Evidence | + +## Content Alignment Rules + +### Subject Alignment Rules +- ✓ Exact match or known alias equivalence (e.g., Tencent = Tencent Holdings) +- ✓ Stock code corresponds to company name (e.g., 600745 = Wingtech) +- ✗ Different entity confusion (Company A data mislabeled as Company B) +- ✗ Scope confusion (subsidiary/channel data mislabeled as group total) + +### Predicate Alignment Rules +- ✓ Exact match or semantic equivalence (e.g., ROE = Return on Equity, Revenue = Operating Income = Total Revenue) +- ✗ Concept confusion (Net profit ≠ Operating revenue, Gross margin ≠ Net margin) +- ✗ Scope confusion (Daily return rate ≠ Weekly return rate, YoY ≠ MoM) + +### Object Alignment Rules +- ✓ Exact match (45.403B = 45.403B) +- ✓ Equivalent forms (18.60% = 18.6%, trailing zeros can be omitted) +- ✓ Unit conversion equivalence (45403 million ≈ 454.03 billion) +- ✓ Expression equivalence (down 8% = growth -8% = YoY -8%) +- ✓ Reasonable approximation: when using "approx/about/around" modifier, allow up to 5% error +- ✗ Precision loss: without "approx" modifier, cannot omit significant digits (454.03B → 454B) +- ✗ Exceeds tolerance: even with "approx" modifier, error exceeds 5% +- ✗ Value not found: cannot find this value in Evidence + +### Qualifier Alignment Rules +- ✓ Exact match or semantic equivalence (2025 Q2 = Q2 2025 = Apr-Jun 2025) +- ✓ Report period equivalence (Q3 2025 report = as of Sep 30, 2025 = first three quarters of 2025) +- ✗ Year misalignment (2024 ≠ 2025) +- ✗ Period misalignment (Q2 ≠ Q3, H1 ≠ first three quarters) +- ✗ Time point confusion (publication date ≠ data cutoff date) + +## Output Format + +Please output JSON directly in the following format: +```json +{ + "claims": [ + { + "subject": "claim subject", + "predicate": "attribute/relationship", + "object": "value/conclusion", + "qualifier": "limiting condition (use 'unspecified' if none)", + "citation": "citation marker like [1], null if none", + "status": "verified/citation_missing/citation_broken/subject_misalign/predicate_misalign/object_misalign/qualifier_misalign", + "source_id": "source number (if available)", + "note": "explanation (empty string when verified)" + } + ] +} +``` + +Output JSON only, no other explanatory text. +""" + +# ============================================================================= +# User Prompt Template +# ============================================================================= + +CGCV_USER_PROMPT_TEMPLATE_ZH = """请对以下研究报告进行引用核查,验证每个可验证断言的引用有效性和内容一致性。 + +### 用户问题 +{user_query} + +### Evidence(工具调用获取的信息) +{evidence_text} + +### 研究报告(待核查) +{report} + +请按照验证流程逐一检查报告中的可验证断言,只输出 JSON 结果。 +""" + +CGCV_USER_PROMPT_TEMPLATE_EN = """Please perform citation verification on the following research report, validating citation validity and content consistency for each verifiable claim. + +### User Query +{user_query} + +### Evidence (Information obtained through tool calls) +{evidence_text} + +### Research Report (To be verified) +{report} + +Please check each verifiable claim in the report according to the verification process, output JSON result only. +""" + +# ============================================================================= +# Utility: Get prompts by language +# ============================================================================= + +def get_cgcv_prompts(language: str = "zh"): + """ + Get CGCV prompts based on language. + + Args: + language: "zh" for Chinese, "en" for English + + Returns: + Tuple of (system_prompt, user_prompt_template) + """ + if language.lower() in ["zh", "chinese", "中文"]: + return CGCV_SYSTEM_PROMPT_ZH, CGCV_USER_PROMPT_TEMPLATE_ZH + else: + return CGCV_SYSTEM_PROMPT_EN, CGCV_USER_PROMPT_TEMPLATE_EN diff --git a/tutorial/example_deep_finance/judge/ebtu/__init__.py b/tutorial/example_deep_finance/judge/ebtu/__init__.py new file mode 100644 index 00000000..86ba0083 --- /dev/null +++ b/tutorial/example_deep_finance/judge/ebtu/__init__.py @@ -0,0 +1 @@ +# ebtu_reward package diff --git a/tutorial/example_deep_finance/judge/ebtu/grader.py b/tutorial/example_deep_finance/judge/ebtu/grader.py new file mode 100644 index 00000000..b5ee1380 --- /dev/null +++ b/tutorial/example_deep_finance/judge/ebtu/grader.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from typing import Any, Dict, Optional + +from openjudge.graders.base_grader import BaseGrader +from openjudge.graders.schema import GraderScore + +try: + from openjudge.models import OpenAIChatModel +except Exception: # pragma: no cover + from openjudge.models.openai_chat_model import OpenAIChatModel + +from .prompt import EBTU_SYSTEM_PROMPT, EBTU_USER_PROMPT_TEMPLATE +from .json_utils import ( + strict_load_json, + validate_shape, + coerce_to_messages_list, + construct_ebtu_prompt, + count_digit_tokens, +) + + +class EBTUTraceabilityGrader(BaseGrader): + """ + Evidence-Backed Trace Units (EBTU) Grader + + Input: + - traj or record JSON that contains trajectory messages + + Output: + - GraderScore(score in [0,1], reason with compact stats) + """ + + def __init__( + self, + model: Optional[OpenAIChatModel] = None, + name: str = "ebtu_traceability", + temperature: float = 0.0, + max_tokens: int = 2600, + model_name: str = "qwen-flash", + ) -> None: + super().__init__(name=name) + self.model = model or OpenAIChatModel( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + response_format={"type": "json_object"}, + ) + + async def _aevaluate(self, traj: Any, **kwargs: Any) -> GraderScore: + messages = coerce_to_messages_list(traj) + + # 输入有效性检查 + if not messages: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty or invalid trajectory", + ) + + user_prompt, report_plain = construct_ebtu_prompt(messages, EBTU_USER_PROMPT_TEMPLATE) + + judge_messages = [ + {"role": "system", "content": EBTU_SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + + # 模型调用(带异常保护) + try: + resp = await self.model.achat(judge_messages) + raw_text = getattr(resp, "content", None) + if raw_text is None: + raw_text = str(resp) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ) + + try: + obj = strict_load_json(str(raw_text)) + norm = validate_shape(obj) + + score = self._compute_score(norm, report_plain) + reason = self._build_reason(norm, report_plain, score) + return GraderScore(name=self.name, score=score, reason=reason) + except Exception as e: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore(name=self.name, score=0.0, reason=f"EBTU parse error: {e}; raw[:200]={snippet}") + + def _compute_score(self, norm: Dict[str, Any], report_plain: str) -> float: + stats = norm["stats"] + units = norm["units"] + + hard_total = max(1, int(stats.get("hard_units", 0))) + supported = int(stats.get("supported", 0)) + contradicted = int(stats.get("contradicted", 0)) + no_evidence = int(stats.get("no_evidence", 0)) + unclear = int(stats.get("unclear", 0)) + misattrib = int(stats.get("misattrib", 0)) + anchored_hard = max(1, int(stats.get("anchored_hard_units", 0))) + + # Base: reward supported; penalize contradicted/no_evidence strongly; unclear mildly + base = (supported - 1.4 * contradicted - 0.9 * no_evidence - 0.4 * unclear) / hard_total + base = max(0.0, min(1.0, base)) + + # Misattribution penalty: anchors exist but not supported (wrong anchor / wrong use) + misattrib_rate = misattrib / anchored_hard + misattrib_factor = max(0.0, 1.0 - 0.7 * misattrib_rate) + + # Deterministic coverage heuristics based on report digit tokens + digit_tokens = count_digit_tokens(report_plain) + expected_min_units = min(25, max(6, digit_tokens // 2)) + extracted_units = max(1, len(units)) + selection_factor = min(1.0, extracted_units / expected_min_units) if expected_min_units > 0 else 1.0 + + # Optional judge-reported digit/date coverage (soft) + reported_total = int(stats.get("report_digit_date_tokens", 0)) + reported_cov = int(stats.get("covered_digit_date_tokens", 0)) + if reported_total > 0: + cov_ratio = max(0.0, min(1.0, reported_cov / reported_total)) + else: + cov_ratio = 1.0 + cov_factor = 0.65 + 0.35 * cov_ratio # [0.65, 1.0] + + score = base * misattrib_factor * selection_factor * cov_factor + return float(max(0.0, min(1.0, score))) + + def _build_reason(self, norm: Dict[str, Any], report_plain: str, score: float) -> str: + s = norm["stats"] + ex = norm.get("examples", {}) + best = ex.get("best_supported", []) + worst = ex.get("worst_failed", []) + digit_tokens = count_digit_tokens(report_plain) + + parts = [ + f"score={score:.3f}", + f"units={s['total_units']}", + f"hard={s['hard_units']}", + f"sup={s['supported']}", + f"ctr={s['contradicted']}", + f"noev={s['no_evidence']}", + f"unc={s['unclear']}", + f"anch_hard={s['anchored_hard_units']}", + f"misattrib={s['misattrib']}", + f"report_digits≈{digit_tokens}", + ] + if best: + parts.append(f"best={best[:1]}") + if worst: + parts.append(f"worst={worst[:1]}") + return " | ".join(parts) diff --git a/tutorial/example_deep_finance/judge/ebtu/json_utils.py b/tutorial/example_deep_finance/judge/ebtu/json_utils.py new file mode 100644 index 00000000..69b22b13 --- /dev/null +++ b/tutorial/example_deep_finance/judge/ebtu/json_utils.py @@ -0,0 +1,455 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import json +import re +from typing import Any, Dict, List, Tuple + + +# ============================================================================= +# JSON Repair Helper +# ============================================================================= + +def _repair_json(js: str) -> str: + """ + 尝试修复常见的JSON格式错误 + 1. 修复字符串中未转义的换行符 + 2. 修复trailing comma + 3. 修复不完整的JSON(截断) + """ + # 1. 替换字符串值中的未转义换行符 + def escape_newlines_in_strings(s: str) -> str: + result = [] + in_string = False + escape_next = False + i = 0 + while i < len(s): + c = s[i] + if escape_next: + result.append(c) + escape_next = False + elif c == '\\': + result.append(c) + escape_next = True + elif c == '"': + result.append(c) + in_string = not in_string + elif in_string and c == '\n': + result.append('\\n') + elif in_string and c == '\r': + result.append('\\r') + elif in_string and c == '\t': + result.append('\\t') + else: + result.append(c) + i += 1 + return ''.join(result) + + js = escape_newlines_in_strings(js) + + # 2. 移除trailing comma: ",}" -> "}" 和 ",]" -> "]" + js = re.sub(r',\s*}', '}', js) + js = re.sub(r',\s*]', ']', js) + + # 3. 尝试修复截断的JSON - 补全缺失的括号 + open_braces = js.count('{') + close_braces = js.count('}') + open_brackets = js.count('[') + close_brackets = js.count(']') + + if open_braces > close_braces: + # 先关闭可能未闭合的字符串 + in_string = False + escape_next = False + for c in js: + if escape_next: + escape_next = False + elif c == '\\': + escape_next = True + elif c == '"': + in_string = not in_string + if in_string: + js += '"' + + # 补全缺失的括号 + js += ']' * (open_brackets - close_brackets) + js += '}' * (open_braces - close_braces) + + return js + + +def strict_load_json(text: str) -> Dict[str, Any]: + """Parse a JSON object from model output; extract first {...} block if needed. 带容错修复。""" + text = (text or "").strip() + + # 第一次尝试:直接解析 + try: + obj = json.loads(text) + if isinstance(obj, dict): + return obj + except Exception: + pass + + # 尝试提取 {...} 片段 + start = text.find("{") + end = text.rfind("}") + if start != -1 and end != -1 and end > start: + snippet = text[start : end + 1] + # 第二次尝试:直接解析提取的片段 + try: + obj = json.loads(snippet) + if isinstance(obj, dict): + return obj + except Exception: + pass + + # 第三次尝试:修复后解析 + try: + repaired = _repair_json(snippet) + obj = json.loads(repaired) + if isinstance(obj, dict): + return obj + except Exception: + pass + + raise ValueError("Invalid JSON output") + + +def _clip(s: str, n: int) -> str: + s = s or "" + s = s.replace("\u0000", "") + return s[:n] + + +def _as_int(x: Any, default: int = 0) -> int: + try: + if x is None: + return default + if isinstance(x, bool): + return int(x) + if isinstance(x, (int, float)): + return int(x) + if isinstance(x, str) and x.strip(): + return int(float(x.strip())) + except Exception: + return default + return default + + +def _as_list_str(x: Any, max_items: int = 10, max_len: int = 60) -> List[str]: + if not isinstance(x, list): + return [] + out: List[str] = [] + for item in x[:max_items]: + if isinstance(item, str): + out.append(_clip(item, max_len)) + else: + out.append(_clip(str(item), max_len)) + return out + + +def validate_shape(obj: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate and normalize model output for EBTU. + + Returns: + {"units": [...], "stats": {...}, "examples": {...}} + """ + if not isinstance(obj, dict): + raise ValueError("Output is not a JSON object") + + units_raw = obj.get("units", []) + if not isinstance(units_raw, list): + units_raw = [] + + units: List[Dict[str, Any]] = [] + for u in units_raw[:30]: + if not isinstance(u, dict): + continue + + claim = _clip(str(u.get("claim", "")), 280) + hardness = _clip(str(u.get("hardness", "hard")), 8) + if hardness not in {"hard", "soft"}: + hardness = "hard" + + utype = _clip(str(u.get("type", "other")), 24) + + sig = u.get("signature", {}) + if not isinstance(sig, dict): + sig = {} + entities = _as_list_str(sig.get("entities", []), max_items=10, max_len=60) + numbers = _as_list_str(sig.get("numbers", []), max_items=10, max_len=40) + times = _as_list_str(sig.get("times", []), max_items=10, max_len=40) + + ev = u.get("evidence", {}) + if not isinstance(ev, dict): + ev = {} + anchors_raw = ev.get("anchors", []) + anchors: List[Dict[str, Any]] = [] + if isinstance(anchors_raw, list): + for a in anchors_raw[:2]: + if not isinstance(a, dict): + continue + step = _as_int(a.get("step", -1), default=-1) + quote = _clip(str(a.get("quote", "")), 120) + if step >= 0 and quote: + anchors.append({"step": step, "quote": quote}) + anchor_note = _clip(str(ev.get("anchor_note", "")), 60) + + ver = u.get("verification", {}) + if not isinstance(ver, dict): + ver = {} + + verdict = _clip(str(ver.get("verdict", "unclear")), 20) + if verdict not in {"supported", "contradicted", "no_evidence", "speculative_ok", "unclear"}: + verdict = "unclear" + + issue = _clip(str(ver.get("issue", "none")), 20) + allowed_issues = { + "none", "entity_mismatch", "time_mismatch", "value_mismatch", "scope_mismatch", + "logic_leap", "over_precision", "missing_anchor" + } + if issue not in allowed_issues: + issue = "none" + + note = _clip(str(ver.get("note", "")), 80) + + units.append({ + "claim": claim, + "hardness": hardness, + "type": utype, + "signature": {"entities": entities, "numbers": numbers, "times": times}, + "evidence": {"anchors": anchors, "anchor_note": anchor_note}, + "verification": {"verdict": verdict, "issue": issue, "note": note}, + }) + + # Recompute counts (anti-gaming) + verdict_counts = {k: 0 for k in ["supported", "contradicted", "no_evidence", "speculative_ok", "unclear"]} + hard_units = 0 + anchored_hard_units = 0 + misattrib = 0 + for u in units: + v = u["verification"]["verdict"] + verdict_counts[v] += 1 + if u["hardness"] == "hard": + hard_units += 1 + if u["evidence"]["anchors"]: + anchored_hard_units += 1 + if v != "supported": + misattrib += 1 + + stats_raw = obj.get("stats", {}) + if not isinstance(stats_raw, dict): + stats_raw = {} + report_digit_date_tokens = max(0, _as_int(stats_raw.get("report_digit_date_tokens", 0), default=0)) + covered_digit_date_tokens = max(0, _as_int(stats_raw.get("covered_digit_date_tokens", 0), default=0)) + + stats = { + "total_units": len(units), + "hard_units": hard_units, + "supported": verdict_counts["supported"], + "contradicted": verdict_counts["contradicted"], + "no_evidence": verdict_counts["no_evidence"], + "speculative_ok": verdict_counts["speculative_ok"], + "unclear": verdict_counts["unclear"], + "report_digit_date_tokens": report_digit_date_tokens, + "covered_digit_date_tokens": covered_digit_date_tokens, + "anchored_hard_units": anchored_hard_units, + "misattrib": misattrib, + } + + examples_raw = obj.get("examples", {}) + if not isinstance(examples_raw, dict): + examples_raw = {} + + def _norm_list(x: Any, max_items: int = 2) -> List[Dict[str, Any]]: + if not isinstance(x, list): + return [] + out: List[Dict[str, Any]] = [] + for it in x[:max_items]: + if isinstance(it, dict): + out.append({k: _clip(str(v), 160) for k, v in list(it.items())[:4]}) + elif isinstance(it, str): + out.append({"text": _clip(it, 160)}) + return out + + examples = { + "best_supported": _norm_list(examples_raw.get("best_supported", []), 2), + "worst_failed": _norm_list(examples_raw.get("worst_failed", []), 2), + } + + return {"units": units, "stats": stats, "examples": examples} + + +def coerce_to_messages_list(traj: Any) -> List[Dict[str, Any]]: + """Accept list[dict], list[list[dict]], or dict wrapper.""" + if traj is None: + return [] + if isinstance(traj, dict): + for key in ("traj", "messages", "conversation", "steps"): + if key in traj: + return coerce_to_messages_list(traj[key]) + return [] + if isinstance(traj, list): + if not traj: + return [] + if isinstance(traj[0], list): + for inner in traj: + if isinstance(inner, list) and inner and isinstance(inner[0], dict): + return inner + return [] + if isinstance(traj[0], dict): + return traj + return [] + + +def _extract_text_content(content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: List[str] = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif isinstance(item, str): + parts.append(item) + return "\n".join([p for p in parts if p]) + return str(content) + + +def strip_references(markdown: str) -> str: + if not isinstance(markdown, str): + return "" + m = re.search(r"\n#+\s*References\b", markdown, flags=re.IGNORECASE) + if m: + return markdown[: m.start()].strip() + return markdown.strip() + + +def count_digit_tokens(text: str) -> int: + if not text: + return 0 + pats = [ + r"\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", + r"\b\d+(?:\.\d+)?%?\b", + ] + tokens: List[str] = [] + for p in pats: + tokens.extend(re.findall(p, text)) + return len(tokens) + + +def _strip_think(text: str) -> str: + """去除 ... 标签""" + if not text: + return "" + return re.sub(r".*?\s*", "", text, flags=re.S).strip() + + +def _looks_like_tool_result(text: str) -> bool: + """判断是否为工具返回结果""" + t = (text or "").strip() + if t.startswith("Tool:") or t.startswith("Result:"): + return True + if t.startswith("{") and ("query" in t) and ("search_results" in t or "response_content" in t): + return True + if ("股票代码 |" in t) or ("单位:" in t) or t.startswith("### "): + return True + return False + + +def _is_probably_final_report(text: str) -> bool: + """判断是否为最终报告""" + if not text: + return False + t = text.strip() + # 放宽条件:任一条件满足即可 + if "## References" in t or "[TASK_COMPLETED]" in t: + return True + if t.lstrip().startswith("# "): + return True + # 兼容原有逻辑 + has_markdown = ("#" in t) or ("|---" in t) or ("## " in t) + has_refs = re.search(r"#+\s*References\b", t, flags=re.IGNORECASE) is not None + return has_markdown and has_refs + + +def _extract_tool_calls_and_results(trajectory: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + items: List[Dict[str, Any]] = [] + for i, msg in enumerate(trajectory): + role = msg.get("role", "") + content = _extract_text_content(msg.get("content", "")) + if role == "assistant": + if "```json" in content and ("tool_name" in content or "tool_args" in content): + items.append({"step": i, "kind": "tool_call", "text": content}) + elif role == "tool": + items.append({"step": i, "kind": "tool_result", "text": content}) + return items + + +def construct_reward_prompt(trajectory: List[Dict[str, Any]], user_prompt_template: str) -> str: + trajectory = coerce_to_messages_list(trajectory) + + # 提取 user_query(第一个非工具结果的 user 消息) + user_query = "" + for msg in trajectory: + if msg.get("role") == "user": + raw = _extract_text_content(msg.get("content", "")) + if not _looks_like_tool_result(raw): + user_query = _strip_think(raw) + break + + # 提取 final_report(从后往前找第一个符合条件的 assistant 消息) + final_report = "" + for msg in reversed(trajectory): + if msg.get("role") == "assistant": + raw = _extract_text_content(msg.get("content", "")) + t = _strip_think(raw) + if _is_probably_final_report(t): + final_report = t + break + if not final_report: + for msg in reversed(trajectory): + if msg.get("role") == "assistant": + raw = _extract_text_content(msg.get("content", "")) + final_report = _strip_think(raw) + break + + evidence_items = _extract_tool_calls_and_results(trajectory) + evidence_lines: List[str] = [] + for it in evidence_items: + step = it["step"] + prefix = "CALL" if it["kind"] == "tool_call" else "RESULT" + evidence_lines.append(f"[{prefix} step={step}]\n{it['text']}".strip()) + evidence_text = "\n\n".join(evidence_lines).strip() + + return user_prompt_template.format( + user_query=user_query, + evidence_text=evidence_text, + final_report=final_report, + ) + + +def construct_ebtu_prompt( + trajectory: List[Dict[str, Any]], + user_prompt_template: str, +) -> Tuple[str, str]: + """ + Returns: + - user_prompt (for judge) + - report_plain (final report without References) for deterministic coverage checks + """ + user_prompt = construct_reward_prompt(trajectory, user_prompt_template) + + report_plain = "" + for marker in ("\n## Report\n", "\n## AI Report\n"): + if marker in user_prompt: + report_plain = user_prompt.split(marker, 1)[1] + break + if not report_plain: + report_plain = user_prompt + + report_plain = strip_references(report_plain) + return user_prompt, report_plain diff --git a/tutorial/example_deep_finance/judge/ebtu/prompt.py b/tutorial/example_deep_finance/judge/ebtu/prompt.py new file mode 100644 index 00000000..af1fd4ba --- /dev/null +++ b/tutorial/example_deep_finance/judge/ebtu/prompt.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- +""" +EBTU Reward: Evidence-Backed Trace Units (Evidence-first Traceability) + +设计目标: +- 用“证据优先(先证据锚点、后裁决)”的审计输出,支撑可计算的 faithful / FACT-like reward; +- 不绑定金融六元组:以通用的 Trace Unit(原子断言)为核心; +- 结构化输出便于后续确定性打分,避免“先给分再圆”。 + +本文件仅包含 Prompt(System + User Template)。 +打分逻辑在 grader.py 中实现。 +""" + +EBTU_SYSTEM_PROMPT = """ +# 你的身份 +你是一名【证据优先审计官(Evidence-first Auditor)】。 + +# 输入 +你将收到三部分: +1) User Question:用户问题 +2) Evidence:证据区(工具调用与工具返回的原文集合,按 step 编号) +3) Report:需要审计的最终报告 + +# 你的目标 +对 Report 做“可追溯性/可核验性审计”:判断 Report 中的【原子断言】是否能在 Evidence 中找到明确证据锚点。 + +# 核心原则(硬约束) +1) Evidence 是唯一事实来源:不得使用外部常识/训练记忆补全缺失证据。 +2) 证据优先:必须先给出 evidence.anchors(step+quote),再给 verification(verdict/issue/note)。 + - 严禁先输出分数或先下结论再找证据。 +3) 仅审计 Report 正文:忽略 “## References” 及其之后内容。 +4) 覆盖要求:必须覆盖 Report 正文里出现的每一个【数字/日期 token】(近似即可)。 + - 数字/日期 token 示例:13.7%、2025-09-30、1330亿美元 各算 1 个。 +5) 锚点要求: + - 对于 hardness=hard 的断言(尤其含数字/日期),必须提供 1–2 个 anchors,除非 verdict=no_evidence。 + - quote 必须来自 Evidence 原文,可截断;长度 ≤120 字。 +6) 输出必须是严格 JSON(不含 Markdown,不含额外文本);不得新增顶层字段。 +7) 不要输出 score。只输出 units + stats + examples(用于外部确定性计算 reward)。 + +# 断言类型与硬度 +- type 可选:numeric|temporal|event|definition|comparison|causal|recommendation|other +- hardness: + - hard:确定性事实断言(尤其含数字/日期/明确比较/明确事实) + - soft:明确标注推测/假设/情景分析(可能/预计/推测/假设/大概率等)且不伪装成事实 + +# verdict(只能从以下5类选) +- supported:anchors 足以直接支持断言(关键要素匹配) +- contradicted:anchors 明确与断言冲突(主体/时间/数值/方向相反) +- no_evidence:Evidence 中找不到支撑锚点,且断言是确定性表述(hard) +- speculative_ok:断言明确为推测/假设/情景分析(soft)且未伪装成事实 +- unclear:Evidence 有相关但不足以支持/反驳(口径/范围/条件缺失等) + +# issue(只能从以下枚举选) +none | entity_mismatch | time_mismatch | value_mismatch | scope_mismatch | logic_leap | over_precision | missing_anchor + +# JSON 输出模板(字段顺序必须严格一致:先证据后裁决) +{ + "units": [ + { + "claim": "<报告中的原子断言>", + "hardness": "", + "type": "", + "signature": { + "entities": ["<涉及的实体>"], + "numbers": ["<涉及的数字>"], + "times": ["<涉及的时间>"] + }, + "evidence": { + "anchors": [ + { "step": , "quote": "<来自Evidence的原文刦段,≠12字>" } + ], + "anchor_note": "<≤60字,说明为何这些anchors相关>" + }, + "verification": { + "verdict": "", + "issue": "", + "note": "<≤80字,指出支持点/冲突点/缺失点>" + } + } + ], + "stats": { + "total_units": , + "hard_units": , + "supported": , + "contradicted": , + "no_evidence": , + "speculative_ok": , + "unclear": , + "report_digit_date_tokens": , + "covered_digit_date_tokens": <被 units 覆盖的token数>, + "anchored_hard_units": , + "misattrib": <有锚点但verdict不是supported的条数> + }, + "examples": { + "best_supported": [{ "claim": "...", "anchor": { "step": 0, "quote": "..." }, "why": "<≤60字>" }], + "worst_failed": [{ "claim": "...", "verdict": "...", "why": "<≤60字>" }] + } +} + +# 示例(展示完整输出格式) +{ + "units": [ + { + "claim": "2024年Q3营收同比增长15.2%", + "hardness": "hard", + "type": "numeric", + "signature": { "entities": ["营收"], "numbers": ["15.2%"], "times": ["2024年Q3"] }, + "evidence": { + "anchors": [{ "step": 5, "quote": "Q3营收同比+15.2%,达到88.5亿元" }], + "anchor_note": "来自财报工具返回的原始数据" + }, + "verification": { "verdict": "supported", "issue": "none", "note": "数值完全匹配,时间范围一致" } + }, + { + "claim": "预计2025年净利润将达到50亿元", + "hardness": "soft", + "type": "numeric", + "signature": { "entities": ["净利润"], "numbers": ["50亿元"], "times": ["2025年"] }, + "evidence": { + "anchors": [], + "anchor_note": "分析师预测,非硬性事实" + }, + "verification": { "verdict": "speculative_ok", "issue": "none", "note": "明确标注为预测,未伪装成事实" } + } + ], + "stats": { + "total_units": 2, "hard_units": 1, + "supported": 1, "contradicted": 0, "no_evidence": 0, "speculative_ok": 1, "unclear": 0, + "report_digit_date_tokens": 8, "covered_digit_date_tokens": 6, + "anchored_hard_units": 1, "misattrib": 0 + }, + "examples": { + "best_supported": [{ "claim": "2024年Q3营收同比增长15.2%", "anchor": { "step": 5, "quote": "Q3营收同比+15.2%" }, "why": "数值精确匹配" }], + "worst_failed": [] + } +} + +# 统计口径(必须一致) +- total_units = units 的条数 +- hard_units = hardness=hard 的条数 +- supported/contradicted/no_evidence/speculative_ok/unclear 必须与 units[*].verification.verdict 统计一致 +- report_digit_date_tokens:你在 Report 正文中识别到的数字/日期 token 数(近似) +- covered_digit_date_tokens:这些 token 中,有多少被包含在 units[*].signature.numbers 或 units[*].signature.times 中(近似) +- anchored_hard_units:hard_units 中 anchors 非空的条数 +- misattrib:hard_units 中 anchors 非空,但 verdict 不是 supported 的条数(“有锚点但不支持/矛盾/不清楚”) +""" + +EBTU_USER_PROMPT_TEMPLATE = """ +## User Question +{user_query} + +## Evidence +{evidence_text} + +## Report +{final_report} +""" diff --git a/tutorial/example_deep_finance/judge/grounding/grader.py b/tutorial/example_deep_finance/judge/grounding/grader.py index 599ccc9c..86fc87a3 100644 --- a/tutorial/example_deep_finance/judge/grounding/grader.py +++ b/tutorial/example_deep_finance/judge/grounding/grader.py @@ -80,7 +80,7 @@ def create_default_model( return OpenAIChatModel(**kwargs) - async def aevaluate( + async def _aevaluate( self, traj: Any, **_: Any, @@ -195,11 +195,12 @@ def _compute_scores(self, obj: Dict[str, Any]) -> Tuple[float, str]: # 轻量惩罚:存在 invalid refs 会降低 reward # 每个 invalid 号扣 0.1,最多扣 0.5 - invalid_penalty = min(0.1 * invalid_ref_count, 0.5) + # invalid_penalty = min(0.1 * invalid_ref_count, 0.5) + invalid_penalty = 0 # final_reward: 综合分数(权重 0.5:0.5),再叠加 invalid 惩罚 final_reward = 0.5 * citation_coverage_score + 0.5 * grounding_score - final_reward = max(0.0, final_reward - invalid_penalty) + # final_reward = max(0.0, final_reward - invalid_penalty) # 构建 reason good_citations = obj.get('good_citations', []) diff --git a/tutorial/example_deep_finance/judge/grounding/prompt.py b/tutorial/example_deep_finance/judge/grounding/prompt.py index 24bea134..337cf4bc 100644 --- a/tutorial/example_deep_finance/judge/grounding/prompt.py +++ b/tutorial/example_deep_finance/judge/grounding/prompt.py @@ -1,103 +1,152 @@ """Grounding Grader Prompt - 引用规范性评估""" -GROUNDING_SYSTEM_PROMPT = """你是一位"引用审计员",负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。 - -======================== -一、引用规范(以此为准) -======================== -1) 关键事实句必须引用: - - 关键事实句包括:数字(金额/比例/增速/同比环比/份额/排名等)、日期/期间、财务指标、估值倍数、明确事实结论、具体事件、具体公司/行业的可验证陈述、政策/条款等。 - - 不确定或推断性表述必须显式写“推测/可能/假设/预计/或有风险”等,不得用引用把推断包装成既定事实。 - -2) 引用位置规则(严格执行): - - 关键事实句必须在“句末”出现引用编号:[1] 或 [1][2](可以多个,但必须紧贴句末)。 - - 若引用出现在句中但句末没有引用编号,则该句仍按“缺引用(missing)”处理。 - -3) References 必须存在且可追溯: - - 报告末尾必须包含标题 `## References`(大小写/空格差异可容忍,但必须是一个清晰的 References 区块)。 - - 正文出现的每个 [n] 必须能在 References 中找到对应条目。 - -4) References 条目两种合法形式(必须满足其一): - A) URL 形式:`[n] 标题或简述 - https://...` - - URL 必须为可用的 http/https 链接,不能为空,也不能是 `javascript:void(0)` 之类的伪链接。 - B) no-url 形式:`[n] 简述,工具:,参数:,数据日期/报告期: - (no-url)` - - no-url 必须同时包含:工具名、参数、日期/报告期 三者(缺一即不合规)。 - - `javascript:void(0)` 等无效链接视为无效 URL(会进入 invalid_reference_nums),若要合规应改为 no-url 记录来源。 - -======================== -二、输入 -======================== +# GROUNDING_SYSTEM_PROMPT = """你是一位"引用审计员",负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。 + +# ======================== +# 一、引用规范(以此为准) +# ======================== +# 1) 关键事实句必须引用: +# - 关键事实句包括:数字(金额/比例/增速/同比环比/份额/排名等)、日期/期间、财务指标、估值倍数、明确事实结论、具体事件、具体公司/行业的可验证陈述、政策/条款等。 +# - 不确定或推断性表述必须显式写“推测/可能/假设/预计/或有风险”等,不得用引用把推断包装成既定事实。 + +# 2) 引用位置规则(严格执行): +# - 关键事实句必须在“句末”出现引用编号:[1] 或 [1][2](可以多个,但必须紧贴句末)。 +# - 若引用出现在句中但句末没有引用编号,则该句仍按“缺引用(missing)”处理。 + +# 3) References 必须存在且可追溯: +# - 报告末尾必须包含标题 `## References`(大小写/空格差异可容忍,但必须是一个清晰的 References 区块)。 +# - 正文出现的每个 [n] 必须能在 References 中找到对应条目。 + +# 4) References 条目两种合法形式(必须满足其一): +# A) URL 形式:`[n] 标题或简述 - https://...` +# - URL 必须为可用的 http/https 链接,不能为空,也不能是 `javascript:void(0)` 之类的伪链接。 +# B) no-url 形式:`[n] 简述,工具:,参数:,数据日期/报告期: - (no-url)` +# - no-url 必须同时包含:工具名、参数、日期/报告期 三者(缺一即不合规)。 +# - `javascript:void(0)` 等无效链接视为无效 URL(会进入 invalid_reference_nums),若要合规应改为 no-url 记录来源。 + +# ======================== +# 二、输入 +# ======================== +# 你会收到: +# - User Query +# - Evidence(从完整 trajectory 提取的工具调用/工具返回/用户补充信息) +# - AI Report(待审计报告,含正文与 References) + +# 真实性核对原则: +# - 以 Evidence 为准:只有在“明显矛盾”或“Evidence 明显找不到任何依据且该句仍把内容写成确定事实”时,才判 fake。 +# - 无法确认/证据缺失/证据不充分时,不要判 fake(宁可不判)。 + +# ======================== +# 三、统计与判定口径(严格遵守) +# ======================== +# 【文本范围】 +# - 只审计 AI Report 的“正文部分”(不包含 References 区块内部的文字)。 +# - References 区块仅用于校验编号是否存在、格式是否合规、URL 是否有效。 + +# 【句子/条目如何计数】 +# - “句子/条目”包括:普通句号/分号/换行分点(如列表项、段落中的 bullet)、表格中的单元格陈述(若表达了关键事实,也算关键事实句)。 +# - 一句包含多个数字/多个事实点:仍按 1 条关键事实句计数(不要过度拆分)。 +# - 同一句若重复出现多次(复制粘贴重复段落):按出现次数计数。 + +# 【关键事实句识别(务求稳定)】 +# - 满足任一条件可视为关键事实句: +# (a) 含具体数值/比例/排名/区间/估值倍数/财务指标; +# (b) 含具体日期或期间(如 “2024Q3/2025年/截至XX日”); +# (c) 对具体公司/行业/政策做了可验证的确定性陈述; +# (d) 给出明确结论且呈确定口吻并可被证据支持/反驳。 + +# 【引用是否“句末”】【重要】 +# - 句末引用指:该句最后的可见字符为一个或多个连续的 [n](允许中间无空格或有极少空格),例如: +# - “……增长 20%[3]” +# - “……增长 20% [3][4]” +# - 若 [n] 后面仍有正文内容(哪怕很短),则不算句末引用。 + +# 【invalid_reference_nums 的定义】 +# - 统计“正文中出现过”的编号 n(去重),若满足任一条件则判为 invalid: +# (a) References 中不存在该编号条目; +# (b) 该编号条目为 URL 形式但 URL 无效(空/非 http(s)/javascript:void(0) 等); +# (c) 该编号条目为 no-url 形式但缺少 工具名/参数/日期(报告期) 任意之一。 +# - invalid_reference_nums 输出按数字升序;最多 5 个,超出截断。 + +# 【missing_count 的定义】 +# - 关键事实句中“句末没有任何 [n]”的数量(即使句中出现 [n] 也算 missing)。 + +# 【cited_key_facts 的定义】 +# - 关键事实句中“句末包含至少一个 [n]”的数量(不要求该引用有效)。 + +# 【fake_count 的定义(只在明显时计数)】 +# - 关键事实句若“句末带引用”,但与 Evidence 明显矛盾,或 Evidence 明显找不到任何依据且该句仍用确定口吻陈述为事实,计为 fake。 +# - 若只是 Evidence 未覆盖/不充分/不确定,不计 fake。 + +# 【good_citations 的定义】 +# - 从报告原文中抽取最多 2 条“引用做得正确”的关键事实句,要求同时满足: +# - 是关键事实句; +# - 句末有 [n]; +# - 所有句末 [n] 在 References 中均存在且条目合法(URL 有效或 no-url 字段齐全)。 +# - good_citations 是原文截取,不要加解释;最多 2 条,超出截断。 + +# ======================== +# 四、输出(只输出 JSON,字段固定) +# ======================== +# { +# "total_key_facts": , +# "cited_key_facts": , +# "good_citations": ["...", "..."], +# "missing_count": , +# "fake_count": , +# "invalid_reference_nums": [, ...] +# } + +# 只输出 JSON,不要输出解释文字或 Markdown。确保 JSON 可被严格解析(双引号、逗号、方括号等格式正确)。 +# """ + +GROUNDING_SYSTEM_PROMPT = """ +你是一位“引用审计员”,负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。 + +## 引用规范(以此为准) +- 关键事实句必须引用:关键事实句包括数字/同比环比/日期/财务指标/估值倍数/明确事实结论/具体事件/具体公司或行业陈述/政策条款。 +- 关键事实句句末必须出现引用编号:[1] 或 [1][2]。 +- 报告末尾必须包含 `## References`。 +- 正文出现的每个 [n] 必须能在 References 中找到对应条目。 +- References 条目两种合法形式: + A) URL 形式:`[n] 标题或简述 - https://...` + B) no-url 形式:`[n] 简述,工具:,参数:,数据日期/报告期: - (no-url)` +- `javascript:void(0)` 等无效链接不算 URL,应按 no-url 形式记录来源信息。 +- 禁止伪造来源;没有证据支撑的只能写“推测/假设”,不能用引用把推测包装成事实。 + +## 输入 你会收到: - User Query - Evidence(从完整 trajectory 提取的工具调用/工具返回/用户补充信息) - AI Report(待审计报告,含正文与 References) -真实性核对原则: -- 以 Evidence 为准:只有在“明显矛盾”或“Evidence 明显找不到任何依据且该句仍把内容写成确定事实”时,才判 fake。 -- 无法确认/证据缺失/证据不充分时,不要判 fake(宁可不判)。 - -======================== -三、统计与判定口径(严格遵守) -======================== -【文本范围】 -- 只审计 AI Report 的“正文部分”(不包含 References 区块内部的文字)。 -- References 区块仅用于校验编号是否存在、格式是否合规、URL 是否有效。 - -【句子/条目如何计数】 -- “句子/条目”包括:普通句号/分号/换行分点(如列表项、段落中的 bullet)、表格中的单元格陈述(若表达了关键事实,也算关键事实句)。 -- 一句包含多个数字/多个事实点:仍按 1 条关键事实句计数(不要过度拆分)。 -- 同一句若重复出现多次(复制粘贴重复段落):按出现次数计数。 - -【关键事实句识别(务求稳定)】 -- 满足任一条件可视为关键事实句: - (a) 含具体数值/比例/排名/区间/估值倍数/财务指标; - (b) 含具体日期或期间(如 “2024Q3/2025年/截至XX日”); - (c) 对具体公司/行业/政策做了可验证的确定性陈述; - (d) 给出明确结论且呈确定口吻并可被证据支持/反驳。 - -【引用是否“句末”】【重要】 -- 句末引用指:该句最后的可见字符为一个或多个连续的 [n](允许中间无空格或有极少空格),例如: - - “……增长 20%[3]” - - “……增长 20% [3][4]” -- 若 [n] 后面仍有正文内容(哪怕很短),则不算句末引用。 - -【invalid_reference_nums 的定义】 -- 统计“正文中出现过”的编号 n(去重),若满足任一条件则判为 invalid: - (a) References 中不存在该编号条目; - (b) 该编号条目为 URL 形式但 URL 无效(空/非 http(s)/javascript:void(0) 等); - (c) 该编号条目为 no-url 形式但缺少 工具名/参数/日期(报告期) 任意之一。 -- invalid_reference_nums 输出按数字升序;最多 5 个,超出截断。 - -【missing_count 的定义】 -- 关键事实句中“句末没有任何 [n]”的数量(即使句中出现 [n] 也算 missing)。 - -【cited_key_facts 的定义】 -- 关键事实句中“句末包含至少一个 [n]”的数量(不要求该引用有效)。 - -【fake_count 的定义(只在明显时计数)】 -- 关键事实句若“句末带引用”,但与 Evidence 明显矛盾,或 Evidence 明显找不到任何依据且该句仍用确定口吻陈述为事实,计为 fake。 -- 若只是 Evidence 未覆盖/不充分/不确定,不计 fake。 - -【good_citations 的定义】 -- 从报告原文中抽取最多 2 条“引用做得正确”的关键事实句,要求同时满足: - - 是关键事实句; - - 句末有 [n]; - - 所有句末 [n] 在 References 中均存在且条目合法(URL 有效或 no-url 字段齐全)。 -- good_citations 是原文截取,不要加解释;最多 2 条,超出截断。 - -======================== -四、输出(只输出 JSON,字段固定) -======================== +核对真实性时,以 Evidence 为准:只有在“明显矛盾/明显找不到依据”时才判 fake;无法确认则不要判 fake。 + +## 输出(只输出 JSON,字段固定) { "total_key_facts": , "cited_key_facts": , - "good_citations": ["...", "..."], + "good_citations": ["从报告原文截取的:关键事实句 + 句末 [n],且 References 可追溯(最多 5 条)", ...] "missing_count": , "fake_count": , - "invalid_reference_nums": [, ...] + "invalid_reference_nums": [, ...], } -只输出 JSON,不要输出解释文字或 Markdown。确保 JSON 可被严格解析(双引号、逗号、方括号等格式正确)。 +统计口径(为保证稳定,严格遵守): +- total_key_facts:正文中关键事实句的总数(按句子/条目计;一句多个数字也算 1 条即可,不要过度拆分)。 +- cited_key_facts:关键事实句中,句末包含至少一个 [n] 的数量(不要求该引用一定有效)。 +- invalid_reference_nums:正文出现过、但满足任一条件的编号: + (a) References 中不存在该编号条目; + (b) URL 形式但 URL 无效(空或 javascript:void(0) 等); + (c) no-url 形式但缺少“工具名/参数/日期(报告期)”之一。 +- missing_count:关键事实句中“句末没有 [n]”的数量。 +- fake_count:关键事实句“带引用但与 Evidence 明显矛盾/明显无支撑”的数量(仅明显时计数)。 +- good_citations:从报告原文中选取最多 5 条“引用做得正确”的关键事实句(句末有 [n],且 [n] 在 References 中合法)。 + +长度约束(必须): +- invalid_reference_nums 最多 5 个,多余截断。 +- good_citations 最多 2 条,多余截断。 +只输出 JSON,不要输出解释文字或 Markdown。 """ # ============================================================================= diff --git a/tutorial/example_deep_finance/judge/presentation_quality/grader.py b/tutorial/example_deep_finance/judge/presentation_quality/grader.py index c440c3e4..80de5740 100644 --- a/tutorial/example_deep_finance/judge/presentation_quality/grader.py +++ b/tutorial/example_deep_finance/judge/presentation_quality/grader.py @@ -83,7 +83,7 @@ def create_default_model( return OpenAIChatModel(**kwargs) - async def aevaluate( + async def _aevaluate( self, report_content: str, user_query: str | None = None, diff --git a/tutorial/example_deep_finance/judge/traceability/__init__.py b/tutorial/example_deep_finance/judge/traceability/__init__.py new file mode 100644 index 00000000..18845402 --- /dev/null +++ b/tutorial/example_deep_finance/judge/traceability/__init__.py @@ -0,0 +1,7 @@ +""" +CGCV (Citation-Grounded Claim Verification) Grader +引用锚定的断言验证框架 +""" +from .grader import TraceabilityRewardGrader + +__all__ = ["TraceabilityRewardGrader"] diff --git a/tutorial/example_deep_finance/judge/traceability/grader.py b/tutorial/example_deep_finance/judge/traceability/grader.py new file mode 100644 index 00000000..42beee5b --- /dev/null +++ b/tutorial/example_deep_finance/judge/traceability/grader.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from typing import Any, Dict, List + +from openjudge.graders.base_grader import BaseGrader +from openjudge.graders.schema import GraderScore + +try: + from openjudge.models import OpenAIChatModel +except Exception: # pragma: no cover + from openjudge.models.openai_chat_model import OpenAIChatModel + +from .prompt import TRACEABILITY_SYSTEM_PROMPT, TRACEABILITY_USER_PROMPT_TEMPLATE +from .json_utils import strict_load_json, validate_shape, coerce_to_messages_list, construct_traceability_prompt, count_digit_tokens + + +class TraceabilityRewardGrader(BaseGrader): + """ + Traceability & Verifiability Reward (TVR) + + Input: traj (trajectory / record) - supports: + - list[dict] + - list[list[dict]] + - dict with {"traj": ...} etc. + + Output: GraderScore(name="traceability", score in [0,1], reason includes stats + brief examples) + """ + + def __init__( + self, + model: OpenAIChatModel, + name: str = "traceability", + **kwargs: Any, + ) -> None: + super().__init__(name=name, **kwargs) + self.model = model + + async def _aevaluate(self, traj: Any, **kwargs: Any) -> GraderScore: + messages = coerce_to_messages_list(traj) + + if not messages: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty trajectory", + ) + + user_prompt, report_plain = construct_traceability_prompt( + messages, + TRACEABILITY_USER_PROMPT_TEMPLATE, + ) + + judge_messages = [ + {"role": "system", "content": TRACEABILITY_SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + + # 调用模型(带异常捕获) + try: + resp = await self.model.achat(judge_messages) + raw_text = getattr(resp, "content", None) + if raw_text is None: + raw_text = str(resp) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ) + + # 解析 JSON 并计算分数 + try: + obj = strict_load_json(str(raw_text)) + norm = validate_shape(obj) + score = self._compute_score(norm, report_plain) + reason = self._build_reason(norm, report_plain, score) + return GraderScore(name=self.name, score=score, reason=reason) + except Exception as e: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"TVR ParseError: {e}; raw[:200]={snippet}", + ) + + def _compute_score(self, norm: Dict[str, Any], report_plain: str) -> float: + stats = norm["stats"] + total = max(1, int(stats.get("total_claims", 0))) + + supported = int(stats.get("supported", 0)) + contradicted = int(stats.get("contradicted", 0)) + no_evidence = int(stats.get("no_evidence", 0)) + speculative_ok = int(stats.get("speculative_ok", 0)) + unclear = int(stats.get("unclear", 0)) + + # Positive contribution + pos = supported + 0.6 * speculative_ok + 0.3 * unclear + # Negative contribution (contradiction is harsh) + neg = 1.0 * contradicted + 0.8 * no_evidence + + base = (pos - neg) / total # can be negative + base = max(0.0, min(1.0, base)) + + # Coverage factor (deterministic) based on digits/dates in report body + real_digit_tokens = count_digit_tokens(report_plain) + expected_min_claims = min(25, max(6, real_digit_tokens // 2)) + claim_count = int(stats.get("total_claims", total)) + + selection_factor = min(1.0, claim_count / expected_min_claims) if expected_min_claims > 0 else 1.0 + + # If the judge reports digit coverage, blend it in (but keep deterministic as the main) + reported_total_digits = int(stats.get("report_digit_tokens", 0)) + reported_covered_digits = int(stats.get("covered_digit_tokens", 0)) + if reported_total_digits > 0: + reported_cov = min(1.0, max(0.0, reported_covered_digits / reported_total_digits)) + else: + reported_cov = 1.0 + + cov_factor = 0.7 + 0.3 * reported_cov # [0.7, 1.0] + + score = base * selection_factor * cov_factor + score = max(0.0, min(1.0, score)) + return float(score) + + def _build_reason(self, norm: Dict[str, Any], report_plain: str, score: float) -> str: + stats = norm["stats"] + ex = norm.get("examples", {}) + best = ex.get("best_supported", []) + worst = ex.get("worst_failed", []) + + real_digit_tokens = count_digit_tokens(report_plain) + + parts = [] + parts.append( + f"score={score:.3f}; " + f"claims={stats['total_claims']}; " + f"supported={stats['supported']}; " + f"spec_ok={stats['speculative_ok']}; " + f"unclear={stats['unclear']}; " + f"no_ev={stats['no_evidence']}; " + f"contradicted={stats['contradicted']}; " + f"report_digits≈{real_digit_tokens}" + ) + + if best: + parts.append(f"best_supported={best[:1]}") + if worst: + parts.append(f"worst_failed={worst[:1]}") + return " | ".join(parts) diff --git a/tutorial/example_deep_finance/judge/traceability/json_utils.py b/tutorial/example_deep_finance/judge/traceability/json_utils.py new file mode 100644 index 00000000..7a005b62 --- /dev/null +++ b/tutorial/example_deep_finance/judge/traceability/json_utils.py @@ -0,0 +1,462 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import json +import re +from typing import Any, Dict, List, Tuple + + +# ============================================================================= +# JSON Repair Helper +# ============================================================================= + +def _repair_json(js: str) -> str: + """ + 尝试修复常见的JSON格式错误 + 1. 修复字符串中未转义的换行符 + 2. 修复trailing comma + 3. 修复不完整的JSON(截断) + """ + # 1. 替换字符串值中的未转义换行符 + def escape_newlines_in_strings(s: str) -> str: + result = [] + in_string = False + escape_next = False + i = 0 + while i < len(s): + c = s[i] + if escape_next: + result.append(c) + escape_next = False + elif c == '\\': + result.append(c) + escape_next = True + elif c == '"': + result.append(c) + in_string = not in_string + elif in_string and c == '\n': + result.append('\\n') + elif in_string and c == '\r': + result.append('\\r') + elif in_string and c == '\t': + result.append('\\t') + else: + result.append(c) + i += 1 + return ''.join(result) + + js = escape_newlines_in_strings(js) + + # 2. 移除trailing comma: ",}" -> "}" 和 ",]" -> "]" + js = re.sub(r',\s*}', '}', js) + js = re.sub(r',\s*]', ']', js) + + # 3. 尝试修复截断的JSON - 补全缺失的括号 + open_braces = js.count('{') + close_braces = js.count('}') + open_brackets = js.count('[') + close_brackets = js.count(']') + + if open_braces > close_braces: + # 先关闭可能未闭合的字符串 + in_string = False + escape_next = False + for c in js: + if escape_next: + escape_next = False + elif c == '\\': + escape_next = True + elif c == '"': + in_string = not in_string + if in_string: + js += '"' + + # 补全缺失的括号 + js += ']' * (open_brackets - close_brackets) + js += '}' * (open_braces - close_braces) + + return js + + +# -------------------------- +# JSON parsing helpers +# -------------------------- + +def strict_load_json(text: str) -> Dict[str, Any]: + """ + Parse a JSON object from model output. 带容错修复。 + + - Accept plain JSON. + - If extra text exists, extract the first {...} block. + - If parsing fails, attempt to repair common JSON errors. + """ + text = (text or "").strip() + + # 第一次尝试:直接解析 + try: + obj = json.loads(text) + if isinstance(obj, dict): + return obj + except Exception: + pass + + start = text.find("{") + end = text.rfind("}") + if start != -1 and end != -1 and end > start: + snippet = text[start : end + 1] + # 第二次尝试:直接解析提取的片段 + try: + obj = json.loads(snippet) + if isinstance(obj, dict): + return obj + except Exception: + pass + + # 第三次尝试:修复后解析 + try: + repaired = _repair_json(snippet) + obj = json.loads(repaired) + if isinstance(obj, dict): + return obj + except Exception: + pass + + raise ValueError("Invalid JSON output") + + +def _clip(s: str, n: int) -> str: + s = s or "" + s = s.replace("\u0000", "") + return s[:n] + + +def _as_int(x: Any, default: int = 0) -> int: + try: + if x is None: + return default + if isinstance(x, bool): + return int(x) + if isinstance(x, (int, float)): + return int(x) + if isinstance(x, str) and x.strip(): + return int(float(x.strip())) + except Exception: + return default + return default + + +def _as_list_str(x: Any, max_items: int = 10, max_len: int = 60) -> List[str]: + if not isinstance(x, list): + return [] + out: List[str] = [] + for item in x[:max_items]: + if isinstance(item, str): + out.append(_clip(item, max_len)) + else: + out.append(_clip(str(item), max_len)) + return out + + +def validate_shape(obj: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate and normalize model output for TVR. + + Returns: + { + "claims": [...], + "stats": {...}, + "examples": {...} + } + """ + if not isinstance(obj, dict): + raise ValueError("Output is not a JSON object") + + claims_raw = obj.get("claims", []) + if not isinstance(claims_raw, list): + claims_raw = [] + + claims: List[Dict[str, Any]] = [] + for c in claims_raw[:25]: + if not isinstance(c, dict): + continue + + claim = _clip(str(c.get("claim", "")), 240) + ctype = _clip(str(c.get("type", "other")), 24) + + sig = c.get("signature", {}) + if not isinstance(sig, dict): + sig = {} + entities = _as_list_str(sig.get("entities", []), max_items=10, max_len=50) + numbers = _as_list_str(sig.get("numbers", []), max_items=10, max_len=40) + times = _as_list_str(sig.get("times", []), max_items=10, max_len=40) + + anchors_raw = c.get("anchors", []) + anchors: List[Dict[str, Any]] = [] + if isinstance(anchors_raw, list): + for a in anchors_raw[:2]: + if not isinstance(a, dict): + continue + step = _as_int(a.get("step", -1), default=-1) + quote = _clip(str(a.get("quote", "")), 120) + if step >= 0 and quote: + anchors.append({"step": step, "quote": quote}) + + verdict = _clip(str(c.get("verdict", "unclear")), 20) + if verdict not in {"supported", "contradicted", "no_evidence", "speculative_ok", "unclear"}: + verdict = "unclear" + + issue = _clip(str(c.get("issue", "none")), 20) + allowed_issues = { + "none", "entity_mismatch", "time_mismatch", "value_mismatch", "scope_mismatch", + "logic_leap", "over_precision", "missing_anchor" + } + if issue not in allowed_issues: + issue = "none" + + note = _clip(str(c.get("note", "")), 80) + + claims.append({ + "claim": claim, + "type": ctype, + "signature": {"entities": entities, "numbers": numbers, "times": times}, + "anchors": anchors, + "verdict": verdict, + "issue": issue, + "note": note, + }) + + # stats + stats_raw = obj.get("stats", {}) + if not isinstance(stats_raw, dict): + stats_raw = {} + + # always re-count to avoid mismatch / gaming + verdict_counts = { + "supported": 0, + "contradicted": 0, + "no_evidence": 0, + "speculative_ok": 0, + "unclear": 0, + } + for c in claims: + verdict_counts[c["verdict"]] += 1 + + report_digit_tokens = max(0, _as_int(stats_raw.get("report_digit_tokens", 0), default=0)) + covered_digit_tokens = max(0, _as_int(stats_raw.get("covered_digit_tokens", 0), default=0)) + + stats = { + "total_claims": len(claims), + "supported": verdict_counts["supported"], + "contradicted": verdict_counts["contradicted"], + "no_evidence": verdict_counts["no_evidence"], + "speculative_ok": verdict_counts["speculative_ok"], + "unclear": verdict_counts["unclear"], + "report_digit_tokens": report_digit_tokens, + "covered_digit_tokens": covered_digit_tokens, + } + + # examples (small) + examples_raw = obj.get("examples", {}) + if not isinstance(examples_raw, dict): + examples_raw = {} + + def _normalize_example_list(x: Any, max_items: int = 2) -> List[Dict[str, Any]]: + if not isinstance(x, list): + return [] + out: List[Dict[str, Any]] = [] + for it in x[:max_items]: + if isinstance(it, dict): + out.append({k: _clip(str(v), 140) for k, v in list(it.items())[:3]}) + elif isinstance(it, str): + out.append({"text": _clip(it, 140)}) + return out + + examples = { + "best_supported": _normalize_example_list(examples_raw.get("best_supported", []), 2), + "worst_failed": _normalize_example_list(examples_raw.get("worst_failed", []), 2), + } + + return {"claims": claims, "stats": stats, "examples": examples} + + +# -------------------------- +# Trajectory helpers +# -------------------------- + +def coerce_to_messages_list(traj: Any) -> List[Dict[str, Any]]: + """ + Accepts: + - list[dict] + - list[list[dict]] (take first non-empty inner list) + - dict with keys: traj / messages / conversation / steps (best-effort) + + Returns list[dict] message objects. + """ + if traj is None: + return [] + + if isinstance(traj, dict): + for key in ("traj", "messages", "conversation", "steps"): + if key in traj: + return coerce_to_messages_list(traj[key]) + return [] + + if isinstance(traj, list): + if not traj: + return [] + if isinstance(traj[0], list): + for inner in traj: + if isinstance(inner, list) and inner and isinstance(inner[0], dict): + return inner + return [] + if isinstance(traj[0], dict): + return traj + + return [] + + +def _extract_text_content(content: Any) -> str: + """ + Extract textual content from different possible message formats. + """ + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: List[str] = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif isinstance(item, str): + parts.append(item) + return "\n".join([p for p in parts if p]) + return str(content) + + +def strip_references(markdown: str) -> str: + """ + Remove References section and anything after it (common Markdown headings). + """ + if not isinstance(markdown, str): + return "" + m = re.search(r"\n#+\s*References\b", markdown, flags=re.IGNORECASE) + if m: + return markdown[: m.start()].strip() + return markdown.strip() + + +def count_digit_tokens(text: str) -> int: + """ + Rough count for digit/date tokens in text. + """ + if not text: + return 0 + pats = [ + r"\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", # ISO-ish date + r"\b\d+(?:\.\d+)?%?\b", # number / percent + ] + tokens: List[str] = [] + for p in pats: + tokens.extend(re.findall(p, text)) + return len(tokens) + + +def _is_probably_final_report(text: str) -> bool: + """ + Heuristic: final report is usually markdown-ish and contains References / TASK_COMPLETED etc. + """ + if not text: + return False + # allow either TASK_COMPLETED or markdown headings + References + has_markdown = ("#" in text) or ("|---" in text) or ("## " in text) + has_refs = re.search(r"#+\s*References\b", text, flags=re.IGNORECASE) is not None + has_done = "[TASK_COMPLETED]" in text + return has_done or (has_markdown and has_refs) + + +def _extract_tool_calls_and_results(trajectory: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Extract tool call and tool output blocks, loosely following the format in your existing data. + """ + items: List[Dict[str, Any]] = [] + for i, msg in enumerate(trajectory): + role = msg.get("role", "") + content = _extract_text_content(msg.get("content", "")) + + if role == "assistant": + # look for JSON code block that indicates tool calls + if "```json" in content and ("tool_name" in content or "tool_args" in content): + items.append({"step": i, "kind": "tool_call", "text": content}) + elif role == "tool": + items.append({"step": i, "kind": "tool_result", "text": content}) + return items + + +def construct_reward_prompt(trajectory: List[Dict[str, Any]], user_prompt_template: str) -> str: + """ + Build a user prompt with: + - user_query: last user message + - evidence_text: concatenated tool calls/results with step index + - final_report: last assistant message that looks like final report + """ + trajectory = coerce_to_messages_list(trajectory) + + user_query = "" + for msg in reversed(trajectory): + if msg.get("role") == "user": + user_query = _extract_text_content(msg.get("content", "")) + break + + final_report = "" + for msg in reversed(trajectory): + if msg.get("role") == "assistant": + t = _extract_text_content(msg.get("content", "")) + if _is_probably_final_report(t): + final_report = t + break + if not final_report: + # fallback to last assistant msg + for msg in reversed(trajectory): + if msg.get("role") == "assistant": + final_report = _extract_text_content(msg.get("content", "")) + break + + evidence_items = _extract_tool_calls_and_results(trajectory) + evidence_lines: List[str] = [] + for it in evidence_items: + step = it["step"] + kind = it["kind"] + prefix = "CALL" if kind == "tool_call" else "RESULT" + evidence_lines.append(f"[{prefix} step={step}]\n{it['text']}".strip()) + evidence_text = "\n\n".join(evidence_lines).strip() + + return user_prompt_template.format( + user_query=user_query, + evidence_text=evidence_text, + final_report=final_report, + ) + + +def construct_traceability_prompt( + trajectory: List[Dict[str, Any]], + user_prompt_template: str, +) -> Tuple[str, str]: + """ + Returns: + - user_prompt (for the judge model) + - report_plain (final report without References) for deterministic coverage checks + """ + user_prompt = construct_reward_prompt(trajectory, user_prompt_template) + + final_report = "" + marker = "\n## AI Report\n" + if marker in user_prompt: + final_report = user_prompt.split(marker, 1)[1] + # Cut at "\n\n### 审计流程" if present. + cut = "\n\n### 审计流程" + if cut in final_report: + final_report = final_report.split(cut, 1)[0] + + report_plain = strip_references(final_report) + return user_prompt, report_plain diff --git a/tutorial/example_deep_finance/judge/traceability/prompt.py b/tutorial/example_deep_finance/judge/traceability/prompt.py new file mode 100644 index 00000000..e8d7d1bd --- /dev/null +++ b/tutorial/example_deep_finance/judge/traceability/prompt.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +""" +Traceability & Verifiability Reward (TVR) + +目标: +- 用“可追溯性/可核验性”替代“引用是否存在”的Reference类reward; +- 避免强金融领域绑定:面向任何深度研究报告,只要有“证据(工具结果/对话上下文)+ 报告文本”即可工作; +- 通过“断言-证据锚点”审计,奖励:事实陈述可在证据中找到锚点、或明确标注为推测;惩罚:无证据的硬断言、与证据矛盾、过度精确的数值。 + +注意:该文件仅包含 prompt,不包含打分逻辑。打分由 grader.py 依据模型输出的结构化审计结果计算。 +""" + +TRACEABILITY_SYSTEM_PROMPT = r""" +# 你的身份 +你是一名“可追溯性/可核验性审计官(Traceability Auditor)”。 + +# 你的目标 +给定: +- 用户问题(User Question) +- 证据区(Evidence):包含对话中工具调用与工具返回的原文片段(视为“可用证据全集”) +- 待审计报告(AI Report):模型写出的最终 Markdown 报告 + +你需要评估:报告中的“可核验断言”是否能在 Evidence 中找到明确锚点(traceable),或是否被正确标注为“推测/假设”。 + +# 核心原则(非常重要) +1) **Evidence 是唯一事实来源**:不得使用外部常识/训练记忆补全缺失证据。 +2) **先举证再下结论**:输出结构中必须先给出断言与证据锚点/不匹配点,再汇总统计;不要先给分数再找理由。 +3) **惩罚“硬断言无证据”**:越具体(数字、日期、比例、排名、同比环比、绝对结论)的断言越需要证据锚点。 +4) **允许“推测/假设”**:若报告明确使用“可能/预计/推测/假设/大概率”等表述,并且没有把它包装成确定事实,则可以判为 speculative_ok(弱奖励/不惩罚)。 +5) **优先覆盖“数字/日期/实体”断言**:必须覆盖报告正文中出现的每一个数字或日期(含表格);因为这是最容易出现“编造”的区域。 +6) **不要评估写作质量**(结构/文风/可读性等不在本任务范围),只评估“可追溯/可核验”。 + +# 你要产出的 JSON(严格 JSON,不要 markdown,不要多余文本) +输出 JSON 需要包含: +- claims:断言列表,每条必须包含断言原文、锚点要素(实体/数值/时间)、证据锚点(step+quote)、判定与原因 +- stats:统计汇总(先统计,再由外部计算分数) +- examples:最多各2条“最好的支持案例”和“最差的失败案例”(用于调试) + +断言(claim)的判定(verdict)只能是: +- supported:Evidence 中有明确锚点支撑(实体/时间/数值关键点对应) +- contradicted:Evidence 中存在明确冲突(数值/时间/事实相反) +- no_evidence:找不到相关证据锚点,且该断言是硬断言 +- speculative_ok:断言被明确标注为推测/假设,且未伪装成事实 +- unclear:Evidence 有相关但不足以确定支持/反驳(模糊、缺关键字段、只部分匹配) + +issue(主要问题)建议从下面枚举中选择一个: +- none | entity_mismatch | time_mismatch | value_mismatch | scope_mismatch | logic_leap | over_precision | missing_anchor + +额外要求: +- 每条 claim 的 note ≤ 80 字(给出关键理由即可) +- evidence_quote ≤ 120 字,必须是 Evidence 中的原文片段(可截断) +""" + +# NOTE: 该模板会被 json_utils.construct_reward_prompt 填充 {user_query} {evidence_text} {final_report} +TRACEABILITY_USER_PROMPT_TEMPLATE = r""" +请对下面的 AI Report 做“可追溯性/可核验性审计”,并严格按要求输出 JSON。 + +## User Question +{user_query} + +## Evidence +{evidence_text} + +## AI Report +{final_report} + +### 审计流程(必须执行) +1) 仅审计 **AI Report 正文**(忽略其 `## References` 及之后内容)。 +2) 抽取“可核验断言”: + - 必须包含:所有出现“数字/日期”的句子或表格行(逐条拆成原子断言) + - 另外补充:3–8条非数字但可核验的硬事实(涉及具体实体/事件/定义/比较/因果的断言) +3) 对每条断言: + - 提取锚点要素:entities / numbers / times(可以为空列表,但含数字/日期的断言不得为空) + - 在 Evidence 中找到最相关的 1–2 个锚点(用 step 序号 + 原文 quote 表示) + - 给出 verdict + issue + note(简短指出匹配/不匹配的关键点) +4) 最后汇总 stats 与 examples(不要给分数)。 + +### 输出 JSON 结构(严格遵守字段名;不要新增顶层字段) +{{ + "claims": [ + {{ + "claim": "从报告中复制的原句或原子断言(尽量短)", + "type": "quant|event|definition|comparison|causal|recommendation|other", + "signature": {{ + "entities": ["..."], + "numbers": ["..."], + "times": ["..."] + }}, + "anchors": [ + {{"step": 12, "quote": "Evidence 原文片段..."}}, + {{"step": 13, "quote": "Evidence 原文片段..."}} + ], + "verdict": "supported|contradicted|no_evidence|speculative_ok|unclear", + "issue": "none|entity_mismatch|time_mismatch|value_mismatch|scope_mismatch|logic_leap|over_precision|missing_anchor", + "note": "≤80字,说明为何这样判定" + }} + ], + "stats": {{ + "total_claims": 0, + "supported": 0, + "contradicted": 0, + "no_evidence": 0, + "speculative_ok": 0, + "unclear": 0, + "report_digit_tokens": 0, + "covered_digit_tokens": 0 + }}, + "examples": {{ + "best_supported": [ + {{"claim": "...", "anchor": {{"step": 0, "quote": "..."}}}} + ], + "worst_failed": [ + {{"claim": "...", "why": "..." }} + ] + }} +}} + +### 统计口径(必须一致) +- report_digit_tokens:你在报告正文中识别到的“数字/日期 token”的数量(近似即可;如 1330亿美元、13.7%、2025-09-30 各算 1 个 token) +- covered_digit_tokens:这些 token 中,有多少出现在你提取的 claims 的 signature.numbers 或 signature.times 里(近似即可) +- total_claims 必须等于 claims 的条数;其余计数必须与 claims 中 verdict 的统计一致 +""" diff --git a/tutorial/example_deep_finance/prompt/tool_prompt_builder.py b/tutorial/example_deep_finance/prompt/tool_prompt_builder.py index 5c940fd7..0345f2c9 100644 --- a/tutorial/example_deep_finance/prompt/tool_prompt_builder.py +++ b/tutorial/example_deep_finance/prompt/tool_prompt_builder.py @@ -60,11 +60,6 @@ def get_tool_prompt_template() -> str: **参数**: - `query` (必填, string): 搜索关键词 -#### ✅ crawl_url -**功能**: 网页内容解析工具,获取并格式化指定URL的网页内容。 -**参数**: - - `url` (必填, string): 目标网页URL - # --- ### 📈 同花顺专项数据工具 (Crawl THS) diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index 166381da..38aa82ed 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -12,6 +12,10 @@ ajet: # OpenJudge 权重配置 presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估 grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估 + cgcv_weight: {{CGCV_WEIGHT}} # Citation-Grounded Claim Verification + audit_weight: {{AUDIT_WEIGHT}} # 引用逻辑审计 + traceability_weight: {{TRACEABILITY_WEIGHT}} # 可追溯性/可核验性审计 (TVR) + ebtu_weight: {{EBTU_WEIGHT}} # Audit Grader: audit reward EBTU证据优先可追溯性审计 rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 task_judge: # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml new file mode 100644 index 00000000..0ddd541c --- /dev/null +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml @@ -0,0 +1,91 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: "{{PREFIX}}" + experiment_name: "{{SUFFIX}}" + # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) + judge: + openjudge_llm: {{OPENJUDGE_LLM}} # OpenJudge 模型 + rm_llm: {{RM_LLM}} # RM Gallery 模型 + concurrency: {{JUDGE_CONCURRENCY}} # Judge 并发数 + train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径 + val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径 + # OpenJudge 权重配置 + presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估 + grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估 + cgcv_weight: {{CGCV_WEIGHT}} # Citation-Grounded Claim Verification + audit_weight: {{AUDIT_WEIGHT}} # 引用逻辑审计 + traceability_weight: {{TRACEABILITY_WEIGHT}} # 可追溯性/可核验性审计 (TVR) + ebtu_weight: {{EBTU_WEIGHT}} # Audit Grader: audit reward EBTU证据优先可追溯性审计 + rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 + task_judge: + # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) + judge_protocol: tutorial.example_deep_finance.deep_finance_judge->DeepFinanceJudgeByOpenJudge + model: + # ✨✨✨✨ 设置待训练的模型 + path: {{MODEL_PATH}} + trainer_common: + nnodes: {{NNODES}} + n_gpus_per_node: 8 + val_before_train: True + val_pass_n: 8 + save_freq: 10 + test_freq: 2 + total_epochs: 200 + save_trajectory_as_json_file: True + rollout: + # ✨✨✨✨ 编写并选择Agent + user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol + force_disable_toolcalls: False + enable_oversample: False + tensor_model_parallel_size: 8 + num_repeat: {{NUM_REPEAT}} + max_env_worker: 64 # 增加环境并行数 + max_num_seqs: 64 # 增加VLLM并发序列数 + max_response_length_in_one_turn: 8000 + max_model_len: {{MAX_MODEL_LEN}} + agent_madness_reward: 0.0 + compute_madness_checklist: None + multi_turn: + max_steps: {{NUM_STEPS}} + interchange_server: + interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + debug: + debug_max_parallel: 1 # 增加并行任务数,充分利用GPU + debug_first_n_tasks: 100 # 增加处理的任务数 + data: + train_batch_size: {{TRAIN_BATCH_SIZE}} + max_prompt_length: 20000 + max_response_length: 45000 + + task_reader: + type: deep_finance # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service + deep_finance: + training: + file_path: {{TRAIN_DATA_PATH}} + validation: + file_path: {{VAL_DATA_PATH}} + # env_service 仍需配置(用于工具调用) + env_service: + env_type: "finworld" + env_url: {{ENV_SERVICE_URL}} + env_action_preference: code +trainer: + default_local_dir: "{{CKPT_SAVE_PATH}}/{{PREFIX}}/{{SUFFIX}}" + # resume_mode: disable # 禁用自动恢复,从头开始训练 +actor_rollout_ref: + rollout: + tensor_model_parallel_size: 8 + gpu_memory_utilization: 0.8 +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_deep_finance/yaml_template/infer.yaml b/tutorial/example_deep_finance/yaml_template/infer.yaml new file mode 100644 index 00000000..5e9d400e --- /dev/null +++ b/tutorial/example_deep_finance/yaml_template/infer.yaml @@ -0,0 +1,91 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: "{{PREFIX}}" + experiment_name: "{{SUFFIX}}" + # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) + judge: + openjudge_llm: {{OPENJUDGE_LLM}} # OpenJudge 模型 + rm_llm: {{RM_LLM}} # RM Gallery 模型 + concurrency: {{JUDGE_CONCURRENCY}} # Judge 并发数 + train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径 + val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径 + # OpenJudge 权重配置 + presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估 + grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估 + cgcv_weight: {{CGCV_WEIGHT}} # Citation-Grounded Claim Verification + audit_weight: {{AUDIT_WEIGHT}} # 引用逻辑审计 + traceability_weight: {{TRACEABILITY_WEIGHT}} # 可追溯性/可核验性审计 (TVR) + ebtu_weight: {{EBTU_WEIGHT}} # Audit Grader: audit reward EBTU证据优先可追溯性审计 + rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 + task_judge: + # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) + judge_protocol: tutorial.example_deep_finance.deep_finance_judge->DeepFinanceJudgeByOpenJudge + model: + # ✨✨✨✨ 设置待训练的模型 + path: {{MODEL_PATH}} + trainer_common: + nnodes: {{NNODES}} + n_gpus_per_node: 8 + val_before_train: True + val_pass_n: 2 + save_freq: 10 + test_freq: 2 + total_epochs: {{TOTAL_EPOCHS}} + save_trajectory_as_json_file: True + rollout: + # ✨✨✨✨ 编写并选择Agent + user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol + force_disable_toolcalls: False + enable_oversample: False + tensor_model_parallel_size: 8 + num_repeat: {{NUM_REPEAT}} + max_env_worker: 64 # 增加环境并行数 + max_num_seqs: 64 # 增加VLLM并发序列数 + max_response_length_in_one_turn: 8000 + max_model_len: {{MAX_MODEL_LEN}} + agent_madness_reward: 0.0 + compute_madness_checklist: None + multi_turn: + max_steps: {{NUM_STEPS}} + interchange_server: + interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + debug: + debug_max_parallel: 1 # 增加并行任务数,充分利用GPU + debug_first_n_tasks: 100 # 增加处理的任务数 + data: + train_batch_size: {{TRAIN_BATCH_SIZE}} + max_prompt_length: 8000 + max_response_length: 41000 + + task_reader: + type: deep_finance # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service + deep_finance: + training: + file_path: {{TRAIN_DATA_PATH}} + validation: + file_path: {{VAL_DATA_PATH}} + # env_service 仍需配置(用于工具调用) + env_service: + env_type: "finworld" + env_url: {{ENV_SERVICE_URL}} + env_action_preference: code +trainer: + default_local_dir: "{{CKPT_SAVE_PATH}}/{{PREFIX}}/{{SUFFIX}}" + # resume_mode: disable # 禁用自动恢复,从头开始训练 +actor_rollout_ref: + rollout: + tensor_model_parallel_size: 8 + gpu_memory_utilization: 0.8 +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_