Skip to content

Add MPerBlock=128 tile size for blockscale FP8 MoE kernels#2314

Open
ChuanLi1101 wants to merge 5 commits intomainfrom
chuan/add-blockscale-mperblock128
Open

Add MPerBlock=128 tile size for blockscale FP8 MoE kernels#2314
ChuanLi1101 wants to merge 5 commits intomainfrom
chuan/add-blockscale-mperblock128

Conversation

@ChuanLi1101
Copy link

@ChuanLi1101 ChuanLi1101 commented Mar 17, 2026

Summary

  • Enable MPerBlock=128 for a8w8 blockscale stage1 and stage2 CK MoE kernel lists
  • This tile size was previously commented out since the initial commit and is needed for DeepSeek V3 EP8 prefill scenarios where tokens_per_expert is large (~4096)
  • The same tile configuration (256, 128, 128, 128, 1, 4, V3) is already proven working in non-blockscale a8w8 kernel lists
  • Add block_m=128 branches in blockscale heuristic dispatch for both stage1 and stage2
  • Include minimal DSV3 EP8 FP8 blockscale shape CSV for tuning verification

Motivation

For DeepSeek V3 EP8 FP8 blockscale prefill (ISL=8K, concurrency=128), the current maximum blockscale tile (MPerBlock=64) limits performance. Adding MPerBlock=128 should improve compute utilization for large token-per-expert counts.

Test plan

  • Run tuner with the included verify CSV on MI355X to confirm correctness (errRatio < 0.5%) and performance improvement:
      -i aiter/configs/untuned_dsv3_ep8_verify.csv \
      -o /tmp/tuned_dsv3_ep8_verify.csv --last -v
    ```n- [ ] Compare TFLOPS of MPerBlock=128 vs MPerBlock=64 for EP8 FP8 blockscale
  • Verify no regression on existing tuned shapes

@ChuanLi1101 ChuanLi1101 requested a review from a team March 17, 2026 18:13
@github-actions
Copy link
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2314 --add-label <label>

@ChuanLi1101 ChuanLi1101 force-pushed the chuan/add-blockscale-mperblock128 branch from 7a13969 to 7124fa1 Compare March 17, 2026 18:18
@valarLip
Copy link
Collaborator

need some results

Enable MPerBlock=128 for a8w8 blockscale stage1 and stage2 kernel lists.
This tile size was previously commented out and is needed for DeepSeek V3
EP8 prefill where tokens_per_expert is large. The same tile configuration
(256,128,128,128,1,4,V3) is already proven working in non-blockscale a8w8
kernel lists.

Changes:
- Uncomment and enable MPerBlock=128 in a8w8_gemm1_blockscale_kernels_list
- Add MPerBlock=128 to a8w8_gemm2_blockscale_kernels_list with MWaves=1
  NWaves=4 consistent with non-blockscale instead of original 2,2
- Add block_m=128 branches in blockscale heuristic dispatch for both
  stage1 and stage2
- Add minimal DSV3 EP8 FP8 blockscale shape for tuning verification

Made-with: Cursor
Stage2 with MWaves=1,NWaves=4 causes AGPR spill (MXDLPerWave=8 too high),
resulting in 15.1% error. Revert to original MWaves=2,NWaves=2 which has
MXDLPerWave=4, matching the original author's intent to reduce register
pressure. Stage1 remains at 1,4 as it showed only 0.1% error.

Made-with: Cursor
@ChuanLi1101 ChuanLi1101 force-pushed the chuan/add-blockscale-mperblock128 branch from f9b79b5 to 2c56264 Compare March 18, 2026 19:04
Blockscale FP8 had only 1 kernel config per MPerBlock while non-blockscale
had 4+. This limits the tuner's search space. Add V1 pipeline and K=256
variants for M=32/64, plus a 2,2 wave config for M=128 Stage1:

Stage1 additions:
- M=32, K=256, V1: better prefetch for memory-bound cases
- M=64, V1: lower-overhead pipeline alternative to V3
- M=64, K=256, V3: bigger K for better pipeline utilization
- M=128, MWaves=2,NWaves=2, V3: lower AGPR pressure alternative

Stage2 additions (same pattern, excluding M=128 1,4 due to AGPR spill):
- M=32, K=256, V1
- M=64, V1
- M=64, K=256, V3

Made-with: Cursor
@JohnQinAMD
Copy link

@ChuanLi1101 There are two issues in this PR during the manual test on MI355, two potential change are :

  • Change KPerBlock=256 blockscale configs from V3 to V1 in both Stage1
    and Stage2 kernel lists. V3 pipeline requires scaleblocksliceK=1, but
    KPerBlock=256 with blockscale=128 gives scaleblocksliceK=2.
  • Need to adapt to MI355. Could we add both torch.float8_e4m3fn (gfx950) and torch.float8_e4m3fnuz (gfx942)
    to dtype2str_dict so the tuner works with verify CSVs on either platform.

…support

Address two issues reported by JohnQinAMD during MI355X testing:

1. KPerBlock=256 blockscale configs used V3 pipeline which requires
   scaleblocksliceK=1, but K=256/blockscale128=2. Changed M=64 K=256
   variants from V3 to V1 in both Stage1 and Stage2.

2. Add both torch.float8_e4m3fn (gfx950) and torch.float8_e4m3fnuz
   (gfx942) to dtype2str_dict so the tuner works with verify CSVs
   on either MI355X or MI300X without dtype lookup failures.

Made-with: Cursor
@ChuanLi1101
Copy link
Author

@JohnQinAMD Thanks for catching these! Both issues are fixed in the latest push:

  1. KPerBlock=256 V3->V1: Changed M=64 K=256 blockscale configs from V3 to V1 in both Stage1 and Stage2. V3 requires scaleblocksliceK=1 but K=256/blockscale128=2. The M=16 and M=32 K=256 variants were already V1 so no issue there.

  2. Cross-platform FP8 dtype: Added both torch.float8_e4m3fn (gfx950) and torch.float8_e4m3fnuz (gfx942) to dtype2str_dict in moe_op.py, so the tuner works on either MI355X or MI300X regardless of which dtype the CSV uses.

Could you re-test on MI355X when you get a chance?

@JohnQinAMD
Copy link

@ChuanLi1101 If set doweight_stage1=0, Stage2 blockscale will have 15% error due to BF16 AtomicAdd accumulation error, a quick walkaround is set doweight_stage1=1 in the csv file. The tflops will not be impacted.

@JohnQinAMD
Copy link

Another two issues identified during the test.

  1. gemm_moe_ck2stages_common.py — Ned to disable AGPR-spilling Stage1 kernel
    4: kernelInstanceGEMM1(256, 128, 128, 128, 1, 4, 3,), # AGPR spill: MXDLPerWave=8, 16.7% error
    Commented out Stage1 MPerBlock=128 with MWaves=1, NWaves=4. This config has MXDLPerWave=8 causing AGPR register spill → 16.7% error. The safe alternative (id=8, MWaves=2, NWaves=2, 0.1% error) remains.

  2. gen_instances.py — heuristic dispatch for block_m=128
    return ck_moe_stage1_gemm<..., 1, 4, ...>;
    return ck_moe_stage1_gemm<..., 2, 2, ...>;
    The blockscale Stage1 heuristic for block_m==128 was hardcoded to MWaves=1, NWaves=4 (the AGPR-spilling config). can changed to MWaves=2, NWaves=2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants