IFU 2.12 JAX merge conflict resolution#493
Closed
Micky774 wants to merge 4 commits intoyewang12/IFU-dev-v2.12-99df88-20260116from
Closed
IFU 2.12 JAX merge conflict resolution#493Micky774 wants to merge 4 commits intoyewang12/IFU-dev-v2.12-99df88-20260116from
Micky774 wants to merge 4 commits intoyewang12/IFU-dev-v2.12-99df88-20260116from
Conversation
ipanfilo
requested changes
Mar 18, 2026
ipanfilo
approved these changes
Mar 19, 2026
Contributor
Author
|
Closing as completed -- no merge needed |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR focuses on the resolutions to merge conflicts located in the PyTorch integration, as introduced as part of the main IFU branch. This merge resolution is performed in separate PRs to limit the scope of each PR and simplify the review process, while preserving commit history and conversation readability.
This is one of three PRs which target the common components of TE, as well as the JAX and PyTorch integrations. Once all three PRs are resolved and merged into the IFU branch, a new IFU PR will be created with a focus on ensuring a successful total build as well as CI validation. In the meanwhile, the PRs will be partially validated via local testing.
Type of change
Changes
Please list the changes introduced in this PR:
The changelog is generated with assistance from Claude.
tests/jax/test_fused_attn.py— four conflict resolutions:- Kept ROCm's new-style RNG key handling (jax.random.key vs PRNGKey) while adopting upstream's 6-way split (added softmax_key)
- Kept ROCm's RNG partition spec logic while integrating upstream's new softmax offset sharding
- Fixed test param ID for cross-attention from seqlen 512 to 2048 while keeping upstream's
RAGGED_KV_PACKEDlayout addition- Added
pytest.skipfortest_context_parallel_allgather_striped_attnon ROCm — THD + ALL_GATHER + Striped attention is not yet supportedtests/jax/utils.py— resolved conflict keeping both ROCm's MXFP8 GEMM support-checking functions and upstream's newis_devices_equalutilityjax/csrc/extensions/attention.cpp— fixed platform guard structure for fused attention aux tensors: ROCm enters the arbitrary-seqlen code path unconditionally (both aotriton and CK backends need it), while CUDA only enters forNVTE_F16_arbitrary_seqlen. Also integrated upstream's new softmax_offset aux tensor with proper#ifndef USE_ROCMguardjax/csrc/extensions/gemm.cpp— resolved conflict merging ROCm's NVFP4 guard ("ROCm TE does not support NVFP4 yet") with upstream's newset_with_gemm_swizzled_scalesAPI. Adopted upstream's code structure while preserving the ROCm NVFP4 errorChecklist: