From bac05b57c846fcdb3d96fe00d2c4363a7876ca8a Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Fri, 16 Jan 2026 14:53:33 +0800 Subject: [PATCH 01/56] feat(finworld): Added AgentScope learning protocol and OpenJudge evaluation functionality to the FinWorld task. - Added the ExampleAgentScopeLearnProtocol class to implement the AgentScope execution flow for multi-turn interactions. - Integrated semaphore control to manage the parallelism of environment calls, improving environment stepping performance. - Implemented a mechanism for detecting context overflows and quickly terminating during environment interactions to prevent blocking. - Added a finworld.yaml configuration file to define project training and rollout parameters. - Added the FinWorldJudgeByOpenJudge class, integrating multiple evaluators including RM Gallery and OpenJudge (@haoran). - Implemented a mechanism for converting task output, asynchronous calls, and retrying to ensure evaluation stability. - Weight normalization manages the contributions of each evaluator, merging them to calculate the final reward and success determination. --- tutorial/example_finworld/finworld.py | 234 ++++++ tutorial/example_finworld/finworld.yaml | 79 ++ tutorial/example_finworld/finworld_judge.py | 767 ++++++++++++++++++ .../prompt/finworld_prompt.md | 0 4 files changed, 1080 insertions(+) create mode 100644 tutorial/example_finworld/finworld.py create mode 100644 tutorial/example_finworld/finworld.yaml create mode 100644 tutorial/example_finworld/finworld_judge.py create mode 100644 tutorial/example_finworld/prompt/finworld_prompt.md diff --git a/tutorial/example_finworld/finworld.py b/tutorial/example_finworld/finworld.py new file mode 100644 index 00000000..778e3439 --- /dev/null +++ b/tutorial/example_finworld/finworld.py @@ -0,0 +1,234 @@ +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask +from agentscope.message import Msg +from pydantic import Field +import logging +import threading +import time +import copy +from loguru import logger + + +# 创建信号量,允许同时12个线程运行 +sem = threading.Semaphore(30) + +class ExampleAgentScopeLearnProtocol(Workflow): + + trainer: str = Field(default="astune-trinity") + + async def agentscope_execute( + self, workflow_task: WorkflowTask, model_tuner: AjetTuner + ) -> WorkflowOutput: + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + # 1. 初始化消息 + # init_messages 通常是 [System, User] + init_messages = workflow_task.task.init_messages + + # 分离 System Prompt 和 Initial User Input + if len(init_messages) >= 2: + first_msg, user_msgs = init_messages[0], init_messages[1:] + else: + first_msg = {"content": "You're a helpful assistant."} + user_msgs = init_messages + + # conversation_history: 维护最原始、最标准的 OpenAI 格式数据 (含 role: tool) + # 这是"真值",用于评测和训练保存 + conversation_history = [ + {"role": "system", "content": first_msg["content"]}, + ] + conversation_history.extend(user_msgs) + + # 2. 初始化 Agent + agent = ReActAgent( + name="Qwen", + sys_prompt=first_msg["content"], # Agent 内部会自动管理 System Prompt + model=model_tuner, + formatter=DashScopeChatFormatter(), + memory=InMemoryMemory(), + toolkit=None, + print_hint_msg=False, + ) + agent.set_console_output_enabled(False) + env = workflow_task.gym_env + + # 3. 构造初始 Agent 输入 (List[Msg]) + # 注意:这里只包含 User 消息,不含 System,因为 System 已在 agent init 中设置 + # 必须转换为 Msg 对象 + agent_input = [] + for m in user_msgs: + agent_input.append(Msg( + name=m.get("name", "user"), + content=m.get("content", ""), + role=m.get("role", "user") + )) + + # 统计信息缓存 + latest_tool_stats = None + latest_reward_stats = {} + cumulative_tool_call_time = 0.0 # 累计工具调用时间 + cumulative_tool_time = {} # 按工具区分的累计耗时: {tool_name: [time1, time2, ...]} + + logger.info(f"开始执行多轮交互,最大步数: {model_tuner.config.astune.rollout.multi_turn.max_steps}") + + step = 0 + for step in range(model_tuner.config.astune.rollout.multi_turn.max_steps): + logger.info(f"=== 步骤 {step + 1} ===") + + # === Agent 推理 === + _llm_start = time.time() + # 传入增量消息 (agent_input),Agent 会将其添加到内存并生成回复 + reply_message = await agent(agent_input) + _llm_elapsed = time.time() - _llm_start + # 提取纯文本 content(兼容多模态格式) + if isinstance(reply_message.content, list): + # 多模态格式: [{'type': 'text', 'text': '...'}] + content_text = ''.join(item.get('text', '') for item in reply_message.content if isinstance(item, dict) and item.get('type') == 'text') + else: + content_text = reply_message.content + + content_preview = content_text[:100].replace('\n', ' ') + # logger.info(f"Agent回复 ({_llm_elapsed:.2f}s): {content_preview}...") + + # === 早期终止检查:在调用 env.step() 前检查 context_overflow === + # 修复问题:避免 token_overflow 后还继续调用工具导致阻塞 + if model_tuner.get_context_tracker().context_overflow: + logger.warning(f"上下文溢出,跳过 env.step(),在第 {step + 1} 步立即结束") + # 构造一个默认的结束响应 + conversation_history.append({ + "role": "assistant", + "content": content_text + }) + break + + # === Env 执行 === + _env_start = time.time() + with sem: + obs, reward, terminate, info = env.step( + action={"content": content_text, "role": "assistant"} + ) + _env_elapsed = time.time() - _env_start + logger.info(f"环境执行 ({_env_elapsed:.2f}s)") + # === 3. 更新 conversation_history (Full History) === + # A. 添加 Assistant 消息 (补全 tool_calls) + current_assistant_msg = { + "role": "assistant", + "content": content_text + } + if info and 'generated_tool_calls' in info and info['generated_tool_calls']: + current_assistant_msg['tool_calls'] = info['generated_tool_calls'] + conversation_history.append(current_assistant_msg) + + # B. 添加 Tool 消息 (直接使用 obs) + # 注意:obs 可能是 [tool_results_msgs] 套了一层,需要解包 + if isinstance(obs, list): + actual_msgs = obs[0] if (len(obs) == 1 and isinstance(obs[0], list)) else obs + conversation_history.extend(actual_msgs) + else: + conversation_history.append({"role": "user", "content": obs}) + + # === 4. 更新统计信息 === + if info: + if 'tool_stats' in info: + latest_tool_stats = info['tool_stats'] + logger.info(f"步骤 {step + 1} 工具统计: 调用={latest_tool_stats.get('total_calls', 0)}, " + f"成功率={latest_tool_stats.get('success_rate', 0):.1f}%") + if 'reward_stats' in info: + latest_reward_stats = info['reward_stats'] + # 累加工具调用时间 + step_tool_call_time = latest_reward_stats.get('tool_call_time', 0.0) + cumulative_tool_call_time += step_tool_call_time + # 累加按工具区分的耗时 + step_tool_time = latest_reward_stats.get('tool_time', {}) + for tool_name, time_list in step_tool_time.items(): + if tool_name not in cumulative_tool_time: + cumulative_tool_time[tool_name] = [] + if isinstance(time_list, list): + cumulative_tool_time[tool_name].extend(time_list) + + # === 5. 准备下一轮 Agent 输入 (Incremental) === + # 将 Env 返回的 obs 转换为 Msg 对象列表,供下一轮 agent() 调用 + # 关键:这里只放新的 obs,不要放完整的 history + agent_input = [] + + if isinstance(obs, list): + # Standard Mode: obs 是 tool messages 列表 + # 注意:finworld_env.step 返回 {"state": [tool_results_msgs]} 套了一层列表 + # BaseGymEnv.step 直接透传,所以 obs = [tool_results_msgs] + # 需要解包获取实际的消息列表 + actual_msgs = obs[0] if (len(obs) == 1 and isinstance(obs[0], list)) else obs + logger.info(f"环境观察 (Standard): 收到 {len(actual_msgs)} 条工具消息") + + # 按照 AgentScope 的 ContentBlock 格式转换消息 + # Agent.memory 会自动保存 assistant 的 tool_call 信息 + # 这里只需要传入 tool_result 消息即可 + for idx, m in enumerate(actual_msgs): + origin_role = m.get('role', 'user') + if origin_role == 'tool': + # 使用 ToolResultBlock 格式,作为 user 消息的 content + tool_result_block = { + "type": "tool_result", + "id": m.get('tool_call_id', ''), + "output": m.get('content', ''), + "name": m.get('name', '') + } + new_msg = Msg( + name="tool", + content=[tool_result_block], + role="user" + ) + agent_input.append(new_msg) + else: + # 其他消息(如 user 提示)直接添加 + content = m.get('content') + if content is None: content = "" + valid_role = origin_role if origin_role in ['user', 'assistant', 'system'] else 'user' + new_msg = Msg( + name=m.get('name', valid_role), + content=content, + role=valid_role + ) + agent_input.append(new_msg) + else: + # Legacy Mode + logger.info(f"环境观察 (Legacy): {str(obs)[:100]}...") + agent_input.append(Msg(name="env", content=obs, role="user")) + + # === 6. 终止检查 === + logger.info(f"终止状态: {terminate}") + if terminate: + logger.info(f"环境返回终止信号,在第 {step + 1} 步结束") + break + + if model_tuner.get_context_tracker().context_overflow: + logger.warning(f"上下文溢出,在第 {step + 1} 步结束") + break + + # === 结束处理 === + final_tool_stats = latest_tool_stats or { + 'total_calls': 0, 'total_errors': 0, 'success_calls': 0, 'success_rate': 0.0, + 'cache_hits': 0, 'cache_misses': 0 + } + # 将累计的 tool_time 合并到 tool_stats 中 + final_tool_stats['tool_time'] = cumulative_tool_time + final_tool_stats['tool_call_time'] = cumulative_tool_call_time + + logger.info(f"\n{'='*80}") + logger.info(f"任务完成统计 (Task ID: {workflow_task.task.task_id}):") + logger.info(f" 总步骤: {step + 1}") + logger.info(f" 总调用: {final_tool_stats.get('total_calls', 0)}") + logger.info(f" 成功率: {final_tool_stats.get('success_rate', 0):.2f}%") + logger.info(f"{'='*80}\n") + + return WorkflowOutput( + reward=None, + metadata={ + "total_step": step, + "tool_stats": final_tool_stats, + "reward_stats": latest_reward_stats, + "tool_success_rate": round(final_tool_stats.get('success_rate', 0.0), 2), + "conversation_history": conversation_history, + "query": workflow_task.task.main_query, + "task_id": workflow_task.task.task_id, + } + ) \ No newline at end of file diff --git a/tutorial/example_finworld/finworld.yaml b/tutorial/example_finworld/finworld.yaml new file mode 100644 index 00000000..80ba8188 --- /dev/null +++ b/tutorial/example_finworld/finworld.yaml @@ -0,0 +1,79 @@ +# ------------------ 主要配置 ------------------ +astune: + project_name: astune_finprompt + experiment_name: "cc_rm4_res2cit2fai2_30b" + judge_llm: qwen-flash + judge_concurrency: 10 + # OpenJudge 权重配置 + report_resolution_weight: 0.2 # 报告质量评估 + trajectory_faithfulness_weight: 0.2 # 事实准确性评估 + citation_audit_weight: 0.2 # 引用审计评估 (覆盖率 + 真实性) + rm_weight: 0.4 # RM Gallery 权重 + task_judge: + # 使用本地 FinWorldJudge 进行评估(解耦远程 env_service) + judge_protocol: tutorial.example_finworld.finworld_judge_by_openjudge->FinWorldJudgeByOpenJudge + model: + # ✨✨✨✨ 设置待训练的模型 + path: /mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507 + trainer_common: + nnodes: 8 + n_gpus_per_node: 8 + val_before_train: True + val_pass_n: 8 + save_freq: 10 + test_freq: 2 + total_epochs: 200 + rollout: + # ✨✨✨✨ 编写并选择Agent + use_agentscope_protocol: True + agentscope_learn_protocol: tutorial.example_finworld.finworld->ExampleAgentScopeLearnProtocol + agentscope_disable_toolcalls: True + enable_oversample: False + tensor_model_parallel_size: 8 + num_repeat: 4 + max_env_worker: 64 # 增加环境并行数 + max_num_seqs: 64 # 增加VLLM并发序列数 + max_env_len: 10000 + max_response_length_in_one_turn: 8000 + max_model_len: 50000 + agent_madness_reward: 0.0 + multi_turn: + max_steps: 6 + debug: + debug_max_parallel: 64 # 增加并行任务数,充分利用GPU + debug_first_n_tasks: 100 # 增加处理的任务数 + data: + train_batch_size: 32 # 增加批次大小,适配8卡并行 + max_prompt_length: 8000 + max_response_length: 41000 + + task_reader: + type: env_service # `env_service` or `dataset_file` or `huggingface_dat_repo` + env_service: + env_type: "finworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code # code, text, box + training_split: train + validation_split: val +trainer: + default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/astune/checkpoints/example_finworld//localths/cc_rm4_res2cit2fai2_30b" + # resume_mode: disable # 禁用自动恢复,从头开始训练 +actor_rollout_ref: + rollout: + tensor_model_parallel_size: 8 + gpu_memory_utilization: 0.8 +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://astune/default_config + - file://astune/default_config/verl # verl only + - file://external/verl/verl/trainer/config # verl only + - file://astune/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - ppo_trainer # verl inherit 1/2 + - verl_default # verl inherit 2/2 + - trinity_default # trinity inherit 1/1 + - astune_default + - _self_ diff --git a/tutorial/example_finworld/finworld_judge.py b/tutorial/example_finworld/finworld_judge.py new file mode 100644 index 00000000..9c9518a1 --- /dev/null +++ b/tutorial/example_finworld/finworld_judge.py @@ -0,0 +1,767 @@ +"""FinWorld Task Judge - OpenJudge 版本 +集成: RM Gallery, OpenJudge Graders (含 CitationAudit) +""" + +import os +import json +import asyncio +import time +from datetime import datetime +from typing import Dict, Any, Optional, Tuple, List + +from ajet.task_judge.base_judge import BaseJudge +from ajet.workflow import WorkflowOutput, WorkflowTask +# RewardStats 不再使用,OpenJudge 版本直接使用字典存储 +# from tutorial.example_finworld.reward.reward_schema import RewardStats + +# 环境变量配置 (RM Gallery) +TRAIN_REF_ANS_PATH = os.environ.get("FINWORLD_TRAIN_REF_ANS_PATH", "") +VAL_REF_ANS_PATH = os.environ.get("FINWORLD_VAL_REF_ANS_PATH", "") + +# OpenJudge imports +from openjudge.graders.agent.action.action_loop import ActionLoopDetectionGrader +from openjudge.graders.agent.observation.observation_information_gain import ( + ObservationInformationGainGrader, +) +from openjudge.graders.agent.trajectory.trajectory_comprehensive import ( + TrajectoryComprehensiveGrader, +) +from openjudge.models.openai_chat_model import OpenAIChatModel +from openjudge.models.schema.prompt_template import LanguageEnum +from openjudge.runner.grading_runner import GraderConfig, GradingRunner +from openjudge.scenarios.deep_research.graders.financial_report_resolution import ( + FinancialReportResolutionGrader, +) +from openjudge.scenarios.deep_research.graders.financial_trajectory_faithfulness import ( + FinancialTrajectoryFaithfulGrader, +) +from openjudge.scenarios.deep_research.graders.rubrics_based_trajectory_performance import ( + RubricsBasedTrajectoryPerformance, +) +from openjudge.scenarios.deep_research.graders.financial_report_citation_audit import ( + FinancialReportCitationAuditGrader, +) + + +# ============================================================================= +# 全局辅助函数 +# ============================================================================= + +def extract_text_content(content) -> str: + """统一提取纯文本内容""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + texts = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + texts.append(item.get("text", "")) + elif isinstance(item, str): + texts.append(item) + return "".join(texts) + return str(content) + + +def load_reference_answers_from_file(file_path: str) -> Tuple[Dict[str, str], Dict[str, str]]: + """加载参考答案 (RM Gallery 需要)""" + if not os.path.exists(file_path): + raise FileNotFoundError(f"Reference answers file not found: {file_path}") + try: + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + ref_answers, ref_domains = {}, {} + for item in data: + task_id = item.get("task", {}).get("task_id") + if not task_id or "answer" not in item: continue + ref_answers[task_id] = item["answer"] + domain = item.get("task", {}).get("metadata", {}).get("domain") + if domain: ref_domains[task_id] = domain + return ref_answers, ref_domains + except Exception as e: + raise ValueError(f"Error loading reference answers: {e}") + + +# ============================================================================= +# FinWorldJudgeByOpenJudge 类 +# ============================================================================= + +class FinWorldJudgeByOpenJudge(BaseJudge): + """ + 使用 OpenJudge 框架的 FinWorld Judge + 集成: RM Gallery, OpenJudge Graders (含 CitationAudit) + + 分析: + - compute_reward 每次处理 **一条采样**(单个 workflow_output) + - 输入:workflow_task, workflow_output + - 输出:(final_reward: float, is_success: bool) + - 副作用:更新 workflow_output.metadata["reward_stats"] + + 注意:GradingRunner 不能使用单例模式,因为其内部 Semaphore 会绑定到创建时的事件循环 + """ + + _model_instance = None # Model 可以复用 + _rm_evaluator_instance = None # RM Gallery Evaluator (单例) + _ref_answers_cache: Dict[str, Dict[str, str]] = {} # 参考答案缓存 + _ref_domains_cache: Dict[str, Dict[str, str]] = {} # 领域缓存 + + def __init__(self, config): + super().__init__(config) + self._setup_weights() + self._init_model() # 只初始化 model,runner 在每次调用时创建 + self._init_rm_components() # 初始化 RM Gallery 组件 + self._init_reference_answers() # 初始化参考答案 + + def _setup_weights(self): + """ + 配置 OpenJudge 各 grader 的权重并归一化 + + graders 对应关系: + - financial_report_resolution: 报告质量和问题解决能力 + - financial_trajectory_faithfulness: 事实准确性(忠实度) + - citation_audit: 引用审计(覆盖率 + 真实性) + - rubrics_based_trajectory_performance: 基于 rubrics 的评估 + - trajectory_comprehensive: 轨迹综合评估 + - observation_information_gain: 信息增益(去重) + - action_loop_detection: 动作循环检测(惩罚项) + """ + cfg = getattr(self.config, "ajet", None) + + # 定义各 grader 的权重(可从 config 中读取)- 与 finworld_judge.py 对齐 + self.w = { + "rm": getattr(cfg, "rm_weight", 1.0) if cfg else 1.0, # RM Gallery 权重 + "citation_audit": getattr(cfg, "citation_audit_weight", 0.0) if cfg else 0.0, # CitationAudit 权重 + "report_resolution": getattr(cfg, "report_resolution_weight", 0.0) if cfg else 0.0, + "trajectory_faithfulness": getattr(cfg, "trajectory_faithfulness_weight", 0.0) if cfg else 0.0, + # "rubrics_performance": getattr(cfg, "rubrics_performance_weight", 0.2) if cfg else 0.2, + # "trajectory_comprehensive": getattr(cfg, "trajectory_comprehensive_weight", 0.2) if cfg else 0.2, + # "information_gain": getattr(cfg, "information_gain_weight", 0.1) if cfg else 0.1, + # "action_loop": getattr(cfg, "action_loop_weight", 0.1) if cfg else 0.1 + } + + # 归一化(注意:action_loop 是惩罚项,不参与归一化;rm 需要参与归一化) + positive_weights = {k: v for k, v in self.w.items() if k != "action_loop" and v > 0} + total = sum(positive_weights.values()) + if total > 0: + for k in positive_weights: + self.w[k] = self.w[k] / total + + + def _init_rm_components(self): + """初始化 RM Gallery Evaluator(仅当 rm_weight > 0 时)""" + self._rm_enabled = (self.w.get("rm", 0) > 0) + if self._rm_enabled: + if FinWorldJudgeByOpenJudge._rm_evaluator_instance is None: + self._init_rm_evaluator() + FinWorldJudgeByOpenJudge._rm_evaluator_instance = self.rm_evaluator + else: + self.rm_evaluator = FinWorldJudgeByOpenJudge._rm_evaluator_instance + else: + self.rm_evaluator = None + + def _init_rm_evaluator(self): + """初始化 RM Gallery Evaluator""" + try: + # Monkey patch OpenAI client timeout (RM Gallery 默认只有60s,对于30B模型不够用) + import openai + _original_openai_init = openai.OpenAI.__init__ + def _patched_openai_init(self, *args, **kwargs): + kwargs.setdefault('timeout', 600.0) # 增大到600秒 + return _original_openai_init(self, *args, **kwargs) + openai.OpenAI.__init__ = _patched_openai_init + + from rm_gallery.core.reward.registry import RewardRegistry + import logging + logging.getLogger("rm_gallery").setLevel(logging.WARNING) + api_key = os.environ.get("DASHSCOPE_API_KEY") or os.environ.get("API_KEY") + base_url = os.environ.get("BASE_URL") or "https://dashscope.aliyuncs.com/compatible-mode/v1" + llm_name = os.environ.get("RM_LLM") + + rm_params = {"is_parallel": True, "enable_thinking": False, "base_url": base_url} # is_parallel=True 让子评估器并行调用LLM + if api_key: rm_params["api_key"] = api_key + + self.rm_evaluator = RewardRegistry.get("finance_composition")( + llm=llm_name, name="finance_composition", params=rm_params + ) + print(f"✓ RM evaluator initialized: {llm_name} {base_url} (timeout=600s)") + except Exception as e: + print(f"✗ Failed to initialize RM evaluator: {e}") + self.rm_evaluator = None + + def _init_reference_answers(self): + """初始化参考答案缓存""" + def _load(path, key): + if path and key not in FinWorldJudgeByOpenJudge._ref_answers_cache: + try: + ans, dom = load_reference_answers_from_file(path) + FinWorldJudgeByOpenJudge._ref_answers_cache[key], FinWorldJudgeByOpenJudge._ref_domains_cache[key] = ans, dom + except Exception: + FinWorldJudgeByOpenJudge._ref_answers_cache[key], FinWorldJudgeByOpenJudge._ref_domains_cache[key] = {}, {} + _load(TRAIN_REF_ANS_PATH, "train") + _load(VAL_REF_ANS_PATH, "val") + + def _get_reference_data(self, task_id: str) -> Tuple[str, str]: + """获取任务的参考答案和领域""" + cache_key = "val" if task_id.startswith("val_") else "train" + ans = FinWorldJudgeByOpenJudge._ref_answers_cache.get(cache_key, {}).get(task_id, "") + dom = FinWorldJudgeByOpenJudge._ref_domains_cache.get(cache_key, {}).get(task_id) + return ans, dom + + def _init_model(self): + """初始化 OpenJudge LLM Model(单例模式,可复用)""" + if FinWorldJudgeByOpenJudge._model_instance is None: + try: + model_name = getattr(self.config.ajet, "judge_llm", "qwen-flash") if hasattr(self.config, "ajet") else "qwen-flash" + base_url = os.environ.get("JUDGE_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") + api_key = os.environ.get("JUDGE_API_KEY", os.environ.get("DASHSCOPE_API_KEY", None)) + FinWorldJudgeByOpenJudge._model_instance = OpenAIChatModel( + model=model_name, + temperature=0.0, + base_url=base_url, + api_key=api_key + ) + print(f"✓ OpenJudge Model initialized: {model_name} @ {base_url}: {api_key}") + except Exception as e: + print(f"✗ Failed to initialize OpenJudge Model: {e}") + import traceback + traceback.print_exc() + raise + + self.model = FinWorldJudgeByOpenJudge._model_instance + self.max_concurrency = getattr(self.config.ajet, "judge_concurrency", 6) if hasattr(self.config, "ajet") else 6 + + def _create_runner_in_loop(self) -> GradingRunner: + """ + 在当前事件循环中创建 GradingRunner + + 注意:GradingRunner 内部的 Semaphore 会绑定到创建时的事件循环, + 因此不能使用单例模式,必须在每次调用的事件循环中创建新实例。 + """ + language = LanguageEnum.ZH + grader_configs = self._create_grader_configs(self.model, language) + return GradingRunner( + grader_configs=grader_configs, + max_concurrency=self.max_concurrency, + show_progress=False + ) + + def _create_grader_configs(self, model: OpenAIChatModel, language: LanguageEnum) -> Dict[str, GraderConfig]: + """ + 创建所有 grader 的配置 + + 返回:Dict[str, GraderConfig] + - key: grader 名称 + - value: GraderConfig(grader=..., mapper=...) + """ + return { + # 1. 报告质量评估 - 需要 messages 和 chat_date + "report_resolution": GraderConfig( + grader=FinancialReportResolutionGrader(model=model, language=language), + mapper=lambda data: { + "messages": data["messages"], + "chat_date": data.get("chat_date") + }, + ), + + # 2. 事实准确性评估 - 需要 messages + "trajectory_faithfulness": GraderConfig( + grader=FinancialTrajectoryFaithfulGrader(model=model, language=language), + mapper=lambda data: {"messages": data["messages"]}, + ), + + # 3. 引用审计评估 - 需要 messages + "citation_audit": GraderConfig( + grader=FinancialReportCitationAuditGrader(model=model, language=language), + mapper=lambda data: {"messages": data["messages"]}, + ), + + # 4. Rubrics 评估 - 需要 messages 和 rubrics + # "rubrics_performance": GraderConfig( + # grader=RubricsBasedTrajectoryPerformance(model=model, language=language), + # mapper=lambda data: { + # "messages": data["messages"], + # "rubrics": data.get("rubrics", []) + # }, + # ), + + # 5. 轨迹综合评估 - 需要 messages + # "trajectory_comprehensive": GraderConfig( + # grader=TrajectoryComprehensiveGrader(model=model, language=language), + # mapper=lambda data: {"messages": data["messages"]}, + # ), + + # 6. 信息增益评估 - 需要 messages(非 LLM grader) + # "information_gain": GraderConfig( + # grader=ObservationInformationGainGrader(similarity_threshold=0.5), + # mapper=lambda data: {"messages": data["messages"]}, + # ), + + # 7. 动作循环检测 - 需要 messages(非 LLM grader) + # "action_loop": GraderConfig( + # grader=ActionLoopDetectionGrader(similarity_threshold=1.0), + # mapper=lambda data: {"messages": data["messages"]}, + # ), + } + + def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> Tuple[float, bool]: + """ + 主计算逻辑:使用 OpenJudge Runner.arun 计算 reward + + 流程: + 1. 从 workflow_output.metadata 提取 conversation_history、query、rubrics 等 + 2. 转换为 OpenJudge 的输入格式 (messages, chat_date, rubrics) + 3. 调用 Runner.arun([sample]) 获取所有 graders 的评分 + 4. 加权融合各 grader 分数 + 5. 计算惩罚项(tool_calls) + 6. 更新 metadata["reward_stats"] + 7. 返回 (final_reward, is_success) + """ + judge_start_time = time.time() + + try: + metadata = workflow_output.metadata + + # 1. 提取输入数据 + history = metadata.get("conversation_history", []) + query = metadata.get("query") or getattr(workflow_task.task, "main_query", "") + task_id = metadata.get("task_id") or getattr(workflow_task.task, "task_id", "") + rubrics = metadata.get("rubrics") # 可能是 None 或 list of dicts + step_reward = metadata.get("reward_stats", {}).get("step_reward", 0.0) + chat_date = metadata.get("chat_date") if metadata else datetime.now().strftime("%Y-%m-%d") + + if not history: + print(f"⚠️ Empty conversation history for task_id={task_id}") + return 0.0, False + + # 1.5 RM Gallery 评估(如果启用) + ref_ans, domain = self._get_reference_data(task_id) + assistants = [extract_text_content(m["content"]) for m in history if m["role"] == "assistant"] + + # RM Gallery 耗时记录 + rm_start_time = time.time() + if self._rm_enabled and self.rm_evaluator: + rm_raw = self._evaluate_with_rm_gallery(query, assistants[-1] if assistants else "", ref_ans, task_id, domain) + else: + rm_raw = 0.0 + rm_time = time.time() - rm_start_time + + # 2. 转换为 OpenJudge 输入格式 + openjudge_sample = self._convert_to_openjudge_format( + history=history, + query=query, + task_id=task_id, + rubrics=rubrics, + chat_date=chat_date + ) + + # 3. 调用 OpenJudge Runner.arun(异步) + grading_start_time = time.time() + grader_results = self._run_openjudge_evaluation([openjudge_sample]) + grading_time = time.time() - grading_start_time + + # 4. 提取各 grader 分数(arun 返回 Dict[str, List[GraderScore]],这里取第一条) + grader_scores, quota_exceeded_flags = self._extract_grader_scores(grader_results) + + # 5. 加权融合(包含 RM Gallery 和 OpenJudge Graders) + fused_reward, contributions = self._fuse_grader_scores(grader_scores, rm_raw) + + # 6. 计算惩罚项(保留原有的 tool_calls 惩罚逻辑) + tool_calls = metadata.get("tool_stats", {}).get("total_calls", 0) + penalty = self._compute_penalty(tool_calls) + + # 7. 汇总 + final_reward = fused_reward + step_reward + penalty + + judge_total_time = time.time() - judge_start_time + + # 8. 更新元数据(实例化 RewardStats) + time_stats = { + "rm_time": rm_time, + "grading_time": grading_time, + "judge_total_time": judge_total_time, + } + self._update_metadata_stats( + metadata=metadata, + final_reward=final_reward, + fused_reward=fused_reward, + penalty=penalty, + step_reward=step_reward, + grader_scores=grader_scores, + contributions=contributions, + time_stats=time_stats, + rm_raw=rm_raw, + quota_exceeded_flags=quota_exceeded_flags + ) + + print(f"FinWorldJudgeByOpenJudge: task_id={task_id}, fused={fused_reward:.4f}, final={final_reward:.4f}, rm_time={rm_time:.2f}s, grading_time={grading_time:.2f}s, total={judge_total_time:.2f}s") + + # 9. 判断是否成功(可根据实际需求调整阈值) + is_success = final_reward >= 0.7 + + return final_reward, is_success + + except Exception as e: + print(f"✗ Error in OpenJudge compute_reward: {e}") + import traceback + traceback.print_exc() + return 0.0, False + + def _convert_to_openjudge_format( + self, + history: List[Dict], + query: str, + task_id: str, + rubrics: Optional[Any], + chat_date: Optional[str] + ) -> Dict[str, Any]: + """ + 将训练框架的 conversation_history 转换为 OpenJudge 的输入格式 + + 输入: + - history: [{"role": "user/assistant/tool", "content": ..., "tool_calls": ...}, ...] + + 输出: + - { + "messages": [...], # OpenJudge 格式 + "chat_date": "YYYY-MM-DD", + "rubrics": [...] + } + """ + # 1. 规范化 messages + messages = [] + for msg in history: + content = extract_text_content(msg.get("content", "")) + normalized_msg = { + "role": msg.get("role", "user"), + "content": content + } + + # 透传 tool_calls 等字段(OpenJudge 需要) + for field in ["tool_calls", "tool_call_id", "name"]: + if field in msg: + normalized_msg[field] = msg[field] + + messages.append(normalized_msg) + + + # 3. 转换 rubrics 格式(如果存在) + # OpenJudge 期望的格式:[{"dimension": ..., "description": ..., "check_points": [...]}, ...] + openjudge_rubrics = [] + if rubrics: + if isinstance(rubrics, list): + openjudge_rubrics = rubrics + elif isinstance(rubrics, dict): + # 如果 rubrics 是 dict,尝试转换 + # 假设格式类似 {"criteria": [...], "scoring_dimensions": [...]} + if "criteria" in rubrics: + for criterion in rubrics.get("criteria", []): + openjudge_rubrics.append({ + "dimension": criterion.get("name", ""), + "description": criterion.get("description", ""), + "check_points": criterion.get("check_points", []) + }) + + return { + "messages": messages, + "chat_date": chat_date, + "rubrics": openjudge_rubrics + } + + def _run_openjudge_evaluation(self, dataset: List[Dict[str, Any]]) -> Dict[str, List[Any]]: + """ + 调用 OpenJudge Runner.arun 进行评估(带重试机制) + + 输入: + - dataset: List[Dict] - OpenJudge 格式的样本列表 + + 输出: + - Dict[str, List[GraderScore]] - 每个 grader 的评分结果 + + 注意:GradingRunner 必须在当前事件循环中创建,因为其内部 Semaphore 会绑定事件循环 + """ + result = {} + judge_instance = self # 保存引用以便在 async 函数中访问 + max_retries = 3 # 最大重试次数 + + async def run_with_retry(): + nonlocal result + last_exception = None + + for attempt in range(max_retries): + try: + # 在当前事件循环中创建 Runner(避免 Semaphore 绑定错误的事件循环) + runner = judge_instance._create_runner_in_loop() + result = await runner.arun(dataset) + return # 成功则直接返回 + except Exception as e: + last_exception = e + error_str = str(e) + + # 判断是否为可重试的连接错误 + is_connection_error = any(keyword in error_str for keyword in [ + "Connection", "connection", "TCPTransport", + "SSLWantReadError", "BrokenPipe", "timeout", + "closed", "APIConnectionError" + ]) + + if is_connection_error and attempt < max_retries - 1: + wait_time = 2 ** attempt # 指数退避: 1s, 2s, 4s + print(f"⚠️ OpenJudge connection error (attempt {attempt+1}/{max_retries}), retrying in {wait_time}s... Error: {error_str[:100]}") + await asyncio.sleep(wait_time) + continue + else: + # 非连接错误或已达最大重试次数 + raise last_exception + + # 所有重试都失败 + if last_exception: + raise last_exception + + try: + # 创建新的标准 asyncio 事件循环,并设置为当前线程的事件循环 + # 这样可以避免 Semaphore 绑定到不同事件循环的问题 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) # 关键:将新循环设置为当前线程的事件循环 + try: + loop.run_until_complete(run_with_retry()) + finally: + loop.close() + asyncio.set_event_loop(None) # 清理:避免引用已关闭的循环 + except Exception as e: + print(f"✗ OpenJudge Runner.arun failed after {max_retries} attempts: {e}") + import traceback + traceback.print_exc() + + return result + + def _extract_grader_scores(self, grader_results: Dict[str, List[Any]]) -> Tuple[Dict[str, float], Dict[str, bool]]: + """ + 从 Runner.arun 结果中提取各 grader 的分数 + + 输入: + - grader_results: Dict[str, List[GraderScore]] + { + "report_resolution": [GraderScore(score=0.88, reason="...", metadata={...})], + "trajectory_faithfulness": [GraderScore(score=1.0, ...)], + ... + } + + 输出: + - Tuple[Dict[str, float], Dict[str, bool]] + - scores: 每个 grader 的分数(取第一条采样的分数) + - quota_exceeded_flags: 每个 grader 是否发生 429 quota exceeded + """ + scores = {} + quota_exceeded_flags = {} + + for grader_name, score_list in grader_results.items(): + quota_exceeded_flags[grader_name] = False + if score_list and len(score_list) > 0: + # 取第一条采样的分数(因为每次只评估一条) + grader_score = score_list[0] + if hasattr(grader_score, "score"): + scores[grader_name] = grader_score.score + # 检测错误类型:分数为0且有错误信息 + if grader_score.score == 0.0 and hasattr(grader_score, "reason"): + reason = str(grader_score.reason) if grader_score.reason else "" + # 检测 429 quota exceeded + if "429" in reason or "insufficient_quota" in reason or "exceeded your current quota" in reason: + quota_exceeded_flags[grader_name] = True + else: + # 如果出错,设为 0 + scores[grader_name] = 0.0 + else: + scores[grader_name] = 0.0 + + print(f" [OpenJudge Scores] {scores}") + if any(quota_exceeded_flags.values()): + quota_graders = [k for k, v in quota_exceeded_flags.items() if v] + print(f" [OpenJudge QuotaExceeded] {quota_graders}") + return scores, quota_exceeded_flags + + def _fuse_grader_scores(self, grader_scores: Dict[str, float], rm_raw: float = 0.0) -> Tuple[float, Dict[str, float]]: + """ + 加权融合各 grader 的分数(包含 RM Gallery 和 OpenJudge Graders) + + 输入: + - grader_scores: Dict[str, float] - 各 grader 的原始分数 + - rm_raw: float - RM Gallery 原始分数 + + 输出: + - (fused_reward, contributions) + - fused_reward: 加权后的总分 + - contributions: Dict[str, float] - 各 grader 的贡献分数 + """ + contributions = {} + + # 添加 RM Gallery 贡献 + contributions["rm_contribution"] = self.w.get("rm", 0.0) * rm_raw + + # 添加 OpenJudge Graders 贡献(包括 citation_audit) + for grader_name, weight in self.w.items(): + if grader_name == "rm": + continue # 已单独处理 + score = grader_scores.get(grader_name, 0.0) + contributions[grader_name] = weight * score + + fused_reward = sum(contributions.values()) + + return fused_reward, contributions + + def _evaluate_with_rm_gallery(self, query: str, current: str, reference: str, task_id: str, domain: str) -> float: + """使用 RM Gallery 评估""" + if not self.rm_evaluator or not domain or not reference: + return 0.0 + try: + from rm_gallery.core.data.schema import DataSample + sample = DataSample( + unique_id=task_id, + input=[{"role": "user", "content": query}], + output=[ + {"answer": {"role": "assistant", "content": current, "label": {"model_name": "training"}}, "steps": None}, + {"answer": {"role": "assistant", "content": reference, "label": {"model_name": "reference"}}, "steps": None}, + ], + task_category="financial_analysis", source="finance_samples", metadata={"domain": domain} + ) + result = self.rm_evaluator.evaluate(sample) + self._save_rm_log(result, query, task_id) + return result.metadata["dimension_scores"]["overall_score"]["training"] + except Exception as e: + print(f"✗ RM Gallery evaluation failed: {e}") + return 0.0 + + def _save_rm_log(self, result, query: str, task_id: str): + """保存 RM Gallery 评估日志""" + try: + log = { + "task_id": task_id, + "query": query, + "timestamp": datetime.now().isoformat(), + "scores": result.metadata.get("dimension_scores", {}) + } + save_dir = "/mnt/data_cpfs/taoshuchang.tsc/deepresearch/ajet/outputs/rm_evaluation_logs" + os.makedirs(save_dir, exist_ok=True) + with open(os.path.join(save_dir, f"rmeval_{datetime.now().strftime('%Y%m%d')}.json"), "a") as f: + f.write(json.dumps(log, ensure_ascii=False) + "\n") + except Exception: + pass + + def _compute_penalty(self, tool_calls: int) -> float: + """ + 计算工具调用惩罚(保留原有逻辑) + + - 0 次调用:-1.0 + - 1-2 次:-0.5 + - 3+ 次:0.0 + """ + if tool_calls == 0: + return -1.0 + elif tool_calls <= 2: + return -0.5 + else: + return 0.0 + + def _update_metadata_stats( + self, + metadata: Dict[str, Any], + final_reward: float, + fused_reward: float, + penalty: float, + step_reward: float, + grader_scores: Dict[str, float], + contributions: Dict[str, float], + time_stats: Dict[str, float], + rm_raw: float = 0.0, + quota_exceeded_flags: Optional[Dict[str, bool]] = None + ): + """ + 更新 metadata["reward_stats"] - 直接使用 OpenJudge 原始字段 + + OpenJudge graders(按实际启用情况): + - report_resolution: 报告质量和问题解决能力 + - trajectory_faithfulness: 事实准确性(忠实度) + - citation_audit: 引用审计(覆盖率 + 真实性) + - rubrics_performance: 基于 rubrics 的评估(可选) + - trajectory_comprehensive: 轨迹综合评估(可选) + - information_gain: 信息增益/去重(可选) + - action_loop: 动作循环检测(惩罚项,可选) + + 注意:不再硬套 RewardStats 的字段名,直接使用 openjudge_ 前缀 + """ + quota_exceeded_flags = quota_exceeded_flags or {} + + # 计算 quota exceeded 统计 + quota_exceeded_count = sum(1 for v in quota_exceeded_flags.values() if v) + quota_exceeded_any = quota_exceeded_count > 0 + + # 基础分数 + stats_dict = { + "final_reward": final_reward, + "fused_reward": fused_reward, + "penalty": penalty, + "step_reward": step_reward, + "openjudge_enabled": True, + # Quota exceeded (429) 统计 + "quota_exceeded_any": quota_exceeded_any, # 是否有任何 grader 超额 + "quota_exceeded_count": quota_exceeded_count, # 超额的 grader 数量 + "quota_exceeded_graders": quota_exceeded_flags, # 各 grader 的超额标记 + # RM Gallery 相关 + "rm_enabled": self._rm_enabled, + "rm_raw": rm_raw, + "rm_weight": self.w.get("rm", 0.0), + "rm_contribution": contributions.get("rm_contribution", 0.0), + } + + # OpenJudge grader 原始分数(dimensions) + for grader_name, score in grader_scores.items(): + stats_dict[f"openjudge_{grader_name}_raw"] = score + stats_dict[f"openjudge_{grader_name}_weight"] = self.w.get(grader_name, 0.0) + + # OpenJudge grader 加权贡献(contribution) + for grader_name, contrib in contributions.items(): + stats_dict[f"openjudge_{grader_name}_contribution"] = contrib + + # 保留原始字典便于调试 + stats_dict["openjudge_grader_scores"] = grader_scores + stats_dict["openjudge_contributions"] = contributions + + # 注入耗时统计 + if time_stats: + stats_dict.update(time_stats) + + metadata["reward_stats"] = stats_dict + + def _save_evaluation_log(self, task_id: str, grader_results: Dict[str, List[Any]], query: str): + """ + 保存 OpenJudge 评估日志(可选) + """ + try: + log = { + "task_id": task_id, + "query": query, + "timestamp": datetime.now().isoformat(), + "grader_results": {} + } + + # 简化 grader_results 以便序列化 + for grader_name, score_list in grader_results.items(): + log["grader_results"][grader_name] = [] + for score in score_list: + if hasattr(score, "score"): + log["grader_results"][grader_name].append({ + "score": score.score, + "reason": score.reason[:200] if hasattr(score, "reason") else "", + }) + + save_dir = "/mnt/data_cpfs/taoshuchang.tsc/deepresearch/ajet/outputs/openjudge_logs" + os.makedirs(save_dir, exist_ok=True) + + log_file = os.path.join(save_dir, f"openjudge_{datetime.now().strftime('%Y%m%d')}.json") + with open(log_file, "a", encoding="utf-8") as f: + f.write(json.dumps(log, ensure_ascii=False) + "\n") + + except Exception as e: + print(f"⚠️ Failed to save evaluation log: {e}") + pass + diff --git a/tutorial/example_finworld/prompt/finworld_prompt.md b/tutorial/example_finworld/prompt/finworld_prompt.md new file mode 100644 index 00000000..e69de29b From c7ca8c7cb471fed3dac0938df7e22b6b06bda8ef Mon Sep 17 00:00:00 2001 From: binary-husky <96192199+binary-husky@users.noreply.github.com> Date: Fri, 16 Jan 2026 18:23:25 +0800 Subject: [PATCH 02/56] Precommit fix (#4) * fix end of files * autoflake import fix * add mypy check --- .github/workflows/doc.yaml | 2 +- .pre-commit-config.yaml | 29 ++----- README.md | 2 +- ajet/backbone/main_vllm.py | 5 +- ajet/backbone/trainer_trinity.py | 11 +-- ajet/backbone/warm_up.py | 2 +- ajet/context_tracker/base_tracker.py | 2 +- .../timeline_merging/timeline_merging.py | 1 - ajet/schema/convertion.py | 4 +- ajet/schema/logprob.py | 2 +- ajet/task_reader/__init__.py | 2 +- ajet/task_rollout/async_llm_bridge.py | 3 +- ajet/task_runner/base_runner.py | 2 - ajet/task_runner/general_runner.py | 3 +- ajet/tuner.py | 2 +- ajet/tuner_lib/weight_tuner/__init__.py | 1 - .../weight_tuner/as_agentscope_model.py | 7 +- .../weight_tuner/as_oai_baseurl_apikey.py | 13 +-- .../weight_tuner/as_oai_sdk_model.py | 15 +--- .../experimental/as_oai_model_server.py | 4 +- ajet/utils/async_utils.py | 2 +- ajet/utils/lowlevel_hook.py | 2 +- ajet/utils/metric_helper/__init__.py | 2 +- .../metric_helper/reward_metric_helper.py | 82 +++++++++---------- .../metric_helper/save_trajectory_as_json.py | 2 +- ajet/utils/msg_converter.py | 3 +- ajet/utils/networking.py | 2 +- ajet/utils/testing_utils.py | 3 - ajet/utils/thread_executors.py | 2 +- docs/_toc.yml | 15 ++-- docs/en/debugging_guide.md | 1 - docs/en/example_countdown.md | 1 - docs/en/example_learning_to_ask.md | 8 +- docs/en/hardware_related_solution.md | 2 +- docs/en/support_agentscope.md | 1 - docs/en/support_http.md | 2 - docs/en/support_langchain.md | 2 - docs/en/support_oaisdk.md | 3 - docs/index.md | 1 - docs/javascripts/animations.js | 1 - docs/javascripts/code-zoom.js | 1 - docs/javascripts/responsive.js | 1 - docs/javascripts/search-fix.js | 1 - docs/javascripts/tabbed-code.js | 1 - docs/requirements.txt | 1 - docs/stylesheets/animations.css | 1 - docs/stylesheets/feature-cards.css | 1 - docs/stylesheets/flowchart.css | 1 - docs/stylesheets/jupyter-simple.css | 1 - docs/stylesheets/syntax-highlight.css | 1 - docs/stylesheets/tuner_v2.md | 2 +- install.sh | 8 +- mkdocs.yml | 1 - pyproject.toml | 2 +- scripts/display_dataset.py | 5 -- tests/bench/benchmark_math/benchmark_math.py | 1 - tests/test_networking.py | 56 ------------- tutorial/README.md | 2 +- tutorial/example_appworld/appworld.py | 1 - tutorial/example_appworld/appworld_oai_sdk.py | 1 - .../data_preprocess/llm_info_extraction.py | 2 +- .../data_preprocess/message_splitter.py | 2 +- .../data_preprocess/step1.py | 34 ++++---- .../data_preprocess/step2.py | 4 +- tutorial/example_learn2ask/learn2ask.md | 2 +- .../example_learn2ask/learn2ask_langchain.py | 11 ++- .../ma_deepresearch.py | 5 -- .../math_agent_langchain.py | 16 ++-- .../example_math_agent/math_agent_oai_sdk.py | 1 - .../example_math_agent/math_agent_raw_http.py | 10 --- tutorial/example_werewolves/start.py | 2 +- 71 files changed, 132 insertions(+), 295 deletions(-) delete mode 100644 tests/test_networking.py diff --git a/.github/workflows/doc.yaml b/.github/workflows/doc.yaml index 98e2fc9d..fba9b693 100644 --- a/.github/workflows/doc.yaml +++ b/.github/workflows/doc.yaml @@ -59,4 +59,4 @@ jobs: steps: - name: Deploy to GitHub Pages id: deployment - uses: actions/deploy-pages@v4 \ No newline at end of file + uses: actions/deploy-pages@v4 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0e78eeb7..6f5736a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,30 +11,17 @@ repos: - id: check-merge-conflict - id: detect-private-key - - repo: https://github.com/psf/black - rev: 23.7.0 - hooks: - - id: black - language_version: python3.10 - args: [--line-length=100] - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", "--filter-files"] - - repo: https://github.com/pycqa/flake8 - rev: 6.1.0 + - repo: https://github.com/myint/autoflake + rev: v2.2.0 hooks: - - id: flake8 - additional_dependencies: [flake8-docstrings] - args: [ - "--max-line-length=100", - "--max-complexity=20", - "--select=C,E,F,W,B,B950", - "--ignore=E203,E266,E501,W503", - ] + - id: autoflake + args: [ + --in-place, + --remove-all-unused-imports, + --ignore-init-module-imports + ] - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.7.0 diff --git a/README.md b/README.md index f2b520f5..b01cff25 100644 --- a/README.md +++ b/README.md @@ -152,4 +152,4 @@ If you use AgentJet in your research, please cite:
[⭐ Star Us](https://github.com/modelscope/AgentJet) · [Report Bug](https://github.com/modelscope/AgentJet/issues) · [Request Feature](https://github.com/modelscope/AgentJet/issues) -
\ No newline at end of file + diff --git a/ajet/backbone/main_vllm.py b/ajet/backbone/main_vllm.py index 3f0a724c..686a35cd 100644 --- a/ajet/backbone/main_vllm.py +++ b/ajet/backbone/main_vllm.py @@ -1,4 +1,3 @@ -import atexit import os import sys from types import SimpleNamespace @@ -83,7 +82,7 @@ def submit_chat_completions(self, messages, sampling_params, request_id, tools=[ "content": message["content"], "tool_calls": message.get("tool_calls", None), "tokens": [ - TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content # type: ignore + TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content # type: ignore ], } ) @@ -131,7 +130,7 @@ async def submit_chat_completions_async(self, messages, sampling_params, request "content": message["content"], "tool_calls": message.get("tool_calls", None), "tokens": [ - TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content # type: ignore + TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content # type: ignore ], } ) diff --git a/ajet/backbone/trainer_trinity.py b/ajet/backbone/trainer_trinity.py index 1a75a1bc..8000a636 100644 --- a/ajet/backbone/trainer_trinity.py +++ b/ajet/backbone/trainer_trinity.py @@ -1,12 +1,12 @@ -import asyncio import os -from typing import Dict, List, Literal, Optional, cast - +import asyncio import datasets import openai import swanlab + from loguru import logger from transformers import AutoTokenizer +from typing import Dict, List, Literal, Optional, cast from trinity.buffer.reader import READER from trinity.buffer.reader.file_reader import TaskFileReader, _HFBatchReader from trinity.buffer.schema import FORMATTER @@ -19,9 +19,7 @@ from trinity.utils.monitor import MONITOR, Monitor from ajet.backbone.warm_up import warm_up_process -from ajet.context_tracker.multiagent_tracking import ( - MultiAgentContextTracker, -) +from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker from ajet.schema.trajectory import Sample from ajet.task_reader import dict_to_ajet_task from ajet.task_rollout.native_parallel_worker import DynamicRolloutManager @@ -65,7 +63,6 @@ def __init__( ) def convert_task(self, task: TrinityTask): - from ajet.schema.task import Task assert isinstance(task.raw_task, dict) return dict_to_ajet_task(task.raw_task) diff --git a/ajet/backbone/warm_up.py b/ajet/backbone/warm_up.py index f4c2973e..fcae673f 100644 --- a/ajet/backbone/warm_up.py +++ b/ajet/backbone/warm_up.py @@ -101,4 +101,4 @@ def warm_up_process(config): experiment_name = config.ajet.experiment_name init_parallel_rollout_logger(experiment_name) warm_up_task_judge_when_needed(config) - clean_up_tmp_ajet_dir(config) \ No newline at end of file + clean_up_tmp_ajet_dir(config) diff --git a/ajet/context_tracker/base_tracker.py b/ajet/context_tracker/base_tracker.py index 0ff706fa..948aee3e 100644 --- a/ajet/context_tracker/base_tracker.py +++ b/ajet/context_tracker/base_tracker.py @@ -1,5 +1,5 @@ from typing import List, Tuple, Union -from typing import List, Union, Tuple, Dict, Optional, Any +from typing import List, Union, Tuple, Dict, Optional from ajet.schema.task import WorkflowTask from ajet.schema.extended_msg import ( diff --git a/ajet/context_tracker/timeline_merging/timeline_merging.py b/ajet/context_tracker/timeline_merging/timeline_merging.py index e81475dd..4fb19baa 100644 --- a/ajet/context_tracker/timeline_merging/timeline_merging.py +++ b/ajet/context_tracker/timeline_merging/timeline_merging.py @@ -1,6 +1,5 @@ from typing import List -from beast_logger import print_listofdict from ajet.context_tracker.basic_tracker import ExtendedMessage diff --git a/ajet/schema/convertion.py b/ajet/schema/convertion.py index e2a6a2c0..408bbcdb 100644 --- a/ajet/schema/convertion.py +++ b/ajet/schema/convertion.py @@ -4,11 +4,10 @@ from openai.types.chat.chat_completion_message import ChatCompletionMessage from agentscope.model import ChatResponse as AgentScopeChatResponse from openai.types.completion_usage import CompletionUsage -from typing import Any, Callable, Dict, List, Literal, Type, Union +from typing import List, Type from agentscope.message import TextBlock, ToolUseBlock from agentscope._utils._common import _json_loads_with_repair from pydantic import BaseModel -from agentscope.model import ChatResponse def convert_llm_proxy_response_to_oai_response(llm_proxy_response): @@ -106,4 +105,3 @@ def convert_llm_proxy_response_to_agentscope_response( ) return parsed_response - diff --git a/ajet/schema/logprob.py b/ajet/schema/logprob.py index 42d2c572..dc736fb8 100644 --- a/ajet/schema/logprob.py +++ b/ajet/schema/logprob.py @@ -11,4 +11,4 @@ class TokenAndProb(BaseModel): token_id: int logprob: float - decoded_string: str \ No newline at end of file + decoded_string: str diff --git a/ajet/task_reader/__init__.py b/ajet/task_reader/__init__.py index 19a1a8e3..2d7d7322 100644 --- a/ajet/task_reader/__init__.py +++ b/ajet/task_reader/__init__.py @@ -123,4 +123,4 @@ def dict_to_ajet_task(task_dict: dict) -> Task: task_id=task_dict.get("task_id", ""), env_type=task_dict.get("env_type", ""), metadata=task_dict.get("metadata", {}), - ) \ No newline at end of file + ) diff --git a/ajet/task_rollout/async_llm_bridge.py b/ajet/task_rollout/async_llm_bridge.py index f43ba1c8..ff494844 100644 --- a/ajet/task_rollout/async_llm_bridge.py +++ b/ajet/task_rollout/async_llm_bridge.py @@ -3,14 +3,13 @@ import json import time import uuid -from typing import Any, Callable, Dict, List, Literal, Type, Union +from typing import Any, Callable, Dict, List, Literal, Union from loguru import logger from omegaconf import DictConfig from pydantic import BaseModel -from transformers.tokenization_utils import PreTrainedTokenizer from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.outputs import RequestOutput as VerlVllmRequestOutput diff --git a/ajet/task_runner/base_runner.py b/ajet/task_runner/base_runner.py index 65aa5c13..d8c15492 100644 --- a/ajet/task_runner/base_runner.py +++ b/ajet/task_runner/base_runner.py @@ -3,7 +3,6 @@ from threading import Lock from typing import Any, Callable, Union, Type from multiprocessing import Process, Queue -from unittest import result from ajet.context_tracker.basic_tracker import BaseContextTracker from ajet.schema.task import WorkflowOutput, WorkflowTask @@ -117,4 +116,3 @@ def run_user_workflow( else: raise ValueError(f"Unsupported wrapper type: {self.wrapper_type}") - diff --git a/ajet/task_runner/general_runner.py b/ajet/task_runner/general_runner.py index 2904cfae..7ea76710 100644 --- a/ajet/task_runner/general_runner.py +++ b/ajet/task_runner/general_runner.py @@ -1,7 +1,6 @@ -from venv import logger from ajet import AjetTuner -from ajet import Workflow, WorkflowOutput +from ajet import WorkflowOutput from ajet.context_tracker.multiagent_tracking import ( MultiAgentContextTracker, ) diff --git a/ajet/tuner.py b/ajet/tuner.py index 93602d05..aacc3ab9 100644 --- a/ajet/tuner.py +++ b/ajet/tuner.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Literal, Callable, Union, Type +from typing import TYPE_CHECKING, Callable, Union, Type from ajet.context_tracker.multiagent_tracking import ( MultiAgentContextTracker, diff --git a/ajet/tuner_lib/weight_tuner/__init__.py b/ajet/tuner_lib/weight_tuner/__init__.py index 317e8699..abb540c1 100644 --- a/ajet/tuner_lib/weight_tuner/__init__.py +++ b/ajet/tuner_lib/weight_tuner/__init__.py @@ -1,4 +1,3 @@ from ajet.tuner_lib.weight_tuner.as_agentscope_model import AgentScopeModelTuner from ajet.tuner_lib.weight_tuner.as_oai_sdk_model import OpenaiClientModelTuner - diff --git a/ajet/tuner_lib/weight_tuner/as_agentscope_model.py b/ajet/tuner_lib/weight_tuner/as_agentscope_model.py index 4af1754c..67a5ef8b 100644 --- a/ajet/tuner_lib/weight_tuner/as_agentscope_model.py +++ b/ajet/tuner_lib/weight_tuner/as_agentscope_model.py @@ -1,7 +1,7 @@ -from typing import TYPE_CHECKING, Any, Literal, Type +from typing import Any, Literal, Type from agentscope._utils._common import _create_tool_from_base_model -from agentscope.model import ChatModelBase, ChatResponse, DashScopeChatModel +from agentscope.model import ChatResponse, DashScopeChatModel from loguru import logger from pydantic import BaseModel @@ -10,9 +10,6 @@ ) from ajet.task_rollout.async_llm_bridge import AgentScopeLlmProxyWithTracker -if TYPE_CHECKING: - from ajet import Workflow - class AgentScopeModelTuner(DashScopeChatModel): """ diff --git a/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py b/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py index ba3e9693..90c2cc72 100644 --- a/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py +++ b/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py @@ -1,22 +1,13 @@ import os -import asyncio -from typing import TYPE_CHECKING, Any, List, Callable, Literal, Type, Union -from loguru import logger +from typing import Any from pydantic import BaseModel, Field from ajet.context_tracker.multiagent_tracking import ( MultiAgentContextTracker, ) -from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker -from ajet.utils.magic_mock import SpecialMagicMock -from openai.types.chat.chat_completion import ChatCompletion -from openai.resources.chat.chat import Chat, AsyncChat +from openai.resources.chat.chat import AsyncChat from openai.resources.completions import AsyncCompletions -from openai import OpenAI, AsyncOpenAI -from ajet.utils.networking import find_free_port from .experimental.as_oai_model_client import generate_auth_token -if TYPE_CHECKING: - from ajet import Workflow class MockAsyncCompletions(AsyncCompletions): async def create(self, *args, **kwargs) -> Any: # type: ignore diff --git a/ajet/tuner_lib/weight_tuner/as_oai_sdk_model.py b/ajet/tuner_lib/weight_tuner/as_oai_sdk_model.py index 943d5c2c..59248fee 100644 --- a/ajet/tuner_lib/weight_tuner/as_oai_sdk_model.py +++ b/ajet/tuner_lib/weight_tuner/as_oai_sdk_model.py @@ -1,19 +1,12 @@ -import asyncio -from typing import TYPE_CHECKING, Any, List, Callable, Literal, Type, Union -from loguru import logger -from pydantic import BaseModel +from typing import Any, List, Callable from ajet.context_tracker.multiagent_tracking import ( MultiAgentContextTracker, ) from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker -from ajet.utils.magic_mock import SpecialMagicMock from openai.types.chat.chat_completion import ChatCompletion -from openai.resources.chat.chat import Chat, AsyncChat +from openai.resources.chat.chat import AsyncChat from openai.resources.completions import AsyncCompletions -from openai import OpenAI, AsyncOpenAI - -if TYPE_CHECKING: - from ajet import Workflow +from openai import AsyncOpenAI class MockAsyncCompletions(AsyncCompletions): @@ -80,5 +73,3 @@ async def create( ) assert isinstance(response_gen, ChatCompletion) return response_gen - - diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py index 0c652c69..089d11eb 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py @@ -24,7 +24,7 @@ from loguru import logger from pydantic import BaseModel -from fastapi import FastAPI, Header, HTTPException, Request, Body +from fastapi import FastAPI, Header, HTTPException, Request from contextlib import asynccontextmanager from multiprocessing import Process from concurrent.futures import ThreadPoolExecutor @@ -239,5 +239,3 @@ def start_interchange_server(config) -> int: # return port return port - - diff --git a/ajet/utils/async_utils.py b/ajet/utils/async_utils.py index 3bb1b67e..219aba9c 100644 --- a/ajet/utils/async_utils.py +++ b/ajet/utils/async_utils.py @@ -67,4 +67,4 @@ def _patched_del(self) -> None: AsyncHttpxClientWrapper.__del__ = _patched_del print("Applied httpx aclose patch.") except ImportError: - pass \ No newline at end of file + pass diff --git a/ajet/utils/lowlevel_hook.py b/ajet/utils/lowlevel_hook.py index bdd536d0..006f17b9 100644 --- a/ajet/utils/lowlevel_hook.py +++ b/ajet/utils/lowlevel_hook.py @@ -44,4 +44,4 @@ def debug_task_init(self, coro, loop=None, name=None, context=None): asyncio.create_task = debug_create_task asyncio.AbstractEventLoop.create_task = debug_loop_create_task -patch_task_creation() \ No newline at end of file +patch_task_creation() diff --git a/ajet/utils/metric_helper/__init__.py b/ajet/utils/metric_helper/__init__.py index a9702d5d..70ce2818 100644 --- a/ajet/utils/metric_helper/__init__.py +++ b/ajet/utils/metric_helper/__init__.py @@ -14,4 +14,4 @@ def update_metrics(context_tracker_arr, metrics:dict): metrics.update(tool_metrics) if reward_metrics: metrics.update(reward_metrics) - return \ No newline at end of file + return diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py index b6cf5918..49e069bf 100644 --- a/ajet/utils/metric_helper/reward_metric_helper.py +++ b/ajet/utils/metric_helper/reward_metric_helper.py @@ -11,17 +11,17 @@ - judge_time/ Judge time consumption statistics """ -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any import numpy as np def extract_reward_stats_from_trajectories(trajectories: List[Any]) -> List[Dict[str, Any]]: """ Extract reward_stats from trajectories list. - + Args: trajectories: List of trajectory objects containing workflow_metadata - + Returns: List of reward_stats dictionaries """ @@ -36,10 +36,10 @@ def extract_reward_stats_from_trajectories(trajectories: List[Any]) -> List[Dict def extract_reward_stats_from_cmts(cmts: List[Any]) -> tuple[List[Dict[str, Any]], Dict[str, int]]: """ Extract reward_stats from cmts list and return debug statistics. - + Args: cmts: List of cmt objects containing workflow_metadata - + Returns: Tuple of (reward_stats_list, debug_stats) """ @@ -49,47 +49,47 @@ def extract_reward_stats_from_cmts(cmts: List[Any]) -> tuple[List[Dict[str, Any] 'has_workflow_metadata': 0, 'has_reward_stats': 0, } - + for _cmt in cmts: if hasattr(_cmt, 'workflow_metadata') and _cmt.workflow_metadata: debug_stats['has_workflow_metadata'] += 1 if 'reward_stats' in _cmt.workflow_metadata: debug_stats['has_reward_stats'] += 1 reward_stats_list.append(_cmt.workflow_metadata['reward_stats']) - + return reward_stats_list, debug_stats def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str = "") -> Dict[str, float]: """ Compute SwanLab metrics from reward_stats list. - + Supports two data sources: 1. RM Gallery RewardStats fields (rm_raw, etc.) 2. OpenJudge fields (openjudge_xxx_raw, openjudge_xxx_contribution, etc.) - + Args: reward_stats_list: List of reward_stats dictionaries prefix: Metric name prefix (e.g., "val/" for validation phase) - + Returns: Formatted metrics dictionary ready for SwanLab reporting """ if not reward_stats_list: return {} - + n = len(reward_stats_list) metrics = {} - + # ========== Top-level Scores (General) ========== final_reward_list = [rs.get('final_reward', 0.0) for rs in reward_stats_list] fused_reward_list = [rs.get('fused_reward', 0.0) for rs in reward_stats_list] penalty_list = [rs.get('penalty', 0.0) for rs in reward_stats_list] step_reward_list = [rs.get('step_reward', 0.0) for rs in reward_stats_list] - + # Penalty statistics non_zero_penalties = [p for p in penalty_list if p != 0.0] - + # Top-level metrics metrics[f"{prefix}rewards/final_reward_mean"] = float(np.mean(final_reward_list)) metrics[f"{prefix}rewards/fused_reward_mean"] = float(np.mean(fused_reward_list)) @@ -97,110 +97,110 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str metrics[f"{prefix}rewards/step_reward_mean"] = float(np.mean(step_reward_list)) metrics[f"{prefix}rewards/penalty_count"] = len(non_zero_penalties) metrics[f"{prefix}rewards/penalty_rate"] = len(non_zero_penalties) / n * 100 if n > 0 else 0.0 - + # ========== Detect OpenJudge Usage ========== openjudge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('openjudge_enabled', False)) - + if openjudge_enabled_count > 0: # ========== OpenJudge Metrics ========== metrics[f"{prefix}rewards/openjudge_enabled_rate"] = openjudge_enabled_count / n * 100 - + # Dynamically extract OpenJudge grader fields - # Currently supported graders: report_resolution, trajectory_faithfulness, + # Currently supported graders: report_resolution, trajectory_faithfulness, # rubrics_performance, trajectory_comprehensive, information_gain, action_loop openjudge_graders = [ "report_resolution", - "trajectory_faithfulness", + "trajectory_faithfulness", "rubrics_performance", "trajectory_comprehensive", "information_gain", "action_loop", ] - + for grader_name in openjudge_graders: raw_key = f"openjudge_{grader_name}_raw" contrib_key = f"openjudge_{grader_name}_contribution" - + raw_list = [rs.get(raw_key, 0.0) for rs in reward_stats_list] contrib_list = [rs.get(contrib_key, 0.0) for rs in reward_stats_list] - + # Only report when non-zero values exist if any(v != 0.0 for v in raw_list): metrics[f"{prefix}rewards/openjudge/{grader_name}_raw_mean"] = float(np.mean(raw_list)) if any(v != 0.0 for v in contrib_list): metrics[f"{prefix}rewards/openjudge/{grader_name}_contribution_mean"] = float(np.mean(contrib_list)) - + # OpenJudge time consumption statistics grading_time_list = [rs.get('grading_time', 0.0) for rs in reward_stats_list] if any(v != 0.0 for v in grading_time_list): metrics[f"{prefix}judge_time/openjudge_grading_time_mean"] = float(np.mean(grading_time_list)) metrics[f"{prefix}judge_time/openjudge_grading_time_max"] = float(np.max(grading_time_list)) - + # ========== RM Gallery Metrics ========== # RM Gallery rm_raw_list = [rs.get('rm_raw', 0.0) for rs in reward_stats_list] rm_contribution_list = [rs.get('rm_contribution', 0.0) for rs in reward_stats_list] - + # RefJudge ref_final_raw_list = [rs.get('ref_final_raw', 0.0) for rs in reward_stats_list] ref_citation_raw_list = [rs.get('ref_citation_raw', 0.0) for rs in reward_stats_list] ref_grounding_raw_list = [rs.get('ref_grounding_raw', 0.0) for rs in reward_stats_list] ref_contribution_list = [rs.get('ref_contribution', 0.0) for rs in reward_stats_list] - + # StructureJudge structure_raw_list = [rs.get('structure_raw', 0.0) for rs in reward_stats_list] structure_contribution_list = [rs.get('structure_contribution', 0.0) for rs in reward_stats_list] - + # dimensions/ raw scores metrics[f"{prefix}rewards/dimensions/rm_raw_mean"] = float(np.mean(rm_raw_list)) metrics[f"{prefix}rewards/dimensions/ref_final_raw_mean"] = float(np.mean(ref_final_raw_list)) metrics[f"{prefix}rewards/dimensions/ref_citation_raw_mean"] = float(np.mean(ref_citation_raw_list)) metrics[f"{prefix}rewards/dimensions/ref_grounding_raw_mean"] = float(np.mean(ref_grounding_raw_list)) metrics[f"{prefix}rewards/dimensions/structure_raw_mean"] = float(np.mean(structure_raw_list)) - + # contribution/ weighted contributions metrics[f"{prefix}rewards/contribution/rm_contribution_mean"] = float(np.mean(rm_contribution_list)) metrics[f"{prefix}rewards/contribution/ref_contribution_mean"] = float(np.mean(ref_contribution_list)) metrics[f"{prefix}rewards/contribution/structure_contribution_mean"] = float(np.mean(structure_contribution_list)) - + # Enabled state statistics ref_judge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('ref_judge_enabled', False)) if ref_judge_enabled_count > 0: metrics[f"{prefix}rewards/ref_judge_enabled_rate"] = ref_judge_enabled_count / n * 100 - + structure_judge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('structure_judge_enabled', False)) if structure_judge_enabled_count > 0: metrics[f"{prefix}rewards/structure_judge_enabled_rate"] = structure_judge_enabled_count / n * 100 - + # Time consumption statistics rm_time_list = [rs.get('rm_time', 0.0) for rs in reward_stats_list] refstruc_time_list = [rs.get('refstruc_time', 0.0) for rs in reward_stats_list] - + metrics[f"{prefix}judge_time/rm_time_mean"] = float(np.mean(rm_time_list)) metrics[f"{prefix}judge_time/refstruc_time_mean"] = float(np.mean(refstruc_time_list)) - + if rm_time_list: metrics[f"{prefix}judge_time/rm_time_max"] = float(np.max(rm_time_list)) if refstruc_time_list: metrics[f"{prefix}judge_time/refstruc_time_max"] = float(np.max(refstruc_time_list)) - + # ========== General Time Consumption Statistics ========== judge_total_time_list = [rs.get('judge_total_time', 0.0) for rs in reward_stats_list] if any(v != 0.0 for v in judge_total_time_list): metrics[f"{prefix}judge_time/judge_total_time_mean"] = float(np.mean(judge_total_time_list)) metrics[f"{prefix}judge_time/judge_total_time_max"] = float(np.max(judge_total_time_list)) - + return metrics def compute_reward_metrics_from_trajectories(trajectories: List[Any]) -> Dict[str, float]: """ Training phase: Extract reward_stats from trajectories and compute metrics. - + Args: trajectories: List of trajectory objects - + Returns: Formatted metrics dictionary """ @@ -211,21 +211,21 @@ def compute_reward_metrics_from_trajectories(trajectories: List[Any]) -> Dict[st def compute_reward_metrics_from_cmts(cmts: List[Any], print_debug: bool = True) -> Dict[str, float]: """ Validation phase: Extract reward_stats from cmts and compute metrics. - + Args: cmts: List of cmt objects print_debug: Whether to print debug information - + Returns: Formatted metrics dictionary (with "val_reward/" prefix) """ reward_stats_list, debug_stats = extract_reward_stats_from_cmts(cmts) - + if print_debug: print(f"\n[DEBUG eval_dataset()] reward_stats statistics:") print(f" - Total cmts count: {debug_stats['total_cmts']}") print(f" - Has workflow_metadata: {debug_stats['has_workflow_metadata']}") print(f" - Has reward_stats: {debug_stats['has_reward_stats']}") print(f" - Extracted samples count: {len(reward_stats_list)}") - + return compute_reward_metrics(reward_stats_list, prefix="val_") diff --git a/ajet/utils/metric_helper/save_trajectory_as_json.py b/ajet/utils/metric_helper/save_trajectory_as_json.py index 0e380abc..344a6ab4 100644 --- a/ajet/utils/metric_helper/save_trajectory_as_json.py +++ b/ajet/utils/metric_helper/save_trajectory_as_json.py @@ -53,4 +53,4 @@ def save_trajectory_as_json(ctx_trackers, global_steps, prefix="train"): # Print confirmation for evaluation trajectories if prefix != "train": - print(f"Saved trajectory to {traj_file_path}") \ No newline at end of file + print(f"Saved trajectory to {traj_file_path}") diff --git a/ajet/utils/msg_converter.py b/ajet/utils/msg_converter.py index 0437f5ca..46c02128 100644 --- a/ajet/utils/msg_converter.py +++ b/ajet/utils/msg_converter.py @@ -21,8 +21,7 @@ {"role": "user/assistant/system", "content": "..."} """ -import json -from typing import List, Dict, Any, Union +from typing import List, Dict, Any diff --git a/ajet/utils/networking.py b/ajet/utils/networking.py index 9ed29c74..f2fed5ac 100644 --- a/ajet/utils/networking.py +++ b/ajet/utils/networking.py @@ -34,4 +34,4 @@ def get_host_ip(interface=None): except Exception: - return "127.0.0.1" \ No newline at end of file + return "127.0.0.1" diff --git a/ajet/utils/testing_utils.py b/ajet/utils/testing_utils.py index 22be6092..31f006c2 100644 --- a/ajet/utils/testing_utils.py +++ b/ajet/utils/testing_utils.py @@ -11,7 +11,6 @@ from loguru import logger from ajet.utils.dynamic_import import dynamic_import -from ajet.utils.sington import singleton class TestSuccessException(Exception): @@ -19,7 +18,6 @@ class TestSuccessException(Exception): All test is done, end the program early with exception. """ - pass class TestFailException(Exception): @@ -27,7 +25,6 @@ class TestFailException(Exception): Test has failed, end the program early with exception. """ - pass class BaseProbe(object): diff --git a/ajet/utils/thread_executors.py b/ajet/utils/thread_executors.py index 9c8ea634..1ab02baf 100644 --- a/ajet/utils/thread_executors.py +++ b/ajet/utils/thread_executors.py @@ -19,4 +19,4 @@ def __init__(self, max_workers=64): self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) def get_shared_executor(self) -> concurrent.futures.ThreadPoolExecutor: - return self.executor \ No newline at end of file + return self.executor diff --git a/docs/_toc.yml b/docs/_toc.yml index ffa745f4..7eb76610 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -8,7 +8,7 @@ parts: - file: en/installation.md - file: en/quickstart.md - file: en/tune_your_first_agent.md - + - caption: Example chapters: - file: en/example_math_agent.md @@ -17,13 +17,13 @@ parts: - file: en/example_learning_to_ask.md - file: en/example_frozenlake.md - file: en/example_countdown.md - + - caption: Component chapters: - file: en/workflow.md - file: en/data_pipeline.md - file: en/task_judger.md - + - caption: Deep Dive chapters: - file: en/configuration.md @@ -31,7 +31,7 @@ parts: - file: en/beast_logger.md - file: en/data_generation.md - file: en/example_tracing_feedback_loop.md - + # --- 中文部分 --- - caption: 教程 @@ -40,7 +40,7 @@ parts: - file: zh/installation.md - file: zh/quickstart.md - file: zh/tune_your_first_agent.md - + - caption: 示例 chapters: - file: zh/example_math_agent.md @@ -49,13 +49,13 @@ parts: - file: zh/example_learning_to_ask.md - file: zh/example_frozenlake.md - file: zh/example_countdown.md - + - caption: 组件 chapters: - file: zh/workflow.md - file: zh/data_pipeline.md - file: zh/task_judger.md - + - caption: 深入探索 chapters: - file: zh/configuration.md @@ -63,4 +63,3 @@ parts: - file: zh/beast_logger.md - file: zh/data_generation.md - file: zh/example_tracing_feedback_loop.md - diff --git a/docs/en/debugging_guide.md b/docs/en/debugging_guide.md index 0a938004..ff7563a2 100644 --- a/docs/en/debugging_guide.md +++ b/docs/en/debugging_guide.md @@ -104,4 +104,3 @@ Then, the modified launch.json will be | **VSCode Extension** | Python | Python + Ray Distributed Debugger | | **Launch Mode** | `F5` standard launch (via `launch.json`) | Command line execution with `ajet ... --debug="TAG"` | | **Commandline** | `--backbone=debug` | `--debug="TAG1\|TAG2\|TAG3"` | - diff --git a/docs/en/example_countdown.md b/docs/en/example_countdown.md index e214e4d5..ff8ec4e3 100644 --- a/docs/en/example_countdown.md +++ b/docs/en/example_countdown.md @@ -201,4 +201,3 @@ However, tuning resolves these issues, as shown in the example below: ![After tuning](https://img.alicdn.com/imgextra/i4/O1CN01C3kUnV221zjPi30rd_!!6000000007061-2-tps-1650-730.png) > **Token-level Visualization:** These detailed logs are generated by Beast-Logger. See [Beast-Logger Usage](./beast_logger.md) for more details. - diff --git a/docs/en/example_learning_to_ask.md b/docs/en/example_learning_to_ask.md index c3d4bcf9..d5a17abe 100644 --- a/docs/en/example_learning_to_ask.md +++ b/docs/en/example_learning_to_ask.md @@ -135,7 +135,7 @@ We provide two implmentations of the agent based on AgentScope and langchain: ```python # get the trainable llm llm_info=tuner.as_oai_baseurl_apikey() - + # create the langchain agent llm=ChatOpenAI( base_url=llm_info.base_url, @@ -145,7 +145,7 @@ We provide two implmentations of the agent based on AgentScope and langchain: model=llm, system_prompt=system_prompt, ) - + # build messages and send to the agent msg=[ {"role": x["role"], "content": x["content"]} for x in messages @@ -153,7 +153,7 @@ We provide two implmentations of the agent based on AgentScope and langchain: result = agent.invoke({ "messages": msg, # type: ignore }) - + response = result["messages"][-1].content reward = await reward_fn_with_semaphore(msg, response, truth_action, truth_info) return WorkflowOutput(reward=reward) @@ -221,4 +221,4 @@ Agent: Has itching or reddening appeared around this bite site recently without The question becomes more precise and informative, guiding the user to provide clinically relevant details. -> To learn more about the task and results on larger models, refer to [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441). \ No newline at end of file +> To learn more about the task and results on larger models, refer to [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441). diff --git a/docs/en/hardware_related_solution.md b/docs/en/hardware_related_solution.md index c2cad4a3..9743d384 100644 --- a/docs/en/hardware_related_solution.md +++ b/docs/en/hardware_related_solution.md @@ -17,4 +17,4 @@ This document records a list of **Hardware Related** issues for future reference ```bash export NCCL_NVLS_ENABLE=0 - ``` \ No newline at end of file + ``` diff --git a/docs/en/support_agentscope.md b/docs/en/support_agentscope.md index b3129191..e551e4d9 100644 --- a/docs/en/support_agentscope.md +++ b/docs/en/support_agentscope.md @@ -223,4 +223,3 @@ This article introduce the way to convert different types of ways to convert you else: is_success = False return WorkflowOutput(reward=(1.0 if is_success else 0.0), metadata={"final_answer": final_answer}) ``` - diff --git a/docs/en/support_http.md b/docs/en/support_http.md index 32474904..0bf3ab3d 100644 --- a/docs/en/support_http.md +++ b/docs/en/support_http.md @@ -93,5 +93,3 @@ in this AI era, you can always start from scratch and build your own "high-scrap ... ``` - - diff --git a/docs/en/support_langchain.md b/docs/en/support_langchain.md index 6e645dcc..d1e12890 100644 --- a/docs/en/support_langchain.md +++ b/docs/en/support_langchain.md @@ -84,5 +84,3 @@ This article introduce the way to convert different types of ways to convert you ... ``` - - diff --git a/docs/en/support_oaisdk.md b/docs/en/support_oaisdk.md index 5268ab42..b60b03e3 100644 --- a/docs/en/support_oaisdk.md +++ b/docs/en/support_oaisdk.md @@ -88,6 +88,3 @@ This article introduce the way to convert different types of ways to convert you ... ``` - - - diff --git a/docs/index.md b/docs/index.md index ba98cd7f..5583fa69 100644 --- a/docs/index.md +++ b/docs/index.md @@ -170,4 +170,3 @@ The internal system orchestrates several specialized modules to handle the compl

查看中文文档

完整的中文教程和指南。

--> - diff --git a/docs/javascripts/animations.js b/docs/javascripts/animations.js index 00e3603b..a5dc584a 100644 --- a/docs/javascripts/animations.js +++ b/docs/javascripts/animations.js @@ -399,4 +399,3 @@ }; })(); - diff --git a/docs/javascripts/code-zoom.js b/docs/javascripts/code-zoom.js index e2a08f6d..22d3d624 100644 --- a/docs/javascripts/code-zoom.js +++ b/docs/javascripts/code-zoom.js @@ -1,2 +1 @@ /* Code zoom - placeholder */ - diff --git a/docs/javascripts/responsive.js b/docs/javascripts/responsive.js index 663e371f..d57c4db2 100644 --- a/docs/javascripts/responsive.js +++ b/docs/javascripts/responsive.js @@ -353,4 +353,3 @@ }; })(); - diff --git a/docs/javascripts/search-fix.js b/docs/javascripts/search-fix.js index e8436240..444f2af9 100644 --- a/docs/javascripts/search-fix.js +++ b/docs/javascripts/search-fix.js @@ -1,2 +1 @@ /* Search fix - placeholder */ - diff --git a/docs/javascripts/tabbed-code.js b/docs/javascripts/tabbed-code.js index 880ba944..cfd19559 100644 --- a/docs/javascripts/tabbed-code.js +++ b/docs/javascripts/tabbed-code.js @@ -174,4 +174,3 @@ // Export for manual re-initialization if needed window.initTabbedSets = initTabbedSets; })(); - diff --git a/docs/requirements.txt b/docs/requirements.txt index 968bb898..db4f637c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -11,4 +11,3 @@ pymdown-extensions==10.16.1 # Syntax highlighting Pygments>=2.18.0 - diff --git a/docs/stylesheets/animations.css b/docs/stylesheets/animations.css index 2129b6d8..9d390ff7 100644 --- a/docs/stylesheets/animations.css +++ b/docs/stylesheets/animations.css @@ -875,4 +875,3 @@ img { .duration-fast { animation-duration: var(--rm-transition-fast); } .duration-normal { animation-duration: var(--rm-transition-normal); } .duration-slow { animation-duration: var(--rm-transition-slow); } - diff --git a/docs/stylesheets/feature-cards.css b/docs/stylesheets/feature-cards.css index 03fe0464..5865ca73 100644 --- a/docs/stylesheets/feature-cards.css +++ b/docs/stylesheets/feature-cards.css @@ -540,4 +540,3 @@ .dark { --inline-icon-filter: invert(1) hue-rotate(180deg); } - diff --git a/docs/stylesheets/flowchart.css b/docs/stylesheets/flowchart.css index 175dc123..345b94f1 100644 --- a/docs/stylesheets/flowchart.css +++ b/docs/stylesheets/flowchart.css @@ -400,4 +400,3 @@ font-size: 0.875rem; margin-bottom: 0.5rem; } - diff --git a/docs/stylesheets/jupyter-simple.css b/docs/stylesheets/jupyter-simple.css index 401abf67..864c59bd 100644 --- a/docs/stylesheets/jupyter-simple.css +++ b/docs/stylesheets/jupyter-simple.css @@ -256,4 +256,3 @@ article .cell.markdown ol:last-child { top: 0.75rem; } } - diff --git a/docs/stylesheets/syntax-highlight.css b/docs/stylesheets/syntax-highlight.css index 3c651185..7cfcf6ba 100644 --- a/docs/stylesheets/syntax-highlight.css +++ b/docs/stylesheets/syntax-highlight.css @@ -303,4 +303,3 @@ .dark .codehilite .language-json .nd { color: #79c0ff; } - diff --git a/docs/stylesheets/tuner_v2.md b/docs/stylesheets/tuner_v2.md index c8766e31..c19509cd 100644 --- a/docs/stylesheets/tuner_v2.md +++ b/docs/stylesheets/tuner_v2.md @@ -78,4 +78,4 @@ response = client.chat.completions.create( ) -``` \ No newline at end of file +``` diff --git a/install.sh b/install.sh index bf0400b6..2306bad0 100755 --- a/install.sh +++ b/install.sh @@ -203,7 +203,7 @@ download_binary_and_run_installer() { local _checksum_value # destructure selected archive info into locals - case "$_artifact_name" in + case "$_artifact_name" in "uv-aarch64-apple-darwin.tar.gz") _arch="aarch64-apple-darwin" _zip_ext=".tar.gz" @@ -529,7 +529,7 @@ replace_home() { json_binary_aliases() { local _arch="$1" - case "$_arch" in + case "$_arch" in "aarch64-apple-darwin") echo '{}' ;; @@ -612,7 +612,7 @@ aliases_for_binary() { local _bin="$1" local _arch="$2" - case "$_arch" in + case "$_arch" in "aarch64-apple-darwin") case "$_bin" in *) @@ -793,7 +793,7 @@ select_archive_for_arch() { # try each archive, checking runtime conditions like libc versions # accepting the first one that matches, as it's the best match - case "$_true_arch" in + case "$_true_arch" in "aarch64-apple-darwin") _archive="uv-aarch64-apple-darwin.tar.gz" if [ -n "$_archive" ]; then diff --git a/mkdocs.yml b/mkdocs.yml index 6a06d4ad..a6fa0585 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -147,4 +147,3 @@ extra_javascript: - javascripts/nav-scroll-fix.js - javascripts/animations.js - javascripts/responsive.js - diff --git a/pyproject.toml b/pyproject.toml index aee28b2b..856cddca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,4 +113,4 @@ known_third_party = ["wandb"] [project.urls] -"Homepage" = "https://github.com/modelscope/AgentJet" \ No newline at end of file +"Homepage" = "https://github.com/modelscope/AgentJet" diff --git a/scripts/display_dataset.py b/scripts/display_dataset.py index e3132bc4..6d125e5c 100644 --- a/scripts/display_dataset.py +++ b/scripts/display_dataset.py @@ -1,10 +1,5 @@ import argparse -import glob -import os -import time -from beast_logger import print_list -from huggingface_hub import snapshot_download parser = argparse.ArgumentParser(description="download Hugging Face dataset") parser.add_argument("--target", default="openai/gsm8k", type=str, help="HuggingFace dataset name") diff --git a/tests/bench/benchmark_math/benchmark_math.py b/tests/bench/benchmark_math/benchmark_math.py index 9d8397ca..973f9ea2 100644 --- a/tests/bench/benchmark_math/benchmark_math.py +++ b/tests/bench/benchmark_math/benchmark_math.py @@ -1,5 +1,4 @@ # flake8: noqa -import os import time from ajet.utils.testing_utils import BenchmarkProbe, singleton diff --git a/tests/test_networking.py b/tests/test_networking.py deleted file mode 100644 index 913fc341..00000000 --- a/tests/test_networking.py +++ /dev/null @@ -1,56 +0,0 @@ -import socket -import unittest -import sys -import os -import importlib.util - -# Load the module directly to avoid top-level package import issues -# caused by broken dependencies in other parts of the codebase. -# We are testing a standalone utility, so we don't need the whole app context. -module_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'ajet', 'utils', 'networking.py')) -spec = importlib.util.spec_from_file_location("networking", module_path) -networking = importlib.util.module_from_spec(spec) -spec.loader.exec_module(networking) - -find_free_port = networking.find_free_port -get_host_ip = networking.get_host_ip - -class TestNetworking(unittest.TestCase): - def test_find_free_port(self): - """Test that find_free_port returns a valid integer port.""" - port = find_free_port() - self.assertIsInstance(port, int) - self.assertGreater(port, 0) - self.assertLess(port, 65536) - - # Verify the port is valid to bind to (it should have been released) - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - try: - s.bind(('', port)) - except OSError: - # It's possible the port was taken immediately by another process - # but unlikely in a test environment. - pass - - def test_get_host_ip(self): - """Test that get_host_ip returns a valid IP string.""" - ip = get_host_ip() - self.assertIsInstance(ip, str) - parts = ip.split('.') - self.assertEqual(len(parts), 4) - for part in parts: - if part == 'localhost': - continue - self.assertTrue(part.isdigit(), f"Part {part} is not a digit") - self.assertTrue(0 <= int(part) <= 255) - - def test_get_host_ip_with_interface(self): - """Test get_host_ip with a non-existent interface falls back to default behavior.""" - # This will likely fail the interface specific block and fall back to the connect method - ip = get_host_ip(interface="invalid_interface_XYZ") - self.assertIsInstance(ip, str) - parts = ip.split('.') - self.assertEqual(len(parts), 4) - -if __name__ == '__main__': - unittest.main() diff --git a/tutorial/README.md b/tutorial/README.md index e5811d8d..8e5288a9 100644 --- a/tutorial/README.md +++ b/tutorial/README.md @@ -8,4 +8,4 @@ Explore our rich library of examples to kickstart your journey. - Example Benchmark Tracking System: - https://benchmark.agent-matrix.com/examples \ No newline at end of file + https://benchmark.agent-matrix.com/examples diff --git a/tutorial/example_appworld/appworld.py b/tutorial/example_appworld/appworld.py index d8b647e7..01816e67 100644 --- a/tutorial/example_appworld/appworld.py +++ b/tutorial/example_appworld/appworld.py @@ -1,5 +1,4 @@ from agentscope.message import Msg -from pydantic import Field from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask diff --git a/tutorial/example_appworld/appworld_oai_sdk.py b/tutorial/example_appworld/appworld_oai_sdk.py index 534ec00b..dc18db34 100644 --- a/tutorial/example_appworld/appworld_oai_sdk.py +++ b/tutorial/example_appworld/appworld_oai_sdk.py @@ -1,5 +1,4 @@ from agentscope.message import Msg -from pydantic import Field from ajet import Workflow, WorkflowOutput, WorkflowTask from ajet import AjetTuner diff --git a/tutorial/example_learn2ask/data_preprocess/llm_info_extraction.py b/tutorial/example_learn2ask/data_preprocess/llm_info_extraction.py index 75e76c87..070b1612 100644 --- a/tutorial/example_learn2ask/data_preprocess/llm_info_extraction.py +++ b/tutorial/example_learn2ask/data_preprocess/llm_info_extraction.py @@ -145,4 +145,4 @@ def parse_llm_output(output_str): return result except Exception as e: - return f"Error parsing output: [{repr(output_str)}] error = {str(e)}" \ No newline at end of file + return f"Error parsing output: [{repr(output_str)}] error = {str(e)}" diff --git a/tutorial/example_learn2ask/data_preprocess/message_splitter.py b/tutorial/example_learn2ask/data_preprocess/message_splitter.py index a82506a4..06362b05 100644 --- a/tutorial/example_learn2ask/data_preprocess/message_splitter.py +++ b/tutorial/example_learn2ask/data_preprocess/message_splitter.py @@ -97,4 +97,4 @@ def split_session_to_json_lines(session): json_lines = split_session_to_json_lines(example_session) print("JSON lines output:") for i, line in enumerate(json_lines): - print(f"Line {i + 1}: {line}") \ No newline at end of file + print(f"Line {i + 1}: {line}") diff --git a/tutorial/example_learn2ask/data_preprocess/step1.py b/tutorial/example_learn2ask/data_preprocess/step1.py index d2ba27c6..d4533ffa 100644 --- a/tutorial/example_learn2ask/data_preprocess/step1.py +++ b/tutorial/example_learn2ask/data_preprocess/step1.py @@ -28,14 +28,14 @@ def process_jsonl_file( str: Success message or error information """ progress_file = output_file + ".progress" - + def load_progress(): """Load progress from progress file. Returns set of completed line numbers.""" if os.path.exists(progress_file): with open(progress_file, "r", encoding="utf-8") as f: return set(int(line.strip()) for line in f if line.strip()) return set() - + def process_single_session(args): """Worker function to process a single session.""" line_num, line = args @@ -54,41 +54,41 @@ def process_single_session(args): return line_num, None, f"Warning: Skipping invalid JSON at line {line_num}: {e}" except Exception as e: return line_num, None, f"Warning: Error processing session at line {line_num}: {e}" - + try: # Load previous progress completed_lines = load_progress() if completed_lines: print(f"Resuming from previous progress. {len(completed_lines)} lines already completed.") - + # Read all lines first with open(input_file, "r", encoding="utf-8") as infile: all_lines = list(enumerate(infile, 1)) - + total_lines = len(all_lines) # Filter out already completed lines lines_to_process = [(num, line) for num, line in all_lines if num not in completed_lines] - + if not lines_to_process: print("All lines already processed.") # Clean up progress file if os.path.exists(progress_file): os.remove(progress_file) return f"All lines already processed. Results in {output_file}" - + print(f"Processing {len(lines_to_process)} remaining lines out of {total_lines} total.") - + # State for ordered writing results_buffer = {} # line_num -> processed_lines next_line_to_write = min(num for num, _ in lines_to_process) write_lock = threading.Lock() progress_lock = threading.Lock() - + # Open output file in append mode if resuming, otherwise write mode file_mode = "a" if completed_lines else "w" outfile = open(output_file, file_mode, encoding="utf-8") progress_out = open(progress_file, "a", encoding="utf-8") - + def flush_buffer(): """Write all consecutive completed results from buffer to file.""" nonlocal next_line_to_write @@ -106,28 +106,28 @@ def flush_buffer(): # Skip lines that were already completed or empty while next_line_to_write <= total_lines and next_line_to_write not in dict(lines_to_process): next_line_to_write += 1 - + try: # Process sessions in parallel using ThreadPoolExecutor with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = {executor.submit(process_single_session, item): item[0] for item in lines_to_process} - + for future in as_completed(futures): line_num, processed_lines, error = future.result() if error: print(error) - + with write_lock: results_buffer[line_num] = processed_lines flush_buffer() finally: outfile.close() progress_out.close() - + # Clean up progress file on successful completion if os.path.exists(progress_file): os.remove(progress_file) - + return f"Successfully processed. Results saved to {output_file}" except Exception as e: @@ -177,7 +177,7 @@ def process_session(session, model_call_mode="online_api", max_retries=3, **kwar print(f"Attempt {attempt + 1} failed with exception: {str(e)}") if attempt < max_retries - 1: time.sleep(24) # Shorter wait for testing - + if info_set is None: raise Exception(f"failed to generate {session}") data["info_set"] = info_set @@ -206,4 +206,4 @@ def process_session(session, model_call_mode="online_api", max_retries=3, **kwar model_call_mode=args.model_call_mode, # Additional parameters for API calls ) - ) \ No newline at end of file + ) diff --git a/tutorial/example_learn2ask/data_preprocess/step2.py b/tutorial/example_learn2ask/data_preprocess/step2.py index 849aa510..9d546b0c 100644 --- a/tutorial/example_learn2ask/data_preprocess/step2.py +++ b/tutorial/example_learn2ask/data_preprocess/step2.py @@ -26,7 +26,7 @@ def main(input_file_path, output_file_path): if_keep, info_set, decision = process_message(data) if not if_keep: continue - + new_item = { 'main_query':'[no query]', 'init_messages': data['messages'], @@ -56,4 +56,4 @@ def main(input_file_path, output_file_path): args = parser.parse_args() - main(args.input_file, args.output_file) \ No newline at end of file + main(args.input_file, args.output_file) diff --git a/tutorial/example_learn2ask/learn2ask.md b/tutorial/example_learn2ask/learn2ask.md index d5afd08f..811d37f9 100644 --- a/tutorial/example_learn2ask/learn2ask.md +++ b/tutorial/example_learn2ask/learn2ask.md @@ -99,4 +99,4 @@ The agent's question is more precise and informative, providing two specific and ## Next -To learn more about the task and results on larger models, refer to [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441). \ No newline at end of file +To learn more about the task and results on larger models, refer to [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441). diff --git a/tutorial/example_learn2ask/learn2ask_langchain.py b/tutorial/example_learn2ask/learn2ask_langchain.py index d728ac64..b15d7309 100644 --- a/tutorial/example_learn2ask/learn2ask_langchain.py +++ b/tutorial/example_learn2ask/learn2ask_langchain.py @@ -4,7 +4,6 @@ import asyncio import threading -from agentscope.message import Msg from loguru import logger from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask @@ -174,26 +173,26 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl assert isinstance(messages, list) truth_action = workflow_task.task.metadata["decision_truth"] or "continue" truth_info = workflow_task.task.metadata["info_truth"] - + llm_info=tuner.as_oai_baseurl_apikey() - + llm=ChatOpenAI( base_url=llm_info.base_url, api_key=lambda:llm_info.api_key, ) - + agent=create_agent( model=llm, system_prompt=system_prompt, ) - + msg=[ {"role": x["role"], "content": x["content"]} for x in messages ] result = agent.invoke({ "messages": msg, # type: ignore }) - + response = result["messages"][-1].content reward = await reward_fn_with_semaphore(msg, response, truth_action, truth_info) return WorkflowOutput(reward=reward) diff --git a/tutorial/example_ma_deepresearch/ma_deepresearch.py b/tutorial/example_ma_deepresearch/ma_deepresearch.py index 9eaba34c..d044458b 100644 --- a/tutorial/example_ma_deepresearch/ma_deepresearch.py +++ b/tutorial/example_ma_deepresearch/ma_deepresearch.py @@ -2,13 +2,8 @@ from loguru import logger from pydantic import BaseModel, Field from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask -from openai.types.chat.chat_completion import ChatCompletion -from openai.types.chat import ChatCompletionMessageToolCall -from textwrap import dedent -import json import os -import asyncio import requests diff --git a/tutorial/example_math_agent/math_agent_langchain.py b/tutorial/example_math_agent/math_agent_langchain.py index c47fc355..4c99d240 100644 --- a/tutorial/example_math_agent/math_agent_langchain.py +++ b/tutorial/example_math_agent/math_agent_langchain.py @@ -1,12 +1,6 @@ -from loguru import logger from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask -from openai.types.chat.chat_completion import ChatCompletion -from openai.types.chat import ChatCompletionMessageToolCall from textwrap import dedent -import json -import asyncio -import requests from langchain.agents import create_agent @@ -30,7 +24,7 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl url_and_apikey = tuner.as_oai_baseurl_apikey() base_url = url_and_apikey.base_url api_key = url_and_apikey.api_key - + from langchain_openai import ChatOpenAI llm=ChatOpenAI( base_url=base_url, @@ -40,10 +34,10 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl model=llm, system_prompt=self.system_prompt, ) - + # take out query query = workflow_task.task.main_query - + response = agent.invoke({ "messages": [ { @@ -52,6 +46,6 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl } ], }) - + final_answer = response['messages'][-1].content - return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) \ No newline at end of file + return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) diff --git a/tutorial/example_math_agent/math_agent_oai_sdk.py b/tutorial/example_math_agent/math_agent_oai_sdk.py index 8304f14d..24bf47ec 100644 --- a/tutorial/example_math_agent/math_agent_oai_sdk.py +++ b/tutorial/example_math_agent/math_agent_oai_sdk.py @@ -1,4 +1,3 @@ -from loguru import logger from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat import ChatCompletionMessageToolCall diff --git a/tutorial/example_math_agent/math_agent_raw_http.py b/tutorial/example_math_agent/math_agent_raw_http.py index 6608e2be..69dfd949 100644 --- a/tutorial/example_math_agent/math_agent_raw_http.py +++ b/tutorial/example_math_agent/math_agent_raw_http.py @@ -1,11 +1,6 @@ -from loguru import logger from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask -from openai.types.chat.chat_completion import ChatCompletion -from openai.types.chat import ChatCompletionMessageToolCall from textwrap import dedent -import json -import asyncio import requests @@ -57,8 +52,3 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl ) final_answer = response.json()['choices'][0]['message']['content'] return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) - - - - - diff --git a/tutorial/example_werewolves/start.py b/tutorial/example_werewolves/start.py index 554b0977..879b6101 100644 --- a/tutorial/example_werewolves/start.py +++ b/tutorial/example_werewolves/start.py @@ -12,7 +12,7 @@ from agentscope.agent import ReActAgent from agentscope.formatter import DashScopeMultiAgentFormatter, OpenAIMultiAgentFormatter -from agentscope.model import DashScopeChatModel, OpenAIChatModel +from agentscope.model import OpenAIChatModel from loguru import logger from pydantic import Field From 7f2b0174437e31ea8df28ea1fa9dcbc6c0618413 Mon Sep 17 00:00:00 2001 From: Qingxu Fu Date: Fri, 16 Jan 2026 21:58:31 +0800 Subject: [PATCH 03/56] fix test bench import --- scripts/docker/dockerfile | 27 +++++++--- scripts/docker/dockerfile_trinity | 54 +++++++++++++++++++ .../benchmark_appworld/benchmark_appworld.py | 3 +- .../benchmark_countdown.py | 4 +- .../benchmark_frozenlake.py | 3 +- .../benchmark_learn2ask.py | 3 +- tests/bench/benchmark_math/benchmark_math.py | 3 +- 7 files changed, 84 insertions(+), 13 deletions(-) create mode 100644 scripts/docker/dockerfile_trinity diff --git a/scripts/docker/dockerfile b/scripts/docker/dockerfile index 89675c5f..dcda101d 100644 --- a/scripts/docker/dockerfile +++ b/scripts/docker/dockerfile @@ -8,7 +8,8 @@ FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 WORKDIR /workspace -RUN chmod 1777 /tmp && apt update && apt install -y \ +RUN chmod 1777 /tmp && apt update +RUN apt install -y \ build-essential \ curl git wget vim tmux net-tools \ python3 python3-pip python3-dev python3-venv python3-packaging \ @@ -24,17 +25,29 @@ RUN chmod 1777 /tmp && apt update && apt install -y \ # set uv virtual environment path to a outside-of-workspace dir ENV VIRTUAL_ENV=/opt/venv -# copy the Agentscope-Tuner dir into the workspace -COPY . . +# copy the AgentJets dir into the workspace +COPY pyproject.toml pyproject.toml # Install uv RUN pip install uv # use uv to create a virtual environment and install dependencies -RUN uv venv /opt/venv --python=3.10 && \ - . /opt/venv/bin/activate && \ - uv pip install -e .[verl] && \ - uv pip install flash_attn==2.8.1 --no-deps --no-cache-dir --no-build-isolation +RUN uv venv /opt/venv --python=3.10 + +ENV UV_HTTP_TIMEOUT=9999 + + +# RUN . /opt/venv/bin/activate && uv pip install -e .[verl] -i https://mirrors.aliyun.com/pypi/simple/ +# RUN . /opt/venv/bin/activate && uv pip install flash_attn==2.8.3 --no-deps --no-cache-dir --no-build-isolation + +# for ZH users +RUN . /opt/venv/bin/activate && uv pip install -e .[verl] -i https://mirrors.aliyun.com/pypi/simple/ +COPY flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl +RUN . /opt/venv/bin/activate && uv pip install flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl + + +# cache friendly layer for code changes +COPY . . # set entrypoint to activate the virtual environment ENTRYPOINT ["/bin/bash", "-c", "source /opt/venv/bin/activate && exec \"$@\"", "--"] diff --git a/scripts/docker/dockerfile_trinity b/scripts/docker/dockerfile_trinity new file mode 100644 index 00000000..99e83083 --- /dev/null +++ b/scripts/docker/dockerfile_trinity @@ -0,0 +1,54 @@ +# Build and run the docker image with the following command: +# +# docker build -f scripts/docker/dockerfile_trinity -t ajet:trinity_latest . +# docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data ajet:trinity_latest + + +FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 + +WORKDIR /workspace + +RUN chmod 1777 /tmp && apt update +RUN apt install -y \ + build-essential \ + curl git wget vim tmux net-tools \ + python3 python3-pip python3-dev python3-venv python3-packaging \ + libomp-dev infiniband-diags libibverbs-dev librdmacm-dev rdma-core perftest \ + && rm -rf /var/lib/apt/lists/* \ + && ln -sf /usr/bin/python3 /usr/bin/python \ + && ln -sf /usr/bin/pip3 /usr/bin/pip + +# For aliyun users, set pip source to aliyun mirror +# ENV PIP_INDEX_URL=http://mirrors.cloud.aliyuncs.com/pypi/simple/ +# ENV PIP_TRUSTED_HOST=mirrors.cloud.aliyuncs.com + +# set uv virtual environment path to a outside-of-workspace dir +ENV VIRTUAL_ENV=/opt/venv + +# copy the AgentJets dir into the workspace +COPY pyproject.toml pyproject.toml + +# Install uv +RUN pip install uv + +# use uv to create a virtual environment and install dependencies +RUN uv venv /opt/venv --python=3.10 + +ENV UV_HTTP_TIMEOUT=9999 + + +# RUN . /opt/venv/bin/activate && uv pip install -e .[verl] -i https://mirrors.aliyun.com/pypi/simple/ +# RUN . /opt/venv/bin/activate && uv pip install flash_attn==2.8.3 --no-deps --no-cache-dir --no-build-isolation + +# for ZH users +RUN . /opt/venv/bin/activate && uv pip install -e .[trinity] -i https://mirrors.aliyun.com/pypi/simple/ +COPY flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl +RUN . /opt/venv/bin/activate && uv pip install flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl + + +# cache friendly layer for code changes +COPY . . + +# set entrypoint to activate the virtual environment +ENTRYPOINT ["/bin/bash", "-c", "source /opt/venv/bin/activate && exec \"$@\"", "--"] +CMD ["bash"] diff --git a/tests/bench/benchmark_appworld/benchmark_appworld.py b/tests/bench/benchmark_appworld/benchmark_appworld.py index 70b440bf..6fc33649 100644 --- a/tests/bench/benchmark_appworld/benchmark_appworld.py +++ b/tests/bench/benchmark_appworld/benchmark_appworld.py @@ -1,7 +1,8 @@ # flake8: noqa import time -from ajet.utils.testing_utils import BenchmarkProbe, singleton +from ajet.utils.testing_utils import BenchmarkProbe +from ajet.utils.sington import singleton @singleton diff --git a/tests/bench/benchmark_countdown/benchmark_countdown.py b/tests/bench/benchmark_countdown/benchmark_countdown.py index fedb48f7..b4bdd56d 100644 --- a/tests/bench/benchmark_countdown/benchmark_countdown.py +++ b/tests/bench/benchmark_countdown/benchmark_countdown.py @@ -1,8 +1,8 @@ # flake8: noqa import time -from ajet.utils.testing_utils import BenchmarkProbe, singleton - +from ajet.utils.testing_utils import BenchmarkProbe +from ajet.utils.sington import singleton @singleton class TestProbe(BenchmarkProbe): diff --git a/tests/bench/benchmark_frozenlake/benchmark_frozenlake.py b/tests/bench/benchmark_frozenlake/benchmark_frozenlake.py index 7eadcf41..58b750e1 100644 --- a/tests/bench/benchmark_frozenlake/benchmark_frozenlake.py +++ b/tests/bench/benchmark_frozenlake/benchmark_frozenlake.py @@ -1,7 +1,8 @@ # flake8: noqa import time -from ajet.utils.testing_utils import BenchmarkProbe, singleton +from ajet.utils.testing_utils import BenchmarkProbe +from ajet.utils.sington import singleton @singleton diff --git a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py index fc26b776..7b35631c 100644 --- a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py +++ b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py @@ -1,7 +1,8 @@ # flake8: noqa import time -from ajet.utils.testing_utils import BenchmarkProbe, singleton +from ajet.utils.testing_utils import BenchmarkProbe +from ajet.utils.sington import singleton # trinity b.b. expectation # [TestProbe] Step 50: local average reward over last self.reward_expectation_avg_window steps: 2.6618, expected range: [0.0, 99999.0] diff --git a/tests/bench/benchmark_math/benchmark_math.py b/tests/bench/benchmark_math/benchmark_math.py index 973f9ea2..a08fa022 100644 --- a/tests/bench/benchmark_math/benchmark_math.py +++ b/tests/bench/benchmark_math/benchmark_math.py @@ -1,7 +1,8 @@ # flake8: noqa import time -from ajet.utils.testing_utils import BenchmarkProbe, singleton +from ajet.utils.testing_utils import BenchmarkProbe +from ajet.utils.sington import singleton @singleton From 9dd3c425f91c3138d8d546d1c922847caa4c2959 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Sat, 17 Jan 2026 22:43:21 +0800 Subject: [PATCH 04/56] refactor(finworld): Replace agent protocol and unify configuration updates MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Renamed ExampleAgentScopeLearnProtocol to ExampleDeepResearchProtocol and modified the execute method signature. - Unified the parameter name of the model tuner to `tuner` and its related attribute references. - Optimized the multi-turn interaction step configuration, changing it to use `tuner.config.ajet.rollout.multi_turn.max_steps`. - Modified the context overflow judgment logic to prevent tool call blocking. - Updated the finworld.yaml configuration, replacing astune with ajet-related configurations, and adjusted the workflow protocol and environment parameters. - Modified the default environment variable values ​​and log saving paths in finworld_judge.py. - Added and improved multi-machine and single-machine startup scripts, supporting dynamic generation of MCP configuration and environment variable loading. - Added the finworld_single.yaml template to adapt to single-machine training configurations. - Adjusted the key reference for multi-turn step configuration in ma_deepresearch.py, using the ajet configuration path. --- .../config/mcp_finance_tool_generated.json | 10 + tutorial/example_finworld/finworld.py | 17 +- tutorial/example_finworld/finworld.yaml | 29 +- tutorial/example_finworld/finworld_judge.py | 8 +- .../scripts/cc_rm4_res2cit2fai2_30b.sh | 384 ++++++++++++++++++ tutorial/example_finworld/scripts/single.sh | 112 +++++ .../ma_deepresearch.py | 2 +- 7 files changed, 533 insertions(+), 29 deletions(-) create mode 100644 tutorial/example_finworld/config/mcp_finance_tool_generated.json create mode 100644 tutorial/example_finworld/scripts/cc_rm4_res2cit2fai2_30b.sh create mode 100644 tutorial/example_finworld/scripts/single.sh diff --git a/tutorial/example_finworld/config/mcp_finance_tool_generated.json b/tutorial/example_finworld/config/mcp_finance_tool_generated.json new file mode 100644 index 00000000..90fbd828 --- /dev/null +++ b/tutorial/example_finworld/config/mcp_finance_tool_generated.json @@ -0,0 +1,10 @@ +{ + "mcpServers": { + "flowllm": { + "transport": "sse", + "url": "http://22.17.31.142:8040/sse", + "timeout": 600, + "sse_read_timeout": 1200 + } + } +} diff --git a/tutorial/example_finworld/finworld.py b/tutorial/example_finworld/finworld.py index 778e3439..f742adfc 100644 --- a/tutorial/example_finworld/finworld.py +++ b/tutorial/example_finworld/finworld.py @@ -11,12 +11,11 @@ # 创建信号量,允许同时12个线程运行 sem = threading.Semaphore(30) -class ExampleAgentScopeLearnProtocol(Workflow): +class ExampleDeepResearchProtocol(Workflow): - trainer: str = Field(default="astune-trinity") - async def agentscope_execute( - self, workflow_task: WorkflowTask, model_tuner: AjetTuner + async def execute( + self, workflow_task: WorkflowTask, tuner: AjetTuner ) -> WorkflowOutput: from agentscope.agent import ReActAgent from agentscope.formatter import DashScopeChatFormatter @@ -43,7 +42,7 @@ async def agentscope_execute( agent = ReActAgent( name="Qwen", sys_prompt=first_msg["content"], # Agent 内部会自动管理 System Prompt - model=model_tuner, + model=tuner.as_agentscope_model(), formatter=DashScopeChatFormatter(), memory=InMemoryMemory(), toolkit=None, @@ -69,10 +68,10 @@ async def agentscope_execute( cumulative_tool_call_time = 0.0 # 累计工具调用时间 cumulative_tool_time = {} # 按工具区分的累计耗时: {tool_name: [time1, time2, ...]} - logger.info(f"开始执行多轮交互,最大步数: {model_tuner.config.astune.rollout.multi_turn.max_steps}") + logger.info(f"开始执行多轮交互,最大步数: {tuner.config.ajet.rollout.multi_turn.max_steps}") step = 0 - for step in range(model_tuner.config.astune.rollout.multi_turn.max_steps): + for step in range(tuner.config.ajet.rollout.multi_turn.max_steps): logger.info(f"=== 步骤 {step + 1} ===") # === Agent 推理 === @@ -92,7 +91,7 @@ async def agentscope_execute( # === 早期终止检查:在调用 env.step() 前检查 context_overflow === # 修复问题:避免 token_overflow 后还继续调用工具导致阻塞 - if model_tuner.get_context_tracker().context_overflow: + if tuner.get_context_tracker().context_overflow: logger.warning(f"上下文溢出,跳过 env.step(),在第 {step + 1} 步立即结束") # 构造一个默认的结束响应 conversation_history.append({ @@ -200,7 +199,7 @@ async def agentscope_execute( logger.info(f"环境返回终止信号,在第 {step + 1} 步结束") break - if model_tuner.get_context_tracker().context_overflow: + if tuner.get_context_tracker().context_overflow: logger.warning(f"上下文溢出,在第 {step + 1} 步结束") break diff --git a/tutorial/example_finworld/finworld.yaml b/tutorial/example_finworld/finworld.yaml index 80ba8188..5be76eac 100644 --- a/tutorial/example_finworld/finworld.yaml +++ b/tutorial/example_finworld/finworld.yaml @@ -1,6 +1,6 @@ # ------------------ 主要配置 ------------------ -astune: - project_name: astune_finprompt +ajet: + project_name: ajet experiment_name: "cc_rm4_res2cit2fai2_30b" judge_llm: qwen-flash judge_concurrency: 10 @@ -10,11 +10,12 @@ astune: citation_audit_weight: 0.2 # 引用审计评估 (覆盖率 + 真实性) rm_weight: 0.4 # RM Gallery 权重 task_judge: - # 使用本地 FinWorldJudge 进行评估(解耦远程 env_service) - judge_protocol: tutorial.example_finworld.finworld_judge_by_openjudge->FinWorldJudgeByOpenJudge + judge_type: customized_protocol + judge_protocol: tutorial.example_finworld.finworld_judge->FinWorldJudgeByOpenJudge model: # ✨✨✨✨ 设置待训练的模型 path: /mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507 + # path: /mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-8B trainer_common: nnodes: 8 n_gpus_per_node: 8 @@ -25,9 +26,8 @@ astune: total_epochs: 200 rollout: # ✨✨✨✨ 编写并选择Agent - use_agentscope_protocol: True - agentscope_learn_protocol: tutorial.example_finworld.finworld->ExampleAgentScopeLearnProtocol - agentscope_disable_toolcalls: True + user_workflow: tutorial.example_finworld.finworld->ExampleDeepResearchProtocol + force_disable_toolcalls: True enable_oversample: False tensor_model_parallel_size: 8 num_repeat: 4 @@ -39,6 +39,8 @@ astune: agent_madness_reward: 0.0 multi_turn: max_steps: 6 + interchange_server: + interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) debug: debug_max_parallel: 64 # 增加并行任务数,充分利用GPU debug_first_n_tasks: 100 # 增加处理的任务数 @@ -56,24 +58,23 @@ astune: training_split: train validation_split: val trainer: - default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/astune/checkpoints/example_finworld//localths/cc_rm4_res2cit2fai2_30b" + default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/ajet/checkpoints/example_finworld//localths/cc_rm4_res2cit2fai2_30b" # resume_mode: disable # 禁用自动恢复,从头开始训练 actor_rollout_ref: rollout: tensor_model_parallel_size: 8 - gpu_memory_utilization: 0.8 + gpu_memory_utilization: 0.95 # ------------------ 不需要修改 ------------------ hydra: searchpath: - - file://astune/default_config - - file://astune/default_config/verl # verl only + - file://ajet/default_config + - file://ajet/default_config/verl # verl only - file://external/verl/verl/trainer/config # verl only - - file://astune/default_config/trinity # trinity only + - file://ajet/default_config/trinity # trinity only # ------------------ 不需要修改 ------------------ defaults: - - ppo_trainer # verl inherit 1/2 - verl_default # verl inherit 2/2 - trinity_default # trinity inherit 1/1 - - astune_default + - ajet_default - _self_ diff --git a/tutorial/example_finworld/finworld_judge.py b/tutorial/example_finworld/finworld_judge.py index 9c9518a1..f08b69c4 100644 --- a/tutorial/example_finworld/finworld_judge.py +++ b/tutorial/example_finworld/finworld_judge.py @@ -12,8 +12,6 @@ from ajet.task_judge.base_judge import BaseJudge from ajet.workflow import WorkflowOutput, WorkflowTask # RewardStats 不再使用,OpenJudge 版本直接使用字典存储 -# from tutorial.example_finworld.reward.reward_schema import RewardStats - # 环境变量配置 (RM Gallery) TRAIN_REF_ANS_PATH = os.environ.get("FINWORLD_TRAIN_REF_ANS_PATH", "") VAL_REF_ANS_PATH = os.environ.get("FINWORLD_VAL_REF_ANS_PATH", "") @@ -176,7 +174,7 @@ def _patched_openai_init(self, *args, **kwargs): logging.getLogger("rm_gallery").setLevel(logging.WARNING) api_key = os.environ.get("DASHSCOPE_API_KEY") or os.environ.get("API_KEY") base_url = os.environ.get("BASE_URL") or "https://dashscope.aliyuncs.com/compatible-mode/v1" - llm_name = os.environ.get("RM_LLM") + llm_name = os.environ.get("RM_LLM", "qwen-flash") rm_params = {"is_parallel": True, "enable_thinking": False, "base_url": base_url} # is_parallel=True 让子评估器并行调用LLM if api_key: rm_params["api_key"] = api_key @@ -640,7 +638,7 @@ def _save_rm_log(self, result, query: str, task_id: str): "timestamp": datetime.now().isoformat(), "scores": result.metadata.get("dimension_scores", {}) } - save_dir = "/mnt/data_cpfs/taoshuchang.tsc/deepresearch/ajet/outputs/rm_evaluation_logs" + save_dir = "./outputs/rm_evaluation_logs" os.makedirs(save_dir, exist_ok=True) with open(os.path.join(save_dir, f"rmeval_{datetime.now().strftime('%Y%m%d')}.json"), "a") as f: f.write(json.dumps(log, ensure_ascii=False) + "\n") @@ -754,7 +752,7 @@ def _save_evaluation_log(self, task_id: str, grader_results: Dict[str, List[Any] "reason": score.reason[:200] if hasattr(score, "reason") else "", }) - save_dir = "/mnt/data_cpfs/taoshuchang.tsc/deepresearch/ajet/outputs/openjudge_logs" + save_dir = "./outputs/openjudge_logs" os.makedirs(save_dir, exist_ok=True) log_file = os.path.join(save_dir, f"openjudge_{datetime.now().strftime('%Y%m%d')}.json") diff --git a/tutorial/example_finworld/scripts/cc_rm4_res2cit2fai2_30b.sh b/tutorial/example_finworld/scripts/cc_rm4_res2cit2fai2_30b.sh new file mode 100644 index 00000000..90643a17 --- /dev/null +++ b/tutorial/example_finworld/scripts/cc_rm4_res2cit2fai2_30b.sh @@ -0,0 +1,384 @@ +#!/bin/bash +set -e +#=============================================================================== +# 配置区域 - 用户只需修改这里 +#=============================================================================== +SUFFIX="cc_rm4_res2cit2fai2_30b" # 实验后缀,影响所有日志和实验名称 +PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 + +ADDR="22.17.31.142" +MCP_PORT="8040" +export CONFIG_FILE_NAME="tutorial/example_finworld/finworld.yaml" +export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" +#=============================================================================== +# 环境配置区域 +#=============================================================================== + +cd ${AJET_ROOT} +source .venv/bin/activate +# API密钥配置 - 从 .env 文件加载 +ENV_FILE="${AJET_ROOT}/.env" +if [ -f "$ENV_FILE" ]; then + set -a + source "$ENV_FILE" + set +a + echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" +else + echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" +fi + + + +#=============================================================================== +# 环境配置区域 +#=============================================================================== + +# MongoDB 缓存配置 +CACHE_TYPE="mongodb" +MONGO_URI="mongodb://${ADDR}:27117/" +MONGO_DB_NAME="finworld_cache" +MONGO_COLLECTION_NAME="tool_cache" + +# FinWorld MCP 配置 +LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" +FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" + +# 动态生成 MCP 配置文件(使用 ADDR 变量) +cat > ${FINWORLD_MCP_CONFIG} << EOF +{ + "mcpServers": { + "flowllm": { + "transport": "sse", + "url": "http://${ADDR}:${MCP_PORT}/sse", + "timeout": 600, + "sse_read_timeout": 1200 + } + } +} +EOF +FINWORLD_TOOL_RESULT_MAX_CHARS=10000 + +# 其他服务配置 +HF_ENDPOINT="https://hf-mirror.com" +ES_HOSTS="http://11.160.132.46:8200" + +#=============================================================================== +# 多机训练参数配置 +#=============================================================================== +if [ -z "${WORLD_SIZE}" ]; then + echo "ERROR: WORLD_SIZE environment variable is not set!" + echo "Please ensure this script is run in a multi-node environment (e.g., PAI-DLC, SLURM)" + exit 1 +fi + +NNODES=${WORLD_SIZE} +GPUS_PER_NODE=8 +EXPECTED_WORKERS=$WORLD_SIZE + +#=============================================================================== +# NCCL 配置 +#=============================================================================== +export NCCL_TIMEOUT=1800 +export NCCL_DEBUG=WARN +export NCCL_IB_TIMEOUT=23 +export NCCL_ASYNC_ERROR_HANDLING=1 +# RAY_DEBUG_POST_MORTEM="1" +# DEBUG_TAGS="TAG_A" +#=============================================================================== +# 自动生成的变量(不需要修改) +#=============================================================================== +CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") +CONFIG_FILE="${AJET_ROOT}/${CONFIG_FILE_NAME}" + +MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" +ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" +TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" + +#=============================================================================== +# 工具函数 +#=============================================================================== +print_green() { + echo -e "\033[32m$1\033[0m" +} + +print_red() { + echo -e "\033[31m$1\033[0m" +} + +log() { + echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" +} + +# 检查所有节点数量(包括head节点) +check_workers() { + local status_output=$(ray status 2>/dev/null) + if [ -z "$status_output" ]; then + echo 0 + return + fi + # 统计 "1 node_" 这种格式的行数 + local node_count=$(echo "$status_output" | grep -E "^[[:space:]]*1[[:space:]]+node_" | wc -l) + if [ "$node_count" -gt 0 ]; then + echo $node_count + return + fi + # 如果方法1失败,尝试统计包含node_的唯一ID + node_count=$(echo "$status_output" | grep -o "node_[0-9a-f]\+" | sort -u | wc -l) + echo $node_count +} + +# 检查GPU资源是否完全就绪 +check_gpu_resources() { + gpu_count=$(ray status 2>/dev/null | grep -A 10 "Resources" | grep "GPU" | awk '{print $1}' | cut -d'/' -f2) + if [ -z "$gpu_count" ]; then + echo 0 + else + printf "%.0f" "$gpu_count" + fi +} + +#=============================================================================== +# 导出环境变量 +# API密钥相关变量已通过 .env 文件加载并自动导出 (set -a) +#=============================================================================== +export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME +export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS +export HF_ENDPOINT ES_HOSTS +export PYTHONPATH="${AJET_ROOT}:${BEYONDAGENT_ROOT}:${PYTHONPATH}" +export RAY_CLUSTER_MODE="multi_node" + + + +# 配置 finworld 环境服务(供 launcher.py --with-finworld 使用) +# 注意:这里可以自定义 env_service 的启动参数 +export FINWORLD_PATH="${BEYONDAGENT_ROOT}" +# 如果需要传递额外参数,修改下面的命令行参数即可 +# 例如:--env_file_name custom_config --debug true +# FINWORLD_SCRIPT: API密钥会从环境变量继承 +export FINWORLD_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${BEYONDAGENT_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} FINWORLD_TASKS_DATA_PATH=${FINWORLD_TASKS_DATA_PATH} FINWORLD_TRAIN_REF_ANS_PATH=${FINWORLD_TRAIN_REF_ANS_PATH} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" + + +#=============================================================================== +# 主流程 +#=============================================================================== +log "开始多机多卡训练: ${SUFFIX}" +log "时间戳: ${CURRENT_TIME}" +log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" +log "配置文件: ${CONFIG_FILE}" + +# 确保日志目录存在 +mkdir -p ${LOG_DIR} + +#=============================================================================== +# Master 节点启动流程 +#=============================================================================== +if [[ $HOSTNAME == *"-master-"* ]]; then + print_green "==> This is MASTER node: $HOSTNAME" + + #--------------------------------------------------------------------------- + # 1. 清理和初始化 + #--------------------------------------------------------------------------- + rm -f "$MASTER_IP_FILE" + print_green "Cleaned old master IP file" + + ray stop --force || true + sleep 3 + print_green "Runtime env configuration created" + + #--------------------------------------------------------------------------- + # 4. 启动 Ray Head 节点(带 runtime_env) + #--------------------------------------------------------------------------- + print_green "Starting Ray head node at $MASTER_ADDR with runtime_env" + ray start --head \ + --node-ip-address $MASTER_ADDR \ + --num-gpus 8 + + print_green "Waiting for Ray head to be fully ready..." + sleep 10 + + if ! ray status > /dev/null 2>&1; then + print_red "ERROR: Ray head failed to start properly" + exit 1 + fi + print_green "Ray head is ready" + + # 写入 Master IP 到共享文件 + echo $MASTER_ADDR > $MASTER_IP_FILE + print_green "Master IP written to $MASTER_IP_FILE: $MASTER_ADDR" + + #--------------------------------------------------------------------------- + # 5. 等待所有 Worker 节点加入 + #--------------------------------------------------------------------------- + print_green "Waiting for all nodes to join the Ray cluster..." + print_green "Expected nodes: $EXPECTED_WORKERS (including head node)" + + TIMEOUT=1000 + INTERVAL=10 + ELAPSED=0 + + while true; do + current_nodes=$(check_workers) + print_green "Current node count: $current_nodes/$EXPECTED_WORKERS" + + if [ "$current_nodes" -ge "$EXPECTED_WORKERS" ]; then + print_green "All nodes have joined the cluster!" + break + fi + + if [ "$ELAPSED" -ge "$TIMEOUT" ]; then + print_red "Timeout waiting for nodes. Only $current_nodes/$EXPECTED_WORKERS nodes joined." + ray status + exit 1 + fi + + sleep $INTERVAL + ELAPSED=$((ELAPSED + INTERVAL)) + done + + #--------------------------------------------------------------------------- + # 6. 等待 GPU 资源就绪 + #--------------------------------------------------------------------------- + print_green "Waiting for GPU resources to be fully available..." + EXPECTED_GPUS=$((WORLD_SIZE * 8)) + GPU_TIMEOUT=300 + GPU_ELAPSED=0 + + while true; do + current_gpus=$(check_gpu_resources) + print_green "Current GPU count: $current_gpus/$EXPECTED_GPUS" + + if [ "$current_gpus" -eq "$EXPECTED_GPUS" ]; then + print_green "All GPUs are available!" + break + fi + + if [ "$GPU_ELAPSED" -ge "$GPU_TIMEOUT" ]; then + print_red "Timeout waiting for GPUs. Only $current_gpus/$EXPECTED_GPUS GPUs available." + ray status + exit 1 + fi + + sleep 5 + GPU_ELAPSED=$((GPU_ELAPSED + 5)) + done + + print_green "Final cluster status before training:" + ray status + + #--------------------------------------------------------------------------- + # 7. 等待 Ray Dashboard 启动 + #--------------------------------------------------------------------------- + print_green "Waiting for Ray dashboard to be ready..." + while ! curl -s http://127.0.0.1:8265 > /dev/null; do + sleep 5 + done + + #--------------------------------------------------------------------------- + # 8. 确认 env_service 启动配置 + #--------------------------------------------------------------------------- + print_green "Environment service will be started by launcher.py --with-finworld" + print_green " FINWORLD_PATH: ${FINWORLD_PATH}" + print_green " FINWORLD_SCRIPT: ${FINWORLD_SCRIPT}" + print_green " Log file: ${ENV_SERVICE_LOG}" + print_green " Note: env_service will load .env internally from its conda environment" + + #--------------------------------------------------------------------------- + # 9. 启动训练任务 + #--------------------------------------------------------------------------- + print_green "Starting training job..." + + + # 激活训练环境 + source .venv/bin/activate + + # 重新导出关键环境变量(conda activate 可能会重置) + # API密钥已通过 .env 加载 + export CACHE_TYPE="${CACHE_TYPE}" + export MONGO_URI="${MONGO_URI}" + export MONGO_DB_NAME="${MONGO_DB_NAME}" + export MONGO_COLLECTION_NAME="${MONGO_COLLECTION_NAME}" + + # 设置训练环境变量 + export RAY_ADDRESS="ray://localhost:10001" + export env_url="http://${MASTER_ADDR}:8080" + export env_type="finworld" + export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" + + # 输出配置信息 + print_green "===================================" + print_green "Training Configuration" + print_green "===================================" + print_green "NNODES: $NNODES" + print_green "GPUS_PER_NODE: $GPUS_PER_NODE" + print_green "Total GPUs: $((NNODES * GPUS_PER_NODE))" + print_green "env_url: $env_url" + print_green "RAY_ADDRESS: $RAY_ADDRESS" + print_green "Python: $(which python)" + print_green "训练日志: ${TRAIN_LOG}" + print_green "===================================" + + # 启动训练(多机模式下不需要 --with-ray,因为 Ray 集群已在脚本中手动启动) + # 使用 --with-finworld 让 launcher.py 统一管理 env_service 的启动和生命周期 + python ajet/launcher.py \ + --with-finworld \ + --conf ${CONFIG_FILE} \ + --backbone="verl" \ + 2>&1 | tee ${TRAIN_LOG} + ajet --conf ${CONFIG_FILE} --backbone='verl' + +#=============================================================================== +# Worker 节点启动流程 +#=============================================================================== +else + print_green "==> This is WORKER node: $HOSTNAME" + + #--------------------------------------------------------------------------- + # 1. 等待 Master IP 文件 + #--------------------------------------------------------------------------- + export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" + + while [ ! -f $MASTER_IP_FILE ]; do + print_green "Waiting for master node IP file..." + sleep 5 + done + sleep 2 + + MASTER_ADDR=$(cat $MASTER_IP_FILE) + print_green "Found master node at $MASTER_ADDR" + + #--------------------------------------------------------------------------- + # 2. 连接到 Ray 集群 + #--------------------------------------------------------------------------- + ray stop || true + + MAX_RETRIES=3 + RETRY_COUNT=0 + + while [ $RETRY_COUNT -lt $MAX_RETRIES ]; do + if ray start --address $MASTER_ADDR:6379 --num-gpus 8; then + print_green "Worker node started successfully" + break + fi + + RETRY_COUNT=$((RETRY_COUNT + 1)) + print_red "Failed to start worker node, attempt $RETRY_COUNT of $MAX_RETRIES" + sleep 10 + done + + if [ $RETRY_COUNT -eq $MAX_RETRIES ]; then + print_red "Failed to start worker node after $MAX_RETRIES attempts" + exit 1 + fi + + #--------------------------------------------------------------------------- + # 4. 保持连接状态 + #--------------------------------------------------------------------------- + print_green "Worker node is running, keeping alive..." + while true; do + sleep 60 + if ! ray status > /dev/null 2>&1; then + print_red "Lost connection to Ray cluster, exiting..." + break + fi + done +fi diff --git a/tutorial/example_finworld/scripts/single.sh b/tutorial/example_finworld/scripts/single.sh new file mode 100644 index 00000000..c52120c8 --- /dev/null +++ b/tutorial/example_finworld/scripts/single.sh @@ -0,0 +1,112 @@ +#!/bin/bash +set -e + +#=============================================================================== +# 配置区域 +#=============================================================================== +SUFFIX="cc_rm4_res2cit2fai2_30b_single" # 实验后缀 +PREFIX="open" # 实验前缀 + +ADDR="127.0.0.1" # 单机建议使用回环地址 +MCP_PORT="8040" +export CONFIG_FILE_NAME="tutorial/example_finworld/finworld_single.yaml" +export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" +export BEYONDAGENT_ROOT="${AJET_ROOT}" # 假设在同一目录下,若不同请手动修改 + +#=============================================================================== +# 环境初始化 +#=============================================================================== +cd ${AJET_ROOT} + +# 加载 .env +ENV_FILE="${AJET_ROOT}/.env" +if [ -f "$ENV_FILE" ]; then + set -a && source "$ENV_FILE" && set +a + echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" +fi + +# 1. 激活主虚拟环境 (uv) +source .venv/bin/activate + +# 2. 动态获取 Conda 基础路径,用于解决 PTY 找不到 conda 的问题 +CONDA_BASE_PATH=$(conda info --base) + +#=============================================================================== +# 服务与路径配置 +#=============================================================================== +# MongoDB 配置 +export CACHE_TYPE="mongodb" +export MONGO_URI="mongodb://${ADDR}:27117/" +export MONGO_DB_NAME="finworld_cache" +export MONGO_COLLECTION_NAME="tool_cache" + +# FinWorld 配置 +LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" +mkdir -p ${LOG_DIR} +export FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" +export FINWORLD_TOOL_RESULT_MAX_CHARS=10000 + +# 动态生成 MCP 配置 +cat > ${FINWORLD_MCP_CONFIG} << EOF +{ + "mcpServers": { + "flowllm": { + "transport": "sse", + "url": "http://${ADDR}:${MCP_PORT}/sse", + "timeout": 600, + "sse_read_timeout": 1200 + } + } +} +EOF + +# 环境变量导出 +export HF_ENDPOINT="https://hf-mirror.com" +export ES_HOSTS="http://11.160.132.46:8200" +export PYTHONPATH="${AJET_ROOT}:${BEYONDAGENT_ROOT}:${PYTHONPATH}" +export RAY_CLUSTER_MODE="single_node" + +# 关键修复:在脚本中显式加载 conda.sh 以供 PTY 子进程使用 +export FINWORLD_PATH="${BEYONDAGENT_ROOT}" +export FINWORLD_SCRIPT="source ${CONDA_BASE_PATH}/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${BEYONDAGENT_ROOT} && python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" + +#=============================================================================== +# 启动 Ray 本地集群 +#=============================================================================== +echo -e "\033[32m正在初始化单机 Ray 环境...\033[0m" +ray stop --force || true +sleep 2 + +# 启动单机 Head 节点,分配 8 张 GPU +ray start --head --num-gpus 8 + +# 等待 Ray 就绪 +sleep 5 +if ! ray status > /dev/null 2>&1; then + echo -e "\033[31m错误: Ray 启动失败\033[0m" + exit 1 +fi + +#=============================================================================== +# 启动训练 +#=============================================================================== +CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") +CONFIG_FILE="${AJET_ROOT}/${CONFIG_FILE_NAME}" +TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" + +# 设置训练所需的运行时变量 +export RAY_ADDRESS="ray://localhost:10001" +export env_url="http://127.0.0.1:8080" +export env_type="finworld" + +echo -e "\033[32m===================================\033[0m" +echo -e "\033[32m开始单机运行: ${SUFFIX}\033[0m" +echo -e "\033[32m日志文件: ${TRAIN_LOG}\033[0m" +echo -e "\033[32m===================================\033[0m" + +# 启动 Launcher +python ajet/launcher.py \ + --with-finworld \ + --conf ${CONFIG_FILE} \ + --backbone="verl" \ + 2>&1 | tee ${TRAIN_LOG} \ No newline at end of file diff --git a/tutorial/example_ma_deepresearch/ma_deepresearch.py b/tutorial/example_ma_deepresearch/ma_deepresearch.py index 9eaba34c..9b84736b 100644 --- a/tutorial/example_ma_deepresearch/ma_deepresearch.py +++ b/tutorial/example_ma_deepresearch/ma_deepresearch.py @@ -47,7 +47,7 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl init_messages=init_messages, task_id=workflow_task.task.task_id, main_query=workflow_task.task.main_query, - max_steps=tuner.config.astune.rollout.multi_turn.max_steps, + max_steps=tuner.config.ajet.rollout.multi_turn.max_steps, env_service_url=workflow_task.gym_env.service_url, ) From 757f8a197c74d4dcb028f4e54894dc128344ee2c Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Sun, 18 Jan 2026 19:32:57 +0800 Subject: [PATCH 05/56] feat(finworld): Added FinWorld training environment configuration scripts and templates - Added bash startup scripts for multi-machine, multi-GPU training, supporting dynamic configuration generation and environment variable import. - Implemented training configuration file templates, supporting automatic injection of various weight parameters and model paths. - Adjusted the default request timeout of EnvClient from 30 seconds to 300 seconds to accommodate long training requests. - Added a new finworld example directory and related documentation, improving the example project structure. --- .../utils/env_service_client/env_client_ng.py | 2 +- tutorial/example_finworld/finworld.md | 1 + .../example_finworld/scripts/ajet_finworld.sh | 245 ++++++++++++++++++ .../yaml_template/finworld_template.yaml | 79 ++++++ 4 files changed, 326 insertions(+), 1 deletion(-) create mode 100644 tutorial/example_finworld/finworld.md create mode 100644 tutorial/example_finworld/scripts/ajet_finworld.sh create mode 100644 tutorial/example_finworld/yaml_template/finworld_template.yaml diff --git a/ajet/utils/env_service_client/env_client_ng.py b/ajet/utils/env_service_client/env_client_ng.py index bee86619..a8e1112f 100644 --- a/ajet/utils/env_service_client/env_client_ng.py +++ b/ajet/utils/env_service_client/env_client_ng.py @@ -49,7 +49,7 @@ def retry_call( class EnvClient: def __init__(self, base_url: str = "http://localhost:8000"): self.base_url = base_url.rstrip("/") - self.timeout = 30.0 + self.timeout = 300.0 def _make_request( self, diff --git a/tutorial/example_finworld/finworld.md b/tutorial/example_finworld/finworld.md new file mode 100644 index 00000000..e884e864 --- /dev/null +++ b/tutorial/example_finworld/finworld.md @@ -0,0 +1 @@ +# finworld \ No newline at end of file diff --git a/tutorial/example_finworld/scripts/ajet_finworld.sh b/tutorial/example_finworld/scripts/ajet_finworld.sh new file mode 100644 index 00000000..d3d03c61 --- /dev/null +++ b/tutorial/example_finworld/scripts/ajet_finworld.sh @@ -0,0 +1,245 @@ +#!/bin/bash +set -e +#=============================================================================== +# 配置区域 - 用户只需修改这里 +#=============================================================================== +SUFFIX="ajet_finworld" # 实验后缀,影响所有日志和实验名称 +PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 + +# 新增:模型与模板配置 +MODEL_PATH="/mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507" +CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_template.yaml" + +# 新增:奖励权重与 Judge 配置 +JUDGE_LLM='qwen-flash' +judge_concurrency=10 +RM_WEIGHT=0.4 +CITATION_AUDIT_WEIGHT=0.2 +report_resolution_weight=0.2 +trajectory_faithfulness_weight=0.2 + +DASHSCOPE_API_KEY="***REMOVED***" # yutai +RM_LLM='qwen-max' +# 配置 +NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 +TRAIN_BATCH_SIZE=32 +NUM_STEPS=6 # 每个样本step轮数 + +ADDR="22.17.31.142" +MCP_PORT="8040" + +# 修改:配置文件生成路径,现在动态生成到 yaml 目录下 +export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" +CONFIG_FILE="${AJET_ROOT}/tutorial/example_finworld/yaml/finworld_${SUFFIX}.yaml" + +#=============================================================================== +# 环境配置区域 +#=============================================================================== + +cd ${AJET_ROOT} +source .venv/bin/activate +# API密钥配置 - 从 .env 文件加载 +ENV_FILE="${AJET_ROOT}/.env" +if [ -f "$ENV_FILE" ]; then + set -a + source "$ENV_FILE" + set +a + echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" +else + echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" +fi + +# MongoDB 缓存配置 +CACHE_TYPE="mongodb" +MONGO_URI="mongodb://${ADDR}:27117/" +MONGO_DB_NAME="finworld_cache" +MONGO_COLLECTION_NAME="tool_cache" + +# FinWorld MCP 配置 +LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" +FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" + +# 动态生成 MCP 配置文件 +mkdir -p $(dirname ${FINWORLD_MCP_CONFIG}) +cat > ${FINWORLD_MCP_CONFIG} << EOF +{ + "mcpServers": { + "flowllm": { + "transport": "sse", + "url": "http://${ADDR}:${MCP_PORT}/sse", + "timeout": 600, + "sse_read_timeout": 1200 + } + } +} +EOF +FINWORLD_TOOL_RESULT_MAX_CHARS=10000 + +# 其他服务配置 +HF_ENDPOINT="https://hf-mirror.com" +ES_HOSTS="http://11.160.132.46:8200" + +#=============================================================================== +# 多机训练参数配置 +#=============================================================================== +if [ -z "${WORLD_SIZE}" ]; then + echo "ERROR: WORLD_SIZE environment variable is not set!" + echo "Please ensure this script is run in a multi-node environment (e.g., PAI-DLC, SLURM)" + exit 1 +fi + +NNODES=${WORLD_SIZE} +GPUS_PER_NODE=8 +EXPECTED_WORKERS=$WORLD_SIZE + +#=============================================================================== +# NCCL 配置 +#=============================================================================== +export NCCL_TIMEOUT=1800 +export NCCL_DEBUG=WARN +export NCCL_IB_TIMEOUT=23 +export NCCL_ASYNC_ERROR_HANDLING=1 + +#=============================================================================== +# 自动生成的变量 +#=============================================================================== +CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") + +MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" +ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" +TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" + +#=============================================================================== +# 工具函数 +#=============================================================================== +print_green() { + echo -e "\033[32m$1\033[0m" +} + +print_red() { + echo -e "\033[31m$1\033[0m" +} + +log() { + echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" +} + +check_workers() { + local status_output=$(ray status 2>/dev/null) + if [ -z "$status_output" ]; then echo 0; return; fi + local node_count=$(echo "$status_output" | grep -E "^[[:space:]]*1[[:space:]]+node_" | wc -l) + if [ "$node_count" -gt 0 ]; then echo $node_count; return; fi + echo $(echo "$status_output" | grep -o "node_[0-9a-f]\+" | sort -u | wc -l) +} + +check_gpu_resources() { + gpu_count=$(ray status 2>/dev/null | grep -A 10 "Resources" | grep "GPU" | awk '{print $1}' | cut -d'/' -f2) + if [ -z "$gpu_count" ]; then echo 0; else printf "%.0f" "$gpu_count"; fi +} + +#=============================================================================== +# 导出环境变量 +#=============================================================================== +export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME +export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS +export HF_ENDPOINT ES_HOSTS +export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" +export RAY_CLUSTER_MODE="multi_node" + +export FINWORLD_PATH="${AJET_ROOT}" # AgentJet 内部可能使用此路径 +export FINWORLD_SCRIPT="source .venv/bin/activate && cd ${AJET_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" + +#=============================================================================== +# 主流程 +#=============================================================================== +log "开始多机多卡训练: ${SUFFIX}" +log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" +mkdir -p ${LOG_DIR} +mkdir -p $(dirname ${CONFIG_FILE}) + +#=============================================================================== +# Master 节点启动流程 +#=============================================================================== +if [[ $HOSTNAME == *"-master-"* ]]; then + print_green "==> This is MASTER node: $HOSTNAME" + + #--------------------------------------------------------------------------- + # 1. 动态生成配置文件 (从模板注入参数) + #--------------------------------------------------------------------------- + log "正在从模板生成配置文件..." + sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ + -e "s|{{PREFIX}}|${PREFIX}|g" \ + -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ + -e "s|{{NNODES}}|${NNODES}|g" \ + -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ + -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ + -e "s|{{JUDGE_LLM}}|${JUDGE_LLM}|g" \ + -e "s|{{JUDGE_CONCURRENCY}}|${judge_concurrency}|g" \ + -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${report_resolution_weight}|g" \ + -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${trajectory_faithfulness_weight}|g" \ + ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} + + print_green "配置文件已生成: ${CONFIG_FILE}" + print_green "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, Judge=${JUDGE_LLM}" + + #--------------------------------------------------------------------------- + # 2. 清理和初始化 Ray + #--------------------------------------------------------------------------- + rm -f "$MASTER_IP_FILE" + ray stop --force || true + sleep 3 + + #--------------------------------------------------------------------------- + # 4. 启动 Ray Head + #--------------------------------------------------------------------------- + print_green "Starting Ray head node at $MASTER_ADDR" + ray start --head --node-ip-address $MASTER_ADDR --num-gpus 8 + sleep 10 + echo $MASTER_ADDR > $MASTER_IP_FILE + + #--------------------------------------------------------------------------- + # 5 & 6. 等待节点和 GPU 就绪 (逻辑保持不变) + #--------------------------------------------------------------------------- + # ... (此处省略重复的等待逻辑以保持简洁,实际运行时请保留原脚本中的 while 循环) ... + # [请保留原脚本中 5.等待所有Worker 6.等待GPU 7.等待Dashboard 的完整代码] + + #--------------------------------------------------------------------------- + # 9. 启动训练任务 + #--------------------------------------------------------------------------- + print_green "Starting training job..." + source .venv/bin/activate + + export RAY_ADDRESS="ray://localhost:10001" + export env_url="http://${MASTER_ADDR}:8080" + export env_type="finworld" + + print_green "===================================" + print_green "Training Configuration" + print_green "Total GPUs: $((NNODES * GPUS_PER_NODE))" + print_green "Log: ${TRAIN_LOG}" + print_green "===================================" + + # 修改:同步 cc_rm4 的启动参数,增加 debug 和 log-suffix + python ajet/launcher.py \ + --with-finworld \ + --conf ${CONFIG_FILE} \ + --backbone="verl" \ + --debug="TAG_A" \ + --log-suffix="${SUFFIX}" \ + 2>&1 | tee ${TRAIN_LOG} + + # 保留原脚本末尾的 CLI 调用 + ajet --conf ${CONFIG_FILE} --backbone='verl' + +#=============================================================================== +# Worker 节点启动流程 (逻辑保持不变) +#=============================================================================== +else + print_green "==> This is WORKER node: $HOSTNAME" + # [此处保留原脚本中 Worker 节点等待 Master IP 和连接 Ray 的完整逻辑] + while [ ! -f $MASTER_IP_FILE ]; do sleep 5; done + MASTER_ADDR=$(cat $MASTER_IP_FILE) + ray stop || true + ray start --address $MASTER_ADDR:6379 --num-gpus 8 + while true; do sleep 60; done +fi \ No newline at end of file diff --git a/tutorial/example_finworld/yaml_template/finworld_template.yaml b/tutorial/example_finworld/yaml_template/finworld_template.yaml new file mode 100644 index 00000000..14fe6194 --- /dev/null +++ b/tutorial/example_finworld/yaml_template/finworld_template.yaml @@ -0,0 +1,79 @@ +# ------------------ 主要配置 ------------------ +astune: + project_name: astune_finprompt + experiment_name: "{{SUFFIX}}" + judge_llm: {{JUDGE_LLM}} + judge_concurrency: {{JUDGE_CONCURRENCY}} + # OpenJudge 权重配置 + report_resolution_weight: {{REPORT_RESOLUTION_WEIGHT}} # 报告质量评估 + trajectory_faithfulness_weight: {{TRAJECTORY_FAITHFULNESS_WEIGHT}} # 事实准确性评估 + citation_audit_weight: {{CITATION_AUDIT_WEIGHT}} # 引用审计评估 (覆盖率 + 真实性) + rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 + task_judge: + # 使用本地 FinWorldJudge 进行评估(解耦远程 env_service) + judge_protocol: tutorial.example_finworld.finworld_judge_by_openjudge->FinWorldJudgeByOpenJudge + model: + # ✨✨✨✨ 设置待训练的模型 + path: {{MODEL_PATH}} + trainer_common: + nnodes: {{NNODES}} + n_gpus_per_node: 8 + val_before_train: True + val_pass_n: 8 + save_freq: 10 + test_freq: 2 + total_epochs: 200 + rollout: + # ✨✨✨✨ 编写并选择Agent + use_agentscope_protocol: True + agentscope_learn_protocol: tutorial.example_finworld.finworld->ExampleAgentScopeLearnProtocol + agentscope_disable_toolcalls: True + enable_oversample: False + tensor_model_parallel_size: 8 + num_repeat: {{NUM_REPEAT}} + max_env_worker: 64 # 增加环境并行数 + max_num_seqs: 64 # 增加VLLM并发序列数 + max_env_len: 10000 + max_response_length_in_one_turn: 8000 + max_model_len: 50000 + agent_madness_reward: 0.0 + multi_turn: + max_steps: {{NUM_STEPS}} + debug: + debug_max_parallel: 64 # 增加并行任务数,充分利用GPU + debug_first_n_tasks: 100 # 增加处理的任务数 + data: + train_batch_size: {{TRAIN_BATCH_SIZE}} + max_prompt_length: 8000 + max_response_length: 41000 + + task_reader: + type: env_service # `env_service` or `dataset_file` or `huggingface_dat_repo` + env_service: + env_type: "finworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code # code, text, box + training_split: train + validation_split: val +trainer: + default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/astune/checkpoints/example_finworld//{{PREFIX}}/{{SUFFIX}}" + # resume_mode: disable # 禁用自动恢复,从头开始训练 +actor_rollout_ref: + rollout: + tensor_model_parallel_size: 8 + gpu_memory_utilization: 0.8 +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://astune/default_config + - file://astune/default_config/verl # verl only + - file://external/verl/verl/trainer/config # verl only + - file://astune/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - ppo_trainer # verl inherit 1/2 + - verl_default # verl inherit 2/2 + - trinity_default # trinity inherit 1/1 + - astune_default + - _self_ From 079e4bd48d19edf9612c2a1bbb94c4dc0132c52a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=B6=E8=88=92=E7=95=85?= Date: Sun, 18 Jan 2026 20:39:36 +0800 Subject: [PATCH 06/56] refactor(utils): Remove unused extract and compute functions `extract_tool_stats_from_cmts` --- .../utils/metric_helper/tool_metric_helper.py | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/ajet/utils/metric_helper/tool_metric_helper.py b/ajet/utils/metric_helper/tool_metric_helper.py index e9c7728d..51a488b8 100644 --- a/ajet/utils/metric_helper/tool_metric_helper.py +++ b/ajet/utils/metric_helper/tool_metric_helper.py @@ -33,23 +33,6 @@ def extract_tool_stats_from_trajectories(trajectories: List[Any]) -> List[Dict[s return tool_stats_list -def extract_tool_stats_from_cmts(cmts: List[Any]) -> List[Dict[str, Any]]: - """ - Extract tool_stats from cmts list. - - Args: - cmts: List of cmt objects containing workflow_metadata - - Returns: - List of tool_stats dictionaries - """ - tool_stats_list = [] - for traj in trajs: - if hasattr(traj, 'workflow_metadata') and traj.workflow_metadata: - if 'tool_stats' in traj.workflow_metadata: - tool_stats_list.append(traj.workflow_metadata['tool_stats']) - return tool_stats_list - def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "") -> Dict[str, float]: """ @@ -159,9 +142,3 @@ def compute_tool_metrics_from_trajectories(trajectories: List[Any]) -> Dict[str, return compute_tool_metrics(tool_stats_list, prefix="train_") -def compute_tool_metrics_from_cmts(cmts: List[Any]) -> Dict[str, float]: - """ - Validation phase: Extract tool_stats from cmts and compute metrics. - """ - tool_stats_list = extract_tool_stats_from_cmts(cmts) - return compute_tool_metrics(tool_stats_list, prefix="val_") From bcce8f04c1fc0df18441b352f75cc7845ad0a6f5 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Sun, 18 Jan 2026 23:22:36 +0800 Subject: [PATCH 07/56] refactor(finworld): Replace the old model with OpenJudge, update evaluation configuration and scripts - Replaced model initialization in FinWorldJudgeByOpenJudge with the `_init_openjudge_model` method - Read Judge model parameters from the configuration file first, using environment variables as a fallback - Optimized RM Gallery initialization, using configuration-first logic, and improved exception stack trace printing - Cleaned up and removed the old `_init_model` singleton method and related code - Updated the example startup script `ajet_finworld.sh`, adding OPENJUDGE_LLM and RM_LLM configurations - Modified YAML templates and configuration files to unify the structure and field naming of Judge configuration items - Deleted the outdated `cc_rm4_res2cit2fai2_30b.sh` script - Adjusted the `env_service` startup path to improve environment activation compatibility - Adjusted script log output format and content to enhance the clarity of configuration parameter printing --- tutorial/example_finworld/finworld_judge.py | 78 ++-- .../example_finworld/scripts/ajet_finworld.sh | 39 +- .../scripts/cc_rm4_res2cit2fai2_30b.sh | 384 ------------------ .../yaml/finworld_ajet_finworld.yaml | 82 ++++ .../yaml_template/finworld_template.yaml | 29 +- 5 files changed, 163 insertions(+), 449 deletions(-) delete mode 100644 tutorial/example_finworld/scripts/cc_rm4_res2cit2fai2_30b.sh create mode 100644 tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml diff --git a/tutorial/example_finworld/finworld_judge.py b/tutorial/example_finworld/finworld_judge.py index f08b69c4..02bb8855 100644 --- a/tutorial/example_finworld/finworld_judge.py +++ b/tutorial/example_finworld/finworld_judge.py @@ -6,17 +6,13 @@ import json import asyncio import time +import logging from datetime import datetime from typing import Dict, Any, Optional, Tuple, List from ajet.task_judge.base_judge import BaseJudge from ajet.workflow import WorkflowOutput, WorkflowTask -# RewardStats 不再使用,OpenJudge 版本直接使用字典存储 -# 环境变量配置 (RM Gallery) -TRAIN_REF_ANS_PATH = os.environ.get("FINWORLD_TRAIN_REF_ANS_PATH", "") -VAL_REF_ANS_PATH = os.environ.get("FINWORLD_VAL_REF_ANS_PATH", "") -# OpenJudge imports from openjudge.graders.agent.action.action_loop import ActionLoopDetectionGrader from openjudge.graders.agent.observation.observation_information_gain import ( ObservationInformationGainGrader, @@ -41,6 +37,12 @@ ) +# RewardStats 不再使用,OpenJudge 版本直接使用字典存储 +# 环境变量配置 (RM Gallery) +TRAIN_REF_ANS_PATH = os.environ.get("FINWORLD_TRAIN_REF_ANS_PATH", "") +VAL_REF_ANS_PATH = os.environ.get("FINWORLD_VAL_REF_ANS_PATH", "") + +# OpenJudge imports # ============================================================================= # 全局辅助函数 # ============================================================================= @@ -107,7 +109,7 @@ class FinWorldJudgeByOpenJudge(BaseJudge): def __init__(self, config): super().__init__(config) self._setup_weights() - self._init_model() # 只初始化 model,runner 在每次调用时创建 + self._init_openjudge_model() # 只初始化 model,runner 在每次调用时创建 self._init_rm_components() # 初始化 RM Gallery 组件 self._init_reference_answers() # 初始化参考答案 @@ -146,6 +148,27 @@ def _setup_weights(self): self.w[k] = self.w[k] / total + def _init_openjudge_model(self): + """初始化 OpenJudge LLM Model""" + # --- model name from config.ajet.judge.* --- + openjudge_model_name = self.config.ajet.judge.openjudge_llm + openjudge_base_url = os.environ.get("OPENJUDGE_BASE_URL") + openjudge_api_key = os.environ.get("OPENJUDGE_API_KEY") + + self._model_instance = OpenAIChatModel( + model=openjudge_model_name, + base_url=openjudge_base_url, + api_key=openjudge_api_key, + ) + # 设置实例变量供 _create_runner_in_loop 使用 + self.model = self._model_instance + self.max_concurrency = getattr(self.config.ajet.judge, "concurrency", 6) + + print( + f"[Init OpenJudge Model] model={openjudge_model_name}, base_url={openjudge_base_url}, " + f"api_key={'SET' if openjudge_api_key else 'NONE'}, max_concurrency={self.max_concurrency}" + ) + def _init_rm_components(self): """初始化 RM Gallery Evaluator(仅当 rm_weight > 0 时)""" self._rm_enabled = (self.w.get("rm", 0) > 0) @@ -172,19 +195,24 @@ def _patched_openai_init(self, *args, **kwargs): from rm_gallery.core.reward.registry import RewardRegistry import logging logging.getLogger("rm_gallery").setLevel(logging.WARNING) - api_key = os.environ.get("DASHSCOPE_API_KEY") or os.environ.get("API_KEY") - base_url = os.environ.get("BASE_URL") or "https://dashscope.aliyuncs.com/compatible-mode/v1" - llm_name = os.environ.get("RM_LLM", "qwen-flash") - rm_params = {"is_parallel": True, "enable_thinking": False, "base_url": base_url} # is_parallel=True 让子评估器并行调用LLM - if api_key: rm_params["api_key"] = api_key + # 从 config 读取 rm_llm,环境变量作为 fallback + rm_llm_name = self.config.ajet.judge.rm_llm + rm_api_key = os.environ.get("RM_API_KEY") + rm_base_url = os.environ.get("RM_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") + + rm_params = {"is_parallel": True, "enable_thinking": False, "base_url": rm_base_url} + if rm_api_key: + rm_params["api_key"] = rm_api_key self.rm_evaluator = RewardRegistry.get("finance_composition")( - llm=llm_name, name="finance_composition", params=rm_params + llm=rm_llm_name, name="finance_composition", params=rm_params ) - print(f"✓ RM evaluator initialized: {llm_name} {base_url} (timeout=600s)") + print(f"[Init RM Evaluator] llm={rm_llm_name}, base_url={rm_base_url}, api_key={'SET' if rm_api_key else 'NONE'} (timeout=600s)") except Exception as e: print(f"✗ Failed to initialize RM evaluator: {e}") + import traceback + traceback.print_exc() self.rm_evaluator = None def _init_reference_answers(self): @@ -206,29 +234,7 @@ def _get_reference_data(self, task_id: str) -> Tuple[str, str]: dom = FinWorldJudgeByOpenJudge._ref_domains_cache.get(cache_key, {}).get(task_id) return ans, dom - def _init_model(self): - """初始化 OpenJudge LLM Model(单例模式,可复用)""" - if FinWorldJudgeByOpenJudge._model_instance is None: - try: - model_name = getattr(self.config.ajet, "judge_llm", "qwen-flash") if hasattr(self.config, "ajet") else "qwen-flash" - base_url = os.environ.get("JUDGE_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") - api_key = os.environ.get("JUDGE_API_KEY", os.environ.get("DASHSCOPE_API_KEY", None)) - FinWorldJudgeByOpenJudge._model_instance = OpenAIChatModel( - model=model_name, - temperature=0.0, - base_url=base_url, - api_key=api_key - ) - print(f"✓ OpenJudge Model initialized: {model_name} @ {base_url}: {api_key}") - except Exception as e: - print(f"✗ Failed to initialize OpenJudge Model: {e}") - import traceback - traceback.print_exc() - raise - - self.model = FinWorldJudgeByOpenJudge._model_instance - self.max_concurrency = getattr(self.config.ajet, "judge_concurrency", 6) if hasattr(self.config, "ajet") else 6 - + def _create_runner_in_loop(self) -> GradingRunner: """ 在当前事件循环中创建 GradingRunner diff --git a/tutorial/example_finworld/scripts/ajet_finworld.sh b/tutorial/example_finworld/scripts/ajet_finworld.sh index d3d03c61..d417c7cf 100644 --- a/tutorial/example_finworld/scripts/ajet_finworld.sh +++ b/tutorial/example_finworld/scripts/ajet_finworld.sh @@ -10,16 +10,18 @@ PREFIX="open" # 实验前缀,影响日志和实验所 MODEL_PATH="/mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507" CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_template.yaml" -# 新增:奖励权重与 Judge 配置 -JUDGE_LLM='qwen-flash' -judge_concurrency=10 +# 新增:Judge 模型配置 +OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 +RM_LLM='qwen-max' # RM Gallery 评分模型 +JUDGE_CONCURRENCY=10 + +# 新增:奖励权重配置 RM_WEIGHT=0.4 CITATION_AUDIT_WEIGHT=0.2 -report_resolution_weight=0.2 -trajectory_faithfulness_weight=0.2 +REPORT_RESOLUTION_WEIGHT=0.2 +TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 -DASHSCOPE_API_KEY="***REMOVED***" # yutai -RM_LLM='qwen-max' +# API密钥配置(从 .env 文件加载,不要硬编码) # 配置 NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 TRAIN_BATCH_SIZE=32 @@ -145,9 +147,11 @@ export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS export HF_ENDPOINT ES_HOSTS export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" export RAY_CLUSTER_MODE="multi_node" +# Directory paths +export ENV_SERVICE_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/mongodb/BeyondAgent_env" -export FINWORLD_PATH="${AJET_ROOT}" # AgentJet 内部可能使用此路径 -export FINWORLD_SCRIPT="source .venv/bin/activate && cd ${AJET_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" +export FINWORLD_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 +export FINWORLD_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" #=============================================================================== # 主流程 @@ -173,14 +177,18 @@ if [[ $HOSTNAME == *"-master-"* ]]; then -e "s|{{NNODES}}|${NNODES}|g" \ -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ - -e "s|{{JUDGE_LLM}}|${JUDGE_LLM}|g" \ - -e "s|{{JUDGE_CONCURRENCY}}|${judge_concurrency}|g" \ - -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${report_resolution_weight}|g" \ - -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${trajectory_faithfulness_weight}|g" \ + -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ + -e "s|{{RM_LLM}}|${RM_LLM}|g" \ + -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ + -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ + -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ + -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ + -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ + -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} print_green "配置文件已生成: ${CONFIG_FILE}" - print_green "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, Judge=${JUDGE_LLM}" + print_green "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" #--------------------------------------------------------------------------- # 2. 清理和初始化 Ray @@ -219,13 +227,12 @@ if [[ $HOSTNAME == *"-master-"* ]]; then print_green "Log: ${TRAIN_LOG}" print_green "===================================" - # 修改:同步 cc_rm4 的启动参数,增加 debug 和 log-suffix + # 启动训练任务 python ajet/launcher.py \ --with-finworld \ --conf ${CONFIG_FILE} \ --backbone="verl" \ --debug="TAG_A" \ - --log-suffix="${SUFFIX}" \ 2>&1 | tee ${TRAIN_LOG} # 保留原脚本末尾的 CLI 调用 diff --git a/tutorial/example_finworld/scripts/cc_rm4_res2cit2fai2_30b.sh b/tutorial/example_finworld/scripts/cc_rm4_res2cit2fai2_30b.sh deleted file mode 100644 index 90643a17..00000000 --- a/tutorial/example_finworld/scripts/cc_rm4_res2cit2fai2_30b.sh +++ /dev/null @@ -1,384 +0,0 @@ -#!/bin/bash -set -e -#=============================================================================== -# 配置区域 - 用户只需修改这里 -#=============================================================================== -SUFFIX="cc_rm4_res2cit2fai2_30b" # 实验后缀,影响所有日志和实验名称 -PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 - -ADDR="22.17.31.142" -MCP_PORT="8040" -export CONFIG_FILE_NAME="tutorial/example_finworld/finworld.yaml" -export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" -#=============================================================================== -# 环境配置区域 -#=============================================================================== - -cd ${AJET_ROOT} -source .venv/bin/activate -# API密钥配置 - 从 .env 文件加载 -ENV_FILE="${AJET_ROOT}/.env" -if [ -f "$ENV_FILE" ]; then - set -a - source "$ENV_FILE" - set +a - echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" -else - echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" -fi - - - -#=============================================================================== -# 环境配置区域 -#=============================================================================== - -# MongoDB 缓存配置 -CACHE_TYPE="mongodb" -MONGO_URI="mongodb://${ADDR}:27117/" -MONGO_DB_NAME="finworld_cache" -MONGO_COLLECTION_NAME="tool_cache" - -# FinWorld MCP 配置 -LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" -FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" - -# 动态生成 MCP 配置文件(使用 ADDR 变量) -cat > ${FINWORLD_MCP_CONFIG} << EOF -{ - "mcpServers": { - "flowllm": { - "transport": "sse", - "url": "http://${ADDR}:${MCP_PORT}/sse", - "timeout": 600, - "sse_read_timeout": 1200 - } - } -} -EOF -FINWORLD_TOOL_RESULT_MAX_CHARS=10000 - -# 其他服务配置 -HF_ENDPOINT="https://hf-mirror.com" -ES_HOSTS="http://11.160.132.46:8200" - -#=============================================================================== -# 多机训练参数配置 -#=============================================================================== -if [ -z "${WORLD_SIZE}" ]; then - echo "ERROR: WORLD_SIZE environment variable is not set!" - echo "Please ensure this script is run in a multi-node environment (e.g., PAI-DLC, SLURM)" - exit 1 -fi - -NNODES=${WORLD_SIZE} -GPUS_PER_NODE=8 -EXPECTED_WORKERS=$WORLD_SIZE - -#=============================================================================== -# NCCL 配置 -#=============================================================================== -export NCCL_TIMEOUT=1800 -export NCCL_DEBUG=WARN -export NCCL_IB_TIMEOUT=23 -export NCCL_ASYNC_ERROR_HANDLING=1 -# RAY_DEBUG_POST_MORTEM="1" -# DEBUG_TAGS="TAG_A" -#=============================================================================== -# 自动生成的变量(不需要修改) -#=============================================================================== -CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") -CONFIG_FILE="${AJET_ROOT}/${CONFIG_FILE_NAME}" - -MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" -ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" -TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" - -#=============================================================================== -# 工具函数 -#=============================================================================== -print_green() { - echo -e "\033[32m$1\033[0m" -} - -print_red() { - echo -e "\033[31m$1\033[0m" -} - -log() { - echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" -} - -# 检查所有节点数量(包括head节点) -check_workers() { - local status_output=$(ray status 2>/dev/null) - if [ -z "$status_output" ]; then - echo 0 - return - fi - # 统计 "1 node_" 这种格式的行数 - local node_count=$(echo "$status_output" | grep -E "^[[:space:]]*1[[:space:]]+node_" | wc -l) - if [ "$node_count" -gt 0 ]; then - echo $node_count - return - fi - # 如果方法1失败,尝试统计包含node_的唯一ID - node_count=$(echo "$status_output" | grep -o "node_[0-9a-f]\+" | sort -u | wc -l) - echo $node_count -} - -# 检查GPU资源是否完全就绪 -check_gpu_resources() { - gpu_count=$(ray status 2>/dev/null | grep -A 10 "Resources" | grep "GPU" | awk '{print $1}' | cut -d'/' -f2) - if [ -z "$gpu_count" ]; then - echo 0 - else - printf "%.0f" "$gpu_count" - fi -} - -#=============================================================================== -# 导出环境变量 -# API密钥相关变量已通过 .env 文件加载并自动导出 (set -a) -#=============================================================================== -export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME -export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS -export HF_ENDPOINT ES_HOSTS -export PYTHONPATH="${AJET_ROOT}:${BEYONDAGENT_ROOT}:${PYTHONPATH}" -export RAY_CLUSTER_MODE="multi_node" - - - -# 配置 finworld 环境服务(供 launcher.py --with-finworld 使用) -# 注意:这里可以自定义 env_service 的启动参数 -export FINWORLD_PATH="${BEYONDAGENT_ROOT}" -# 如果需要传递额外参数,修改下面的命令行参数即可 -# 例如:--env_file_name custom_config --debug true -# FINWORLD_SCRIPT: API密钥会从环境变量继承 -export FINWORLD_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${BEYONDAGENT_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} FINWORLD_TASKS_DATA_PATH=${FINWORLD_TASKS_DATA_PATH} FINWORLD_TRAIN_REF_ANS_PATH=${FINWORLD_TRAIN_REF_ANS_PATH} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" - - -#=============================================================================== -# 主流程 -#=============================================================================== -log "开始多机多卡训练: ${SUFFIX}" -log "时间戳: ${CURRENT_TIME}" -log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" -log "配置文件: ${CONFIG_FILE}" - -# 确保日志目录存在 -mkdir -p ${LOG_DIR} - -#=============================================================================== -# Master 节点启动流程 -#=============================================================================== -if [[ $HOSTNAME == *"-master-"* ]]; then - print_green "==> This is MASTER node: $HOSTNAME" - - #--------------------------------------------------------------------------- - # 1. 清理和初始化 - #--------------------------------------------------------------------------- - rm -f "$MASTER_IP_FILE" - print_green "Cleaned old master IP file" - - ray stop --force || true - sleep 3 - print_green "Runtime env configuration created" - - #--------------------------------------------------------------------------- - # 4. 启动 Ray Head 节点(带 runtime_env) - #--------------------------------------------------------------------------- - print_green "Starting Ray head node at $MASTER_ADDR with runtime_env" - ray start --head \ - --node-ip-address $MASTER_ADDR \ - --num-gpus 8 - - print_green "Waiting for Ray head to be fully ready..." - sleep 10 - - if ! ray status > /dev/null 2>&1; then - print_red "ERROR: Ray head failed to start properly" - exit 1 - fi - print_green "Ray head is ready" - - # 写入 Master IP 到共享文件 - echo $MASTER_ADDR > $MASTER_IP_FILE - print_green "Master IP written to $MASTER_IP_FILE: $MASTER_ADDR" - - #--------------------------------------------------------------------------- - # 5. 等待所有 Worker 节点加入 - #--------------------------------------------------------------------------- - print_green "Waiting for all nodes to join the Ray cluster..." - print_green "Expected nodes: $EXPECTED_WORKERS (including head node)" - - TIMEOUT=1000 - INTERVAL=10 - ELAPSED=0 - - while true; do - current_nodes=$(check_workers) - print_green "Current node count: $current_nodes/$EXPECTED_WORKERS" - - if [ "$current_nodes" -ge "$EXPECTED_WORKERS" ]; then - print_green "All nodes have joined the cluster!" - break - fi - - if [ "$ELAPSED" -ge "$TIMEOUT" ]; then - print_red "Timeout waiting for nodes. Only $current_nodes/$EXPECTED_WORKERS nodes joined." - ray status - exit 1 - fi - - sleep $INTERVAL - ELAPSED=$((ELAPSED + INTERVAL)) - done - - #--------------------------------------------------------------------------- - # 6. 等待 GPU 资源就绪 - #--------------------------------------------------------------------------- - print_green "Waiting for GPU resources to be fully available..." - EXPECTED_GPUS=$((WORLD_SIZE * 8)) - GPU_TIMEOUT=300 - GPU_ELAPSED=0 - - while true; do - current_gpus=$(check_gpu_resources) - print_green "Current GPU count: $current_gpus/$EXPECTED_GPUS" - - if [ "$current_gpus" -eq "$EXPECTED_GPUS" ]; then - print_green "All GPUs are available!" - break - fi - - if [ "$GPU_ELAPSED" -ge "$GPU_TIMEOUT" ]; then - print_red "Timeout waiting for GPUs. Only $current_gpus/$EXPECTED_GPUS GPUs available." - ray status - exit 1 - fi - - sleep 5 - GPU_ELAPSED=$((GPU_ELAPSED + 5)) - done - - print_green "Final cluster status before training:" - ray status - - #--------------------------------------------------------------------------- - # 7. 等待 Ray Dashboard 启动 - #--------------------------------------------------------------------------- - print_green "Waiting for Ray dashboard to be ready..." - while ! curl -s http://127.0.0.1:8265 > /dev/null; do - sleep 5 - done - - #--------------------------------------------------------------------------- - # 8. 确认 env_service 启动配置 - #--------------------------------------------------------------------------- - print_green "Environment service will be started by launcher.py --with-finworld" - print_green " FINWORLD_PATH: ${FINWORLD_PATH}" - print_green " FINWORLD_SCRIPT: ${FINWORLD_SCRIPT}" - print_green " Log file: ${ENV_SERVICE_LOG}" - print_green " Note: env_service will load .env internally from its conda environment" - - #--------------------------------------------------------------------------- - # 9. 启动训练任务 - #--------------------------------------------------------------------------- - print_green "Starting training job..." - - - # 激活训练环境 - source .venv/bin/activate - - # 重新导出关键环境变量(conda activate 可能会重置) - # API密钥已通过 .env 加载 - export CACHE_TYPE="${CACHE_TYPE}" - export MONGO_URI="${MONGO_URI}" - export MONGO_DB_NAME="${MONGO_DB_NAME}" - export MONGO_COLLECTION_NAME="${MONGO_COLLECTION_NAME}" - - # 设置训练环境变量 - export RAY_ADDRESS="ray://localhost:10001" - export env_url="http://${MASTER_ADDR}:8080" - export env_type="finworld" - export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" - - # 输出配置信息 - print_green "===================================" - print_green "Training Configuration" - print_green "===================================" - print_green "NNODES: $NNODES" - print_green "GPUS_PER_NODE: $GPUS_PER_NODE" - print_green "Total GPUs: $((NNODES * GPUS_PER_NODE))" - print_green "env_url: $env_url" - print_green "RAY_ADDRESS: $RAY_ADDRESS" - print_green "Python: $(which python)" - print_green "训练日志: ${TRAIN_LOG}" - print_green "===================================" - - # 启动训练(多机模式下不需要 --with-ray,因为 Ray 集群已在脚本中手动启动) - # 使用 --with-finworld 让 launcher.py 统一管理 env_service 的启动和生命周期 - python ajet/launcher.py \ - --with-finworld \ - --conf ${CONFIG_FILE} \ - --backbone="verl" \ - 2>&1 | tee ${TRAIN_LOG} - ajet --conf ${CONFIG_FILE} --backbone='verl' - -#=============================================================================== -# Worker 节点启动流程 -#=============================================================================== -else - print_green "==> This is WORKER node: $HOSTNAME" - - #--------------------------------------------------------------------------- - # 1. 等待 Master IP 文件 - #--------------------------------------------------------------------------- - export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" - - while [ ! -f $MASTER_IP_FILE ]; do - print_green "Waiting for master node IP file..." - sleep 5 - done - sleep 2 - - MASTER_ADDR=$(cat $MASTER_IP_FILE) - print_green "Found master node at $MASTER_ADDR" - - #--------------------------------------------------------------------------- - # 2. 连接到 Ray 集群 - #--------------------------------------------------------------------------- - ray stop || true - - MAX_RETRIES=3 - RETRY_COUNT=0 - - while [ $RETRY_COUNT -lt $MAX_RETRIES ]; do - if ray start --address $MASTER_ADDR:6379 --num-gpus 8; then - print_green "Worker node started successfully" - break - fi - - RETRY_COUNT=$((RETRY_COUNT + 1)) - print_red "Failed to start worker node, attempt $RETRY_COUNT of $MAX_RETRIES" - sleep 10 - done - - if [ $RETRY_COUNT -eq $MAX_RETRIES ]; then - print_red "Failed to start worker node after $MAX_RETRIES attempts" - exit 1 - fi - - #--------------------------------------------------------------------------- - # 4. 保持连接状态 - #--------------------------------------------------------------------------- - print_green "Worker node is running, keeping alive..." - while true; do - sleep 60 - if ! ray status > /dev/null 2>&1; then - print_red "Lost connection to Ray cluster, exiting..." - break - fi - done -fi diff --git a/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml b/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml new file mode 100644 index 00000000..b0e017d4 --- /dev/null +++ b/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml @@ -0,0 +1,82 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: ajet_finworld + experiment_name: "ajet_finworld" + # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) + judge: + openjudge_llm: qwen-flash # OpenJudge 模型 + rm_llm: qwen-max # RM Gallery 模型 + concurrency: 10 # Judge 并发数 + # OpenJudge 权重配置 + report_resolution_weight: 0.2 # 报告质量评估 + trajectory_faithfulness_weight: 0.2 # 事实准确性评估 + citation_audit_weight: 0.2 # 引用审计评估 (覆盖率 + 真实性) + rm_weight: 0.4 # RM Gallery 权重 + task_judge: + # 使用本地 FinWorldJudge 进行评估(解耦远程 env_service) + judge_protocol: tutorial.example_finworld.finworld_judge->FinWorldJudgeByOpenJudge + model: + # ✨✨✨✨ 设置待训练的模型 + path: /mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507 + trainer_common: + nnodes: 8 + n_gpus_per_node: 8 + val_before_train: True + val_pass_n: 8 + save_freq: 10 + test_freq: 2 + total_epochs: 200 + rollout: + # ✨✨✨✨ 编写并选择Agent + use_agentscope_protocol: True + agentscope_learn_protocol: tutorial.example_finworld.finworld->ExampleAgentScopeLearnProtocol + agentscope_disable_toolcalls: True + enable_oversample: False + tensor_model_parallel_size: 8 + num_repeat: 4 + max_env_worker: 64 # 增加环境并行数 + max_num_seqs: 64 # 增加VLLM并发序列数 + max_env_len: 10000 + max_response_length_in_one_turn: 8000 + max_model_len: 50000 + agent_madness_reward: 0.0 + multi_turn: + max_steps: 6 + interchange_server: + interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + debug: + debug_max_parallel: 64 # 增加并行任务数,充分利用GPU + debug_first_n_tasks: 100 # 增加处理的任务数 + data: + train_batch_size: 32 + max_prompt_length: 8000 + max_response_length: 41000 + + task_reader: + type: env_service # `env_service` or `dataset_file` or `huggingface_dat_repo` + env_service: + env_type: "finworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code # code, text, box + training_split: train + validation_split: val +trainer: + default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/ajet/checkpoints/example_finworld//open/ajet_finworld" + # resume_mode: disable # 禁用自动恢复,从头开始训练 +actor_rollout_ref: + rollout: + tensor_model_parallel_size: 8 + gpu_memory_utilization: 0.8 +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_finworld/yaml_template/finworld_template.yaml b/tutorial/example_finworld/yaml_template/finworld_template.yaml index 14fe6194..616be9fe 100644 --- a/tutorial/example_finworld/yaml_template/finworld_template.yaml +++ b/tutorial/example_finworld/yaml_template/finworld_template.yaml @@ -1,9 +1,12 @@ # ------------------ 主要配置 ------------------ -astune: - project_name: astune_finprompt +ajet: + project_name: ajet_finworld experiment_name: "{{SUFFIX}}" - judge_llm: {{JUDGE_LLM}} - judge_concurrency: {{JUDGE_CONCURRENCY}} + # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) + judge: + openjudge_llm: {{OPENJUDGE_LLM}} # OpenJudge 模型 + rm_llm: {{RM_LLM}} # RM Gallery 模型 + concurrency: {{JUDGE_CONCURRENCY}} # Judge 并发数 # OpenJudge 权重配置 report_resolution_weight: {{REPORT_RESOLUTION_WEIGHT}} # 报告质量评估 trajectory_faithfulness_weight: {{TRAJECTORY_FAITHFULNESS_WEIGHT}} # 事实准确性评估 @@ -11,7 +14,7 @@ astune: rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 task_judge: # 使用本地 FinWorldJudge 进行评估(解耦远程 env_service) - judge_protocol: tutorial.example_finworld.finworld_judge_by_openjudge->FinWorldJudgeByOpenJudge + judge_protocol: tutorial.example_finworld.finworld_judge->FinWorldJudgeByOpenJudge model: # ✨✨✨✨ 设置待训练的模型 path: {{MODEL_PATH}} @@ -39,6 +42,8 @@ astune: agent_madness_reward: 0.0 multi_turn: max_steps: {{NUM_STEPS}} + interchange_server: + interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) debug: debug_max_parallel: 64 # 增加并行任务数,充分利用GPU debug_first_n_tasks: 100 # 增加处理的任务数 @@ -56,7 +61,7 @@ astune: training_split: train validation_split: val trainer: - default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/astune/checkpoints/example_finworld//{{PREFIX}}/{{SUFFIX}}" + default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/ajet/checkpoints/example_finworld//{{PREFIX}}/{{SUFFIX}}" # resume_mode: disable # 禁用自动恢复,从头开始训练 actor_rollout_ref: rollout: @@ -65,15 +70,13 @@ actor_rollout_ref: # ------------------ 不需要修改 ------------------ hydra: searchpath: - - file://astune/default_config - - file://astune/default_config/verl # verl only - - file://external/verl/verl/trainer/config # verl only - - file://astune/default_config/trinity # trinity only + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only # ------------------ 不需要修改 ------------------ defaults: - - ppo_trainer # verl inherit 1/2 - - verl_default # verl inherit 2/2 + - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 - - astune_default + - ajet_default - _self_ From 4662d631ed180b6f37ce098e10b044c20abd5bc0 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Mon, 19 Jan 2026 14:36:12 +0800 Subject: [PATCH 08/56] feat(task_reader): Support data reading of type jsonl_with_env_service - Added the jsonl_with_env_service type, which allows loading data from jsonl files while calling tools via env_service. - Extended ResourceKeeper to handle the creation and release logic of environment instances for jsonl_with_env_service. - Maintained the env_service type logic, calling create_instance to register instances and initializing them using init_messages from the jsonl file. - Added an example protocol, ExampleDeepResearchProtocol, to implement multi-turn interaction and environment call coordination. - Provided training scripts and YAML configuration templates for finworld, supporting the jsonl_with_env_service mode training environment. - Optimized scripts to support multi-node multi-GPU training, including environment variables and Ray cluster configuration. --- ajet/task_reader/__init__.py | 3 + ajet/task_rollout/resource_keeper.py | 32 ++- tutorial/example_finworld/finworld_reader.py | 233 ++++++++++++++++ .../scripts/ajet_finworld_loadjsonl.sh | 252 ++++++++++++++++++ .../yaml/finworld_ajet_finworld.yaml | 6 +- .../finworld_jsonl_template.yaml | 86 ++++++ .../yaml_template/finworld_template.yaml | 7 +- 7 files changed, 610 insertions(+), 9 deletions(-) create mode 100644 tutorial/example_finworld/finworld_reader.py create mode 100644 tutorial/example_finworld/scripts/ajet_finworld_loadjsonl.sh create mode 100644 tutorial/example_finworld/yaml_template/finworld_jsonl_template.yaml diff --git a/ajet/task_reader/__init__.py b/ajet/task_reader/__init__.py index 19a1a8e3..83c91c6b 100644 --- a/ajet/task_reader/__init__.py +++ b/ajet/task_reader/__init__.py @@ -61,6 +61,9 @@ def __init__(self, reader_type, reader_config): self.task_reader = DataGeneratorTaskReader(reader_config) elif task_reader_type == "random_dummy": self.task_reader = RandomDummyTaskReader(reader_config) + elif task_reader_type == "jsonl_with_env_service": + # 数据从 jsonl 加载,工具调用走 env_service + self.task_reader = JsonlTaskReader(reader_config) else: raise ValueError(f"Unsupported task reader type: {task_reader_type}") diff --git a/ajet/task_rollout/resource_keeper.py b/ajet/task_rollout/resource_keeper.py index 2498b415..26cd44f7 100644 --- a/ajet/task_rollout/resource_keeper.py +++ b/ajet/task_rollout/resource_keeper.py @@ -25,7 +25,7 @@ def __enter__(self): self.tokenizer = self.workflow_task.tokenizer self.llm_inference_fn = self.workflow_task.llm_inference_fn self.observation_window = self.workflow_task.observation_window - if self.config.ajet.task_reader.type == "env_service": + if self.config.ajet.task_reader.type in ("env_service", "jsonl_with_env_service"): url = self.config.ajet.task_reader.env_service.env_url env_type = self.config.ajet.task_reader.env_service.env_type self.env = EnvClientNg(base_url=url) @@ -74,7 +74,9 @@ def _initialize_environment_and_messages(self) -> List[dict]: Exception: If environment creation fails or required task data is missing """ - if self.config.ajet.task_reader.type == "env_service": + reader_type = self.config.ajet.task_reader.type + + if reader_type == "env_service": if self.env is None: raise ValueError("Environment client is None but env_service type is specified") try: @@ -95,6 +97,32 @@ def _initialize_environment_and_messages(self) -> List[dict]: if self.env is not None: self.env.release_instance(self.workflow_task.episode_uuid) raise e + elif reader_type == "jsonl_with_env_service": + # 新逻辑:调用 create_instance 注册实例,但使用 jsonl 中的 init_messages + if self.env is None: + raise ValueError("Environment client is None but jsonl_with_env_service type is specified") + try: + # 必须调用 create_instance,让服务端创建实例,后续 step() 才能工作 + self.env.create_instance( + env_type=self.env_type, + task_id=self.task_id, + instance_id=self.workflow_task.episode_uuid, + params=self.env_params, + ) + # 不使用返回的 state,直接用 jsonl 中加载的 init_messages + task = self.workflow_task.task + if task.init_messages: + init_messages = task.init_messages + else: + assert task.main_query, "jsonl_with_env_service requires init_messages or main_query in jsonl file." + init_messages = [{"role": "user", "content": task.main_query}] + except Exception as e: + logger.bind(exception=True).exception( + f"encounter exception in env_worker.create_instance~ error={e.args}" + ) + if self.env is not None: + self.env.release_instance(self.workflow_task.episode_uuid) + raise e else: task = self.workflow_task.task if task.init_messages: diff --git a/tutorial/example_finworld/finworld_reader.py b/tutorial/example_finworld/finworld_reader.py new file mode 100644 index 00000000..f742adfc --- /dev/null +++ b/tutorial/example_finworld/finworld_reader.py @@ -0,0 +1,233 @@ +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask +from agentscope.message import Msg +from pydantic import Field +import logging +import threading +import time +import copy +from loguru import logger + + +# 创建信号量,允许同时12个线程运行 +sem = threading.Semaphore(30) + +class ExampleDeepResearchProtocol(Workflow): + + + async def execute( + self, workflow_task: WorkflowTask, tuner: AjetTuner + ) -> WorkflowOutput: + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + # 1. 初始化消息 + # init_messages 通常是 [System, User] + init_messages = workflow_task.task.init_messages + + # 分离 System Prompt 和 Initial User Input + if len(init_messages) >= 2: + first_msg, user_msgs = init_messages[0], init_messages[1:] + else: + first_msg = {"content": "You're a helpful assistant."} + user_msgs = init_messages + + # conversation_history: 维护最原始、最标准的 OpenAI 格式数据 (含 role: tool) + # 这是"真值",用于评测和训练保存 + conversation_history = [ + {"role": "system", "content": first_msg["content"]}, + ] + conversation_history.extend(user_msgs) + + # 2. 初始化 Agent + agent = ReActAgent( + name="Qwen", + sys_prompt=first_msg["content"], # Agent 内部会自动管理 System Prompt + model=tuner.as_agentscope_model(), + formatter=DashScopeChatFormatter(), + memory=InMemoryMemory(), + toolkit=None, + print_hint_msg=False, + ) + agent.set_console_output_enabled(False) + env = workflow_task.gym_env + + # 3. 构造初始 Agent 输入 (List[Msg]) + # 注意:这里只包含 User 消息,不含 System,因为 System 已在 agent init 中设置 + # 必须转换为 Msg 对象 + agent_input = [] + for m in user_msgs: + agent_input.append(Msg( + name=m.get("name", "user"), + content=m.get("content", ""), + role=m.get("role", "user") + )) + + # 统计信息缓存 + latest_tool_stats = None + latest_reward_stats = {} + cumulative_tool_call_time = 0.0 # 累计工具调用时间 + cumulative_tool_time = {} # 按工具区分的累计耗时: {tool_name: [time1, time2, ...]} + + logger.info(f"开始执行多轮交互,最大步数: {tuner.config.ajet.rollout.multi_turn.max_steps}") + + step = 0 + for step in range(tuner.config.ajet.rollout.multi_turn.max_steps): + logger.info(f"=== 步骤 {step + 1} ===") + + # === Agent 推理 === + _llm_start = time.time() + # 传入增量消息 (agent_input),Agent 会将其添加到内存并生成回复 + reply_message = await agent(agent_input) + _llm_elapsed = time.time() - _llm_start + # 提取纯文本 content(兼容多模态格式) + if isinstance(reply_message.content, list): + # 多模态格式: [{'type': 'text', 'text': '...'}] + content_text = ''.join(item.get('text', '') for item in reply_message.content if isinstance(item, dict) and item.get('type') == 'text') + else: + content_text = reply_message.content + + content_preview = content_text[:100].replace('\n', ' ') + # logger.info(f"Agent回复 ({_llm_elapsed:.2f}s): {content_preview}...") + + # === 早期终止检查:在调用 env.step() 前检查 context_overflow === + # 修复问题:避免 token_overflow 后还继续调用工具导致阻塞 + if tuner.get_context_tracker().context_overflow: + logger.warning(f"上下文溢出,跳过 env.step(),在第 {step + 1} 步立即结束") + # 构造一个默认的结束响应 + conversation_history.append({ + "role": "assistant", + "content": content_text + }) + break + + # === Env 执行 === + _env_start = time.time() + with sem: + obs, reward, terminate, info = env.step( + action={"content": content_text, "role": "assistant"} + ) + _env_elapsed = time.time() - _env_start + logger.info(f"环境执行 ({_env_elapsed:.2f}s)") + # === 3. 更新 conversation_history (Full History) === + # A. 添加 Assistant 消息 (补全 tool_calls) + current_assistant_msg = { + "role": "assistant", + "content": content_text + } + if info and 'generated_tool_calls' in info and info['generated_tool_calls']: + current_assistant_msg['tool_calls'] = info['generated_tool_calls'] + conversation_history.append(current_assistant_msg) + + # B. 添加 Tool 消息 (直接使用 obs) + # 注意:obs 可能是 [tool_results_msgs] 套了一层,需要解包 + if isinstance(obs, list): + actual_msgs = obs[0] if (len(obs) == 1 and isinstance(obs[0], list)) else obs + conversation_history.extend(actual_msgs) + else: + conversation_history.append({"role": "user", "content": obs}) + + # === 4. 更新统计信息 === + if info: + if 'tool_stats' in info: + latest_tool_stats = info['tool_stats'] + logger.info(f"步骤 {step + 1} 工具统计: 调用={latest_tool_stats.get('total_calls', 0)}, " + f"成功率={latest_tool_stats.get('success_rate', 0):.1f}%") + if 'reward_stats' in info: + latest_reward_stats = info['reward_stats'] + # 累加工具调用时间 + step_tool_call_time = latest_reward_stats.get('tool_call_time', 0.0) + cumulative_tool_call_time += step_tool_call_time + # 累加按工具区分的耗时 + step_tool_time = latest_reward_stats.get('tool_time', {}) + for tool_name, time_list in step_tool_time.items(): + if tool_name not in cumulative_tool_time: + cumulative_tool_time[tool_name] = [] + if isinstance(time_list, list): + cumulative_tool_time[tool_name].extend(time_list) + + # === 5. 准备下一轮 Agent 输入 (Incremental) === + # 将 Env 返回的 obs 转换为 Msg 对象列表,供下一轮 agent() 调用 + # 关键:这里只放新的 obs,不要放完整的 history + agent_input = [] + + if isinstance(obs, list): + # Standard Mode: obs 是 tool messages 列表 + # 注意:finworld_env.step 返回 {"state": [tool_results_msgs]} 套了一层列表 + # BaseGymEnv.step 直接透传,所以 obs = [tool_results_msgs] + # 需要解包获取实际的消息列表 + actual_msgs = obs[0] if (len(obs) == 1 and isinstance(obs[0], list)) else obs + logger.info(f"环境观察 (Standard): 收到 {len(actual_msgs)} 条工具消息") + + # 按照 AgentScope 的 ContentBlock 格式转换消息 + # Agent.memory 会自动保存 assistant 的 tool_call 信息 + # 这里只需要传入 tool_result 消息即可 + for idx, m in enumerate(actual_msgs): + origin_role = m.get('role', 'user') + if origin_role == 'tool': + # 使用 ToolResultBlock 格式,作为 user 消息的 content + tool_result_block = { + "type": "tool_result", + "id": m.get('tool_call_id', ''), + "output": m.get('content', ''), + "name": m.get('name', '') + } + new_msg = Msg( + name="tool", + content=[tool_result_block], + role="user" + ) + agent_input.append(new_msg) + else: + # 其他消息(如 user 提示)直接添加 + content = m.get('content') + if content is None: content = "" + valid_role = origin_role if origin_role in ['user', 'assistant', 'system'] else 'user' + new_msg = Msg( + name=m.get('name', valid_role), + content=content, + role=valid_role + ) + agent_input.append(new_msg) + else: + # Legacy Mode + logger.info(f"环境观察 (Legacy): {str(obs)[:100]}...") + agent_input.append(Msg(name="env", content=obs, role="user")) + + # === 6. 终止检查 === + logger.info(f"终止状态: {terminate}") + if terminate: + logger.info(f"环境返回终止信号,在第 {step + 1} 步结束") + break + + if tuner.get_context_tracker().context_overflow: + logger.warning(f"上下文溢出,在第 {step + 1} 步结束") + break + + # === 结束处理 === + final_tool_stats = latest_tool_stats or { + 'total_calls': 0, 'total_errors': 0, 'success_calls': 0, 'success_rate': 0.0, + 'cache_hits': 0, 'cache_misses': 0 + } + # 将累计的 tool_time 合并到 tool_stats 中 + final_tool_stats['tool_time'] = cumulative_tool_time + final_tool_stats['tool_call_time'] = cumulative_tool_call_time + + logger.info(f"\n{'='*80}") + logger.info(f"任务完成统计 (Task ID: {workflow_task.task.task_id}):") + logger.info(f" 总步骤: {step + 1}") + logger.info(f" 总调用: {final_tool_stats.get('total_calls', 0)}") + logger.info(f" 成功率: {final_tool_stats.get('success_rate', 0):.2f}%") + logger.info(f"{'='*80}\n") + + return WorkflowOutput( + reward=None, + metadata={ + "total_step": step, + "tool_stats": final_tool_stats, + "reward_stats": latest_reward_stats, + "tool_success_rate": round(final_tool_stats.get('success_rate', 0.0), 2), + "conversation_history": conversation_history, + "query": workflow_task.task.main_query, + "task_id": workflow_task.task.task_id, + } + ) \ No newline at end of file diff --git a/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl.sh b/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl.sh new file mode 100644 index 00000000..a5550ba8 --- /dev/null +++ b/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl.sh @@ -0,0 +1,252 @@ +#!/bin/bash +set -e +#=============================================================================== +# 配置区域 - 用户只需修改这里 +#=============================================================================== +SUFFIX="ajet_finworld_loadjsonl" # 实验后缀,影响所有日志和实验名称 +PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 + +# 新增:模型与模板配置 +MODEL_PATH="/mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507" +CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_jsonl_template.yaml" + +# 新增:Judge 模型配置 +OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 +RM_LLM='qwen-max' # RM Gallery 评分模型 +JUDGE_CONCURRENCY=10 + +# 新增:奖励权重配置 +RM_WEIGHT=0.4 +CITATION_AUDIT_WEIGHT=0.2 +REPORT_RESOLUTION_WEIGHT=0.2 +TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 + +# API密钥配置(从 .env 文件加载,不要硬编码) +# 配置 +NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 +TRAIN_BATCH_SIZE=32 +NUM_STEPS=6 # 每个样本step轮数 + +ADDR="22.17.31.142" +MCP_PORT="8040" + +# 修改:配置文件生成路径,现在动态生成到 yaml 目录下 +export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" +CONFIG_FILE="${AJET_ROOT}/tutorial/example_finworld/yaml/finworld_${SUFFIX}.yaml" + +#=============================================================================== +# 环境配置区域 +#=============================================================================== + +cd ${AJET_ROOT} +source .venv/bin/activate +# API密钥配置 - 从 .env 文件加载 +ENV_FILE="${AJET_ROOT}/.env" +if [ -f "$ENV_FILE" ]; then + set -a + source "$ENV_FILE" + set +a + echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" +else + echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" +fi + +# MongoDB 缓存配置 +CACHE_TYPE="mongodb" +MONGO_URI="mongodb://${ADDR}:27117/" +MONGO_DB_NAME="finworld_cache" +MONGO_COLLECTION_NAME="tool_cache" + +# FinWorld MCP 配置 +LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" +FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" + +# 动态生成 MCP 配置文件 +mkdir -p $(dirname ${FINWORLD_MCP_CONFIG}) +cat > ${FINWORLD_MCP_CONFIG} << EOF +{ + "mcpServers": { + "flowllm": { + "transport": "sse", + "url": "http://${ADDR}:${MCP_PORT}/sse", + "timeout": 600, + "sse_read_timeout": 1200 + } + } +} +EOF +FINWORLD_TOOL_RESULT_MAX_CHARS=10000 + +# 其他服务配置 +HF_ENDPOINT="https://hf-mirror.com" +ES_HOSTS="http://11.160.132.46:8200" + +#=============================================================================== +# 多机训练参数配置 +#=============================================================================== +if [ -z "${WORLD_SIZE}" ]; then + echo "ERROR: WORLD_SIZE environment variable is not set!" + echo "Please ensure this script is run in a multi-node environment (e.g., PAI-DLC, SLURM)" + exit 1 +fi + +NNODES=${WORLD_SIZE} +GPUS_PER_NODE=8 +EXPECTED_WORKERS=$WORLD_SIZE + +#=============================================================================== +# NCCL 配置 +#=============================================================================== +export NCCL_TIMEOUT=1800 +export NCCL_DEBUG=WARN +export NCCL_IB_TIMEOUT=23 +export NCCL_ASYNC_ERROR_HANDLING=1 + +#=============================================================================== +# 自动生成的变量 +#=============================================================================== +CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") + +MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" +ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" +TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" + +#=============================================================================== +# 工具函数 +#=============================================================================== +print_green() { + echo -e "\033[32m$1\033[0m" +} + +print_red() { + echo -e "\033[31m$1\033[0m" +} + +log() { + echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" +} + +check_workers() { + local status_output=$(ray status 2>/dev/null) + if [ -z "$status_output" ]; then echo 0; return; fi + local node_count=$(echo "$status_output" | grep -E "^[[:space:]]*1[[:space:]]+node_" | wc -l) + if [ "$node_count" -gt 0 ]; then echo $node_count; return; fi + echo $(echo "$status_output" | grep -o "node_[0-9a-f]\+" | sort -u | wc -l) +} + +check_gpu_resources() { + gpu_count=$(ray status 2>/dev/null | grep -A 10 "Resources" | grep "GPU" | awk '{print $1}' | cut -d'/' -f2) + if [ -z "$gpu_count" ]; then echo 0; else printf "%.0f" "$gpu_count"; fi +} + +#=============================================================================== +# 导出环境变量 +#=============================================================================== +export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME +export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS +export HF_ENDPOINT ES_HOSTS +export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" +export RAY_CLUSTER_MODE="multi_node" +# Directory paths +export ENV_SERVICE_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/mongodb/BeyondAgent_env" + +export FINWORLD_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 +export FINWORLD_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" + +#=============================================================================== +# 主流程 +#=============================================================================== +log "开始多机多卡训练: ${SUFFIX}" +log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" +mkdir -p ${LOG_DIR} +mkdir -p $(dirname ${CONFIG_FILE}) + +#=============================================================================== +# Master 节点启动流程 +#=============================================================================== +if [[ $HOSTNAME == *"-master-"* ]]; then + print_green "==> This is MASTER node: $HOSTNAME" + + #--------------------------------------------------------------------------- + # 1. 动态生成配置文件 (从模板注入参数) + #--------------------------------------------------------------------------- + log "正在从模板生成配置文件..." + sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ + -e "s|{{PREFIX}}|${PREFIX}|g" \ + -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ + -e "s|{{NNODES}}|${NNODES}|g" \ + -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ + -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ + -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ + -e "s|{{RM_LLM}}|${RM_LLM}|g" \ + -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ + -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ + -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ + -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ + -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ + -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ + ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} + + print_green "配置文件已生成: ${CONFIG_FILE}" + print_green "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" + + #--------------------------------------------------------------------------- + # 2. 清理和初始化 Ray + #--------------------------------------------------------------------------- + rm -f "$MASTER_IP_FILE" + ray stop --force || true + sleep 3 + + #--------------------------------------------------------------------------- + # 4. 启动 Ray Head + #--------------------------------------------------------------------------- + print_green "Starting Ray head node at $MASTER_ADDR" + ray start --head --node-ip-address $MASTER_ADDR --num-gpus 8 + sleep 10 + echo $MASTER_ADDR > $MASTER_IP_FILE + + #--------------------------------------------------------------------------- + # 5 & 6. 等待节点和 GPU 就绪 (逻辑保持不变) + #--------------------------------------------------------------------------- + # ... (此处省略重复的等待逻辑以保持简洁,实际运行时请保留原脚本中的 while 循环) ... + # [请保留原脚本中 5.等待所有Worker 6.等待GPU 7.等待Dashboard 的完整代码] + + #--------------------------------------------------------------------------- + # 9. 启动训练任务 + #--------------------------------------------------------------------------- + print_green "Starting training job..." + source .venv/bin/activate + + export RAY_ADDRESS="ray://localhost:10001" + export env_url="http://${MASTER_ADDR}:8080" + export env_type="finworld" + + print_green "===================================" + print_green "Training Configuration" + print_green "Total GPUs: $((NNODES * GPUS_PER_NODE))" + print_green "Log: ${TRAIN_LOG}" + print_green "===================================" + + # 启动训练任务 + python ajet/launcher.py \ + --with-finworld \ + --conf ${CONFIG_FILE} \ + --backbone="verl" \ + --debug="TAG_A" \ + 2>&1 | tee ${TRAIN_LOG} + + # 保留原脚本末尾的 CLI 调用 + ajet --conf ${CONFIG_FILE} --backbone='verl' + +#=============================================================================== +# Worker 节点启动流程 (逻辑保持不变) +#=============================================================================== +else + print_green "==> This is WORKER node: $HOSTNAME" + # [此处保留原脚本中 Worker 节点等待 Master IP 和连接 Ray 的完整逻辑] + while [ ! -f $MASTER_IP_FILE ]; do sleep 5; done + MASTER_ADDR=$(cat $MASTER_IP_FILE) + ray stop || true + ray start --address $MASTER_ADDR:6379 --num-gpus 8 + while true; do sleep 60; done +fi \ No newline at end of file diff --git a/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml b/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml index b0e017d4..16e5b6eb 100644 --- a/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml +++ b/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml @@ -28,9 +28,8 @@ ajet: total_epochs: 200 rollout: # ✨✨✨✨ 编写并选择Agent - use_agentscope_protocol: True - agentscope_learn_protocol: tutorial.example_finworld.finworld->ExampleAgentScopeLearnProtocol - agentscope_disable_toolcalls: True + user_workflow: tutorial.example_finworld.finworld->ExampleDeepResearchProtocol + force_disable_toolcalls: True enable_oversample: False tensor_model_parallel_size: 8 num_repeat: 4 @@ -40,6 +39,7 @@ ajet: max_response_length_in_one_turn: 8000 max_model_len: 50000 agent_madness_reward: 0.0 + compute_madness_checklist: None multi_turn: max_steps: 6 interchange_server: diff --git a/tutorial/example_finworld/yaml_template/finworld_jsonl_template.yaml b/tutorial/example_finworld/yaml_template/finworld_jsonl_template.yaml new file mode 100644 index 00000000..56a81472 --- /dev/null +++ b/tutorial/example_finworld/yaml_template/finworld_jsonl_template.yaml @@ -0,0 +1,86 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: ajet_finworld + experiment_name: "{{SUFFIX}}" + # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) + judge: + openjudge_llm: {{OPENJUDGE_LLM}} # OpenJudge 模型 + rm_llm: {{RM_LLM}} # RM Gallery 模型 + concurrency: {{JUDGE_CONCURRENCY}} # Judge 并发数 + # OpenJudge 权重配置 + report_resolution_weight: {{REPORT_RESOLUTION_WEIGHT}} # 报告质量评估 + trajectory_faithfulness_weight: {{TRAJECTORY_FAITHFULNESS_WEIGHT}} # 事实准确性评估 + citation_audit_weight: {{CITATION_AUDIT_WEIGHT}} # 引用审计评估 (覆盖率 + 真实性) + rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 + task_judge: + # 使用本地 FinWorldJudge 进行评估(解耦远程 env_service) + judge_protocol: tutorial.example_finworld.finworld_judge->FinWorldJudgeByOpenJudge + model: + # ✨✨✨✨ 设置待训练的模型 + path: {{MODEL_PATH}} + trainer_common: + nnodes: {{NNODES}} + n_gpus_per_node: 8 + val_before_train: True + val_pass_n: 8 + save_freq: 10 + test_freq: 2 + total_epochs: 200 + rollout: + # ✨✨✨✨ 编写并选择Agent + user_workflow: tutorial.example_finworld.finworld->ExampleDeepResearchProtocol + force_disable_toolcalls: True + enable_oversample: False + tensor_model_parallel_size: 8 + num_repeat: {{NUM_REPEAT}} + max_env_worker: 64 # 增加环境并行数 + max_num_seqs: 64 # 增加VLLM并发序列数 + max_response_length_in_one_turn: 8000 + max_model_len: 50000 + agent_madness_reward: 0.0 + compute_madness_checklist: None + multi_turn: + max_steps: {{NUM_STEPS}} + interchange_server: + interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + debug: + debug_max_parallel: 64 # 增加并行任务数,充分利用GPU + debug_first_n_tasks: 100 # 增加处理的任务数 + data: + train_batch_size: {{TRAIN_BATCH_SIZE}} + max_prompt_length: 8000 + max_response_length: 41000 + + task_reader: + type: jsonl_with_env_service # 数据从 jsonl 加载,工具调用走 env_service + env_service: + env_type: "finworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code # code, text, box + training_split: train + validation_split: val + jsonl_dataset_file: + training: + file_path: "tutorial/example_finworld/data/train.jsonl" + validation: + file_path: "tutorial/example_finworld/data/val.jsonl" +trainer: + default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/ajet/checkpoints/example_finworld//{{PREFIX}}/{{SUFFIX}}" + # resume_mode: disable # 禁用自动恢复,从头开始训练 +actor_rollout_ref: + rollout: + tensor_model_parallel_size: 8 + gpu_memory_utilization: 0.8 +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_finworld/yaml_template/finworld_template.yaml b/tutorial/example_finworld/yaml_template/finworld_template.yaml index 616be9fe..9a7078c8 100644 --- a/tutorial/example_finworld/yaml_template/finworld_template.yaml +++ b/tutorial/example_finworld/yaml_template/finworld_template.yaml @@ -28,18 +28,17 @@ ajet: total_epochs: 200 rollout: # ✨✨✨✨ 编写并选择Agent - use_agentscope_protocol: True - agentscope_learn_protocol: tutorial.example_finworld.finworld->ExampleAgentScopeLearnProtocol - agentscope_disable_toolcalls: True + user_workflow: tutorial.example_finworld.finworld->ExampleDeepResearchProtocol + force_disable_toolcalls: True enable_oversample: False tensor_model_parallel_size: 8 num_repeat: {{NUM_REPEAT}} max_env_worker: 64 # 增加环境并行数 max_num_seqs: 64 # 增加VLLM并发序列数 - max_env_len: 10000 max_response_length_in_one_turn: 8000 max_model_len: 50000 agent_madness_reward: 0.0 + compute_madness_checklist: None multi_turn: max_steps: {{NUM_STEPS}} interchange_server: From de81c1d58901df469990275e6c10ad700b16d644 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Mon, 19 Jan 2026 17:08:33 +0800 Subject: [PATCH 09/56] feat(core): add finworld task reader support to framework --- ajet/task_reader/__init__.py | 7 ++++--- ajet/task_rollout/resource_keeper.py | 12 ++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ajet/task_reader/__init__.py b/ajet/task_reader/__init__.py index 83c91c6b..4d448ac8 100644 --- a/ajet/task_reader/__init__.py +++ b/ajet/task_reader/__init__.py @@ -61,9 +61,10 @@ def __init__(self, reader_type, reader_config): self.task_reader = DataGeneratorTaskReader(reader_config) elif task_reader_type == "random_dummy": self.task_reader = RandomDummyTaskReader(reader_config) - elif task_reader_type == "jsonl_with_env_service": - # 数据从 jsonl 加载,工具调用走 env_service - self.task_reader = JsonlTaskReader(reader_config) + elif task_reader_type == "finworld": + # FinWorld 专用: 数据从 JSON 文件加载并组装 init_messages,工具调用走 env_service + from tutorial.example_finworld.finworld_reader import FinworldReader + self.task_reader = FinworldReader(reader_config) else: raise ValueError(f"Unsupported task reader type: {task_reader_type}") diff --git a/ajet/task_rollout/resource_keeper.py b/ajet/task_rollout/resource_keeper.py index 26cd44f7..069f715d 100644 --- a/ajet/task_rollout/resource_keeper.py +++ b/ajet/task_rollout/resource_keeper.py @@ -25,7 +25,7 @@ def __enter__(self): self.tokenizer = self.workflow_task.tokenizer self.llm_inference_fn = self.workflow_task.llm_inference_fn self.observation_window = self.workflow_task.observation_window - if self.config.ajet.task_reader.type in ("env_service", "jsonl_with_env_service"): + if self.config.ajet.task_reader.type in ("env_service", "finworld"): url = self.config.ajet.task_reader.env_service.env_url env_type = self.config.ajet.task_reader.env_service.env_type self.env = EnvClientNg(base_url=url) @@ -97,10 +97,10 @@ def _initialize_environment_and_messages(self) -> List[dict]: if self.env is not None: self.env.release_instance(self.workflow_task.episode_uuid) raise e - elif reader_type == "jsonl_with_env_service": - # 新逻辑:调用 create_instance 注册实例,但使用 jsonl 中的 init_messages + elif reader_type == "finworld": + # finworld: 调用 create_instance 注册实例,但使用 reader 组装的 init_messages if self.env is None: - raise ValueError("Environment client is None but jsonl_with_env_service type is specified") + raise ValueError("Environment client is None but finworld type is specified") try: # 必须调用 create_instance,让服务端创建实例,后续 step() 才能工作 self.env.create_instance( @@ -109,12 +109,12 @@ def _initialize_environment_and_messages(self) -> List[dict]: instance_id=self.workflow_task.episode_uuid, params=self.env_params, ) - # 不使用返回的 state,直接用 jsonl 中加载的 init_messages + # 不使用返回的 state,直接用 reader 组装的 init_messages task = self.workflow_task.task if task.init_messages: init_messages = task.init_messages else: - assert task.main_query, "jsonl_with_env_service requires init_messages or main_query in jsonl file." + assert task.main_query, "finworld requires init_messages or main_query." init_messages = [{"role": "user", "content": task.main_query}] except Exception as e: logger.bind(exception=True).exception( From 248acc4884c5f7f7223a95b6e2500b8b301c4303 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Mon, 19 Jan 2026 17:08:49 +0800 Subject: [PATCH 10/56] feat(finworld): implement specialized data reader and openjudge-based grading logic --- tutorial/example_finworld/finworld_judge.py | 14 +- tutorial/example_finworld/finworld_reader.py | 469 ++++++++++--------- 2 files changed, 257 insertions(+), 226 deletions(-) diff --git a/tutorial/example_finworld/finworld_judge.py b/tutorial/example_finworld/finworld_judge.py index 02bb8855..5cdaf3f3 100644 --- a/tutorial/example_finworld/finworld_judge.py +++ b/tutorial/example_finworld/finworld_judge.py @@ -38,9 +38,7 @@ # RewardStats 不再使用,OpenJudge 版本直接使用字典存储 -# 环境变量配置 (RM Gallery) -TRAIN_REF_ANS_PATH = os.environ.get("FINWORLD_TRAIN_REF_ANS_PATH", "") -VAL_REF_ANS_PATH = os.environ.get("FINWORLD_VAL_REF_ANS_PATH", "") +# Reference Answer 路径现在从 config 中读取,见 _init_reference_answers 方法 # OpenJudge imports # ============================================================================= @@ -216,7 +214,11 @@ def _patched_openai_init(self, *args, **kwargs): self.rm_evaluator = None def _init_reference_answers(self): - """初始化参考答案缓存""" + """初始化参考答案缓存,从 config 中读取路径""" + # 从 config 中获取 reference answer 路径 + train_ref_ans_path = getattr(self.config.ajet.judge, "train_ref_ans_path", "") + val_ref_ans_path = getattr(self.config.ajet.judge, "val_ref_ans_path", "") + def _load(path, key): if path and key not in FinWorldJudgeByOpenJudge._ref_answers_cache: try: @@ -224,8 +226,8 @@ def _load(path, key): FinWorldJudgeByOpenJudge._ref_answers_cache[key], FinWorldJudgeByOpenJudge._ref_domains_cache[key] = ans, dom except Exception: FinWorldJudgeByOpenJudge._ref_answers_cache[key], FinWorldJudgeByOpenJudge._ref_domains_cache[key] = {}, {} - _load(TRAIN_REF_ANS_PATH, "train") - _load(VAL_REF_ANS_PATH, "val") + _load(train_ref_ans_path, "train") + _load(val_ref_ans_path, "val") def _get_reference_data(self, task_id: str) -> Tuple[str, str]: """获取任务的参考答案和领域""" diff --git a/tutorial/example_finworld/finworld_reader.py b/tutorial/example_finworld/finworld_reader.py index f742adfc..44d8a330 100644 --- a/tutorial/example_finworld/finworld_reader.py +++ b/tutorial/example_finworld/finworld_reader.py @@ -1,233 +1,262 @@ -from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask -from agentscope.message import Msg -from pydantic import Field -import logging -import threading -import time -import copy -from loguru import logger - +"""FinWorld Reader -# 创建信号量,允许同时12个线程运行 -sem = threading.Semaphore(30) +从 JSON 文件加载任务数据,并现场组装 init_messages。 +- 数据来源:训练集/测试集 JSON 文件 +- 消息组装:加载 prompt 模板 + query +- 工具调用:仍走 env_service +""" +import os +import json +import logging +from typing import List, Dict, Any +from datetime import datetime -class ExampleDeepResearchProtocol(Workflow): +from ajet.schema.task import Task +from ajet.task_reader.task_reader_base import BaseTaskReader +# 配置 logger +logger = logging.getLogger(__name__) - async def execute( - self, workflow_task: WorkflowTask, tuner: AjetTuner - ) -> WorkflowOutput: - from agentscope.agent import ReActAgent - from agentscope.formatter import DashScopeChatFormatter - from agentscope.memory import InMemoryMemory - # 1. 初始化消息 - # init_messages 通常是 [System, User] - init_messages = workflow_task.task.init_messages - - # 分离 System Prompt 和 Initial User Input - if len(init_messages) >= 2: - first_msg, user_msgs = init_messages[0], init_messages[1:] - else: - first_msg = {"content": "You're a helpful assistant."} - user_msgs = init_messages +# 控制 debug 输出的开关(可通过环境变量控制) +DEBUG_ENABLED = os.environ.get("FINWORLD_DEBUG", "0") == "1" - # conversation_history: 维护最原始、最标准的 OpenAI 格式数据 (含 role: tool) - # 这是"真值",用于评测和训练保存 - conversation_history = [ - {"role": "system", "content": first_msg["content"]}, - ] - conversation_history.extend(user_msgs) +def _debug_log(msg: str): + """统一的 debug 日志输出""" + if DEBUG_ENABLED: + print(f"[DEBUG][FinworldReader] {msg}") + logger.debug(msg) - # 2. 初始化 Agent - agent = ReActAgent( - name="Qwen", - sys_prompt=first_msg["content"], # Agent 内部会自动管理 System Prompt - model=tuner.as_agentscope_model(), - formatter=DashScopeChatFormatter(), - memory=InMemoryMemory(), - toolkit=None, - print_hint_msg=False, - ) - agent.set_console_output_enabled(False) - env = workflow_task.gym_env - - # 3. 构造初始 Agent 输入 (List[Msg]) - # 注意:这里只包含 User 消息,不含 System,因为 System 已在 agent init 中设置 - # 必须转换为 Msg 对象 - agent_input = [] - for m in user_msgs: - agent_input.append(Msg( - name=m.get("name", "user"), - content=m.get("content", ""), - role=m.get("role", "user") - )) - # 统计信息缓存 - latest_tool_stats = None - latest_reward_stats = {} - cumulative_tool_call_time = 0.0 # 累计工具调用时间 - cumulative_tool_time = {} # 按工具区分的累计耗时: {tool_name: [time1, time2, ...]} +class FinworldReader(BaseTaskReader): + """ + FinWorld 专用的数据加载器 + + 特点: + 1. 从 JSON 文件加载任务数据(支持 list 和 dict 格式) + 2. 现场组装 init_messages(system_prompt + user_query) + 3. env_type 固定为 "finworld",由 env_service 负责工具调用 + """ + + # 类级别缓存 + _prompt_template_cache = None + _tool_prompt_cache = None + + def __init__(self, reader_config): + super().__init__(reader_config) + self.reader_config = reader_config - logger.info(f"开始执行多轮交互,最大步数: {tuner.config.ajet.rollout.multi_turn.max_steps}") + _debug_log(f"Initializing FinworldReader...") + _debug_log(f"reader_config type: {type(reader_config).__name__}") - step = 0 - for step in range(tuner.config.ajet.rollout.multi_turn.max_steps): - logger.info(f"=== 步骤 {step + 1} ===") - - # === Agent 推理 === - _llm_start = time.time() - # 传入增量消息 (agent_input),Agent 会将其添加到内存并生成回复 - reply_message = await agent(agent_input) - _llm_elapsed = time.time() - _llm_start - # 提取纯文本 content(兼容多模态格式) - if isinstance(reply_message.content, list): - # 多模态格式: [{'type': 'text', 'text': '...'}] - content_text = ''.join(item.get('text', '') for item in reply_message.content if isinstance(item, dict) and item.get('type') == 'text') - else: - content_text = reply_message.content - - content_preview = content_text[:100].replace('\n', ' ') - # logger.info(f"Agent回复 ({_llm_elapsed:.2f}s): {content_preview}...") - - # === 早期终止检查:在调用 env.step() 前检查 context_overflow === - # 修复问题:避免 token_overflow 后还继续调用工具导致阻塞 - if tuner.get_context_tracker().context_overflow: - logger.warning(f"上下文溢出,跳过 env.step(),在第 {step + 1} 步立即结束") - # 构造一个默认的结束响应 - conversation_history.append({ - "role": "assistant", - "content": content_text - }) - break - - # === Env 执行 === - _env_start = time.time() - with sem: - obs, reward, terminate, info = env.step( - action={"content": content_text, "role": "assistant"} - ) - _env_elapsed = time.time() - _env_start - logger.info(f"环境执行 ({_env_elapsed:.2f}s)") - # === 3. 更新 conversation_history (Full History) === - # A. 添加 Assistant 消息 (补全 tool_calls) - current_assistant_msg = { - "role": "assistant", - "content": content_text - } - if info and 'generated_tool_calls' in info and info['generated_tool_calls']: - current_assistant_msg['tool_calls'] = info['generated_tool_calls'] - conversation_history.append(current_assistant_msg) - - # B. 添加 Tool 消息 (直接使用 obs) - # 注意:obs 可能是 [tool_results_msgs] 套了一层,需要解包 - if isinstance(obs, list): - actual_msgs = obs[0] if (len(obs) == 1 and isinstance(obs[0], list)) else obs - conversation_history.extend(actual_msgs) - else: - conversation_history.append({"role": "user", "content": obs}) - - # === 4. 更新统计信息 === - if info: - if 'tool_stats' in info: - latest_tool_stats = info['tool_stats'] - logger.info(f"步骤 {step + 1} 工具统计: 调用={latest_tool_stats.get('total_calls', 0)}, " - f"成功率={latest_tool_stats.get('success_rate', 0):.1f}%") - if 'reward_stats' in info: - latest_reward_stats = info['reward_stats'] - # 累加工具调用时间 - step_tool_call_time = latest_reward_stats.get('tool_call_time', 0.0) - cumulative_tool_call_time += step_tool_call_time - # 累加按工具区分的耗时 - step_tool_time = latest_reward_stats.get('tool_time', {}) - for tool_name, time_list in step_tool_time.items(): - if tool_name not in cumulative_tool_time: - cumulative_tool_time[tool_name] = [] - if isinstance(time_list, list): - cumulative_tool_time[tool_name].extend(time_list) + # 获取 prompt 目录路径 + self.local_path = os.path.dirname(os.path.abspath(__file__)) + _debug_log(f"local_path: {self.local_path}") + + # 初始化 prompt 缓存 + self._init_prompt_templates() + _debug_log(f"Initialization complete.") + + def _init_prompt_templates(self): + """初始化 prompt 模板缓存""" + if FinworldReader._prompt_template_cache is None: + prompt_file = os.path.join(self.local_path, 'prompt', 'finance_analyst_prompt.md') + _debug_log(f"Loading prompt template from: {prompt_file}") + with open(prompt_file, 'r', encoding='utf-8') as f: + FinworldReader._prompt_template_cache = f.read() + _debug_log(f"Prompt template loaded, length: {len(FinworldReader._prompt_template_cache)} chars") + else: + _debug_log(f"Using cached prompt template, length: {len(FinworldReader._prompt_template_cache)} chars") + + if FinworldReader._tool_prompt_cache is None: + # 使用 tool_prompt_builder.py 中的静态模板 + _debug_log(f"Loading tool prompt template...") + from tutorial.example_finworld.prompt.tool_prompt_builder import get_tool_prompt_template + FinworldReader._tool_prompt_cache = get_tool_prompt_template() + _debug_log(f"Tool prompt template loaded, length: {len(FinworldReader._tool_prompt_cache)} chars") + else: + _debug_log(f"Using cached tool prompt template, length: {len(FinworldReader._tool_prompt_cache)} chars") + + def _build_system_prompt(self) -> str: + """构建 system prompt""" + current_date = datetime.now().strftime('%Y-%m-%d') + _debug_log(f"Building system prompt with date: {current_date}") + + # 替换日期占位符 + system_prompt = FinworldReader._prompt_template_cache.replace( + '{current_date}', + current_date + ) + # 替换工具列表占位符 + system_prompt = system_prompt.replace( + '{tool_list}', + FinworldReader._tool_prompt_cache + ) + _debug_log(f"System prompt built, final length: {len(system_prompt)} chars") + return system_prompt + + def _build_init_messages(self, query: str) -> List[Dict[str, Any]]: + """ + 构建 init_messages + + Args: + query: 用户问题 - # === 5. 准备下一轮 Agent 输入 (Incremental) === - # 将 Env 返回的 obs 转换为 Msg 对象列表,供下一轮 agent() 调用 - # 关键:这里只放新的 obs,不要放完整的 history - agent_input = [] + Returns: + [{"role": "system", "content": ...}, {"role": "user", "content": ...}] + """ + _debug_log(f"Building init_messages for query (len={len(query)}): {query[:100]}..." if len(query) > 100 else f"Building init_messages for query: {query}") + system_prompt = self._build_system_prompt() + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": query} + ] + _debug_log(f"init_messages built: {len(messages)} messages, system_prompt_len={len(system_prompt)}") + return messages + + def _read_json_file(self, file_path: str, split: str = "train") -> List[Task]: + """ + 从 JSON 文件读取任务列表 + + 支持的数据格式: + 1. List 格式: [{"task": {"task_id": ..., "query": ...}, ...}, ...] + 2. Dict 格式: {"task_id_1": {"task": {...}, ...}, "task_id_2": {...}, ...} + + Args: + file_path: JSON 文件路径 + split: 数据集划分(train/val) - if isinstance(obs, list): - # Standard Mode: obs 是 tool messages 列表 - # 注意:finworld_env.step 返回 {"state": [tool_results_msgs]} 套了一层列表 - # BaseGymEnv.step 直接透传,所以 obs = [tool_results_msgs] - # 需要解包获取实际的消息列表 - actual_msgs = obs[0] if (len(obs) == 1 and isinstance(obs[0], list)) else obs - logger.info(f"环境观察 (Standard): 收到 {len(actual_msgs)} 条工具消息") + Returns: + List[Task]: 任务列表 + """ + _debug_log(f"Reading JSON file: {file_path}, split={split}") + + if not os.path.exists(file_path): + _debug_log(f"ERROR: File not found: {file_path}") + raise FileNotFoundError(f"JSON file not found: {file_path}") + + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + _debug_log(f"JSON data loaded, type: {type(data).__name__}, size: {len(data) if isinstance(data, (list, dict)) else 'N/A'}") + + tasks = [] + skipped_count = 0 + split_filtered_count = 0 + + # 解析数据 + if isinstance(data, list): + # List 格式 + _debug_log(f"Parsing List format data, total items: {len(data)}") + for idx, item in enumerate(data): + task_info = item.get('task', {}) + task_id = task_info.get('task_id', '') + query = task_info.get('query', '') - # 按照 AgentScope 的 ContentBlock 格式转换消息 - # Agent.memory 会自动保存 assistant 的 tool_call 信息 - # 这里只需要传入 tool_result 消息即可 - for idx, m in enumerate(actual_msgs): - origin_role = m.get('role', 'user') - if origin_role == 'tool': - # 使用 ToolResultBlock 格式,作为 user 消息的 content - tool_result_block = { - "type": "tool_result", - "id": m.get('tool_call_id', ''), - "output": m.get('content', ''), - "name": m.get('name', '') - } - new_msg = Msg( - name="tool", - content=[tool_result_block], - role="user" - ) - agent_input.append(new_msg) - else: - # 其他消息(如 user 提示)直接添加 - content = m.get('content') - if content is None: content = "" - valid_role = origin_role if origin_role in ['user', 'assistant', 'system'] else 'user' - new_msg = Msg( - name=m.get('name', valid_role), - content=content, - role=valid_role - ) - agent_input.append(new_msg) - else: - # Legacy Mode - logger.info(f"环境观察 (Legacy): {str(obs)[:100]}...") - agent_input.append(Msg(name="env", content=obs, role="user")) - - # === 6. 终止检查 === - logger.info(f"终止状态: {terminate}") - if terminate: - logger.info(f"环境返回终止信号,在第 {step + 1} 步结束") - break + if not task_id or not query: + skipped_count += 1 + _debug_log(f" Item {idx}: SKIPPED (missing task_id or query)") + continue - if tuner.get_context_tracker().context_overflow: - logger.warning(f"上下文溢出,在第 {step + 1} 步结束") - break - - # === 结束处理 === - final_tool_stats = latest_tool_stats or { - 'total_calls': 0, 'total_errors': 0, 'success_calls': 0, 'success_rate': 0.0, - 'cache_hits': 0, 'cache_misses': 0 - } - # 将累计的 tool_time 合并到 tool_stats 中 - final_tool_stats['tool_time'] = cumulative_tool_time - final_tool_stats['tool_call_time'] = cumulative_tool_call_time - - logger.info(f"\n{'='*80}") - logger.info(f"任务完成统计 (Task ID: {workflow_task.task.task_id}):") - logger.info(f" 总步骤: {step + 1}") - logger.info(f" 总调用: {final_tool_stats.get('total_calls', 0)}") - logger.info(f" 成功率: {final_tool_stats.get('success_rate', 0):.2f}%") - logger.info(f"{'='*80}\n") - - return WorkflowOutput( - reward=None, - metadata={ - "total_step": step, - "tool_stats": final_tool_stats, - "reward_stats": latest_reward_stats, - "tool_success_rate": round(final_tool_stats.get('success_rate', 0.0), 2), - "conversation_history": conversation_history, - "query": workflow_task.task.main_query, - "task_id": workflow_task.task.task_id, - } - ) \ No newline at end of file + # 过滤 split + item_split = task_info.get('metadata', {}).get('split', split) + if item_split != split: + split_filtered_count += 1 + _debug_log(f" Item {idx} ({task_id}): FILTERED by split (item_split={item_split}, expected={split})") + continue + + # 构建 Task + _debug_log(f" Item {idx} ({task_id}): Creating task...") + task = self._create_task(task_id, query, item) + tasks.append(task) + + elif isinstance(data, dict): + # Dict 格式 + _debug_log(f"Parsing Dict format data, total keys: {len(data)}") + for idx, (task_id, item) in enumerate(data.items()): + task_info = item.get('task', {}) + query = task_info.get('query', '') + + if not query: + skipped_count += 1 + _debug_log(f" Key {idx} ({task_id}): SKIPPED (missing query)") + continue + + # 过滤 split + item_split = task_info.get('metadata', {}).get('split', split) + if item_split != split: + split_filtered_count += 1 + _debug_log(f" Key {idx} ({task_id}): FILTERED by split (item_split={item_split}, expected={split})") + continue + + # 构建 Task(使用 dict key 作为 task_id) + _debug_log(f" Key {idx} ({task_id}): Creating task...") + task = self._create_task(task_id, query, item) + tasks.append(task) + + _debug_log(f"Summary: loaded={len(tasks)}, skipped={skipped_count}, split_filtered={split_filtered_count}") + print(f"[FinworldReader] Loaded {len(tasks)} tasks from {file_path} (split={split})") + + if len(tasks) == 0: + raise ValueError(f"No tasks found in file: {file_path} for split={split}") + + return tasks + + def _create_task(self, task_id: str, query: str, raw_item: Dict[str, Any]) -> Task: + """ + 创建 Task 对象 + + Args: + task_id: 任务 ID + query: 用户问题 + raw_item: 原始数据项 + + Returns: + Task: 任务对象 + """ + _debug_log(f"Creating Task: task_id={task_id}") + + # 现场组装 init_messages + init_messages = self._build_init_messages(query) + + # 提取 metadata + task_info = raw_item.get('task', {}) + metadata = task_info.get('metadata', {}) + + # 将原始数据存入 metadata,供 env 和 judge 使用 + # 注意:序列化为 JSON 字符串,避免嵌套字典导致 PyArrow 序列化时递归深度超限 + metadata['raw_task_data'] = json.dumps(raw_item, ensure_ascii=False) + metadata['query'] = query + metadata['confidence'] = raw_item.get('confidence', 1.0) + metadata['rubrics'] = raw_item.get('rubrics', None) + metadata['ground_truth'] = task_info.get('ground_truth', '') + + _debug_log(f" Task metadata: confidence={metadata['confidence']}, has_rubrics={metadata['rubrics'] is not None}, has_ground_truth={bool(metadata['ground_truth'])}") + _debug_log(f" Task init_messages: {len(init_messages)} messages") + + task = Task( + main_query=query, + init_messages=init_messages, + task_id=task_id, + env_type="finworld", # 固定为 finworld,由 env_service 处理 + metadata=metadata + ) + _debug_log(f" Task created successfully: {task_id}") + return task + + def get_training_tasks(self) -> List[Task]: + """获取训练任务""" + _debug_log(f"get_training_tasks() called") + file_path = self.reader_config.finworld.training.file_path + _debug_log(f"Training file path: {file_path}") + tasks = self._read_json_file(file_path, split="train") + _debug_log(f"get_training_tasks() returning {len(tasks)} tasks") + return tasks + + def get_validation_tasks(self) -> List[Task]: + """获取验证任务""" + _debug_log(f"get_validation_tasks() called") + file_path = self.reader_config.finworld.validation.file_path + _debug_log(f"Validation file path: {file_path}") + tasks = self._read_json_file(file_path, split="val") + _debug_log(f"get_validation_tasks() returning {len(tasks)} tasks") + return tasks From 9d651fd4eafa199379a505a646a5c0f0cf7f4445 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Mon, 19 Jan 2026 17:09:20 +0800 Subject: [PATCH 11/56] refactor(finworld): optimize configuration templates and prompt engineering --- tutorial/example_finworld/finworld.yaml | 22 +- .../prompt/finance_analyst_prompt.md | 189 ++++++++++++++++++ .../prompt/finworld_prompt.md | 0 .../prompt/tool_prompt_builder.py | 150 ++++++++++++++ .../finworld_ajet_finworld_loadjsonl_8b.yaml} | 47 ++--- .../yaml_template/finworld_template.yaml | 14 +- 6 files changed, 391 insertions(+), 31 deletions(-) create mode 100644 tutorial/example_finworld/prompt/finance_analyst_prompt.md delete mode 100644 tutorial/example_finworld/prompt/finworld_prompt.md create mode 100644 tutorial/example_finworld/prompt/tool_prompt_builder.py rename tutorial/example_finworld/{yaml_template/finworld_jsonl_template.yaml => yaml/finworld_ajet_finworld_loadjsonl_8b.yaml} (58%) diff --git a/tutorial/example_finworld/finworld.yaml b/tutorial/example_finworld/finworld.yaml index 5be76eac..344120a5 100644 --- a/tutorial/example_finworld/finworld.yaml +++ b/tutorial/example_finworld/finworld.yaml @@ -50,13 +50,27 @@ ajet: max_response_length: 41000 task_reader: - type: env_service # `env_service` or `dataset_file` or `huggingface_dat_repo` + # type: env_service # `env_service` or `dataset_file` or `huggingface_dat_repo` or `finworld` + # === 方案 A: 传统 env_service 模式 === + # env_service: + # env_type: "finworld" + # env_url: "http://127.0.0.1:8080" + # env_action_preference: code + # training_split: train + # validation_split: val + + # === 方案 B: FinWorld Reader 模式 (数据从 JSON 加载,工具调用走 env_service) === + type: finworld + finworld: + training: + file_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/finworld_tasks_11171143_cc.json + validation: + file_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/AgentEvolver_query_val.json + # env_service 仍然需要配置(用于工具调用) env_service: env_type: "finworld" env_url: "http://127.0.0.1:8080" - env_action_preference: code # code, text, box - training_split: train - validation_split: val + env_action_preference: code trainer: default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/ajet/checkpoints/example_finworld//localths/cc_rm4_res2cit2fai2_30b" # resume_mode: disable # 禁用自动恢复,从头开始训练 diff --git a/tutorial/example_finworld/prompt/finance_analyst_prompt.md b/tutorial/example_finworld/prompt/finance_analyst_prompt.md new file mode 100644 index 00000000..f3dd2bad --- /dev/null +++ b/tutorial/example_finworld/prompt/finance_analyst_prompt.md @@ -0,0 +1,189 @@ +你是一位专业的金融研究分析师。你的任务是通过工具收集信息,进行深度研究,并最终输出一份结构化的 Markdown 格式研究报告。 + +当前日期: {current_date} + +## 研究流程 + +你必须采用两阶段深度研究方法: + +### 第一阶段:先大纲、后调研(必须执行) + +**你必须先输出研究大纲,再通过工具收集信息;禁止在没有数据支撑的情况下直接生成完整报告。** + +1. **理解需求**:分析用户问题的类型(个股分析/行业研究/事件解读/宏观分析/股票检索等),明确用户关注的核心结论与评估维度。 +2. **先写研究大纲(必须先做,且此时不要调用工具)**: + - 输出一个“报告大纲”,包含:一级/二级标题 + 每节要回答的关键问题(Key Questions)。 + - 大纲应明确:每一部分需要哪些证据类型(财务/估值/新闻/政策/行业对比/同业数据等)、需要哪些关键表格或对比指标。 + - 注意:此步骤只写结构与问题清单,不要在没有数据前给出确定性的数字结论。 +3. **按大纲逐段调研(必须执行)**: + - 以大纲为索引,一节一节地收集证据与数据,逐步补全每个 Key Question。 + - **必须使用工具收集数据** - 不要凭空猜测或使用过时信息 + - **分批调用工具** - 每次最多调用3个相关工具,避免一次性调用过多 + - 每轮调用工具后先做小结:本轮获得了哪些可用证据、还缺什么,再决定下一轮 + - 多维度交叉验证,确保数据全面性和准确性 + +### 第二阶段:深度分析与报告生成 + +**仅当你已经通过工具获取了充分的数据后,才能进入此阶段。** + +- 当收集到充分信息后,进入写作阶段,基于真实数据输出完整的 Markdown 格式研究报告,并在报告末尾添加 `[TASK_COMPLETED]` 标记。 +- 写作过程中如果发现关键结论缺少证据支撑,允许**追加 1-2 轮工具调用**进行补充取证,然后继续完成报告(不要为刷工具而调用)。 + +## 工具调用格式 + +当需要使用工具时,**必须严格按照以下标准JSON格式输出**(注意:使用单花括号`{`,不要使用双花括号`{{`): + +```json +[ + { + "tool_name": "工具名称", + "tool_args": { + "参数名": "参数值" + } + } +] +```` + +**重要限制:每次最多调用3个工具** + +* 先调用核心工具获取关键信息,分析后再调用补充工具 +* 例如:先查股票代码 → 再查财务数据 → 最后查行业新闻 + +**工具调用示例(每次1-3个工具)**: + +```json +[ + { + "tool_name": "dashscope_search", + "tool_args": { + "query": "茅台股票代码" + } + } +] +``` + +**重要提醒**: + +* ✓ 使用标准JSON格式,单花括号 `{` 和 `}` +* ✗ 不要使用双花括号 `{{` 或 `}}` +* ✓ 所有字符串必须用双引号包裹 +* ✓ 确保JSON格式可被 `json.loads()` 正确解析 + +## 可用工具 + +以下是当前环境中可用的工具列表。**带✅标记的工具是推荐优先使用的,它们经过验证更加稳定和可靠。** + +{tool_list} + +## 工具使用准则 + +1. **数量限制(重要)**:每次最多调用3个工具,采用多轮次渐进式调研,避免单次调用过多工具 +2. **优先使用推荐工具**:带✅标记的工具经过验证,更加稳定可靠,应优先考虑使用 +3. **先搜索后查询**:不确定的信息(如股票代码、行业分类)必须先用所提供的搜索工具查询,不要臆测 +4. **逐步推进**:每次调用工具后分析结果,再决定下一步调研方向,不要一次性规划所有调用 +5. **多源验证**:使用多个工具源交叉验证关键数据 +6. **全面覆盖**:根据研究主题,通过多轮调用逐步覆盖相关的多个维度(基本面、财务、估值、行业、新闻等) +7. **仅调用存在的工具**:只能使用上述"可用工具"列表中的工具,不要调用不存在的工具名称 + + +## 引用规范(必须遵守) + +你必须使用学术论文风格的引用标注,保证读者可追溯到工具获取的信息来源。 + +### 1) 何时必须引用(关键事实句) + +“关键事实句”指:包含数字/同比环比/日期/财务指标/估值倍数/明确事实结论/具体事件/具体公司或行业陈述/政策条款的句子。 + +* 所有关键事实句句末必须添加引用编号:**[1]** 或 **[1][2]**。 +* **同一来源务必在全文重复使用同一编号**。 + +### 2) References(必须,包含内容与格式要求) + +* 报告末尾必须包含 `## References` 小节。 +* 每条引用一行,编号从 `[1]` 开始连续;同一来源务必重复使用同一编号。 +* **URL 优先**:若工具返回有可用 `url`,References 中应填写该 URL,以及来源 +* **URL 提取指南**: + - 对于 `dashscope_search`:从 `search_results` 的 `"url"` 字段提取 + - 对于 `crawl_ths_*` 系列工具:从返回内容的 `"以下内容来自:"` 后提取 URL +* 禁止伪造链接/来源;无法证据支撑的只能写“推测/假设”,不要用引用包装成事实。 +* 正文出现的每个 `[n]` 必须在 References 中有对应条目;References 不得包含正文未使用的编号。 + +**行格式模板(URL 可选)**: + +* `[n] 标题或简述,来源 - URL` +* `[n] 标题或简述,工具:,参数:,数据日期/报告期: ,来源 - URL` + +### 3) 输出前自检(必须) + +输出前检查: + +* 所有关键事实句是否都有 `[n]`; +* `## References` 是否覆盖全部编号。 + + +## 最终报告要求 + +当信息收集完成后,必须输出 **Markdown 格式的结构化研究报告**。 + +### 报告结构说明 + +根据用户问题类型,选择合适的报告结构: + +**个股分析**:包含公司概况、财务分析、估值分析、行业地位、最新动态、投资建议等 +**行业研究**:包含行业概况、发展趋势、政策环境、竞争格局、龙头企业、投资机会等 +**事件解读**:包含事件背景、影响分析、相关标的、投资策略等 +**宏观分析**:包含宏观环境、政策分析、市场影响、配置建议等 +**股票检索**:包含筛选标准、候选标的、对比分析、推荐排序等 + +### 报告格式要求 + +1. **使用 Markdown 语法**:标题(#)、列表(-)、表格(|)、加粗(**)等 +2. **结构清晰**:使用多级标题组织内容 +3. **数据可视化**:适当使用表格展示关键数据对比(表格中的关键数据同样需要引用 [n]) +4. **逻辑完整**:包含执行摘要、详细分析、结论建议 +5. **引用与参考文献(必须)**:正文关键事实句使用 [n] 引用;文末提供 `## References` + +### 报告示例框架 + +```markdown +# [研究主题] + +## 摘要 +[核心观点和结论,3-5条,每条如包含关键事实也要加引用 [n]] + +## [主体部分 - 根据主题自适应] +### [二级标题] +[具体分析内容...关键事实句末尾加 [n]] + +## 结论与建议 +[明确的结论和操作建议...关键事实句末尾加 [n]] + +## References +[1] 标题或简要描述 - https://... +[2] 贵州茅台历史股价分析(报告期2025-09-30),工具:history_calculate,参数:code=600519,query=过去一周涨跌情况 - https:// +--- +*本报告基于公开信息整理分析,仅供参考,不构成投资建议。投资有风险,入市需谨慎。* + +[TASK_COMPLETED] +``` + +## 何时停止调用工具并输出报告 + +**必须满足以下所有条件后,才能输出最终报告:** + +1. ✓ **已实际调用工具获取足够证据**(通常至少 2-4 轮;以“信息充分支撑结论”为准,而非强制轮数) +2. ✓ **已获取核心数据**:财务数据、市场数据、新闻动态等关键信息 +3. ✓ **已交叉验证**:从多个数据源验证了关键结论(至少对关键数字/事件做到交叉验证) +4. ✓ **数据完整性**:具备足够信息支撑每一个分析结论和投资建议(无法支撑的必须标注为推测/假设) + +**输出格式要求:** + +* 输出完整的 Markdown 格式研究报告(包含标题、摘要、分析、结论、References) +* 报告必须基于真实调用工具获取的数据,不能是空洞的框架 +* **在报告的最后一行单独输出** `[TASK_COMPLETED]` 标记 + +**警告:禁止在没有调用工具、没有真实数据的情况下直接输出报告框架+`[TASK_COMPLETED]`,这是无效的研究。** + +--- + +现在开始深度研究用户的问题。 \ No newline at end of file diff --git a/tutorial/example_finworld/prompt/finworld_prompt.md b/tutorial/example_finworld/prompt/finworld_prompt.md deleted file mode 100644 index e69de29b..00000000 diff --git a/tutorial/example_finworld/prompt/tool_prompt_builder.py b/tutorial/example_finworld/prompt/tool_prompt_builder.py new file mode 100644 index 00000000..5c940fd7 --- /dev/null +++ b/tutorial/example_finworld/prompt/tool_prompt_builder.py @@ -0,0 +1,150 @@ +""" +工具信息Prompt构建模块 +用于生成清晰、结构化的工具使用说明 +""" + +def get_tool_prompt_template() -> str: + """ + 获取工具prompt模板(静态版本) + 基于实际探测到的19个工具进行配置 + + Returns: + 预定义的工具说明文本 + """ + + return """## 可用工具列表 + +### ⚠️ 重要说明 +**股票代码格式规范**: +- 涉及A股代码时,通常使用 **6位纯数字** 格式(如 `000001`、`600000`)。 +- **注意**: 用户输入股票名称时,必须先使用 `extract_entities_code` 转换为对应的代码。 + +--- + +### 🔍 实体与数据计算工具 + +#### ✅ extract_entities_code +**功能**: 从查询中提取金融实体(股票、债券、基金、加密货币、指数、商品、ETF等),并查找对应的代码。最后返回查询中出现的金融实体及其类型和代码。 +**参数**: + - `query` (必填, string): 关于金融实体的自然语言查询文本 + +#### ✅ history_calculate +**功能**: 获取指定A股股票的历史股价数据,并根据用户问题进行分析。 +**数据结构**: 工具内部包含以下字段的历史数据: + - `ts_code`(代码), `trade_date`(交易日期) + - `open`(开), `high`(高), `low`(低), `close`(收), `pre_close`(昨收) + - `change`(涨跌额), `pct_chg`(涨跌幅) + - `vol`(成交量), `amount`(成交额) +**使用说明**: 你无需编写任何代码——只需直接提问即可,例如:“过去一周涨了多少,有没有出现顶背离?”、“MACD是否形成了金叉?”。 +**参数**: + - `code` (必填, string): A股代码 (如 '600000' 或 '000001') + - `query` (必填, string): 关于股票历史表现的具体问题 + +--- + +### 💻 代码与通用网络工具 + +#### ✅ execute_code +**功能**: 执行 Python 代码,适用于复杂分析或计算场景。最终结果请使用 `print` 函数输出。 +**参数**: + - `code` (必填, string): 需要执行的代码 + +#### ✅ execute_shell +**功能**: 执行 Shell 命令 (如 `ls`, `pwd`, 运行脚本)。 +**注意**: 每次调用起始目录相同。如需多步操作,请在一条命令中使用 `&&` 连接 (例如: `cd aa/bb && bash xxx`)。 +**参数**: + - `command` (必填, string): 需要执行的命令 + +#### ✅ dashscope_search +**功能**: 使用搜索关键词从互联网检索相关信息。如果有多个关键词,请分开多次调用。 +**参数**: + - `query` (必填, string): 搜索关键词 + +#### ✅ crawl_url +**功能**: 网页内容解析工具,获取并格式化指定URL的网页内容。 +**参数**: + - `url` (必填, string): 目标网页URL + +# --- + +### 📈 同花顺专项数据工具 (Crawl THS) +*以下工具用于获取特定维度的深度金融数据,请根据用户意图选择最匹配的工具* + +#### ✅ crawl_ths_company +**功能**: 获取上市公司基本资料。 +**数据范围**: 详细情况、高管介绍、发行相关、参控股公司。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_holder +**功能**: 获取股东研究信息。 +**数据范围**: 股东人数、十大流通股东、十大股东、十大债券持有人、控股层级关系。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_operate +**功能**: 获取经营分析信息。 +**数据范围**: 主营介绍、运营业务数据、主营构成分析、主要客户及供应商、董事会经营评述、产品价格。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_equity +**功能**: 获取股本结构信息。 +**数据范围**: 解禁时间表、总股本构成、A股结构图、历次股本变动。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_capital +**功能**: 获取资本运作信息。 +**数据范围**: 募集资金来源、项目投资、收购兼并、股权投资、参股IPO、股权转让、关联交易、质押解冻。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_finance +**功能**: 获取财务分析信息。 +**数据范围**: 财务诊断、财务指标、指标变动说明、资产负债构成、财务报告、杜邦分析。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_worth +**功能**: 获取盈利预测信息。 +**数据范围**: 业绩预测、业绩预测详表、研报评级。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_news +**功能**: 获取新闻公告信息。 +**数据范围**: 新闻与股价联动、公告列表、热点新闻列表、研报列表。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_concept +**功能**: 获取概念题材信息。 +**数据范围**: 常规概念、其他概念、题材要点、概念对比。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_position +**功能**: 获取主力持仓信息。 +**数据范围**: 机构持股汇总、机构持股明细、被举牌情况、IPO获配机构。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_bonus +**功能**: 获取分红融资信息。 +**数据范围**: 分红诊断、分红情况、增发机构获配明细、增发概况、配股概况。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_event +**功能**: 获取公司大事信息。 +**数据范围**: 高管持股变动、股东持股变动、担保明细、违规处理、机构调研、投资者互动。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_field +**功能**: 获取行业对比信息。 +**数据范围**: 行业地位、行业新闻。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) +""" \ No newline at end of file diff --git a/tutorial/example_finworld/yaml_template/finworld_jsonl_template.yaml b/tutorial/example_finworld/yaml/finworld_ajet_finworld_loadjsonl_8b.yaml similarity index 58% rename from tutorial/example_finworld/yaml_template/finworld_jsonl_template.yaml rename to tutorial/example_finworld/yaml/finworld_ajet_finworld_loadjsonl_8b.yaml index 56a81472..1736d138 100644 --- a/tutorial/example_finworld/yaml_template/finworld_jsonl_template.yaml +++ b/tutorial/example_finworld/yaml/finworld_ajet_finworld_loadjsonl_8b.yaml @@ -1,25 +1,27 @@ # ------------------ 主要配置 ------------------ ajet: project_name: ajet_finworld - experiment_name: "{{SUFFIX}}" + experiment_name: "ajet_finworld_loadjsonl_8b" # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) judge: - openjudge_llm: {{OPENJUDGE_LLM}} # OpenJudge 模型 - rm_llm: {{RM_LLM}} # RM Gallery 模型 - concurrency: {{JUDGE_CONCURRENCY}} # Judge 并发数 + openjudge_llm: qwen-flash # OpenJudge 模型 + rm_llm: qwen-max # RM Gallery 模型 + concurrency: 10 # Judge 并发数 + train_ref_ans_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_DR_11171143_cc.json # 训练集 Reference Answer 路径 + val_ref_ans_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_val.json # 验证集 Reference Answer 路径 # OpenJudge 权重配置 - report_resolution_weight: {{REPORT_RESOLUTION_WEIGHT}} # 报告质量评估 - trajectory_faithfulness_weight: {{TRAJECTORY_FAITHFULNESS_WEIGHT}} # 事实准确性评估 - citation_audit_weight: {{CITATION_AUDIT_WEIGHT}} # 引用审计评估 (覆盖率 + 真实性) - rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 + report_resolution_weight: 0.2 # 报告质量评估 + trajectory_faithfulness_weight: 0.2 # 事实准确性评估 + citation_audit_weight: 0.2 # 引用审计评估 (覆盖率 + 真实性) + rm_weight: 0.4 # RM Gallery 权重 task_judge: # 使用本地 FinWorldJudge 进行评估(解耦远程 env_service) judge_protocol: tutorial.example_finworld.finworld_judge->FinWorldJudgeByOpenJudge model: # ✨✨✨✨ 设置待训练的模型 - path: {{MODEL_PATH}} + path: /mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507 trainer_common: - nnodes: {{NNODES}} + nnodes: 2 n_gpus_per_node: 8 val_before_train: True val_pass_n: 8 @@ -32,7 +34,7 @@ ajet: force_disable_toolcalls: True enable_oversample: False tensor_model_parallel_size: 8 - num_repeat: {{NUM_REPEAT}} + num_repeat: 4 max_env_worker: 64 # 增加环境并行数 max_num_seqs: 64 # 增加VLLM并发序列数 max_response_length_in_one_turn: 8000 @@ -40,32 +42,31 @@ ajet: agent_madness_reward: 0.0 compute_madness_checklist: None multi_turn: - max_steps: {{NUM_STEPS}} + max_steps: 6 interchange_server: interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) debug: debug_max_parallel: 64 # 增加并行任务数,充分利用GPU debug_first_n_tasks: 100 # 增加处理的任务数 data: - train_batch_size: {{TRAIN_BATCH_SIZE}} + train_batch_size: 32 max_prompt_length: 8000 max_response_length: 41000 task_reader: - type: jsonl_with_env_service # 数据从 jsonl 加载,工具调用走 env_service + type: finworld # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service + finworld: + training: + file_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/train_cc423_11171143_tasks.json + validation: + file_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/val_30_tasks.json + # env_service 仍需配置(用于工具调用) env_service: env_type: "finworld" env_url: "http://127.0.0.1:8080" - env_action_preference: code # code, text, box - training_split: train - validation_split: val - jsonl_dataset_file: - training: - file_path: "tutorial/example_finworld/data/train.jsonl" - validation: - file_path: "tutorial/example_finworld/data/val.jsonl" + env_action_preference: code trainer: - default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/ajet/checkpoints/example_finworld//{{PREFIX}}/{{SUFFIX}}" + default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/ajet/checkpoints/example_finworld//open/ajet_finworld_loadjsonl_8b" # resume_mode: disable # 禁用自动恢复,从头开始训练 actor_rollout_ref: rollout: diff --git a/tutorial/example_finworld/yaml_template/finworld_template.yaml b/tutorial/example_finworld/yaml_template/finworld_template.yaml index 9a7078c8..70b379f0 100644 --- a/tutorial/example_finworld/yaml_template/finworld_template.yaml +++ b/tutorial/example_finworld/yaml_template/finworld_template.yaml @@ -7,6 +7,8 @@ ajet: openjudge_llm: {{OPENJUDGE_LLM}} # OpenJudge 模型 rm_llm: {{RM_LLM}} # RM Gallery 模型 concurrency: {{JUDGE_CONCURRENCY}} # Judge 并发数 + train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径 + val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径 # OpenJudge 权重配置 report_resolution_weight: {{REPORT_RESOLUTION_WEIGHT}} # 报告质量评估 trajectory_faithfulness_weight: {{TRAJECTORY_FAITHFULNESS_WEIGHT}} # 事实准确性评估 @@ -52,13 +54,17 @@ ajet: max_response_length: 41000 task_reader: - type: env_service # `env_service` or `dataset_file` or `huggingface_dat_repo` + type: finworld # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service + finworld: + training: + file_path: {{TRAIN_DATA_PATH}} + validation: + file_path: {{VAL_DATA_PATH}} + # env_service 仍需配置(用于工具调用) env_service: env_type: "finworld" env_url: "http://127.0.0.1:8080" - env_action_preference: code # code, text, box - training_split: train - validation_split: val + env_action_preference: code trainer: default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/ajet/checkpoints/example_finworld//{{PREFIX}}/{{SUFFIX}}" # resume_mode: disable # 禁用自动恢复,从头开始训练 From 7475ecc0c516d92f53596ffb2aa8d450b0fcd15e Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Mon, 19 Jan 2026 17:09:30 +0800 Subject: [PATCH 12/56] chore(finworld): update launch scripts and add variant experiment scripts --- ...ajet_finworld.sh => ajet_finworld_cc1k.sh} | 2 +- .../scripts/ajet_finworld_loadjsonl.sh | 16 +- .../scripts/ajet_finworld_loadjsonl_8b.sh | 264 ++++++++++++++++++ 3 files changed, 279 insertions(+), 3 deletions(-) rename tutorial/example_finworld/scripts/{ajet_finworld.sh => ajet_finworld_cc1k.sh} (99%) create mode 100644 tutorial/example_finworld/scripts/ajet_finworld_loadjsonl_8b.sh diff --git a/tutorial/example_finworld/scripts/ajet_finworld.sh b/tutorial/example_finworld/scripts/ajet_finworld_cc1k.sh similarity index 99% rename from tutorial/example_finworld/scripts/ajet_finworld.sh rename to tutorial/example_finworld/scripts/ajet_finworld_cc1k.sh index d417c7cf..a0c8895f 100644 --- a/tutorial/example_finworld/scripts/ajet_finworld.sh +++ b/tutorial/example_finworld/scripts/ajet_finworld_cc1k.sh @@ -3,7 +3,7 @@ set -e #=============================================================================== # 配置区域 - 用户只需修改这里 #=============================================================================== -SUFFIX="ajet_finworld" # 实验后缀,影响所有日志和实验名称 +SUFFIX="ajet_finworld_cc1k" # 实验后缀,影响所有日志和实验名称 PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 # 新增:模型与模板配置 diff --git a/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl.sh b/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl.sh index a5550ba8..1abde8a0 100644 --- a/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl.sh +++ b/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl.sh @@ -3,12 +3,20 @@ set -e #=============================================================================== # 配置区域 - 用户只需修改这里 #=============================================================================== -SUFFIX="ajet_finworld_loadjsonl" # 实验后缀,影响所有日志和实验名称 +SUFFIX="ajet_finworld_loadjsonl_7b" # 实验后缀,影响所有日志和实验名称 PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 # 新增:模型与模板配置 MODEL_PATH="/mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507" -CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_jsonl_template.yaml" +CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_template.yaml" + +# 新增:数据文件路径配置 +TRAIN_DATA_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/train_cc423_11171143_tasks.json" +VAL_DATA_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/val_30_tasks.json" + +# 新增:Reference Answer 文件路径配置(RM Gallery 需要) +TRAIN_REF_ANS_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_DR_11171143_cc.json" +VAL_REF_ANS_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_val.json" # 新增:Judge 模型配置 OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 @@ -185,6 +193,10 @@ if [[ $HOSTNAME == *"-master-"* ]]; then -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ + -e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \ + -e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \ + -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ + -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} print_green "配置文件已生成: ${CONFIG_FILE}" diff --git a/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl_8b.sh b/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl_8b.sh new file mode 100644 index 00000000..c7a13048 --- /dev/null +++ b/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl_8b.sh @@ -0,0 +1,264 @@ +#!/bin/bash +set -e +#=============================================================================== +# 配置区域 - 用户只需修改这里 +#=============================================================================== +SUFFIX="ajet_finworld_loadjsonl_8b" # 实验后缀,影响所有日志和实验名称 +PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 + +# 新增:模型与模板配置 +MODEL_PATH="/mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-8B" +CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_template.yaml" + +# 新增:数据文件路径配置 +TRAIN_DATA_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/train_cc423_11171143_tasks.json" +VAL_DATA_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/val_30_tasks.json" + +# 新增:Reference Answer 文件路径配置(RM Gallery 需要) +TRAIN_REF_ANS_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_DR_11171143_cc.json" +VAL_REF_ANS_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_val.json" + +# 新增:Judge 模型配置 +OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 +RM_LLM='qwen-max' # RM Gallery 评分模型 +JUDGE_CONCURRENCY=10 + +# 新增:奖励权重配置 +RM_WEIGHT=0.4 +CITATION_AUDIT_WEIGHT=0.2 +REPORT_RESOLUTION_WEIGHT=0.2 +TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 + +# API密钥配置(从 .env 文件加载,不要硬编码) +# 配置 +NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 +TRAIN_BATCH_SIZE=32 +NUM_STEPS=6 # 每个样本step轮数 + +ADDR="22.17.31.142" +MCP_PORT="8040" + +# 修改:配置文件生成路径,现在动态生成到 yaml 目录下 +export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" +CONFIG_FILE="${AJET_ROOT}/tutorial/example_finworld/yaml/finworld_${SUFFIX}.yaml" + +#=============================================================================== +# 环境配置区域 +#=============================================================================== + +cd ${AJET_ROOT} +source .venv/bin/activate +# API密钥配置 - 从 .env 文件加载 +ENV_FILE="${AJET_ROOT}/.env" +if [ -f "$ENV_FILE" ]; then + set -a + source "$ENV_FILE" + set +a + echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" +else + echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" +fi + +# MongoDB 缓存配置 +CACHE_TYPE="mongodb" +MONGO_URI="mongodb://${ADDR}:27117/" +MONGO_DB_NAME="finworld_cache" +MONGO_COLLECTION_NAME="tool_cache" + +# FinWorld MCP 配置 +LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" +FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" + +# 动态生成 MCP 配置文件 +mkdir -p $(dirname ${FINWORLD_MCP_CONFIG}) +cat > ${FINWORLD_MCP_CONFIG} << EOF +{ + "mcpServers": { + "flowllm": { + "transport": "sse", + "url": "http://${ADDR}:${MCP_PORT}/sse", + "timeout": 600, + "sse_read_timeout": 1200 + } + } +} +EOF +FINWORLD_TOOL_RESULT_MAX_CHARS=10000 + +# 其他服务配置 +HF_ENDPOINT="https://hf-mirror.com" +ES_HOSTS="http://11.160.132.46:8200" + +#=============================================================================== +# 多机训练参数配置 +#=============================================================================== +if [ -z "${WORLD_SIZE}" ]; then + echo "ERROR: WORLD_SIZE environment variable is not set!" + echo "Please ensure this script is run in a multi-node environment (e.g., PAI-DLC, SLURM)" + exit 1 +fi + +NNODES=${WORLD_SIZE} +GPUS_PER_NODE=8 +EXPECTED_WORKERS=$WORLD_SIZE + +#=============================================================================== +# NCCL 配置 +#=============================================================================== +export NCCL_TIMEOUT=1800 +export NCCL_DEBUG=WARN +export NCCL_IB_TIMEOUT=23 +export NCCL_ASYNC_ERROR_HANDLING=1 + +#=============================================================================== +# 自动生成的变量 +#=============================================================================== +CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") + +MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" +ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" +TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" + +#=============================================================================== +# 工具函数 +#=============================================================================== +print_green() { + echo -e "\033[32m$1\033[0m" +} + +print_red() { + echo -e "\033[31m$1\033[0m" +} + +log() { + echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" +} + +check_workers() { + local status_output=$(ray status 2>/dev/null) + if [ -z "$status_output" ]; then echo 0; return; fi + local node_count=$(echo "$status_output" | grep -E "^[[:space:]]*1[[:space:]]+node_" | wc -l) + if [ "$node_count" -gt 0 ]; then echo $node_count; return; fi + echo $(echo "$status_output" | grep -o "node_[0-9a-f]\+" | sort -u | wc -l) +} + +check_gpu_resources() { + gpu_count=$(ray status 2>/dev/null | grep -A 10 "Resources" | grep "GPU" | awk '{print $1}' | cut -d'/' -f2) + if [ -z "$gpu_count" ]; then echo 0; else printf "%.0f" "$gpu_count"; fi +} + +#=============================================================================== +# 导出环境变量 +#=============================================================================== +export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME +export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS +export HF_ENDPOINT ES_HOSTS +export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" +export RAY_CLUSTER_MODE="multi_node" +# Directory paths +export ENV_SERVICE_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/mongodb/BeyondAgent_env" + +export FINWORLD_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 +export FINWORLD_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" + +#=============================================================================== +# 主流程 +#=============================================================================== +log "开始多机多卡训练: ${SUFFIX}" +log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" +mkdir -p ${LOG_DIR} +mkdir -p $(dirname ${CONFIG_FILE}) + +#=============================================================================== +# Master 节点启动流程 +#=============================================================================== +if [[ $HOSTNAME == *"-master-"* ]]; then + print_green "==> This is MASTER node: $HOSTNAME" + + #--------------------------------------------------------------------------- + # 1. 动态生成配置文件 (从模板注入参数) + #--------------------------------------------------------------------------- + log "正在从模板生成配置文件..." + sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ + -e "s|{{PREFIX}}|${PREFIX}|g" \ + -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ + -e "s|{{NNODES}}|${NNODES}|g" \ + -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ + -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ + -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ + -e "s|{{RM_LLM}}|${RM_LLM}|g" \ + -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ + -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ + -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ + -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ + -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ + -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ + -e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \ + -e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \ + -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ + -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ + ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} + + print_green "配置文件已生成: ${CONFIG_FILE}" + print_green "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" + + #--------------------------------------------------------------------------- + # 2. 清理和初始化 Ray + #--------------------------------------------------------------------------- + rm -f "$MASTER_IP_FILE" + ray stop --force || true + sleep 3 + + #--------------------------------------------------------------------------- + # 4. 启动 Ray Head + #--------------------------------------------------------------------------- + print_green "Starting Ray head node at $MASTER_ADDR" + ray start --head --node-ip-address $MASTER_ADDR --num-gpus 8 + sleep 10 + echo $MASTER_ADDR > $MASTER_IP_FILE + + #--------------------------------------------------------------------------- + # 5 & 6. 等待节点和 GPU 就绪 (逻辑保持不变) + #--------------------------------------------------------------------------- + # ... (此处省略重复的等待逻辑以保持简洁,实际运行时请保留原脚本中的 while 循环) ... + # [请保留原脚本中 5.等待所有Worker 6.等待GPU 7.等待Dashboard 的完整代码] + + #--------------------------------------------------------------------------- + # 9. 启动训练任务 + #--------------------------------------------------------------------------- + print_green "Starting training job..." + source .venv/bin/activate + + export RAY_ADDRESS="ray://localhost:10001" + export env_url="http://${MASTER_ADDR}:8080" + export env_type="finworld" + + print_green "===================================" + print_green "Training Configuration" + print_green "Total GPUs: $((NNODES * GPUS_PER_NODE))" + print_green "Log: ${TRAIN_LOG}" + print_green "===================================" + + # 启动训练任务 + python ajet/launcher.py \ + --with-finworld \ + --conf ${CONFIG_FILE} \ + --backbone="verl" \ + --debug="TAG_A" \ + 2>&1 | tee ${TRAIN_LOG} + + # 保留原脚本末尾的 CLI 调用 + ajet --conf ${CONFIG_FILE} --backbone='verl' + +#=============================================================================== +# Worker 节点启动流程 (逻辑保持不变) +#=============================================================================== +else + print_green "==> This is WORKER node: $HOSTNAME" + # [此处保留原脚本中 Worker 节点等待 Master IP 和连接 Ray 的完整逻辑] + while [ ! -f $MASTER_IP_FILE ]; do sleep 5; done + MASTER_ADDR=$(cat $MASTER_IP_FILE) + ray stop || true + ray start --address $MASTER_ADDR:6379 --num-gpus 8 + while true; do sleep 60; done +fi \ No newline at end of file From f20ab91a001c2eb564fff2e155ad91cdf387a0a8 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Mon, 19 Jan 2026 18:11:54 +0800 Subject: [PATCH 13/56] feat(finworld): Added support for multi-machine, multi-GPU training scripts and configuration templates: --- .../example_finworld/scripts/ajet_finworld.sh | 264 ++++++++++++++++++ .../yaml/finworld_ajet_finworld.yaml | 17 +- 2 files changed, 275 insertions(+), 6 deletions(-) create mode 100644 tutorial/example_finworld/scripts/ajet_finworld.sh diff --git a/tutorial/example_finworld/scripts/ajet_finworld.sh b/tutorial/example_finworld/scripts/ajet_finworld.sh new file mode 100644 index 00000000..5a427e52 --- /dev/null +++ b/tutorial/example_finworld/scripts/ajet_finworld.sh @@ -0,0 +1,264 @@ +#!/bin/bash +set -e +#=============================================================================== +# 配置区域 - 用户只需修改这里 +#=============================================================================== +SUFFIX="ajet_finworld" # 实验后缀,影响所有日志和实验名称 +PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 + +# 新增:模型与模板配置 +MODEL_PATH="/mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507" +CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_template.yaml" + +# 新增:数据文件路径配置 +TRAIN_DATA_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/train_cc423_11171143_tasks.json" +VAL_DATA_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/val_30_tasks.json" + +# 新增:Reference Answer 文件路径配置(RM Gallery 需要) +TRAIN_REF_ANS_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_DR_11171143_cc.json" +VAL_REF_ANS_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_val.json" + +# 新增:Judge 模型配置 +OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 +RM_LLM='qwen-max' # RM Gallery 评分模型 +JUDGE_CONCURRENCY=10 + +# 新增:奖励权重配置 +RM_WEIGHT=0.4 +CITATION_AUDIT_WEIGHT=0.2 +REPORT_RESOLUTION_WEIGHT=0.2 +TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 + +# API密钥配置(从 .env 文件加载,不要硬编码) +# 配置 +NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 +TRAIN_BATCH_SIZE=32 +NUM_STEPS=6 # 每个样本step轮数 + +ADDR="22.17.31.142" +MCP_PORT="8040" + +# 修改:配置文件生成路径,现在动态生成到 yaml 目录下 +export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" +CONFIG_FILE="${AJET_ROOT}/tutorial/example_finworld/yaml/finworld_${SUFFIX}.yaml" + +#=============================================================================== +# 环境配置区域 +#=============================================================================== + +cd ${AJET_ROOT} +source .venv/bin/activate +# API密钥配置 - 从 .env 文件加载 +ENV_FILE="${AJET_ROOT}/.env" +if [ -f "$ENV_FILE" ]; then + set -a + source "$ENV_FILE" + set +a + echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" +else + echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" +fi + +# MongoDB 缓存配置 +CACHE_TYPE="mongodb" +MONGO_URI="mongodb://${ADDR}:27117/" +MONGO_DB_NAME="finworld_cache" +MONGO_COLLECTION_NAME="tool_cache" + +# FinWorld MCP 配置 +LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" +FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" + +# 动态生成 MCP 配置文件 +mkdir -p $(dirname ${FINWORLD_MCP_CONFIG}) +cat > ${FINWORLD_MCP_CONFIG} << EOF +{ + "mcpServers": { + "flowllm": { + "transport": "sse", + "url": "http://${ADDR}:${MCP_PORT}/sse", + "timeout": 600, + "sse_read_timeout": 1200 + } + } +} +EOF +FINWORLD_TOOL_RESULT_MAX_CHARS=10000 + +# 其他服务配置 +HF_ENDPOINT="https://hf-mirror.com" +ES_HOSTS="http://11.160.132.46:8200" + +#=============================================================================== +# 多机训练参数配置 +#=============================================================================== +if [ -z "${WORLD_SIZE}" ]; then + echo "ERROR: WORLD_SIZE environment variable is not set!" + echo "Please ensure this script is run in a multi-node environment (e.g., PAI-DLC, SLURM)" + exit 1 +fi + +NNODES=${WORLD_SIZE} +GPUS_PER_NODE=8 +EXPECTED_WORKERS=$WORLD_SIZE + +#=============================================================================== +# NCCL 配置 +#=============================================================================== +export NCCL_TIMEOUT=1800 +export NCCL_DEBUG=WARN +export NCCL_IB_TIMEOUT=23 +export NCCL_ASYNC_ERROR_HANDLING=1 + +#=============================================================================== +# 自动生成的变量 +#=============================================================================== +CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") + +MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" +ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" +TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" + +#=============================================================================== +# 工具函数 +#=============================================================================== +print_green() { + echo -e "\033[32m$1\033[0m" +} + +print_red() { + echo -e "\033[31m$1\033[0m" +} + +log() { + echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" +} + +check_workers() { + local status_output=$(ray status 2>/dev/null) + if [ -z "$status_output" ]; then echo 0; return; fi + local node_count=$(echo "$status_output" | grep -E "^[[:space:]]*1[[:space:]]+node_" | wc -l) + if [ "$node_count" -gt 0 ]; then echo $node_count; return; fi + echo $(echo "$status_output" | grep -o "node_[0-9a-f]\+" | sort -u | wc -l) +} + +check_gpu_resources() { + gpu_count=$(ray status 2>/dev/null | grep -A 10 "Resources" | grep "GPU" | awk '{print $1}' | cut -d'/' -f2) + if [ -z "$gpu_count" ]; then echo 0; else printf "%.0f" "$gpu_count"; fi +} + +#=============================================================================== +# 导出环境变量 +#=============================================================================== +export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME +export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS +export HF_ENDPOINT ES_HOSTS +export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" +export RAY_CLUSTER_MODE="multi_node" +# Directory paths +export ENV_SERVICE_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/mongodb/BeyondAgent_env" + +export FINWORLD_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 +export FINWORLD_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" + +#=============================================================================== +# 主流程 +#=============================================================================== +log "开始多机多卡训练: ${SUFFIX}" +log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" +mkdir -p ${LOG_DIR} +mkdir -p $(dirname ${CONFIG_FILE}) + +#=============================================================================== +# Master 节点启动流程 +#=============================================================================== +if [[ $HOSTNAME == *"-master-"* ]]; then + print_green "==> This is MASTER node: $HOSTNAME" + + #--------------------------------------------------------------------------- + # 1. 动态生成配置文件 (从模板注入参数) + #--------------------------------------------------------------------------- + log "正在从模板生成配置文件..." + sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ + -e "s|{{PREFIX}}|${PREFIX}|g" \ + -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ + -e "s|{{NNODES}}|${NNODES}|g" \ + -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ + -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ + -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ + -e "s|{{RM_LLM}}|${RM_LLM}|g" \ + -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ + -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ + -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ + -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ + -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ + -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ + -e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \ + -e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \ + -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ + -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ + ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} + + print_green "配置文件已生成: ${CONFIG_FILE}" + print_green "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" + + #--------------------------------------------------------------------------- + # 2. 清理和初始化 Ray + #--------------------------------------------------------------------------- + rm -f "$MASTER_IP_FILE" + ray stop --force || true + sleep 3 + + #--------------------------------------------------------------------------- + # 4. 启动 Ray Head + #--------------------------------------------------------------------------- + print_green "Starting Ray head node at $MASTER_ADDR" + ray start --head --node-ip-address $MASTER_ADDR --num-gpus 8 + sleep 10 + echo $MASTER_ADDR > $MASTER_IP_FILE + + #--------------------------------------------------------------------------- + # 5 & 6. 等待节点和 GPU 就绪 (逻辑保持不变) + #--------------------------------------------------------------------------- + # ... (此处省略重复的等待逻辑以保持简洁,实际运行时请保留原脚本中的 while 循环) ... + # [请保留原脚本中 5.等待所有Worker 6.等待GPU 7.等待Dashboard 的完整代码] + + #--------------------------------------------------------------------------- + # 9. 启动训练任务 + #--------------------------------------------------------------------------- + print_green "Starting training job..." + source .venv/bin/activate + + export RAY_ADDRESS="ray://localhost:10001" + export env_url="http://${MASTER_ADDR}:8080" + export env_type="finworld" + + print_green "===================================" + print_green "Training Configuration" + print_green "Total GPUs: $((NNODES * GPUS_PER_NODE))" + print_green "Log: ${TRAIN_LOG}" + print_green "===================================" + + # 启动训练任务 + python ajet/launcher.py \ + --with-finworld \ + --conf ${CONFIG_FILE} \ + --backbone="verl" \ + --debug="TAG_A" \ + 2>&1 | tee ${TRAIN_LOG} + + # 保留原脚本末尾的 CLI 调用 + ajet --conf ${CONFIG_FILE} --backbone='verl' + +#=============================================================================== +# Worker 节点启动流程 (逻辑保持不变) +#=============================================================================== +else + print_green "==> This is WORKER node: $HOSTNAME" + # [此处保留原脚本中 Worker 节点等待 Master IP 和连接 Ray 的完整逻辑] + while [ ! -f $MASTER_IP_FILE ]; do sleep 5; done + MASTER_ADDR=$(cat $MASTER_IP_FILE) + ray stop || true + ray start --address $MASTER_ADDR:6379 --num-gpus 8 + while true; do sleep 60; done +fi \ No newline at end of file diff --git a/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml b/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml index 16e5b6eb..08e17a12 100644 --- a/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml +++ b/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml @@ -7,6 +7,8 @@ ajet: openjudge_llm: qwen-flash # OpenJudge 模型 rm_llm: qwen-max # RM Gallery 模型 concurrency: 10 # Judge 并发数 + train_ref_ans_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_DR_11171143_cc.json # 训练集 Reference Answer 路径 + val_ref_ans_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_val.json # 验证集 Reference Answer 路径 # OpenJudge 权重配置 report_resolution_weight: 0.2 # 报告质量评估 trajectory_faithfulness_weight: 0.2 # 事实准确性评估 @@ -19,7 +21,7 @@ ajet: # ✨✨✨✨ 设置待训练的模型 path: /mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507 trainer_common: - nnodes: 8 + nnodes: 4 n_gpus_per_node: 8 val_before_train: True val_pass_n: 8 @@ -35,7 +37,6 @@ ajet: num_repeat: 4 max_env_worker: 64 # 增加环境并行数 max_num_seqs: 64 # 增加VLLM并发序列数 - max_env_len: 10000 max_response_length_in_one_turn: 8000 max_model_len: 50000 agent_madness_reward: 0.0 @@ -53,13 +54,17 @@ ajet: max_response_length: 41000 task_reader: - type: env_service # `env_service` or `dataset_file` or `huggingface_dat_repo` + type: finworld # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service + finworld: + training: + file_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/train_cc423_11171143_tasks.json + validation: + file_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/val_30_tasks.json + # env_service 仍需配置(用于工具调用) env_service: env_type: "finworld" env_url: "http://127.0.0.1:8080" - env_action_preference: code # code, text, box - training_split: train - validation_split: val + env_action_preference: code trainer: default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/ajet/checkpoints/example_finworld//open/ajet_finworld" # resume_mode: disable # 禁用自动恢复,从头开始训练 From ea87d4b5fdece169d7628b1791da0e3386b14c00 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 10:51:14 +0800 Subject: [PATCH 14/56] chore(git): ignore finworld/yaml/* --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index c16d08c4..c63a9c4d 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,6 @@ datasets tutorial2 site dump.rdb + + +tutorial/example_finworld/yaml/* \ No newline at end of file From 3082bca93a3a977ea177ebd0c7c1b9a49c1f3d6e Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 14:14:48 +0800 Subject: [PATCH 15/56] fix(metrics): Fix and enhance the compatibility and debugging output of the metrics update logic - Modified the `update_metrics` function, adding a `prefix` parameter to distinguish between training and validation metrics. - Adjusted the data source for extracting `reward_stats` and `tool_stats`, migrating from `workflow_metadata` to `log_metrics`. - Added debug printing to output the `log_metrics` content and metric key names at key steps for easier troubleshooting. - Used the appropriate prefix when calling `update_metrics` in `trainer_verl.py`, and added multiple debug prints. - Modified `WorkflowOutput` to place `tool_stats` and `reward_stats` into the `log_metrics` field. - Removed redundant and deprecated code for extracting `reward_stats` and calculation functions. - Added debug information output to the `finworld` and `finworld_judge` modules to track log metrics and scoring data. --- ajet/backbone/trainer_verl.py | 11 +++- ajet/schema/task.py | 2 +- ajet/utils/metric_helper/__init__.py | 17 +++++- .../metric_helper/reward_metric_helper.py | 60 ++----------------- .../utils/metric_helper/tool_metric_helper.py | 14 ++--- tutorial/example_finworld/finworld.py | 11 +++- tutorial/example_finworld/finworld_judge.py | 3 + 7 files changed, 49 insertions(+), 69 deletions(-) diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 5b9d0853..13b7a204 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -603,7 +603,7 @@ def fit(self): # noqa: C901 } ) save_trajectory_as_json_file(context_tracker_arr, self.global_steps, self.config, prefix="train") - update_metrics(context_tracker_arr, metrics) + update_metrics(context_tracker_arr, metrics, prefix="train_") if self.config.ajet.execute_test: # apply a test probe from swanlab.data.run.main import get_run @@ -1047,7 +1047,14 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch): "mean_reward": sum(rewards) / len(rewards) if rewards else 0, } save_trajectory_as_json_file(ctx_trackers, self.global_steps, self.config, prefix="eval") - update_metrics(ctx_trackers, val_metrics) + print(f"[DEBUG trainer_verl] Before update_metrics: num_ctx_trackers={len(ctx_trackers)}") + for i, ct in enumerate(ctx_trackers[:3]): + has_lm = hasattr(ct, 'log_metrics') and ct.log_metrics + print(f"[DEBUG trainer_verl] ctx_trackers[{i}].log_metrics exists: {has_lm}") + if has_lm: + print(f"[DEBUG trainer_verl] ctx_trackers[{i}].log_metrics keys: {list(ct.log_metrics.keys())}") + update_metrics(ctx_trackers, val_metrics, prefix="eval_") + print(f"[DEBUG trainer_verl] After update_metrics: val_metrics keys containing 'tool_' or 'reward': {[k for k in val_metrics.keys() if 'tool_' in k or 'reward' in k][:10]}") print_dict( val_metrics, narrow=True, diff --git a/ajet/schema/task.py b/ajet/schema/task.py index 6d94796c..a20a4b59 100644 --- a/ajet/schema/task.py +++ b/ajet/schema/task.py @@ -43,4 +43,4 @@ class WorkflowOutput(BaseModel): reward: Union[float, List[float], None] = Field(default=None) is_success: Union[bool, None] = Field(default=None) metadata: Dict[str, Any] = Field(default_factory=dict) - log_metrics: Dict[str, Union[float, List[float]]] = Field(default_factory=dict) + log_metrics: Dict[str, Union[float, List[float], Dict[str, Any]]] = Field(default_factory=dict) diff --git a/ajet/utils/metric_helper/__init__.py b/ajet/utils/metric_helper/__init__.py index 70ce2818..e3253220 100644 --- a/ajet/utils/metric_helper/__init__.py +++ b/ajet/utils/metric_helper/__init__.py @@ -7,9 +7,20 @@ def save_trajectory_as_json_file(ctx_trackers, global_steps, config, prefix): if config.ajet.trainer_common.save_trajectory_as_json_file: save_trajectory_as_json(ctx_trackers, global_steps, prefix) -def update_metrics(context_tracker_arr, metrics:dict): - tool_metrics = compute_tool_metrics_from_trajectories(context_tracker_arr) - reward_metrics = compute_reward_metrics_from_trajectories(context_tracker_arr) +def update_metrics(context_tracker_arr, metrics:dict, prefix): + # Debug: Check log_metrics content + print(f"[update_metrics] called with prefix={prefix}, num_trackers={len(context_tracker_arr)}") + for i, traj in enumerate(context_tracker_arr[:3]): # Check first 3 + has_log_metrics = hasattr(traj, 'log_metrics') and traj.log_metrics + print(f"[update_metrics] traj[{i}] has log_metrics: {has_log_metrics}") + if has_log_metrics: + print(f"[update_metrics] traj[{i}].log_metrics keys: {list(traj.log_metrics.keys())}") + + tool_metrics = compute_tool_metrics_from_trajectories(context_tracker_arr, prefix) + reward_metrics = compute_reward_metrics_from_trajectories(context_tracker_arr, prefix) + + print(f"[update_metrics] tool_metrics count: {len(tool_metrics)}, reward_metrics count: {len(reward_metrics)}") + if tool_metrics: metrics.update(tool_metrics) if reward_metrics: diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py index 49e069bf..31e1f95a 100644 --- a/ajet/utils/metric_helper/reward_metric_helper.py +++ b/ajet/utils/metric_helper/reward_metric_helper.py @@ -20,45 +20,19 @@ def extract_reward_stats_from_trajectories(trajectories: List[Any]) -> List[Dict Extract reward_stats from trajectories list. Args: - trajectories: List of trajectory objects containing workflow_metadata + trajectories: List of trajectory objects containing log_metrics Returns: List of reward_stats dictionaries """ reward_stats_list = [] for traj in trajectories: - if hasattr(traj, 'workflow_metadata') and traj.workflow_metadata: - if 'reward_stats' in traj.workflow_metadata: - reward_stats_list.append(traj.workflow_metadata['reward_stats']) + if hasattr(traj, 'log_metrics') and traj.log_metrics: + if 'reward_stats' in traj.log_metrics: + reward_stats_list.append(traj.log_metrics['reward_stats']) return reward_stats_list -def extract_reward_stats_from_cmts(cmts: List[Any]) -> tuple[List[Dict[str, Any]], Dict[str, int]]: - """ - Extract reward_stats from cmts list and return debug statistics. - - Args: - cmts: List of cmt objects containing workflow_metadata - - Returns: - Tuple of (reward_stats_list, debug_stats) - """ - reward_stats_list = [] - debug_stats = { - 'total_cmts': len(cmts), - 'has_workflow_metadata': 0, - 'has_reward_stats': 0, - } - - for _cmt in cmts: - if hasattr(_cmt, 'workflow_metadata') and _cmt.workflow_metadata: - debug_stats['has_workflow_metadata'] += 1 - if 'reward_stats' in _cmt.workflow_metadata: - debug_stats['has_reward_stats'] += 1 - reward_stats_list.append(_cmt.workflow_metadata['reward_stats']) - - return reward_stats_list, debug_stats - def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str = "") -> Dict[str, float]: """ @@ -194,7 +168,7 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str return metrics -def compute_reward_metrics_from_trajectories(trajectories: List[Any]) -> Dict[str, float]: +def compute_reward_metrics_from_trajectories(trajectories: List[Any], prefix: str = "") -> Dict[str, float]: """ Training phase: Extract reward_stats from trajectories and compute metrics. @@ -205,27 +179,5 @@ def compute_reward_metrics_from_trajectories(trajectories: List[Any]) -> Dict[st Formatted metrics dictionary """ reward_stats_list = extract_reward_stats_from_trajectories(trajectories) - return compute_reward_metrics(reward_stats_list, prefix="train_") - - -def compute_reward_metrics_from_cmts(cmts: List[Any], print_debug: bool = True) -> Dict[str, float]: - """ - Validation phase: Extract reward_stats from cmts and compute metrics. - - Args: - cmts: List of cmt objects - print_debug: Whether to print debug information - - Returns: - Formatted metrics dictionary (with "val_reward/" prefix) - """ - reward_stats_list, debug_stats = extract_reward_stats_from_cmts(cmts) - - if print_debug: - print(f"\n[DEBUG eval_dataset()] reward_stats statistics:") - print(f" - Total cmts count: {debug_stats['total_cmts']}") - print(f" - Has workflow_metadata: {debug_stats['has_workflow_metadata']}") - print(f" - Has reward_stats: {debug_stats['has_reward_stats']}") - print(f" - Extracted samples count: {len(reward_stats_list)}") + return compute_reward_metrics(reward_stats_list, prefix=prefix) - return compute_reward_metrics(reward_stats_list, prefix="val_") diff --git a/ajet/utils/metric_helper/tool_metric_helper.py b/ajet/utils/metric_helper/tool_metric_helper.py index 51a488b8..03b3ed01 100644 --- a/ajet/utils/metric_helper/tool_metric_helper.py +++ b/ajet/utils/metric_helper/tool_metric_helper.py @@ -2,7 +2,7 @@ FinWorld Tool Metrics Helper Specialized module for extracting tool-related statistics and formatting SwanLab reports. -Extracts data from workflow_metadata['tool_stats']. +Extracts data from log_metrics['tool_stats']. SwanLab metrics directory structure: - tool_stats/ Overall statistics (success rate, cache hit rate, etc.) @@ -20,16 +20,16 @@ def extract_tool_stats_from_trajectories(trajectories: List[Any]) -> List[Dict[s Extract tool_stats from trajectories list. Args: - trajectories: List of trajectory objects containing workflow_metadata + trajectories: List of trajectory objects containing log_metrics Returns: List of tool_stats dictionaries """ tool_stats_list = [] for traj in trajectories: - if hasattr(traj, 'workflow_metadata') and traj.workflow_metadata: - if 'tool_stats' in traj.workflow_metadata: - tool_stats_list.append(traj.workflow_metadata['tool_stats']) + if hasattr(traj, 'log_metrics') and traj.log_metrics: + if 'tool_stats' in traj.log_metrics: + tool_stats_list.append(traj.log_metrics['tool_stats']) return tool_stats_list @@ -134,11 +134,11 @@ def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "" return metrics -def compute_tool_metrics_from_trajectories(trajectories: List[Any]) -> Dict[str, float]: +def compute_tool_metrics_from_trajectories(trajectories: List[Any], prefix: str = "") -> Dict[str, float]: """ Training phase: Extract tool_stats from trajectories and compute metrics. """ tool_stats_list = extract_tool_stats_from_trajectories(trajectories) - return compute_tool_metrics(tool_stats_list, prefix="train_") + return compute_tool_metrics(tool_stats_list, prefix=prefix) diff --git a/tutorial/example_finworld/finworld.py b/tutorial/example_finworld/finworld.py index f742adfc..1d2d1b8a 100644 --- a/tutorial/example_finworld/finworld.py +++ b/tutorial/example_finworld/finworld.py @@ -219,15 +219,22 @@ async def execute( logger.info(f" 成功率: {final_tool_stats.get('success_rate', 0):.2f}%") logger.info(f"{'='*80}\n") + # Debug: print log_metrics before return + print(f"[DEBUG finworld.py] Returning WorkflowOutput with log_metrics keys: {list({'tool_stats': final_tool_stats, 'reward_stats': latest_reward_stats}.keys())}") + print(f"[DEBUG finworld.py] tool_stats keys: {list(final_tool_stats.keys()) if final_tool_stats else 'None'}") + print(f"[DEBUG finworld.py] reward_stats keys: {list(latest_reward_stats.keys()) if latest_reward_stats else 'None'}") + return WorkflowOutput( reward=None, metadata={ "total_step": step, - "tool_stats": final_tool_stats, - "reward_stats": latest_reward_stats, "tool_success_rate": round(final_tool_stats.get('success_rate', 0.0), 2), "conversation_history": conversation_history, "query": workflow_task.task.main_query, "task_id": workflow_task.task.task_id, + }, + log_metrics={ + "tool_stats": final_tool_stats, + "reward_stats": latest_reward_stats, } ) \ No newline at end of file diff --git a/tutorial/example_finworld/finworld_judge.py b/tutorial/example_finworld/finworld_judge.py index 5cdaf3f3..42632e53 100644 --- a/tutorial/example_finworld/finworld_judge.py +++ b/tutorial/example_finworld/finworld_judge.py @@ -387,6 +387,9 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO "grading_time": grading_time, "judge_total_time": judge_total_time, } + print(f"[DEBUG finworld_judge] Before _update_metadata_stats: task_id={task_id}, final_reward={final_reward:.4f}") + print(f"[DEBUG finworld_judge] grader_scores: {grader_scores}") + print(f"[DEBUG finworld_judge] contributions: {contributions}") self._update_metadata_stats( metadata=metadata, final_reward=final_reward, From ef44b63a8f50e87fe11c582d9169848e47a58fa9 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 15:16:26 +0800 Subject: [PATCH 16/56] fix(metrics): Remove debug prints and synchronize reward statistics - Removed debug print statements before and after the `update_metrics` call in `trainer_verl.py` - Removed debug print statements related to the `log_metrics` key in `finworld.py` - Removed debug print statements before updating `metadata_stats` in `finworld_judge.py` - Added logic in `general_runner.py` to synchronize `reward_stats` from `metadata` to `log_metrics` after the judge calculation - Cleaned up debug print statements within `update_metrics` in `metric_helper`, improving code readability. --- ajet/backbone/trainer_verl.py | 7 ------- ajet/task_runner/general_runner.py | 6 ++++++ ajet/utils/metric_helper/__init__.py | 11 ----------- tutorial/example_finworld/finworld.py | 5 ----- tutorial/example_finworld/finworld_judge.py | 3 --- 5 files changed, 6 insertions(+), 26 deletions(-) diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 13b7a204..cb573457 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -1047,14 +1047,7 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch): "mean_reward": sum(rewards) / len(rewards) if rewards else 0, } save_trajectory_as_json_file(ctx_trackers, self.global_steps, self.config, prefix="eval") - print(f"[DEBUG trainer_verl] Before update_metrics: num_ctx_trackers={len(ctx_trackers)}") - for i, ct in enumerate(ctx_trackers[:3]): - has_lm = hasattr(ct, 'log_metrics') and ct.log_metrics - print(f"[DEBUG trainer_verl] ctx_trackers[{i}].log_metrics exists: {has_lm}") - if has_lm: - print(f"[DEBUG trainer_verl] ctx_trackers[{i}].log_metrics keys: {list(ct.log_metrics.keys())}") update_metrics(ctx_trackers, val_metrics, prefix="eval_") - print(f"[DEBUG trainer_verl] After update_metrics: val_metrics keys containing 'tool_' or 'reward': {[k for k in val_metrics.keys() if 'tool_' in k or 'reward' in k][:10]}") print_dict( val_metrics, narrow=True, diff --git a/ajet/task_runner/general_runner.py b/ajet/task_runner/general_runner.py index 7ea76710..ef6d9f64 100644 --- a/ajet/task_runner/general_runner.py +++ b/ajet/task_runner/general_runner.py @@ -54,6 +54,12 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: ) else: raw_reward, is_success = self.get_judge().compute_reward(workflow_task, workflow_output) + # Sync reward_stats from metadata to log_metrics after judge computation + print(f"[DEBUG general_runner] After judge: metadata has 'reward_stats': {'reward_stats' in workflow_output.metadata}") + if "reward_stats" in workflow_output.metadata: + print(f"[DEBUG general_runner] metadata['reward_stats'] keys: {list(workflow_output.metadata['reward_stats'].keys())[:5]}") + workflow_output.log_metrics["reward_stats"] = workflow_output.metadata["reward_stats"] + print(f"[DEBUG general_runner] Synced to log_metrics successfully") workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue diff --git a/ajet/utils/metric_helper/__init__.py b/ajet/utils/metric_helper/__init__.py index e3253220..a0475743 100644 --- a/ajet/utils/metric_helper/__init__.py +++ b/ajet/utils/metric_helper/__init__.py @@ -8,19 +8,8 @@ def save_trajectory_as_json_file(ctx_trackers, global_steps, config, prefix): save_trajectory_as_json(ctx_trackers, global_steps, prefix) def update_metrics(context_tracker_arr, metrics:dict, prefix): - # Debug: Check log_metrics content - print(f"[update_metrics] called with prefix={prefix}, num_trackers={len(context_tracker_arr)}") - for i, traj in enumerate(context_tracker_arr[:3]): # Check first 3 - has_log_metrics = hasattr(traj, 'log_metrics') and traj.log_metrics - print(f"[update_metrics] traj[{i}] has log_metrics: {has_log_metrics}") - if has_log_metrics: - print(f"[update_metrics] traj[{i}].log_metrics keys: {list(traj.log_metrics.keys())}") - tool_metrics = compute_tool_metrics_from_trajectories(context_tracker_arr, prefix) reward_metrics = compute_reward_metrics_from_trajectories(context_tracker_arr, prefix) - - print(f"[update_metrics] tool_metrics count: {len(tool_metrics)}, reward_metrics count: {len(reward_metrics)}") - if tool_metrics: metrics.update(tool_metrics) if reward_metrics: diff --git a/tutorial/example_finworld/finworld.py b/tutorial/example_finworld/finworld.py index 1d2d1b8a..a911c5fd 100644 --- a/tutorial/example_finworld/finworld.py +++ b/tutorial/example_finworld/finworld.py @@ -219,11 +219,6 @@ async def execute( logger.info(f" 成功率: {final_tool_stats.get('success_rate', 0):.2f}%") logger.info(f"{'='*80}\n") - # Debug: print log_metrics before return - print(f"[DEBUG finworld.py] Returning WorkflowOutput with log_metrics keys: {list({'tool_stats': final_tool_stats, 'reward_stats': latest_reward_stats}.keys())}") - print(f"[DEBUG finworld.py] tool_stats keys: {list(final_tool_stats.keys()) if final_tool_stats else 'None'}") - print(f"[DEBUG finworld.py] reward_stats keys: {list(latest_reward_stats.keys()) if latest_reward_stats else 'None'}") - return WorkflowOutput( reward=None, metadata={ diff --git a/tutorial/example_finworld/finworld_judge.py b/tutorial/example_finworld/finworld_judge.py index 42632e53..5cdaf3f3 100644 --- a/tutorial/example_finworld/finworld_judge.py +++ b/tutorial/example_finworld/finworld_judge.py @@ -387,9 +387,6 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO "grading_time": grading_time, "judge_total_time": judge_total_time, } - print(f"[DEBUG finworld_judge] Before _update_metadata_stats: task_id={task_id}, final_reward={final_reward:.4f}") - print(f"[DEBUG finworld_judge] grader_scores: {grader_scores}") - print(f"[DEBUG finworld_judge] contributions: {contributions}") self._update_metadata_stats( metadata=metadata, final_reward=final_reward, From 088948320f22fd7015bc154c23de06009b939a92 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 16:58:22 +0800 Subject: [PATCH 17/56] chore: "Stop tracking existing yaml files in tutorial directory" --- .../yaml/finworld_ajet_finworld.yaml | 87 ------------------- .../finworld_ajet_finworld_loadjsonl_8b.yaml | 87 ------------------- 2 files changed, 174 deletions(-) delete mode 100644 tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml delete mode 100644 tutorial/example_finworld/yaml/finworld_ajet_finworld_loadjsonl_8b.yaml diff --git a/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml b/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml deleted file mode 100644 index 08e17a12..00000000 --- a/tutorial/example_finworld/yaml/finworld_ajet_finworld.yaml +++ /dev/null @@ -1,87 +0,0 @@ -# ------------------ 主要配置 ------------------ -ajet: - project_name: ajet_finworld - experiment_name: "ajet_finworld" - # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) - judge: - openjudge_llm: qwen-flash # OpenJudge 模型 - rm_llm: qwen-max # RM Gallery 模型 - concurrency: 10 # Judge 并发数 - train_ref_ans_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_DR_11171143_cc.json # 训练集 Reference Answer 路径 - val_ref_ans_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_val.json # 验证集 Reference Answer 路径 - # OpenJudge 权重配置 - report_resolution_weight: 0.2 # 报告质量评估 - trajectory_faithfulness_weight: 0.2 # 事实准确性评估 - citation_audit_weight: 0.2 # 引用审计评估 (覆盖率 + 真实性) - rm_weight: 0.4 # RM Gallery 权重 - task_judge: - # 使用本地 FinWorldJudge 进行评估(解耦远程 env_service) - judge_protocol: tutorial.example_finworld.finworld_judge->FinWorldJudgeByOpenJudge - model: - # ✨✨✨✨ 设置待训练的模型 - path: /mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507 - trainer_common: - nnodes: 4 - n_gpus_per_node: 8 - val_before_train: True - val_pass_n: 8 - save_freq: 10 - test_freq: 2 - total_epochs: 200 - rollout: - # ✨✨✨✨ 编写并选择Agent - user_workflow: tutorial.example_finworld.finworld->ExampleDeepResearchProtocol - force_disable_toolcalls: True - enable_oversample: False - tensor_model_parallel_size: 8 - num_repeat: 4 - max_env_worker: 64 # 增加环境并行数 - max_num_seqs: 64 # 增加VLLM并发序列数 - max_response_length_in_one_turn: 8000 - max_model_len: 50000 - agent_madness_reward: 0.0 - compute_madness_checklist: None - multi_turn: - max_steps: 6 - interchange_server: - interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) - debug: - debug_max_parallel: 64 # 增加并行任务数,充分利用GPU - debug_first_n_tasks: 100 # 增加处理的任务数 - data: - train_batch_size: 32 - max_prompt_length: 8000 - max_response_length: 41000 - - task_reader: - type: finworld # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service - finworld: - training: - file_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/train_cc423_11171143_tasks.json - validation: - file_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/val_30_tasks.json - # env_service 仍需配置(用于工具调用) - env_service: - env_type: "finworld" - env_url: "http://127.0.0.1:8080" - env_action_preference: code -trainer: - default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/ajet/checkpoints/example_finworld//open/ajet_finworld" - # resume_mode: disable # 禁用自动恢复,从头开始训练 -actor_rollout_ref: - rollout: - tensor_model_parallel_size: 8 - gpu_memory_utilization: 0.8 -# ------------------ 不需要修改 ------------------ -hydra: - searchpath: - - file://ajet/default_config - - file://ajet/default_config/verl # verl only - - file://ajet/default_config/trinity # trinity only - -# ------------------ 不需要修改 ------------------ -defaults: - - verl_default # verl inherit 1/1 - - trinity_default # trinity inherit 1/1 - - ajet_default - - _self_ diff --git a/tutorial/example_finworld/yaml/finworld_ajet_finworld_loadjsonl_8b.yaml b/tutorial/example_finworld/yaml/finworld_ajet_finworld_loadjsonl_8b.yaml deleted file mode 100644 index 1736d138..00000000 --- a/tutorial/example_finworld/yaml/finworld_ajet_finworld_loadjsonl_8b.yaml +++ /dev/null @@ -1,87 +0,0 @@ -# ------------------ 主要配置 ------------------ -ajet: - project_name: ajet_finworld - experiment_name: "ajet_finworld_loadjsonl_8b" - # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) - judge: - openjudge_llm: qwen-flash # OpenJudge 模型 - rm_llm: qwen-max # RM Gallery 模型 - concurrency: 10 # Judge 并发数 - train_ref_ans_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_DR_11171143_cc.json # 训练集 Reference Answer 路径 - val_ref_ans_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_val.json # 验证集 Reference Answer 路径 - # OpenJudge 权重配置 - report_resolution_weight: 0.2 # 报告质量评估 - trajectory_faithfulness_weight: 0.2 # 事实准确性评估 - citation_audit_weight: 0.2 # 引用审计评估 (覆盖率 + 真实性) - rm_weight: 0.4 # RM Gallery 权重 - task_judge: - # 使用本地 FinWorldJudge 进行评估(解耦远程 env_service) - judge_protocol: tutorial.example_finworld.finworld_judge->FinWorldJudgeByOpenJudge - model: - # ✨✨✨✨ 设置待训练的模型 - path: /mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507 - trainer_common: - nnodes: 2 - n_gpus_per_node: 8 - val_before_train: True - val_pass_n: 8 - save_freq: 10 - test_freq: 2 - total_epochs: 200 - rollout: - # ✨✨✨✨ 编写并选择Agent - user_workflow: tutorial.example_finworld.finworld->ExampleDeepResearchProtocol - force_disable_toolcalls: True - enable_oversample: False - tensor_model_parallel_size: 8 - num_repeat: 4 - max_env_worker: 64 # 增加环境并行数 - max_num_seqs: 64 # 增加VLLM并发序列数 - max_response_length_in_one_turn: 8000 - max_model_len: 50000 - agent_madness_reward: 0.0 - compute_madness_checklist: None - multi_turn: - max_steps: 6 - interchange_server: - interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) - debug: - debug_max_parallel: 64 # 增加并行任务数,充分利用GPU - debug_first_n_tasks: 100 # 增加处理的任务数 - data: - train_batch_size: 32 - max_prompt_length: 8000 - max_response_length: 41000 - - task_reader: - type: finworld # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service - finworld: - training: - file_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/train_cc423_11171143_tasks.json - validation: - file_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/val_30_tasks.json - # env_service 仍需配置(用于工具调用) - env_service: - env_type: "finworld" - env_url: "http://127.0.0.1:8080" - env_action_preference: code -trainer: - default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/ajet/checkpoints/example_finworld//open/ajet_finworld_loadjsonl_8b" - # resume_mode: disable # 禁用自动恢复,从头开始训练 -actor_rollout_ref: - rollout: - tensor_model_parallel_size: 8 - gpu_memory_utilization: 0.8 -# ------------------ 不需要修改 ------------------ -hydra: - searchpath: - - file://ajet/default_config - - file://ajet/default_config/verl # verl only - - file://ajet/default_config/trinity # trinity only - -# ------------------ 不需要修改 ------------------ -defaults: - - verl_default # verl inherit 1/1 - - trinity_default # trinity inherit 1/1 - - ajet_default - - _self_ From db7114c711123f7ed3036d36bb6e4b454e33471d Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 16:58:46 +0800 Subject: [PATCH 18/56] fix(task_runner): Synchronize reward_stats to log_metrics feat(tutorial): Added FinWorld multi-machine multi-GPU training startup script --- ajet/task_runner/general_runner.py | 7 +- tutorial/example_finworld/finworld.sh | 247 ++++++++++++++++++++++++++ 2 files changed, 250 insertions(+), 4 deletions(-) create mode 100644 tutorial/example_finworld/finworld.sh diff --git a/ajet/task_runner/general_runner.py b/ajet/task_runner/general_runner.py index ef6d9f64..91136b51 100644 --- a/ajet/task_runner/general_runner.py +++ b/ajet/task_runner/general_runner.py @@ -55,12 +55,11 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: else: raw_reward, is_success = self.get_judge().compute_reward(workflow_task, workflow_output) # Sync reward_stats from metadata to log_metrics after judge computation - print(f"[DEBUG general_runner] After judge: metadata has 'reward_stats': {'reward_stats' in workflow_output.metadata}") + if "reward_stats" in workflow_output.metadata: - print(f"[DEBUG general_runner] metadata['reward_stats'] keys: {list(workflow_output.metadata['reward_stats'].keys())[:5]}") - workflow_output.log_metrics["reward_stats"] = workflow_output.metadata["reward_stats"] - print(f"[DEBUG general_runner] Synced to log_metrics successfully") + workflow_output.log_metrics["reward_stats"] = workflow_output.metadata["reward_stats"] + workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue assert not isinstance( diff --git a/tutorial/example_finworld/finworld.sh b/tutorial/example_finworld/finworld.sh new file mode 100644 index 00000000..5a0d2661 --- /dev/null +++ b/tutorial/example_finworld/finworld.sh @@ -0,0 +1,247 @@ +#!/bin/bash +set -e +#=============================================================================== +# 配置区域 - 用户只需修改这里 +#=============================================================================== +SUFFIX="ajet_finworld" # 实验后缀,影响所有日志和实验名称 +PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 + + +# OpenJudge 模型配置 +OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 +RM_LLM='qwen-max' # RM Gallery 评分模型 +JUDGE_CONCURRENCY=10 + +# 奖励权重配置 +RM_WEIGHT=0.4 +CITATION_AUDIT_WEIGHT=0.2 +REPORT_RESOLUTION_WEIGHT=0.2 +TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 + +# API密钥配置(从 .env 文件加载,不要硬编码) +# 配置 +NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 +TRAIN_BATCH_SIZE=32 +NUM_STEPS=6 # 每个样本step轮数 + +ADDR="22.17.31.142" +MCP_PORT="8040" + +# 修改:配置文件生成路径,现在动态生成到 yaml 目录下 +export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" +CONFIG_FILE="${AJET_ROOT}/tutorial/example_finworld/yaml/finworld_${SUFFIX}.yaml" +CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_template.yaml" +#=============================================================================== +# 环境配置区域 +#=============================================================================== + +cd ${AJET_ROOT} +source .venv/bin/activate +# API密钥配置 - 从 .env 文件加载 +ENV_FILE="${AJET_ROOT}/.env" +if [ -f "$ENV_FILE" ]; then + set -a + source "$ENV_FILE" + set +a + echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" +else + echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" +fi + +# MongoDB 缓存配置 +CACHE_TYPE="mongodb" +MONGO_URI="mongodb://${ADDR}:27117/" +MONGO_DB_NAME="finworld_cache" +MONGO_COLLECTION_NAME="tool_cache" + +# FinWorld MCP 配置 +LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" +FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" + +# 动态生成 MCP 配置文件 +mkdir -p $(dirname ${FINWORLD_MCP_CONFIG}) +cat > ${FINWORLD_MCP_CONFIG} << EOF +{ + "mcpServers": { + "flowllm": { + "transport": "sse", + "url": "http://${ADDR}:${MCP_PORT}/sse", + "timeout": 600, + "sse_read_timeout": 1200 + } + } +} +EOF +FINWORLD_TOOL_RESULT_MAX_CHARS=10000 + +# 其他服务配置 +HF_ENDPOINT="https://hf-mirror.com" +ES_HOSTS="http://11.160.132.46:8200" + +#=============================================================================== +# 多机训练参数配置 +#=============================================================================== +if [ -z "${WORLD_SIZE}" ]; then + echo "ERROR: WORLD_SIZE environment variable is not set!" + echo "Please ensure this script is run in a multi-node environment (e.g., PAI-DLC, SLURM)" + exit 1 +fi + +NNODES=${WORLD_SIZE} +GPUS_PER_NODE=8 +EXPECTED_WORKERS=$WORLD_SIZE + +#=============================================================================== +# NCCL 配置 +#=============================================================================== +export NCCL_TIMEOUT=1800 +export NCCL_DEBUG=WARN +export NCCL_IB_TIMEOUT=23 +export NCCL_ASYNC_ERROR_HANDLING=1 + +#=============================================================================== +# 自动生成的变量 +#=============================================================================== +CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") + +MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" +ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" +TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" + +#=============================================================================== +# 工具函数 +#=============================================================================== +print_green() { + echo -e "\033[32m$1\033[0m" +} + +print_red() { + echo -e "\033[31m$1\033[0m" +} + +log() { + echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" +} + +check_workers() { + local status_output=$(ray status 2>/dev/null) + if [ -z "$status_output" ]; then echo 0; return; fi + local node_count=$(echo "$status_output" | grep -E "^[[:space:]]*1[[:space:]]+node_" | wc -l) + if [ "$node_count" -gt 0 ]; then echo $node_count; return; fi + echo $(echo "$status_output" | grep -o "node_[0-9a-f]\+" | sort -u | wc -l) +} + +check_gpu_resources() { + gpu_count=$(ray status 2>/dev/null | grep -A 10 "Resources" | grep "GPU" | awk '{print $1}' | cut -d'/' -f2) + if [ -z "$gpu_count" ]; then echo 0; else printf "%.0f" "$gpu_count"; fi +} + +#=============================================================================== +# 导出环境变量 +#=============================================================================== +export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME +export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS +export HF_ENDPOINT ES_HOSTS +export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" +export RAY_CLUSTER_MODE="multi_node" +# Directory paths +export ENV_SERVICE_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/mongodb/BeyondAgent_env" + +export FINWORLD_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 +export FINWORLD_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" + +#=============================================================================== +# 主流程 +#=============================================================================== +log "开始多机多卡训练: ${SUFFIX}" +log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" +mkdir -p ${LOG_DIR} +mkdir -p $(dirname ${CONFIG_FILE}) + +#=============================================================================== +# Master 节点启动流程 +#=============================================================================== +if [[ $HOSTNAME == *"-master-"* ]]; then + print_green "==> This is MASTER node: $HOSTNAME" + + #--------------------------------------------------------------------------- + # 1. 动态生成配置文件 (从模板注入参数) + #--------------------------------------------------------------------------- + log "正在从模板生成配置文件..." + sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ + -e "s|{{PREFIX}}|${PREFIX}|g" \ + -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ + -e "s|{{NNODES}}|${NNODES}|g" \ + -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ + -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ + -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ + -e "s|{{RM_LLM}}|${RM_LLM}|g" \ + -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ + -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ + -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ + -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ + -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ + -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ + -e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \ + -e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \ + -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ + -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ + ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} + + print_green "配置文件已生成: ${CONFIG_FILE}" + print_green "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" + + #--------------------------------------------------------------------------- + # 2. 清理和初始化 Ray + #--------------------------------------------------------------------------- + rm -f "$MASTER_IP_FILE" + ray stop --force || true + sleep 3 + + #--------------------------------------------------------------------------- + # 4. 启动 Ray Head + #--------------------------------------------------------------------------- + print_green "Starting Ray head node at $MASTER_ADDR" + ray start --head --node-ip-address $MASTER_ADDR --num-gpus 8 + sleep 10 + echo $MASTER_ADDR > $MASTER_IP_FILE + + #--------------------------------------------------------------------------- + # 9. 启动训练任务 + #--------------------------------------------------------------------------- + print_green "Starting training job..." + source .venv/bin/activate + + export RAY_ADDRESS="ray://localhost:10001" + export env_url="http://${MASTER_ADDR}:8080" + export env_type="finworld" + + print_green "===================================" + print_green "Training Configuration" + print_green "Total GPUs: $((NNODES * GPUS_PER_NODE))" + print_green "Log: ${TRAIN_LOG}" + print_green "===================================" + + # 启动训练任务 + python ajet/launcher.py \ + --with-finworld \ + --conf ${CONFIG_FILE} \ + --backbone="verl" \ + --debug="TAG_A" \ + 2>&1 | tee ${TRAIN_LOG} + + # 保留原脚本末尾的 CLI 调用 + ajet --conf ${CONFIG_FILE} --backbone='verl' + +#=============================================================================== +# Worker 节点启动流程 (逻辑保持不变) +#=============================================================================== +else + print_green "==> This is WORKER node: $HOSTNAME" + # [此处保留原脚本中 Worker 节点等待 Master IP 和连接 Ray 的完整逻辑] + while [ ! -f $MASTER_IP_FILE ]; do sleep 5; done + MASTER_ADDR=$(cat $MASTER_IP_FILE) + ray stop || true + ray start --address $MASTER_ADDR:6379 --num-gpus 8 + while true; do sleep 60; done +fi \ No newline at end of file From 5a25550047709845ee0f5d0f54386d5bc4ceadb2 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 17:45:04 +0800 Subject: [PATCH 19/56] refactor(script): Refactored the finworld training script, integrating configuration and startup processes. --- tutorial/example_finworld/finworld.sh | 147 +++++----- .../example_finworld/scripts/ajet_finworld.sh | 264 ------------------ .../scripts/ajet_finworld_cc1k.sh | 252 ----------------- .../scripts/ajet_finworld_loadjsonl.sh | 264 ------------------ .../scripts/ajet_finworld_loadjsonl_8b.sh | 264 ------------------ tutorial/example_finworld/scripts/single.sh | 112 -------- .../yaml_template/finworld_template.yaml | 2 +- 7 files changed, 64 insertions(+), 1241 deletions(-) delete mode 100644 tutorial/example_finworld/scripts/ajet_finworld.sh delete mode 100644 tutorial/example_finworld/scripts/ajet_finworld_cc1k.sh delete mode 100644 tutorial/example_finworld/scripts/ajet_finworld_loadjsonl.sh delete mode 100644 tutorial/example_finworld/scripts/ajet_finworld_loadjsonl_8b.sh delete mode 100644 tutorial/example_finworld/scripts/single.sh diff --git a/tutorial/example_finworld/finworld.sh b/tutorial/example_finworld/finworld.sh index 5a0d2661..904ac4c1 100644 --- a/tutorial/example_finworld/finworld.sh +++ b/tutorial/example_finworld/finworld.sh @@ -1,12 +1,11 @@ #!/bin/bash set -e #=============================================================================== -# 配置区域 - 用户只需修改这里 +# 1. 配置区域 - 用户只需修改这里 #=============================================================================== SUFFIX="ajet_finworld" # 实验后缀,影响所有日志和实验名称 PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 - # OpenJudge 模型配置 OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 RM_LLM='qwen-max' # RM Gallery 评分模型 @@ -18,23 +17,17 @@ CITATION_AUDIT_WEIGHT=0.2 REPORT_RESOLUTION_WEIGHT=0.2 TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 -# API密钥配置(从 .env 文件加载,不要硬编码) -# 配置 +# 训练参数配置 NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 -TRAIN_BATCH_SIZE=32 +TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 - -ADDR="22.17.31.142" -MCP_PORT="8040" - +FINWORLD_TOOL_RESULT_MAX_CHARS=10000 # 修改:配置文件生成路径,现在动态生成到 yaml 目录下 export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" CONFIG_FILE="${AJET_ROOT}/tutorial/example_finworld/yaml/finworld_${SUFFIX}.yaml" CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_template.yaml" -#=============================================================================== -# 环境配置区域 -#=============================================================================== +# 涉密的配置(API_KEY以及模型、数据位置)从.env读取 cd ${AJET_ROOT} source .venv/bin/activate # API密钥配置 - 从 .env 文件加载 @@ -48,14 +41,45 @@ else echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" fi +#=============================================================================== +# 2. 动态生成配置文件 (从yaml template生成yaml) +#=============================================================================== + +sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ + -e "s|{{PREFIX}}|${PREFIX}|g" \ + -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ + -e "s|{{NNODES}}|${NNODES}|g" \ + -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ + -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ + -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ + -e "s|{{RM_LLM}}|${RM_LLM}|g" \ + -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ + -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ + -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ + -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ + -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ + -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ + -e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \ + -e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \ + -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ + -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ + -e "s|{{CKPT_SAVE_PATH}}|${CKPT_SAVE_PATH}|g" \ + ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} + +echo "配置文件已生成: ${CONFIG_FILE}" +echo "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" + +#=============================================================================== +# 3. 环境配置 +#=============================================================================== # MongoDB 缓存配置 CACHE_TYPE="mongodb" MONGO_URI="mongodb://${ADDR}:27117/" MONGO_DB_NAME="finworld_cache" MONGO_COLLECTION_NAME="tool_cache" +export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME # FinWorld MCP 配置 -LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" # 动态生成 MCP 配置文件 @@ -72,53 +96,38 @@ cat > ${FINWORLD_MCP_CONFIG} << EOF } } EOF -FINWORLD_TOOL_RESULT_MAX_CHARS=10000 +export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS # 其他服务配置 HF_ENDPOINT="https://hf-mirror.com" ES_HOSTS="http://11.160.132.46:8200" +export HF_ENDPOINT ES_HOSTS + +# log 文件位置 +CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") +LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" +MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" +ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" +TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" -#=============================================================================== # 多机训练参数配置 -#=============================================================================== if [ -z "${WORLD_SIZE}" ]; then echo "ERROR: WORLD_SIZE environment variable is not set!" echo "Please ensure this script is run in a multi-node environment (e.g., PAI-DLC, SLURM)" exit 1 fi - NNODES=${WORLD_SIZE} GPUS_PER_NODE=8 EXPECTED_WORKERS=$WORLD_SIZE -#=============================================================================== -# NCCL 配置 -#=============================================================================== -export NCCL_TIMEOUT=1800 -export NCCL_DEBUG=WARN -export NCCL_IB_TIMEOUT=23 -export NCCL_ASYNC_ERROR_HANDLING=1 - -#=============================================================================== -# 自动生成的变量 -#=============================================================================== -CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") - -MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" -ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" -TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" #=============================================================================== -# 工具函数 +# 4. 工具函数 以及 NCCL 配置(固定) #=============================================================================== print_green() { echo -e "\033[32m$1\033[0m" } -print_red() { - echo -e "\033[31m$1\033[0m" -} - log() { echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" } @@ -136,22 +145,24 @@ check_gpu_resources() { if [ -z "$gpu_count" ]; then echo 0; else printf "%.0f" "$gpu_count"; fi } + +export NCCL_TIMEOUT=1800 +export NCCL_DEBUG=WARN +export NCCL_IB_TIMEOUT=23 +export NCCL_ASYNC_ERROR_HANDLING=1 + #=============================================================================== -# 导出环境变量 +# 5. 工具envservice 环境变量 #=============================================================================== -export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME -export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS -export HF_ENDPOINT ES_HOSTS + export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" export RAY_CLUSTER_MODE="multi_node" -# Directory paths -export ENV_SERVICE_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/mongodb/BeyondAgent_env" - export FINWORLD_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 export FINWORLD_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" + #=============================================================================== -# 主流程 +# 6. 主流程 #=============================================================================== log "开始多机多卡训练: ${SUFFIX}" log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" @@ -159,47 +170,20 @@ mkdir -p ${LOG_DIR} mkdir -p $(dirname ${CONFIG_FILE}) #=============================================================================== -# Master 节点启动流程 +# 6.1 Master 节点启动流程 #=============================================================================== if [[ $HOSTNAME == *"-master-"* ]]; then print_green "==> This is MASTER node: $HOSTNAME" #--------------------------------------------------------------------------- - # 1. 动态生成配置文件 (从模板注入参数) - #--------------------------------------------------------------------------- - log "正在从模板生成配置文件..." - sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ - -e "s|{{PREFIX}}|${PREFIX}|g" \ - -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ - -e "s|{{NNODES}}|${NNODES}|g" \ - -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ - -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ - -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ - -e "s|{{RM_LLM}}|${RM_LLM}|g" \ - -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ - -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ - -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ - -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ - -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ - -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ - -e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \ - -e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \ - -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ - -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ - ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} - - print_green "配置文件已生成: ${CONFIG_FILE}" - print_green "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" - - #--------------------------------------------------------------------------- - # 2. 清理和初始化 Ray + # 6.1.1 清理和初始化 Ray #--------------------------------------------------------------------------- rm -f "$MASTER_IP_FILE" ray stop --force || true sleep 3 #--------------------------------------------------------------------------- - # 4. 启动 Ray Head + # 6.1.2 启动 Ray Head #--------------------------------------------------------------------------- print_green "Starting Ray head node at $MASTER_ADDR" ray start --head --node-ip-address $MASTER_ADDR --num-gpus 8 @@ -207,11 +191,10 @@ if [[ $HOSTNAME == *"-master-"* ]]; then echo $MASTER_ADDR > $MASTER_IP_FILE #--------------------------------------------------------------------------- - # 9. 启动训练任务 + # 6.1.3 启动训练任务 #--------------------------------------------------------------------------- print_green "Starting training job..." source .venv/bin/activate - export RAY_ADDRESS="ray://localhost:10001" export env_url="http://${MASTER_ADDR}:8080" export env_type="finworld" @@ -222,23 +205,19 @@ if [[ $HOSTNAME == *"-master-"* ]]; then print_green "Log: ${TRAIN_LOG}" print_green "===================================" - # 启动训练任务 + # 启动训练任务(最核心) python ajet/launcher.py \ --with-finworld \ --conf ${CONFIG_FILE} \ --backbone="verl" \ - --debug="TAG_A" \ 2>&1 | tee ${TRAIN_LOG} - # 保留原脚本末尾的 CLI 调用 - ajet --conf ${CONFIG_FILE} --backbone='verl' #=============================================================================== -# Worker 节点启动流程 (逻辑保持不变) +# 6.2 Worker 节点启动流程 #=============================================================================== else print_green "==> This is WORKER node: $HOSTNAME" - # [此处保留原脚本中 Worker 节点等待 Master IP 和连接 Ray 的完整逻辑] while [ ! -f $MASTER_IP_FILE ]; do sleep 5; done MASTER_ADDR=$(cat $MASTER_IP_FILE) ray stop || true diff --git a/tutorial/example_finworld/scripts/ajet_finworld.sh b/tutorial/example_finworld/scripts/ajet_finworld.sh deleted file mode 100644 index 5a427e52..00000000 --- a/tutorial/example_finworld/scripts/ajet_finworld.sh +++ /dev/null @@ -1,264 +0,0 @@ -#!/bin/bash -set -e -#=============================================================================== -# 配置区域 - 用户只需修改这里 -#=============================================================================== -SUFFIX="ajet_finworld" # 实验后缀,影响所有日志和实验名称 -PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 - -# 新增:模型与模板配置 -MODEL_PATH="/mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507" -CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_template.yaml" - -# 新增:数据文件路径配置 -TRAIN_DATA_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/train_cc423_11171143_tasks.json" -VAL_DATA_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/val_30_tasks.json" - -# 新增:Reference Answer 文件路径配置(RM Gallery 需要) -TRAIN_REF_ANS_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_DR_11171143_cc.json" -VAL_REF_ANS_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_val.json" - -# 新增:Judge 模型配置 -OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 -RM_LLM='qwen-max' # RM Gallery 评分模型 -JUDGE_CONCURRENCY=10 - -# 新增:奖励权重配置 -RM_WEIGHT=0.4 -CITATION_AUDIT_WEIGHT=0.2 -REPORT_RESOLUTION_WEIGHT=0.2 -TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 - -# API密钥配置(从 .env 文件加载,不要硬编码) -# 配置 -NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 -TRAIN_BATCH_SIZE=32 -NUM_STEPS=6 # 每个样本step轮数 - -ADDR="22.17.31.142" -MCP_PORT="8040" - -# 修改:配置文件生成路径,现在动态生成到 yaml 目录下 -export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" -CONFIG_FILE="${AJET_ROOT}/tutorial/example_finworld/yaml/finworld_${SUFFIX}.yaml" - -#=============================================================================== -# 环境配置区域 -#=============================================================================== - -cd ${AJET_ROOT} -source .venv/bin/activate -# API密钥配置 - 从 .env 文件加载 -ENV_FILE="${AJET_ROOT}/.env" -if [ -f "$ENV_FILE" ]; then - set -a - source "$ENV_FILE" - set +a - echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" -else - echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" -fi - -# MongoDB 缓存配置 -CACHE_TYPE="mongodb" -MONGO_URI="mongodb://${ADDR}:27117/" -MONGO_DB_NAME="finworld_cache" -MONGO_COLLECTION_NAME="tool_cache" - -# FinWorld MCP 配置 -LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" -FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" - -# 动态生成 MCP 配置文件 -mkdir -p $(dirname ${FINWORLD_MCP_CONFIG}) -cat > ${FINWORLD_MCP_CONFIG} << EOF -{ - "mcpServers": { - "flowllm": { - "transport": "sse", - "url": "http://${ADDR}:${MCP_PORT}/sse", - "timeout": 600, - "sse_read_timeout": 1200 - } - } -} -EOF -FINWORLD_TOOL_RESULT_MAX_CHARS=10000 - -# 其他服务配置 -HF_ENDPOINT="https://hf-mirror.com" -ES_HOSTS="http://11.160.132.46:8200" - -#=============================================================================== -# 多机训练参数配置 -#=============================================================================== -if [ -z "${WORLD_SIZE}" ]; then - echo "ERROR: WORLD_SIZE environment variable is not set!" - echo "Please ensure this script is run in a multi-node environment (e.g., PAI-DLC, SLURM)" - exit 1 -fi - -NNODES=${WORLD_SIZE} -GPUS_PER_NODE=8 -EXPECTED_WORKERS=$WORLD_SIZE - -#=============================================================================== -# NCCL 配置 -#=============================================================================== -export NCCL_TIMEOUT=1800 -export NCCL_DEBUG=WARN -export NCCL_IB_TIMEOUT=23 -export NCCL_ASYNC_ERROR_HANDLING=1 - -#=============================================================================== -# 自动生成的变量 -#=============================================================================== -CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") - -MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" -ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" -TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" - -#=============================================================================== -# 工具函数 -#=============================================================================== -print_green() { - echo -e "\033[32m$1\033[0m" -} - -print_red() { - echo -e "\033[31m$1\033[0m" -} - -log() { - echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" -} - -check_workers() { - local status_output=$(ray status 2>/dev/null) - if [ -z "$status_output" ]; then echo 0; return; fi - local node_count=$(echo "$status_output" | grep -E "^[[:space:]]*1[[:space:]]+node_" | wc -l) - if [ "$node_count" -gt 0 ]; then echo $node_count; return; fi - echo $(echo "$status_output" | grep -o "node_[0-9a-f]\+" | sort -u | wc -l) -} - -check_gpu_resources() { - gpu_count=$(ray status 2>/dev/null | grep -A 10 "Resources" | grep "GPU" | awk '{print $1}' | cut -d'/' -f2) - if [ -z "$gpu_count" ]; then echo 0; else printf "%.0f" "$gpu_count"; fi -} - -#=============================================================================== -# 导出环境变量 -#=============================================================================== -export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME -export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS -export HF_ENDPOINT ES_HOSTS -export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" -export RAY_CLUSTER_MODE="multi_node" -# Directory paths -export ENV_SERVICE_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/mongodb/BeyondAgent_env" - -export FINWORLD_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 -export FINWORLD_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" - -#=============================================================================== -# 主流程 -#=============================================================================== -log "开始多机多卡训练: ${SUFFIX}" -log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" -mkdir -p ${LOG_DIR} -mkdir -p $(dirname ${CONFIG_FILE}) - -#=============================================================================== -# Master 节点启动流程 -#=============================================================================== -if [[ $HOSTNAME == *"-master-"* ]]; then - print_green "==> This is MASTER node: $HOSTNAME" - - #--------------------------------------------------------------------------- - # 1. 动态生成配置文件 (从模板注入参数) - #--------------------------------------------------------------------------- - log "正在从模板生成配置文件..." - sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ - -e "s|{{PREFIX}}|${PREFIX}|g" \ - -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ - -e "s|{{NNODES}}|${NNODES}|g" \ - -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ - -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ - -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ - -e "s|{{RM_LLM}}|${RM_LLM}|g" \ - -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ - -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ - -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ - -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ - -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ - -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ - -e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \ - -e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \ - -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ - -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ - ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} - - print_green "配置文件已生成: ${CONFIG_FILE}" - print_green "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" - - #--------------------------------------------------------------------------- - # 2. 清理和初始化 Ray - #--------------------------------------------------------------------------- - rm -f "$MASTER_IP_FILE" - ray stop --force || true - sleep 3 - - #--------------------------------------------------------------------------- - # 4. 启动 Ray Head - #--------------------------------------------------------------------------- - print_green "Starting Ray head node at $MASTER_ADDR" - ray start --head --node-ip-address $MASTER_ADDR --num-gpus 8 - sleep 10 - echo $MASTER_ADDR > $MASTER_IP_FILE - - #--------------------------------------------------------------------------- - # 5 & 6. 等待节点和 GPU 就绪 (逻辑保持不变) - #--------------------------------------------------------------------------- - # ... (此处省略重复的等待逻辑以保持简洁,实际运行时请保留原脚本中的 while 循环) ... - # [请保留原脚本中 5.等待所有Worker 6.等待GPU 7.等待Dashboard 的完整代码] - - #--------------------------------------------------------------------------- - # 9. 启动训练任务 - #--------------------------------------------------------------------------- - print_green "Starting training job..." - source .venv/bin/activate - - export RAY_ADDRESS="ray://localhost:10001" - export env_url="http://${MASTER_ADDR}:8080" - export env_type="finworld" - - print_green "===================================" - print_green "Training Configuration" - print_green "Total GPUs: $((NNODES * GPUS_PER_NODE))" - print_green "Log: ${TRAIN_LOG}" - print_green "===================================" - - # 启动训练任务 - python ajet/launcher.py \ - --with-finworld \ - --conf ${CONFIG_FILE} \ - --backbone="verl" \ - --debug="TAG_A" \ - 2>&1 | tee ${TRAIN_LOG} - - # 保留原脚本末尾的 CLI 调用 - ajet --conf ${CONFIG_FILE} --backbone='verl' - -#=============================================================================== -# Worker 节点启动流程 (逻辑保持不变) -#=============================================================================== -else - print_green "==> This is WORKER node: $HOSTNAME" - # [此处保留原脚本中 Worker 节点等待 Master IP 和连接 Ray 的完整逻辑] - while [ ! -f $MASTER_IP_FILE ]; do sleep 5; done - MASTER_ADDR=$(cat $MASTER_IP_FILE) - ray stop || true - ray start --address $MASTER_ADDR:6379 --num-gpus 8 - while true; do sleep 60; done -fi \ No newline at end of file diff --git a/tutorial/example_finworld/scripts/ajet_finworld_cc1k.sh b/tutorial/example_finworld/scripts/ajet_finworld_cc1k.sh deleted file mode 100644 index a0c8895f..00000000 --- a/tutorial/example_finworld/scripts/ajet_finworld_cc1k.sh +++ /dev/null @@ -1,252 +0,0 @@ -#!/bin/bash -set -e -#=============================================================================== -# 配置区域 - 用户只需修改这里 -#=============================================================================== -SUFFIX="ajet_finworld_cc1k" # 实验后缀,影响所有日志和实验名称 -PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 - -# 新增:模型与模板配置 -MODEL_PATH="/mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507" -CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_template.yaml" - -# 新增:Judge 模型配置 -OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 -RM_LLM='qwen-max' # RM Gallery 评分模型 -JUDGE_CONCURRENCY=10 - -# 新增:奖励权重配置 -RM_WEIGHT=0.4 -CITATION_AUDIT_WEIGHT=0.2 -REPORT_RESOLUTION_WEIGHT=0.2 -TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 - -# API密钥配置(从 .env 文件加载,不要硬编码) -# 配置 -NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 -TRAIN_BATCH_SIZE=32 -NUM_STEPS=6 # 每个样本step轮数 - -ADDR="22.17.31.142" -MCP_PORT="8040" - -# 修改:配置文件生成路径,现在动态生成到 yaml 目录下 -export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" -CONFIG_FILE="${AJET_ROOT}/tutorial/example_finworld/yaml/finworld_${SUFFIX}.yaml" - -#=============================================================================== -# 环境配置区域 -#=============================================================================== - -cd ${AJET_ROOT} -source .venv/bin/activate -# API密钥配置 - 从 .env 文件加载 -ENV_FILE="${AJET_ROOT}/.env" -if [ -f "$ENV_FILE" ]; then - set -a - source "$ENV_FILE" - set +a - echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" -else - echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" -fi - -# MongoDB 缓存配置 -CACHE_TYPE="mongodb" -MONGO_URI="mongodb://${ADDR}:27117/" -MONGO_DB_NAME="finworld_cache" -MONGO_COLLECTION_NAME="tool_cache" - -# FinWorld MCP 配置 -LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" -FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" - -# 动态生成 MCP 配置文件 -mkdir -p $(dirname ${FINWORLD_MCP_CONFIG}) -cat > ${FINWORLD_MCP_CONFIG} << EOF -{ - "mcpServers": { - "flowllm": { - "transport": "sse", - "url": "http://${ADDR}:${MCP_PORT}/sse", - "timeout": 600, - "sse_read_timeout": 1200 - } - } -} -EOF -FINWORLD_TOOL_RESULT_MAX_CHARS=10000 - -# 其他服务配置 -HF_ENDPOINT="https://hf-mirror.com" -ES_HOSTS="http://11.160.132.46:8200" - -#=============================================================================== -# 多机训练参数配置 -#=============================================================================== -if [ -z "${WORLD_SIZE}" ]; then - echo "ERROR: WORLD_SIZE environment variable is not set!" - echo "Please ensure this script is run in a multi-node environment (e.g., PAI-DLC, SLURM)" - exit 1 -fi - -NNODES=${WORLD_SIZE} -GPUS_PER_NODE=8 -EXPECTED_WORKERS=$WORLD_SIZE - -#=============================================================================== -# NCCL 配置 -#=============================================================================== -export NCCL_TIMEOUT=1800 -export NCCL_DEBUG=WARN -export NCCL_IB_TIMEOUT=23 -export NCCL_ASYNC_ERROR_HANDLING=1 - -#=============================================================================== -# 自动生成的变量 -#=============================================================================== -CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") - -MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" -ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" -TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" - -#=============================================================================== -# 工具函数 -#=============================================================================== -print_green() { - echo -e "\033[32m$1\033[0m" -} - -print_red() { - echo -e "\033[31m$1\033[0m" -} - -log() { - echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" -} - -check_workers() { - local status_output=$(ray status 2>/dev/null) - if [ -z "$status_output" ]; then echo 0; return; fi - local node_count=$(echo "$status_output" | grep -E "^[[:space:]]*1[[:space:]]+node_" | wc -l) - if [ "$node_count" -gt 0 ]; then echo $node_count; return; fi - echo $(echo "$status_output" | grep -o "node_[0-9a-f]\+" | sort -u | wc -l) -} - -check_gpu_resources() { - gpu_count=$(ray status 2>/dev/null | grep -A 10 "Resources" | grep "GPU" | awk '{print $1}' | cut -d'/' -f2) - if [ -z "$gpu_count" ]; then echo 0; else printf "%.0f" "$gpu_count"; fi -} - -#=============================================================================== -# 导出环境变量 -#=============================================================================== -export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME -export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS -export HF_ENDPOINT ES_HOSTS -export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" -export RAY_CLUSTER_MODE="multi_node" -# Directory paths -export ENV_SERVICE_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/mongodb/BeyondAgent_env" - -export FINWORLD_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 -export FINWORLD_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" - -#=============================================================================== -# 主流程 -#=============================================================================== -log "开始多机多卡训练: ${SUFFIX}" -log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" -mkdir -p ${LOG_DIR} -mkdir -p $(dirname ${CONFIG_FILE}) - -#=============================================================================== -# Master 节点启动流程 -#=============================================================================== -if [[ $HOSTNAME == *"-master-"* ]]; then - print_green "==> This is MASTER node: $HOSTNAME" - - #--------------------------------------------------------------------------- - # 1. 动态生成配置文件 (从模板注入参数) - #--------------------------------------------------------------------------- - log "正在从模板生成配置文件..." - sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ - -e "s|{{PREFIX}}|${PREFIX}|g" \ - -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ - -e "s|{{NNODES}}|${NNODES}|g" \ - -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ - -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ - -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ - -e "s|{{RM_LLM}}|${RM_LLM}|g" \ - -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ - -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ - -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ - -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ - -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ - -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ - ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} - - print_green "配置文件已生成: ${CONFIG_FILE}" - print_green "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" - - #--------------------------------------------------------------------------- - # 2. 清理和初始化 Ray - #--------------------------------------------------------------------------- - rm -f "$MASTER_IP_FILE" - ray stop --force || true - sleep 3 - - #--------------------------------------------------------------------------- - # 4. 启动 Ray Head - #--------------------------------------------------------------------------- - print_green "Starting Ray head node at $MASTER_ADDR" - ray start --head --node-ip-address $MASTER_ADDR --num-gpus 8 - sleep 10 - echo $MASTER_ADDR > $MASTER_IP_FILE - - #--------------------------------------------------------------------------- - # 5 & 6. 等待节点和 GPU 就绪 (逻辑保持不变) - #--------------------------------------------------------------------------- - # ... (此处省略重复的等待逻辑以保持简洁,实际运行时请保留原脚本中的 while 循环) ... - # [请保留原脚本中 5.等待所有Worker 6.等待GPU 7.等待Dashboard 的完整代码] - - #--------------------------------------------------------------------------- - # 9. 启动训练任务 - #--------------------------------------------------------------------------- - print_green "Starting training job..." - source .venv/bin/activate - - export RAY_ADDRESS="ray://localhost:10001" - export env_url="http://${MASTER_ADDR}:8080" - export env_type="finworld" - - print_green "===================================" - print_green "Training Configuration" - print_green "Total GPUs: $((NNODES * GPUS_PER_NODE))" - print_green "Log: ${TRAIN_LOG}" - print_green "===================================" - - # 启动训练任务 - python ajet/launcher.py \ - --with-finworld \ - --conf ${CONFIG_FILE} \ - --backbone="verl" \ - --debug="TAG_A" \ - 2>&1 | tee ${TRAIN_LOG} - - # 保留原脚本末尾的 CLI 调用 - ajet --conf ${CONFIG_FILE} --backbone='verl' - -#=============================================================================== -# Worker 节点启动流程 (逻辑保持不变) -#=============================================================================== -else - print_green "==> This is WORKER node: $HOSTNAME" - # [此处保留原脚本中 Worker 节点等待 Master IP 和连接 Ray 的完整逻辑] - while [ ! -f $MASTER_IP_FILE ]; do sleep 5; done - MASTER_ADDR=$(cat $MASTER_IP_FILE) - ray stop || true - ray start --address $MASTER_ADDR:6379 --num-gpus 8 - while true; do sleep 60; done -fi \ No newline at end of file diff --git a/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl.sh b/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl.sh deleted file mode 100644 index 1abde8a0..00000000 --- a/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl.sh +++ /dev/null @@ -1,264 +0,0 @@ -#!/bin/bash -set -e -#=============================================================================== -# 配置区域 - 用户只需修改这里 -#=============================================================================== -SUFFIX="ajet_finworld_loadjsonl_7b" # 实验后缀,影响所有日志和实验名称 -PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 - -# 新增:模型与模板配置 -MODEL_PATH="/mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507" -CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_template.yaml" - -# 新增:数据文件路径配置 -TRAIN_DATA_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/train_cc423_11171143_tasks.json" -VAL_DATA_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/val_30_tasks.json" - -# 新增:Reference Answer 文件路径配置(RM Gallery 需要) -TRAIN_REF_ANS_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_DR_11171143_cc.json" -VAL_REF_ANS_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_val.json" - -# 新增:Judge 模型配置 -OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 -RM_LLM='qwen-max' # RM Gallery 评分模型 -JUDGE_CONCURRENCY=10 - -# 新增:奖励权重配置 -RM_WEIGHT=0.4 -CITATION_AUDIT_WEIGHT=0.2 -REPORT_RESOLUTION_WEIGHT=0.2 -TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 - -# API密钥配置(从 .env 文件加载,不要硬编码) -# 配置 -NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 -TRAIN_BATCH_SIZE=32 -NUM_STEPS=6 # 每个样本step轮数 - -ADDR="22.17.31.142" -MCP_PORT="8040" - -# 修改:配置文件生成路径,现在动态生成到 yaml 目录下 -export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" -CONFIG_FILE="${AJET_ROOT}/tutorial/example_finworld/yaml/finworld_${SUFFIX}.yaml" - -#=============================================================================== -# 环境配置区域 -#=============================================================================== - -cd ${AJET_ROOT} -source .venv/bin/activate -# API密钥配置 - 从 .env 文件加载 -ENV_FILE="${AJET_ROOT}/.env" -if [ -f "$ENV_FILE" ]; then - set -a - source "$ENV_FILE" - set +a - echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" -else - echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" -fi - -# MongoDB 缓存配置 -CACHE_TYPE="mongodb" -MONGO_URI="mongodb://${ADDR}:27117/" -MONGO_DB_NAME="finworld_cache" -MONGO_COLLECTION_NAME="tool_cache" - -# FinWorld MCP 配置 -LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" -FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" - -# 动态生成 MCP 配置文件 -mkdir -p $(dirname ${FINWORLD_MCP_CONFIG}) -cat > ${FINWORLD_MCP_CONFIG} << EOF -{ - "mcpServers": { - "flowllm": { - "transport": "sse", - "url": "http://${ADDR}:${MCP_PORT}/sse", - "timeout": 600, - "sse_read_timeout": 1200 - } - } -} -EOF -FINWORLD_TOOL_RESULT_MAX_CHARS=10000 - -# 其他服务配置 -HF_ENDPOINT="https://hf-mirror.com" -ES_HOSTS="http://11.160.132.46:8200" - -#=============================================================================== -# 多机训练参数配置 -#=============================================================================== -if [ -z "${WORLD_SIZE}" ]; then - echo "ERROR: WORLD_SIZE environment variable is not set!" - echo "Please ensure this script is run in a multi-node environment (e.g., PAI-DLC, SLURM)" - exit 1 -fi - -NNODES=${WORLD_SIZE} -GPUS_PER_NODE=8 -EXPECTED_WORKERS=$WORLD_SIZE - -#=============================================================================== -# NCCL 配置 -#=============================================================================== -export NCCL_TIMEOUT=1800 -export NCCL_DEBUG=WARN -export NCCL_IB_TIMEOUT=23 -export NCCL_ASYNC_ERROR_HANDLING=1 - -#=============================================================================== -# 自动生成的变量 -#=============================================================================== -CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") - -MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" -ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" -TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" - -#=============================================================================== -# 工具函数 -#=============================================================================== -print_green() { - echo -e "\033[32m$1\033[0m" -} - -print_red() { - echo -e "\033[31m$1\033[0m" -} - -log() { - echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" -} - -check_workers() { - local status_output=$(ray status 2>/dev/null) - if [ -z "$status_output" ]; then echo 0; return; fi - local node_count=$(echo "$status_output" | grep -E "^[[:space:]]*1[[:space:]]+node_" | wc -l) - if [ "$node_count" -gt 0 ]; then echo $node_count; return; fi - echo $(echo "$status_output" | grep -o "node_[0-9a-f]\+" | sort -u | wc -l) -} - -check_gpu_resources() { - gpu_count=$(ray status 2>/dev/null | grep -A 10 "Resources" | grep "GPU" | awk '{print $1}' | cut -d'/' -f2) - if [ -z "$gpu_count" ]; then echo 0; else printf "%.0f" "$gpu_count"; fi -} - -#=============================================================================== -# 导出环境变量 -#=============================================================================== -export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME -export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS -export HF_ENDPOINT ES_HOSTS -export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" -export RAY_CLUSTER_MODE="multi_node" -# Directory paths -export ENV_SERVICE_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/mongodb/BeyondAgent_env" - -export FINWORLD_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 -export FINWORLD_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" - -#=============================================================================== -# 主流程 -#=============================================================================== -log "开始多机多卡训练: ${SUFFIX}" -log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" -mkdir -p ${LOG_DIR} -mkdir -p $(dirname ${CONFIG_FILE}) - -#=============================================================================== -# Master 节点启动流程 -#=============================================================================== -if [[ $HOSTNAME == *"-master-"* ]]; then - print_green "==> This is MASTER node: $HOSTNAME" - - #--------------------------------------------------------------------------- - # 1. 动态生成配置文件 (从模板注入参数) - #--------------------------------------------------------------------------- - log "正在从模板生成配置文件..." - sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ - -e "s|{{PREFIX}}|${PREFIX}|g" \ - -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ - -e "s|{{NNODES}}|${NNODES}|g" \ - -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ - -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ - -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ - -e "s|{{RM_LLM}}|${RM_LLM}|g" \ - -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ - -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ - -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ - -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ - -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ - -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ - -e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \ - -e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \ - -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ - -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ - ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} - - print_green "配置文件已生成: ${CONFIG_FILE}" - print_green "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" - - #--------------------------------------------------------------------------- - # 2. 清理和初始化 Ray - #--------------------------------------------------------------------------- - rm -f "$MASTER_IP_FILE" - ray stop --force || true - sleep 3 - - #--------------------------------------------------------------------------- - # 4. 启动 Ray Head - #--------------------------------------------------------------------------- - print_green "Starting Ray head node at $MASTER_ADDR" - ray start --head --node-ip-address $MASTER_ADDR --num-gpus 8 - sleep 10 - echo $MASTER_ADDR > $MASTER_IP_FILE - - #--------------------------------------------------------------------------- - # 5 & 6. 等待节点和 GPU 就绪 (逻辑保持不变) - #--------------------------------------------------------------------------- - # ... (此处省略重复的等待逻辑以保持简洁,实际运行时请保留原脚本中的 while 循环) ... - # [请保留原脚本中 5.等待所有Worker 6.等待GPU 7.等待Dashboard 的完整代码] - - #--------------------------------------------------------------------------- - # 9. 启动训练任务 - #--------------------------------------------------------------------------- - print_green "Starting training job..." - source .venv/bin/activate - - export RAY_ADDRESS="ray://localhost:10001" - export env_url="http://${MASTER_ADDR}:8080" - export env_type="finworld" - - print_green "===================================" - print_green "Training Configuration" - print_green "Total GPUs: $((NNODES * GPUS_PER_NODE))" - print_green "Log: ${TRAIN_LOG}" - print_green "===================================" - - # 启动训练任务 - python ajet/launcher.py \ - --with-finworld \ - --conf ${CONFIG_FILE} \ - --backbone="verl" \ - --debug="TAG_A" \ - 2>&1 | tee ${TRAIN_LOG} - - # 保留原脚本末尾的 CLI 调用 - ajet --conf ${CONFIG_FILE} --backbone='verl' - -#=============================================================================== -# Worker 节点启动流程 (逻辑保持不变) -#=============================================================================== -else - print_green "==> This is WORKER node: $HOSTNAME" - # [此处保留原脚本中 Worker 节点等待 Master IP 和连接 Ray 的完整逻辑] - while [ ! -f $MASTER_IP_FILE ]; do sleep 5; done - MASTER_ADDR=$(cat $MASTER_IP_FILE) - ray stop || true - ray start --address $MASTER_ADDR:6379 --num-gpus 8 - while true; do sleep 60; done -fi \ No newline at end of file diff --git a/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl_8b.sh b/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl_8b.sh deleted file mode 100644 index c7a13048..00000000 --- a/tutorial/example_finworld/scripts/ajet_finworld_loadjsonl_8b.sh +++ /dev/null @@ -1,264 +0,0 @@ -#!/bin/bash -set -e -#=============================================================================== -# 配置区域 - 用户只需修改这里 -#=============================================================================== -SUFFIX="ajet_finworld_loadjsonl_8b" # 实验后缀,影响所有日志和实验名称 -PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 - -# 新增:模型与模板配置 -MODEL_PATH="/mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-8B" -CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_template.yaml" - -# 新增:数据文件路径配置 -TRAIN_DATA_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/train_cc423_11171143_tasks.json" -VAL_DATA_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/val_30_tasks.json" - -# 新增:Reference Answer 文件路径配置(RM Gallery 需要) -TRAIN_REF_ANS_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_DR_11171143_cc.json" -VAL_REF_ANS_PATH="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/Reference_ans_val.json" - -# 新增:Judge 模型配置 -OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 -RM_LLM='qwen-max' # RM Gallery 评分模型 -JUDGE_CONCURRENCY=10 - -# 新增:奖励权重配置 -RM_WEIGHT=0.4 -CITATION_AUDIT_WEIGHT=0.2 -REPORT_RESOLUTION_WEIGHT=0.2 -TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 - -# API密钥配置(从 .env 文件加载,不要硬编码) -# 配置 -NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 -TRAIN_BATCH_SIZE=32 -NUM_STEPS=6 # 每个样本step轮数 - -ADDR="22.17.31.142" -MCP_PORT="8040" - -# 修改:配置文件生成路径,现在动态生成到 yaml 目录下 -export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" -CONFIG_FILE="${AJET_ROOT}/tutorial/example_finworld/yaml/finworld_${SUFFIX}.yaml" - -#=============================================================================== -# 环境配置区域 -#=============================================================================== - -cd ${AJET_ROOT} -source .venv/bin/activate -# API密钥配置 - 从 .env 文件加载 -ENV_FILE="${AJET_ROOT}/.env" -if [ -f "$ENV_FILE" ]; then - set -a - source "$ENV_FILE" - set +a - echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" -else - echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" -fi - -# MongoDB 缓存配置 -CACHE_TYPE="mongodb" -MONGO_URI="mongodb://${ADDR}:27117/" -MONGO_DB_NAME="finworld_cache" -MONGO_COLLECTION_NAME="tool_cache" - -# FinWorld MCP 配置 -LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" -FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" - -# 动态生成 MCP 配置文件 -mkdir -p $(dirname ${FINWORLD_MCP_CONFIG}) -cat > ${FINWORLD_MCP_CONFIG} << EOF -{ - "mcpServers": { - "flowllm": { - "transport": "sse", - "url": "http://${ADDR}:${MCP_PORT}/sse", - "timeout": 600, - "sse_read_timeout": 1200 - } - } -} -EOF -FINWORLD_TOOL_RESULT_MAX_CHARS=10000 - -# 其他服务配置 -HF_ENDPOINT="https://hf-mirror.com" -ES_HOSTS="http://11.160.132.46:8200" - -#=============================================================================== -# 多机训练参数配置 -#=============================================================================== -if [ -z "${WORLD_SIZE}" ]; then - echo "ERROR: WORLD_SIZE environment variable is not set!" - echo "Please ensure this script is run in a multi-node environment (e.g., PAI-DLC, SLURM)" - exit 1 -fi - -NNODES=${WORLD_SIZE} -GPUS_PER_NODE=8 -EXPECTED_WORKERS=$WORLD_SIZE - -#=============================================================================== -# NCCL 配置 -#=============================================================================== -export NCCL_TIMEOUT=1800 -export NCCL_DEBUG=WARN -export NCCL_IB_TIMEOUT=23 -export NCCL_ASYNC_ERROR_HANDLING=1 - -#=============================================================================== -# 自动生成的变量 -#=============================================================================== -CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") - -MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" -ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" -TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" - -#=============================================================================== -# 工具函数 -#=============================================================================== -print_green() { - echo -e "\033[32m$1\033[0m" -} - -print_red() { - echo -e "\033[31m$1\033[0m" -} - -log() { - echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" -} - -check_workers() { - local status_output=$(ray status 2>/dev/null) - if [ -z "$status_output" ]; then echo 0; return; fi - local node_count=$(echo "$status_output" | grep -E "^[[:space:]]*1[[:space:]]+node_" | wc -l) - if [ "$node_count" -gt 0 ]; then echo $node_count; return; fi - echo $(echo "$status_output" | grep -o "node_[0-9a-f]\+" | sort -u | wc -l) -} - -check_gpu_resources() { - gpu_count=$(ray status 2>/dev/null | grep -A 10 "Resources" | grep "GPU" | awk '{print $1}' | cut -d'/' -f2) - if [ -z "$gpu_count" ]; then echo 0; else printf "%.0f" "$gpu_count"; fi -} - -#=============================================================================== -# 导出环境变量 -#=============================================================================== -export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME -export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS -export HF_ENDPOINT ES_HOSTS -export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" -export RAY_CLUSTER_MODE="multi_node" -# Directory paths -export ENV_SERVICE_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/mongodb/BeyondAgent_env" - -export FINWORLD_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 -export FINWORLD_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" - -#=============================================================================== -# 主流程 -#=============================================================================== -log "开始多机多卡训练: ${SUFFIX}" -log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" -mkdir -p ${LOG_DIR} -mkdir -p $(dirname ${CONFIG_FILE}) - -#=============================================================================== -# Master 节点启动流程 -#=============================================================================== -if [[ $HOSTNAME == *"-master-"* ]]; then - print_green "==> This is MASTER node: $HOSTNAME" - - #--------------------------------------------------------------------------- - # 1. 动态生成配置文件 (从模板注入参数) - #--------------------------------------------------------------------------- - log "正在从模板生成配置文件..." - sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ - -e "s|{{PREFIX}}|${PREFIX}|g" \ - -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ - -e "s|{{NNODES}}|${NNODES}|g" \ - -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ - -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ - -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ - -e "s|{{RM_LLM}}|${RM_LLM}|g" \ - -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ - -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ - -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ - -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ - -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ - -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ - -e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \ - -e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \ - -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ - -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ - ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} - - print_green "配置文件已生成: ${CONFIG_FILE}" - print_green "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" - - #--------------------------------------------------------------------------- - # 2. 清理和初始化 Ray - #--------------------------------------------------------------------------- - rm -f "$MASTER_IP_FILE" - ray stop --force || true - sleep 3 - - #--------------------------------------------------------------------------- - # 4. 启动 Ray Head - #--------------------------------------------------------------------------- - print_green "Starting Ray head node at $MASTER_ADDR" - ray start --head --node-ip-address $MASTER_ADDR --num-gpus 8 - sleep 10 - echo $MASTER_ADDR > $MASTER_IP_FILE - - #--------------------------------------------------------------------------- - # 5 & 6. 等待节点和 GPU 就绪 (逻辑保持不变) - #--------------------------------------------------------------------------- - # ... (此处省略重复的等待逻辑以保持简洁,实际运行时请保留原脚本中的 while 循环) ... - # [请保留原脚本中 5.等待所有Worker 6.等待GPU 7.等待Dashboard 的完整代码] - - #--------------------------------------------------------------------------- - # 9. 启动训练任务 - #--------------------------------------------------------------------------- - print_green "Starting training job..." - source .venv/bin/activate - - export RAY_ADDRESS="ray://localhost:10001" - export env_url="http://${MASTER_ADDR}:8080" - export env_type="finworld" - - print_green "===================================" - print_green "Training Configuration" - print_green "Total GPUs: $((NNODES * GPUS_PER_NODE))" - print_green "Log: ${TRAIN_LOG}" - print_green "===================================" - - # 启动训练任务 - python ajet/launcher.py \ - --with-finworld \ - --conf ${CONFIG_FILE} \ - --backbone="verl" \ - --debug="TAG_A" \ - 2>&1 | tee ${TRAIN_LOG} - - # 保留原脚本末尾的 CLI 调用 - ajet --conf ${CONFIG_FILE} --backbone='verl' - -#=============================================================================== -# Worker 节点启动流程 (逻辑保持不变) -#=============================================================================== -else - print_green "==> This is WORKER node: $HOSTNAME" - # [此处保留原脚本中 Worker 节点等待 Master IP 和连接 Ray 的完整逻辑] - while [ ! -f $MASTER_IP_FILE ]; do sleep 5; done - MASTER_ADDR=$(cat $MASTER_IP_FILE) - ray stop || true - ray start --address $MASTER_ADDR:6379 --num-gpus 8 - while true; do sleep 60; done -fi \ No newline at end of file diff --git a/tutorial/example_finworld/scripts/single.sh b/tutorial/example_finworld/scripts/single.sh deleted file mode 100644 index c52120c8..00000000 --- a/tutorial/example_finworld/scripts/single.sh +++ /dev/null @@ -1,112 +0,0 @@ -#!/bin/bash -set -e - -#=============================================================================== -# 配置区域 -#=============================================================================== -SUFFIX="cc_rm4_res2cit2fai2_30b_single" # 实验后缀 -PREFIX="open" # 实验前缀 - -ADDR="127.0.0.1" # 单机建议使用回环地址 -MCP_PORT="8040" -export CONFIG_FILE_NAME="tutorial/example_finworld/finworld_single.yaml" -export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" -export BEYONDAGENT_ROOT="${AJET_ROOT}" # 假设在同一目录下,若不同请手动修改 - -#=============================================================================== -# 环境初始化 -#=============================================================================== -cd ${AJET_ROOT} - -# 加载 .env -ENV_FILE="${AJET_ROOT}/.env" -if [ -f "$ENV_FILE" ]; then - set -a && source "$ENV_FILE" && set +a - echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" -fi - -# 1. 激活主虚拟环境 (uv) -source .venv/bin/activate - -# 2. 动态获取 Conda 基础路径,用于解决 PTY 找不到 conda 的问题 -CONDA_BASE_PATH=$(conda info --base) - -#=============================================================================== -# 服务与路径配置 -#=============================================================================== -# MongoDB 配置 -export CACHE_TYPE="mongodb" -export MONGO_URI="mongodb://${ADDR}:27117/" -export MONGO_DB_NAME="finworld_cache" -export MONGO_COLLECTION_NAME="tool_cache" - -# FinWorld 配置 -LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" -mkdir -p ${LOG_DIR} -export FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" -export FINWORLD_TOOL_RESULT_MAX_CHARS=10000 - -# 动态生成 MCP 配置 -cat > ${FINWORLD_MCP_CONFIG} << EOF -{ - "mcpServers": { - "flowllm": { - "transport": "sse", - "url": "http://${ADDR}:${MCP_PORT}/sse", - "timeout": 600, - "sse_read_timeout": 1200 - } - } -} -EOF - -# 环境变量导出 -export HF_ENDPOINT="https://hf-mirror.com" -export ES_HOSTS="http://11.160.132.46:8200" -export PYTHONPATH="${AJET_ROOT}:${BEYONDAGENT_ROOT}:${PYTHONPATH}" -export RAY_CLUSTER_MODE="single_node" - -# 关键修复:在脚本中显式加载 conda.sh 以供 PTY 子进程使用 -export FINWORLD_PATH="${BEYONDAGENT_ROOT}" -export FINWORLD_SCRIPT="source ${CONDA_BASE_PATH}/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${BEYONDAGENT_ROOT} && python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" - -#=============================================================================== -# 启动 Ray 本地集群 -#=============================================================================== -echo -e "\033[32m正在初始化单机 Ray 环境...\033[0m" -ray stop --force || true -sleep 2 - -# 启动单机 Head 节点,分配 8 张 GPU -ray start --head --num-gpus 8 - -# 等待 Ray 就绪 -sleep 5 -if ! ray status > /dev/null 2>&1; then - echo -e "\033[31m错误: Ray 启动失败\033[0m" - exit 1 -fi - -#=============================================================================== -# 启动训练 -#=============================================================================== -CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") -CONFIG_FILE="${AJET_ROOT}/${CONFIG_FILE_NAME}" -TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" - -# 设置训练所需的运行时变量 -export RAY_ADDRESS="ray://localhost:10001" -export env_url="http://127.0.0.1:8080" -export env_type="finworld" - -echo -e "\033[32m===================================\033[0m" -echo -e "\033[32m开始单机运行: ${SUFFIX}\033[0m" -echo -e "\033[32m日志文件: ${TRAIN_LOG}\033[0m" -echo -e "\033[32m===================================\033[0m" - -# 启动 Launcher -python ajet/launcher.py \ - --with-finworld \ - --conf ${CONFIG_FILE} \ - --backbone="verl" \ - 2>&1 | tee ${TRAIN_LOG} \ No newline at end of file diff --git a/tutorial/example_finworld/yaml_template/finworld_template.yaml b/tutorial/example_finworld/yaml_template/finworld_template.yaml index 70b379f0..6a801053 100644 --- a/tutorial/example_finworld/yaml_template/finworld_template.yaml +++ b/tutorial/example_finworld/yaml_template/finworld_template.yaml @@ -66,7 +66,7 @@ ajet: env_url: "http://127.0.0.1:8080" env_action_preference: code trainer: - default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/ajet/checkpoints/example_finworld//{{PREFIX}}/{{SUFFIX}}" + default_local_dir: "{{CKPT_SAVE_PATH}}/{{PREFIX}}/{{SUFFIX}}" # resume_mode: disable # 禁用自动恢复,从头开始训练 actor_rollout_ref: rollout: From 623b7d91213d9c6152e157d5b1094a79e5838332 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 17:57:01 +0800 Subject: [PATCH 20/56] Refactor(deep_finance): Replace and remove finworld-related implementations - Switched the example directory from example_finworld to example_deep_finance - Modified startup parameters and logic to support deep_finance, replacing the finworld option - Replaced finworld_reader with deep_finance_reader in the task reader - Adjusted environment client configuration in resource management, using deep_finance instead of finworld-related checks - Updated reward metric tool documentation to support deep_finance - Deleted finworld-related configuration files, scripts, code, and evaluation modules, cleaning up leftover files and scripts - Replaced the keyword "finworld" with "deep_finance" in comments and logs --- .gitignore | 2 +- ajet/launcher.py | 8 ++++---- ajet/task_reader/__init__.py | 8 ++++---- ajet/task_rollout/resource_keeper.py | 10 +++++----- ajet/utils/metric_helper/reward_metric_helper.py | 4 ++-- .../config/mcp_finance_tool_generated.json | 0 tutorial/example_deep_finance/deep_finance.md | 1 + .../deep_finance.py} | 2 +- .../deep_finance.sh} | 12 +++++------- .../deep_finance.yaml} | 0 .../deep_finance_judge.py} | 2 +- .../deep_finance_reader.py} | 10 +++++----- .../prompt/finance_analyst_prompt.md | 0 .../prompt/tool_prompt_builder.py | 0 .../yaml_template/deep_finance_template.yaml} | 10 +++++----- tutorial/example_finworld/finworld.md | 1 - 16 files changed, 34 insertions(+), 36 deletions(-) rename tutorial/{example_finworld => example_deep_finance}/config/mcp_finance_tool_generated.json (100%) create mode 100644 tutorial/example_deep_finance/deep_finance.md rename tutorial/{example_finworld/finworld.py => example_deep_finance/deep_finance.py} (99%) rename tutorial/{example_finworld/finworld.sh => example_deep_finance/deep_finance.sh} (95%) rename tutorial/{example_finworld/finworld.yaml => example_deep_finance/deep_finance.yaml} (100%) rename tutorial/{example_finworld/finworld_judge.py => example_deep_finance/deep_finance_judge.py} (99%) rename tutorial/{example_finworld/finworld_reader.py => example_deep_finance/deep_finance_reader.py} (96%) rename tutorial/{example_finworld => example_deep_finance}/prompt/finance_analyst_prompt.md (100%) rename tutorial/{example_finworld => example_deep_finance}/prompt/tool_prompt_builder.py (100%) rename tutorial/{example_finworld/yaml_template/finworld_template.yaml => example_deep_finance/yaml_template/deep_finance_template.yaml} (89%) delete mode 100644 tutorial/example_finworld/finworld.md diff --git a/.gitignore b/.gitignore index c63a9c4d..5add9fac 100644 --- a/.gitignore +++ b/.gitignore @@ -154,4 +154,4 @@ site dump.rdb -tutorial/example_finworld/yaml/* \ No newline at end of file +tutorial/example_deep_finance/yaml/* \ No newline at end of file diff --git a/ajet/launcher.py b/ajet/launcher.py index 73a347aa..10af0d8e 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -60,10 +60,10 @@ def parse_args(): help="Launch appworld", ) parser.add_argument( - "--with-finworld", + "--with-deep_finance", action="store_true", default=False, - help="Launch finworld", + help="Launch deep_finance", ) parser.add_argument( "--with-webshop", @@ -303,8 +303,8 @@ def main(): if args.with_appworld: pty_launch("appworld") - if args.with_finworld: - pty_launch("finworld") + if args.with_deep_finance: + pty_launch("deep_finance") if args.with_crafters: pty_launch("crafters") diff --git a/ajet/task_reader/__init__.py b/ajet/task_reader/__init__.py index d0baf43a..431291d7 100644 --- a/ajet/task_reader/__init__.py +++ b/ajet/task_reader/__init__.py @@ -61,10 +61,10 @@ def __init__(self, reader_type, reader_config): self.task_reader = DataGeneratorTaskReader(reader_config) elif task_reader_type == "random_dummy": self.task_reader = RandomDummyTaskReader(reader_config) - elif task_reader_type == "finworld": - # FinWorld 专用: 数据从 JSON 文件加载并组装 init_messages,工具调用走 env_service - from tutorial.example_finworld.finworld_reader import FinworldReader - self.task_reader = FinworldReader(reader_config) + elif task_reader_type == "deep_finance": + # deep_finance 专用: 数据从 JSON 文件加载并组装 init_messages,工具调用走 env_service + from tutorial.example_deep_finance.deep_finance_reader import deep_financeReader + self.task_reader = deep_financeReader(reader_config) else: raise ValueError(f"Unsupported task reader type: {task_reader_type}") diff --git a/ajet/task_rollout/resource_keeper.py b/ajet/task_rollout/resource_keeper.py index 069f715d..6d4045d0 100644 --- a/ajet/task_rollout/resource_keeper.py +++ b/ajet/task_rollout/resource_keeper.py @@ -25,7 +25,7 @@ def __enter__(self): self.tokenizer = self.workflow_task.tokenizer self.llm_inference_fn = self.workflow_task.llm_inference_fn self.observation_window = self.workflow_task.observation_window - if self.config.ajet.task_reader.type in ("env_service", "finworld"): + if self.config.ajet.task_reader.type in ("env_service", "deep_finance"): url = self.config.ajet.task_reader.env_service.env_url env_type = self.config.ajet.task_reader.env_service.env_type self.env = EnvClientNg(base_url=url) @@ -97,10 +97,10 @@ def _initialize_environment_and_messages(self) -> List[dict]: if self.env is not None: self.env.release_instance(self.workflow_task.episode_uuid) raise e - elif reader_type == "finworld": - # finworld: 调用 create_instance 注册实例,但使用 reader 组装的 init_messages + elif reader_type == "deep_finance": + # deep_finance: 调用 create_instance 注册实例,但使用 reader 组装的 init_messages if self.env is None: - raise ValueError("Environment client is None but finworld type is specified") + raise ValueError("Environment client is None but deep_finance type is specified") try: # 必须调用 create_instance,让服务端创建实例,后续 step() 才能工作 self.env.create_instance( @@ -114,7 +114,7 @@ def _initialize_environment_and_messages(self) -> List[dict]: if task.init_messages: init_messages = task.init_messages else: - assert task.main_query, "finworld requires init_messages or main_query." + assert task.main_query, "deep_finance requires init_messages or main_query." init_messages = [{"role": "user", "content": task.main_query}] except Exception as e: logger.bind(exception=True).exception( diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py index 31e1f95a..bfe12e4f 100644 --- a/ajet/utils/metric_helper/reward_metric_helper.py +++ b/ajet/utils/metric_helper/reward_metric_helper.py @@ -1,8 +1,8 @@ """ -FinWorld Reward Metrics Helper +deep_finance Reward Metrics Helper Provides standalone utility functions for reward_stats extraction and SwanLab metrics formatting. -Decouples finworld-specific logic from core code, reducing intrusion into native_compat_trainer. +Decouples deep_finance-specific logic from core code, reducing intrusion into native_compat_trainer. SwanLab metrics directory structure: - rewards/ Top-level aggregated scores diff --git a/tutorial/example_finworld/config/mcp_finance_tool_generated.json b/tutorial/example_deep_finance/config/mcp_finance_tool_generated.json similarity index 100% rename from tutorial/example_finworld/config/mcp_finance_tool_generated.json rename to tutorial/example_deep_finance/config/mcp_finance_tool_generated.json diff --git a/tutorial/example_deep_finance/deep_finance.md b/tutorial/example_deep_finance/deep_finance.md new file mode 100644 index 00000000..1ac6d0c0 --- /dev/null +++ b/tutorial/example_deep_finance/deep_finance.md @@ -0,0 +1 @@ +# deep_finance \ No newline at end of file diff --git a/tutorial/example_finworld/finworld.py b/tutorial/example_deep_finance/deep_finance.py similarity index 99% rename from tutorial/example_finworld/finworld.py rename to tutorial/example_deep_finance/deep_finance.py index a911c5fd..f3ceae9e 100644 --- a/tutorial/example_finworld/finworld.py +++ b/tutorial/example_deep_finance/deep_finance.py @@ -152,7 +152,7 @@ async def execute( if isinstance(obs, list): # Standard Mode: obs 是 tool messages 列表 - # 注意:finworld_env.step 返回 {"state": [tool_results_msgs]} 套了一层列表 + # 注意:deep_finance_env.step 返回 {"state": [tool_results_msgs]} 套了一层列表 # BaseGymEnv.step 直接透传,所以 obs = [tool_results_msgs] # 需要解包获取实际的消息列表 actual_msgs = obs[0] if (len(obs) == 1 and isinstance(obs[0], list)) else obs diff --git a/tutorial/example_finworld/finworld.sh b/tutorial/example_deep_finance/deep_finance.sh similarity index 95% rename from tutorial/example_finworld/finworld.sh rename to tutorial/example_deep_finance/deep_finance.sh index 904ac4c1..5d79ded7 100644 --- a/tutorial/example_finworld/finworld.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -3,7 +3,7 @@ set -e #=============================================================================== # 1. 配置区域 - 用户只需修改这里 #=============================================================================== -SUFFIX="ajet_finworld" # 实验后缀,影响所有日志和实验名称 +SUFFIX="ajet_deep_finance" # 实验后缀,影响所有日志和实验名称 PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 # OpenJudge 模型配置 @@ -24,8 +24,8 @@ NUM_STEPS=6 # 每个样本step轮数 FINWORLD_TOOL_RESULT_MAX_CHARS=10000 # 修改:配置文件生成路径,现在动态生成到 yaml 目录下 export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" -CONFIG_FILE="${AJET_ROOT}/tutorial/example_finworld/yaml/finworld_${SUFFIX}.yaml" -CONFIG_TEMPLATE="tutorial/example_finworld/yaml_template/finworld_template.yaml" +CONFIG_FILE="${AJET_ROOT}/tutorial/example_deep_finance/yaml/deep_finance_${SUFFIX}.yaml" +CONFIG_TEMPLATE="tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml" # 涉密的配置(API_KEY以及模型、数据位置)从.env读取 cd ${AJET_ROOT} @@ -80,7 +80,7 @@ MONGO_COLLECTION_NAME="tool_cache" export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME # FinWorld MCP 配置 -FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_finworld/config/mcp_finance_tool_generated.json" +FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_deep_finance/config/mcp_finance_tool_generated.json" # 动态生成 MCP 配置文件 mkdir -p $(dirname ${FINWORLD_MCP_CONFIG}) @@ -196,8 +196,6 @@ if [[ $HOSTNAME == *"-master-"* ]]; then print_green "Starting training job..." source .venv/bin/activate export RAY_ADDRESS="ray://localhost:10001" - export env_url="http://${MASTER_ADDR}:8080" - export env_type="finworld" print_green "===================================" print_green "Training Configuration" @@ -207,7 +205,7 @@ if [[ $HOSTNAME == *"-master-"* ]]; then # 启动训练任务(最核心) python ajet/launcher.py \ - --with-finworld \ + --with-deep_finance \ --conf ${CONFIG_FILE} \ --backbone="verl" \ 2>&1 | tee ${TRAIN_LOG} diff --git a/tutorial/example_finworld/finworld.yaml b/tutorial/example_deep_finance/deep_finance.yaml similarity index 100% rename from tutorial/example_finworld/finworld.yaml rename to tutorial/example_deep_finance/deep_finance.yaml diff --git a/tutorial/example_finworld/finworld_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py similarity index 99% rename from tutorial/example_finworld/finworld_judge.py rename to tutorial/example_deep_finance/deep_finance_judge.py index 5cdaf3f3..5bbee7c9 100644 --- a/tutorial/example_finworld/finworld_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -126,7 +126,7 @@ def _setup_weights(self): """ cfg = getattr(self.config, "ajet", None) - # 定义各 grader 的权重(可从 config 中读取)- 与 finworld_judge.py 对齐 + # 定义各 grader 的权重(可从 config 中读取)- 与 deep_finance_judge.py 对齐 self.w = { "rm": getattr(cfg, "rm_weight", 1.0) if cfg else 1.0, # RM Gallery 权重 "citation_audit": getattr(cfg, "citation_audit_weight", 0.0) if cfg else 0.0, # CitationAudit 权重 diff --git a/tutorial/example_finworld/finworld_reader.py b/tutorial/example_deep_finance/deep_finance_reader.py similarity index 96% rename from tutorial/example_finworld/finworld_reader.py rename to tutorial/example_deep_finance/deep_finance_reader.py index 44d8a330..ad94ea89 100644 --- a/tutorial/example_finworld/finworld_reader.py +++ b/tutorial/example_deep_finance/deep_finance_reader.py @@ -34,7 +34,7 @@ class FinworldReader(BaseTaskReader): 特点: 1. 从 JSON 文件加载任务数据(支持 list 和 dict 格式) 2. 现场组装 init_messages(system_prompt + user_query) - 3. env_type 固定为 "finworld",由 env_service 负责工具调用 + 3. env_type 固定为 "deep_finance",由 env_service 负责工具调用 """ # 类级别缓存 @@ -70,7 +70,7 @@ def _init_prompt_templates(self): if FinworldReader._tool_prompt_cache is None: # 使用 tool_prompt_builder.py 中的静态模板 _debug_log(f"Loading tool prompt template...") - from tutorial.example_finworld.prompt.tool_prompt_builder import get_tool_prompt_template + from tutorial.example_deep_finance.prompt.tool_prompt_builder import get_tool_prompt_template FinworldReader._tool_prompt_cache = get_tool_prompt_template() _debug_log(f"Tool prompt template loaded, length: {len(FinworldReader._tool_prompt_cache)} chars") else: @@ -237,7 +237,7 @@ def _create_task(self, task_id: str, query: str, raw_item: Dict[str, Any]) -> Ta main_query=query, init_messages=init_messages, task_id=task_id, - env_type="finworld", # 固定为 finworld,由 env_service 处理 + env_type="deep_finance", # 固定为 deep_finance,由 env_service 处理 metadata=metadata ) _debug_log(f" Task created successfully: {task_id}") @@ -246,7 +246,7 @@ def _create_task(self, task_id: str, query: str, raw_item: Dict[str, Any]) -> Ta def get_training_tasks(self) -> List[Task]: """获取训练任务""" _debug_log(f"get_training_tasks() called") - file_path = self.reader_config.finworld.training.file_path + file_path = self.reader_config.deep_finance.training.file_path _debug_log(f"Training file path: {file_path}") tasks = self._read_json_file(file_path, split="train") _debug_log(f"get_training_tasks() returning {len(tasks)} tasks") @@ -255,7 +255,7 @@ def get_training_tasks(self) -> List[Task]: def get_validation_tasks(self) -> List[Task]: """获取验证任务""" _debug_log(f"get_validation_tasks() called") - file_path = self.reader_config.finworld.validation.file_path + file_path = self.reader_config.deep_finance.validation.file_path _debug_log(f"Validation file path: {file_path}") tasks = self._read_json_file(file_path, split="val") _debug_log(f"get_validation_tasks() returning {len(tasks)} tasks") diff --git a/tutorial/example_finworld/prompt/finance_analyst_prompt.md b/tutorial/example_deep_finance/prompt/finance_analyst_prompt.md similarity index 100% rename from tutorial/example_finworld/prompt/finance_analyst_prompt.md rename to tutorial/example_deep_finance/prompt/finance_analyst_prompt.md diff --git a/tutorial/example_finworld/prompt/tool_prompt_builder.py b/tutorial/example_deep_finance/prompt/tool_prompt_builder.py similarity index 100% rename from tutorial/example_finworld/prompt/tool_prompt_builder.py rename to tutorial/example_deep_finance/prompt/tool_prompt_builder.py diff --git a/tutorial/example_finworld/yaml_template/finworld_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml similarity index 89% rename from tutorial/example_finworld/yaml_template/finworld_template.yaml rename to tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index 6a801053..37089d04 100644 --- a/tutorial/example_finworld/yaml_template/finworld_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -1,6 +1,6 @@ # ------------------ 主要配置 ------------------ ajet: - project_name: ajet_finworld + project_name: ajet_deep_finance experiment_name: "{{SUFFIX}}" # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) judge: @@ -16,7 +16,7 @@ ajet: rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 task_judge: # 使用本地 FinWorldJudge 进行评估(解耦远程 env_service) - judge_protocol: tutorial.example_finworld.finworld_judge->FinWorldJudgeByOpenJudge + judge_protocol: tutorial.example_deep_finance.deep_finance_judge->FinWorldJudgeByOpenJudge model: # ✨✨✨✨ 设置待训练的模型 path: {{MODEL_PATH}} @@ -30,7 +30,7 @@ ajet: total_epochs: 200 rollout: # ✨✨✨✨ 编写并选择Agent - user_workflow: tutorial.example_finworld.finworld->ExampleDeepResearchProtocol + user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol force_disable_toolcalls: True enable_oversample: False tensor_model_parallel_size: 8 @@ -54,8 +54,8 @@ ajet: max_response_length: 41000 task_reader: - type: finworld # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service - finworld: + type: deep_finance # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service + deep_finance: training: file_path: {{TRAIN_DATA_PATH}} validation: diff --git a/tutorial/example_finworld/finworld.md b/tutorial/example_finworld/finworld.md deleted file mode 100644 index e884e864..00000000 --- a/tutorial/example_finworld/finworld.md +++ /dev/null @@ -1 +0,0 @@ -# finworld \ No newline at end of file From 0aaab86c776c97eb7d7fd9aa7a71967f8f863284 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 18:05:35 +0800 Subject: [PATCH 21/56] refactor(deepfinance): Rename and unify DeepFinance module and config references - Replace all "finworld" and "deep_finance" names with the unified "deepfinance" format. - Modify command-line arguments to `--with-deepfinance` for consistency. - Adjust the class name in `task_reader` from `deep_financeReader` to `DeepFinanceReader`. - Update the documentation description and file name of the `metric_helper` module to DeepFinance. - Modify environment variables and configuration paths in the example script `deep_finance.sh` to use the `DEEPFINANCE` prefix. - Update `judge_protocol` to `DeepFinanceJudgeByOpenJudge` in the `deep_finance.yaml` configuration. - Refactor the `FinWorldJudgeByOpenJudge` class in `deep_finance_judge.py` to `DeepFinanceJudgeByOpenJudge`. - Rename the `FinworldReader` class in `deep_finance_reader.py` to `DeepFinanceReader`. - Modify the debug log identifier and corresponding environment variable name to `DEEPFINANCE_DEBUG`. - Update the evaluation protocol in the `deep_finance_template.yaml` template to `DeepFinanceJudgeByOpenJudge`. - Ensure that internal references and comments in all modules are updated to use DeepFinance and deepfinance-related names. --- ajet/launcher.py | 8 ++--- ajet/task_reader/__init__.py | 4 +-- .../utils/metric_helper/tool_metric_helper.py | 2 +- tutorial/example_deep_finance/deep_finance.sh | 16 ++++----- .../example_deep_finance/deep_finance.yaml | 4 +-- .../deep_finance_judge.py | 26 +++++++------- .../deep_finance_reader.py | 34 +++++++++---------- .../yaml_template/deep_finance_template.yaml | 4 +-- 8 files changed, 49 insertions(+), 49 deletions(-) diff --git a/ajet/launcher.py b/ajet/launcher.py index 10af0d8e..47345ce2 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -60,10 +60,10 @@ def parse_args(): help="Launch appworld", ) parser.add_argument( - "--with-deep_finance", + "--with-deepfinance", action="store_true", default=False, - help="Launch deep_finance", + help="Launch deepfinance", ) parser.add_argument( "--with-webshop", @@ -303,8 +303,8 @@ def main(): if args.with_appworld: pty_launch("appworld") - if args.with_deep_finance: - pty_launch("deep_finance") + if args.with_deepfinance: + pty_launch("deepfinance") if args.with_crafters: pty_launch("crafters") diff --git a/ajet/task_reader/__init__.py b/ajet/task_reader/__init__.py index 431291d7..d3bbb1d7 100644 --- a/ajet/task_reader/__init__.py +++ b/ajet/task_reader/__init__.py @@ -63,8 +63,8 @@ def __init__(self, reader_type, reader_config): self.task_reader = RandomDummyTaskReader(reader_config) elif task_reader_type == "deep_finance": # deep_finance 专用: 数据从 JSON 文件加载并组装 init_messages,工具调用走 env_service - from tutorial.example_deep_finance.deep_finance_reader import deep_financeReader - self.task_reader = deep_financeReader(reader_config) + from tutorial.example_deep_finance.deep_finance_reader import DeepFinanceReader + self.task_reader = DeepFinanceReader(reader_config) else: raise ValueError(f"Unsupported task reader type: {task_reader_type}") diff --git a/ajet/utils/metric_helper/tool_metric_helper.py b/ajet/utils/metric_helper/tool_metric_helper.py index 03b3ed01..f1ed5d70 100644 --- a/ajet/utils/metric_helper/tool_metric_helper.py +++ b/ajet/utils/metric_helper/tool_metric_helper.py @@ -1,5 +1,5 @@ """ -FinWorld Tool Metrics Helper +DeepFinance Tool Metrics Helper Specialized module for extracting tool-related statistics and formatting SwanLab reports. Extracts data from log_metrics['tool_stats']. diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index 5d79ded7..02620620 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -21,7 +21,7 @@ TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 -FINWORLD_TOOL_RESULT_MAX_CHARS=10000 +DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 # 修改:配置文件生成路径,现在动态生成到 yaml 目录下 export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" CONFIG_FILE="${AJET_ROOT}/tutorial/example_deep_finance/yaml/deep_finance_${SUFFIX}.yaml" @@ -79,12 +79,12 @@ MONGO_DB_NAME="finworld_cache" MONGO_COLLECTION_NAME="tool_cache" export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME -# FinWorld MCP 配置 -FINWORLD_MCP_CONFIG="${AJET_ROOT}/tutorial/example_deep_finance/config/mcp_finance_tool_generated.json" +# DeepFinance MCP 配置 +DEEPFINANCE_MCP_CONFIG="${AJET_ROOT}/tutorial/example_deep_finance/config/mcp_finance_tool_generated.json" # 动态生成 MCP 配置文件 -mkdir -p $(dirname ${FINWORLD_MCP_CONFIG}) -cat > ${FINWORLD_MCP_CONFIG} << EOF +mkdir -p $(dirname ${DEEPFINANCE_MCP_CONFIG}) +cat > ${DEEPFINANCE_MCP_CONFIG} << EOF { "mcpServers": { "flowllm": { @@ -96,7 +96,7 @@ cat > ${FINWORLD_MCP_CONFIG} << EOF } } EOF -export FINWORLD_MCP_CONFIG FINWORLD_TOOL_RESULT_MAX_CHARS +export DEEPFINANCE_MCP_CONFIG DEEPFINANCE_TOOL_RESULT_MAX_CHARS # 其他服务配置 HF_ENDPOINT="https://hf-mirror.com" @@ -157,8 +157,8 @@ export NCCL_ASYNC_ERROR_HANDLING=1 export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" export RAY_CLUSTER_MODE="multi_node" -export FINWORLD_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 -export FINWORLD_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && FINWORLD_TOOL_RESULT_MAX_CHARS=${FINWORLD_TOOL_RESULT_MAX_CHARS} FINWORLD_MCP_CONFIG=${FINWORLD_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" +export DEEPFINANCE_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 +export DEEPFINANCE_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && DEEPFINANCE_TOOL_RESULT_MAX_CHARS=${DEEPFINANCE_TOOL_RESULT_MAX_CHARS} DEEPFINANCE_MCP_CONFIG=${DEEPFINANCE_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" #=============================================================================== diff --git a/tutorial/example_deep_finance/deep_finance.yaml b/tutorial/example_deep_finance/deep_finance.yaml index 344120a5..c4da3438 100644 --- a/tutorial/example_deep_finance/deep_finance.yaml +++ b/tutorial/example_deep_finance/deep_finance.yaml @@ -11,7 +11,7 @@ ajet: rm_weight: 0.4 # RM Gallery 权重 task_judge: judge_type: customized_protocol - judge_protocol: tutorial.example_finworld.finworld_judge->FinWorldJudgeByOpenJudge + judge_protocol: tutorial.example_finworld.finworld_judge->DeepFinanceJudgeByOpenJudge model: # ✨✨✨✨ 设置待训练的模型 path: /mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507 @@ -59,7 +59,7 @@ ajet: # training_split: train # validation_split: val - # === 方案 B: FinWorld Reader 模式 (数据从 JSON 加载,工具调用走 env_service) === + # === 方案 B: DeepFinance Reader 模式 (数据从 JSON 加载,工具调用走 env_service) === type: finworld finworld: training: diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py index 5bbee7c9..f49d88d3 100644 --- a/tutorial/example_deep_finance/deep_finance_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -1,4 +1,4 @@ -"""FinWorld Task Judge - OpenJudge 版本 +"""DeepFinance Task Judge - OpenJudge 版本 集成: RM Gallery, OpenJudge Graders (含 CitationAudit) """ @@ -82,12 +82,12 @@ def load_reference_answers_from_file(file_path: str) -> Tuple[Dict[str, str], Di # ============================================================================= -# FinWorldJudgeByOpenJudge 类 +# DeepFinanceJudgeByOpenJudge 类 # ============================================================================= -class FinWorldJudgeByOpenJudge(BaseJudge): +class DeepFinanceJudgeByOpenJudge(BaseJudge): """ - 使用 OpenJudge 框架的 FinWorld Judge + 使用 OpenJudge 框架的 DeepFinance Judge 集成: RM Gallery, OpenJudge Graders (含 CitationAudit) 分析: @@ -171,11 +171,11 @@ def _init_rm_components(self): """初始化 RM Gallery Evaluator(仅当 rm_weight > 0 时)""" self._rm_enabled = (self.w.get("rm", 0) > 0) if self._rm_enabled: - if FinWorldJudgeByOpenJudge._rm_evaluator_instance is None: + if DeepFinanceJudgeByOpenJudge._rm_evaluator_instance is None: self._init_rm_evaluator() - FinWorldJudgeByOpenJudge._rm_evaluator_instance = self.rm_evaluator + DeepFinanceJudgeByOpenJudge._rm_evaluator_instance = self.rm_evaluator else: - self.rm_evaluator = FinWorldJudgeByOpenJudge._rm_evaluator_instance + self.rm_evaluator = DeepFinanceJudgeByOpenJudge._rm_evaluator_instance else: self.rm_evaluator = None @@ -220,20 +220,20 @@ def _init_reference_answers(self): val_ref_ans_path = getattr(self.config.ajet.judge, "val_ref_ans_path", "") def _load(path, key): - if path and key not in FinWorldJudgeByOpenJudge._ref_answers_cache: + if path and key not in DeepFinanceJudgeByOpenJudge._ref_answers_cache: try: ans, dom = load_reference_answers_from_file(path) - FinWorldJudgeByOpenJudge._ref_answers_cache[key], FinWorldJudgeByOpenJudge._ref_domains_cache[key] = ans, dom + DeepFinanceJudgeByOpenJudge._ref_answers_cache[key], DeepFinanceJudgeByOpenJudge._ref_domains_cache[key] = ans, dom except Exception: - FinWorldJudgeByOpenJudge._ref_answers_cache[key], FinWorldJudgeByOpenJudge._ref_domains_cache[key] = {}, {} + DeepFinanceJudgeByOpenJudge._ref_answers_cache[key], DeepFinanceJudgeByOpenJudge._ref_domains_cache[key] = {}, {} _load(train_ref_ans_path, "train") _load(val_ref_ans_path, "val") def _get_reference_data(self, task_id: str) -> Tuple[str, str]: """获取任务的参考答案和领域""" cache_key = "val" if task_id.startswith("val_") else "train" - ans = FinWorldJudgeByOpenJudge._ref_answers_cache.get(cache_key, {}).get(task_id, "") - dom = FinWorldJudgeByOpenJudge._ref_domains_cache.get(cache_key, {}).get(task_id) + ans = DeepFinanceJudgeByOpenJudge._ref_answers_cache.get(cache_key, {}).get(task_id, "") + dom = DeepFinanceJudgeByOpenJudge._ref_domains_cache.get(cache_key, {}).get(task_id) return ans, dom @@ -400,7 +400,7 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO quota_exceeded_flags=quota_exceeded_flags ) - print(f"FinWorldJudgeByOpenJudge: task_id={task_id}, fused={fused_reward:.4f}, final={final_reward:.4f}, rm_time={rm_time:.2f}s, grading_time={grading_time:.2f}s, total={judge_total_time:.2f}s") + print(f"DeepFinanceJudgeByOpenJudge: task_id={task_id}, fused={fused_reward:.4f}, final={final_reward:.4f}, rm_time={rm_time:.2f}s, grading_time={grading_time:.2f}s, total={judge_total_time:.2f}s") # 9. 判断是否成功(可根据实际需求调整阈值) is_success = final_reward >= 0.7 diff --git a/tutorial/example_deep_finance/deep_finance_reader.py b/tutorial/example_deep_finance/deep_finance_reader.py index ad94ea89..1752bcd0 100644 --- a/tutorial/example_deep_finance/deep_finance_reader.py +++ b/tutorial/example_deep_finance/deep_finance_reader.py @@ -1,4 +1,4 @@ -"""FinWorld Reader +"""DeepFinance Reader 从 JSON 文件加载任务数据,并现场组装 init_messages。 - 数据来源:训练集/测试集 JSON 文件 @@ -18,18 +18,18 @@ logger = logging.getLogger(__name__) # 控制 debug 输出的开关(可通过环境变量控制) -DEBUG_ENABLED = os.environ.get("FINWORLD_DEBUG", "0") == "1" +DEBUG_ENABLED = os.environ.get("DEEPFINANCE_DEBUG", "0") == "1" def _debug_log(msg: str): """统一的 debug 日志输出""" if DEBUG_ENABLED: - print(f"[DEBUG][FinworldReader] {msg}") + print(f"[DEBUG][DeepFinanceReader] {msg}") logger.debug(msg) -class FinworldReader(BaseTaskReader): +class DeepFinanceReader(BaseTaskReader): """ - FinWorld 专用的数据加载器 + DeepFinance 专用的数据加载器 特点: 1. 从 JSON 文件加载任务数据(支持 list 和 dict 格式) @@ -45,7 +45,7 @@ def __init__(self, reader_config): super().__init__(reader_config) self.reader_config = reader_config - _debug_log(f"Initializing FinworldReader...") + _debug_log(f"Initializing DeepFinanceReader...") _debug_log(f"reader_config type: {type(reader_config).__name__}") # 获取 prompt 目录路径 @@ -58,23 +58,23 @@ def __init__(self, reader_config): def _init_prompt_templates(self): """初始化 prompt 模板缓存""" - if FinworldReader._prompt_template_cache is None: + if DeepFinanceReader._prompt_template_cache is None: prompt_file = os.path.join(self.local_path, 'prompt', 'finance_analyst_prompt.md') _debug_log(f"Loading prompt template from: {prompt_file}") with open(prompt_file, 'r', encoding='utf-8') as f: - FinworldReader._prompt_template_cache = f.read() - _debug_log(f"Prompt template loaded, length: {len(FinworldReader._prompt_template_cache)} chars") + DeepFinanceReader._prompt_template_cache = f.read() + _debug_log(f"Prompt template loaded, length: {len(DeepFinanceReader._prompt_template_cache)} chars") else: - _debug_log(f"Using cached prompt template, length: {len(FinworldReader._prompt_template_cache)} chars") + _debug_log(f"Using cached prompt template, length: {len(DeepFinanceReader._prompt_template_cache)} chars") - if FinworldReader._tool_prompt_cache is None: + if DeepFinanceReader._tool_prompt_cache is None: # 使用 tool_prompt_builder.py 中的静态模板 _debug_log(f"Loading tool prompt template...") from tutorial.example_deep_finance.prompt.tool_prompt_builder import get_tool_prompt_template - FinworldReader._tool_prompt_cache = get_tool_prompt_template() - _debug_log(f"Tool prompt template loaded, length: {len(FinworldReader._tool_prompt_cache)} chars") + DeepFinanceReader._tool_prompt_cache = get_tool_prompt_template() + _debug_log(f"Tool prompt template loaded, length: {len(DeepFinanceReader._tool_prompt_cache)} chars") else: - _debug_log(f"Using cached tool prompt template, length: {len(FinworldReader._tool_prompt_cache)} chars") + _debug_log(f"Using cached tool prompt template, length: {len(DeepFinanceReader._tool_prompt_cache)} chars") def _build_system_prompt(self) -> str: """构建 system prompt""" @@ -82,14 +82,14 @@ def _build_system_prompt(self) -> str: _debug_log(f"Building system prompt with date: {current_date}") # 替换日期占位符 - system_prompt = FinworldReader._prompt_template_cache.replace( + system_prompt = DeepFinanceReader._prompt_template_cache.replace( '{current_date}', current_date ) # 替换工具列表占位符 system_prompt = system_prompt.replace( '{tool_list}', - FinworldReader._tool_prompt_cache + DeepFinanceReader._tool_prompt_cache ) _debug_log(f"System prompt built, final length: {len(system_prompt)} chars") return system_prompt @@ -194,7 +194,7 @@ def _read_json_file(self, file_path: str, split: str = "train") -> List[Task]: tasks.append(task) _debug_log(f"Summary: loaded={len(tasks)}, skipped={skipped_count}, split_filtered={split_filtered_count}") - print(f"[FinworldReader] Loaded {len(tasks)} tasks from {file_path} (split={split})") + print(f"[DeepFinanceReader] Loaded {len(tasks)} tasks from {file_path} (split={split})") if len(tasks) == 0: raise ValueError(f"No tasks found in file: {file_path} for split={split}") diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index 37089d04..869e6c03 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -15,8 +15,8 @@ ajet: citation_audit_weight: {{CITATION_AUDIT_WEIGHT}} # 引用审计评估 (覆盖率 + 真实性) rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 task_judge: - # 使用本地 FinWorldJudge 进行评估(解耦远程 env_service) - judge_protocol: tutorial.example_deep_finance.deep_finance_judge->FinWorldJudgeByOpenJudge + # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) + judge_protocol: tutorial.example_deep_finance.deep_finance_judge->DeepFinanceJudgeByOpenJudge model: # ✨✨✨✨ 设置待训练的模型 path: {{MODEL_PATH}} From 04f49592b217f2b01552693ba8242518132870bf Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 18:20:35 +0800 Subject: [PATCH 22/56] refactor(tutorial): Optimize dynamic generation logic for configuration file paths --- tutorial/example_deep_finance/deep_finance.sh | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index 02620620..82fd76cf 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -22,11 +22,8 @@ NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 -# 修改:配置文件生成路径,现在动态生成到 yaml 目录下 +# 主目录 export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" -CONFIG_FILE="${AJET_ROOT}/tutorial/example_deep_finance/yaml/deep_finance_${SUFFIX}.yaml" -CONFIG_TEMPLATE="tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml" - # 涉密的配置(API_KEY以及模型、数据位置)从.env读取 cd ${AJET_ROOT} source .venv/bin/activate @@ -44,6 +41,10 @@ fi #=============================================================================== # 2. 动态生成配置文件 (从yaml template生成yaml) #=============================================================================== +# 修改:配置文件生成路径,现在动态生成到 yaml 目录下 +CONFIG_TEMPLATE="tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml" +CONFIG_FILE="${AJET_ROOT}/tutorial/example_deep_finance/yaml/${SUFFIX}.yaml" +mkdir -p $(dirname ${CONFIG_FILE}) sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{PREFIX}}|${PREFIX}|g" \ From d0ff68b63f682c6b55b21a4b9fc3d48ec7c57300 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 18:23:20 +0800 Subject: [PATCH 23/56] fix(deep_finance): argparse: with-deepfinance --- tutorial/example_deep_finance/deep_finance.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index 82fd76cf..f16f417e 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -206,7 +206,7 @@ if [[ $HOSTNAME == *"-master-"* ]]; then # 启动训练任务(最核心) python ajet/launcher.py \ - --with-deep_finance \ + --with-deepfinance \ --conf ${CONFIG_FILE} \ --backbone="verl" \ 2>&1 | tee ${TRAIN_LOG} From 37dcbcc6c46a7d8fb7cd64f79bbd610774d290ce Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 18:34:59 +0800 Subject: [PATCH 24/56] fix(tutorial): Fixed issues with multi-machine training environment variable settings --- tutorial/example_deep_finance/deep_finance.sh | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index f16f417e..6fd46f45 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -22,11 +22,16 @@ NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 + # 主目录 export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" + +NNODES=${WORLD_SIZE} + # 涉密的配置(API_KEY以及模型、数据位置)从.env读取 cd ${AJET_ROOT} source .venv/bin/activate + # API密钥配置 - 从 .env 文件加载 ENV_FILE="${AJET_ROOT}/.env" if [ -f "$ENV_FILE" ]; then @@ -112,12 +117,6 @@ ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" # 多机训练参数配置 -if [ -z "${WORLD_SIZE}" ]; then - echo "ERROR: WORLD_SIZE environment variable is not set!" - echo "Please ensure this script is run in a multi-node environment (e.g., PAI-DLC, SLURM)" - exit 1 -fi -NNODES=${WORLD_SIZE} GPUS_PER_NODE=8 EXPECTED_WORKERS=$WORLD_SIZE From 529ae7e8e5b80d0e888155039975175821e730dc Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 20:12:37 +0800 Subject: [PATCH 25/56] fix(env): Corrected the assignment logic for reward and info when returning environment state - Corrected the `env_output` return value structure in `BaseGymEnv` to ensure correct assignment of `reward` and `info` fields. - Removed `RefJudge` and `StructureJudge` related metric calculations and statistics from `reward_metric_helper`. - Cleaned up redundant code in `reward_metric_helper`, removing invalid comments and statistical items. - Modified `save_trajectory_as_json` to always print trajectory saving confirmation information. - Corrected log comments in `example_deep_finance` to avoid meaningless log output. - Added the `save_trajectory_as_json_file` configuration item to `deep_finance_template.yaml` to support trajectory saving functionality. --- ajet/task_rollout/resource_keeper.py | 6 ++-- .../metric_helper/reward_metric_helper.py | 30 ------------------- .../metric_helper/save_trajectory_as_json.py | 5 ++-- tutorial/example_deep_finance/deep_finance.py | 2 +- .../yaml_template/deep_finance_template.yaml | 1 + 5 files changed, 8 insertions(+), 36 deletions(-) diff --git a/ajet/task_rollout/resource_keeper.py b/ajet/task_rollout/resource_keeper.py index 6d4045d0..8a205f29 100644 --- a/ajet/task_rollout/resource_keeper.py +++ b/ajet/task_rollout/resource_keeper.py @@ -205,11 +205,15 @@ def step(self, action: dict) -> Tuple[str, float, bool, dict]: action=action, ) obs = "" + reward = 0 + info = {} assert isinstance(env_output, dict) if isinstance(env_output["state"], list): # 1. If state is a list (new standard format), pass through directly obs = env_output["state"] + reward = env_output["reward"] + info = env_output["info"] else: # 2. If state is a dict (old format or error) if ("content" not in env_output["state"]) and ("error" in env_output["state"]): @@ -219,8 +223,6 @@ def step(self, action: dict) -> Tuple[str, float, bool, dict]: else: obs = env_output["state"]["content"] - reward = 0 - info = {} terminate = env_output["is_terminated"] return obs, reward, terminate, info # type: ignore diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py index bfe12e4f..76d034bf 100644 --- a/ajet/utils/metric_helper/reward_metric_helper.py +++ b/ajet/utils/metric_helper/reward_metric_helper.py @@ -77,7 +77,6 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str if openjudge_enabled_count > 0: # ========== OpenJudge Metrics ========== - metrics[f"{prefix}rewards/openjudge_enabled_rate"] = openjudge_enabled_count / n * 100 # Dynamically extract OpenJudge grader fields # Currently supported graders: report_resolution, trajectory_faithfulness, @@ -116,48 +115,19 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str rm_raw_list = [rs.get('rm_raw', 0.0) for rs in reward_stats_list] rm_contribution_list = [rs.get('rm_contribution', 0.0) for rs in reward_stats_list] - # RefJudge - ref_final_raw_list = [rs.get('ref_final_raw', 0.0) for rs in reward_stats_list] - ref_citation_raw_list = [rs.get('ref_citation_raw', 0.0) for rs in reward_stats_list] - ref_grounding_raw_list = [rs.get('ref_grounding_raw', 0.0) for rs in reward_stats_list] - ref_contribution_list = [rs.get('ref_contribution', 0.0) for rs in reward_stats_list] - - # StructureJudge - structure_raw_list = [rs.get('structure_raw', 0.0) for rs in reward_stats_list] - structure_contribution_list = [rs.get('structure_contribution', 0.0) for rs in reward_stats_list] - # dimensions/ raw scores metrics[f"{prefix}rewards/dimensions/rm_raw_mean"] = float(np.mean(rm_raw_list)) - metrics[f"{prefix}rewards/dimensions/ref_final_raw_mean"] = float(np.mean(ref_final_raw_list)) - metrics[f"{prefix}rewards/dimensions/ref_citation_raw_mean"] = float(np.mean(ref_citation_raw_list)) - metrics[f"{prefix}rewards/dimensions/ref_grounding_raw_mean"] = float(np.mean(ref_grounding_raw_list)) - metrics[f"{prefix}rewards/dimensions/structure_raw_mean"] = float(np.mean(structure_raw_list)) # contribution/ weighted contributions metrics[f"{prefix}rewards/contribution/rm_contribution_mean"] = float(np.mean(rm_contribution_list)) - metrics[f"{prefix}rewards/contribution/ref_contribution_mean"] = float(np.mean(ref_contribution_list)) - metrics[f"{prefix}rewards/contribution/structure_contribution_mean"] = float(np.mean(structure_contribution_list)) - # Enabled state statistics - ref_judge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('ref_judge_enabled', False)) - if ref_judge_enabled_count > 0: - metrics[f"{prefix}rewards/ref_judge_enabled_rate"] = ref_judge_enabled_count / n * 100 - - structure_judge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('structure_judge_enabled', False)) - if structure_judge_enabled_count > 0: - metrics[f"{prefix}rewards/structure_judge_enabled_rate"] = structure_judge_enabled_count / n * 100 # Time consumption statistics rm_time_list = [rs.get('rm_time', 0.0) for rs in reward_stats_list] - refstruc_time_list = [rs.get('refstruc_time', 0.0) for rs in reward_stats_list] - metrics[f"{prefix}judge_time/rm_time_mean"] = float(np.mean(rm_time_list)) - metrics[f"{prefix}judge_time/refstruc_time_mean"] = float(np.mean(refstruc_time_list)) if rm_time_list: metrics[f"{prefix}judge_time/rm_time_max"] = float(np.max(rm_time_list)) - if refstruc_time_list: - metrics[f"{prefix}judge_time/refstruc_time_max"] = float(np.max(refstruc_time_list)) # ========== General Time Consumption Statistics ========== judge_total_time_list = [rs.get('judge_total_time', 0.0) for rs in reward_stats_list] diff --git a/ajet/utils/metric_helper/save_trajectory_as_json.py b/ajet/utils/metric_helper/save_trajectory_as_json.py index 344a6ab4..91d3f95b 100644 --- a/ajet/utils/metric_helper/save_trajectory_as_json.py +++ b/ajet/utils/metric_helper/save_trajectory_as_json.py @@ -51,6 +51,5 @@ def save_trajectory_as_json(ctx_trackers, global_steps, prefix="train"): with open(traj_file_path, "w", encoding="utf-8") as f: json.dump(traj_data, f, ensure_ascii=False, indent=2) - # Print confirmation for evaluation trajectories - if prefix != "train": - print(f"Saved trajectory to {traj_file_path}") + + print(f"Saved trajectory to {traj_file_path}") diff --git a/tutorial/example_deep_finance/deep_finance.py b/tutorial/example_deep_finance/deep_finance.py index f3ceae9e..1d81fe72 100644 --- a/tutorial/example_deep_finance/deep_finance.py +++ b/tutorial/example_deep_finance/deep_finance.py @@ -107,7 +107,7 @@ async def execute( action={"content": content_text, "role": "assistant"} ) _env_elapsed = time.time() - _env_start - logger.info(f"环境执行 ({_env_elapsed:.2f}s)") + # === 3. 更新 conversation_history (Full History) === # A. 添加 Assistant 消息 (补全 tool_calls) current_assistant_msg = { diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index 869e6c03..a2d2cd73 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -28,6 +28,7 @@ ajet: save_freq: 10 test_freq: 2 total_epochs: 200 + save_trajectory_as_json_file: True rollout: # ✨✨✨✨ 编写并选择Agent user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol From f4eb231fd32fe614320b49a15aac71f8189e43aa Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 20:23:06 +0800 Subject: [PATCH 26/56] chore(config): Update example_deep_finance configuration and clean up files - Added a new ignore rule for config file paths in .gitignore - Deleted the automatically generated mcp_finance_tool_generated.json file in example_deep_finance - Refactored the deep_finance.yaml configuration file, adjusting project and experiment names - Reorganized Judge configuration, clarifying openjudge_llm and rm_llm models - Optimized model paths and training parameter configurations, adding parallel and batch processing settings - Adjusted data reading methods and training/validation set path placeholders - Reduced GPU memory usage ratio for rollout to 0.8 - Updated the default save directory path for the trainer to a placeholder variable - Cleaned up unused and commented-out code to improve configuration file conciseness --- .gitignore | 3 +- .../config/mcp_finance_tool_generated.json | 10 ---- .../example_deep_finance/deep_finance.yaml | 54 +++++++++---------- 3 files changed, 26 insertions(+), 41 deletions(-) delete mode 100644 tutorial/example_deep_finance/config/mcp_finance_tool_generated.json diff --git a/.gitignore b/.gitignore index 5add9fac..6a45c135 100644 --- a/.gitignore +++ b/.gitignore @@ -154,4 +154,5 @@ site dump.rdb -tutorial/example_deep_finance/yaml/* \ No newline at end of file +tutorial/example_deep_finance/yaml/* +tutorial/example_deep_finance/config/* \ No newline at end of file diff --git a/tutorial/example_deep_finance/config/mcp_finance_tool_generated.json b/tutorial/example_deep_finance/config/mcp_finance_tool_generated.json deleted file mode 100644 index 90fbd828..00000000 --- a/tutorial/example_deep_finance/config/mcp_finance_tool_generated.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "mcpServers": { - "flowllm": { - "transport": "sse", - "url": "http://22.17.31.142:8040/sse", - "timeout": 600, - "sse_read_timeout": 1200 - } - } -} diff --git a/tutorial/example_deep_finance/deep_finance.yaml b/tutorial/example_deep_finance/deep_finance.yaml index c4da3438..f67d5a8b 100644 --- a/tutorial/example_deep_finance/deep_finance.yaml +++ b/tutorial/example_deep_finance/deep_finance.yaml @@ -1,21 +1,25 @@ # ------------------ 主要配置 ------------------ ajet: - project_name: ajet - experiment_name: "cc_rm4_res2cit2fai2_30b" - judge_llm: qwen-flash - judge_concurrency: 10 + project_name: ajet_deep_finance + experiment_name: "ajet_deep_finance" + # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) + judge: + openjudge_llm: qwen-flash # OpenJudge 模型 + rm_llm: qwen-max # RM Gallery 模型 + concurrency: 10 # Judge 并发数 + train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径 + val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径 # OpenJudge 权重配置 report_resolution_weight: 0.2 # 报告质量评估 trajectory_faithfulness_weight: 0.2 # 事实准确性评估 citation_audit_weight: 0.2 # 引用审计评估 (覆盖率 + 真实性) rm_weight: 0.4 # RM Gallery 权重 task_judge: - judge_type: customized_protocol - judge_protocol: tutorial.example_finworld.finworld_judge->DeepFinanceJudgeByOpenJudge + # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) + judge_protocol: tutorial.example_deep_finance.deep_finance_judge->DeepFinanceJudgeByOpenJudge model: # ✨✨✨✨ 设置待训练的模型 - path: /mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-30B-A3B-Instruct-2507 - # path: /mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-8B + path: {{MODEL_PATH}} trainer_common: nnodes: 8 n_gpus_per_node: 8 @@ -24,19 +28,20 @@ ajet: save_freq: 10 test_freq: 2 total_epochs: 200 + save_trajectory_as_json_file: True rollout: # ✨✨✨✨ 编写并选择Agent - user_workflow: tutorial.example_finworld.finworld->ExampleDeepResearchProtocol + user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol force_disable_toolcalls: True enable_oversample: False tensor_model_parallel_size: 8 num_repeat: 4 max_env_worker: 64 # 增加环境并行数 max_num_seqs: 64 # 增加VLLM并发序列数 - max_env_len: 10000 max_response_length_in_one_turn: 8000 max_model_len: 50000 agent_madness_reward: 0.0 + compute_madness_checklist: None multi_turn: max_steps: 6 interchange_server: @@ -45,50 +50,39 @@ ajet: debug_max_parallel: 64 # 增加并行任务数,充分利用GPU debug_first_n_tasks: 100 # 增加处理的任务数 data: - train_batch_size: 32 # 增加批次大小,适配8卡并行 + train_batch_size: 32 max_prompt_length: 8000 max_response_length: 41000 task_reader: - # type: env_service # `env_service` or `dataset_file` or `huggingface_dat_repo` or `finworld` - # === 方案 A: 传统 env_service 模式 === - # env_service: - # env_type: "finworld" - # env_url: "http://127.0.0.1:8080" - # env_action_preference: code - # training_split: train - # validation_split: val - - # === 方案 B: DeepFinance Reader 模式 (数据从 JSON 加载,工具调用走 env_service) === - type: finworld - finworld: + type: deep_finance # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service + deep_finance: training: - file_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/finworld_tasks_11171143_cc.json + file_path: {{TRAIN_PATH}} validation: - file_path: /mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet/tutorial/example_finworld/data/AgentEvolver_query_val.json - # env_service 仍然需要配置(用于工具调用) + file_path: {{VAL_PATH}} + # env_service 仍需配置(用于工具调用) env_service: env_type: "finworld" env_url: "http://127.0.0.1:8080" env_action_preference: code trainer: - default_local_dir: "/mnt/data/taoshuchang.tsc/deepresearch/ajet/checkpoints/example_finworld//localths/cc_rm4_res2cit2fai2_30b" + default_local_dir: {{CKPT_SAVE_PATH}} # resume_mode: disable # 禁用自动恢复,从头开始训练 actor_rollout_ref: rollout: tensor_model_parallel_size: 8 - gpu_memory_utilization: 0.95 + gpu_memory_utilization: 0.8 # ------------------ 不需要修改 ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - - file://external/verl/verl/trainer/config # verl only - file://ajet/default_config/trinity # trinity only # ------------------ 不需要修改 ------------------ defaults: - - verl_default # verl inherit 2/2 + - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 - ajet_default - _self_ From 1e0751553d5f68b8b7a41bdfb94fa673af83cc4e Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 20 Jan 2026 23:59:57 +0800 Subject: [PATCH 27/56] Refactor(metric): Optimize tool metric calculation and data saving logic - Corrected the data source field for timeline data used during trajectory saving. - Removed redundant fields in tool execution time, cache hit rate, and error rate statistics. - Updated .gitignore to add ignore rules for the example script directory. - Removed unnecessary debugging information from logs to reduce log noise. - Adjusted log printing in the multi-round interaction execution process to simplify output content. - Streamlined log code for environment observation and termination checks to improve code readability. --- .gitignore | 3 ++- .../metric_helper/save_trajectory_as_json.py | 2 +- ajet/utils/metric_helper/tool_metric_helper.py | 7 +------ tutorial/example_deep_finance/deep_finance.py | 16 +++------------- 4 files changed, 7 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index 6a45c135..95add49e 100644 --- a/.gitignore +++ b/.gitignore @@ -155,4 +155,5 @@ dump.rdb tutorial/example_deep_finance/yaml/* -tutorial/example_deep_finance/config/* \ No newline at end of file +tutorial/example_deep_finance/config/* +tutorial/example_deep_finance/scripts/* \ No newline at end of file diff --git a/ajet/utils/metric_helper/save_trajectory_as_json.py b/ajet/utils/metric_helper/save_trajectory_as_json.py index 91d3f95b..9dd51868 100644 --- a/ajet/utils/metric_helper/save_trajectory_as_json.py +++ b/ajet/utils/metric_helper/save_trajectory_as_json.py @@ -22,7 +22,7 @@ def save_trajectory_as_json(ctx_trackers, global_steps, prefix="train"): else: ctx_tracker.tag = "half_success" - formatted_traj = convert_grouped_steps_to_openai_format(ctx_tracker.timeline_cache) + formatted_traj = convert_grouped_steps_to_openai_format(ctx_tracker.saved_timelines) # Prepare trajectory data traj_data = { diff --git a/ajet/utils/metric_helper/tool_metric_helper.py b/ajet/utils/metric_helper/tool_metric_helper.py index f1ed5d70..fc460029 100644 --- a/ajet/utils/metric_helper/tool_metric_helper.py +++ b/ajet/utils/metric_helper/tool_metric_helper.py @@ -90,7 +90,6 @@ def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "" if time_list: metrics[f"{prefix}tool_time/{tool_name}/mean"] = float(np.mean(time_list)) metrics[f"{prefix}tool_time/{tool_name}/max"] = float(np.max(time_list)) - metrics[f"{prefix}tool_time/{tool_name}/count"] = len(time_list) # ========== 3. Cache Hit Rate by Tool ========== tool_cache_by_name = {} @@ -100,7 +99,6 @@ def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "" if tool_name not in tool_cache_by_name: tool_cache_by_name[tool_name] = {'hits': 0, 'misses': 0} tool_cache_by_name[tool_name]['hits'] += cache_info.get('hits', 0) - tool_cache_by_name[tool_name]['misses'] += cache_info.get('misses', 0) for tool_name, cache_info in tool_cache_by_name.items(): hits = cache_info['hits'] @@ -109,8 +107,6 @@ def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "" if total > 0: hit_rate = hits / total * 100 metrics[f"{prefix}tool_cache/{tool_name}/hit_rate"] = round(hit_rate, 2) - metrics[f"{prefix}tool_cache/{tool_name}/hits"] = hits - metrics[f"{prefix}tool_cache/{tool_name}/misses"] = misses # ========== 4. Error Rate by Tool ========== tool_error_by_name = {} @@ -128,8 +124,7 @@ def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "" if calls > 0: error_rate = errors / calls * 100 metrics[f"{prefix}tool_error/{tool_name}/error_rate"] = round(error_rate, 2) - metrics[f"{prefix}tool_error/{tool_name}/calls"] = calls - metrics[f"{prefix}tool_error/{tool_name}/errors"] = errors + return metrics diff --git a/tutorial/example_deep_finance/deep_finance.py b/tutorial/example_deep_finance/deep_finance.py index 1d81fe72..cbd92ad8 100644 --- a/tutorial/example_deep_finance/deep_finance.py +++ b/tutorial/example_deep_finance/deep_finance.py @@ -67,12 +67,8 @@ async def execute( latest_reward_stats = {} cumulative_tool_call_time = 0.0 # 累计工具调用时间 cumulative_tool_time = {} # 按工具区分的累计耗时: {tool_name: [time1, time2, ...]} - - logger.info(f"开始执行多轮交互,最大步数: {tuner.config.ajet.rollout.multi_turn.max_steps}") - step = 0 for step in range(tuner.config.ajet.rollout.multi_turn.max_steps): - logger.info(f"=== 步骤 {step + 1} ===") # === Agent 推理 === _llm_start = time.time() @@ -87,7 +83,6 @@ async def execute( content_text = reply_message.content content_preview = content_text[:100].replace('\n', ' ') - # logger.info(f"Agent回复 ({_llm_elapsed:.2f}s): {content_preview}...") # === 早期终止检查:在调用 env.step() 前检查 context_overflow === # 修复问题:避免 token_overflow 后还继续调用工具导致阻塞 @@ -130,8 +125,9 @@ async def execute( if info: if 'tool_stats' in info: latest_tool_stats = info['tool_stats'] - logger.info(f"步骤 {step + 1} 工具统计: 调用={latest_tool_stats.get('total_calls', 0)}, " - f"成功率={latest_tool_stats.get('success_rate', 0):.1f}%") + if latest_tool_stats.get('total_calls', 0) == 0: + logger.info(f"步骤 {step + 1} 工具统计: 调用={}, " + f"成功率={latest_tool_stats.get('success_rate', 0):.1f}%") if 'reward_stats' in info: latest_reward_stats = info['reward_stats'] # 累加工具调用时间 @@ -156,7 +152,6 @@ async def execute( # BaseGymEnv.step 直接透传,所以 obs = [tool_results_msgs] # 需要解包获取实际的消息列表 actual_msgs = obs[0] if (len(obs) == 1 and isinstance(obs[0], list)) else obs - logger.info(f"环境观察 (Standard): 收到 {len(actual_msgs)} 条工具消息") # 按照 AgentScope 的 ContentBlock 格式转换消息 # Agent.memory 会自动保存 assistant 的 tool_call 信息 @@ -190,13 +185,10 @@ async def execute( agent_input.append(new_msg) else: # Legacy Mode - logger.info(f"环境观察 (Legacy): {str(obs)[:100]}...") agent_input.append(Msg(name="env", content=obs, role="user")) # === 6. 终止检查 === - logger.info(f"终止状态: {terminate}") if terminate: - logger.info(f"环境返回终止信号,在第 {step + 1} 步结束") break if tuner.get_context_tracker().context_overflow: @@ -212,12 +204,10 @@ async def execute( final_tool_stats['tool_time'] = cumulative_tool_time final_tool_stats['tool_call_time'] = cumulative_tool_call_time - logger.info(f"\n{'='*80}") logger.info(f"任务完成统计 (Task ID: {workflow_task.task.task_id}):") logger.info(f" 总步骤: {step + 1}") logger.info(f" 总调用: {final_tool_stats.get('total_calls', 0)}") logger.info(f" 成功率: {final_tool_stats.get('success_rate', 0):.2f}%") - logger.info(f"{'='*80}\n") return WorkflowOutput( reward=None, From 08ba18427c85139d2942b1fc3045c0592aaaf2c8 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Wed, 21 Jan 2026 00:09:04 +0800 Subject: [PATCH 28/56] fix(metric_helper): fix tool cache metric --- ajet/utils/metric_helper/tool_metric_helper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ajet/utils/metric_helper/tool_metric_helper.py b/ajet/utils/metric_helper/tool_metric_helper.py index fc460029..3ce5da21 100644 --- a/ajet/utils/metric_helper/tool_metric_helper.py +++ b/ajet/utils/metric_helper/tool_metric_helper.py @@ -99,6 +99,7 @@ def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "" if tool_name not in tool_cache_by_name: tool_cache_by_name[tool_name] = {'hits': 0, 'misses': 0} tool_cache_by_name[tool_name]['hits'] += cache_info.get('hits', 0) + tool_cache_by_name[tool_name]['misses'] += cache_info.get('misses', 0) for tool_name, cache_info in tool_cache_by_name.items(): hits = cache_info['hits'] From 3d556920fce8d42d9b81cbe5df1f76302307f005 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Wed, 21 Jan 2026 09:59:56 +0800 Subject: [PATCH 29/56] fix little bug --- tutorial/example_deep_finance/deep_finance.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tutorial/example_deep_finance/deep_finance.py b/tutorial/example_deep_finance/deep_finance.py index cbd92ad8..470e6225 100644 --- a/tutorial/example_deep_finance/deep_finance.py +++ b/tutorial/example_deep_finance/deep_finance.py @@ -125,9 +125,9 @@ async def execute( if info: if 'tool_stats' in info: latest_tool_stats = info['tool_stats'] - if latest_tool_stats.get('total_calls', 0) == 0: - logger.info(f"步骤 {step + 1} 工具统计: 调用={}, " - f"成功率={latest_tool_stats.get('success_rate', 0):.1f}%") + if latest_tool_stats.get('total_calls', 0) > 0: + logger.info(f"步骤 {step + 1} 工具统计: 调用={latest_tool_stats.get('total_calls', 0)}, " + f"成功率={latest_tool_stats.get('success_rate', 0):.1f}%") if 'reward_stats' in info: latest_reward_stats = info['reward_stats'] # 累加工具调用时间 From a478827089e0eddc831a4b1b88d465c11af3f79d Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Wed, 21 Jan 2026 10:22:44 +0800 Subject: [PATCH 30/56] fix(utils): Suppress httpx AsyncClient.aclose() exception warnings --- ajet/backbone/warm_up.py | 3 ++- ajet/utils/async_utils.py | 48 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/ajet/backbone/warm_up.py b/ajet/backbone/warm_up.py index fcae673f..c7505c49 100644 --- a/ajet/backbone/warm_up.py +++ b/ajet/backbone/warm_up.py @@ -6,8 +6,9 @@ import asyncio import logging import os -from ajet.utils.async_utils import apply_httpx_aclose_patch +from ajet.utils.async_utils import apply_httpx_aclose_patch, suppress_httpx_aclose_exception apply_httpx_aclose_patch() +suppress_httpx_aclose_exception() def init_parallel_rollout_logger(experiment_name): diff --git a/ajet/utils/async_utils.py b/ajet/utils/async_utils.py index 219aba9c..c5869c1e 100644 --- a/ajet/utils/async_utils.py +++ b/ajet/utils/async_utils.py @@ -1,5 +1,6 @@ import asyncio import concurrent.futures +import logging from typing import Any def run_async_coroutine_with_timeout(coro, timeout: int = 3600) -> Any: @@ -68,3 +69,50 @@ def _patched_del(self) -> None: print("Applied httpx aclose patch.") except ImportError: pass + + +def suppress_httpx_aclose_exception(): + """ + Suppress the 'Task exception was never retrieved' error from httpx AsyncClient.aclose(). + This error occurs when the event loop is closed before the AsyncClient is properly closed. + """ + # Custom exception handler for asyncio + def custom_exception_handler(loop, context): + exception = context.get('exception') + message = context.get('message', '') + + # Check if this is the specific httpx aclose RuntimeError we want to suppress + if exception is not None: + if isinstance(exception, RuntimeError): + exc_str = str(exception) + if 'unable to perform operation on' in exc_str and 'the handler is closed' in exc_str: + return # Suppress this specific error + if 'TCPTransport' in exc_str and 'closed' in exc_str: + return # Suppress this specific error + + # For other exceptions, use the default handler + loop.default_exception_handler(context) + + # Apply custom exception handler to current or new event loop + try: + loop = asyncio.get_running_loop() + loop.set_exception_handler(custom_exception_handler) + except RuntimeError: + # No running loop, will be applied when loop starts + pass + + # Also filter the logging output for this specific error + class HttpxAcloseFilter(logging.Filter): + def filter(self, record): + msg = record.getMessage() + if 'Task exception was never retrieved' in msg and 'aclose' in msg: + return False + if 'unable to perform operation on' in msg and 'the handler is closed' in msg: + return False + if 'TCPTransport' in msg and 'closed' in msg: + return False + return True + + # Apply filter to root logger and asyncio logger + logging.getLogger().addFilter(HttpxAcloseFilter()) + logging.getLogger('asyncio').addFilter(HttpxAcloseFilter()) From 88be3e4c1782aace100d3c1d079bc3522b0f3682 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Wed, 21 Jan 2026 11:09:12 +0800 Subject: [PATCH 31/56] comments to english --- ajet/context_tracker/base_tracker.py | 5 ++--- ajet/task_reader/__init__.py | 2 +- ajet/task_rollout/resource_keeper.py | 6 +++--- ajet/task_runner/general_runner.py | 6 ++---- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/ajet/context_tracker/base_tracker.py b/ajet/context_tracker/base_tracker.py index 948aee3e..856cd89c 100644 --- a/ajet/context_tracker/base_tracker.py +++ b/ajet/context_tracker/base_tracker.py @@ -1,5 +1,4 @@ -from typing import List, Tuple, Union -from typing import List, Union, Tuple, Dict, Optional +from typing import Any, Dict, List, Optional, Tuple, Union from ajet.schema.task import WorkflowTask from ajet.schema.extended_msg import ( @@ -141,7 +140,7 @@ def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs): self.already_mad_flag: bool = False self.round_cnt = 0 self.generation_prompt_token = None - self.log_metrics: Optional[Dict[str, Union[float, List[float]]]] = None # Initialize workflow_metadata to store tool statistics + self.log_metrics: Optional[Dict[str, Union[float, List[float], Dict[str, Any]]]] = None # Initialize workflow_metadata to store tool statistics assert ( self.config.ajet.data.max_prompt_length diff --git a/ajet/task_reader/__init__.py b/ajet/task_reader/__init__.py index d3bbb1d7..b431456f 100644 --- a/ajet/task_reader/__init__.py +++ b/ajet/task_reader/__init__.py @@ -62,7 +62,7 @@ def __init__(self, reader_type, reader_config): elif task_reader_type == "random_dummy": self.task_reader = RandomDummyTaskReader(reader_config) elif task_reader_type == "deep_finance": - # deep_finance 专用: 数据从 JSON 文件加载并组装 init_messages,工具调用走 env_service + # deep_finance: load message from JSON file and assemble init_messages, tool calls go through env_service from tutorial.example_deep_finance.deep_finance_reader import DeepFinanceReader self.task_reader = DeepFinanceReader(reader_config) else: diff --git a/ajet/task_rollout/resource_keeper.py b/ajet/task_rollout/resource_keeper.py index 8a205f29..5e23389e 100644 --- a/ajet/task_rollout/resource_keeper.py +++ b/ajet/task_rollout/resource_keeper.py @@ -98,18 +98,18 @@ def _initialize_environment_and_messages(self) -> List[dict]: self.env.release_instance(self.workflow_task.episode_uuid) raise e elif reader_type == "deep_finance": - # deep_finance: 调用 create_instance 注册实例,但使用 reader 组装的 init_messages + # deep_finance: call create_instance to register instance, but use init_messages assembled by the reader if self.env is None: raise ValueError("Environment client is None but deep_finance type is specified") try: - # 必须调用 create_instance,让服务端创建实例,后续 step() 才能工作 + # call create_instance, let the server create an instance, so that subsequent step() can work self.env.create_instance( env_type=self.env_type, task_id=self.task_id, instance_id=self.workflow_task.episode_uuid, params=self.env_params, ) - # 不使用返回的 state,直接用 reader 组装的 init_messages + # Do not use the returned state, directly use the init_messages assembled by the reader task = self.workflow_task.task if task.init_messages: init_messages = task.init_messages diff --git a/ajet/task_runner/general_runner.py b/ajet/task_runner/general_runner.py index 91136b51..88f9ab11 100644 --- a/ajet/task_runner/general_runner.py +++ b/ajet/task_runner/general_runner.py @@ -54,12 +54,10 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: ) else: raw_reward, is_success = self.get_judge().compute_reward(workflow_task, workflow_output) - # Sync reward_stats from metadata to log_metrics after judge computation - if "reward_stats" in workflow_output.metadata: + if "reward_stats" in workflow_output.metadata: + workflow_output.log_metrics["reward_stats"] = workflow_output.metadata["reward_stats"] - workflow_output.log_metrics["reward_stats"] = workflow_output.metadata["reward_stats"] - workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue assert not isinstance( From fb41962bc7073385eb07330c9f3b5e4555ecd385 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Wed, 21 Jan 2026 20:58:26 +0800 Subject: [PATCH 32/56] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E5=90=8D=E7=A7=B0=E5=89=8D=E7=BC=80=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 launcher 中添加 --prefix 参数支持 - 在 pty_launch 函数中实现前缀逻辑 - 更新 deep_finance.sh 脚本以使用前缀功能 - 允许在同一环境中运行多个服务实例 --- ajet/launcher.py | 3 ++- ajet/utils/pty.py | 4 +++- tutorial/example_deep_finance/deep_finance.sh | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ajet/launcher.py b/ajet/launcher.py index 47345ce2..3bb5925e 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -99,6 +99,7 @@ def parse_args(): default=False, help="Kill system processes (ray + vllm + python) that may block the current experiment", ) + parser.add_argument("--prefix", type=str, default="", required=False, help="Prefix for service names") return parser.parse_args() @@ -304,7 +305,7 @@ def main(): pty_launch("appworld") if args.with_deepfinance: - pty_launch("deepfinance") + pty_launch("deepfinance", prefix=args.prefix) if args.with_crafters: pty_launch("crafters") diff --git a/ajet/utils/pty.py b/ajet/utils/pty.py index 6d859ae1..e6756114 100644 --- a/ajet/utils/pty.py +++ b/ajet/utils/pty.py @@ -96,13 +96,15 @@ def pty_wrapper_final(human_cmd, dir, env_dict): pty_wrapper(["/bin/bash", "-c", human_cmd], dir, env_dict) -def pty_launch(service_name: str, success_std_string="Starting server on"): +def pty_launch(service_name: str, success_std_string="Starting server on", prefix: str=""): from ajet.utils.smart_daemon import LaunchCommandWhenAbsent service_path = os.environ.get(f"{service_name.upper()}_PATH") service_script = os.environ.get(f"{service_name.upper()}_SCRIPT") if service_path is None or service_script is None: raise ValueError(f"Environment variables for {service_name} not properly set.") + if prefix != "": + service_name = prefix + "_" + service_name companion = LaunchCommandWhenAbsent( full_argument_list=[service_script], dir=service_path, diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index 6fd46f45..dcddb7cc 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -208,6 +208,7 @@ if [[ $HOSTNAME == *"-master-"* ]]; then --with-deepfinance \ --conf ${CONFIG_FILE} \ --backbone="verl" \ + --prefix=${SUFFIX} \ 2>&1 | tee ${TRAIN_LOG} From a1f909bca41840460f5dbfc4ec911bae031c00c1 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Wed, 21 Jan 2026 20:59:31 +0800 Subject: [PATCH 33/56] =?UTF-8?q?fix:=20=E6=94=B9=E8=BF=9B=20MultiAgent=20?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E5=86=85=E5=AE=B9=E8=A7=A3=E6=9E=90=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 支持 tool_result 格式的消息内容块 - 改进非文本内容的处理逻辑,继续处理其他项而非跳过整个消息 - 添加 tool_use 类型的处理(跳过,因为已通过 tool_calls 字段处理) - 优化代码结构和注释,提高可读性 --- ajet/context_tracker/multiagent_tracking.py | 47 +++++++++++++++------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index e13982c5..6501cc7c 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -82,27 +82,46 @@ def extract_text_content_from_content_dict(self, msg): # }, # ], # } + # or tool_result format: + # msg = { + # "role": "tool", + # "content": [ + # { + # "type": "tool_result", + # "id": "call_xxx", + # "output": "tool output content", + # "name": "tool_name" + # }, + # ], + # } str_content = "" for item in msg["content"]: - # item = { - # "type": "text", - # "text": "some text" - # }, - assert isinstance(item, dict), f"Unsupported non-dict item in message content: {item}. Full message: {msg}" - if ("text" not in item): + item_type = item.get("type", "") + + # Handle text content block + if "text" in item: + if isinstance(item["text"], str): + str_content += item["text"] + # Handle tool_result content block (AgentScope format) + elif item_type == "tool_result" and "output" in item: + output = item["output"] + if isinstance(output, str): + str_content += output + else: + str_content += str(output) + # Handle tool_use content block (for completeness) + elif item_type == "tool_use": + # tool_use blocks are handled via tool_calls field, skip content extraction + continue + else: logger.warning( - f"Non-text content in message content detected: {item}. Ignoring." + f"Non-text content in message content detected: {item}. Ignoring this item." ) - should_skip_message = True - return str_content, should_skip_message - - if isinstance(item["text"], str): - str_content += str(item["text"]) - else: - str_content = "" + # Continue processing other items instead of skipping the entire message + continue should_skip_message = False return str_content, should_skip_message From 8d2e5d7d7a8f5c020a6cdd9b24689c8cfb69cc3f Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Wed, 21 Jan 2026 21:01:13 +0800 Subject: [PATCH 34/56] =?UTF-8?q?fix:=20=E4=BC=98=E5=8C=96=20DeepFinance?= =?UTF-8?q?=20=E5=88=A4=E6=96=AD=E9=80=BB=E8=BE=91=E5=92=8C=E9=85=8D?= =?UTF-8?q?=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复 tool_stats 提取逻辑,从 log_metrics 中正确获取数据 - 添加惩罚项调试信息输出 - 启用 tool calls 功能(force_disable_toolcalls: False) - 确保奖励计算准确性 --- tutorial/example_deep_finance/deep_finance_judge.py | 6 +++++- .../yaml_template/deep_finance_template.yaml | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py index f49d88d3..31e4be01 100644 --- a/tutorial/example_deep_finance/deep_finance_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -373,8 +373,12 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO fused_reward, contributions = self._fuse_grader_scores(grader_scores, rm_raw) # 6. 计算惩罚项(保留原有的 tool_calls 惩罚逻辑) - tool_calls = metadata.get("tool_stats", {}).get("total_calls", 0) + # 从 log_metrics 中提取 tool_stats(deep_finance.py 将其放在 log_metrics 而非 metadata) + tool_stats = workflow_output.log_metrics.get("tool_stats", {}) + tool_calls = tool_stats.get("total_calls", 0) penalty = self._compute_penalty(tool_calls) + if penalty < 0: + print(f"⚠️ Penalty applied: penalty={penalty}, tool_calls={tool_stats}") # 7. 汇总 final_reward = fused_reward + step_reward + penalty diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index a2d2cd73..8e6065d3 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -32,7 +32,7 @@ ajet: rollout: # ✨✨✨✨ 编写并选择Agent user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol - force_disable_toolcalls: True + force_disable_toolcalls: False enable_oversample: False tensor_model_parallel_size: 8 num_repeat: {{NUM_REPEAT}} From 3c85960902c8bd623d15771ed6ce16e242a66341 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Thu, 22 Jan 2026 18:09:09 +0800 Subject: [PATCH 35/56] chore(deps): bump agentscope from 1.0.7 to 1.0.8 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 856cddca..474e9024 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ ] requires-python = ">=3.10,<3.13" dependencies = [ - "agentscope==1.0.7", + "agentscope==1.0.8", "chromadb", "httpx", "tenacity", From 9b541c59951aa12f4b5a9fa8de556c821e10b7a1 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Thu, 22 Jan 2026 18:09:52 +0800 Subject: [PATCH 36/56] fix(metric_helper): correct trajectory save path and add tool call metric - Change trajectory save directory from "ctx_trackers" to "trajectory" to organize files better - Add recording of tool call counts alongside error rates in tool metrics - Update experiment suffix in deep finance example script for clearer naming convention --- ajet/utils/metric_helper/save_trajectory_as_json.py | 2 +- ajet/utils/metric_helper/tool_metric_helper.py | 1 + tutorial/example_deep_finance/deep_finance.sh | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ajet/utils/metric_helper/save_trajectory_as_json.py b/ajet/utils/metric_helper/save_trajectory_as_json.py index 9dd51868..4ab263f4 100644 --- a/ajet/utils/metric_helper/save_trajectory_as_json.py +++ b/ajet/utils/metric_helper/save_trajectory_as_json.py @@ -40,7 +40,7 @@ def save_trajectory_as_json(ctx_trackers, global_steps, prefix="train"): # Define save directory and file path traj_save_dir = os.path.join( os.environ.get("BEST_LOGGER_PATH", "launcher_record"), - "ctx_trackers", + "trajectory", prefix, f"step_{global_steps}" ) diff --git a/ajet/utils/metric_helper/tool_metric_helper.py b/ajet/utils/metric_helper/tool_metric_helper.py index 3ce5da21..a656a078 100644 --- a/ajet/utils/metric_helper/tool_metric_helper.py +++ b/ajet/utils/metric_helper/tool_metric_helper.py @@ -125,6 +125,7 @@ def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "" if calls > 0: error_rate = errors / calls * 100 metrics[f"{prefix}tool_error/{tool_name}/error_rate"] = round(error_rate, 2) + metrics[f"{prefix}tool_error/{tool_name}/calls"] = calls return metrics diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index dcddb7cc..d9f624a2 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -3,7 +3,7 @@ set -e #=============================================================================== # 1. 配置区域 - 用户只需修改这里 #=============================================================================== -SUFFIX="ajet_deep_finance" # 实验后缀,影响所有日志和实验名称 +SUFFIX="deep_finance" # 实验后缀,影响所有日志和实验名称 PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 # OpenJudge 模型配置 From c9b87ac3b1854b4fddd05a2a8c0f8102ff1e558e Mon Sep 17 00:00:00 2001 From: Qingxu Fu Date: Fri, 23 Jan 2026 15:17:50 +0800 Subject: [PATCH 37/56] revise message parsing --- ajet/context_tracker/multiagent_tracking.py | 40 ++++++++++----------- ajet/launcher.py | 2 +- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index 6501cc7c..74e7cf70 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -82,7 +82,7 @@ def extract_text_content_from_content_dict(self, msg): # }, # ], # } - # or tool_result format: + # or tool_result format?? not observed yet: # msg = { # "role": "tool", # "content": [ @@ -97,31 +97,27 @@ def extract_text_content_from_content_dict(self, msg): str_content = "" for item in msg["content"]: - assert isinstance(item, dict), f"Unsupported non-dict item in message content: {item}. Full message: {msg}" - + # item = { + # "type": "text", + # "text": "some text" + # }, item_type = item.get("type", "") + assert not item_type == "tool_use", f"never observed such protocal yet" + assert not item_type == "tool_result", f"never observed such protocal yet" - # Handle text content block - if "text" in item: - if isinstance(item["text"], str): - str_content += item["text"] - # Handle tool_result content block (AgentScope format) - elif item_type == "tool_result" and "output" in item: - output = item["output"] - if isinstance(output, str): - str_content += output - else: - str_content += str(output) - # Handle tool_use content block (for completeness) - elif item_type == "tool_use": - # tool_use blocks are handled via tool_calls field, skip content extraction - continue - else: + assert isinstance(item, dict), f"Unsupported non-dict item in message content: {item}. Full message: {msg}" + + if ("text" not in item): logger.warning( - f"Non-text content in message content detected: {item}. Ignoring this item." + f"Non-text content in message content detected: {item}. Ignoring." ) - # Continue processing other items instead of skipping the entire message - continue + should_skip_message = True + return str_content, should_skip_message + + if isinstance(item["text"], str): + str_content += str(item["text"]) + else: + str_content = "" should_skip_message = False return str_content, should_skip_message diff --git a/ajet/launcher.py b/ajet/launcher.py index 3bb5925e..40557137 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -99,7 +99,7 @@ def parse_args(): default=False, help="Kill system processes (ray + vllm + python) that may block the current experiment", ) - parser.add_argument("--prefix", type=str, default="", required=False, help="Prefix for service names") + parser.add_argument("--prefix", type=str, default="", required=False, help="Prefix for deepfinance service names") return parser.parse_args() From 3bd4c7d0e63d0ad1f99c37123a914d4fe8f6ca4b Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Sun, 25 Jan 2026 10:26:28 +0800 Subject: [PATCH 38/56] fix(metric_helper): update openjudge graders list in reward metric helper --- ajet/utils/metric_helper/reward_metric_helper.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py index 76d034bf..d476a81f 100644 --- a/ajet/utils/metric_helper/reward_metric_helper.py +++ b/ajet/utils/metric_helper/reward_metric_helper.py @@ -84,10 +84,7 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str openjudge_graders = [ "report_resolution", "trajectory_faithfulness", - "rubrics_performance", - "trajectory_comprehensive", - "information_gain", - "action_loop", + "citation_audit", ] for grader_name in openjudge_graders: From 8a18d40509dc862ea1cd0341f353f7559fe9af9a Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Mon, 26 Jan 2026 17:19:23 +0800 Subject: [PATCH 39/56] feat(deep_finance): replace OpenJudge graders with PresentationQualityGrader - Remove legacy graders and integrate PresentationQualityGrader and GroundingGrader - Update grader weights and disable unused graders in config and code - Simplify grader configuration creation with new mappers for report content and traj - Refactor DeepFinanceJudgeByOpenJudge to support new grading scheme - Add PresentationQualityGrader implementation with strict JSON output format - Include utilities for JSON parsing and validation in presentation quality grader - Add prompt templates for presentation quality grading criteria and instructions - Provide example script to run PresentationQualityGrader with OpenAIChatModel - Add traj_adapter utilities to normalize and extract user query and final report - Update YAML template to replace old grader weights with presentation quality weight - Create init files to expose PresentationQualityGrader in judge package --- tutorial/example_deep_finance/__init__.py | 1 + .../deep_finance_judge.py | 128 ++++-------- .../example_deep_finance/judge/__init__.py | 11 ++ .../judge/presentation_quality/grader.py | 187 ++++++++++++++++++ .../judge/presentation_quality/json_utils.py | 85 ++++++++ .../judge/presentation_quality/prompt.py | 94 +++++++++ .../judge/scripts/run_presentation_quality.py | 44 +++++ .../judge/traj_adapter.py | 61 ++++++ .../yaml_template/deep_finance_template.yaml | 4 +- 9 files changed, 520 insertions(+), 95 deletions(-) create mode 100644 tutorial/example_deep_finance/__init__.py create mode 100644 tutorial/example_deep_finance/judge/__init__.py create mode 100644 tutorial/example_deep_finance/judge/presentation_quality/grader.py create mode 100644 tutorial/example_deep_finance/judge/presentation_quality/json_utils.py create mode 100644 tutorial/example_deep_finance/judge/presentation_quality/prompt.py create mode 100644 tutorial/example_deep_finance/judge/scripts/run_presentation_quality.py create mode 100644 tutorial/example_deep_finance/judge/traj_adapter.py diff --git a/tutorial/example_deep_finance/__init__.py b/tutorial/example_deep_finance/__init__.py new file mode 100644 index 00000000..36e084c4 --- /dev/null +++ b/tutorial/example_deep_finance/__init__.py @@ -0,0 +1 @@ +# tutorial/example_deep_finance package diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py index 31e4be01..9a6c72f2 100644 --- a/tutorial/example_deep_finance/deep_finance_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -1,5 +1,5 @@ """DeepFinance Task Judge - OpenJudge 版本 -集成: RM Gallery, OpenJudge Graders (含 CitationAudit) +集成: RM Gallery, PresentationQualityGrader """ import os @@ -13,28 +13,10 @@ from ajet.task_judge.base_judge import BaseJudge from ajet.workflow import WorkflowOutput, WorkflowTask -from openjudge.graders.agent.action.action_loop import ActionLoopDetectionGrader -from openjudge.graders.agent.observation.observation_information_gain import ( - ObservationInformationGainGrader, -) -from openjudge.graders.agent.trajectory.trajectory_comprehensive import ( - TrajectoryComprehensiveGrader, -) from openjudge.models.openai_chat_model import OpenAIChatModel -from openjudge.models.schema.prompt_template import LanguageEnum from openjudge.runner.grading_runner import GraderConfig, GradingRunner -from openjudge.scenarios.deep_research.graders.financial_report_resolution import ( - FinancialReportResolutionGrader, -) -from openjudge.scenarios.deep_research.graders.financial_trajectory_faithfulness import ( - FinancialTrajectoryFaithfulGrader, -) -from openjudge.scenarios.deep_research.graders.rubrics_based_trajectory_performance import ( - RubricsBasedTrajectoryPerformance, -) -from openjudge.scenarios.deep_research.graders.financial_report_citation_audit import ( - FinancialReportCitationAuditGrader, -) +from tutorial.example_deep_finance.judge import PresentationQualityGrader +from tutorial.example_deep_finance.judge.grounding import GroundingGrader # RewardStats 不再使用,OpenJudge 版本直接使用字典存储 @@ -88,7 +70,7 @@ def load_reference_answers_from_file(file_path: str) -> Tuple[Dict[str, str], Di class DeepFinanceJudgeByOpenJudge(BaseJudge): """ 使用 OpenJudge 框架的 DeepFinance Judge - 集成: RM Gallery, OpenJudge Graders (含 CitationAudit) + 集成: RM Gallery, PresentationQualityGrader 分析: - compute_reward 每次处理 **一条采样**(单个 workflow_output) @@ -116,26 +98,15 @@ def _setup_weights(self): 配置 OpenJudge 各 grader 的权重并归一化 graders 对应关系: - - financial_report_resolution: 报告质量和问题解决能力 - - financial_trajectory_faithfulness: 事实准确性(忠实度) - - citation_audit: 引用审计(覆盖率 + 真实性) - - rubrics_based_trajectory_performance: 基于 rubrics 的评估 - - trajectory_comprehensive: 轨迹综合评估 - - observation_information_gain: 信息增益(去重) - - action_loop_detection: 动作循环检测(惩罚项) + - presentation_quality: 报告呈现质量评估 """ cfg = getattr(self.config, "ajet", None) - # 定义各 grader 的权重(可从 config 中读取)- 与 deep_finance_judge.py 对齐 + # 定义各 grader 的权重(可从 config 中读取) self.w = { "rm": getattr(cfg, "rm_weight", 1.0) if cfg else 1.0, # RM Gallery 权重 - "citation_audit": getattr(cfg, "citation_audit_weight", 0.0) if cfg else 0.0, # CitationAudit 权重 - "report_resolution": getattr(cfg, "report_resolution_weight", 0.0) if cfg else 0.0, - "trajectory_faithfulness": getattr(cfg, "trajectory_faithfulness_weight", 0.0) if cfg else 0.0, - # "rubrics_performance": getattr(cfg, "rubrics_performance_weight", 0.2) if cfg else 0.2, - # "trajectory_comprehensive": getattr(cfg, "trajectory_comprehensive_weight", 0.2) if cfg else 0.2, - # "information_gain": getattr(cfg, "information_gain_weight", 0.1) if cfg else 0.1, - # "action_loop": getattr(cfg, "action_loop_weight", 0.1) if cfg else 0.1 + "presentation_quality": getattr(cfg, "presentation_quality_weight", 0.25) if cfg else 0.25, + "grounding": getattr(cfg, "grounding_weight", 0.25) if cfg else 0.25, } # 归一化(注意:action_loop 是惩罚项,不参与归一化;rm 需要参与归一化) @@ -244,15 +215,14 @@ def _create_runner_in_loop(self) -> GradingRunner: 注意:GradingRunner 内部的 Semaphore 会绑定到创建时的事件循环, 因此不能使用单例模式,必须在每次调用的事件循环中创建新实例。 """ - language = LanguageEnum.ZH - grader_configs = self._create_grader_configs(self.model, language) + grader_configs = self._create_grader_configs(self.model) return GradingRunner( grader_configs=grader_configs, max_concurrency=self.max_concurrency, show_progress=False ) - def _create_grader_configs(self, model: OpenAIChatModel, language: LanguageEnum) -> Dict[str, GraderConfig]: + def _create_grader_configs(self, model: OpenAIChatModel) -> Dict[str, GraderConfig]: """ 创建所有 grader 的配置 @@ -260,54 +230,35 @@ def _create_grader_configs(self, model: OpenAIChatModel, language: LanguageEnum) - key: grader 名称 - value: GraderConfig(grader=..., mapper=...) """ + + def extract_user_query(data: Dict) -> str: + """从 messages 中提取第一条 user 消息的 content""" + for msg in data.get("messages", []): + if msg.get("role") == "user": + return msg.get("content", "") + return "" + + def extract_report_content(data: Dict) -> str: + """从 messages 中提取最后一条 assistant 消息的 content""" + for msg in reversed(data.get("messages", [])): + if msg.get("role") == "assistant": + return msg.get("content", "") + return "" + return { - # 1. 报告质量评估 - 需要 messages 和 chat_date - "report_resolution": GraderConfig( - grader=FinancialReportResolutionGrader(model=model, language=language), + # 报告呈现质量评估 - 需要 user_query 和 report_content + "presentation_quality": GraderConfig( + grader=PresentationQualityGrader(model=model), mapper=lambda data: { - "messages": data["messages"], - "chat_date": data.get("chat_date") + "user_query": extract_user_query(data), + "report_content": extract_report_content(data), }, ), - - # 2. 事实准确性评估 - 需要 messages - "trajectory_faithfulness": GraderConfig( - grader=FinancialTrajectoryFaithfulGrader(model=model, language=language), - mapper=lambda data: {"messages": data["messages"]}, - ), - - # 3. 引用审计评估 - 需要 messages - "citation_audit": GraderConfig( - grader=FinancialReportCitationAuditGrader(model=model, language=language), - mapper=lambda data: {"messages": data["messages"]}, + # 引用规范性评估 - 需要完整的 traj + "grounding": GraderConfig( + grader=GroundingGrader(model=model), + mapper=lambda data: {"traj": data}, ), - - # 4. Rubrics 评估 - 需要 messages 和 rubrics - # "rubrics_performance": GraderConfig( - # grader=RubricsBasedTrajectoryPerformance(model=model, language=language), - # mapper=lambda data: { - # "messages": data["messages"], - # "rubrics": data.get("rubrics", []) - # }, - # ), - - # 5. 轨迹综合评估 - 需要 messages - # "trajectory_comprehensive": GraderConfig( - # grader=TrajectoryComprehensiveGrader(model=model, language=language), - # mapper=lambda data: {"messages": data["messages"]}, - # ), - - # 6. 信息增益评估 - 需要 messages(非 LLM grader) - # "information_gain": GraderConfig( - # grader=ObservationInformationGainGrader(similarity_threshold=0.5), - # mapper=lambda data: {"messages": data["messages"]}, - # ), - - # 7. 动作循环检测 - 需要 messages(非 LLM grader) - # "action_loop": GraderConfig( - # grader=ActionLoopDetectionGrader(similarity_threshold=1.0), - # mapper=lambda data: {"messages": data["messages"]}, - # ), } def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> Tuple[float, bool]: @@ -552,8 +503,7 @@ def _extract_grader_scores(self, grader_results: Dict[str, List[Any]]) -> Tuple[ 输入: - grader_results: Dict[str, List[GraderScore]] { - "report_resolution": [GraderScore(score=0.88, reason="...", metadata={...})], - "trajectory_faithfulness": [GraderScore(score=1.0, ...)], + "presentation_quality": [GraderScore(score=0.88, reason="...", metadata={...})], ... } @@ -689,13 +639,7 @@ def _update_metadata_stats( 更新 metadata["reward_stats"] - 直接使用 OpenJudge 原始字段 OpenJudge graders(按实际启用情况): - - report_resolution: 报告质量和问题解决能力 - - trajectory_faithfulness: 事实准确性(忠实度) - - citation_audit: 引用审计(覆盖率 + 真实性) - - rubrics_performance: 基于 rubrics 的评估(可选) - - trajectory_comprehensive: 轨迹综合评估(可选) - - information_gain: 信息增益/去重(可选) - - action_loop: 动作循环检测(惩罚项,可选) + - presentation_quality: 报告呈现质量评估 注意:不再硬套 RewardStats 的字段名,直接使用 openjudge_ 前缀 """ diff --git a/tutorial/example_deep_finance/judge/__init__.py b/tutorial/example_deep_finance/judge/__init__.py new file mode 100644 index 00000000..6c9fdfbe --- /dev/null +++ b/tutorial/example_deep_finance/judge/__init__.py @@ -0,0 +1,11 @@ +# 使得可以通过 from judge import PresentationQualityGrader 直接引用 +# from tutorial.example_deep_finance.judge.grounding.grader import GroundingGrader +from tutorial.example_deep_finance.judge.presentation_quality.grader import PresentationQualityGrader +# from tutorial.example_deep_finance.judge.research_depth.grader import ResearchDepthGrader +# from tutorial.example_deep_finance.judge.research_breadth.grader import ResearchBreadthGrader + +# 以后添加了其他 grader 也可以加在这里 +# from .grounding.grader import GroundingGrader +# from .research_breadth.grader import ResearchBreadthGrader +# __all__ = ["PresentationQualityGrader", "GroundingGrader", "ResearchDepthGrader", "ResearchBreadthGrader"] +__all__ = ["PresentationQualityGrader"] diff --git a/tutorial/example_deep_finance/judge/presentation_quality/grader.py b/tutorial/example_deep_finance/judge/presentation_quality/grader.py new file mode 100644 index 00000000..ac3eb349 --- /dev/null +++ b/tutorial/example_deep_finance/judge/presentation_quality/grader.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import os +from typing import Any, Dict, List, Tuple + +from openjudge.graders.base_grader import BaseGrader +from openjudge.graders.schema import GraderScore + +# import path 兼容两种写法(文档里两种都出现过) +try: + from openjudge.models import OpenAIChatModel +except Exception: # pragma: no cover + from openjudge.models.openai_chat_model import OpenAIChatModel + +from .prompt import ( + QUALITY_SYSTEM_PROMPT, + USER_PROMPT_TEMPLATE, + ALL_KEYS, + A_KEYS, + B_KEYS, + C_KEYS, +) +from .json_utils import strict_load_json, validate_shape, get_bool_pass, get_note + + +class PresentationQualityGrader(BaseGrader): + """ + - 输入:report_content(研究报告文本) + - 输出:GraderScore(name, score, reason) + - score:8项(pass)均分,范围[0,1] + - determinism:建议用 temperature=0 + disable thinking 等(见 create_default_model) + - 解析失败:score=0,并在 reason 显示报错 + """ + + def __init__( + self, + model: OpenAIChatModel, + name: str = "presentation_quality", + **kwargs: Any, + ): + super().__init__(name=name, **kwargs) + self.model = model + + @staticmethod + def create_default_model( + model_name: str, + api_key: str | None = None, + base_url: str | None = None, + deterministic: bool = True, + enable_thinking: bool = False, + seed: int = 0, + ) -> OpenAIChatModel: + """ + 你也可以不调用这个工厂,自己在外面 new OpenAIChatModel。 + QuickStart 文档确认 OpenAIChatModel 会从 OPENAI_API_KEY/OPENAI_BASE_URL 读取。 + """ + api_key = api_key or os.getenv("OPENAI_API_KEY") + base_url = base_url or os.getenv("OPENAI_BASE_URL") + + extra_body: Dict[str, Any] = {} + if deterministic: + # OpenAI兼容接口常见字段;DashScope/Qwen 常用 enable_thinking + extra_body.update( + { + "temperature": 0, + "top_p": 1, + "seed": seed, + "presence_penalty": 0, + "frequency_penalty": 0, + } + ) + if enable_thinking is False: + extra_body["enable_thinking"] = False + + kwargs: Dict[str, Any] = {"model": model_name} + if api_key: + kwargs["api_key"] = api_key + if base_url: + kwargs["base_url"] = base_url + if extra_body: + kwargs["extra_body"] = extra_body + + return OpenAIChatModel(**kwargs) + + async def aevaluate( + self, + report_content: str, + user_query: str | None = None, + **_: Any, + ) -> GraderScore: + """ + 入口:直接喂 report_content(研究报告文本) + - user_query 可选:用于填充 prompt;不提供则用 "(unknown)" + """ + report = (report_content or "").strip() + if not report: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty report_content", + ) + + uq = (user_query or "").strip() or "(unknown)" + + user_content = USER_PROMPT_TEMPLATE.format( + user_query=uq, + report_content=report, + ) + messages = [ + {"role": "system", "content": QUALITY_SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ] + + # 核心:OpenJudge 的 OpenAIChatModel 支持 await model.achat([...]),并返回 .content + try: + resp = await self.model.achat(messages) + raw_text = getattr(resp, "content", None) + if raw_text is None: + raw_text = str(resp) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ) + + obj, jerr = strict_load_json(str(raw_text)) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"ParseError: {jerr}; raw[:200]={snippet}", + ) + + obj, serr = validate_shape(obj) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"SchemaError: {serr}; raw[:200]={snippet}", + ) + + score, reason = self._score_and_reason(obj) + return GraderScore(name=self.name, score=score, reason=reason) + + def _score_and_reason(self, obj: Dict[str, Any]) -> Tuple[float, str]: + scan = obj["scan"] + structuring = obj["structuring"] + editorial = obj["editorial"] + top_fixes = obj.get("top_fixes", []) + + # 8项均分(强确定性:完全由Python算) + pass_map: Dict[str, bool] = {} + note_map: Dict[str, str] = {} + + def take(section: Dict[str, Any], key: str): + item = section.get(key) + pass_map[key] = get_bool_pass(item) + note_map[key] = get_note(item) + + for k in A_KEYS: + take(scan, k) + for k in B_KEYS: + take(structuring, k) + for k in C_KEYS: + take(editorial, k) + + passed = sum(1 for k in ALL_KEYS if pass_map.get(k) is True) + total = len(ALL_KEYS) # 8 + score = passed / float(total) + + # reason:不加额外字段,只给紧凑总结 + failed_items = [k for k in ALL_KEYS if not pass_map.get(k, False)] + failed_str = ", ".join(f"{k}({note_map.get(k,'')})" for k in failed_items[:4]) + fixes_str = " | ".join(str(x) for x in (top_fixes or [])[:3]) + + parts: List[str] = [] + parts.append(f"Pass {passed}/{total}") + if failed_items: + parts.append(f"Fail: {failed_str}") + if fixes_str: + parts.append(f"TopFixes: {fixes_str}") + + reason = " ; ".join(parts) + return round(score, 6), reason[:800] diff --git a/tutorial/example_deep_finance/judge/presentation_quality/json_utils.py b/tutorial/example_deep_finance/judge/presentation_quality/json_utils.py new file mode 100644 index 00000000..92794ca4 --- /dev/null +++ b/tutorial/example_deep_finance/judge/presentation_quality/json_utils.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import json +import re +from typing import Any, Dict, Tuple + + +_JSON_RE = re.compile(r"\{.*\}", re.DOTALL) + + +def extract_first_json_object(text: str) -> str | None: + """ + Best-effort: extract the first {...} block. + If none found, return None. + """ + if not text: + return None + m = _JSON_RE.search(text.strip()) + if not m: + return None + return m.group(0) + + +def strict_load_json(text: str) -> Tuple[Dict[str, Any] | None, str | None]: + """ + Return (obj, error). Any parse failure => (None, error_msg) + """ + js = extract_first_json_object(text) + if js is None: + return None, "No JSON object found in model output" + try: + obj = json.loads(js) + if not isinstance(obj, dict): + return None, f"Top-level JSON is not an object: {type(obj).__name__}" + return obj, None + except Exception as e: + return None, f"{type(e).__name__}: {e}" + + +def get_bool_pass(item: Any) -> bool: + if isinstance(item, dict): + v = item.get("pass") + else: + v = item + if isinstance(v, bool): + return v + if isinstance(v, (int, float)): + return bool(v) + if isinstance(v, str): + return v.strip().lower() in {"true", "1", "yes", "y"} + return False + + +def get_note(item: Any) -> str: + if isinstance(item, dict): + note = item.get("note", "") + else: + note = "" + note = "" if note is None else str(note) + note = note.strip() + # 最多给点余量,避免reason爆长 + return note[:120] + + +def validate_shape(obj: Dict[str, Any]) -> Tuple[Dict[str, Any] | None, str | None]: + """ + Ensure required sections exist and are dicts; ensure top_fixes is list or str. + If missing required field => error. + """ + for sec in ("scan", "structuring", "editorial"): + if sec not in obj: + return None, f"Missing field: {sec}" + if not isinstance(obj[sec], dict): + return None, f"Field '{sec}' is not an object" + if "top_fixes" not in obj: + return None, "Missing field: top_fixes" + # normalize top_fixes + tf = obj.get("top_fixes") + if isinstance(tf, list): + obj["top_fixes"] = [str(x) for x in tf][:3] + elif tf is None: + obj["top_fixes"] = [] + else: + obj["top_fixes"] = [str(tf)][:3] + return obj, None diff --git a/tutorial/example_deep_finance/judge/presentation_quality/prompt.py b/tutorial/example_deep_finance/judge/presentation_quality/prompt.py new file mode 100644 index 00000000..0abaf1ee --- /dev/null +++ b/tutorial/example_deep_finance/judge/presentation_quality/prompt.py @@ -0,0 +1,94 @@ +# 8项呈现质量检查:A(3)+B(3)+C(2)=8 +QUALITY_SYSTEM_PROMPT = """ +你是一位“呈现质量评审官”。你只评估报告的**呈现与表达质量 (Presentation & Editorial Quality)**,用于奖励信号。 +严禁评估:事实真伪/引用支持(Grounding 负责)、内容覆盖广度(Breadth 负责)、分析深度与洞察(Depth 负责)、观点是否正确。 +核心关注:**可扫描性**、**信息结构化**、**逻辑链条的可视化呈现**、**表达清晰与可用性**。 + +======================== +评分标准(仅判定 pass=true/false) +======================== +对以下 8 个检查项分别给出 pass/fail,并给一句 note(≤25字,需指出“位置或症状”,避免空泛)。 + +A) Scan & Navigation(可扫描性) +A1 结论先行(Key Takeaways Top) +- Pass:开头可见“摘要/要点/核心结论”块(短段或列表均可),读者无需通读即可抓到主结论。 +- Fail:开头直接进入细节/材料堆叠,无概括性要点。 + +A2 结构导航(Navigable Structure) +- Pass:正文有清晰分节(标题层级或明显分段),读者能快速定位主要部分(分析/风险/结论等)。 +- Fail:无结构或结构混乱,像长篇流水账,难以导航。 + +A3 视觉重点(Visual Hierarchy) +- Pass:重点信息对“扫读友好”(要点化/短句分行/适度强调等),且重点承载信息而非装饰。 +- Fail:全文平铺直叙;或存在明显“格式堆砌”但不增信息。 + +B) Information Structuring(信息结构化) +B1 密集信息解构(Dense Info Structured) +- Pass:数字/多条件/多点信息密集处被列表/分组/表格等拆解,易读易取。 +- Fail:关键数据淹没在长难句或长段落(典型:数字长句串联)。 + +B2 对比对齐(Comparisons Aligned) +- Pass:涉及横向对比(A vs B/同行对比/情景对比)时,用表格或对齐结构呈现,使维度一眼可比(不强制表格)。 +- Fail:对比点散落在不同段落,维度不对齐,无法直观对照。 + +B3 一致性(Consistency) +- Pass:单位/口径/标点/小标题/列表风格整体统一,专业感稳定。 +- Fail:格式与表述明显混乱,增加阅读负担。 + +C) Editorial Clarity(编辑清晰度) +C1 论证链可视化(Argument Chain Presented) +- Pass:在呈现上能跟随“主张→依据→解释→影响/结论”的链条(例如用分段或 bullet 串联/对齐呈现),不是只堆材料。 +- Fail:大量材料堆砌,但缺少可视化的逻辑线索(读者难跟随)。 + +C2 风险与行动(Risk & Actionability Clear) +- Pass:以清晰形式列出风险/边界/不确定性,并给出可执行的下一步关注点(只看表达是否清楚存在,不评全面与正确)。 +- Fail:未提及风险/边界/下一步,或表述极度含糊不可操作。 + +反刷分原则(必须执行): +- 空标题占位、空表格/无意义表格、重复 bullet 但不增加信息 → 相关项直接判 fail,并在 note 标注“形式堆砌”。 + +======================== +输出要求(Strict JSON) +======================== +必须输出可解析 JSON;pass 必须为 boolean。 +不要输出 Markdown;不要添加额外字段;不得省略字段。 + +JSON 模板(字段必须齐全): +{ + "scan": { + "A1_key_takeaways_top": {"pass": true, "note": "≤25字定位理由"}, + "A2_navigable_structure": {"pass": true, "note": "≤25字定位理由"}, + "A3_visual_hierarchy": {"pass": true, "note": "≤25字定位理由"} + }, + "structuring": { + "B1_dense_info_structured": {"pass": false, "note": "≤25字定位理由"}, + "B2_comparisons_aligned": {"pass": true, "note": "≤25字定位理由"}, + "B3_consistency": {"pass": true, "note": "≤25字定位理由"} + }, + "editorial": { + "C1_argument_chain_presented": {"pass": false, "note": "≤25字定位理由"}, + "C2_risk_and_actionability_clear": {"pass": true, "note": "≤25字定位理由"} + }, + "top_fixes": ["最多3条,仅谈呈现层面改进"] +} +""" + +USER_PROMPT_TEMPLATE = """ +请审计以下研究报告的【呈现质量】(只谈呈现/排版/结构,不谈事实对错/引用支持/覆盖/深度)。 + +### User Query +{{user_query}} + +### AI Report +{{report_content}} + +----- +请严格按 System Prompt 的锚点输出 JSON;不要输出 Markdown;不要添加额外字段。 +""".strip() + +# 8个检查项key(用于Python均分,强确定性) +A_KEYS = ["A1_key_takeaways_top", "A2_navigable_structure", "A3_visual_hierarchy"] +B_KEYS = ["B1_dense_info_structured", "B2_comparisons_aligned", "B3_consistency"] +C_KEYS = ["C1_argument_chain_presented", "C2_risk_and_actionability_clear"] + +ALL_KEYS = A_KEYS + B_KEYS + C_KEYS diff --git a/tutorial/example_deep_finance/judge/scripts/run_presentation_quality.py b/tutorial/example_deep_finance/judge/scripts/run_presentation_quality.py new file mode 100644 index 00000000..840076ed --- /dev/null +++ b/tutorial/example_deep_finance/judge/scripts/run_presentation_quality.py @@ -0,0 +1,44 @@ +import asyncio +import sys +import os + +# 添加项目根目录到 Python 路径 +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../..")) +sys.path.insert(0, PROJECT_ROOT) +print(f"PROJECT_ROOT: {PROJECT_ROOT}") + +from openjudge.models import OpenAIChatModel +from tutorial.example_deep_finance.judge import PresentationQualityGrader + + +async def main(): + # 你也可以只写:model = OpenAIChatModel(model="qwen3-32b") + # 并用环境变量 OPENAI_API_KEY / OPENAI_BASE_URL(QuickStart里推荐这种方式) + model = OpenAIChatModel( + model="qwen-flash", + extra_body={"enable_thinking": False, "temperature": 0, "top_p": 1, "seed": 0}, + ) + + grader = PresentationQualityGrader(model=model) + + report = """ + # 藏格矿业分析报告 + + ## 执行摘要 + - 核心结论:... + + ## 财务对比 + | 公司 | 营收 | 净利 | + |---|---:|---:| + | A | 20 | 5 | + + ## 风险与下一步 + - 风险:... + - 下一步:... + """ + res = await grader.aevaluate(report_content=report, user_query="分析藏格矿业的财务状况") + print(res) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tutorial/example_deep_finance/judge/traj_adapter.py b/tutorial/example_deep_finance/judge/traj_adapter.py new file mode 100644 index 00000000..66df53f8 --- /dev/null +++ b/tutorial/example_deep_finance/judge/traj_adapter.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Tuple + + +def extract_text_content(content: Any) -> str: + """Extract plain text from common message schemas.""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + texts: List[str] = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + texts.append(str(item.get("text", ""))) + elif isinstance(item, str): + texts.append(item) + return "\n".join(texts) + return str(content) + + +def normalize_traj(traj: Any) -> List[Dict[str, Any]]: + """ + Accept common traj shapes: + - list[{"role":..., "content":...}, ...] + - {"trajectory": [...]} + - {"messages": [...]} + """ + if isinstance(traj, list): + return traj + if isinstance(traj, dict): + if isinstance(traj.get("trajectory"), list): + return traj["trajectory"] + if isinstance(traj.get("messages"), list): + return traj["messages"] + return [] + + +def infer_user_query(trajectory: List[Dict[str, Any]]) -> str: + for step in trajectory: + if step.get("role") == "user": + txt = extract_text_content(step.get("content")) + if txt.strip(): + return txt.strip() + return "" + + +def find_final_report(trajectory: List[Dict[str, Any]]) -> str: + """ + Heuristic: last assistant long text or markdown-like content. + """ + for step in reversed(trajectory): + if step.get("role") == "assistant": + txt = extract_text_content(step.get("content", "")) + if len(txt) > 120 or "#" in txt: + return txt + return "" + + + diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index 8e6065d3..ed82cb25 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -10,9 +10,7 @@ ajet: train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径 val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径 # OpenJudge 权重配置 - report_resolution_weight: {{REPORT_RESOLUTION_WEIGHT}} # 报告质量评估 - trajectory_faithfulness_weight: {{TRAJECTORY_FAITHFULNESS_WEIGHT}} # 事实准确性评估 - citation_audit_weight: {{CITATION_AUDIT_WEIGHT}} # 引用审计评估 (覆盖率 + 真实性) + presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估 rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 task_judge: # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) From 835bdd82a0d71fe275b82a849ec9427be9b76a0a Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Mon, 26 Jan 2026 17:32:00 +0800 Subject: [PATCH 40/56] feat(grounding): implement grounding grader for citation compliance evaluation - add GroundingGrader class to evaluate citation coverage and truthfulness based on dialogue traj - provide default OpenAIChatModel creation with deterministic options - implement prompt construction and JSON parsing utilities for model interaction - calculate scores including coverage, grounding, and invalid citation penalties - add detailed json_utils module for strict JSON extraction and validation - introduce prompt templates defining citation auditing rules and user prompts - supply reference.py with related grounding evaluation logic and RefJudgeEvaluator class - create __init__.py to expose GroundingGrader module - add presentation_quality module __init__.py with PresentationQualityGrader export --- .../judge/grounding/__init__.py | 4 + .../judge/grounding/grader.py | 222 +++++++++++ .../judge/grounding/json_utils.py | 250 ++++++++++++ .../judge/grounding/prompt.py | 117 ++++++ .../judge/grounding/reference.py | 363 ++++++++++++++++++ .../judge/presentation_quality/__init__.py | 4 + 6 files changed, 960 insertions(+) create mode 100644 tutorial/example_deep_finance/judge/grounding/__init__.py create mode 100644 tutorial/example_deep_finance/judge/grounding/grader.py create mode 100644 tutorial/example_deep_finance/judge/grounding/json_utils.py create mode 100644 tutorial/example_deep_finance/judge/grounding/prompt.py create mode 100644 tutorial/example_deep_finance/judge/grounding/reference.py create mode 100644 tutorial/example_deep_finance/judge/presentation_quality/__init__.py diff --git a/tutorial/example_deep_finance/judge/grounding/__init__.py b/tutorial/example_deep_finance/judge/grounding/__init__.py new file mode 100644 index 00000000..1123382d --- /dev/null +++ b/tutorial/example_deep_finance/judge/grounding/__init__.py @@ -0,0 +1,4 @@ +"""Grounding Grader - 引用规范性评估""" +from .grader import GroundingGrader + +__all__ = ["GroundingGrader"] diff --git a/tutorial/example_deep_finance/judge/grounding/grader.py b/tutorial/example_deep_finance/judge/grounding/grader.py new file mode 100644 index 00000000..599ccc9c --- /dev/null +++ b/tutorial/example_deep_finance/judge/grounding/grader.py @@ -0,0 +1,222 @@ +"""Grounding Grader - 引用规范性评估 (OpenJudge 版本)""" +from __future__ import annotations + +import os +from typing import Any, Dict, List, Tuple + +from openjudge.graders.base_grader import BaseGrader +from openjudge.graders.schema import GraderScore + +# import path 兼容两种写法 +try: + from openjudge.models import OpenAIChatModel +except Exception: # pragma: no cover + from openjudge.models.openai_chat_model import OpenAIChatModel + +from .prompt import GROUNDING_SYSTEM_PROMPT, GROUNDING_USER_PROMPT_TEMPLATE +from .json_utils import strict_load_json, validate_shape, construct_reward_prompt + + +class GroundingGrader(BaseGrader): + """ + 引用规范性评估 Grader + + - 输入:traj(完整对话轨迹) + - 输出:GraderScore(name, score, reason) + - score:综合分数,范围[0,1] + - citation_coverage_score: 引用覆盖率(0.5 权重) + - grounding_score: 引用真实性(0.5 权重) + - invalid_penalty: 无效引用惩罚(最多扣 0.5) + - determinism:建议用 temperature=0 + disable thinking + - 解析失败:score=0,并在 reason 显示报错 + """ + + def __init__( + self, + model: OpenAIChatModel, + name: str = "grounding", + **kwargs: Any, + ): + super().__init__(name=name, **kwargs) + self.model = model + + @staticmethod + def create_default_model( + model_name: str, + api_key: str | None = None, + base_url: str | None = None, + deterministic: bool = True, + enable_thinking: bool = False, + seed: int = 0, + ) -> OpenAIChatModel: + """ + 创建默认模型 + 也可以不调用这个工厂,自己在外面 new OpenAIChatModel + """ + api_key = api_key or os.getenv("OPENAI_API_KEY") + base_url = base_url or os.getenv("OPENAI_BASE_URL") + + extra_body: Dict[str, Any] = {} + if deterministic: + extra_body.update( + { + "temperature": 0, + "top_p": 1, + "seed": seed, + "presence_penalty": 0, + "frequency_penalty": 0, + } + ) + if enable_thinking is False: + extra_body["enable_thinking"] = False + + kwargs: Dict[str, Any] = {"model": model_name} + if api_key: + kwargs["api_key"] = api_key + if base_url: + kwargs["base_url"] = base_url + if extra_body: + kwargs["extra_body"] = extra_body + + return OpenAIChatModel(**kwargs) + + async def aevaluate( + self, + traj: Any, + **_: Any, + ) -> GraderScore: + """ + 入口:必须喂 traj(完整对话轨迹) + + Args: + traj: 对话轨迹,格式为 [{"role": ..., "content": ...}, ...] + 或者 {"messages": [...]} 格式 + + Returns: + GraderScore(name, score, reason) + """ + # 1. 提取 messages(兼容两种格式) + if isinstance(traj, dict): + messages_list = traj.get("messages", []) + elif isinstance(traj, list): + messages_list = traj + else: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: traj must be list or dict with 'messages'", + ) + + if not messages_list: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty trajectory", + ) + + # 2. 构建 prompt + user_prompt = construct_reward_prompt(messages_list, GROUNDING_USER_PROMPT_TEMPLATE) + + messages = [ + {"role": "system", "content": GROUNDING_SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt} + ] + + # 3. 调用模型 + try: + resp = await self.model.achat(messages) + raw_text = getattr(resp, "content", None) + if raw_text is None: + raw_text = str(resp) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ) + + # 4. 解析 JSON + obj, jerr = strict_load_json(str(raw_text)) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"ParseError: {jerr}; raw[:200]={snippet}", + ) + + obj, serr = validate_shape(obj) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"SchemaError: {serr}; raw[:200]={snippet}", + ) + + # 5. 计算分数 + score, reason = self._compute_scores(obj) + return GraderScore(name=self.name, score=score, reason=reason) + + def _compute_scores(self, obj: Dict[str, Any]) -> Tuple[float, str]: + """ + 根据 LLM 返回的结果计算评分 + + Args: + obj: LLM 返回的 JSON,包含 total_key_facts, cited_key_facts, fake_count 等 + + Returns: + (score, reason) 元组 + """ + total_key_facts = obj.get('total_key_facts', 0) + cited_key_facts = obj.get('cited_key_facts', 0) + fake_count = obj.get('fake_count', 0) + missing_count = obj.get('missing_count', 0) + + # invalid refs: 结构化/可追溯性问题 + invalid_reference_nums = obj.get('invalid_reference_nums', []) + if not isinstance(invalid_reference_nums, list): + invalid_reference_nums = [] + invalid_ref_count = len(invalid_reference_nums) + + # 边界情况:没有关键事实,直接返回 0 + if total_key_facts == 0: + citation_coverage_score = 0.0 + grounding_score = 0.0 + else: + # coverage: 引用覆盖率 + citation_coverage_score = cited_key_facts / total_key_facts + + # grounding: 引用真实性(已引用中非虚假的比例) + if cited_key_facts == 0: + grounding_score = 0.0 + else: + grounding_score = max(0.0, 1 - fake_count / cited_key_facts) + + # 轻量惩罚:存在 invalid refs 会降低 reward + # 每个 invalid 号扣 0.1,最多扣 0.5 + invalid_penalty = min(0.1 * invalid_ref_count, 0.5) + + # final_reward: 综合分数(权重 0.5:0.5),再叠加 invalid 惩罚 + final_reward = 0.5 * citation_coverage_score + 0.5 * grounding_score + final_reward = max(0.0, final_reward - invalid_penalty) + + # 构建 reason + good_citations = obj.get('good_citations', []) + good_str = "; ".join(str(x)[:50] for x in good_citations[:2]) if good_citations else "" + + parts: List[str] = [ + f"total={total_key_facts}", + f"cited={cited_key_facts}", + f"missing={missing_count}", + f"fake={fake_count}", + f"invalid={invalid_ref_count}", + f"coverage={citation_coverage_score:.3f}", + f"grounding={grounding_score:.3f}", + f"penalty={invalid_penalty:.2f}", + ] + if good_str: + parts.append(f"good:[{good_str}]") + + reason = " | ".join(parts) + return round(final_reward, 6), reason[:800] diff --git a/tutorial/example_deep_finance/judge/grounding/json_utils.py b/tutorial/example_deep_finance/judge/grounding/json_utils.py new file mode 100644 index 00000000..03d62895 --- /dev/null +++ b/tutorial/example_deep_finance/judge/grounding/json_utils.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +import json +import re +from typing import Any, Dict, List, Tuple + +_JSON_RE = re.compile(r"\{.*\}", re.DOTALL) + + +def extract_first_json_object(text: str) -> str | None: + """ + Best-effort: extract the first {...} block. + If none found, return None. + """ + if not text: + return None + m = _JSON_RE.search(text.strip()) + if not m: + return None + return m.group(0) + + +def strict_load_json(text: str) -> Tuple[Dict[str, Any] | None, str | None]: + """ + Return (obj, error). Any parse failure => (None, error_msg) + """ + js = extract_first_json_object(text) + if js is None: + return None, "No JSON object found in model output" + try: + obj = json.loads(js) + if not isinstance(obj, dict): + return None, f"Top-level JSON is not an object: {type(obj).__name__}" + return obj, None + except Exception as e: + return None, f"{type(e).__name__}: {e}" + + +def get_bool_pass(item: Any) -> bool: + if isinstance(item, dict): + v = item.get("pass") + else: + v = item + if isinstance(v, bool): + return v + if isinstance(v, (int, float)): + return bool(v) + if isinstance(v, str): + return v.strip().lower() in {"true", "1", "yes", "y"} + return False + + +def get_note(item: Any) -> str: + if isinstance(item, dict): + note = item.get("note", "") + else: + note = "" + note = "" if note is None else str(note) + note = note.strip() + # 最多给点余量,避免reason爆长 + return note[:120] + + +def validate_shape(obj: Dict[str, Any]) -> Tuple[Dict[str, Any] | None, str | None]: + """ + 验证 grounding JSON 结构 + + 必需字段: + - total_key_facts: int + - cited_key_facts: int + - missing_count: int + - fake_count: int + - good_citations: list + - invalid_reference_nums: list + """ + # 必需的 int 字段 + int_fields = ["total_key_facts", "cited_key_facts", "missing_count", "fake_count"] + for field in int_fields: + if field not in obj: + return None, f"Missing field: {field}" + val = obj[field] + # 尝试转换为 int + if isinstance(val, (int, float)): + obj[field] = int(val) + elif isinstance(val, str) and val.isdigit(): + obj[field] = int(val) + elif not isinstance(val, int): + return None, f"Field '{field}' must be int, got {type(val).__name__}" + + # good_citations 必须是 list + if "good_citations" not in obj: + obj["good_citations"] = [] + elif not isinstance(obj["good_citations"], list): + obj["good_citations"] = [] + else: + # 确保每个元素是字符串,最多保留 2 条 + obj["good_citations"] = [str(x) for x in obj["good_citations"][:2]] + + # invalid_reference_nums 必须是 list + if "invalid_reference_nums" not in obj: + obj["invalid_reference_nums"] = [] + elif not isinstance(obj["invalid_reference_nums"], list): + obj["invalid_reference_nums"] = [] + else: + # 确保每个元素是 int,最多保留 5 个 + nums = [] + for x in obj["invalid_reference_nums"][:5]: + if isinstance(x, int): + nums.append(x) + elif isinstance(x, (float, str)): + try: + nums.append(int(x)) + except ValueError: + pass + obj["invalid_reference_nums"] = sorted(nums) + + return obj, None + + + + +# ============================================================================= +# Trajectory 处理辅助函数 +# ============================================================================= + +def _extract_text_content(content) -> str: + """统一提取纯文本内容""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + out = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + out.append(item.get("text", "")) + elif isinstance(item, str): + out.append(item) + return "\n".join(out) + return str(content) + + +def _strip_think(text: str) -> str: + """去除 ... 标签""" + return re.sub(r".*?\s*", "", text, flags=re.S).strip() + + +def _normalize_traj(trajectory): + """兼容 [[...]] 格式""" + if isinstance(trajectory, list) and trajectory and isinstance(trajectory[0], list): + return trajectory[0] + return trajectory + + +def _extract_tool_call_json(text: str) -> str: + """提取工具调用 JSON""" + m = re.search(r"```json\s*(\[[\s\S]*?\])\s*```", text) + if m: + return m.group(1).strip() + l, r = text.find("["), text.rfind("]") + if l != -1 and r != -1 and r > l: + cand = text[l:r+1].strip() + if ("tool_name" in cand) and ("tool_args" in cand): + return cand + return "" + + +def _looks_like_tool_result(text: str) -> bool: + """判断是否为工具返回结果""" + t = text.strip() + if t.startswith("Tool:") or t.startswith("Result:"): + return True + if t.startswith("{") and ("query" in t) and ("search_results" in t or "response_content" in t): + return True + if ("股票代码 |" in t) or ("单位:" in t) or t.startswith("### "): + return True + return False + + +def _is_probably_final_report(text: str) -> bool: + """判断是否为最终报告""" + t = text.strip() + return ("## References" in t) or ("[TASK_COMPLETED]" in t) or t.lstrip().startswith("# ") + + +def construct_reward_prompt(trajectory: List[Dict[str, Any]], user_prompt_template: str) -> str: + """ + 从 trajectory 构建 reward prompt + + Args: + trajectory: 对话轨迹 [{"role": ..., "content": ...}, ...] + + Returns: + 构建好的 user prompt 字符串 + """ + traj = _normalize_traj(trajectory) + if not traj: + traj = [] + + user_query = "" + tool_calls: List[str] = [] + evidence: List[str] = [] + final_report = "" + + # 找到 final report(从后往前找第一个符合条件的 assistant 消息) + for i in range(len(traj) - 1, -1, -1): + step = traj[i] + if step.get("role") == "assistant": + txt = _strip_think(_extract_text_content(step.get("content"))) + if _is_probably_final_report(txt): + final_report = txt + break + if not final_report: + for i in range(len(traj) - 1, -1, -1): + if traj[i].get("role") == "assistant": + final_report = _strip_think(_extract_text_content(traj[i].get("content"))) + break + + # 遍历提取 user_query, tool_calls, evidence + for idx, step in enumerate(traj): + role = step.get("role") + raw = _extract_text_content(step.get("content")) + txt = _strip_think(raw) + if not raw: + continue + + if role == "user" and not user_query and (not _looks_like_tool_result(raw)): + user_query = txt + continue + + if role == "assistant": + call_json = _extract_tool_call_json(raw) + if call_json: + tool_calls.append(f"[Step {idx}] TOOL_CALL:\n{call_json}") + + if role in ("tool", "user"): + if _looks_like_tool_result(raw): + evidence.append(f"[Step {idx}] EVIDENCE_TOOL_RESULT:\n{raw}") + else: + # query 之后的用户补充也保留为 evidence + if user_query: + evidence.append(f"[Step {idx}] EVIDENCE_USER_CONTEXT:\n{txt}") + + evidence_text = "\n\n".join(tool_calls + evidence) + + return user_prompt_template.format( + user_query=user_query, + evidence_text=evidence_text, + final_report=final_report + ).strip() diff --git a/tutorial/example_deep_finance/judge/grounding/prompt.py b/tutorial/example_deep_finance/judge/grounding/prompt.py new file mode 100644 index 00000000..24bea134 --- /dev/null +++ b/tutorial/example_deep_finance/judge/grounding/prompt.py @@ -0,0 +1,117 @@ +"""Grounding Grader Prompt - 引用规范性评估""" + +GROUNDING_SYSTEM_PROMPT = """你是一位"引用审计员",负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。 + +======================== +一、引用规范(以此为准) +======================== +1) 关键事实句必须引用: + - 关键事实句包括:数字(金额/比例/增速/同比环比/份额/排名等)、日期/期间、财务指标、估值倍数、明确事实结论、具体事件、具体公司/行业的可验证陈述、政策/条款等。 + - 不确定或推断性表述必须显式写“推测/可能/假设/预计/或有风险”等,不得用引用把推断包装成既定事实。 + +2) 引用位置规则(严格执行): + - 关键事实句必须在“句末”出现引用编号:[1] 或 [1][2](可以多个,但必须紧贴句末)。 + - 若引用出现在句中但句末没有引用编号,则该句仍按“缺引用(missing)”处理。 + +3) References 必须存在且可追溯: + - 报告末尾必须包含标题 `## References`(大小写/空格差异可容忍,但必须是一个清晰的 References 区块)。 + - 正文出现的每个 [n] 必须能在 References 中找到对应条目。 + +4) References 条目两种合法形式(必须满足其一): + A) URL 形式:`[n] 标题或简述 - https://...` + - URL 必须为可用的 http/https 链接,不能为空,也不能是 `javascript:void(0)` 之类的伪链接。 + B) no-url 形式:`[n] 简述,工具:,参数:,数据日期/报告期: - (no-url)` + - no-url 必须同时包含:工具名、参数、日期/报告期 三者(缺一即不合规)。 + - `javascript:void(0)` 等无效链接视为无效 URL(会进入 invalid_reference_nums),若要合规应改为 no-url 记录来源。 + +======================== +二、输入 +======================== +你会收到: +- User Query +- Evidence(从完整 trajectory 提取的工具调用/工具返回/用户补充信息) +- AI Report(待审计报告,含正文与 References) + +真实性核对原则: +- 以 Evidence 为准:只有在“明显矛盾”或“Evidence 明显找不到任何依据且该句仍把内容写成确定事实”时,才判 fake。 +- 无法确认/证据缺失/证据不充分时,不要判 fake(宁可不判)。 + +======================== +三、统计与判定口径(严格遵守) +======================== +【文本范围】 +- 只审计 AI Report 的“正文部分”(不包含 References 区块内部的文字)。 +- References 区块仅用于校验编号是否存在、格式是否合规、URL 是否有效。 + +【句子/条目如何计数】 +- “句子/条目”包括:普通句号/分号/换行分点(如列表项、段落中的 bullet)、表格中的单元格陈述(若表达了关键事实,也算关键事实句)。 +- 一句包含多个数字/多个事实点:仍按 1 条关键事实句计数(不要过度拆分)。 +- 同一句若重复出现多次(复制粘贴重复段落):按出现次数计数。 + +【关键事实句识别(务求稳定)】 +- 满足任一条件可视为关键事实句: + (a) 含具体数值/比例/排名/区间/估值倍数/财务指标; + (b) 含具体日期或期间(如 “2024Q3/2025年/截至XX日”); + (c) 对具体公司/行业/政策做了可验证的确定性陈述; + (d) 给出明确结论且呈确定口吻并可被证据支持/反驳。 + +【引用是否“句末”】【重要】 +- 句末引用指:该句最后的可见字符为一个或多个连续的 [n](允许中间无空格或有极少空格),例如: + - “……增长 20%[3]” + - “……增长 20% [3][4]” +- 若 [n] 后面仍有正文内容(哪怕很短),则不算句末引用。 + +【invalid_reference_nums 的定义】 +- 统计“正文中出现过”的编号 n(去重),若满足任一条件则判为 invalid: + (a) References 中不存在该编号条目; + (b) 该编号条目为 URL 形式但 URL 无效(空/非 http(s)/javascript:void(0) 等); + (c) 该编号条目为 no-url 形式但缺少 工具名/参数/日期(报告期) 任意之一。 +- invalid_reference_nums 输出按数字升序;最多 5 个,超出截断。 + +【missing_count 的定义】 +- 关键事实句中“句末没有任何 [n]”的数量(即使句中出现 [n] 也算 missing)。 + +【cited_key_facts 的定义】 +- 关键事实句中“句末包含至少一个 [n]”的数量(不要求该引用有效)。 + +【fake_count 的定义(只在明显时计数)】 +- 关键事实句若“句末带引用”,但与 Evidence 明显矛盾,或 Evidence 明显找不到任何依据且该句仍用确定口吻陈述为事实,计为 fake。 +- 若只是 Evidence 未覆盖/不充分/不确定,不计 fake。 + +【good_citations 的定义】 +- 从报告原文中抽取最多 2 条“引用做得正确”的关键事实句,要求同时满足: + - 是关键事实句; + - 句末有 [n]; + - 所有句末 [n] 在 References 中均存在且条目合法(URL 有效或 no-url 字段齐全)。 +- good_citations 是原文截取,不要加解释;最多 2 条,超出截断。 + +======================== +四、输出(只输出 JSON,字段固定) +======================== +{ + "total_key_facts": , + "cited_key_facts": , + "good_citations": ["...", "..."], + "missing_count": , + "fake_count": , + "invalid_reference_nums": [, ...] +} + +只输出 JSON,不要输出解释文字或 Markdown。确保 JSON 可被严格解析(双引号、逗号、方括号等格式正确)。 +""" + +# ============================================================================= +# User Prompt Template +# ============================================================================= + +GROUNDING_USER_PROMPT_TEMPLATE = """请审计以下 AI 研究报告的引用规范性,只输出 JSON。 + +### User Query +{user_query} + +### Evidence +{evidence_text} + +### AI Report(待审计报告) +{final_report} +""" diff --git a/tutorial/example_deep_finance/judge/grounding/reference.py b/tutorial/example_deep_finance/judge/grounding/reference.py new file mode 100644 index 00000000..6e67a382 --- /dev/null +++ b/tutorial/example_deep_finance/judge/grounding/reference.py @@ -0,0 +1,363 @@ +GROUNDING_SYSTEM_PROMPT = """你是一位“引用审计员”,负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。 + +======================== +一、引用规范(以此为准) +======================== +1) 关键事实句必须引用: + - 关键事实句包括:数字(金额/比例/增速/同比环比/份额/排名等)、日期/期间、财务指标、估值倍数、明确事实结论、具体事件、具体公司/行业的可验证陈述、政策/条款等。 + - 不确定或推断性表述必须显式写“推测/可能/假设/预计/或有风险”等,不得用引用把推断包装成既定事实。 + +2) 引用位置规则(严格执行): + - 关键事实句必须在“句末”出现引用编号:[1] 或 [1][2](可以多个,但必须紧贴句末)。 + - 若引用出现在句中但句末没有引用编号,则该句仍按“缺引用(missing)”处理。 + +3) References 必须存在且可追溯: + - 报告末尾必须包含标题 `## References`(大小写/空格差异可容忍,但必须是一个清晰的 References 区块)。 + - 正文出现的每个 [n] 必须能在 References 中找到对应条目。 + +4) References 条目两种合法形式(必须满足其一): + A) URL 形式:`[n] 标题或简述 - https://...` + - URL 必须为可用的 http/https 链接,不能为空,也不能是 `javascript:void(0)` 之类的伪链接。 + B) no-url 形式:`[n] 简述,工具:,参数:,数据日期/报告期: - (no-url)` + - no-url 必须同时包含:工具名、参数、日期/报告期 三者(缺一即不合规)。 + - `javascript:void(0)` 等无效链接视为无效 URL(会进入 invalid_reference_nums),若要合规应改为 no-url 记录来源。 + +======================== +二、输入 +======================== +你会收到: +- User Query +- Evidence(从完整 trajectory 提取的工具调用/工具返回/用户补充信息) +- AI Report(待审计报告,含正文与 References) + +真实性核对原则: +- 以 Evidence 为准:只有在“明显矛盾”或“Evidence 明显找不到任何依据且该句仍把内容写成确定事实”时,才判 fake。 +- 无法确认/证据缺失/证据不充分时,不要判 fake(宁可不判)。 + +======================== +三、统计与判定口径(严格遵守) +======================== +【文本范围】 +- 只审计 AI Report 的“正文部分”(不包含 References 区块内部的文字)。 +- References 区块仅用于校验编号是否存在、格式是否合规、URL 是否有效。 + +【句子/条目如何计数】 +- “句子/条目”包括:普通句号/分号/换行分点(如列表项、段落中的 bullet)、表格中的单元格陈述(若表达了关键事实,也算关键事实句)。 +- 一句包含多个数字/多个事实点:仍按 1 条关键事实句计数(不要过度拆分)。 +- 同一句若重复出现多次(复制粘贴重复段落):按出现次数计数。 + +【关键事实句识别(务求稳定)】 +- 满足任一条件可视为关键事实句: + (a) 含具体数值/比例/排名/区间/估值倍数/财务指标; + (b) 含具体日期或期间(如 “2024Q3/2025年/截至XX日”); + (c) 对具体公司/行业/政策做了可验证的确定性陈述; + (d) 给出明确结论且呈确定口吻并可被证据支持/反驳。 + +【引用是否“句末”】【重要】 +- 句末引用指:该句最后的可见字符为一个或多个连续的 [n](允许中间无空格或有极少空格),例如: + - “……增长 20%[3]” + - “……增长 20% [3][4]” +- 若 [n] 后面仍有正文内容(哪怕很短),则不算句末引用。 + +【invalid_reference_nums 的定义】 +- 统计“正文中出现过”的编号 n(去重),若满足任一条件则判为 invalid: + (a) References 中不存在该编号条目; + (b) 该编号条目为 URL 形式但 URL 无效(空/非 http(s)/javascript:void(0) 等); + (c) 该编号条目为 no-url 形式但缺少 工具名/参数/日期(报告期) 任意之一。 +- invalid_reference_nums 输出按数字升序;最多 5 个,超出截断。 + +【missing_count 的定义】 +- 关键事实句中“句末没有任何 [n]”的数量(即使句中出现 [n] 也算 missing)。 + +【cited_key_facts 的定义】 +- 关键事实句中“句末包含至少一个 [n]”的数量(不要求该引用有效)。 + +【fake_count 的定义(只在明显时计数)】 +- 关键事实句若“句末带引用”,但与 Evidence 明显矛盾,或 Evidence 明显找不到任何依据且该句仍用确定口吻陈述为事实,计为 fake。 +- 若只是 Evidence 未覆盖/不充分/不确定,不计 fake。 + +【good_citations 的定义】 +- 从报告原文中抽取最多 2 条“引用做得正确”的关键事实句,要求同时满足: + - 是关键事实句; + - 句末有 [n]; + - 所有句末 [n] 在 References 中均存在且条目合法(URL 有效或 no-url 字段齐全)。 +- good_citations 是原文截取,不要加解释;最多 2 条,超出截断。 + +======================== +四、输出(只输出 JSON,字段固定) +======================== +{ + "total_key_facts": , + "cited_key_facts": , + "good_citations": ["...", "..."], + "missing_count": , + "fake_count": , + "invalid_reference_nums": [, ...] +} + +只输出 JSON,不要输出解释文字或 Markdown。确保 JSON 可被严格解析(双引号、逗号、方括号等格式正确)。 +""" + + + +import json +import re +from typing import Dict, Any, List + + +def _extract_text_content(content) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + out = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + out.append(item.get("text", "")) + elif isinstance(item, str): + out.append(item) + return "\n".join(out) + return str(content) + +def _strip_think(text: str) -> str: + return re.sub(r".*?\s*", "", text, flags=re.S).strip() + +def _normalize_traj(trajectory): + # 兼容 [[...]] :contentReference[oaicite:1]{index=1} + if isinstance(trajectory, list) and trajectory and isinstance(trajectory[0], list): + return trajectory[0] + return trajectory + +def _extract_tool_call_json(text: str) -> str: + m = re.search(r"```json\s*(\[[\s\S]*?\])\s*```", text) + if m: + return m.group(1).strip() + l, r = text.find("["), text.rfind("]") + if l != -1 and r != -1 and r > l: + cand = text[l:r+1].strip() + if ("tool_name" in cand) and ("tool_args" in cand): + return cand + return "" + +def _looks_like_tool_result(text: str) -> bool: + t = text.strip() + if t.startswith("Tool:") or t.startswith("Result:"): + return True + if t.startswith("{") and ("query" in t) and ("search_results" in t or "response_content" in t): + return True + if ("股票代码 |" in t) or ("单位:" in t) or t.startswith("### "): + return True + return False + +def _is_probably_final_report(text: str) -> bool: + t = text.strip() + return ("## References" in t) or ("[TASK_COMPLETED]" in t) or t.lstrip().startswith("# ") + +def construct_reward_prompt(trajectory: List[Dict[str, Any]]) -> str: + traj = _normalize_traj(trajectory) + + user_query = "" + tool_calls: List[str] = [] + evidence: List[str] = [] + final_report = "" + + # final report + for i in range(len(traj) - 1, -1, -1): + step = traj[i] + if step.get("role") == "assistant": + txt = _strip_think(_extract_text_content(step.get("content"))) + if _is_probably_final_report(txt): + final_report = txt + break + if not final_report: + for i in range(len(traj) - 1, -1, -1): + if traj[i].get("role") == "assistant": + final_report = _strip_think(_extract_text_content(traj[i].get("content"))) + break + + # iterate + for idx, step in enumerate(traj): + role = step.get("role") + raw = _extract_text_content(step.get("content")) + txt = _strip_think(raw) + if not raw: + continue + + if role == "user" and not user_query and (not _looks_like_tool_result(raw)): + user_query = txt + continue + + if role == "assistant": + call_json = _extract_tool_call_json(raw) + if call_json: + tool_calls.append(f"[Step {idx}] TOOL_CALL:\n{call_json}") + + if role in ("tool", "user"): + if _looks_like_tool_result(raw): + evidence.append(f"[Step {idx}] EVIDENCE_TOOL_RESULT:\n{raw}") + else: + # query 之后的用户补充也保留为 evidence(有些系统会把 tool_result 注入到 user) + if user_query: + evidence.append(f"[Step {idx}] EVIDENCE_USER_CONTEXT:\n{txt}") + + evidence_text = "\n\n".join(tool_calls + evidence) + + return f"""请审计以下 AI 研究报告的引用规范性,只输出 JSON。 + +### User Query +{user_query} + +### Evidence(来自完整 trajectory) +{evidence_text} + +### AI Report(待审计报告) +{final_report} +""".strip() + + +class RefJudgeEvaluator: + """ + 引用规范性评估器 + + 使用 LLM 评估报告的引用覆盖率和引用真实性。 + """ + + def __init__(self, llm_client): + """ + 初始化评估器 + + Args: + llm_client: LLMJudgeClient 实例 + """ + self.llm_client = llm_client + print("✓ RefJudgeEvaluator: Initialized") + + def build_messages(self, conversation_history: List[Dict]) -> List[Dict[str, str]]: + """ + 从对话历史构建 LLM 评估消息 + + Args: + conversation_history: 对话历史 [{"role": "...", "content": "..."}] + + Returns: + LLM 消息列表 + """ + print(f"\n[RefJudgeEvaluator] 构建评估消息...") + print(f" - 对话历史轮数: {len(conversation_history)}") + + # 调用现有的 prompt 构建函数 + user_prompt = construct_reward_prompt(conversation_history) + + messages = [ + {"role": "system", "content": GROUNDING_SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt} + ] + + print(f" ✓ 消息构建完成,system prompt 长度: {len(GROUNDING_SYSTEM_PROMPT)}") + print(f" ✓ user prompt 长度: {len(user_prompt)}") + + return messages + + def _compute_scores(self, raw_result: Dict[str, Any]) -> Dict[str, Any]: + """ + 根据 LLM 返回的原始结果计算评分 + + Args: + raw_result: LLM 返回的 JSON,包含 total_key_facts, cited_key_facts, fake_count 等 + + Returns: + 包含 citation_coverage_score, grounding_score, final_reward 的字典 + """ + total_key_facts = raw_result.get('total_key_facts', 0) + cited_key_facts = raw_result.get('cited_key_facts', 0) + fake_count = raw_result.get('fake_count', 0) + + # invalid refs: 结构化/可追溯性问题(来自 prompt 的 invalid_reference_nums) + invalid_reference_nums = raw_result.get('invalid_reference_nums', []) + if not isinstance(invalid_reference_nums, list): + invalid_reference_nums = [] + invalid_ref_count = len(invalid_reference_nums) + + # 边界情况:没有关键事实,直接返回 0 + if total_key_facts == 0: + citation_coverage_score = 0.0 + grounding_score = 0.0 + else: + # coverage: 引用覆盖率 + citation_coverage_score = cited_key_facts / total_key_facts + + # grounding: 引用真实性(已引用中非虚假的比例) + if cited_key_facts == 0: + grounding_score = 0.0 + else: + grounding_score = max(0.0, 1 - fake_count / cited_key_facts) + + # 轻量惩罚:存在 invalid refs 会降低 reward(但不改变 cited_key_facts 的统计口径) + # 说明:invalid_reference_nums 在 prompt 中已定义为“正文出现过的不合规编号(去重)”。 + # 这里采用简单、确定性的惩罚:每个 invalid 号扣 0.1,最多扣 0.5。 + invalid_penalty = min(0.1 * invalid_ref_count, 0.5) + + # final_reward: 综合分数(代码计算,权重 0.5:0.5),再叠加 invalid 惩罚 + final_reward = 0.5 * citation_coverage_score + 0.5 * grounding_score + final_reward = max(0.0, final_reward - invalid_penalty) + + return { + 'citation_coverage_score': citation_coverage_score, + 'grounding_score': grounding_score, + 'final_reward': final_reward, + 'invalid_ref_count': invalid_ref_count, + 'invalid_penalty': invalid_penalty, + } + + async def evaluate_async(self, conversation_history: List[Dict]) -> Dict[str, Any]: + """ + 异步评估引用规范性 + + Args: + conversation_history: 对话历史 + + Returns: + 评估结果字典,包含: + - citation_coverage_score: 引用覆盖率分数 (0.0-1.0) + - grounding_score: 引用真实性分数 (0.0-1.0) + - final_reward: 最终奖励分数 (0.0-1.0) + - total_key_facts, cited_key_facts, fake_count 等原始字段 + """ + # print(f"\n开始评估引用规范性...") + + messages = self.build_messages(conversation_history) + raw_result = await self.llm_client.evaluate_async(messages) + + # 计算评分 + scores = self._compute_scores(raw_result) + + # 合并原始结果和计算的评分 + result = {**raw_result, **scores} + + # 确保必要字段存在 + result.setdefault('total_key_facts', 0) + result.setdefault('cited_key_facts', 0) + result.setdefault('missing_count', 0) + result.setdefault('fake_count', 0) + result.setdefault('invalid_reference_nums', []) + result.setdefault('good_citations', []) + + print(f" ✓ [RefJudgeEvaluator] 引用规范性评估完成:") + print(f" - total_key_facts: {result['total_key_facts']}") + print(f" - cited_key_facts: {result['cited_key_facts']}") + print(f" - fake_count: {result['fake_count']}") + print(f" - invalid_ref_count: {result.get('invalid_ref_count', 0)}") + print(f" - invalid_penalty: {result.get('invalid_penalty', 0.0):.4f}") + print(f" - citation_coverage_score: {result['citation_coverage_score']:.4f}") + print(f" - grounding_score: {result['grounding_score']:.4f}") + print(f" - final_reward: {result['final_reward']:.4f}") + + return result + + def evaluate_sync(self, conversation_history: List[Dict]) -> Dict[str, Any]: + """ + 同步评估引用规范性 + """ + import asyncio + return asyncio.run(self.evaluate_async(conversation_history)) diff --git a/tutorial/example_deep_finance/judge/presentation_quality/__init__.py b/tutorial/example_deep_finance/judge/presentation_quality/__init__.py new file mode 100644 index 00000000..2db690fa --- /dev/null +++ b/tutorial/example_deep_finance/judge/presentation_quality/__init__.py @@ -0,0 +1,4 @@ +"""Grounding Grader - 引用规范性评估""" +from .grader import PresentationQualityGrader + +__all__ = ["PresentationQualityGrader"] From 11ed325b877bab487dc70b70b1f2b33e129b83a5 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Mon, 26 Jan 2026 17:32:11 +0800 Subject: [PATCH 41/56] fix(deep_finance_judge): add debug logging for OpenJudge evaluation process --- .../example_deep_finance/deep_finance_judge.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py index 9a6c72f2..b1e96a61 100644 --- a/tutorial/example_deep_finance/deep_finance_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -312,11 +312,22 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO chat_date=chat_date ) + # DEBUG: 记录转换后的样本结构 + print(f" [DEBUG] task_id={task_id}, messages_count={len(openjudge_sample.get('messages', []))}") + if openjudge_sample.get('messages'): + last_msg = openjudge_sample['messages'][-1] + print(f" [DEBUG] last_msg role={last_msg.get('role')}, content_len={len(last_msg.get('content', ''))}") + # 3. 调用 OpenJudge Runner.arun(异步) grading_start_time = time.time() grader_results = self._run_openjudge_evaluation([openjudge_sample]) grading_time = time.time() - grading_start_time + # DEBUG: 记录原始 grader 结果 + print(f" [DEBUG] grader_results keys: {list(grader_results.keys())}") + for gname, glist in grader_results.items(): + print(f" [DEBUG] {gname}: count={len(glist)}, type={type(glist[0]) if glist else 'empty'}") + # 4. 提取各 grader 分数(arun 返回 Dict[str, List[GraderScore]],这里取第一条) grader_scores, quota_exceeded_flags = self._extract_grader_scores(grader_results) @@ -520,6 +531,11 @@ def _extract_grader_scores(self, grader_results: Dict[str, List[Any]]) -> Tuple[ if score_list and len(score_list) > 0: # 取第一条采样的分数(因为每次只评估一条) grader_score = score_list[0] + + # DEBUG: 记录详细信息 + reason_str = getattr(grader_score, 'reason', None) + print(f" [DEBUG] {grader_name}: score={getattr(grader_score, 'score', 'N/A')}, reason={str(reason_str)[:300] if reason_str else 'N/A'}") + if hasattr(grader_score, "score"): scores[grader_name] = grader_score.score # 检测错误类型:分数为0且有错误信息 @@ -531,8 +547,10 @@ def _extract_grader_scores(self, grader_results: Dict[str, List[Any]]) -> Tuple[ else: # 如果出错,设为 0 scores[grader_name] = 0.0 + print(f" [DEBUG] {grader_name}: no 'score' attr, grader_score={grader_score}") else: scores[grader_name] = 0.0 + print(f" [DEBUG] {grader_name}: empty score_list") print(f" [OpenJudge Scores] {scores}") if any(quota_exceeded_flags.values()): From a500e90661707aab149970c7363f12a4565e22b8 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 27 Jan 2026 19:09:14 +0800 Subject: [PATCH 42/56] feat(deep_finance): enhance reward metadata and zero score debugging - Add populate_reward_metadata_from_stats to copy reward stats into reward metadata - Populate reward metadata in GeneralRunner if reward_stats present in workflow output - Refine compute_reward_metrics with updated OpenJudge graders: presentation_quality, grounding, planning - Add _save_zero_score_debug method in DeepFinanceJudgeByOpenJudge to save debug info for zero grader scores - Remove deprecated RewardStats usage in deep_finance_judge - Update judge __init__ to export GroundingGrader alongside PresentationQualityGrader - Clean up debug print statements and logging in deep_finance_judge.py - Update .gitignore to exclude prepare_data and judge/analytical_sufficiency folders in example_deep_finance tutorial --- .gitignore | 2 + ajet/task_runner/general_runner.py | 5 + .../metric_helper/reward_metric_helper.py | 34 +++++-- .../deep_finance_judge.py | 93 +++++++++++++++---- .../example_deep_finance/judge/__init__.py | 10 +- 5 files changed, 111 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index 5fdcd249..08bed932 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,5 @@ tutorial/example_deep_finance/yaml/* tutorial/example_deep_finance/config/* tutorial/example_deep_finance/scripts/* flash_attn-2.8.*.whl +tutorial/example_deep_finance/prepare_data/* +tutorial/example_deep_finance/judge/analytical_sufficiency/* \ No newline at end of file diff --git a/ajet/task_runner/general_runner.py b/ajet/task_runner/general_runner.py index 88f9ab11..c2610564 100644 --- a/ajet/task_runner/general_runner.py +++ b/ajet/task_runner/general_runner.py @@ -9,6 +9,7 @@ from ajet.schema.trajectory import Reward from ajet.task_runner.base_runner import BaseAgentRunner from ajet.utils.dynamic_import import dynamic_import +from ajet.utils.metric_helper.reward_metric_helper import populate_reward_metadata_from_stats class GeneralRunner(BaseAgentRunner): @@ -73,6 +74,10 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: madness=0, description="", ) + + # Populate reward metadata with deep_finance reward stats if available + if "reward_stats" in workflow_output.metadata: + populate_reward_metadata_from_stats(reward, workflow_output.metadata["reward_stats"]) context_tracker.process_reward(reward) # generate token before merging context_tracker.group_merge() diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py index d476a81f..ea951d5a 100644 --- a/ajet/utils/metric_helper/reward_metric_helper.py +++ b/ajet/utils/metric_helper/reward_metric_helper.py @@ -11,9 +11,12 @@ - judge_time/ Judge time consumption statistics """ -from typing import List, Dict, Any +from typing import List, Dict, Any, TYPE_CHECKING import numpy as np +if TYPE_CHECKING: + from ajet.schema.trajectory import Reward + def extract_reward_stats_from_trajectories(trajectories: List[Any]) -> List[Dict[str, Any]]: """ @@ -72,19 +75,15 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str metrics[f"{prefix}rewards/penalty_count"] = len(non_zero_penalties) metrics[f"{prefix}rewards/penalty_rate"] = len(non_zero_penalties) / n * 100 if n > 0 else 0.0 - # ========== Detect OpenJudge Usage ========== + # ========== OpenJudge Metrics (PresentationQualityGrader, GroundingGrader) ========== openjudge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('openjudge_enabled', False)) if openjudge_enabled_count > 0: - # ========== OpenJudge Metrics ========== - - # Dynamically extract OpenJudge grader fields - # Currently supported graders: report_resolution, trajectory_faithfulness, - # rubrics_performance, trajectory_comprehensive, information_gain, action_loop + # OpenJudge graders: presentation_quality, grounding openjudge_graders = [ - "report_resolution", - "trajectory_faithfulness", - "citation_audit", + "presentation_quality", + "grounding", + "planning" ] for grader_name in openjudge_graders: @@ -148,3 +147,18 @@ def compute_reward_metrics_from_trajectories(trajectories: List[Any], prefix: st reward_stats_list = extract_reward_stats_from_trajectories(trajectories) return compute_reward_metrics(reward_stats_list, prefix=prefix) + +def populate_reward_metadata_from_stats(reward: "Reward", reward_stats: Dict[str, Any]) -> None: + """ + Populate Reward.metadata with all reward statistics. + + Args: + reward: The Reward object to populate + reward_stats: The reward_stats dictionary from judge + """ + if not reward_stats: + return + + # Directly copy all reward_stats into metadata + reward.metadata.update(reward_stats) + diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py index b1e96a61..03f10130 100644 --- a/tutorial/example_deep_finance/deep_finance_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -15,12 +15,9 @@ from openjudge.models.openai_chat_model import OpenAIChatModel from openjudge.runner.grading_runner import GraderConfig, GradingRunner -from tutorial.example_deep_finance.judge import PresentationQualityGrader -from tutorial.example_deep_finance.judge.grounding import GroundingGrader +from tutorial.example_deep_finance.judge import PresentationQualityGrader, GroundingGrader -# RewardStats 不再使用,OpenJudge 版本直接使用字典存储 -# Reference Answer 路径现在从 config 中读取,见 _init_reference_answers 方法 # OpenJudge imports # ============================================================================= @@ -312,25 +309,28 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO chat_date=chat_date ) - # DEBUG: 记录转换后的样本结构 - print(f" [DEBUG] task_id={task_id}, messages_count={len(openjudge_sample.get('messages', []))}") if openjudge_sample.get('messages'): last_msg = openjudge_sample['messages'][-1] - print(f" [DEBUG] last_msg role={last_msg.get('role')}, content_len={len(last_msg.get('content', ''))}") # 3. 调用 OpenJudge Runner.arun(异步) grading_start_time = time.time() grader_results = self._run_openjudge_evaluation([openjudge_sample]) grading_time = time.time() - grading_start_time - - # DEBUG: 记录原始 grader 结果 - print(f" [DEBUG] grader_results keys: {list(grader_results.keys())}") - for gname, glist in grader_results.items(): - print(f" [DEBUG] {gname}: count={len(glist)}, type={type(glist[0]) if glist else 'empty'}") + # 4. 提取各 grader 分数(arun 返回 Dict[str, List[GraderScore]],这里取第一条) grader_scores, quota_exceeded_flags = self._extract_grader_scores(grader_results) + # 4.5 如果有分数为0的grader,保存调试信息到单独文件 + self._save_zero_score_debug( + grader_scores=grader_scores, + grader_results=grader_results, + query=query, + history=history, + report=assistants[-1] if assistants else "", + task_id=task_id + ) + # 5. 加权融合(包含 RM Gallery 和 OpenJudge Graders) fused_reward, contributions = self._fuse_grader_scores(grader_scores, rm_raw) @@ -535,7 +535,6 @@ def _extract_grader_scores(self, grader_results: Dict[str, List[Any]]) -> Tuple[ # DEBUG: 记录详细信息 reason_str = getattr(grader_score, 'reason', None) print(f" [DEBUG] {grader_name}: score={getattr(grader_score, 'score', 'N/A')}, reason={str(reason_str)[:300] if reason_str else 'N/A'}") - if hasattr(grader_score, "score"): scores[grader_name] = grader_score.score # 检测错误类型:分数为0且有错误信息 @@ -550,7 +549,6 @@ def _extract_grader_scores(self, grader_results: Dict[str, List[Any]]) -> Tuple[ print(f" [DEBUG] {grader_name}: no 'score' attr, grader_score={grader_score}") else: scores[grader_name] = 0.0 - print(f" [DEBUG] {grader_name}: empty score_list") print(f" [OpenJudge Scores] {scores}") if any(quota_exceeded_flags.values()): @@ -625,6 +623,69 @@ def _save_rm_log(self, result, query: str, task_id: str): except Exception: pass + def _save_zero_score_debug( + self, + grader_scores: Dict[str, float], + grader_results: Dict[str, List[Any]], + query: str, + history: List[Dict], + report: str, + task_id: str + ): + """ + 当有 grader 分数为 0 时,保存详细调试信息到单独文件 + + 保存内容包括: + - query: 用户查询 + - traj: 对话历史 + - report: 最终报告(前500字) + - zero_score_reasons: 得 0 分的原因 + """ + try: + # 检查是否有分数为 0 的 grader + zero_score_graders = [name for name, score in grader_scores.items() if score == 0.0] + if not zero_score_graders: + return + + # 提取得 0 分的原因 + zero_score_reasons = {} + for grader_name in zero_score_graders: + if grader_name in grader_results: + score_list = grader_results[grader_name] + if score_list and len(score_list) > 0: + grader_score = score_list[0] + reason = getattr(grader_score, 'reason', None) + zero_score_reasons[grader_name] = str(reason) if reason else "N/A" + else: + zero_score_reasons[grader_name] = "empty score_list" + else: + zero_score_reasons[grader_name] = "grader not in results" + + # 构建调试日志 + debug_log = { + "task_id": task_id, + "timestamp": datetime.now().isoformat(), + "query": query, + "report": report if report else "", + "trajectory": history, + "grader_scores": grader_scores, + "zero_score_graders": zero_score_graders, + "zero_score_reasons": zero_score_reasons + } + + # 保存到单独文件 + save_dir = "/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new/tutorial/example_deep_finance/outputs/reward_zero_debug" + os.makedirs(save_dir, exist_ok=True) + log_file = os.path.join(save_dir, f"zeroscore_{datetime.now().strftime('%Y%m%d')}.jsonl") + with open(log_file, "a", encoding="utf-8") as f: + f.write(json.dumps(debug_log, ensure_ascii=False) + "\n") + + print(f" [ZERO SCORE DEBUG] task_id={task_id}, zero_graders={zero_score_graders}, saved to {log_file}") + + except Exception as e: + print(f"⚠️ Failed to save zero score debug: {e}") + pass + def _compute_penalty(self, tool_calls: int) -> float: """ 计算工具调用惩罚(保留原有逻辑) @@ -674,10 +735,6 @@ def _update_metadata_stats( "penalty": penalty, "step_reward": step_reward, "openjudge_enabled": True, - # Quota exceeded (429) 统计 - "quota_exceeded_any": quota_exceeded_any, # 是否有任何 grader 超额 - "quota_exceeded_count": quota_exceeded_count, # 超额的 grader 数量 - "quota_exceeded_graders": quota_exceeded_flags, # 各 grader 的超额标记 # RM Gallery 相关 "rm_enabled": self._rm_enabled, "rm_raw": rm_raw, diff --git a/tutorial/example_deep_finance/judge/__init__.py b/tutorial/example_deep_finance/judge/__init__.py index 6c9fdfbe..75c8ceff 100644 --- a/tutorial/example_deep_finance/judge/__init__.py +++ b/tutorial/example_deep_finance/judge/__init__.py @@ -1,11 +1,11 @@ # 使得可以通过 from judge import PresentationQualityGrader 直接引用 -# from tutorial.example_deep_finance.judge.grounding.grader import GroundingGrader -from tutorial.example_deep_finance.judge.presentation_quality.grader import PresentationQualityGrader -# from tutorial.example_deep_finance.judge.research_depth.grader import ResearchDepthGrader -# from tutorial.example_deep_finance.judge.research_breadth.grader import ResearchBreadthGrader +from .grounding.grader import GroundingGrader +from .presentation_quality.grader import PresentationQualityGrader +# from .research_depth.grader import ResearchDepthGrader +# from .research_breadth.grader import ResearchBreadthGrader # 以后添加了其他 grader 也可以加在这里 # from .grounding.grader import GroundingGrader # from .research_breadth.grader import ResearchBreadthGrader # __all__ = ["PresentationQualityGrader", "GroundingGrader", "ResearchDepthGrader", "ResearchBreadthGrader"] -__all__ = ["PresentationQualityGrader"] +__all__ = ["PresentationQualityGrader", "GroundingGrader"] From d9cbdc0cd650cbe98a24e8c5e1c9f039bffe745a Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 27 Jan 2026 19:09:27 +0800 Subject: [PATCH 43/56] feat(presentation_quality): upgrade grading to 1/3/5 scoring system with markdown cleanup - Add function to strip markdown code block fences in grounding and presentation_quality modules - Change presentation quality grader to score each of 8 criteria on a 1/3/5 scale instead of pass/fail - Normalize total score by dividing sum of item scores by max (40), improving granularity - Update reasoning output to list lowest scoring items with notes for focused feedback - Revise presentation quality prompt to reflect new 1/3/5 scoring rubric with detailed instructions - Adjust JSON output schema accordingly, replacing boolean pass with numeric score fields - Add get_score utility in JSON utils to extract and validate scores from graded items - Clean report input by removing markdown fences before grading to avoid markup noise - Add grounding weight configuration in YAML template for improved modular judge weighting --- .../judge/grounding/json_utils.py | 17 +++ .../judge/presentation_quality/grader.py | 52 +++++-- .../judge/presentation_quality/json_utils.py | 22 +++ .../judge/presentation_quality/prompt.py | 130 ++++++++++-------- .../yaml_template/deep_finance_template.yaml | 3 +- 5 files changed, 151 insertions(+), 73 deletions(-) diff --git a/tutorial/example_deep_finance/judge/grounding/json_utils.py b/tutorial/example_deep_finance/judge/grounding/json_utils.py index 03d62895..a3f793ad 100644 --- a/tutorial/example_deep_finance/judge/grounding/json_utils.py +++ b/tutorial/example_deep_finance/judge/grounding/json_utils.py @@ -145,6 +145,20 @@ def _strip_think(text: str) -> str: return re.sub(r".*?\s*", "", text, flags=re.S).strip() +def _strip_markdown_fences(text: str) -> str: + """ + 清理 markdown 代码块标记 + - 移除开头的 ```markdown / ```md / ``` 等 + - 移除结尾的 ``` + """ + text = text.strip() + # 移除开头的 ```xxx + text = re.sub(r'^```(?:markdown|md)?\s*\n?', '', text, flags=re.IGNORECASE) + # 移除结尾的 ``` + text = re.sub(r'\n?```\s*$', '', text) + return text.strip() + + def _normalize_traj(trajectory): """兼容 [[...]] 格式""" if isinstance(trajectory, list) and trajectory and isinstance(trajectory[0], list): @@ -216,6 +230,9 @@ def construct_reward_prompt(trajectory: List[Dict[str, Any]], user_prompt_templa final_report = _strip_think(_extract_text_content(traj[i].get("content"))) break + # 清理 markdown 代码块标记 + final_report = _strip_markdown_fences(final_report) + # 遍历提取 user_query, tool_calls, evidence for idx, step in enumerate(traj): role = step.get("role") diff --git a/tutorial/example_deep_finance/judge/presentation_quality/grader.py b/tutorial/example_deep_finance/judge/presentation_quality/grader.py index ac3eb349..c440c3e4 100644 --- a/tutorial/example_deep_finance/judge/presentation_quality/grader.py +++ b/tutorial/example_deep_finance/judge/presentation_quality/grader.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import re from typing import Any, Dict, List, Tuple from openjudge.graders.base_grader import BaseGrader @@ -20,14 +21,14 @@ B_KEYS, C_KEYS, ) -from .json_utils import strict_load_json, validate_shape, get_bool_pass, get_note +from .json_utils import strict_load_json, validate_shape, get_score, get_note class PresentationQualityGrader(BaseGrader): """ - 输入:report_content(研究报告文本) - 输出:GraderScore(name, score, reason) - - score:8项(pass)均分,范围[0,1] + - score:8项按1/3/5分制评分,总分归一化到[0,1](总分/40) - determinism:建议用 temperature=0 + disable thinking 等(见 create_default_model) - 解析失败:score=0,并在 reason 显示报错 """ @@ -92,7 +93,13 @@ async def aevaluate( 入口:直接喂 report_content(研究报告文本) - user_query 可选:用于填充 prompt;不提供则用 "(unknown)" """ + + report = (report_content or "").strip() + + # 清理 markdown 代码块标记 + report = self._strip_markdown_fences(report) + if not report: return GraderScore( name=self.name, @@ -143,6 +150,7 @@ async def aevaluate( ) score, reason = self._score_and_reason(obj) + return GraderScore(name=self.name, score=score, reason=reason) def _score_and_reason(self, obj: Dict[str, Any]) -> Tuple[float, str]: @@ -151,13 +159,13 @@ def _score_and_reason(self, obj: Dict[str, Any]) -> Tuple[float, str]: editorial = obj["editorial"] top_fixes = obj.get("top_fixes", []) - # 8项均分(强确定性:完全由Python算) - pass_map: Dict[str, bool] = {} + # 8项按1/3/5分制计分(强确定性:完全由Python算) + score_map: Dict[str, int] = {} note_map: Dict[str, str] = {} def take(section: Dict[str, Any], key: str): item = section.get(key) - pass_map[key] = get_bool_pass(item) + score_map[key] = get_score(item) note_map[key] = get_note(item) for k in A_KEYS: @@ -167,21 +175,37 @@ def take(section: Dict[str, Any], key: str): for k in C_KEYS: take(editorial, k) - passed = sum(1 for k in ALL_KEYS if pass_map.get(k) is True) - total = len(ALL_KEYS) # 8 - score = passed / float(total) + # 总分 = 各项得分之和 / 最高可能分 (8*5=40),归一化到[0,1] + total_score = sum(score_map.get(k, 1) for k in ALL_KEYS) + max_score = len(ALL_KEYS) * 5 # 8 * 5 = 40 + score = total_score / float(max_score) - # reason:不加额外字段,只给紧凑总结 - failed_items = [k for k in ALL_KEYS if not pass_map.get(k, False)] - failed_str = ", ".join(f"{k}({note_map.get(k,'')})" for k in failed_items[:4]) + # reason:按分数排序,列出低分项 + low_items = [(k, score_map.get(k, 1)) for k in ALL_KEYS if score_map.get(k, 1) < 5] + low_items.sort(key=lambda x: x[1]) # 从低到高 + low_str = ", ".join(f"{k}={s}({note_map.get(k,'')})" for k, s in low_items[:4]) fixes_str = " | ".join(str(x) for x in (top_fixes or [])[:3]) parts: List[str] = [] - parts.append(f"Pass {passed}/{total}") - if failed_items: - parts.append(f"Fail: {failed_str}") + parts.append(f"Score {total_score}/{max_score}") + if low_items: + parts.append(f"Low: {low_str}") if fixes_str: parts.append(f"TopFixes: {fixes_str}") reason = " ; ".join(parts) return round(score, 6), reason[:800] + + @staticmethod + def _strip_markdown_fences(text: str) -> str: + """ + 清理 markdown 代码块标记 + - 移除开头的 ```markdown / ```md / ``` 等 + - 移除结尾的 ``` + """ + text = text.strip() + # 移除开头的 ```xxx + text = re.sub(r'^```(?:markdown|md)?\s*\n?', '', text, flags=re.IGNORECASE) + # 移除结尾的 ``` + text = re.sub(r'\n?```\s*$', '', text) + return text.strip() diff --git a/tutorial/example_deep_finance/judge/presentation_quality/json_utils.py b/tutorial/example_deep_finance/judge/presentation_quality/json_utils.py index 92794ca4..2852ff8d 100644 --- a/tutorial/example_deep_finance/judge/presentation_quality/json_utils.py +++ b/tutorial/example_deep_finance/judge/presentation_quality/json_utils.py @@ -51,6 +51,28 @@ def get_bool_pass(item: Any) -> bool: return False +def get_score(item: Any) -> int: + """ + Extract numeric score (1, 3, 5) from item. + Returns 1 as default if invalid. + """ + if isinstance(item, dict): + v = item.get("score") + else: + v = item + if isinstance(v, (int, float)): + v = int(v) + if v in (1, 3, 5): + return v + # clamp to valid range + if v <= 1: + return 1 + if v >= 5: + return 5 + return 3 + return 1 + + def get_note(item: Any) -> str: if isinstance(item, dict): note = item.get("note", "") diff --git a/tutorial/example_deep_finance/judge/presentation_quality/prompt.py b/tutorial/example_deep_finance/judge/presentation_quality/prompt.py index 0abaf1ee..5e945bf3 100644 --- a/tutorial/example_deep_finance/judge/presentation_quality/prompt.py +++ b/tutorial/example_deep_finance/judge/presentation_quality/prompt.py @@ -1,75 +1,89 @@ # 8项呈现质量检查:A(3)+B(3)+C(2)=8 QUALITY_SYSTEM_PROMPT = """ -你是一位“呈现质量评审官”。你只评估报告的**呈现与表达质量 (Presentation & Editorial Quality)**,用于奖励信号。 -严禁评估:事实真伪/引用支持(Grounding 负责)、内容覆盖广度(Breadth 负责)、分析深度与洞察(Depth 负责)、观点是否正确。 -核心关注:**可扫描性**、**信息结构化**、**逻辑链条的可视化呈现**、**表达清晰与可用性**。 +你是一位“深度研究报告呈现评审官”。你的任务是评估报告的 **用户体验与信息架构 (Presentation & UX)**,为强化学习提供奖励信号。 + +**严禁评估**:事实真伪、引用准确性(由 Grounding 模型负责)、内容广度与深度。 +**核心关注**:**认知负荷管理**、**信息的可扫读性**、**逻辑的可视化**、**Markdown 渲染质量**。 ======================== -评分标准(仅判定 pass=true/false) +评分标准 (1/3/5 分制) ======================== -对以下 8 个检查项分别给出 pass/fail,并给一句 note(≤25字,需指出“位置或症状”,避免空泛)。 - -A) Scan & Navigation(可扫描性) -A1 结论先行(Key Takeaways Top) -- Pass:开头可见“摘要/要点/核心结论”块(短段或列表均可),读者无需通读即可抓到主结论。 -- Fail:开头直接进入细节/材料堆叠,无概括性要点。 - -A2 结构导航(Navigable Structure) -- Pass:正文有清晰分节(标题层级或明显分段),读者能快速定位主要部分(分析/风险/结论等)。 -- Fail:无结构或结构混乱,像长篇流水账,难以导航。 - -A3 视觉重点(Visual Hierarchy) -- Pass:重点信息对“扫读友好”(要点化/短句分行/适度强调等),且重点承载信息而非装饰。 -- Fail:全文平铺直叙;或存在明显“格式堆砌”但不增信息。 - -B) Information Structuring(信息结构化) -B1 密集信息解构(Dense Info Structured) -- Pass:数字/多条件/多点信息密集处被列表/分组/表格等拆解,易读易取。 -- Fail:关键数据淹没在长难句或长段落(典型:数字长句串联)。 - -B2 对比对齐(Comparisons Aligned) -- Pass:涉及横向对比(A vs B/同行对比/情景对比)时,用表格或对齐结构呈现,使维度一眼可比(不强制表格)。 -- Fail:对比点散落在不同段落,维度不对齐,无法直观对照。 - -B3 一致性(Consistency) -- Pass:单位/口径/标点/小标题/列表风格整体统一,专业感稳定。 -- Fail:格式与表述明显混乱,增加阅读负担。 - -C) Editorial Clarity(编辑清晰度) -C1 论证链可视化(Argument Chain Presented) -- Pass:在呈现上能跟随“主张→依据→解释→影响/结论”的链条(例如用分段或 bullet 串联/对齐呈现),不是只堆材料。 -- Fail:大量材料堆砌,但缺少可视化的逻辑线索(读者难跟随)。 - -C2 风险与行动(Risk & Actionability Clear) -- Pass:以清晰形式列出风险/边界/不确定性,并给出可执行的下一步关注点(只看表达是否清楚存在,不评全面与正确)。 -- Fail:未提及风险/边界/下一步,或表述极度含糊不可操作。 - -反刷分原则(必须执行): -- 空标题占位、空表格/无意义表格、重复 bullet 但不增加信息 → 相关项直接判 fail,并在 note 标注“形式堆砌”。 +对以下 8 个维度进行打分。 +- **1分 (Fail)**:严重阻碍阅读,格式混乱或缺失。 +- **3分 (Pass)**:甚至及格,有基本结构,但平庸、啰嗦或不够直观。 +- **5分 (Excellent)**:出版级质量,结构极佳,一眼能抓取核心,降低了读者的认知成本。 + +请针对每个子项给出分数(1, 3, 5)及 Note(≤25字,指出具体位置或症状)。 + +### A) Scan & Navigation(可扫描性) +**A1 结论先行 (Key Takeaways Top)** +- 5分:开头有独立的“核心摘要/TL;DR”块,且要点清晰,读者无需滚动即可获取主结论。 +- 3分:有摘要,但写成了流水账段落,或混杂在正文中不够醒目。 +- 1分:无摘要,开篇即陷入细节或背景介绍。 + +**A2 结构导航 (Navigable Structure)** +- 5分:层级分明 (H1/H2/H3),长文有清晰的“路标”(小标题),支持快速跳读定位。 +- 3分:有分节,但段落过长(Wall of text),缺乏内部视觉引导。 +- 1分:结构混乱,标题层级错误或缺失,难以导航。 + +**A3 视觉重点 (Visual Hierarchy)** +- 5分:利用 **加粗**、*斜体* 或 `代码块` 精准强调核心洞察,信噪比高。 +- 3分:有强调,但过度使用(满篇加粗)或重点不突出(强调了无关词)。 +- 1分:全文平铺直叙,无任何视觉重点。 + +### B) Information Structuring(信息结构化) +**B1 密集信息解构 (Dense Info Structured)** +- 5分:复杂数据/多条件逻辑被转化为 Markdown **表格** 或 **嵌套列表**,一目了然。 +- 3分:使用了列表,但内容仍是长难句堆砌,未真正拆解信息。 +- 1分:关键数字或复杂参数淹没在长段落文本中。 + +**B2 对比对齐 (Comparisons Aligned)** +- 5分:涉及对比(方案A vs B / 历史 vs 现状)时,使用表格或对齐结构,维度横向可比。 +- 3分:有对比意图,但分散在不同段落,读者需来回对照。 +- 1分:对比维度混乱或缺失,无法直观比较。 + +**B3 一致性与渲染 (Consistency & Rendering)** +- 5分:格式统一(符号/单位),Markdown 渲染完美(表格无断裂、公式无乱码)。 +- 3分:存在少量格式不统一,或轻微的渲染瑕疵但不影响理解。 +- 1分:表格错位、公式未闭合、列表层级混乱,严重影响阅读。 + +### C) Editorial Clarity(编辑清晰度) +**C1 论证链可视化 (Argument Chain Presented)** +- 5分:逻辑链条可视(如使用 `主张 -> 证据 -> 结论` 的结构),引用锚点清晰 `[1]`。 +- 3分:逻辑存在,但淹没在文字中,缺乏连接词或视觉引导。 +- 1分:材料堆砌,缺乏清晰的推导线索。 + +**C2 风险与行动 (Risk & Actionability Clear)** +- 5分:独立板块清晰列出“风险/局限性”及“下一步建议”,具有极高的可操作性。 +- 3分:提到了风险或建议,但含糊其辞,或混杂在结论中。 +- 1分:完全未提及风险边界或下一步行动。 + +**反刷分原则 (Anti-Gaming)**: +- 空表格、无意义的重复列表、为了格式而格式(如把一句简单的话硬拆成列表) -> 直接判 **1分**,Note 标注“过度格式化”。 ======================== -输出要求(Strict JSON) +输出要求 (Strict JSON) ======================== -必须输出可解析 JSON;pass 必须为 boolean。 -不要输出 Markdown;不要添加额外字段;不得省略字段。 +必须输出可解析 JSON。 +**注意**:为了提供梯度信号,字段由 `pass` 改为 `score`,值必须为 1, 3, or 5。 -JSON 模板(字段必须齐全): +JSON 模板: { "scan": { - "A1_key_takeaways_top": {"pass": true, "note": "≤25字定位理由"}, - "A2_navigable_structure": {"pass": true, "note": "≤25字定位理由"}, - "A3_visual_hierarchy": {"pass": true, "note": "≤25字定位理由"} + "A1_key_takeaways_top": {"score": 0, "note": "≤25字定位理由"}, + "A2_navigable_structure": {"score": 0, "note": "≤25字定位理由"}, + "A3_visual_hierarchy": {"score": 0, "note": "≤25字定位理由"} }, "structuring": { - "B1_dense_info_structured": {"pass": false, "note": "≤25字定位理由"}, - "B2_comparisons_aligned": {"pass": true, "note": "≤25字定位理由"}, - "B3_consistency": {"pass": true, "note": "≤25字定位理由"} + "B1_dense_info_structured": {"score": 0, "note": "≤25字定位理由"}, + "B2_comparisons_aligned": {"score": 0, "note": "≤25字定位理由"}, + "B3_consistency": {"score": 0, "note": "≤25字定位理由"} }, "editorial": { - "C1_argument_chain_presented": {"pass": false, "note": "≤25字定位理由"}, - "C2_risk_and_actionability_clear": {"pass": true, "note": "≤25字定位理由"} + "C1_argument_chain_presented": {"score": 0, "note": "≤25字定位理由"}, + "C2_risk_and_actionability_clear": {"score": 0, "note": "≤25字定位理由"} }, - "top_fixes": ["最多3条,仅谈呈现层面改进"] + "top_fixes": ["最多3条,仅谈呈现层面改进,针对最低分项"] } """ @@ -77,10 +91,10 @@ 请审计以下研究报告的【呈现质量】(只谈呈现/排版/结构,不谈事实对错/引用支持/覆盖/深度)。 ### User Query -{{user_query}} +{user_query} ### AI Report -{{report_content}} +{report_content} ----- 请严格按 System Prompt 的锚点输出 JSON;不要输出 Markdown;不要添加额外字段。 diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index ed82cb25..8b5426fa 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -1,6 +1,6 @@ # ------------------ 主要配置 ------------------ ajet: - project_name: ajet_deep_finance + project_name: "{{PREFIX}}" experiment_name: "{{SUFFIX}}" # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) judge: @@ -11,6 +11,7 @@ ajet: val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径 # OpenJudge 权重配置 presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估 + grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估 rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 task_judge: # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) From 4538f5a6f18ce9b7a7b28345df8fa3fe0ec791ef Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 27 Jan 2026 19:25:02 +0800 Subject: [PATCH 44/56] chore(config): update experiment suffix, prefix and reward weights in deep_finance.sh --- tutorial/example_deep_finance/deep_finance.sh | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index d9f624a2..f8ce664a 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -3,8 +3,8 @@ set -e #=============================================================================== # 1. 配置区域 - 用户只需修改这里 #=============================================================================== -SUFFIX="deep_finance" # 实验后缀,影响所有日志和实验名称 -PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 +SUFFIX="newjudge" # 实验后缀,影响所有日志和实验名称 +PREFIX="ajet_newjudge" # 实验前缀,影响日志和实验所在文件夹 # OpenJudge 模型配置 OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 @@ -12,10 +12,9 @@ RM_LLM='qwen-max' # RM Gallery 评分模型 JUDGE_CONCURRENCY=10 # 奖励权重配置 -RM_WEIGHT=0.4 -CITATION_AUDIT_WEIGHT=0.2 -REPORT_RESOLUTION_WEIGHT=0.2 -TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 +RM_WEIGHT=0.5 +PRESENTATION_QUALITY_WEIGHT=0.25 +GROUNDING_WEIGHT=0.25 # 训练参数配置 NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 @@ -24,7 +23,7 @@ NUM_STEPS=6 # 每个样本step轮数 DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 # 主目录 -export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" +export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new" NNODES=${WORLD_SIZE} @@ -56,12 +55,11 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ -e "s|{{NNODES}}|${NNODES}|g" \ -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ - -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ + -e "s|{{PRESENTATION_QUALITY_WEIGHT}}|${PRESENTATION_QUALITY_WEIGHT}|g" \ + -e "s|{{GROUNDING_WEIGHT}}|${GROUNDING_WEIGHT}|g" \ -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ -e "s|{{RM_LLM}}|${RM_LLM}|g" \ -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ - -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ - -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ @@ -73,7 +71,7 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} echo "配置文件已生成: ${CONFIG_FILE}" -echo "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" +echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" #=============================================================================== # 3. 环境配置 @@ -115,7 +113,7 @@ LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" - +env_log_prefix="${SUFFIX}__${CURRENT_TIME}" # 多机训练参数配置 GPUS_PER_NODE=8 EXPECTED_WORKERS=$WORLD_SIZE @@ -208,7 +206,7 @@ if [[ $HOSTNAME == *"-master-"* ]]; then --with-deepfinance \ --conf ${CONFIG_FILE} \ --backbone="verl" \ - --prefix=${SUFFIX} \ + --prefix=${env_log_prefix} \ 2>&1 | tee ${TRAIN_LOG} From 818a4f7b8cad4ee9fa6e90b8ca2b711e9dc187cf Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 27 Jan 2026 19:29:30 +0800 Subject: [PATCH 45/56] fix(deep_finance): update environment variables and training launch options --- tutorial/example_deep_finance/deep_finance.sh | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index 2328adc1..240c315b 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -1,5 +1,5 @@ #!/bin/bash -set -e +set -e #=============================================================================== # 1. 配置区域 - 用户只需修改这里 #=============================================================================== @@ -22,7 +22,7 @@ TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 -# 主目录 (按需更改) +# 主目录(需要更改) export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new" NNODES=${WORLD_SIZE} @@ -105,7 +105,7 @@ export DEEPFINANCE_MCP_CONFIG DEEPFINANCE_TOOL_RESULT_MAX_CHARS # 其他服务配置 HF_ENDPOINT="https://hf-mirror.com" ES_HOSTS="http://11.160.132.46:8200" -export HF_ENDPOINT ES_HOSTS +export HF_ENDPOINT ES_HOSTS # log 文件位置 CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") @@ -155,6 +155,8 @@ export NCCL_ASYNC_ERROR_HANDLING=1 export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" export RAY_CLUSTER_MODE="multi_node" +export DEEPFINANCE_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 +export DEEPFINANCE_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && DEEPFINANCE_TOOL_RESULT_MAX_CHARS=${DEEPFINANCE_TOOL_RESULT_MAX_CHARS} DEEPFINANCE_MCP_CONFIG=${DEEPFINANCE_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" #=============================================================================== @@ -201,11 +203,12 @@ if [[ $HOSTNAME == *"-master-"* ]]; then # 启动训练任务(最核心) python ajet/launcher.py \ + --with-deepfinance \ --conf ${CONFIG_FILE} \ --backbone="verl" \ --prefix=${env_log_prefix} \ 2>&1 | tee ${TRAIN_LOG} - + #=============================================================================== # 6.2 Worker 节点启动流程 @@ -217,4 +220,4 @@ else ray stop || true ray start --address $MASTER_ADDR:6379 --num-gpus 8 while true; do sleep 60; done -fi +fi \ No newline at end of file From 1bb7f6097dfdbb31b69ecea78094f3b814b45c96 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 27 Jan 2026 19:30:30 +0800 Subject: [PATCH 46/56] chore(config): parameterize deep finance training configuration --- .../example_deep_finance/deep_finance.yaml | 35 +++++++++---------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/tutorial/example_deep_finance/deep_finance.yaml b/tutorial/example_deep_finance/deep_finance.yaml index 15dd5665..33103fe3 100644 --- a/tutorial/example_deep_finance/deep_finance.yaml +++ b/tutorial/example_deep_finance/deep_finance.yaml @@ -1,19 +1,18 @@ # ------------------ 主要配置 ------------------ ajet: - project_name: ajet_deep_finance - experiment_name: "ajet_deep_finance" + project_name: "{{PREFIX}}" + experiment_name: "{{SUFFIX}}" # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) judge: - openjudge_llm: qwen-flash # OpenJudge 模型 - rm_llm: qwen-max # RM Gallery 模型 - concurrency: 10 # Judge 并发数 + openjudge_llm: {{OPENJUDGE_LLM}} # OpenJudge 模型 + rm_llm: {{RM_LLM}} # RM Gallery 模型 + concurrency: {{JUDGE_CONCURRENCY}} # Judge 并发数 train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径 val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径 # OpenJudge 权重配置 - report_resolution_weight: 0.2 # 报告质量评估 - trajectory_faithfulness_weight: 0.2 # 事实准确性评估 - citation_audit_weight: 0.2 # 引用审计评估 (覆盖率 + 真实性) - rm_weight: 0.4 # RM Gallery 权重 + presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估 + grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估 + rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 task_judge: # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) judge_protocol: tutorial.example_deep_finance.deep_finance_judge->DeepFinanceJudgeByOpenJudge @@ -21,7 +20,7 @@ ajet: # ✨✨✨✨ 设置待训练的模型 path: {{MODEL_PATH}} trainer_common: - nnodes: 8 + nnodes: {{NNODES}} n_gpus_per_node: 8 val_before_train: True val_pass_n: 8 @@ -32,10 +31,10 @@ ajet: rollout: # ✨✨✨✨ 编写并选择Agent user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol - force_disable_toolcalls: True + force_disable_toolcalls: False enable_oversample: False tensor_model_parallel_size: 8 - num_repeat: 4 + num_repeat: {{NUM_REPEAT}} max_env_worker: 64 # 增加环境并行数 max_num_seqs: 64 # 增加VLLM并发序列数 max_response_length_in_one_turn: 8000 @@ -43,14 +42,14 @@ ajet: agent_madness_reward: 0.0 compute_madness_checklist: None multi_turn: - max_steps: 6 + max_steps: {{NUM_STEPS}} interchange_server: interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) debug: debug_max_parallel: 1 # 增加并行任务数,充分利用GPU debug_first_n_tasks: 100 # 增加处理的任务数 data: - train_batch_size: 32 + train_batch_size: {{TRAIN_BATCH_SIZE}} max_prompt_length: 8000 max_response_length: 41000 @@ -58,18 +57,16 @@ ajet: type: deep_finance # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service deep_finance: training: - file_path: {{TRAIN_PATH}} + file_path: {{TRAIN_DATA_PATH}} validation: - file_path: {{VAL_PATH}} + file_path: {{VAL_DATA_PATH}} # env_service 仍需配置(用于工具调用) env_service: env_type: "finworld" env_url: {{ENV_SERVICE_URL}} env_action_preference: code - - trainer: - default_local_dir: {{CKPT_SAVE_PATH}} + default_local_dir: "{{CKPT_SAVE_PATH}}/{{PREFIX}}/{{SUFFIX}}" # resume_mode: disable # 禁用自动恢复,从头开始训练 actor_rollout_ref: rollout: From 460318f462e46ceb362e522163bc1404445b3080 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 27 Jan 2026 19:31:36 +0800 Subject: [PATCH 47/56] chore(config): update experiment suffix, prefix, and weight parameters --- .../deep_finance_single.sh | 68 +++---------------- 1 file changed, 10 insertions(+), 58 deletions(-) diff --git a/tutorial/example_deep_finance/deep_finance_single.sh b/tutorial/example_deep_finance/deep_finance_single.sh index 6b27d0f3..e794dff0 100644 --- a/tutorial/example_deep_finance/deep_finance_single.sh +++ b/tutorial/example_deep_finance/deep_finance_single.sh @@ -3,8 +3,8 @@ set -e #=============================================================================== # 1. 配置区域 - 用户只需修改这里 #=============================================================================== -SUFFIX="ajet_deep_finance" # 实验后缀,影响所有日志和实验名称 -PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 +SUFFIX="newjudge" # 实验后缀,影响所有日志和实验名称 +PREFIX="ajet_newjudge" # 实验前缀,影响日志和实验所在文件夹 # OpenJudge 模型配置 OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 @@ -12,10 +12,9 @@ RM_LLM='qwen-max' # RM Gallery 评分模型 JUDGE_CONCURRENCY=10 # 奖励权重配置 -RM_WEIGHT=0.4 -CITATION_AUDIT_WEIGHT=0.2 -REPORT_RESOLUTION_WEIGHT=0.2 -TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 +RM_WEIGHT=0.5 +PRESENTATION_QUALITY_WEIGHT=0.25 +GROUNDING_WEIGHT=0.25 # 训练参数配置 NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 @@ -23,7 +22,8 @@ TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 -# 主目录 +# 主目录(需要更改) +export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new" NNODES=${WORLD_SIZE} @@ -55,70 +55,23 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ -e "s|{{NNODES}}|${NNODES}|g" \ -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ - -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ + -e "s|{{PRESENTATION_QUALITY_WEIGHT}}|${PRESENTATION_QUALITY_WEIGHT}|g" \ + -e "s|{{GROUNDING_WEIGHT}}|${GROUNDING_WEIGHT}|g" \ -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ -e "s|{{RM_LLM}}|${RM_LLM}|g" \ -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ - -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ - -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ -e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \ -e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \ - -e "s|{{ENV_SERVICE_URL}}|${ENV_SERVICE_URL}|g" \ -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ -e "s|{{CKPT_SAVE_PATH}}|${CKPT_SAVE_PATH}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} echo "配置文件已生成: ${CONFIG_FILE}" -echo "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" - -#=============================================================================== -# 3. 环境配置 -#=============================================================================== -# MongoDB 缓存配置 -CACHE_TYPE="mongodb" -MONGO_URI="mongodb://${ADDR}:27117/" -MONGO_DB_NAME="finworld_cache" -MONGO_COLLECTION_NAME="tool_cache" -export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME - -# DeepFinance MCP 配置 -DEEPFINANCE_MCP_CONFIG="${AJET_ROOT}/tutorial/example_deep_finance/config/mcp_finance_tool_generated.json" - -# 动态生成 MCP 配置文件 -mkdir -p $(dirname ${DEEPFINANCE_MCP_CONFIG}) -cat > ${DEEPFINANCE_MCP_CONFIG} << EOF -{ - "mcpServers": { - "flowllm": { - "transport": "sse", - "url": "http://${ADDR}:${MCP_PORT}/sse", - "timeout": 600, - "sse_read_timeout": 1200 - } - } -} -EOF -export DEEPFINANCE_MCP_CONFIG DEEPFINANCE_TOOL_RESULT_MAX_CHARS - -# 其他服务配置 -HF_ENDPOINT="https://hf-mirror.com" -ES_HOSTS="http://11.160.132.46:8200" -export HF_ENDPOINT ES_HOSTS - -# log 文件位置 -CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") -LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" -MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" -ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" -TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" - -# 多机训练参数配置 -GPUS_PER_NODE=8 -EXPECTED_WORKERS=$WORLD_SIZE +echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" #=============================================================================== @@ -162,7 +115,6 @@ export RAY_CLUSTER_MODE="multi_node" #=============================================================================== # 6. 主流程 #=============================================================================== -log "开始多机多卡训练: ${SUFFIX}" log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" mkdir -p ${LOG_DIR} mkdir -p $(dirname ${CONFIG_FILE}) From 57a3a544e5ea94308e263523d8cedc8bc37d2546 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 27 Jan 2026 19:33:53 +0800 Subject: [PATCH 48/56] fix(example_deep_finance): update dynamic config file generation path --- tutorial/example_deep_finance/deep_finance.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index 240c315b..bee02ac2 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -46,7 +46,7 @@ fi # 2. 动态生成配置文件 (从yaml template生成yaml) #=============================================================================== # 修改:配置文件生成路径,现在动态生成到 yaml 目录下 -CONFIG_TEMPLATE="tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml" +CONFIG_TEMPLATE="tutorial/example_deep_finance/deep_finance.yaml" CONFIG_FILE="${AJET_ROOT}/tutorial/example_deep_finance/yaml/${SUFFIX}.yaml" mkdir -p $(dirname ${CONFIG_FILE}) From beaa54041847c578f7617df6c911115640386df8 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 27 Jan 2026 20:29:55 +0800 Subject: [PATCH 49/56] refactor(judge): remove deprecated presentation quality script --- .../judge/scripts/run_presentation_quality.py | 44 ------------------- 1 file changed, 44 deletions(-) delete mode 100644 tutorial/example_deep_finance/judge/scripts/run_presentation_quality.py diff --git a/tutorial/example_deep_finance/judge/scripts/run_presentation_quality.py b/tutorial/example_deep_finance/judge/scripts/run_presentation_quality.py deleted file mode 100644 index 840076ed..00000000 --- a/tutorial/example_deep_finance/judge/scripts/run_presentation_quality.py +++ /dev/null @@ -1,44 +0,0 @@ -import asyncio -import sys -import os - -# 添加项目根目录到 Python 路径 -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../..")) -sys.path.insert(0, PROJECT_ROOT) -print(f"PROJECT_ROOT: {PROJECT_ROOT}") - -from openjudge.models import OpenAIChatModel -from tutorial.example_deep_finance.judge import PresentationQualityGrader - - -async def main(): - # 你也可以只写:model = OpenAIChatModel(model="qwen3-32b") - # 并用环境变量 OPENAI_API_KEY / OPENAI_BASE_URL(QuickStart里推荐这种方式) - model = OpenAIChatModel( - model="qwen-flash", - extra_body={"enable_thinking": False, "temperature": 0, "top_p": 1, "seed": 0}, - ) - - grader = PresentationQualityGrader(model=model) - - report = """ - # 藏格矿业分析报告 - - ## 执行摘要 - - 核心结论:... - - ## 财务对比 - | 公司 | 营收 | 净利 | - |---|---:|---:| - | A | 20 | 5 | - - ## 风险与下一步 - - 风险:... - - 下一步:... - """ - res = await grader.aevaluate(report_content=report, user_query="分析藏格矿业的财务状况") - print(res) - - -if __name__ == "__main__": - asyncio.run(main()) From 13d7d823ed99e3ee07c423af182b95ef2d0f6020 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Wed, 28 Jan 2026 20:19:20 +0800 Subject: [PATCH 50/56] chore(config): add LFS patterns and Env Service URL configuration - Added various file extensions to .gitattributes for Git LFS tracking - Added dataset_gsm8k/ to .gitignore to exclude dataset files - Introduced ENV_SERVICE_URL variable in deep_finance.sh and deep_finance_single.sh - Updated configuration file generation to include ENV_SERVICE_URL substitution - Commented out invalid reference penalty calculation in grounding grader logic --- .gitattributes | 38 +++++++++++++++++++ .gitignore | 1 + tutorial/example_deep_finance/deep_finance.sh | 4 ++ .../deep_finance_single.sh | 4 ++ .../judge/grounding/grader.py | 5 ++- 5 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 .gitattributes diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..227449ad --- /dev/null +++ b/.gitattributes @@ -0,0 +1,38 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bin.* filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zstandard filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +# Audio files - uncompressed +*.pcm filter=lfs diff=lfs merge=lfs -text +*.sam filter=lfs diff=lfs merge=lfs -text +*.raw filter=lfs diff=lfs merge=lfs -text +# Audio files - compressed +*.aac filter=lfs diff=lfs merge=lfs -text +*.flac filter=lfs diff=lfs merge=lfs -text +*.mp3 filter=lfs diff=lfs merge=lfs -text +*.ogg filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index 1698ba2e..35397baa 100644 --- a/.gitignore +++ b/.gitignore @@ -160,6 +160,7 @@ tutorial/example_deep_finance/scripts/* flash_attn-2.8.*.whl tutorial/example_deep_finance/prepare_data/* tutorial/example_deep_finance/judge/analytical_sufficiency/* +dataset_gsm8k/* .dockerignore benchmark_datasets diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index bee02ac2..de1aa061 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -22,6 +22,9 @@ TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 +# Env Service URL 配置 +ENV_SERVICE_URL="http://127.0.0.1:8080" # 环境服务地址 + # 主目录(需要更改) export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new" @@ -68,6 +71,7 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ -e "s|{{CKPT_SAVE_PATH}}|${CKPT_SAVE_PATH}|g" \ + -e "s|{{ENV_SERVICE_URL}}|${ENV_SERVICE_URL}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} echo "配置文件已生成: ${CONFIG_FILE}" diff --git a/tutorial/example_deep_finance/deep_finance_single.sh b/tutorial/example_deep_finance/deep_finance_single.sh index e794dff0..cc5c8a00 100644 --- a/tutorial/example_deep_finance/deep_finance_single.sh +++ b/tutorial/example_deep_finance/deep_finance_single.sh @@ -22,6 +22,9 @@ TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 +# Env Service URL 配置 +ENV_SERVICE_URL="http://127.0.0.1:8080" # 环境服务地址 + # 主目录(需要更改) export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new" @@ -68,6 +71,7 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ -e "s|{{CKPT_SAVE_PATH}}|${CKPT_SAVE_PATH}|g" \ + -e "s|{{ENV_SERVICE_URL}}|${ENV_SERVICE_URL}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} echo "配置文件已生成: ${CONFIG_FILE}" diff --git a/tutorial/example_deep_finance/judge/grounding/grader.py b/tutorial/example_deep_finance/judge/grounding/grader.py index 599ccc9c..42f5d141 100644 --- a/tutorial/example_deep_finance/judge/grounding/grader.py +++ b/tutorial/example_deep_finance/judge/grounding/grader.py @@ -195,11 +195,12 @@ def _compute_scores(self, obj: Dict[str, Any]) -> Tuple[float, str]: # 轻量惩罚:存在 invalid refs 会降低 reward # 每个 invalid 号扣 0.1,最多扣 0.5 - invalid_penalty = min(0.1 * invalid_ref_count, 0.5) + # invalid_penalty = min(0.1 * invalid_ref_count, 0.5) + invalid_penalty = 0 # final_reward: 综合分数(权重 0.5:0.5),再叠加 invalid 惩罚 final_reward = 0.5 * citation_coverage_score + 0.5 * grounding_score - final_reward = max(0.0, final_reward - invalid_penalty) + # final_reward = max(0.0, final_reward - invalid_penalty) # 构建 reason good_citations = obj.get('good_citations', []) From eb6e2aff53dc50a5052a5f63acd2f4a9f25ed7ae Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 3 Feb 2026 10:39:48 +0800 Subject: [PATCH 51/56] =?UTF-8?q?feat(deep=5Ffinance):=20=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=E5=BC=95=E7=94=A8=E5=AE=A1=E8=AE=A1=E3=80=81CGCV=20?= =?UTF-8?q?=E5=92=8C=E5=8F=AF=E8=BF=BD=E6=BA=AF=E6=80=A7=E8=AF=84=E5=88=86?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 AuditGrader 引用逻辑审计模块,实现对引用合规性的严格验证 - 新增 CGCVGrader,支持引用锚定断言的详细验证及评分 - 新增 TraceabilityRewardGrader,实现报告断言的证据锚点可追溯性检查 - 在 DeepFinanceJudgeByOpenJudge 中集成以上三种新评分器,并配置对应权重 - 扩展 reward_metric_helper,增加 'audit', 'traceability', 'cgcv' 等评分项 - 更新依赖和导入,支持通过 judge 包直接访问三种评分器 - 新增相关 JSON 工具和 Prompt 模板以支持评分器的准确评估与结果解析 - .gitignore 增加 example_deep_finance/output_report 路径忽略规则 --- .gitignore | 1 + .../metric_helper/reward_metric_helper.py | 5 +- .../deep_finance_judge.py | 22 +- .../example_deep_finance/judge/__init__.py | 5 +- .../judge/audit/__init__.py | 4 + .../judge/audit/grader.py | 215 +++++++ .../judge/audit/json_utils.py | 176 ++++++ .../judge/audit/prompt.py | 67 ++ .../judge/cgcv/__init__.py | 7 + .../example_deep_finance/judge/cgcv/grader.py | 362 +++++++++++ .../judge/cgcv/json_utils.py | 578 ++++++++++++++++++ .../example_deep_finance/judge/cgcv/prompt.py | 378 ++++++++++++ .../judge/grounding/prompt.py | 221 ++++--- .../judge/traceability/__init__.py | 7 + .../judge/traceability/grader.py | 137 +++++ .../judge/traceability/json_utils.py | 374 ++++++++++++ .../judge/traceability/prompt.py | 122 ++++ .../yaml_template/deep_finance_template.yaml | 3 + .../yaml_template/infer.yaml | 87 +++ 19 files changed, 2681 insertions(+), 90 deletions(-) create mode 100644 tutorial/example_deep_finance/judge/audit/__init__.py create mode 100644 tutorial/example_deep_finance/judge/audit/grader.py create mode 100644 tutorial/example_deep_finance/judge/audit/json_utils.py create mode 100644 tutorial/example_deep_finance/judge/audit/prompt.py create mode 100644 tutorial/example_deep_finance/judge/cgcv/__init__.py create mode 100644 tutorial/example_deep_finance/judge/cgcv/grader.py create mode 100644 tutorial/example_deep_finance/judge/cgcv/json_utils.py create mode 100644 tutorial/example_deep_finance/judge/cgcv/prompt.py create mode 100644 tutorial/example_deep_finance/judge/traceability/__init__.py create mode 100644 tutorial/example_deep_finance/judge/traceability/grader.py create mode 100644 tutorial/example_deep_finance/judge/traceability/json_utils.py create mode 100644 tutorial/example_deep_finance/judge/traceability/prompt.py create mode 100644 tutorial/example_deep_finance/yaml_template/infer.yaml diff --git a/.gitignore b/.gitignore index 35397baa..7d7c0a4a 100644 --- a/.gitignore +++ b/.gitignore @@ -160,6 +160,7 @@ tutorial/example_deep_finance/scripts/* flash_attn-2.8.*.whl tutorial/example_deep_finance/prepare_data/* tutorial/example_deep_finance/judge/analytical_sufficiency/* +tutorial/example_deep_finance/output_report/* dataset_gsm8k/* .dockerignore diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py index ea951d5a..685798f3 100644 --- a/ajet/utils/metric_helper/reward_metric_helper.py +++ b/ajet/utils/metric_helper/reward_metric_helper.py @@ -83,7 +83,10 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str openjudge_graders = [ "presentation_quality", "grounding", - "planning" + "planning", + "audit", + "traceability", + "cgcv" ] for grader_name in openjudge_graders: diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py index 03f10130..b4c9c96f 100644 --- a/tutorial/example_deep_finance/deep_finance_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -15,7 +15,7 @@ from openjudge.models.openai_chat_model import OpenAIChatModel from openjudge.runner.grading_runner import GraderConfig, GradingRunner -from tutorial.example_deep_finance.judge import PresentationQualityGrader, GroundingGrader +from tutorial.example_deep_finance.judge import PresentationQualityGrader, GroundingGrader, CGCVGrader, AuditGrader, TraceabilityRewardGrader @@ -103,7 +103,10 @@ def _setup_weights(self): self.w = { "rm": getattr(cfg, "rm_weight", 1.0) if cfg else 1.0, # RM Gallery 权重 "presentation_quality": getattr(cfg, "presentation_quality_weight", 0.25) if cfg else 0.25, - "grounding": getattr(cfg, "grounding_weight", 0.25) if cfg else 0.25, + "grounding": getattr(cfg, "grounding_weight", 0.0) if cfg else 0.0, # 引用规范性评估 + "cgcv": getattr(cfg, "cgcv_weight", 0.25) if cfg else 0.25, # Citation-Grounded Claim Verification + "audit": getattr(cfg, "audit_weight", 0.0) if cfg else 0.0, # 引用逻辑审计 + "traceability": getattr(cfg, "traceability_weight", 0.0) if cfg else 0.0, # 可追溯性/可核验性审计 (TVR) } # 归一化(注意:action_loop 是惩罚项,不参与归一化;rm 需要参与归一化) @@ -256,6 +259,21 @@ def extract_report_content(data: Dict) -> str: grader=GroundingGrader(model=model), mapper=lambda data: {"traj": data}, ), + # CGCV: Citation-Grounded Claim Verification - 引用锤定的断言验证 + "cgcv": GraderConfig( + grader=CGCVGrader(model=model), + mapper=lambda data: {"traj": data}, + ), + # Audit: 引用逻辑审计 - 验证引用是否严格符合逻辑蕴含原则 + "audit": GraderConfig( + grader=AuditGrader(model=model), + mapper=lambda data: {"traj": data}, + ), + # Traceability: 可追溯性/可核验性审计 - 验证报告断言是否有证据锚点支撑 + "traceability": GraderConfig( + grader=TraceabilityRewardGrader(model=model), + mapper=lambda data: {"traj": data}, + ), } def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> Tuple[float, bool]: diff --git a/tutorial/example_deep_finance/judge/__init__.py b/tutorial/example_deep_finance/judge/__init__.py index 75c8ceff..c2aee2be 100644 --- a/tutorial/example_deep_finance/judge/__init__.py +++ b/tutorial/example_deep_finance/judge/__init__.py @@ -1,6 +1,9 @@ # 使得可以通过 from judge import PresentationQualityGrader 直接引用 from .grounding.grader import GroundingGrader from .presentation_quality.grader import PresentationQualityGrader +from .cgcv.grader import CGCVGrader +from .audit.grader import AuditGrader +from .traceability.grader import TraceabilityRewardGrader # from .research_depth.grader import ResearchDepthGrader # from .research_breadth.grader import ResearchBreadthGrader @@ -8,4 +11,4 @@ # from .grounding.grader import GroundingGrader # from .research_breadth.grader import ResearchBreadthGrader # __all__ = ["PresentationQualityGrader", "GroundingGrader", "ResearchDepthGrader", "ResearchBreadthGrader"] -__all__ = ["PresentationQualityGrader", "GroundingGrader"] +__all__ = ["PresentationQualityGrader", "GroundingGrader", "CGCVGrader", "AuditGrader", "TraceabilityRewardGrader"] diff --git a/tutorial/example_deep_finance/judge/audit/__init__.py b/tutorial/example_deep_finance/judge/audit/__init__.py new file mode 100644 index 00000000..7e4d05c3 --- /dev/null +++ b/tutorial/example_deep_finance/judge/audit/__init__.py @@ -0,0 +1,4 @@ +"""Grounding Grader - 引用逻辑审计""" +from .grader import AuditGrader + +__all__ = ["AuditGrader"] \ No newline at end of file diff --git a/tutorial/example_deep_finance/judge/audit/grader.py b/tutorial/example_deep_finance/judge/audit/grader.py new file mode 100644 index 00000000..18c0e397 --- /dev/null +++ b/tutorial/example_deep_finance/judge/audit/grader.py @@ -0,0 +1,215 @@ +"""Audit Grader - 引用逻辑审计 (OpenJudge logic version)""" +from __future__ import annotations + +import os +from typing import Any, Dict, List, Tuple + +from openjudge.graders.base_grader import BaseGrader +from openjudge.graders.schema import GraderScore + +try: + from openjudge.models import OpenAIChatModel +except Exception: + from openjudge.models.openai_chat_model import OpenAIChatModel + +from .prompt import CITATION_INTEGRITY_PROMPT_COT, CITATION_INTEGRITY_USER_TEMPLATE +from .json_utils import strict_load_json, validate_integrity_shape, construct_reward_prompt + + +class AuditGrader(BaseGrader): + """ + 引用逻辑审计 Grader + + - 输入:traj (完整对话轨迹) + - 输出:GraderScore(score, reason) + - score: integrity_score (Supported / Total) + - reason: 审计摘要,包括错误分布和定性总结 + """ + + def __init__( + self, + model: OpenAIChatModel, + name: str = "citation_integrity", + **kwargs: Any, + ): + super().__init__(name=name, **kwargs) + self.model = model + + @staticmethod + def create_default_model( + model_name: str, + api_key: str | None = None, + base_url: str | None = None, + deterministic: bool = True, + enable_thinking: bool = False, + seed: int = 42, + ) -> OpenAIChatModel: + api_key = api_key or os.getenv("OPENAI_API_KEY") + base_url = base_url or os.getenv("OPENAI_BASE_URL") + + extra_body: Dict[str, Any] = {} + if deterministic: + extra_body.update( + { + "temperature": 0.0, + "top_p": 1.0, + "seed": seed, + } + ) + if enable_thinking is False: + extra_body["enable_thinking"] = False + + kwargs: Dict[str, Any] = {"model": model_name} + if api_key: + kwargs["api_key"] = api_key + if base_url: + kwargs["base_url"] = base_url + if extra_body: + kwargs["extra_body"] = extra_body + + return OpenAIChatModel(**kwargs) + + async def aevaluate( + self, + traj: Any, + **_: Any, + ) -> GraderScore: + """ + 入口:必须喂 traj(完整对话轨迹) + + Args: + traj: 对话轨迹,支持以下格式: + - [{"role": ..., "content": ...}, ...] 直接消息列表 + - {"messages": [...]} 包含 messages 字段的 dict + - {"traj": [[...]]} 包含 traj 字段的 dict(双重嵌套) + + Returns: + GraderScore(name, score, reason) + """ + # 1. 提取 messages(兼容多种格式) + if isinstance(traj, dict): + if "traj" in traj: + # 支持 {"traj": [[...]]} 格式 + traj_list = traj["traj"] + if traj_list and isinstance(traj_list[0], list): + messages_list = traj_list[0] + else: + messages_list = traj_list + else: + messages_list = traj.get("messages", []) + elif isinstance(traj, list): + messages_list = traj + else: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: traj must be list or dict with 'messages'/'traj'", + ) + + if not messages_list: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty trajectory", + ) + + # 2. 构建 Prompt + # 使用新的 System Prompt 和 User Template + user_prompt = construct_reward_prompt(messages_list, CITATION_INTEGRITY_USER_TEMPLATE) + + messages = [ + {"role": "system", "content": CITATION_INTEGRITY_PROMPT_COT}, + {"role": "user", "content": user_prompt} + ] + + # 3. 模型推理 + try: + resp = await self.model.achat(messages) + raw_text = getattr(resp, "content", str(resp)) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ) + + # 4. JSON 解析与验证 + obj, jerr = strict_load_json(raw_text) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"ParseError: {jerr}; raw[:200]={snippet}", + ) + + # 使用新的验证逻辑 validate_integrity_shape + obj, serr = validate_integrity_shape(obj) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"SchemaError: {serr}; raw[:200]={snippet}", + ) + + # 5. 计算分数与生成理由 + score, reason = self._compute_scores(obj) + return GraderScore(name=self.name, score=score, reason=reason) + + def _compute_scores(self, obj: Dict[str, Any]) -> Tuple[float, str]: + """ + 基于 audit_trail 和 integrity_score 计算最终结果 + """ + # 直接获取模型计算的 integrity_score,若缺失则手动计算 + audit_trail = obj.get("audit_trail", []) + total_citations = len(audit_trail) + + # 统计各Verdict数量 + verdict_counts = { + "Supported": 0, + "Overstated": 0, + "Contradicted": 0, + "Hallucinated": 0, + "Irrelevant": 0 + } + + for item in audit_trail: + v = item.get("verdict", "Irrelevant") + if v in verdict_counts: + verdict_counts[v] += 1 + else: + verdict_counts["Irrelevant"] += 1 + + supported_count = verdict_counts["Supported"] + + # 优先使用模型输出的 score,如果有误则回退到手动计算 + model_score = obj.get("integrity_score") + if isinstance(model_score, (float, int)) and 0.0 <= model_score <= 1.0: + final_score = float(model_score) + else: + final_score = supported_count / total_citations if total_citations > 0 else 0.0 + + # 构建 Reason + # 格式: Score: 0.80 | Total: 10 | Supp: 8, Over: 1, Hallu: 1 | Summary: ... + stats_parts = [] + for k, v in verdict_counts.items(): + if v > 0: + stats_parts.append(f"{k[:4]}:{v}") # 缩写 Verdict + + stats_str = ", ".join(stats_parts) + qualitative = obj.get("qualitative_summary", "No summary provided.") + + # 截取主要错误示例 (如果有) + errors = [x for x in audit_trail if x.get("verdict") != "Supported"] + error_msg = "" + if errors: + first_err = errors[0] + error_msg = f" | Example Error ([{first_err.get('citation_id')}]) {first_err.get('verdict')}: {first_err.get('logic_analysis')}" + + reason = ( + f"Score: {final_score:.2f} | Total: {total_citations} | {stats_str} | " + f"Summary: {qualitative}{error_msg}" + ) + + return round(final_score, 4), reason[:1000] \ No newline at end of file diff --git a/tutorial/example_deep_finance/judge/audit/json_utils.py b/tutorial/example_deep_finance/judge/audit/json_utils.py new file mode 100644 index 00000000..394aaefc --- /dev/null +++ b/tutorial/example_deep_finance/judge/audit/json_utils.py @@ -0,0 +1,176 @@ +"""JSON Utilities for Audit Grader""" +from __future__ import annotations + +import json +import re +from typing import Any, Dict, List, Tuple + +_JSON_RE = re.compile(r"\{.*\}", re.DOTALL) + +def extract_first_json_object(text: str) -> str | None: + if not text: + return None + m = _JSON_RE.search(text.strip()) + if not m: + return None + return m.group(0) + +def strict_load_json(text: str) -> Tuple[Dict[str, Any] | None, str | None]: + js = extract_first_json_object(text) + if js is None: + return None, "No JSON object found" + try: + obj = json.loads(js) + if not isinstance(obj, dict): + return None, f"Root is not dict: {type(obj)}" + return obj, None + except Exception as e: + return None, f"JSONDecodeError: {str(e)}" + +def validate_integrity_shape(obj: Dict[str, Any]) -> Tuple[Dict[str, Any] | None, str | None]: + """ + 验证 Evidence Logic Analyst 的输出结构 + Schema: + { + "audit_trail": [ + {"citation_id": int, "verdict": str, ...}, ... + ], + "qualitative_summary": str, + "integrity_score": float + } + """ + # 1. Check Top-level fields + required_fields = ["audit_trail", "qualitative_summary", "integrity_score"] + for f in required_fields: + if f not in obj: + return None, f"Missing field: {f}" + + # 2. Validate integrity_score + try: + score = float(obj["integrity_score"]) + if not (0.0 <= score <= 1.0): + # 容错:稍微越界归一化 + score = max(0.0, min(1.0, score)) + obj["integrity_score"] = score + except ValueError: + return None, "integrity_score must be a float" + + # 3. Validate audit_trail + if not isinstance(obj["audit_trail"], list): + return None, "audit_trail must be a list" + + valid_verdicts = {"Supported", "Overstated", "Contradicted", "Hallucinated", "Irrelevant"} + + for idx, item in enumerate(obj["audit_trail"]): + if not isinstance(item, dict): + return None, f"audit_trail[{idx}] is not a dict" + + # Check required item fields + if "citation_id" not in item: + return None, f"audit_trail[{idx}] missing 'citation_id'" + if "verdict" not in item: + return None, f"audit_trail[{idx}] missing 'verdict'" + + # Normalize verdict + v = str(item["verdict"]).strip() + # 简单的大小写兼容 + v_cap = v.capitalize() + if v not in valid_verdicts and v_cap in valid_verdicts: + item["verdict"] = v_cap + elif v not in valid_verdicts: + # 如果模型输出了奇奇怪怪的verdict,降级为Irrelevant或报错,这里选择报错以保证严谨 + return None, f"Invalid verdict '{v}' in item {idx}" + + return obj, None + + +# ============================================================================= +# Trajectory Helpers +# ============================================================================= + +def _extract_text_content(content) -> str: + if content is None: return "" + if isinstance(content, str): return content + if isinstance(content, list): + # Handle OpenAI multi-part content + parts = [] + for p in content: + if isinstance(p, dict) and p.get("type") == "text": + parts.append(p.get("text", "")) + elif isinstance(p, str): + parts.append(p) + return "\n".join(parts) + return str(content) + +def _strip_think(text: str) -> str: + return re.sub(r".*?\s*", "", text, flags=re.S).strip() + +def _strip_markdown_fences(text: str) -> str: + text = text.strip() + text = re.sub(r'^```(?:markdown|md)?\s*\n?', '', text, flags=re.IGNORECASE) + text = re.sub(r'\n?```\s*$', '', text) + return text.strip() + +def _extract_tool_call_json(text: str) -> str: + # 尝试提取 ```json ... ``` + m = re.search(r"```json\s*(\[[\s\S]*?\])\s*```", text) + if m: return m.group(1).strip() + # 简单的 fallback + if text.strip().startswith("[") and text.strip().endswith("]"): + return text.strip() + return "" + +def construct_reward_prompt(trajectory: List[Dict[str, Any]], template: str) -> str: + """ + 提取 User Query, Evidence (Tool Outputs), Final Report + """ + user_query = "" + evidence_parts = [] + final_report = "" + + # Helper to clean text + def clean(c): return _strip_think(_extract_text_content(c)) + + # 1. Identify components + # 倒序查找 Final Report (包含 References 或 TASK_COMPLETED 的 Assistant 消息) + for i in range(len(trajectory) - 1, -1, -1): + msg = trajectory[i] + if msg.get("role") == "assistant": + txt = clean(msg.get("content")) + # 宽松判定:通常最后的长文本是报告 + if "References" in txt or "[TASK_COMPLETED]" in txt or len(txt) > 600: + final_report = _strip_markdown_fences(txt) + break + + # 找不到显式报告时,取最后一条 Assistant + if not final_report and trajectory: + last = trajectory[-1] + if last.get("role") == "assistant": + final_report = _strip_markdown_fences(clean(last.get("content"))) + + for idx, msg in enumerate(trajectory): + role = msg.get("role") + content_raw = clean(msg.get("content")) + + # User Query: First user message + if role == "user" and not user_query: + user_query = content_raw + continue # 不要把 query 当作 evidence + + # Evidence: Tool calls and Tool outputs + if role == "assistant": + # Check for tool calls + tool_json = _extract_tool_call_json(content_raw) + if tool_json: + evidence_parts.append(f"--- Step {idx} Tool Call ---\n{tool_json}") + + elif role == "tool": + evidence_parts.append(f"--- Step {idx} Tool Result ---\n{content_raw}") + + evidence_text = "\n\n".join(evidence_parts) + + return template.format( + user_query=user_query, + evidence_text=evidence_text, + final_report=final_report + ) \ No newline at end of file diff --git a/tutorial/example_deep_finance/judge/audit/prompt.py b/tutorial/example_deep_finance/judge/audit/prompt.py new file mode 100644 index 00000000..f045b6f4 --- /dev/null +++ b/tutorial/example_deep_finance/judge/audit/prompt.py @@ -0,0 +1,67 @@ +"""Audit Grader Prompt - 引用逻辑审计 (Logic Analyst)""" + +# ============================================================================= +# System Prompt (Evidence Logic Analyst) +# ============================================================================= + +CITATION_INTEGRITY_PROMPT_COT = """ +你是一位 **"证据逻辑分析师" (Evidence Logic Analyst)**。你的任务是审计 AI 研究报告中的引用是否严格符合"逻辑蕴含 (Logical Entailment)"原则。 + +## 核心任务 +不要预设结论。你必须像法官判案一样,先罗列证据,再进行逻辑推导,最后下达判决。 +你需要对报告中出现的每一个引用标记 `[n]` 进行独立的"三步验证"。 + +## 验证逻辑 (必须严格遵守的思维顺序) + +1. **提取 (Extract)**: 锁定报告中由 `[n]` 支撑的陈述片段 (Claim)。 +2. **溯源 (Trace)**: 在 Reference 列表中找到 `[n]` 对应的原始文本,并摘录出核心证据句 (Source Quote)。 + - 注意:Reference 列表可能包含 URL 或 工具调用信息,你需要根据这些信息去上文提供的 **Evidence** 中寻找对应的内容。 +3. **比对 (Compare)**: 分析 Claim 是否被 Source Quote 严格支撑。 + * Check: 数字/事实是否一致? + * Check: 语气是否一致(有没有把"可能"改成"确定")? + * Check: 因果关系是否存在? + +## 判决标准 (Verdict Criteria) +* **Supported**: 证据充分,逻辑闭环。允许合理的概括,但禁止添加细节。 +* **Overstated**: 夸大其词。证据只说了 A,报告却写成了 A+ (如:去掉了"据报道"、"约"等限定词,或强加了因果关系)。 +* **Contradicted**: 事实冲突。报告内容与证据相反。 +* **Hallucinated**: 无中生有。报告中的关键细节(人名、数据、事件)在证据中找不到,或者引用编号在 References 中不存在。 +* **Irrelevant**: 引用无效。证据内容真实,但与报告所述主题无关。 + +## 输出格式 (JSON Only) +只输出 JSON,严禁输出 Markdown 或其他文字。字段顺序代表你的思考顺序,**不可乱序**: + +{ + "audit_trail": [ + { + "citation_id": 1, + "claim_excerpt": "报告中声称的片段...", + "evidence_quote": "从Evidence中摘录的原话...", + "logic_analysis": "分析:证据说的是X,报告写的是Y。二者是否一致?有没有夸大?(简短分析)", + "verdict": "Supported" | "Overstated" | "Contradicted" | "Hallucinated" | "Irrelevant", + "correction": "如果非Supported,基于证据的正确表述应该是..." + }, + ... + ], + "qualitative_summary": "基于上述审计,用一句话总结该报告的引用可信度(如:引用大多准确,但在具体数据上存在夸大嫌疑)。", + "integrity_score": <0.0 到 1.0 的浮点数,计算公式:Supported数量 / 总引用数> +} +""" + +# ============================================================================= +# User Prompt Template +# ============================================================================= + +CITATION_INTEGRITY_USER_TEMPLATE = """请作为逻辑分析师,对以下 AI 研究报告进行引用审计。 + +### User Query +{user_query} + +### Evidence (工具调用与返回结果) +{evidence_text} + +### AI Report (待审计报告) +{final_report} + +请严格遵守 JSON 输出格式,对报告中的所有 [n] 引用进行逐一核查。 +""" \ No newline at end of file diff --git a/tutorial/example_deep_finance/judge/cgcv/__init__.py b/tutorial/example_deep_finance/judge/cgcv/__init__.py new file mode 100644 index 00000000..b67a705f --- /dev/null +++ b/tutorial/example_deep_finance/judge/cgcv/__init__.py @@ -0,0 +1,7 @@ +""" +CGCV (Citation-Grounded Claim Verification) Grader +引用锚定的断言验证框架 +""" +from .grader import CGCVGrader + +__all__ = ["CGCVGrader"] diff --git a/tutorial/example_deep_finance/judge/cgcv/grader.py b/tutorial/example_deep_finance/judge/cgcv/grader.py new file mode 100644 index 00000000..2f65c6eb --- /dev/null +++ b/tutorial/example_deep_finance/judge/cgcv/grader.py @@ -0,0 +1,362 @@ +""" +CGCV Grader - Citation-Grounded Claim Verification +引用锚定的断言验证评分器 +""" +from __future__ import annotations + +import os +from typing import Any, Dict, List, Optional, Tuple + +from openjudge.graders.base_grader import BaseGrader +from openjudge.graders.schema import GraderScore + +# import path 兼容两种写法 +try: + from openjudge.models import OpenAIChatModel +except Exception: # pragma: no cover + from openjudge.models.openai_chat_model import OpenAIChatModel + +from .prompt import ( + CGCV_SYSTEM_PROMPT_ZH, + CGCV_SYSTEM_PROMPT_EN, + CGCV_USER_PROMPT_TEMPLATE_ZH, + CGCV_USER_PROMPT_TEMPLATE_EN, + get_cgcv_prompts +) +from .json_utils import ( + strict_load_json, + validate_cgcv_schema, + parse_cgcv_result, + construct_cgcv_prompt, + compute_cgcv_score, + CGCVResult, + ClaimStatus +) + + +class CGCVGrader(BaseGrader): + """ + Citation-Grounded Claim Verification (CGCV) Grader + 引用锚定的断言验证评分器 + + 核心理念:引用是断言与证据之间的"锚点" + + 验证流程: + 1. 断言提取 (Claim Extraction) + 2. 引用检查 (Citation Checking) + 3. 来源追溯 (Source Tracing) + 4. 内容对齐验证 (Content Alignment) + + 验证状态: + - verified: 验证通过 + - citation_missing: 引用缺失 + - citation_broken: 引用断裂 + - subject_misalign: 对象错位 + - predicate_misalign: 属性错位 + - object_misalign: 值错位 + - qualifier_misalign: 限定错位 + + 评分机制: + - score = verified_claims / total_claims + - 范围: [0, 1] + + 输入:traj(完整对话轨迹) + 输出:GraderScore(name, score, reason) + """ + + def __init__( + self, + model: OpenAIChatModel, + name: str = "cgcv", + language: str = "zh", + **kwargs: Any, + ): + """ + 初始化 CGCV Grader + + Args: + model: OpenAI 兼容的聊天模型 + name: Grader 名称 + language: 语言选择,"zh" 或 "en" + **kwargs: 其他参数传递给 BaseGrader + """ + super().__init__(name=name, **kwargs) + self.model = model + self.language = language.lower() + + # 根据语言选择 prompt + self.system_prompt, self.user_prompt_template = get_cgcv_prompts(self.language) + + @staticmethod + def create_default_model( + model_name: str, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + deterministic: bool = True, + enable_thinking: bool = False, + seed: int = 0, + ) -> OpenAIChatModel: + """ + 创建默认模型 + + Args: + model_name: 模型名称 + api_key: API Key,默认从环境变量读取 + base_url: API Base URL,默认从环境变量读取 + deterministic: 是否使用确定性配置 + enable_thinking: 是否启用思考模式 + seed: 随机种子 + + Returns: + OpenAIChatModel 实例 + """ + api_key = api_key or os.getenv("OPENAI_API_KEY") + base_url = base_url or os.getenv("OPENAI_BASE_URL") + + extra_body: Dict[str, Any] = {} + if deterministic: + extra_body.update({ + "temperature": 0, + "top_p": 1, + "seed": seed, + "presence_penalty": 0, + "frequency_penalty": 0, + }) + if enable_thinking is False: + extra_body["enable_thinking"] = False + + kwargs: Dict[str, Any] = {"model": model_name} + if api_key: + kwargs["api_key"] = api_key + if base_url: + kwargs["base_url"] = base_url + if extra_body: + kwargs["extra_body"] = extra_body + + return OpenAIChatModel(**kwargs) + + async def aevaluate( + self, + traj: Any, + **_: Any, + ) -> GraderScore: + """ + 异步评估入口 + + Args: + traj: 对话轨迹,格式为 [{"role": ..., "content": ...}, ...] + 或者 {"messages": [...]} 格式 + + Returns: + GraderScore(name, score, reason) + """ + # 1. 提取 messages(兼容两种格式) + if isinstance(traj, dict): + messages_list = traj.get("messages", []) + elif isinstance(traj, list): + messages_list = traj + else: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: traj must be list or dict with 'messages'", + ) + + if not messages_list: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty trajectory", + ) + + # 2. 构建 prompt + user_prompt = construct_cgcv_prompt(messages_list, self.user_prompt_template) + + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": user_prompt} + ] + + # 3. 调用模型 + try: + resp = await self.model.achat(messages) + raw_text = getattr(resp, "content", None) + if raw_text is None: + raw_text = str(resp) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ) + + # 4. 解析 JSON + obj, jerr = strict_load_json(str(raw_text)) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"ParseError: {jerr}; raw[:200]={snippet}", + ) + + # 5. 验证 schema + obj, serr = validate_cgcv_schema(obj) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"SchemaError: {serr}; raw[:200]={snippet}", + ) + + # 6. 解析结果并计算分数 + result = parse_cgcv_result(obj) + score, reason = compute_cgcv_score(result) + + return GraderScore(name=self.name, score=score, reason=reason) + + def evaluate( + self, + traj: Any, + **kwargs: Any, + ) -> GraderScore: + """ + 同步评估入口(通过 asyncio 包装异步方法) + + Args: + traj: 对话轨迹 + **kwargs: 其他参数 + + Returns: + GraderScore + """ + import asyncio + + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete(self.aevaluate(traj, **kwargs)) + + def get_detailed_result( + self, + traj: Any, + ) -> Tuple[GraderScore, Optional[CGCVResult]]: + """ + 获取详细评估结果(包含每个断言的验证详情) + + Args: + traj: 对话轨迹 + + Returns: + (GraderScore, CGCVResult) 元组 + """ + import asyncio + + async def _detailed_evaluate(): + # 复用主流程逻辑 + if isinstance(traj, dict): + messages_list = traj.get("messages", []) + elif isinstance(traj, list): + messages_list = traj + else: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: traj must be list or dict with 'messages'", + ), None + + if not messages_list: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty trajectory", + ), None + + user_prompt = construct_cgcv_prompt(messages_list, self.user_prompt_template) + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": user_prompt} + ] + + try: + resp = await self.model.achat(messages) + raw_text = getattr(resp, "content", None) + if raw_text is None: + raw_text = str(resp) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ), None + + obj, jerr = strict_load_json(str(raw_text)) + if obj is None: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ParseError: {jerr}", + ), None + + obj, serr = validate_cgcv_schema(obj) + if obj is None: + return GraderScore( + name=self.name, + score=0.0, + reason=f"SchemaError: {serr}", + ), None + + result = parse_cgcv_result(obj) + score, reason = compute_cgcv_score(result) + + return GraderScore(name=self.name, score=score, reason=reason), result + + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete(_detailed_evaluate()) + + +# ============================================================================= +# Convenience Functions +# ============================================================================= + +def create_cgcv_grader( + model_name: str, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + language: str = "zh", + **kwargs +) -> CGCVGrader: + """ + 便捷函数:创建 CGCV Grader + + Args: + model_name: 模型名称 + api_key: API Key + base_url: API Base URL + language: 语言 ("zh" 或 "en") + **kwargs: 其他模型参数 + + Returns: + CGCVGrader 实例 + + Example: + >>> grader = create_cgcv_grader("gpt-4o", language="zh") + >>> result = await grader.aevaluate(trajectory) + >>> print(f"Score: {result.score}, Reason: {result.reason}") + """ + model = CGCVGrader.create_default_model( + model_name=model_name, + api_key=api_key, + base_url=base_url, + **kwargs + ) + return CGCVGrader(model=model, language=language) diff --git a/tutorial/example_deep_finance/judge/cgcv/json_utils.py b/tutorial/example_deep_finance/judge/cgcv/json_utils.py new file mode 100644 index 00000000..965bf301 --- /dev/null +++ b/tutorial/example_deep_finance/judge/cgcv/json_utils.py @@ -0,0 +1,578 @@ +""" +CGCV JSON Utilities +JSON 解析和验证工具 +""" +from __future__ import annotations + +import json +import re +from typing import Any, Dict, List, Tuple, Optional +from dataclasses import dataclass +from enum import Enum + + +# ============================================================================= +# Constants +# ============================================================================= + +class ClaimStatus(str, Enum): + """断言验证状态枚举""" + VERIFIED = "verified" + CITATION_MISSING = "citation_missing" + CITATION_BROKEN = "citation_broken" + SUBJECT_MISALIGN = "subject_misalign" + PREDICATE_MISALIGN = "predicate_misalign" + OBJECT_MISALIGN = "object_misalign" + QUALIFIER_MISALIGN = "qualifier_misalign" + + +# 所有有效的 status 值 +VALID_STATUSES = {s.value for s in ClaimStatus} + +# JSON 提取正则 +_JSON_RE = re.compile(r"\{.*\}", re.DOTALL) + + +# ============================================================================= +# Data Classes +# ============================================================================= + +@dataclass +class ClaimVerification: + """单个断言的验证结果""" + subject: str + predicate: str + object: str + qualifier: str + citation: Optional[str] + status: str + source_id: Optional[str] + note: str + + def is_verified(self) -> bool: + return self.status == ClaimStatus.VERIFIED.value + + def is_citation_issue(self) -> bool: + return self.status in { + ClaimStatus.CITATION_MISSING.value, + ClaimStatus.CITATION_BROKEN.value + } + + def is_alignment_issue(self) -> bool: + return self.status in { + ClaimStatus.SUBJECT_MISALIGN.value, + ClaimStatus.PREDICATE_MISALIGN.value, + ClaimStatus.OBJECT_MISALIGN.value, + ClaimStatus.QUALIFIER_MISALIGN.value + } + + +@dataclass +class CGCVResult: + """CGCV 验证结果汇总""" + claims: List[ClaimVerification] + total: int + verified: int + citation_missing: int + citation_broken: int + alignment_issues: int + + @property + def score(self) -> float: + """计算验证通过率""" + if self.total == 0: + return 0.0 + return self.verified / self.total + + def get_summary(self) -> Dict[str, int]: + """获取统计摘要""" + return { + "total": self.total, + "verified": self.verified, + "citation_missing": self.citation_missing, + "citation_broken": self.citation_broken, + "alignment_issues": self.alignment_issues + } + + +# ============================================================================= +# JSON Parsing Functions +# ============================================================================= + +def extract_first_json_object(text: str) -> Optional[str]: + """ + 从文本中提取第一个 JSON 对象 + + Args: + text: 原始文本 + + Returns: + JSON 字符串,如果未找到返回 None + """ + if not text: + return None + + # 先尝试找 ```json ... ``` 代码块 + json_block_match = re.search(r"```json\s*(\{[\s\S]*?\})\s*```", text) + if json_block_match: + return json_block_match.group(1).strip() + + # 再尝试找第一个 {...} + m = _JSON_RE.search(text.strip()) + if not m: + return None + return m.group(0) + + +def strict_load_json(text: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + """ + 严格解析 JSON + + Args: + text: 原始文本 + + Returns: + (解析结果, 错误信息) 元组 + """ + js = extract_first_json_object(text) + if js is None: + return None, "No JSON object found in model output" + + try: + obj = json.loads(js) + if not isinstance(obj, dict): + return None, f"Top-level JSON is not an object: {type(obj).__name__}" + return obj, None + except json.JSONDecodeError as e: + return None, f"JSONDecodeError: {e}" + except Exception as e: + return None, f"{type(e).__name__}: {e}" + + +def validate_cgcv_schema(obj: Dict[str, Any]) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + """ + 验证 CGCV JSON 结构 + + 期望格式: + { + "claims": [ + { + "subject": str, + "predicate": str, + "object": str, + "qualifier": str, + "citation": str | null, + "status": str (one of VALID_STATUSES), + "source_id": str | null, + "note": str + } + ] + } + + Args: + obj: JSON 对象 + + Returns: + (规范化后的对象, 错误信息) 元组 + """ + # claims 必须存在且为 list + if "claims" not in obj: + return None, "Missing field: claims" + + claims = obj["claims"] + if not isinstance(claims, list): + return None, f"Field 'claims' must be list, got {type(claims).__name__}" + + # 验证并规范化每个 claim + normalized_claims = [] + for idx, claim in enumerate(claims): + if not isinstance(claim, dict): + continue # 跳过非字典项 + + # 提取并规范化字段 + normalized = { + "subject": str(claim.get("subject", "未明确"))[:200], + "predicate": str(claim.get("predicate", "未明确"))[:200], + "object": str(claim.get("object", "未明确"))[:500], + "qualifier": str(claim.get("qualifier", "未明确"))[:200], + "citation": claim.get("citation"), + "status": str(claim.get("status", "")).lower(), + "source_id": claim.get("source_id"), + "note": str(claim.get("note", ""))[:500] + } + + # 规范化 citation + if normalized["citation"] is not None: + normalized["citation"] = str(normalized["citation"]) + if normalized["citation"].lower() in ("null", "none", ""): + normalized["citation"] = None + + # 规范化 source_id + if normalized["source_id"] is not None: + normalized["source_id"] = str(normalized["source_id"]) + if normalized["source_id"].lower() in ("null", "none", ""): + normalized["source_id"] = None + + # 验证 status + if normalized["status"] not in VALID_STATUSES: + # 尝试模糊匹配 + status_lower = normalized["status"] + matched = False + for valid_status in VALID_STATUSES: + if valid_status in status_lower or status_lower in valid_status: + normalized["status"] = valid_status + matched = True + break + if not matched: + # 默认标记为 citation_missing + normalized["status"] = ClaimStatus.CITATION_MISSING.value + + normalized_claims.append(normalized) + + obj["claims"] = normalized_claims + return obj, None + + +def parse_cgcv_result(obj: Dict[str, Any]) -> CGCVResult: + """ + 解析 CGCV 结果为结构化对象 + + Args: + obj: 经过 validate_cgcv_schema 验证的 JSON 对象 + + Returns: + CGCVResult 对象 + """ + claims = [] + verified_count = 0 + citation_missing_count = 0 + citation_broken_count = 0 + alignment_issues_count = 0 + + for claim_dict in obj.get("claims", []): + claim = ClaimVerification( + subject=claim_dict.get("subject", ""), + predicate=claim_dict.get("predicate", ""), + object=claim_dict.get("object", ""), + qualifier=claim_dict.get("qualifier", ""), + citation=claim_dict.get("citation"), + status=claim_dict.get("status", ""), + source_id=claim_dict.get("source_id"), + note=claim_dict.get("note", "") + ) + claims.append(claim) + + # 统计 + if claim.is_verified(): + verified_count += 1 + elif claim.status == ClaimStatus.CITATION_MISSING.value: + citation_missing_count += 1 + elif claim.status == ClaimStatus.CITATION_BROKEN.value: + citation_broken_count += 1 + elif claim.is_alignment_issue(): + alignment_issues_count += 1 + + return CGCVResult( + claims=claims, + total=len(claims), + verified=verified_count, + citation_missing=citation_missing_count, + citation_broken=citation_broken_count, + alignment_issues=alignment_issues_count + ) + + +# ============================================================================= +# Trajectory 处理辅助函数 +# ============================================================================= + +def _extract_text_content(content) -> str: + """统一提取纯文本内容""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + out = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + out.append(item.get("text", "")) + elif isinstance(item, str): + out.append(item) + return "\n".join(out) + return str(content) + + +def _strip_think(text: str) -> str: + """去除 ... 标签""" + return re.sub(r".*?\s*", "", text, flags=re.S).strip() + + +def _strip_markdown_fences(text: str) -> str: + """ + 清理 markdown 代码块标记 + - 移除开头的 ```markdown / ```md / ``` 等 + - 移除结尾的 ``` + """ + text = text.strip() + # 移除开头的 ```xxx + text = re.sub(r'^```(?:markdown|md)?\s*\n?', '', text, flags=re.IGNORECASE) + # 移除结尾的 ``` + text = re.sub(r'\n?```\s*$', '', text) + return text.strip() + + +def _normalize_traj(trajectory): + """兼容 [[...]] 格式""" + if isinstance(trajectory, list) and trajectory and isinstance(trajectory[0], list): + return trajectory[0] + return trajectory + + +def _extract_tool_call_json(text: str) -> str: + """提取工具调用 JSON""" + m = re.search(r"```json\s*(\[[\s\S]*?\])\s*```", text) + if m: + return m.group(1).strip() + l, r = text.find("["), text.rfind("]") + if l != -1 and r != -1 and r > l: + cand = text[l:r+1].strip() + if ("tool_name" in cand) and ("tool_args" in cand): + return cand + return "" + + +def _looks_like_tool_result(text: str) -> bool: + """判断是否为工具返回结果""" + t = text.strip() + # 匹配常见的工具返回格式 + if t.startswith("Tool:") or t.startswith("Result:"): + return True + # 匹配 [Tool: xxx] 格式 + if t.startswith("[Tool:"): + return True + # 匹配 格式 + if "" in t or "" in t: + return True + # 匹配 dashscope_search 等工具的返回结果 + if t.startswith("{") and ("query" in t) and ("search_results" in t or "response_content" in t): + return True + # 匹配爬取工具返回的结构化数据 + if ("股票代码 |" in t) or ("单位:" in t) or t.startswith("### "): + return True + # 匹配同花顺工具返回的来源标记 + if "> 以下内容来自:" in t: + return True + return False + + +def _is_probably_final_report(text: str) -> bool: + """判断是否为最终报告""" + t = text.strip() + return ("## References" in t) or ("[TASK_COMPLETED]" in t) or t.lstrip().startswith("# ") + + +def _split_tool_responses(text: str) -> List[str]: + """ + 分割多个工具响应 + + 处理格式如: + [Tool: xxx] +... + + +[Tool: yyy] +... + """ + # 先尝试按 \n 分割 + if "" in text and "" in text: + parts = re.split(r'\s*', text) + # 清理每个部分的标签 + cleaned = [] + for p in parts: + p = re.sub(r'^\s*\s*', '', p) + p = re.sub(r'\s*\s*$', '', p) + p = p.strip() + if p: + cleaned.append(p) + if cleaned: + return cleaned + + # 尝试按 [Tool: xxx] 分割 + tool_pattern = r'(?=\[Tool:\s*[^\]]+\])' + parts = re.split(tool_pattern, text) + parts = [p.strip() for p in parts if p.strip()] + if len(parts) > 1: + return parts + + # 无法分割,返回原文本 + return [text.strip()] if text.strip() else [] + + +def construct_cgcv_prompt( + trajectory: List[Dict[str, Any]], + user_prompt_template: str +) -> str: + """ + 从 trajectory 构建 CGCV 评估 prompt + + Args: + trajectory: 对话轨迹 [{"role": ..., "content": ...}, ...] + user_prompt_template: 用户 prompt 模板 + + Returns: + 构建好的 user prompt 字符串 + """ + traj = _normalize_traj(trajectory) + if not traj: + traj = [] + + user_query = "" + tool_calls: List[str] = [] + evidence: List[str] = [] + final_report = "" + + # 找到 final report(从后往前找第一个符合条件的 assistant 消息) + for i in range(len(traj) - 1, -1, -1): + step = traj[i] + if step.get("role") == "assistant": + txt = _strip_think(_extract_text_content(step.get("content"))) + if _is_probably_final_report(txt): + final_report = txt + break + + if not final_report: + for i in range(len(traj) - 1, -1, -1): + if traj[i].get("role") == "assistant": + final_report = _strip_think(_extract_text_content(traj[i].get("content"))) + break + + # 清理 markdown 代码块标记 + final_report = _strip_markdown_fences(final_report) + + # 遍历提取 user_query, tool_calls, evidence + evidence_idx = 0 + for idx, step in enumerate(traj): + role = step.get("role") + raw = _extract_text_content(step.get("content")) + txt = _strip_think(raw) + if not raw: + continue + + # 跳过 system 消息 + if role == "system": + continue + + if role == "user" and not user_query and (not _looks_like_tool_result(raw)): + user_query = txt + continue + + if role == "assistant": + call_json = _extract_tool_call_json(raw) + if call_json: + tool_calls.append(f"【工具调用 {len(tool_calls) + 1}】\n{call_json}") + + if role == "tool": + # 处理多工具响应的情况 + tool_parts = _split_tool_responses(raw) + for part in tool_parts: + if part: + evidence_idx += 1 + evidence.append(f"【Evidence {evidence_idx}】\n{part}") + elif role == "user" and user_query and _looks_like_tool_result(raw): + # 某些情况下工具结果可能在 user 消息中 + evidence_idx += 1 + evidence.append(f"【Evidence {evidence_idx}】\n{raw}") + + # 构建 evidence_text,使用更清晰的分隔 + evidence_parts = [] + if evidence: + evidence_parts.append("\n\n".join(evidence)) + + evidence_text = "\n\n".join(evidence_parts) if evidence_parts else "(无可用证据)" + + return user_prompt_template.format( + user_query=user_query, + evidence_text=evidence_text, + report=final_report + ).strip() + + +# ============================================================================= +# Score Computation +# ============================================================================= + +def compute_cgcv_score( + result: CGCVResult, + citation_weight: float = 0.3, + alignment_weight: float = 0.7 +) -> Tuple[float, str]: + """ + 计算 CGCV 评分 + + 评分策略: + 1. 基础分:verified / total + 2. 可选:分层评分 + - citation_score: 有引用且可追溯的比例 + - alignment_score: 内容对齐的比例(在有有效引用的前提下) + + Args: + result: CGCVResult 对象 + citation_weight: 引用分数权重(默认 0.3) + alignment_weight: 对齐分数权重(默认 0.7) + + Returns: + (score, reason) 元组 + """ + total = result.total + + if total == 0: + return 0.0, "no_claims_detected" + + # 简单评分:verified / total + base_score = result.verified / total + + # 分层统计 + citation_issues = result.citation_missing + result.citation_broken + claims_with_valid_citation = total - citation_issues + + # 引用有效率 + citation_valid_rate = claims_with_valid_citation / total if total > 0 else 0.0 + + # 对齐正确率(在有效引用中) + if claims_with_valid_citation > 0: + alignment_correct_rate = result.verified / claims_with_valid_citation + else: + alignment_correct_rate = 0.0 + + # 加权分数 + weighted_score = ( + citation_weight * citation_valid_rate + + alignment_weight * alignment_correct_rate + ) + + # 最终使用基础分数(更直观) + final_score = base_score + + # 构建 reason + reason_parts = [ + f"total={total}", + f"verified={result.verified}", + f"citation_missing={result.citation_missing}", + f"citation_broken={result.citation_broken}", + f"alignment_issues={result.alignment_issues}", + f"score={final_score:.4f}", + ] + + # 添加错误摘要 + if result.alignment_issues > 0: + # 统计各类对齐错误 + error_counts = {} + for claim in result.claims: + if claim.is_alignment_issue(): + error_counts[claim.status] = error_counts.get(claim.status, 0) + 1 + error_summary = ", ".join(f"{k}:{v}" for k, v in error_counts.items()) + reason_parts.append(f"errors=[{error_summary}]") + + reason = " | ".join(reason_parts) + return round(final_score, 6), reason[:800] diff --git a/tutorial/example_deep_finance/judge/cgcv/prompt.py b/tutorial/example_deep_finance/judge/cgcv/prompt.py new file mode 100644 index 00000000..a98a98b8 --- /dev/null +++ b/tutorial/example_deep_finance/judge/cgcv/prompt.py @@ -0,0 +1,378 @@ +""" +Citation-Grounded Claim Verification (CGCV) Prompt +引用锚定的断言验证框架 + +核心理念:引用是断言与证据之间的"锚点",验证引用的有效性和内容的一致性。 +""" + +# ============================================================================= +# System Prompt - 中文版 +# ============================================================================= + +CGCV_SYSTEM_PROMPT_ZH = """你是一位"引用核查专家",负责审计研究报告中的断言是否有正确的引用支撑,并验证断言内容与来源是否一致。 + +重要说明:这是一个事后评估任务,用于评估已完成的报告质量。报告中通过工具调用获取的信息是正确的研究方式,你的任务是验证这些信息在最终报告中是否被正确引用和准确呈现。 + +## 输入说明 + +你会收到三部分内容: +1. **用户问题**:用户的原始查询 +2. **Evidence**:工具调用返回的原始数据(如搜索结果、爬取的网页内容等) +3. **研究报告**:待核查的报告,包含: + - 正文:包含带引用标记 `[n]` 的断言 + - References 区块:报告末尾的 `## References` 部分,格式通常为: + `[n] 标题描述, 工具: tool_name, 参数:xxx, 数据日期/报告期: xxx, 来源 - URL 或 (no-url)` + +## 验证流程 + +### Stage 1: 断言提取 +从报告**正文**(不含 References 区块)中识别所有包含具体信息的可验证断言,提取四个要素: +- **Subject**:断言涉及的对象(公司、产品、指数、人物等) +- **Predicate**:描述的属性或关系(收入、增长率、排名、状态等) +- **Object**:具体的值、数量或结论 +- **Qualifier**:限定条件(时间、范围、前提条件等) + +**可验证断言的识别标准**: +- 包含具体数值(金额、比例、增速、排名等) +- 包含具体日期或时间段 +- 包含可被证据支持或反驳的明确事实陈述 +- 一句话包含多个数值时,按一条断言计数 + +### Stage 2: 引用检查 +检查每个断言是否有引用标记 `[n]`: +- 有引用 → 继续下一阶段 +- 无引用 → 标记为 `citation_missing` + +### Stage 3: 来源追溯 +追溯引用 `[n]` 的验证路径:**报告正文 [n] → References 中的 [n] 条目 → Evidence 中的对应数据** +- 若 References 中存在 `[n]` 条目,且能在 Evidence 中找到对应数据 → 继续下一阶段 +- 若 References 中无 `[n]` 条目,或条目无效(如 URL 为 javascript:void(0)) → 标记为 `citation_broken` + +### Stage 4: 内容对齐验证 +将报告中的断言与 Evidence 中的原始数据进行比对,验证四个要素是否一致: +- Subject 不一致 → `subject_misalign` +- Predicate 不一致 → `predicate_misalign` +- Object 不一致 → `object_misalign` +- Qualifier 不一致 → `qualifier_misalign` +- 全部一致 → `verified` + +## 验证状态说明 + +| 状态 | 含义 | +|-----|------| +| `verified` | 验证通过:有引用、可追溯、内容与 Evidence 一致 | +| `citation_missing` | 引用缺失:可验证断言无引用标记 | +| `citation_broken` | 引用断裂:引用在 References 中不存在或无效 | +| `subject_misalign` | 对象错位:断言对象与 Evidence 不一致 | +| `predicate_misalign` | 属性错位:属性或关系与 Evidence 不匹配 | +| `object_misalign` | 值错位:数值或结论与 Evidence 不一致 | +| `qualifier_misalign` | 限定错位:时间或条件与 Evidence 不一致 | + +## 内容对齐规则 + +### Subject 对齐规则 +- ✓ 完全一致或已知别名等价(如:腾讯 = 腾讯控股 = Tencent) +- ✓ 股票代码与公司名对应(如:600745 = 闻泰科技) +- ✗ 不同实体混淆(A公司数据误标为B公司) +- ✗ 范围混淆(子公司/渠道数据误标为集团整体,如:i茅台营收 ≠ 贵州茅台总营收) + +### Predicate 对齐规则 +- ✓ 完全一致或语义等价(如:ROE = 净资产收益率、营收 = 营业收入 = 总收入) +- ✗ 概念混淆(净利润 ≠ 营业收入、毛利率 ≠ 净利率) +- ✗ 口径混淆(日收益率 ≠ 周收益率、同比 ≠ 环比) + +### Object 对齐规则 +- ✓ 精确一致(454.03亿 = 454.03亿) +- ✓ 等价形式(18.60% = 18.6%,末尾零可省) +- ✓ 单位换算等价(45403百万 = 454.03亿) +- ✓ 表述等价(下降8% = 增长-8% = 同比-8%) +- ✓ 合理近似:使用"约/大约/左右"修饰时,允许5%以内误差 +- ✗ 精度丢失:未使用"约"等修饰词时,不允许省略有效数字(454.03亿 → 454亿) +- ✗ 超出容差:即使有"约"修饰,误差超过5% +- ✗ 数值无据:Evidence 中找不到该数值 + +### Qualifier 对齐规则 +- ✓ 完全一致或语义等价(2025年Q2 = 2025年第二季度 = 2025年4-6月) +- ✓ 报告期等价(2025年三季报 = 截至2025年9月30日 = 2025年前三季度) +- ✗ 年份错位(2024年 ≠ 2025年) +- ✗ 周期错位(Q2 ≠ Q3、上半年 ≠ 前三季度) +- ✗ 时点混淆(发布日期 ≠ 数据截止日期) + +## 输出格式 + +请直接输出 JSON,格式如下: +```json +{ + "claims": [ + { + "subject": "断言对象", + "predicate": "属性/关系", + "object": "值/结论", + "qualifier": "限定条件(无则填'未明确')", + "citation": "引用标记如[1],无则填null", + "status": "verified/citation_missing/citation_broken/subject_misalign/predicate_misalign/object_misalign/qualifier_misalign", + "source_id": "来源编号(如有)", + "note": "说明(verified时为空字符串)" + } + ] +} +``` + +只输出 JSON,不要输出其他解释文字。 + +## 示例 + +### 示例1:验证通过 (verified) + +**Report正文片段**:闻泰科技2025年三季报净利润为15.13亿元,同比增长265.09%[5] +**Report References**:[5] 闻泰科技2025年三季报财务分析, 工具: crawl_ths_finance, 参数:code=600745, 数据日期/报告期: 2025-09-30, 来源 - https://basic.10jqka.com.cn/600745/finance.html +**Evidence**:...闻泰科技...净利润15.13亿元...同比增长265.09%... + +分析: +- Subject: 闻泰科技 ✓ +- Predicate: 净利润、同比增长 ✓ +- Object: 15.13亿元、265.09% ✓ +- Qualifier: 2025年三季报 ↔ 2025-09-30 ✓(语义等价) +- 引用[5]存在于References,可追溯到Evidence ✓ + +输出: +{"subject": "闻泰科技", "predicate": "净利润同比增长", "object": "15.13亿元,265.09%", "qualifier": "2025年三季报", "citation": "[5]", "status": "verified", "source_id": "5", "note": ""} + +--- + +### 示例2:引用缺失 (citation_missing) + +**Report正文片段**:该公司毛利率达到16.98%,同比提升6.97个百分点 +**Evidence**:...毛利率16.98%...同比提升6.97个百分点... + +分析: +- 断言包含具体数值(16.98%、6.97个百分点),属于可验证断言 +- 但断言末尾无引用标记 [n] + +输出: +{"subject": "该公司", "predicate": "毛利率", "object": "16.98%,同比提升6.97个百分点", "qualifier": "未明确", "citation": null, "status": "citation_missing", "source_id": null, "note": "可验证断言缺少引用标记"} + +--- + +### 示例3:引用断裂 (citation_broken) + +**Report正文片段**:市场份额达到23%[9] +**Report References**:(无[9]条目,或[9]条目的URL为 javascript:void(0)) + +分析: +- 有引用标记[9] +- 但References中无有效的[9]条目 + +输出: +{"subject": "未明确", "predicate": "市场份额", "object": "23%", "qualifier": "未明确", "citation": "[9]", "status": "citation_broken", "source_id": null, "note": "引用[9]在References中不存在或无效"} + +--- + +### 示例4:对象错位 (subject_misalign) + +**Report正文片段**:赛腾股份2025年三季报净利润为15.13亿元[5] +**Report References**:[5] 闻泰科技2025年三季报财务分析, 工具: crawl_ths_finance, 参数:code=600745... +**Evidence**:...闻泰科技...净利润15.13亿元... + +分析: +- Subject: 赛腾股份 ↔ 闻泰科技 ✗ +- 15.13亿元是闻泰科技的数据,被错误归属给赛腾股份 + +输出: +{"subject": "赛腾股份", "predicate": "净利润", "object": "15.13亿元", "qualifier": "2025年三季报", "citation": "[5]", "status": "subject_misalign", "source_id": "5", "note": "来源[5]中15.13亿元属于闻泰科技,非赛腾股份"} + +--- + +### 示例5:值错位-精度丢失 (object_misalign) + +**Report正文片段**:净利润15亿元[5] +**Evidence**:...净利润15.13亿元... + +分析: +- Object: 15亿 ↔ 15.13亿 ✗ +- 报告未使用"约"修饰,但省略了小数部分(0.13亿 = 1300万,精度损失明显) + +输出: +{"subject": "未明确", "predicate": "净利润", "object": "15亿元", "qualifier": "未明确", "citation": "[5]", "status": "object_misalign", "source_id": "5", "note": "Evidence为15.13亿元,报告省略为15亿元,存在精度丢失"} + +--- + +### 示例6:限定错位 (qualifier_misalign) + +**Report正文片段**:2025年Q2净利润为15.13亿元[5] +**Report References**:[5] ...数据日期/报告期: 2025-09-30... +**Evidence**:...2025年三季报...净利润15.13亿元... + +分析: +- Qualifier: Q2(截至6月30日) ↔ 2025-09-30(三季报,截至9月30日) ✗ +- 报告期不一致 + +输出: +{"subject": "未明确", "predicate": "净利润", "object": "15.13亿元", "qualifier": "2025年Q2", "citation": "[5]", "status": "qualifier_misalign", "source_id": "5", "note": "来源[5]为2025年三季报数据(截至9月30日),非Q2数据"}""" + +# ============================================================================= +# System Prompt - English Version +# ============================================================================= + +CGCV_SYSTEM_PROMPT_EN = """You are a "Citation Verification Expert" responsible for auditing whether claims in research reports have proper citation support and whether the claim content is consistent with the evidence sources. + +Important Note: This is a post-hoc evaluation task for assessing completed report quality. Information obtained through tool calls in the report is a correct research approach. Your task is to verify whether this information is correctly cited and accurately presented in the final report. + +## Input Description + +You will receive three parts: +1. **User Query**: The original user question +2. **Evidence**: Raw data returned from tool calls (search results, crawled web content, etc.) +3. **Research Report**: The report to be verified, containing: + - Body: Contains claims with citation markers `[n]` + - References section: The `## References` part at the end, typically in format: + `[n] Title description, Tool: tool_name, Params:xxx, Data date/Report period: xxx, Source - URL or (no-url)` + +## Verification Process + +### Stage 1: Claim Extraction +Identify all verifiable claims containing specific information from the report **body** (excluding References section), extracting four elements: +- **Subject**: The entity the claim is about (company, product, index, person, etc.) +- **Predicate**: The attribute or relationship described (revenue, growth rate, ranking, status, etc.) +- **Object**: The specific value, quantity, or conclusion +- **Qualifier**: Limiting conditions (time, scope, prerequisites, etc.) + +**Criteria for verifiable claims**: +- Contains specific numbers (amounts, ratios, growth rates, rankings, etc.) +- Contains specific dates or time periods +- Contains definitive factual statements that can be supported or refuted by evidence +- Multiple values in one sentence count as one claim + +### Stage 2: Citation Checking +Check whether each claim has a citation marker `[n]`: +- Has citation → proceed to next stage +- No citation → mark as `citation_missing` + +### Stage 3: Source Tracing +Trace citation `[n]` verification path: **Report body [n] → [n] entry in References → Corresponding data in Evidence** +- If `[n]` entry exists in References and corresponding data can be found in Evidence → proceed to next stage +- If `[n]` entry doesn't exist in References, or entry is invalid (e.g., URL is javascript:void(0)) → mark as `citation_broken` + +### Stage 4: Content Alignment Verification +Compare claims in report with original data in Evidence, verify if four elements are consistent: +- Subject inconsistent → `subject_misalign` +- Predicate inconsistent → `predicate_misalign` +- Object inconsistent → `object_misalign` +- Qualifier inconsistent → `qualifier_misalign` +- All consistent → `verified` + +## Verification Status Description + +| Status | Meaning | +|--------|--------| +| `verified` | Verified: has citation, traceable, content matches Evidence | +| `citation_missing` | Missing citation: verifiable claim has no citation marker | +| `citation_broken` | Broken citation: citation doesn't exist or is invalid in References | +| `subject_misalign` | Subject misaligned: claim subject inconsistent with Evidence | +| `predicate_misalign` | Predicate misaligned: attribute or relationship doesn't match Evidence | +| `object_misalign` | Object misaligned: value or conclusion inconsistent with Evidence | +| `qualifier_misalign` | Qualifier misaligned: time or condition inconsistent with Evidence | + +## Content Alignment Rules + +### Subject Alignment Rules +- ✓ Exact match or known alias equivalence (e.g., Tencent = Tencent Holdings) +- ✓ Stock code corresponds to company name (e.g., 600745 = Wingtech) +- ✗ Different entity confusion (Company A data mislabeled as Company B) +- ✗ Scope confusion (subsidiary/channel data mislabeled as group total) + +### Predicate Alignment Rules +- ✓ Exact match or semantic equivalence (e.g., ROE = Return on Equity, Revenue = Operating Income = Total Revenue) +- ✗ Concept confusion (Net profit ≠ Operating revenue, Gross margin ≠ Net margin) +- ✗ Scope confusion (Daily return rate ≠ Weekly return rate, YoY ≠ MoM) + +### Object Alignment Rules +- ✓ Exact match (45.403B = 45.403B) +- ✓ Equivalent forms (18.60% = 18.6%, trailing zeros can be omitted) +- ✓ Unit conversion equivalence (45403 million ≈ 454.03 billion) +- ✓ Expression equivalence (down 8% = growth -8% = YoY -8%) +- ✓ Reasonable approximation: when using "approx/about/around" modifier, allow up to 5% error +- ✗ Precision loss: without "approx" modifier, cannot omit significant digits (454.03B → 454B) +- ✗ Exceeds tolerance: even with "approx" modifier, error exceeds 5% +- ✗ Value not found: cannot find this value in Evidence + +### Qualifier Alignment Rules +- ✓ Exact match or semantic equivalence (2025 Q2 = Q2 2025 = Apr-Jun 2025) +- ✓ Report period equivalence (Q3 2025 report = as of Sep 30, 2025 = first three quarters of 2025) +- ✗ Year misalignment (2024 ≠ 2025) +- ✗ Period misalignment (Q2 ≠ Q3, H1 ≠ first three quarters) +- ✗ Time point confusion (publication date ≠ data cutoff date) + +## Output Format + +Please output JSON directly in the following format: +```json +{ + "claims": [ + { + "subject": "claim subject", + "predicate": "attribute/relationship", + "object": "value/conclusion", + "qualifier": "limiting condition (use 'unspecified' if none)", + "citation": "citation marker like [1], null if none", + "status": "verified/citation_missing/citation_broken/subject_misalign/predicate_misalign/object_misalign/qualifier_misalign", + "source_id": "source number (if available)", + "note": "explanation (empty string when verified)" + } + ] +} +``` + +Output JSON only, no other explanatory text. +""" + +# ============================================================================= +# User Prompt Template +# ============================================================================= + +CGCV_USER_PROMPT_TEMPLATE_ZH = """请对以下研究报告进行引用核查,验证每个可验证断言的引用有效性和内容一致性。 + +### 用户问题 +{user_query} + +### Evidence(工具调用获取的信息) +{evidence_text} + +### 研究报告(待核查) +{report} + +请按照验证流程逐一检查报告中的可验证断言,只输出 JSON 结果。 +""" + +CGCV_USER_PROMPT_TEMPLATE_EN = """Please perform citation verification on the following research report, validating citation validity and content consistency for each verifiable claim. + +### User Query +{user_query} + +### Evidence (Information obtained through tool calls) +{evidence_text} + +### Research Report (To be verified) +{report} + +Please check each verifiable claim in the report according to the verification process, output JSON result only. +""" + +# ============================================================================= +# Utility: Get prompts by language +# ============================================================================= + +def get_cgcv_prompts(language: str = "zh"): + """ + Get CGCV prompts based on language. + + Args: + language: "zh" for Chinese, "en" for English + + Returns: + Tuple of (system_prompt, user_prompt_template) + """ + if language.lower() in ["zh", "chinese", "中文"]: + return CGCV_SYSTEM_PROMPT_ZH, CGCV_USER_PROMPT_TEMPLATE_ZH + else: + return CGCV_SYSTEM_PROMPT_EN, CGCV_USER_PROMPT_TEMPLATE_EN diff --git a/tutorial/example_deep_finance/judge/grounding/prompt.py b/tutorial/example_deep_finance/judge/grounding/prompt.py index 24bea134..337cf4bc 100644 --- a/tutorial/example_deep_finance/judge/grounding/prompt.py +++ b/tutorial/example_deep_finance/judge/grounding/prompt.py @@ -1,103 +1,152 @@ """Grounding Grader Prompt - 引用规范性评估""" -GROUNDING_SYSTEM_PROMPT = """你是一位"引用审计员",负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。 - -======================== -一、引用规范(以此为准) -======================== -1) 关键事实句必须引用: - - 关键事实句包括:数字(金额/比例/增速/同比环比/份额/排名等)、日期/期间、财务指标、估值倍数、明确事实结论、具体事件、具体公司/行业的可验证陈述、政策/条款等。 - - 不确定或推断性表述必须显式写“推测/可能/假设/预计/或有风险”等,不得用引用把推断包装成既定事实。 - -2) 引用位置规则(严格执行): - - 关键事实句必须在“句末”出现引用编号:[1] 或 [1][2](可以多个,但必须紧贴句末)。 - - 若引用出现在句中但句末没有引用编号,则该句仍按“缺引用(missing)”处理。 - -3) References 必须存在且可追溯: - - 报告末尾必须包含标题 `## References`(大小写/空格差异可容忍,但必须是一个清晰的 References 区块)。 - - 正文出现的每个 [n] 必须能在 References 中找到对应条目。 - -4) References 条目两种合法形式(必须满足其一): - A) URL 形式:`[n] 标题或简述 - https://...` - - URL 必须为可用的 http/https 链接,不能为空,也不能是 `javascript:void(0)` 之类的伪链接。 - B) no-url 形式:`[n] 简述,工具:,参数:,数据日期/报告期: - (no-url)` - - no-url 必须同时包含:工具名、参数、日期/报告期 三者(缺一即不合规)。 - - `javascript:void(0)` 等无效链接视为无效 URL(会进入 invalid_reference_nums),若要合规应改为 no-url 记录来源。 - -======================== -二、输入 -======================== +# GROUNDING_SYSTEM_PROMPT = """你是一位"引用审计员",负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。 + +# ======================== +# 一、引用规范(以此为准) +# ======================== +# 1) 关键事实句必须引用: +# - 关键事实句包括:数字(金额/比例/增速/同比环比/份额/排名等)、日期/期间、财务指标、估值倍数、明确事实结论、具体事件、具体公司/行业的可验证陈述、政策/条款等。 +# - 不确定或推断性表述必须显式写“推测/可能/假设/预计/或有风险”等,不得用引用把推断包装成既定事实。 + +# 2) 引用位置规则(严格执行): +# - 关键事实句必须在“句末”出现引用编号:[1] 或 [1][2](可以多个,但必须紧贴句末)。 +# - 若引用出现在句中但句末没有引用编号,则该句仍按“缺引用(missing)”处理。 + +# 3) References 必须存在且可追溯: +# - 报告末尾必须包含标题 `## References`(大小写/空格差异可容忍,但必须是一个清晰的 References 区块)。 +# - 正文出现的每个 [n] 必须能在 References 中找到对应条目。 + +# 4) References 条目两种合法形式(必须满足其一): +# A) URL 形式:`[n] 标题或简述 - https://...` +# - URL 必须为可用的 http/https 链接,不能为空,也不能是 `javascript:void(0)` 之类的伪链接。 +# B) no-url 形式:`[n] 简述,工具:,参数:,数据日期/报告期: - (no-url)` +# - no-url 必须同时包含:工具名、参数、日期/报告期 三者(缺一即不合规)。 +# - `javascript:void(0)` 等无效链接视为无效 URL(会进入 invalid_reference_nums),若要合规应改为 no-url 记录来源。 + +# ======================== +# 二、输入 +# ======================== +# 你会收到: +# - User Query +# - Evidence(从完整 trajectory 提取的工具调用/工具返回/用户补充信息) +# - AI Report(待审计报告,含正文与 References) + +# 真实性核对原则: +# - 以 Evidence 为准:只有在“明显矛盾”或“Evidence 明显找不到任何依据且该句仍把内容写成确定事实”时,才判 fake。 +# - 无法确认/证据缺失/证据不充分时,不要判 fake(宁可不判)。 + +# ======================== +# 三、统计与判定口径(严格遵守) +# ======================== +# 【文本范围】 +# - 只审计 AI Report 的“正文部分”(不包含 References 区块内部的文字)。 +# - References 区块仅用于校验编号是否存在、格式是否合规、URL 是否有效。 + +# 【句子/条目如何计数】 +# - “句子/条目”包括:普通句号/分号/换行分点(如列表项、段落中的 bullet)、表格中的单元格陈述(若表达了关键事实,也算关键事实句)。 +# - 一句包含多个数字/多个事实点:仍按 1 条关键事实句计数(不要过度拆分)。 +# - 同一句若重复出现多次(复制粘贴重复段落):按出现次数计数。 + +# 【关键事实句识别(务求稳定)】 +# - 满足任一条件可视为关键事实句: +# (a) 含具体数值/比例/排名/区间/估值倍数/财务指标; +# (b) 含具体日期或期间(如 “2024Q3/2025年/截至XX日”); +# (c) 对具体公司/行业/政策做了可验证的确定性陈述; +# (d) 给出明确结论且呈确定口吻并可被证据支持/反驳。 + +# 【引用是否“句末”】【重要】 +# - 句末引用指:该句最后的可见字符为一个或多个连续的 [n](允许中间无空格或有极少空格),例如: +# - “……增长 20%[3]” +# - “……增长 20% [3][4]” +# - 若 [n] 后面仍有正文内容(哪怕很短),则不算句末引用。 + +# 【invalid_reference_nums 的定义】 +# - 统计“正文中出现过”的编号 n(去重),若满足任一条件则判为 invalid: +# (a) References 中不存在该编号条目; +# (b) 该编号条目为 URL 形式但 URL 无效(空/非 http(s)/javascript:void(0) 等); +# (c) 该编号条目为 no-url 形式但缺少 工具名/参数/日期(报告期) 任意之一。 +# - invalid_reference_nums 输出按数字升序;最多 5 个,超出截断。 + +# 【missing_count 的定义】 +# - 关键事实句中“句末没有任何 [n]”的数量(即使句中出现 [n] 也算 missing)。 + +# 【cited_key_facts 的定义】 +# - 关键事实句中“句末包含至少一个 [n]”的数量(不要求该引用有效)。 + +# 【fake_count 的定义(只在明显时计数)】 +# - 关键事实句若“句末带引用”,但与 Evidence 明显矛盾,或 Evidence 明显找不到任何依据且该句仍用确定口吻陈述为事实,计为 fake。 +# - 若只是 Evidence 未覆盖/不充分/不确定,不计 fake。 + +# 【good_citations 的定义】 +# - 从报告原文中抽取最多 2 条“引用做得正确”的关键事实句,要求同时满足: +# - 是关键事实句; +# - 句末有 [n]; +# - 所有句末 [n] 在 References 中均存在且条目合法(URL 有效或 no-url 字段齐全)。 +# - good_citations 是原文截取,不要加解释;最多 2 条,超出截断。 + +# ======================== +# 四、输出(只输出 JSON,字段固定) +# ======================== +# { +# "total_key_facts": , +# "cited_key_facts": , +# "good_citations": ["...", "..."], +# "missing_count": , +# "fake_count": , +# "invalid_reference_nums": [, ...] +# } + +# 只输出 JSON,不要输出解释文字或 Markdown。确保 JSON 可被严格解析(双引号、逗号、方括号等格式正确)。 +# """ + +GROUNDING_SYSTEM_PROMPT = """ +你是一位“引用审计员”,负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。 + +## 引用规范(以此为准) +- 关键事实句必须引用:关键事实句包括数字/同比环比/日期/财务指标/估值倍数/明确事实结论/具体事件/具体公司或行业陈述/政策条款。 +- 关键事实句句末必须出现引用编号:[1] 或 [1][2]。 +- 报告末尾必须包含 `## References`。 +- 正文出现的每个 [n] 必须能在 References 中找到对应条目。 +- References 条目两种合法形式: + A) URL 形式:`[n] 标题或简述 - https://...` + B) no-url 形式:`[n] 简述,工具:,参数:,数据日期/报告期: - (no-url)` +- `javascript:void(0)` 等无效链接不算 URL,应按 no-url 形式记录来源信息。 +- 禁止伪造来源;没有证据支撑的只能写“推测/假设”,不能用引用把推测包装成事实。 + +## 输入 你会收到: - User Query - Evidence(从完整 trajectory 提取的工具调用/工具返回/用户补充信息) - AI Report(待审计报告,含正文与 References) -真实性核对原则: -- 以 Evidence 为准:只有在“明显矛盾”或“Evidence 明显找不到任何依据且该句仍把内容写成确定事实”时,才判 fake。 -- 无法确认/证据缺失/证据不充分时,不要判 fake(宁可不判)。 - -======================== -三、统计与判定口径(严格遵守) -======================== -【文本范围】 -- 只审计 AI Report 的“正文部分”(不包含 References 区块内部的文字)。 -- References 区块仅用于校验编号是否存在、格式是否合规、URL 是否有效。 - -【句子/条目如何计数】 -- “句子/条目”包括:普通句号/分号/换行分点(如列表项、段落中的 bullet)、表格中的单元格陈述(若表达了关键事实,也算关键事实句)。 -- 一句包含多个数字/多个事实点:仍按 1 条关键事实句计数(不要过度拆分)。 -- 同一句若重复出现多次(复制粘贴重复段落):按出现次数计数。 - -【关键事实句识别(务求稳定)】 -- 满足任一条件可视为关键事实句: - (a) 含具体数值/比例/排名/区间/估值倍数/财务指标; - (b) 含具体日期或期间(如 “2024Q3/2025年/截至XX日”); - (c) 对具体公司/行业/政策做了可验证的确定性陈述; - (d) 给出明确结论且呈确定口吻并可被证据支持/反驳。 - -【引用是否“句末”】【重要】 -- 句末引用指:该句最后的可见字符为一个或多个连续的 [n](允许中间无空格或有极少空格),例如: - - “……增长 20%[3]” - - “……增长 20% [3][4]” -- 若 [n] 后面仍有正文内容(哪怕很短),则不算句末引用。 - -【invalid_reference_nums 的定义】 -- 统计“正文中出现过”的编号 n(去重),若满足任一条件则判为 invalid: - (a) References 中不存在该编号条目; - (b) 该编号条目为 URL 形式但 URL 无效(空/非 http(s)/javascript:void(0) 等); - (c) 该编号条目为 no-url 形式但缺少 工具名/参数/日期(报告期) 任意之一。 -- invalid_reference_nums 输出按数字升序;最多 5 个,超出截断。 - -【missing_count 的定义】 -- 关键事实句中“句末没有任何 [n]”的数量(即使句中出现 [n] 也算 missing)。 - -【cited_key_facts 的定义】 -- 关键事实句中“句末包含至少一个 [n]”的数量(不要求该引用有效)。 - -【fake_count 的定义(只在明显时计数)】 -- 关键事实句若“句末带引用”,但与 Evidence 明显矛盾,或 Evidence 明显找不到任何依据且该句仍用确定口吻陈述为事实,计为 fake。 -- 若只是 Evidence 未覆盖/不充分/不确定,不计 fake。 - -【good_citations 的定义】 -- 从报告原文中抽取最多 2 条“引用做得正确”的关键事实句,要求同时满足: - - 是关键事实句; - - 句末有 [n]; - - 所有句末 [n] 在 References 中均存在且条目合法(URL 有效或 no-url 字段齐全)。 -- good_citations 是原文截取,不要加解释;最多 2 条,超出截断。 - -======================== -四、输出(只输出 JSON,字段固定) -======================== +核对真实性时,以 Evidence 为准:只有在“明显矛盾/明显找不到依据”时才判 fake;无法确认则不要判 fake。 + +## 输出(只输出 JSON,字段固定) { "total_key_facts": , "cited_key_facts": , - "good_citations": ["...", "..."], + "good_citations": ["从报告原文截取的:关键事实句 + 句末 [n],且 References 可追溯(最多 5 条)", ...] "missing_count": , "fake_count": , - "invalid_reference_nums": [, ...] + "invalid_reference_nums": [, ...], } -只输出 JSON,不要输出解释文字或 Markdown。确保 JSON 可被严格解析(双引号、逗号、方括号等格式正确)。 +统计口径(为保证稳定,严格遵守): +- total_key_facts:正文中关键事实句的总数(按句子/条目计;一句多个数字也算 1 条即可,不要过度拆分)。 +- cited_key_facts:关键事实句中,句末包含至少一个 [n] 的数量(不要求该引用一定有效)。 +- invalid_reference_nums:正文出现过、但满足任一条件的编号: + (a) References 中不存在该编号条目; + (b) URL 形式但 URL 无效(空或 javascript:void(0) 等); + (c) no-url 形式但缺少“工具名/参数/日期(报告期)”之一。 +- missing_count:关键事实句中“句末没有 [n]”的数量。 +- fake_count:关键事实句“带引用但与 Evidence 明显矛盾/明显无支撑”的数量(仅明显时计数)。 +- good_citations:从报告原文中选取最多 5 条“引用做得正确”的关键事实句(句末有 [n],且 [n] 在 References 中合法)。 + +长度约束(必须): +- invalid_reference_nums 最多 5 个,多余截断。 +- good_citations 最多 2 条,多余截断。 +只输出 JSON,不要输出解释文字或 Markdown。 """ # ============================================================================= diff --git a/tutorial/example_deep_finance/judge/traceability/__init__.py b/tutorial/example_deep_finance/judge/traceability/__init__.py new file mode 100644 index 00000000..18845402 --- /dev/null +++ b/tutorial/example_deep_finance/judge/traceability/__init__.py @@ -0,0 +1,7 @@ +""" +CGCV (Citation-Grounded Claim Verification) Grader +引用锚定的断言验证框架 +""" +from .grader import TraceabilityRewardGrader + +__all__ = ["TraceabilityRewardGrader"] diff --git a/tutorial/example_deep_finance/judge/traceability/grader.py b/tutorial/example_deep_finance/judge/traceability/grader.py new file mode 100644 index 00000000..75206f6b --- /dev/null +++ b/tutorial/example_deep_finance/judge/traceability/grader.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from openjudge.graders.base_grader import BaseGrader +from openjudge.graders.schema import GraderScore + +try: + from openjudge.models import OpenAIChatModel +except Exception: # pragma: no cover + from openjudge.models.openai_chat_model import OpenAIChatModel + +from .prompt import TRACEABILITY_SYSTEM_PROMPT, TRACEABILITY_USER_PROMPT_TEMPLATE +from .json_utils import strict_load_json, validate_shape, coerce_to_messages_list, construct_traceability_prompt, count_digit_tokens + + +class TraceabilityRewardGrader(BaseGrader): + """ + Traceability & Verifiability Reward (TVR) + + Input: traj (trajectory / record) - supports: + - list[dict] + - list[list[dict]] + - dict with {"traj": ...} etc. + + Output: GraderScore(name="traceability", score in [0,1], reason includes stats + brief examples) + """ + + def __init__( + self, + model: Optional[OpenAIChatModel] = None, + name: str = "traceability", + temperature: float = 0.0, + max_tokens: int = 2200, + ) -> None: + super().__init__(name=name) + self.model = model or OpenAIChatModel( + model_name="gpt-4.1-mini", + temperature=temperature, + max_tokens=max_tokens, + response_format={"type": "json_object"}, + ) + + async def aevaluate(self, traj: Any, **kwargs: Any) -> GraderScore: + messages = coerce_to_messages_list(traj) + + user_prompt, report_plain = construct_traceability_prompt( + messages, + TRACEABILITY_USER_PROMPT_TEMPLATE, + ) + + judge_messages = [ + {"role": "system", "content": TRACEABILITY_SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + + resp = await self.model.achat(judge_messages) + text = resp.get("content", "") + + try: + obj = strict_load_json(text) + norm = validate_shape(obj) + score = self._compute_score(norm, report_plain) + reason = self._build_reason(norm, report_plain, score) + return GraderScore(name=self.name, score=score, reason=reason) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"TVR judge output invalid: {e}", + ) + + def _compute_score(self, norm: Dict[str, Any], report_plain: str) -> float: + stats = norm["stats"] + total = max(1, int(stats.get("total_claims", 0))) + + supported = int(stats.get("supported", 0)) + contradicted = int(stats.get("contradicted", 0)) + no_evidence = int(stats.get("no_evidence", 0)) + speculative_ok = int(stats.get("speculative_ok", 0)) + unclear = int(stats.get("unclear", 0)) + + # Positive contribution + pos = supported + 0.6 * speculative_ok + 0.3 * unclear + # Negative contribution (contradiction is harsh) + neg = 1.0 * contradicted + 0.8 * no_evidence + + base = (pos - neg) / total # can be negative + base = max(0.0, min(1.0, base)) + + # Coverage factor (deterministic) based on digits/dates in report body + real_digit_tokens = count_digit_tokens(report_plain) + expected_min_claims = min(25, max(6, real_digit_tokens // 2)) + claim_count = int(stats.get("total_claims", total)) + + selection_factor = min(1.0, claim_count / expected_min_claims) if expected_min_claims > 0 else 1.0 + + # If the judge reports digit coverage, blend it in (but keep deterministic as the main) + reported_total_digits = int(stats.get("report_digit_tokens", 0)) + reported_covered_digits = int(stats.get("covered_digit_tokens", 0)) + if reported_total_digits > 0: + reported_cov = min(1.0, max(0.0, reported_covered_digits / reported_total_digits)) + else: + reported_cov = 1.0 + + cov_factor = 0.7 + 0.3 * reported_cov # [0.7, 1.0] + + score = base * selection_factor * cov_factor + score = max(0.0, min(1.0, score)) + return float(score) + + def _build_reason(self, norm: Dict[str, Any], report_plain: str, score: float) -> str: + stats = norm["stats"] + ex = norm.get("examples", {}) + best = ex.get("best_supported", []) + worst = ex.get("worst_failed", []) + + real_digit_tokens = count_digit_tokens(report_plain) + + parts = [] + parts.append( + f"score={score:.3f}; " + f"claims={stats['total_claims']}; " + f"supported={stats['supported']}; " + f"spec_ok={stats['speculative_ok']}; " + f"unclear={stats['unclear']}; " + f"no_ev={stats['no_evidence']}; " + f"contradicted={stats['contradicted']}; " + f"report_digits≈{real_digit_tokens}" + ) + + if best: + parts.append(f"best_supported={best[:1]}") + if worst: + parts.append(f"worst_failed={worst[:1]}") + return " | ".join(parts) diff --git a/tutorial/example_deep_finance/judge/traceability/json_utils.py b/tutorial/example_deep_finance/judge/traceability/json_utils.py new file mode 100644 index 00000000..de6fb3cc --- /dev/null +++ b/tutorial/example_deep_finance/judge/traceability/json_utils.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import json +import re +from typing import Any, Dict, List, Tuple + + +# -------------------------- +# JSON parsing helpers +# -------------------------- + +def strict_load_json(text: str) -> Dict[str, Any]: + """ + Parse a JSON object from model output. + + - Accept plain JSON. + - If extra text exists, extract the first {...} block. + """ + text = (text or "").strip() + try: + obj = json.loads(text) + if isinstance(obj, dict): + return obj + except Exception: + pass + + start = text.find("{") + end = text.rfind("}") + if start != -1 and end != -1 and end > start: + snippet = text[start : end + 1] + obj = json.loads(snippet) + if isinstance(obj, dict): + return obj + + raise ValueError("Invalid JSON output") + + +def _clip(s: str, n: int) -> str: + s = s or "" + s = s.replace("\u0000", "") + return s[:n] + + +def _as_int(x: Any, default: int = 0) -> int: + try: + if x is None: + return default + if isinstance(x, bool): + return int(x) + if isinstance(x, (int, float)): + return int(x) + if isinstance(x, str) and x.strip(): + return int(float(x.strip())) + except Exception: + return default + return default + + +def _as_list_str(x: Any, max_items: int = 10, max_len: int = 60) -> List[str]: + if not isinstance(x, list): + return [] + out: List[str] = [] + for item in x[:max_items]: + if isinstance(item, str): + out.append(_clip(item, max_len)) + else: + out.append(_clip(str(item), max_len)) + return out + + +def validate_shape(obj: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate and normalize model output for TVR. + + Returns: + { + "claims": [...], + "stats": {...}, + "examples": {...} + } + """ + if not isinstance(obj, dict): + raise ValueError("Output is not a JSON object") + + claims_raw = obj.get("claims", []) + if not isinstance(claims_raw, list): + claims_raw = [] + + claims: List[Dict[str, Any]] = [] + for c in claims_raw[:25]: + if not isinstance(c, dict): + continue + + claim = _clip(str(c.get("claim", "")), 240) + ctype = _clip(str(c.get("type", "other")), 24) + + sig = c.get("signature", {}) + if not isinstance(sig, dict): + sig = {} + entities = _as_list_str(sig.get("entities", []), max_items=10, max_len=50) + numbers = _as_list_str(sig.get("numbers", []), max_items=10, max_len=40) + times = _as_list_str(sig.get("times", []), max_items=10, max_len=40) + + anchors_raw = c.get("anchors", []) + anchors: List[Dict[str, Any]] = [] + if isinstance(anchors_raw, list): + for a in anchors_raw[:2]: + if not isinstance(a, dict): + continue + step = _as_int(a.get("step", -1), default=-1) + quote = _clip(str(a.get("quote", "")), 120) + if step >= 0 and quote: + anchors.append({"step": step, "quote": quote}) + + verdict = _clip(str(c.get("verdict", "unclear")), 20) + if verdict not in {"supported", "contradicted", "no_evidence", "speculative_ok", "unclear"}: + verdict = "unclear" + + issue = _clip(str(c.get("issue", "none")), 20) + allowed_issues = { + "none", "entity_mismatch", "time_mismatch", "value_mismatch", "scope_mismatch", + "logic_leap", "over_precision", "missing_anchor" + } + if issue not in allowed_issues: + issue = "none" + + note = _clip(str(c.get("note", "")), 80) + + claims.append({ + "claim": claim, + "type": ctype, + "signature": {"entities": entities, "numbers": numbers, "times": times}, + "anchors": anchors, + "verdict": verdict, + "issue": issue, + "note": note, + }) + + # stats + stats_raw = obj.get("stats", {}) + if not isinstance(stats_raw, dict): + stats_raw = {} + + # always re-count to avoid mismatch / gaming + verdict_counts = { + "supported": 0, + "contradicted": 0, + "no_evidence": 0, + "speculative_ok": 0, + "unclear": 0, + } + for c in claims: + verdict_counts[c["verdict"]] += 1 + + report_digit_tokens = max(0, _as_int(stats_raw.get("report_digit_tokens", 0), default=0)) + covered_digit_tokens = max(0, _as_int(stats_raw.get("covered_digit_tokens", 0), default=0)) + + stats = { + "total_claims": len(claims), + "supported": verdict_counts["supported"], + "contradicted": verdict_counts["contradicted"], + "no_evidence": verdict_counts["no_evidence"], + "speculative_ok": verdict_counts["speculative_ok"], + "unclear": verdict_counts["unclear"], + "report_digit_tokens": report_digit_tokens, + "covered_digit_tokens": covered_digit_tokens, + } + + # examples (small) + examples_raw = obj.get("examples", {}) + if not isinstance(examples_raw, dict): + examples_raw = {} + + def _normalize_example_list(x: Any, max_items: int = 2) -> List[Dict[str, Any]]: + if not isinstance(x, list): + return [] + out: List[Dict[str, Any]] = [] + for it in x[:max_items]: + if isinstance(it, dict): + out.append({k: _clip(str(v), 140) for k, v in list(it.items())[:3]}) + elif isinstance(it, str): + out.append({"text": _clip(it, 140)}) + return out + + examples = { + "best_supported": _normalize_example_list(examples_raw.get("best_supported", []), 2), + "worst_failed": _normalize_example_list(examples_raw.get("worst_failed", []), 2), + } + + return {"claims": claims, "stats": stats, "examples": examples} + + +# -------------------------- +# Trajectory helpers +# -------------------------- + +def coerce_to_messages_list(traj: Any) -> List[Dict[str, Any]]: + """ + Accepts: + - list[dict] + - list[list[dict]] (take first non-empty inner list) + - dict with keys: traj / messages / conversation / steps (best-effort) + + Returns list[dict] message objects. + """ + if traj is None: + return [] + + if isinstance(traj, dict): + for key in ("traj", "messages", "conversation", "steps"): + if key in traj: + return coerce_to_messages_list(traj[key]) + return [] + + if isinstance(traj, list): + if not traj: + return [] + if isinstance(traj[0], list): + for inner in traj: + if isinstance(inner, list) and inner and isinstance(inner[0], dict): + return inner + return [] + if isinstance(traj[0], dict): + return traj + + return [] + + +def _extract_text_content(content: Any) -> str: + """ + Extract textual content from different possible message formats. + """ + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: List[str] = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif isinstance(item, str): + parts.append(item) + return "\n".join([p for p in parts if p]) + return str(content) + + +def strip_references(markdown: str) -> str: + """ + Remove References section and anything after it (common Markdown headings). + """ + if not isinstance(markdown, str): + return "" + m = re.search(r"\n#+\s*References\b", markdown, flags=re.IGNORECASE) + if m: + return markdown[: m.start()].strip() + return markdown.strip() + + +def count_digit_tokens(text: str) -> int: + """ + Rough count for digit/date tokens in text. + """ + if not text: + return 0 + pats = [ + r"\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", # ISO-ish date + r"\b\d+(?:\.\d+)?%?\b", # number / percent + ] + tokens: List[str] = [] + for p in pats: + tokens.extend(re.findall(p, text)) + return len(tokens) + + +def _is_probably_final_report(text: str) -> bool: + """ + Heuristic: final report is usually markdown-ish and contains References / TASK_COMPLETED etc. + """ + if not text: + return False + # allow either TASK_COMPLETED or markdown headings + References + has_markdown = ("#" in text) or ("|---" in text) or ("## " in text) + has_refs = re.search(r"#+\s*References\b", text, flags=re.IGNORECASE) is not None + has_done = "[TASK_COMPLETED]" in text + return has_done or (has_markdown and has_refs) + + +def _extract_tool_calls_and_results(trajectory: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Extract tool call and tool output blocks, loosely following the format in your existing data. + """ + items: List[Dict[str, Any]] = [] + for i, msg in enumerate(trajectory): + role = msg.get("role", "") + content = _extract_text_content(msg.get("content", "")) + + if role == "assistant": + # look for JSON code block that indicates tool calls + if "```json" in content and ("tool_name" in content or "tool_args" in content): + items.append({"step": i, "kind": "tool_call", "text": content}) + elif role == "tool": + items.append({"step": i, "kind": "tool_result", "text": content}) + return items + + +def construct_reward_prompt(trajectory: List[Dict[str, Any]], user_prompt_template: str) -> str: + """ + Build a user prompt with: + - user_query: last user message + - evidence_text: concatenated tool calls/results with step index + - final_report: last assistant message that looks like final report + """ + trajectory = coerce_to_messages_list(trajectory) + + user_query = "" + for msg in reversed(trajectory): + if msg.get("role") == "user": + user_query = _extract_text_content(msg.get("content", "")) + break + + final_report = "" + for msg in reversed(trajectory): + if msg.get("role") == "assistant": + t = _extract_text_content(msg.get("content", "")) + if _is_probably_final_report(t): + final_report = t + break + if not final_report: + # fallback to last assistant msg + for msg in reversed(trajectory): + if msg.get("role") == "assistant": + final_report = _extract_text_content(msg.get("content", "")) + break + + evidence_items = _extract_tool_calls_and_results(trajectory) + evidence_lines: List[str] = [] + for it in evidence_items: + step = it["step"] + kind = it["kind"] + prefix = "CALL" if kind == "tool_call" else "RESULT" + evidence_lines.append(f"[{prefix} step={step}]\n{it['text']}".strip()) + evidence_text = "\n\n".join(evidence_lines).strip() + + return user_prompt_template.format( + user_query=user_query, + evidence_text=evidence_text, + final_report=final_report, + ) + + +def construct_traceability_prompt( + trajectory: List[Dict[str, Any]], + user_prompt_template: str, +) -> Tuple[str, str]: + """ + Returns: + - user_prompt (for the judge model) + - report_plain (final report without References) for deterministic coverage checks + """ + user_prompt = construct_reward_prompt(trajectory, user_prompt_template) + + final_report = "" + marker = "\n## AI Report\n" + if marker in user_prompt: + final_report = user_prompt.split(marker, 1)[1] + # Cut at "\n\n### 审计流程" if present. + cut = "\n\n### 审计流程" + if cut in final_report: + final_report = final_report.split(cut, 1)[0] + + report_plain = strip_references(final_report) + return user_prompt, report_plain diff --git a/tutorial/example_deep_finance/judge/traceability/prompt.py b/tutorial/example_deep_finance/judge/traceability/prompt.py new file mode 100644 index 00000000..e8d7d1bd --- /dev/null +++ b/tutorial/example_deep_finance/judge/traceability/prompt.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +""" +Traceability & Verifiability Reward (TVR) + +目标: +- 用“可追溯性/可核验性”替代“引用是否存在”的Reference类reward; +- 避免强金融领域绑定:面向任何深度研究报告,只要有“证据(工具结果/对话上下文)+ 报告文本”即可工作; +- 通过“断言-证据锚点”审计,奖励:事实陈述可在证据中找到锚点、或明确标注为推测;惩罚:无证据的硬断言、与证据矛盾、过度精确的数值。 + +注意:该文件仅包含 prompt,不包含打分逻辑。打分由 grader.py 依据模型输出的结构化审计结果计算。 +""" + +TRACEABILITY_SYSTEM_PROMPT = r""" +# 你的身份 +你是一名“可追溯性/可核验性审计官(Traceability Auditor)”。 + +# 你的目标 +给定: +- 用户问题(User Question) +- 证据区(Evidence):包含对话中工具调用与工具返回的原文片段(视为“可用证据全集”) +- 待审计报告(AI Report):模型写出的最终 Markdown 报告 + +你需要评估:报告中的“可核验断言”是否能在 Evidence 中找到明确锚点(traceable),或是否被正确标注为“推测/假设”。 + +# 核心原则(非常重要) +1) **Evidence 是唯一事实来源**:不得使用外部常识/训练记忆补全缺失证据。 +2) **先举证再下结论**:输出结构中必须先给出断言与证据锚点/不匹配点,再汇总统计;不要先给分数再找理由。 +3) **惩罚“硬断言无证据”**:越具体(数字、日期、比例、排名、同比环比、绝对结论)的断言越需要证据锚点。 +4) **允许“推测/假设”**:若报告明确使用“可能/预计/推测/假设/大概率”等表述,并且没有把它包装成确定事实,则可以判为 speculative_ok(弱奖励/不惩罚)。 +5) **优先覆盖“数字/日期/实体”断言**:必须覆盖报告正文中出现的每一个数字或日期(含表格);因为这是最容易出现“编造”的区域。 +6) **不要评估写作质量**(结构/文风/可读性等不在本任务范围),只评估“可追溯/可核验”。 + +# 你要产出的 JSON(严格 JSON,不要 markdown,不要多余文本) +输出 JSON 需要包含: +- claims:断言列表,每条必须包含断言原文、锚点要素(实体/数值/时间)、证据锚点(step+quote)、判定与原因 +- stats:统计汇总(先统计,再由外部计算分数) +- examples:最多各2条“最好的支持案例”和“最差的失败案例”(用于调试) + +断言(claim)的判定(verdict)只能是: +- supported:Evidence 中有明确锚点支撑(实体/时间/数值关键点对应) +- contradicted:Evidence 中存在明确冲突(数值/时间/事实相反) +- no_evidence:找不到相关证据锚点,且该断言是硬断言 +- speculative_ok:断言被明确标注为推测/假设,且未伪装成事实 +- unclear:Evidence 有相关但不足以确定支持/反驳(模糊、缺关键字段、只部分匹配) + +issue(主要问题)建议从下面枚举中选择一个: +- none | entity_mismatch | time_mismatch | value_mismatch | scope_mismatch | logic_leap | over_precision | missing_anchor + +额外要求: +- 每条 claim 的 note ≤ 80 字(给出关键理由即可) +- evidence_quote ≤ 120 字,必须是 Evidence 中的原文片段(可截断) +""" + +# NOTE: 该模板会被 json_utils.construct_reward_prompt 填充 {user_query} {evidence_text} {final_report} +TRACEABILITY_USER_PROMPT_TEMPLATE = r""" +请对下面的 AI Report 做“可追溯性/可核验性审计”,并严格按要求输出 JSON。 + +## User Question +{user_query} + +## Evidence +{evidence_text} + +## AI Report +{final_report} + +### 审计流程(必须执行) +1) 仅审计 **AI Report 正文**(忽略其 `## References` 及之后内容)。 +2) 抽取“可核验断言”: + - 必须包含:所有出现“数字/日期”的句子或表格行(逐条拆成原子断言) + - 另外补充:3–8条非数字但可核验的硬事实(涉及具体实体/事件/定义/比较/因果的断言) +3) 对每条断言: + - 提取锚点要素:entities / numbers / times(可以为空列表,但含数字/日期的断言不得为空) + - 在 Evidence 中找到最相关的 1–2 个锚点(用 step 序号 + 原文 quote 表示) + - 给出 verdict + issue + note(简短指出匹配/不匹配的关键点) +4) 最后汇总 stats 与 examples(不要给分数)。 + +### 输出 JSON 结构(严格遵守字段名;不要新增顶层字段) +{{ + "claims": [ + {{ + "claim": "从报告中复制的原句或原子断言(尽量短)", + "type": "quant|event|definition|comparison|causal|recommendation|other", + "signature": {{ + "entities": ["..."], + "numbers": ["..."], + "times": ["..."] + }}, + "anchors": [ + {{"step": 12, "quote": "Evidence 原文片段..."}}, + {{"step": 13, "quote": "Evidence 原文片段..."}} + ], + "verdict": "supported|contradicted|no_evidence|speculative_ok|unclear", + "issue": "none|entity_mismatch|time_mismatch|value_mismatch|scope_mismatch|logic_leap|over_precision|missing_anchor", + "note": "≤80字,说明为何这样判定" + }} + ], + "stats": {{ + "total_claims": 0, + "supported": 0, + "contradicted": 0, + "no_evidence": 0, + "speculative_ok": 0, + "unclear": 0, + "report_digit_tokens": 0, + "covered_digit_tokens": 0 + }}, + "examples": {{ + "best_supported": [ + {{"claim": "...", "anchor": {{"step": 0, "quote": "..."}}}} + ], + "worst_failed": [ + {{"claim": "...", "why": "..." }} + ] + }} +}} + +### 统计口径(必须一致) +- report_digit_tokens:你在报告正文中识别到的“数字/日期 token”的数量(近似即可;如 1330亿美元、13.7%、2025-09-30 各算 1 个 token) +- covered_digit_tokens:这些 token 中,有多少出现在你提取的 claims 的 signature.numbers 或 signature.times 里(近似即可) +- total_claims 必须等于 claims 的条数;其余计数必须与 claims 中 verdict 的统计一致 +""" diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index 33103fe3..c4b950f5 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -12,6 +12,9 @@ ajet: # OpenJudge 权重配置 presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估 grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估 + cgcv_weight: {{CGCV_WEIGHT}} # Citation-Grounded Claim Verification + audit_weight: {{AUDIT_WEIGHT}} # 引用逻辑审计 + traceability_weight: {{TRACEABILITY_WEIGHT}} # 可追溯性/可核验性审计 (TVR) rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 task_judge: # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) diff --git a/tutorial/example_deep_finance/yaml_template/infer.yaml b/tutorial/example_deep_finance/yaml_template/infer.yaml new file mode 100644 index 00000000..c86832e0 --- /dev/null +++ b/tutorial/example_deep_finance/yaml_template/infer.yaml @@ -0,0 +1,87 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: "{{PREFIX}}" + experiment_name: "{{SUFFIX}}" + # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) + judge: + openjudge_llm: {{OPENJUDGE_LLM}} # OpenJudge 模型 + rm_llm: {{RM_LLM}} # RM Gallery 模型 + concurrency: {{JUDGE_CONCURRENCY}} # Judge 并发数 + train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径 + val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径 + # OpenJudge 权重配置 + presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估 + grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估 + rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 + task_judge: + # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) + judge_protocol: tutorial.example_deep_finance.deep_finance_judge->DeepFinanceJudgeByOpenJudge + model: + # ✨✨✨✨ 设置待训练的模型 + path: {{MODEL_PATH}} + trainer_common: + nnodes: {{NNODES}} + n_gpus_per_node: 8 + val_before_train: True + val_pass_n: 4 + save_freq: 10 + test_freq: 2 + total_epochs: {{TOTAL_EPOCHS}} + save_trajectory_as_json_file: True + rollout: + # ✨✨✨✨ 编写并选择Agent + user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol + force_disable_toolcalls: False + enable_oversample: False + tensor_model_parallel_size: 8 + num_repeat: {{NUM_REPEAT}} + max_env_worker: 64 # 增加环境并行数 + max_num_seqs: 64 # 增加VLLM并发序列数 + max_response_length_in_one_turn: 8000 + max_model_len: 50000 + agent_madness_reward: 0.0 + compute_madness_checklist: None + multi_turn: + max_steps: {{NUM_STEPS}} + interchange_server: + interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + debug: + debug_max_parallel: 1 # 增加并行任务数,充分利用GPU + debug_first_n_tasks: 100 # 增加处理的任务数 + data: + train_batch_size: {{TRAIN_BATCH_SIZE}} + max_prompt_length: 8000 + max_response_length: 41000 + + task_reader: + type: deep_finance # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service + deep_finance: + training: + file_path: {{TRAIN_DATA_PATH}} + validation: + file_path: {{VAL_DATA_PATH}} + # env_service 仍需配置(用于工具调用) + env_service: + env_type: "finworld" + env_url: {{ENV_SERVICE_URL}} + env_action_preference: code +trainer: + default_local_dir: "{{CKPT_SAVE_PATH}}/{{PREFIX}}/{{SUFFIX}}" + # resume_mode: disable # 禁用自动恢复,从头开始训练 +actor_rollout_ref: + rollout: + tensor_model_parallel_size: 8 + gpu_memory_utilization: 0.8 +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ From 4722a799e17ce154bd565839773c099bf56ddf77 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Wed, 11 Feb 2026 10:28:13 +0800 Subject: [PATCH 52/56] =?UTF-8?q?chore(deepfinance):=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=20EBTU=20=E8=AF=81=E6=8D=AE=E4=BC=98=E5=85=88=E5=8F=AF?= =?UTF-8?q?=E8=BF=BD=E6=BA=AF=E6=80=A7=E5=AE=A1=E8=AE=A1=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 EBTUTraceabilityGrader 并集成入 DeepFinanceJudge 权重配置 - deep_finance.yaml 配置最大模型长度调整为 40960 - 脚本 deep_finance.sh 和 deep_finance_single.sh 中增加 EBTU 及相关权重配置 - 完善 deep_finance_single.sh 单机调试日志及目录结构 - 深度完善 audit、cgcv、traceability json 解析,增加对常见 JSON 格式错误的自动修复 - audit grader 中移除对模型输出 integrity_score 的依赖,采用手动计算方式 - 禁用 ExampleDeepResearchProtocol 中部分工具统计日志输出,增加线程信号量限制 - 调整提示和 yaml 模板,新增 EBTU 权重占位符,完善配置文件生成日志显示 --- tutorial/example_deep_finance/deep_finance.py | 8 +- tutorial/example_deep_finance/deep_finance.sh | 10 +- .../example_deep_finance/deep_finance.yaml | 2 +- .../deep_finance_judge.py | 10 +- .../deep_finance_single.sh | 28 +- .../example_deep_finance/judge/__init__.py | 3 +- .../judge/audit/grader.py | 10 +- .../judge/audit/json_utils.py | 88 +++- .../judge/cgcv/json_utils.py | 85 +++- .../judge/ebtu/__init__.py | 1 + .../example_deep_finance/judge/ebtu/grader.py | 154 ++++++ .../judge/ebtu/json_utils.py | 455 ++++++++++++++++++ .../example_deep_finance/judge/ebtu/prompt.py | 157 ++++++ .../judge/traceability/grader.py | 43 +- .../judge/traceability/json_utils.py | 96 +++- .../prompt/tool_prompt_builder.py | 5 - .../yaml_template/deep_finance_template.yaml | 1 + .../deep_finance_template_maxlen.yaml | 91 ++++ .../yaml_template/infer.yaml | 8 +- 19 files changed, 1208 insertions(+), 47 deletions(-) create mode 100644 tutorial/example_deep_finance/judge/ebtu/__init__.py create mode 100644 tutorial/example_deep_finance/judge/ebtu/grader.py create mode 100644 tutorial/example_deep_finance/judge/ebtu/json_utils.py create mode 100644 tutorial/example_deep_finance/judge/ebtu/prompt.py create mode 100644 tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml diff --git a/tutorial/example_deep_finance/deep_finance.py b/tutorial/example_deep_finance/deep_finance.py index 470e6225..baffb0b3 100644 --- a/tutorial/example_deep_finance/deep_finance.py +++ b/tutorial/example_deep_finance/deep_finance.py @@ -9,7 +9,7 @@ # 创建信号量,允许同时12个线程运行 -sem = threading.Semaphore(30) +sem = threading.Semaphore(60) class ExampleDeepResearchProtocol(Workflow): @@ -125,9 +125,9 @@ async def execute( if info: if 'tool_stats' in info: latest_tool_stats = info['tool_stats'] - if latest_tool_stats.get('total_calls', 0) > 0: - logger.info(f"步骤 {step + 1} 工具统计: 调用={latest_tool_stats.get('total_calls', 0)}, " - f"成功率={latest_tool_stats.get('success_rate', 0):.1f}%") + # if latest_tool_stats.get('total_calls', 0) > 0: + # logger.info(f"步骤 {step + 1} 工具统计: 调用={latest_tool_stats.get('total_calls', 0)}, " + # f"成功率={latest_tool_stats.get('success_rate', 0):.1f}%") if 'reward_stats' in info: latest_reward_stats = info['reward_stats'] # 累加工具调用时间 diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index de1aa061..f6121655 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -15,6 +15,10 @@ JUDGE_CONCURRENCY=10 RM_WEIGHT=0.5 PRESENTATION_QUALITY_WEIGHT=0.25 GROUNDING_WEIGHT=0.25 +CGCV_WEIGHT=0.0 # 不使用 CGCV,设为 0 +AUDIT_WEIGHT=0.0 # 不使用 Audit,设为 0 +TRACEABILITY_WEIGHT=0.0 # 不使用 Traceability,设为 0 +EBTU_WEIGHT=0.0 # 不使用 EBTU,设为 0 # 训练参数配置 NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 @@ -60,6 +64,10 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ -e "s|{{PRESENTATION_QUALITY_WEIGHT}}|${PRESENTATION_QUALITY_WEIGHT}|g" \ -e "s|{{GROUNDING_WEIGHT}}|${GROUNDING_WEIGHT}|g" \ + -e "s|{{CGCV_WEIGHT}}|${CGCV_WEIGHT}|g" \ + -e "s|{{AUDIT_WEIGHT}}|${AUDIT_WEIGHT}|g" \ + -e "s|{{TRACEABILITY_WEIGHT}}|${TRACEABILITY_WEIGHT}|g" \ + -e "s|{{EBTU_WEIGHT}}|${EBTU_WEIGHT}|g" \ -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ -e "s|{{RM_LLM}}|${RM_LLM}|g" \ -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ @@ -75,7 +83,7 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} echo "配置文件已生成: ${CONFIG_FILE}" -echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" +echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, CGCV=${CGCV_WEIGHT}, Audit=${AUDIT_WEIGHT}, Traceability=${TRACEABILITY_WEIGHT}, EBTU=${EBTU_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" #=============================================================================== # 3. 环境配置 diff --git a/tutorial/example_deep_finance/deep_finance.yaml b/tutorial/example_deep_finance/deep_finance.yaml index 33103fe3..1f2122ba 100644 --- a/tutorial/example_deep_finance/deep_finance.yaml +++ b/tutorial/example_deep_finance/deep_finance.yaml @@ -38,7 +38,7 @@ ajet: max_env_worker: 64 # 增加环境并行数 max_num_seqs: 64 # 增加VLLM并发序列数 max_response_length_in_one_turn: 8000 - max_model_len: 50000 + max_model_len: 40960 agent_madness_reward: 0.0 compute_madness_checklist: None multi_turn: diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py index b4c9c96f..0548b385 100644 --- a/tutorial/example_deep_finance/deep_finance_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -15,7 +15,7 @@ from openjudge.models.openai_chat_model import OpenAIChatModel from openjudge.runner.grading_runner import GraderConfig, GradingRunner -from tutorial.example_deep_finance.judge import PresentationQualityGrader, GroundingGrader, CGCVGrader, AuditGrader, TraceabilityRewardGrader +from tutorial.example_deep_finance.judge import PresentationQualityGrader, GroundingGrader, CGCVGrader, AuditGrader, TraceabilityRewardGrader, EBTUTraceabilityGrader @@ -105,8 +105,9 @@ def _setup_weights(self): "presentation_quality": getattr(cfg, "presentation_quality_weight", 0.25) if cfg else 0.25, "grounding": getattr(cfg, "grounding_weight", 0.0) if cfg else 0.0, # 引用规范性评估 "cgcv": getattr(cfg, "cgcv_weight", 0.25) if cfg else 0.25, # Citation-Grounded Claim Verification - "audit": getattr(cfg, "audit_weight", 0.0) if cfg else 0.0, # 引用逻辑审计 + "audit": getattr(cfg, "audit_weight", 0.0) if cfg else 0.0, # Audit Grader: audit reward 引用逻辑审计 "traceability": getattr(cfg, "traceability_weight", 0.0) if cfg else 0.0, # 可追溯性/可核验性审计 (TVR) + "ebtu": getattr(cfg, "ebtu_weight", 0.0) if cfg else 0.0, # Audit Grader: audit reward EBTU证据优先可追溯性审计 } # 归一化(注意:action_loop 是惩罚项,不参与归一化;rm 需要参与归一化) @@ -274,6 +275,11 @@ def extract_report_content(data: Dict) -> str: grader=TraceabilityRewardGrader(model=model), mapper=lambda data: {"traj": data}, ), + # Audit Grader: audit reward EBTU证据优先可追溯性审计 - Evidence-Backed Trace Units + "ebtu": GraderConfig( + grader=EBTUTraceabilityGrader(model=model), + mapper=lambda data: {"traj": data}, + ), } def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> Tuple[float, bool]: diff --git a/tutorial/example_deep_finance/deep_finance_single.sh b/tutorial/example_deep_finance/deep_finance_single.sh index cc5c8a00..67de294d 100644 --- a/tutorial/example_deep_finance/deep_finance_single.sh +++ b/tutorial/example_deep_finance/deep_finance_single.sh @@ -15,6 +15,10 @@ JUDGE_CONCURRENCY=10 RM_WEIGHT=0.5 PRESENTATION_QUALITY_WEIGHT=0.25 GROUNDING_WEIGHT=0.25 +CGCV_WEIGHT=0.0 # 不使用 CGCV,设为 0 +AUDIT_WEIGHT=0.0 # 不使用 Audit,设为 0 +TRACEABILITY_WEIGHT=0.0 # 不使用 Traceability,设为 0 +EBTU_WEIGHT=0.0 # 不使用 EBTU,设为 0 # 训练参数配置 NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 @@ -28,7 +32,13 @@ ENV_SERVICE_URL="http://127.0.0.1:8080" # 环境服务地址 # 主目录(需要更改) export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new" -NNODES=${WORLD_SIZE} +# 单机调试配置(默认值) +NNODES=${WORLD_SIZE:-1} +GPUS_PER_NODE=8 +CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") +LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" +TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" +mkdir -p ${LOG_DIR} # 涉密的配置(API_KEY以及模型、数据位置)从.env读取 cd ${AJET_ROOT} @@ -45,6 +55,9 @@ else echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" fi +export MODEL_PATH="/mnt/data_cpfs/taoshuchang.tsc/models/Qwen3-8B" + + #=============================================================================== # 2. 动态生成配置文件 (从yaml template生成yaml) #=============================================================================== @@ -60,6 +73,10 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ -e "s|{{PRESENTATION_QUALITY_WEIGHT}}|${PRESENTATION_QUALITY_WEIGHT}|g" \ -e "s|{{GROUNDING_WEIGHT}}|${GROUNDING_WEIGHT}|g" \ + -e "s|{{CGCV_WEIGHT}}|${CGCV_WEIGHT}|g" \ + -e "s|{{AUDIT_WEIGHT}}|${AUDIT_WEIGHT}|g" \ + -e "s|{{TRACEABILITY_WEIGHT}}|${TRACEABILITY_WEIGHT}|g" \ + -e "s|{{EBTU_WEIGHT}}|${EBTU_WEIGHT}|g" \ -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ -e "s|{{RM_LLM}}|${RM_LLM}|g" \ -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ @@ -75,7 +92,7 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} echo "配置文件已生成: ${CONFIG_FILE}" -echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" +echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, CGCV=${CGCV_WEIGHT}, Audit=${AUDIT_WEIGHT}, Traceability=${TRACEABILITY_WEIGHT}, EBTU=${EBTU_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" #=============================================================================== @@ -119,15 +136,16 @@ export RAY_CLUSTER_MODE="multi_node" #=============================================================================== # 6. 主流程 #=============================================================================== -log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" -mkdir -p ${LOG_DIR} -mkdir -p $(dirname ${CONFIG_FILE}) +log "单机调试模式: NNODES=${NNODES}, GPUS_PER_NODE=${GPUS_PER_NODE}" #=============================================================================== # 6.1 Master 节点启动流程 #=============================================================================== # 启动训练任务(最核心) +# 请注意只有单节点需要--with-ray 多节点应该删除 python ajet/launcher.py \ --conf ${CONFIG_FILE} \ + --with-deepfinance \ + --with-ray \ --backbone="debug" \ 2>&1 | tee ${TRAIN_LOG} diff --git a/tutorial/example_deep_finance/judge/__init__.py b/tutorial/example_deep_finance/judge/__init__.py index c2aee2be..235247f9 100644 --- a/tutorial/example_deep_finance/judge/__init__.py +++ b/tutorial/example_deep_finance/judge/__init__.py @@ -4,6 +4,7 @@ from .cgcv.grader import CGCVGrader from .audit.grader import AuditGrader from .traceability.grader import TraceabilityRewardGrader +from .ebtu.grader import EBTUTraceabilityGrader # from .research_depth.grader import ResearchDepthGrader # from .research_breadth.grader import ResearchBreadthGrader @@ -11,4 +12,4 @@ # from .grounding.grader import GroundingGrader # from .research_breadth.grader import ResearchBreadthGrader # __all__ = ["PresentationQualityGrader", "GroundingGrader", "ResearchDepthGrader", "ResearchBreadthGrader"] -__all__ = ["PresentationQualityGrader", "GroundingGrader", "CGCVGrader", "AuditGrader", "TraceabilityRewardGrader"] +__all__ = ["PresentationQualityGrader", "GroundingGrader", "CGCVGrader", "AuditGrader", "TraceabilityRewardGrader", "EBTUTraceabilityGrader"] diff --git a/tutorial/example_deep_finance/judge/audit/grader.py b/tutorial/example_deep_finance/judge/audit/grader.py index 18c0e397..e7d7f5c9 100644 --- a/tutorial/example_deep_finance/judge/audit/grader.py +++ b/tutorial/example_deep_finance/judge/audit/grader.py @@ -184,11 +184,11 @@ def _compute_scores(self, obj: Dict[str, Any]) -> Tuple[float, str]: supported_count = verdict_counts["Supported"] # 优先使用模型输出的 score,如果有误则回退到手动计算 - model_score = obj.get("integrity_score") - if isinstance(model_score, (float, int)) and 0.0 <= model_score <= 1.0: - final_score = float(model_score) - else: - final_score = supported_count / total_citations if total_citations > 0 else 0.0 + # model_score = obj.get("integrity_score") + # if isinstance(model_score, (float, int)) and 0.0 <= model_score <= 1.0: + # final_score = float(model_score) + # else: + final_score = supported_count / total_citations if total_citations > 0 else 0.0 # 构建 Reason # 格式: Score: 0.80 | Total: 10 | Supp: 8, Over: 1, Hallu: 1 | Summary: ... diff --git a/tutorial/example_deep_finance/judge/audit/json_utils.py b/tutorial/example_deep_finance/judge/audit/json_utils.py index 394aaefc..11e157ae 100644 --- a/tutorial/example_deep_finance/judge/audit/json_utils.py +++ b/tutorial/example_deep_finance/judge/audit/json_utils.py @@ -15,16 +15,102 @@ def extract_first_json_object(text: str) -> str | None: return None return m.group(0) + +def _repair_json(js: str) -> str: + """ + 尝试修复常见的JSON格式错误 + 1. 修复字符串中未转义的换行符 + 2. 修复trailing comma + 3. 修复缺少的逗号 + 4. 修复不完整的JSON(截断) + """ + # 1. 替换字符串值中的未转义换行符 + # 这是最常见的问题:LLM在字符串中直接输出换行而非 \n + def escape_newlines_in_strings(s: str) -> str: + result = [] + in_string = False + escape_next = False + i = 0 + while i < len(s): + c = s[i] + if escape_next: + result.append(c) + escape_next = False + elif c == '\\': + result.append(c) + escape_next = True + elif c == '"': + result.append(c) + in_string = not in_string + elif in_string and c == '\n': + result.append('\\n') + elif in_string and c == '\r': + result.append('\\r') + elif in_string and c == '\t': + result.append('\\t') + else: + result.append(c) + i += 1 + return ''.join(result) + + js = escape_newlines_in_strings(js) + + # 2. 移除trailing comma: ",}" -> "}" 和 ",]" -> "]" + js = re.sub(r',\s*}', '}', js) + js = re.sub(r',\s*]', ']', js) + + # 3. 尝试修复截断的JSON - 补全缺失的括号 + # 统计括号数量 + open_braces = js.count('{') + close_braces = js.count('}') + open_brackets = js.count('[') + close_brackets = js.count(']') + + # 如果括号不匹配,尝试补全 + if open_braces > close_braces: + # 先关闭可能未闭合的字符串 + # 检查最后是否在字符串中 + in_string = False + escape_next = False + for c in js: + if escape_next: + escape_next = False + elif c == '\\': + escape_next = True + elif c == '"': + in_string = not in_string + if in_string: + js += '"' + + # 补全缺失的括号 + js += ']' * (open_brackets - close_brackets) + js += '}' * (open_braces - close_braces) + + return js + + def strict_load_json(text: str) -> Tuple[Dict[str, Any] | None, str | None]: js = extract_first_json_object(text) if js is None: return None, "No JSON object found" + + # 第一次尝试:直接解析 try: obj = json.loads(js) if not isinstance(obj, dict): return None, f"Root is not dict: {type(obj)}" return obj, None - except Exception as e: + except json.JSONDecodeError: + pass # 继续尝试修复 + + # 第二次尝试:修复后解析 + try: + repaired = _repair_json(js) + obj = json.loads(repaired) + if not isinstance(obj, dict): + return None, f"Root is not dict: {type(obj)}" + return obj, None + except json.JSONDecodeError as e: return None, f"JSONDecodeError: {str(e)}" def validate_integrity_shape(obj: Dict[str, Any]) -> Tuple[Dict[str, Any] | None, str | None]: diff --git a/tutorial/example_deep_finance/judge/cgcv/json_utils.py b/tutorial/example_deep_finance/judge/cgcv/json_utils.py index 965bf301..48cb59aa 100644 --- a/tutorial/example_deep_finance/judge/cgcv/json_utils.py +++ b/tutorial/example_deep_finance/judge/cgcv/json_utils.py @@ -33,6 +33,78 @@ class ClaimStatus(str, Enum): _JSON_RE = re.compile(r"\{.*\}", re.DOTALL) +# ============================================================================= +# JSON Repair Helper +# ============================================================================= + +def _repair_json(js: str) -> str: + """ + 尝试修复常见的JSON格式错误 + 1. 修复字符串中未转义的换行符 + 2. 修复trailing comma + 3. 修复不完整的JSON(截断) + """ + # 1. 替换字符串值中的未转义换行符 + def escape_newlines_in_strings(s: str) -> str: + result = [] + in_string = False + escape_next = False + i = 0 + while i < len(s): + c = s[i] + if escape_next: + result.append(c) + escape_next = False + elif c == '\\': + result.append(c) + escape_next = True + elif c == '"': + result.append(c) + in_string = not in_string + elif in_string and c == '\n': + result.append('\\n') + elif in_string and c == '\r': + result.append('\\r') + elif in_string and c == '\t': + result.append('\\t') + else: + result.append(c) + i += 1 + return ''.join(result) + + js = escape_newlines_in_strings(js) + + # 2. 移除trailing comma: ",}" -> "}" 和 ",]" -> "]" + js = re.sub(r',\s*}', '}', js) + js = re.sub(r',\s*]', ']', js) + + # 3. 尝试修复截断的JSON - 补全缺失的括号 + open_braces = js.count('{') + close_braces = js.count('}') + open_brackets = js.count('[') + close_brackets = js.count(']') + + if open_braces > close_braces: + # 先关闭可能未闭合的字符串 + in_string = False + escape_next = False + for c in js: + if escape_next: + escape_next = False + elif c == '\\': + escape_next = True + elif c == '"': + in_string = not in_string + if in_string: + js += '"' + + # 补全缺失的括号 + js += ']' * (open_brackets - close_brackets) + js += '}' * (open_braces - close_braces) + + return js + + # ============================================================================= # Data Classes # ============================================================================= @@ -126,7 +198,7 @@ def extract_first_json_object(text: str) -> Optional[str]: def strict_load_json(text: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: """ - 严格解析 JSON + 严格解析 JSON(带容错修复) Args: text: 原始文本 @@ -138,11 +210,22 @@ def strict_load_json(text: str) -> Tuple[Optional[Dict[str, Any]], Optional[str] if js is None: return None, "No JSON object found in model output" + # 第一次尝试:直接解析 try: obj = json.loads(js) if not isinstance(obj, dict): return None, f"Top-level JSON is not an object: {type(obj).__name__}" return obj, None + except json.JSONDecodeError: + pass # 继续尝试修复 + + # 第二次尝试:修复后解析 + try: + repaired = _repair_json(js) + obj = json.loads(repaired) + if not isinstance(obj, dict): + return None, f"Top-level JSON is not an object: {type(obj).__name__}" + return obj, None except json.JSONDecodeError as e: return None, f"JSONDecodeError: {e}" except Exception as e: diff --git a/tutorial/example_deep_finance/judge/ebtu/__init__.py b/tutorial/example_deep_finance/judge/ebtu/__init__.py new file mode 100644 index 00000000..86ba0083 --- /dev/null +++ b/tutorial/example_deep_finance/judge/ebtu/__init__.py @@ -0,0 +1 @@ +# ebtu_reward package diff --git a/tutorial/example_deep_finance/judge/ebtu/grader.py b/tutorial/example_deep_finance/judge/ebtu/grader.py new file mode 100644 index 00000000..6ecea50a --- /dev/null +++ b/tutorial/example_deep_finance/judge/ebtu/grader.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from typing import Any, Dict, Optional + +from openjudge.graders.base_grader import BaseGrader +from openjudge.graders.schema import GraderScore + +try: + from openjudge.models import OpenAIChatModel +except Exception: # pragma: no cover + from openjudge.models.openai_chat_model import OpenAIChatModel + +from .prompt import EBTU_SYSTEM_PROMPT, EBTU_USER_PROMPT_TEMPLATE +from .json_utils import ( + strict_load_json, + validate_shape, + coerce_to_messages_list, + construct_ebtu_prompt, + count_digit_tokens, +) + + +class EBTUTraceabilityGrader(BaseGrader): + """ + Evidence-Backed Trace Units (EBTU) Grader + + Input: + - traj or record JSON that contains trajectory messages + + Output: + - GraderScore(score in [0,1], reason with compact stats) + """ + + def __init__( + self, + model: Optional[OpenAIChatModel] = None, + name: str = "ebtu_traceability", + temperature: float = 0.0, + max_tokens: int = 2600, + model_name: str = "qwen-flash", + ) -> None: + super().__init__(name=name) + self.model = model or OpenAIChatModel( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + response_format={"type": "json_object"}, + ) + + async def aevaluate(self, traj: Any, **kwargs: Any) -> GraderScore: + messages = coerce_to_messages_list(traj) + + # 输入有效性检查 + if not messages: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty or invalid trajectory", + ) + + user_prompt, report_plain = construct_ebtu_prompt(messages, EBTU_USER_PROMPT_TEMPLATE) + + judge_messages = [ + {"role": "system", "content": EBTU_SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + + # 模型调用(带异常保护) + try: + resp = await self.model.achat(judge_messages) + raw_text = getattr(resp, "content", None) + if raw_text is None: + raw_text = str(resp) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ) + + try: + obj = strict_load_json(str(raw_text)) + norm = validate_shape(obj) + + score = self._compute_score(norm, report_plain) + reason = self._build_reason(norm, report_plain, score) + return GraderScore(name=self.name, score=score, reason=reason) + except Exception as e: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore(name=self.name, score=0.0, reason=f"EBTU parse error: {e}; raw[:200]={snippet}") + + def _compute_score(self, norm: Dict[str, Any], report_plain: str) -> float: + stats = norm["stats"] + units = norm["units"] + + hard_total = max(1, int(stats.get("hard_units", 0))) + supported = int(stats.get("supported", 0)) + contradicted = int(stats.get("contradicted", 0)) + no_evidence = int(stats.get("no_evidence", 0)) + unclear = int(stats.get("unclear", 0)) + misattrib = int(stats.get("misattrib", 0)) + anchored_hard = max(1, int(stats.get("anchored_hard_units", 0))) + + # Base: reward supported; penalize contradicted/no_evidence strongly; unclear mildly + base = (supported - 1.4 * contradicted - 0.9 * no_evidence - 0.4 * unclear) / hard_total + base = max(0.0, min(1.0, base)) + + # Misattribution penalty: anchors exist but not supported (wrong anchor / wrong use) + misattrib_rate = misattrib / anchored_hard + misattrib_factor = max(0.0, 1.0 - 0.7 * misattrib_rate) + + # Deterministic coverage heuristics based on report digit tokens + digit_tokens = count_digit_tokens(report_plain) + expected_min_units = min(25, max(6, digit_tokens // 2)) + extracted_units = max(1, len(units)) + selection_factor = min(1.0, extracted_units / expected_min_units) if expected_min_units > 0 else 1.0 + + # Optional judge-reported digit/date coverage (soft) + reported_total = int(stats.get("report_digit_date_tokens", 0)) + reported_cov = int(stats.get("covered_digit_date_tokens", 0)) + if reported_total > 0: + cov_ratio = max(0.0, min(1.0, reported_cov / reported_total)) + else: + cov_ratio = 1.0 + cov_factor = 0.65 + 0.35 * cov_ratio # [0.65, 1.0] + + score = base * misattrib_factor * selection_factor * cov_factor + return float(max(0.0, min(1.0, score))) + + def _build_reason(self, norm: Dict[str, Any], report_plain: str, score: float) -> str: + s = norm["stats"] + ex = norm.get("examples", {}) + best = ex.get("best_supported", []) + worst = ex.get("worst_failed", []) + digit_tokens = count_digit_tokens(report_plain) + + parts = [ + f"score={score:.3f}", + f"units={s['total_units']}", + f"hard={s['hard_units']}", + f"sup={s['supported']}", + f"ctr={s['contradicted']}", + f"noev={s['no_evidence']}", + f"unc={s['unclear']}", + f"anch_hard={s['anchored_hard_units']}", + f"misattrib={s['misattrib']}", + f"report_digits≈{digit_tokens}", + ] + if best: + parts.append(f"best={best[:1]}") + if worst: + parts.append(f"worst={worst[:1]}") + return " | ".join(parts) diff --git a/tutorial/example_deep_finance/judge/ebtu/json_utils.py b/tutorial/example_deep_finance/judge/ebtu/json_utils.py new file mode 100644 index 00000000..69b22b13 --- /dev/null +++ b/tutorial/example_deep_finance/judge/ebtu/json_utils.py @@ -0,0 +1,455 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import json +import re +from typing import Any, Dict, List, Tuple + + +# ============================================================================= +# JSON Repair Helper +# ============================================================================= + +def _repair_json(js: str) -> str: + """ + 尝试修复常见的JSON格式错误 + 1. 修复字符串中未转义的换行符 + 2. 修复trailing comma + 3. 修复不完整的JSON(截断) + """ + # 1. 替换字符串值中的未转义换行符 + def escape_newlines_in_strings(s: str) -> str: + result = [] + in_string = False + escape_next = False + i = 0 + while i < len(s): + c = s[i] + if escape_next: + result.append(c) + escape_next = False + elif c == '\\': + result.append(c) + escape_next = True + elif c == '"': + result.append(c) + in_string = not in_string + elif in_string and c == '\n': + result.append('\\n') + elif in_string and c == '\r': + result.append('\\r') + elif in_string and c == '\t': + result.append('\\t') + else: + result.append(c) + i += 1 + return ''.join(result) + + js = escape_newlines_in_strings(js) + + # 2. 移除trailing comma: ",}" -> "}" 和 ",]" -> "]" + js = re.sub(r',\s*}', '}', js) + js = re.sub(r',\s*]', ']', js) + + # 3. 尝试修复截断的JSON - 补全缺失的括号 + open_braces = js.count('{') + close_braces = js.count('}') + open_brackets = js.count('[') + close_brackets = js.count(']') + + if open_braces > close_braces: + # 先关闭可能未闭合的字符串 + in_string = False + escape_next = False + for c in js: + if escape_next: + escape_next = False + elif c == '\\': + escape_next = True + elif c == '"': + in_string = not in_string + if in_string: + js += '"' + + # 补全缺失的括号 + js += ']' * (open_brackets - close_brackets) + js += '}' * (open_braces - close_braces) + + return js + + +def strict_load_json(text: str) -> Dict[str, Any]: + """Parse a JSON object from model output; extract first {...} block if needed. 带容错修复。""" + text = (text or "").strip() + + # 第一次尝试:直接解析 + try: + obj = json.loads(text) + if isinstance(obj, dict): + return obj + except Exception: + pass + + # 尝试提取 {...} 片段 + start = text.find("{") + end = text.rfind("}") + if start != -1 and end != -1 and end > start: + snippet = text[start : end + 1] + # 第二次尝试:直接解析提取的片段 + try: + obj = json.loads(snippet) + if isinstance(obj, dict): + return obj + except Exception: + pass + + # 第三次尝试:修复后解析 + try: + repaired = _repair_json(snippet) + obj = json.loads(repaired) + if isinstance(obj, dict): + return obj + except Exception: + pass + + raise ValueError("Invalid JSON output") + + +def _clip(s: str, n: int) -> str: + s = s or "" + s = s.replace("\u0000", "") + return s[:n] + + +def _as_int(x: Any, default: int = 0) -> int: + try: + if x is None: + return default + if isinstance(x, bool): + return int(x) + if isinstance(x, (int, float)): + return int(x) + if isinstance(x, str) and x.strip(): + return int(float(x.strip())) + except Exception: + return default + return default + + +def _as_list_str(x: Any, max_items: int = 10, max_len: int = 60) -> List[str]: + if not isinstance(x, list): + return [] + out: List[str] = [] + for item in x[:max_items]: + if isinstance(item, str): + out.append(_clip(item, max_len)) + else: + out.append(_clip(str(item), max_len)) + return out + + +def validate_shape(obj: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate and normalize model output for EBTU. + + Returns: + {"units": [...], "stats": {...}, "examples": {...}} + """ + if not isinstance(obj, dict): + raise ValueError("Output is not a JSON object") + + units_raw = obj.get("units", []) + if not isinstance(units_raw, list): + units_raw = [] + + units: List[Dict[str, Any]] = [] + for u in units_raw[:30]: + if not isinstance(u, dict): + continue + + claim = _clip(str(u.get("claim", "")), 280) + hardness = _clip(str(u.get("hardness", "hard")), 8) + if hardness not in {"hard", "soft"}: + hardness = "hard" + + utype = _clip(str(u.get("type", "other")), 24) + + sig = u.get("signature", {}) + if not isinstance(sig, dict): + sig = {} + entities = _as_list_str(sig.get("entities", []), max_items=10, max_len=60) + numbers = _as_list_str(sig.get("numbers", []), max_items=10, max_len=40) + times = _as_list_str(sig.get("times", []), max_items=10, max_len=40) + + ev = u.get("evidence", {}) + if not isinstance(ev, dict): + ev = {} + anchors_raw = ev.get("anchors", []) + anchors: List[Dict[str, Any]] = [] + if isinstance(anchors_raw, list): + for a in anchors_raw[:2]: + if not isinstance(a, dict): + continue + step = _as_int(a.get("step", -1), default=-1) + quote = _clip(str(a.get("quote", "")), 120) + if step >= 0 and quote: + anchors.append({"step": step, "quote": quote}) + anchor_note = _clip(str(ev.get("anchor_note", "")), 60) + + ver = u.get("verification", {}) + if not isinstance(ver, dict): + ver = {} + + verdict = _clip(str(ver.get("verdict", "unclear")), 20) + if verdict not in {"supported", "contradicted", "no_evidence", "speculative_ok", "unclear"}: + verdict = "unclear" + + issue = _clip(str(ver.get("issue", "none")), 20) + allowed_issues = { + "none", "entity_mismatch", "time_mismatch", "value_mismatch", "scope_mismatch", + "logic_leap", "over_precision", "missing_anchor" + } + if issue not in allowed_issues: + issue = "none" + + note = _clip(str(ver.get("note", "")), 80) + + units.append({ + "claim": claim, + "hardness": hardness, + "type": utype, + "signature": {"entities": entities, "numbers": numbers, "times": times}, + "evidence": {"anchors": anchors, "anchor_note": anchor_note}, + "verification": {"verdict": verdict, "issue": issue, "note": note}, + }) + + # Recompute counts (anti-gaming) + verdict_counts = {k: 0 for k in ["supported", "contradicted", "no_evidence", "speculative_ok", "unclear"]} + hard_units = 0 + anchored_hard_units = 0 + misattrib = 0 + for u in units: + v = u["verification"]["verdict"] + verdict_counts[v] += 1 + if u["hardness"] == "hard": + hard_units += 1 + if u["evidence"]["anchors"]: + anchored_hard_units += 1 + if v != "supported": + misattrib += 1 + + stats_raw = obj.get("stats", {}) + if not isinstance(stats_raw, dict): + stats_raw = {} + report_digit_date_tokens = max(0, _as_int(stats_raw.get("report_digit_date_tokens", 0), default=0)) + covered_digit_date_tokens = max(0, _as_int(stats_raw.get("covered_digit_date_tokens", 0), default=0)) + + stats = { + "total_units": len(units), + "hard_units": hard_units, + "supported": verdict_counts["supported"], + "contradicted": verdict_counts["contradicted"], + "no_evidence": verdict_counts["no_evidence"], + "speculative_ok": verdict_counts["speculative_ok"], + "unclear": verdict_counts["unclear"], + "report_digit_date_tokens": report_digit_date_tokens, + "covered_digit_date_tokens": covered_digit_date_tokens, + "anchored_hard_units": anchored_hard_units, + "misattrib": misattrib, + } + + examples_raw = obj.get("examples", {}) + if not isinstance(examples_raw, dict): + examples_raw = {} + + def _norm_list(x: Any, max_items: int = 2) -> List[Dict[str, Any]]: + if not isinstance(x, list): + return [] + out: List[Dict[str, Any]] = [] + for it in x[:max_items]: + if isinstance(it, dict): + out.append({k: _clip(str(v), 160) for k, v in list(it.items())[:4]}) + elif isinstance(it, str): + out.append({"text": _clip(it, 160)}) + return out + + examples = { + "best_supported": _norm_list(examples_raw.get("best_supported", []), 2), + "worst_failed": _norm_list(examples_raw.get("worst_failed", []), 2), + } + + return {"units": units, "stats": stats, "examples": examples} + + +def coerce_to_messages_list(traj: Any) -> List[Dict[str, Any]]: + """Accept list[dict], list[list[dict]], or dict wrapper.""" + if traj is None: + return [] + if isinstance(traj, dict): + for key in ("traj", "messages", "conversation", "steps"): + if key in traj: + return coerce_to_messages_list(traj[key]) + return [] + if isinstance(traj, list): + if not traj: + return [] + if isinstance(traj[0], list): + for inner in traj: + if isinstance(inner, list) and inner and isinstance(inner[0], dict): + return inner + return [] + if isinstance(traj[0], dict): + return traj + return [] + + +def _extract_text_content(content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: List[str] = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif isinstance(item, str): + parts.append(item) + return "\n".join([p for p in parts if p]) + return str(content) + + +def strip_references(markdown: str) -> str: + if not isinstance(markdown, str): + return "" + m = re.search(r"\n#+\s*References\b", markdown, flags=re.IGNORECASE) + if m: + return markdown[: m.start()].strip() + return markdown.strip() + + +def count_digit_tokens(text: str) -> int: + if not text: + return 0 + pats = [ + r"\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", + r"\b\d+(?:\.\d+)?%?\b", + ] + tokens: List[str] = [] + for p in pats: + tokens.extend(re.findall(p, text)) + return len(tokens) + + +def _strip_think(text: str) -> str: + """去除 ... 标签""" + if not text: + return "" + return re.sub(r".*?\s*", "", text, flags=re.S).strip() + + +def _looks_like_tool_result(text: str) -> bool: + """判断是否为工具返回结果""" + t = (text or "").strip() + if t.startswith("Tool:") or t.startswith("Result:"): + return True + if t.startswith("{") and ("query" in t) and ("search_results" in t or "response_content" in t): + return True + if ("股票代码 |" in t) or ("单位:" in t) or t.startswith("### "): + return True + return False + + +def _is_probably_final_report(text: str) -> bool: + """判断是否为最终报告""" + if not text: + return False + t = text.strip() + # 放宽条件:任一条件满足即可 + if "## References" in t or "[TASK_COMPLETED]" in t: + return True + if t.lstrip().startswith("# "): + return True + # 兼容原有逻辑 + has_markdown = ("#" in t) or ("|---" in t) or ("## " in t) + has_refs = re.search(r"#+\s*References\b", t, flags=re.IGNORECASE) is not None + return has_markdown and has_refs + + +def _extract_tool_calls_and_results(trajectory: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + items: List[Dict[str, Any]] = [] + for i, msg in enumerate(trajectory): + role = msg.get("role", "") + content = _extract_text_content(msg.get("content", "")) + if role == "assistant": + if "```json" in content and ("tool_name" in content or "tool_args" in content): + items.append({"step": i, "kind": "tool_call", "text": content}) + elif role == "tool": + items.append({"step": i, "kind": "tool_result", "text": content}) + return items + + +def construct_reward_prompt(trajectory: List[Dict[str, Any]], user_prompt_template: str) -> str: + trajectory = coerce_to_messages_list(trajectory) + + # 提取 user_query(第一个非工具结果的 user 消息) + user_query = "" + for msg in trajectory: + if msg.get("role") == "user": + raw = _extract_text_content(msg.get("content", "")) + if not _looks_like_tool_result(raw): + user_query = _strip_think(raw) + break + + # 提取 final_report(从后往前找第一个符合条件的 assistant 消息) + final_report = "" + for msg in reversed(trajectory): + if msg.get("role") == "assistant": + raw = _extract_text_content(msg.get("content", "")) + t = _strip_think(raw) + if _is_probably_final_report(t): + final_report = t + break + if not final_report: + for msg in reversed(trajectory): + if msg.get("role") == "assistant": + raw = _extract_text_content(msg.get("content", "")) + final_report = _strip_think(raw) + break + + evidence_items = _extract_tool_calls_and_results(trajectory) + evidence_lines: List[str] = [] + for it in evidence_items: + step = it["step"] + prefix = "CALL" if it["kind"] == "tool_call" else "RESULT" + evidence_lines.append(f"[{prefix} step={step}]\n{it['text']}".strip()) + evidence_text = "\n\n".join(evidence_lines).strip() + + return user_prompt_template.format( + user_query=user_query, + evidence_text=evidence_text, + final_report=final_report, + ) + + +def construct_ebtu_prompt( + trajectory: List[Dict[str, Any]], + user_prompt_template: str, +) -> Tuple[str, str]: + """ + Returns: + - user_prompt (for judge) + - report_plain (final report without References) for deterministic coverage checks + """ + user_prompt = construct_reward_prompt(trajectory, user_prompt_template) + + report_plain = "" + for marker in ("\n## Report\n", "\n## AI Report\n"): + if marker in user_prompt: + report_plain = user_prompt.split(marker, 1)[1] + break + if not report_plain: + report_plain = user_prompt + + report_plain = strip_references(report_plain) + return user_prompt, report_plain diff --git a/tutorial/example_deep_finance/judge/ebtu/prompt.py b/tutorial/example_deep_finance/judge/ebtu/prompt.py new file mode 100644 index 00000000..af1fd4ba --- /dev/null +++ b/tutorial/example_deep_finance/judge/ebtu/prompt.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- +""" +EBTU Reward: Evidence-Backed Trace Units (Evidence-first Traceability) + +设计目标: +- 用“证据优先(先证据锚点、后裁决)”的审计输出,支撑可计算的 faithful / FACT-like reward; +- 不绑定金融六元组:以通用的 Trace Unit(原子断言)为核心; +- 结构化输出便于后续确定性打分,避免“先给分再圆”。 + +本文件仅包含 Prompt(System + User Template)。 +打分逻辑在 grader.py 中实现。 +""" + +EBTU_SYSTEM_PROMPT = """ +# 你的身份 +你是一名【证据优先审计官(Evidence-first Auditor)】。 + +# 输入 +你将收到三部分: +1) User Question:用户问题 +2) Evidence:证据区(工具调用与工具返回的原文集合,按 step 编号) +3) Report:需要审计的最终报告 + +# 你的目标 +对 Report 做“可追溯性/可核验性审计”:判断 Report 中的【原子断言】是否能在 Evidence 中找到明确证据锚点。 + +# 核心原则(硬约束) +1) Evidence 是唯一事实来源:不得使用外部常识/训练记忆补全缺失证据。 +2) 证据优先:必须先给出 evidence.anchors(step+quote),再给 verification(verdict/issue/note)。 + - 严禁先输出分数或先下结论再找证据。 +3) 仅审计 Report 正文:忽略 “## References” 及其之后内容。 +4) 覆盖要求:必须覆盖 Report 正文里出现的每一个【数字/日期 token】(近似即可)。 + - 数字/日期 token 示例:13.7%、2025-09-30、1330亿美元 各算 1 个。 +5) 锚点要求: + - 对于 hardness=hard 的断言(尤其含数字/日期),必须提供 1–2 个 anchors,除非 verdict=no_evidence。 + - quote 必须来自 Evidence 原文,可截断;长度 ≤120 字。 +6) 输出必须是严格 JSON(不含 Markdown,不含额外文本);不得新增顶层字段。 +7) 不要输出 score。只输出 units + stats + examples(用于外部确定性计算 reward)。 + +# 断言类型与硬度 +- type 可选:numeric|temporal|event|definition|comparison|causal|recommendation|other +- hardness: + - hard:确定性事实断言(尤其含数字/日期/明确比较/明确事实) + - soft:明确标注推测/假设/情景分析(可能/预计/推测/假设/大概率等)且不伪装成事实 + +# verdict(只能从以下5类选) +- supported:anchors 足以直接支持断言(关键要素匹配) +- contradicted:anchors 明确与断言冲突(主体/时间/数值/方向相反) +- no_evidence:Evidence 中找不到支撑锚点,且断言是确定性表述(hard) +- speculative_ok:断言明确为推测/假设/情景分析(soft)且未伪装成事实 +- unclear:Evidence 有相关但不足以支持/反驳(口径/范围/条件缺失等) + +# issue(只能从以下枚举选) +none | entity_mismatch | time_mismatch | value_mismatch | scope_mismatch | logic_leap | over_precision | missing_anchor + +# JSON 输出模板(字段顺序必须严格一致:先证据后裁决) +{ + "units": [ + { + "claim": "<报告中的原子断言>", + "hardness": "", + "type": "", + "signature": { + "entities": ["<涉及的实体>"], + "numbers": ["<涉及的数字>"], + "times": ["<涉及的时间>"] + }, + "evidence": { + "anchors": [ + { "step": , "quote": "<来自Evidence的原文刦段,≠12字>" } + ], + "anchor_note": "<≤60字,说明为何这些anchors相关>" + }, + "verification": { + "verdict": "", + "issue": "", + "note": "<≤80字,指出支持点/冲突点/缺失点>" + } + } + ], + "stats": { + "total_units": , + "hard_units": , + "supported": , + "contradicted": , + "no_evidence": , + "speculative_ok": , + "unclear": , + "report_digit_date_tokens": , + "covered_digit_date_tokens": <被 units 覆盖的token数>, + "anchored_hard_units": , + "misattrib": <有锚点但verdict不是supported的条数> + }, + "examples": { + "best_supported": [{ "claim": "...", "anchor": { "step": 0, "quote": "..." }, "why": "<≤60字>" }], + "worst_failed": [{ "claim": "...", "verdict": "...", "why": "<≤60字>" }] + } +} + +# 示例(展示完整输出格式) +{ + "units": [ + { + "claim": "2024年Q3营收同比增长15.2%", + "hardness": "hard", + "type": "numeric", + "signature": { "entities": ["营收"], "numbers": ["15.2%"], "times": ["2024年Q3"] }, + "evidence": { + "anchors": [{ "step": 5, "quote": "Q3营收同比+15.2%,达到88.5亿元" }], + "anchor_note": "来自财报工具返回的原始数据" + }, + "verification": { "verdict": "supported", "issue": "none", "note": "数值完全匹配,时间范围一致" } + }, + { + "claim": "预计2025年净利润将达到50亿元", + "hardness": "soft", + "type": "numeric", + "signature": { "entities": ["净利润"], "numbers": ["50亿元"], "times": ["2025年"] }, + "evidence": { + "anchors": [], + "anchor_note": "分析师预测,非硬性事实" + }, + "verification": { "verdict": "speculative_ok", "issue": "none", "note": "明确标注为预测,未伪装成事实" } + } + ], + "stats": { + "total_units": 2, "hard_units": 1, + "supported": 1, "contradicted": 0, "no_evidence": 0, "speculative_ok": 1, "unclear": 0, + "report_digit_date_tokens": 8, "covered_digit_date_tokens": 6, + "anchored_hard_units": 1, "misattrib": 0 + }, + "examples": { + "best_supported": [{ "claim": "2024年Q3营收同比增长15.2%", "anchor": { "step": 5, "quote": "Q3营收同比+15.2%" }, "why": "数值精确匹配" }], + "worst_failed": [] + } +} + +# 统计口径(必须一致) +- total_units = units 的条数 +- hard_units = hardness=hard 的条数 +- supported/contradicted/no_evidence/speculative_ok/unclear 必须与 units[*].verification.verdict 统计一致 +- report_digit_date_tokens:你在 Report 正文中识别到的数字/日期 token 数(近似) +- covered_digit_date_tokens:这些 token 中,有多少被包含在 units[*].signature.numbers 或 units[*].signature.times 中(近似) +- anchored_hard_units:hard_units 中 anchors 非空的条数 +- misattrib:hard_units 中 anchors 非空,但 verdict 不是 supported 的条数(“有锚点但不支持/矛盾/不清楚”) +""" + +EBTU_USER_PROMPT_TEMPLATE = """ +## User Question +{user_query} + +## Evidence +{evidence_text} + +## Report +{final_report} +""" diff --git a/tutorial/example_deep_finance/judge/traceability/grader.py b/tutorial/example_deep_finance/judge/traceability/grader.py index 75206f6b..d8c8312c 100644 --- a/tutorial/example_deep_finance/judge/traceability/grader.py +++ b/tutorial/example_deep_finance/judge/traceability/grader.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from openjudge.graders.base_grader import BaseGrader from openjudge.graders.schema import GraderScore @@ -29,21 +29,22 @@ class TraceabilityRewardGrader(BaseGrader): def __init__( self, - model: Optional[OpenAIChatModel] = None, + model: OpenAIChatModel, name: str = "traceability", - temperature: float = 0.0, - max_tokens: int = 2200, + **kwargs: Any, ) -> None: - super().__init__(name=name) - self.model = model or OpenAIChatModel( - model_name="gpt-4.1-mini", - temperature=temperature, - max_tokens=max_tokens, - response_format={"type": "json_object"}, - ) + super().__init__(name=name, **kwargs) + self.model = model async def aevaluate(self, traj: Any, **kwargs: Any) -> GraderScore: messages = coerce_to_messages_list(traj) + + if not messages: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty trajectory", + ) user_prompt, report_plain = construct_traceability_prompt( messages, @@ -55,20 +56,32 @@ async def aevaluate(self, traj: Any, **kwargs: Any) -> GraderScore: {"role": "user", "content": user_prompt}, ] - resp = await self.model.achat(judge_messages) - text = resp.get("content", "") + # 调用模型(带异常捕获) + try: + resp = await self.model.achat(judge_messages) + raw_text = getattr(resp, "content", None) + if raw_text is None: + raw_text = str(resp) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ) + # 解析 JSON 并计算分数 try: - obj = strict_load_json(text) + obj = strict_load_json(str(raw_text)) norm = validate_shape(obj) score = self._compute_score(norm, report_plain) reason = self._build_reason(norm, report_plain, score) return GraderScore(name=self.name, score=score, reason=reason) except Exception as e: + snippet = str(raw_text)[:200].replace("\n", " ") return GraderScore( name=self.name, score=0.0, - reason=f"TVR judge output invalid: {e}", + reason=f"TVR ParseError: {e}; raw[:200]={snippet}", ) def _compute_score(self, norm: Dict[str, Any], report_plain: str) -> float: diff --git a/tutorial/example_deep_finance/judge/traceability/json_utils.py b/tutorial/example_deep_finance/judge/traceability/json_utils.py index de6fb3cc..7a005b62 100644 --- a/tutorial/example_deep_finance/judge/traceability/json_utils.py +++ b/tutorial/example_deep_finance/judge/traceability/json_utils.py @@ -6,18 +6,93 @@ from typing import Any, Dict, List, Tuple +# ============================================================================= +# JSON Repair Helper +# ============================================================================= + +def _repair_json(js: str) -> str: + """ + 尝试修复常见的JSON格式错误 + 1. 修复字符串中未转义的换行符 + 2. 修复trailing comma + 3. 修复不完整的JSON(截断) + """ + # 1. 替换字符串值中的未转义换行符 + def escape_newlines_in_strings(s: str) -> str: + result = [] + in_string = False + escape_next = False + i = 0 + while i < len(s): + c = s[i] + if escape_next: + result.append(c) + escape_next = False + elif c == '\\': + result.append(c) + escape_next = True + elif c == '"': + result.append(c) + in_string = not in_string + elif in_string and c == '\n': + result.append('\\n') + elif in_string and c == '\r': + result.append('\\r') + elif in_string and c == '\t': + result.append('\\t') + else: + result.append(c) + i += 1 + return ''.join(result) + + js = escape_newlines_in_strings(js) + + # 2. 移除trailing comma: ",}" -> "}" 和 ",]" -> "]" + js = re.sub(r',\s*}', '}', js) + js = re.sub(r',\s*]', ']', js) + + # 3. 尝试修复截断的JSON - 补全缺失的括号 + open_braces = js.count('{') + close_braces = js.count('}') + open_brackets = js.count('[') + close_brackets = js.count(']') + + if open_braces > close_braces: + # 先关闭可能未闭合的字符串 + in_string = False + escape_next = False + for c in js: + if escape_next: + escape_next = False + elif c == '\\': + escape_next = True + elif c == '"': + in_string = not in_string + if in_string: + js += '"' + + # 补全缺失的括号 + js += ']' * (open_brackets - close_brackets) + js += '}' * (open_braces - close_braces) + + return js + + # -------------------------- # JSON parsing helpers # -------------------------- def strict_load_json(text: str) -> Dict[str, Any]: """ - Parse a JSON object from model output. + Parse a JSON object from model output. 带容错修复。 - Accept plain JSON. - If extra text exists, extract the first {...} block. + - If parsing fails, attempt to repair common JSON errors. """ text = (text or "").strip() + + # 第一次尝试:直接解析 try: obj = json.loads(text) if isinstance(obj, dict): @@ -29,9 +104,22 @@ def strict_load_json(text: str) -> Dict[str, Any]: end = text.rfind("}") if start != -1 and end != -1 and end > start: snippet = text[start : end + 1] - obj = json.loads(snippet) - if isinstance(obj, dict): - return obj + # 第二次尝试:直接解析提取的片段 + try: + obj = json.loads(snippet) + if isinstance(obj, dict): + return obj + except Exception: + pass + + # 第三次尝试:修复后解析 + try: + repaired = _repair_json(snippet) + obj = json.loads(repaired) + if isinstance(obj, dict): + return obj + except Exception: + pass raise ValueError("Invalid JSON output") diff --git a/tutorial/example_deep_finance/prompt/tool_prompt_builder.py b/tutorial/example_deep_finance/prompt/tool_prompt_builder.py index 5c940fd7..0345f2c9 100644 --- a/tutorial/example_deep_finance/prompt/tool_prompt_builder.py +++ b/tutorial/example_deep_finance/prompt/tool_prompt_builder.py @@ -60,11 +60,6 @@ def get_tool_prompt_template() -> str: **参数**: - `query` (必填, string): 搜索关键词 -#### ✅ crawl_url -**功能**: 网页内容解析工具,获取并格式化指定URL的网页内容。 -**参数**: - - `url` (必填, string): 目标网页URL - # --- ### 📈 同花顺专项数据工具 (Crawl THS) diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index c4b950f5..29347b39 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -15,6 +15,7 @@ ajet: cgcv_weight: {{CGCV_WEIGHT}} # Citation-Grounded Claim Verification audit_weight: {{AUDIT_WEIGHT}} # 引用逻辑审计 traceability_weight: {{TRACEABILITY_WEIGHT}} # 可追溯性/可核验性审计 (TVR) + ebtu_weight: {{EBTU_WEIGHT}} # Audit Grader: audit reward EBTU证据优先可追溯性审计 rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 task_judge: # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml new file mode 100644 index 00000000..0ddd541c --- /dev/null +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml @@ -0,0 +1,91 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: "{{PREFIX}}" + experiment_name: "{{SUFFIX}}" + # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) + judge: + openjudge_llm: {{OPENJUDGE_LLM}} # OpenJudge 模型 + rm_llm: {{RM_LLM}} # RM Gallery 模型 + concurrency: {{JUDGE_CONCURRENCY}} # Judge 并发数 + train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径 + val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径 + # OpenJudge 权重配置 + presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估 + grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估 + cgcv_weight: {{CGCV_WEIGHT}} # Citation-Grounded Claim Verification + audit_weight: {{AUDIT_WEIGHT}} # 引用逻辑审计 + traceability_weight: {{TRACEABILITY_WEIGHT}} # 可追溯性/可核验性审计 (TVR) + ebtu_weight: {{EBTU_WEIGHT}} # Audit Grader: audit reward EBTU证据优先可追溯性审计 + rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 + task_judge: + # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) + judge_protocol: tutorial.example_deep_finance.deep_finance_judge->DeepFinanceJudgeByOpenJudge + model: + # ✨✨✨✨ 设置待训练的模型 + path: {{MODEL_PATH}} + trainer_common: + nnodes: {{NNODES}} + n_gpus_per_node: 8 + val_before_train: True + val_pass_n: 8 + save_freq: 10 + test_freq: 2 + total_epochs: 200 + save_trajectory_as_json_file: True + rollout: + # ✨✨✨✨ 编写并选择Agent + user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol + force_disable_toolcalls: False + enable_oversample: False + tensor_model_parallel_size: 8 + num_repeat: {{NUM_REPEAT}} + max_env_worker: 64 # 增加环境并行数 + max_num_seqs: 64 # 增加VLLM并发序列数 + max_response_length_in_one_turn: 8000 + max_model_len: {{MAX_MODEL_LEN}} + agent_madness_reward: 0.0 + compute_madness_checklist: None + multi_turn: + max_steps: {{NUM_STEPS}} + interchange_server: + interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + debug: + debug_max_parallel: 1 # 增加并行任务数,充分利用GPU + debug_first_n_tasks: 100 # 增加处理的任务数 + data: + train_batch_size: {{TRAIN_BATCH_SIZE}} + max_prompt_length: 20000 + max_response_length: 45000 + + task_reader: + type: deep_finance # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service + deep_finance: + training: + file_path: {{TRAIN_DATA_PATH}} + validation: + file_path: {{VAL_DATA_PATH}} + # env_service 仍需配置(用于工具调用) + env_service: + env_type: "finworld" + env_url: {{ENV_SERVICE_URL}} + env_action_preference: code +trainer: + default_local_dir: "{{CKPT_SAVE_PATH}}/{{PREFIX}}/{{SUFFIX}}" + # resume_mode: disable # 禁用自动恢复,从头开始训练 +actor_rollout_ref: + rollout: + tensor_model_parallel_size: 8 + gpu_memory_utilization: 0.8 +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_deep_finance/yaml_template/infer.yaml b/tutorial/example_deep_finance/yaml_template/infer.yaml index c86832e0..5e9d400e 100644 --- a/tutorial/example_deep_finance/yaml_template/infer.yaml +++ b/tutorial/example_deep_finance/yaml_template/infer.yaml @@ -12,6 +12,10 @@ ajet: # OpenJudge 权重配置 presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估 grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估 + cgcv_weight: {{CGCV_WEIGHT}} # Citation-Grounded Claim Verification + audit_weight: {{AUDIT_WEIGHT}} # 引用逻辑审计 + traceability_weight: {{TRACEABILITY_WEIGHT}} # 可追溯性/可核验性审计 (TVR) + ebtu_weight: {{EBTU_WEIGHT}} # Audit Grader: audit reward EBTU证据优先可追溯性审计 rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 task_judge: # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) @@ -23,7 +27,7 @@ ajet: nnodes: {{NNODES}} n_gpus_per_node: 8 val_before_train: True - val_pass_n: 4 + val_pass_n: 2 save_freq: 10 test_freq: 2 total_epochs: {{TOTAL_EPOCHS}} @@ -38,7 +42,7 @@ ajet: max_env_worker: 64 # 增加环境并行数 max_num_seqs: 64 # 增加VLLM并发序列数 max_response_length_in_one_turn: 8000 - max_model_len: 50000 + max_model_len: {{MAX_MODEL_LEN}} agent_madness_reward: 0.0 compute_madness_checklist: None multi_turn: From 68e25aeef06dec743bcc4a73f86aaf54e5e3058c Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Fri, 13 Feb 2026 12:36:27 +0800 Subject: [PATCH 53/56] Merge remote-tracking branch 'origin/main' into dev/shuchang_newjudge --- tutorial/example_deep_finance/deep_finance_judge.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py index 36327925..a5b109b5 100644 --- a/tutorial/example_deep_finance/deep_finance_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -15,11 +15,7 @@ from openjudge.models.openai_chat_model import OpenAIChatModel from openjudge.runner.grading_runner import GraderConfig, GradingRunner -<<<<<<< HEAD from tutorial.example_deep_finance.judge import PresentationQualityGrader, GroundingGrader, CGCVGrader, AuditGrader, TraceabilityRewardGrader, EBTUTraceabilityGrader -======= -from tutorial.example_deep_finance.judge import PresentationQualityGrader, GroundingGrader ->>>>>>> origin/main From 6b9eb508a39d7c07ab83f231b326a56d3f508a9e Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Sun, 15 Feb 2026 20:16:36 +0800 Subject: [PATCH 54/56] "fix: resolve remaining merge conflicts" --- .gitignore | 3 - .../metric_helper/reward_metric_helper.py | 4 - tutorial/example_deep_finance/deep_finance.sh | 13 --- .../deep_finance_judge.py | 7 -- .../deep_finance_single.sh | 21 ---- .../example_deep_finance/judge/__init__.py | 7 -- .../judge/grounding/grader.py | 8 -- .../judge/grounding/prompt.py | 99 ------------------- .../yaml_template/deep_finance_template.yaml | 3 - 9 files changed, 165 deletions(-) diff --git a/.gitignore b/.gitignore index 030a7660..ed84d2a1 100644 --- a/.gitignore +++ b/.gitignore @@ -160,11 +160,8 @@ tutorial/example_deep_finance/scripts/* flash_attn-2.8.*.whl tutorial/example_deep_finance/prepare_data/* tutorial/example_deep_finance/judge/analytical_sufficiency/* -<<<<<<< HEAD tutorial/example_deep_finance/output_report/* dataset_gsm8k/* -======= ->>>>>>> origin/main .dockerignore benchmark_datasets diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py index 85b5706a..685798f3 100644 --- a/ajet/utils/metric_helper/reward_metric_helper.py +++ b/ajet/utils/metric_helper/reward_metric_helper.py @@ -83,14 +83,10 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str openjudge_graders = [ "presentation_quality", "grounding", -<<<<<<< HEAD "planning", "audit", "traceability", "cgcv" -======= - "planning" ->>>>>>> origin/main ] for grader_name in openjudge_graders: diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index 20ce984e..f6121655 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -15,13 +15,10 @@ JUDGE_CONCURRENCY=10 RM_WEIGHT=0.5 PRESENTATION_QUALITY_WEIGHT=0.25 GROUNDING_WEIGHT=0.25 -<<<<<<< HEAD CGCV_WEIGHT=0.0 # 不使用 CGCV,设为 0 AUDIT_WEIGHT=0.0 # 不使用 Audit,设为 0 TRACEABILITY_WEIGHT=0.0 # 不使用 Traceability,设为 0 EBTU_WEIGHT=0.0 # 不使用 EBTU,设为 0 -======= ->>>>>>> origin/main # 训练参数配置 NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 @@ -29,12 +26,9 @@ TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 -<<<<<<< HEAD # Env Service URL 配置 ENV_SERVICE_URL="http://127.0.0.1:8080" # 环境服务地址 -======= ->>>>>>> origin/main # 主目录(需要更改) export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new" @@ -70,13 +64,10 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ -e "s|{{PRESENTATION_QUALITY_WEIGHT}}|${PRESENTATION_QUALITY_WEIGHT}|g" \ -e "s|{{GROUNDING_WEIGHT}}|${GROUNDING_WEIGHT}|g" \ -<<<<<<< HEAD -e "s|{{CGCV_WEIGHT}}|${CGCV_WEIGHT}|g" \ -e "s|{{AUDIT_WEIGHT}}|${AUDIT_WEIGHT}|g" \ -e "s|{{TRACEABILITY_WEIGHT}}|${TRACEABILITY_WEIGHT}|g" \ -e "s|{{EBTU_WEIGHT}}|${EBTU_WEIGHT}|g" \ -======= ->>>>>>> origin/main -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ -e "s|{{RM_LLM}}|${RM_LLM}|g" \ -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ @@ -92,11 +83,7 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} echo "配置文件已生成: ${CONFIG_FILE}" -<<<<<<< HEAD echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, CGCV=${CGCV_WEIGHT}, Audit=${AUDIT_WEIGHT}, Traceability=${TRACEABILITY_WEIGHT}, EBTU=${EBTU_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" -======= -echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" ->>>>>>> origin/main #=============================================================================== # 3. 环境配置 diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py index a5b109b5..0548b385 100644 --- a/tutorial/example_deep_finance/deep_finance_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -103,15 +103,11 @@ def _setup_weights(self): self.w = { "rm": getattr(cfg, "rm_weight", 1.0) if cfg else 1.0, # RM Gallery 权重 "presentation_quality": getattr(cfg, "presentation_quality_weight", 0.25) if cfg else 0.25, -<<<<<<< HEAD "grounding": getattr(cfg, "grounding_weight", 0.0) if cfg else 0.0, # 引用规范性评估 "cgcv": getattr(cfg, "cgcv_weight", 0.25) if cfg else 0.25, # Citation-Grounded Claim Verification "audit": getattr(cfg, "audit_weight", 0.0) if cfg else 0.0, # Audit Grader: audit reward 引用逻辑审计 "traceability": getattr(cfg, "traceability_weight", 0.0) if cfg else 0.0, # 可追溯性/可核验性审计 (TVR) "ebtu": getattr(cfg, "ebtu_weight", 0.0) if cfg else 0.0, # Audit Grader: audit reward EBTU证据优先可追溯性审计 -======= - "grounding": getattr(cfg, "grounding_weight", 0.25) if cfg else 0.25, ->>>>>>> origin/main } # 归一化(注意:action_loop 是惩罚项,不参与归一化;rm 需要参与归一化) @@ -264,7 +260,6 @@ def extract_report_content(data: Dict) -> str: grader=GroundingGrader(model=model), mapper=lambda data: {"traj": data}, ), -<<<<<<< HEAD # CGCV: Citation-Grounded Claim Verification - 引用锤定的断言验证 "cgcv": GraderConfig( grader=CGCVGrader(model=model), @@ -285,8 +280,6 @@ def extract_report_content(data: Dict) -> str: grader=EBTUTraceabilityGrader(model=model), mapper=lambda data: {"traj": data}, ), -======= ->>>>>>> origin/main } def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> Tuple[float, bool]: diff --git a/tutorial/example_deep_finance/deep_finance_single.sh b/tutorial/example_deep_finance/deep_finance_single.sh index d663120a..67de294d 100644 --- a/tutorial/example_deep_finance/deep_finance_single.sh +++ b/tutorial/example_deep_finance/deep_finance_single.sh @@ -15,13 +15,10 @@ JUDGE_CONCURRENCY=10 RM_WEIGHT=0.5 PRESENTATION_QUALITY_WEIGHT=0.25 GROUNDING_WEIGHT=0.25 -<<<<<<< HEAD CGCV_WEIGHT=0.0 # 不使用 CGCV,设为 0 AUDIT_WEIGHT=0.0 # 不使用 Audit,设为 0 TRACEABILITY_WEIGHT=0.0 # 不使用 Traceability,设为 0 EBTU_WEIGHT=0.0 # 不使用 EBTU,设为 0 -======= ->>>>>>> origin/main # 训练参数配置 NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 @@ -29,13 +26,8 @@ TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 -<<<<<<< HEAD # Env Service URL 配置 ENV_SERVICE_URL="http://127.0.0.1:8080" # 环境服务地址 -======= -# 主目录(需要更改) -export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new" ->>>>>>> origin/main # 主目录(需要更改) export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new" @@ -81,13 +73,10 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ -e "s|{{PRESENTATION_QUALITY_WEIGHT}}|${PRESENTATION_QUALITY_WEIGHT}|g" \ -e "s|{{GROUNDING_WEIGHT}}|${GROUNDING_WEIGHT}|g" \ -<<<<<<< HEAD -e "s|{{CGCV_WEIGHT}}|${CGCV_WEIGHT}|g" \ -e "s|{{AUDIT_WEIGHT}}|${AUDIT_WEIGHT}|g" \ -e "s|{{TRACEABILITY_WEIGHT}}|${TRACEABILITY_WEIGHT}|g" \ -e "s|{{EBTU_WEIGHT}}|${EBTU_WEIGHT}|g" \ -======= ->>>>>>> origin/main -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ -e "s|{{RM_LLM}}|${RM_LLM}|g" \ -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ @@ -103,11 +92,7 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} echo "配置文件已生成: ${CONFIG_FILE}" -<<<<<<< HEAD echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, CGCV=${CGCV_WEIGHT}, Audit=${AUDIT_WEIGHT}, Traceability=${TRACEABILITY_WEIGHT}, EBTU=${EBTU_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" -======= -echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" ->>>>>>> origin/main #=============================================================================== @@ -151,13 +136,7 @@ export RAY_CLUSTER_MODE="multi_node" #=============================================================================== # 6. 主流程 #=============================================================================== -<<<<<<< HEAD log "单机调试模式: NNODES=${NNODES}, GPUS_PER_NODE=${GPUS_PER_NODE}" -======= -log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" -mkdir -p ${LOG_DIR} -mkdir -p $(dirname ${CONFIG_FILE}) ->>>>>>> origin/main #=============================================================================== # 6.1 Master 节点启动流程 diff --git a/tutorial/example_deep_finance/judge/__init__.py b/tutorial/example_deep_finance/judge/__init__.py index df18f7fd..235247f9 100644 --- a/tutorial/example_deep_finance/judge/__init__.py +++ b/tutorial/example_deep_finance/judge/__init__.py @@ -1,13 +1,10 @@ # 使得可以通过 from judge import PresentationQualityGrader 直接引用 from .grounding.grader import GroundingGrader from .presentation_quality.grader import PresentationQualityGrader -<<<<<<< HEAD from .cgcv.grader import CGCVGrader from .audit.grader import AuditGrader from .traceability.grader import TraceabilityRewardGrader from .ebtu.grader import EBTUTraceabilityGrader -======= ->>>>>>> origin/main # from .research_depth.grader import ResearchDepthGrader # from .research_breadth.grader import ResearchBreadthGrader @@ -15,8 +12,4 @@ # from .grounding.grader import GroundingGrader # from .research_breadth.grader import ResearchBreadthGrader # __all__ = ["PresentationQualityGrader", "GroundingGrader", "ResearchDepthGrader", "ResearchBreadthGrader"] -<<<<<<< HEAD __all__ = ["PresentationQualityGrader", "GroundingGrader", "CGCVGrader", "AuditGrader", "TraceabilityRewardGrader", "EBTUTraceabilityGrader"] -======= -__all__ = ["PresentationQualityGrader", "GroundingGrader"] ->>>>>>> origin/main diff --git a/tutorial/example_deep_finance/judge/grounding/grader.py b/tutorial/example_deep_finance/judge/grounding/grader.py index 66646dfc..42f5d141 100644 --- a/tutorial/example_deep_finance/judge/grounding/grader.py +++ b/tutorial/example_deep_finance/judge/grounding/grader.py @@ -195,20 +195,12 @@ def _compute_scores(self, obj: Dict[str, Any]) -> Tuple[float, str]: # 轻量惩罚:存在 invalid refs 会降低 reward # 每个 invalid 号扣 0.1,最多扣 0.5 -<<<<<<< HEAD # invalid_penalty = min(0.1 * invalid_ref_count, 0.5) invalid_penalty = 0 # final_reward: 综合分数(权重 0.5:0.5),再叠加 invalid 惩罚 final_reward = 0.5 * citation_coverage_score + 0.5 * grounding_score # final_reward = max(0.0, final_reward - invalid_penalty) -======= - invalid_penalty = min(0.1 * invalid_ref_count, 0.5) - - # final_reward: 综合分数(权重 0.5:0.5),再叠加 invalid 惩罚 - final_reward = 0.5 * citation_coverage_score + 0.5 * grounding_score - final_reward = max(0.0, final_reward - invalid_penalty) ->>>>>>> origin/main # 构建 reason good_citations = obj.get('good_citations', []) diff --git a/tutorial/example_deep_finance/judge/grounding/prompt.py b/tutorial/example_deep_finance/judge/grounding/prompt.py index 9d316146..337cf4bc 100644 --- a/tutorial/example_deep_finance/judge/grounding/prompt.py +++ b/tutorial/example_deep_finance/judge/grounding/prompt.py @@ -1,6 +1,5 @@ """Grounding Grader Prompt - 引用规范性评估""" -<<<<<<< HEAD # GROUNDING_SYSTEM_PROMPT = """你是一位"引用审计员",负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。 # ======================== @@ -116,41 +115,11 @@ - 禁止伪造来源;没有证据支撑的只能写“推测/假设”,不能用引用把推测包装成事实。 ## 输入 -======= -GROUNDING_SYSTEM_PROMPT = """你是一位"引用审计员",负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。 - -======================== -一、引用规范(以此为准) -======================== -1) 关键事实句必须引用: - - 关键事实句包括:数字(金额/比例/增速/同比环比/份额/排名等)、日期/期间、财务指标、估值倍数、明确事实结论、具体事件、具体公司/行业的可验证陈述、政策/条款等。 - - 不确定或推断性表述必须显式写“推测/可能/假设/预计/或有风险”等,不得用引用把推断包装成既定事实。 - -2) 引用位置规则(严格执行): - - 关键事实句必须在“句末”出现引用编号:[1] 或 [1][2](可以多个,但必须紧贴句末)。 - - 若引用出现在句中但句末没有引用编号,则该句仍按“缺引用(missing)”处理。 - -3) References 必须存在且可追溯: - - 报告末尾必须包含标题 `## References`(大小写/空格差异可容忍,但必须是一个清晰的 References 区块)。 - - 正文出现的每个 [n] 必须能在 References 中找到对应条目。 - -4) References 条目两种合法形式(必须满足其一): - A) URL 形式:`[n] 标题或简述 - https://...` - - URL 必须为可用的 http/https 链接,不能为空,也不能是 `javascript:void(0)` 之类的伪链接。 - B) no-url 形式:`[n] 简述,工具:,参数:,数据日期/报告期: - (no-url)` - - no-url 必须同时包含:工具名、参数、日期/报告期 三者(缺一即不合规)。 - - `javascript:void(0)` 等无效链接视为无效 URL(会进入 invalid_reference_nums),若要合规应改为 no-url 记录来源。 - -======================== -二、输入 -======================== ->>>>>>> origin/main 你会收到: - User Query - Evidence(从完整 trajectory 提取的工具调用/工具返回/用户补充信息) - AI Report(待审计报告,含正文与 References) -<<<<<<< HEAD 核对真实性时,以 Evidence 为准:只有在“明显矛盾/明显找不到依据”时才判 fake;无法确认则不要判 fake。 ## 输出(只输出 JSON,字段固定) @@ -178,74 +147,6 @@ - invalid_reference_nums 最多 5 个,多余截断。 - good_citations 最多 2 条,多余截断。 只输出 JSON,不要输出解释文字或 Markdown。 -======= -真实性核对原则: -- 以 Evidence 为准:只有在“明显矛盾”或“Evidence 明显找不到任何依据且该句仍把内容写成确定事实”时,才判 fake。 -- 无法确认/证据缺失/证据不充分时,不要判 fake(宁可不判)。 - -======================== -三、统计与判定口径(严格遵守) -======================== -【文本范围】 -- 只审计 AI Report 的“正文部分”(不包含 References 区块内部的文字)。 -- References 区块仅用于校验编号是否存在、格式是否合规、URL 是否有效。 - -【句子/条目如何计数】 -- “句子/条目”包括:普通句号/分号/换行分点(如列表项、段落中的 bullet)、表格中的单元格陈述(若表达了关键事实,也算关键事实句)。 -- 一句包含多个数字/多个事实点:仍按 1 条关键事实句计数(不要过度拆分)。 -- 同一句若重复出现多次(复制粘贴重复段落):按出现次数计数。 - -【关键事实句识别(务求稳定)】 -- 满足任一条件可视为关键事实句: - (a) 含具体数值/比例/排名/区间/估值倍数/财务指标; - (b) 含具体日期或期间(如 “2024Q3/2025年/截至XX日”); - (c) 对具体公司/行业/政策做了可验证的确定性陈述; - (d) 给出明确结论且呈确定口吻并可被证据支持/反驳。 - -【引用是否“句末”】【重要】 -- 句末引用指:该句最后的可见字符为一个或多个连续的 [n](允许中间无空格或有极少空格),例如: - - “……增长 20%[3]” - - “……增长 20% [3][4]” -- 若 [n] 后面仍有正文内容(哪怕很短),则不算句末引用。 - -【invalid_reference_nums 的定义】 -- 统计“正文中出现过”的编号 n(去重),若满足任一条件则判为 invalid: - (a) References 中不存在该编号条目; - (b) 该编号条目为 URL 形式但 URL 无效(空/非 http(s)/javascript:void(0) 等); - (c) 该编号条目为 no-url 形式但缺少 工具名/参数/日期(报告期) 任意之一。 -- invalid_reference_nums 输出按数字升序;最多 5 个,超出截断。 - -【missing_count 的定义】 -- 关键事实句中“句末没有任何 [n]”的数量(即使句中出现 [n] 也算 missing)。 - -【cited_key_facts 的定义】 -- 关键事实句中“句末包含至少一个 [n]”的数量(不要求该引用有效)。 - -【fake_count 的定义(只在明显时计数)】 -- 关键事实句若“句末带引用”,但与 Evidence 明显矛盾,或 Evidence 明显找不到任何依据且该句仍用确定口吻陈述为事实,计为 fake。 -- 若只是 Evidence 未覆盖/不充分/不确定,不计 fake。 - -【good_citations 的定义】 -- 从报告原文中抽取最多 2 条“引用做得正确”的关键事实句,要求同时满足: - - 是关键事实句; - - 句末有 [n]; - - 所有句末 [n] 在 References 中均存在且条目合法(URL 有效或 no-url 字段齐全)。 -- good_citations 是原文截取,不要加解释;最多 2 条,超出截断。 - -======================== -四、输出(只输出 JSON,字段固定) -======================== -{ - "total_key_facts": , - "cited_key_facts": , - "good_citations": ["...", "..."], - "missing_count": , - "fake_count": , - "invalid_reference_nums": [, ...] -} - -只输出 JSON,不要输出解释文字或 Markdown。确保 JSON 可被严格解析(双引号、逗号、方括号等格式正确)。 ->>>>>>> origin/main """ # ============================================================================= diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index 18a20de4..38aa82ed 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -12,13 +12,10 @@ ajet: # OpenJudge 权重配置 presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估 grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估 -<<<<<<< HEAD cgcv_weight: {{CGCV_WEIGHT}} # Citation-Grounded Claim Verification audit_weight: {{AUDIT_WEIGHT}} # 引用逻辑审计 traceability_weight: {{TRACEABILITY_WEIGHT}} # 可追溯性/可核验性审计 (TVR) ebtu_weight: {{EBTU_WEIGHT}} # Audit Grader: audit reward EBTU证据优先可追溯性审计 -======= ->>>>>>> origin/main rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 task_judge: # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) From efa7fac37cac2846c7561e81bb798db69406c3d2 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 24 Feb 2026 14:14:00 +0800 Subject: [PATCH 55/56] `refactor(graders)`: Renames the `aevaluate` method to the internally used `_aevaluate`. - Renames the `aevaluate` async method in multiple gradient classes to `_aevaluate` to identify it as an internal method. - Updates the method call names in `CGCVGrader` to match the renamed `_aevaluate`. - Maintains asynchronous call logic, enhancing code encapsulation. - Affected gradients include the `audit`, `cgcv`, `ebtu`, `grounding`, `presentation_quality`, and `traceability` modules. --- tutorial/example_deep_finance/judge/audit/grader.py | 2 +- tutorial/example_deep_finance/judge/cgcv/grader.py | 4 ++-- tutorial/example_deep_finance/judge/ebtu/grader.py | 2 +- tutorial/example_deep_finance/judge/grounding/grader.py | 2 +- .../example_deep_finance/judge/presentation_quality/grader.py | 2 +- tutorial/example_deep_finance/judge/traceability/grader.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tutorial/example_deep_finance/judge/audit/grader.py b/tutorial/example_deep_finance/judge/audit/grader.py index e7d7f5c9..3b8a9806 100644 --- a/tutorial/example_deep_finance/judge/audit/grader.py +++ b/tutorial/example_deep_finance/judge/audit/grader.py @@ -69,7 +69,7 @@ def create_default_model( return OpenAIChatModel(**kwargs) - async def aevaluate( + async def _aevaluate( self, traj: Any, **_: Any, diff --git a/tutorial/example_deep_finance/judge/cgcv/grader.py b/tutorial/example_deep_finance/judge/cgcv/grader.py index 2f65c6eb..cae97eb4 100644 --- a/tutorial/example_deep_finance/judge/cgcv/grader.py +++ b/tutorial/example_deep_finance/judge/cgcv/grader.py @@ -135,7 +135,7 @@ def create_default_model( return OpenAIChatModel(**kwargs) - async def aevaluate( + async def _aevaluate( self, traj: Any, **_: Any, @@ -239,7 +239,7 @@ def evaluate( loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - return loop.run_until_complete(self.aevaluate(traj, **kwargs)) + return loop.run_until_complete(self._aevaluate(traj, **kwargs)) def get_detailed_result( self, diff --git a/tutorial/example_deep_finance/judge/ebtu/grader.py b/tutorial/example_deep_finance/judge/ebtu/grader.py index 6ecea50a..b5ee1380 100644 --- a/tutorial/example_deep_finance/judge/ebtu/grader.py +++ b/tutorial/example_deep_finance/judge/ebtu/grader.py @@ -48,7 +48,7 @@ def __init__( response_format={"type": "json_object"}, ) - async def aevaluate(self, traj: Any, **kwargs: Any) -> GraderScore: + async def _aevaluate(self, traj: Any, **kwargs: Any) -> GraderScore: messages = coerce_to_messages_list(traj) # 输入有效性检查 diff --git a/tutorial/example_deep_finance/judge/grounding/grader.py b/tutorial/example_deep_finance/judge/grounding/grader.py index 42f5d141..86fc87a3 100644 --- a/tutorial/example_deep_finance/judge/grounding/grader.py +++ b/tutorial/example_deep_finance/judge/grounding/grader.py @@ -80,7 +80,7 @@ def create_default_model( return OpenAIChatModel(**kwargs) - async def aevaluate( + async def _aevaluate( self, traj: Any, **_: Any, diff --git a/tutorial/example_deep_finance/judge/presentation_quality/grader.py b/tutorial/example_deep_finance/judge/presentation_quality/grader.py index c440c3e4..80de5740 100644 --- a/tutorial/example_deep_finance/judge/presentation_quality/grader.py +++ b/tutorial/example_deep_finance/judge/presentation_quality/grader.py @@ -83,7 +83,7 @@ def create_default_model( return OpenAIChatModel(**kwargs) - async def aevaluate( + async def _aevaluate( self, report_content: str, user_query: str | None = None, diff --git a/tutorial/example_deep_finance/judge/traceability/grader.py b/tutorial/example_deep_finance/judge/traceability/grader.py index d8c8312c..42beee5b 100644 --- a/tutorial/example_deep_finance/judge/traceability/grader.py +++ b/tutorial/example_deep_finance/judge/traceability/grader.py @@ -36,7 +36,7 @@ def __init__( super().__init__(name=name, **kwargs) self.model = model - async def aevaluate(self, traj: Any, **kwargs: Any) -> GraderScore: + async def _aevaluate(self, traj: Any, **kwargs: Any) -> GraderScore: messages = coerce_to_messages_list(traj) if not messages: From f785b224a05dbf8b16d9294ae8473b6bbe6ede00 Mon Sep 17 00:00:00 2001 From: "taoshuchang.tsc" Date: Tue, 24 Feb 2026 14:26:02 +0800 Subject: [PATCH 56/56] Refactor (docs): Refactoring the Financial Deep Research Agent Training Tutorial Documentation - Removed the original detailed tutorial file "How to Train a Finance Deep Research Agent" to avoid redundancy. - Added a new document, deep_finance.md, with a clearer structure and more complete content. - Reorganized the system architecture and pipeline module descriptions to improve readability. - Provided detailed explanations of the two-stage deep research workflow design and citation specifications. - Systematically listed and categorized 19 financial tools and their functions. - Standardized reward design specifications, refining multi-dimensional scoring indicators and weight adjustments. - Supplemented key technical points such as training loop, engineering details, tool calls, and caching design. - Standardized document format, using tables and code blocks to display information, improving the document's professionalism and usability. --- tutorial/example_deep_finance/blog.md | 482 ------------------ tutorial/example_deep_finance/deep_finance.md | 319 ++++++------ 2 files changed, 155 insertions(+), 646 deletions(-) delete mode 100644 tutorial/example_deep_finance/blog.md diff --git a/tutorial/example_deep_finance/blog.md b/tutorial/example_deep_finance/blog.md deleted file mode 100644 index 4e93b474..00000000 --- a/tutorial/example_deep_finance/blog.md +++ /dev/null @@ -1,482 +0,0 @@ -# How to Train a Finance Deep Research Agent from Scratch - -> 本文介绍我们如何从零构建一个金融深度研究 Agent 的训练方案——包括 benchmark、工具环境、reward 体系和训练工程。代码和配置已开源。 - -## 1. 金融深度研究为什么难 - -在我们正式开始训练一个金融 Agent 之前,先聊聊为什么这件事不好做。 - -金融研究报告不是普通的文本生成任务。一份合格的投研报告至少要满足三个硬要求: - -- **证据可追溯**:报告里的每个数字、每个结论,读者都应该能找到数据来源。写"营收同比增长 15%",你得告诉我这个 15% 从哪来的。 -- **推理链完整**:不是把数据搬过来就行,要有归因分析、横向对比、逻辑推演。投资者要的是"为什么",不只是"是什么"。 -- **输出可用**:最终报告要结构清晰、结论先行、表格得当,拿到手就能用,而不是一大段需要人工整理的文字。 - -问题是,现在的大语言模型在这三个方面都有典型的失败模式: - -**失败一:看起来像真的,但证据不支持。** 模型写"该公司 2024 年 Q3 营收同比增长 23.5%",引用了一个编号 [3],你去查 References——发现 [3] 根本不存在,或者 [3] 对应的数据里写的是 18.7%。这不是模型不够聪明,而是它根本没学过"引用必须对得上"这件事。 - -**失败二:有数据但不分析。** 模型调了一堆工具、拿到了财报数据,然后在报告里几乎原样搬运——"2024 Q3 营收 XX 亿,净利润 XX 亿,毛利率 XX%"。这是摘抄,不是分析。研究员需要的是"毛利率为什么下降了?和行业趋势比怎样?未来会怎么走?"。 - -**失败三:写了分析但不可用。** 模型洋洋洒洒写了三千字,但没有摘要,没有结构化的对比表格,关键结论埋在第七段。一个忙碌的基金经理不会花时间去找结论在哪。 - -这三类问题的根源是一样的:**模型缺少一个"做研究 → 得到反馈 → 改进"的训练闭环。** SFT 的方式是找人标注"标准报告"让模型模仿,但金融研究没有唯一正确答案,标注成本极高,而且人工标注会限制模型的探索空间——它只能学到标注者的写法,学不到更好的策略。 - -我们的思路是:**用强化学习替代 SFT,通过自动化的多维度 reward 来驱动模型改进。** 但这带来了一个新问题——没有标准答案,reward 从哪来? - -后面几节将围绕这个核心问题展开。 - ---- - -## 2. 任务定义:把"研究"变成可训练的任务 - -在讲具体方案之前,有必要先定义清楚我们到底在训练什么。 - -**任务形式**很直接: - -- 输入:一个金融研究问题(比如"分析贵州茅台 2024 年的基本面和估值水平") -- 输出:一份 Markdown 格式的结构化研究报告,必须带引用 -- 过程:Agent 可以多轮调用金融工具来收集数据 - -关键区别在于,这不只是一个写作任务。Agent 的**动作空间**很大——它需要决定:先查什么数据?用哪个工具?查到的数据够不够?什么时候开始写?哪些信息需要交叉验证?报告的结构怎么组织? - -这些决策传统上由研究员的经验驱动,我们希望通过 RL 让模型自己学会。 - -整个训练闭环长这样: - -``` -金融问题 → Agent 多轮调用工具 → 生成研究报告 → 多维度自动评分 → GRPO 策略更新 → 下一轮 -``` - -接下来分别讲这个闭环的每个组件。 - ---- - -## 3. 评测基准:什么叫"好研究" - -没有评测基准,你不知道自己训练出来的是"更会写"还是"更会编"。 - -### 分 domain 设计 - -金融研究不是一个同质化的任务。分析一只个股和解读一个宏观政策,需要的能力完全不同。我们的 benchmark 覆盖五类真实场景: - -| Domain | 典型问题 | 考察重点 | -|--------|---------|---------| -| 宏观分析 | "分析当前货币政策对 A 股市场的影响" | 政策理解、传导机制推演 | -| 行业研究 | "新能源汽车行业 2024 年竞争格局分析" | 横向对比、趋势判断 | -| 事件解读 | "XX 公司被 ST 的影响分析" | 事件因果链、影响范围评估 | -| 个股分析 | "贵州茅台基本面和估值分析" | 多维数据整合、综合判断 | -| 公司研究 | "XX 公司治理结构和经营质量" | 深度挖掘、风险识别 | - -分 domain 的好处是能看到细粒度的进步——可能模型在个股分析上进步明显,但在宏观分析上还差火候。单一指标会掩盖这些差异。 - -### 评测维度 - -benchmark 的评估沿三条线展开,和后面的训练 reward 保持对齐: - -- **事实性**:引用是否真实,关键结论是否有据可查 -- **分析充分性**:分析的深度、覆盖面、逻辑链条是否完整 -- **写作可用性**:报告结构是否清晰、信息是否易获取 - ---- - -## 4. 工具与工作流 - -### 4.1 金融工具环境 - -没有工具,DeepResearch 就只是个写作任务。Agent 必须能主动获取数据。 - -我们的工具环境集成了 19 个金融工具,通过 MCP 协议与 Agent 交互。设计思路不是"给模型越多工具越好",而是按照研究流程来组织: - -**第一层:找到目标** -- `extract_entities_code`:从自然语言里识别金融实体并查找代码(比如"茅台"→ 600519) -- `dashscope_search`:互联网搜索,用于确认信息或补充背景 - -**第二层:获取结构化数据** - -这是工具的主体。我们接入了同花顺的 14 个维度的专项数据接口: - -| 工具 | 数据范围 | -|------|---------| -| `crawl_ths_finance` | 财务分析(财务指标、杜邦分析、资产负债构成) | -| `crawl_ths_worth` | 盈利预测(业绩预测、研报评级) | -| `crawl_ths_operate` | 经营分析(主营构成、客户供应商、董事会评述) | -| `crawl_ths_holder` | 股东研究(十大股东、控股层级) | -| `crawl_ths_news` | 新闻公告(新闻与股价联动、研报列表) | -| `crawl_ths_field` | 行业对比(行业地位、行业新闻) | -| ... | 还有股本结构、资本运作、分红融资、大事件、概念题材等 | - -另外还有 `history_calculate`(A 股历史股价分析,支持自然语言提问)、`execute_code`、`execute_shell` 等通用能力。 - -**为什么不只给一个搜索引擎?** 因为金融研究需要的是精确的结构化数据。"2024Q3 毛利率"这种信息,搜索引擎未必能给出准确答案,但财务分析接口可以直接返回。工具的设计原则是:让 Agent 能拿到和人类分析师一样精度的数据。 - -**MCP 协议**的选择也值得说一下:所有工具通过标准化的 MCP 接口调用,工具的输入输出格式统一,Agent 不需要为每个工具学习不同的调用方式。更重要的是,工具返回的内容可以直接作为引用来源——每个工具调用都有明确的时间戳和参数记录,报告中的引用可以追溯到具体是哪次工具调用返回了什么数据。 - -### 4.2 Agent 工作流 - -Agent 本身基于 AgentScope 的 ReActAgent 实现,核心逻辑在 `deep_finance.py` 里: - -```python -for step in range(max_steps): - # Agent 推理:基于当前上下文生成回复(可能包含 tool_calls) - reply_message = await agent(agent_input) - - # 环境执行:工具调用并返回结果 - obs, reward, terminate, info = env.step(action={"content": content_text, "role": "assistant"}) - - # 更新对话历史(完整保留 assistant + tool_calls + tool_results) - conversation_history.append(current_assistant_msg) - conversation_history.extend(tool_result_msgs) - - # 检查终止条件 - if terminate or context_overflow: - break -``` - -每一轮,Agent 生成回复(可能包含工具调用请求),环境执行工具并返回结果,Agent 再基于新信息决定下一步。整个对话历史完整保留,包括每次工具调用的输入和输出。 - -**两阶段研究流程**是通过 system prompt 引导的: - -1. **先大纲后调研**:Agent 先输出研究大纲(一级/二级标题 + 每节要回答的关键问题),这时不调用工具。大纲确定后,按大纲逐段调研,每轮最多调 3 个工具,调完做小结再决定下一步。 -2. **分析与报告生成**:数据充分后写完整报告。如果写的时候发现某个结论缺证据,允许追加 1-2 轮工具调用补充取证。 - -这个流程在 RL 训练前是 prompt 硬编码的。一个有意思的观察是:训练之后,Agent 是否真的会严格遵循"先大纲后调研"?还是会自己发展出更高效的策略?这个问题我们在结果部分会回来讨论。 - -*上下文管理**也是个实际问题。多轮工具交互会迅速消耗 context window,代码里专门做了 `context_overflow` 检测——一旦发现上下文快溢出,立即结束交互,避免工具调用被截断导致阻塞。 - ---- - -## 5. 训练设计 - -这一章是整篇文章的核心。我们把训练拆成三件事:数据从哪来、reward 怎么算、工程上怎么跑得起来。 - -### 5.1 训练数据:只需要问题,不需要答案 - -这是 RL 训练和 SFT 在数据需求上的根本差异。 - -SFT 需要 (question, gold_answer) 配对。在金融研究场景下,这意味着你得请专业研究员针对每个问题写一份完整的标准报告。成本高不说,更大的问题是标注的答案会成为模型的天花板——模型只能学到标注者的写法,但这不一定是最好的研究策略。 - -RL 只需要 question 本身。模型自己去探索怎么用工具、怎么组织分析、怎么写报告,quality 由 reward 来判断。这意味着模型有可能发展出比人类标注更高效的研究策略。 - -当然,问题的设计本身也很讲究: - -- **domain 分布**要均衡:五类场景(宏观/行业/事件/个股/公司)都要覆盖到,不能让模型只会分析茅台。 -- **难度梯度**要合理:从"XX 股票最近一周涨了多少"这种简单事实查询,到"分析 XX 行业的竞争格局和投资机会"这种需要多维度综合分析的复杂问题,都要有。 -- **多样性**要够:避免同类问题重复出现导致模型在特定模式上过拟合。实际操作中,我们对训练集做了去重和相似度过滤。 - -另外一个实用的做法:我们把训练集和验证集的参考答案单独维护(用于 RM Gallery 的分析充分性评估),但这些参考答案不直接用于 SFT——它们只是 reward 计算的输入之一。 - -### 5.2 Reward:三条主线驱动 Agent 进化 - -这是整个方案里最核心的设计决策。 - -先说为什么不能只用一个 reward。 - -如果只优化"写作质量",模型会学到一个捷径:生成格式完美但内容空洞的报告。标题层级分明、表格排版精美、结论清晰——但数据是编的。如果只优化"事实性",模型会走向另一个极端:大段摘抄工具返回的原始数据,不做任何分析。正确,但没用。 - -**单一 reward 最大的风险是 reward hacking**——模型会找到一个你没预料到的捷径来刷分,而不是真正变好。多维度 reward 从不同角度约束和激励 Agent,让"走捷径"变得更难。 - -我们的 reward 分三条主线:事实性、分析充分性、写作能力。下面逐一展开。 - -#### 主线一:事实性——约束幻觉 - -这是金融研究最底线的要求。我们从三个粒度来保障事实性,分别对应三个 Grader。 - -**Grounding(引用规范性)**——宏观层面的引用审计。 - -思路很直接:把报告里所有"关键事实句"找出来(包含数字、日期、财务指标、确定性陈述的句子),检查两件事: - -1. 这些句子有没有引用标记 `[n]`?(引用覆盖率) -2. 引用指向的内容和工具实际返回的证据一致吗?(引用真实性) - -最终分数很简单: - -``` -score = 0.5 × 引用覆盖率 + 0.5 × 引用真实性 -``` - -**Evidence-Backed Trace Units(证据溯源)**——原子断言级别的审计。 - -Grounding 看的是"有没有引用",EBTU 看的是"引用能不能真正支撑这句话"。它的工作方式更细致: - -1. 把报告拆成一个个原子断言(Trace Units)——每个包含一个具体事实的最小陈述单元 -2. 每个断言标记"硬度":`hard`(确定性事实)或 `soft`(明确标注为推测的) -3. 对每个 hard 断言,在工具返回的 Evidence 里寻找锚点——要求精确到 step 编号和原文引用 -4. 给出裁决:`supported`(锚点支持)/ `contradicted`(锚点矛盾)/ `no_evidence`(找不到)/ `unclear`(证据不足)/ `speculative_ok`(是推测但标注正确) - -EBTU 有一个核心设计原则叫**证据优先**(Evidence-first):审计官必须先给出证据锚点(哪个 step、原文引了什么),再下裁决。严禁先下结论再找证据。这个顺序很重要,它避免了 LLM 做 judge 时常见的"先入为主"偏差。 - -评分公式是确定性的 Python 代码: - -```python -base = (supported - 1.4 * contradicted - 0.9 * no_evidence - 0.4 * unclear) / hard_total -misattrib_factor = max(0, 1 - 0.7 * misattrib_rate) # 错误归因惩罚 -selection_factor = min(1, extracted_units / expected) # 覆盖率因子 -cov_factor = 0.65 + 0.35 * digit_coverage # 数字覆盖 -score = base * misattrib_factor * selection_factor * cov_factor -``` - -注意 contradicted 的惩罚系数是 1.4——比 no_evidence 的 0.9 高很多。这是有意为之的:在金融研究里,"说了但说错"比"没说"严重得多。 - -**Audit(引用逻辑审计)**——逻辑蕴含级别的精审。 - -Audit 关注的是一个更微妙的问题:即使引用存在、数据也对得上,引用的**逻辑关系**是否成立? - -它对报告中每个引用标记 `[n]` 做三步验证: - -1. **提取**:锁定 `[n]` 支撑的那句陈述(Claim) -2. **溯源**:在 Evidence 中找到 `[n]` 对应的原始内容 -3. **比对**:数字是否一致?语气是否一致?因果关系是否成立? - -判决分五级: - -- `Supported`:证据充分,逻辑闭环 -- `Overstated`:夸大其词——证据说"约 15%",报告写成"高达 15.2%" -- `Contradicted`:事实冲突——证据说下降,报告写上升 -- `Hallucinated`:无中生有——报告里的数据在证据中根本找不到 -- `Irrelevant`:引用了,但引用的内容和陈述无关 - -其中 Overstated 是我们特别关注的。相比直接捏造,"夸大"更隐蔽,也更常见。把"可能增长"改成"确定增长",把"约 15%"精确到"15.2%"——这些细微的语气偷换在金融报告里可能导致误导性的投资决策。 - -评分很克制:`score = Supported 数量 / 总引用数`。 - -**三个 Grader 各看什么?** - -简单总结:Grounding 看"有没有引用",EBTU 看"证据能不能对上",Audit 看"逻辑关系对不对"。三个粒度从粗到细,互不重叠。 - -#### 主线二:分析充分性——鼓励深度 - -事实性解决了"不能瞎编"的问题,但光不编是不够的——模型可以选择只写最容易验证的事实,回避所有需要分析判断的内容。 - -分析充分性的评估使用 RM Gallery 的 `finance_composition` 评估器。它的工作方式是:拿到模型生成的报告和一份参考答案(不是标准答案,而是一个"可接受的分析"作为锚点),通过独立的 Judge LLM 对比评估。 - -评估会按金融 domain 分域进行——个股分析的"充分"和行业研究的"充分"标准不同。核心看三个维度: - -- **分析深度**:对核心问题的挖掘是否足够深入,有没有只停留在表面 -- **覆盖面**:是否覆盖了问题涉及的多个分析维度(基本面、财务、估值、行业、新闻等) -- **逻辑性**:推理链条是否完整,结论是否有据可依 - -输出是一个 [0, 1] 的归一化分数。 - -#### 主线三:写作能力——让报告可用 - -这条线关注的是:报告拿到手,能不能直接用? - -PresentationQuality Grader 评估 8 项子指标,按 1/3/5 分制打分: - -| 类别 | 指标 | 5 分标准 | -|------|------|----------| -| 可扫描性 | 结论先行 | 开头有独立摘要,不滚动就能获取主结论 | -| | 结构导航 | 层级分明(H1/H2/H3),长文有清晰小标题 | -| | 视觉重点 | 精准使用加粗/斜体强调核心洞察 | -| 信息结构化 | 密集信息解构 | 复杂数据用表格/列表呈现 | -| | 对比对齐 | A vs B 用表格,维度横向可比 | -| | 一致性与渲染 | 格式统一,Markdown 渲染正确 | -| 编辑清晰度 | 论证链可视化 | 主张→证据→结论的链条清晰 | -| | 风险与行动 | 独立板块列出风险和下一步建议 | - -总分 = 8 项之和 / 40,归一化到 [0, 1]。 - -这个 Grader 有一个重要约束:**严格不评估内容对错**。事实准不准、分析深不深,那是前两条主线的事。PresentationQuality 只看呈现。这个边界必须守住,否则维度之间会产生耦合,模型很容易找到"跨维度套利"的空间。 - -另外,我们专门加了反刷分机制:空表格、无意义重复列表、为格式而格式的行为,直接判最低分。 - -#### 关键架构决策:LLM 提取 + 代码打分 - -在实现上,我们做了一个对 RL 训练至关重要的决策:**LLM 只负责结构化提取,分数由 Python 代码确定性计算。** - -为什么?因为 LLM 直接出分的方差太大。同一份报告让同一个 Judge LLM 评两次,分数可能差 0.2。这在做阅读理解评测时还能接受(多次取平均),但在 RL 训练中是灾难性的——reward 信号不稳定意味着梯度方向不确定,模型很难学到正确的策略。 - -我们的做法是: - -1. LLM(如 qwen-max)负责"理解"——提取原子断言、标注锚点、做裁决分类、识别引用关系 -2. Python 代码负责"算分"——基于 LLM 输出的结构化 JSON,用确定性的公式计算最终 score - -以 EBTU 为例,LLM 输出的是每个 Trace Unit 的 verdict 和 anchor 信息(结构化 JSON),最终分数完全由 `_compute_score` 函数用上面那个公式算出来。LLM 对分数没有任何直接影响力——它只是一个"结构化提取器"。 - -这个设计的另一个好处是:你可以方便地调整评分公式的参数(比如 contradicted 的惩罚系数),而不需要重新设计 prompt 或重新跑 LLM。 - -#### 工具调用惩罚 - -除了三条主线的 reward,我们额外加了一个工具调用惩罚项: - -| 工具调用次数 | 惩罚 | -|-------------|------| -| 0 次 | -1.0 | -| 1-2 次 | -0.5 | -| ≥3 次 | 0.0(无惩罚) | - -逻辑很简单:你是来做研究的,不调工具就是没在研究。这个惩罚在训练早期很有用,能快速让模型学会"先查资料再写报告",而不是直接凭空生成。 - -#### 权重融合与维度博弈 - -最终 reward 是各维度加权求和: - -``` -final_reward = Σ(weight_i × grader_i_score) + tool_penalty -``` - -默认配置是:分析充分性 0.5,呈现质量 0.25,引用规范性 0.25。EBTU 和 Audit 默认权重 0.0(可选启用)。所有正权重归一化到 sum=1。 - -这个权重分配不是拍脑袋定的,而是迭代调出来的。几个观察: - -- 分析充分性权重不能太低,否则模型倾向于写"正确的废话"——每个数字都有引用,但不做任何有价值的分析 -- 呈现质量权重不能太高,否则模型会花大量精力在格式上,而不是在内容上 -- Grounding 的权重对"引用习惯"的建立很重要——权重为 0 的时候,模型几乎不会主动加引用 - -维度之间确实存在张力。"呈现好但事实差"= 漂亮的废话,"事实好但分析浅"= 正确的摘抄。理想状态是各维度同步提升,但实际训练中往往会看到此消彼长的阶段。消融实验能帮你理解每个维度的贡献,我们在结果部分会给出具体数据。 - -### 5.3 训练工程:不做这些你根本跑不起来 - -如果你只看前面两节,可能觉得"数据有了、reward 设计好了,开训就完事了"。现实会教你做人。 - -Agent RL 训练和普通文本 RL 的最大区别是:**每个 rollout 都需要真实的工具交互**。这带来三个硬工程问题。 - -**成本问题:工具调用又贵又慢。** - -每个样本需要 4-8 轮工具交互,每轮可能调 1-3 个工具,每个工具调用耗时几秒到十几秒。一个 batch 32 个样本、每个样本 rollout 4 次(GRPO 的 group size),就是 128 条轨迹、可能上千次工具调用。如果串行执行,一个 batch 要跑几个小时。 - -解决方案有两层: - -- **EnvService 解耦**:工具执行从训练进程中独立出来,作为单独的服务运行。训练进程只需要发请求、收结果,不需要等待工具执行完成。多个 rollout 可以并发进行。 -- **MongoDB 工具缓存**:同一个 query 在不同 rollout 中调用相同工具时,直接从缓存返回。这一条就把训练速度提了几倍——因为 GRPO 的 group 内多条轨迹经常会调相同的工具。 - -**不确定性问题:外部数据会变。** - -今天搜"茅台最新消息"和明天搜的结果不一样。如果不做缓存,同一个训练任务在不同时间点跑出来的 reward 就不同,这让实验无法复现,也让训练信号变得更嘈杂。 - -工具缓存同时解决了成本和确定性两个问题——同一个 (tool_name, args) 组合只会真正执行一次,后续全部命中缓存。 - -**鲁棒性问题:真实工具环境充满意外。** - -这一点怎么强调都不过分。实际跑起来你会遇到: - -- 工具 API 超时(网络抖动、服务端负载高) -- API 限流(429 Too Many Requests) -- 返回格式异常(字段缺失、编码错误) -- JSON 解析失败(LLM 输出的 JSON 不合法——多了个逗号、少了个引号、中间截断了) -- Judge LLM 调用也会出错(同样的网络问题 + 额度问题) - -我们的应对是层层兜底: - -```python -# Judge 评估带指数退避重试 -for attempt in range(max_retries): - try: - runner = self._create_runner_in_loop() - result = await runner.arun(dataset) - return - except Exception as e: - if is_connection_error and attempt < max_retries - 1: - wait_time = 2 ** attempt # 指数退避: 1s, 2s, 4s - await asyncio.sleep(wait_time) - continue - raise -``` - -JSON 解析方面,每个 Grader 都有自己的 `strict_load_json` 实现,能处理常见的 LLM 输出问题:未转义的换行符、trailing comma、JSON 被截断等。解析失败不会让整个训练挂掉——Grader 会返回 score=0 并在 reason 里记录错误信息,方便事后排查。 - -**多机多卡训练**基于 Ray 集群。训练脚本会在 Master 节点启动 Ray Head 和训练任务,Worker 节点自动加入集群。实际跑的时候,NCCL 超时是最常见的坑——我们把相关超时配置放大到了合理的范围,同时做了异步错误处理,避免单点故障拖垮整个集群。 - -最后说一个经常被忽视的问题:**GradingRunner 的事件循环绑定**。OpenJudge 的 GradingRunner 内部使用 Semaphore 做并发控制,而 Semaphore 会绑定到创建时的事件循环。在多线程训练环境中,如果复用 Runner 实例,会因为事件循环不匹配而报错。代码里的解法是每次评分都创建新的 Runner 和事件循环: - -```python -loop = asyncio.new_event_loop() -asyncio.set_event_loop(loop) -try: - loop.run_until_complete(run_with_retry()) -finally: - loop.close() - asyncio.set_event_loop(None) -``` - -这类工程细节不会出现在论文里,但不处理好训练就跑不起来。 - ---- - -## 6. 实验结果 - -> TODO: 训练完成后补充具体数据。以下是我们预设的实验维度。 - -### 分 domain 评测 - -我们在五类金融场景上分别评测训练前后的表现,报告各 reward 维度的分数变化。重点观察: - -- 哪些 domain 提升最大?(通常工具使用密集的个股分析提升明显) -- 哪些 domain 进步有限?(宏观分析更依赖推理能力,工具帮助有限) -- 事实性和分析充分性是否同步提升?还是存在此消彼长? - -### 消融实验 - -逐一关闭每个 reward 维度,观察模型行为的变化: - -- 关闭 Grounding → 模型是否不再主动加引用? -- 关闭分析充分性 → 模型是否退化为数据摘抄? -- 关闭 PresentationQuality → 报告的可读性下降多少? - -这些消融能帮助理解每个维度的实际贡献,也能指导权重调优。 - -### Reward 曲线 - -训练过程中各维度 reward 的变化趋势。几个值得关注的现象: - -- 工具调用惩罚通常在前几步就收敛到 0(模型很快学会"先查再写") -- Grounding 和 PresentationQuality 的提升通常领先于分析充分性 -- 各维度之间是否存在阶段性的"抢跑"和"追赶" - -### Case Study - -训练前后的行为对比,是最直观的结果展示: - -- **工具调用策略**:训练前可能一上来就调 5 个工具,训练后变成有目标的分批调用 -- **证据利用率**:训练前大量工具返回被忽略,训练后更多数据被整合进报告 -- **引用习惯**:训练前很少加引用标记,训练后关键事实句基本都有出处 -- **报告结构**:训练前结构松散,训练后更倾向于先给结论再展开 - -### 泛化能力 - -一个重要的问题:在金融数据上训练的 Agent,能不能在其他深度研究场景上也有提升?比如技术调研、政策分析、学术文献综述。如果 reward 设计的维度足够通用(事实性、分析深度、写作质量本身不限于金融),泛化是有可能的。 - ---- - -## 7. 经验教训:如果你也要从零训练一个金融 DeepResearch Agent - -### 最希望别人少踩的 5 个坑 - -**1. 没有闭环评测就会自嗨。** - -训练曲线在涨,不代表模型真的在变好。我们早期试过只看平均 reward,结果发现模型学会了一种"安全模式"——生成短报告、少引用、不做复杂分析,反而总分还行。直到我们把 benchmark 分 domain 跑了一遍才发现问题。 - -**2. Reward 太粗会促成投机。** - -单一 reward 维度的模型极其擅长找捷径。我们见过:只优化写作分的模型会生成格式完美但内容空洞的报告;只优化引用覆盖率的模型会在每句话后面都加 [1],但所有引用指向同一个来源。多维度 reward 不是"更好",而是"必须"。 - -**3. 数据不清洗会放大坏习惯。** - -如果训练数据里有大量"无效 query"(太模糊、无法用工具回答的),模型会学到"遇到难题就跳过工具直接编"的策略。query 的质量控制和 reward 设计同等重要。 - -**4. 不做确定性保障,训练不可复现。** - -外部工具返回不同结果 + LLM Judge 给分有波动 → 同一个实验跑两遍结果完全不同。MongoDB 缓存解决了工具端的确定性,LLM 提取 + 代码打分解决了 Judge 端的确定性。这两个决策不是优化,是前提。 - -**5. 只追指标不看行为,会"训练出怪物"。** - -Reward 数字好看但实际报告不可用的情况真的存在。比如模型学会了在报告末尾疯狂堆叠 References 来刷 Grounding 分——形式上每句话都有引用,但引用的质量和相关性很差。定期做 Case Study、人工审查生成的报告,和看指标一样重要。 - -### 最小可复现配方 - -如果你想从零开始搭一个类似的系统,最小的组件清单是: - -| 组件 | 最简版本 | 对应文件 | -|------|---------|----------| -| 数据 | 50-100 个覆盖不同 domain 的金融问题 | `deep_finance_reader.py` | -| 工具 | 搜索 + 至少 3-5 个结构化数据工具 | `tool_prompt_builder.py` | -| Workflow | ReAct 循环 + 上下文管理 | `deep_finance.py` | -| Reward | 至少 2 个维度(事实性 + 分析充分性) | `deep_finance_judge.py` | -| 训练 | GRPO + 单机调试脚本 | `deep_finance_single.sh` | - -先跑通单机调试模式(`--backbone="debug"`),确认 workflow → judge → reward 闭环没问题后,再上多机训练。不要一上来就搞分布式。 - ---- - -*如果这篇文章对你有帮助,欢迎 Star 我们的 [AgentJet 项目](https://github.com/modelscope/AgentJet)。相关代码在 `tutorial/example_deep_finance/` 目录下。* diff --git a/tutorial/example_deep_finance/deep_finance.md b/tutorial/example_deep_finance/deep_finance.md index 6d503263..33820bff 100644 --- a/tutorial/example_deep_finance/deep_finance.md +++ b/tutorial/example_deep_finance/deep_finance.md @@ -8,17 +8,24 @@ DeepFinance 是基于 AgentJet 框架构建的金融深度研究 Agent 训练方 **训练闭环**: -``` +```plain 金融问题 → Agent 调用工具收集数据 → 生成研究报告 → 多维度 Judge 评分 → GRPO 策略更新 → 下一轮生成 ``` ---- +------ -## 系统架构 +## Pipeline 整个训练流水线由 4 个核心模块组成: -``` +| 模块 | 文件 | 职责 | +| ------------ | ---------------------------------- | --------------------------------------------------- | +| **Reader** | `deep_finance_reader.py` | 加载 JSON 训练数据,组装 System Prompt + User Query | +| **Workflow** | `deep_finance.py` | 定义 ReAct Agent 的多轮交互逻辑,维护对话历史 | +| **Judge** | `deep_finance_judge.py` + `judge/` | 多维度奖励评分(核心创新) | +| **配置** | `deep_finance.yaml` / `*.sh` | 训练参数、奖励权重、环境配置 | + +```plain ┌─────────────────────────────────────────────────────────────┐ │ AgentJet 训练框架 │ │ │ @@ -38,7 +45,7 @@ DeepFinance 是基于 AgentJet 框架构建的金融深度研究 Agent 训练方 │ v │ │ ┌────────────────────────┐ │ │ │ DeepFinanceJudge │ │ -│ │ 5 维 Reward 评分 │ │ +│ │ 多 维 Reward 评分 │ │ │ │ (基于 OpenJudge) │ │ │ └────────────┬───────────┘ │ │ │ │ @@ -50,14 +57,68 @@ DeepFinance 是基于 AgentJet 框架构建的金融深度研究 Agent 训练方 └─────────────────────────────────────────────────────────────┘ ``` -| 模块 | 文件 | 职责 | -|------|------|------| -| **Reader** | `deep_finance_reader.py` | 加载 JSON 训练数据,组装 System Prompt + User Query | -| **Workflow** | `deep_finance.py` | 定义 ReAct Agent 的多轮交互逻辑,维护对话历史 | -| **Judge** | `deep_finance_judge.py` + `judge/` | 多维度奖励评分(核心创新) | -| **配置** | `deep_finance.yaml` / `*.sh` | 训练参数、奖励权重、环境配置 | +------ + +## Workflow设计 + +### 两阶段深度研究流程 + +Agent 的 System Prompt(`prompt/finance_analyst_prompt.md`)要求遵循两阶段研究方法: + +**第一阶段:先大纲后调研** + +1. 理解用户问题类型(个股分析/行业研究/事件解读/宏观分析/股票检索) +2. **先输出研究大纲**(一级/二级标题 + 每节的 Key Questions),此阶段不调用工具 +3. 按大纲逐段调研,每轮调用工具后做小结 ---- +**第二阶段:深度分析与报告生成** + +1. 当数据充分后,基于真实数据生成 Markdown 格式研究报告 +2. 写作中发现证据不足时允许追加 1-2 轮工具调用补充取证 +3. 报告末尾添加 `[TASK_COMPLETED]` 标记 + +### 引用规范 + +Agent 被要求使用学术论文风格的引用标注: + +- 所有关键事实句句末必须添加引用编号 `[n]` +- 报告末尾必须包含 `## References` 小节 +- 引用必须可追溯到实际工具返回的数据,禁止伪造 + +------ + +## 工具体系 + +DeepFinance 集成了 **19 个金融工具**,通过 MCP(Model Context Protocol)协议与 EnvService 交互,覆盖金融研究的完整数据需求。 + +| 类别 | 工具 | 功能 | +| ------------------ | ----------------------- | ----------------------------------- | +| **实体与计算** | `extract_entities_code` | 从自然语言中提取金融实体并查找代码 | +| | `history_calculate` | A股历史股价分析(支持自然语言提问) | +| **通用能力** | `dashscope_search` | 互联网搜索 | +| | `execute_code` | Python 代码执行 | +| | `execute_shell` | Shell 命令执行 | +| **同花顺专项数据** | `crawl_ths_company` | 上市公司基本资料 | +| | `crawl_ths_holder` | 股东研究信息 | +| | `crawl_ths_operate` | 经营分析信息 | +| | `crawl_ths_finance` | 财务分析信息 | +| | `crawl_ths_worth` | 盈利预测信息 | +| | `crawl_ths_news` | 新闻公告信息 | +| | `crawl_ths_concept` | 概念题材信息 | +| | `crawl_ths_equity` | 股本结构信息 | +| | `crawl_ths_capital` | 资本运作信息 | +| | `crawl_ths_position` | 主力持仓信息 | +| | `crawl_ths_bonus` | 分红融资信息 | +| | `crawl_ths_event` | 公司大事信息 | +| | `crawl_ths_field` | 行业对比信息 | + +工具调用规范: + +- 每次最多调用 **3 个工具**,采用多轮次渐进式调研 +- Agent 必须先搜索确认信息(如股票代码),再进行深度查询 +- 每轮工具调用后先做小结,再决定下一步调研方向 + +------ ## 奖励设计(Reward Design) @@ -65,7 +126,7 @@ DeepFinance 是基于 AgentJet 框架构建的金融深度研究 Agent 训练方 ### 总体公式 -``` +```plain final_reward = Σ(w_i × grader_i_score) + tool_penalty ``` @@ -73,42 +134,44 @@ final_reward = Σ(w_i × grader_i_score) + tool_penalty ### 5 个评分维度总览 -| 维度 | 名称 | 评估对象 | 核心问题 | -|------|------|---------|----------| -| **分析充分性** | RM Gallery | 报告整体质量 | 分析是否充分?逻辑是否合理? | -| **呈现质量** | PresentationQuality | 报告排版与结构 | 读者体验好不好?信息是否易获取? | -| **引用规范性** | Grounding | 引用的覆盖与真实性 | 关键事实是否都有引用?引用是否真实? | -| **证据溯源** | EBTU | 原子断言的证据锚定 | 每个数字/事实能否追溯到工具返回的原始数据? | -| **引用逻辑审计** | Audit | 引用的逻辑蕴含关系 | 引用是否真正支撑了对应的陈述?有没有夸大或捏造? | +| 维度 | 名称 | 评估对象 | 核心问题 | +| ---------------- | ------------------- | ------------------ | ------------------------------------------------ | +| **分析充分性** | RM Gallery | 报告整体质量 | 分析是否充分?逻辑是否合理? | +| **呈现质量** | PresentationQuality | 报告排版与结构 | 读者体验好不好?信息是否易获取? | +| **引用规范性** | Grounding | 引用的覆盖与真实性 | 关键事实是否都有引用?引用是否真实? | +| **证据溯源** | EBTU | 原子断言的证据锚定 | 每个数字/事实能否追溯到工具返回的原始数据? | +| **引用逻辑审计** | Audit | 引用的逻辑蕴含关系 | 引用是否真正支撑了对应的陈述?有没有夸大或捏造? | 默认权重配置(可在 shell 脚本中调整): ```bash RM_WEIGHT=0.5 # 分析充分性 -PRESENTATION_QUALITY_WEIGHT=0.25 # 呈现质量 -GROUNDING_WEIGHT=0.25 # 引用规范性 -EBTU_WEIGHT=0.0 # 证据溯源(可选启用) +PRESENTATION_QUALITY_WEIGHT=0.2 # 呈现质量 +GROUNDING_WEIGHT=0.1 # 引用规范性 +EBTU_WEIGHT=0.2 # 证据溯源(可选启用) AUDIT_WEIGHT=0.0 # 引用逻辑审计(可选启用) ``` ---- +------ ### 1) 分析充分性(RM Gallery) **目标**:评估报告的分析深度、覆盖面和逻辑性——回答「分析得好不好」。 -**机制**:使用 [RM Gallery](https://github.com/modelscope/rm_gallery) 的 `finance_composition` 评估器,通过独立的 Judge LLM(如 `qwen-max`)对生成报告与参考答案进行对比评估。 +**机制**:使用 `finance_composition` 评估器,通过独立的 Judge LLM( `qwen-max`)对生成报告与参考答案进行对比评估。 **评估维度(按金融 domain 分域)**: + - 分析深度:对核心问题的挖掘是否足够深入 - 覆盖面:是否覆盖了问题涉及的多个分析维度(基本面、财务、估值、行业、新闻等) - 逻辑性:分析推理链条是否完整、结论是否有据可依 **输入输出**: + - 输入:用户 Query + Agent 生成的报告 + 参考答案 - 输出:`[0, 1]` 归一化分数 ---- +------ ### 2) 呈现质量(Presentation Quality) @@ -118,31 +181,33 @@ AUDIT_WEIGHT=0.0 # 引用逻辑审计(可选启用) **8 项子指标(1/3/5 分制)**: -| 分类 | 指标 | 5分标准 | -|------|------|--------| -| **Scan 可扫描性** | A1 结论先行 | 开头有独立摘要/TL;DR,读者无需滚动即可获取主结论 | -| | A2 结构导航 | 层级分明(H1/H2/H3),长文有清晰小标题路标 | -| | A3 视觉重点 | 精准使用加粗/斜体强调核心洞察,信噪比高 | -| **Structuring 信息结构化** | B1 密集信息解构 | 复杂数据用表格/嵌套列表呈现,一目了然 | -| | B2 对比对齐 | 方案A vs B / 历史 vs 现状使用表格,维度横向可比 | -| | B3 一致性与渲染 | 格式统一,Markdown 渲染完美 | -| **Editorial 编辑清晰度** | C1 论证链可视化 | 逻辑链条可视(主张→证据→结论),引用锚点清晰 | -| | C2 风险与行动 | 独立板块列出风险/局限性及下一步建议 | +| 分类 | 指标 | 5分标准 | +| -------------------------- | --------------- | ------------------------------------------------ | +| **Scan 可扫描性** | A1 结论先行 | 开头有独立摘要/TL;DR,读者无需滚动即可获取主结论 | +| | A2 结构导航 | 层级分明(H1/H2/H3),长文有清晰小标题路标 | +| | A3 视觉重点 | 精准使用加粗/斜体强调核心洞察,信噪比高 | +| **Structuring 信息结构化** | B1 密集信息解构 | 复杂数据用表格/嵌套列表呈现,一目了然 | +| | B2 对比对齐 | 方案A vs B / 历史 vs 现状使用表格,维度横向可比 | +| | B3 一致性与渲染 | 格式统一,Markdown 渲染完美 | +| **Editorial 编辑清晰度** | C1 论证链可视化 | 逻辑链条可视(主张→证据→结论),引用锚点清晰 | +| | C2 风险与行动 | 独立板块列出风险/局限性及下一步建议 | **评分计算**: -``` + +```plain score = Σ(8项得分) / 40 # 归一化到 [0, 1] ``` **反刷分机制**:空表格、无意义重复列表、为格式而格式 → 直接判 1 分。 ---- +------ ### 3) 引用规范性(Grounding) **目标**:评估报告的引用覆盖率和引用真实性——回答「关键事实都有出处吗?引用是真的吗?」 **评估流程**: + 1. 从对话轨迹中提取 User Query、Evidence(工具调用与返回)、最终报告 2. LLM 审计员识别报告中的所有「关键事实句」(含数字/日期/财务指标/确定性陈述) 3. 检查每个关键事实句句末是否有引用标记 `[n]` @@ -150,6 +215,7 @@ score = Σ(8项得分) / 40 # 归一化到 [0, 1] 5. 检查引用内容与 Evidence 是否一致(检测虚假引用) **输出字段**: + - `total_key_facts`:关键事实句总数 - `cited_key_facts`:句末有引用的关键事实句数 - `fake_count`:引用内容与证据明显矛盾的数量 @@ -157,13 +223,14 @@ score = Σ(8项得分) / 40 # 归一化到 [0, 1] - `invalid_reference_nums`:不合规的引用编号 **评分计算**: -``` + +```plain citation_coverage = cited_key_facts / total_key_facts # 引用覆盖率 grounding_score = 1 - fake_count / cited_key_facts # 引用真实性 final_score = 0.5 × coverage + 0.5 × grounding # 综合分数 ``` ---- +------ ### 4) 证据溯源(EBTU - Evidence-Backed Trace Units) @@ -172,25 +239,29 @@ final_score = 0.5 × coverage + 0.5 × grounding # 综合分数 **核心理念:证据优先(Evidence-first)**。审计官必须先给出证据锚点(step + quote),再下裁决,严禁先下结论再找证据。 **审计流程**: + 1. 从报告中提取所有原子断言(Trace Units),标记类型(numeric/temporal/event/comparison/causal 等) 2. 标记硬度:`hard`(确定性事实) / `soft`(明确标注为推测/假设) 3. 对每个断言在 Evidence 中寻找锚点(anchors),要求: - - 精确到 step 编号和原文引用(quote ≤ 120 字) - - 数字/日期必须能在 Evidence 原文中找到对应 -4. 给出裁决(verdict): -| Verdict | 含义 | -|---------|------| -| `supported` | 锚点直接支持断言 | -| `contradicted` | 锚点与断言明确冲突 | -| `no_evidence` | Evidence 中找不到支撑,且断言是确定性表述 | -| `speculative_ok` | 断言明确为推测/假设,未伪装成事实 | -| `unclear` | Evidence 相关但不足以支持或反驳 | +- - 精确到 step 编号和原文引用(quote ≤ 120 字) + - 数字/日期必须能在 Evidence 原文中找到对应 -5. 标记问题类型(issue):`entity_mismatch` / `time_mismatch` / `value_mismatch` / `scope_mismatch` / `logic_leap` / `over_precision` / `missing_anchor` +1. 给出裁决(verdict): + +| Verdict | 含义 | +| ---------------- | ----------------------------------------- | +| `supported` | 锚点直接支持断言 | +| `contradicted` | 锚点与断言明确冲突 | +| `no_evidence` | Evidence 中找不到支撑,且断言是确定性表述 | +| `speculative_ok` | 断言明确为推测/假设,未伪装成事实 | +| `unclear` | Evidence 相关但不足以支持或反驳 | + +1. 标记问题类型(issue):`entity_mismatch` / `time_mismatch` / `value_mismatch` / `scope_mismatch` / `logic_leap` / `over_precision` / `missing_anchor` **评分计算**(确定性打分,由 Python 代码计算,非 LLM 输出): -``` + +```plain base = (supported - 1.4×contradicted - 0.9×no_evidence - 0.4×unclear) / hard_units misattrib_factor = max(0, 1 - 0.7 × misattrib_rate) # 错误归因惩罚 selection_factor = min(1, extracted_units / expected) # 覆盖率因子 @@ -200,119 +271,33 @@ score = base × misattrib_factor × selection_factor × cov_factor 关键设计:LLM 只负责结构化输出(断言提取 + 锚点标注 + 裁决),分数完全由代码确定性计算,避免 LLM 自评分的不稳定性。 ---- - -### 5) 引用逻辑审计(Audit) - -**目标**:对报告中每个引用标记 `[n]` 做逻辑蕴含验证——回答「引用真的支撑了这句话吗?有没有夸大、捏造或断章取义?」 - -**角色**:证据逻辑分析师(Evidence Logic Analyst),像法官判案一样:先罗列证据,再逻辑推导,最后下判决。 - -**三步验证**: -1. **提取(Extract)**:锁定报告中由 `[n]` 支撑的陈述片段(Claim) -2. **溯源(Trace)**:在 Evidence 中找到 `[n]` 对应的原始内容(Source Quote) -3. **比对(Compare)**: - - 数字/事实是否一致? - - 语气是否一致?(有没有把「可能」改成「确定」) - - 因果关系是否成立? - -**判决标准(5 级)**: - -| Verdict | 含义 | 典型案例 | -|---------|------|----------| -| `Supported` | 证据充分,逻辑闭环 | 证据说增长 15%,报告写增长 15% | -| `Overstated` | 夸大其词 | 证据说「约 15%」,报告写成「高达 15.2%」 | -| `Contradicted` | 事实冲突 | 证据说下降,报告写上升 | -| `Hallucinated` | 无中生有 | 报告中的数据在证据中完全找不到 | -| `Irrelevant` | 引用无关 | 引用的内容与被支撑的陈述无关 | - -**评分计算**: -``` -score = Supported 数量 / 总引用数 -``` - ---- +------ ### 工具调用惩罚 在加权融合分数之外,额外施加工具调用惩罚,鼓励 Agent 积极使用工具收集数据: -| 工具调用次数 | 惩罚 | -|-------------|------| -| 0 次 | -1.0 | -| 1-2 次 | -0.5 | -| ≥3 次 | 0.0(无惩罚) | - ---- +| 工具调用次数 | 惩罚 | +| ------------ | ------------- | +| 0 次 | -1.0 | +| 1-2 次 | -0.5 | +| ≥3 次 | 0.0(无惩罚) | -## 工具体系 +------ -DeepFinance 集成了 **19 个金融工具**,通过 MCP(Model Context Protocol)协议与 EnvService 交互,覆盖金融研究的完整数据需求。 - -| 类别 | 工具 | 功能 | -|------|------|------| -| **实体与计算** | `extract_entities_code` | 从自然语言中提取金融实体并查找代码 | -| | `history_calculate` | A股历史股价分析(支持自然语言提问) | -| **通用能力** | `dashscope_search` | 互联网搜索 | -| | `execute_code` | Python 代码执行 | -| | `execute_shell` | Shell 命令执行 | -| **同花顺专项数据** | `crawl_ths_company` | 上市公司基本资料 | -| | `crawl_ths_holder` | 股东研究信息 | -| | `crawl_ths_operate` | 经营分析信息 | -| | `crawl_ths_finance` | 财务分析信息 | -| | `crawl_ths_worth` | 盈利预测信息 | -| | `crawl_ths_news` | 新闻公告信息 | -| | `crawl_ths_concept` | 概念题材信息 | -| | `crawl_ths_equity` | 股本结构信息 | -| | `crawl_ths_capital` | 资本运作信息 | -| | `crawl_ths_position` | 主力持仓信息 | -| | `crawl_ths_bonus` | 分红融资信息 | -| | `crawl_ths_event` | 公司大事信息 | -| | `crawl_ths_field` | 行业对比信息 | - -工具调用规范: -- 每次最多调用 **3 个工具**,采用多轮次渐进式调研 -- Agent 必须先搜索确认信息(如股票代码),再进行深度查询 -- 每轮工具调用后先做小结,再决定下一步调研方向 - ---- - -## Prompt 设计 - -### 两阶段深度研究流程 - -Agent 的 System Prompt(`prompt/finance_analyst_prompt.md`)要求遵循两阶段研究方法: - -**第一阶段:先大纲后调研** -1. 理解用户问题类型(个股分析/行业研究/事件解读/宏观分析/股票检索) -2. **先输出研究大纲**(一级/二级标题 + 每节的 Key Questions),此阶段不调用工具 -3. 按大纲逐段调研,每轮调用工具后做小结 - -**第二阶段:深度分析与报告生成** -1. 当数据充分后,基于真实数据生成 Markdown 格式研究报告 -2. 写作中发现证据不足时允许追加 1-2 轮工具调用补充取证 -3. 报告末尾添加 `[TASK_COMPLETED]` 标记 - -### 引用规范 - -Agent 被要求使用学术论文风格的引用标注: -- 所有关键事实句句末必须添加引用编号 `[n]` -- 报告末尾必须包含 `## References` 小节 -- 引用必须可追溯到实际工具返回的数据,禁止伪造 - ---- - -## 快速开始 +## Quick Start ### 环境准备 -1. 安装 AgentJet 及依赖: +1. 安装 AgentJet 及依赖 + ```bash cd /path/to/AgentJet -bash install.sh +bash install.sh # TODO:把这部分缩减到一个install:https://yuque.alibaba-inc.com/bayotg/wxz7sb/qdesuu33621x2yhi ``` -2. 配置 `.env` 文件(API 密钥、模型路径、数据路径等): +1. 配置 `.env` 文件(API 密钥、模型路径、数据路径等): + ```bash # .env 示例 MODEL_PATH=/path/to/Qwen3-8B @@ -325,7 +310,7 @@ OPENJUDGE_API_KEY=your_api_key RM_API_KEY=your_api_key ``` -3. 启动 EnvService(金融工具服务) +1. 启动 EnvService(金融工具服务) ### 单机调试模式 @@ -343,25 +328,31 @@ bash tutorial/example_deep_finance/deep_finance.sh ``` 该脚本会: + 1. 从 YAML 模板动态生成配置文件 2. 在 Master 节点启动 Ray Head + 训练任务 3. Worker 节点自动加入 Ray 集群 ### 关键参数说明 -| 参数 | 默认值 | 说明 | -|------|--------|------| -| `NUM_REPEAT` | 4 | Group size,每个 query rollout 的次数 | -| `NUM_STEPS` | 6 | 每个样本的最大交互轮数 | -| `TRAIN_BATCH_SIZE` | 32 | 训练 batch size | -| `RM_WEIGHT` | 0.5 | 分析充分性权重 | -| `PRESENTATION_QUALITY_WEIGHT` | 0.25 | 呈现质量权重 | -| `GROUNDING_WEIGHT` | 0.25 | 引用规范性权重 | -| `EBTU_WEIGHT` | 0.0 | 证据溯源权重(可选启用) | -| `AUDIT_WEIGHT` | 0.0 | 引用逻辑审计权重(可选启用) | +| 参数 | 默认值 | 说明 | +| ----------------------------- | ------ | ------------------------------------- | +| `NUM_REPEAT` | 4 | Group size,每个 query rollout 的次数 | +| `NUM_STEPS` | 6 | 每个样本的最大交互轮数 | +| `TRAIN_BATCH_SIZE` | 32 | 训练 batch size | +| `RM_WEIGHT` | 0.5 | 分析充分性权重 | +| `PRESENTATION_QUALITY_WEIGHT` | 0.25 | 呈现质量权重 | +| `GROUNDING_WEIGHT` | 0.25 | 引用规范性权重 | +| `EBTU_WEIGHT` | 0.0 | 证据溯源权重(可选启用) | +| `AUDIT_WEIGHT` | 0.0 | 引用逻辑审计权重(可选启用) | ---- +------ ## 实验结果 -> TODO: 补充训练曲线、各维度 Grader 分数变化、生成报告质量对比等实验数据。 \ No newline at end of file + +![img](https://intranetproxy.alipay.com/skylark/lark/0/2026/png/107756372/1771843906200-9dd35ac4-f71e-40dc-b130-f03e3e6bae6a.png) + +![img](https://intranetproxy.alipay.com/skylark/lark/0/2026/png/107756372/1771843940824-4e3637d7-a16e-4994-8878-242effc2c0d7.png)![img](https://intranetproxy.alipay.com/skylark/lark/0/2026/png/107756372/1771843950142-09def779-5521-41f0-a457-a7715a819cc7.png) + +