From a4370b90009d33f334fe646ebe9e03df73a3de43 Mon Sep 17 00:00:00 2001 From: Tony Date: Sun, 3 May 2026 23:51:00 -0400 Subject: [PATCH] Add FSDP2 parallel strategy with per-FSDP-unit torch.compile and DTensor-aware Muon Replaces --distributed (DDP-only) with --parallel_strategy {ddp, fsdp}, introduces a 3D-aware ParallelContext (DeviceMesh sized (1, world, 1) today with names "pp", "dp", "tp" so adding TP/PP later is API-stable). Key pieces: - ParallelContext singleton in src/utils/parallel_context.py with shim helpers preserved in gpu_manager.py for backward compatibility. - GPUSetup dispatches strategy: DDP (existing path), FSDP2 via fully_shard with MixedPrecisionPolicy(param=bf16, reduce=fp32) and Megatron-style fp32 master weights (model cast to fp32 before wrap, MP downcasts for forward). - Per-FSDP-unit torch.compile inside _wrap_fsdp so each unit's gather hook fires at the unit boundary instead of being skipped by an outer compile. - MuonDistributed (src/optimizers/muon_distributed.py): subclasses Muon and routes DTensor params through full_tensor() -> Newton-Schulz -> distribute_tensor so FSDP-sharded weights work. - CheckpointManager moved to torch.distributed.checkpoint.state_dict APIs; same single .pt file format as before so eval/chat consumers are unchanged. MuonAdamW state dict split into "muon"/"adamw" halves to round-trip cleanly. - Per-LLM fsdp_wrap_modules() hooks (Qwen25/Llama3/Gemma2) returning decoder blocks via shared elms/llms/_wrap.get_decoder_layers helper. - ST-MEM hardcoded .to(float32) replaced with .to(next(self.parameters()).dtype) to honor the wrapped model's actual dtype (matches MTAE's pattern). - LLM wrappers pass use_cache=False during training to keep torch.compile from hitting Dynamo's recompile limit on per-layer KV cache init guards. - Trainer/RL-trainer/main_trainer no longer rank-0-gate save_checkpoint. Under FSDP, get_model_state_dict is a collective that needs all ranks; the old DDP-era is_main() gate caused the gather to deadlock and corrupt state when --save_step or save_epoch fired. Decisions stay rank-0 (save_step is deterministic; save_epoch is broadcast); the save itself now runs collectively, and only rank 0 writes the file (gated inside CheckpointManager.save_checkpoint). Existing scripts (train.sh, train2.sh, st_mem_full_training.sh) and the README swap --distributed for --parallel_strategy ddp; behavior unchanged for DDP. Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 10 +- scripts/st_mem_full_training.sh | 2 +- scripts/train.sh | 8 +- scripts/train2.sh | 6 +- src/configs/config.py | 3 +- src/dataloaders/build_dataloader.py | 9 +- src/elms/build_encoder.py | 3 +- src/elms/ecg_encoders/merl/merl.py | 1 - src/elms/ecg_encoders/st_mem/st_mem.py | 2 +- src/elms/llm_encoders/base_elf.py | 3 + src/elms/llm_encoders/llava.py | 3 + src/elms/llms/_wrap.py | 19 ++++ src/elms/llms/gemma2/gemma2.py | 7 +- src/elms/llms/llama3/llama3.py | 7 +- src/elms/llms/qwen25/qwen25.py | 7 +- src/main_trainer.py | 15 ++- src/optimizers/muon_distributed.py | 96 +++++++++++++++++ src/optimizers/optimizer_setup.py | 12 +-- src/runners/rl_trainer.py | 11 +- src/runners/trainer.py | 8 +- src/utils/checkpoint_manager.py | 70 ++++++++++--- src/utils/gpu_manager.py | 125 ++++++++++++++-------- src/utils/parallel_context.py | 138 +++++++++++++++++++++++++ src/utils/wandb_manager.py | 2 +- 24 files changed, 468 insertions(+), 99 deletions(-) create mode 100644 src/elms/llms/_wrap.py create mode 100644 src/optimizers/muon_distributed.py create mode 100644 src/utils/parallel_context.py diff --git a/README.md b/README.md index ff7488f..5024736 100644 --- a/README.md +++ b/README.md @@ -147,9 +147,11 @@ uv run torchrun --standalone --nproc_per_node=4 \ --llm qwen2.5-1.5b-instruct \ --encoder $ECG_ENCODER or $VISION_ENCODER \ --elm mlp_llava \ - --distributed + --parallel_strategy ddp ``` +`--parallel_strategy ddp` uses `DistributedDataParallel` (full model replica per rank). Swap in `--parallel_strategy fsdp` for FSDP2 per-parameter sharding (each rank holds a slice of the LLM's transformer blocks); use this when the LLM is too large to replicate per GPU. + For ECG Encoders, you will have to pretrain your own ECG Encoder using [ecg_nn](https://github.com/ELM-Research/ecg_nn). We plan to release pretrained encoders soon! To load in the pretrained encoder during ELM training run the following: ```bash @@ -227,7 +229,7 @@ uv run torchrun --standalone --nproc_per_node=$NPROC \ --optimizer adamw \ --lr 5e-4 \ --encoder_ckpt $ENCODER_CHECKPOINT.pt \ - --distributed + --parallel_strategy ddp ``` ### SFT @@ -246,7 +248,7 @@ uv run torchrun --standalone --nproc_per_node=$NPROC \ --optimizer adamw \ --lr 1e-4 \ --elm_ckpt $PRETRAIN_CKPT.pt \ - --distributed + --parallel_strategy ddp ``` ### RL @@ -270,7 +272,7 @@ uv run torchrun --standalone --nproc_per_node=$NPROC \ --rl_tau_pos 1.0 \ --rl_tau_neg 1.05 \ --elm_ckpt $SFT_CKPT.pt \ - --distributed + --parallel_strategy ddp ``` See `scripts/st_mem_full_training.sh` for an end-to-end pretrain → SFT → RL example. diff --git a/scripts/st_mem_full_training.sh b/scripts/st_mem_full_training.sh index e121f45..37571be 100644 --- a/scripts/st_mem_full_training.sh +++ b/scripts/st_mem_full_training.sh @@ -16,7 +16,7 @@ COMMON_FLAGS=( --grad_clip 1.0 --llm_input_len 2048 --num_encoder_tokens 50 \ - --distributed + --parallel_strategy ddp --system_prompt "$SYSTEM_PROMPT" --llm qwen2.5-3b-instruct --gradient_checkpointing diff --git a/scripts/train.sh b/scripts/train.sh index 49cfd2b..f1deade 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -76,7 +76,7 @@ src/main_trainer.py \ --epochs 10 \ --grad_clip 1.0 \ --num_workers 16 \ ---distributed \ +--parallel_strategy ddp \ --peft \ --torch_compile \ --wandb @@ -101,7 +101,7 @@ src/main_trainer.py \ --epochs 10 \ --grad_clip 1.0 \ --num_workers 16 \ ---distributed \ +--parallel_strategy ddp \ --peft \ --torch_compile \ --wandb @@ -126,7 +126,7 @@ src/main_trainer.py \ --epochs 10 \ --grad_clip 1.0 \ --num_workers 16 \ ---distributed \ +--parallel_strategy ddp \ --peft \ --torch_compile \ --wandb @@ -152,7 +152,7 @@ src/main_trainer.py \ --epochs 10 \ --grad_clip 1.0 \ --num_workers 16 \ ---distributed \ +--parallel_strategy ddp \ --peft \ --torch_compile \ --wandb diff --git a/scripts/train2.sh b/scripts/train2.sh index 8f82e7f..540fb09 100644 --- a/scripts/train2.sh +++ b/scripts/train2.sh @@ -75,7 +75,7 @@ src/main_trainer.py \ --epochs 10 \ --grad_clip 1.0 \ --num_workers 16 \ ---distributed \ +--parallel_strategy ddp \ --peft \ --torch_compile \ --wandb @@ -100,7 +100,7 @@ src/main_trainer.py \ --epochs 10 \ --grad_clip 1.0 \ --num_workers 16 \ ---distributed \ +--parallel_strategy ddp \ --peft \ --torch_compile \ --wandb @@ -123,7 +123,7 @@ src/main_trainer.py \ --llm_input_len 1024 \ --epochs 10 \ --num_workers 16 \ ---distributed \ +--parallel_strategy ddp \ --peft \ --torch_compile \ --wandb diff --git a/src/configs/config.py b/src/configs/config.py index 1b16501..d2ccc51 100644 --- a/src/configs/config.py +++ b/src/configs/config.py @@ -47,7 +47,8 @@ def get_args(mode: Mode) -> argparse.Namespace: parser.add_argument("--num_workers", type=int, default=0, help="Please choose the num works for the dataloader") parser.add_argument("--wandb", action="store_true", default=None, help="Enable logging") parser.add_argument("--device", type=str, default=None, help="Device (cuda/cpu)") - parser.add_argument("--distributed", action="store_true", default=None, help="Enable distributed training") + parser.add_argument("--parallel_strategy", type=str, default=None, choices=["ddp", "fsdp"], + help="Parallel strategy when launched via torchrun. ddp = DistributedDataParallel; fsdp = FSDP2 (per-parameter sharding). Omit for single-device.") parser.add_argument("--torch_compile", action="store_true", default=None, help="Torch compile the model (should really only be used during pretraining or large finetuning.)") parser.add_argument("--gradient_checkpointing", action="store_true", default=False, diff --git a/src/dataloaders/build_dataloader.py b/src/dataloaders/build_dataloader.py index af229d3..cefc5cd 100644 --- a/src/dataloaders/build_dataloader.py +++ b/src/dataloaders/build_dataloader.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader from collections.abc import Mapping, Sequence -from utils.gpu_manager import get_world_size, get_rank +from utils.parallel_context import get_parallel_context from dataloaders.dataset_mixer import DatasetMixer @@ -53,9 +53,10 @@ def get_torch_dataloader_sampler( self, torch_dataset, ): - if self.args.distributed: - sampler = DistributedSampler(torch_dataset, num_replicas=get_world_size(), - rank=get_rank(), seed=self.args.seed, shuffle=True) + if self.args.parallel_strategy: + ctx = get_parallel_context() + sampler = DistributedSampler(torch_dataset, num_replicas=ctx.dp_size, + rank=ctx.dp_rank, seed=self.args.seed, shuffle=True) else: sampler = None return sampler diff --git a/src/elms/build_encoder.py b/src/elms/build_encoder.py index f1fd288..724d809 100644 --- a/src/elms/build_encoder.py +++ b/src/elms/build_encoder.py @@ -67,8 +67,7 @@ def prepare_hf_siglip(self,): def prepare_merl(self,): from elms.ecg_encoders.merl.merl import MerlConfig, Merl - cfg = MerlConfig(distributed=self.args.distributed, - num_encoder_tokens=self.args.num_encoder_tokens) + cfg = MerlConfig(num_encoder_tokens=self.args.num_encoder_tokens) model = Merl(cfg) return {"encoder": model} diff --git a/src/elms/ecg_encoders/merl/merl.py b/src/elms/ecg_encoders/merl/merl.py index 67ae7ff..aef6d99 100644 --- a/src/elms/ecg_encoders/merl/merl.py +++ b/src/elms/ecg_encoders/merl/merl.py @@ -13,7 +13,6 @@ class MerlConfig: seq_len: int = 2500 lm: str = "ncbi/MedCPT-Query-Encoder" resnet_type: str = "resnet101" - distributed: bool = False spacial_dim: int = None d_model: int = 2048 num_encoder_tokens: int = 1 diff --git a/src/elms/ecg_encoders/st_mem/st_mem.py b/src/elms/ecg_encoders/st_mem/st_mem.py index 76fe5af..7202154 100644 --- a/src/elms/ecg_encoders/st_mem/st_mem.py +++ b/src/elms/ecg_encoders/st_mem/st_mem.py @@ -410,7 +410,7 @@ def forward_encoder(self, x,): return x def get_encoder_embeddings(self, ecg_signal): - x_latents = self.forward_encoder(ecg_signal.to(torch.float32)) + x_latents = self.forward_encoder(ecg_signal.to(next(self.parameters()).dtype)) out = rearrange(x_latents, 'b c n d -> b (c n) d') out = out.transpose(1, 2) out = self.avgpool(out) diff --git a/src/elms/llm_encoders/base_elf.py b/src/elms/llm_encoders/base_elf.py index a0afd20..1495bea 100644 --- a/src/elms/llm_encoders/base_elf.py +++ b/src/elms/llm_encoders/base_elf.py @@ -20,6 +20,9 @@ def train(self, mode: bool = True): module.train(mode if name in self.update else False) return self + def fsdp_wrap_modules(self): + return self.llm.fsdp_wrap_modules() + def forward(self, elm_input_ids, encoder_tokenizer_out, elm_attention_mask, elm_labels, signal_id_indices): projected_embeds = self.get_projections(**encoder_tokenizer_out) diff --git a/src/elms/llm_encoders/llava.py b/src/elms/llm_encoders/llava.py index 3236b8f..e96173a 100644 --- a/src/elms/llm_encoders/llava.py +++ b/src/elms/llm_encoders/llava.py @@ -22,6 +22,9 @@ def train(self, mode: bool = True): module.train(mode if name in self.update else False) return self + def fsdp_wrap_modules(self): + return self.llm.fsdp_wrap_modules() + def forward(self, elm_input_ids, elm_attention_mask, elm_labels, signal_id_indices, encoder_tokenizer_out): projected_embeds = self.get_projections(encoder_tokenizer_out) diff --git a/src/elms/llms/_wrap.py b/src/elms/llms/_wrap.py new file mode 100644 index 0000000..93f8dc2 --- /dev/null +++ b/src/elms/llms/_wrap.py @@ -0,0 +1,19 @@ +"""Shared helpers for FSDP wrapping of LLM wrapper modules.""" + +from torch import nn + + +def get_decoder_layers(hf_model: nn.Module) -> list[nn.Module]: + """Return the transformer decoder block ModuleList inside an HF causal LM. + + Handles PEFT (PeftModel -> base_model -> model -> model.layers) and the + bare HF (model.model.layers) layouts. + """ + node = hf_model + if hasattr(node, "base_model") and hasattr(node.base_model, "model"): + node = node.base_model.model + if hasattr(node, "model") and hasattr(node.model, "layers"): + return list(node.model.layers) + if hasattr(node, "layers"): + return list(node.layers) + raise RuntimeError(f"Could not locate decoder layers inside {type(hf_model).__name__}") diff --git a/src/elms/llms/gemma2/gemma2.py b/src/elms/llms/gemma2/gemma2.py index f075d25..04ba5f6 100644 --- a/src/elms/llms/gemma2/gemma2.py +++ b/src/elms/llms/gemma2/gemma2.py @@ -14,12 +14,17 @@ def forward(self, elm_input_ids, elm_attention_mask, inputs_embeds = elm_inputs_embeds, attention_mask = elm_attention_mask, labels = elm_labels, - output_hidden_states = self.output_hidden_states) + output_hidden_states = self.output_hidden_states, + use_cache = False) def get_llm_embeddings(self, elm_input_ids): out = self.llm.get_input_embeddings()(elm_input_ids.to(self.llm.device)) return out + def fsdp_wrap_modules(self): + from elms.llms._wrap import get_decoder_layers + return get_decoder_layers(self.llm) + def generate(self, elm_input_ids, elm_attention_mask, elm_inputs_embeds= None, max_new_tokens=128, **gen_kwargs): return self.llm.generate( diff --git a/src/elms/llms/llama3/llama3.py b/src/elms/llms/llama3/llama3.py index 278ef81..0634c17 100644 --- a/src/elms/llms/llama3/llama3.py +++ b/src/elms/llms/llama3/llama3.py @@ -14,12 +14,17 @@ def forward(self, elm_input_ids, elm_attention_mask, inputs_embeds = elm_inputs_embeds, attention_mask = elm_attention_mask, labels = elm_labels, - output_hidden_states = self.output_hidden_states) + output_hidden_states = self.output_hidden_states, + use_cache = False) def get_llm_embeddings(self, elm_input_ids): out = self.llm.get_input_embeddings()(elm_input_ids.to(self.llm.device)) return out + def fsdp_wrap_modules(self): + from elms.llms._wrap import get_decoder_layers + return get_decoder_layers(self.llm) + def generate(self, elm_input_ids, elm_attention_mask, elm_inputs_embeds= None, max_new_tokens=128, **gen_kwargs): return self.llm.generate( diff --git a/src/elms/llms/qwen25/qwen25.py b/src/elms/llms/qwen25/qwen25.py index cb21ada..62510a7 100644 --- a/src/elms/llms/qwen25/qwen25.py +++ b/src/elms/llms/qwen25/qwen25.py @@ -14,12 +14,17 @@ def forward(self, elm_input_ids, elm_attention_mask, inputs_embeds = elm_inputs_embeds, attention_mask = elm_attention_mask, labels = elm_labels, - output_hidden_states = self.output_hidden_states) + output_hidden_states = self.output_hidden_states, + use_cache = False) def get_llm_embeddings(self, elm_input_ids): out = self.llm.get_input_embeddings()(elm_input_ids.to(self.llm.device)) return out + def fsdp_wrap_modules(self): + from elms.llms._wrap import get_decoder_layers + return get_decoder_layers(self.llm) + def generate(self, elm_input_ids, elm_attention_mask, elm_inputs_embeds= None, max_new_tokens=128, **gen_kwargs): return self.llm.generate( diff --git a/src/main_trainer.py b/src/main_trainer.py index 3973cfa..3d196eb 100644 --- a/src/main_trainer.py +++ b/src/main_trainer.py @@ -30,8 +30,8 @@ def main(): args.mode = mode args.task = "train" - if args.distributed: - init_dist() + if args.parallel_strategy: + init_dist(args.parallel_strategy) gc.collect() torch.cuda.empty_cache() @@ -70,12 +70,17 @@ def main(): for epoch in range(start_epoch, args.epochs): train_result = runner(elm, optimizer, dataloader, epoch, args, checkpoint_manager) should_stop = False + should_save = False if checkpoint_manager and is_main(): - if checkpoint_manager.save_epoch(train_result["average_loss"]): - checkpoint_manager.save_checkpoint(elm, optimizer, epoch, -1, is_best=True, prefix="epoch_") + should_save = checkpoint_manager.save_epoch(train_result["average_loss"]) if args.early_stopping and checkpoint_manager.stop_early(): print(f"Early stopping at epoch {epoch}") should_stop = True + # Decision is rank-0-only (best_loss tracking lives there); save + # itself must be collective for FSDP get_model_state_dict. + should_save = broadcast_value(should_save, src=0) + if checkpoint_manager and should_save: + checkpoint_manager.save_checkpoint(elm, optimizer, epoch, -1, is_best=True, prefix="epoch_") should_stop = broadcast_value(should_stop, src=0) if should_stop: break @@ -84,7 +89,7 @@ def main(): with open(f"{run_folder}/DONE.txt", "w") as _: pass finally: - if args.distributed: + if args.parallel_strategy: cleanup() if is_main() and args.wandb: cleanup_wandb() diff --git a/src/optimizers/muon_distributed.py b/src/optimizers/muon_distributed.py new file mode 100644 index 0000000..77a5b6c --- /dev/null +++ b/src/optimizers/muon_distributed.py @@ -0,0 +1,96 @@ +"""FSDP-aware Muon: handles DTensor (sharded) parameters. + +Newton-Schulz orthogonalization is not shardable, so for each DTensor +parameter we all-gather to the full matrix, run NS on every rank (duplicated +compute, deterministic), then write the local slice back. Non-DTensor +parameters fall through to upstream Muon's functional path unchanged. + +Momentum buffers stay sharded (DTensor) to match the parameter — momentum +updates are pointwise and so are local-shard correct. +""" + +import torch +from torch.optim import Muon +from torch.optim._muon import _adjust_lr, _zeropower_via_newtonschulz, muon as _muon_fn + + +class MuonDistributed(Muon): + + @torch.no_grad() + def step(self, closure=None): + from torch.distributed.tensor import DTensor + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + local_params, local_grads, local_bufs = [], [], [] + dtensor_params = [] + + for p in group["params"]: + if p.grad is None: + continue + if torch.is_complex(p): + raise RuntimeError("Muon does not support complex parameters") + if p.grad.is_sparse: + raise RuntimeError("Muon does not support sparse gradients") + + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like( + p.grad, memory_format=torch.preserve_format + ) + + if isinstance(p, DTensor) or isinstance(p.grad, DTensor): + dtensor_params.append(p) + else: + local_params.append(p) + local_grads.append(p.grad) + local_bufs.append(state["momentum_buffer"]) + + if local_params: + _muon_fn( + local_params, local_grads, local_bufs, + lr=group["lr"], weight_decay=group["weight_decay"], + momentum=group["momentum"], nesterov=group["nesterov"], + ns_coefficients=group["ns_coefficients"], + eps=group["eps"], ns_steps=group["ns_steps"], + adjust_lr_fn=group["adjust_lr_fn"], has_complex=False, + ) + + for p in dtensor_params: + self._dtensor_step(p, group) + + return loss + + @torch.no_grad() + def _dtensor_step(self, p, group): + from torch.distributed.tensor import DTensor, distribute_tensor + + lr = group["lr"] + wd = group["weight_decay"] + mom = group["momentum"] + state = self.state[p] + buf = state["momentum_buffer"] + grad = p.grad + + # Match upstream: buf <- mom*buf + (1-mom)*grad (pointwise, shard-local). + buf.lerp_(grad, 1 - mom) + update = grad.lerp(buf, mom) if group["nesterov"] else buf + + update_full = update.full_tensor() if isinstance(update, DTensor) else update + update_full = _zeropower_via_newtonschulz( + update_full, group["ns_coefficients"], group["ns_steps"], group["eps"] + ) + adjusted_lr = _adjust_lr(lr, group["adjust_lr_fn"], p.shape) + + p_full = p.full_tensor() if isinstance(p, DTensor) else p + p_full = p_full * (1 - lr * wd) - adjusted_lr * update_full + + if isinstance(p, DTensor): + new_dt = distribute_tensor(p_full, p.device_mesh, p.placements) + p.to_local().copy_(new_dt.to_local()) + else: + p.copy_(p_full) diff --git a/src/optimizers/optimizer_setup.py b/src/optimizers/optimizer_setup.py index 8b952ac..e7932de 100644 --- a/src/optimizers/optimizer_setup.py +++ b/src/optimizers/optimizer_setup.py @@ -1,7 +1,9 @@ import torch import numpy as np from torch.optim import Adam, AdamW -from utils.gpu_manager import get_world_size, is_main +from utils.gpu_manager import is_main +from utils.parallel_context import get_parallel_context +from optimizers.muon_distributed import MuonDistributed OPTIMIZERS = {"adam": Adam, "adamw": AdamW} @@ -64,8 +66,8 @@ def __init__(self, model, args): self._log_config() def _world_size(self): - ws = get_world_size() - if ws == 1 and self.args.distributed: + ws = get_parallel_context().dp_size + if ws == 1 and self.args.parallel_strategy: return max(1, torch.cuda.device_count()) return ws @@ -99,8 +101,6 @@ def _build_optimizer(self, model): ) def _build_muon_optimizer(self, model): - from torch.optim import Muon - muon_params = [] adamw_params = [] @@ -126,7 +126,7 @@ def _build_muon_optimizer(self, model): self._adamw_lr = adamw_lr self._adamw_lr_ratio = adamw_lr_ratio - muon_opt = Muon( + muon_opt = MuonDistributed( muon_params, lr=self.peak_lr, momentum=muon_momentum, diff --git a/src/runners/rl_trainer.py b/src/runners/rl_trainer.py index 1ea64b2..c270c9b 100644 --- a/src/runners/rl_trainer.py +++ b/src/runners/rl_trainer.py @@ -3,14 +3,15 @@ from tqdm import tqdm import wandb -from utils.gpu_manager import is_main, get_world_size, train_dev_break +from utils.gpu_manager import is_main, train_dev_break +from utils.parallel_context import get_parallel_context from runners.helper import batch_to_device from rl.rl_loss import get_rl_loss, get_loss_kwargs from rl.rollout import rollout_group, current_log_prob def run_rl_train(nn, optimizer, dataloader, epoch, args, checkpoint_manager=None): - if getattr(args, "distributed", False) and hasattr(getattr(dataloader, "sampler", None), "set_epoch"): + if getattr(args, "parallel_strategy", None) and hasattr(getattr(dataloader, "sampler", None), "set_epoch"): dataloader.sampler.set_epoch(epoch) show_progress = is_main() @@ -23,7 +24,7 @@ def run_rl_train(nn, optimizer, dataloader, epoch, args, checkpoint_manager=None total_steps_per_epoch = len(dataloader) loss_fn = get_rl_loss(args.rl_algo) algo_kw = get_loss_kwargs(args.rl_algo, args) - dp_size = get_world_size() + dp_size = get_parallel_context().dp_size tokenizer = dataloader.dataset.llm_tokenizer optimizer.zero_grad() @@ -62,7 +63,9 @@ def run_rl_train(nn, optimizer, dataloader, epoch, args, checkpoint_manager=None "epoch": epoch, **{f"train/{k}": v for k, v in last_metrics.items()}}) accum_loss_for_log, accum_reward_for_log = 0.0, 0.0 - if args.save_step and checkpoint_manager and is_main(): + if args.save_step and checkpoint_manager: + # save_checkpoint must be collective under FSDP (get_model_state_dict + # gathers across ranks); only rank 0 writes the file (gated inside). if checkpoint_manager.save_step(step, total_steps_per_epoch): checkpoint_manager.save_checkpoint(nn, optimizer, epoch, step, prefix="step_") diff --git a/src/runners/trainer.py b/src/runners/trainer.py index 8b1a3fb..494cb3c 100644 --- a/src/runners/trainer.py +++ b/src/runners/trainer.py @@ -14,7 +14,7 @@ def run_train( args, checkpoint_manager=None, ): - if getattr(args, "distributed", False) and hasattr(getattr(dataloader, "sampler", None), "set_epoch"): + if getattr(args, "parallel_strategy", None) and hasattr(getattr(dataloader, "sampler", None), "set_epoch"): dataloader.sampler.set_epoch(epoch) show_progress = is_main() @@ -67,7 +67,11 @@ def run_train( accum_loss_for_log = 0.0 - if args.save_step and checkpoint_manager and is_main(): + if args.save_step and checkpoint_manager: + # save_step is deterministic (pure function of step + total) so all + # ranks reach the same decision. save_checkpoint must be called on + # all ranks because get_model_state_dict is an FSDP collective; + # only rank 0 actually writes the file (gated inside). if checkpoint_manager.save_step(step, total_steps_per_epoch): checkpoint_manager.save_checkpoint(nn, optimizer, epoch, step, prefix="step_") diff --git a/src/utils/checkpoint_manager.py b/src/utils/checkpoint_manager.py index fc852b0..f87ea13 100644 --- a/src/utils/checkpoint_manager.py +++ b/src/utils/checkpoint_manager.py @@ -1,8 +1,32 @@ -import torch import os + +import torch +import torch.distributed as dist +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_optimizer_state_dict, + set_model_state_dict, + set_optimizer_state_dict, +) + from utils.gpu_manager import is_main +_SAVE_OPTS = StateDictOptions(full_state_dict=True, cpu_offload=True) + + +def _load_opts() -> StateDictOptions: + # broadcast_from_rank0/cpu_offload route through the default process group; + # disable them when running single-process so DCP doesn't require dist init. + is_dist = dist.is_available() and dist.is_initialized() + return StateDictOptions( + full_state_dict=True, + broadcast_from_rank0=is_dist, + cpu_offload=is_dist, + ) + + class CheckpointManager: def __init__(self, run_dir, args): self.run_dir = run_dir @@ -14,22 +38,20 @@ def __init__(self, run_dir, args): os.makedirs(self.checkpoint_dir, exist_ok=True) def save_checkpoint(self, model, optimizer, epoch, step, is_best=False, prefix=""): + # All ranks participate in the gather; only rank 0 writes. + model_sd = get_model_state_dict(model, options=_SAVE_OPTS) + optimizer_sd = self._get_optimizer_state_dict(model, optimizer) + if not is_main(): return + filename = f"{prefix}epoch_{epoch}_step_{step}.pt" filepath = os.path.join(self.checkpoint_dir, filename) - - # Handle DDP-wrapped models - if self.args.distributed: - model_state_dict = model.module.state_dict() - else: - model_state_dict = model.state_dict() - checkpoint = { "epoch": epoch, "step": step, - "model_state_dict": model_state_dict, - "optimizer_state_dict": optimizer.optimizer.state_dict(), + "model_state_dict": model_sd, + "optimizer_state_dict": optimizer_sd, "n_current_steps": optimizer.n_current_steps, "best_loss": self.best_loss, } @@ -60,14 +82,32 @@ def stop_early(self): return current_loss > best_loss - self.args.patience_delta def resume_checkpoint(self, path, model, optimizer): - device = next(model.parameters()).device - ckpt = torch.load(path, map_location=device, weights_only=False) - (model.module if self.args.distributed else model).load_state_dict(ckpt["model_state_dict"]) - optimizer.optimizer.load_state_dict(ckpt["optimizer_state_dict"]) + ckpt = torch.load(path, map_location="cpu", weights_only=False) + set_model_state_dict(model, model_state_dict=ckpt["model_state_dict"], options=_load_opts()) + self._load_optimizer_state_dict(model, optimizer, ckpt["optimizer_state_dict"]) optimizer.n_current_steps = ckpt["n_current_steps"] self.best_loss = ckpt.get("best_loss", float("inf")) start_epoch = ckpt["epoch"] + 1 if is_main(): print(f"Resumed from {path} | epoch {start_epoch} | step {optimizer.n_current_steps}") del ckpt - return start_epoch \ No newline at end of file + return start_epoch + + def _get_optimizer_state_dict(self, model, optimizer): + from optimizers.optimizer_setup import MuonAdamW + inner = optimizer.optimizer + if isinstance(inner, MuonAdamW): + return { + "muon": get_optimizer_state_dict(model, inner.muon, options=_SAVE_OPTS), + "adamw": get_optimizer_state_dict(model, inner.adamw, options=_SAVE_OPTS), + } + return get_optimizer_state_dict(model, inner, options=_SAVE_OPTS) + + def _load_optimizer_state_dict(self, model, optimizer, optimizer_sd): + from optimizers.optimizer_setup import MuonAdamW + inner = optimizer.optimizer + if isinstance(inner, MuonAdamW): + set_optimizer_state_dict(model, inner.muon, optim_state_dict=optimizer_sd["muon"], options=_load_opts()) + set_optimizer_state_dict(model, inner.adamw, optim_state_dict=optimizer_sd["adamw"], options=_load_opts()) + else: + set_optimizer_state_dict(model, inner, optim_state_dict=optimizer_sd, options=_load_opts()) diff --git a/src/utils/gpu_manager.py b/src/utils/gpu_manager.py index 3320863..ad73ea0 100644 --- a/src/utils/gpu_manager.py +++ b/src/utils/gpu_manager.py @@ -1,52 +1,53 @@ -import torch, argparse, os, torch.distributed as dist +"""GPU placement and parallel-strategy dispatch. + +Free helpers (`is_main`, `get_world_size`, `init_dist`, ...) are thin shims over +the global ParallelContext, kept for backward compatibility with call sites +across the codebase. +""" + +import argparse +from typing import Iterable + +import torch from torch.nn.parallel import DistributedDataParallel as DDP +from utils.parallel_context import ( + ParallelContext, + get_parallel_context, + init_parallel_context, +) -def init_dist(): - if not dist.is_initialized(): - dist.init_process_group(backend="nccl", init_method="env://") - torch.cuda.set_device(get_local_rank()) + +def init_dist(strategy: str = "ddp") -> ParallelContext: + return init_parallel_context(strategy=strategy) def get_local_rank() -> int: - return int(os.environ.get("LOCAL_RANK", 0)) + return get_parallel_context().local_rank def get_rank() -> int: - if dist.is_available() and dist.is_initialized(): - return dist.get_rank() - return 0 + return get_parallel_context().global_rank def get_world_size() -> int: - if dist.is_available() and dist.is_initialized(): - return dist.get_world_size() - return 1 + return get_parallel_context().world_size def is_main() -> bool: - return get_rank() == 0 + return get_parallel_context().is_main -def barrier(): - if dist.is_available() and dist.is_initialized(): - dist.barrier() +def barrier() -> None: + get_parallel_context().barrier() -def cleanup(): - if dist.is_available() and dist.is_initialized(): - try: - dist.destroy_process_group() - except OSError: pass +def cleanup() -> None: + get_parallel_context().cleanup() def broadcast_value(val, src: int = 0): - """Broadcast a small Python object (e.g., str/int) without GPU assumptions.""" - if not (dist.is_available() and dist.is_initialized()): - return val - obj = [val] - dist.broadcast_object_list(obj, src=src) - return obj[0] + return get_parallel_context().broadcast_value(val, src=src) def train_dev_break(enabled: bool, batch: dict, loss_value: float) -> bool: @@ -67,30 +68,70 @@ def train_dev_break(enabled: bool, batch: dict, loss_value: float) -> bool: class GPUSetup: def __init__(self, args: argparse.Namespace): self.args = args + self.ctx = get_parallel_context() - def setup_gpu(self, model: torch.nn.Module, find_unused_parameters) -> torch.nn.Module: + def setup_gpu(self, model: torch.nn.Module, find_unused_parameters: bool) -> torch.nn.Module: device = self.get_device() model = model.to(device) - if getattr(self.args, "distributed", False): - model = DDP(model, device_ids=[device.index], output_device=device.index, find_unused_parameters=find_unused_parameters) + strategy = self.ctx.strategy + if strategy == "ddp": + model = DDP( + model, + device_ids=[device.index], + output_device=device.index, + find_unused_parameters=find_unused_parameters, + ) + if self.args.torch_compile: + model = torch.compile(model) + elif strategy == "fsdp": + # FSDP path applies torch.compile per-unit inside _wrap_fsdp. + # Compiling the outer FSDP-wrapped model traces through the + # pre-forward all-gather hooks and crashes on DTensor params. + model = self._wrap_fsdp(model) + else: + if self.args.torch_compile: + model = torch.compile(model) if is_main(): - print(f"find_unused_parameters: {find_unused_parameters}") - if self.args.torch_compile: - model = torch.compile(model) + print(f"[GPUSetup] strategy={strategy} | find_unused_parameters={find_unused_parameters} | torch_compile={self.args.torch_compile}") return model - def get_device(self) -> torch.device: - return self.get_multi_device() if getattr(self.args, "distributed", False) else self.get_single_device() + def _wrap_fsdp(self, model: torch.nn.Module) -> torch.nn.Module: + from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy + + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ) + # Megatron / torchtitan / nanotron recipe: keep fp32 master weights + # and fp32 optimizer state, run forward + backward in bf16 via the + # MixedPrecisionPolicy, and reduce gradients in fp32. + # Casting up to fp32 also satisfies FSDP2's "uniform original dtype + # per unit" rule: HF LLMs load in bf16 while the encoder/connector + # default to fp32, so a single dtype must be chosen — fp32 preserves + # the encoder's precision and gives the LLM a master copy to update. + model = model.to(torch.float32) + dp_mesh = self.ctx.dp_mesh + compile_units = self.args.torch_compile + for module in self._collect_wrap_modules(model): + fully_shard(module, mesh=dp_mesh, mp_policy=mp_policy, reshard_after_forward=True) + if compile_units: + # Per-unit compile: each FSDP unit's forward is a fresh Dynamo + # graph and FSDP's hooks are honoured at the unit boundary. + module.compile() + fully_shard(model, mesh=dp_mesh, mp_policy=mp_policy, reshard_after_forward=True) + return model + + def _collect_wrap_modules(self, model: torch.nn.Module) -> Iterable[torch.nn.Module]: + if hasattr(model, "fsdp_wrap_modules"): + return list(model.fsdp_wrap_modules()) + return [] - def get_single_device(self) -> torch.device: + def get_device(self) -> torch.device: + if self.ctx.is_distributed and torch.cuda.is_available(): + return torch.device(f"cuda:{self.ctx.local_rank}") dev = getattr(self.args, "device", None) return torch.device(dev or ("cuda" if torch.cuda.is_available() else "cpu")) - def get_multi_device(self) -> torch.device: - if torch.cuda.is_available(): - return torch.device(f"cuda:{get_local_rank()}") - return torch.device("cpu") - def print_model_device(self, model: torch.nn.Module, name: str) -> None: if is_main(): - print(f"{name} device:", next(model.parameters()).device) \ No newline at end of file + print(f"{name} device:", next(model.parameters()).device) diff --git a/src/utils/parallel_context.py b/src/utils/parallel_context.py new file mode 100644 index 0000000..2ec0af1 --- /dev/null +++ b/src/utils/parallel_context.py @@ -0,0 +1,138 @@ +"""ParallelContext: 3D-aware (pp, dp, tp) device mesh and distributed state. + +Today only the dp axis carries size > 1; tp/pp are always 1. The 3D mesh shape +is allocated up-front so that adding TP/PP later requires no API changes +elsewhere — call sites can reference ctx.dp_mesh / ctx.tp_mesh / ctx.pp_mesh +without knowing what's wired through them. +""" + +import os +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh + + +class ParallelContext: + + def __init__(self, strategy: str = "none"): + if strategy not in {"none", "ddp", "fsdp"}: + raise ValueError(f"unknown parallel strategy: {strategy}") + self.strategy = strategy + self._mesh: Optional[DeviceMesh] = None + self._local_rank = 0 + self._init() + + def _init(self) -> None: + if self.strategy == "none": + return + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", init_method="env://") + self._local_rank = int(os.environ.get("LOCAL_RANK", 0)) + if torch.cuda.is_available(): + torch.cuda.set_device(self._local_rank) + world = dist.get_world_size() + self._mesh = init_device_mesh( + "cuda", (1, world, 1), mesh_dim_names=("pp", "dp", "tp") + ) + + @property + def is_distributed(self) -> bool: + return self.strategy != "none" + + @property + def world_size(self) -> int: + if self._mesh is not None and dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + return 1 + + @property + def global_rank(self) -> int: + if self._mesh is not None and dist.is_available() and dist.is_initialized(): + return dist.get_rank() + return 0 + + @property + def local_rank(self) -> int: + return self._local_rank + + @property + def is_main(self) -> bool: + return self.global_rank == 0 + + @property + def dp_mesh(self) -> Optional[DeviceMesh]: + return self._mesh["dp"] if self._mesh is not None else None + + @property + def tp_mesh(self) -> Optional[DeviceMesh]: + return self._mesh["tp"] if self._mesh is not None else None + + @property + def pp_mesh(self) -> Optional[DeviceMesh]: + return self._mesh["pp"] if self._mesh is not None else None + + @property + def dp_size(self) -> int: + return self.dp_mesh.size() if self.dp_mesh is not None else 1 + + @property + def tp_size(self) -> int: + return self.tp_mesh.size() if self.tp_mesh is not None else 1 + + @property + def pp_size(self) -> int: + return self.pp_mesh.size() if self.pp_mesh is not None else 1 + + @property + def dp_rank(self) -> int: + return self.dp_mesh.get_local_rank() if self.dp_mesh is not None else 0 + + @property + def tp_rank(self) -> int: + return self.tp_mesh.get_local_rank() if self.tp_mesh is not None else 0 + + @property + def pp_rank(self) -> int: + return self.pp_mesh.get_local_rank() if self.pp_mesh is not None else 0 + + def barrier(self) -> None: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + def cleanup(self) -> None: + if dist.is_available() and dist.is_initialized(): + try: + dist.destroy_process_group() + except OSError: + pass + + def broadcast_value(self, val, src: int = 0): + if not (dist.is_available() and dist.is_initialized()): + return val + obj = [val] + dist.broadcast_object_list(obj, src=src) + return obj[0] + + +_CTX: Optional[ParallelContext] = None + + +def init_parallel_context(strategy: str = "none") -> ParallelContext: + global _CTX + if _CTX is None: + _CTX = ParallelContext(strategy=strategy) + return _CTX + + +def get_parallel_context() -> ParallelContext: + global _CTX + if _CTX is None: + _CTX = ParallelContext(strategy="none") + return _CTX + + +def reset_parallel_context() -> None: + global _CTX + _CTX = None diff --git a/src/utils/wandb_manager.py b/src/utils/wandb_manager.py index 9f6f39a..328ed24 100644 --- a/src/utils/wandb_manager.py +++ b/src/utils/wandb_manager.py @@ -14,4 +14,4 @@ def cleanup_wandb(): def log_wandb(metrics, prefix = None): if prefix: metrics = {f"{prefix}/{k}" : v for k, v in metrics.items()} - wandb.log(metrics) \ No newline at end of file + wandb.log(metrics)