Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions examples/jax/encoder/common.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
<<<<<<< HEAD
# This file was modified for portability to AMDGPU
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

=======
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
>>>>>>> 99df88
# See LICENSE for license information.
"""Shared functions for the encoder tests"""
from functools import lru_cache
Expand Down
6 changes: 2 additions & 4 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,6 @@ def test_self_attn(
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
],
)
<<<<<<< HEAD
@pytest.mark.skipif(version.parse(jax.__version__) < version.parse("0.5.0"), reason="shardy sharding requires JAX 0.5.0")
=======
@pytest.mark.parametrize(
"softmax_type",
[
Expand All @@ -205,7 +202,6 @@ def test_self_attn(
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
>>>>>>> 99df88
def test_self_attn_shardy(
self,
device_count,
Expand Down Expand Up @@ -558,6 +554,8 @@ def test_context_parallel_allgather_striped_attn(
):
if not qkv_layout.is_thd():
pytest.skip("Only THD layout is supported for CP + AG + Striped attention")
if is_hip_extension():
pytest.skip("THD + ALL_GATHER + Striped attention is not yet supported on ROCm")
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
Expand Down
53 changes: 29 additions & 24 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,18 +495,13 @@ def _setup_inputs(self):
self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)

<<<<<<< HEAD
# only support new-style RNGs on AMD hardware since they will crash otherwise
if is_hip_extension() and not self.use_old_rng:
key = jax.random.key(0)
else:
key = jax.random.PRNGKey(0)

q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
=======
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6)
>>>>>>> 99df88

q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
Expand Down Expand Up @@ -799,17 +794,6 @@ def to_dp_shardings(x):
self.bias_pspec = PartitionSpec()
self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)

<<<<<<< HEAD
# New-style RNG fix is only applied for AMD GPUs
if is_hip_extension():
if self.dropout_rng is not None and jnp.issubdtype(self.dropout_rng.dtype, jax.dtypes.prng_key):
self.dropout_rng_pspec = PartitionSpec()
else:
self.dropout_rng_pspec = PartitionSpec(None,)
else:
self.dropout_rng_pspec = PartitionSpec(None,)

=======
# Softmax offset sharding (1, num_heads, 1, 1)
# Use the same logic as HEAD_AXES: tpsp_resource if enabled, else tp_resource
head_resource = (
Expand All @@ -823,7 +807,14 @@ def to_dp_shardings(x):
self.dropout_rng_pspec = PartitionSpec(
None,
)
>>>>>>> 99df88
# New-style RNG fix is only applied for AMD GPUs
if (
is_hip_extension() and
self.dropout_rng is not None and
jnp.issubdtype(self.dropout_rng.dtype, jax.dtypes.prng_key)
):
self.dropout_rng_pspec = PartitionSpec()

self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)

self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
Expand Down Expand Up @@ -1153,19 +1144,15 @@ def check_dqkv(primitive, reference, pad, idx):
),
pytest.param(
2,
2048,
512,
1024,
12,
12,
64,
64,
jnp.bfloat16,
<<<<<<< HEAD
id="2-2048-1024-12-12-64-64-BF16-CROSS",
=======
QKVLayout.THD_T2HD,
id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED",
>>>>>>> 99df88
),
# large data size + bf16 + cross attn + diff hidden v dim + qkv separate
pytest.param(
Expand Down Expand Up @@ -1293,10 +1280,28 @@ def check_dqkv(primitive, reference, pad, idx):
id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-RAGGED_SEPARATE",
),
pytest.param(
10, 4096, 4096, 16, 16, 192, 128, jnp.float16, id="10-4096-4096-16-16-192-128-FP16-MLA",
10,
4096,
4096,
16,
16,
192,
128,
jnp.float16,
QKVLayout.BSHD_BSHD_BSHD,
id="10-4096-4096-16-16-192-128-FP16-MLA",
),
pytest.param(
10, 4096, 4096, 16, 16, 192, 128, jnp.bfloat16, id="10-4096-4096-16-16-192-128-BF16-MLA",
10,
4096,
4096,
16,
16,
192,
128,
jnp.bfloat16,
QKVLayout.BSHD_BSHD_BSHD,
id="10-4096-4096-16-16-192-128-BF16-MLA",
),
],
)
Expand Down
4 changes: 1 addition & 3 deletions tests/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def is_devices_enough(required):
return len(jax.devices()) >= required


<<<<<<< HEAD
def _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias=False):
if not is_hip_extension():
return
Expand Down Expand Up @@ -145,14 +144,13 @@ def _check_mxfp8_layernorm_mlp_grad_support(
m, n, k,
use_bias
)
=======

def is_devices_equal(required):
"""
Check if the available GPUs is exactly equal
"""
return len(jax.devices()) == required

>>>>>>> 99df88

def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
# Generate broadcast dims for drop_path.
Expand Down
8 changes: 3 additions & 5 deletions transformer_engine/jax/csrc/extensions/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
// arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
#ifndef USE_ROCM
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#else
{
#endif
// ROCm fused attn has two backends: aotriton and ck
// They both have the same shape and stride for softmax and rng aux tensors
Expand Down Expand Up @@ -89,10 +91,7 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
bias_aux_data.dtype = static_cast<NVTEDType>(dtype);
nvte_set_tensor_param(&bias_aux, kNVTERowwiseData, &bias_aux_data);
}
<<<<<<< HEAD
#ifndef USE_ROCM
=======

// include softmax_offset if provided
if (softmax_offset_buf != nullptr) {
NVTETensor &softmax_offset_aux = tensor_pack->tensors[size];
Expand All @@ -107,12 +106,11 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
softmax_offset_aux_data.dtype = static_cast<NVTEDType>(DType::kFloat32);
nvte_set_tensor_param(&softmax_offset_aux, kNVTERowwiseData, &softmax_offset_aux_data);
}
#endif

// Set final size
tensor_pack->size = size;
>>>>>>> 99df88
}
#endif
nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data);
}
/*
Expand Down
11 changes: 3 additions & 8 deletions transformer_engine/jax/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,11 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
} else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
}
<<<<<<< HEAD
} else { // Swizzle for NVFP4
input.set_with_gemm_swizzled_scales(true);
} else if (is_nvfp4) { // Swizzle for NVFP4
#ifdef USE_ROCM
NVTE_ERROR("ROCm TE does not support NVFP4 yet.");
}
#else
=======
input.set_with_gemm_swizzled_scales(true);
} else if (is_nvfp4) { // Swizzle for NVFP4
>>>>>>> 99df88
NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS");
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
// Create tensor to hold swizzled scale factor
Expand All @@ -108,14 +103,14 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
// Set swizzled scales into the input tensor
input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape);
input.set_with_gemm_swizzled_scales(true);
#endif
} else { // Tensor scaling
if (rowwise) {
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
}
}
#endif // #ifdef USE_ROCM
}

return std::make_tuple(std::move(input), input_shape);
Expand Down