Skip to content

Fuse qk norm group quant#2344

Draft
yzhou103 wants to merge 6 commits intomainfrom
fuse_qk_norm_group_quant
Draft

Fuse qk norm group quant#2344
yzhou103 wants to merge 6 commits intomainfrom
fuse_qk_norm_group_quant

Conversation

@yzhou103
Copy link
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

yzhou103 and others added 6 commits March 19, 2026 15:02
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…ze ops

- Add __launch_bounds__(64, 8) for better register allocation
- Replace int64_t with int32_t for token/slot indices to halve SGPR usage
- Move block_size from template param to MlaKernelParams as log2, use
  bitwise shift/mask instead of division/modulo (power-of-2 guaranteed)
- Promote runtime vars to constexpr: oob sizes, head_size, nope_offset,
  kv_lora_vec, reduce_thread_size
- Remove unused variables (qH_per_kH, num_kv_vecs, group_size shadow)
- Pre-compute token-level base offsets to avoid repeated stride multiplies
- Hoist group_id and inv_scale out of inner quantization loops
- Use bitshift for group_id: (tid * vec_size_i) >> 6 instead of / 64

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@yzhou103 yzhou103 requested review from a team and Copilot March 19, 2026 07:03
@yzhou103 yzhou103 marked this pull request as draft March 19, 2026 07:03
@github-actions
Copy link
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2344 --add-label <label>

