From 3b424c482e6416ddc3ebc80a3585ea59690f093a Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 17 Mar 2026 16:33:56 -0500 Subject: [PATCH 1/4] IFU resolution JAX --- examples/jax/encoder/common.py | 5 --- tests/jax/test_distributed_fused_attn.py | 3 -- tests/jax/test_fused_attn.py | 35 ++++++------------- tests/jax/utils.py | 4 +-- .../jax/csrc/extensions/attention.cpp | 8 ++--- .../jax/csrc/extensions/gemm.cpp | 11 ++---- 6 files changed, 17 insertions(+), 49 deletions(-) diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index a83008254..8cc666443 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -1,12 +1,7 @@ -<<<<<<< HEAD # This file was modified for portability to AMDGPU # Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - -======= -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # ->>>>>>> 99df88 # See LICENSE for license information. """Shared functions for the encoder tests""" from functools import lru_cache diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index c02740aff..caf9b1143 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -194,9 +194,7 @@ def test_self_attn( pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), ], ) -<<<<<<< HEAD @pytest.mark.skipif(version.parse(jax.__version__) < version.parse("0.5.0"), reason="shardy sharding requires JAX 0.5.0") -======= @pytest.mark.parametrize( "softmax_type", [ @@ -205,7 +203,6 @@ def test_self_attn( pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"), ], ) ->>>>>>> 99df88 def test_self_attn_shardy( self, device_count, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 7b75f37d9..26257f977 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -495,18 +495,13 @@ def _setup_inputs(self): self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1) self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1) -<<<<<<< HEAD # only support new-style RNGs on AMD hardware since they will crash otherwise if is_hip_extension() and not self.use_old_rng: key = jax.random.key(0) else: key = jax.random.PRNGKey(0) - q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) -======= - key = jax.random.PRNGKey(0) q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6) ->>>>>>> 99df88 q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk) k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk) @@ -799,17 +794,6 @@ def to_dp_shardings(x): self.bias_pspec = PartitionSpec() self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec) -<<<<<<< HEAD - # New-style RNG fix is only applied for AMD GPUs - if is_hip_extension(): - if self.dropout_rng is not None and jnp.issubdtype(self.dropout_rng.dtype, jax.dtypes.prng_key): - self.dropout_rng_pspec = PartitionSpec() - else: - self.dropout_rng_pspec = PartitionSpec(None,) - else: - self.dropout_rng_pspec = PartitionSpec(None,) - -======= # Softmax offset sharding (1, num_heads, 1, 1) # Use the same logic as HEAD_AXES: tpsp_resource if enabled, else tp_resource head_resource = ( @@ -820,10 +804,15 @@ def to_dp_shardings(x): self.softmax_offset_pspec = PartitionSpec(None, head_resource, None, None) self.softmax_offset_sharding = NamedSharding(self.mesh, self.softmax_offset_pspec) - self.dropout_rng_pspec = PartitionSpec( - None, - ) ->>>>>>> 99df88 + # New-style RNG fix is only applied for AMD GPUs + self.dropout_rng_pspec = PartitionSpec(None,) + if ( + is_hip_extension() and + self.dropout_rng is not None and + jnp.issubdtype(self.dropout_rng.dtype, jax.dtypes.prng_key) + ): + self.dropout_rng_pspec = PartitionSpec() + self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec) self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None) @@ -1160,12 +1149,8 @@ def check_dqkv(primitive, reference, pad, idx): 64, 64, jnp.bfloat16, -<<<<<<< HEAD - id="2-2048-1024-12-12-64-64-BF16-CROSS", -======= QKVLayout.THD_T2HD, - id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED", ->>>>>>> 99df88 + id="2-2048-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED", ), # large data size + bf16 + cross attn + diff hidden v dim + qkv separate pytest.param( diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 554a145dc..373f0a938 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -52,7 +52,6 @@ def is_devices_enough(required): return len(jax.devices()) >= required -<<<<<<< HEAD def _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias=False): if not is_hip_extension(): return @@ -145,14 +144,13 @@ def _check_mxfp8_layernorm_mlp_grad_support( m, n, k, use_bias ) -======= + def is_devices_equal(required): """ Check if the available GPUs is exactly equal """ return len(jax.devices()) == required ->>>>>>> 99df88 def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]: # Generate broadcast dims for drop_path. diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 9e13c62b7..41347a85e 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -57,6 +57,8 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax #ifndef USE_ROCM if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { +#else + { #endif // ROCm fused attn has two backends: aotriton and ck // They both have the same shape and stride for softmax and rng aux tensors @@ -89,10 +91,7 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t bias_aux_data.dtype = static_cast(dtype); nvte_set_tensor_param(&bias_aux, kNVTERowwiseData, &bias_aux_data); } -<<<<<<< HEAD #ifndef USE_ROCM -======= - // include softmax_offset if provided if (softmax_offset_buf != nullptr) { NVTETensor &softmax_offset_aux = tensor_pack->tensors[size]; @@ -107,12 +106,11 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t softmax_offset_aux_data.dtype = static_cast(DType::kFloat32); nvte_set_tensor_param(&softmax_offset_aux, kNVTERowwiseData, &softmax_offset_aux_data); } +#endif // Set final size tensor_pack->size = size; ->>>>>>> 99df88 } -#endif nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data); } /* diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 7117f5101..40121049a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -86,16 +86,11 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( } else { input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } -<<<<<<< HEAD - } else { // Swizzle for NVFP4 + input.set_with_gemm_swizzled_scales(true); + } else if (is_nvfp4) { // Swizzle for NVFP4 #ifdef USE_ROCM NVTE_ERROR("ROCm TE does not support NVFP4 yet."); - } #else -======= - input.set_with_gemm_swizzled_scales(true); - } else if (is_nvfp4) { // Swizzle for NVFP4 ->>>>>>> 99df88 NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); // Create tensor to hold swizzled scale factor @@ -108,6 +103,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( // Set swizzled scales into the input tensor input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); input.set_with_gemm_swizzled_scales(true); +#endif } else { // Tensor scaling if (rowwise) { input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); @@ -115,7 +111,6 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } } -#endif // #ifdef USE_ROCM } return std::make_tuple(std::move(input), input_shape); From 2bd4a059124452d6d886c80bfad0bfa279068e60 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 18 Mar 2026 14:46:25 -0500 Subject: [PATCH 2/4] Removed version guard, minimized diff --- tests/jax/test_distributed_fused_attn.py | 1 - tests/jax/test_fused_attn.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index caf9b1143..27fb4b0c0 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -194,7 +194,6 @@ def test_self_attn( pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), ], ) - @pytest.mark.skipif(version.parse(jax.__version__) < version.parse("0.5.0"), reason="shardy sharding requires JAX 0.5.0") @pytest.mark.parametrize( "softmax_type", [ diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 26257f977..bf6ce781c 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -804,8 +804,10 @@ def to_dp_shardings(x): self.softmax_offset_pspec = PartitionSpec(None, head_resource, None, None) self.softmax_offset_sharding = NamedSharding(self.mesh, self.softmax_offset_pspec) + self.dropout_rng_pspec = PartitionSpec( + None, + ) # New-style RNG fix is only applied for AMD GPUs - self.dropout_rng_pspec = PartitionSpec(None,) if ( is_hip_extension() and self.dropout_rng is not None and From fb831e1ac429bc00a43c84356395b29fe8e46e89 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 18 Mar 2026 15:22:39 -0500 Subject: [PATCH 3/4] Corrected size for test config --- tests/jax/test_fused_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index bf6ce781c..98d1089c3 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1144,7 +1144,7 @@ def check_dqkv(primitive, reference, pad, idx): ), pytest.param( 2, - 2048, + 512, 1024, 12, 12, @@ -1152,7 +1152,7 @@ def check_dqkv(primitive, reference, pad, idx): 64, jnp.bfloat16, QKVLayout.THD_T2HD, - id="2-2048-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED", + id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED", ), # large data size + bf16 + cross attn + diff hidden v dim + qkv separate pytest.param( From 024a76bada8cec12a1a6b75aedb4b79b52d7f026 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 19 Mar 2026 10:54:46 -0500 Subject: [PATCH 4/4] Updated test params and skipped unsupported striped attn --- tests/jax/test_distributed_fused_attn.py | 2 ++ tests/jax/test_fused_attn.py | 22 ++++++++++++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 27fb4b0c0..fb40e2b1f 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -554,6 +554,8 @@ def test_context_parallel_allgather_striped_attn( ): if not qkv_layout.is_thd(): pytest.skip("Only THD layout is supported for CP + AG + Striped attention") + if is_hip_extension(): + pytest.skip("THD + ALL_GATHER + Striped attention is not yet supported on ROCm") self.impl_test_context_parallel_attn( device_count, mesh_shape, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 98d1089c3..dc5cffc28 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1280,10 +1280,28 @@ def check_dqkv(primitive, reference, pad, idx): id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-RAGGED_SEPARATE", ), pytest.param( - 10, 4096, 4096, 16, 16, 192, 128, jnp.float16, id="10-4096-4096-16-16-192-128-FP16-MLA", + 10, + 4096, + 4096, + 16, + 16, + 192, + 128, + jnp.float16, + QKVLayout.BSHD_BSHD_BSHD, + id="10-4096-4096-16-16-192-128-FP16-MLA", ), pytest.param( - 10, 4096, 4096, 16, 16, 192, 128, jnp.bfloat16, id="10-4096-4096-16-16-192-128-BF16-MLA", + 10, + 4096, + 4096, + 16, + 16, + 192, + 128, + jnp.bfloat16, + QKVLayout.BSHD_BSHD_BSHD, + id="10-4096-4096-16-16-192-128-BF16-MLA", ), ], )