MXFP8 training fixes for Megatron-FSDP, Torch-FSDP, MoE, FP8 Parameter initialization#130
MXFP8 training fixes for Megatron-FSDP, Torch-FSDP, MoE, FP8 Parameter initialization#130sudhu2k wants to merge 1 commit into
Conversation
…P8 param-gather, and TE 2.12 - _BaseDataParallel.finish_grad_sync: accept force_all_reduce kwarg - torch_fully_sharded_data_parallel: TE 2.x _fp8_attrs flat-attr fallback; drop debug print - arguments: gate FSDP2 fp8_param_gather warning on TE >= 2.12, not 2.0 - fp8_utils: MXFP8 align size 32 -> 128 for HipBLASLt - llama2/llama3 train scripts: drop MXFP8 guards; auto-add --reuse-grad-buf-for-mxfp8-param-ag when MXFP8 + fp8-param-gather without FSDP - MLA: Ensure the proj value are aligned with mxfp8 align size.
|
|
||
| MODEL_SIZE="${MODEL_SIZE:-70}" | ||
| TP="${TP:-8}" | ||
| TP="${TP:-1}" |
There was a problem hiding this comment.
By default TP=8 in llama2 script. So, each gemm is now seeing K dim shape as
11008/8 = 1376.
When the K dim is 1376, it isn't aligned (not divisible by 128) according to hipblasLT and this throws " Unable to find any suitable algorithms"
| 'FSDP always requires CUDA_DEVICE_MAX_CONNECTIONS value large than one' | ||
|
|
||
| if args.fp8_param_gather and is_te_min_version("2.0.0"): | ||
| if args.fp8_param_gather and not is_te_min_version("2.12.0.dev0"): |
There was a problem hiding this comment.
Bug from upstream, fsdp2 fp8 param gather was introduced in 2.12
| """Get the alignment size required for fp8 GEMM.""" | ||
| if fp8_recipe == Fp8Recipe.mxfp8: | ||
| return 32 | ||
| # HipblasLT requires 128 aligned Tensors for MXFP8. |
There was a problem hiding this comment.
This is used by MoEs when requiring padding on number of tokens going to an expert when it's not divisible by 128. HipblasLT requires 128 padding for GEMM.
| if remainder != 0: | ||
| self.kv_down_proj_mxfp8_padding = align - remainder | ||
| kv_down_proj_out_size += self.kv_down_proj_mxfp8_padding | ||
|
|
There was a problem hiding this comment.
Changes in multi_latent_attention.py fixes issues where the projection dimension (kv_down_proj_out_size) isn't divisible by 128. In deepseekv3 example, the kv down projection dimension is 576 [512 + 64] -> not divisible by 128, hence needs padding here.
Changes:
What does this PR do?
Fix issues with MXFP8 training in Megatron-LM
Fixes: #15425
#15420