else:
err_kv = checkAllclose(kv_cache, ref_kv_cache, msg="bf16 kv result compared with ref")
err_q_out = checkAllclose(q_out[..., :kv_lora_rank], ref_q_out[..., :kv_lora_rank], msg="bf16 q_nope result compared with ref")
err_q_pe = checkAllclose(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F841> reported by reviewdog 🐶
Local variable err_q_pe is assigned to but never used

Suggested change
err_q_pe = checkAllclose(
checkAllclose(

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new fused ROCm cache kernel variant that combines RMSNorm + RoPE with group-quantized KV (and optionally group-quantized Q), and extends the Python bindings/tests to exercise the new path.

Changes:

  • Add a new HIP kernel + C++ entrypoint for fused (Q/K) norm + RoPE + group-quant + cache write for MLA prefill (GQA).
  • Expose the new op via pybind and aiter Python stubs.
  • Add new op-tests benchmarks + torch reference implementations for correctness/perf comparisons.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
csrc/kernels/cache_kernels.cu Adds the new optimized fused kernel and a new C++ interface; also adjusts existing prefill launch config.
csrc/include/cache.h Declares the new C++ API.
csrc/include/rocm_ops.hpp Exposes the new op via pybind.
aiter/ops/cache.py Adds the Python compile_ops stub for the new op.
op_tests/test_concat_cache_mla.py Adds torch references + benchmark/correctness harnesses for the new fused path(s).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


if (use_optimized_prefill) {
dim3 grid(num_tokens);
dim3 grid(num_tokens, num_kv_heads); // Use 2D grid: x=tokens, y=kv_heads
Comment on lines +2510 to +2547
const float inverted_DTYPE_MAX =
std::is_same_v<query_t, ck_tile::fp4x2_t>
? 0.25f
: (1.f / ck_tile::type_convert<float>(ck_tile::numeric<cache_t>::max()));
const float q_group_scale = thread_max * inverted_DTYPE_MAX;

query_t* q_nope_out_base = reinterpret_cast<query_t*>(
q_out + token_qout_base + q_head_idx * params.q_out_stride_1 + (is_nope_first ? 0 : pe_dim));

float q_inv_scale;
if constexpr (std::is_same_v<query_t, ck_tile::fp8_t>) {
uint32_t u32 = __builtin_bit_cast(uint32_t, q_group_scale);
uint32_t exponent = (u32 >> 23) & 0xFF;
if (u32 & 0x7FFFFF) exponent += 1;
if (tid % reduce_thread_size == 0) {
const int group_id = (tid * vec_size_i) >> 6;
auto* tmp = reinterpret_cast<uint8_t*>(q_nope_out_base + kv_lora_dim);
if (tid < kv_lora_vec) {
tmp[group_id] = static_cast<uint8_t>(exponent);
} else {
tmp[group_id] = 0;
}
}
uint32_t e8m0_u32 = exponent << 23;
q_inv_scale = has_q_data ? (1.0f / __builtin_bit_cast(float, e8m0_u32)) : 0.0f;
} else {
q_inv_scale = has_q_data ? (1.0f / q_group_scale) : 0.0f;
}

// Store quantized Q data
if (has_q_data) {
const uint32_t q_offset = tid * vec_size_i;
opus_vec_q vec_out;
#pragma unroll
for (int i = 0; i < vec_size_i; i++) {
vec_out[i] = ck_tile::type_convert<cache_t>(q_fp32[i] * q_inv_scale);
}
q_buffer_o.template store<vec_size_o>(vec_out, q_offset);
Comment on lines +4115 to +4132
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(-1);
int pe_dim = k_pe.size(-1);
int block_size = kv_cache.size(1);
int num_heads = q_nope.size(1);
int qk_lora_rank = q_nope.size(-1);
int rot_dim = cos_cache.size(-1) * 2;

// Validate dimensions
TORCH_CHECK(q_nope.dim() == q_pe.dim());
TORCH_CHECK(q_nope.size(1) == q_pe.size(1));
TORCH_CHECK(q_out.size(2) == qk_lora_rank + pe_dim);
TORCH_CHECK(kv_lora_rank == qk_lora_rank, "kv_lora_rank and qk_lora_rank must be the same");
TORCH_CHECK(k_weight.size(0) == kv_lora_rank, "k_weight size must match kv_lora_rank");
TORCH_CHECK(kv_c.dim() == 3 && k_pe.dim() == 3, "Only prefill mode with GQA is supported");
TORCH_CHECK(kv_c.stride(-1) == 1, "kv_c stride(-1) must be equal to 1");
TORCH_CHECK(k_pe.stride(-1) == 1, "k_pe stride(-1) must be equal to 1");

Comment on lines +4235 to +4240
". Supported: kv_lora_rank <= 512");
}
} else {
TORCH_CHECK(false,
"fused_qk_norm_rope_concat_and_cache_mla currently only supports "
"kv_lora_rank<=512 and rot_dim=64. Got kv_lora_rank=", kv_lora_rank,
Comment on lines +2313 to +2315
const int32_t s = static_cast<int32_t>(slot_idx);
const int32_t bs_log2 = params.block_size_log2;
const int32_t kv_cache_offset = (s >> bs_log2) * params.block_stride + (s & ((1 << bs_log2) - 1)) * params.entry_stride;
Comment on lines +378 to 412
m.def("fused_qk_norm_rope_group_quant_concat_and_cache_mla", \
&aiter::fused_qk_norm_rope_group_quant_concat_and_cache_mla, \
"fused_qk_norm_rope_group_quant_concat_and_cache_mla(" \
" Tensor q_nope, Tensor q_pe," \
" Tensor kv_c, Tensor! k_pe," \
" Tensor k_weight," \
" Tensor? q_weight," \
" Tensor! kv_cache," \
" Tensor! q_out," \
" Tensor slot_mapping," \
" Tensor q_scale," \
" Tensor positions," \
" Tensor cos_cache," \
" Tensor sin_cache," \
" float eps," \
" int group_size," \
" bool is_neox," \
" bool is_nope_first)->()", \
py::arg("q_nope"), \
py::arg("q_pe"), \
py::arg("kv_c"), \
py::arg("k_pe"), \
py::arg("k_weight"), \
py::arg("q_weight"), \
py::arg("kv_cache"), \
py::arg("q_out"), \
py::arg("slot_mapping"), \
py::arg("q_scale"), \
py::arg("positions"), \
py::arg("cos_cache"), \
py::arg("sin_cache"), \
py::arg("eps"), \
py::arg("group_size"), \
py::arg("is_neox"), \
py::arg("is_nope_first"));
Comment on lines +106 to +123
void fused_qk_norm_rope_group_quant_concat_and_cache_mla(
torch::Tensor& q_nope, // [num_tokens, num_heads, qk_lora_rank]
torch::Tensor& q_pe, // [num_tokens, num_heads, pe_dim]
torch::Tensor& kv_c, // [num_tokens, k_num_heads, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, k_num_heads, pe_dim]
torch::Tensor& k_weight, // [kv_lora_rank] RMSNorm weights for K
std::optional<torch::Tensor> q_weight_opt, // [kv_lora_rank] RMSNorm weights for Q (optional)
torch::Tensor& kv_cache, // [num_blocks, block_size, k_num_heads, kv_lora_rank + pe_dim)]
torch::Tensor& q_out, // [num_tokens, num_heads, qk_lora_rank+pe_dim]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
torch::Tensor& q_scale, // scale for q
torch::Tensor& positions, // [num_tokens]
torch::Tensor &cos_cache, // [max_position, rot_dim//2]
torch::Tensor &sin_cache, // [max_position, rot_dim//2]
double eps, // epsilon for RMS norm
int group_size, // group size for group quantization (default 64)
bool is_neox,
bool is_nope_first);
Comment on lines +1046 to +1052
def _f32_to_e8m0_ceil_q(x):
"""Match kernel's ceil-style e8m0."""
u32 = x.view(torch.int32).item()
exponent = (u32 >> 23) & 0xFF
if u32 & 0x7FFFFF:
exponent += 1
return exponent
TORCH_CHECK(k_weight.size(0) == kv_lora_rank, "k_weight size must match kv_lora_rank");
TORCH_CHECK(kv_c.dim() == 3 && k_pe.dim() == 3, "Only prefill mode with GQA is supported");
TORCH_CHECK(kv_c.stride(-1) == 1, "kv_c stride(-1) must be equal to 1");
TORCH_CHECK(k_pe.stride(-1) == 1, "k_pe stride(-1) must be equal to 1");
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants