-
Notifications
You must be signed in to change notification settings - Fork 24
Add AITER fused RoPE dispatch to FusedRoPEFunc #489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sarthak-amd
wants to merge
1
commit into
dev
Choose a base branch
from
feature/aiter-fused-rope
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+79
−19
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,12 +5,32 @@ | |
| """ | ||
| Rotary Position Embedding implementation of different types along with helper functions | ||
| """ | ||
| import os | ||
| import warnings | ||
| from typing import Optional, Tuple, Union, List | ||
| import torch | ||
|
|
||
| import transformer_engine_torch as tex | ||
| from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat | ||
|
|
||
| # Optional AITER fused RoPE dispatch (ROCm). | ||
| # Controlled by NVTE_USE_AITER_FUSED_ROPE (default "1" = enabled). | ||
| _aiter_rope_fwd = None | ||
| _aiter_rope_bwd = None | ||
| _HAVE_AITER_ROPE = False | ||
| _USE_AITER_FUSED_ROPE = os.environ.get("NVTE_USE_AITER_FUSED_ROPE", "1") == "1" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also need to mention this env and its usage in README |
||
|
|
||
| if _USE_AITER_FUSED_ROPE: | ||
| try: | ||
| from aiter.ops.rope import rope_fwd as _aiter_rope_fwd, rope_bwd as _aiter_rope_bwd | ||
| _HAVE_AITER_ROPE = True | ||
| except Exception as e: | ||
| warnings.warn( | ||
| f"AITER fused RoPE import failed ({type(e).__name__}: {e}). " | ||
| f"Falling back to TE native kernels. " | ||
| f"Set NVTE_USE_AITER_FUSED_ROPE=0 to suppress this warning." | ||
|
|
||
| ) | ||
|
Comment on lines
+16
to
+32
|
||
|
|
||
|
|
||
| __all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] | ||
|
|
||
|
|
@@ -116,8 +136,24 @@ class FusedRoPEFunc(torch.autograd.Function): | |
| This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and | ||
| the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid | ||
| the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern. | ||
|
|
||
| When AITER is available and the configuration is compatible (sbhd, non-interleaved, | ||
| cp_size==1, no cu_seqlens/start_positions), the forward and backward are dispatched | ||
| to AITER's fused HIP kernels for higher throughput on AMD GPUs. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def _can_use_aiter(tensor_format, interleaved, cu_seqlens, cp_size, start_positions): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In TE, function name in pytorch does not start with _ usually |
||
| """Return True when AITER fused rope kernels can handle this configuration.""" | ||
| return ( | ||
| _HAVE_AITER_ROPE | ||
| and tensor_format == "sbhd" | ||
| and not interleaved | ||
| and cu_seqlens is None | ||
| and cp_size == 1 | ||
| and start_positions is None | ||
| ) | ||
|
Comment on lines
+145
to
+155
|
||
|
|
||
| @staticmethod | ||
| def forward( | ||
| ctx, | ||
|
|
@@ -139,38 +175,62 @@ def forward( | |
| "bshd", | ||
| "thd", | ||
| ), f"Unsupported tensor_format: {tensor_format}." | ||
| output = tex.fused_rope_forward( | ||
| t, | ||
| freqs, | ||
| start_positions, | ||
| QKVFormat[tensor_format], | ||
| interleaved, | ||
| cu_seqlens, | ||
| cp_size, | ||
| cp_rank, | ||
|
|
||
| use_aiter = FusedRoPEFunc._can_use_aiter( | ||
| tensor_format, interleaved, cu_seqlens, cp_size, start_positions | ||
| ) | ||
|
|
||
| if use_aiter: | ||
| rotate_style = 1 if interleaved else 0 | ||
| output = _aiter_rope_fwd( | ||
| t, freqs, rotate_style, | ||
| False, # reuse_freqs_front_part | ||
| False, # nope_first | ||
| ) | ||
| else: | ||
| output = tex.fused_rope_forward( | ||
| t, | ||
| freqs, | ||
| start_positions, | ||
| QKVFormat[tensor_format], | ||
| interleaved, | ||
| cu_seqlens, | ||
| cp_size, | ||
| cp_rank, | ||
| ) | ||
|
Comment on lines
+179
to
+200
|
||
|
|
||
| ctx.save_for_backward(freqs, cu_seqlens, start_positions) | ||
| ctx.tensor_format = tensor_format | ||
| ctx.cp_size = cp_size | ||
| ctx.cp_rank = cp_rank | ||
| ctx.interleaved = interleaved | ||
| ctx.use_aiter = use_aiter | ||
|
|
||
| return output | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: | ||
| """Fused RoPE backward.""" | ||
| freqs, cu_seqlens, start_positions = ctx.saved_tensors | ||
| grad_input = tex.fused_rope_backward( | ||
| grad_output, | ||
| freqs, | ||
| start_positions, | ||
| QKVFormat[ctx.tensor_format], | ||
| ctx.interleaved, | ||
| cu_seqlens, | ||
| ctx.cp_size, | ||
| ctx.cp_rank, | ||
| ) | ||
|
|
||
| if ctx.use_aiter: | ||
| rotate_style = 1 if ctx.interleaved else 0 | ||
| grad_input = _aiter_rope_bwd( | ||
| grad_output, freqs, rotate_style, | ||
| False, # reuse_freqs_front_part | ||
| False, # nope_first | ||
| ) | ||
| else: | ||
| grad_input = tex.fused_rope_backward( | ||
| grad_output, | ||
| freqs, | ||
| start_positions, | ||
| QKVFormat[ctx.tensor_format], | ||
| ctx.interleaved, | ||
| cu_seqlens, | ||
| ctx.cp_size, | ||
| ctx.cp_rank, | ||
| ) | ||
|
|
||
| return grad_input, None, None, None, None, None, None, None, None | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ROCm specific guards needed