Skip to content

Add AITER fused RoPE dispatch to FusedRoPEFunc#489

Open
sarthak-amd wants to merge 1 commit intodevfrom
feature/aiter-fused-rope
Open

Add AITER fused RoPE dispatch to FusedRoPEFunc#489
sarthak-amd wants to merge 1 commit intodevfrom
feature/aiter-fused-rope

Conversation

@sarthak-amd
Copy link
Collaborator

Summary

  • Dispatch FusedRoPEFunc fwd/bwd to AITER HIP RoPE kernels on ROCm when available, controlled by NVTE_USE_AITER_FUSED_ROPE (default 1).
  • Falls back to TE-native kernels for unsupported configs (non-sbhd, interleaved, CP>1, cu_seqlens, start_positions) or when AITER is absent.

Made with Cursor

When the AITER library is available on ROCm, dispatch FusedRoPEFunc
forward and backward to aiter.ops.rope HIP kernels for higher
throughput on AMD GPUs.

Guarded by NVTE_USE_AITER_FUSED_ROPE env var (default "1"). Only
activates for sbhd format, non-interleaved, cp_size==1, and when
cu_seqlens/start_positions are absent. Falls back to TE-native
tex.fused_rope_{forward,backward} otherwise.

Made-with: Cursor
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an optional dispatch path so FusedRoPEFunc forward/backward can call AITER’s fused HIP RoPE kernels on ROCm when available, controlled by NVTE_USE_AITER_FUSED_ROPE, and otherwise fall back to Transformer Engine’s native kernels.

Changes:

  • Add optional import of aiter.ops.rope and a feature flag (NVTE_USE_AITER_FUSED_ROPE) to enable/disable AITER dispatch.
  • Introduce _can_use_aiter(...) gating logic and route FusedRoPEFunc fwd/bwd to AITER when the config is compatible.
  • Save whether AITER was used on ctx to ensure backward follows the same path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +16 to +32
# Optional AITER fused RoPE dispatch (ROCm).
# Controlled by NVTE_USE_AITER_FUSED_ROPE (default "1" = enabled).
_aiter_rope_fwd = None
_aiter_rope_bwd = None
_HAVE_AITER_ROPE = False
_USE_AITER_FUSED_ROPE = os.environ.get("NVTE_USE_AITER_FUSED_ROPE", "1") == "1"

if _USE_AITER_FUSED_ROPE:
try:
from aiter.ops.rope import rope_fwd as _aiter_rope_fwd, rope_bwd as _aiter_rope_bwd
_HAVE_AITER_ROPE = True
except Exception as e:
warnings.warn(
f"AITER fused RoPE import failed ({type(e).__name__}: {e}). "
f"Falling back to TE native kernels. "
f"Set NVTE_USE_AITER_FUSED_ROPE=0 to suppress this warning."
)
warnings.warn(
f"AITER fused RoPE import failed ({type(e).__name__}: {e}). "
f"Falling back to TE native kernels. "
f"Set NVTE_USE_AITER_FUSED_ROPE=0 to suppress this warning."
Comment on lines +145 to +155
@staticmethod
def _can_use_aiter(tensor_format, interleaved, cu_seqlens, cp_size, start_positions):
"""Return True when AITER fused rope kernels can handle this configuration."""
return (
_HAVE_AITER_ROPE
and tensor_format == "sbhd"
and not interleaved
and cu_seqlens is None
and cp_size == 1
and start_positions is None
)
Comment on lines +179 to +200
use_aiter = FusedRoPEFunc._can_use_aiter(
tensor_format, interleaved, cu_seqlens, cp_size, start_positions
)

if use_aiter:
rotate_style = 1 if interleaved else 0
output = _aiter_rope_fwd(
t, freqs, rotate_style,
False, # reuse_freqs_front_part
False, # nope_first
)
else:
output = tex.fused_rope_forward(
t,
freqs,
start_positions,
QKVFormat[tensor_format],
interleaved,
cu_seqlens,
cp_size,
cp_rank,
)
Copy link
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to my comments and copilot comments, notice that currently it's tricky to maintain both standalone aiter and TE together. To do so, you will need carefully make the aiter commit the same as that in TE 3rdparty and also pre-built the libmha_fwd/bwd.so in standalone aiter.

Generally I don't think this change can be merged to our TE dev for now, until our TE FA is decoupled from aiter shared lib

"""
Rotary Position Embedding implementation of different types along with helper functions
"""
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCm specific guards needed

_aiter_rope_fwd = None
_aiter_rope_bwd = None
_HAVE_AITER_ROPE = False
_USE_AITER_FUSED_ROPE = os.environ.get("NVTE_USE_AITER_FUSED_ROPE", "1") == "1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to mention this env and its usage in README

"""

@staticmethod
def _can_use_aiter(tensor_format, interleaved, cu_seqlens, cp_size, start_positions):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In TE, function name in pytorch does not start with _ usually

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants