[Perf] Fuse RoPE + KV cache update for MLA backends (#35879): eliminate extra Inductor copy + enable on ROCm#163
Conversation
Follow-up to the cache+rope fusion in PR vllm-project#35879. Eliminates an extra Triton copy kernel that Inductor inserts on every MLA layer when the fused custom op mutates a strided view of `k_pe` that is also consumed by a second opaque custom op in a downstream partition. Problem: with the fusion enabled, the post-fusion FX graph consumes a strided slice of the qkv_a_proj output in both `fused_concat_and_cache_mla_rope` (which mutates it in-place) and the downstream `unified_mla_attention_with_output`. Inductor partitions at the second op and, because `k_pe` is a strided alias mutated then re-read across the partition boundary by an opaque consumer, it materialises a fresh contiguous buffer to keep aliasing safe. That materialisation shows up once per MLA layer (~4.7us * 200 steps on Kimi-K2-Thinking-MXFP4, MI355x), nearly doubling the per-layer fused-op cost. Fix: make the C++ kernel write the rotated `k_pe` into a fresh contiguous output tensor `k_pe_out` (shape `[num_tokens, 1, rot_dim]`) instead of mutating the strided input view. The kv_cache write inside the kernel still sees the rotated value (rotation stays in registers; only the final store target changes). Changes: * C++ kernel + ABI: `concat_and_cache_mla_rope_fused` gains a trailing `Tensor! k_pe_out` output. * Python op wrapper (`vllm/_custom_ops.py`): allocates `k_pe_out` contiguous like `k_pe` and returns it; drops `k_pe` from `mutates_args`. * Fusion pass (`kv_cache_mla_rope_fusion.py`): rewires downstream consumers to the new contiguous `k_pe_out`. * `fix_functionalization.py`: updated to match the new op signature (3 outputs; only the third - `k_pe_out` - is wired to the new contiguous tensor; the first two get rank-1 `new_empty(0)` slots). Based on AMD-MLPerf/workloads-inference@ee50ce9. Made-with: Cursor Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in PRs do not trigger a full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. Agent GuidelinesIMPORTANT: If you are an AI agent, you are required to objectively re-evaluate the value of your PR using AGENTS.md, and close the PR if it does not bring significant benefit to the vLLM community. Failure to do so may result in an immediate ban. 🚀 |
|
@rbrugaro-amd can you provide more information on the rematerialization? We should file a pytorch issue with a minimal repro if we can. |
|
In
What Inductor does with it After fusion, the post-grad FX graph has this per-layer pattern (layer 0 shown): # fused op mutates k_pe (unsqueeze_1) in-place, returns bf16[0]
fused_concat_and_cache_mla_rope_default_60: "bf16[0]" = torch.ops.vllm.fused_concat_and_cache_mla_rope.default(
arg9_1, slice_tensor_60, unsqueeze_1, getitem_1523, ...)
# attention op reads the SAME strided view across the Inductor partition boundary
auto_functionalized = torch.ops.higher_order.auto_functionalized(
torch.ops.vllm.unified_mla_attention_with_output.default,
q=reshape_default_60, kv_c_normed=getitem_1523, k_pe=unsqueeze_1, ...)Inductor partitions at @triton.jit
def triton_poi_fused_fused_concat_and_cache_mla_rope_slice_split_with_sizes_unsqueeze_view_10(
in_ptr0, out_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = (xindex % 64) # rot_dim index
x1 = xindex // 64 # token index
x2 = xindex
tmp0 = tl.load(in_ptr0 + (2048 + x0 + 2112*x1), xmask).to(tl.float32)
tl.store(out_ptr0 + (2048 + x0 + 2112*x1), tmp0, xmask) # identity write-back (no-op)
tl.store(out_ptr1 + (x2), tmp0, xmask) # strided → contiguous copyThe stride pattern How this PR fixes it The fused op now returns a fresh contiguous # fused op returns contiguous k_pe_out
fused_concat_and_cache_mla_rope_default_60: "bf16[s72, 1, 64]" = torch.ops.vllm.fused_concat_and_cache_mla_rope.default(
arg9_1, slice_tensor_60, unsqueeze_1, getitem_1523, ...)
# attention op receives the contiguous output — no alias conflict
auto_functionalized = torch.ops.higher_order.auto_functionalized(
torch.ops.vllm.unified_mla_attention_with_output.default,
q=reshape_default_60, kv_c_normed=getitem_1523, k_pe=fused_concat_and_cache_mla_rope_default_60, ...)No strided alias crosses the partition boundary, so Inductor has nothing to |
|
Thank you for the detailed description. cc @zou3519 @BoyuanFeng @eellison is this a fundamental limitation of Inductor, or is this something we can address? |
|
Great fix, I'm looking forward to the benchmark results! 😄 |
As i understand, this is a known limitation of inductor with |
|
@attila-dusnoki-htec but we have the custom |
|
@ElizaWszola @rbrugaro-amd let me look at this a bit more closely to understand if we can avoid the fix altogether |
I had to double check it and i was wrong. |
|
@attila-dusnoki-htec can we just remove the |
|
Just a brief update, I was able to reproduce this with Full command: Graph: Turn into the following partitioned code: For some reason |
|
Without inductor partition (just Triton kernel: |
|
cc @ProExpertProg would it be possible to share either the fx graph runnable or the tlparse ? that would help make it easier to investigate |
|
I'll send one over once I'm back at my laptop, repro should also be easy from the commands above on the base branch of this PR |
|
I tried to add the Kernel signature + meta triton_meta={'signature': {
'in_ptr0': '*bf16',
'out_ptr0': '*bf16',
'out_ptr1': '*bf16', # <- contiguous staging buffer
'xnumel': 'i32', 'XBLOCK': 'constexpr'}, ...},
inductor_meta={..., 'mutated_arg_names': ['in_ptr0', 'out_ptr0'], ...},
def triton_poi_fused_..._1(in_ptr0, out_ptr0, out_ptr1, xnumel, XBLOCK):With flexible_layout (kernel_74.py): triton_meta={'signature': {
'in_ptr0': '*bf16',
'out_ptr0': '*bf16', # <- only two buffers now
'xnumel': 'i32', 'XBLOCK': 'constexpr'}, ...},
inductor_meta={..., 'mutated_arg_names': ['in_ptr0', 'out_ptr0'], ...},
def triton_poi_fused_..._1(in_ptr0, out_ptr0, xnumel, XBLOCK):Orchestrator call site Baseline: buf13 = torch.ops.vllm.fused_concat_and_cache_mla_rope.default(
arg9_1,
reinterpret_tensor(buf11, (s72, 16, 64), (3072, 192, 1), 128),
reinterpret_tensor(buf6, (s72, 1, 64), (2112, 64, 1), 2048),
buf9, ...)
buf18 = empty_strided_cuda((s72, 1, 64), (64, 64, 1), torch.bfloat16) # <- per-layer alloc
triton_poi_fused_..._1.run(buf6, buf6, buf18, ..., stream=stream0)
return (buf2, buf9, buf11, buf12, buf16, buf18, ...) # <- buf18 crosses partition
# next partition
torch.ops.vllm.unified_mla_attention_with_output.default(..., buf18, ...) # <- contiguous k_peWith flexible_layout: buf13 = torch.ops.vllm.fused_concat_and_cache_mla_rope.default(
arg9_1,
reinterpret_tensor(buf11, (s72, 16, 64), (3072, 192, 1), 128),
reinterpret_tensor(buf6, (s72, 1, 64), (2112, 64, 1), 2048),
buf9, ...)
# <- no alloc
triton_poi_fused_..._1.run(buf6, buf6, ..., stream=stream0)
return (buf2, buf6, buf9, buf11, buf12, buf16, ...) # <- buf6 itself crosses
# next partition
torch.ops.vllm.unified_mla_attention_with_output.default(
...,
reinterpret_tensor(buf6, (s72, 1, 64), (2112, 0, 1), 2048), # <- strided view, free
...)Kernel Baseline (load + store-back + store-to-contiguous): tmp0 = tl.load (in_ptr0 + (2048 + x0 + 2112*x1), xmask)
tl.store(out_ptr0 + (2048 + x0 + 2112*x1), tmp0, xmask) # buf6 -> buf6
tl.store(out_ptr1 + (x2), tmp0, xmask) # -> contiguous buf18With flexible_layout (load + store-back only, out_ptr0 == in_ptr0 == buf6): tmp0 = tl.load (in_ptr0 + (2048 + x0 + 2112*x1), xmask)
tl.store(out_ptr0 + (2048 + x0 + 2112*x1), tmp0, xmask) # buf6 -> buf6 (self-copy)Inductor honored the tag, so the kernel is no longer there for layout reasons it seems, but the same-named realize-marker survives :/ |
|
@attila-dusnoki-htec could you collect a tlparse with and without the tag change? Also with and without inductor partition would be good (without we can use splitting_ops=[]) |
|
I'm actually not seeing the identity copy, and the extra kernel is succesfully eliminated for me with
All were run with the base command: Dynamo partition ( Inductor partition ( Dynamo partition ( Inductor partition ( |
|
With these 2 changes, i'm also observing the absence of the triton kernel: diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py
--- a/vllm/model_executor/layers/attention/mla_attention.py
+++ b/vllm/model_executor/layers/attention/mla_attention.py
@@ -1009,6 +1009,7 @@ direct_register_custom_op(
op_name="unified_mla_attention_with_output",
op_func=unified_mla_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_mla_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
+ tags=(torch.Tag.flexible_layout,),
)diff --git a/vllm/compilation/passes/fusion/kv_cache_mla_rope_fusion.py b/vllm/compilation/passes/fusion/kv_cache_mla_rope_fusion.py
--- a/vllm/compilation/passes/fusion/kv_cache_mla_rope_fusion.py
+++ b/vllm/compilation/passes/fusion/kv_cache_mla_rope_fusion.py
@@ -86,7 +86,6 @@ direct_register_custom_op(
op_name="fused_concat_and_cache_mla_rope",
op_func=fused_concat_and_cache_mla_rope_impl,
fake_impl=fused_concat_and_cache_mla_rope_fake,
- mutates_args=["q_pe", "k_pe"],
)Only adding the tag was not enough. The accuracy seems correct on my use case. |
|
@attila-dusnoki-htec the mutates_args change is probably not correct :/ |
Of course not, a better solution would the actual PR change. |
|
Looks like inductor partition + |
Follow-up to vllm-project#35879. Two changes on top of the fused RoPE + KV cache update for MLA backends:
1. Return contiguous
k_pefrom fused kernel (eliminate extra Triton copy)Problem: With the fusion enabled, the post-fusion FX graph has a strided slice of the
qkv_a_projoutput consumed back-to-back by two opaque custom ops:Because
k_peis a strided alias mutated by the first op and then re-read across the Inductor partition boundary by the second, Inductor materialises it into a fresh contiguous buffer to keep aliasing safe. That materialisation shows up as:once per MLA layer (~4.7 µs × 200 steps on Kimi-K2-Thinking-MXFP4/MI355x), nearly doubling the per-layer fused-op cost.
Fix: Make the C++ kernel write the rotated
k_peinto a fresh contiguous output tensork_pe_out([num_tokens, 1, rot_dim]) instead of mutating the strided input view. The kv_cache write inside the kernel still sees the rotated value (rotation stays in registers; only the final store target changes).Changes:
concat_and_cache_mla_rope_fusedgains a trailingTensor! k_pe_outoutput.vllm/_custom_ops.py): allocatesk_pe_outcontiguous likek_peand returns it; dropsk_pefrommutates_args.kv_cache_mla_rope_fusion.py): rewires downstream consumers to the new contiguousk_pe_out.fix_functionalization.py: updated to match the new op signature (3 outputs; only the third —k_pe_out— is wired to the new contiguous tensor; the first two get rank-1new_empty(0)slots).2. Enable fusion on ROCm (CUDA-alike platforms)
The config guard in
PassConfig.__post_init__disabledenable_cache_mla_rope_fusionwhencurrent_platform.is_cuda()returnedFalse, which silently prevents the fusion on ROCm even though the pass itself is imported under anis_cuda_alike()gate inpass_manager.py. Relaxed the guard tois_cuda_alike()so ROCm (HIP) gets the same fused path.before and after this fusion
before (PR35789)
after (this PR)
Accuracy validation
Tested with Kimi-K2-Thinking-MXFP4 on MI355x with fusion enabled (
enable_cache_mla_rope_fusion: True):Test plan
Acknowledgements
Thanks @attila-dusnoki-htec for contiguous
k_pe_outkernel change