diff --git a/ci/pytorch.sh b/ci/pytorch.sh index be150485f..d6ae9fcb4 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -89,6 +89,7 @@ run_test_config_mgpu(){ configure_omp_threads 8 run_default_fa 1 test_fused_optimizer.py run_default_fa 3 test_sanity_import.py + run_default_fa 3 distributed/test_cast_master_weights_to_fp8.py run_default_fa 2 distributed/test_fusible_ops.py run_default_fa 2 distributed/test_numerics.py run_default_fa 1 distributed/test_torch_fsdp2.py diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index a4bdf5e07..8103fc276 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -1,4 +1,6 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -15,6 +17,7 @@ from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from ..optimizers.multi_tensor_apply import multi_tensor_applier +from ..utils import is_non_tn_fp8_gemm_supported, is_fp8_fnuz def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): @@ -282,7 +285,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo # Step 3: Update scales and scale_invs. # --------------------------------------------------------------------------------------------- if fp8_dtype == tex.DType.kFloat8E4M3: - max_fp8 = 448.0 + max_fp8 = 240.0 if is_fp8_fnuz() else 448.0 elif fp8_dtype == tex.DType.kFloat8E5M2: max_fp8 = 57344.0 else: @@ -412,7 +415,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # Step 3: Update scales and scale_invs. # --------------------------------------------------------------------------------------------- if fp8_dtype == tex.DType.kFloat8E4M3: - max_fp8 = 448.0 + max_fp8 = 240.0 if is_fp8_fnuz() else 448.0 elif fp8_dtype == tex.DType.kFloat8E5M2: max_fp8 = 57344.0 else: diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index ceb88108f..514e1cdba 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -12,11 +12,16 @@ from triton.language import core from triton.language.standard import _log2 +from packaging import version # The following three argsort related kernels are adapted from # the issue https://github.com/triton-lang/triton/issues/3698 +get_int_dtype = core.get_int_dtype +if version.parse(triton.__version__) >= version.parse("3.5.0"): + get_int_dtype = triton.constexpr_function(get_int_dtype) + @triton.jit def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr): @@ -37,7 +42,7 @@ def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr): l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape) r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape) - idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) il_value = l_value.to(idtype, bitcast=True) ir_value = r_value.to(idtype, bitcast=True)