Add AITER fused RoPE dispatch to FusedRoPEFunc#489
Conversation
When the AITER library is available on ROCm, dispatch FusedRoPEFunc
forward and backward to aiter.ops.rope HIP kernels for higher
throughput on AMD GPUs.
Guarded by NVTE_USE_AITER_FUSED_ROPE env var (default "1"). Only
activates for sbhd format, non-interleaved, cp_size==1, and when
cu_seqlens/start_positions are absent. Falls back to TE-native
tex.fused_rope_{forward,backward} otherwise.
Made-with: Cursor
There was a problem hiding this comment.
Pull request overview
This PR adds an optional dispatch path so FusedRoPEFunc forward/backward can call AITER’s fused HIP RoPE kernels on ROCm when available, controlled by NVTE_USE_AITER_FUSED_ROPE, and otherwise fall back to Transformer Engine’s native kernels.
Changes:
- Add optional import of
aiter.ops.ropeand a feature flag (NVTE_USE_AITER_FUSED_ROPE) to enable/disable AITER dispatch. - Introduce
_can_use_aiter(...)gating logic and routeFusedRoPEFuncfwd/bwd to AITER when the config is compatible. - Save whether AITER was used on
ctxto ensure backward follows the same path.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| # Optional AITER fused RoPE dispatch (ROCm). | ||
| # Controlled by NVTE_USE_AITER_FUSED_ROPE (default "1" = enabled). | ||
| _aiter_rope_fwd = None | ||
| _aiter_rope_bwd = None | ||
| _HAVE_AITER_ROPE = False | ||
| _USE_AITER_FUSED_ROPE = os.environ.get("NVTE_USE_AITER_FUSED_ROPE", "1") == "1" | ||
|
|
||
| if _USE_AITER_FUSED_ROPE: | ||
| try: | ||
| from aiter.ops.rope import rope_fwd as _aiter_rope_fwd, rope_bwd as _aiter_rope_bwd | ||
| _HAVE_AITER_ROPE = True | ||
| except Exception as e: | ||
| warnings.warn( | ||
| f"AITER fused RoPE import failed ({type(e).__name__}: {e}). " | ||
| f"Falling back to TE native kernels. " | ||
| f"Set NVTE_USE_AITER_FUSED_ROPE=0 to suppress this warning." | ||
| ) |
| warnings.warn( | ||
| f"AITER fused RoPE import failed ({type(e).__name__}: {e}). " | ||
| f"Falling back to TE native kernels. " | ||
| f"Set NVTE_USE_AITER_FUSED_ROPE=0 to suppress this warning." |
| @staticmethod | ||
| def _can_use_aiter(tensor_format, interleaved, cu_seqlens, cp_size, start_positions): | ||
| """Return True when AITER fused rope kernels can handle this configuration.""" | ||
| return ( | ||
| _HAVE_AITER_ROPE | ||
| and tensor_format == "sbhd" | ||
| and not interleaved | ||
| and cu_seqlens is None | ||
| and cp_size == 1 | ||
| and start_positions is None | ||
| ) |
| use_aiter = FusedRoPEFunc._can_use_aiter( | ||
| tensor_format, interleaved, cu_seqlens, cp_size, start_positions | ||
| ) | ||
|
|
||
| if use_aiter: | ||
| rotate_style = 1 if interleaved else 0 | ||
| output = _aiter_rope_fwd( | ||
| t, freqs, rotate_style, | ||
| False, # reuse_freqs_front_part | ||
| False, # nope_first | ||
| ) | ||
| else: | ||
| output = tex.fused_rope_forward( | ||
| t, | ||
| freqs, | ||
| start_positions, | ||
| QKVFormat[tensor_format], | ||
| interleaved, | ||
| cu_seqlens, | ||
| cp_size, | ||
| cp_rank, | ||
| ) |
wangye805
left a comment
There was a problem hiding this comment.
In addition to my comments and copilot comments, notice that currently it's tricky to maintain both standalone aiter and TE together. To do so, you will need carefully make the aiter commit the same as that in TE 3rdparty and also pre-built the libmha_fwd/bwd.so in standalone aiter.
Generally I don't think this change can be merged to our TE dev for now, until our TE FA is decoupled from aiter shared lib
| """ | ||
| Rotary Position Embedding implementation of different types along with helper functions | ||
| """ | ||
| import os |
There was a problem hiding this comment.
ROCm specific guards needed
| _aiter_rope_fwd = None | ||
| _aiter_rope_bwd = None | ||
| _HAVE_AITER_ROPE = False | ||
| _USE_AITER_FUSED_ROPE = os.environ.get("NVTE_USE_AITER_FUSED_ROPE", "1") == "1" |
There was a problem hiding this comment.
Also need to mention this env and its usage in README
| """ | ||
|
|
||
| @staticmethod | ||
| def _can_use_aiter(tensor_format, interleaved, cu_seqlens, cp_size, start_positions): |
There was a problem hiding this comment.
In TE, function name in pytorch does not start with _ usually
Summary
FusedRoPEFuncfwd/bwd to AITER HIP RoPE kernels on ROCm when available, controlled byNVTE_USE_AITER_FUSED_ROPE(default1).Made with Cursor