Add FSDP2 parallellism#38
Conversation
…sor-aware Muon
Replaces --distributed (DDP-only) with --parallel_strategy {ddp, fsdp},
introduces a 3D-aware ParallelContext (DeviceMesh sized (1, world, 1) today
with names "pp", "dp", "tp" so adding TP/PP later is API-stable).
Key pieces:
- ParallelContext singleton in src/utils/parallel_context.py with shim
helpers preserved in gpu_manager.py for backward compatibility.
- GPUSetup dispatches strategy: DDP (existing path), FSDP2 via fully_shard
with MixedPrecisionPolicy(param=bf16, reduce=fp32) and Megatron-style fp32
master weights (model cast to fp32 before wrap, MP downcasts for forward).
- Per-FSDP-unit torch.compile inside _wrap_fsdp so each unit's gather hook
fires at the unit boundary instead of being skipped by an outer compile.
- MuonDistributed (src/optimizers/muon_distributed.py): subclasses Muon and
routes DTensor params through full_tensor() -> Newton-Schulz -> distribute_tensor
so FSDP-sharded weights work.
- CheckpointManager moved to torch.distributed.checkpoint.state_dict APIs;
same single .pt file format as before so eval/chat consumers are unchanged.
MuonAdamW state dict split into "muon"/"adamw" halves to round-trip cleanly.
- Per-LLM fsdp_wrap_modules() hooks (Qwen25/Llama3/Gemma2) returning decoder
blocks via shared elms/llms/_wrap.get_decoder_layers helper.
- ST-MEM hardcoded .to(float32) replaced with .to(next(self.parameters()).dtype)
to honor the wrapped model's actual dtype (matches MTAE's pattern).
- LLM wrappers pass use_cache=False during training to keep torch.compile from
hitting Dynamo's recompile limit on per-layer KV cache init guards.
- Trainer/RL-trainer/main_trainer no longer rank-0-gate save_checkpoint.
Under FSDP, get_model_state_dict is a collective that needs all ranks;
the old DDP-era is_main() gate caused the gather to deadlock and corrupt
state when --save_step or save_epoch fired. Decisions stay rank-0
(save_step is deterministic; save_epoch is broadcast); the save itself
now runs collectively, and only rank 0 writes the file (gated inside
CheckpointManager.save_checkpoint).
Existing scripts (train.sh, train2.sh, st_mem_full_training.sh) and the README
swap --distributed for --parallel_strategy ddp; behavior unchanged for DDP.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Implementation wise it seems okay, but I want to split this into smaller PRs. Also I want us to start thinking about how we can refactor our own codebase to make these changes easier. |
There was a problem hiding this comment.
I would rather want this to be one Muon object with a clean, intuitive separation between distributed and non distributed settings. However, the most ideal would be the following refactor to not only Muon but other things in the REPO:
What do people do in terms of distributed and non distributed settings? should we support both? Why do we need to support both? For example for DDP, we can run single device runs by setting --device and we run multiple GPU devices with --distributed and torchrun. However, we can use solely a single device in the distributed setting with CUDA_VISIBLE_DEVICES=0 uv run torchrun --standalone --nproc_per_node=1. I feel we have a lot of split paths of code between "single device specific code" and "multi device specific code".
Notes:
Summary of changes:
Adds FSDP2 as a parallel strategy alongside DDP. The CLI flag --distributed is replaced with --parallel_strategy {ddp, fsdp}. The DDP path is unchanged; FSDP is the new option, and you opt in by swapping that flag.
The plumbing is built behind a new ParallelContext that wraps a DeviceMesh sized (pp=1, dp=world, tp=1) with axis names ("pp", "dp", "tp"). Only the dp axis is in use today, but the 3D shape is in place so adding tensor or pipeline parallelism later is wiring rather than an API break. Free helpers like is_main() and get_world_size() become thin shims over the context, so existing call sites continue to work.
The FSDP path uses the standard Megatron / torchtitan mixed-precision recipe: fp32 master weights and fp32 optimizer state, bf16 forward and backward, fp32 cross-rank gradient reduce. Each transformer block becomes its own FSDP unit, and if --torch_compile is set we compile each unit individually after sharding it (compiling the outer model traces straight through FSDP2's pre-forward gather hooks and crashes).
Muon needs a wrapper because Newton-Schulz operates on the full weight matrix and doesn't shard. The new MuonDistributed subclass detects DTensor params, all-gathers them for the orthogonalization step, then writes the local slice back. Non-DTensor params fall through to upstream Muon unchanged. In practice Muon under FSDP is correct but bandwidth-heavy on PCIe-only / cross-NUMA setups; AdamW under FSDP doesn't have this overhead because its update is element-wise.
Checkpointing is now routed through torch.distributed.checkpoint's full-state-dict APIs with gather to rank 0. The on-disk .pt format is unchanged, so main_evaluator.py and main_chat.py consume FSDP-saved checkpoints identically to DDP-saved ones — verified end-to-end (FSDP train → save → single-process eval load → inference produces sensible metrics). A latent bug surfaced here: the trainer's old rank-0 gate around save_checkpoint was correct under DDP but deadlocked under FSDP, where the gather is a collective. Fixed — decision logic stays rank-0, save is collective, only rank 0 writes the file.
Two small dtype/cache fixes also landed because they blocked the FSDP path. ST-MEM had a hardcoded .to(torch.float32) on its encoder input that fought the bf16 forward; replaced with .to(next(self.parameters()).dtype) (the pattern MTAE already used). And the LLM wrappers now pass use_cache=False during training — HF's per-layer KV-cache init guards otherwise blow past Dynamo's recompile limit and silently degrade torch.compile to eager. Generation paths unchanged.
For users: existing --distributed-style scripts and the README have been migrated to --parallel_strategy ddp with no behavior change. Eval and chat scripts work as-is regardless of which strategy was used to train. Default for --parallel_strategy is unset (single-device), DDP gives you exactly what you got before, and FSDP is the new addition.