IFU 2.12 PyTorch merge conflict resolution#495
Open
Micky774 wants to merge 3 commits intoyewang12/IFU-dev-v2.12-99df88-20260116from
Open
IFU 2.12 PyTorch merge conflict resolution#495Micky774 wants to merge 3 commits intoyewang12/IFU-dev-v2.12-99df88-20260116from
Micky774 wants to merge 3 commits intoyewang12/IFU-dev-v2.12-99df88-20260116from
Conversation
ipanfilo
requested changes
Mar 18, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR focuses on the resolutions to merge conflicts located in the PyTorch integration, as introduced as part of the main IFU branch. This merge resolution is performed in separate PRs to limit the scope of each PR and simplify the review process, while preserving commit history and conversation readability.
This is one of three PRs which target the common components of TE, as well as the JAX and PyTorch integrations. Once all three PRs are resolved and merged into the IFU branch, a new IFU PR will be created with a focus on ensuring a successful total build as well as CI validation. In the meanwhile, the PRs will be partially validated via local testing.
Type of change
Changes
Please list the changes introduced in this PR:
The changelog is generated with assistance from Claude.
The large negative delta (-590 lines) is because upstream refactored autograd forward() signatures to bundle arguments into a
non_tensor_argstuple — removing the old per-parameter style that ROCm had conflict markers against.C++ / CSRC changes:
DeviceGuardpattern — addedHIPGuardMasqueradingAsCUDAconditional typedef toattention.cpp,gemm.cpp, andnormalization.cpp, replacingat::cuda::CUDAGuardwith the platform-awareDeviceGuard. This is necessary due to the inability to properly hipify the standardat::cuda::CUDAGuardto the proper masquerading HIP device guard.cast.cpp— guarded NVFP4-specific enum values (BULK_NVFP4, FUSED_NVFP4) and code paths with#ifndef USE_ROCM; moved the#endiffor the bulk allocation function to the correct locationquantizer.cpp— resolved scale shape computation conflicts: adopted upstream'sceildiv()helper while preserving ROCm's guard to skiproundup(..., 4)padding (ROCm doesn't need 4-aligned scale dimensions)common.cpp / common.h— kept both ROCm's atomic amax functions and upstream's newceildivutility; resolvedkFloat8E8M0dtype mapping conflict keeping ROCm's FNUZ variantextensions.h / util.h— moved#ifndef USE_ROCMguard to correct positions, keepinginplace_swizzle_scale_for_gemmdeclaration visible while guarding NVSHMEM APIsPython module changes:
gemm.py— added ROCm workspace sizing toget_cublas_workspace_size_bytes()(64 MiB for gfx950, 32 MiB for other ROCm GPUs) to avoid a HIPBLASLT errorgrouped_linear.py,layernorm_linear.py,layernorm_mlp.py,linear.py— adopted upstream'snon_tensor_argstuple pattern for autograd forward, integrating ROCm-specific fields (m_splits_tensor, use_grouped_gemm_triton, keep_fp8_weight_transpose_cache, use_fsdp2) into the tuple. Removed the old per-parameter signatures and their correspondingNonereturns in backwardbase.py— kept the oldget_cublas_workspace_size_bytes()(still used by legacy code paths) while integrating upstream'sNVFP4TensorStorageimport andget_nvtx_range_contextutilityjit.py,distributed.py— resolved trivial conflicts (import merges, trailing whitespace)float8_blockwise_tensor.py— resolved conflicts in scale shape computation, same pattern asquantizer.cpp(skip roundup padding on ROCm)test_attention.pyskips GQA with packed QKV layouts;run_attention_with_cp.pyusesIS_HIP_EXTENSIONto branchcu_seqlenshandling;test_numerics.pykeeps both ROCm wgrad skip and upstream's debug mode skipChecklist: