Add MPerBlock=128 tile size for blockscale FP8 MoE kernels#2314
Add MPerBlock=128 tile size for blockscale FP8 MoE kernels#2314ChuanLi1101 wants to merge 5 commits intomainfrom
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
7a13969 to
7124fa1
Compare
|
need some results |
Enable MPerBlock=128 for a8w8 blockscale stage1 and stage2 kernel lists. This tile size was previously commented out and is needed for DeepSeek V3 EP8 prefill where tokens_per_expert is large. The same tile configuration (256,128,128,128,1,4,V3) is already proven working in non-blockscale a8w8 kernel lists. Changes: - Uncomment and enable MPerBlock=128 in a8w8_gemm1_blockscale_kernels_list - Add MPerBlock=128 to a8w8_gemm2_blockscale_kernels_list with MWaves=1 NWaves=4 consistent with non-blockscale instead of original 2,2 - Add block_m=128 branches in blockscale heuristic dispatch for both stage1 and stage2 - Add minimal DSV3 EP8 FP8 blockscale shape for tuning verification Made-with: Cursor
Stage2 with MWaves=1,NWaves=4 causes AGPR spill (MXDLPerWave=8 too high), resulting in 15.1% error. Revert to original MWaves=2,NWaves=2 which has MXDLPerWave=4, matching the original author's intent to reduce register pressure. Stage1 remains at 1,4 as it showed only 0.1% error. Made-with: Cursor
f9b79b5 to
2c56264
Compare
Blockscale FP8 had only 1 kernel config per MPerBlock while non-blockscale had 4+. This limits the tuner's search space. Add V1 pipeline and K=256 variants for M=32/64, plus a 2,2 wave config for M=128 Stage1: Stage1 additions: - M=32, K=256, V1: better prefetch for memory-bound cases - M=64, V1: lower-overhead pipeline alternative to V3 - M=64, K=256, V3: bigger K for better pipeline utilization - M=128, MWaves=2,NWaves=2, V3: lower AGPR pressure alternative Stage2 additions (same pattern, excluding M=128 1,4 due to AGPR spill): - M=32, K=256, V1 - M=64, V1 - M=64, K=256, V3 Made-with: Cursor
|
@ChuanLi1101 There are two issues in this PR during the manual test on MI355, two potential change are :
|
Made-with: Cursor
…support Address two issues reported by JohnQinAMD during MI355X testing: 1. KPerBlock=256 blockscale configs used V3 pipeline which requires scaleblocksliceK=1, but K=256/blockscale128=2. Changed M=64 K=256 variants from V3 to V1 in both Stage1 and Stage2. 2. Add both torch.float8_e4m3fn (gfx950) and torch.float8_e4m3fnuz (gfx942) to dtype2str_dict so the tuner works with verify CSVs on either MI355X or MI300X without dtype lookup failures. Made-with: Cursor
|
@JohnQinAMD Thanks for catching these! Both issues are fixed in the latest push:
Could you re-test on MI355X when you get a chance? |
|
@ChuanLi1101 If set doweight_stage1=0, Stage2 blockscale will have 15% error due to BF16 AtomicAdd accumulation error, a quick walkaround is set doweight_stage1=1 in the csv file. The tflops will not be impacted. |
|
Another two issues identified during the test.
|
Summary
Motivation
For DeepSeek V3 EP8 FP8 blockscale prefill (ISL=8K, concurrency=128), the current maximum blockscale tile (MPerBlock=64) limits performance. Adding MPerBlock=128 should improve compute utilization for large token-per-expert counts.
Test plan