diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 613ecb1a7..b9102f22f 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -5,12 +5,32 @@ """ Rotary Position Embedding implementation of different types along with helper functions """ +import os +import warnings from typing import Optional, Tuple, Union, List import torch import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat +# Optional AITER fused RoPE dispatch (ROCm). +# Controlled by NVTE_USE_AITER_FUSED_ROPE (default "1" = enabled). +_aiter_rope_fwd = None +_aiter_rope_bwd = None +_HAVE_AITER_ROPE = False +_USE_AITER_FUSED_ROPE = os.environ.get("NVTE_USE_AITER_FUSED_ROPE", "1") == "1" + +if _USE_AITER_FUSED_ROPE: + try: + from aiter.ops.rope import rope_fwd as _aiter_rope_fwd, rope_bwd as _aiter_rope_bwd + _HAVE_AITER_ROPE = True + except Exception as e: + warnings.warn( + f"AITER fused RoPE import failed ({type(e).__name__}: {e}). " + f"Falling back to TE native kernels. " + f"Set NVTE_USE_AITER_FUSED_ROPE=0 to suppress this warning." + ) + __all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] @@ -116,8 +136,24 @@ class FusedRoPEFunc(torch.autograd.Function): This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern. + + When AITER is available and the configuration is compatible (sbhd, non-interleaved, + cp_size==1, no cu_seqlens/start_positions), the forward and backward are dispatched + to AITER's fused HIP kernels for higher throughput on AMD GPUs. """ + @staticmethod + def _can_use_aiter(tensor_format, interleaved, cu_seqlens, cp_size, start_positions): + """Return True when AITER fused rope kernels can handle this configuration.""" + return ( + _HAVE_AITER_ROPE + and tensor_format == "sbhd" + and not interleaved + and cu_seqlens is None + and cp_size == 1 + and start_positions is None + ) + @staticmethod def forward( ctx, @@ -139,21 +175,36 @@ def forward( "bshd", "thd", ), f"Unsupported tensor_format: {tensor_format}." - output = tex.fused_rope_forward( - t, - freqs, - start_positions, - QKVFormat[tensor_format], - interleaved, - cu_seqlens, - cp_size, - cp_rank, + + use_aiter = FusedRoPEFunc._can_use_aiter( + tensor_format, interleaved, cu_seqlens, cp_size, start_positions ) + + if use_aiter: + rotate_style = 1 if interleaved else 0 + output = _aiter_rope_fwd( + t, freqs, rotate_style, + False, # reuse_freqs_front_part + False, # nope_first + ) + else: + output = tex.fused_rope_forward( + t, + freqs, + start_positions, + QKVFormat[tensor_format], + interleaved, + cu_seqlens, + cp_size, + cp_rank, + ) + ctx.save_for_backward(freqs, cu_seqlens, start_positions) ctx.tensor_format = tensor_format ctx.cp_size = cp_size ctx.cp_rank = cp_rank ctx.interleaved = interleaved + ctx.use_aiter = use_aiter return output @@ -161,16 +212,25 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """Fused RoPE backward.""" freqs, cu_seqlens, start_positions = ctx.saved_tensors - grad_input = tex.fused_rope_backward( - grad_output, - freqs, - start_positions, - QKVFormat[ctx.tensor_format], - ctx.interleaved, - cu_seqlens, - ctx.cp_size, - ctx.cp_rank, - ) + + if ctx.use_aiter: + rotate_style = 1 if ctx.interleaved else 0 + grad_input = _aiter_rope_bwd( + grad_output, freqs, rotate_style, + False, # reuse_freqs_front_part + False, # nope_first + ) + else: + grad_input = tex.fused_rope_backward( + grad_output, + freqs, + start_positions, + QKVFormat[ctx.tensor_format], + ctx.interleaved, + cu_seqlens, + ctx.cp_size, + ctx.cp_rank, + ) return grad_input, None, None, None, None, None, None, None, None