Skip to content

Add W4A8/W4A_FP8 MoE support with groupwise scale#202

Open
ClementLinCF wants to merge 6 commits intomainfrom
feature/w4a8-moe-port
Open

Add W4A8/W4A_FP8 MoE support with groupwise scale#202
ClementLinCF wants to merge 6 commits intomainfrom
feature/w4a8-moe-port

Conversation

@ClementLinCF
Copy link
Contributor

@ClementLinCF ClementLinCF commented Mar 12, 2026

Motivation

The existing fused MoE 2-stage kernel supports fp8, fp16, bf16, int8, and W4A16 (int4_bf16) data types. This PR extends it with W4A8 (int4) and W4A_FP8 (int4_fp8) support, and adds groupwise scale (group_size=32) for all three int4 weight variants — enabling lower-precision MoE inference paths that are critical for production deployment of large MoE models (e.g., Kimi K2.5).

Technical Details

New dtype: int4_fp8 (W4A_FP8)

  • FP8 activations + packed int4 weights, using mfma_f32_16x16x32_fp8_fp8.

  • In-kernel int4→fp8 unpack via cvt_pk_fp8_f32 ROCDL intrinsic.

  • 8-byte K64 weight loads (buffer_load_dwordx2) for improved memory efficiency.

Groupwise scale (group_size=32) for W4A8/W4A16/W4A_FP8

  • Per-K32 groupwise accumulation: fresh MFMA accumulator + per-group scale + running f32 accumulator, for both stage1 and stage2.

  • Groupwise scale address formula using [E, num_groups, N] layout with preshuffled scale tensors.

  • Epilogue correctly skips sitofp for groupwise accumulators (already f32 from per-K32 accumulation).

Test Plan

pytest tests/kernels/test_moe_gemm.py::test_moe_gemm_2stage -k "(g32) or (perrow and fp8)" covering:

  • Groupwise (g32) + out_f16: shapes S/M/L × int4, int4_bf16, int4_fp8 × atomic/reduce × eager/graph

  • Groupwise (g32) + out_f32: shapes S/M/L × int4, int4_bf16, int4_fp8 × atomic × eager/graph

  • Non-groupwise (perrow) + fp8: shapes S/M/L × fp8, int4_fp8 × f16/f32 × atomic/reduce × eager/graph

Each test verifies correctness against torch reference (for S/M shapes) and runs perf timing.

Test Result

120 passed, 360 skipped, 288 deselected in 35.39s
Groupwise f16 (S/M/L): all PASSED (int4, int4_bf16, int4_fp8)
Groupwise f32 (S/M/L): all PASSED (int4, int4_bf16, int4_fp8)
Non-groupwise perrow fp8 (S/M/L): all PASSED
0 failures

E2E test

Kimi-K2.5 W4A16, W4A8 on MI308

Metrics W4A8 con=2 W4A8 con=40 W4A16 con=2 W4A16 con=40
Output throughput (tok/s) 61.37 261.02 24.48 104.97
Peak output throughput (tok/s) 82.00 800.00 60.00 442.00
Total throughput (tok/s) 1288.80 5481.35 514.06 2204.34
Mean TPOT (ms) 25.82 102.51 35.80 223.64
Median TPOT (ms) 25.83 102.98 35.51 231.12
Mean TTFT (ms) 3485.58 26063.61 23530.88 68158.30
Median ITL (ms) 24.89 52.07 34.08 83.65
Mean E2E Latency (ms) 16679.06 78443.72 41823.70 182436.60

Submission Checklist

Copilot AI review requested due to automatic review settings March 12, 2026 15:53
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds W4A8 (int4) and W4A_FP8 (int4_fp8) MoE support with groupwise scaling (group_size=32) to the FlyDSL fused MoE 2-stage kernel. It extends the existing int4_bf16 (W4A16) path with new load/unpack helpers and per-K32 group accumulation logic.

Changes:

  • New int4_fp8 dtype support with FP8 activations + packed int4 weights, using mfma_f32_16x16x32_fp8_fp8 and in-kernel int4→fp8 conversion via cvt_pk_fp8_f32.
  • Groupwise scale (group_size=32) for all int4 weight variants (int4, int4_bf16, int4_fp8) with per-K32 fresh-accumulator + scale + running f32 accumulator pattern.
  • New load/unpack helpers (load_b_raw_w4a8_k64, load_b_raw_w4a8_groupwise_k64, unpack_b_w4a8, unpack_b_w4a_fp8, etc.) in mfma_preshuffle_pipeline.py.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
kernels/mfma_preshuffle_pipeline.py New load/unpack helpers for W4A8, W4A_FP8, and groupwise scale variants
kernels/moe_gemm_2stage.py Extended stage1/stage2 compile functions with int4_fp8 dtype and groupwise scale paths
tests/kernels/test_moe_gemm.py Added int4_fp8 to test parameterization and corresponding quantization/routing logic
tests/test_common.py Minor whitespace cleanup

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@coderfeli coderfeli requested a review from yadaish March 17, 2026 01:14
…tions

- Extract _unpack_int4_to_int8_pair(): shared 7-op int4->int8 bit
  manipulation used by unpack_b_w4a16, unpack_b_w4a8, unpack_b_w4a_fp8,
  and load_b_pack_k32 (was copy-pasted in 4 places)
- Extract _pack_i32_pair_to_i64(): shared (even, odd) -> i64 packing
- Extract _load_groupwise_scale(): shared scale address calculation and
  buffer_load for W4A16 and W4A8 groupwise paths
- Have load_b_raw_w4a8_groupwise_k64 delegate weight load to
  load_b_raw_w4a8_k64 (matching W4A16 groupwise pattern)
- Replace ir.IntegerType.get_signless(32) / ir.F32Type.get() with
  T.i32 / T.f32 to follow project conventions
- Replace arith.constant(..., index=True) with fx.Index(...) throughout
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants