From c4f22f2041faf2fedf8437a70c93f9c7861609ad Mon Sep 17 00:00:00 2001 From: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:14:29 +0800 Subject: [PATCH] feat(infra): allow colocation with offloading and disk weight updates (#1157) * feat(infra): allow colocation with offloading and disk weight updates Enable scheduler-level colocation to run actor/critic and engine processes on shared GPU allocations while preserving async update correctness. Key changes: - Add colocate/offload/disk-update config fields and scheduler plumbing - Harden RPC server and engine blueprint coordination for colocated roles - Extend trainer paths and tests for colocated evaluation dispatch behavior * fix(infra): restore default train-engine RPC broadcast Keep initialized TrainEngine RPC calls backward compatible across the guard and Ray servers so non-head model-parallel ranks continue to receive controller payloads without every call site opting in manually. * fix: enforce offload prerequisites for colocated training Fail fast when colocated or explicit train-engine offload would run without TMS support, and provision Ray workers with the same offload environment as the local and Slurm schedulers. --------- Co-authored-by: Wentai Zhang --- areal/api/cli_args.py | 8 + areal/api/scheduler_api.py | 6 + areal/engine/fsdp_engine.py | 89 +++++--- areal/engine/megatron_engine.py | 119 ++++++---- areal/experimental/engine/archon_engine.py | 83 ++++--- areal/infra/controller/rollout_controller.py | 8 + areal/infra/controller/train_controller.py | 34 ++- areal/infra/rpc/guard/engine_blueprint.py | 63 +++--- areal/infra/rpc/ray_rpc_server.py | 55 +++-- areal/infra/scheduler/local.py | 11 + areal/infra/scheduler/ray.py | 20 +- areal/infra/scheduler/slurm.py | 4 + areal/infra/utils/concurrent.py | 16 ++ areal/trainer/ppo/actor.py | 12 +- areal/trainer/ppo/critic.py | 8 +- areal/trainer/rl_trainer.py | 223 +++++++++++++++++-- areal/trainer/rw/rw_engine.py | 12 +- areal/trainer/sft/lm_engine.py | 8 +- docs/en/cli_reference.md | 4 + docs/zh/cli_reference.md | 4 + tests/test_eval_dispatch.py | 2 +- tests/test_examples.py | 46 ++++ 22 files changed, 649 insertions(+), 186 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 5430f32aed..d039191fab 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1081,6 +1081,14 @@ class TrainEngineConfig: archon: ArchonEngineConfig = field(default_factory=ArchonEngineConfig) megatron: MegatronEngineConfig = field(default_factory=MegatronEngineConfig) + # offload + offload: bool = field( + default=False, + metadata={ + "help": "Whether to offload model parameters and optimizer states to CPU. " + }, + ) + # Lora use_lora: bool = field( default=False, diff --git a/areal/api/scheduler_api.py b/areal/api/scheduler_api.py index 71f4751f23..6c760d55e9 100644 --- a/areal/api/scheduler_api.py +++ b/areal/api/scheduler_api.py @@ -226,6 +226,7 @@ def call_engine( method: str, engine_name: str | None = None, *args, + rpc_meta: dict[str, Any] | None = None, **kwargs, ) -> Any: """Call a method on an engine instance running on a worker (data plane operation). @@ -243,6 +244,8 @@ def call_engine( Defaults to worker_id if not specified. *args Positional arguments to pass to the method + rpc_meta : dict[str, Any] | None, optional + RPC metadata, by default None **kwargs Keyword arguments to pass to the method @@ -269,6 +272,7 @@ async def async_call_engine( method: str, engine_name: str | None = None, *args, + rpc_meta: dict[str, Any] | None = None, **kwargs, ) -> Any: """Async version of call_engine for calling engine methods asynchronously. @@ -286,6 +290,8 @@ async def async_call_engine( Defaults to worker_id if not specified. *args Positional arguments to pass to the method + rpc_meta : dict[str, Any] | None, optional + RPC metadata, by default None **kwargs Keyword arguments to pass to the method diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index d78c644e2f..89a235944f 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -7,7 +7,7 @@ import time from collections.abc import Callable, Iterator from concurrent.futures import Future -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from datetime import datetime from typing import TYPE_CHECKING, Any @@ -485,20 +485,14 @@ def prepare_batch( def update_weights(self, meta: WeightUpdateMeta): self._check_rollout_engine_connected() - if meta.type == "xccl": - assert self.weight_update_group_initialized - # In offload mode, wakes up parameters as needed to perform the update. - tms_context = ( - torch_memory_saver.disable() - if self.is_offload and not torch.version.hip - else nullcontext() - ) - with tms_context: + with self._offload_aware_context(): + if meta.type == "xccl": + assert self.weight_update_group_initialized self._update_weights_from_distributed(meta) - elif meta.type == "disk": - self._update_weights_from_disk(meta) - else: - raise ValueError(f"Unknown weight update type {meta.type}") + elif meta.type == "disk": + self._update_weights_from_disk(meta) + else: + raise ValueError(f"Unknown weight update type {meta.type}") def set_version(self, version: int): self._version = version @@ -507,31 +501,47 @@ def get_version(self) -> int: return self._version def save(self, meta: SaveLoadMeta): - if meta.weight_format == "hf": - self._save_model_to_hf(meta.path, meta.tokenizer, meta.processor) - elif meta.weight_format == "dcp": - self._save_to_dcp(meta.path, meta.with_optim) - else: - raise ValueError(f"Unknown weight format {meta.weight_format}. ") + with self._offload_aware_context(): + if meta.weight_format == "hf": + self._save_model_to_hf(meta.path, meta.tokenizer, meta.processor) + elif meta.weight_format == "dcp": + self._save_to_dcp(meta.path, meta.with_optim) + else: + raise ValueError(f"Unknown weight format {meta.weight_format}. ") - if meta.with_optim and meta.weight_format == "hf": - self._save_optimizer_state(meta.path) + if meta.with_optim and meta.weight_format == "hf": + self._save_optimizer_state(meta.path) def load(self, meta: SaveLoadMeta): - if meta.weight_format == "hf": - self._load_model_from_hf(meta.path) - elif meta.weight_format == "dcp": - self._load_from_dcp(meta.path, meta.with_optim) - else: - raise ValueError(f"Unknown weight format {meta.weight_format}. ") - - if meta.with_optim and meta.weight_format == "hf": - self._load_optimizer_state(meta.path) + with self._offload_aware_context(): + if meta.weight_format == "hf": + self._load_model_from_hf(meta.path) + elif meta.weight_format == "dcp": + self._load_from_dcp(meta.path, meta.with_optim) + else: + raise ValueError(f"Unknown weight format {meta.weight_format}. ") + + if meta.with_optim and meta.weight_format == "hf": + self._load_optimizer_state(meta.path) + + # Checkpoint load replaces optimizer state tensor objects, losing + # pinning and normalization established by PerLayerOptimWrapper.__init__. + if meta.with_optim and self._per_layer_optim_wrapper is not None: + self._per_layer_optim_wrapper.refresh_states() + + @contextmanager + def _offload_aware_context(self): + """Temporarily onload parameters for offload-unsafe operations.""" + if not self.is_offload: + with nullcontext(): + yield + return - # Checkpoint load replaces optimizer state tensor objects, losing - # pinning and normalization established by PerLayerOptimWrapper.__init__. - if meta.with_optim and self._per_layer_optim_wrapper is not None: - self._per_layer_optim_wrapper.refresh_states() + self.onload() + try: + yield + finally: + self.offload() def optimizer_zero_grad(self): assert self.optimizer is not None @@ -749,13 +759,20 @@ def process_output(logits: torch.Tensor, ctx_dict: dict[str, Any]) -> None: return split_batch(result, meta) def export_stats(self) -> dict[str, float]: - return stats_tracker.export_all(reduce_group=self.data_parallel_group) + with self._offload_aware_context(): + return stats_tracker.export_all( + reduce_group=self.data_parallel_group, + ) def offload(self) -> None: """Offload model memory to CPU using torch_memory_saver. Ref: https://github.com/THUDM/slime/blob/main/slime/backends/fsdp_utils/actor.py """ + if not is_tms_enabled(): + raise RuntimeError( + "torch_memory_saver requires `enable_offload=True` in yaml config." + ) self.get_device_stats().log("before offload model") diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 9eb88cb0d6..914f8c1798 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -7,7 +7,7 @@ import os from collections.abc import Callable, Iterator from concurrent.futures import Future -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from datetime import datetime from typing import TYPE_CHECKING, Any @@ -569,20 +569,14 @@ def prepare_batch( def update_weights(self, meta: WeightUpdateMeta): self._check_rollout_engine_connected() - if meta.type == "xccl": - assert self.weight_update_group_initialized - # In offload mode, wakes up parameters as needed to perform the update. - tms_context = ( - torch_memory_saver.disable() - if self.is_offload and not torch.version.hip - else nullcontext() - ) - with tms_context: + with self._offload_aware_context(): + if meta.type == "xccl": + assert self.weight_update_group_initialized self._update_weights_from_distributed(meta) - elif meta.type == "disk": - self._update_weights_from_disk(meta) - else: - raise ValueError(f"Unknown weight update type {meta.type}") + elif meta.type == "disk": + self._update_weights_from_disk(meta) + else: + raise ValueError(f"Unknown weight update type {meta.type}") def set_version(self, version: int): self._version = version @@ -591,45 +585,65 @@ def get_version(self) -> int: return self._version def save(self, meta: SaveLoadMeta): - if meta.weight_format == "hf": - if meta.with_optim: - raise ValueError( - "HF format does not support optimizer state saving, please use DCP format instead." + with self._offload_aware_context(): + if meta.weight_format == "hf": + if meta.with_optim: + raise ValueError( + "HF format does not support optimizer state saving, please use DCP format instead." + ) + self._save_model_to_hf( + meta.path, + tokenizer=meta.tokenizer, + processor=meta.processor, + base_model_path=meta.base_model_path, ) - self._save_model_to_hf( - meta.path, - tokenizer=meta.tokenizer, - processor=meta.processor, - base_model_path=meta.base_model_path, - ) - elif meta.weight_format == "dcp": - if self.checkpointer is None: - raise NotImplementedError( - "DCP checkpoint save is not available for this Megatron configuration " - "(e.g., LoRA path without distributed optimizer support). " - "Please use weight_format='hf' for adapter/full-model export." + elif meta.weight_format == "dcp": + if self.checkpointer is None: + raise NotImplementedError( + "DCP checkpoint save is not available for this Megatron configuration " + "(e.g., LoRA path without distributed optimizer support). " + "Please use weight_format='hf' for adapter/full-model export." + ) + self.checkpointer.save_checkpoint( + meta.path, with_optimizer=meta.with_optim ) - self.checkpointer.save_checkpoint(meta.path, with_optimizer=meta.with_optim) - else: - raise ValueError(f"Unknown weight format {meta.weight_format}. ") + else: + raise ValueError(f"Unknown weight format {meta.weight_format}. ") def load(self, meta: SaveLoadMeta): - if meta.weight_format == "hf": - if meta.with_optim: - raise ValueError( - "HF format does not support optimizer state loading, please use DCP format instead." - ) - self._load_model_from_hf(meta.path) - elif meta.weight_format == "dcp": - if self.checkpointer is None: - raise NotImplementedError( - "DCP checkpoint load is not available for this Megatron configuration " - "(e.g., LoRA path without distributed optimizer support). " - "Please use weight_format='hf' for adapter/full-model load." + with self._offload_aware_context(): + if meta.weight_format == "hf": + if meta.with_optim: + raise ValueError( + "HF format does not support optimizer state loading, please use DCP format instead." + ) + self._load_model_from_hf(meta.path) + elif meta.weight_format == "dcp": + if self.checkpointer is None: + raise NotImplementedError( + "DCP checkpoint load is not available for this Megatron configuration " + "(e.g., LoRA path without distributed optimizer support). " + "Please use weight_format='hf' for adapter/full-model load." + ) + self.checkpointer.load_checkpoint( + meta.path, with_optimizer=meta.with_optim ) - self.checkpointer.load_checkpoint(meta.path, with_optimizer=meta.with_optim) - else: - raise ValueError(f"Unknown weight format {meta.weight_format}. ") + else: + raise ValueError(f"Unknown weight format {meta.weight_format}. ") + + @contextmanager + def _offload_aware_context(self): + """Temporarily onload parameters for offload-unsafe operations.""" + if not self.is_offload: + with nullcontext(): + yield + return + + self.onload() + try: + yield + finally: + self.offload() def optimizer_zero_grad(self): assert self.optimizer is not None, "Optimizer is not initialized." @@ -868,7 +882,10 @@ def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> None: return split_batch(res, meta) def export_stats(self) -> dict[str, float]: - data = stats_tracker.export_all(reduce_group=self.data_parallel_group) + with self._offload_aware_context(): + data = stats_tracker.export_all( + reduce_group=self.data_parallel_group, + ) if mpu.get_pipeline_model_parallel_world_size() > 1: # Some log info only exist in last pipeline rank data_list = [data] @@ -885,6 +902,10 @@ def offload(self) -> None: Ref: https://github.com/THUDM/slime/blob/main/slime/backends/megatron_utils/actor.py """ + if not is_tms_enabled(): + raise RuntimeError( + "torch_memory_saver requires `enable_offload=True` in yaml config." + ) self.get_device_stats().log("before offload model") current_platform.clear_memory() diff --git a/areal/experimental/engine/archon_engine.py b/areal/experimental/engine/archon_engine.py index a246b607be..a9d80479ff 100644 --- a/areal/experimental/engine/archon_engine.py +++ b/areal/experimental/engine/archon_engine.py @@ -6,7 +6,7 @@ import os import time from collections.abc import Callable -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from dataclasses import dataclass from typing import TYPE_CHECKING, Any @@ -691,51 +691,69 @@ def clear_batches(self, *args): def update_weights(self, meta: WeightUpdateMeta): """Update weights to inference engine.""" self._check_rollout_engine_connected() - if meta.type == "xccl": - assert self._weight_sync_state.group_initialized - tms_context = ( - torch_memory_saver.disable() - if self.is_offload and not torch.version.hip - else nullcontext() - ) - with tms_context: + with self._offload_aware_context(): + if meta.type == "xccl": + assert self._weight_sync_state.group_initialized update_weights_from_distributed( state=self._weight_sync_state, meta=meta, engine=self, ) - elif meta.type == "disk": - update_weights_from_disk( - meta=meta, - engine=self, - ) + elif meta.type == "disk": + update_weights_from_disk( + meta=meta, + engine=self, + ) + else: + raise ValueError(f"Unknown weight update type {meta.type}") def save(self, meta: SaveLoadMeta): """Save model in HuggingFace or DCP format.""" - if meta.weight_format == "hf": - save_model_to_hf(self, meta.path, meta.tokenizer, meta.processor) - elif meta.weight_format == "dcp": - save_to_dcp(self, meta.path, meta.with_optim) - else: - raise ValueError(f"Unknown weight format {meta.weight_format}.") + with self._offload_aware_context(): + if meta.weight_format == "hf": + save_model_to_hf(self, meta.path, meta.tokenizer, meta.processor) + elif meta.weight_format == "dcp": + save_to_dcp(self, meta.path, meta.with_optim) + else: + raise ValueError(f"Unknown weight format {meta.weight_format}.") - if meta.with_optim and meta.weight_format == "hf": - save_optimizer_state(self, meta.path) + if meta.with_optim and meta.weight_format == "hf": + save_optimizer_state(self, meta.path) def load(self, meta: SaveLoadMeta): """Load model from HuggingFace or DCP format.""" - if meta.weight_format == "hf": - load_model_from_hf(self, meta.path) - elif meta.weight_format == "dcp": - load_from_dcp(self, meta.path, meta.with_optim) - else: - raise ValueError(f"Unknown weight format {meta.weight_format}.") + with self._offload_aware_context(): + if meta.weight_format == "hf": + load_model_from_hf(self, meta.path) + elif meta.weight_format == "dcp": + load_from_dcp(self, meta.path, meta.with_optim) + else: + raise ValueError(f"Unknown weight format {meta.weight_format}.") + + if meta.with_optim and meta.weight_format == "hf": + load_optimizer_state(self, meta.path) + + @contextmanager + def _offload_aware_context(self): + """Temporarily onload parameters for offload-unsafe operations.""" + if not self.is_offload: + with nullcontext(): + yield + return - if meta.with_optim and meta.weight_format == "hf": - load_optimizer_state(self, meta.path) + self.onload() + try: + yield + finally: + self.offload() def offload(self) -> None: """Offload model memory to CPU using torch_memory_saver.""" + if not is_tms_enabled(): + raise RuntimeError( + "torch_memory_saver requires `enable_offload=True` in yaml config." + ) + self.get_device_stats().log("before offload model") current_platform.clear_memory() @@ -759,7 +777,10 @@ def onload(self) -> None: def export_stats(self) -> dict[str, float]: assert self._initialized - data = stats_tracker.export_all(reduce_group=self.data_parallel_group) + with self._offload_aware_context(): + data = stats_tracker.export_all( + reduce_group=self.data_parallel_group, + ) if self.parallel_dims.pp_enabled: data_list = [data] dist.broadcast_object_list( diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index 8fde8e5fdf..210d2e43fa 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -1037,6 +1037,14 @@ async def pause_generation(self): async def continue_generation(self): await self._collective_rpc_async("continue_generation") + def offload(self) -> None: + """Offload rollout model memory on all inference workers.""" + self._collective_rpc("offload") + + def onload(self, tags: list[str] | None = None) -> None: + """Onload rollout model memory on all inference workers.""" + self._collective_rpc("onload", tags=tags) + def set_version(self, version: int) -> None: with self._version_lock: self._version = version diff --git a/areal/infra/controller/train_controller.py b/areal/infra/controller/train_controller.py index 59bef0a85b..491951875d 100644 --- a/areal/infra/controller/train_controller.py +++ b/areal/infra/controller/train_controller.py @@ -442,16 +442,32 @@ async def _destroy_all_engines(): dist.destroy_process_group() logger.info("TrainController destroyed") - def _custom_function_call(self, method: str, *args, **kwargs): + def _custom_function_call( + self, + method: str, + *args, + rpc_meta: dict[str, Any] | None = None, + **kwargs, + ): """Dispatch method call to workers via the appropriate path.""" dp_args, dp_kwargs, group_indices = self._prepare_dispatch(*args, **kwargs) - results = run_async_task(self._call_workers, method, dp_args, dp_kwargs) + results = run_async_task( + self._call_workers, method, dp_args, dp_kwargs, rpc_meta=rpc_meta + ) return self._collect_results(results, group_indices) - async def _async_custom_function_call(self, method: str, *args, **kwargs): + async def _async_custom_function_call( + self, + method: str, + *args, + rpc_meta: dict[str, Any] | None = None, + **kwargs, + ): """Async version of _custom_function_call.""" dp_args, dp_kwargs, group_indices = self._prepare_dispatch(*args, **kwargs) - results = await self._call_workers(method, dp_args, dp_kwargs) + results = await self._call_workers( + method, dp_args, dp_kwargs, rpc_meta=rpc_meta + ) return self._collect_results(results, group_indices) def _pad_eval_dispatch_args( @@ -518,6 +534,7 @@ async def _call_workers( method: str, dp_split_args: list[list[Any]], dp_split_kwargs: dict[str, list[Any]], + rpc_meta: dict[str, Any] | None = None, ): """Send dispatched inputs to workers. DP heads get slices, others empty.""" tasks = [] @@ -539,6 +556,7 @@ async def _call_workers( method, self._engine_name(idx), *worker_args, + rpc_meta=rpc_meta, **worker_kwargs, ) ) @@ -669,6 +687,14 @@ def update_weights(self, meta: WeightUpdateMeta): self._check_rollout_engine_connected() self._custom_function_call("update_weights", meta=meta) + def offload(self) -> None: + """Offload model parameters to CPU across all train workers.""" + self._custom_function_call("offload") + + def onload(self) -> None: + """Onload model parameters to GPU across all train workers.""" + self._custom_function_call("onload") + def get_device_stats(self): return self._custom_function_call("get_device_stats") diff --git a/areal/infra/rpc/guard/engine_blueprint.py b/areal/infra/rpc/guard/engine_blueprint.py index 4dbd1c765d..b9fa03bbf2 100644 --- a/areal/infra/rpc/guard/engine_blueprint.py +++ b/areal/infra/rpc/guard/engine_blueprint.py @@ -37,6 +37,25 @@ logger = logging.getLogger("EngineBP") + +def _should_broadcast_payload( + engine: TrainEngine | InferenceEngine, rpc_meta: dict[str, Any] | None +) -> bool: + default_broadcast = isinstance(engine, TrainEngine) and engine.initialized + if rpc_meta is None: + return default_broadcast + if not isinstance(rpc_meta, dict): + raise ValueError( + f"Invalid rpc_meta: expected dict or None, got {type(rpc_meta)}" + ) + broadcast = rpc_meta.get("broadcast", default_broadcast) + if not isinstance(broadcast, bool): + raise ValueError( + f"Invalid rpc_meta.broadcast: expected bool, got {type(broadcast)}" + ) + return broadcast + + engine_bp = Blueprint("engine", __name__) # --------------------------------------------------------------------------- @@ -405,6 +424,7 @@ def call_engine_method(): engine_name = data.get("engine_name") raw_args = data.get("args", []) raw_kwargs = data.get("kwargs", {}) + rpc_meta = data.get("rpc_meta") if not method_name: return ( @@ -441,28 +461,13 @@ def call_engine_method(): def execute_in_engine_thread(): try: - # Broadcast args when engine is a TrainEngine and initialized - if isinstance(engine, TrainEngine) and engine.initialized: - logger.debug( - f"Broadcasting data for TrainEngine method: {method_name}" - ) - - nonlocal raw_args, raw_kwargs - raw_args = broadcast_tensor_container( - tensor_container_to( - raw_args, current_platform.current_device() - ), - src_rank=engine.current_data_parallel_head(), - group=engine.context_and_model_parallel_group, - ) - raw_kwargs = broadcast_tensor_container( - tensor_container_to( - raw_kwargs, current_platform.current_device() - ), - src_rank=engine.current_data_parallel_head(), - group=engine.context_and_model_parallel_group, - ) - + args_bcast = args + kwargs_bcast = kwargs + should_broadcast = _should_broadcast_payload( + engine=engine, rpc_meta=rpc_meta + ) + if should_broadcast: + logger.debug(f"Broadcasting RPC payload for method: {method_name}") args_bcast = tensor_container_to( args, current_platform.current_device() ) @@ -479,13 +484,19 @@ def execute_in_engine_thread(): src_rank=engine.current_data_parallel_head(), group=engine.context_and_model_parallel_group, ) - logger.debug("Broadcasting data done.") - else: - args_bcast = args - kwargs_bcast = kwargs + logger.debug("Broadcasting RPC payload done.") logger.debug(f"Calling engine '{engine_name}' method: {method_name}") + # Re-establish current device in RPC execution context before + # calling engine methods that may issue object collectives. + if ( + isinstance(engine, TrainEngine) + and engine.initialized + and current_platform.device_type != "cpu" + ): + current_platform.set_device(current_platform.current_device()) + # Determine trace category based on method name category = "misc" # Default category method_lower = method_name.lower() diff --git a/areal/infra/rpc/ray_rpc_server.py b/areal/infra/rpc/ray_rpc_server.py index 09403eed42..55b5f9d07d 100644 --- a/areal/infra/rpc/ray_rpc_server.py +++ b/areal/infra/rpc/ray_rpc_server.py @@ -43,6 +43,25 @@ def _get_device(self): return current_platform.current_device() + def _should_broadcast_payload( + self, + engine: TrainEngine | InferenceEngine, + rpc_meta: dict[str, Any] | None, + ) -> bool: + default_broadcast = isinstance(engine, TrainEngine) and engine.initialized + if rpc_meta is None: + return default_broadcast + if not isinstance(rpc_meta, dict): + raise ValueError( + f"Invalid rpc_meta: expected dict or None, got {type(rpc_meta)}" + ) + broadcast = rpc_meta.get("broadcast", default_broadcast) + if not isinstance(broadcast, bool): + raise ValueError( + f"Invalid rpc_meta.broadcast: expected bool, got {type(broadcast)}" + ) + return broadcast + def ping(self) -> str: return "ok" @@ -99,7 +118,14 @@ def create_engine( ) raise - def call(self, method: str, *args, engine_name: str | None = None, **kwargs) -> Any: + def call( + self, + method: str, + *args, + engine_name: str | None = None, + rpc_meta: dict[str, Any] | None = None, + **kwargs, + ) -> Any: self.logger.debug( f"Calling {method} on engine '{engine_name}' with arguments {args=} {kwargs=}" ) @@ -121,21 +147,12 @@ def call(self, method: str, *args, engine_name: str | None = None, **kwargs) -> args = RTensor.localize(raw_args) kwargs = RTensor.localize(raw_kwargs) - # Broadcast args when engine is a TrainEngine and has been initialized try: - if isinstance(engine, TrainEngine) and engine.initialized: + should_broadcast = self._should_broadcast_payload( + engine=engine, rpc_meta=rpc_meta + ) + if should_broadcast: device = self._get_device() - - raw_args = broadcast_tensor_container( - tensor_container_to(raw_args, self._get_device()), - src_rank=engine.current_data_parallel_head(), - group=engine.context_and_model_parallel_group, - ) - raw_kwargs = broadcast_tensor_container( - tensor_container_to(raw_kwargs, self._get_device()), - src_rank=engine.current_data_parallel_head(), - group=engine.context_and_model_parallel_group, - ) args = tensor_container_to(args, device) args = broadcast_tensor_container( args, @@ -157,6 +174,16 @@ def call(self, method: str, *args, engine_name: str | None = None, **kwargs) -> try: fn = getattr(engine, method) + # Re-establish current device in RPC execution context before + # invoking engine methods that may issue object collectives. + if ( + isinstance(engine, TrainEngine) + and engine.initialized + and self._get_device().type != "cpu" + ): + from areal.infra.platforms import current_platform + + current_platform.set_device(current_platform.current_device()) result = fn(*args, **kwargs) if isinstance(result, Future): result = result.result() diff --git a/areal/infra/scheduler/local.py b/areal/infra/scheduler/local.py index 8c1b9a7a35..f97361bbe2 100644 --- a/areal/infra/scheduler/local.py +++ b/areal/infra/scheduler/local.py @@ -50,6 +50,7 @@ format_hostport, gethostip, ) +from areal.utils.offload import get_tms_env_vars logger = logging.getLogger("LocalScheduler") @@ -102,6 +103,7 @@ def __init__( experiment_name: str | None = None, trial_name: str | None = None, fileroot: str | None = None, + enable_tms_offload: bool | None = None, name_resolve_type: str = "nfs", nfs_record_root: str = "/tmp/areal/name_resolve", etcd3_addr: str = "localhost:2379", @@ -113,10 +115,12 @@ def __init__( self.experiment_name = experiment_name self.trial_name = trial_name self.fileroot = fileroot + self.enable_tms_offload = bool(enable_tms_offload) if exp_config is not None: self.experiment_name = exp_config.experiment_name self.trial_name = exp_config.trial_name self.fileroot = exp_config.cluster.fileroot + self.enable_tms_offload = exp_config.enable_offload # name_resolve config (exp_config overwrites direct params) self.name_resolve_config = NameResolveConfig( @@ -727,6 +731,9 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: ) env.update(thread_env) + if self.enable_tms_offload: + env.update(get_tms_env_vars()) + if scheduling.env_vars: env.update(scheduling.env_vars) @@ -1251,6 +1258,7 @@ def call_engine( method: str, engine_name: str | None = None, *args, + rpc_meta: dict[str, Any] | None = None, http_timeout: float = 7200.0, max_retries: int = 3, retry_delay: float = 1.0, @@ -1308,6 +1316,7 @@ def call_engine( "engine_name": engine_name, "args": serialized_args, "kwargs": serialized_kwargs, + "rpc_meta": rpc_meta, } # Retry logic with exponential backoff @@ -1378,6 +1387,7 @@ async def async_call_engine( method: str, engine_name: str | None = None, *args, + rpc_meta: dict[str, Any] | None = None, http_timeout: float = 7200.0, max_retries: int = 3, retry_delay: float = 1.0, @@ -1437,6 +1447,7 @@ async def async_call_engine( "engine_name": engine_name, "args": serialized_args, "kwargs": serialized_kwargs, + "rpc_meta": rpc_meta, } last_error = None diff --git a/areal/infra/scheduler/ray.py b/areal/infra/scheduler/ray.py index 98a8d4b346..6ea45ed85c 100644 --- a/areal/infra/scheduler/ray.py +++ b/areal/infra/scheduler/ray.py @@ -37,6 +37,7 @@ ray_resource_type, ) from areal.utils import logging +from areal.utils.offload import get_tms_env_vars logger = logging.getLogger("RayScheduler") @@ -61,6 +62,9 @@ def __init__( ): self.exp_config = exp_config self.startup_timeout = startup_timeout + self.enable_tms_offload = False + if exp_config is not None: + self.enable_tms_offload = exp_config.enable_offload self._workers: dict[str, list[RayWorkerInfo]] = defaultdict(list) self._worker_info_by_id: dict[str, RayWorkerInfo] = {} @@ -120,6 +124,8 @@ def _build_env_vars(self, spec: SchedulingSpec) -> dict[str, str]: if spec.env_vars: additional_envs_str = ",".join(f"{k}={v}" for k, v in spec.env_vars.items()) env = get_env_vars(additional_envs_str) + if self.enable_tms_offload: + env.update(get_tms_env_vars()) thread_env = get_thread_env_vars( cpus_per_task=spec.cpu, existing_env_vars=spec.env_vars, @@ -632,6 +638,7 @@ def call_engine( method: str, engine_name: str | None = None, *args, + rpc_meta: dict[str, Any] | None = None, http_timeout: float = 7200.0, max_retries: int = 3, retry_delay: float = 1.0, @@ -647,7 +654,11 @@ def call_engine( try: # Pass engine_name to support multiple engines per worker (colocation) ref = wi.actor.call.remote( - method, *args, engine_name=engine_name, **kwargs + method, + *args, + engine_name=engine_name, + rpc_meta=rpc_meta, + **kwargs, ) result = ray.get(ref, timeout=http_timeout) if attempt > 1: @@ -687,6 +698,7 @@ async def async_call_engine( method: str, engine_name: str | None = None, *args, + rpc_meta: dict[str, Any] | None = None, http_timeout: float = 7200.0, max_retries: int = 3, retry_delay: float = 1.0, @@ -702,7 +714,11 @@ async def async_call_engine( try: # Pass engine_name to support multiple engines per worker (colocation) ref = wi.actor.call.remote( - method, *args, engine_name=engine_name, **kwargs + method, + *args, + engine_name=engine_name, + rpc_meta=rpc_meta, + **kwargs, ) result = await ref if attempt > 1: diff --git a/areal/infra/scheduler/slurm.py b/areal/infra/scheduler/slurm.py index 16ef5402d9..6c13151ab7 100644 --- a/areal/infra/scheduler/slurm.py +++ b/areal/infra/scheduler/slurm.py @@ -1447,6 +1447,7 @@ def call_engine( method: str, engine_name: str | None = None, *args, + rpc_meta: dict[str, Any] | None = None, http_timeout: float = 7200.0, max_retries: int = 3, retry_delay: float = 1.0, @@ -1502,6 +1503,7 @@ def call_engine( "engine_name": engine_name, "args": serialized_args, "kwargs": serialized_kwargs, + "rpc_meta": rpc_meta, } port = int(worker_info.worker.worker_ports[0]) @@ -1578,6 +1580,7 @@ async def async_call_engine( method: str, engine_name: str | None = None, *args, + rpc_meta: dict[str, Any] | None = None, http_timeout: float = 7200.0, max_retries: int = 3, retry_delay: float = 1.0, @@ -1633,6 +1636,7 @@ async def async_call_engine( "engine_name": engine_name, "args": serialized_args, "kwargs": serialized_kwargs, + "rpc_meta": rpc_meta, } port = int(worker_info.worker.worker_ports[0]) diff --git a/areal/infra/utils/concurrent.py b/areal/infra/utils/concurrent.py index a6268670b6..a175e6ffb5 100644 --- a/areal/infra/utils/concurrent.py +++ b/areal/infra/utils/concurrent.py @@ -5,6 +5,7 @@ import threading import weakref from functools import partial +from typing import Any from areal.utils import logging @@ -69,6 +70,21 @@ def run_async_task(func, *args, **kwargs): return asyncio.run(func(*args, **kwargs)) +def call_maybe_async(method, *args, **kwargs) -> None: + """Call a callable and await it if it returns an awaitable. + + This helper lets sync call-sites invoke methods that may be implemented as + either sync or async without branching at every call site. + """ + result: Any = method(*args, **kwargs) + if inspect.isawaitable(result): + + async def _wait_result() -> None: + await result + + run_async_task(_wait_result) + + # ============================================================================ # Event Loop Cleanup Utilities # ============================================================================ diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 1c24a549c6..aba803c349 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -356,13 +356,19 @@ def _ppo_update(self, data: dict[str, Any]) -> None: class PPOActorController(TrainController): def compute_logp(self, *args, **kwargs): - return self._custom_function_call("compute_logp", *args, **kwargs) + return self._custom_function_call( + "compute_logp", *args, rpc_meta={"broadcast": True}, **kwargs + ) def compute_advantages(self, *args, **kwargs): - return self._custom_function_call("compute_advantages", *args, **kwargs) + return self._custom_function_call( + "compute_advantages", *args, rpc_meta={"broadcast": True}, **kwargs + ) def ppo_update(self, *args, **kwargs) -> None: - self._custom_function_call("ppo_update", *args, **kwargs) + self._custom_function_call( + "ppo_update", *args, rpc_meta={"broadcast": True}, **kwargs + ) def grpo_loss_fn( diff --git a/areal/trainer/ppo/critic.py b/areal/trainer/ppo/critic.py index e24d879b55..887c7b73c6 100644 --- a/areal/trainer/ppo/critic.py +++ b/areal/trainer/ppo/critic.py @@ -70,10 +70,14 @@ def _ppo_update(self, data: dict[str, Any]) -> None: class PPOCriticController(TrainController): def compute_values(self, *args, **kwargs): - return self._custom_function_call("compute_values", *args, **kwargs) + return self._custom_function_call( + "compute_values", *args, rpc_meta={"broadcast": True}, **kwargs + ) def ppo_update(self, *args, **kwargs): - self._custom_function_call("ppo_update", *args, **kwargs) + self._custom_function_call( + "ppo_update", *args, rpc_meta={"broadcast": True}, **kwargs + ) def ppo_loss_fn( diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 41b4524f6c..d6f8dc4720 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -43,6 +43,7 @@ from areal.infra.data_service import DataController from areal.infra.data_service.controller.config import DataServiceConfig from areal.infra.data_service.rdataset import RDataset +from areal.infra.utils.concurrent import call_maybe_async from areal.utils import logging, perf_tracer, seeding, stats_tracker from areal.utils.dataloader import create_dataloader from areal.utils.environ import is_single_controller @@ -127,6 +128,17 @@ def __init__( self.rollout_alloc = ModelAllocation.from_str( config.rollout.backend, name="rollout" ) + self._should_offload_rollout = self._is_actor_rollout_colocated(config) + self._should_offload_actor = ( + self._should_offload_rollout or config.actor.offload + ) + self._should_offload_critic = ( + config.critic is not None and config.critic.offload + ) + self._should_offload_ref = config.ref is not None and config.ref.offload + self._should_offload_teacher = ( + config.teacher is not None and config.teacher.offload + ) # Validate config before proceeding with weight initialization self._validate_cfg() @@ -352,6 +364,136 @@ def __init__( ) self._config_perf_tracer() + self._apply_initial_offload_policy() + + @staticmethod + def _is_colocation(strategy: SchedulingStrategy | None) -> bool: + if strategy is None: + return False + return strategy.type in ( + SchedulingStrategyType.colocation, + SchedulingStrategyType.colocation.value, + "colocation", + ) + + def _is_actor_rollout_colocated(self, config: PPOConfig) -> bool: + actor_s = config.actor.scheduling_strategy + rollout_s = config.rollout.scheduling_strategy + return (self._is_colocation(actor_s) and actor_s.target == "rollout") or ( + self._is_colocation(rollout_s) and rollout_s.target == "actor" + ) + + def _onload_model(self, engine, role: str) -> None: + with ( + stats_tracker.record_timing(f"{role}_onload"), + perf_tracer.trace_scope( + f"train.{role}_onload", + category=Category.IO, + ), + ): + engine.onload() + + def _offload_model(self, engine, role: str) -> None: + with ( + stats_tracker.record_timing(f"{role}_offload"), + perf_tracer.trace_scope( + f"train.{role}_offload", + category=Category.IO, + ), + ): + engine.offload() + + def _offload_rollout(self, is_eval: bool = False): + rollout = self.rollout if not is_eval else self.eval_rollout + if rollout is None: + return + + with ( + stats_tracker.record_timing("rollout_pause"), + perf_tracer.trace_scope( + "train.rollout_pause", + category=Category.INSTR, + ), + ): + rollout.pause() + + with ( + stats_tracker.record_timing("rollout_pause_generation"), + perf_tracer.trace_scope( + "train.rollout_pause_generation", + category=Category.INSTR, + ), + ): + call_maybe_async(rollout.pause_generation) + + with ( + stats_tracker.record_timing("rollout_offload"), + perf_tracer.trace_scope( + "train.rollout_offload", + category=Category.IO, + ), + ): + rollout.offload() + + def _onload_rollout(self, is_eval: bool = False) -> None: + cleanup_error: Exception | None = None + + rollout = self.rollout if not is_eval else self.eval_rollout + if rollout is None: + return + + try: + with ( + stats_tracker.record_timing("rollout_onload"), + perf_tracer.trace_scope( + "train.rollout_onload", + category=Category.IO, + ), + ): + rollout.onload() + except Exception as exc: # noqa: BLE001 + cleanup_error = exc + + try: + with ( + stats_tracker.record_timing("rollout_continue_generation"), + perf_tracer.trace_scope( + "train.rollout_continue_generation", + category=Category.INSTR, + ), + ): + call_maybe_async(rollout.continue_generation) + except Exception as exc: # noqa: BLE001 + if cleanup_error is None: + cleanup_error = exc + + try: + with ( + stats_tracker.record_timing("rollout_resume"), + perf_tracer.trace_scope( + "train.rollout_resume", + category=Category.INSTR, + ), + ): + rollout.resume() + except Exception as exc: # noqa: BLE001 + if cleanup_error is None: + cleanup_error = exc + + if cleanup_error is not None: + raise cleanup_error + + def _apply_initial_offload_policy(self) -> None: + if self._should_offload_rollout: + self._offload_rollout() + if self._should_offload_ref: + self._offload_model(self.ref, role="ref") + if self._should_offload_critic: + self._offload_model(self.critic, role="critic") + if self._should_offload_teacher: + self._offload_model(self.teacher, role="teacher") + if self._should_offload_actor: + self._offload_model(self.actor, role="actor") def train( self, @@ -399,6 +541,8 @@ def train( epoch = global_step // steps_per_epoch step = global_step % steps_per_epoch + if self._should_offload_rollout: + self._onload_rollout() with ( stats_tracker.record_timing("rollout"), perf_tracer.trace_scope( @@ -418,8 +562,12 @@ def train( group_size=config.gconfig.n_samples, dynamic_bs=self.config.dynamic_bs, ) + if self._should_offload_rollout: + self._offload_rollout() if self.critic is not None: + if self._should_offload_critic: + self._onload_model(self.critic, role="critic") with ( stats_tracker.record_timing("critic_values"), perf_tracer.trace_scope( @@ -432,22 +580,12 @@ def train( for traj, v in zip(rollout_batch, values): traj["values"] = v self.critic.get_device_stats().log("critic values") - - if config.actor.should_compute_prox_logp(): - with ( - stats_tracker.record_timing("recompute_logp"), - perf_tracer.trace_scope( - "train.recompute_logp", - category=Category.COMPUTE, - args={"global_step": global_step}, - ), - ): - prox_logps = self.actor.compute_logp(rollout_batch) - for traj, logp in zip(rollout_batch, prox_logps): - traj["prox_logp"] = logp - self.actor.get_device_stats().log("recompute logp") + if self._should_offload_critic: + self._offload_model(self.critic, role="critic") if self.ref is not None: + if self._should_offload_ref: + self._onload_model(self.ref, role="ref") with ( stats_tracker.record_timing("ref_logp"), perf_tracer.trace_scope( @@ -460,8 +598,12 @@ def train( for traj, logp in zip(rollout_batch, ref_logps): traj["ref_logp"] = logp self.ref.get_device_stats().log("ref logp") + if self._should_offload_ref: + self._offload_model(self.ref, role="ref") if self.teacher is not None: + if self._should_offload_teacher: + self._onload_model(self.teacher, role="teacher") with ( stats_tracker.record_timing("teacher_logp"), perf_tracer.trace_scope( @@ -478,6 +620,24 @@ def train( self.config.teacher.distill_loss_weight ) self.teacher.get_device_stats().log("teacher logp") + if self._should_offload_teacher: + self._offload_model(self.teacher, role="teacher") + + if self._should_offload_actor: + self._onload_model(self.actor, role="actor") + if config.actor.should_compute_prox_logp(): + with ( + stats_tracker.record_timing("recompute_logp"), + perf_tracer.trace_scope( + "train.recompute_logp", + category=Category.COMPUTE, + args={"global_step": global_step}, + ), + ): + prox_logps = self.actor.compute_logp(rollout_batch) + for traj, logp in zip(rollout_batch, prox_logps): + traj["prox_logp"] = logp + self.actor.get_device_stats().log("recompute logp") with ( stats_tracker.record_timing("compute_advantage"), @@ -504,8 +664,12 @@ def train( self.actor.ppo_update(adv_batch) self.actor.step_lr_scheduler() self.actor.get_device_stats().log("ppo update") + if self._should_offload_actor: + self._offload_model(self.actor, role="actor") if self.critic is not None: + if self._should_offload_critic: + self._onload_model(self.critic, role="critic") with ( stats_tracker.record_timing("critic_train_step"), perf_tracer.trace_scope( @@ -517,6 +681,8 @@ def train( self.critic.ppo_update(adv_batch) self.critic.step_lr_scheduler() self.critic.get_device_stats().log("ppo critic update") + if self._should_offload_critic: + self._offload_model(self.critic, role="critic") # pause inference for updating weights, save, and evaluation self.rollout.pause() @@ -563,6 +729,8 @@ def train( epoch=epoch, epoch_step=step, global_step=global_step ) + if self._should_offload_rollout: + self._onload_rollout(is_eval=True) with ( stats_tracker.record_timing("eval"), perf_tracer.trace_scope( @@ -578,6 +746,8 @@ def train( epoch_step=step, global_step=global_step, ) + if self._should_offload_rollout: + self._offload_rollout(is_eval=True) with ( stats_tracker.record_timing("clear_batches"), @@ -973,6 +1143,31 @@ def _validate_cfg(self): """validate config for incompatible settings before weight initialization, to avoid wasted resources on spawning workers and loading models.""" rollout_backend = self.rollout_alloc.backend actor_backend = self.actor_alloc.backend + requires_train_engine_offload = any( + ( + self._should_offload_rollout, + self._should_offload_actor, + self._should_offload_critic, + self._should_offload_ref, + self._should_offload_teacher, + ) + ) + + if requires_train_engine_offload and not self.config.enable_offload: + raise ValueError( + "enable_offload must be True when colocation scheduling or train-engine " + "offload is enabled. Please set enable_offload=True." + ) + + if ( + self._is_actor_rollout_colocated(self.config) + and self.config.actor.weight_update_mode != "disk" + ): + raise ValueError( + "weight_update_mode must be 'disk' when colocation scheduling is enabled. " + "Please set actor.weight_update_mode=disk." + ) + if rollout_backend == "vllm" and self.config.rollout.return_routed_experts: raise ValueError( "return_routed_experts is only supported with SGLang backend. " diff --git a/areal/trainer/rw/rw_engine.py b/areal/trainer/rw/rw_engine.py index 5331732082..2d43e8ac07 100644 --- a/areal/trainer/rw/rw_engine.py +++ b/areal/trainer/rw/rw_engine.py @@ -71,13 +71,21 @@ def _evaluate_rw(self, data: dict[str, Any]) -> None: class RWController(TrainController): def train_rw(self, *args, **kwargs): - self._custom_function_call("train_rw", *args, **kwargs) + self._custom_function_call( + "train_rw", *args, rpc_meta={"broadcast": True}, **kwargs + ) def evaluate_rw(self, *args, **kwargs): # rw_modeling_collate_fn produces 2 sequences (chosen + rejected) per # example; group_size=2 keeps each pair on the same DP rank. args, kwargs = self._pad_eval_dispatch_args(args, kwargs, group_size=2) - self._custom_function_call("evaluate_rw", *args, group_size=2, **kwargs) + self._custom_function_call( + "evaluate_rw", + *args, + group_size=2, + rpc_meta={"broadcast": True}, + **kwargs, + ) def compute_rw_loss(scores: torch.Tensor, input_: dict[str, Any]) -> torch.Tensor: diff --git a/areal/trainer/sft/lm_engine.py b/areal/trainer/sft/lm_engine.py index 10883c4a0c..fa9cd42e30 100644 --- a/areal/trainer/sft/lm_engine.py +++ b/areal/trainer/sft/lm_engine.py @@ -43,11 +43,15 @@ def _evaluate_lm(self, data: dict[str, Any]) -> None: class LMController(TrainController): def train_lm(self, *args, **kwargs): - self._custom_function_call("train_lm", *args, **kwargs) + self._custom_function_call( + "train_lm", *args, rpc_meta={"broadcast": True}, **kwargs + ) def evaluate_lm(self, *args, **kwargs): args, kwargs = self._pad_eval_dispatch_args(args, kwargs, group_size=1) - self._custom_function_call("evaluate_lm", *args, **kwargs) + self._custom_function_call( + "evaluate_lm", *args, rpc_meta={"broadcast": True}, **kwargs + ) def compute_packed_sft_loss( diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 0cd05c3452..f519635cf4 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -353,6 +353,7 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | | `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | @@ -420,6 +421,7 @@ Configuration for PPO critic model, a subclass of a TrainEngine. | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | | `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | @@ -460,6 +462,7 @@ Core configuration for model training, including optimization and backend settin | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | | `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | @@ -1075,6 +1078,7 @@ Configuration class: TeacherConfig | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | | `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index d2b10d77c3..252830b7f8 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -351,6 +351,7 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | | `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | @@ -418,6 +419,7 @@ Configuration for PPO critic model, a subclass of a TrainEngine. | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | | `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | @@ -458,6 +460,7 @@ Core configuration for model training, including optimization and backend settin | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | | `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | @@ -1073,6 +1076,7 @@ Configuration class: TeacherConfig | `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | | `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | | `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | | `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | | `lora_rank` | integer | `32` | lora rank | | `lora_alpha` | integer | `16` | lora alpha | diff --git a/tests/test_eval_dispatch.py b/tests/test_eval_dispatch.py index ec96af3659..30c40706fa 100644 --- a/tests/test_eval_dispatch.py +++ b/tests/test_eval_dispatch.py @@ -389,7 +389,7 @@ def _capture_call(self, method: str, *args, **kwargs): assert captured["method"] == "evaluate_lm" padded_items = cast(list[dict[str, object]], captured["args"][0]) assert len(padded_items) == 8 - assert captured["kwargs"] == {} + assert captured["kwargs"] == {"rpc_meta": {"broadcast": True}} def test_rw_controller_evaluate_rw_explicitly_pads_pairs(self): controller = RWController.__new__(RWController) diff --git a/tests/test_examples.py b/tests/test_examples.py index 1887af70b9..3685bd3822 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -385,6 +385,52 @@ def test_gsm8k_ppo(tmp_path_factory): assert success, "GSM8K PPO example failed" +@pytest.mark.sglang +@pytest.mark.gpu +def test_gsm8k_ppo_colocate(tmp_path_factory): + experiments_path = tmp_path_factory.mktemp("experiments") + name_resolve_path = tmp_path_factory.mktemp("name_resolve") + model_path = get_model_path( + "/storage/openpsi/models/Qwen__Qwen3-0.6B", "Qwen/Qwen3-0.6B" + ) + dataset_path = get_dataset_path("/storage/openpsi/data/gsm8k", "openai/gsm8k") + + example_file = "examples/math/gsm8k_rl.py" + config_name = "examples/math/gsm8k_ppo.yaml" + success = run_async_task( + run_example, + example_file, + config_name, + "rollout.backend=sglang:d2", + "actor.backend=fsdp:d2", + "+actor.weight_update_mode=disk", + "+rollout.scheduling_strategy.type=colocation", + "+rollout.scheduling_strategy.target=actor", + "critic.scheduling_strategy.type=colocation", + "critic.scheduling_strategy.target=actor", + "ref.scheduling_strategy.type=colocation", + "ref.scheduling_strategy.target=actor", + "enable_offload=True", + "gconfig.n_samples=2", + "gconfig.max_new_tokens=256", + "sglang.mem_fraction_static=0.3", + "vllm.gpu_memory_utilization=0.3", + "actor.mb_spec.max_tokens_per_mb=1024", + "critic.mb_spec.max_tokens_per_mb=1024", + "train_dataset.batch_size=16", + "valid_dataset.batch_size=16", + f"train_dataset.path={dataset_path}", + f"valid_dataset.path={dataset_path}", + "cluster.n_gpus_per_node=1", + f"cluster.fileroot={str(experiments_path)}", + f"cluster.name_resolve.nfs_record_root={str(name_resolve_path)}", + f"actor.path={model_path}", + f"critic.path={model_path}", + "scheduler.type=local", + ) + assert success, "GSM8K PPO colocated example failed" + + @pytest.mark.ci @pytest.mark.parametrize( "rollout_backend,actor_backend",