Skip to content

MXFP8 training fixes for Megatron-FSDP, Torch-FSDP, MoE, FP8 Parameter initialization#130

Open
sudhu2k wants to merge 1 commit into
rocm_devfrom
sudhu/mxfp8_bugfixes_moe_llama
Open

MXFP8 training fixes for Megatron-FSDP, Torch-FSDP, MoE, FP8 Parameter initialization#130
sudhu2k wants to merge 1 commit into
rocm_devfrom
sudhu/mxfp8_bugfixes_moe_llama

Conversation

@sudhu2k
Copy link
Copy Markdown
Collaborator

@sudhu2k sudhu2k commented May 15, 2026

Changes:

  • _BaseDataParallel.finish_grad_sync: accept force_all_reduce kwarg
  • torch_fully_sharded_data_parallel: TE 2.x _fp8_attrs flat-attr fallback; drop debug print
  • arguments: gate FSDP2 fp8_param_gather warning on TE >= 2.12, not 2.0
  • fp8_utils: MXFP8 align size 32 -> 128 for HipBLASLt
  • llama2/llama3 train scripts: drop MXFP8 guards; auto-add --reuse-grad-buf-for-mxfp8-param-ag when MXFP8 + fp8-param-gather without FSDP
  • MLA: Ensure the proj value are aligned with mxfp8 align size.

What does this PR do?

Fix issues with MXFP8 training in Megatron-LM

Fixes: #15425
#15420

…P8 param-gather, and TE 2.12

- _BaseDataParallel.finish_grad_sync: accept force_all_reduce kwarg
- torch_fully_sharded_data_parallel: TE 2.x _fp8_attrs flat-attr fallback; drop debug print
- arguments: gate FSDP2 fp8_param_gather warning on TE >= 2.12, not 2.0
- fp8_utils: MXFP8 align size 32 -> 128 for HipBLASLt
- llama2/llama3 train scripts: drop MXFP8 guards; auto-add --reuse-grad-buf-for-mxfp8-param-ag when MXFP8 + fp8-param-gather without FSDP
- MLA: Ensure the proj value are aligned with mxfp8 align size.
@sudhu2k sudhu2k requested a review from wenchenvincent May 15, 2026 17:14
@sudhu2k sudhu2k self-assigned this May 15, 2026

MODEL_SIZE="${MODEL_SIZE:-70}"
TP="${TP:-8}"
TP="${TP:-1}"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default TP=8 in llama2 script. So, each gemm is now seeing K dim shape as
11008/8 = 1376.
When the K dim is 1376, it isn't aligned (not divisible by 128) according to hipblasLT and this throws " Unable to find any suitable algorithms"

'FSDP always requires CUDA_DEVICE_MAX_CONNECTIONS value large than one'

if args.fp8_param_gather and is_te_min_version("2.0.0"):
if args.fp8_param_gather and not is_te_min_version("2.12.0.dev0"):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug from upstream, fsdp2 fp8 param gather was introduced in 2.12

"""Get the alignment size required for fp8 GEMM."""
if fp8_recipe == Fp8Recipe.mxfp8:
return 32
# HipblasLT requires 128 aligned Tensors for MXFP8.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used by MoEs when requiring padding on number of tokens going to an expert when it's not divisible by 128. HipblasLT requires 128 padding for GEMM.

if remainder != 0:
self.kv_down_proj_mxfp8_padding = align - remainder
kv_down_proj_out_size += self.kv_down_proj_mxfp8_padding

Copy link
Copy Markdown
Collaborator Author

@sudhu2k sudhu2k May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes in multi_latent_attention.py fixes issues where the projection dimension (kv_down_proj_out_size) isn't divisible by 128. In deepseekv3 example, the kv down projection dimension is 576 [512 + 64] -> not divisible by 128, hence needs padding here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant