Skip to content

Add CK-free fallback for fused QKNorm+RoPE+Cache#279

Open
sunway513 wants to merge 1 commit intoROCm:mainfrom
sunway513:feat/ckfree-fused-rope-fallback
Open

Add CK-free fallback for fused QKNorm+RoPE+Cache#279
sunway513 wants to merge 1 commit intoROCm:mainfrom
sunway513:feat/ckfree-fused-rope-fallback

Conversation

@sunway513
Copy link
Collaborator

Summary

  • Wraps fused_qk_norm_rope_cache_quant_shuffle in attention_mha.py:rope_cache() with try-except
  • On CK-free builds where the HIP fused kernel is unavailable, gracefully falls through to the existing non-fused Triton path (rotary_emb + q/k_norm + reshape_and_cache)
  • qkv.clone() backup before fused kernel call, restored on failure (protects against partial in-place writes)
  • Log-once warning via class attribute to avoid log spam on hot path
  • q_norm is None guard on middle path preserves original elif invariant

Related

Test Results

# Test Result
1 Import: from atom.model_ops.attention_mha import PagedAttentionImpl PASS
2 Mock: fused kernel raises → fallback executes (rotary_emb, q_norm, k_norm all called, qkv restored, log-once) PASS
3 E2E: Qwen3-0.6B with fused kernel disabled BLOCKED (separate FMHA CK-free JIT issue)
4 Numerical: fused vs non-fused cosine similarity BLOCKED (same FMHA issue)

Known Limitation

E2E tests are blocked by a separate AITER-side issue: module_fmha_v3_varlen_fwd JIT compilation fails in CK-free builds because the ASM attention kernels still depend on CK-Tile headers (fmha_fwd.hpp). Our rope_cache fallback works correctly through model load + warmup (logs confirm individual module_rope_pos_fwd and module_cache loaded successfully).

Shengnan's team is working on removing the CK header dependency from ASM attention kernels in AITER. Once that lands, the full CK-free E2E path (this PR + FMHA fix) will be unblocked.

Wrap fused_qk_norm_rope_cache_quant_shuffle in try-except so that
CK-free builds gracefully fall through to the non-fused Triton path
(rotary_emb + q/k_norm + reshape_and_cache) instead of crashing.

Key safety measures:
- qkv.clone() backup before fused kernel call, restored on failure
  (protects against partial in-place writes before exception)
- log-once warning via class attribute to avoid log spam
- q_norm is None guard on middle path preserves original elif invariant
Copilot AI review requested due to automatic review settings March 9, 2026 04:33
@sunway513
Copy link
Collaborator Author

@gyohuangxin @ZhiweiYan-96 @valarLip — requesting review on this CK-free fallback for rope_cache().

What this does: When fused_qk_norm_rope_cache_quant_shuffle (HIP fused kernel) is unavailable in CK-free builds, we catch the exception and fall through to the existing non-fused Triton path (individual rotary_emb → q/k_norm → reshape_and_cache). This is a companion fix to #278 (CK-free Docker builds).

Safety measures: qkv.clone() backup before the fused call (restored on failure), log-once warning, and q_norm is None guard to preserve the original elif routing invariant.

Known limitation: Full E2E test is blocked by a separate AITER-side issue — the ASM attention kernels (module_fmha_v3_varlen_fwd) still depend on CK-Tile headers (fmha_fwd.hpp) even in CK-free builds. Shengnan's team is working on removing this dependency. Once that lands, the full CK-free inference path will be unblocked.

Unit tests (import + mock fallback with 5 assertions) pass cleanly.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a runtime fallback in the server-mode paged-attention rope_cache() path so CK-free (no HIP fused kernel) builds can gracefully use the existing non-fused Triton/py path instead of failing.

Changes:

  • Add module logger and warning-once behavior for fused-kernel failures.
  • Wrap fused_qk_norm_rope_cache_quant_shuffle(...) in try/except and fall back to the non-fused path on failure.
  • Add q_norm is None guard on the Triton fused path to preserve the intended branch invariants.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +19 to 22
logger = logging.getLogger("atom")

from atom.plugin.prepare import is_plugin_mode, is_vllm
from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger = ... is executed before the subsequent from atom.plugin... imports, which will trigger Ruff E402 (module level import not at top of file). Move the logger initialization below all imports (or move these imports above the logger assignment) so lint passes.

Copilot uses AI. Check for mistakes.
Comment on lines +134 to +137
qkv_backup = qkv.clone()
try:
fused_qk_norm_rope_cache_quant_shuffle(
qkv,
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On CK-free builds where the fused kernel consistently fails, this will still clone() and then raise/catch an exception on every rope_cache() call, which is very expensive on the decode hot path. Cache the failure (e.g., a class/instance flag) and skip the fused attempt entirely after the first failure; similarly, after the first successful fused call you can avoid taking a backup on subsequent calls.

Copilot uses AI. Check for mistakes.
Comment on lines 133 to +135
):
fused_qk_norm_rope_cache_quant_shuffle(
qkv,
num_heads_q=self.num_heads,
num_heads_k=self.num_kv_heads,
num_heads_v=self.num_kv_heads,
head_dim=self.head_dim,
eps=self.q_norm.eps,
qw=self.q_norm.weight,
kw=self.k_norm.weight,
cos_sin_cache=self.rotary_emb.cos_sin_cache,
is_neox_style=self.rotary_emb.is_neox_style,
pos_ids=position,
k_cache=k_cache,
v_cache=v_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype=(
"auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype
),
k_scale=k_scale,
v_scale=v_scale,
)
qkv_backup = qkv.clone()
try:
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qkv_backup = qkv.clone() is outside the try: and assumes qkv is always a Tensor. Since qkv is an optional argument, a None (or clone failure) would raise before reaching the except, preventing the intended fallback. Consider adding qkv is not None to the fused-path condition and/or moving backup creation inside the try: with a guarded restore only when the backup exists.

Copilot uses AI. Check for mistakes.
@sunway513
Copy link
Collaborator Author

Related AITER PR: ROCm/aiter#2227 — Adds a pure-Triton implementation of fused_rope_rms() (QKNorm + RoPE) to AITER. Once merged, RotaryEmbeddingFusedQKNorm will have a working Triton path even in CK-free builds, complementing the fallback in this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants