Add CK-free fallback for fused QKNorm+RoPE+Cache#279
Add CK-free fallback for fused QKNorm+RoPE+Cache#279
Conversation
Wrap fused_qk_norm_rope_cache_quant_shuffle in try-except so that CK-free builds gracefully fall through to the non-fused Triton path (rotary_emb + q/k_norm + reshape_and_cache) instead of crashing. Key safety measures: - qkv.clone() backup before fused kernel call, restored on failure (protects against partial in-place writes before exception) - log-once warning via class attribute to avoid log spam - q_norm is None guard on middle path preserves original elif invariant
|
@gyohuangxin @ZhiweiYan-96 @valarLip — requesting review on this CK-free fallback for What this does: When Safety measures: Known limitation: Full E2E test is blocked by a separate AITER-side issue — the ASM attention kernels ( Unit tests (import + mock fallback with 5 assertions) pass cleanly. |
There was a problem hiding this comment.
Pull request overview
Adds a runtime fallback in the server-mode paged-attention rope_cache() path so CK-free (no HIP fused kernel) builds can gracefully use the existing non-fused Triton/py path instead of failing.
Changes:
- Add module logger and warning-once behavior for fused-kernel failures.
- Wrap
fused_qk_norm_rope_cache_quant_shuffle(...)in try/except and fall back to the non-fused path on failure. - Add
q_norm is Noneguard on the Triton fused path to preserve the intended branch invariants.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| logger = logging.getLogger("atom") | ||
|
|
||
| from atom.plugin.prepare import is_plugin_mode, is_vllm | ||
| from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode |
There was a problem hiding this comment.
logger = ... is executed before the subsequent from atom.plugin... imports, which will trigger Ruff E402 (module level import not at top of file). Move the logger initialization below all imports (or move these imports above the logger assignment) so lint passes.
| qkv_backup = qkv.clone() | ||
| try: | ||
| fused_qk_norm_rope_cache_quant_shuffle( | ||
| qkv, |
There was a problem hiding this comment.
On CK-free builds where the fused kernel consistently fails, this will still clone() and then raise/catch an exception on every rope_cache() call, which is very expensive on the decode hot path. Cache the failure (e.g., a class/instance flag) and skip the fused attempt entirely after the first failure; similarly, after the first successful fused call you can avoid taking a backup on subsequent calls.
| ): | ||
| fused_qk_norm_rope_cache_quant_shuffle( | ||
| qkv, | ||
| num_heads_q=self.num_heads, | ||
| num_heads_k=self.num_kv_heads, | ||
| num_heads_v=self.num_kv_heads, | ||
| head_dim=self.head_dim, | ||
| eps=self.q_norm.eps, | ||
| qw=self.q_norm.weight, | ||
| kw=self.k_norm.weight, | ||
| cos_sin_cache=self.rotary_emb.cos_sin_cache, | ||
| is_neox_style=self.rotary_emb.is_neox_style, | ||
| pos_ids=position, | ||
| k_cache=k_cache, | ||
| v_cache=v_cache, | ||
| slot_mapping=attn_metadata.slot_mapping, | ||
| kv_cache_dtype=( | ||
| "auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype | ||
| ), | ||
| k_scale=k_scale, | ||
| v_scale=v_scale, | ||
| ) | ||
| qkv_backup = qkv.clone() | ||
| try: |
There was a problem hiding this comment.
qkv_backup = qkv.clone() is outside the try: and assumes qkv is always a Tensor. Since qkv is an optional argument, a None (or clone failure) would raise before reaching the except, preventing the intended fallback. Consider adding qkv is not None to the fused-path condition and/or moving backup creation inside the try: with a guarded restore only when the backup exists.
|
Related AITER PR: ROCm/aiter#2227 — Adds a pure-Triton implementation of |
Summary
fused_qk_norm_rope_cache_quant_shuffleinattention_mha.py:rope_cache()with try-exceptqkv.clone()backup before fused kernel call, restored on failure (protects against partial in-place writes)q_norm is Noneguard on middle path preserves originalelifinvariantRelated
Test Results
from atom.model_ops.attention_mha import PagedAttentionImplKnown Limitation
E2E tests are blocked by a separate AITER-side issue:
module_fmha_v3_varlen_fwdJIT compilation fails in CK-free builds because the ASM attention kernels still depend on CK-Tile headers (fmha_fwd.hpp). Ourrope_cachefallback works correctly through model load + warmup (logs confirm individualmodule_rope_pos_fwdandmodule_cacheloaded successfully).Shengnan's team is working on removing the CK header dependency from ASM attention kernels in AITER. Once that lands, the full CK-free E2E path (this PR + FMHA fix) will be unblocked.