Skip to content

Add FP8 Support For CK Tile Group GEMM#475

Draft
aris134 wants to merge 70 commits intodevfrom
amartin/ck-grouped-gemm-fp8
Draft

Add FP8 Support For CK Tile Group GEMM#475
aris134 wants to merge 70 commits intodevfrom
amartin/ck-grouped-gemm-fp8

Conversation

@aris134
Copy link

@aris134 aris134 commented Mar 6, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes https://github.com/ROCm/frameworks-internal/issues/15787

TODO:

  • Add support for other architectures (i.e., MI350X)
  • Add support for other quantization modes
  • Performance analysis and tuning

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Enables mixed precision (fp8/bf8 FNUZ variants) support for CK tile grouped GEMM with tensor quantization

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

matthiasdiener and others added 18 commits February 23, 2026 13:10
Align GemmRowColTensorQuantPipelineProblem with ck_tile V3 requirements by using AccType for intermediate C results. Specific to TensorQuant (per-tensor scaling); limited to e4m3/e5m2 FNUZ formats. Updates test_numerics.py to exercise FP8 inputs in the grouped linear accuracy suite.
Enable mixed FP8/BF8 grouped GEMM for the CK backend used by GroupedLinear backward.

Certain mixed-type combinations normalize to (AType=bf8_t, BType=fp8_t), but CK currently lacks a corresponding warp GEMM specialization for WarpGemmMfma_f32_32x32x32_bf8_fp8. This prevents the default FP8 tile configuration (K_Warp_Tile=32) from compiling or dispatching correctly.

To address this, a fallback tile policy is introduced that routes the (bf8_t, fp8_t) case to a supported kernel configuration using K_Warp_Tile=16. This preserves correct GEMM operand ordering and avoids unsafe operand-swapping workarounds.

Notes:
- Only tensor quantization mode is currently supported.
- Implementation targets MI300X (CDNA3) FP8/BF8 kernels.
- Additional kernel coverage may be required for MI350X (CDNA4).

With this change, mixed FP8/BF8 backprop paths are supported and all parametrized unit tests in test_grouped_linear_accuracy_cutlass() pass successfully.
@aris134 aris134 self-assigned this Mar 6, 2026
aris134 and others added 6 commits March 6, 2026 15:31
…e shared runner abstraction

This commit restructures the CK grouped GEMM implementation to improve
maintainability and better separate datatype-specific logic.

Key changes:

• Split the original single-source implementation into separate files for
  FP16 and FP8 grouped GEMM kernels.

• Introduced a shared header defining a common abstraction for grouped GEMM
  runners. The design is similar in spirit to the Primus Turbo dispatch.

• Added an abstract parent class that encapsulates the common interface and
  provides an overloaded operator() / run() entrypoint for launching
  kernels. Concrete runners implement datatype-specific behavior while
  sharing the same invocation path.

• Introduced a GroupedGemmRunContext structure that carries runtime
  configuration (layout, splits, pointers, etc.) through the dispatch
  pipeline. This removes large argument lists and centralizes execution
  state.

• Refactored dispatch code to construct the appropriate runner and invoke it
  through the unified interface.

• Added documentation comments explaining the new structure and the
  responsibilities of each component (context, runner base class, and
  datatype-specific implementations).

Functional behavior is unchanged. The refactor preserves the previous
execution paths and continues to pass all existing Transformer Engine
tests that exercised the original implementation.
Introduce extern template declarations and dedicated instantiation translation
units for GroupedGemmRunner and QuantGroupedGemmRunner.

This moves template instantiation for the supported dtype/layout/tile
combinations into separate compilation units to reduce duplicate template
instantiation across translation units and better isolate kernel codegen.

No functional changes; this is purely a build/compile-time refactor.
Break up the previous monolithic templated header into separate FP16 and FP8
implementation paths.

Key changes:
- Introduce a lightweight common header (ck_grouped_gemm_common.h) containing
  shared utilities, the run context, and the RunnerInterface abstraction.
- Move heavy template definitions into dtype-specific implementation headers:
  ck_grouped_gemm_fp16_impl.h and ck_grouped_gemm_fp8_impl.h.
- Add FP16/FP8 factory source files responsible for constructing the correct
  runner instances based on dtype/layout/tile configuration.
- Keep dispatch entry points thin and dependent only on the lightweight header.

This isolates the heavy CK template code to a smaller number of translation
units and prevents unnecessary template parsing across the codebase, improving
build scalability without changing runtime behavior.
Refactor the FP8 and FP16 explicit template instantiations into smaller,
more manageable translation units.

Key changes:
- Move explicit instantiations into a new subdirectory: gemm/instantiations/.
- Split FP8 and FP16 instantiations across multiple source files organized
  by operand data type combinations.
- Reduce the size of individual translation units to improve build
  parallelism and avoid long single-TU compile bottlenecks.

No functional changes; this is a build-structure refactor only.
…950)

Introduce runtime GPU architecture detection and split CK FP8 grouped GEMM
factories into arch-specific implementations for gfx942 and gfx950.

Key changes:
- Added GPUArch detection helper using hipGetDeviceProperties.
- Introduced common factory dispatcher that selects the correct arch-specific
  runner factory at runtime.
- Split FP8 runner factories into gfx942 and gfx950 implementations.
- Added arch-specific kernel instantiation translation units for each arch.
- Updated build to compile both arch implementations into the same library
  while selecting the correct one at runtime.

This enables a single TransformerEngine build to support both MI300 (gfx942)
and MI350 (gfx950) while avoiding invalid tile configurations during kernel
instantiation.
@aris134 aris134 force-pushed the amartin/ck-grouped-gemm-fp8 branch from 9f308a9 to e9cd6b8 Compare March 11, 2026 15:37
@aris134 aris134 force-pushed the amartin/ck-grouped-gemm-fp8 branch from c834302 to 32f2ac3 Compare March 11, 2026 15:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants