Draft
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…ze ops - Add __launch_bounds__(64, 8) for better register allocation - Replace int64_t with int32_t for token/slot indices to halve SGPR usage - Move block_size from template param to MlaKernelParams as log2, use bitwise shift/mask instead of division/modulo (power-of-2 guaranteed) - Promote runtime vars to constexpr: oob sizes, head_size, nope_offset, kv_lora_vec, reduce_thread_size - Remove unused variables (qH_per_kH, num_kv_vecs, group_size shadow) - Pre-compute token-level base offsets to avoid repeated stride multiplies - Hoist group_id and inv_scale out of inner quantization loops - Use bitshift for group_id: (tid * vec_size_i) >> 6 instead of / 64 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
This PR introduces a new fused ROCm cache kernel variant that combines RMSNorm + RoPE with group-quantized KV (and optionally group-quantized Q), and extends the Python bindings/tests to exercise the new path.
Changes:
- Add a new HIP kernel + C++ entrypoint for fused (Q/K) norm + RoPE + group-quant + cache write for MLA prefill (GQA).
- Expose the new op via pybind and aiter Python stubs.
- Add new op-tests benchmarks + torch reference implementations for correctness/perf comparisons.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
csrc/kernels/cache_kernels.cu |
Adds the new optimized fused kernel and a new C++ interface; also adjusts existing prefill launch config. |
csrc/include/cache.h |
Declares the new C++ API. |
csrc/include/rocm_ops.hpp |
Exposes the new op via pybind. |
aiter/ops/cache.py |
Adds the Python compile_ops stub for the new op. |
op_tests/test_concat_cache_mla.py |
Adds torch references + benchmark/correctness harnesses for the new fused path(s). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| if (use_optimized_prefill) { | ||
| dim3 grid(num_tokens); | ||
| dim3 grid(num_tokens, num_kv_heads); // Use 2D grid: x=tokens, y=kv_heads |
Comment on lines
+2510
to
+2547
| const float inverted_DTYPE_MAX = | ||
| std::is_same_v<query_t, ck_tile::fp4x2_t> | ||
| ? 0.25f | ||
| : (1.f / ck_tile::type_convert<float>(ck_tile::numeric<cache_t>::max())); | ||
| const float q_group_scale = thread_max * inverted_DTYPE_MAX; | ||
|
|
||
| query_t* q_nope_out_base = reinterpret_cast<query_t*>( | ||
| q_out + token_qout_base + q_head_idx * params.q_out_stride_1 + (is_nope_first ? 0 : pe_dim)); | ||
|
|
||
| float q_inv_scale; | ||
| if constexpr (std::is_same_v<query_t, ck_tile::fp8_t>) { | ||
| uint32_t u32 = __builtin_bit_cast(uint32_t, q_group_scale); | ||
| uint32_t exponent = (u32 >> 23) & 0xFF; | ||
| if (u32 & 0x7FFFFF) exponent += 1; | ||
| if (tid % reduce_thread_size == 0) { | ||
| const int group_id = (tid * vec_size_i) >> 6; | ||
| auto* tmp = reinterpret_cast<uint8_t*>(q_nope_out_base + kv_lora_dim); | ||
| if (tid < kv_lora_vec) { | ||
| tmp[group_id] = static_cast<uint8_t>(exponent); | ||
| } else { | ||
| tmp[group_id] = 0; | ||
| } | ||
| } | ||
| uint32_t e8m0_u32 = exponent << 23; | ||
| q_inv_scale = has_q_data ? (1.0f / __builtin_bit_cast(float, e8m0_u32)) : 0.0f; | ||
| } else { | ||
| q_inv_scale = has_q_data ? (1.0f / q_group_scale) : 0.0f; | ||
| } | ||
|
|
||
| // Store quantized Q data | ||
| if (has_q_data) { | ||
| const uint32_t q_offset = tid * vec_size_i; | ||
| opus_vec_q vec_out; | ||
| #pragma unroll | ||
| for (int i = 0; i < vec_size_i; i++) { | ||
| vec_out[i] = ck_tile::type_convert<cache_t>(q_fp32[i] * q_inv_scale); | ||
| } | ||
| q_buffer_o.template store<vec_size_o>(vec_out, q_offset); |
Comment on lines
+4115
to
+4132
| int num_tokens = slot_mapping.size(0); | ||
| int kv_lora_rank = kv_c.size(-1); | ||
| int pe_dim = k_pe.size(-1); | ||
| int block_size = kv_cache.size(1); | ||
| int num_heads = q_nope.size(1); | ||
| int qk_lora_rank = q_nope.size(-1); | ||
| int rot_dim = cos_cache.size(-1) * 2; | ||
|
|
||
| // Validate dimensions | ||
| TORCH_CHECK(q_nope.dim() == q_pe.dim()); | ||
| TORCH_CHECK(q_nope.size(1) == q_pe.size(1)); | ||
| TORCH_CHECK(q_out.size(2) == qk_lora_rank + pe_dim); | ||
| TORCH_CHECK(kv_lora_rank == qk_lora_rank, "kv_lora_rank and qk_lora_rank must be the same"); | ||
| TORCH_CHECK(k_weight.size(0) == kv_lora_rank, "k_weight size must match kv_lora_rank"); | ||
| TORCH_CHECK(kv_c.dim() == 3 && k_pe.dim() == 3, "Only prefill mode with GQA is supported"); | ||
| TORCH_CHECK(kv_c.stride(-1) == 1, "kv_c stride(-1) must be equal to 1"); | ||
| TORCH_CHECK(k_pe.stride(-1) == 1, "k_pe stride(-1) must be equal to 1"); | ||
|
|
Comment on lines
+4235
to
+4240
| ". Supported: kv_lora_rank <= 512"); | ||
| } | ||
| } else { | ||
| TORCH_CHECK(false, | ||
| "fused_qk_norm_rope_concat_and_cache_mla currently only supports " | ||
| "kv_lora_rank<=512 and rot_dim=64. Got kv_lora_rank=", kv_lora_rank, |
Comment on lines
+2313
to
+2315
| const int32_t s = static_cast<int32_t>(slot_idx); | ||
| const int32_t bs_log2 = params.block_size_log2; | ||
| const int32_t kv_cache_offset = (s >> bs_log2) * params.block_stride + (s & ((1 << bs_log2) - 1)) * params.entry_stride; |
Comment on lines
+378
to
412
| m.def("fused_qk_norm_rope_group_quant_concat_and_cache_mla", \ | ||
| &aiter::fused_qk_norm_rope_group_quant_concat_and_cache_mla, \ | ||
| "fused_qk_norm_rope_group_quant_concat_and_cache_mla(" \ | ||
| " Tensor q_nope, Tensor q_pe," \ | ||
| " Tensor kv_c, Tensor! k_pe," \ | ||
| " Tensor k_weight," \ | ||
| " Tensor? q_weight," \ | ||
| " Tensor! kv_cache," \ | ||
| " Tensor! q_out," \ | ||
| " Tensor slot_mapping," \ | ||
| " Tensor q_scale," \ | ||
| " Tensor positions," \ | ||
| " Tensor cos_cache," \ | ||
| " Tensor sin_cache," \ | ||
| " float eps," \ | ||
| " int group_size," \ | ||
| " bool is_neox," \ | ||
| " bool is_nope_first)->()", \ | ||
| py::arg("q_nope"), \ | ||
| py::arg("q_pe"), \ | ||
| py::arg("kv_c"), \ | ||
| py::arg("k_pe"), \ | ||
| py::arg("k_weight"), \ | ||
| py::arg("q_weight"), \ | ||
| py::arg("kv_cache"), \ | ||
| py::arg("q_out"), \ | ||
| py::arg("slot_mapping"), \ | ||
| py::arg("q_scale"), \ | ||
| py::arg("positions"), \ | ||
| py::arg("cos_cache"), \ | ||
| py::arg("sin_cache"), \ | ||
| py::arg("eps"), \ | ||
| py::arg("group_size"), \ | ||
| py::arg("is_neox"), \ | ||
| py::arg("is_nope_first")); |
Comment on lines
+106
to
+123
| void fused_qk_norm_rope_group_quant_concat_and_cache_mla( | ||
| torch::Tensor& q_nope, // [num_tokens, num_heads, qk_lora_rank] | ||
| torch::Tensor& q_pe, // [num_tokens, num_heads, pe_dim] | ||
| torch::Tensor& kv_c, // [num_tokens, k_num_heads, kv_lora_rank] | ||
| torch::Tensor& k_pe, // [num_tokens, k_num_heads, pe_dim] | ||
| torch::Tensor& k_weight, // [kv_lora_rank] RMSNorm weights for K | ||
| std::optional<torch::Tensor> q_weight_opt, // [kv_lora_rank] RMSNorm weights for Q (optional) | ||
| torch::Tensor& kv_cache, // [num_blocks, block_size, k_num_heads, kv_lora_rank + pe_dim)] | ||
| torch::Tensor& q_out, // [num_tokens, num_heads, qk_lora_rank+pe_dim] | ||
| torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] | ||
| torch::Tensor& q_scale, // scale for q | ||
| torch::Tensor& positions, // [num_tokens] | ||
| torch::Tensor &cos_cache, // [max_position, rot_dim//2] | ||
| torch::Tensor &sin_cache, // [max_position, rot_dim//2] | ||
| double eps, // epsilon for RMS norm | ||
| int group_size, // group size for group quantization (default 64) | ||
| bool is_neox, | ||
| bool is_nope_first); |
Comment on lines
+1046
to
+1052
| def _f32_to_e8m0_ceil_q(x): | ||
| """Match kernel's ceil-style e8m0.""" | ||
| u32 = x.view(torch.int32).item() | ||
| exponent = (u32 >> 23) & 0xFF | ||
| if u32 & 0x7FFFFF: | ||
| exponent += 1 | ||
| return exponent |
| TORCH_CHECK(k_weight.size(0) == kv_lora_rank, "k_weight size must match kv_lora_rank"); | ||
| TORCH_CHECK(kv_c.dim() == 3 && k_pe.dim() == 3, "Only prefill mode with GQA is supported"); | ||
| TORCH_CHECK(kv_c.stride(-1) == 1, "kv_c stride(-1) must be equal to 1"); | ||
| TORCH_CHECK(k_pe.stride(-1) == 1, "k_pe stride(-1) must be equal to 1"); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist