Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 79 additions & 19 deletions transformer_engine/pytorch/attention/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,32 @@
"""
Rotary Position Embedding implementation of different types along with helper functions
"""
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCm specific guards needed

import warnings
from typing import Optional, Tuple, Union, List
import torch

import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat

# Optional AITER fused RoPE dispatch (ROCm).
# Controlled by NVTE_USE_AITER_FUSED_ROPE (default "1" = enabled).
_aiter_rope_fwd = None
_aiter_rope_bwd = None
_HAVE_AITER_ROPE = False
_USE_AITER_FUSED_ROPE = os.environ.get("NVTE_USE_AITER_FUSED_ROPE", "1") == "1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to mention this env and its usage in README


if _USE_AITER_FUSED_ROPE:
try:
from aiter.ops.rope import rope_fwd as _aiter_rope_fwd, rope_bwd as _aiter_rope_bwd
_HAVE_AITER_ROPE = True
except Exception as e:
warnings.warn(
f"AITER fused RoPE import failed ({type(e).__name__}: {e}). "
f"Falling back to TE native kernels. "
f"Set NVTE_USE_AITER_FUSED_ROPE=0 to suppress this warning."
)
Comment on lines +16 to +32


__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"]

Expand Down Expand Up @@ -116,8 +136,24 @@ class FusedRoPEFunc(torch.autograd.Function):
This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.

When AITER is available and the configuration is compatible (sbhd, non-interleaved,
cp_size==1, no cu_seqlens/start_positions), the forward and backward are dispatched
to AITER's fused HIP kernels for higher throughput on AMD GPUs.
"""

@staticmethod
def _can_use_aiter(tensor_format, interleaved, cu_seqlens, cp_size, start_positions):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In TE, function name in pytorch does not start with _ usually

"""Return True when AITER fused rope kernels can handle this configuration."""
return (
_HAVE_AITER_ROPE
and tensor_format == "sbhd"
and not interleaved
and cu_seqlens is None
and cp_size == 1
and start_positions is None
)
Comment on lines +145 to +155

@staticmethod
def forward(
ctx,
Expand All @@ -139,38 +175,62 @@ def forward(
"bshd",
"thd",
), f"Unsupported tensor_format: {tensor_format}."
output = tex.fused_rope_forward(
t,
freqs,
start_positions,
QKVFormat[tensor_format],
interleaved,
cu_seqlens,
cp_size,
cp_rank,

use_aiter = FusedRoPEFunc._can_use_aiter(
tensor_format, interleaved, cu_seqlens, cp_size, start_positions
)

if use_aiter:
rotate_style = 1 if interleaved else 0
output = _aiter_rope_fwd(
t, freqs, rotate_style,
False, # reuse_freqs_front_part
False, # nope_first
)
else:
output = tex.fused_rope_forward(
t,
freqs,
start_positions,
QKVFormat[tensor_format],
interleaved,
cu_seqlens,
cp_size,
cp_rank,
)
Comment on lines +179 to +200

ctx.save_for_backward(freqs, cu_seqlens, start_positions)
ctx.tensor_format = tensor_format
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
ctx.interleaved = interleaved
ctx.use_aiter = use_aiter

return output

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused RoPE backward."""
freqs, cu_seqlens, start_positions = ctx.saved_tensors
grad_input = tex.fused_rope_backward(
grad_output,
freqs,
start_positions,
QKVFormat[ctx.tensor_format],
ctx.interleaved,
cu_seqlens,
ctx.cp_size,
ctx.cp_rank,
)

if ctx.use_aiter:
rotate_style = 1 if ctx.interleaved else 0
grad_input = _aiter_rope_bwd(
grad_output, freqs, rotate_style,
False, # reuse_freqs_front_part
False, # nope_first
)
else:
grad_input = tex.fused_rope_backward(
grad_output,
freqs,
start_positions,
QKVFormat[ctx.tensor_format],
ctx.interleaved,
cu_seqlens,
ctx.cp_size,
ctx.cp_rank,
)

return grad_input, None, None, None, None, None, None, None, None

Expand Down
Loading