Draft
Conversation
This reverts commit 86fbbac.
Align GemmRowColTensorQuantPipelineProblem with ck_tile V3 requirements by using AccType for intermediate C results. Specific to TensorQuant (per-tensor scaling); limited to e4m3/e5m2 FNUZ formats. Updates test_numerics.py to exercise FP8 inputs in the grouped linear accuracy suite.
… param_types in test_numerics
Enable mixed FP8/BF8 grouped GEMM for the CK backend used by GroupedLinear backward. Certain mixed-type combinations normalize to (AType=bf8_t, BType=fp8_t), but CK currently lacks a corresponding warp GEMM specialization for WarpGemmMfma_f32_32x32x32_bf8_fp8. This prevents the default FP8 tile configuration (K_Warp_Tile=32) from compiling or dispatching correctly. To address this, a fallback tile policy is introduced that routes the (bf8_t, fp8_t) case to a supported kernel configuration using K_Warp_Tile=16. This preserves correct GEMM operand ordering and avoids unsafe operand-swapping workarounds. Notes: - Only tensor quantization mode is currently supported. - Implementation targets MI300X (CDNA3) FP8/BF8 kernels. - Additional kernel coverage may be required for MI350X (CDNA4). With this change, mixed FP8/BF8 backprop paths are supported and all parametrized unit tests in test_grouped_linear_accuracy_cutlass() pass successfully.
…e shared runner abstraction This commit restructures the CK grouped GEMM implementation to improve maintainability and better separate datatype-specific logic. Key changes: • Split the original single-source implementation into separate files for FP16 and FP8 grouped GEMM kernels. • Introduced a shared header defining a common abstraction for grouped GEMM runners. The design is similar in spirit to the Primus Turbo dispatch. • Added an abstract parent class that encapsulates the common interface and provides an overloaded operator() / run() entrypoint for launching kernels. Concrete runners implement datatype-specific behavior while sharing the same invocation path. • Introduced a GroupedGemmRunContext structure that carries runtime configuration (layout, splits, pointers, etc.) through the dispatch pipeline. This removes large argument lists and centralizes execution state. • Refactored dispatch code to construct the appropriate runner and invoke it through the unified interface. • Added documentation comments explaining the new structure and the responsibilities of each component (context, runner base class, and datatype-specific implementations). Functional behavior is unchanged. The refactor preserves the previous execution paths and continues to pass all existing Transformer Engine tests that exercised the original implementation.
Introduce extern template declarations and dedicated instantiation translation units for GroupedGemmRunner and QuantGroupedGemmRunner. This moves template instantiation for the supported dtype/layout/tile combinations into separate compilation units to reduce duplicate template instantiation across translation units and better isolate kernel codegen. No functional changes; this is purely a build/compile-time refactor.
Break up the previous monolithic templated header into separate FP16 and FP8 implementation paths. Key changes: - Introduce a lightweight common header (ck_grouped_gemm_common.h) containing shared utilities, the run context, and the RunnerInterface abstraction. - Move heavy template definitions into dtype-specific implementation headers: ck_grouped_gemm_fp16_impl.h and ck_grouped_gemm_fp8_impl.h. - Add FP16/FP8 factory source files responsible for constructing the correct runner instances based on dtype/layout/tile configuration. - Keep dispatch entry points thin and dependent only on the lightweight header. This isolates the heavy CK template code to a smaller number of translation units and prevents unnecessary template parsing across the codebase, improving build scalability without changing runtime behavior.
Refactor the FP8 and FP16 explicit template instantiations into smaller, more manageable translation units. Key changes: - Move explicit instantiations into a new subdirectory: gemm/instantiations/. - Split FP8 and FP16 instantiations across multiple source files organized by operand data type combinations. - Reduce the size of individual translation units to improve build parallelism and avoid long single-TU compile bottlenecks. No functional changes; this is a build-structure refactor only.
…950) Introduce runtime GPU architecture detection and split CK FP8 grouped GEMM factories into arch-specific implementations for gfx942 and gfx950. Key changes: - Added GPUArch detection helper using hipGetDeviceProperties. - Introduced common factory dispatcher that selects the correct arch-specific runner factory at runtime. - Split FP8 runner factories into gfx942 and gfx950 implementations. - Added arch-specific kernel instantiation translation units for each arch. - Updated build to compile both arch implementations into the same library while selecting the correct one at runtime. This enables a single TransformerEngine build to support both MI300 (gfx942) and MI350 (gfx950) while avoiding invalid tile configurations during kernel instantiation.
9f308a9 to
e9cd6b8
Compare
c834302 to
32f2ac3
Compare
…Also tweaked tile configs slightly, and gave them clearer names for gfx942
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes https://github.com/ROCm/frameworks-internal/issues/15787
TODO:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: