Skip to content

Add FSDP2 parallellism#38

Open
TonyChen06 wants to merge 1 commit into
ELM-Research:mainfrom
TonyChen06:feat/fsdp-support
Open

Add FSDP2 parallellism#38
TonyChen06 wants to merge 1 commit into
ELM-Research:mainfrom
TonyChen06:feat/fsdp-support

Conversation

@TonyChen06

Copy link
Copy Markdown
Contributor

Notes:

  1. --distributed changed to --parallel_strategy {ddp, fsdp} so make sure to update scripts
  2. A6000s don't have NVLink meaning that FSDP is 5x slower than DDP for Muon and around 40% slower than DDP for Adam due to communication costs. No idea how much slower FSDP is than DDP for H100s but it should be much much less inefficient.
  3. FSDP doesn't do much unless the model is 3b+ parameters and introduces slight overhead so use DDP if possible still
  4. Didn't implement tensor parallelism or pipeline parallelism because those are overkill for now (but the way we did FSDP is friendly to those additions if needed in the future)

Summary of changes:

Adds FSDP2 as a parallel strategy alongside DDP. The CLI flag --distributed is replaced with --parallel_strategy {ddp, fsdp}. The DDP path is unchanged; FSDP is the new option, and you opt in by swapping that flag.

The plumbing is built behind a new ParallelContext that wraps a DeviceMesh sized (pp=1, dp=world, tp=1) with axis names ("pp", "dp", "tp"). Only the dp axis is in use today, but the 3D shape is in place so adding tensor or pipeline parallelism later is wiring rather than an API break. Free helpers like is_main() and get_world_size() become thin shims over the context, so existing call sites continue to work.

The FSDP path uses the standard Megatron / torchtitan mixed-precision recipe: fp32 master weights and fp32 optimizer state, bf16 forward and backward, fp32 cross-rank gradient reduce. Each transformer block becomes its own FSDP unit, and if --torch_compile is set we compile each unit individually after sharding it (compiling the outer model traces straight through FSDP2's pre-forward gather hooks and crashes).

Muon needs a wrapper because Newton-Schulz operates on the full weight matrix and doesn't shard. The new MuonDistributed subclass detects DTensor params, all-gathers them for the orthogonalization step, then writes the local slice back. Non-DTensor params fall through to upstream Muon unchanged. In practice Muon under FSDP is correct but bandwidth-heavy on PCIe-only / cross-NUMA setups; AdamW under FSDP doesn't have this overhead because its update is element-wise.

Checkpointing is now routed through torch.distributed.checkpoint's full-state-dict APIs with gather to rank 0. The on-disk .pt format is unchanged, so main_evaluator.py and main_chat.py consume FSDP-saved checkpoints identically to DDP-saved ones — verified end-to-end (FSDP train → save → single-process eval load → inference produces sensible metrics). A latent bug surfaced here: the trainer's old rank-0 gate around save_checkpoint was correct under DDP but deadlocked under FSDP, where the gather is a collective. Fixed — decision logic stays rank-0, save is collective, only rank 0 writes the file.

Two small dtype/cache fixes also landed because they blocked the FSDP path. ST-MEM had a hardcoded .to(torch.float32) on its encoder input that fought the bf16 forward; replaced with .to(next(self.parameters()).dtype) (the pattern MTAE already used). And the LLM wrappers now pass use_cache=False during training — HF's per-layer KV-cache init guards otherwise blow past Dynamo's recompile limit and silently degrade torch.compile to eager. Generation paths unchanged.

For users: existing --distributed-style scripts and the README have been migrated to --parallel_strategy ddp with no behavior change. Eval and chat scripts work as-is regardless of which strategy was used to train. Default for --parallel_strategy is unset (single-device), DDP gives you exactly what you got before, and FSDP is the new addition.

…sor-aware Muon

Replaces --distributed (DDP-only) with --parallel_strategy {ddp, fsdp},
introduces a 3D-aware ParallelContext (DeviceMesh sized (1, world, 1) today
with names "pp", "dp", "tp" so adding TP/PP later is API-stable).

Key pieces:
- ParallelContext singleton in src/utils/parallel_context.py with shim
  helpers preserved in gpu_manager.py for backward compatibility.
- GPUSetup dispatches strategy: DDP (existing path), FSDP2 via fully_shard
  with MixedPrecisionPolicy(param=bf16, reduce=fp32) and Megatron-style fp32
  master weights (model cast to fp32 before wrap, MP downcasts for forward).
- Per-FSDP-unit torch.compile inside _wrap_fsdp so each unit's gather hook
  fires at the unit boundary instead of being skipped by an outer compile.
- MuonDistributed (src/optimizers/muon_distributed.py): subclasses Muon and
  routes DTensor params through full_tensor() -> Newton-Schulz -> distribute_tensor
  so FSDP-sharded weights work.
- CheckpointManager moved to torch.distributed.checkpoint.state_dict APIs;
  same single .pt file format as before so eval/chat consumers are unchanged.
  MuonAdamW state dict split into "muon"/"adamw" halves to round-trip cleanly.
- Per-LLM fsdp_wrap_modules() hooks (Qwen25/Llama3/Gemma2) returning decoder
  blocks via shared elms/llms/_wrap.get_decoder_layers helper.
- ST-MEM hardcoded .to(float32) replaced with .to(next(self.parameters()).dtype)
  to honor the wrapped model's actual dtype (matches MTAE's pattern).
- LLM wrappers pass use_cache=False during training to keep torch.compile from
  hitting Dynamo's recompile limit on per-layer KV cache init guards.
- Trainer/RL-trainer/main_trainer no longer rank-0-gate save_checkpoint.
  Under FSDP, get_model_state_dict is a collective that needs all ranks;
  the old DDP-era is_main() gate caused the gather to deadlock and corrupt
  state when --save_step or save_epoch fired. Decisions stay rank-0
  (save_step is deterministic; save_epoch is broadcast); the save itself
  now runs collectively, and only rank 0 writes the file (gated inside
  CheckpointManager.save_checkpoint).

Existing scripts (train.sh, train2.sh, st_mem_full_training.sh) and the README
swap --distributed for --parallel_strategy ddp; behavior unchanged for DDP.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@willxxy

willxxy commented May 4, 2026

Copy link
Copy Markdown
Member

Implementation wise it seems okay, but I want to split this into smaller PRs. Also I want us to start thinking about how we can refactor our own codebase to make these changes easier.

@willxxy willxxy May 4, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather want this to be one Muon object with a clean, intuitive separation between distributed and non distributed settings. However, the most ideal would be the following refactor to not only Muon but other things in the REPO:

What do people do in terms of distributed and non distributed settings? should we support both? Why do we need to support both? For example for DDP, we can run single device runs by setting --device and we run multiple GPU devices with --distributed and torchrun. However, we can use solely a single device in the distributed setting with CUDA_VISIBLE_DEVICES=0 uv run torchrun --standalone --nproc_per_node=1. I feel we have a lot of split paths of code between "single device specific code" and "multi device specific code".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants