Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,15 @@ class MegatronEngineConfig:
# FP8 Training Configuration
fp8_config: FP8EngineConfig | None = None

# Bridge backend used for HF<->Megatron conversion/model creation.
bridge_type: str = field(
default="mbridge",
metadata={
"help": "Bridge backend for MegatronEngine. Choices: 'mbridge' or 'megatron-bridge'.",
"choices": ["mbridge", "megatron-bridge"],
},
)


class SchedulingStrategyType(str, Enum):
separation = "separation"
Expand Down
138 changes: 104 additions & 34 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
)
from areal.utils.functional import gather_logprobs, gather_logprobs_entropy
from areal.utils.hf_utils import load_hf_processor_and_tokenizer, load_hf_tokenizer
from areal.utils.network import find_free_ports, gethostip
from areal.utils.network import find_free_ports, format_host_for_url, gethostip
from areal.utils.offload import is_tms_enabled, torch_memory_saver
from areal.utils.perf_tracer import trace_perf, trace_scope
from areal.utils.save_load import get_state_dict_from_repo_id_or_path
Expand Down Expand Up @@ -160,6 +160,14 @@ def to_dict(self) -> dict[str, Any]:
return {f.name: getattr(self, f.name) for f in dataclasses.fields(self)}


@dataclasses.dataclass
class _PendingWeightUpdateBucket:
handles: list[Any]
fut: Future[None]
named_tensors: list[tuple[str, torch.Tensor]]
stream: torch.cuda.Stream | None = None


class FSDPEngine(TrainEngine):
def __init__(self, config: TrainEngineConfig):
self.config = config
Expand Down Expand Up @@ -1031,14 +1039,15 @@ def _get_full_tensor(self, param: nn.Parameter) -> torch.Tensor:
tensor = tensor.to(current_platform.device_type)
return tensor

def _update_bucket_weights_from_distributed(
def _update_bucket_weights_from_distributed_async(
self,
meta: WeightUpdateMeta,
named_tensors: list[tuple[str, nn.Parameter | torch.Tensor]],
):
stream: torch.cuda.Stream | None = None,
) -> _PendingWeightUpdateBucket | None:
# Early exit when chunk size is relatively small
if not named_tensors:
return
return None

