From a8a0a7c16f4edecbeb77a25b3163d88165b2bda6 Mon Sep 17 00:00:00 2001 From: Haizhong Date: Wed, 20 May 2026 09:26:33 -0400 Subject: [PATCH 1/3] refactor: group config, weight_manager, workflow under astraflow.core Move three cross-cutting subpackages into a shared astraflow/core/ namespace to distinguish them from the engine packages (raas, train_worker, dataflow). Pure rename + mechanical import sweep, no logic changes: - astraflow/{config,weight_manager,workflow}/ -> astraflow/core// - import paths updated across the package, docs, and AGENTS.md - new astraflow/core/__init__.py documents the group's role - astraflow/__init__.py comment clarifies the dataflow re-export --- AGENTS.md | 30 ++++++++--------- astraflow/__init__.py | 2 +- astraflow/core/__init__.py | 9 ++++++ astraflow/{ => core}/config/__init__.py | 2 +- astraflow/{ => core}/config/loader.py | 0 .../{ => core}/weight_manager/__init__.py | 4 +-- astraflow/{ => core}/weight_manager/config.py | 2 +- .../weight_manager/transfer/__init__.py | 0 .../weight_manager/transfer/config.py | 0 .../weight_manager/transfer/receiver_agent.py | 0 .../weight_manager/transfer/sender_agent.py | 0 .../tests/test_chunked_delta_equiv.py | 0 .../transfer/transfer_engine.py | 0 .../weight_manager/weight_manager.py | 4 +-- astraflow/core/workflow/__init__.py | 32 +++++++++++++++++++ astraflow/{ => core}/workflow/api/__init__.py | 0 astraflow/{ => core}/workflow/api/cli_args.py | 0 .../{ => core}/workflow/api/engine_api.py | 2 +- .../{ => core}/workflow/api/io_struct.py | 2 +- .../{ => core}/workflow/api/reward_api.py | 2 +- .../{ => core}/workflow/api/workflow_api.py | 2 +- .../{ => core}/workflow/impl/__init__.py | 0 .../workflow/impl/actor_and_verify.py | 20 ++++++------ .../impl/agentbench/alfworld_task_server.py | 4 +-- .../workflow/impl/agentbench/task_server.py | 14 ++++---- .../agentbench/webshop_checker_workflow.py | 16 +++++----- .../impl/agentbench/webshop_task_server.py | 4 +-- .../workflow/impl/asearcher/__init__.py | 0 .../workflow/impl/asearcher/agent.py | 0 .../workflow/impl/asearcher/prompts.py | 0 .../workflow/impl/asearcher/reward.py | 0 .../workflow/impl/asearcher/search.py | 2 +- .../workflow/impl/asearcher/workflow.py | 12 +++---- .../workflow/impl/code_actor_and_verify.py | 18 +++++------ .../workflow/impl/code_actor_and_verify_v2.py | 18 +++++------ .../workflow/impl/code_actor_and_verify_v3.py | 18 +++++------ .../workflow/impl/code_solve_and_select.py | 18 +++++------ .../impl/livecodebench_single_turn.py | 4 +-- .../{ => core}/workflow/impl/multi_turn.py | 16 +++++----- .../workflow/impl/plan_and_solve.py | 22 ++++++------- astraflow/{ => core}/workflow/impl/rlvr.py | 22 ++++++------- .../workflow/impl/sep_solve_and_check.py | 20 ++++++------ .../{ => core}/workflow/impl/sm_lg_router.py | 18 +++++------ .../workflow/impl/solve_and_check.py | 20 ++++++------ .../workflow/impl/solve_and_verify.py | 20 ++++++------ .../{ => core}/workflow/impl/vision_rlvr.py | 18 +++++------ astraflow/{ => core}/workflow/registry.py | 0 .../{ => core}/workflow/reward/__init__.py | 2 +- .../workflow/reward/clevr_count_70k.py | 2 +- .../{ => core}/workflow/reward/geometry3k.py | 6 ++-- .../workflow/reward/human_eval_reward.py | 6 ++-- .../workflow/reward/livecodebench_reward.py | 8 ++--- .../{ => core}/workflow/reward/math_verify.py | 6 ++-- .../{ => core}/workflow/utils/__init__.py | 0 .../workflow/utils/code_execution_mraas.py | 4 +-- astraflow/{ => core}/workflow/utils/data.py | 0 .../workflow/utils/dynamic_import.py | 0 .../{ => core}/workflow/utils/hf_utils.py | 2 +- astraflow/{ => core}/workflow/utils/image.py | 0 .../{ => core}/workflow/utils/logging.py | 0 .../{ => core}/workflow/utils/perf_tracer.py | 0 .../workflow/utils/stats_tracker.py | 0 .../{ => core}/workflow/utils/testing_util.py | 0 .../workflow/utils/testing_util_mraas.py | 0 astraflow/dataflow/__main__.py | 2 +- astraflow/dataflow/data_acquisition.py | 2 +- .../dataflow/dataset/deepcoder_preview.py | 2 +- astraflow/dataflow/dataset/human_eval.py | 2 +- astraflow/dataflow/dataset/livecodebench.py | 2 +- astraflow/dataflow/prompt_curators.py | 2 +- astraflow/raas/api/cli_args.py | 2 +- astraflow/raas/server/manager.py | 8 ++--- astraflow/raas/server/tcp_receiver.py | 6 ++-- astraflow/train_worker/api/cli_args.py | 2 +- astraflow/train_worker/api/engine_api.py | 8 ++--- astraflow/train_worker/trainer/ppo_trainer.py | 4 +-- astraflow/workflow/__init__.py | 32 ------------------- docs/en/architecture/custom-raas.md | 8 ++--- .../multi-agent-weight-transfer.md | 4 +-- docs/en/architecture/trainer.md | 2 +- docs/en/architecture/weight-manager.md | 4 +-- 81 files changed, 267 insertions(+), 258 deletions(-) create mode 100644 astraflow/core/__init__.py rename astraflow/{ => core}/config/__init__.py (91%) rename astraflow/{ => core}/config/loader.py (100%) rename astraflow/{ => core}/weight_manager/__init__.py (66%) rename astraflow/{ => core}/weight_manager/config.py (89%) rename astraflow/{ => core}/weight_manager/transfer/__init__.py (100%) rename astraflow/{ => core}/weight_manager/transfer/config.py (100%) rename astraflow/{ => core}/weight_manager/transfer/receiver_agent.py (100%) rename astraflow/{ => core}/weight_manager/transfer/sender_agent.py (100%) rename astraflow/{ => core}/weight_manager/transfer/tests/test_chunked_delta_equiv.py (100%) rename astraflow/{ => core}/weight_manager/transfer/transfer_engine.py (100%) rename astraflow/{ => core}/weight_manager/weight_manager.py (99%) create mode 100644 astraflow/core/workflow/__init__.py rename astraflow/{ => core}/workflow/api/__init__.py (100%) rename astraflow/{ => core}/workflow/api/cli_args.py (100%) rename astraflow/{ => core}/workflow/api/engine_api.py (97%) rename astraflow/{ => core}/workflow/api/io_struct.py (96%) rename astraflow/{ => core}/workflow/api/reward_api.py (99%) rename astraflow/{ => core}/workflow/api/workflow_api.py (92%) rename astraflow/{ => core}/workflow/impl/__init__.py (100%) rename astraflow/{ => core}/workflow/impl/actor_and_verify.py (96%) rename astraflow/{ => core}/workflow/impl/agentbench/alfworld_task_server.py (88%) rename astraflow/{ => core}/workflow/impl/agentbench/task_server.py (96%) rename astraflow/{ => core}/workflow/impl/agentbench/webshop_checker_workflow.py (97%) rename astraflow/{ => core}/workflow/impl/agentbench/webshop_task_server.py (89%) rename astraflow/{ => core}/workflow/impl/asearcher/__init__.py (100%) rename astraflow/{ => core}/workflow/impl/asearcher/agent.py (100%) rename astraflow/{ => core}/workflow/impl/asearcher/prompts.py (100%) rename astraflow/{ => core}/workflow/impl/asearcher/reward.py (100%) rename astraflow/{ => core}/workflow/impl/asearcher/search.py (99%) rename astraflow/{ => core}/workflow/impl/asearcher/workflow.py (96%) rename astraflow/{ => core}/workflow/impl/code_actor_and_verify.py (98%) rename astraflow/{ => core}/workflow/impl/code_actor_and_verify_v2.py (98%) rename astraflow/{ => core}/workflow/impl/code_actor_and_verify_v3.py (98%) rename astraflow/{ => core}/workflow/impl/code_solve_and_select.py (97%) rename astraflow/{ => core}/workflow/impl/livecodebench_single_turn.py (82%) rename astraflow/{ => core}/workflow/impl/multi_turn.py (92%) rename astraflow/{ => core}/workflow/impl/plan_and_solve.py (95%) rename astraflow/{ => core}/workflow/impl/rlvr.py (90%) rename astraflow/{ => core}/workflow/impl/sep_solve_and_check.py (94%) rename astraflow/{ => core}/workflow/impl/sm_lg_router.py (97%) rename astraflow/{ => core}/workflow/impl/solve_and_check.py (95%) rename astraflow/{ => core}/workflow/impl/solve_and_verify.py (96%) rename astraflow/{ => core}/workflow/impl/vision_rlvr.py (90%) rename astraflow/{ => core}/workflow/registry.py (100%) rename astraflow/{ => core}/workflow/reward/__init__.py (98%) rename astraflow/{ => core}/workflow/reward/clevr_count_70k.py (90%) rename astraflow/{ => core}/workflow/reward/geometry3k.py (84%) rename astraflow/{ => core}/workflow/reward/human_eval_reward.py (95%) rename astraflow/{ => core}/workflow/reward/livecodebench_reward.py (96%) rename astraflow/{ => core}/workflow/reward/math_verify.py (76%) rename astraflow/{ => core}/workflow/utils/__init__.py (100%) rename astraflow/{ => core}/workflow/utils/code_execution_mraas.py (98%) rename astraflow/{ => core}/workflow/utils/data.py (100%) rename astraflow/{ => core}/workflow/utils/dynamic_import.py (100%) rename astraflow/{ => core}/workflow/utils/hf_utils.py (96%) rename astraflow/{ => core}/workflow/utils/image.py (100%) rename astraflow/{ => core}/workflow/utils/logging.py (100%) rename astraflow/{ => core}/workflow/utils/perf_tracer.py (100%) rename astraflow/{ => core}/workflow/utils/stats_tracker.py (100%) rename astraflow/{ => core}/workflow/utils/testing_util.py (100%) rename astraflow/{ => core}/workflow/utils/testing_util_mraas.py (100%) delete mode 100644 astraflow/workflow/__init__.py diff --git a/AGENTS.md b/AGENTS.md index 76ae4dc..4d75f36 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -66,7 +66,7 @@ your response. - `server/` — Manager, TCP receiver, FastAPI app. - `engine/` — Remote inference engine adapters. - `api/cli_args.py` — RaaS-side dataclass configs. - - `astraflow/workflow/` — Rollout workflows and reward functions + - `astraflow/core/workflow/` — Rollout workflows and reward functions (swappable). - `api/` — Base interfaces: `RolloutWorkflow` (`workflow_api.py`), `AsyncRewardWrapper` (`reward_api.py`). @@ -82,9 +82,9 @@ your response. and rewards (`WORKFLOW_REGISTRY`, `REWARD_REGISTRY`). - `__init__.py` — Imports every `impl/` and `reward/` module so the registration decorators run at import time. - - `astraflow/weight_manager/` — Weight transport (TCP/ZMQ) + - `astraflow/core/weight_manager/` — Weight transport (TCP/ZMQ) between trainer and RaaS. `transfer/tests/` holds its unit tests. - - `astraflow/config/` — Hydra/OmegaConf config loader and merging. + - `astraflow/core/config/` — Hydra/OmegaConf config loader and merging. - `astraEnv/` — Vendored environment code (not part of the package): `AgentBench` (alfworld, webshop), `ASearcher` (retrieval-augmented search), `human-eval`. Carries upstream licenses; treat as a @@ -119,7 +119,7 @@ your response. vLLM / SGLang inference servers and exposes `/availability`, `/eval_pull`, weight-update, and rollout endpoints. See `docs/en/architecture/raas.md` and `custom-raas.md`. -4. **Workflow** (`astraflow/workflow/`) — Rollout workflows +4. **Workflow** (`astraflow/core/workflow/`) — Rollout workflows (subclasses of `RolloutWorkflow` with `arun_episode`) and reward callables. New workflows/rewards plug in via decorator registration in `registry.py`. @@ -129,7 +129,7 @@ your response. Conventions for all code under `astraflow/`: - **Logging**: `astraflow.train_worker.utils.logging.getLogger(__name__)` - (workflow code uses `astraflow.workflow.utils.logging`) — never + (workflow code uses `astraflow.core.workflow.utils.logging`) — never `print`. Emit metrics through `stats_tracker.get("scope")` and `StatsLogger`; the latter pushes to W&B / SwanLab. - **Async**: Rollout workflows are non-blocking. `await` with @@ -150,7 +150,7 @@ Conventions for all code under `astraflow/`: consistent. Place heavy optional deps inside function bodies to avoid import-time side effects (Megatron, flash-attn, etc.). - **Rewards**: Wrap blocking reward code with `AsyncRewardWrapper` - (`astraflow/workflow/api/reward_api.py`). Standard signature is + (`astraflow/core/workflow/api/reward_api.py`). Standard signature is `(prompt, completions, prompt_ids, completion_ids, **data)` where `**data` carries dataset-specific fields (e.g. `answer`); return a `float`. @@ -164,9 +164,9 @@ Conventions for all code under `astraflow/`: ### Add a rollout workflow -- Create a new module under `astraflow/workflow/impl/.py`. +- Create a new module under `astraflow/core/workflow/impl/.py`. - Subclass `RolloutWorkflow` - (`astraflow/workflow/api/workflow_api.py`) and implement async + (`astraflow/core/workflow/api/workflow_api.py`) and implement async `arun_episode`. - Thread `GenerationHyperparameters`, tokenizer, reward callable, stat scope, and optional `dump_dir` through `__init__`. Wrap the @@ -176,20 +176,20 @@ Conventions for all code under `astraflow/`: - Persist transcripts to `{dump_dir}/{engine.get_version()}/` when debugging. - Decorate the class with `@register_workflow("")` from - `astraflow/workflow/registry.py`, then add an - `import astraflow.workflow.impl.` line to - `astraflow/workflow/__init__.py` so the decorator runs at import + `astraflow/core/workflow/registry.py`, then add an + `import astraflow.core.workflow.impl.` line to + `astraflow/core/workflow/__init__.py` so the decorator runs at import time. Reference `""` from `workflow_spec.workflow_cls` in the recipe YAML. ### Add a reward function -- Create `astraflow/workflow/reward/.py` with a callable +- Create `astraflow/core/workflow/reward/.py` with a callable matching the reward API contract (see "Rewards" above). - Decorate it with `@register_reward("")` from - `astraflow/workflow/registry.py`, then add an - `import astraflow.workflow.reward.` line to - `astraflow/workflow/__init__.py`. Reference `""` from + `astraflow/core/workflow/registry.py`, then add an + `import astraflow.core.workflow.reward.` line to + `astraflow/core/workflow/__init__.py`. Reference `""` from `workflow_spec.reward_fn` in the recipe YAML. - Keep slow or external-service logic in the reward module; let the calling workflow wrap it with `AsyncRewardWrapper`. diff --git a/astraflow/__init__.py b/astraflow/__init__.py index 781ca03..a228317 100644 --- a/astraflow/__init__.py +++ b/astraflow/__init__.py @@ -2,7 +2,7 @@ from .version import __version__ # noqa -# Re-export core orchestration symbols for convenience. +# Re-export dataflow orchestration symbols for convenience. # Allows ``from astraflow import AstraFlow`` instead of # ``from astraflow.dataflow import AstraFlow``. from .dataflow import ( # noqa: F401 diff --git a/astraflow/core/__init__.py b/astraflow/core/__init__.py new file mode 100644 index 0000000..bf02c65 --- /dev/null +++ b/astraflow/core/__init__.py @@ -0,0 +1,9 @@ +"""Shared building blocks for AstraFlow. + +Contains cross-cutting components used by the engine packages +(``raas``, ``train_worker``, ``dataflow``): + +- ``config`` — experiment/raas/dataflow/trainer YAML loading +- ``weight_manager`` — weight transport between trainer and RaaS +- ``workflow`` — rollout workflows and reward function registry +""" diff --git a/astraflow/config/__init__.py b/astraflow/core/config/__init__.py similarity index 91% rename from astraflow/config/__init__.py rename to astraflow/core/config/__init__.py index f601b56..f504c51 100644 --- a/astraflow/config/__init__.py +++ b/astraflow/core/config/__init__.py @@ -5,7 +5,7 @@ raas.yaml files for per-instance hardware configuration. """ -from astraflow.config.loader import ( +from astraflow.core.config.loader import ( load_and_merge_configs, load_dataflow_config, load_raas_config, diff --git a/astraflow/config/loader.py b/astraflow/core/config/loader.py similarity index 100% rename from astraflow/config/loader.py rename to astraflow/core/config/loader.py diff --git a/astraflow/weight_manager/__init__.py b/astraflow/core/weight_manager/__init__.py similarity index 66% rename from astraflow/weight_manager/__init__.py rename to astraflow/core/weight_manager/__init__.py index f934c23..2bcfdc7 100644 --- a/astraflow/weight_manager/__init__.py +++ b/astraflow/core/weight_manager/__init__.py @@ -4,8 +4,8 @@ Both trainer and RaaS import from this package; neither imports from the other. """ -from astraflow.weight_manager.config import WeightManagerConfig -from astraflow.weight_manager.weight_manager import WeightManager +from astraflow.core.weight_manager.config import WeightManagerConfig +from astraflow.core.weight_manager.weight_manager import WeightManager __all__ = [ "WeightManager", diff --git a/astraflow/weight_manager/config.py b/astraflow/core/weight_manager/config.py similarity index 89% rename from astraflow/weight_manager/config.py rename to astraflow/core/weight_manager/config.py index d3eb860..cf7c16e 100644 --- a/astraflow/weight_manager/config.py +++ b/astraflow/core/weight_manager/config.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import List -from astraflow.weight_manager.transfer.config import SenderAgentConfig +from astraflow.core.weight_manager.transfer.config import SenderAgentConfig @dataclass diff --git a/astraflow/weight_manager/transfer/__init__.py b/astraflow/core/weight_manager/transfer/__init__.py similarity index 100% rename from astraflow/weight_manager/transfer/__init__.py rename to astraflow/core/weight_manager/transfer/__init__.py diff --git a/astraflow/weight_manager/transfer/config.py b/astraflow/core/weight_manager/transfer/config.py similarity index 100% rename from astraflow/weight_manager/transfer/config.py rename to astraflow/core/weight_manager/transfer/config.py diff --git a/astraflow/weight_manager/transfer/receiver_agent.py b/astraflow/core/weight_manager/transfer/receiver_agent.py similarity index 100% rename from astraflow/weight_manager/transfer/receiver_agent.py rename to astraflow/core/weight_manager/transfer/receiver_agent.py diff --git a/astraflow/weight_manager/transfer/sender_agent.py b/astraflow/core/weight_manager/transfer/sender_agent.py similarity index 100% rename from astraflow/weight_manager/transfer/sender_agent.py rename to astraflow/core/weight_manager/transfer/sender_agent.py diff --git a/astraflow/weight_manager/transfer/tests/test_chunked_delta_equiv.py b/astraflow/core/weight_manager/transfer/tests/test_chunked_delta_equiv.py similarity index 100% rename from astraflow/weight_manager/transfer/tests/test_chunked_delta_equiv.py rename to astraflow/core/weight_manager/transfer/tests/test_chunked_delta_equiv.py diff --git a/astraflow/weight_manager/transfer/transfer_engine.py b/astraflow/core/weight_manager/transfer/transfer_engine.py similarity index 100% rename from astraflow/weight_manager/transfer/transfer_engine.py rename to astraflow/core/weight_manager/transfer/transfer_engine.py diff --git a/astraflow/weight_manager/weight_manager.py b/astraflow/core/weight_manager/weight_manager.py similarity index 99% rename from astraflow/weight_manager/weight_manager.py rename to astraflow/core/weight_manager/weight_manager.py index 53ad7ca..f8772a7 100644 --- a/astraflow/weight_manager/weight_manager.py +++ b/astraflow/core/weight_manager/weight_manager.py @@ -26,8 +26,8 @@ import torch.multiprocessing as mp from torch.distributed._tensor import DTensor -from astraflow.weight_manager.config import WeightManagerConfig -from astraflow.weight_manager.transfer.sender_agent import ( +from astraflow.core.weight_manager.config import WeightManagerConfig +from astraflow.core.weight_manager.transfer.sender_agent import ( create_tensor_from_shared_memory, start_transfer_agent, ) diff --git a/astraflow/core/workflow/__init__.py b/astraflow/core/workflow/__init__.py new file mode 100644 index 0000000..0c7a5a8 --- /dev/null +++ b/astraflow/core/workflow/__init__.py @@ -0,0 +1,32 @@ +"""Standalone workflow package for rollout workflows and reward functions. + +Importing this package triggers auto-registration of all built-in +workflows and reward functions via their @register_workflow / @register_reward +decorators. +""" + +# Auto-import implementations to trigger registry decorators +import astraflow.core.workflow.impl.agentbench.alfworld_task_server +import astraflow.core.workflow.impl.agentbench.task_server +import astraflow.core.workflow.impl.agentbench.webshop_task_server +import astraflow.core.workflow.impl.agentbench.webshop_checker_workflow +import astraflow.core.workflow.impl.asearcher +import astraflow.core.workflow.impl.code_actor_and_verify +import astraflow.core.workflow.impl.code_actor_and_verify_v2 +import astraflow.core.workflow.impl.code_actor_and_verify_v3 +import astraflow.core.workflow.impl.code_solve_and_select +import astraflow.core.workflow.impl.livecodebench_single_turn +import astraflow.core.workflow.impl.multi_turn +import astraflow.core.workflow.impl.plan_and_solve +import astraflow.core.workflow.impl.solve_and_check +import astraflow.core.workflow.impl.sep_solve_and_check +import astraflow.core.workflow.impl.solve_and_verify +import astraflow.core.workflow.impl.actor_and_verify +import astraflow.core.workflow.impl.rlvr +import astraflow.core.workflow.impl.sm_lg_router +import astraflow.core.workflow.impl.vision_rlvr +import astraflow.core.workflow.reward.clevr_count_70k +import astraflow.core.workflow.reward.geometry3k +import astraflow.core.workflow.reward.math_verify +import astraflow.core.workflow.reward.human_eval_reward +import astraflow.core.workflow.reward.livecodebench_reward diff --git a/astraflow/workflow/api/__init__.py b/astraflow/core/workflow/api/__init__.py similarity index 100% rename from astraflow/workflow/api/__init__.py rename to astraflow/core/workflow/api/__init__.py diff --git a/astraflow/workflow/api/cli_args.py b/astraflow/core/workflow/api/cli_args.py similarity index 100% rename from astraflow/workflow/api/cli_args.py rename to astraflow/core/workflow/api/cli_args.py diff --git a/astraflow/workflow/api/engine_api.py b/astraflow/core/workflow/api/engine_api.py similarity index 97% rename from astraflow/workflow/api/engine_api.py rename to astraflow/core/workflow/api/engine_api.py index e18444a..1759368 100644 --- a/astraflow/workflow/api/engine_api.py +++ b/astraflow/core/workflow/api/engine_api.py @@ -14,7 +14,7 @@ from contextlib import asynccontextmanager from typing import Any, Protocol, runtime_checkable -from astraflow.workflow.api.io_struct import ModelRequest, ModelResponse +from astraflow.core.workflow.api.io_struct import ModelRequest, ModelResponse @runtime_checkable diff --git a/astraflow/workflow/api/io_struct.py b/astraflow/core/workflow/api/io_struct.py similarity index 96% rename from astraflow/workflow/api/io_struct.py rename to astraflow/core/workflow/api/io_struct.py index bf2cd6a..4aa8aba 100644 --- a/astraflow/workflow/api/io_struct.py +++ b/astraflow/core/workflow/api/io_struct.py @@ -9,7 +9,7 @@ from PIL.Image import Image as ImageObject from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters if TYPE_CHECKING: from transformers import AutoProcessor diff --git a/astraflow/workflow/api/reward_api.py b/astraflow/core/workflow/api/reward_api.py similarity index 99% rename from astraflow/workflow/api/reward_api.py rename to astraflow/core/workflow/api/reward_api.py index a5ade35..d39d471 100644 --- a/astraflow/workflow/api/reward_api.py +++ b/astraflow/core/workflow/api/reward_api.py @@ -8,7 +8,7 @@ from concurrent.futures.process import BrokenProcessPool from functools import partial -from astraflow.workflow.utils import logging +from astraflow.core.workflow.utils import logging logger = logging.getLogger("Reward API") diff --git a/astraflow/workflow/api/workflow_api.py b/astraflow/core/workflow/api/workflow_api.py similarity index 92% rename from astraflow/workflow/api/workflow_api.py rename to astraflow/core/workflow/api/workflow_api.py index 73f49f1..2dcb3dc 100644 --- a/astraflow/workflow/api/workflow_api.py +++ b/astraflow/core/workflow/api/workflow_api.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from astraflow.workflow.api.engine_api import InferenceEngine + from astraflow.core.workflow.api.engine_api import InferenceEngine class RolloutWorkflow(ABC): diff --git a/astraflow/workflow/impl/__init__.py b/astraflow/core/workflow/impl/__init__.py similarity index 100% rename from astraflow/workflow/impl/__init__.py rename to astraflow/core/workflow/impl/__init__.py diff --git a/astraflow/workflow/impl/actor_and_verify.py b/astraflow/core/workflow/impl/actor_and_verify.py similarity index 96% rename from astraflow/workflow/impl/actor_and_verify.py rename to astraflow/core/workflow/impl/actor_and_verify.py index c432093..686c68b 100644 --- a/astraflow/workflow/impl/actor_and_verify.py +++ b/astraflow/core/workflow/impl/actor_and_verify.py @@ -43,15 +43,15 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import EngineGroup, InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.reward_api import AsyncRewardWrapper -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id -from astraflow.workflow.utils.dynamic_import import import_from_string +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import EngineGroup, InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.api.reward_api import AsyncRewardWrapper +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id +from astraflow.core.workflow.utils.dynamic_import import import_from_string logger = logging.getLogger("ActorAndVerify workflow") @@ -200,7 +200,7 @@ def __init__( ): self.reward_fn = reward_fn if isinstance(tokenizer, str): - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer tokenizer = load_hf_tokenizer(tokenizer) self.tokenizer = tokenizer diff --git a/astraflow/workflow/impl/agentbench/alfworld_task_server.py b/astraflow/core/workflow/impl/agentbench/alfworld_task_server.py similarity index 88% rename from astraflow/workflow/impl/agentbench/alfworld_task_server.py rename to astraflow/core/workflow/impl/agentbench/alfworld_task_server.py index c4f1407..6d6bf98 100644 --- a/astraflow/workflow/impl/agentbench/alfworld_task_server.py +++ b/astraflow/core/workflow/impl/agentbench/alfworld_task_server.py @@ -4,8 +4,8 @@ from typing import Any -from astraflow.workflow.impl.agentbench.task_server import TaskServerWorkflow -from astraflow.workflow.registry import register_workflow +from astraflow.core.workflow.impl.agentbench.task_server import TaskServerWorkflow +from astraflow.core.workflow.registry import register_workflow @register_workflow("alfworld_task_server") diff --git a/astraflow/workflow/impl/agentbench/task_server.py b/astraflow/core/workflow/impl/agentbench/task_server.py similarity index 96% rename from astraflow/workflow/impl/agentbench/task_server.py rename to astraflow/core/workflow/impl/agentbench/task_server.py index 1159696..6d16dd2 100644 --- a/astraflow/workflow/impl/agentbench/task_server.py +++ b/astraflow/core/workflow/impl/agentbench/task_server.py @@ -17,13 +17,13 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id, results_to_structured +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id, results_to_structured logger = logging.getLogger("TaskServerWorkflow") diff --git a/astraflow/workflow/impl/agentbench/webshop_checker_workflow.py b/astraflow/core/workflow/impl/agentbench/webshop_checker_workflow.py similarity index 97% rename from astraflow/workflow/impl/agentbench/webshop_checker_workflow.py rename to astraflow/core/workflow/impl/agentbench/webshop_checker_workflow.py index bcbdbe6..8685f22 100644 --- a/astraflow/workflow/impl/agentbench/webshop_checker_workflow.py +++ b/astraflow/core/workflow/impl/agentbench/webshop_checker_workflow.py @@ -32,15 +32,15 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.impl.agentbench.webshop_task_server import ( +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.impl.agentbench.webshop_task_server import ( WebshopTaskServerWorkflow, ) -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils.data import resolve_prompt_id -from astraflow.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils.data import resolve_prompt_id +from astraflow.core.workflow.utils import logging, stats_tracker logger = logging.getLogger("WebShopCheckerWorkflow") @@ -300,7 +300,7 @@ async def _run_one_episode( raise ValueError("Data must contain 'task_id' or 'index' field") # Resolve actor/checker engines and tokenizers. - from astraflow.workflow.api.engine_api import EngineGroup + from astraflow.core.workflow.api.engine_api import EngineGroup multi_model = isinstance(engine, EngineGroup) and "model1" in engine if multi_model: diff --git a/astraflow/workflow/impl/agentbench/webshop_task_server.py b/astraflow/core/workflow/impl/agentbench/webshop_task_server.py similarity index 89% rename from astraflow/workflow/impl/agentbench/webshop_task_server.py rename to astraflow/core/workflow/impl/agentbench/webshop_task_server.py index 19b45d4..354c316 100644 --- a/astraflow/workflow/impl/agentbench/webshop_task_server.py +++ b/astraflow/core/workflow/impl/agentbench/webshop_task_server.py @@ -4,8 +4,8 @@ from typing import Any -from astraflow.workflow.impl.agentbench.task_server import TaskServerWorkflow -from astraflow.workflow.registry import register_workflow +from astraflow.core.workflow.impl.agentbench.task_server import TaskServerWorkflow +from astraflow.core.workflow.registry import register_workflow @register_workflow("webshop_task_server") diff --git a/astraflow/workflow/impl/asearcher/__init__.py b/astraflow/core/workflow/impl/asearcher/__init__.py similarity index 100% rename from astraflow/workflow/impl/asearcher/__init__.py rename to astraflow/core/workflow/impl/asearcher/__init__.py diff --git a/astraflow/workflow/impl/asearcher/agent.py b/astraflow/core/workflow/impl/asearcher/agent.py similarity index 100% rename from astraflow/workflow/impl/asearcher/agent.py rename to astraflow/core/workflow/impl/asearcher/agent.py diff --git a/astraflow/workflow/impl/asearcher/prompts.py b/astraflow/core/workflow/impl/asearcher/prompts.py similarity index 100% rename from astraflow/workflow/impl/asearcher/prompts.py rename to astraflow/core/workflow/impl/asearcher/prompts.py diff --git a/astraflow/workflow/impl/asearcher/reward.py b/astraflow/core/workflow/impl/asearcher/reward.py similarity index 100% rename from astraflow/workflow/impl/asearcher/reward.py rename to astraflow/core/workflow/impl/asearcher/reward.py diff --git a/astraflow/workflow/impl/asearcher/search.py b/astraflow/core/workflow/impl/asearcher/search.py similarity index 99% rename from astraflow/workflow/impl/asearcher/search.py rename to astraflow/core/workflow/impl/asearcher/search.py index 3ad4a00..0a29299 100644 --- a/astraflow/workflow/impl/asearcher/search.py +++ b/astraflow/core/workflow/impl/asearcher/search.py @@ -16,7 +16,7 @@ import aiohttp -from astraflow.workflow.utils import logging +from astraflow.core.workflow.utils import logging from .reward import compute_score_em, compute_score_f1 diff --git a/astraflow/workflow/impl/asearcher/workflow.py b/astraflow/core/workflow/impl/asearcher/workflow.py similarity index 96% rename from astraflow/workflow/impl/asearcher/workflow.py rename to astraflow/core/workflow/impl/asearcher/workflow.py index 8116e1b..81b5252 100644 --- a/astraflow/workflow/impl/asearcher/workflow.py +++ b/astraflow/core/workflow/impl/asearcher/workflow.py @@ -13,12 +13,12 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging -from astraflow.workflow.utils.data import resolve_prompt_id +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging +from astraflow.core.workflow.utils.data import resolve_prompt_id from .agent import SearchAgent from .prompts import ( diff --git a/astraflow/workflow/impl/code_actor_and_verify.py b/astraflow/core/workflow/impl/code_actor_and_verify.py similarity index 98% rename from astraflow/workflow/impl/code_actor_and_verify.py rename to astraflow/core/workflow/impl/code_actor_and_verify.py index 8ee0493..357a6a3 100644 --- a/astraflow/workflow/impl/code_actor_and_verify.py +++ b/astraflow/core/workflow/impl/code_actor_and_verify.py @@ -16,14 +16,14 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import EngineGroup, InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id -from astraflow.workflow.utils.code_execution_mraas import ( +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import EngineGroup, InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id +from astraflow.core.workflow.utils.code_execution_mraas import ( SINGLE_CASE_EXEC_TIMEOUT, call_verify_collect_all, extract_python_code, @@ -473,7 +473,7 @@ def __init__( ): del reward_fn if isinstance(tokenizer, str): - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer tokenizer = load_hf_tokenizer(tokenizer) self.tokenizer = tokenizer diff --git a/astraflow/workflow/impl/code_actor_and_verify_v2.py b/astraflow/core/workflow/impl/code_actor_and_verify_v2.py similarity index 98% rename from astraflow/workflow/impl/code_actor_and_verify_v2.py rename to astraflow/core/workflow/impl/code_actor_and_verify_v2.py index 8ab54c1..df2f13b 100644 --- a/astraflow/workflow/impl/code_actor_and_verify_v2.py +++ b/astraflow/core/workflow/impl/code_actor_and_verify_v2.py @@ -17,14 +17,14 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import EngineGroup, InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id -from astraflow.workflow.utils.code_execution_mraas import ( +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import EngineGroup, InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id +from astraflow.core.workflow.utils.code_execution_mraas import ( SINGLE_CASE_EXEC_TIMEOUT, call_verify_collect_all, extract_python_code, @@ -551,7 +551,7 @@ def __init__( ): del reward_fn if isinstance(tokenizer, str): - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer tokenizer = load_hf_tokenizer(tokenizer) self.tokenizer = tokenizer diff --git a/astraflow/workflow/impl/code_actor_and_verify_v3.py b/astraflow/core/workflow/impl/code_actor_and_verify_v3.py similarity index 98% rename from astraflow/workflow/impl/code_actor_and_verify_v3.py rename to astraflow/core/workflow/impl/code_actor_and_verify_v3.py index d388e77..df7d8e8 100644 --- a/astraflow/workflow/impl/code_actor_and_verify_v3.py +++ b/astraflow/core/workflow/impl/code_actor_and_verify_v3.py @@ -70,14 +70,14 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import EngineGroup, InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id -from astraflow.workflow.utils.code_execution_mraas import ( +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import EngineGroup, InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id +from astraflow.core.workflow.utils.code_execution_mraas import ( SINGLE_CASE_EXEC_TIMEOUT, call_verify_collect_all, extract_python_code, @@ -653,7 +653,7 @@ def __init__( ): del reward_fn if isinstance(tokenizer, str): - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer tokenizer = load_hf_tokenizer(tokenizer) self.tokenizer = tokenizer diff --git a/astraflow/workflow/impl/code_solve_and_select.py b/astraflow/core/workflow/impl/code_solve_and_select.py similarity index 97% rename from astraflow/workflow/impl/code_solve_and_select.py rename to astraflow/core/workflow/impl/code_solve_and_select.py index f06cb39..354b2dc 100644 --- a/astraflow/workflow/impl/code_solve_and_select.py +++ b/astraflow/core/workflow/impl/code_solve_and_select.py @@ -44,14 +44,14 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import EngineGroup, InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id -from astraflow.workflow.utils.code_execution_mraas import ( +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import EngineGroup, InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id +from astraflow.core.workflow.utils.code_execution_mraas import ( SINGLE_CASE_EXEC_TIMEOUT, call_verify_collect_all, extract_python_code, @@ -285,7 +285,7 @@ def __init__( ): del reward_fn if isinstance(tokenizer, str): - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer tokenizer = load_hf_tokenizer(tokenizer) self.tokenizer = tokenizer diff --git a/astraflow/workflow/impl/livecodebench_single_turn.py b/astraflow/core/workflow/impl/livecodebench_single_turn.py similarity index 82% rename from astraflow/workflow/impl/livecodebench_single_turn.py rename to astraflow/core/workflow/impl/livecodebench_single_turn.py index 7722b70..f58ecc2 100644 --- a/astraflow/workflow/impl/livecodebench_single_turn.py +++ b/astraflow/core/workflow/impl/livecodebench_single_turn.py @@ -2,8 +2,8 @@ from __future__ import annotations -from astraflow.workflow.impl.rlvr import RLVRWorkflow -from astraflow.workflow.registry import register_workflow +from astraflow.core.workflow.impl.rlvr import RLVRWorkflow +from astraflow.core.workflow.registry import register_workflow def _identity_prompt_extractor(data: dict): diff --git a/astraflow/workflow/impl/multi_turn.py b/astraflow/core/workflow/impl/multi_turn.py similarity index 92% rename from astraflow/workflow/impl/multi_turn.py rename to astraflow/core/workflow/impl/multi_turn.py index b8a4c3b..b25d3b0 100644 --- a/astraflow/workflow/impl/multi_turn.py +++ b/astraflow/core/workflow/impl/multi_turn.py @@ -10,14 +10,14 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.reward_api import AsyncRewardWrapper -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id, results_to_structured +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.api.reward_api import AsyncRewardWrapper +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id, results_to_structured logger = logging.getLogger("Multi-Turn workflow") diff --git a/astraflow/workflow/impl/plan_and_solve.py b/astraflow/core/workflow/impl/plan_and_solve.py similarity index 95% rename from astraflow/workflow/impl/plan_and_solve.py rename to astraflow/core/workflow/impl/plan_and_solve.py index 67209e7..ce7a32d 100644 --- a/astraflow/workflow/impl/plan_and_solve.py +++ b/astraflow/core/workflow/impl/plan_and_solve.py @@ -27,15 +27,15 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import EngineGroup, InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.reward_api import AsyncRewardWrapper -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id, results_to_structured -from astraflow.workflow.utils.dynamic_import import import_from_string +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import EngineGroup, InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.api.reward_api import AsyncRewardWrapper +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id, results_to_structured +from astraflow.core.workflow.utils.dynamic_import import import_from_string logger = logging.getLogger("PlanAndSolve workflow") @@ -120,7 +120,7 @@ def __init__( ): self.reward_fn = reward_fn if isinstance(tokenizer, str): - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer tokenizer = load_hf_tokenizer(tokenizer) self.tokenizer = tokenizer @@ -128,7 +128,7 @@ def __init__( # Separate tokenizer for planner if specified (different chat template) if planner_tokenizer is not None: if isinstance(planner_tokenizer, str): - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer planner_tokenizer = load_hf_tokenizer(planner_tokenizer) self.planner_tokenizer = planner_tokenizer diff --git a/astraflow/workflow/impl/rlvr.py b/astraflow/core/workflow/impl/rlvr.py similarity index 90% rename from astraflow/workflow/impl/rlvr.py rename to astraflow/core/workflow/impl/rlvr.py index 31536fe..7fd78c0 100644 --- a/astraflow/workflow/impl/rlvr.py +++ b/astraflow/core/workflow/impl/rlvr.py @@ -11,16 +11,16 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest, ModelResponse -from astraflow.workflow.api.reward_api import AsyncRewardWrapper -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id, results_to_structured -from astraflow.workflow.utils.dynamic_import import import_from_string -from astraflow.workflow.utils.perf_tracer import ( +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest, ModelResponse +from astraflow.core.workflow.api.reward_api import AsyncRewardWrapper +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id, results_to_structured +from astraflow.core.workflow.utils.dynamic_import import import_from_string +from astraflow.core.workflow.utils.perf_tracer import ( atrace_session_phase, session_context, trace_session, @@ -80,7 +80,7 @@ def __init__( self.reward_fn = reward_fn self.tokenizer = tokenizer if isinstance(self.tokenizer, str): - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer tokenizer = load_hf_tokenizer(self.tokenizer) self.tokenizer = tokenizer diff --git a/astraflow/workflow/impl/sep_solve_and_check.py b/astraflow/core/workflow/impl/sep_solve_and_check.py similarity index 94% rename from astraflow/workflow/impl/sep_solve_and_check.py rename to astraflow/core/workflow/impl/sep_solve_and_check.py index ba97a34..caedf03 100644 --- a/astraflow/workflow/impl/sep_solve_and_check.py +++ b/astraflow/core/workflow/impl/sep_solve_and_check.py @@ -38,15 +38,15 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import EngineGroup, InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.reward_api import AsyncRewardWrapper -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id -from astraflow.workflow.utils.dynamic_import import import_from_string +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import EngineGroup, InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.api.reward_api import AsyncRewardWrapper +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id +from astraflow.core.workflow.utils.dynamic_import import import_from_string logger = logging.getLogger("SolveAndCheckV2 workflow") @@ -153,7 +153,7 @@ def __init__( ): self.reward_fn = reward_fn if isinstance(tokenizer, str): - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer tokenizer = load_hf_tokenizer(tokenizer) self.tokenizer = tokenizer diff --git a/astraflow/workflow/impl/sm_lg_router.py b/astraflow/core/workflow/impl/sm_lg_router.py similarity index 97% rename from astraflow/workflow/impl/sm_lg_router.py rename to astraflow/core/workflow/impl/sm_lg_router.py index be35734..35ab671 100644 --- a/astraflow/workflow/impl/sm_lg_router.py +++ b/astraflow/core/workflow/impl/sm_lg_router.py @@ -51,14 +51,14 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import EngineGroup, InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.reward_api import AsyncRewardWrapper -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import EngineGroup, InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.api.reward_api import AsyncRewardWrapper +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id logger = logging.getLogger(__name__) @@ -201,7 +201,7 @@ def __init__( ): self.reward_fn = reward_fn if isinstance(tokenizer, str): - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer tokenizer = load_hf_tokenizer(tokenizer) self.tokenizer = tokenizer diff --git a/astraflow/workflow/impl/solve_and_check.py b/astraflow/core/workflow/impl/solve_and_check.py similarity index 95% rename from astraflow/workflow/impl/solve_and_check.py rename to astraflow/core/workflow/impl/solve_and_check.py index b7deb71..e55716e 100644 --- a/astraflow/workflow/impl/solve_and_check.py +++ b/astraflow/core/workflow/impl/solve_and_check.py @@ -28,15 +28,15 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import EngineGroup, InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.reward_api import AsyncRewardWrapper -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id, results_to_structured -from astraflow.workflow.utils.dynamic_import import import_from_string +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import EngineGroup, InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.api.reward_api import AsyncRewardWrapper +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id, results_to_structured +from astraflow.core.workflow.utils.dynamic_import import import_from_string logger = logging.getLogger("SolveAndCheck workflow") @@ -111,7 +111,7 @@ def __init__( ): self.reward_fn = reward_fn if isinstance(tokenizer, str): - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer tokenizer = load_hf_tokenizer(tokenizer) self.tokenizer = tokenizer diff --git a/astraflow/workflow/impl/solve_and_verify.py b/astraflow/core/workflow/impl/solve_and_verify.py similarity index 96% rename from astraflow/workflow/impl/solve_and_verify.py rename to astraflow/core/workflow/impl/solve_and_verify.py index 98be42c..e85be73 100644 --- a/astraflow/workflow/impl/solve_and_verify.py +++ b/astraflow/core/workflow/impl/solve_and_verify.py @@ -45,15 +45,15 @@ import torch from transformers import PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import EngineGroup, InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.reward_api import AsyncRewardWrapper -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id -from astraflow.workflow.utils.dynamic_import import import_from_string +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import EngineGroup, InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest +from astraflow.core.workflow.api.reward_api import AsyncRewardWrapper +from astraflow.core.workflow.api.workflow_api import RolloutWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id +from astraflow.core.workflow.utils.dynamic_import import import_from_string logger = logging.getLogger("SolveAndVerify workflow") @@ -210,7 +210,7 @@ def __init__( ): self.reward_fn = reward_fn if isinstance(tokenizer, str): - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer tokenizer = load_hf_tokenizer(tokenizer) self.tokenizer = tokenizer diff --git a/astraflow/workflow/impl/vision_rlvr.py b/astraflow/core/workflow/impl/vision_rlvr.py similarity index 90% rename from astraflow/workflow/impl/vision_rlvr.py rename to astraflow/core/workflow/impl/vision_rlvr.py index 2b08fac..9d75e7a 100644 --- a/astraflow/workflow/impl/vision_rlvr.py +++ b/astraflow/core/workflow/impl/vision_rlvr.py @@ -10,15 +10,15 @@ import torch from transformers import AutoProcessor, PreTrainedTokenizerFast -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.engine_api import InferenceEngine -from astraflow.workflow.api.io_struct import ModelRequest, ModelResponse -from astraflow.workflow.impl.rlvr import RLVRWorkflow -from astraflow.workflow.registry import register_workflow -from astraflow.workflow.utils import logging, stats_tracker -from astraflow.workflow.utils.data import resolve_prompt_id, results_to_structured -from astraflow.workflow.utils.image import image2base64 -from astraflow.workflow.utils.perf_tracer import ( +from astraflow.core.workflow.api.cli_args import GenerationHyperparameters +from astraflow.core.workflow.api.engine_api import InferenceEngine +from astraflow.core.workflow.api.io_struct import ModelRequest, ModelResponse +from astraflow.core.workflow.impl.rlvr import RLVRWorkflow +from astraflow.core.workflow.registry import register_workflow +from astraflow.core.workflow.utils import logging, stats_tracker +from astraflow.core.workflow.utils.data import resolve_prompt_id, results_to_structured +from astraflow.core.workflow.utils.image import image2base64 +from astraflow.core.workflow.utils.perf_tracer import ( atrace_session_phase, session_context, trace_session, diff --git a/astraflow/workflow/registry.py b/astraflow/core/workflow/registry.py similarity index 100% rename from astraflow/workflow/registry.py rename to astraflow/core/workflow/registry.py diff --git a/astraflow/workflow/reward/__init__.py b/astraflow/core/workflow/reward/__init__.py similarity index 98% rename from astraflow/workflow/reward/__init__.py rename to astraflow/core/workflow/reward/__init__.py index 10d612f..48a9153 100644 --- a/astraflow/workflow/reward/__init__.py +++ b/astraflow/core/workflow/reward/__init__.py @@ -2,7 +2,7 @@ from math_verify.metric import math_metric from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig -from astraflow.workflow.utils import logging +from astraflow.core.workflow.utils import logging logger = logging.getLogger(__name__) diff --git a/astraflow/workflow/reward/clevr_count_70k.py b/astraflow/core/workflow/reward/clevr_count_70k.py similarity index 90% rename from astraflow/workflow/reward/clevr_count_70k.py rename to astraflow/core/workflow/reward/clevr_count_70k.py index 402a836..b66b9b7 100644 --- a/astraflow/workflow/reward/clevr_count_70k.py +++ b/astraflow/core/workflow/reward/clevr_count_70k.py @@ -1,6 +1,6 @@ import re -from astraflow.workflow.registry import register_reward +from astraflow.core.workflow.registry import register_reward def extract_answer(pred_str, data_name, use_last_number=True): diff --git a/astraflow/workflow/reward/geometry3k.py b/astraflow/core/workflow/reward/geometry3k.py similarity index 84% rename from astraflow/workflow/reward/geometry3k.py rename to astraflow/core/workflow/reward/geometry3k.py index 9da1170..2c92628 100644 --- a/astraflow/workflow/reward/geometry3k.py +++ b/astraflow/core/workflow/reward/geometry3k.py @@ -1,8 +1,8 @@ import re -from astraflow.workflow.registry import register_reward -from astraflow.workflow.reward import get_math_verify_worker -from astraflow.workflow.utils import logging +from astraflow.core.workflow.registry import register_reward +from astraflow.core.workflow.reward import get_math_verify_worker +from astraflow.core.workflow.utils import logging logger = logging.getLogger(__name__) diff --git a/astraflow/workflow/reward/human_eval_reward.py b/astraflow/core/workflow/reward/human_eval_reward.py similarity index 95% rename from astraflow/workflow/reward/human_eval_reward.py rename to astraflow/core/workflow/reward/human_eval_reward.py index 8f5b889..9cf190a 100644 --- a/astraflow/workflow/reward/human_eval_reward.py +++ b/astraflow/core/workflow/reward/human_eval_reward.py @@ -6,9 +6,9 @@ import sys from pathlib import Path -from astraflow.workflow.registry import register_reward -from astraflow.workflow.utils import logging -from astraflow.workflow.utils.code_execution_mraas import ( +from astraflow.core.workflow.registry import register_reward +from astraflow.core.workflow.utils import logging +from astraflow.core.workflow.utils.code_execution_mraas import ( SINGLE_CASE_EXEC_TIMEOUT, extract_python_code, ) diff --git a/astraflow/workflow/reward/livecodebench_reward.py b/astraflow/core/workflow/reward/livecodebench_reward.py similarity index 96% rename from astraflow/workflow/reward/livecodebench_reward.py rename to astraflow/core/workflow/reward/livecodebench_reward.py index 57a52ff..2ae42e0 100644 --- a/astraflow/workflow/reward/livecodebench_reward.py +++ b/astraflow/core/workflow/reward/livecodebench_reward.py @@ -13,9 +13,9 @@ from pathlib import Path from typing import Any -from astraflow.workflow.registry import register_reward -from astraflow.workflow.utils import logging -from astraflow.workflow.utils.code_execution_mraas import verifier_work_dir +from astraflow.core.workflow.registry import register_reward +from astraflow.core.workflow.utils import logging +from astraflow.core.workflow.utils.code_execution_mraas import verifier_work_dir logger = logging.getLogger(__name__) @@ -32,7 +32,7 @@ def _repo_root() -> Path: def _verifier_module() -> str: - return "astraflow.workflow.utils.testing_util" + return "astraflow.core.workflow.utils.testing_util" def _verifier_script_path() -> str: diff --git a/astraflow/workflow/reward/math_verify.py b/astraflow/core/workflow/reward/math_verify.py similarity index 76% rename from astraflow/workflow/reward/math_verify.py rename to astraflow/core/workflow/reward/math_verify.py index a7dd5de..098e8bd 100644 --- a/astraflow/workflow/reward/math_verify.py +++ b/astraflow/core/workflow/reward/math_verify.py @@ -1,6 +1,6 @@ -from astraflow.workflow.registry import register_reward -from astraflow.workflow.reward import get_math_verify_worker -from astraflow.workflow.utils import logging +from astraflow.core.workflow.registry import register_reward +from astraflow.core.workflow.reward import get_math_verify_worker +from astraflow.core.workflow.utils import logging logger = logging.getLogger(__name__) diff --git a/astraflow/workflow/utils/__init__.py b/astraflow/core/workflow/utils/__init__.py similarity index 100% rename from astraflow/workflow/utils/__init__.py rename to astraflow/core/workflow/utils/__init__.py diff --git a/astraflow/workflow/utils/code_execution_mraas.py b/astraflow/core/workflow/utils/code_execution_mraas.py similarity index 98% rename from astraflow/workflow/utils/code_execution_mraas.py rename to astraflow/core/workflow/utils/code_execution_mraas.py index ea3a87e..8d1378e 100644 --- a/astraflow/workflow/utils/code_execution_mraas.py +++ b/astraflow/core/workflow/utils/code_execution_mraas.py @@ -15,7 +15,7 @@ from pathlib import Path from typing import Any -from astraflow.workflow.utils import logging +from astraflow.core.workflow.utils import logging logger = logging.getLogger(__name__) @@ -35,7 +35,7 @@ def _repo_root() -> Path: def _verifier_module() -> str: - return "astraflow.workflow.utils.testing_util_mraas" + return "astraflow.core.workflow.utils.testing_util_mraas" def _verifier_script_path() -> str: diff --git a/astraflow/workflow/utils/data.py b/astraflow/core/workflow/utils/data.py similarity index 100% rename from astraflow/workflow/utils/data.py rename to astraflow/core/workflow/utils/data.py diff --git a/astraflow/workflow/utils/dynamic_import.py b/astraflow/core/workflow/utils/dynamic_import.py similarity index 100% rename from astraflow/workflow/utils/dynamic_import.py rename to astraflow/core/workflow/utils/dynamic_import.py diff --git a/astraflow/workflow/utils/hf_utils.py b/astraflow/core/workflow/utils/hf_utils.py similarity index 96% rename from astraflow/workflow/utils/hf_utils.py rename to astraflow/core/workflow/utils/hf_utils.py index 9c66bf7..7b830b6 100644 --- a/astraflow/workflow/utils/hf_utils.py +++ b/astraflow/core/workflow/utils/hf_utils.py @@ -2,7 +2,7 @@ import transformers -from astraflow.workflow.utils import logging +from astraflow.core.workflow.utils import logging logger = logging.getLogger("HF Utility") diff --git a/astraflow/workflow/utils/image.py b/astraflow/core/workflow/utils/image.py similarity index 100% rename from astraflow/workflow/utils/image.py rename to astraflow/core/workflow/utils/image.py diff --git a/astraflow/workflow/utils/logging.py b/astraflow/core/workflow/utils/logging.py similarity index 100% rename from astraflow/workflow/utils/logging.py rename to astraflow/core/workflow/utils/logging.py diff --git a/astraflow/workflow/utils/perf_tracer.py b/astraflow/core/workflow/utils/perf_tracer.py similarity index 100% rename from astraflow/workflow/utils/perf_tracer.py rename to astraflow/core/workflow/utils/perf_tracer.py diff --git a/astraflow/workflow/utils/stats_tracker.py b/astraflow/core/workflow/utils/stats_tracker.py similarity index 100% rename from astraflow/workflow/utils/stats_tracker.py rename to astraflow/core/workflow/utils/stats_tracker.py diff --git a/astraflow/workflow/utils/testing_util.py b/astraflow/core/workflow/utils/testing_util.py similarity index 100% rename from astraflow/workflow/utils/testing_util.py rename to astraflow/core/workflow/utils/testing_util.py diff --git a/astraflow/workflow/utils/testing_util_mraas.py b/astraflow/core/workflow/utils/testing_util_mraas.py similarity index 100% rename from astraflow/workflow/utils/testing_util_mraas.py rename to astraflow/core/workflow/utils/testing_util_mraas.py diff --git a/astraflow/dataflow/__main__.py b/astraflow/dataflow/__main__.py index 52063ce..ba6924a 100644 --- a/astraflow/dataflow/__main__.py +++ b/astraflow/dataflow/__main__.py @@ -17,7 +17,7 @@ def _parse_config(config_path: str) -> ServiceConfig: """Parse an experiment YAML into a ServiceConfig.""" - from astraflow.config.loader import load_and_merge_configs, load_dataflow_config + from astraflow.core.config.loader import load_and_merge_configs, load_dataflow_config raw = load_and_merge_configs([config_path]) af = load_dataflow_config(raw) diff --git a/astraflow/dataflow/data_acquisition.py b/astraflow/dataflow/data_acquisition.py index dbabf3f..9dc644e 100644 --- a/astraflow/dataflow/data_acquisition.py +++ b/astraflow/dataflow/data_acquisition.py @@ -814,7 +814,7 @@ def _submit_tick_debug(self, max_submit_per_tick: int): with self._stats_lock: self._curator_stats["selected"] += 1 if _DEBUG_PRODUCER: - from astraflow.workflow.utils.data import resolve_prompt_id as _rpi + from astraflow.core.workflow.utils.data import resolve_prompt_id as _rpi _qid = _rpi(data) or "" global _DEBUG_SUBMIT_COUNTER with _DEBUG_SUBMIT_LOCK: diff --git a/astraflow/dataflow/dataset/deepcoder_preview.py b/astraflow/dataflow/dataset/deepcoder_preview.py index b8c0607..12754a7 100644 --- a/astraflow/dataflow/dataset/deepcoder_preview.py +++ b/astraflow/dataflow/dataset/deepcoder_preview.py @@ -12,7 +12,7 @@ SINGLE_TURN_LCB_PROMPT_TEMPLATE, ) from astraflow.dataflow.dataset.utils import attach_query_ids -from astraflow.workflow.utils import logging +from astraflow.core.workflow.utils import logging logger = logging.getLogger(__name__) diff --git a/astraflow/dataflow/dataset/human_eval.py b/astraflow/dataflow/dataset/human_eval.py index 7f67291..e96d986 100644 --- a/astraflow/dataflow/dataset/human_eval.py +++ b/astraflow/dataflow/dataset/human_eval.py @@ -7,7 +7,7 @@ from datasets import load_dataset from astraflow.dataflow.dataset.utils import attach_query_ids -from astraflow.workflow.utils import logging +from astraflow.core.workflow.utils import logging logger = logging.getLogger(__name__) HF_DATASETS_CACHE_DIR = "/tmp/hf-datasets" diff --git a/astraflow/dataflow/dataset/livecodebench.py b/astraflow/dataflow/dataset/livecodebench.py index 89ade31..a8c19a8 100644 --- a/astraflow/dataflow/dataset/livecodebench.py +++ b/astraflow/dataflow/dataset/livecodebench.py @@ -9,7 +9,7 @@ from datasets import load_dataset from astraflow.dataflow.dataset.utils import attach_query_ids -from astraflow.workflow.utils import logging +from astraflow.core.workflow.utils import logging logger = logging.getLogger(__name__) diff --git a/astraflow/dataflow/prompt_curators.py b/astraflow/dataflow/prompt_curators.py index 7a17bd7..3f1ccdb 100644 --- a/astraflow/dataflow/prompt_curators.py +++ b/astraflow/dataflow/prompt_curators.py @@ -276,7 +276,7 @@ def should_submit(self, data, *, version): # Resolve qid via the shared helper so the streak table built by # update() (keyed on the workflow-stamped prompt_id) matches the # lookup here. Both sides MUST go through resolve_prompt_id. - from astraflow.workflow.utils.data import resolve_prompt_id + from astraflow.core.workflow.utils.data import resolve_prompt_id qid = resolve_prompt_id(data) if qid is None: diff --git a/astraflow/raas/api/cli_args.py b/astraflow/raas/api/cli_args.py index b28ac71..b79af81 100644 --- a/astraflow/raas/api/cli_args.py +++ b/astraflow/raas/api/cli_args.py @@ -899,7 +899,7 @@ def parse_cli_args(argv: list[str]): config_paths = args.config # list of paths due to action="append" - from astraflow.config.loader import load_and_merge_configs, load_raas_config + from astraflow.core.config.loader import load_and_merge_configs, load_raas_config raw = load_and_merge_configs(config_paths) raas_dict = load_raas_config(raw) diff --git a/astraflow/raas/server/manager.py b/astraflow/raas/server/manager.py index 765ad79..e985f2c 100644 --- a/astraflow/raas/server/manager.py +++ b/astraflow/raas/server/manager.py @@ -16,8 +16,8 @@ from astraflow.raas.platforms import current_platform from astraflow.raas.utils import logging from astraflow.raas.utils.network import find_free_ports, gethostip -from astraflow.workflow.api.engine_api import EngineGroup -from astraflow.workflow.registry import get_reward, get_workflow +from astraflow.core.workflow.api.engine_api import EngineGroup +from astraflow.core.workflow.registry import get_reward, get_workflow _base_logger = logging.getLogger(__name__) logger = _base_logger # replaced with adapter after engine_id is known @@ -307,7 +307,7 @@ async def bootstrap( self._gconfig = deepcopy(config.gconfig) tokenizer_path = config.tokenizer_path if tokenizer_path: - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer self._tokenizer = load_hf_tokenizer(tokenizer_path) logger.info( @@ -391,7 +391,7 @@ async def bootstrap_multi_model( in *allocation_mode* and produces its own engine. """ from astraflow.raas.api.cli_args import ModelSpec - from astraflow.workflow.utils.hf_utils import load_hf_tokenizer + from astraflow.core.workflow.utils.hf_utils import load_hf_tokenizer self._ensure_async_state() if self._status == "ready": diff --git a/astraflow/raas/server/tcp_receiver.py b/astraflow/raas/server/tcp_receiver.py index 4f2314c..a7d399c 100644 --- a/astraflow/raas/server/tcp_receiver.py +++ b/astraflow/raas/server/tcp_receiver.py @@ -25,12 +25,12 @@ import torch import zmq -from astraflow.weight_manager.transfer.config import ( +from astraflow.core.weight_manager.transfer.config import ( TransferEngineConfig, TransferStatus, ) -from astraflow.weight_manager.transfer.receiver_agent import TransferBuffer -from astraflow.weight_manager.transfer.transfer_engine import TCPTransferEngine +from astraflow.core.weight_manager.transfer.receiver_agent import TransferBuffer +from astraflow.core.weight_manager.transfer.transfer_engine import TCPTransferEngine logger = logging.getLogger(__name__) diff --git a/astraflow/train_worker/api/cli_args.py b/astraflow/train_worker/api/cli_args.py index 6b1a6a7..8ff1be0 100644 --- a/astraflow/train_worker/api/cli_args.py +++ b/astraflow/train_worker/api/cli_args.py @@ -1642,7 +1642,7 @@ def parse_cli_args(argv: list[str]): config_file = Path(args.config).absolute() assert config_file.exists(), f"Config file {config_file} does not exist." - from astraflow.config.loader import load_and_merge_configs, load_trainer_config + from astraflow.core.config.loader import load_and_merge_configs, load_trainer_config raw = load_and_merge_configs([str(config_file)]) trainer_key = args.trainer or "trainer" diff --git a/astraflow/train_worker/api/engine_api.py b/astraflow/train_worker/api/engine_api.py index 59e8efa..823fcff 100644 --- a/astraflow/train_worker/api/engine_api.py +++ b/astraflow/train_worker/api/engine_api.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: from astraflow.train_worker.utils.data import MicroBatchList - from astraflow.workflow.api.workflow_api import RolloutWorkflow + from astraflow.core.workflow.api.workflow_api import RolloutWorkflow class TrainEngine(abc.ABC): @@ -518,7 +518,7 @@ def submit( - An instance of RolloutWorkflow (for sharing resources between rollouts) - A RolloutWorkflow class type (will be instantiated with workflow_kwargs) - - A string module path like "astraflow.workflow.impl.rlvr.RLVRWorkflow" (will be imported + - A string module path like "astraflow.core.workflow.impl.rlvr.RLVRWorkflow" (will be imported and instantiated with workflow_kwargs) workflow_kwargs : dict[str, Any], optional Keyword arguments to pass to the workflow constructor when workflow is a type or string. @@ -614,7 +614,7 @@ def rollout_batch( - An instance of RolloutWorkflow (for sharing resources between rollouts) - A RolloutWorkflow class type (will be instantiated with workflow_kwargs) - - A string module path like "astraflow.workflow.impl.rlvr.RLVRWorkflow" (will be imported + - A string module path like "astraflow.core.workflow.impl.rlvr.RLVRWorkflow" (will be imported and instantiated with workflow_kwargs) workflow_kwargs : dict[str, Any], optional Keyword arguments to pass to the workflow constructor when workflow is a type or string. @@ -661,7 +661,7 @@ def prepare_batch( - An instance of RolloutWorkflow (for sharing resources between rollouts) - A RolloutWorkflow class type (will be instantiated with workflow_kwargs) - - A string module path like "astraflow.workflow.impl.rlvr.RLVRWorkflow" (will be imported + - A string module path like "astraflow.core.workflow.impl.rlvr.RLVRWorkflow" (will be imported and instantiated with workflow_kwargs) workflow_kwargs : dict[str, Any], optional Keyword arguments to pass to the workflow constructor when workflow is a type or string. diff --git a/astraflow/train_worker/trainer/ppo_trainer.py b/astraflow/train_worker/trainer/ppo_trainer.py index a504d3e..b90af8c 100644 --- a/astraflow/train_worker/trainer/ppo_trainer.py +++ b/astraflow/train_worker/trainer/ppo_trainer.py @@ -190,8 +190,8 @@ def _init_weight_manager(self) -> None: """ import socket - from astraflow.weight_manager import WeightManager, WeightManagerConfig - from astraflow.weight_manager.transfer.config import ( + from astraflow.core.weight_manager import WeightManager, WeightManagerConfig + from astraflow.core.weight_manager.transfer.config import ( SenderAgentConfig, TransferEngineConfig, ) diff --git a/astraflow/workflow/__init__.py b/astraflow/workflow/__init__.py deleted file mode 100644 index e409aa8..0000000 --- a/astraflow/workflow/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Standalone workflow package for rollout workflows and reward functions. - -Importing this package triggers auto-registration of all built-in -workflows and reward functions via their @register_workflow / @register_reward -decorators. -""" - -# Auto-import implementations to trigger registry decorators -import astraflow.workflow.impl.agentbench.alfworld_task_server -import astraflow.workflow.impl.agentbench.task_server -import astraflow.workflow.impl.agentbench.webshop_task_server -import astraflow.workflow.impl.agentbench.webshop_checker_workflow -import astraflow.workflow.impl.asearcher -import astraflow.workflow.impl.code_actor_and_verify -import astraflow.workflow.impl.code_actor_and_verify_v2 -import astraflow.workflow.impl.code_actor_and_verify_v3 -import astraflow.workflow.impl.code_solve_and_select -import astraflow.workflow.impl.livecodebench_single_turn -import astraflow.workflow.impl.multi_turn -import astraflow.workflow.impl.plan_and_solve -import astraflow.workflow.impl.solve_and_check -import astraflow.workflow.impl.sep_solve_and_check -import astraflow.workflow.impl.solve_and_verify -import astraflow.workflow.impl.actor_and_verify -import astraflow.workflow.impl.rlvr -import astraflow.workflow.impl.sm_lg_router -import astraflow.workflow.impl.vision_rlvr -import astraflow.workflow.reward.clevr_count_70k -import astraflow.workflow.reward.geometry3k -import astraflow.workflow.reward.math_verify -import astraflow.workflow.reward.human_eval_reward -import astraflow.workflow.reward.livecodebench_reward diff --git a/docs/en/architecture/custom-raas.md b/docs/en/architecture/custom-raas.md index aeeb4b2..9bf2138 100644 --- a/docs/en/architecture/custom-raas.md +++ b/docs/en/architecture/custom-raas.md @@ -611,7 +611,7 @@ streams, and ZMQ completion signaling. Your RaaS just calls A workflow is user code that RaaS must execute correctly. Contract: ```python -# astraflow/workflow/api/workflow_api.py +# astraflow/core/workflow/api/workflow_api.py class RolloutWorkflow(ABC): @abstractmethod async def arun_episode( @@ -631,7 +631,7 @@ class RolloutWorkflow(ABC): workflow can look it up. The `InferenceEngine` protocol your RaaS passes to workflows -(`astraflow/workflow/api/engine_api.py`): +(`astraflow/core/workflow/api/engine_api.py`): ```python class InferenceEngine(Protocol): @@ -850,8 +850,8 @@ If you're starting from scratch, read these in order: 3. `astraflow/raas/server/__main__.py` — launcher and self-registration. 4. `astraflow/raas/server/tcp_receiver.py` — `RaaSWeightReceiver` (reuse this). -5. `astraflow/workflow/api/workflow_api.py` and - `astraflow/workflow/api/engine_api.py` — the contracts your RaaS +5. `astraflow/core/workflow/api/workflow_api.py` and + `astraflow/core/workflow/api/engine_api.py` — the contracts your RaaS must honor for workflows. 6. `astraflow/dataflow/raas2_engine.py` — the client AstraFlow uses to talk to you; matching its method signatures is the diff --git a/docs/en/architecture/multi-agent-weight-transfer.md b/docs/en/architecture/multi-agent-weight-transfer.md index 582b900..62871de 100644 --- a/docs/en/architecture/multi-agent-weight-transfer.md +++ b/docs/en/architecture/multi-agent-weight-transfer.md @@ -324,6 +324,6 @@ trainer_model1: | RaaS pool fan-out (one call per model) | `astraflow/dataflow/raas_pool.py` | `RaaSPool.notify_version()` (line 442), `_notify_one_model()` (line 410) | | RaaS manager pull + load | `astraflow/raas/server/manager.py` | `notify_version()` (line 1556), `_do_weight_update()` (line 1612), `_pull_weights_to_disk()` | | TCP receiver | `astraflow/raas/server/tcp_receiver.py` | `RaaSWeightReceiver` | -| Sender agent | `astraflow/weight_manager/transfer/sender_agent.py` | `SenderAgent` | -| WeightManager offload | `astraflow/weight_manager/weight_manager.py` | `offload()` | +| Sender agent | `astraflow/core/weight_manager/transfer/sender_agent.py` | `SenderAgent` | +| WeightManager offload | `astraflow/core/weight_manager/weight_manager.py` | `offload()` | | Trainer integration | `astraflow/train_worker/trainer/ppo_trainer.py` | `AstraFlowPPOTrainer` | diff --git a/docs/en/architecture/trainer.md b/docs/en/architecture/trainer.md index 906b544..b94f502 100644 --- a/docs/en/architecture/trainer.md +++ b/docs/en/architecture/trainer.md @@ -57,7 +57,7 @@ The trainer interacts with two components: | Weights | `POST` | `/request_transfer` | Every step (pull weights over TCP) | The weight sender is provided as a reusable library -(`astraflow.weight_manager.transfer.sender_agent`) so custom trainers +(`astraflow.core.weight_manager.transfer.sender_agent`) so custom trainers don't need to reimplement TCP/ZMQ machinery. See [WeightManager](weight-manager.md) for details. diff --git a/docs/en/architecture/weight-manager.md b/docs/en/architecture/weight-manager.md index d36de2f..73746ce 100644 --- a/docs/en/architecture/weight-manager.md +++ b/docs/en/architecture/weight-manager.md @@ -1,6 +1,6 @@ # WeightManager -The WeightManager (`astraflow/weight_manager/`) is an independent component +The WeightManager (`astraflow/core/weight_manager/`) is an independent component that handles all weight transfer between Trainer and RaaS. ## Design Principle: Independent Transport Layer @@ -240,7 +240,7 @@ The buffer index swap is a single Python int assignment (atomic under GIL). ## Project Structure ``` -astraflow/weight_manager/ +astraflow/core/weight_manager/ __init__.py ← exports WeightManager, WeightManagerConfig weight_manager.py ← main class: buffer mgmt, GPU→CPU copy, sender lifecycle config.py ← WeightManagerConfig From 6939aaf0b70cb933aaf7de9e684bd722203d51e3 Mon Sep 17 00:00:00 2001 From: Haizhong Date: Wed, 20 May 2026 09:26:46 -0400 Subject: [PATCH 2/3] fix: correct path-relative computations in moved verifier helpers The reorg of workflow/ under core/ added one directory level. Three helpers that compute paths from __file__ still assumed the pre-reorg depth: - _repo_root() in livecodebench_reward.py, human_eval_reward.py, and code_execution_mraas.py used parents[3], which now lands on the package root instead of the repo root; fix to parents[4] - _verifier_script_path() in livecodebench_reward.py and code_execution_mraas.py hardcoded "workflow" path-segments; fix to "core"/"workflow" Symptom: code/livecodebench/human_eval reward subprocesses launched with a non-existent script path, returned rc=2, and the reward fn caught FileNotFoundError and returned [-4]. All rewards came out 0.0, the m-1 filter dropped every rollout group, and the trainer sat at step 0. Other recipes (math, multi-agent, greso, alfworld) weren't affected because they don't go through the subprocess verifier. --- astraflow/core/workflow/reward/human_eval_reward.py | 4 +++- astraflow/core/workflow/reward/livecodebench_reward.py | 6 ++++-- astraflow/core/workflow/utils/code_execution_mraas.py | 6 ++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/astraflow/core/workflow/reward/human_eval_reward.py b/astraflow/core/workflow/reward/human_eval_reward.py index 9cf190a..b29317a 100644 --- a/astraflow/core/workflow/reward/human_eval_reward.py +++ b/astraflow/core/workflow/reward/human_eval_reward.py @@ -17,7 +17,9 @@ def _ensure_human_eval_importable() -> None: - repo_root = Path(__file__).resolve().parents[3] + # __file__ is astraflow/core/workflow/reward/human_eval_reward.py; + # parents[4] = repo root (parents[3] = package root since the reorg). + repo_root = Path(__file__).resolve().parents[4] human_eval_root = repo_root / "astraEnv" / "human-eval" human_eval_path = str(human_eval_root) if human_eval_path not in sys.path: diff --git a/astraflow/core/workflow/reward/livecodebench_reward.py b/astraflow/core/workflow/reward/livecodebench_reward.py index 2ae42e0..5747959 100644 --- a/astraflow/core/workflow/reward/livecodebench_reward.py +++ b/astraflow/core/workflow/reward/livecodebench_reward.py @@ -28,7 +28,9 @@ def _repo_root() -> Path: - return Path(__file__).resolve().parents[3] + # __file__ is astraflow/core/workflow/reward/livecodebench_reward.py; + # parents[4] = repo root (parents[3] = package root since the reorg). + return Path(__file__).resolve().parents[4] def _verifier_module() -> str: @@ -42,7 +44,7 @@ def _verifier_script_path() -> str: automatic script-dir sys.path injection, avoiding the 14 s astraflow package import while still loading the standalone script. """ - return str(_repo_root() / "astraflow" / "workflow" / "utils" / "testing_util.py") + return str(_repo_root() / "astraflow" / "core" / "workflow" / "utils" / "testing_util.py") def _extract_python_code(text: str, min_length: int = 20) -> str | None: diff --git a/astraflow/core/workflow/utils/code_execution_mraas.py b/astraflow/core/workflow/utils/code_execution_mraas.py index 8d1378e..42cdc59 100644 --- a/astraflow/core/workflow/utils/code_execution_mraas.py +++ b/astraflow/core/workflow/utils/code_execution_mraas.py @@ -31,7 +31,9 @@ def _repo_root() -> Path: - return Path(__file__).resolve().parents[3] + # __file__ is astraflow/core/workflow/utils/code_execution_mraas.py; + # parents[4] = repo root (parents[3] = package root since the reorg). + return Path(__file__).resolve().parents[4] def _verifier_module() -> str: @@ -42,7 +44,7 @@ def _verifier_script_path() -> str: """Absolute path to testing_util_mraas.py for direct invocation via ``python -P `` (skips script-dir sys.path injection, avoiding the 14 s astraflow package import).""" - return str(_repo_root() / "astraflow" / "workflow" / "utils" / "testing_util_mraas.py") + return str(_repo_root() / "astraflow" / "core" / "workflow" / "utils" / "testing_util_mraas.py") def _verifier_work_root() -> Path | None: From ccae8a92866fd3ee5c40e626a19f2260d3226477 Mon Sep 17 00:00:00 2001 From: Haizhong Date: Wed, 20 May 2026 09:29:17 -0400 Subject: [PATCH 3/3] chore: delete dead Python tree under astraEnv/ASearcher The astraEnv/ASearcher/ASearcher/ subtree and astraEnv/ASearcher/agent/ were vendored upstream snapshots that duplicated logic now living in astraflow/core/workflow/impl/asearcher/. Nothing in the live code path (search recipe, examples, in-package workflows, RAG server scripts) loaded them, but their unconditional astraflow imports created a backwards layering edge from astraEnv into astraflow. Kept (verified live via examples/search/ and astraEnv/ASearcher/scripts/): - scripts/ - RAG server launchers - tools/ - local_retrieval_server.py invoked by scripts - utils/ - index_builder.py invoked by build_index.sh - evaluation/ - config_loader imported by tools/search_utils.py - configs/, demo/, docs/, assets/, qa_synthesis/ - user-facing assets astraEnv is now a pure scripts/assets directory with no Python reaching back into astraflow. --- astraEnv/ASearcher/ASearcher/__init__.py | 0 .../ASearcher/configs/asearcher_local.yaml | 162 ---- .../configs/asearcher_local_1.5b_example.yaml | 153 ---- .../configs/asearcher_local_16nodes.yaml | 165 ---- .../ASearcher/configs/asearcher_web.yaml | 162 ---- .../configs/asearcher_web_16nodes.yaml | 165 ---- .../ASearcher/configs/asearcher_web_qwq.yaml | 167 ----- .../ASearcher/ASearcher/train/asearcher.py | 553 -------------- .../ASearcher/train/asearcher_reasoning.py | 467 ------------ .../ASearcher/train/asearcher_train.py | 106 --- astraEnv/ASearcher/ASearcher/train/prompts.py | 4 - .../ASearcher/train/reasoning_agent.py | 703 ------------------ .../ASearcher/ASearcher/train/search_agent.py | 187 ----- astraEnv/ASearcher/ASearcher/utils/rewards.py | 272 ------- .../ASearcher/ASearcher/utils/search_tool.py | 173 ----- .../ASearcher/ASearcher/utils/search_utils.py | 433 ----------- .../ASearcher/ASearcher/utils/web_browser.py | 164 ---- astraEnv/ASearcher/agent/__init__.py | 13 - astraEnv/ASearcher/agent/asearcher.py | 192 ----- .../ASearcher/agent/asearcher_reasoning.py | 579 --------------- astraEnv/ASearcher/agent/search_r1.py | 297 -------- 21 files changed, 5117 deletions(-) delete mode 100644 astraEnv/ASearcher/ASearcher/__init__.py delete mode 100644 astraEnv/ASearcher/ASearcher/configs/asearcher_local.yaml delete mode 100644 astraEnv/ASearcher/ASearcher/configs/asearcher_local_1.5b_example.yaml delete mode 100644 astraEnv/ASearcher/ASearcher/configs/asearcher_local_16nodes.yaml delete mode 100644 astraEnv/ASearcher/ASearcher/configs/asearcher_web.yaml delete mode 100644 astraEnv/ASearcher/ASearcher/configs/asearcher_web_16nodes.yaml delete mode 100644 astraEnv/ASearcher/ASearcher/configs/asearcher_web_qwq.yaml delete mode 100644 astraEnv/ASearcher/ASearcher/train/asearcher.py delete mode 100644 astraEnv/ASearcher/ASearcher/train/asearcher_reasoning.py delete mode 100644 astraEnv/ASearcher/ASearcher/train/asearcher_train.py delete mode 100644 astraEnv/ASearcher/ASearcher/train/prompts.py delete mode 100644 astraEnv/ASearcher/ASearcher/train/reasoning_agent.py delete mode 100644 astraEnv/ASearcher/ASearcher/train/search_agent.py delete mode 100644 astraEnv/ASearcher/ASearcher/utils/rewards.py delete mode 100644 astraEnv/ASearcher/ASearcher/utils/search_tool.py delete mode 100644 astraEnv/ASearcher/ASearcher/utils/search_utils.py delete mode 100644 astraEnv/ASearcher/ASearcher/utils/web_browser.py delete mode 100644 astraEnv/ASearcher/agent/__init__.py delete mode 100644 astraEnv/ASearcher/agent/asearcher.py delete mode 100644 astraEnv/ASearcher/agent/asearcher_reasoning.py delete mode 100644 astraEnv/ASearcher/agent/search_r1.py diff --git a/astraEnv/ASearcher/ASearcher/__init__.py b/astraEnv/ASearcher/ASearcher/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/astraEnv/ASearcher/ASearcher/configs/asearcher_local.yaml b/astraEnv/ASearcher/ASearcher/configs/asearcher_local.yaml deleted file mode 100644 index 174a5b2..0000000 --- a/astraEnv/ASearcher/ASearcher/configs/asearcher_local.yaml +++ /dev/null @@ -1,162 +0,0 @@ -experiment_name: asearcher-7b-zero-local -trial_name: run1 - -cluster: - fileroot: /tmp/areal/experiments - n_nodes: 1 - n_gpus_per_node: 8 - name_resolve: - type: nfs - nfs_record_root: /tmp/areal/name_resolve -seed: 1 -total_train_epochs: 10 -total_train_steps: null -tokenizer_path: ${actor.path} -allocation_mode: sglang.d4p1t1+d1c4 -async_training: true - -rollout: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - max_concurrent_rollouts: 32 - queue_size: null - consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 4 - enable_rollout_tracing: true - -gconfig: - n_samples: 16 - min_new_tokens: 0 - max_new_tokens: 1024 - greedy: false - temperature: 1.0 - -actor: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: Qwen/Qwen2.5-7B - init_from_scratch: false - disable_dropout: true - gradient_checkpointing: true - dtype: bfloat16 - mb_spec: - max_tokens_per_mb: 16000 - pad_to_maximum: true - optimizer: - type: adam - lr: 5e-6 - weight_decay: 0.01 - beta1: 0.9 - beta2: 0.999 - eps: 1e-8 - lr_scheduler_type: constant - gradient_clipping: 0.2 - warmup_steps_proportion: 0.001 - backend: fsdp - - group_size: ${gconfig.n_samples} - group_reward_norm: false - eps_clip: 0.4 - temperature: ${gconfig.temperature} - reward_scaling: 10.0 - reward_bias: -0.5 - kl_ctl: 0.0 - ppo_n_minibatches: 1 - recompute_logprob: true - use_decoupled_loss: true - behav_imp_weight_cap: 5.0 - adv_norm: - mean_level: batch - std_level: batch - -ref: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: ${actor.path} - init_from_scratch: false - disable_dropout: true - dtype: ${actor.dtype} - mb_spec: - max_tokens_per_mb: 32768 - optimizer: null - backend: fsdp - -# SGLang -sglang: - model_path: ${actor.path} - random_seed: ${seed} - skip_tokenizer_init: false - dtype: ${actor.dtype} - max_running_requests: null - context_length: 32768 - mem_fraction_static: 0.6 - attention_backend: fa3 - -# datasets -train_dataset: - batch_size: 128 - shuffle: true - pin_memory: true - path: path_to_training_data - -# Utilities -saver: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: 10 - freq_secs: 3600 - -recover: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: 3600 - -evaluator: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: null - freq_steps: null - freq_secs: null - -stats_logger: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - wandb: - mode: disabled - -# Launcher -launcher: - inference_server_cpus_per_gpu: 15 - inference_server_mem_per_gpu: 153600 - trainer_cpus_per_gpu: 15 - trainer_mem_per_gpu: 153600 - trainer_env_vars: PYTHONPATH=path_to_asearcher:path_to_areal,WANDB_API_KEY=your_wandb_api_key,RAG_SERVER_ADDR_DIR=directory_of_rag_server_addrs - - -max_turns: 32 -n_trajs: 16 -search_client_type: async-search-access -reward_type: F1 -topk: 5 -valid_inst_ratio: 0.3 -log_agent_stats: true -log_agent_stats_keys: - - num_input_tokens - - num_output_tokens - - num_llm_gens - - num_search_queries - - num_success_search_queries - - num_failed_search_queries - - num_pages - - num_success_url_accesses - - num_failed_url_accesses - - score - - judge_q_invalid - - format_reward \ No newline at end of file diff --git a/astraEnv/ASearcher/ASearcher/configs/asearcher_local_1.5b_example.yaml b/astraEnv/ASearcher/ASearcher/configs/asearcher_local_1.5b_example.yaml deleted file mode 100644 index 5d67587..0000000 --- a/astraEnv/ASearcher/ASearcher/configs/asearcher_local_1.5b_example.yaml +++ /dev/null @@ -1,153 +0,0 @@ -experiment_name: asearcher-1.5b-example -trial_name: run1 - -cluster: - fileroot: /tmp/areal/experiments - n_nodes: 1 - n_gpus_per_node: 8 - name_resolve: - type: nfs - nfs_record_root: /tmp/areal/name_resolve -seed: 1 -total_train_epochs: 10 -total_train_steps: null -tokenizer_path: ${actor.path} -allocation_mode: sglang.d4p1t1+d4p1t1 -async_training: true - -rollout: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - max_concurrent_rollouts: 128 - queue_size: null - consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 4 - enable_rollout_tracing: true - -gconfig: - n_samples: 16 - min_new_tokens: 0 - max_new_tokens: 1024 - greedy: false - temperature: 1.0 - -actor: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: Qwen/Qwen2.5-1.5B - init_from_scratch: false - disable_dropout: true - gradient_checkpointing: true - dtype: bfloat16 - mb_spec: - max_tokens_per_mb: 8000 - optimizer: - type: adam - lr: 5e-6 - weight_decay: 0.01 - beta1: 0.9 - beta2: 0.999 - eps: 1e-8 - lr_scheduler_type: constant - gradient_clipping: 0.2 - warmup_steps_proportion: 0.001 - backend: fsdp - - group_size: ${gconfig.n_samples} - group_adv_norm: false - eps_clip: 0.4 - temperature: ${gconfig.temperature} - reward_scaling: 10.0 - reward_bias: -0.5 - kl_ctl: 0.0 - ppo_n_minibatches: 1 - recompute_logprob: true - use_decoupled_loss: true - behav_imp_weight_cap: 5.0 - log_agent_stats: true - log_agent_stats_keys: - - num_input_tokens - - num_output_tokens - - num_llm_gens - - num_search_queries - - num_success_search_queries - - num_failed_search_queries - - num_pages - - num_success_url_accesses - - num_failed_url_accesses - - score - - judge_q_invalid - - format_reward - -ref: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: ${actor.path} - init_from_scratch: false - disable_dropout: true - dtype: ${actor.dtype} - mb_spec: - max_tokens_per_mb: 32768 - optimizer: null - backend: fsdp - -# SGLang -server_only: false -sglang: - model_path: ${actor.path} - random_seed: ${seed} - skip_tokenizer_init: false - dtype: ${actor.dtype} - max_running_requests: null - context_length: 32768 - mem_fraction_static: 0.9 - attention_backend: fa3 - -# datasets -train_dataset: - batch_size: 128 - shuffle: true - pin_memory: true - path: path_to_training_data - -# Utilities -saver: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: 10 - freq_secs: 3600 - -checkpointer: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: 3600 - -evaluator: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: null - freq_steps: null - freq_secs: null - -stats_logger: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - wandb: - mode: disabled - -# Launcher -launcher: - inference_server_cpus_per_gpu: 15 - inference_server_mem_per_gpu: 153600 - trainer_cpus_per_gpu: 15 - trainer_mem_per_gpu: 153600 - - -max_turns: 32 \ No newline at end of file diff --git a/astraEnv/ASearcher/ASearcher/configs/asearcher_local_16nodes.yaml b/astraEnv/ASearcher/ASearcher/configs/asearcher_local_16nodes.yaml deleted file mode 100644 index ea5cd19..0000000 --- a/astraEnv/ASearcher/ASearcher/configs/asearcher_local_16nodes.yaml +++ /dev/null @@ -1,165 +0,0 @@ -experiment_name: asearcher-7b-zero-local-16nodes -trial_name: run1 - -cluster: - fileroot: /tmp/areal/experiments - n_nodes: 16 - n_gpus_per_node: 8 - name_resolve: - type: nfs - nfs_record_root: /tmp/areal/name_resolve -seed: 1 -total_train_epochs: 10 -total_train_steps: null -tokenizer_path: ${actor.path} -allocation_mode: sglang.d96p1t1+d8c4 -async_training: true - -rollout: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - max_concurrent_rollouts: 32 - queue_size: null - consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 4 - enable_rollout_tracing: true - -gconfig: - n_samples: 16 - min_new_tokens: 0 - max_new_tokens: 1024 - greedy: false - temperature: 1.0 - -actor: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: Qwen/Qwen2.5-7B - init_from_scratch: false - disable_dropout: true - gradient_checkpointing: true - dtype: bfloat16 - mb_spec: - max_tokens_per_mb: 12000 - pad_to_maximum: true - optimizer: - type: adam - lr: 5e-6 - weight_decay: 0.01 - beta1: 0.9 - beta2: 0.999 - eps: 1e-8 - lr_scheduler_type: constant - gradient_clipping: 0.2 - warmup_steps_proportion: 0.001 - backend: fsdp - - group_size: ${gconfig.n_samples} - group_reward_norm: false - eps_clip: 0.4 - temperature: ${gconfig.temperature} - reward_scaling: 10.0 - reward_bias: -0.5 - kl_ctl: 0.0 - ppo_n_minibatches: 1 - recompute_logprob: true - use_decoupled_loss: true - behav_imp_weight_cap: 5.0 - adv_norm: - mean_level: batch - std_level: batch - -ref: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: ${actor.path} - init_from_scratch: false - disable_dropout: true - dtype: ${actor.dtype} - mb_spec: - max_tokens_per_mb: 32768 - optimizer: null - backend: fsdp - -# SGLang -sglang: - model_path: ${actor.path} - random_seed: ${seed} - skip_tokenizer_init: false - dtype: ${actor.dtype} - max_running_requests: null - context_length: 32768 - mem_fraction_static: 0.6 - attention_backend: fa3 - -# datasets -train_dataset: - batch_size: 128 - shuffle: true - pin_memory: true - path: path_to_training_data - -# Utilities -saver: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: 10 - freq_secs: 3600 - -recover: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: 3600 - -evaluator: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: null - freq_steps: null - freq_secs: null - -stats_logger: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - wandb: - mode: disabled - -# Launcher -launcher: - inference_server_cpus_per_gpu: 15 - inference_server_mem_per_gpu: 153600 - trainer_cpus_per_gpu: 15 - trainer_mem_per_gpu: 153600 - slurm: - mount: /storage:/storage # mount share storage if necessary - trainer_image: AReaL_Lite_Image - inference_server_image: AReaL_Lite_Image - trainer_env_vars: PYTHONPATH=path_to_asearcher:path_to_areal,WANDB_API_KEY=your_wandb_api_key,RAG_SERVER_ADDR_DIR=Directory_to_RAG_Server_Address - -max_turns: 32 -n_trajs: 16 -search_client_type: async-search-access -reward_type: F1 -topk: 5 -valid_inst_ratio: 0.3 -log_agent_stats: true -log_agent_stats_keys: - - num_input_tokens - - num_output_tokens - - num_llm_gens - - num_search_queries - - num_success_search_queries - - num_failed_search_queries - - num_pages - - num_success_url_accesses - - num_failed_url_accesses - - score - - judge_q_invalid - - format_reward \ No newline at end of file diff --git a/astraEnv/ASearcher/ASearcher/configs/asearcher_web.yaml b/astraEnv/ASearcher/ASearcher/configs/asearcher_web.yaml deleted file mode 100644 index 35a6fdc..0000000 --- a/astraEnv/ASearcher/ASearcher/configs/asearcher_web.yaml +++ /dev/null @@ -1,162 +0,0 @@ -experiment_name: asearcher-7b-zero-web -trial_name: run1 - -cluster: - fileroot: /tmp/areal/experiments - n_nodes: 1 - n_gpus_per_node: 8 - name_resolve: - type: nfs - nfs_record_root: /tmp/areal/name_resolve -seed: 1 -total_train_epochs: 10 -total_train_steps: null -tokenizer_path: ${actor.path} -allocation_mode: sglang.d4p1t1+d1c4 -async_training: true - -rollout: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - max_concurrent_rollouts: 32 - queue_size: null - consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 4 - enable_rollout_tracing: true - -gconfig: - n_samples: 16 - min_new_tokens: 0 - max_new_tokens: 1024 - greedy: false - temperature: 1.0 - -actor: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: Qwen/Qwen2.5-7B - init_from_scratch: false - disable_dropout: true - gradient_checkpointing: true - dtype: bfloat16 - mb_spec: - max_tokens_per_mb: 16000 - pad_to_maximum: true - optimizer: - type: adam - lr: 5e-6 - weight_decay: 0.01 - beta1: 0.9 - beta2: 0.999 - eps: 1e-8 - lr_scheduler_type: constant - gradient_clipping: 0.2 - warmup_steps_proportion: 0.001 - backend: fsdp - - group_size: ${gconfig.n_samples} - group_reward_norm: false - eps_clip: 0.4 - temperature: ${gconfig.temperature} - reward_scaling: 10.0 - reward_bias: -0.5 - kl_ctl: 0.0 - ppo_n_minibatches: 1 - recompute_logprob: true - use_decoupled_loss: true - behav_imp_weight_cap: 5.0 - adv_norm: - mean_level: batch - std_level: batch - -ref: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: ${actor.path} - init_from_scratch: false - disable_dropout: true - dtype: ${actor.dtype} - mb_spec: - max_tokens_per_mb: 32768 - optimizer: null - backend: fsdp - -# SGLang -sglang: - model_path: ${actor.path} - random_seed: ${seed} - skip_tokenizer_init: false - dtype: ${actor.dtype} - max_running_requests: null - context_length: 32768 - mem_fraction_static: 0.6 - attention_backend: fa3 - -# datasets -train_dataset: - batch_size: 128 - shuffle: true - pin_memory: true - path: path_to_training_data - -# Utilities -saver: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: 10 - freq_secs: 3600 - -recover: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: 3600 - -evaluator: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: null - freq_steps: null - freq_secs: null - -stats_logger: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - wandb: - mode: disabled - -# Launcher -launcher: - inference_server_cpus_per_gpu: 15 - inference_server_mem_per_gpu: 153600 - trainer_cpus_per_gpu: 15 - trainer_mem_per_gpu: 153600 - trainer_env_vars: PYTHONPATH=path_to_asearcher:path_to_areal,WANDB_API_KEY=your_wandb_api_key,RAG_SERVER_ADDR_DIR=directory_of_rag_server_addrs - - -max_turns: 32 -n_trajs: 16 -search_client_type: async-online-search-access -reward_type: F1 -topk: 5 -valid_inst_ratio: 0.3 -log_agent_stats: true -log_agent_stats_keys: - - num_input_tokens - - num_output_tokens - - num_llm_gens - - num_search_queries - - num_success_search_queries - - num_failed_search_queries - - num_pages - - num_success_url_accesses - - num_failed_url_accesses - - score - - judge_q_invalid - - format_reward \ No newline at end of file diff --git a/astraEnv/ASearcher/ASearcher/configs/asearcher_web_16nodes.yaml b/astraEnv/ASearcher/ASearcher/configs/asearcher_web_16nodes.yaml deleted file mode 100644 index d6df032..0000000 --- a/astraEnv/ASearcher/ASearcher/configs/asearcher_web_16nodes.yaml +++ /dev/null @@ -1,165 +0,0 @@ -experiment_name: asearcher-7b-zero-web-16nodes -trial_name: run1 - -cluster: - fileroot: /tmp/areal/experiments - n_nodes: 16 - n_gpus_per_node: 8 - name_resolve: - type: nfs - nfs_record_root: /tmp/areal/name_resolve -seed: 1 -total_train_epochs: 10 -total_train_steps: null -tokenizer_path: ${actor.path} -allocation_mode: sglang.d96p1t1+d8c4 -async_training: true - -rollout: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - max_concurrent_rollouts: 32 - queue_size: null - consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 4 - enable_rollout_tracing: true - -gconfig: - n_samples: 16 - min_new_tokens: 0 - max_new_tokens: 1024 - greedy: false - temperature: 1.0 - -actor: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: Qwen/Qwen2.5-7B - init_from_scratch: false - disable_dropout: true - gradient_checkpointing: true - dtype: bfloat16 - mb_spec: - max_tokens_per_mb: 12000 - pad_to_maximum: true - optimizer: - type: adam - lr: 5e-6 - weight_decay: 0.01 - beta1: 0.9 - beta2: 0.999 - eps: 1e-8 - lr_scheduler_type: constant - gradient_clipping: 0.2 - warmup_steps_proportion: 0.001 - backend: fsdp - - group_size: ${gconfig.n_samples} - group_reward_norm: false - eps_clip: 0.4 - temperature: ${gconfig.temperature} - reward_scaling: 10.0 - reward_bias: -0.5 - kl_ctl: 0.0 - ppo_n_minibatches: 1 - recompute_logprob: true - use_decoupled_loss: true - behav_imp_weight_cap: 5.0 - adv_norm: - mean_level: batch - std_level: batch - -ref: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: ${actor.path} - init_from_scratch: false - disable_dropout: true - dtype: ${actor.dtype} - mb_spec: - max_tokens_per_mb: 32768 - optimizer: null - backend: fsdp - -# SGLang -sglang: - model_path: ${actor.path} - random_seed: ${seed} - skip_tokenizer_init: false - dtype: ${actor.dtype} - max_running_requests: null - context_length: 32768 - mem_fraction_static: 0.6 - attention_backend: fa3 - -# datasets -train_dataset: - batch_size: 128 - shuffle: true - pin_memory: true - path: path_to_training_data - -# Utilities -saver: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: 10 - freq_secs: 3600 - -recover: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: 3600 - -evaluator: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: null - freq_steps: null - freq_secs: null - -stats_logger: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - wandb: - mode: disabled - -# Launcher -launcher: - inference_server_cpus_per_gpu: 15 - inference_server_mem_per_gpu: 153600 - trainer_cpus_per_gpu: 15 - trainer_mem_per_gpu: 153600 - slurm: - mount: /storage:/storage # mount share storage if necessary - trainer_image: AReaL_Lite_Image - inference_server_image: AReaL_Lite_Image - trainer_env_vars: PYTHONPATH=path_to_asearcher:path_to_areal,WANDB_API_KEY=your_wandb_api_key,SERPER_API_KEY=your_serper_api_key - -max_turns: 32 -n_trajs: 16 -search_client_type: async-online-search-access -reward_type: F1 -topk: 5 -valid_inst_ratio: 0.3 -log_agent_stats: true -log_agent_stats_keys: - - num_input_tokens - - num_output_tokens - - num_llm_gens - - num_search_queries - - num_success_search_queries - - num_failed_search_queries - - num_pages - - num_success_url_accesses - - num_failed_url_accesses - - score - - judge_q_invalid - - format_reward \ No newline at end of file diff --git a/astraEnv/ASearcher/ASearcher/configs/asearcher_web_qwq.yaml b/astraEnv/ASearcher/ASearcher/configs/asearcher_web_qwq.yaml deleted file mode 100644 index 0bafac2..0000000 --- a/astraEnv/ASearcher/ASearcher/configs/asearcher_web_qwq.yaml +++ /dev/null @@ -1,167 +0,0 @@ -experiment_name: asearcher-qwq-web-8nodes -trial_name: run1 - -cluster: - fileroot: /tmp/areal/experiments - n_nodes: 6 - n_gpus_per_node: 8 - name_resolve: - type: nfs - nfs_record_root: /tmp/areal/experiments -seed: 1 -total_train_epochs: 10 -total_train_steps: null -tokenizer_path: ${actor.path} -allocation_mode: sglang.d2t8+d4t8 -async_training: true - -rollout: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - max_concurrent_rollouts: 160 - queue_size: null - consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 4 - enable_rollout_tracing: true - -gconfig: - n_samples: 16 - min_new_tokens: 0 - max_new_tokens: 30000 - greedy: false - temperature: 1.0 - -actor: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: Qwen/QwQ-32B - init_from_scratch: false - disable_dropout: true - gradient_checkpointing: true - dtype: bfloat16 - mb_spec: - max_tokens_per_mb: 32000 - pad_to_maximum: true - optimizer: - type: adam - lr: 5e-5 - weight_decay: 0.01 - beta1: 0.9 - beta2: 0.999 - eps: 1e-8 - lr_scheduler_type: constant - gradient_clipping: 1.0 - warmup_steps_proportion: 0.001 - backend: fsdp - - group_size: ${gconfig.n_samples} - adv_norm: - group_size: ${gconfig.n_samples} - eps_clip: 0.4 - temperature: ${gconfig.temperature} - reward_scaling: 10.0 - reward_bias: -0.5 - kl_ctl: 0.0 - ppo_n_minibatches: 1 - recompute_logprob: true - use_decoupled_loss: true - behav_imp_weight_cap: 5.0 - -ref: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: ${actor.path} - init_from_scratch: false - disable_dropout: true - dtype: ${actor.dtype} - mb_spec: - max_tokens_per_mb: 32768 - optimizer: null - backend: fsdp - -# SGLang -sglang: - model_path: ${actor.path} - random_seed: ${seed} - skip_tokenizer_init: false - dtype: ${actor.dtype} - max_running_requests: null - context_length: 32768 - mem_fraction_static: 0.8 - attention_backend: fa3 - -# datasets -train_dataset: - batch_size: 64 - shuffle: true - pin_memory: true - path: path_to_training_data - -# Utilities -saver: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: 10 - freq_secs: 3600 - - -recover: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: 3600 - -evaluator: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: null - freq_steps: null - freq_secs: null - -stats_logger: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - wandb: - mode: online - -# Launcher -launcher: - inference_server_cpus_per_gpu: 15 - inference_server_mem_per_gpu: 153600 - trainer_cpus_per_gpu: 15 - trainer_mem_per_gpu: 153600 - slurm: - mount: /storage:/storage # mount share storage if necessary - trainer_image: AReaL_Lite_Image - inference_server_image: AReaL_Lite_Image - - trainer_env_vars: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,WANDB_API_KEY=your_wandb_api_key,SERPER_API_KEY=your_serper_api_key,JINA_API_KEY=your_jina_api_key - -judge_engine: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - max_concurrent_rollouts: 1 - queue_size: null - consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 4 - enable_rollout_tracing: false - -max_turns: 128 -force_turns: 4 -n_trajs: 16 -search_client_type: async-online-search-access -topk: 10 -log_agent_stats: true -log_agent_stats_keys: - - turns - - num_search - - num_access - - score - - num_input_tokens - - num_output_tokens \ No newline at end of file diff --git a/astraEnv/ASearcher/ASearcher/train/asearcher.py b/astraEnv/ASearcher/ASearcher/train/asearcher.py deleted file mode 100644 index c0daffe..0000000 --- a/astraEnv/ASearcher/ASearcher/train/asearcher.py +++ /dev/null @@ -1,553 +0,0 @@ -import asyncio -import os -import random -import sys -import uuid -import json -import gc -import torch -import torch.distributed as dist -import numpy as np -from datasets import load_dataset -from datasets.distributed import split_dataset_by_node -from torchdata.stateful_dataloader import StatefulDataLoader -from transformers import PreTrainedTokenizerFast -from dataclasses import dataclass, field -from typing import List - -import hashlib - -from astraflow.workflow.api.cli_args import GenerationHyperparameters -from astraflow.workflow.api.io_struct import ModelRequest -from astraflow.workflow.api.workflow_api import RolloutWorkflow -from astraflow.workflow.utils import logging -try: - from astraflow.train_worker.api.cli_args import GRPOConfig, load_expr_config - from astraflow.train_worker.utils.hf_utils import load_hf_tokenizer -except Exception: - from areal.api.cli_args import GRPOConfig, load_expr_config # type: ignore - from areal.utils.hf_utils import load_hf_tokenizer # type: ignore - -from astraEnv.ASearcher.ASearcher.train.prompts import ( - INVALID_PROMPT, - SEARCH_ACCESS_PROMPT_TEMPLATE, - SEARCH_ONLY_PROMPT_TEMPLATE, - VALID_PROMPT, -) -from astraEnv.ASearcher.ASearcher.train.search_agent import SearchAgent -from astraEnv.ASearcher.ASearcher.utils.rewards import correct_format_fn -from astraEnv.ASearcher.ASearcher.utils.search_tool import SearchToolBox - -worker_id = uuid.uuid4().hex[:4] - -logger = logging.getLogger(f"ASearcher @ {worker_id}") - -def hash(numbers): - """Hash an entire list of integers as a single string""" - # Convert list to string representation - list_str = json.dumps(numbers, sort_keys=True) # sort_keys for consistency - return hashlib.sha256(list_str.encode()).hexdigest() - - -class ASearcherWorkflow(RolloutWorkflow): - def __init__( - self, - gconfig: GenerationHyperparameters, - tokenizer: PreTrainedTokenizerFast, - dataset_path: str, - dump_dir: str | None = None, - max_turns: int = 128, - n_trajs: int = 1, - search_client_type: str = "async-online-search-access", - reward_type: str = "F1", - topk: int = 5, - valid_inst_ratio: float = 1.0, - max_tokens: int = 32000, - search_only: bool = True, - ): - self.gconfig = gconfig - self.gconfig.n_samples = 1 - self.tokenizer = tokenizer - self.dump_dir = dump_dir - self.max_tokens = max_tokens - self.search_only = search_only - if self.dump_dir is not None and not os.path.exists(self.dump_dir): - os.makedirs(self.dump_dir, exist_ok=True) - - # Search hyper-parameters - self.max_turns = max_turns - self.n_trajs = n_trajs - self.reward_type = reward_type - self.topk = topk - self.valid_inst_ratio = valid_inst_ratio - self.search_client_type = search_client_type - - self.toolbox = SearchToolBox(dataset_path=dataset_path, reward_type=self.reward_type, topk=self.topk, search_client_type=self.search_client_type) - - async def collect_agent_trajectory(self, valid_inst, qid, prompt, prompt_token_ids, engine): - agent = SearchAgent(prompt, prompt_token_ids) - score = 0 - ground_truth = None - # a unique trajectory rid to ensure all requests goes to the same sglang server - traj_rid = uuid.uuid4().hex - while agent.num_turns < self.max_turns and not agent.is_finished: - # The agent prepares the prompt and sampling params for LLM generation - input_ids, sampling_params = agent.prepare_llm_query(self.tokenizer) - - # Send request to inference engine and get response - req = ModelRequest( - rid=traj_rid, - input_ids=input_ids, - gconfig=self.gconfig.new(n_samples=1), - ) - if "stop" in sampling_params: - req.gconfig.stop = sampling_params["stop"] - if len(input_ids) + self.gconfig.max_new_tokens >= self.max_tokens: - break - resp = await engine.agenerate(req) - completion_str = self.tokenizer.decode(resp.output_tokens) - - # agent extracts tool callings from the llm response - tool_calls = agent.consume_llm_response(resp, completion_str) - - # call tool and compute reward - if tool_calls is not None and len(tool_calls) > 0: - tool_call = tool_calls[0] - res = (await self.toolbox.step((qid, [tool_call])))[0] - - agent.consume_tool_response(res, topk=self.topk) - - if "score" in res: - score = res["score"] - if "ground_truth" in res: - ground_truth = res["ground_truth"] - - if resp.output_tokens[-1] in [self.tokenizer.eos_token_id, self.tokenizer.pad_token_id]: - break - - llm_gen_records = agent.memory.filter_records("llm_gen") - format_reward = float(all([correct_format_fn(i, r.text) for i, r in enumerate(llm_gen_records)])) - - # compute rewards - score = (score or 0) * format_reward - pred_answer = agent.get_answer() - judge_q_invalid = False - if pred_answer is not None: - judge_q_invalid = any([_c in pred_answer for _c in ["question", "invalid", "appropriate", "valid"]]) - if valid_inst and judge_q_invalid: - score = -0.5 - - stats = agent.memory.logging_stats() - stats.update(dict( - score=score, - judge_q_invalid = judge_q_invalid, - format_reward=format_reward, - )) - - return ground_truth, score, agent.memory, stats - - async def arun_episode(self, engine, data): - # Get the unique identifier for this prompt - qid = None - for key in ["query_id", "id", "qid"]: - qid = data.get(key, None) - if qid is not None: - break - qid = str(qid) if qid is not None else uuid.uuid4().hex - - # Decide whether to dump this rollout (1/128 probability). - _should_dump = self.dump_dir is not None and random.random() < 1 / 128 - - # Initialize and Prepare the prompt - version = engine.get_version() - prompt_template = SEARCH_ONLY_PROMPT_TEMPLATE if self.search_only else SEARCH_ACCESS_PROMPT_TEMPLATE - prompt = prompt_template.format(question=data["question"]) - valid_inst: bool = np.random.uniform(0, 1) <= self.valid_inst_ratio - if valid_inst: - prompt = prompt.replace(INVALID_PROMPT, VALID_PROMPT) - prompt_token_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] - - # Collect trajectories - trajs = await asyncio.gather(*[self.collect_agent_trajectory(valid_inst, qid, prompt, prompt_token_ids, engine) for _ in range(self.n_trajs)]) - - ground_truth, scores, stats = None, [], [] - for gt, score, traj, traj_stats in trajs: - if gt is not None: - ground_truth = gt - scores.append(score) - stats.append(traj_stats) - - # Build structured result: prompt → trajectories → sequences - # Normalization and zero-adv filtering are handled by AstraFlow. - traj_memories = [traj for _, _, traj, _ in trajs] - trajectories = [] - for i, traj_memory in enumerate(traj_memories): - seqs = [] - for j, record in enumerate(traj_memory.memory): - if record.type != "llm_gen": - continue - - # Check whether any previous seq is equivalent to input tokens - success = False - for seq in seqs: - if record.input_len < len(seq["input_ids"]): - continue - h_cur = hash(record.input_tokens[:len(seq["input_ids"])]) - h_seq = hash(seq["input_ids"]) - if h_cur == h_seq: - seq_len = len(seq["input_ids"]) - seq["input_ids"] = record.input_tokens + record.output_tokens - seq["logprobs"] += [0.0] * (record.input_len - seq_len) + record.output_logprobs - seq["loss_mask"] += [0] * (record.input_len - seq_len) + [1] * record.output_len - seq["versions"] += [-1] * (record.input_len - seq_len) + record.output_versions - success = True - break - if not success: - seq = dict( - input_ids = record.input_tokens + record.output_tokens, - logprobs = [0.0] * record.input_len + record.output_logprobs, - loss_mask = [0] * record.input_len + [1] * record.output_len, - versions = [-1] * record.input_len + record.output_versions, - ) - seqs.append(seq) - - traj_stats_i = stats[i] - first_llm_gen = True - seq_dicts = [] - for seq in seqs: - res = dict( - input_ids=torch.tensor(seq["input_ids"]).unsqueeze(0), - loss_mask=torch.tensor(seq["loss_mask"]).unsqueeze(0), - logprobs=torch.tensor(seq["logprobs"]).unsqueeze(0), - versions=torch.tensor(seq["versions"]).unsqueeze(0), - attention_mask=torch.ones(len(seq["input_ids"]), dtype=torch.bool).unsqueeze(0), - begin_of_trajectory=torch.tensor([int(first_llm_gen)]), - ) - res.update({k: torch.tensor([v]) for k, v in traj_stats_i.items()}) - first_llm_gen = False - seq_dicts.append(res) - - trajectories.append({ - "sequences": seq_dicts, - "stats": traj_stats_i, - }) - - if _should_dump: - os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True) - with open( - os.path.join(self.dump_dir, str(version), f"{qid}.jsonl"), "w" - ) as f: - for i, (traj_memory, score) in enumerate(zip(traj_memories, scores)): - f.write(json.dumps(dict(memory=traj_memory.to_dict(), reward=score, ground_truth=ground_truth, traj_idx=i)) + "\n") - - return { - "n_trajs": self.n_trajs, - "rewards": torch.tensor(scores, dtype=torch.float32), - "trajectories": trajectories, - } - -@dataclass -class AgentRLConfig(GRPOConfig): - max_turns: int = field( - default=128, - metadata={ - "help": "maximum number of turns for search agent" - } - ) - n_trajs: int = field( - default=1, - metadata={ - "help": "We could collect multiple trajectories for a single query. By default n_trajs=1." - } - ) - search_client_type: str = field( - default="async-online-search-access", - metadata={ - "help": "Type of tool (async-online-search-access/async-search-access). By default we use 'async-online-search-access'" - } - ) - reward_type: str = field( - default="F1", - metadata={ - "help": "The type of reward function" - } - ) - topk: int = field( - default=5, - metadata={ - "help": "search returns the top-k results. Default top_k=5" - } - ) - valid_inst_ratio: float = field( - default=1.0, - metadata={ - "help": "We randomly force a ratio of queries to produce valid anwers. By default valid_inst_ratio=1.0" - } - ) - # Logging Agent Trajectories - log_agent_stats: bool = field( - default=False, - metadata={ - "help": "Log stats for agent trajectories" - }, - ) - log_agent_stats_keys: List[str] = field( - default_factory=lambda: ["num_llm_gens"], - metadata={ - "help": "Keys of log stats for agent trajectories" - }, - ) - - -def get_search_dataset(dataset_path, tokenizer, rank, world_size): - dataset = load_dataset( - path="json", - split="train", - data_files=dataset_path, - ) - # dataset = dataset.filter(lambda x: len(tokenizer.encode(x["question"])) <= 1024) - return split_dataset_by_node(dataset, rank=rank, world_size=world_size) - -def main(args): - try: - from areal.api.io_struct import ( - AllocationMode, - FinetuneSpec, - StepInfo, - WeightUpdateMeta, - ) - from areal.engine.ppo.actor import FSDPPPOActor - from areal.engine.sglang_remote import RemoteSGLangEngine - from areal.platforms import current_platform - from areal.utils import seeding, stats_tracker - from areal.utils.data import broadcast_tensor_container, cycle_dataloader - from areal.utils.device import log_gpu_stats - from areal.utils.evaluator import Evaluator - from areal.utils.recover import RecoverHandler - from areal.utils.saver import Saver - from areal.utils.stats_logger import StatsLogger - except Exception as exc: - raise RuntimeError( - "Legacy ASearcher standalone trainer dependencies are unavailable. " - "Use AstraFlow v2 recipe entrypoints for training." - ) from exc - - config, _ = load_expr_config(args, AgentRLConfig) - config: AgentRLConfig - - rank = int(os.getenv("RANK")) - world_size = int(os.getenv("WORLD_SIZE")) - tokenizer = load_hf_tokenizer(config.tokenizer_path) - - seeding.set_random_seed(config.seed, key=f"trainer{rank}") - allocation_mode = AllocationMode.from_str(config.allocation_mode) - parallel_strategy = allocation_mode.train - - # Initialize train engine - actor = FSDPPPOActor(config=config.actor) - actor.create_process_group(parallel_strategy=parallel_strategy) - ref = None - - # Create dataset and dataloaders - worker_batch_size = config.train_dataset.batch_size // world_size - train_dataloader = StatefulDataLoader( - get_search_dataset(config.train_dataset.path, tokenizer, rank, world_size), - batch_size=config.train_dataset.batch_size // world_size, - shuffle=config.train_dataset.shuffle, - num_workers=config.train_dataset.num_workers, - collate_fn=lambda x: x, - drop_last=config.train_dataset.drop_last, - ) - ft_spec = FinetuneSpec( - total_train_epochs=config.total_train_epochs, - dataset_size=len(train_dataloader) * config.train_dataset.batch_size, - train_batch_size=config.train_dataset.batch_size, - ) - - # Initialize inference engine - rollout = RemoteSGLangEngine(config.rollout) - rollout.initialize(train_data_parallel_size=parallel_strategy.dp_size) - - - - actor.initialize(None, ft_spec) - ref = None - - # NOTE: Weight update meta only requires address and free port of rank 0, - # but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks - # due to `engine.get_param_specs()`. - # Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0. - weight_update_meta = [ - # WeightUpdateMeta.from_disk(config.experiment_name, config.trial_name, config.cluster.fileroot, "default") - WeightUpdateMeta.from_fsdp_nccl(AllocationMode.from_str(config.allocation_mode), actor) - ] - dist.broadcast_object_list(weight_update_meta, src=0) - weight_update_meta = weight_update_meta[0] - - # Create rollout workflow - if tokenizer.pad_token_id not in config.gconfig.stop_token_ids: - config.gconfig.stop_token_ids.append(tokenizer.pad_token_id) - if tokenizer.eos_token_id not in config.gconfig.stop_token_ids: - config.gconfig.stop_token_ids.append(tokenizer.eos_token_id) - workflow = ASearcherWorkflow( - gconfig=config.gconfig, - tokenizer=tokenizer, - dump_dir=os.path.join( - StatsLogger.get_log_path(config.stats_logger), "generated" - ), - dataset_path=config.train_dataset.path, - max_turns=config.max_turns, - n_trajs=config.n_trajs, - search_client_type=config.search_client_type, - reward_type=config.reward_type, - topk=config.topk, - valid_inst_ratio=config.valid_inst_ratio, - max_tokens=config.actor.mb_spec.max_tokens_per_mb, - ) - - # Run training. - saver = Saver(config.saver, ft_spec) - stats_logger = StatsLogger(config.stats_logger, ft_spec) - evaluator = Evaluator(config.evaluator, ft_spec) - - # Recover - recover_handler = RecoverHandler(config.recover, ft_spec) - recover_info = recover_handler.load( - actor, - saver, - evaluator, - stats_logger, - train_dataloader, - inference_engine=rollout, - weight_update_meta=weight_update_meta, - ) - start_step = ( - recover_info.last_step_info.next().global_step - if recover_info is not None - else 0 - ) - - total_epochs = config.total_train_epochs - steps_per_epoch = len(train_dataloader) - max_steps = total_epochs * steps_per_epoch - - data_generator = cycle_dataloader(train_dataloader) - for global_step in range(start_step, max_steps): - epoch = global_step // steps_per_epoch - step = global_step % steps_per_epoch - step_info = StepInfo( - global_step=global_step, - epoch=epoch, - epoch_step=step, - steps_per_epoch=steps_per_epoch, - ) - - print(f"Epoch {epoch}. Step: {step}/{steps_per_epoch}") - - with stats_tracker.record_timing("rollout"): - if config.async_training: - batch = rollout.prepare_batch(train_dataloader, workflow=workflow) - else: - try: - data = next(data_generator) - except StopIteration: - data_generator = iter(train_dataloader) - data = next(data_generator) - batch = rollout.rollout_batch(data, workflow=workflow) - batch = batch.to(actor.device) - batch = broadcast_tensor_container( - batch, - src_rank=actor.current_data_parallel_head(), - group=actor.context_and_model_parallel_group, - ) - - # Create barrier to synchronize all rollout processes. - dist.barrier(device_ids=[actor.device.index]) - current_platform.synchronize() - - if config.actor.recompute_logprob or config.actor.use_decoupled_loss: - with stats_tracker.record_timing("recompute_logp"): - logp = actor.compute_logp(batch) - batch["prox_logp"] = logp - log_gpu_stats("recompute logp") - - if ref is not None: - with stats_tracker.record_timing("ref_logp"): - batch["ref_logp"] = ref.compute_logp(batch) - log_gpu_stats("ref logp") - - with stats_tracker.record_timing("compute_advantage"): - actor.compute_advantages(batch) - log_gpu_stats("compute advantages") - - gc.collect() - torch.cuda.empty_cache() - gc.collect() - - with ( - stats_tracker.record_timing("train_step"), - stats_tracker.scope("grpo_actor"), - ): - if config.log_agent_stats: - agent_denominator = (batch["begin_of_trajectory"] > 0).bool() - stats_tracker.denominator(agent=agent_denominator) - stats_tracker.stat( - **{k: batch[k].float() for k in config.log_agent_stats_keys}, - denominator="agent", - ) - - stats = actor.ppo_update(batch) - actor.step_lr_scheduler() - log_gpu_stats("actor update") - - # pause inference for updating weights, save, and evaluation - rollout.pause() - - with stats_tracker.record_timing("update_weights"): - if dist.get_rank() == 0: - future = rollout.update_weights(weight_update_meta) - actor.upload_weights(weight_update_meta) - if dist.get_rank() == 0: - future.result() - dist.barrier(device_ids=[actor.device.index]) - current_platform.synchronize() - - actor.set_version(global_step + 1) - rollout.set_version(global_step + 1) - - with stats_tracker.record_timing("save"): - saver.save(actor, epoch, step, global_step, tokenizer=tokenizer) - - with stats_tracker.record_timing("checkpoint_for_recover"): - recover_handler.dump( - actor, - step_info, - saver, - evaluator, - stats_logger, - train_dataloader, - tokenizer=tokenizer, - ) - - dist.barrier(device_ids=[actor.device.index]) - current_platform.synchronize() - - # Upload statistics to the logger (e.g., wandb) - stats[0].update(stats_tracker.export_all(reduce_group=actor.parallelism_group)) - stats_logger.commit(epoch, step, global_step, stats) - - dist.barrier(device_ids=[actor.device.index]) - current_platform.synchronize() - - # Resume rollout - rollout.resume() - - stats_logger.close() - rollout.destroy() - if ref is not None: - ref.destroy() - actor.destroy() - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/astraEnv/ASearcher/ASearcher/train/asearcher_reasoning.py b/astraEnv/ASearcher/ASearcher/train/asearcher_reasoning.py deleted file mode 100644 index 3137459..0000000 --- a/astraEnv/ASearcher/ASearcher/train/asearcher_reasoning.py +++ /dev/null @@ -1,467 +0,0 @@ -import itertools -import asyncio -import os -import sys -import uuid -import json -import gc -import torch -import torch.distributed as dist -import numpy as np -from datasets import load_dataset -from datasets.distributed import split_dataset_by_node -from tensordict import TensorDict -from torchdata.stateful_dataloader import StatefulDataLoader -from transformers import PreTrainedTokenizerFast -from areal.platforms import current_platform -from areal.utils.evaluator import Evaluator -from areal.utils.hf_utils import load_hf_tokenizer -from areal.utils.recover import RecoverHandler -from dataclasses import dataclass, field -from typing import List - -import hashlib - -from areal.api.cli_args import ( - GenerationHyperparameters, - GRPOConfig, - load_expr_config, - InferenceEngineConfig, -) -from areal.api.io_struct import ( - AllocationMode, - FinetuneSpec, - ModelRequest, - WeightUpdateMeta, - StepInfo, -) -from areal.api.workflow_api import RolloutWorkflow -from areal.api.cli_args import GRPOConfig -from areal.engine.ppo.actor import FSDPPPOActor -from areal.engine.sglang_remote import RemoteSGLangEngine -from areal.utils.data import concat_padded_tensors, broadcast_tensor_container -from areal.utils.device import log_gpu_stats -from areal.utils.saver import Saver -from areal.utils.stats_logger import StatsLogger -from areal.utils import seeding, logging, stats_tracker -from areal.experimental.openai import ArealOpenAI -from areal.utils.redistributor import redistribute - -import sys -from pathlib import Path -# sys.path.append("/storage/openpsi/users/xushusheng.xss/projects/ASearcher-Lite@0908") -sys.path.append(str(Path(__file__).resolve().parents[2])) - -from astraEnv.ASearcher.ASearcher.train.reasoning_agent import run_agent -from astraEnv.ASearcher.ASearcher.utils.search_tool import SearchToolBox - -worker_id = uuid.uuid4().hex[:4] - -logger = logging.getLogger(f"ASearcher-Reasoning @ {worker_id}") - -def hash(numbers): - """Hash an entire list of integers as a single string""" - # Convert list to string representation - list_str = json.dumps(numbers, sort_keys=True) # sort_keys for consistency - return hashlib.sha256(list_str.encode()).hexdigest() - - -class ASearcherReasoningWorkflow(RolloutWorkflow): - def __init__( - self, - gconfig: GenerationHyperparameters, - tokenizer: PreTrainedTokenizerFast, - dataset_path: str, - dump_dir: str | None = None, - max_turns: int = 128, - force_turns: int = 4, - n_trajs: int = 1, - search_client_type: str = "async-online-search-access", - topk: int = 10, - max_tokens: int = 30000, - judge_engine: RemoteSGLangEngine | None = None, - ): - self.gconfig = gconfig - self.gconfig.n_samples = 1 - self.tokenizer = tokenizer - self.dump_dir = dump_dir - self.max_tokens = max_tokens - if self.dump_dir is not None and not os.path.exists(self.dump_dir): - os.makedirs(self.dump_dir, exist_ok=True) - - # Search hyper-parameters - self.force_turns = force_turns - self.max_turns = max_turns - self.n_trajs = n_trajs - self.topk = topk - self.search_client_type = search_client_type - - self.toolbox = SearchToolBox(dataset_path=dataset_path, reward_type="F1", topk=self.topk, search_client_type=self.search_client_type) - self.judge_client = ArealOpenAI(engine=judge_engine, tokenizer= tokenizer) - - async def arun_episode(self, engine, data): - # Get the unique identifier for this prompt - qid = None - for key in ["query_id", "id", "qid"]: - qid = data.get(key, None) - if qid is not None: - break - qid = str(qid) or uuid.uuid4().hex - data["qid"] = qid - - # check for generated qid when resuming - if self.dump_dir is not None: - import glob - _pattern = os.path.join(self.dump_dir, "*", f"{qid}.jsonl") - if len(glob.glob(_pattern)) > 0: - logger.info(f"{qid} is already trained on") - return None - - # path to save trajs - version = engine.get_version() - save_trajs_path = None - if self.dump_dir is not None: - os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True) - save_trajs_path = os.path.join(self.dump_dir, str(version), f"{qid}/ID.json") - - client = ArealOpenAI(engine=engine, tokenizer=self.tokenizer) - judge_client = self.judge_client - - # Collect trajectories - trajs = await asyncio.gather(*[run_agent(client=client, - judge_client=judge_client, - tokenizer=self.tokenizer, - data=data, - toolbox=self.toolbox, - max_turns=self.max_turns, - force_turns=self.force_turns, - topk=self.topk, - force_valid=True, - max_tokens=self.max_tokens, - save_path=save_trajs_path.replace("ID.json", f"{i}.json") if save_trajs_path is not None else None, - rank=i - ) - for i in range(self.n_trajs)]) - - all_completions = [r[0] for r in trajs] - rewards = np.asarray([r[1] for r in trajs]) - stats = [r[2] for r in trajs] - - logger.info(f"Qid={qid} rewards={rewards}") - - # Group Normalization - advantages = (rewards - rewards.mean()) - if abs(rewards.max() - rewards.mean()) > 1e-3: - advantages = advantages / advantages.std() - else: - return None - - # Set advantages to all completions - for completions, advantage in zip(all_completions, advantages): - for comp in completions: - client.set_reward(comp.id, advantage) - - completions_with_rewards = client.export_completions(turn_discount=0.0) - - results = [] - for i in range(self.n_trajs): - stats[i].update(dict( - num_output_tokens=0, - num_input_tokens=0, - )) - for comp in all_completions[i]: - resp = completions_with_rewards[comp.id].response - stats[i]["num_input_tokens"] += resp.input_len - stats[i]["num_output_tokens"] += resp.output_len - - first_completion = True - for comp in all_completions[i]: - res = completions_with_rewards[comp.id].to_tensor_dict() - - res["begin_of_trajectory"]=torch.tensor([int(first_completion)]) - for k, v in stats[i].items(): - res[k] = torch.tensor([v]) - first_completion = False - results.append(res) - results = concat_padded_tensors(results) - return results - -@dataclass -class AgentRLConfig(GRPOConfig): - max_turns: int = field( - default=128, - metadata={ - "help": "maximum number of turns for search agent" - } - ) - force_turns: int = field( - default=4, - metadata={ - "help": "minimum number of turns for search agent" - } - ) - n_trajs: int = field( - default=1, - metadata={ - "help": "We could collect multiple trajectories for a single query. By default n_trajs=1." - } - ) - search_client_type: str = field( - default="async-online-search-access", - metadata={ - "help": "Type of tool (async-online-search-access/async-search-access). By default we use 'async-online-search-access'" - } - ) - topk: int = field( - default=10, - metadata={ - "help": "search returns the top-k results. Default top_k=5" - } - ) - # Logging Agent Trajectories - log_agent_stats: bool = field( - default=False, - metadata={ - "help": "Log stats for agent trajectories" - }, - ) - log_agent_stats_keys: List[str] = field( - default_factory=lambda: [], - metadata={ - "help": "Keys of log stats for agent trajectories" - }, - ) - judge_engine: InferenceEngineConfig = field(default_factory=InferenceEngineConfig) - - -def get_search_dataset(dataset_path, tokenizer, rank, world_size): - dataset = load_dataset( - path="json", - split="train", - data_files=dataset_path, - ) - # dataset = dataset.filter(lambda x: len(tokenizer.encode(x["question"])) <= 1024) - return split_dataset_by_node(dataset, rank=rank, world_size=world_size) - -def main(args): - config, _ = load_expr_config(args, AgentRLConfig) - config: AgentRLConfig - - rank = int(os.getenv("RANK")) - tokenizer = load_hf_tokenizer(config.tokenizer_path) - - seeding.set_random_seed(config.seed, key=f"trainer{rank}") - allocation_mode = AllocationMode.from_str(config.allocation_mode) - parallel_strategy = allocation_mode.train - - # Initialize train engine - actor = FSDPPPOActor(config=config.actor) - actor.create_process_group(parallel_strategy=parallel_strategy) - - # Create dataset and dataloaders - train_dataloader = StatefulDataLoader( - get_search_dataset(config.train_dataset.path, tokenizer, actor.data_parallel_rank, actor.data_parallel_world_size), - batch_size=config.train_dataset.batch_size // actor.data_parallel_world_size, - shuffle=config.train_dataset.shuffle, - num_workers=config.train_dataset.num_workers, - collate_fn=lambda x: x, - drop_last=config.train_dataset.drop_last, - ) - ft_spec = FinetuneSpec( - total_train_epochs=config.total_train_epochs, - dataset_size=len(train_dataloader) * config.train_dataset.batch_size, - train_batch_size=config.train_dataset.batch_size, - ) - - # Initialize inference engine - rollout = RemoteSGLangEngine(config.rollout) - rollout.initialize(train_data_parallel_size=parallel_strategy.dp_size) - - # Initialize judge inference engine - judge_engine = RemoteSGLangEngine(config.judge_engine) - judge_engine.initialize(train_data_parallel_size=parallel_strategy.dp_size) - - actor.initialize(None, ft_spec) - ref = None - - # NOTE: Weight update meta only requires address and free port of rank 0, - # but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks - # due to `engine.get_param_specs()`. - # Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0. - - weight_update_meta = WeightUpdateMeta.from_disk( - config.experiment_name, - config.trial_name, - config.cluster.fileroot - ) - - # Create rollout workflow - if tokenizer.pad_token_id not in config.gconfig.stop_token_ids: - config.gconfig.stop_token_ids.append(tokenizer.pad_token_id) - if tokenizer.eos_token_id not in config.gconfig.stop_token_ids: - config.gconfig.stop_token_ids.append(tokenizer.eos_token_id) - workflow = ASearcherReasoningWorkflow( - gconfig=config.gconfig, - tokenizer=tokenizer, - dump_dir=os.path.join( - StatsLogger.get_log_path(config.stats_logger), "generated" - ), - dataset_path=config.train_dataset.path, - max_turns=config.max_turns, - force_turns=config.force_turns, - n_trajs=config.n_trajs, - search_client_type=config.search_client_type, - topk=config.topk, - max_tokens=config.gconfig.max_new_tokens, - judge_engine=judge_engine, - ) - - # Run training. - saver = Saver(config.saver, ft_spec) - stats_logger = StatsLogger(config.stats_logger, ft_spec) - evaluator = Evaluator(config.evaluator, ft_spec) - - # Recover - recover_handler = RecoverHandler(config.recover, ft_spec) - recover_info = recover_handler.load( - actor, - saver, - evaluator, - stats_logger, - train_dataloader, - inference_engine=rollout, - weight_update_meta=weight_update_meta, - ) - start_step = ( - recover_info.last_step_info.next().global_step - if recover_info is not None - else 0 - ) - - total_epochs = config.total_train_epochs - steps_per_epoch = len(train_dataloader) - max_steps = total_epochs * steps_per_epoch - - data_generator = itertools.cycle(train_dataloader) - for global_step in range(start_step, max_steps): - epoch = global_step // steps_per_epoch - step = global_step % steps_per_epoch - step_info = StepInfo( - global_step=global_step, - epoch=epoch, - epoch_step=step, - steps_per_epoch=steps_per_epoch, - ) - - print(f"Epoch {epoch}. Step: {step}/{steps_per_epoch}") - - with stats_tracker.record_timing("rollout"): - batch = None - if actor.is_data_parallel_head(): - if config.async_training: - batch = rollout.prepare_batch( - train_dataloader, - workflow=workflow, - should_accept=lambda sample: True, - ) - else: - batch = rollout.rollout_batch( - next(data_generator), - workflow=workflow, - should_accept=lambda sample: True, - ) - batch = batch.to(actor.device) - batch = redistribute(batch, group=actor.data_parallel_group).data - batch = broadcast_tensor_container( - batch, - src_rank=actor.current_data_parallel_head(), - group=actor.context_and_model_parallel_group, - ) - # Create barrier to synchronize all rollout processes. - dist.barrier(device_ids=[actor.device.index]) - current_platform.synchronize() - - if config.actor.recompute_logprob or config.actor.use_decoupled_loss: - with stats_tracker.record_timing("recompute_logp"): - logp = actor.compute_logp(batch) - batch["prox_logp"] = logp - log_gpu_stats("recompute logp") - - if ref is not None: - with stats_tracker.record_timing("ref_logp"): - batch["ref_logp"] = ref.compute_logp(batch) - log_gpu_stats("ref logp") - - with stats_tracker.record_timing("compute_advantage"): - actor.compute_advantages(batch) - log_gpu_stats("compute advantages") - - with ( - stats_tracker.record_timing("train_step"), - stats_tracker.scope("grpo_actor"), - ): - if config.log_agent_stats: - agent_denominator = (batch["begin_of_trajectory"] > 0).bool() - stats_tracker.denominator(agent=agent_denominator) - stats_tracker.stat( - **{k: batch[k].float() for k in config.log_agent_stats_keys}, - denominator="agent", - ) - - stats = actor.ppo_update(batch) - actor.step_lr_scheduler() - log_gpu_stats("actor update") - - # pause inference for updating weights, save, and evaluation - rollout.pause() - - with stats_tracker.record_timing("update_weights"): - if dist.get_rank() == 0: - future = rollout.update_weights(weight_update_meta) - actor.upload_weights(weight_update_meta) - if dist.get_rank() == 0: - future.result() - dist.barrier(device_ids=[actor.device.index]) - current_platform.synchronize() - - actor.set_version(global_step + 1) - rollout.set_version(global_step + 1) - - with stats_tracker.record_timing("save"): - saver.save(actor, epoch, step, global_step, tokenizer=tokenizer) - - with stats_tracker.record_timing("checkpoint_for_recover"): - recover_handler.dump( - actor, - step_info, - saver, - evaluator, - stats_logger, - train_dataloader, - tokenizer=tokenizer, - ) - - dist.barrier(device_ids=[actor.device.index]) - current_platform.synchronize() - - # Upload statistics to the logger (e.g., wandb) - stats[0].update( - stats_tracker.export_all(reduce_group=actor.data_parallel_group) - ) - stats_logger.commit(epoch, step, global_step, stats) - - dist.barrier(device_ids=[actor.device.index]) - current_platform.synchronize() - - # Resume rollout - rollout.resume() - - stats_logger.close() - rollout.destroy() - if ref is not None: - ref.destroy() - actor.destroy() - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/astraEnv/ASearcher/ASearcher/train/asearcher_train.py b/astraEnv/ASearcher/ASearcher/train/asearcher_train.py deleted file mode 100644 index 00971c7..0000000 --- a/astraEnv/ASearcher/ASearcher/train/asearcher_train.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Buffered trainer entrypoint for ASearcher.""" - -from __future__ import annotations - -import os -import sys -from dataclasses import dataclass, field - -from datasets import load_dataset -from omegaconf import OmegaConf - -from areal.api.cli_args import ( - TrainDatasetConfig, - parse_cli_args, - save_config, - to_structured_cfg, -) -from areal.experimental.trainer import BufferedPPOTrainer -from areal.utils import name_resolve -from areal.utils.stats_logger import StatsLogger - -from astraEnv.ASearcher.ASearcher.train.asearcher import ASearcherWorkflow, AgentRLConfig - -@dataclass -class AgentBufferedRLConfig(AgentRLConfig): - # ASearcher configs omit train_dataset.type. Default to RL here. - train_dataset: TrainDatasetConfig = field( - default_factory=lambda: TrainDatasetConfig(path="", type="rl") - ) - - -def get_search_dataset(dataset_path: str): - return load_dataset(path="json", split="train", data_files=dataset_path) - - -def load_agent_config(argv: list[str]) -> tuple[AgentBufferedRLConfig, str]: - cfg, config_file = parse_cli_args(argv) - - cfg_dict = OmegaConf.to_container(cfg, resolve=False) - assert isinstance(cfg_dict, dict) - - # Backward compatibility with legacy ASearcher YAML keys. - cfg_dict.pop("async_training", None) - if isinstance(cfg_dict.get("actor"), dict): - actor_cfg = cfg_dict["actor"] - actor_cfg.pop("backend", None) - if "group_reward_norm" in actor_cfg: - group_reward_norm = bool(actor_cfg.pop("group_reward_norm")) - if group_reward_norm and "reward_norm" not in actor_cfg: - gconfig_cfg = cfg_dict.get("gconfig", {}) - group_size = int(gconfig_cfg.get("n_samples", 1)) - actor_cfg["reward_norm"] = { - "mean_level": "group", - "std_level": "group", - "group_size": group_size, - } - if isinstance(cfg_dict.get("ref"), dict): - cfg_dict["ref"].pop("backend", None) - if isinstance(cfg_dict.get("train_dataset"), dict): - cfg_dict["train_dataset"].setdefault("type", "rl") - - cfg = OmegaConf.create(cfg_dict) - cfg = to_structured_cfg(cfg, AgentBufferedRLConfig) - cfg = OmegaConf.to_object(cfg) - assert isinstance(cfg, AgentBufferedRLConfig) - - name_resolve.reconfigure(cfg.cluster.name_resolve) - if os.getenv("RANK", "0") == "0": - save_config(cfg, StatsLogger.get_log_path(cfg.stats_logger)) - - return cfg, str(config_file) - - -def main(args: list[str]) -> None: - config, _ = load_agent_config(args) - train_dataset = get_search_dataset(config.train_dataset.path) - - with BufferedPPOTrainer( - config, - train_dataset=train_dataset, - valid_dataset=None, - buffer_size=65536, - rollout_batch_size=1, - ) as trainer: - workflow = ASearcherWorkflow( - gconfig=config.gconfig.new_with_stop_and_pad_token_ids(trainer.tokenizer), - tokenizer=trainer.tokenizer, - dump_dir=os.path.join( - StatsLogger.get_log_path(config.stats_logger), - "generated", - ), - dataset_path=config.train_dataset.path, - max_turns=config.max_turns, - n_trajs=config.n_trajs, - search_client_type=config.search_client_type, - reward_type=config.reward_type, - topk=config.topk, - valid_inst_ratio=config.valid_inst_ratio, - max_tokens=config.actor.mb_spec.max_tokens_per_mb, - ) - - trainer.train(workflow) - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/astraEnv/ASearcher/ASearcher/train/prompts.py b/astraEnv/ASearcher/ASearcher/train/prompts.py deleted file mode 100644 index c7a767c..0000000 --- a/astraEnv/ASearcher/ASearcher/train/prompts.py +++ /dev/null @@ -1,4 +0,0 @@ -SEARCH_ACCESS_PROMPT_TEMPLATE="A conversation between User and Assistant. The user asks a question, and the Assistant answers it. The Assistant analyzes the given question and information in the mind, retains important relevant information, calls a search engine to find necessary information, accesses web pages with certain urls, and provides the user with the answer. The Assistant conducts search by query , access cerain url by url , and the top search results and url page will be returned between and . The reasoning processes are enclosed within . Finally, the Assistant provides answer inside and , i.e. answer here . If there are multiple queries, ensure all answers are enclosed within , seperated with comma. Note that when the Assistant finds the question is invalid, e.g. no answer could match all information in the question, the Assistant replies with ' the question is invalid. '. \n\nUser: \n\n{question}. \n\nThe language of your answer should align with the question. \n\nAssistant: \n\n" -SEARCH_ONLY_PROMPT_TEMPLATE="A conversation between User and Assistant. The user asks a question, and the Assistant answers it. The Assistant analyzes the given question and information in the mind, retains important relevant information, calls a search engine to find necessary information, accesses web pages with certain urls, and provides the user with the answer. The Assistant conducts search by query and the top search results will be returned between and . The reasoning processes are enclosed within . Finally, the Assistant provides answer inside and , i.e. answer here . If there are multiple queries, ensure all answers are enclosed within , seperated with comma. Note that when the Assistant finds the question is invalid, e.g. no answer could match all information in the question, the Assistant replies with ' the question is invalid. '. \n\nUser: \n\n{question}. \n\nThe language of your answer should align with the question. \n\nAssistant: \n\n" -INVALID_PROMPT="Note that when the Assistant finds the question is invalid, e.g. no answer could match all information in the question, the Assistant replies with ' the question is invalid. '." -VALID_PROMPT="You should try to find the most likely answer. " \ No newline at end of file diff --git a/astraEnv/ASearcher/ASearcher/train/reasoning_agent.py b/astraEnv/ASearcher/ASearcher/train/reasoning_agent.py deleted file mode 100644 index 73f5911..0000000 --- a/astraEnv/ASearcher/ASearcher/train/reasoning_agent.py +++ /dev/null @@ -1,703 +0,0 @@ -import re -import time -from typing import Dict, List, Any, Optional - -class ASearcherReasoningPrompts: - THINK_AND_ACT_PROMPT_v1 = \ -"""Given a question, you are an autonomous agent trying to solve the question with web browser. Given the question and the history context, generate the thought as well as the next action (only one action). Tthe completed thought should contain analysis of available information and planning for future steps. Enclose the thought within tags. - -The next action could be one of the following three, each with specific tags: -1. Search w. a search engine, e.g. the search query - -2. Accessing some url found in prior history, e.g. the url to access - -3. Answering the question, e.g. the answer (usually in less than 10 words) (WARNING: Answer the question only after you double check the results with sufficient search!) - -Guidelines: -1. You should double check previous conclusions and identified facts using search from different perspectives. -3. You can try different directions to solve the question, such as using different search queries. -3. If you find related entries in the search results, it is usually useful to access the corresponding urls to find more information. -4. You should find the most likely answer. -5. The next action should follow after the thought. -6. Make sure you choose only one action. -7. Carefully select the type of language to conduct your search query (Chinese or English) - -Current Time: Today is 2025.07.21 - -Question: -```txt -{question} -``` - -Reasoning history: -```txt -{history} -``` - -Thought: ... // the thought to be completed - -Next Action: ... // the next action to be completed -""" - - THINK_AND_ACT_PROMPT = \ -"""Given a question, you are an autonomous agent trying to solve the question with web browser. Given the question and the history context, generate the thought as well as the next action (only one action). The completed thought should contain a detailed analysis of current situation and a plan for future steps. The action is either a query to google search or accessing some URL. Enclose the thought within tags. - -The next action could be one of the following two, each with specific tags: -1. Search w. a search engine, e.g. the search query - -2. Accessing some url found in prior history to find more information, e.g. the url to access - -Guidelines: -1. You should double check previous conclusions and identified facts using search from different perspectives. -3. You can try different directions to solve the question, such as using different search queries. -3. If you find related entries in the search results, it is usually useful to access the corresponding urls to find more information. -4. The next action should follow after the thought. -5. Make sure you should choose only one action. - -Current Time: Today is 2025.07.21 - -Question: -```txt -{question} -``` - -Reasoning history: -```txt -{history} -``` - -Thought: ... // the thought to be completed - -Next Action: ... // the next action to be completed -""" - - THINK_AND_ANSWER_PROMPT = \ -"""Given a question, you are an autonomous agent trying to solve the question with web browser. Given the question and the history context, generate the thought as well as the final answer. The completed thought should contain detailed analysis of available information. Enclose the thought within tags, and the answer within tags. - -Guideline: -1. Determine the answer based on the the available information. -2. Try to make your best guess if the found information is not enough. - - -Question: -```txt -{question} -``` - -Reasoning history: -```txt -{history} -``` - -Thought: ... // the thought to be completed - -Final Answer: ... // the final answer -""" - READ_PAGE_PROMPT = \ -"""Given a question, you are an autonomous agent trying to solve the question with web browser. Given the question, the history context, and the current web page, generate a thought after reading the webpage. The completed thought should contain information found related to the question, relevant links from the current webpage, and detailed analysis of available information. Enclose the thought within tags. - -Question: -```txt -{question} -``` - -Reasoning history: -```txt -{history} -``` - -Current webpage: -```txt -{content} -``` - -Thought: ... // the thought to be completed -""" - READ_SEARCH_RESULTS_PROMPT = \ -"""Given a question, you are an autonomous agent trying to solve the question with web browser. Given the question, the history context, and the search results of the latest query, generate a thought after reading the search results. The completed thought should contain information found related to the question, relevant links from the latest search results that may help solve the question, and detailed analysis of available information. Enclose the thought within tags. - -Question: -```txt -{question} -``` - -Reasoning history: -```txt -{history} -``` - -Latest search results: -```txt -{content} -``` - -Thought: ... // the thought to be completed -""" - -def process_webpage(content): - keys = [("title", "title"), ("p", "p"), ("li", "li", lambda c: "\n" not in c)] - content_list = [] - init_length = len(content) - while any([f"<{k[0]}" in content and f"" in content for k in keys]): - klr = [] - for k in keys: - start = 0 - # print(k) - while True: - ls = [content[start:].find(f"<{k[0]}{c}") for c in [">", " "]] - ls = [l for l in ls if l != -1] - l = -1 if len(ls) == 0 else min(ls) - # print(ls) - if l == -1: - break - l += start - r = content[l:].find(f"") - if r == -1: - break - if (len(k) <= 2) or (len(k) >= 3 and k[2](content[l:l+r])): - # print(k, l, l+r) - klr.append((k, l, l+r)) - break - start = l + r - - if len(klr) == 0: - break - klr = sorted(klr, key=lambda x:x[1]) - k, l, r = klr[0] - content_list.append(content[l:r+len(f"")]) - # print(content_list[-1]) - # input("stop...") - if k[0] == "p": - content_list[-1] += "\n\n" - elif k[0] == "li": - content_list[-1] += "\n" - content = content[r:] - content = "".join(content_list) - final_length = len(content) - print(f"process the webpage: {init_length} -> {final_length}. {content[:100]}") - return content - -class AReaLSearchReasoningAgentV1: - - def __init__(self, - max_turns: int = 128, - force_turns: int = 4, - topk: int = 10, - force_valid: bool = True): - - self.max_turns = max_turns - self.force_turns = force_turns - self.force_valid = force_valid - self.topk = topk - # 保持与原agent相同的属性名 - self.stop = ["<|im_end|>", "<|endoftext|>"] - self.stop_sequences = self.stop - - print(f"AReaLSearchAgentV1 初始化完成") - - def get_query_from_text(self, text: str) -> Optional[str]: - pattern = r'(.*?)' - matches = re.findall(pattern, text, re.DOTALL) - if matches: - return "" + matches[-1].strip() + "" - - return None - - def get_url_from_text(self, text: str) -> Optional[str]: - pattern = r'(.*?)' - matches = re.findall(pattern, text, re.DOTALL) - if matches: - return "" + matches[-1].strip() + "" - - return None - - def get_thought_from_text(self, text: str) -> Optional[str]: - pattern = r'(.*?)' - matches = re.findall(pattern, text, re.DOTALL) - if matches: - return "" + matches[-1].strip() + "" - # return "" + matches[-1].strip() + "" - - return None - - def get_answer_from_text(self, text: str) -> Optional[str]: - pattern = r'(.*?)' - matches = re.findall(pattern, text, re.DOTALL) - if matches: - return "" + matches[-1].strip() + "" - - return None - - def print_search_debug_info(self, text: str): - query_starts = text.count('') - query_ends = text.count('') - # print(f"搜索标签统计: {query_starts}个开始标签, {query_ends}个结束标签") - - def debug_generation_tags(self, text: str) -> Dict: - tags = { - 'query': {'open': text.count('<|begin_of_query|>'), 'close': text.count('<|end_of_query|>')}, - 'documents': {'open': text.count('<|begin_of_documents|>'), 'close': text.count('<|end_of_documents|>')}, - 'answer': {'open': text.count(''), 'close': text.count('')} - } - - for tag_name, counts in tags.items(): - tags[tag_name]['balanced'] = counts['open'] == counts['close'] - - return tags - - def all_finished(self, processes: List[Dict]) -> bool: - finished = [] - for process in processes: - finished.append(not process.get("running", True)) - return all(finished) - - def prepare_queries(self, tokenizer, processes: List[Dict]) -> List[Dict]: - queries = [] - for process in processes: - if "history" not in process: - assert "pred_answer" not in process - process["history"] = [dict(type="prompt", text=process["prompt"])] - process["running"] = True - process["phase"] = "search" - - if process["running"]: - if "text" not in process["history"][-1] and "info_str" in process["history"][-1]: - history = "" - for idx, h in enumerate(process["history"][:-1]): - history += h.get("short_info_str", h.get("text", "")) - if len(history) > 25000: - history = history[-25000:] - - if process["history"][-1]["type"] == "page": - prompt = ASearcherReasoningPrompts.READ_PAGE_PROMPT.format(question=process["question"], history=history, content=process["history"][-1]["info_str"]) - elif process["history"][-1]["type"] == "documents": - prompt = ASearcherReasoningPrompts.READ_SEARCH_RESULTS_PROMPT.format(question=process["question"], history=history, content=process["history"][-1]["info_str"]) - else: - raise RuntimeError(f"Not supported history type: {process['history'][-1]['type']}") - - input_text = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False) - query_len = tokenizer([input_text], return_length=True)['length'][0] - - if query_len <= 28000: - print(f"Reading @ Qid {process['id']}", len(tokenizer(input_text, add_special_tokens=False)["input_ids"]), len([h for h in process["history"] if h["type"] == "documents"]), len([h for h in process["history"] if h["type"] == "act"]), flush=True) - queries.append(dict( - type="llm", - sampling=dict(stop=self.stop, max_new_tokens=31000-query_len), - query_len=query_len, - prompt=prompt, - )) - continue - - if "cache_gen_text" in process: - process.pop("cache_gen_text") - - if "text" in process["history"][-1]: - last_text = process["history"][-1]["text"] - if ("" in last_text and - last_text.strip().endswith("")): - if True: - query_text = last_text.split("")[-1].split("")[0].strip() - queries.append(dict( - type="search", - query=[query_text.strip()], - search_params=dict(topk=self.topk) - )) - continue - elif ("" in last_text and - last_text.strip().endswith("")): - query_text = last_text.split("")[-1].split("")[0] - queries.append(dict( - type="access", - urls=[query_text.strip()], - # search_params=dict(topk=self.topk) - )) - continue - - # input_text = "".join([h["text"] for h in process["history"]]) - history = "" - for idx, h in enumerate(process["history"]): - history += h.get("short_info_str", h.get("text", "")) - if len(history) > 25000: - history = history[-25000:] - - prompt = ASearcherReasoningPrompts.THINK_AND_ACT_PROMPT.format(question=process["question"], history=history) - input_text = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False) + process.get("cache_gen_text", "") - # print(f"Generate Act for Qid {process['id']}", len(tokenizer(input_text, add_special_tokens=False)["input_ids"]), len([h for h in process["history"] if h["type"] == "documents"]), len([h for h in process["history"] if h["type"] == "act"]), flush=True) - - if any([ - len([h for h in process["history"] if h["type"] == "documents"]) >= 20, - len([h for h in process["history"] if h["type"] == "act"]) >= self.force_turns, - process.get("phase", "search") == "answer", - ]): - process["phase"] = "answer" - print(f"Direct Generate Answer for Qid {process['id']}", len(tokenizer(input_text, add_special_tokens=False)["input_ids"]), len([h for h in process["history"] if h["type"] == "documents"]), len([h for h in process["history"] if h["type"] == "act"]), flush=True) - prompt = ASearcherReasoningPrompts.THINK_AND_ACT_PROMPT_v1.format(question=process["question"], history=history) - if self.force_valid: - prompt = prompt.replace('4. If you find information contradicting context of the question, you should point out that the question is invalid and the incorrect information in the question.', "4. You should find the most likely answer even when conflicting information is founded.") - input_text = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False) + process.get("cache_gen_text", "") - - # print("Query Input Length (llm):", process["id"], len(tokenizer(input_text, add_special_tokens=False)["input_ids"]), len([h for h in process["history"] if h["type"] == "documents"]), len([h for h in process["history"] if h["type"] == "act"]), flush=True) - if len(tokenizer(input_text, add_special_tokens=False)["input_ids"]) > 32000 or self.get_answer_from_text(process["history"][-1].get("text", "")) is not None: - print("process is done (1)", process["id"]) - process["running"] = False - continue - - query_len = tokenizer([input_text], return_length=True)['length'][0] - process["max_new_tokens"] = max(0, 31000 - query_len) - queries.append(dict( - type="llm", - sampling=dict(stop=self.stop, max_new_tokens=process.get("max_new_tokens", 4096)), - query_len=query_len, - prompt=prompt, - )) - process.pop("max_new_tokens") - - return queries - - def consume_responses(self, processes: List[Dict], queries: List[Dict], responses: List[Any]) -> List[Dict]: - i = 0 - for process in processes: - if process["running"]: - q, r = queries[i], responses[i] - - # print("consume response", process["id"], q["type"]) - - if q["type"] == "search": - if isinstance(r, list) and len(r) == 1: - r = r[0] - if isinstance(r, list) and isinstance(r[0], list): - assert all([isinstance(_r, list) and len(_r) == 1 for _r in r]), ([(type(_r) , len(_r)) for _r in r]) - r = [_r[0] for _r in r] - assert all(["documents" in _r and "server_type" in _r for _r in r]) - full_r = dict( - documents = [], - urls = [], - server_type = [], - ) - for _r in r: - assert isinstance(_r["server_type"], str) - if "online" in _r["server_type"]: - _r["documents"] = ["Google Search Results: " + doc for doc in _r["documents"]] - full_r["documents"].extend(_r["documents"]) - full_r["urls"].extend(_r["urls"]) - full_r["server_type"].extend([_r["server_type"]] * len(_r["documents"])) - r = full_r - if isinstance(r, dict) and 'documents' in r: - documents = r["documents"] - urls = r["urls"] - # server_types = r["server_type"] - # print(f"SearchR1RAGServer响应: {len(documents)}个文档") - - else: - documents = [] - urls = [] - # server_types = [] - - print(f"搜索结果文档数量: {len(documents)}") - - if len(documents) > 0: - doc_id_template = "[Doc {doc_id}]({url}):\n" - info_str = "\n\n\n" + "\n\n".join([doc_id_template.format(doc_id=str(k+1), url=url) + doc for k, (doc, url) in enumerate(zip(documents, urls))]) + "\n\n\n" - short_info_str = "\n\n" + "\n\n".join([doc_id_template.format(doc_id=str(k+1), url=url) + doc + "..." for k, (doc, url) in enumerate(zip(documents, urls))]) + "\n\n\n" - - process["history"].append(dict( - type="documents", - info_str=info_str, - short_info_str=short_info_str - )) - else: - process["history"].append(dict( - type="documents", - info_str= "\n\n\n" + "No Results Found." + "\n\n\n", - short_info_str="\n\n\n" + "No Results Found." + "\n\n\n" - )) - elif q['type'] == "access": - if isinstance(r, list): - r = r[0] - # process the webpage - if isinstance(r, dict) and 'page' in r and isinstance(r["page"], str) and len(r["page"]) > 0: - page = r["page"] - page = page[:250000] - if "page_cache" not in process: - process["page_cache"] = [] - process["page_cache"] = [] - while len(page) > 0 and len(process["page_cache"]) < 10: - _len = min(10000, len(page)) - process["page_cache"].append(f">>>> Page {len(process["page_cache"]) + 1} >>>>\n\n" + page[:_len]) - page = page[_len:] - print("[DEBUG] add page", process["id"], len(r["page"]), len(process["page_cache"]), flush=True) - - if "page_cache" in process and len(process["page_cache"]) > 0: - page = process["page_cache"].pop(0) - info_str = "\n\n" + page + "\n\n\n" - short_info_str = "\n\n\n" + page[:100] + "...\n\n" + "\n\n" - - process["history"].append(dict( - type="page", - info_str=info_str, - short_info_str=short_info_str - )) - - else: - page = "" - process["page_cache"] = [] - info_str = "\n\n\nNo More Information is Found for this URL.\n\n\n" - short_info_str = "\n\n\nNo More Information is Found for this URL.\n\n\n" - - process["history"].append(dict( - type="page", - info_str=info_str, - short_info_str=short_info_str - )) - - elif q["type"] == "llm": - if hasattr(r, 'stop_reason') and hasattr(r, 'text'): - generated_text = r.text - elif isinstance(r, dict): - generated_text = r.get('text', str(r)) - else: - generated_text = r - - if generated_text is None: - generated_text = "" - - raw_generated_text = generated_text - generated_text = process.get("cache_gen_text", "") + generated_text - - self.print_search_debug_info(generated_text) - - extracted_thought = self.get_thought_from_text(generated_text) - extracted_answer = self.get_answer_from_text(generated_text) - extracted_query = self.get_query_from_text(generated_text) - extracted_url = self.get_url_from_text(generated_text) - - # if the prompt is not asking to answer - if "" not in q["prompt"] and extracted_answer is not None: - print(f"Not time for producing answer for {process['id']}", extracted_answer, flush=True) - extracted_answer = None - - think_and_act = "" - if extracted_thought is not None: - think_and_act = think_and_act + extracted_thought - for act in [extracted_query, extracted_url, extracted_answer]: - if act is not None: - think_and_act = think_and_act.strip() + "\n\n" + act - break - - ### print(">>> THINK & ACT >>>\n", think_and_act, flush=True) - - if extracted_thought is not None: - process["history"].append(dict( - type="act", - full_reasoning_text = generated_text, - text=think_and_act.strip() - )) - if "cache_gen_text" in process: - process.pop("cache_gen_text") - - if "page_cache" in process and len(process["page_cache"]) > 0: - page = process["page_cache"].pop(0) - print(f"{process['id']} pop page cache: {[page[:100]]}") - info_str = "\n\n" + page + "\n\n\n" - short_info_str = "\n\n\n" + page[:100] + "...\n\n" + "\n\n" - - process["history"].append(dict( - type="page", - info_str=info_str, - short_info_str=short_info_str - )) - elif len(raw_generated_text) == 0: - process["cache_gen_text"] = "" - process["llm_gen_fail"] = process.get("llm_gen_fail", 0) + 1 - if process["llm_gen_fail"] > 32: - print("process is done (2)", process["id"], process["llm_gen_fail"]) - process["running"] = False - else: - if process["history"][-1]["type"] in ["page", "documents"]: - process["cache_gen_text"] = "" - process["history"].append(dict( - type="act", - full_reasoning_text = generated_text, - text="\n\n" - )) - process["llm_gen_fail"] = process.get("llm_gen_fail", 0) + 1 - process["page_cache"] = [] - else: - process["cache_gen_text"] = generated_text - # process["max_new_tokens"] = process.get("max_new_tokens", 2048) + 1024 - action_count = len([h for h in process["history"] if h["type"] == "act"]) - if action_count >= self.max_turns + 20 or "" in think_and_act: - print("process is done (3)", process["id"], action_count, self.max_turns, "" in think_and_act, flush=True) - process["running"] = False - - # print("[DEBUG] history length", process["id"], process["history"][-1]["type"], len(process["history"]), len(process.get("page_cache", [])), "page_cache" in process, len([h for h in process["history"] if h["type"] == "act"])) - - - i += 1 - - return processes - - def answers(self, processes: List[Dict]) -> List[str]: - - answers = [] - for process in processes: - if "pred_answer" not in process: - full_text = "".join( - [h["text"] for h in process["history"] if h["type"] != "prompt" and "text" in h] - ) - - if "" in full_text and "" in full_text: - answer = full_text.split("")[-1].split("")[0].strip() - else: - reasoning_text = "\n\n".join([h["full_reasoning_text"] for h in process["history"] if "full_reasoning_text" in h] + [process.get("cache_gen_text", "")]) - # find the last line metioning 'answer' - lines = reasoning_text.split("\n") - lines = [l for l in lines if 'answer' in l.lower()] - if len(lines) > 0: - answer = lines[-1] - else: - answer = reasoning_text.strip().split("")[-1].strip() - - process["pred_answer"] = answer - - answers.append(process["pred_answer"]) - - return answers - -from areal.experimental.openai import ArealOpenAI - -def parse_judge_result(raw_response): - # parse results - import json, ast - mbe = None - for parse_fn in [json.loads, ast.literal_eval]: - try: - mbe = parse_fn(raw_response.split("```json")[-1].split("```")[0].strip()) - break - except: - print(f"[WARNING] Error parsing {[raw_response]}") - if mbe is None and '"judgement": "incorrect"' in raw_response: - mbe = dict(judgement="incorrect") - if mbe is None and '"judgement": "correct"' in raw_response: - mbe = dict(judgement="correct") - if mbe is None: - print(f"[WARNING] Unknown judge result: {[raw_response]}") - mbe = dict(judgement="unknown") - score = float("judgement" in mbe and mbe["judgement"] == "correct") - return score - - -async def run_agent( - client: ArealOpenAI, - judge_client: ArealOpenAI, - tokenizer, - data, - toolbox, - max_turns: int = 128, - force_turns: int = 4, - topk: int = 10, - force_valid: bool = True, - max_tokens: int = 30000, - save_path: str | None = None, - rank: int = -1): - # Create client with AReaL engine and tokenizer - # client = ArealOpenAI(engine=rollout_engine, tokenizer=tokenizer) - - # Create ASearcher Reasoning Agent - agent = AReaLSearchReasoningAgentV1(max_turns=max_turns, - force_turns=force_turns, - topk=topk, - force_valid=force_valid) - - qid = data["id"] - process = dict(id=data["id"], - question=data["question"], - prompt=data["question"], - gt=data["answer"]) - - completions = [] - stats = dict( - turns=0, - num_search=0, - num_access=0, - score=0.0, - ) - cnt = 0 - while not agent.all_finished([process]): - cnt += 1 - print(f"Agent Loop: Qid={qid} rank={rank} cnt={cnt}", flush=True) - - # Prepare query - query = agent.prepare_queries(tokenizer, [process])[0] - - if query is None: - break - - response = None - if query["type"] == "llm": - # Use like standard OpenAI client - completion = await client.chat.completions.create( - messages=[{"role": "user", "content": query["prompt"]}], - temperature=1.0, - max_tokens=max_tokens, - max_completion_tokens=max(0, min(max_tokens, max_tokens - query["query_len"])), - ) - response = completion.choices[0].message.content - # print(f"Qid={qid} rank={rank} cnt={cnt} llm gen response: {[response]} query_len={query['query_len']} max_completion_tokens={max(0, min(max_tokens, max_tokens - query['query_len']))}") - completions.append(completion) - stats["turns"] += 1 - elif query["type"] == "search": - # Search - tool_call = f"{query['query'][0]}" - response = (await toolbox.step((data["id"], [tool_call])))[0] - stats["num_search"] += 1 - elif query["type"] == "access": - # Browsing - tool_call = f"{query['urls'][0]}" - response = (await toolbox.step((data["id"], [tool_call])))[0] - stats["num_access"] += 1 - - process = agent.consume_responses([process], [query], [response])[0] - - # Compute reward with LLM-as-Judge - # judge_client = ArealOpenAI(engine=rollout_engine, tokenizer=tokenizer) - judge_prompt_template = "You are an evaluation assistant. Please determine if the predicted answer is equivalent to the labeled answer.\n" \ - "You should first give your rationale for the judgement, and then give your judgement result (i.e., correct or incorrect).\n\n" \ - "\n" \ - "question: {question}\n" \ - "ground truth answers: {gt_answer}\n" \ - "pred_answer: {pred_answer}\n\n" \ - "Did the model give an answer **equivalent** to the labeled answer? \n\nThe output should in the following json format:\n" \ - "```json\n" \ - "{{\n" \ - """ "rationale": "your rationale for the judgement, as a text",\n""" \ - """ "judgement": "your judgement result, can only be 'correct' or 'incorrect'\n""" \ - "}}\n" \ - "```\n" \ - "Your output:" - pred_answer = agent.answers([process])[0] - ground_truth = data["answer"] - if isinstance(ground_truth, list) and len(ground_truth) == 1: - ground_truth = str(ground_truth[0]) - judge_prompt = judge_prompt_template.format(question=data["question"], gt_answer=str(ground_truth), pred_answer=pred_answer[:200]) - judge_completion = await judge_client.chat.completions.create( - messages=[{"role": "user", "content": judge_prompt}], - temperature=1.0, - max_tokens=8192, - max_completion_tokens=8192, - ) - judge_response = judge_completion.choices[0].message.content - reward = parse_judge_result(judge_response) - stats["score"] = reward - - # client.set_reward(completion.id, reward) - - print("LLM as Judge for Qid={}. GT={}. Ans={}. Result: MBE={}. Raw Response={}".format(data["id"], ground_truth, pred_answer, reward, judge_response[:500])) - - if save_path is not None: - import os, json, sys - if not os.path.exists(os.path.dirname(save_path)): - os.makedirs(os.path.dirname(save_path)) - json.dump(process, open(save_path, "w")) - - return completions, reward, stats diff --git a/astraEnv/ASearcher/ASearcher/train/search_agent.py b/astraEnv/ASearcher/ASearcher/train/search_agent.py deleted file mode 100644 index cfae18f..0000000 --- a/astraEnv/ASearcher/ASearcher/train/search_agent.py +++ /dev/null @@ -1,187 +0,0 @@ -import queue -import re -from dataclasses import dataclass, asdict -from typing import Dict, List, Tuple, Optional - -@dataclass -class Record: - type: str # prompt/llm_gen/search_results/webpage - text: str - token_ids: List[int] - # for webpage and search results - short_text: str = "" - # RL data - input_len: Optional[int] = None - input_tokens: Optional[List[int]] = None - output_len: Optional[int] = None - full_token_ids: Optional[List[int]] = None - output_tokens: Optional[List[int]] = None - output_logprobs: Optional[List[float]] = None - output_versions: Optional[List[int]] = None - - def to_dict(self): - return asdict(self) - -class AgentMemory: - def __init__(self, prompt, prompt_token_ids): - self.memory = [Record(type="prompt", text=prompt, token_ids=prompt_token_ids)] - - def llm_gen_count(self): - return sum([r.type == "llm_gen" for r in self.memory]) - - def filter_records(self, record_type): - return [r for r in self.memory if r.type == record_type] - - def prepare_prompt(self): - prompt = "" - for r in self.memory: - if r.type == "prompt": - prompt = r.text - elif r.type in ["search_results", "webpage"]: - prompt = prompt + "\n\n" + r.short_text + "\n\n" - elif r.type == "llm_gen": - prompt = prompt + r.text - else: - raise RuntimeError(f"Unknown record type: {r.type}") - return prompt - - def prepare_prompt_token_ids(self): - prompt_token_ids = [] - for r in self.memory: - prompt_token_ids += r.token_ids - return prompt_token_ids - - def add_record(self, r: Record): - self.memory.append(r) - - def logging_stats(self) -> Dict: - llm_gens = self.filter_records(record_type="llm_gen") - search_results = self.filter_records(record_type="search_results") - webpages = self.filter_records(record_type="webpage") - ret = dict( - num_llm_gens=len(llm_gens), - num_input_tokens=sum([len(r.input_tokens) for r in llm_gens]), - num_output_tokens=sum([len(r.output_tokens) for r in llm_gens]), - num_search_queries=len(search_results), - num_success_search_queries=len([r for r in search_results if "No search results are found" not in r.text]), - num_failed_search_queries=len([r for r in search_results if "No search results are found" in r.text]), - num_pages=len(webpages), - num_success_url_accesses=len([r for r in webpages if ">>>> Page 1 >>>>" in r.text]), - num_failed_url_accesses=len([r for r in webpages if ">>>> Page 1 >>>>" in r.text]), - ) - return ret - - def to_dict(self): - return [r.to_dict() for r in self.memory] - -class SearchAgent: - def __init__(self, prompt, prompt_token_ids): - self.prompt = prompt - self.memory = AgentMemory(prompt=prompt, prompt_token_ids=prompt_token_ids) - self.summary_job_queue = queue.Queue(128) - - @property - def num_turns(self): - return self.memory.llm_gen_count() - - @property - def is_finished(self): - pattern = r'(.*?)' - return any([len(re.findall(pattern, r.text, re.DOTALL)) > 0 for r in self.memory.filter_records("llm_gen")]) - - def add_summary_jobs(self, summary_jobs): - if not isinstance(summary_jobs, list): - summary_jobs = [summary_jobs] - for summary_job in summary_jobs: - assert (summary_job.get("type", "unkown") in ["search_results", "webpage"]), ("Unknown summary_job type: " + summary_job.get("type", "unknown")) - self.summary_job_queue.put_nowait(summary_job) - - def prepare_llm_query(self, tokenizer): - prompt_token_ids = self.memory.prepare_prompt_token_ids() - sampling_params = dict(stop=["", "", ""]) - if not self.summary_job_queue.empty(): - summary_job = self.summary_job_queue.get_nowait() - if summary_job["type"] in ["search_results", "webpage"]: - full_text = "\n\n" + summary_job["text"] + "\n\n" - short_text = "\n\n" + summary_job.get("short_text", summary_job["text"]) + "\n\n" - full_token_ids, short_token_ids = tokenizer([full_text, short_text], add_special_tokens=False)["input_ids"] - new_record = Record( - type=summary_job["type"], - text=full_text, - short_text=short_text, - token_ids=short_token_ids, - full_token_ids=full_token_ids, - ) - prompt_token_ids += full_token_ids - self.memory.add_record(new_record) - sampling_params["stop"] = [""] - return prompt_token_ids, sampling_params - - def consume_llm_response(self, resp, completion_text): - new_record = Record( - type="llm_gen", - text=completion_text, - token_ids=resp.output_tokens, - input_len=resp.input_len, - input_tokens=resp.input_tokens, - output_len=resp.output_len, - output_tokens=resp.output_tokens, - output_logprobs=resp.output_logprobs, - output_versions=resp.output_versions - ) - self.memory.add_record(new_record) - - tool_calls = [] - for pattern in [r'(.*?)', r'(.*?)', r'(.*?)']: - matches = re.findall(pattern, completion_text, re.DOTALL) - if matches: - match = matches[-1] - tool_calls.append(str(pattern.replace('(.*?)', match))) - - return tool_calls - - def consume_tool_response(self, res, topk=5): - # process the search results - if res["type"] == "search": - summary_job = dict(type="search_results") - - documents = res["documents"][:topk] - urls = res["urls"][:topk] - - if len(documents) > 0: - doc_id_template = "[Doc {doc_id}]({url}):\n" - text = "\n" + "\n\n".join([doc_id_template.format(doc_id=str(k+1), url=url) + doc[:5000] for k, (doc, url) in enumerate(zip(documents, urls))]) + "\n" - else: - text = "\nNo search results are found.\n" - - summary_job["text"] = text - self.add_summary_jobs(summary_job) - - # process the webpage - elif res["type"] == "access": - summary_jobs = [] - page = res["page"] - if page is not None and page.strip() != "": - page = page[:250000] - while len(page) > 0 and len(summary_jobs) < 10: - _len = min(25000, len(page)) - summary_jobs.append(dict( - type="webpage", - text=f"\n>>>> Page {len(summary_jobs) + 1} >>>>\n\n" + page[:_len] + "\n", - short_text=f"\n>>>> Page {len(summary_jobs) + 1} >>>>\n\n" + page[:100] + "\n", - )) - page = page[_len:] - else: - summary_jobs.append(dict( - type="webpage", - text="\nNo More Information is Found for this URL.\n", - )) - self.add_summary_jobs(summary_jobs) - - def get_answer(self): - text = self.memory.prepare_prompt() - pattern = r'(.*?)' - matches = re.findall(pattern, text, re.DOTALL) - if matches: - return matches[-1].strip() - return None \ No newline at end of file diff --git a/astraEnv/ASearcher/ASearcher/utils/rewards.py b/astraEnv/ASearcher/ASearcher/utils/rewards.py deleted file mode 100644 index 9d44482..0000000 --- a/astraEnv/ASearcher/ASearcher/utils/rewards.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import string -import random - -def normalize_answer(s): - def remove_articles(text): - return re.sub(r"\b(a|an|the)\b", " ", text) - - def white_space_fix(text): - return " ".join(text.split()) - - def remove_punc(text): - exclude = set(string.punctuation) - return "".join(ch for ch in text if ch not in exclude) - - def lower(text): - return text.lower() - - return white_space_fix(remove_articles(remove_punc(lower(s)))) - -def bool_mapping(s): - if s == "True": - return "yes" - elif s == "False": - return "no" - else: - return s - -def contains_chinese(text): - """ - Check if the given text contains Chinese characters. - Returns True if any Chinese character is found, False otherwise. - """ - for char in text: - # Check for common Chinese characters (CJK Unified Ideographs) - if '\u4e00' <= char <= '\u9fff': - return True - # Check for rare characters (CJK Unified Ideographs Extension A) - if '\u3400' <= char <= '\u4dbf': - return True - # Check for compatibility characters - if '\uf900' <= char <= '\ufaff': - return True - # Check for extensions B-F (requires surrogate pairs in Python) - # Note: This part handles supplementary characters (needed for Python < 3.3) - if len(char) > 1: # Surrogate pair - code = ord(char[0]) << 16 | ord(char[1]) - if (0x20000 <= code <= 0x2a6df or # Extension B - 0x2a700 <= code <= 0x2b73f or # Extension C - 0x2b740 <= code <= 0x2b81f or # Extension D - 0x2b820 <= code <= 0x2ceaf or # Extension E - 0x2ceb0 <= code <= 0x2ebef): # Extension F - return True - return False - -def em_check(prediction, golden_answers): - if isinstance(golden_answers, str): - golden_answers = [golden_answers] - normalized_prediction = normalize_answer(bool_mapping(prediction)) - score = 0 - for golden_answer in golden_answers: - golden_answer = normalize_answer(bool_mapping(golden_answer)) - if golden_answer == normalized_prediction: - score = 1 - break - return score - - -def subem_check(prediction, golden_answers): - if isinstance(golden_answers, str): - golden_answers = [golden_answers] - normalized_prediction = normalize_answer(bool_mapping(prediction)) - score = 0 - for golden_answer in golden_answers: - golden_answer = normalize_answer(bool_mapping(golden_answer)) - if golden_answer in normalized_prediction: - score = 1 - break - return score - - -def extract_solution(solution_str): - """Extract the equation from the solution string.""" - # Remove everything before the first "Assistant:" - # if "Assistant:" in solution_str: - # solution_str = solution_str.split("Assistant:", 1)[1] - # elif "<|im_start|>assistant" in solution_str: - # solution_str = solution_str.split("<|im_start|>assistant", 1)[1] - # else: - # return None - # solution_str = solution_str.split('\n')[-1] - - answer_pattern = r'(.*?)' - match = re.finditer(answer_pattern, solution_str, re.DOTALL) - matches = list(match) - - # If there are 0 or exactly 1 matches, return None - if len(matches) <= 0: #1: - return None - - # If there are 2 or more matches, return the last one - return matches[-1].group(1).strip() - - -def compute_score_em(solution_str, ground_truth, method='strict', format_score=0., score=1.): - - if isinstance(ground_truth, list): - answer = extract_solution(solution_str=solution_str) - return answer, max([compute_score_em(solution_str, g)[1] for g in ground_truth]) - - answer = extract_solution(solution_str=solution_str) - - if answer is None: - return None, 0 - else: - if em_check(answer, ground_truth): - return answer, score - else: - return answer, format_score - - -def compute_score_subem(solution_str, ground_truth, method='strict', format_score=0., score=1.): - """The scoring function for substring exact match (EM). - - Args: - solution_str: the solution text - ground_truth: the ground truth - method: the method to extract the solution, choices are 'strict' and 'flexible' - format_score: the score for the format - score: the score for the correct answer - """ - answer = extract_solution(solution_str=solution_str) - do_print = random.randint(1, 64) == 1 - - if do_print: - print(f"--------------------------------") - print(f"Golden answers: {ground_truth['target']}") - print(f"Extracted answer: {answer}") - print(f"Solution string: {solution_str}") - - if answer is None: - return 0 - else: - if subem_check(answer, ground_truth['target']): - return score - else: - return format_score - -def normalize_text(text: str) -> str: - """预处理文本,用于NQ数据集的评分 - - 处理步骤: - 1. 转换为小写 - 2. 移除标点符号 (.,!?;:'"()[]{}...) - 3. 去除多余空格 - """ - # 将标点符号替换为空格 - for punct in string.punctuation: - text = text.replace(punct, ' ') - - # 替换多个空格为单个空格 - text = re.sub(r'\s+', ' ', text) - - # 去除首尾空格 - text = text.strip().lower() - return text - -def f1_score(answer_content, gt): - answer_content = normalize_text(bool_mapping(answer_content)) - gt = normalize_text(bool_mapping(gt)) - - # 将答案和参考答案分词 - if contains_chinese(gt): - def parse_chinese_str(s): - # parse consecutive numbers - numbers = [] - for i, c in enumerate(s): - if c.isdigit(): - if i > 0 and s[i-1].isdigit(): - numbers[-1] = numbers[-1] + c - else: - numbers.append(c) - for c in "0123456789,。 ,.-": - s = s.replace(c, "") - s = set(list(s) + numbers) - return s - pred_tokens = parse_chinese_str(answer_content) - gt_tokens = parse_chinese_str(gt) - else: - pred_tokens = set(answer_content.split()) - gt_tokens = set(gt.split()) - - if not gt_tokens: # 避免除零错误 - return 0 - if not pred_tokens: - return 0 - - # 计算共同的词数 - common_tokens = pred_tokens & gt_tokens - - # 计算精确率和召回率 - precision = len(common_tokens) / len(pred_tokens) if pred_tokens else 0 - recall = len(common_tokens) / len(gt_tokens) if gt_tokens else 0 - - # 计算F1分数 - f1 = 0 - if precision + recall > 0: # 避免除零错误 - f1 = 2 * (precision * recall) / (precision + recall) - - return f1 - - -def compute_score_f1(solution_str, ground_truth, method='strict', format_score=0., score=1.): - if isinstance(ground_truth, list): - answer = extract_solution(solution_str=solution_str) - return answer, max([compute_score_f1(solution_str, g)[1] for g in ground_truth]) - - answer = extract_solution(solution_str=solution_str) - - if answer is None: - return None, 0 - else: - ret_score = f1_score(answer, ground_truth) - return answer, ret_score - -def cover_exact_match_score_1(solution_str, ground_truth): - if isinstance(ground_truth, list): - answer = extract_solution(solution_str=solution_str) - return answer, max([cover_exact_match_score_1(solution_str, g)[1] for g in ground_truth]) - - answer = extract_solution(solution_str=solution_str) - - if answer is None: - return None, 0 - - pre_list = normalize_answer(bool_mapping(answer)).split(" ") - ground_list = normalize_answer(bool_mapping(ground_truth)).split(" ") - # print("prediction: ",prediction) - # print("ground_truth: ",ground_truth) - # print("pre_list: ",pre_list) - # print("ground_list: ",ground_list) - # 不考虑顺序和连续 - return answer, float(all(ground in pre_list for ground in ground_list)) - -def correct_format_fn(idx, s): - correct = all( - [ - s.count("") == s.count(""), - s.count("") == s.count(""), - s.count("") == s.count(""), - s.count("") + s.count("") + s.count("") <= 1, - # s.count("") == s.count("") == s.count("<|begin_of_documents|>") == s.count("<|end_of_documents|>") == 0, - s.count("Assistant") == s.count("assistant") == 0, - s.count("") <= 1, - # (s.strip().endswith("") or s.strip().endswith("") or s.strip().endswith("") or s.strip().endswith("")), - ] - ) - return correct \ No newline at end of file diff --git a/astraEnv/ASearcher/ASearcher/utils/search_tool.py b/astraEnv/ASearcher/ASearcher/utils/search_tool.py deleted file mode 100644 index 069b8af..0000000 --- a/astraEnv/ASearcher/ASearcher/utils/search_tool.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright 2025 Ant Group Inc. -import json -from typing import List, Tuple - -try: - from realhf.base import logging -except Exception: - from astraflow.workflow.utils import logging -from astraEnv.ASearcher.ASearcher.utils.rewards import compute_score_em, compute_score_f1 -from astraEnv.ASearcher.ASearcher.utils.search_utils import make_search_client - -logger = logging.getLogger("Search ToolBox") - -def load_metadata(dataset_path): - data=[json.loads(ff) for ff in open(dataset_path)] - for i, d in enumerate(data): - if "idx" in d: - d["idx"] = str(d["idx"]) - elif "qid" in d: - d["idx"] = str(d["qid"]) - elif "id" in d: - d["idx"] = str(d["id"]) - elif "_id" in d: - d["idx"] = str(d["_id"]) - elif "query_id" in d: - d["idx"] = str(d["query_id"]) - else: - d["idx"] = str(i) - id2info = {d["idx"]: d for d in data} - return id2info - - -class SearchToolBox: - def __init__(self, dataset_path: str, reward_type: str = "F1", topk:int = 10, search_client_type: str = "async-online-search-access", use_jina=False): - self.id2info = load_metadata(dataset_path) - self.reward_type = reward_type - self.topk = topk - - # search server - self.use_jina = use_jina - self.search_client_type = search_client_type - self.search_client = make_search_client(search_client_type, use_jina=self.use_jina) - - async def step(self, qid_actions: Tuple[str, List[str]]): - qid, actions = qid_actions - - results = [] - for action in actions: - result = dict(documents=None, score=None, ground_truth=None, type=None) - - # tool calling - if "" in action and "" in action: - query = action.split("")[-1].split("")[0].strip() - req_meta = { - "queries": [query], - "topk": self.topk, - "return_scores": False - } - - # send search query to server - response = await self.search_client.query_async(req_meta) - - documents = response[0]["documents"] - urls = response[0]["urls"] - - result["documents"] = documents - result["urls"] = urls - result["type"] = "search" - elif "" in action and "" in action: - url = action.split("")[-1].split("")[0].strip() - - # send wepage access request - response = await self.search_client.access_async([url]) - - page = None - - if self.search_client_type == "async-online-search-access": - if self.use_jina: - page = response[0].get("page", "") - else: - # process webpage - page = self.process_webpage(response[0].get("page", "")) - elif self.search_client_type == "async-search-access": - if response["result"][0] is None: - page = None - else: - page = response["result"][0]["contents"] - - result["page"] = page - result["type"] = "access" - - # compute rewards - ground_truth = self.id2info[qid.split("@")[0]]["answer"] - if isinstance(ground_truth, list) or isinstance(ground_truth, tuple): - ground_truth = [str(gt) for gt in ground_truth] - else: - ground_truth = str(ground_truth) - - ground_truth_aug = None - if "aug_answer" in self.id2info[qid.split("@")[0]] and len(self.id2info[qid.split("@")[0]]["aug_answer"]) > 0: - ground_truth_aug = self.id2info[qid.split("@")[0]]["aug_answer"] - if isinstance(ground_truth_aug, list) or isinstance(ground_truth_aug, tuple): - ground_truth_aug = [str(gt) for gt in ground_truth_aug] - else: - ground_truth_aug = str(ground_truth_aug) - - if self.reward_type == "F1": - extracted, score = compute_score_f1(action, ground_truth, method="strict") - elif self.reward_type == "EM": - extracted, score = compute_score_em(action, ground_truth, method="strict") - if ground_truth_aug is not None: - if self.reward_type == "F1": - _, score_aug = compute_score_f1(action, ground_truth_aug, method="strict") - elif self.reward_type == "EM": - _, score_aug = compute_score_em(action, ground_truth_aug, method="strict") - - result["extracted"] = extracted - result["score"] = score - result["ground_truth"] = self.id2info[qid.split("@")[0]]["answer"] - - if ground_truth_aug is not None: - score_aug = max(score_aug, score) - result["score"] = score * 0.7 + score_aug * 0.3 - result["ground_truth_aug"] = ground_truth_aug - - # if extracted is not None: - # logger.info("F1 Score={:.2f}. Extracted='{}'. Ground Truth='{}'. Qid={}. Question='{}'".format(score, extracted, ground_truth, qid.split("@")[0], self.id2info[qid.split("@")[0]]["question"])) - - results.append(result) - return results - - def process_webpage(self, content): - keys = [("title", "title"), ("p", "p"), ("li", "li", lambda c: "\n" not in c), ("td", "td"), ("tr", "tr")] - content_list = [] - init_length = len(content) - while any([f"<{k[0]}" in content and f"" in content for k in keys]): - klr = [] - for k in keys: - start = 0 - # print(k) - while True: - ls = [content[start:].find(f"<{k[0]}{c}") for c in [">", " "]] - ls = [l for l in ls if l != -1] - l = -1 if len(ls) == 0 else min(ls) - # print(ls) - if l == -1: - break - l += start - r = content[l:].find(f"") - if r == -1: - break - if (len(k) <= 2) or (len(k) >= 3 and k[2](content[l:l+r])): - # print(k, l, l+r) - klr.append((k, l, l+r)) - break - start = l + r - - if len(klr) == 0: - break - klr = sorted(klr, key=lambda x:x[1]) - k, l, r = klr[0] - content_list.append(content[l:r+len(f"")]) - # print(content_list[-1]) - # input("stop...") - if k[0] == "p": - content_list[-1] += "\n\n" - elif k[0] == "li": - content_list[-1] += "\n" - content = content[r:] - content = "".join(content_list) - final_length = len(content) - logger.info(f"process the webpage: {init_length} -> {final_length}. {content[:100]}") - return content diff --git a/astraEnv/ASearcher/ASearcher/utils/search_utils.py b/astraEnv/ASearcher/ASearcher/utils/search_utils.py deleted file mode 100644 index 954bf60..0000000 --- a/astraEnv/ASearcher/ASearcher/utils/search_utils.py +++ /dev/null @@ -1,433 +0,0 @@ -import requests -import random -import time -import json -import asyncio -import html -import os -from typing import Dict, Any, List - -import aiohttp -import asyncio -from typing import Dict, List, Any - -try: - from astraEnv.ASearcher.tools.web_browser import WebPageCache - # from utils.web_browser import WebPageCache - WEBPAGECACHE_AVAILABLE = True -except ImportError as e: - print(f"[WARNING] import web browser error: {e}") - WEBPAGECACHE_AVAILABLE = False - WebPageCache = None - - - -SERPER_STATS = dict( - num_requests = 0 -) - -class AsyncSearchBrowserClient: - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.session = None - self.server_list = self.get_server_list() - if not self.server_list: - rag_dir = os.environ.get("RAG_SERVER_ADDR_DIR", "") - raise RuntimeError( - "No RAG servers found for async-search-access. " - f"RAG_SERVER_ADDR_DIR={rag_dir!r}, cwd={os.getcwd()!r}, " - f"pattern={rag_dir + '/Host*_IP*.txt'!r}" - ) - self.server_addr = random.choice(self.server_list) - # print(self.server_list) - - def get_server_list(self): - import glob - rag_server_addr_dir = os.environ.get("RAG_SERVER_ADDR_DIR", "") - - server_list = [] - - filenames = glob.glob(rag_server_addr_dir + "/Host*_IP*.txt") - for filename in filenames: - try: - with open(filename) as f: - server_list.extend( - [line.strip() for line in f.readlines() if line.strip()] - ) - except: - continue - return server_list - - async def query_async(self, req_meta: Dict[str, Any]) -> List[Dict]: - cnt = 0 - last_exception = None - while cnt < 10: - try: - async with aiohttp.ClientSession() as session: - async with session.post( - f"http://{self.server_addr}/retrieve", - json=req_meta, - timeout=aiohttp.ClientTimeout(total=120, sock_connect=120) - ) as response: - response.raise_for_status() - res = await response.json() - return [ - dict( - documents=[r["contents"] for r in result], - urls=[r["url"] for r in result], - server_type="async-search-browser", - ) for result in res["result"] - ] - except Exception as e: - print("query_async", self.server_list, e.__class__.__name__, e.__cause__) - last_exception = e - self.server_list = self.get_server_list() - if not self.server_list: - raise RuntimeError( - "RAG server list became empty while retrying query_async. " - f"RAG_SERVER_ADDR_DIR={os.environ.get('RAG_SERVER_ADDR_DIR', '')!r}" - ) from e - self.server_addr = random.choice(self.server_list) - print(f"Search Engine switched to {self.server_addr}") - cnt += 1 - await asyncio.sleep(10) - - raise RuntimeError("Fail to post search query to RAG server") from last_exception - - async def access_async(self, urls: List[str]) -> List[Dict]: - cnt = 0 - last_exception = None - while cnt < 10: - try: - async with aiohttp.ClientSession() as session: - async with session.post( - f"http://{self.server_addr}/access", - json={"urls": urls}, - timeout=aiohttp.ClientTimeout(total=120, sock_connect=120) - ) as response: - response.raise_for_status() - res = await response.json() - return [ - dict( - page=result["contents"] if result is not None else "", - type="access", - server_type="async-search-browser", - ) for result in res["result"] - ] - except Exception as e: - print("access_async", self.server_list, e) - last_exception = e - self.server_list = self.get_server_list() - if not self.server_list: - raise RuntimeError( - "RAG server list became empty while retrying access_async. " - f"RAG_SERVER_ADDR_DIR={os.environ.get('RAG_SERVER_ADDR_DIR', '')!r}" - ) from e - self.server_addr = random.choice(self.server_list) - print(f"Search Engine switched to {self.server_addr}") - cnt += 1 - await asyncio.sleep(10) - - raise RuntimeError("Fail to post access request to RAG server") from last_exception - -class AsyncOnlineSearchClient: - - _search_semaphore = None - _access_semaphore = None - - @classmethod - def _get_search_semaphore(cls): - - if cls._search_semaphore is None: - cls._search_semaphore = asyncio.Semaphore(20) - return cls._search_semaphore - - @classmethod - def _get_access_semaphore(cls): - - if cls._access_semaphore is None: - cls._access_semaphore = asyncio.Semaphore(10) - return cls._access_semaphore - - def __init__(self, enable_cache: bool = True, cache_size: int = 10000, cache_file: str = "../webpage_cache.json", - use_jina: bool = False, jina_api_key: str = None, wrapper_format: bool = True): - # Serper API - self.serper_server_addr = "https://google.serper.dev" - self.serper_api_key = os.environ.get('SERPER_API_KEY', '') - if not self.serper_api_key: - raise RuntimeError("Serper API key is not set. Please configure it in config.yaml or set the SERPER_API_KEY environment variable.") - self.serper_headers = { - 'X-API-KEY': self.serper_api_key, - 'Content-Type': 'application/json' - } - self.max_workers = 10 - - self.max_retries = 3 - self.retry_delay = 1.0 - self.backoff_factor = 2.0 - - self.wrapper_format = wrapper_format - - self.use_jina = use_jina - - self.jina_api_key = jina_api_key or os.environ.get('JINA_API_KEY', '') - if self.use_jina and not self.jina_api_key: - raise RuntimeError("Jina is enabled but the API key is not set. Please configure it in config.yaml or set the JINA_API_KEY environment variable.") - - if enable_cache and WEBPAGECACHE_AVAILABLE: - self.webpage_cache = WebPageCache(cache_size, cache_file, save_interval=5) - else: - self.webpage_cache = None - - async def _jina_readpage_async(self, session, url: str) -> str: - """ - Read webpage content using Jina service asynchronously. - - Args: - session: aiohttp ClientSession - url: The URL to read - - Returns: - str: The webpage content or error message - """ - try: - headers = { - 'Authorization': f'Bearer {self.jina_api_key}', - 'Content-Type': 'application/json', - } - - async with session.get(f'https://r.jina.ai/{url}', headers=headers, timeout=aiohttp.ClientTimeout(total=30)) as response: - if response.status == 200: - content = await response.text() - return content - else: - return f"[visit] Failed to read page. Status code: {response.status}" - - except Exception as e: - return f"[visit] Failed to read page. Error: {str(e)}" - - async def query_async(self, req_meta): - - import aiohttp - - queries = req_meta.get("queries", []) - topk = req_meta.get("topk", 5) - - if not queries: - return [] - - async def single_serper_query_async(session, query: str, topk: int) -> dict: - - query = query[:2000] - async with self._get_search_semaphore(): - payload = { - "q": query, - "num": topk - } - - for attempt in range(4): - try: - if attempt > 0: - delay = 1.0 * (2 ** (attempt - 1)) # 1s, 2s, 4s - await asyncio.sleep(delay) - - - await asyncio.sleep(0.1) - - SERPER_STATS["num_requests"] += 1 - print("Serper Stats: ", json.dumps(SERPER_STATS)) - - async with session.post( - f"{self.serper_server_addr}/search", - headers=self.serper_headers, - json=payload, - timeout=aiohttp.ClientTimeout(total=30) - ) as response: - if response.status == 200: - data = await response.json() - if attempt > 0: - print(f"[INFO] AsyncOnlineSearchClient: Query succeeded on retry {attempt}") - return { - "success": True, - "data": data - } - else: - - response_text = await response.text() - error_msg = f"HTTP {response.status}: {response_text[:100]}" - print(f"[WARNING] AsyncOnlineSearchClient: HTTP error (attempt {attempt + 1}): {error_msg}") - if attempt == 3: - return { - "success": False, - "error": error_msg - } - - except Exception as e: - error_msg = f"{type(e).__name__}: {str(e)[:100]}" - print(f"[WARNING] AsyncOnlineSearchClient: Error (attempt {attempt + 1}): {error_msg}") - if attempt == 3: - return { - "success": False, - "error": error_msg - } - - return { - "success": False, - "error": "Unknown error after all retries" - } - - async with aiohttp.ClientSession() as session: - tasks = [single_serper_query_async(session, query, topk) for query in queries] - serper_results = await asyncio.gather(*tasks) - - formatted_results = [] - for query, serper_result in zip(queries, serper_results): - query_results = [] - - if serper_result and serper_result.get("success", False): - data = serper_result.get("data", {}) - organic_results = data.get("organic", [])[:topk] - - for result in organic_results: - query_results.append({ - "title": result.get("title", ""), - "url": result.get("link", ""), - "snippet": result.get("snippet", ""), - "server_type": "async-online-search", - }) - else: - error = serper_result.get("error", "Unknown error") if serper_result else "No response" - print(f"[ERROR] AsyncOnlineSearchClient: Search failed for '{query}': {error}") - - formatted_results.append(query_results) - - if self.wrapper_format: - first_query_results = formatted_results[0] if formatted_results else [] - return [{ - "documents": [result.get("title", "") + " " + result.get("snippet", "") for result in first_query_results], - "urls": [result.get("url", "") for result in first_query_results], - "server_type": "async-online-search" - }] - else: - if len(queries) == 1: - return formatted_results[0] # return [{...}, {...}] rather than [[{...}, {...}]] - else: - return formatted_results # return [[...], [...]] - - async def access_async(self, urls): - - if not urls: - return [] - - results = [] - urls_to_fetch = [] - - for url in urls: - if self.webpage_cache and self.webpage_cache.has(url): - cached_content = self.webpage_cache.get(url) - if cached_content: - results.append(dict(page=cached_content, type="access")) - else: - urls_to_fetch.append(url) - results.append(None) - else: - urls_to_fetch.append(url) - results.append(None) - # print(results) - - if urls_to_fetch: - if self.use_jina and self.jina_api_key: - try: - async with self._get_access_semaphore(): - fetched_results = await self._access_urls_jina_async(urls_to_fetch) - - fetch_index = 0 - for i, result in enumerate(results): - if result is None: - if fetch_index < len(fetched_results): - fetched_result = fetched_results[fetch_index] - results[i] = fetched_result - - if self.webpage_cache and fetched_result.get("page"): - self.webpage_cache.put(urls[i], fetched_result["page"]) - - fetch_index += 1 - else: - results[i] = dict(page="", type="access") - - except Exception as e: - for i, result in enumerate(results): - if result is None: - results[i] = dict(page="", type="access") - else: - for i, result in enumerate(results): - if result is None: - results[i] = dict(page="", type="access") - - for result in results: - if result is not None: - result["server_type"] = "async-online-search" - # print(results) - return results - - async def _access_urls_jina_async(self, urls): - results = [] - - try: - import aiohttp - async with aiohttp.ClientSession() as session: - for url in urls: - content = await self._jina_readpage_async(session, url) - if content and not content.startswith("[visit] Failed"): - results.append(dict(page=content, type="access")) - else: - results.append(dict(page="", type="access")) - - except Exception as e: - results = [dict(page="", type="access") for _ in urls] - - for r in results: - if len(r["page"]) > 0: - r["type"] = "jina" - - return results - - - - def get_cache_stats(self): - if self.webpage_cache: - return self.webpage_cache.get_stats() - else: - return {"cache_disabled": True} - - def clear_cache(self): - if self.webpage_cache: - self.webpage_cache.clear() - - def force_save_cache(self): - if self.webpage_cache: - self.webpage_cache.force_save() - - - -SEARCH_CLIENTS = { - "async-search-access": AsyncSearchBrowserClient, - "async-online-search-access": AsyncOnlineSearchClient, -} - - -def make_search_client(search_client_type: str, use_jina: bool = False, jina_api_key: str = None): - if search_client_type == "async-online-search": - return SEARCH_CLIENTS[search_client_type](use_jina=use_jina, jina_api_key=jina_api_key) - elif search_client_type == "async-online-search-access": - return SEARCH_CLIENTS[search_client_type](use_jina=use_jina, jina_api_key=jina_api_key, wrapper_format=True) - else: - return SEARCH_CLIENTS[search_client_type]() - - -if __name__ == "__main__": - search_client = AsyncOnlineSearchClient(use_jina=True) - url_response = asyncio.run(search_client.access_async(["https://en.wikipedia.org/w/index.php?title=Beveridge,%20Victoria"])) - print(url_response[0]['page'][:1000]) - - exit(0) diff --git a/astraEnv/ASearcher/ASearcher/utils/web_browser.py b/astraEnv/ASearcher/ASearcher/utils/web_browser.py deleted file mode 100644 index 5ab971e..0000000 --- a/astraEnv/ASearcher/ASearcher/utils/web_browser.py +++ /dev/null @@ -1,164 +0,0 @@ -# Shamelessly stolen from Microsoft Autogen team: thanks to them for this great resource! -# https://github.com/microsoft/autogen/blob/gaia_multiagent_v01_march_1st/autogen/browser_utils.py -import atexit -from collections import OrderedDict -import hashlib -import json -import os -import threading -import time -from typing import Any, Dict, Optional - - -class WebPageCache: - - def __init__(self, max_size: int = 100000, cache_file: str = "./webpage_cache.json", save_interval: int = 10): - self.max_size = max_size - self.cache_file = cache_file - self.cache = OrderedDict() - self.lock = threading.Lock() - self.stats = {"hits": 0, "misses": 0, "evictions": 0} - self.save_interval = save_interval - self.operations_since_save = 0 - - self.load_from_file() - - atexit.register(self.save_to_file) - - def _generate_cache_key(self, url: str) -> str: - return hashlib.md5(url.encode()).hexdigest() - - def put(self, url: str, content: str): - if not url or not content: - return - - cache_key = self._generate_cache_key(url) - - with self.lock: - if cache_key in self.cache: - del self.cache[cache_key] - - while len(self.cache) >= self.max_size: - self.cache.popitem(last=False) - self.stats["evictions"] += 1 - - self.cache[cache_key] = { - "url": url, - "content": content, - "timestamp": time.time() - } - - self.operations_since_save += 1 - if self.operations_since_save >= self.save_interval: - self.operations_since_save = 0 - import threading - threading.Thread(target=self._background_save, daemon=True).start() - - def get(self, url: str) -> Optional[str]: - cache_key = self._generate_cache_key(url) - - with self.lock: - if cache_key in self.cache: - # 移动到末尾(最近使用) - entry = self.cache.pop(cache_key) - self.cache[cache_key] = entry - self.stats["hits"] += 1 - return entry["content"] - else: - self.stats["misses"] += 1 - return None - - def has(self, url: str) -> bool: - cache_key = self._generate_cache_key(url) - with self.lock: - return cache_key in self.cache - - def clear(self): - with self.lock: - self.cache.clear() - self.stats = {"hits": 0, "misses": 0, "evictions": 0} - self.operations_since_save = 0 - - def force_save(self): - self.save_to_file() - self.operations_since_save = 0 - - def get_stats(self) -> Dict[str, Any]: - with self.lock: - total_requests = self.stats["hits"] + self.stats["misses"] - hit_rate = self.stats["hits"] / total_requests if total_requests > 0 else 0 - - return { - "cache_size": len(self.cache), - "max_size": self.max_size, - "hits": self.stats["hits"], - "misses": self.stats["misses"], - "evictions": self.stats["evictions"], - "hit_rate": hit_rate, - "total_requests": total_requests - } - - def _background_save(self): - try: - self.save_to_file() - except Exception as e: - print(f"[ERROR] WebPageCache: Background save failed: {e}") - - def save_to_file(self): - try: - with self.lock: - ordered_cache = [] - for key, value in self.cache.items(): - ordered_cache.append((key, value)) - - cache_data = { - "cache_ordered": ordered_cache, - "stats": self.stats, - "max_size": self.max_size, - "saved_at": time.time() - } - - with open(self.cache_file, 'w', encoding='utf-8') as f: - json.dump(cache_data, f, indent=2, ensure_ascii=False) - - print(f"[DEBUG] WebPageCache: Saved {len(self.cache)} entries to {self.cache_file}") - - except Exception as e: - print(f"[ERROR] WebPageCache: Failed to save cache to {self.cache_file}: {e}") - - def load_from_file(self): - """从JSON文件加载缓存""" - if not os.path.exists(self.cache_file): - print(f"[DEBUG] WebPageCache: No existing cache file {self.cache_file}, starting fresh") - return - - try: - with open(self.cache_file, 'r', encoding='utf-8') as f: - cache_data = json.load(f) - - with self.lock: - if "cache_ordered" in cache_data: - ordered_cache = cache_data["cache_ordered"] - self.cache = OrderedDict(ordered_cache) - print(f"[DEBUG] WebPageCache: Loaded ordered cache format") - else: - loaded_cache = cache_data.get("cache", {}) - self.cache = OrderedDict(loaded_cache) - print(f"[DEBUG] WebPageCache: Loaded legacy cache format (LRU order may be lost)") - - self.stats = cache_data.get("stats", {"hits": 0, "misses": 0, "evictions": 0}) - - while len(self.cache) > self.max_size: - self.cache.popitem(last=False) - self.stats["evictions"] += 1 - - saved_at = cache_data.get("saved_at", 0) - saved_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(saved_at)) - - print(f"[DEBUG] WebPageCache: Loaded {len(self.cache)} entries from {self.cache_file} (saved at {saved_time})") - - except Exception as e: - print(f"[ERROR] WebPageCache: Failed to load cache from {self.cache_file}: {e}") - with self.lock: - self.cache = OrderedDict() - self.stats = {"hits": 0, "misses": 0, "evictions": 0} diff --git a/astraEnv/ASearcher/agent/__init__.py b/astraEnv/ASearcher/agent/__init__.py deleted file mode 100644 index 7d2dd62..0000000 --- a/astraEnv/ASearcher/agent/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from agent.search_r1 import SearchR1Agent -from agent.asearcher_reasoning import AsearcherReasoningAgent -from agent.asearcher import AsearcherAgent - - -AGENT_CATEGORY = { - "search-r1": SearchR1Agent, - "asearcher-reasoning": AsearcherReasoningAgent, - "asearcher": AsearcherAgent, -} - -def make_agent(agent_type): - return AGENT_CATEGORY[agent_type]() \ No newline at end of file diff --git a/astraEnv/ASearcher/agent/asearcher.py b/astraEnv/ASearcher/agent/asearcher.py deleted file mode 100644 index f001f5a..0000000 --- a/astraEnv/ASearcher/agent/asearcher.py +++ /dev/null @@ -1,192 +0,0 @@ -import queue -import re -from dataclasses import dataclass, asdict -from typing import Dict, List, Tuple, Optional - -@dataclass -class Record: - type: str # prompt/llm_gen/search_results/webpage - text: str - # for webpage and search results - short_text: str = "" - # RL data - input_len: Optional[int] = None - input_tokens: Optional[List[int]] = None - output_len: Optional[int] = None - output_tokens: Optional[List[int]] = None - output_logprobs: Optional[List[float]] = None - output_versions: Optional[List[int]] = None - - def to_dict(self): - return asdict(self) - -class AgentMemory: - def __init__(self, prompt): - self.memory = [Record(type="prompt", text=prompt)] - - def llm_gen_count(self): - return sum([r.type == "llm_gen" for r in self.memory]) - - def filter_records(self, record_type): - return [r for r in self.memory if r.type == record_type] - - def prepare_prompt(self): - prompt = "" - for r in self.memory: - if r.type == "prompt": - prompt = r.text - elif r.type in ["search_results", "webpage"]: - prompt = prompt + "\n\n" + r.short_text + "\n\n" - elif r.type == "llm_gen": - prompt = prompt + r.text - else: - raise RuntimeError(f"Unknown record type: {r.type}") - return prompt - - def add_record(self, r: Record): - self.memory.append(r) - - def logging_stats(self) -> Dict: - llm_gens = self.filter_records(record_type="llm_gen") - search_results = self.filter_records(record_type="search_results") - webpages = self.filter_records(record_type="webpage") - ret = dict( - num_llm_gens=len(llm_gens), - num_input_tokens=sum([len(r.input_tokens) for r in llm_gens if r.input_tokens is not None]), - num_output_tokens=sum([len(r.output_tokens) for r in llm_gens if r.output_tokens is not None]), - num_search_queries=len(search_results), - num_success_search_queries=len([r for r in search_results if "No search results are found" not in r.text]), - num_failed_search_queries=len([r for r in search_results if "No search results are found" in r.text]), - num_pages=len(webpages), - num_success_url_accesses=len([r for r in webpages if ">>>> Page 1 >>>>" in r.text]), - num_failed_url_accesses=len([r for r in webpages if ">>>> Page 1 >>>>" not in r.text]), - ) - return ret - - def to_dict(self): - return [r.to_dict() for r in self.memory] - -class AsearcherAgent: - def __init__(self, prompt=None): - self.prompt = prompt - self.memory = AgentMemory(prompt=prompt) if prompt else None - self.job_queue = queue.Queue(128) - self.max_turns = 64 # Default max turns like other agents - - def initialize_with_prompt(self, process): - """Initialize or reset agent with a specific prompt""" - prompt = process["prompt"] - self.prompt = prompt - self.memory = AgentMemory(prompt=prompt) - self.job_queue = queue.Queue(128) - - @property - def num_turns(self): - return self.memory.llm_gen_count() if self.memory else 0 - - @property - def is_finished(self): - if not self.memory: - return False - pattern = r'(.*?)' - return any([len(re.findall(pattern, r.text, re.DOTALL)) > 0 for r in self.memory.filter_records("llm_gen")]) - - def add_jobs(self, jobs): - if not isinstance(jobs, list): - jobs = [jobs] - for job in jobs: - assert (job.get("type", "unkown") in ["search_results", "webpage"]), ("Unknown job type: " + job.get("type", "unknown")) - self.job_queue.put_nowait(job) - - def prepare_llm_query(self): - if not self.memory: - raise RuntimeError("Agent not initialized with prompt. Call initialize_with_prompt() first.") - - prompt = self.memory.prepare_prompt() - sampling_params = dict(stop=["", "", ""]) - if not self.job_queue.empty(): - job = self.job_queue.get_nowait() - if job["type"] in ["search_results", "webpage"]: - prompt = prompt + "\n\n" + job["text"] + "\n\n" - new_record = Record( - type=job["type"], - text=job["text"], - short_text=job.get("short_text", job["text"]), - ) - self.memory.add_record(new_record) - sampling_params["stop"] = [""] - return prompt, sampling_params - - def consume_llm_response(self, resp, completion_text): - if not self.memory: - raise RuntimeError("Agent not initialized with prompt. Call initialize_with_prompt() first.") - - new_record = Record( - type="llm_gen", - text=completion_text, - input_len=resp.input_len, - output_len=resp.output_len, - ) - self.memory.add_record(new_record) - - tool_calls = [] - for pattern in [r'(.*?)', r'(.*?)', r'(.*?)']: - matches = re.findall(pattern, completion_text, re.DOTALL) - if matches: - match = matches[-1] - tool_calls.append(str(pattern.replace('(.*?)', match))) - - return tool_calls - - def consume_tool_response(self, res, topk=5): - # process the search results - if res["type"] == "search": - job = dict(type="search_results") - - # Safely handle potentially None documents and urls - documents = res.get("documents") or [] - urls = res.get("urls") or [] - - # Ensure we slice safely - documents = documents[:topk] if documents else [] - urls = urls[:topk] if urls else [] - - if len(documents) > 0: - doc_id_template = "[Doc {doc_id}]({url}):\n" - text = "\n" + "\n\n".join([doc_id_template.format(doc_id=str(k+1), url=url) + doc[:5000] for k, (doc, url) in enumerate(zip(documents, urls))]) + "\n" - else: - text = "\nNo search results are found.\n" - - job["text"] = text - self.add_jobs(job) - - # process the webpage - elif res["type"] == "access": - jobs = [] - page = res["page"] - if page is not None and page.strip() != "": - page = page[:250000] - while len(page) > 0 and len(jobs) < 10: - _len = min(25000, len(page)) - jobs.append(dict( - type="webpage", - text=f"\n>>>> Page {len(jobs) + 1} >>>>\n\n" + page[:_len] + "\n", - short_text=f"\n>>>> Page {len(jobs) + 1} >>>>\n\n" + page[:100] + "\n", - )) - page = page[_len:] - else: - jobs.append(dict( - type="webpage", - text="\nNo More Information is Found for this URL.\n", - )) - self.add_jobs(jobs) - - def get_answer(self): - if not self.memory: - return None - text, _ = self.prepare_llm_query() - pattern = r'(.*?)' - matches = re.findall(pattern, text, re.DOTALL) - if matches: - return matches[-1].strip() - return None \ No newline at end of file diff --git a/astraEnv/ASearcher/agent/asearcher_reasoning.py b/astraEnv/ASearcher/agent/asearcher_reasoning.py deleted file mode 100644 index edf67dd..0000000 --- a/astraEnv/ASearcher/agent/asearcher_reasoning.py +++ /dev/null @@ -1,579 +0,0 @@ -import re -import time -import copy -from typing import Dict, List, Any, Optional -from datetime import datetime - -class ASearcherReasoningPrompts: - THINK_AND_ACT_PROMPT_v1 = \ -"""Given a question, you are an autonomous agent trying to solve the question with web browser. Given the question and the history context, generate the thought as well as the next action (only one action). Tthe completed thought should contain analysis of available information and planning for future steps. Enclose the thought within tags. - -The next action could be one of the following three, each with specific tags: -1. Search w. a search engine, e.g. the search query - -2. Accessing some url found in prior history, e.g. the url to access - -3. Answering the question, e.g. the answer (usually in less than 10 words) (WARNING: Answer the question only after you double check the results with sufficient search!) - -Guidelines: -1. You should double check previous conclusions and identified facts using search from different perspectives. -3. You can try different directions to solve the question, such as using different search queries. -3. If you find related entries in the search results, it is usually useful to access the corresponding urls to find more information. -4. You should find the most likely answer. -5. The next action should follow after the thought. -6. Make sure you choose only one action. -7. Carefully select the type of language to conduct your search query (Chinese or English) - -Current Time: Today is {current_date} - -Question: -```txt -{question} -``` - -Reasoning history: -```txt -{history} -``` - -Thought: ... // the thought to be completed - -Next Action: ... // the next action to be completed -""" - - THINK_AND_ACT_PROMPT = \ -"""Given a question, you are an autonomous agent trying to solve the question with web browser. Given the question and the history context, generate the thought as well as the next action (only one action). The completed thought should contain a detailed analysis of current situation and a plan for future steps. The action is either a query to google search or accessing some URL. Enclose the thought within tags. - -The next action could be one of the following two, each with specific tags: -1. Search w. a search engine, e.g. the search query - -2. Accessing some url found in prior history to find more information, e.g. the url to access - -Guidelines: -1. You should double check previous conclusions and identified facts using search from different perspectives. -3. You can try different directions to solve the question, such as using different search queries. -3. If you find related entries in the search results, it is usually useful to access the corresponding urls to find more information. -4. The next action should follow after the thought. -5. Make sure you should choose only one action. - -Current Time: Today is {current_date} - -Question: -```txt -{question} -``` - -Reasoning history: -```txt -{history} -``` - -Thought: ... // the thought to be completed - -Next Action: ... // the next action to be completed -""" - - THINK_AND_ANSWER_PROMPT = \ -"""Given a question, you are an autonomous agent trying to solve the question with web browser. Given the question and the history context, generate the thought as well as the final answer. The completed thought should contain detailed analysis of available information. Enclose the thought within tags, and the answer within tags. - -Guideline: -1. Determine the answer based on the the available information. -2. Try to make your best guess if the found information is not enough. - - -Question: -```txt -{question} -``` - -Reasoning history: -```txt -{history} -``` - -Thought: ... // the thought to be completed - -Final Answer: ... // the final answer -""" - READ_PAGE_PROMPT = \ -"""Given a question, you are an autonomous agent trying to solve the question with web browser. Given the question, the history context, and the current web page, generate a thought after reading the webpage. The completed thought should contain information found related to the question, relevant links from the current webpage, and detailed analysis of available information. Enclose the thought within tags. - -Question: -```txt -{question} -``` - -Reasoning history: -```txt -{history} -``` - -Current webpage: -```txt -{content} -``` - -Thought: ... // the thought to be completed -""" - READ_SEARCH_RESULTS_PROMPT = \ -"""Given a question, you are an autonomous agent trying to solve the question with web browser. Given the question, the history context, and the search results of the latest query, generate a thought after reading the search results. The completed thought should contain information found related to the question, relevant links from the latest search results that may help solve the question, and detailed analysis of available information. Enclose the thought within tags. - -Question: -```txt -{question} -``` - -Reasoning history: -```txt -{history} -``` - -Latest search results: -```txt -{content} -``` - -Thought: ... // the thought to be completed -""" - -def process_webpage(content): - keys = [("title", "title"), ("p", "p"), ("li", "li", lambda c: "\n" not in c)] - content_list = [] - init_length = len(content) - while any([f"<{k[0]}" in content and f"" in content for k in keys]): - klr = [] - for k in keys: - start = 0 - while True: - ls = [content[start:].find(f"<{k[0]}{c}") for c in [">", " "]] - ls = [l for l in ls if l != -1] - l = -1 if len(ls) == 0 else min(ls) - if l == -1: - break - l += start - r = content[l:].find(f"") - if r == -1: - break - if (len(k) <= 2) or (len(k) >= 3 and k[2](content[l:l+r])): - klr.append((k, l, l+r)) - break - start = l + r - - if len(klr) == 0: - break - klr = sorted(klr, key=lambda x:x[1]) - k, l, r = klr[0] - content_list.append(content[l:r+len(f"")]) - if k[0] == "p": - content_list[-1] += "\n\n" - elif k[0] == "li": - content_list[-1] += "\n" - content = content[r:] - content = "".join(content_list) - return content - -class AsearcherReasoningAgent: - - def __init__(self, - max_turns: int = 128, - force_turns: int = 32, - topk: int = 10, - force_valid: bool = True): - - self.max_turns = max_turns - self.force_turns = force_turns - self.force_valid = force_valid - self.topk = topk - - self.stop = ["<|im_end|>", "<|endoftext|>"] - self.stop_sequences = self.stop - - self.current_process = None - self.tokenizer = None - - # Agent initialized - - def get_query_from_text(self, text: str) -> Optional[str]: - pattern = r'(.*?)' - matches = re.findall(pattern, text, re.DOTALL) - if matches: - return "" + matches[-1].strip() + "" - - return None - - def get_url_from_text(self, text: str) -> Optional[str]: - pattern = r'(.*?)' - matches = re.findall(pattern, text, re.DOTALL) - if matches: - return "" + matches[-1].strip() + "" - - return None - - def get_thought_from_text(self, text: str) -> Optional[str]: - pattern = r'(.*?)' - matches = re.findall(pattern, text, re.DOTALL) - if matches: - return "" + matches[-1].strip() + "" - - return None - - def get_answer_from_text(self, text: str) -> Optional[str]: - pattern = r'(.*?)' - matches = re.findall(pattern, text, re.DOTALL) - if matches: - return "" + matches[-1].strip() + "" - - return None - - - def all_finished(self, processes: List[Dict]) -> bool: - finished = [] - for process in processes: - finished.append(not process.get("running", True)) - return all(finished) - - def initialize_with_prompt(self, process): - """Initialize agent with a specific prompt""" - if "question" not in process: - process["question"] = process["prompt"] - if "prompt" not in process: - process["prompt"] = process["question"] - if len(process["history"]) == 0: - process["history"] = [dict(type="prompt", text=process["prompt"])] - process["running"] = True - process["phase"] = "search" - self.current_process = copy.deepcopy(process) - - def set_tokenizer(self, tokenizer): - """Set tokenizer for the agent""" - self.tokenizer = tokenizer - - @property - def num_turns(self): - """Get current number of turns""" - if not self.current_process: - return 0 - return len([h for h in self.current_process["history"] if h["type"] == "act"]) - - @property - def is_finished(self): - """Check if agent is finished""" - if not self.current_process or not self.current_process.get("running", False): - return True - - # Check if we have an answer - full_text = "".join([h.get("text", "") for h in self.current_process["history"] if h["type"] != "prompt"]) - has_answer = "" in full_text and "" in full_text - - # Check action count limits - action_count = len([h for h in self.current_process["history"] if h["type"] == "act"]) - max_turns_exceeded = action_count >= self.max_turns + 20 - - # Check failure count - llm_gen_fail = self.current_process.get("llm_gen_fail", 0) - too_many_failures = llm_gen_fail > 32 - - return has_answer or max_turns_exceeded or too_many_failures - - def prepare_llm_query(self): - """Prepare LLM query for current process""" - if not self.current_process: - raise RuntimeError("Agent not initialized with prompt. Call initialize_with_prompt() first.") - - if not self.tokenizer: - raise RuntimeError("Tokenizer not set. Call set_tokenizer() first.") - - process = self.current_process - - if not process.get("running", False): - return "", {"stop": self.stop} - - # Handle reading mode - when we have info_str but no text - if "text" not in process["history"][-1] and "info_str" in process["history"][-1]: - history = "" - for idx, h in enumerate(process["history"][:-1]): - history += h.get("short_info_str", h.get("text", "")) - if len(history) > 25000: - history = history[-25000:] - - if process["history"][-1]["type"] == "page": - prompt = ASearcherReasoningPrompts.READ_PAGE_PROMPT.format( - question=process.get("question", process["prompt"]), - history=history, - content=process["history"][-1]["info_str"] - ) - elif process["history"][-1]["type"] == "documents": - prompt = ASearcherReasoningPrompts.READ_SEARCH_RESULTS_PROMPT.format( - question=process.get("question", process["prompt"]), - history=history, - content=process["history"][-1]["info_str"] - ) - else: - raise RuntimeError(f"Not supported history type: {process['history'][-1]['type']}") - - messages = [{"role": "user", "content": prompt}] - input_text = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False) - query_len = self.tokenizer([input_text], return_length=True)['length'][0] - - if query_len <= 28000: - print(f"Reading @ Qid {process['id']}", query_len, flush=True) - sampling_params = {"stop": self.stop, "max_new_tokens": 31000-query_len} - # sampling_params = {"max_completion_tokens": 31000 - query_len} - # return messages, sampling_params - return input_text, sampling_params - - if "cache_gen_text" in process: - process.pop("cache_gen_text") - - # Handle normal generation mode - building prompt from history - history = "" - for idx, h in enumerate(process["history"]): - history += h.get("short_info_str", h.get("text", "")) - if len(history) > 25000: - history = history[-25000:] - - # Determine if we should force answer generation - action_count = len([h for h in process["history"] if h["type"] == "act"]) - doc_count = len([h for h in process["history"] if h["type"] == "documents"]) - should_answer = any([ - doc_count >= 20, - action_count >= self.force_turns, - process.get("phase", "search") == "answer" - ]) - - if should_answer: - process["phase"] = "answer" - prompt = ASearcherReasoningPrompts.THINK_AND_ACT_PROMPT_v1.format( - question=process.get("question", process["prompt"]), - history=history, - current_date=datetime.now().strftime("%Y.%m.%d") - ) - else: - prompt = ASearcherReasoningPrompts.THINK_AND_ACT_PROMPT.format( - question=process.get("question", process["prompt"]), - history=history, - current_date=datetime.now().strftime("%Y.%m.%d") - ) - - input_text = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False) + process.get("cache_gen_text", "") - - # Apply force_valid logic - if self.force_valid: - input_text = input_text.replace( - '4. If you find information contradicting context of the question, you should point out that the question is invalid and the incorrect information in the question.', - "4. You should find the most likely answer even when conflicting information is founded." - ) - - # Check if process should be terminated - input_len = len(self.tokenizer(input_text, add_special_tokens=False)["input_ids"]) - if input_len > 32000 or self.get_answer_from_text(process["history"][-1].get("text", "")): - print(f"Process done (input too long or has answer): {process['id']}") - process["running"] = False - return "", {"stop": self.stop} - - query_len = self.tokenizer([input_text], return_length=True)['length'][0] - max_new_tokens = max(0, 31000 - query_len) - - print(f"Generate {'Answer' if should_answer else 'Act'} @ Qid {process['id']}", - input_len, doc_count, action_count, max_new_tokens, flush=True) - - sampling_params = {"stop": self.stop, "max_new_tokens": max_new_tokens} - return input_text, sampling_params - - def consume_llm_response(self, resp, completion_text): - """Consume LLM response and extract tool calls""" - if not self.current_process: - raise RuntimeError("Agent not initialized with prompt. Call initialize_with_prompt() first.") - - process = self.current_process - - # Handle different response formats - if hasattr(resp, 'stop_reason') and hasattr(resp, 'text'): - generated_text = resp.text - elif isinstance(resp, dict): - generated_text = resp.get('text', str(resp)) - else: - generated_text = completion_text - - if generated_text is None: - generated_text = "" - - raw_generated_text = generated_text - generated_text = process.get("cache_gen_text", "") + generated_text - - # Return tool calls for V2 interface - tool_calls = [] - - # Extract different components - extracted_thought = self.get_thought_from_text(generated_text) - extracted_answer = self.get_answer_from_text(generated_text) - extracted_query = self.get_query_from_text(generated_text) - extracted_url = self.get_url_from_text(generated_text) - - if process.get("phase", "unknown") != "answer" and extracted_answer is not None: - print(f"Not time for producing answer for {process['id']}", extracted_answer, flush=True) - extracted_answer = None - - # Build think_and_act text - think_and_act = "" - if extracted_thought is not None: - think_and_act = think_and_act + extracted_thought - for act in [extracted_query, extracted_url, extracted_answer]: - if act is not None: - think_and_act = think_and_act.strip() + "\n\n" + act - break - - # Update process history if we have a thought - if extracted_thought is not None: - process["history"].append(dict( - type="act", - full_reasoning_text=generated_text, - text=think_and_act.strip() - )) - if "cache_gen_text" in process: - process.pop("cache_gen_text") - - # tool calls - if extracted_query: - tool_calls.append(extracted_query) - if extracted_url: - tool_calls.append(extracted_url) - if extracted_answer: - tool_calls.append(extracted_answer) - - # Handle page cache - if "page_cache" in process and len(process["page_cache"]) > 0: - page = process["page_cache"].pop(0) - print(f"{process['id']} pop page cache: {[page[:100]]}") - info_str = "\n\n" + page + "\n\n\n" - short_info_str = "\n\n\n" + page[:100] + "...\n\n" + "\n\n" - - process["history"].append(dict( - type="page", - info_str=info_str, - short_info_str=short_info_str - )) - elif len(raw_generated_text) == 0: - process["cache_gen_text"] = "" - process["llm_gen_fail"] = process.get("llm_gen_fail", 0) + 1 - if process["llm_gen_fail"] > 16: - print("process is done (2)", process["id"], process["llm_gen_fail"]) - process["running"] = False - else: - if process["history"][-1]["type"] in ["page", "documents"]: - process["cache_gen_text"] = "" - process["history"].append(dict( - type="act", - full_reasoning_text=generated_text, - text="\n\n" - )) - process["llm_gen_fail"] = process.get("llm_gen_fail", 0) + 1 - process["page_cache"] = [] - else: - process["llm_gen_fail"] = process.get("llm_gen_fail", 0) + 1 - process["cache_gen_text"] = generated_text - - # Check termination conditions - action_count = len([h for h in process["history"] if h["type"] == "act"]) - if action_count >= self.max_turns + 20 or "" in think_and_act: - print("process is done (3)", process["id"], action_count, self.max_turns, "" in think_and_act, flush=True) - process["running"] = False - - return tool_calls - - def consume_tool_response(self, res, topk=5): - """Consume tool response (search or access)""" - if not self.current_process: - raise RuntimeError("Agent not initialized with prompt. Call initialize_with_prompt() first.") - - process = self.current_process - - if res["type"] == "search": - # Handle search results - documents = res.get("documents", [])[:topk] - urls = res.get("urls", [])[:topk] - - print(f"Count of Search documents: {len(documents)}") - - if len(documents) > 0: - doc_id_template = "[Doc {doc_id}]({url}):\n" - info_str = "\n\n\n" + "\n\n".join([doc_id_template.format(doc_id=str(k+1), url=url) + doc for k, (doc, url) in enumerate(zip(documents, urls))]) + "\n\n\n" - short_info_str = "\n\n" + "\n\n".join([doc_id_template.format(doc_id=str(k+1), url=url) + doc + "..." for k, (doc, url) in enumerate(zip(documents, urls))]) + "\n\n\n" - else: - info_str = "\n\n\n" + "No Results Found." + "\n\n\n" - short_info_str = info_str - - process["history"].append(dict( - type="documents", - info_str=info_str, - short_info_str=short_info_str - )) - - elif res["type"] == "access": - # Handle webpage access results - page = res.get("page", "") - - if page and len(page) > 0: - page = page[:250000] - if "page_cache" not in process: - process["page_cache"] = [] - process["page_cache"] = [] - - # Split page into chunks - while len(page) > 0 and len(process["page_cache"]) < 10: - _len = min(10000, len(page)) - process["page_cache"].append(f">>>> Page {len(process["page_cache"]) + 1} >>>>\n\n" + page[:_len]) - page = page[_len:] - - print("[DEBUG] add page", process["id"], len(res.get("page", "")), len(process["page_cache"]), flush=True) - - # Add first page immediately if available - if "page_cache" in process and len(process["page_cache"]) > 0: - page = process["page_cache"].pop(0) - info_str = "\n\n" + page + "\n\n\n" - short_info_str = "\n\n\n" + page[:100] + "...\n\n" + "\n\n" - - process["history"].append(dict( - type="page", - info_str=info_str, - short_info_str=short_info_str - )) - else: - # Empty or invalid page - process["page_cache"] = [] - info_str = "\n\n\nNo More Information is Found for this URL.\n\n\n" - short_info_str = info_str - - process["history"].append(dict( - type="page", - info_str=info_str, - short_info_str=short_info_str - )) - - def get_answer(self): - """Get final answer from current process""" - if not self.current_process: - return None - - process = self.current_process - - if "pred_answer" not in process: - full_text = "".join( - [h["text"] for h in process["history"] if h["type"] != "prompt" and "text" in h] - ) - - if "" in full_text and "" in full_text: - answer = full_text.split("")[-1].split("")[0].strip() - else: - reasoning_text = "\n\n".join([h["full_reasoning_text"] for h in process["history"] if "full_reasoning_text" in h] + [process.get("cache_gen_text", "")]) - # find the last line mentioning 'answer' - lines = reasoning_text.split("\n") - lines = [l for l in lines if 'answer' in l.lower()] - if len(lines) > 0: - answer = lines[-1] - else: - answer = reasoning_text.strip().split("")[-1].strip() - - process["pred_answer"] = answer - - return process["pred_answer"] diff --git a/astraEnv/ASearcher/agent/search_r1.py b/astraEnv/ASearcher/agent/search_r1.py deleted file mode 100644 index d0ea4b9..0000000 --- a/astraEnv/ASearcher/agent/search_r1.py +++ /dev/null @@ -1,297 +0,0 @@ -import re -from typing import Dict, List, Any, Optional -from tools.search_utils import AsyncSearchBrowserClient - - -class SearchR1Agent: - - def __init__(self, - max_turns: int = 10, - topk: int = 5): - - self.max_turns = max_turns - self.topk = topk - self.stop = ["<|im_end|>", "<|endoftext|>", "<|end_of_query|>", "", ""] - self.stop_sequences = self.stop - - self.current_process = None - self.tokenizer = None - - print(f"SearchR1Agent initialized.") - - def get_query_from_text(self, text: str) -> Optional[str]: - pattern = r'<\|begin_of_query\|>(.*?)<\|end_of_query\|>' - matches = re.findall(pattern, text, re.DOTALL) - if matches: - return matches[-1].strip() - - if '<|begin_of_query|>' in text: - parts = text.split('<|begin_of_query|>') - if len(parts) > 1: - query_part = parts[-1] - if not query_part.strip().endswith('<|end_of_query|>'): - return query_part.strip() - - search_pattern = r'(.*?)' - search_matches = re.findall(search_pattern, text, re.DOTALL) - if search_matches: - return search_matches[-1].strip() - - if '' in text: - parts = text.split('') - if len(parts) > 1: - query_part = parts[-1] - if not query_part.strip().endswith(''): - return query_part.strip() - - return None - - def fix_incomplete_search_tag(self, text: str) -> str: - if '<|begin_of_query|>' in text and not text.strip().endswith('<|end_of_query|>'): - return text.strip() + '<|end_of_query|>' - - if '' in text and not text.strip().endswith(''): - return text.strip() + '' - - return text - - def all_finished(self, processes: List[Dict]) -> bool: - finished = [] - for process in processes: - finished.append(not process.get("running", True)) - return all(finished) - - def initialize_with_prompt(self, prompt): - """Initialize agent with a specific prompt""" - self.current_process = { - "prompt": prompt, - "history": [dict(type="prompt", text=prompt)], - "running": True, - "id": "0" - } - - def set_tokenizer(self, tokenizer): - """Set tokenizer for the agent""" - self.tokenizer = tokenizer - - @property - def num_turns(self): - """Get current number of turns""" - if not self.current_process: - return 0 - return len([h for h in self.current_process["history"] if h["type"] == "act"]) - - @property - def is_finished(self): - """Check if agent is finished""" - if not self.current_process or not self.current_process.get("running", False): - return True - - # Check if we have an answer - full_text = "".join([h.get("text", "") for h in self.current_process["history"] if h["type"] != "prompt"]) - has_answer = "" in full_text and "" in full_text - - # Check max turns - action_count = len([h for h in self.current_process["history"] if h["type"] == "act"]) - max_turns_exceeded = action_count >= self.max_turns - - return has_answer or max_turns_exceeded - - def prepare_llm_query(self): - """Prepare LLM query for current process""" - if not self.current_process: - raise RuntimeError("Agent not initialized with prompt. Call initialize_with_prompt() first.") - - if not self.tokenizer: - raise RuntimeError("Tokenizer not set. Call set_tokenizer() first.") - - process = self.current_process - - if not process.get("running", False): - return "", {"stop": self.stop} - - # Check if last text contains a search query - last_text = process["history"][-1]["text"] - - # Handle search query patterns - return empty to trigger tool calling - if (("<|begin_of_query|>" in last_text and last_text.strip().endswith("<|end_of_query|>")) or - ("" in last_text and last_text.strip().endswith(""))): - return "", {"stop": self.stop} - - # Normal LLM generation - input_text = "".join([h["text"] for h in process["history"]]) - query_len = self.tokenizer([input_text], return_length=True)['length'][0] - - sampling_params = {"stop": self.stop} - - return input_text, sampling_params - - def consume_llm_response(self, resp, completion_text): - """Consume LLM response and extract tool calls""" - if not self.current_process: - raise RuntimeError("Agent not initialized with prompt. Call initialize_with_prompt() first.") - - process = self.current_process - - # Handle different response formats - if hasattr(resp, 'stop_reason') and hasattr(resp, 'text'): - stop_reason = resp.stop_reason - generated_text = resp.text - elif isinstance(resp, dict): - stop_reason = resp.get('stop_reason', '') - generated_text = resp.get('text', str(resp)) - elif resp is None: - stop_reason = "" - generated_text = completion_text or "" - else: - stop_reason = "" if "" in str(resp) else "" - generated_text = completion_text or str(resp) - - # Fix incomplete search tags - fixed_text = self.fix_incomplete_search_tag(generated_text) - if fixed_text != generated_text: - generated_text = fixed_text - - # Extract query and check for actions - extracted_query = self.get_query_from_text(generated_text) - tool_calls = [] - - if extracted_query: - # This is a search action - process["history"].append(dict( - type="act", - text=generated_text.strip() - )) - # Create tool call for search - if ("<|begin_of_query|>" in generated_text and generated_text.strip().endswith("<|end_of_query|>")): - query_text = generated_text.split("<|begin_of_query|>")[-1].split("<|end_of_query|>")[0] - tool_calls.append(f"{query_text.strip()}") - elif ("" in generated_text and generated_text.strip().endswith("")): - query_text = generated_text.split("")[-1].split("")[0] - tool_calls.append(f"{query_text.strip()}") - - elif "" in generated_text and (stop_reason == "" or "" in generated_text): - # This is a final answer - if not generated_text.strip().endswith(""): - generated_text = generated_text.strip() + "" - process["history"].append(dict( - type="act", - text=generated_text - )) - process["running"] = False - # Extract answer for tool call - if "" in generated_text and "" in generated_text: - answer_text = generated_text.split("")[-1].split("")[0] - tool_calls.append(f"{answer_text.strip()}") - - elif (("" in generated_text and generated_text.strip().endswith("")) or - ("<|begin_of_query|>" in generated_text and generated_text.strip().endswith("<|end_of_query|>"))): - # This is a complete search query - process["history"].append(dict( - type="act", - text=generated_text.strip() + "\n\n" - )) - # Extract query for tool call - if ("<|begin_of_query|>" in generated_text and generated_text.strip().endswith("<|end_of_query|>")): - query_text = generated_text.split("<|begin_of_query|>")[-1].split("<|end_of_query|>")[0] - tool_calls.append(f"{query_text.strip()}") - elif ("" in generated_text and generated_text.strip().endswith("")): - query_text = generated_text.split("")[-1].split("")[0] - tool_calls.append(f"{query_text.strip()}") - else: - # Invalid action, add auxiliary message - process["history"].append(dict( - type="act", - text=generated_text.strip() + "\n\n" - )) - process["history"].append(dict( - type="auxilliary", - text="\nMy previous action is invalid. If I want to search, I should put the query between and . If I want to give the final answer, I should put the answer between and . Let me try again.\n" - )) - - # Check if max turns reached - action_count = len([h for h in process["history"] if h["type"] == "act"]) - if action_count >= self.max_turns: - process["running"] = False - - return tool_calls - - def consume_tool_response(self, res, topk=5): - """Consume tool response (search) - Updated for agent v2 compatibility""" - if not self.current_process: - raise RuntimeError("Agent not initialized with prompt. Call initialize_with_prompt() first.") - - process = self.current_process - - if res["type"] == "search": - if isinstance(res, list): - r = res[0] - else: - r = res - - if isinstance(r, dict) and 'documents' in r: - documents = r["documents"] - urls = r.get("urls", []) - else: - documents = [] - urls = [] - - # Add formatted content for the agent's internal use (LLM consumption) - if len(documents) > 0: - doc_content_list = [] - for j, doc in enumerate(documents): - if isinstance(doc, str): - doc_clean = re.sub(r'^\d+\s+', '', doc.strip()) - doc_content_list.append(f"{j+1}. {doc_clean}\n") - doc_content = '\n'.join(doc_content_list) - else: - doc_content = "" - - if doc_content: - formatted_content = "\n\n" + doc_content + "\n\n" - else: - formatted_content = "\n\nNo relevant documents found.\n\n" - - # Add formatted content for LLM consumption - process["history"].append({ - "type": "documents", # Keep for backward compatibility - "text": formatted_content - }) - - def get_answer(self): - """Get final answer from current process""" - if not self.current_process: - return None - - process = self.current_process - - if "pred_answer" not in process: - full_text = "".join([h["text"] for h in process["history"] if h["type"] != "prompt"]) - - if "" in full_text and "" in full_text: - answer = full_text.split("")[-1].split("")[0].strip() - else: - answer = full_text.strip() - - process["pred_answer"] = answer - - return process["pred_answer"] - - def fix_process_incomplete_tags(self, process: Dict) -> Dict: - fixed_count = 0 - history = process.get("history", []) - - for i, entry in enumerate(history): - if entry.get("type") == "act": - original_text = entry["text"] - fixed_text = self.fix_incomplete_search_tag(original_text) - - if fixed_text != original_text: - history[i]["text"] = fixed_text - fixed_count += 1 - - return { - "total_entries": len(history), - "fixed_entries": fixed_count, - "process_id": process.get("process_id", "unknown") - }