Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,14 @@ class TrainEngineConfig:
archon: ArchonEngineConfig = field(default_factory=ArchonEngineConfig)
megatron: MegatronEngineConfig = field(default_factory=MegatronEngineConfig)

# offload
offload: bool = field(
default=False,
metadata={
"help": "Whether to offload model parameters and optimizer states to CPU. "
},
)

# Lora
use_lora: bool = field(
default=False,
Expand Down
6 changes: 6 additions & 0 deletions areal/api/scheduler_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def call_engine(
method: str,
engine_name: str | None = None,
*args,
rpc_meta: dict[str, Any] | None = None,
**kwargs,
) -> Any:
"""Call a method on an engine instance running on a worker (data plane operation).
Expand All @@ -243,6 +244,8 @@ def call_engine(
Defaults to worker_id if not specified.
*args
Positional arguments to pass to the method
rpc_meta : dict[str, Any] | None, optional
RPC metadata, by default None
**kwargs
Keyword arguments to pass to the method

Expand All @@ -269,6 +272,7 @@ async def async_call_engine(
method: str,
engine_name: str | None = None,
*args,
rpc_meta: dict[str, Any] | None = None,
**kwargs,
) -> Any:
"""Async version of call_engine for calling engine methods asynchronously.
Expand All @@ -286,6 +290,8 @@ async def async_call_engine(
Defaults to worker_id if not specified.
*args
Positional arguments to pass to the method
rpc_meta : dict[str, Any] | None, optional
RPC metadata, by default None
**kwargs
Keyword arguments to pass to the method

Expand Down
89 changes: 53 additions & 36 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
from collections.abc import Callable, Iterator
from concurrent.futures import Future
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from datetime import datetime
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -485,20 +485,14 @@ def prepare_batch(

def update_weights(self, meta: WeightUpdateMeta):
self._check_rollout_engine_connected()
if meta.type == "xccl":
assert self.weight_update_group_initialized
# In offload mode, wakes up parameters as needed to perform the update.
tms_context = (
torch_memory_saver.disable()
if self.is_offload and not torch.version.hip
else nullcontext()
)
with tms_context:
with self._offload_aware_context():
if meta.type == "xccl":
assert self.weight_update_group_initialized
self._update_weights_from_distributed(meta)
elif meta.type == "disk":
self._update_weights_from_disk(meta)
else:
raise ValueError(f"Unknown weight update type {meta.type}")
elif meta.type == "disk":
self._update_weights_from_disk(meta)
else:
raise ValueError(f"Unknown weight update type {meta.type}")

def set_version(self, version: int):
self._version = version
Expand All @@ -507,31 +501,47 @@ def get_version(self) -> int:
return self._version

def save(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
self._save_model_to_hf(meta.path, meta.tokenizer, meta.processor)
elif meta.weight_format == "dcp":
self._save_to_dcp(meta.path, meta.with_optim)
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
with self._offload_aware_context():
if meta.weight_format == "hf":
self._save_model_to_hf(meta.path, meta.tokenizer, meta.processor)
elif meta.weight_format == "dcp":
self._save_to_dcp(meta.path, meta.with_optim)
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")

if meta.with_optim and meta.weight_format == "hf":
self._save_optimizer_state(meta.path)
if meta.with_optim and meta.weight_format == "hf":
self._save_optimizer_state(meta.path)

def load(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
self._load_model_from_hf(meta.path)
elif meta.weight_format == "dcp":
self._load_from_dcp(meta.path, meta.with_optim)
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")

if meta.with_optim and meta.weight_format == "hf":
self._load_optimizer_state(meta.path)
with self._offload_aware_context():
if meta.weight_format == "hf":
self._load_model_from_hf(meta.path)
elif meta.weight_format == "dcp":
self._load_from_dcp(meta.path, meta.with_optim)
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")

if meta.with_optim and meta.weight_format == "hf":
self._load_optimizer_state(meta.path)

# Checkpoint load replaces optimizer state tensor objects, losing
# pinning and normalization established by PerLayerOptimWrapper.__init__.
if meta.with_optim and self._per_layer_optim_wrapper is not None:
self._per_layer_optim_wrapper.refresh_states()

@contextmanager
def _offload_aware_context(self):
"""Temporarily onload parameters for offload-unsafe operations."""
if not self.is_offload:
with nullcontext():
yield
return

# Checkpoint load replaces optimizer state tensor objects, losing
# pinning and normalization established by PerLayerOptimWrapper.__init__.
if meta.with_optim and self._per_layer_optim_wrapper is not None:
self._per_layer_optim_wrapper.refresh_states()
self.onload()
try:
yield
finally:
self.offload()

def optimizer_zero_grad(self):
assert self.optimizer is not None
Expand Down Expand Up @@ -749,13 +759,20 @@ def process_output(logits: torch.Tensor, ctx_dict: dict[str, Any]) -> None:
return split_batch(result, meta)

def export_stats(self) -> dict[str, float]:
return stats_tracker.export_all(reduce_group=self.data_parallel_group)
with self._offload_aware_context():
return stats_tracker.export_all(
reduce_group=self.data_parallel_group,
)

def offload(self) -> None:
"""Offload model memory to CPU using torch_memory_saver.

Ref: https://github.com/THUDM/slime/blob/main/slime/backends/fsdp_utils/actor.py
"""
if not is_tms_enabled():
raise RuntimeError(
"torch_memory_saver requires `enable_offload=True` in yaml config."
)

self.get_device_stats().log("before offload model")

Expand Down
119 changes: 70 additions & 49 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
from collections.abc import Callable, Iterator
from concurrent.futures import Future
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from datetime import datetime
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -569,20 +569,14 @@ def prepare_batch(

def update_weights(self, meta: WeightUpdateMeta):
self._check_rollout_engine_connected()
if meta.type == "xccl":
assert self.weight_update_group_initialized
# In offload mode, wakes up parameters as needed to perform the update.
tms_context = (
torch_memory_saver.disable()
if self.is_offload and not torch.version.hip
else nullcontext()
)
with tms_context:
with self._offload_aware_context():
if meta.type == "xccl":
assert self.weight_update_group_initialized
self._update_weights_from_distributed(meta)
elif meta.type == "disk":
self._update_weights_from_disk(meta)
else:
raise ValueError(f"Unknown weight update type {meta.type}")
elif meta.type == "disk":
self._update_weights_from_disk(meta)
else:
raise ValueError(f"Unknown weight update type {meta.type}")

def set_version(self, version: int):
self._version = version
Expand All @@ -591,45 +585,65 @@ def get_version(self) -> int:
return self._version

def save(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
if meta.with_optim:
raise ValueError(
"HF format does not support optimizer state saving, please use DCP format instead."
with self._offload_aware_context():
if meta.weight_format == "hf":
if meta.with_optim:
raise ValueError(
"HF format does not support optimizer state saving, please use DCP format instead."
)
self._save_model_to_hf(
meta.path,
tokenizer=meta.tokenizer,
processor=meta.processor,
base_model_path=meta.base_model_path,
)
self._save_model_to_hf(
meta.path,
tokenizer=meta.tokenizer,
processor=meta.processor,
base_model_path=meta.base_model_path,
)
elif meta.weight_format == "dcp":
if self.checkpointer is None:
raise NotImplementedError(
"DCP checkpoint save is not available for this Megatron configuration "
"(e.g., LoRA path without distributed optimizer support). "
"Please use weight_format='hf' for adapter/full-model export."
elif meta.weight_format == "dcp":
if self.checkpointer is None:
raise NotImplementedError(
"DCP checkpoint save is not available for this Megatron configuration "
"(e.g., LoRA path without distributed optimizer support). "
"Please use weight_format='hf' for adapter/full-model export."
)
self.checkpointer.save_checkpoint(
meta.path, with_optimizer=meta.with_optim
)
self.checkpointer.save_checkpoint(meta.path, with_optimizer=meta.with_optim)
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")

def load(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
if meta.with_optim:
raise ValueError(
"HF format does not support optimizer state loading, please use DCP format instead."
)
self._load_model_from_hf(meta.path)
elif meta.weight_format == "dcp":
if self.checkpointer is None:
raise NotImplementedError(
"DCP checkpoint load is not available for this Megatron configuration "
"(e.g., LoRA path without distributed optimizer support). "
"Please use weight_format='hf' for adapter/full-model load."
with self._offload_aware_context():
if meta.weight_format == "hf":
if meta.with_optim:
raise ValueError(
"HF format does not support optimizer state loading, please use DCP format instead."
)
self._load_model_from_hf(meta.path)
elif meta.weight_format == "dcp":
if self.checkpointer is None:
raise NotImplementedError(
"DCP checkpoint load is not available for this Megatron configuration "
"(e.g., LoRA path without distributed optimizer support). "
"Please use weight_format='hf' for adapter/full-model load."
)
self.checkpointer.load_checkpoint(
meta.path, with_optimizer=meta.with_optim
)
self.checkpointer.load_checkpoint(meta.path, with_optimizer=meta.with_optim)
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")

@contextmanager
def _offload_aware_context(self):
"""Temporarily onload parameters for offload-unsafe operations."""
if not self.is_offload:
with nullcontext():
yield
return

self.onload()
try:
yield
finally:
self.offload()

def optimizer_zero_grad(self):
assert self.optimizer is not None, "Optimizer is not initialized."
Expand Down Expand Up @@ -868,7 +882,10 @@ def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> None:
return split_batch(res, meta)

def export_stats(self) -> dict[str, float]:
data = stats_tracker.export_all(reduce_group=self.data_parallel_group)
with self._offload_aware_context():
data = stats_tracker.export_all(
reduce_group=self.data_parallel_group,
)
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Some log info only exist in last pipeline rank
data_list = [data]
Expand All @@ -885,6 +902,10 @@ def offload(self) -> None:

Ref: https://github.com/THUDM/slime/blob/main/slime/backends/megatron_utils/actor.py
"""
if not is_tms_enabled():
raise RuntimeError(
"torch_memory_saver requires `enable_offload=True` in yaml config."
)

self.get_device_stats().log("before offload model")
current_platform.clear_memory()
Expand Down
Loading
Loading