param_specs = [
ParamSpec(
Expand Down Expand Up @@ -1067,18 +1076,48 @@ def _update_bucket_weights_from_distributed(
fut = self.rollout_engine.update_weights_from_distributed(meta, param_specs)

handles = []
for _, tensor in named_tensors:
handles.append(
dist.broadcast(
tensor, src=0, group=self.weight_update_group, async_op=True
if stream is not None:
stream.wait_stream(torch.cuda.current_stream())
context = torch.cuda.stream(stream)
else:
context = nullcontext()

with context:
for _, tensor in named_tensors:
handles.append(
dist.broadcast(
tensor, src=0, group=self.weight_update_group, async_op=True
)
)
)
for handle in handles:

return _PendingWeightUpdateBucket(
handles=handles,
fut=fut,
named_tensors=named_tensors,
stream=stream,
)

def _wait_pending_weight_update_bucket(
self, pending_bucket: _PendingWeightUpdateBucket | None
):
if pending_bucket is None:
return

for handle in pending_bucket.handles:
handle.wait()

fut.result()
pending_bucket.fut.result()
pending_bucket.named_tensors.clear()

named_tensors.clear()
def _update_bucket_weights_from_distributed(
self,
meta: WeightUpdateMeta,
named_tensors: list[tuple[str, nn.Parameter | torch.Tensor]],
):
pending_bucket = self._update_bucket_weights_from_distributed_async(
meta, named_tensors
)
self._wait_pending_weight_update_bucket(pending_bucket)

def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta):
assert meta.type == "xccl"
Expand All @@ -1097,15 +1136,16 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta):
fut = self.rollout_engine.init_weights_update_group(meta)

gen_world_size = meta.gen_allocation.parallel.world_size
init_method = f"tcp://{format_host_for_url(meta.nccl_master_address)}:{meta.nccl_master_port}"
self.logger.info(
f"Initializing weight update group: type={meta.type} "
f"init_method=tcp://{meta.nccl_master_address}:{meta.nccl_master_port} "
f"init_method={init_method} "
f"group={meta.nccl_group_name}"
)
self.weight_update_group = init_custom_process_group(
backend=current_platform.communication_backend,
world_size=gen_world_size + 1,
init_method=f"tcp://{meta.nccl_master_address}:{meta.nccl_master_port}",
init_method=init_method,
rank=0,
group_name=meta.nccl_group_name,
timeout=DIST_GROUP_DEFAULT_TIMEOUT,
Expand All @@ -1115,23 +1155,32 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta):

@trace_perf("fsdp_engine.update_weights_from_distributed", category="comm")
def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
"""Broadcast parameters (chunked) from rank 0 (FSDP2 compatible)."""
"""Broadcast parameters with single-pending-bucket pipelining."""

# Reset weight weight meta with local info
meta.nccl_master_address = self.weight_update_master_addr
meta.nccl_master_port = self.weight_update_master_port
meta.nccl_group_name = self.weight_update_group_name

if dist.get_rank() == 0:
main_rank = dist.get_rank() == 0
if main_rank:
self.rollout_engine.pause_generation()

dist.barrier(group=self.cpu_group)

weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024
main_rank = dist.get_rank() == 0
broadcast_stream = None

if (
main_rank
and current_platform.device_type == "cuda"
and torch.cuda.is_available()
):
broadcast_stream = torch.cuda.Stream()

buffer_size = 0
named_tensors: list[tuple[str, torch.Tensor]] = []
pending_bucket: _PendingWeightUpdateBucket | None = None

if self.config.use_lora:
# For LoRA, only iterate over trainable LoRA parameters
Expand All @@ -1144,29 +1193,50 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
# For full model, iterate over all parameters
param_iterator = self._get_model_name_parameters()

for name, param in param_iterator:
tensor = self._get_full_tensor(param)

# Ranks other than 0 only help to get the full tensor
if not main_rank:
continue
try:
for name, param in param_iterator:
# Ranks other than 0 only help to get the full tensor
tensor = self._get_full_tensor(param)
if not main_rank:
continue

tensor_size = tensor.numel() * tensor.element_size()
bucket_overflow = (
buffer_size > 0
and tensor_size + buffer_size > weight_chunked_mem_size
)
if bucket_overflow:
# Only middle buckets need drain+align before the next all-gather.
if pending_bucket is not None:
self._wait_pending_weight_update_bucket(pending_bucket)
pending_bucket = None

pending_bucket = self._update_bucket_weights_from_distributed_async(
meta,
named_tensors,
stream=broadcast_stream,
)

tensor_size = tensor.numel() * tensor.element_size()
named_tensors = []
buffer_size = 0

if tensor_size + buffer_size > weight_chunked_mem_size:
self._update_bucket_weights_from_distributed(meta, named_tensors)
buffer_size = 0
buffer_size += tensor_size
named_tensors.append((name, tensor))

named_tensors.append((name, tensor))
buffer_size += tensor_size
if pending_bucket:
self._wait_pending_weight_update_bucket(pending_bucket)
pending_bucket = None

# Process remaining parameters
if named_tensors:
self._update_bucket_weights_from_distributed(meta, named_tensors)
# Process remaining parameters
if buffer_size > 0:
self._update_bucket_weights_from_distributed(meta, named_tensors)
finally:
if main_rank and pending_bucket is not None:
self._wait_pending_weight_update_bucket(pending_bucket)
pending_bucket = None

dist.barrier(group=self.cpu_group)

if dist.get_rank() == 0:
if main_rank:
self.rollout_engine.continue_generation()

current_platform.synchronize()
Expand Down
22 changes: 11 additions & 11 deletions areal/engine/fsdp_utils/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torch.nn import Parameter
from torch.optim.adam import adam as _adam_fn

from areal.infra.platforms import current_platform


def to_precision_dtype(dtype_str: str) -> torch.dtype:
"""
Expand Down Expand Up @@ -476,12 +478,11 @@ def refresh_states(self) -> None:

def _init_streams_and_events(self) -> None:
"""Pre-allocate streams and events for pipeline synchronization."""
# TODO: abstract via current_platform for non-CUDA devices
num_groups = len(self._layer_param_groups)
self._h2d_stream = torch.cuda.Stream(device=self.device)
self._d2h_stream = torch.cuda.Stream(device=self.device)
self._compute_end_events = [torch.cuda.Event() for _ in range(num_groups)]
self._h2d_end_events = [torch.cuda.Event() for _ in range(num_groups)]
self._h2d_stream = current_platform.Stream(device=self.device)
self._d2h_stream = current_platform.Stream(device=self.device)
self._compute_end_events = [current_platform.Event() for _ in range(num_groups)]
self._h2d_end_events = [current_platform.Event() for _ in range(num_groups)]

# ------------------------------------------------------------------
# Per-layer transfer helpers
Expand Down Expand Up @@ -583,8 +584,7 @@ def step(self) -> None:
"""Per-layer optimizer step with async prefetch pipeline."""
h2d_stream = self._h2d_stream
d2h_stream = self._d2h_stream
# TODO: abstract via current_platform for non-CUDA devices
compute_stream = torch.cuda.current_stream(self.device)
compute_stream = current_platform.current_stream(self.device)
num_groups = len(self._layer_param_groups)
layer_states: list[dict[int, ParamTransferState] | None] = [None] * num_groups

Expand All @@ -593,7 +593,7 @@ def step(self) -> None:

# Prefetch initial layers
for i in range(min(self.prefetch_layers + 1, num_groups)):
with torch.cuda.stream(h2d_stream):
with current_platform.stream(h2d_stream):
layer_states[i] = self._prefetch_layer(i)
h2d_stream.record_event(h2d_end_events[i])

Expand All @@ -609,13 +609,13 @@ def step(self) -> None:
# Prefetch next layer (overlaps with D2H below)
next_idx = i + self.prefetch_layers + 1
if next_idx < num_groups:
with torch.cuda.stream(h2d_stream):
with current_platform.stream(h2d_stream):
layer_states[next_idx] = self._prefetch_layer(next_idx)
h2d_stream.record_event(h2d_end_events[next_idx])

# Offload current layer (waits only for this layer's compute)
d2h_stream.wait_event(compute_end_events[i])
with torch.cuda.stream(d2h_stream):
with current_platform.stream(d2h_stream):
cur_states_offload = layer_states[i]
assert cur_states_offload is not None, f"Layer {i} already freed"
self._offload_layer(cur_states_offload)
Expand All @@ -628,4 +628,4 @@ def step(self) -> None:

# Prevent cross-phase cache pollution: return freed optimizer state
# blocks to driver so forward/backward can't repurpose them.
torch.cuda.empty_cache()
current_platform.empty_cache()
Loading
Loading