Skip to content

feat: MXFP8 epilogue fusion for SM100 (B200)#1

Open
wanyaworld wants to merge 4 commits intomainfrom
gemm_mxfp8_out_tma
Open

feat: MXFP8 epilogue fusion for SM100 (B200)#1
wanyaworld wants to merge 4 commits intomainfrom
gemm_mxfp8_out_tma

Conversation

@wanyaworld
Copy link
Copy Markdown
Member

@wanyaworld wanyaworld commented Apr 10, 2026

요약

  • SM100 (B200) GEMM 커널 에필로그에 MXFP8 quantization 융합 — FP8 + E8M0 scale을 커널에서 직접 출력
  • torch.compile(fullgraph=True) compiled baseline (GEMM + quantize를 하나의 fx graph로) 대비 1.34~2.12x speedup
  • 커널 레벨: fused GEMM ≈ baseline GEMM (~275us 동등), baseline은 추가로 Triton quantize 커널 ~47us 소요

추가된 API

  • fp8_gemm_nt_mxfp8out(a, b, d_fp8, d_sf)
  • m_grouped_fp8_gemm_nt_contiguous_mxfp8out(a, b, d_fp8, d_sf, layout)

B200 성능 (compiled graph baseline 대비, fullgraph=True)

Normal GEMM:

M N K Fused Baseline Speedup Diff
1 7168 2048 0.149ms 0.199ms 1.34x 0.00011
128 7168 2048 0.159ms 0.214ms 1.34x 0.00013
256 7168 2048 0.155ms 0.232ms 1.50x 0.00013
4096 7168 2048 0.191ms 0.271ms 1.42x 0.00013
128 4096 7168 0.159ms 0.247ms 1.55x 0.00013
256 4096 7168 0.168ms 0.273ms 1.62x 0.00013

Grouped GEMM:

G M N K Fused Baseline Speedup Diff
4 30K 7168 2048 0.493ms 0.661ms 1.34x 0.00013
8 37K 7168 2048 0.576ms 0.804ms 1.40x 0.00013
4 35K 4096 7168 0.935ms 1.620ms 1.73x 0.02003
8 32K 4096 7168 0.888ms 1.539ms 1.73x 0.02003
48 63K 1280 4096 0.401ms 0.851ms 2.12x 0.02578
  • M=63k trace:
    /mair/team-sys/jangwoong/DeepGEMM/traces/tma_compiled_1775813473/jangwoong-dg-mxfp8v7-node-0-0_628.1775813473554960843.pt.trace.json

Diff ~0.02는 fused가 FP32 TMEM에서 직접 quantize하고 baseline은 BF16 truncation 후 quantize하기 때문. fused가 더 정확함.

핵심 설계

  • 기존 BF16/FP32 에필로그의 s-loop/TMA pipeline을 if constexpr완전 우회
  • MXFP8 에필로그: TMEM → 레지스터 (thread당 32개 연속 N값) → quantize → SMEM → TMA store
  • Scale (E8M0)은 global memory direct write (TMA로 보내기엔 너무 작음)
  • Fused는 FP32 TMEM에서 직접 quantize → baseline의 BF16 중간 truncation 없이 더 정확함

테스트

  • 정확도: 모든 shape에서 baseline 대비 diff < 0.03
  • 성능: 모든 shape에서 compiled graph baseline보다 빠름
  • 프로파일러 trace: 커널 레벨 동등 확인
  • Normal + Grouped GEMM B200에서 검증

wanyaworld and others added 4 commits April 10, 2026 09:07
MXFP8 quantization fused into SM100 GEMM epilogue.
TMEM → register → quantize → global memory direct write.
TMA pipeline completely bypassed with if constexpr.

Co-Authored-By: Claude <noreply@anthropic.com>
Switch MXFP8 epilogue from thread-level global memory writes to
TMA store via SMEM. Uses independent s-loop with same pipeline
structure as BF16 path (wait → write → fence → sync → TMA → arrive).

Kernel time: 362us → 273us (matches baseline GEMM kernel).

Co-Authored-By: Claude <noreply@anthropic.com>
Wrap DeepGEMM GEMM calls with torch.library.custom_op so baseline
(GEMM + quantize) runs as a single compiled fx graph, eliminating
Python dispatch overhead between the two ops.

Fused still wins 1.34-2.12x vs compiled graph baseline.

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Claude <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant