Skip to content

IFU 2.12 JAX merge conflict resolution#493

Closed
Micky774 wants to merge 4 commits intoyewang12/IFU-dev-v2.12-99df88-20260116from
zain/ifu-2.12/jax
Closed

IFU 2.12 JAX merge conflict resolution#493
Micky774 wants to merge 4 commits intoyewang12/IFU-dev-v2.12-99df88-20260116from
zain/ifu-2.12/jax

Conversation

@Micky774
Copy link
Contributor

@Micky774 Micky774 commented Mar 17, 2026

Description

This PR focuses on the resolutions to merge conflicts located in the PyTorch integration, as introduced as part of the main IFU branch. This merge resolution is performed in separate PRs to limit the scope of each PR and simplify the review process, while preserving commit history and conversation readability.

This is one of three PRs which target the common components of TE, as well as the JAX and PyTorch integrations. Once all three PRs are resolved and merged into the IFU branch, a new IFU PR will be created with a focus on ensuring a successful total build as well as CI validation. In the meanwhile, the PRs will be partially validated via local testing.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

The changelog is generated with assistance from Claude.

  1. tests/jax/test_fused_attn.py — four conflict resolutions:
    - Kept ROCm's new-style RNG key handling (jax.random.key vs PRNGKey) while adopting upstream's 6-way split (added softmax_key)
    - Kept ROCm's RNG partition spec logic while integrating upstream's new softmax offset sharding
    - Fixed test param ID for cross-attention from seqlen 512 to 2048 while keeping upstream's RAGGED_KV_PACKED layout addition
    - Added pytest.skip for test_context_parallel_allgather_striped_attn on ROCm — THD + ALL_GATHER + Striped attention is not yet supported
  2. tests/jax/utils.py — resolved conflict keeping both ROCm's MXFP8 GEMM support-checking functions and upstream's new is_devices_equal utility
  3. jax/csrc/extensions/attention.cpp — fixed platform guard structure for fused attention aux tensors: ROCm enters the arbitrary-seqlen code path unconditionally (both aotriton and CK backends need it), while CUDA only enters for
    NVTE_F16_arbitrary_seqlen. Also integrated upstream's new softmax_offset aux tensor with proper #ifndef USE_ROCM guard
  4. jax/csrc/extensions/gemm.cpp — resolved conflict merging ROCm's NVFP4 guard ("ROCm TE does not support NVFP4 yet") with upstream's new set_with_gemm_swizzled_scales API. Adopted upstream's code structure while preserving the ROCm NVFP4 error

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Micky774 Micky774 mentioned this pull request Mar 18, 2026
13 tasks
@Micky774
Copy link
Contributor Author

Closing as completed -- no merge needed

@Micky774 Micky774 closed this Mar 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants