fix(gemm_a8w8_bpreshuffle): pass splitK/KBatch to CK kernels#2335
fix(gemm_a8w8_bpreshuffle): pass splitK/KBatch to CK kernels#2335AviralGoelAMD wants to merge 3 commits intomainfrom
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
This PR introduces a KBatch (split-K batch) parameter plumbing into the FP8 a8w8 bpreshuffle GEMM implementations for both CK and CKTile backends, enabling the tuning path to run kernels with split-K.
Changes:
- Add
KBatchparameter to CK/CKTile GEMM implementation entrypoints and pass it into kernel args. - Update code generation (
gen_instances.py) so generated kernel wrappers acceptKBatchand forward it to the impl. - Update tune dispatchers to pass the computed
KBatchinto the selected kernel.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile_common.cuh | Add KBatch argument and plumb into args.k_batch. |
| csrc/cktile_gemm_a8w8_bpreshuffle/gen_instances.py | Generate kernel wrappers/manifest declarations with KBatch. |
| csrc/cktile_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_cktile_tune.cu | Update dispatch function types/calls to pass KBatch. |
| csrc/ck_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_common.cuh | Add KBatch argument and plumb into CK MakeArgument. |
| csrc/ck_gemm_a8w8_bpreshuffle/gen_instances.py | Generate kernel wrappers/manifest declarations with KBatch. |
| csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.cu | Update dispatch function types/calls to pass KBatch. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile_common.cuh
Show resolved
Hide resolved
…tK, use integer shift
|
I;ve resolved all the copilot comments. |
|
I think the non-tune path should keep same with tuned path. If a splitK >1 pattern is the best config, the splitK parameter should also be passed to ck kernels |
Summary
splitKparameter was accepted by the tune entry points andKBatch = 2^splitKwas computed, but never forwarded to the CK/CKTile kernels — it was hardcoded to1in both pathsKBatchthrough the full dispatch chain:tune.cu → dispatch → generated kernel wrapper → impl → CK MakeArgument / CKTile args.k_batchint KBatch = 1default parameter