Skip to content

IFU 2.12 PyTorch merge conflict resolution#495

Open
Micky774 wants to merge 3 commits intoyewang12/IFU-dev-v2.12-99df88-20260116from
zain/ifu-2.12/pytorch
Open

IFU 2.12 PyTorch merge conflict resolution#495
Micky774 wants to merge 3 commits intoyewang12/IFU-dev-v2.12-99df88-20260116from
zain/ifu-2.12/pytorch

Conversation

@Micky774
Copy link
Contributor

Description

This PR focuses on the resolutions to merge conflicts located in the PyTorch integration, as introduced as part of the main IFU branch. This merge resolution is performed in separate PRs to limit the scope of each PR and simplify the review process, while preserving commit history and conversation readability.

This is one of three PRs which target the common components of TE, as well as the JAX and PyTorch integrations. Once all three PRs are resolved and merged into the IFU branch, a new IFU PR will be created with a focus on ensuring a successful total build as well as CI validation. In the meanwhile, the PRs will be partially validated via local testing.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

The changelog is generated with assistance from Claude.

The large negative delta (-590 lines) is because upstream refactored autograd forward() signatures to bundle arguments into a non_tensor_args tuple — removing the old per-parameter style that ROCm had conflict markers against.

C++ / CSRC changes:

  1. DeviceGuard pattern — added HIPGuardMasqueradingAsCUDA conditional typedef to attention.cpp, gemm.cpp, and normalization.cpp, replacing at::cuda::CUDAGuard with the platform-aware DeviceGuard. This is necessary due to the inability to properly hipify the standard at::cuda::CUDAGuard to the proper masquerading HIP device guard.
  2. cast.cpp — guarded NVFP4-specific enum values (BULK_NVFP4, FUSED_NVFP4) and code paths with #ifndef USE_ROCM; moved the #endif for the bulk allocation function to the correct location
  3. quantizer.cpp — resolved scale shape computation conflicts: adopted upstream's ceildiv() helper while preserving ROCm's guard to skip roundup(..., 4) padding (ROCm doesn't need 4-aligned scale dimensions)
  4. common.cpp / common.h — kept both ROCm's atomic amax functions and upstream's new ceildiv utility; resolved kFloat8E8M0 dtype mapping conflict keeping ROCm's FNUZ variant
  5. extensions.h / util.h — moved #ifndef USE_ROCM guard to correct positions, keeping inplace_swizzle_scale_for_gemm declaration visible while guarding NVSHMEM APIs

Python module changes:

  1. gemm.py — added ROCm workspace sizing to get_cublas_workspace_size_bytes() (64 MiB for gfx950, 32 MiB for other ROCm GPUs) to avoid a HIPBLASLT error
  2. grouped_linear.py, layernorm_linear.py, layernorm_mlp.py, linear.py — adopted upstream's non_tensor_args tuple pattern for autograd forward, integrating ROCm-specific fields (m_splits_tensor, use_grouped_gemm_triton, keep_fp8_weight_transpose_cache, use_fsdp2) into the tuple. Removed the old per-parameter signatures and their corresponding None returns in backward
  3. base.py — kept the old get_cublas_workspace_size_bytes() (still used by legacy code paths) while integrating upstream's NVFP4TensorStorage import and get_nvtx_range_context utility
  4. jit.py, distributed.py — resolved trivial conflicts (import merges, trailing whitespace)
  5. float8_blockwise_tensor.py — resolved conflicts in scale shape computation, same pattern as quantizer.cpp (skip roundup padding on ROCm)
  6. Test changes — test_attention.py skips GQA with packed QKV layouts; run_attention_with_cp.py uses IS_HIP_EXTENSION to branch cu_seqlens handling; test_numerics.py keeps both ROCm wgrad skip and upstream's debug mode skip

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Micky774 Micky774 mentioned this pull request Mar 18, 2026
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants