From eca6749b9f112e5045cbc7d400f8e37f13cfc62f Mon Sep 17 00:00:00 2001 From: sararora Date: Mon, 16 Mar 2026 23:11:34 -0500 Subject: [PATCH] Add AITER fused RoPE dispatch to FusedRoPEFunc When the AITER library is available on ROCm, dispatch FusedRoPEFunc forward and backward to aiter.ops.rope HIP kernels for higher throughput on AMD GPUs. Guarded by NVTE_USE_AITER_FUSED_ROPE env var (default "1"). Only activates for sbhd format, non-interleaved, cp_size==1, and when cu_seqlens/start_positions are absent. Falls back to TE-native tex.fused_rope_{forward,backward} otherwise. Made-with: Cursor --- transformer_engine/pytorch/attention/rope.py | 98 ++++++++++++++++---- 1 file changed, 79 insertions(+), 19 deletions(-) diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 613ecb1a7..b9102f22f 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -5,12 +5,32 @@ """ Rotary Position Embedding implementation of different types along with helper functions """ +import os +import warnings from typing import Optional, Tuple, Union, List import torch import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat +# Optional AITER fused RoPE dispatch (ROCm). +# Controlled by NVTE_USE_AITER_FUSED_ROPE (default "1" = enabled). +_aiter_rope_fwd = None +_aiter_rope_bwd = None +_HAVE_AITER_ROPE = False +_USE_AITER_FUSED_ROPE = os.environ.get("NVTE_USE_AITER_FUSED_ROPE", "1") == "1" + +if _USE_AITER_FUSED_ROPE: + try: + from aiter.ops.rope import rope_fwd as _aiter_rope_fwd, rope_bwd as _aiter_rope_bwd + _HAVE_AITER_ROPE = True + except Exception as e: + warnings.warn( + f"AITER fused RoPE import failed ({type(e).__name__}: {e}). " + f"Falling back to TE native kernels. " + f"Set NVTE_USE_AITER_FUSED_ROPE=0 to suppress this warning." + ) + __all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] @@ -116,8 +136,24 @@ class FusedRoPEFunc(torch.autograd.Function): This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern. + + When AITER is available and the configuration is compatible (sbhd, non-interleaved, + cp_size==1, no cu_seqlens/start_positions), the forward and backward are dispatched + to AITER's fused HIP kernels for higher throughput on AMD GPUs. """ + @staticmethod + def _can_use_aiter(tensor_format, interleaved, cu_seqlens, cp_size, start_positions): + """Return True when AITER fused rope kernels can handle this configuration.""" + return ( + _HAVE_AITER_ROPE + and tensor_format == "sbhd" + and not interleaved + and cu_seqlens is None + and cp_size == 1 + and start_positions is None + ) + @staticmethod def forward( ctx, @@ -139,21 +175,36 @@ def forward( "bshd", "thd", ), f"Unsupported tensor_format: {tensor_format}." - output = tex.fused_rope_forward( - t, - freqs, - start_positions, - QKVFormat[tensor_format], - interleaved, - cu_seqlens, - cp_size, - cp_rank, + + use_aiter = FusedRoPEFunc._can_use_aiter( + tensor_format, interleaved, cu_seqlens, cp_size, start_positions ) + + if use_aiter: + rotate_style = 1 if interleaved else 0 + output = _aiter_rope_fwd( + t, freqs, rotate_style, + False, # reuse_freqs_front_part + False, # nope_first + ) + else: + output = tex.fused_rope_forward( + t, + freqs, + start_positions, + QKVFormat[tensor_format], + interleaved, + cu_seqlens, + cp_size, + cp_rank, + ) + ctx.save_for_backward(freqs, cu_seqlens, start_positions) ctx.tensor_format = tensor_format ctx.cp_size = cp_size ctx.cp_rank = cp_rank ctx.interleaved = interleaved + ctx.use_aiter = use_aiter return output @@ -161,16 +212,25 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """Fused RoPE backward.""" freqs, cu_seqlens, start_positions = ctx.saved_tensors - grad_input = tex.fused_rope_backward( - grad_output, - freqs, - start_positions, - QKVFormat[ctx.tensor_format], - ctx.interleaved, - cu_seqlens, - ctx.cp_size, - ctx.cp_rank, - ) + + if ctx.use_aiter: + rotate_style = 1 if ctx.interleaved else 0 + grad_input = _aiter_rope_bwd( + grad_output, freqs, rotate_style, + False, # reuse_freqs_front_part + False, # nope_first + ) + else: + grad_input = tex.fused_rope_backward( + grad_output, + freqs, + start_positions, + QKVFormat[ctx.tensor_format], + ctx.interleaved, + cu_seqlens, + ctx.cp_size, + ctx.cp_rank, + ) return grad_input, None, None, None, None, None, None, None, None