Add W4A8/W4A_FP8 MoE support with groupwise scale#202
Open
ClementLinCF wants to merge 6 commits intomainfrom
Open
Add W4A8/W4A_FP8 MoE support with groupwise scale#202ClementLinCF wants to merge 6 commits intomainfrom
ClementLinCF wants to merge 6 commits intomainfrom
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds W4A8 (int4) and W4A_FP8 (int4_fp8) MoE support with groupwise scaling (group_size=32) to the FlyDSL fused MoE 2-stage kernel. It extends the existing int4_bf16 (W4A16) path with new load/unpack helpers and per-K32 group accumulation logic.
Changes:
- New
int4_fp8dtype support with FP8 activations + packed int4 weights, usingmfma_f32_16x16x32_fp8_fp8and in-kernel int4→fp8 conversion viacvt_pk_fp8_f32. - Groupwise scale (group_size=32) for all int4 weight variants (int4, int4_bf16, int4_fp8) with per-K32 fresh-accumulator + scale + running f32 accumulator pattern.
- New load/unpack helpers (
load_b_raw_w4a8_k64,load_b_raw_w4a8_groupwise_k64,unpack_b_w4a8,unpack_b_w4a_fp8, etc.) inmfma_preshuffle_pipeline.py.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| kernels/mfma_preshuffle_pipeline.py | New load/unpack helpers for W4A8, W4A_FP8, and groupwise scale variants |
| kernels/moe_gemm_2stage.py | Extended stage1/stage2 compile functions with int4_fp8 dtype and groupwise scale paths |
| tests/kernels/test_moe_gemm.py | Added int4_fp8 to test parameterization and corresponding quantization/routing logic |
| tests/test_common.py | Minor whitespace cleanup |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
coderfeli
reviewed
Mar 17, 2026
coderfeli
reviewed
Mar 17, 2026
coderfeli
reviewed
Mar 17, 2026
…tions - Extract _unpack_int4_to_int8_pair(): shared 7-op int4->int8 bit manipulation used by unpack_b_w4a16, unpack_b_w4a8, unpack_b_w4a_fp8, and load_b_pack_k32 (was copy-pasted in 4 places) - Extract _pack_i32_pair_to_i64(): shared (even, odd) -> i64 packing - Extract _load_groupwise_scale(): shared scale address calculation and buffer_load for W4A16 and W4A8 groupwise paths - Have load_b_raw_w4a8_groupwise_k64 delegate weight load to load_b_raw_w4a8_k64 (matching W4A16 groupwise pattern) - Replace ir.IntegerType.get_signless(32) / ir.F32Type.get() with T.i32 / T.f32 to follow project conventions - Replace arith.constant(..., index=True) with fx.Index(...) throughout
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
The existing fused MoE 2-stage kernel supports fp8, fp16, bf16, int8, and W4A16 (int4_bf16) data types. This PR extends it with W4A8 (int4) and W4A_FP8 (int4_fp8) support, and adds groupwise scale (group_size=32) for all three int4 weight variants — enabling lower-precision MoE inference paths that are critical for production deployment of large MoE models (e.g., Kimi K2.5).
Technical Details
New dtype: int4_fp8 (W4A_FP8)
FP8 activations + packed int4 weights, using mfma_f32_16x16x32_fp8_fp8.
In-kernel int4→fp8 unpack via cvt_pk_fp8_f32 ROCDL intrinsic.
8-byte K64 weight loads (buffer_load_dwordx2) for improved memory efficiency.
Groupwise scale (group_size=32) for W4A8/W4A16/W4A_FP8
Per-K32 groupwise accumulation: fresh MFMA accumulator + per-group scale + running f32 accumulator, for both stage1 and stage2.
Groupwise scale address formula using [E, num_groups, N] layout with preshuffled scale tensors.
Epilogue correctly skips sitofp for groupwise accumulators (already f32 from per-K32 accumulation).
Test Plan
pytest tests/kernels/test_moe_gemm.py::test_moe_gemm_2stage -k "(g32) or (perrow and fp8)" covering:
Groupwise (g32) + out_f16: shapes S/M/L × int4, int4_bf16, int4_fp8 × atomic/reduce × eager/graph
Groupwise (g32) + out_f32: shapes S/M/L × int4, int4_bf16, int4_fp8 × atomic × eager/graph
Non-groupwise (perrow) + fp8: shapes S/M/L × fp8, int4_fp8 × f16/f32 × atomic/reduce × eager/graph
Each test verifies correctness against torch reference (for S/M shapes) and runs perf timing.
Test Result
120 passed, 360 skipped, 288 deselected in 35.39s
Groupwise f16 (S/M/L): all PASSED (int4, int4_bf16, int4_fp8)
Groupwise f32 (S/M/L): all PASSED (int4, int4_bf16, int4_fp8)
Non-groupwise perrow fp8 (S/M/L): all PASSED
0 failures
E2E test
Kimi-K2.5 W4A16, W4A8 on MI308
Submission Checklist