From b8a402495948d87c55355ef2f8ee4a912dccde98 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Mon, 2 Feb 2026 14:16:13 -0600 Subject: [PATCH 01/41] [ROCm] resolve the conflicts in common dir --- hipify_custom_map.json | 3 +- transformer_engine/common/CMakeLists.txt | 161 +- transformer_engine/common/__init__.py | 63 +- transformer_engine/common/cast/cast.cu | 2 + .../common/cast/core/common.cuh | 2 + .../common/cast/dispatch/dequantize.cuh | 8 + .../common/cast/dispatch/gated.cuh | 18 + .../common/cast/dispatch/quantize.cuh | 8 + .../common/cast/fp8/dequantize_fp8.cuh | 2 + .../common/cast/fp8/gated_fp8.cuh | 4 + .../common/cast/fp8/quantize_fp8.cuh | 159 ++ .../common/cast/mxfp8/dequantize_mxfp8.cuh | 55 +- .../common/cast/mxfp8/gated_mxfp8.cuh | 563 +----- .../common/cast/mxfp8/quantize_mxfp8.cuh | 40 +- .../mxfp8/rocm_dequantize_mxfp8.cuh} | 26 +- .../mxfp8/rocm_gated_mxfp8.cuh} | 77 +- .../mxfp8/rocm_quantize_mxfp8.cuh} | 206 +-- transformer_engine/common/common.cu | 2 +- transformer_engine/common/common.h | 29 +- .../common/fused_attn_rocm/fused_attn.cpp | 124 +- .../fused_attn_rocm/fused_attn_aotriton.cpp | 6 +- .../fused_attn_rocm/fused_attn_aotriton.h | 1 + .../common/fused_attn_rocm/fused_attn_ck.cpp | 9 + .../common/fused_attn_rocm/fused_attn_ck.h | 1 + .../common/gemm/cublaslt_gemm.cu | 56 +- transformer_engine/common/gemm/rocm_gemm.cu | 5 +- .../include/transformer_engine/fused_attn.h | 10 +- .../common/normalization/common.h | 17 +- .../common/normalization/layernorm/ln_api.cpp | 4 - .../normalization/rmsnorm/rmsnorm_api.cpp | 4 - transformer_engine/common/recipe/__init__.py | 22 +- .../common/recipe/current_scaling.cu | 17 +- transformer_engine/common/swizzle/swizzle.cu | 131 -- .../common/util/cast_kernels.cuh | 1546 ----------------- transformer_engine/common/util/logging.h | 5 +- transformer_engine/common/util/ptx.cuh | 108 +- .../common/util/rocm_vectorized_2d.cuh | 68 - transformer_engine/common/utils.cuh | 3 - 38 files changed, 597 insertions(+), 2968 deletions(-) rename transformer_engine/common/{util/rocm_dequantize_kernels.cuh => cast/mxfp8/rocm_dequantize_mxfp8.cuh} (89%) rename transformer_engine/common/{util/rocm_cast_gated_kernels.cuh => cast/mxfp8/rocm_gated_mxfp8.cuh} (87%) rename transformer_engine/common/{util/rocm_cast_kernels.cuh => cast/mxfp8/rocm_quantize_mxfp8.cuh} (66%) delete mode 100644 transformer_engine/common/util/cast_kernels.cuh diff --git a/hipify_custom_map.json b/hipify_custom_map.json index 97824bbdb..812ea384d 100644 --- a/hipify_custom_map.json +++ b/hipify_custom_map.json @@ -6,7 +6,8 @@ "ATen/cudnn/Handle.h" : "ATen/miopen/Handle.h", "CUfunc_cache" : "hipFuncCache_t", "" : "", - "cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)" + "cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)", + "__nv_bfloat162":"__hip_bfloat162" } } diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 399137484..46eb5dba5 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -33,20 +33,8 @@ else() endif() # Language options -<<<<<<< HEAD if(USE_CUDA) # Removed indent to minimize code diff with NV upstream -if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0) - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) - elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) - else () - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) - endif() -endif() -======= ->>>>>>> 389a6b set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) @@ -180,31 +168,21 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) -<<<<<<< HEAD -# Source files in both cuda and rocm -list(APPEND transformer_engine_SOURCES -======= set(transformer_engine_cpp_sources) set(transformer_engine_cuda_sources) set(transformer_engine_cuda_arch_specific_sources) +# Source files in both cuda and rocm list(APPEND transformer_engine_cpp_sources - cudnn_utils.cpp ->>>>>>> 389a6b transformer_engine.cpp - fused_attn/fused_attn.cpp gemm/config.cpp normalization/common.cpp normalization/layernorm/ln_api.cpp normalization/rmsnorm/rmsnorm_api.cpp util/cuda_driver.cpp - util/cuda_nvml.cpp util/cuda_runtime.cpp util/multi_stream.cpp - util/rtc.cpp - comm_gemm_overlap/userbuffers/ipcsocket.cc - comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/comm_gemm_overlap.cpp) + util/rtc.cpp) list(APPEND transformer_engine_cuda_sources common.cu @@ -218,43 +196,18 @@ list(APPEND transformer_engine_cuda_sources transpose/cast_transpose_fusion.cu transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu -<<<<<<< HEAD -======= - transpose/quantize_transpose_vector_blockwise.cu ->>>>>>> 389a6b transpose/swap_first_dims.cu dropout/dropout.cu fused_attn/flash_attn.cu fused_attn/context_parallel.cu fused_attn/kv_cache.cu -<<<<<<< HEAD - activation/relu.cu - activation/swiglu.cu gemm/cublaslt_gemm.cu - normalization/common.cpp - normalization/layernorm/ln_api.cpp -======= - fused_attn/fused_attn_f16_max512_seqlen.cu - fused_attn/fused_attn_f16_arbitrary_seqlen.cu - fused_attn/fused_attn_fp8.cu - fused_attn/utils.cu - gemm/cublaslt_gemm.cu ->>>>>>> 389a6b normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu util/padding.cu -<<<<<<< HEAD - util/cuda_driver.cpp - util/cuda_runtime.cpp - util/multi_stream.cpp - util/rtc.cpp -======= ->>>>>>> 389a6b - swizzle/swizzle.cu - swizzle/swizzle_block_scaling.cu fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu @@ -264,41 +217,51 @@ list(APPEND transformer_engine_cuda_sources fused_router/fused_topk_with_score_function.cu recipe/current_scaling.cu recipe/delayed_scaling.cu -<<<<<<< HEAD recipe/fp8_block_scaling.cu) -if(USE_CUDA) -# Removed indent to minimize code diff with NV upstream -# Files unique in cuda building -list(APPEND transformer_engine_SOURCES - cudnn_utils.cpp - transpose/quantize_transpose_square_blockwise.cu - transpose/quantize_transpose_vector_blockwise.cu - fused_attn/fused_attn_f16_max512_seqlen.cu - fused_attn/fused_attn_f16_arbitrary_seqlen.cu - fused_attn/fused_attn_fp8.cu - fused_attn/fused_attn.cpp - fused_attn/utils.cu - gemm/cutlass_grouped_gemm.cu - util/cuda_nvml.cpp - comm_gemm_overlap/userbuffers/ipcsocket.cc - comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/userbuffers/userbuffers.cu - comm_gemm_overlap/comm_gemm_overlap.cpp) -======= - recipe/fp8_block_scaling.cu - recipe/nvfp4.cu - comm_gemm_overlap/userbuffers/userbuffers.cu) list(APPEND transformer_engine_cuda_arch_specific_sources - gemm/cutlass_grouped_gemm.cu cast/cast.cu activation/gelu.cu activation/relu.cu - activation/swiglu.cu - transpose/quantize_transpose_square_blockwise.cu - transpose/quantize_transpose_vector_blockwise_fp4.cu - hadamard_transform/hadamard_transform.cu - hadamard_transform/hadamard_transform_cast_fusion.cu) + activation/swiglu.cu) + +if(USE_CUDA) +#NV specific source codes + list(APPEND transformer_engine_cpp_sources + cudnn_utils.cpp + fused_attn/fused_attn.cpp + util/cuda_nvml.cpp + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/comm_gemm_overlap.cpp) + list(APPEND transformer_engine_cuda_sources + transpose/quantize_transpose_vector_blockwise.cu + fused_attn/fused_attn_f16_max512_seqlen.cu + fused_attn/fused_attn_f16_arbitrary_seqlen.cu + fused_attn/fused_attn_fp8.cu + fused_attn/utils.cu + swizzle/swizzle.cu + swizzle/swizzle_block_scaling.cu + recipe/nvfp4.cu + comm_gemm_overlap/userbuffers/userbuffers.cu) + list(APPEND transformer_engine_cuda_arch_specific_sources + gemm/cutlass_grouped_gemm.cu + transpose/quantize_transpose_square_blockwise.cu + transpose/quantize_transpose_vector_blockwise_fp4.cu + hadamard_transform/hadamard_transform.cu + hadamard_transform/hadamard_transform_cast_fusion.cu) +else() +#ROCm specific source codes + list(APPEND transformer_engine_cpp_sources + fused_attn_rocm/fused_attn.cpp + gemm/rocm_gemm.cu + amd_detail/system.cpp) + list(APPEND transformer_engine_cuda_sources + fused_attn_rocm/fused_attn_aotriton.cpp + fused_attn_rocm/fused_attn_ck.cpp + fused_attn_rocm/utils.cpp) +endif() + # Compiling the files with the worst compilation time first to hopefully overlap # better with the faster-compiling cpp files @@ -306,6 +269,7 @@ list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_s ${transformer_engine_cuda_sources} ${transformer_engine_cpp_sources}) +if(USE_CUDA) # Set compile options for CUDA sources with generic architectures foreach(cuda_source IN LISTS transformer_engine_cuda_sources) set(arch_compile_options) @@ -339,7 +303,6 @@ foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources) ) endif() endforeach() ->>>>>>> 389a6b if (NVTE_WITH_CUBLASMP) list(APPEND transformer_engine_SOURCES @@ -347,14 +310,8 @@ list(APPEND transformer_engine_SOURCES endif() add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) + else() - list(APPEND transformer_engine_SOURCES - fused_attn_rocm/fused_attn.cpp - fused_attn_rocm/fused_attn_aotriton.cpp - fused_attn_rocm/fused_attn_ck.cpp - fused_attn_rocm/utils.cpp - gemm/rocm_gemm.cu - amd_detail/system.cpp) # process source code files set(TE ${CMAKE_CURRENT_SOURCE_DIR}/../..) @@ -386,32 +343,20 @@ else() message(STATUS "nvte hipified sources: ${te_hip_sources}") add_library(transformer_engine SHARED ${te_hip_sources}) - target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}") + target_include_directories(transformer_engine PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) endif() target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") -<<<<<<< HEAD if (USE_CUDA) -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) - set_source_files_properties( - "gemm/cutlass_grouped_gemm.cu" - PROPERTIES - COMPILE_FLAGS - "-gencode arch=compute_90a,code=sm_90a") -else() - message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") -endif() -endif() #USE_CUDA -======= # CUTLASS kernels require SM90a and cause hang in debug build set_property( SOURCE gemm/cutlass_grouped_gemm.cu APPEND PROPERTY COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a;-g0") ->>>>>>> 389a6b +endif() #USE_CUDA # Configure dependencies if (USE_CUDA) @@ -567,22 +512,7 @@ target_include_directories(transformer_engine PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/string_headers") # Compiler options -<<<<<<< HEAD -set_source_files_properties(fused_softmax/scaled_masked_softmax.cu - fused_softmax/scaled_upper_triang_masked_softmax.cu - fused_softmax/scaled_aligned_causal_masked_softmax.cu - multi_tensor/adam.cu - multi_tensor/compute_scale.cu - multi_tensor/l2norm.cu - multi_tensor/scale.cu - multi_tensor/sgd.cu - fused_attn/flash_attn.cu - fused_attn/context_parallel.cu - fused_attn/kv_cache.cu - PROPERTIES - COMPILE_OPTIONS "--use_fast_math") if(USE_CUDA) -======= set(nvte_sources_with_fast_math) list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu @@ -596,7 +526,6 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu fused_attn/context_parallel.cu fused_attn/kv_cache.cu) ->>>>>>> 389a6b option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) list(APPEND nvte_sources_with_fast_math activation/gelu.cu diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index f8b302d49..cdda37508 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -17,16 +17,9 @@ import subprocess import sys import sysconfig -<<<<<<< HEAD -from typing import Optional - -import transformer_engine - -_logger = logging.getLogger(__name__) -======= from typing import Optional, Tuple ->>>>>>> 389a6b +import transformer_engine @functools.lru_cache(maxsize=None) def _is_package_installed(package) -> bool: @@ -145,8 +138,10 @@ def get_te_core_package_info() -> Tuple[bool, str, str]: Check if Tranformer Engine core package is installed. Returns the module name and version if found. """ - + te_core_packages = ("transformer-engine-cu12", "transformer-engine-cu13") + if te_rocm_build: + te_core_packages = ("transformer-engine-rocm") for package in te_core_packages: if _is_package_installed(package): return True, package, version(package) @@ -171,42 +166,6 @@ def load_framework_extension(framework: str) -> None: if framework == "torch": extra_dep_name = "pytorch" -<<<<<<< HEAD - te_cuda_vers = "rocm" if te_rocm_build else "cu12" - - # If the framework extension pip package is installed, it means that TE is installed via - # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework - # extension are all installed via PyPI and have matching version. - if _is_pip_package_installed(module_name): - assert _is_pip_package_installed( - "transformer_engine" - ), "Could not find `transformer-engine`." - assert _is_pip_package_installed( - f"transformer_engine_{te_cuda_vers}" - ), f"Could not find `transformer-engine-{te_cuda_vers}`." - assert ( - version(module_name) - == version("transformer-engine") - == version(f"transformer-engine-{te_cuda_vers}") - ), ( - "TransformerEngine package version mismatch. Found" - f" {module_name} v{version(module_name)}, transformer-engine" - f" v{version('transformer-engine')}, and transformer-engine-{te_cuda_vers}" - f" v{version(f'transformer-engine-{te_cuda_vers}')}. Install transformer-engine using " - f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'" - ) - - # If the core package is installed via PyPI, log if - # the framework extension is not found from PyPI. - # Note: Should we error? This is a rare use case. - if _is_pip_package_installed(f"transformer-engine-{te_cuda_vers}"): - if not _is_pip_package_installed(module_name): - _logger.info( - "Could not find package %s. Install transformer-engine using " - f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'", - module_name, - ) -======= # Find the TE packages. The core and framework packages can only be installed via PyPI. # For the `transformer-engine` package, we need to check explicity. te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info() @@ -230,7 +189,6 @@ def load_framework_extension(framework: str) -> None: f" v{te_core_version}. Install transformer-engine using " f"'pip3 install --no-build-isolation transformer-engine[{extra_dep_name}]==VERSION'" ) ->>>>>>> 389a6b # After all checks are completed, load the shared object file. spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework)) @@ -438,7 +396,6 @@ def _load_core_library(): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): -<<<<<<< HEAD try: _CUDNN_LIB_CTYPES = _load_cudnn() _NVRTC_LIB_CTYPES = _load_nvrtc() @@ -446,9 +403,6 @@ def _load_core_library(): _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") - # Needed to find the correct headers for NVRTC kernels. - if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir(): - os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir() except (OSError, subprocess.CalledProcessError): pass finally: @@ -473,13 +427,4 @@ def _load_core_library(): assert (rocm_version == build_rocm_version), f"ROCm {'.'.join(rocm_version)} is detected but the library is built for {'.'.join(build_rocm_version)}" except FileNotFoundError: pass -======= - sanity_checks_for_pypi_installation() - _CUDNN_LIB_CTYPES = _load_cudnn() - _NVRTC_LIB_CTYPES = _load_nvrtc() - _CURAND_LIB_CTYPES = _load_curand() - _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") - _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") - _TE_LIB_CTYPES = _load_core_library() ->>>>>>> 389a6b diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 1ed46a335..575106a53 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -5,7 +5,9 @@ ************************************************************************/ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif //#ifndef __HIP_PLATFORM_AMD__ #include #include #include diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index b750142f5..ec36e941f 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -12,7 +12,9 @@ #define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif //#ifndef __HIP_PLATFORM_AMD__ #include #include diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index b8547915c..f55719852 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -16,7 +16,9 @@ #include "../../common.h" #include "../fp8/dequantize_fp8.cuh" #include "../mxfp8/dequantize_mxfp8.cuh" +#ifndef __HIP_PLATFORM_AMD__ #include "../nvfp4/dequantize_nvfp4.cuh" +#endif //#ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { namespace dispatch { @@ -34,17 +36,23 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t break; } case NVTE_MXFP8_1D_SCALING: { +#ifndef __HIP_PLATFORM_AMD__ if (is_supported_by_CC_100()) { +#endif //#ifndef __HIP_PLATFORM_AMD__ mxfp8::dequantize(input, output, stream); +#ifndef __HIP_PLATFORM_AMD__ } else { NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); } +#endif //#ifndef __HIP_PLATFORM_AMD__ break; } +#ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { nvfp4::dequantize(input, output, stream); break; } +#endif //#ifndef __HIP_PLATFORM_AMD__ default: NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); } diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index 4373090b7..8f236023b 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -45,6 +47,9 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { +#ifdef __HIP_PLATFORM_AMD__ + fp8::cast_gated_fwd(input, output, p, stream); +#else const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); if (use_tma_kernels) { Tensor dummy_grad_tensor; @@ -53,6 +58,7 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp } else { fp8::cast_gated_fwd(input, output, p, stream); } +#endif //#ifdef __HIP_PLATFORM_AMD__ break; } case NVTE_MXFP8_1D_SCALING: { @@ -68,8 +74,12 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), "The type of the columnwise output tensor should be FP8."); } +#ifdef __HIP_PLATFORM_AMD__ + //TODO: add gfx950 equivalent checking +#else NVTE_CHECK(is_supported_by_CC_100(), "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); +#endif Tensor dummy_grad_tensor; mxfp8::quantize_gated(input, dummy_grad_tensor, output, p, stream); @@ -122,6 +132,9 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { +#ifdef __HIP_PLATFORM_AMD__ + fp8::cast_gated_bwd(gated_input, grad, output, p, stream); +#else const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); if (use_tma_kernels) { fp8::cast_gated_tma(gated_input, grad, output, p, @@ -129,6 +142,7 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte } else { fp8::cast_gated_bwd(gated_input, grad, output, p, stream); } +#endif //#ifdef __HIP_PLATFORM_AMD__ break; } case NVTE_MXFP8_1D_SCALING: { @@ -144,8 +158,12 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), "The type of the columnwise output tensor should be FP8."); } +#ifdef __HIP_PLATFORM_AMD__ + // add gfx950 equivalent check +#else NVTE_CHECK(is_supported_by_CC_100(), "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); +#endif //#ifdef __HIP_PLATFORM_AMD__ mxfp8::quantize_gated(gated_input, grad, output, p, stream); diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 9f7a4a9b0..8e8993668 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -19,8 +21,10 @@ #include "../core/common.cuh" #include "../fp8/quantize_fp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" +#ifndef __HIP_PLATFORM_AMD__ #include "../nvfp4/quantize_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" +#endif //#ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { namespace dispatch { @@ -87,6 +91,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, dummy_workspace_tensor, stream); break; } +#ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); @@ -167,6 +172,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, columnwise_option, force_pow_2_scales, noop_tensor->data, stream); break; } +#endif//#ifndef __HIP_PLATFORM_AMD__ default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } @@ -232,6 +238,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens stream); break; } +#ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { NVTE_CHECK((!IS_DBIAS && !IS_DACT), "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); @@ -315,6 +322,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens columnwise_option, force_pow_2_scales, noop_tensor->data, stream); break; } +#endif //#ifndef __HIP_PLATFORM_AMD__ default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } diff --git a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh index 2514758b5..5d30a6c3f 100644 --- a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh @@ -12,7 +12,9 @@ #define TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif // #ifndef __HIP_PLATFORM_AMD__ #include #include diff --git a/transformer_engine/common/cast/fp8/gated_fp8.cuh b/transformer_engine/common/cast/fp8/gated_fp8.cuh index 225ef93ed..c9040a3da 100644 --- a/transformer_engine/common/cast/fp8/gated_fp8.cuh +++ b/transformer_engine/common/cast/fp8/gated_fp8.cuh @@ -12,7 +12,9 @@ #define TRANSFORMER_ENGINE_GATED_FP8_CUH_ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif //#ifndef __HIP_PLATFORM_AMD__ #include #include @@ -25,6 +27,7 @@ namespace transformer_engine { namespace dispatch { namespace fp8 { +#ifndef __HIP_PLATFORM_AMD__ namespace kernel { constexpr size_t CHUNK_DIM_Y = 128; @@ -348,6 +351,7 @@ void cast_gated_tma(const Tensor &gated_input, const Tensor &grad, Tensor *outpu NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) ); // NOLINT(*) } +#endif //#ifndef __HIP_PLATFORM_AMD__ template void cast_gated_fwd(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) { diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh index efc5015b7..9de093e96 100644 --- a/transformer_engine/common/cast/fp8/quantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -12,7 +14,9 @@ #define TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif //#ifndef __HIP_PLATFORM_AMD__ #include #include #include @@ -35,6 +39,58 @@ namespace transformer_engine { namespace dispatch { namespace fp8 { +#ifdef __HIP_PLATFORM_AMD__ +constexpr size_t TILE_DIM = 32; +template +__global__ void partial_reduce_kernel(const DTypeReduce* input, float* partial_output, int rows, int cols) { + __shared__ float tile[TILE_DIM][TILE_DIM]; + + int tile_start_col = blockIdx.x * TILE_DIM; + int tile_start_row = blockIdx.y * TILE_DIM; + int thread_col_in_tile = threadIdx.x; + int thread_row_in_tile = threadIdx.y; + + int global_col = tile_start_col + thread_col_in_tile; + int global_row = tile_start_row + thread_row_in_tile; + + if (global_row < rows && global_col < cols) { + tile[thread_row_in_tile][thread_col_in_tile] = static_cast(input[global_row * cols + global_col]); + } else { + tile[thread_row_in_tile][thread_col_in_tile] = 0.0f; + } + __syncthreads(); + + for (int stride = TILE_DIM / 2; stride > 0; stride /= 2) { + if (thread_row_in_tile < stride) { + tile[thread_row_in_tile][thread_col_in_tile] += tile[thread_row_in_tile + stride][thread_col_in_tile]; + } + __syncthreads(); + } + + if (thread_row_in_tile == 0 && global_col < cols) { + partial_output[blockIdx.y * cols + global_col] = tile[0][thread_col_in_tile]; + } +} + +template +void reduce_dbias_rocm(const DTypeReduce *workspace_ptr, Tensor *dbias, const size_t rows, + const size_t cols, cudaStream_t stream, Tensor* partial_sum_workspace) { + dim3 block_dim_partial(TILE_DIM, TILE_DIM); + dim3 grid_dim_partial(DIVUP(cols, TILE_DIM), DIVUP(rows, TILE_DIM)); + + const size_t partial_rows = grid_dim_partial.y; + float* partial_workspace = reinterpret_cast(partial_sum_workspace->data.dptr); + + partial_reduce_kernel<<>>( + workspace_ptr, + partial_workspace, + rows, cols); + + common::reduce_dbias(partial_workspace, dbias, partial_rows, cols, stream); +} + + +#else namespace quantize_2D_kernel { constexpr size_t FP8_CHUNK_DIM_Y = 128; @@ -454,16 +510,33 @@ void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T }); // NOLINT(*) ); // NOLINT(*) } +#endif //#ifdef __HIP_PLATFORM_AMD__ namespace detail { using Empty = transformer_engine::Empty; __device__ inline float identity(float value, const Empty &) { return value; } } // namespace detail +/* HIPCC has strict rules for __device__ functions usage on host. + It forbids not only calling but also other ODR-use assigning to variables + https://github.com/llvm/llvm-project/issues/105825 + Use templated struct wrapper to work around + */ +template +struct ActivationType +{ + static constexpr auto op = OP; +}; + + template void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) { +#ifdef __HIP_PLATFORM_AMD__ + constexpr float (*UnaryOP)(float, const ParamOP &) = (ActivationType::op == nullptr) ? ActivationType::op : ActivationType::op; +#else //#ifdef __HIP_PLATFORM_AMD__ constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; +#endif //#ifdef __HIP_PLATFORM_AMD__ const size_t N = product(input.data.shape); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, IType, @@ -487,7 +560,11 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, template void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, cudaStream_t stream) { +#ifdef __HIP_PLATFORM_AMD__ + constexpr float (*UnaryOP)(float, const ParamOP &) = (ActivationType::op == nullptr) ? ActivationType::op : ActivationType::op; +#else //#ifdef __HIP_PLATFORM_AMD__ constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; +#endif //#ifdef __HIP_PLATFORM_AMD__ const size_t N = product(input->data.shape); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input->data.dtype, IType, @@ -512,7 +589,9 @@ template void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { +#ifndef __HIP_PLATFORM_AMD__ using namespace quantize_1D_kernel; +#endif //#ifndef __HIP_PLATFORM_AMD__ CheckNoopTensor(*noop, "cast_noop"); CheckInputTensor(input, "cast_input"); CheckOutputTensor(*output, "cast_output"); @@ -531,6 +610,85 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); +#ifdef __HIP_PLATFORM_AMD__ + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(workspace, "Workspace must be provided when IS_DBIAS is true."); + if (workspace->data.dptr == nullptr) { + if constexpr (IS_DACT) { + const size_t partial_rows = DIVUP(rows, TILE_DIM); + size_t total_elements = (rows * cols) + (partial_rows * cols); + workspace->data.shape = {total_elements}; + workspace->data.dtype = DType::kFloat32; + } else { + workspace->data.shape = {rows, cols}; + workspace->data.dtype = DType::kFloat32; + } + return; + } + + const void *ptr_to_reduce = nullptr; + DType dtype_to_reduce; + + workspace->amax = {}; + workspace->scale = {}; + workspace->scale_inv = {}; + + Tensor workspace_buffer; + Tensor partial_sum_buffer; + + if constexpr (IS_DACT) { + // The values to reduce are the result of the dAct function. + NVTE_CHECK(act_input, "Gradient tensor must be provided for DBias + DACT."); + + const size_t partial_rows = DIVUP(rows, TILE_DIM); + const size_t full_size_bytes = rows * cols * sizeof(float); + workspace_buffer = *workspace; + workspace_buffer.data.shape = {rows, cols}; + partial_sum_buffer.data.dptr = reinterpret_cast(workspace->data.dptr) + full_size_bytes; + partial_sum_buffer.data.shape = {partial_rows, cols}; + partial_sum_buffer.data.dtype = DType::kFloat32; + workspace = &partial_sum_buffer; + + CastVectorizedUnaryGradKernelLauncher(input, act_input, &workspace_buffer, stream); + if (output && output->data.dptr) { + CastVectorizedUnaryKernelLauncher(workspace_buffer, noop, output, stream); + } + ptr_to_reduce = workspace_buffer.data.dptr; + dtype_to_reduce = workspace_buffer.data.dtype; + } else { + if (output && output->data.dptr) { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + // The values to reduce are just the input values. + ptr_to_reduce = input.data.dptr; + dtype_to_reduce = input.data.dtype; + } + + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias tensor."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + dbias->data.dtype, DBiasTypeOut, + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + dtype_to_reduce, DTypeReduce, + reduce_dbias_rocm( + reinterpret_cast(ptr_to_reduce), + dbias, rows, cols, stream, workspace); + ); + ); + } else { + if (output && output->data.dptr) { + if constexpr (IS_DACT) { + NVTE_CHECK(act_input, "Gradient tensor must be provided for DACT output."); + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } else { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + } + } +#else // Supported by the Arch >= 10.0 if (is_supported_by_CC_100()) { if (!IS_DBIAS && !IS_DACT) { @@ -571,6 +729,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); } } +#endif //#ifdef __HIP_PLATFORM_AMD__ } } // namespace fp8 diff --git a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index 89391b21f..96aed3e88 100644 --- a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -20,34 +20,18 @@ #include #include -<<<<<<< HEAD:transformer_engine/common/util/dequantize_kernels.cuh -#include -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/transpose.h" -#ifdef __HIP_PLATFORM_AMD__ -#include "rocm_dequantize_kernels.cuh" -#endif -======= #include "../../common.h" #include "../../util/math.h" #include "../../util/ptx.cuh" #include "../../utils.cuh" ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh namespace transformer_engine { namespace dispatch { namespace mxfp8 { namespace dequantize_kernel { - -#ifndef __HIP_PLATFORM_AMD__ +#ifdef __HIP_PLATFORM_AMD__ +#include "rocm_dequantize_mxfp8.cuh" +#else template __global__ void __launch_bounds__(THREADS_PER_CHUNK) dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, @@ -225,11 +209,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -<<<<<<< HEAD:transformer_engine/common/util/dequantize_kernels.cuh #endif // #ifndef __HIP_PLATFORM_AMD__ -======= } // namespace dequantize_kernel ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { using namespace dequantize_kernel; @@ -328,39 +309,9 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) #endif NVTE_CHECK_CUDA(cudaGetLastError()); } -<<<<<<< HEAD:transformer_engine/common/util/dequantize_kernels.cuh -} // namespace dequantization - -namespace detail { - -void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - if (is_tensor_scaling(input.scaling_mode)) { - dequantization::fp8_dequantize(input, output, stream); - } else if (is_mxfp_scaling(input.scaling_mode)) { -#ifdef __HIP_PLATFORM_AMD__ - if (1) { -#else - if (is_supported_by_CC_100()) { -#endif - dequantization::mxfp8_dequantize(input, output, stream); - } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); - } - } else { - // TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING - NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); - } -} - -} // namespace detail -======= } // namespace mxfp8 } // namespace dispatch ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index a59e85659..28e46fc7a 100644 --- a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -20,258 +20,6 @@ #include #include -<<<<<<< HEAD:transformer_engine/common/util/cast_gated_kernels.cuh -#include - -#include "../common.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" -#ifdef __HIP_PLATFORM_AMD__ -#include "rocm_cast_gated_kernels.cuh" -#endif - -namespace transformer_engine { - -namespace gated_kernels { - -#ifndef __HIP_PLATFORM_AMD__ -constexpr size_t CHUNK_DIM_Y = 128; -constexpr size_t CHUNK_DIM_X = 128; -constexpr size_t THREADS_PER_CHUNK = 512; -constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; -constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 -constexpr size_t BUFFERS_NUM = 2; -constexpr size_t BUFFER_DIM_Y = 32; -constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 -constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 -constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 - -constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 -constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 -static_assert(ITERATIONS >= 1); - -__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, - const __grid_constant__ CUtensorMap tensor_map_input_act, - const __grid_constant__ CUtensorMap tensor_map_input_gate, - const __grid_constant__ CUtensorMap tensor_map_output_act, - const __grid_constant__ CUtensorMap tensor_map_output_gate, - float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, const size_t rows, const size_t cols) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; - - const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X; - - const size_t thread_offset_Y = tid_Y; - const size_t thread_offset_X = tid_X; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; - - constexpr size_t in_act_mem = buff_size_aligned_in; - constexpr size_t in_gate_mem = buff_size_aligned_in; - constexpr size_t in_mem = in_act_mem + in_gate_mem; - - constexpr size_t out_act_mem = buff_size_aligned_out; - constexpr size_t in_transaction_size = buff_elems * sizeof(IType); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_grad_sh = reinterpret_cast(dshmem); - IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); - IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); - OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); - OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); - - const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); - const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); - const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); - const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); - const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - // Prefetch data of the first stage - - if constexpr (IS_DGATED) { - copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, - TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, - chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], - is_master_thread); - } else { - copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, - TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], - is_master_thread); - } - -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - const size_t buff = it % BUFFERS_NUM; - const size_t next_it = it + 1; - if (next_it < ITERATIONS) { - const size_t next_buff = next_it % BUFFERS_NUM; - const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DGATED) { - copy_2d_to_sharedx3( - &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, - &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, - &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, - in_transaction_size, &mbar[next_it], is_master_thread); - } else { - copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, - chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, - chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, - &mbar[next_it], is_master_thread); - } - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[it], parity); - - IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; - IType *in_act_sh_curr = in_act_sh + buff * buff_elems; - IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; - OType *out_act_sh_curr = out_act_sh + buff * buff_elems; - OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; - -#pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; - const size_t shmem_offset_x = thread_offset_X; - const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - float act_elt = static_cast(in_act_sh_curr[shmem_idx]); - float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); - - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); - - const float x = act_elt; - float act_x; - float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; - } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); - } - - float after_dact = dact_x * grad_elt * gate_elt; - float after_dgate = act_x * grad_elt; - - out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); - out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); - - amax = fmaxf(amax, fabsf(after_dact)); - amax = fmaxf(amax, fabsf(after_dgate)); - } else { - const float after_act = ActOP(act_elt, {}) * gate_elt; - out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); - amax = fmaxf(amax, fabsf(after_act)); - } - } - - // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - - // dGeLU - ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, - chunk_it_offset_y, - reinterpret_cast(out_act_sh_curr)); - - if constexpr (IS_DGATED) { - // dGate - ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_sh_curr)); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - // Destroy the barriers. This invalidates the memory region of the barrier. - // If further computations were to take place in the kernel, this allows the - // memory location of the shared memory barrier to be reused. - if (is_master_thread) { -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_invalid(&mbar[it]); - } - } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -namespace mxfp8_kernel { -======= #include "../../common.h" #include "../../util/math.h" #include "../../util/ptx.cuh" @@ -281,8 +29,9 @@ namespace transformer_engine { namespace dispatch { namespace mxfp8 { namespace gated_kernel { ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh - +#ifdef __HIP_PLATFORM_AMD__ +#include "rocm_gated_mxfp8.cuh" +#else constexpr size_t CHUNK_DIM_Y = 64; constexpr size_t CHUNK_DIM_X = 64; constexpr size_t THREADS_PER_CHUNK_COLWISE = 128; @@ -925,99 +674,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +#endif //#ifndef __HIP_PLATFORM_AMD__ } // namespace gated_kernel template void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *output, ParamOP &p, cudaStream_t stream) { -<<<<<<< HEAD:transformer_engine/common/util/cast_gated_kernels.cuh - checkCuDriverContext(stream); - - if (output->has_data()) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - - NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block_dim(THREADS_PER_CHUNK); - const dim3 grid_dim(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - gated_input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_grad{}; - alignas(64) CUtensorMap tensor_map_input_act{}; - alignas(64) CUtensorMap tensor_map_input_gate{}; - alignas(64) CUtensorMap tensor_map_output_act{}; - alignas(64) CUtensorMap tensor_map_output_gate{}; - - if constexpr (IS_DGATED) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, - cols, 0, typeToNumBits(gated_input.dtype())); - } - - const uint32_t tensor_stride_elems = output_cols; - - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, cols, - typeToNumBits(output->dtype())); - - const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - const size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_fp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_fp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, - tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols); - NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) -} -#endif //#ifdef __HIP_PLATFORM_AMD__ - -template -void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, - cudaStream_t stream) { -======= using namespace gated_kernel; ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh checkCuDriverContext(stream); const bool USE_ROWWISE_SCALING = output->has_data(); @@ -1045,26 +709,18 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_BWD ? 2 : 1) * cols; -<<<<<<< HEAD:transformer_engine/common/util/cast_gated_kernels.cuh #ifdef __HIP_PLATFORM_AMD__ constexpr size_t TMA_SHMEM_ALIGNMENT = ALIGNMENT_SIZE; constexpr size_t BUFF_DIM_Y = BUFFER_DIM_Y; constexpr size_t BUFF_DIM_X = BUFFER_DIM_X; constexpr size_t BUFFS_NUM = BUFFERS_NUM; +#endif const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); -#else - - constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; - constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; - constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; -======= - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +#ifndef __HIP_PLATFORM_AMD__ const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE) ? THREADS_PER_CHUNK_COLWISE : THREADS_PER_CHUNK_NON_COLWISE; @@ -1087,7 +743,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out output->dtype(), OType, #ifdef __HIP_PLATFORM_AMD__ - const IType *tensor_map_grad = IS_DGATED ? reinterpret_cast(grad.data.dptr) : nullptr; + const IType *tensor_map_grad = IS_BWD ? reinterpret_cast(grad.data.dptr) : nullptr; const IType *tensor_map_input_act = reinterpret_cast(gated_input.data.dptr); const IType *tensor_map_input_gate = reinterpret_cast(gated_input.data.dptr) + cols; OType *tensor_map_output_act_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) : nullptr; @@ -1153,15 +809,11 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; -<<<<<<< HEAD:transformer_engine/common/util/cast_gated_kernels.cuh #ifdef __HIP_PLATFORM_AMD__ const size_t out_gate_mem = buff_size_aligned_out; #else - const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); -#endif -======= const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +#endif size_t out_mem = out_act_mem + out_gate_mem; if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } @@ -1175,18 +827,18 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out (USE_ROWWISE_SCALING ? 32 : 1), SCALE_DIM_X, TRANSFORMER_ENGINE_SWITCH_CONDITION(!(cols % (32 * sizeof(IType))), IS_ALIGNED, { NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_gated_kernel, + quantize_gated_mxfp8_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - cast_mxfp8_gated_kernel + quantize_gated_mxfp8_kernel <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaGetLastError()); }))); // NOLINT(*) #else @@ -1232,200 +884,15 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); break; -<<<<<<< HEAD:transformer_engine/common/util/cast_gated_kernels.cuh - } + } + } NVTE_CHECK_CUDA(cudaGetLastError()); // NOLINT(*) #endif ); // NOLINT(*) ); // NOLINT(*) } -template -void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "gated_act_input"); - CheckOutputTensor(*output, "gated_act_output"); - NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(), - "Wrong output shape. Expected (after flattening) [", input.flat_first_dim(), - ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(input.flat_last_dim() % 2 == 0, - "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", - input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == input.flat_last_dim() / 2, - "Wrong output shape. Expected (after flattening) [*, ", input.flat_last_dim() / 2, - "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - GatedActivationKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), - output->flat_last_dim(), {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(grad, "dgated_act_grad"); - CheckInputTensor(input, "dgated_act_input"); - CheckOutputTensor(*output, "dgated_act_output"); - NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(), - "Wrong output shape. Expected (after flattening) [", grad.flat_first_dim(), - ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2, - "Wrong output shape. Expected (after flattening) [*, ", grad.flat_last_dim() * 2, - "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(input.data.shape == output->data.shape, - "Input and output shapes must match. Input shape: ", input.data.shape, - ", output shape: ", output->data.shape, "."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - DGatedActivationKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), - grad.flat_last_dim(), {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, - cudaStream_t stream) { - constexpr bool allow_empty = false; - CheckInputTensor(gated_input, "gated_input"); - CheckOutputTensor(*output, "output", allow_empty); - - NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even."); - - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - if constexpr (IS_DGATED) { - CheckInputTensor(grad, "grad"); - NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); - NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input."); - NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input."); - } - - NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - bool is_fp8_rowwise_output = true; - bool is_fp8_colwise_output = true; - if (output->has_data()) { - is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype); - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); - } - if (output->has_columnwise_data()) { - is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype); - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); - } - - const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && cols % 32 == 0; - - if (is_delayed_tensor_scaling(output->scaling_mode)) { -#ifdef __HIP_PLATFORM_AMD__ - if constexpr (IS_DGATED) { - cast_dgated(grad, gated_input, output, stream); - } else { - cast_gated(gated_input, output, stream); - } -#else - if (use_tma_kernels) { - cast_fp8_gated(grad, gated_input, output, stream); - } else { - if constexpr (IS_DGATED) { - cast_dgated(grad, gated_input, output, stream); - } else { - cast_gated(gated_input, output, stream); - } - } -#endif - } else if (is_mxfp_scaling(output->scaling_mode)) { - if (use_tma_kernels) { - cast_mxfp8_gated(grad, gated_input, output, stream); - } else { - NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", - "by 32, got input of shape ", gated_input.data.shape); - } - } else { - NVTE_ERROR("Not supported scaling mode"); - } -} -} // namespace gated_kernels - -namespace detail { - -template -void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, - cudaStream_t stream) { - using namespace gated_kernels; - Tensor grad_empty_tensor; - const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; - const Tensor gated_input_tensor = *convertNVTETensorCheck(gated_input); - Tensor *output_tensor = convertNVTETensorCheck(output); - -#ifdef __HIP_PLATFORM_AMD__ - if (1) { -#else - if (is_supported_by_CC_100()) { -#endif - quantize_gated(grad_tensor, gated_input_tensor, - output_tensor, stream); - } else { - if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { - if constexpr (IS_DGATED) { - cast_dgated(grad_tensor, gated_input_tensor, output_tensor, stream); - } else { - cast_gated(gated_input_tensor, output_tensor, stream); - } - } else { - // MX scaling - NVTE_ERROR("Not supported by the Arch < 10.0"); - } - } -} -} // namespace detail - -======= - } - } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) -} - } // namespace mxfp8 } // namespace dispatch ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 5505de605..19234e9b4 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -12,7 +14,9 @@ #define TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif //#ifndef __HIP_PLATFORM_AMD__ #include #include @@ -26,7 +30,9 @@ namespace transformer_engine { namespace dispatch { namespace mxfp8 { namespace quantize_kernel { - +#ifdef __HIP_PLATFORM_AMD__ +#include "rocm_quantize_mxfp8.cuh" +#else constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; @@ -536,6 +542,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +#endif //#ifndef __HIP_PLATFORM_AMD__ } // namespace quantize_kernel template has_data(); bool use_colwise_scaling = output->has_columnwise_data(); @@ -562,6 +571,11 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); +#ifdef __HIP_PLATFORM_AMD__ + constexpr size_t CHUNK_DIM_Y = MXFP8_CHUNK_DIM_Y; + constexpr size_t CHUNK_DIM_X = MXFP8_CHUNK_DIM_X; + constexpr size_t THREADS_PER_CHUNK = MXFP8_THREADS_PER_CHUNK; +#else constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; @@ -572,6 +586,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; constexpr size_t BUFF_DIM_Y = THREADS_Y; constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; +#endif const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); @@ -589,6 +604,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, const size_t dbias_rows = blocks_Y; const size_t dbias_cols = cols; +#ifndef __HIP_PLATFORM_AMD__ ScalingType scaling_type; if (use_rowwise_scaling && (!use_colwise_scaling)) { scaling_type = ScalingType::ROWWISE; @@ -597,6 +613,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } else if (use_rowwise_scaling && use_colwise_scaling) { scaling_type = ScalingType::BIDIMENSIONAL; } +#endif if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); @@ -619,6 +636,26 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, +#ifdef __HIP_PLATFORM_AMD__ + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + !(cols % (32 * sizeof(IType))), IS_ALIGNED, + quantize_mxfp8_kernel + <<>>( + reinterpret_cast(input.data.dptr), + (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + scales_rowwise_ptr, scales_colwise_ptr, + reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + ))); // NOLINT(*) +#else // #ifdef __HIP_PLATFORM_AMD__ alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_act_input{}; alignas(64) CUtensorMap tensor_map_output_rowwise{}; @@ -708,6 +745,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } } +#endif // #ifdef __HIP_PLATFORM_AMD__ if constexpr (IS_DBIAS) { common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh similarity index 89% rename from transformer_engine/common/util/rocm_dequantize_kernels.cuh rename to transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh index 398e4c0ad..02224a69f 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh @@ -5,26 +5,7 @@ ************************************************************************/ #pragma once - -#include -#include -#include -#include - -#include "common.h" -#include "math.h" -#include "ptx.cuh" -#include "rocm_vectorized_2d.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/cast.h" -#include "transpose/cast_transpose.h" -#include "transformer_engine/transpose.h" -#include "utils.cuh" -#include "vectorized_pointwise.h" - -namespace transformer_engine { - -namespace dequantization { +// drop-in rocm replacement for mxfp8 dequantize kernel constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; @@ -127,12 +108,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); - bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); } } -} // namespace dequantization -} // namespace transformer_engine + diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh similarity index 87% rename from transformer_engine/common/util/rocm_cast_gated_kernels.cuh rename to transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh index a53fd51c5..7382b8aab 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh @@ -5,22 +5,7 @@ ************************************************************************/ #pragma once - -#include -#include -#include - -#include "common.h" -#include "math.h" -#include "ptx.cuh" -#include "rocm_vectorized_2d.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/cast.h" -#include "vectorized_pointwise.h" -#include "utils.cuh" - -namespace transformer_engine { -namespace gated_kernels { +// drop-in rocm replacement for mxfp8 gated quantize kernel constexpr size_t ALIGNMENT_SIZE = 128; // TODO: Identify optimal chunk/thread size for MI350+ @@ -45,16 +30,17 @@ template __global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_gated_kernel(const IType *grad_ptr, - const IType *input_act, - const IType *input_gate, - OType *output_act_rowwise, - OType *output_gate_rowwise, - OType *output_act_colwise, - OType *output_gate_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { + quantize_gated_mxfp8_kernel( + const IType *grad_ptr, + const IType *input_act, + const IType *input_gate, + OType *output_act_rowwise, + OType *output_gate_rowwise, + OType *output_act_colwise, + OType *output_gate_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise, const ParamOP p) { constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; @@ -171,24 +157,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_elt = static_cast(in_act_sh[shmem_idx]); float gate_elt = static_cast(in_gate_sh[shmem_idx]); + bool dgate_elt = true; // gating is ideally an identity function + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + } + if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh[shmem_idx]); const float x = act_elt; float act_x; float dact_x; - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x * s * (1 - s) + s; + dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } } + after_dact_reg[stage] = dact_x * grad_elt * gate_elt; - after_dgate_reg[stage] = act_x * grad_elt; + after_dgate_reg[stage] = dgate_elt ? act_x * grad_elt : 0.0f; } else { - after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; + after_dact_reg[stage] = ActOP(act_elt, p) * gate_elt; } // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 @@ -355,24 +356,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); if constexpr (USE_ROWWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { - bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } if constexpr (USE_COLWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { - bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } __syncthreads(); } } -} // namespace gated_kernels -} // namespace transformer_engine diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh similarity index 66% rename from transformer_engine/common/util/rocm_cast_kernels.cuh rename to transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh index e39e0a4a7..dc36fb42d 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh @@ -4,28 +4,7 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ #pragma once - -#include -#include -#include - -#include "common.h" -#include "math.h" -#include "ptx.cuh" -#include "rocm_vectorized_2d.cuh" -#include "transformer_engine/cast.h" -#include "transpose/cast_transpose.h" -#include "vectorized_pointwise.h" -#include "utils.cuh" - -namespace transformer_engine { - -// Forward declaration, definition is in cast_kernels.cuh -template -void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream); - +// drop-in replacement for rocm quantize_mxfp8 kernels constexpr size_t MXFP8_CHUNK_DIM_Y = 64; constexpr size_t MXFP8_CHUNK_DIM_X = 64; @@ -53,14 +32,15 @@ template __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) - cast_mxfp8_2D_kernel(const IType *input_ptr, - const IType *act_input_ptr, - OType *output_rowwise, - OType *output_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const float *noop, float *const dbias_workspace, float *const amax_ptr, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { + quantize_mxfp8_kernel( + const IType *input_ptr, + const IType *act_input_ptr, + OType *output_rowwise, + OType *output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const float *noop, float *const dbias_workspace, float *const amax_ptr, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { if (noop != nullptr && noop[0] == 1.0f) return; } @@ -310,12 +290,12 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __syncthreads(); if constexpr (USE_ROWWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } if constexpr (USE_COLWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } @@ -393,165 +373,3 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) atomicMaxFloat(amax_ptr, block_amax); } } - -// Forward declaration of functions defined in `cast_kernels.cuh` -template -void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, - cudaStream_t stream); - -template -void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, - cudaStream_t stream); - -template -void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, - cudaStream_t stream); - -constexpr size_t TILE_DIM = 32; -template -__global__ void partial_reduce_kernel(const DTypeReduce* input, float* partial_output, int rows, int cols) { - __shared__ float tile[TILE_DIM][TILE_DIM]; - - int tile_start_col = blockIdx.x * TILE_DIM; - int tile_start_row = blockIdx.y * TILE_DIM; - int thread_col_in_tile = threadIdx.x; - int thread_row_in_tile = threadIdx.y; - - int global_col = tile_start_col + thread_col_in_tile; - int global_row = tile_start_row + thread_row_in_tile; - - if (global_row < rows && global_col < cols) { - tile[thread_row_in_tile][thread_col_in_tile] = static_cast(input[global_row * cols + global_col]); - } else { - tile[thread_row_in_tile][thread_col_in_tile] = 0.0f; - } - __syncthreads(); - - for (int stride = TILE_DIM / 2; stride > 0; stride /= 2) { - if (thread_row_in_tile < stride) { - tile[thread_row_in_tile][thread_col_in_tile] += tile[thread_row_in_tile + stride][thread_col_in_tile]; - } - __syncthreads(); - } - - if (thread_row_in_tile == 0 && global_col < cols) { - partial_output[blockIdx.y * cols + global_col] = tile[0][thread_col_in_tile]; - } -} - -template -void reduce_dbias_rocm(const DTypeReduce *workspace_ptr, Tensor *dbias, const size_t rows, - const size_t cols, cudaStream_t stream, Tensor* partial_sum_workspace) { - dim3 block_dim_partial(TILE_DIM, TILE_DIM); - dim3 grid_dim_partial(DIVUP(cols, TILE_DIM), DIVUP(rows, TILE_DIM)); - - const size_t partial_rows = grid_dim_partial.y; - float* partial_workspace = reinterpret_cast(partial_sum_workspace->data.dptr); - - partial_reduce_kernel<<>>( - workspace_ptr, - partial_workspace, - rows, cols); - - reduce_dbias(partial_workspace, dbias, partial_rows, cols, stream); -} - -template -void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias, "DBias tensor must be provided when IS_DBIAS is true."); - NVTE_CHECK(workspace, "Workspace must be provided when IS_DBIAS is true."); - if (workspace->data.dptr == nullptr) { - if constexpr (IS_DACT) { - const size_t partial_rows = DIVUP(rows, TILE_DIM); - size_t total_elements = (rows * cols) + (partial_rows * cols); - workspace->data.shape = {total_elements}; - workspace->data.dtype = DType::kFloat32; - } else { - workspace->data.shape = {rows, cols}; - workspace->data.dtype = DType::kFloat32; - } - return; - } - - const void *ptr_to_reduce = nullptr; - DType dtype_to_reduce; - - workspace->amax = {}; - workspace->scale = {}; - workspace->scale_inv = {}; - - Tensor workspace_buffer; - Tensor partial_sum_buffer; - - if constexpr (IS_DACT) { - // The values to reduce are the result of the dAct function. - NVTE_CHECK(act_input, "Gradient tensor must be provided for DBias + DACT."); - - const size_t partial_rows = DIVUP(rows, TILE_DIM); - const size_t full_size_bytes = rows * cols * sizeof(float); - workspace_buffer = *workspace; - workspace_buffer.data.shape = {rows, cols}; - partial_sum_buffer.data.dptr = reinterpret_cast(workspace->data.dptr) + full_size_bytes; - partial_sum_buffer.data.shape = {partial_rows, cols}; - partial_sum_buffer.data.dtype = DType::kFloat32; - workspace = &partial_sum_buffer; - - CastVectorizedUnaryGradKernelLauncher(input, act_input, &workspace_buffer, stream); - if (output && output->data.dptr) { - CastVectorizedUnaryKernelLauncher(workspace_buffer, noop, output, stream); - } - ptr_to_reduce = workspace_buffer.data.dptr; - dtype_to_reduce = workspace_buffer.data.dtype; - } else { - if (output && output->data.dptr) { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } - // The values to reduce are just the input values. - ptr_to_reduce = input.data.dptr; - dtype_to_reduce = input.data.dtype; - } - - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias tensor."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - dbias->data.dtype, DBiasTypeOut, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - dtype_to_reduce, DTypeReduce, - reduce_dbias_rocm( - reinterpret_cast(ptr_to_reduce), - dbias, rows, cols, stream, workspace); - ); - ); - } else { - if (output && output->data.dptr) { - if constexpr (IS_DACT) { - NVTE_CHECK(act_input, "Gradient tensor must be provided for DACT output."); - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } else { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } - } - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize(input, act_input, noop, output, dbias, - workspace, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} - - -} // namespace transformer_engine diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index af3a51373..ab574256c 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -230,13 +230,13 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, // Any element that is outside of bounds will be set to zero by the TMA transfer. CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); } -#endif //#ifndef __HIP_PLATFORM_AMD__ bool is_supported_by_CC_100() { int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); return deviceComputeCapability >= 100; } +#endif //#ifndef __HIP_PLATFORM_AMD__ std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size) { diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index ae90ea4e5..03b90febb 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -354,7 +354,8 @@ using fp8e8m0 = __nv_fp8_e8m0; #endif // CUDA_VERSION >= 12080 #if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; -<<<<<<< HEAD +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; #endif //FP4_TYPE_SUPPORTED #else using bf16 = hip_bfloat16; @@ -362,11 +363,6 @@ using fp8e4m3 = te_hip_fp8_e4m3; using fp8e5m2 = te_hip_fp8_e5m2; #endif //__HIP_PLATFORM_AMD__ -======= -using fp4e2m1x2 = __nv_fp4x2_e2m1; -using fp4e2m1x4 = __nv_fp4x4_e2m1; -#endif ->>>>>>> 389a6b using e8m0_t = uint8_t; namespace detail { @@ -416,15 +412,14 @@ template <> struct TypeExtrema { #ifndef __HIP_PLATFORM_AMD__ static constexpr float max = 448.0f; -<<<<<<< HEAD + static constexpr float max_inverse = 1.0 / max; #elif defined(__HIP_DEVICE_COMPILE__) - static constexpr float maxNorm = te_fp8_fnuz() ? 240.0f : 448.0f; + static constexpr float max = te_fp8_fnuz() ? 240.0f : 448.0f; + static constexpr float max_inverse = 1.0 / max; #else - static float maxNorm; + static float max; + static float max_inverse; #endif -======= - static constexpr float max_inverse = 1.0 / max; ->>>>>>> 389a6b }; template <> @@ -820,21 +815,15 @@ void checkCuDriverContext(CUstream stream); CUtensorMapDataType get_CUtensorMapDataType(DType dtype); // Set up parameters to create TMA descriptor. -<<<<<<< HEAD -void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, - const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, - const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_num_bits); -#endif //#ifdef __HIP_PLATFORM_AMD__ -======= void create_2D_tensor_map( CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits, const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); ->>>>>>> 389a6b bool is_supported_by_CC_100(); +#endif //#ifdef __HIP_PLATFORM_AMD__ + std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index bb5e22887..d39fccbce 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -276,9 +276,10 @@ void log_fused_attn_config( // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right, bool return_max_logit, bool cuda_graph) { using namespace transformer_engine; // by default, fused attn is enabled @@ -311,6 +312,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, @@ -325,6 +327,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, @@ -339,12 +342,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // NVTE fused attention FWD with packed QKV -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - const NVTETensor rng_state, size_t max_seqlen, bool is_training, +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, bool return_max_logit, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + bool cuda_graph, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); @@ -384,9 +389,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); - + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, + cuda_graph); + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_qkvpacked( b, h, max_seqlen, d, @@ -416,15 +422,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } // NVTE fused attention BWD with packed QKV -void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; @@ -468,8 +473,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)){ @@ -505,14 +510,17 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } // NVTE fused attention FWD with packed KV -void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -556,8 +564,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, + return_max_logit, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_kvpacked( @@ -596,11 +605,12 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; @@ -649,9 +659,10 @@ void nvte_fused_attn_bwd_kvpacked( // fix the incompatible window size from upstream frameworks pytorch/jax std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, + d, window_size_left, window_size_right, false, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -694,14 +705,16 @@ void nvte_fused_attn_bwd_kvpacked( // NVTE fused attention FWD with separate Q, K and V void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); @@ -740,8 +753,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, + return_max_logit, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd( @@ -780,14 +794,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -830,8 +845,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, + cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index a8a151b40..1c25fa031 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -44,6 +44,7 @@ bool is_aotriton_backend_supported( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, @@ -68,7 +69,10 @@ bool is_aotriton_backend_supported( if(!(is_no_mask_window_size || is_causal_mask_window_size)){ return false; } - + + if(softmax_type!=NVTE_VANILLA_SOFTMAX){ + return false; + } //aotriton fused attn does not support gqa mode now if(num_attn_heads!=num_gqa_groups){ return false; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h index b016acc67..178bd8d8f 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h @@ -23,6 +23,7 @@ bool is_aotriton_backend_supported( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 7ca6fc95f..8d639c47c 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -26,6 +26,7 @@ bool is_ck_backend_supported( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, @@ -80,6 +81,14 @@ bool is_ck_backend_supported( return false; } + // filter based on softmax type + if(softmax_type!=NVTE_VANILLA_SOFTMAX){ + if(nvte_log_ck_config){ + std::cout<<"AITER/CK fused attn does not support learnable sink yet"<>>>>>> 389a6b +#endif #ifndef __HIP_PLATFORM_AMD__ namespace { @@ -333,7 +328,7 @@ namespace transformer_engine { void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, cublasOperation_t transb, bool grad, void* workspace, size_t workspaceSize, - float alpha, float beta, bool use_split_accumulator, int math_sm_count, + const void* alpha_ptr, const void* beta_ptr, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, hipStream_t stream, int compute_stream_offset = -1); #else // Use cublasLt @@ -928,12 +923,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor cudaStream_t stream) { NVTE_API_CALL(nvte_cublas_atomic_gemm); using namespace transformer_engine; -<<<<<<< HEAD #ifndef __HIP_PLATFORM_AMD__ // Check CUDA and cuBLAS versions -======= ->>>>>>> 389a6b #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); @@ -951,7 +943,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor cublas_version() >= 120205 && cublas_version() < 130000, "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", cublas_version()); -#endif //__HIP_PLATFORM_AMD__ +#endif const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); @@ -971,7 +963,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], alpha_ptr, beta_ptr, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); -#endif +#endif //#ifndef __HIP_PLATFORM_AMD__ } void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, @@ -992,28 +984,24 @@ void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETens } for (int i = 0; i < num_gemms; i++) { -<<<<<<< HEAD #ifdef __HIP_PLATFORM_AMD__ - { - const Tensor *inputA = convertNVTETensorCheck(A[i]); - const Tensor *inputB = convertNVTETensorCheck(B[i]); - Tensor *outputD = convertNVTETensorCheck(D[i]); - const Tensor *biasTensor = convertNVTETensorCheck(bias[i]); - Tensor *outputGelu = convertNVTETensorCheck(pre_gelu_out[i]); - Tensor *wspace = convertNVTETensorCheck(workspace[i % num_streams]); - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], 1.0f, (accumulate) ? 1.0f : 0.0f, - use_split_accumulator, math_sm_count, 0, 0, false, nullptr, - detail::get_compute_stream(i % num_streams), i % num_streams); - } + const Tensor *inputA = convertNVTETensorCheck(A[i]); + const Tensor *inputB = convertNVTETensorCheck(B[i]); + Tensor *outputD = convertNVTETensorCheck(D[i]); + const Tensor *biasTensor = convertNVTETensorCheck(bias[i]); + Tensor *outputGelu = convertNVTETensorCheck(pre_gelu_out[i]); + Tensor *wspace = convertNVTETensorCheck(workspace[i % num_streams]); + + // Scales + const float alpha = 1; + const float beta = accumulate ? 1 : 0; + + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, + (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, + wspace->data.dptr, wspace->data.shape[0], &alpha, &beta, + use_split_accumulator, math_sm_count, 0, 0, false, nullptr, + detail::get_compute_stream(i % num_streams), i % num_streams); #else - nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, - workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, - detail::get_compute_stream(i % num_streams)); -#endif -======= // Check whether GELU or dGELU epilogue is requested Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]); bool with_gelu_dgelu_epilogue = @@ -1038,7 +1026,7 @@ void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETens nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i], workspace[i % num_streams], &config, detail::get_compute_stream(i % num_streams)); ->>>>>>> 389a6b +#endif } // record events on compute streams diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index fef3966a5..97bd2e8a7 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1501,7 +1501,7 @@ void release_service_stream(hipStream_t stream, struct ServiceStreamCtl &ctl) void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, - float alpha, float beta, bool use_split_accumulator, int math_sm_count, + const void* alpha_ptr, const void* beta_ptr, bool use_split_accumulator, int math_sm_count, [[maybe_unused]] int m_split, [[maybe_unused]] int n_split, [[maybe_unused]] bool gemm_producer, [[maybe_unused]] const Tensor *inputCounter, hipStream_t stream, int compute_stream_offset) @@ -1527,6 +1527,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const int ldb = is_transb ? n : k; const int ldd = m; + float alpha = *reinterpret_cast(alpha_ptr); // Assumed to be on CPU + float beta = *reinterpret_cast(beta_ptr); // Assumed to be on CPU + ServiceStreamCtl ss_ctl; bool use_service_stream = (math_sm_count != 0) ? get_service_stream(math_sm_count, stream, ss_ctl) : false; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index e4a86698c..158d8ea5d 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -307,15 +307,12 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, /*! \brief Compute the backward of the dot product attention with packed QKV input. * -<<<<<<< HEAD + * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * Support Matrix for ROCm AOTriton: \verbatim | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | aotriton| FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO | NO/CAUSAL | Yes | arbitrary | arbitrary | \endverbatim -======= - * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. ->>>>>>> 389a6b * * Support Matrix: \verbatim @@ -462,15 +459,12 @@ void nvte_fused_attn_fwd_kvpacked( /*! \brief Compute the backward of the dot product attention with packed KV input. * -<<<<<<< HEAD + * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * Support Matrix for ROCm AOTriton: \verbatim | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | aotriton| FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO | NO/CAUSAL | Yes | arbitrary | arbitrary | \endverbatim -======= - * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. ->>>>>>> 389a6b * * Support Matrix: \verbatim diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index cfcd91646..a5278522c 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -29,7 +29,7 @@ #ifndef __HIP_PLATFORM_AMD__ #include "../cudnn_utils.h" #else -#include "../util/rocm_cast_kernels.cuh" +#include "../cast/mxfp8/quantize_mxfp8.cuh" #endif #include "../util/system.h" @@ -447,10 +447,10 @@ void rocm_norm_mxfp8_quantize(LaunchParams &launch_params) const size_t scale_dim_X_rowwise = 32; const size_t scale_dim_Y_colwise = launch_params.training ? 32 : 1; - const size_t chunks_Y = DIVUP(rows, transformer_engine::MXFP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, transformer_engine::MXFP8_CHUNK_DIM_X); - const size_t blocks_Y = DIVUP(chunks_Y, transformer_engine::MXFP8_CHUNKS_PER_BLOCK_Y); - const size_t blocks_X = DIVUP(chunks_X, transformer_engine::MXFP8_CHUNKS_PER_BLOCK_X); + const size_t chunks_Y = DIVUP(rows, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(chunks_Y, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNKS_PER_BLOCK_Y); + const size_t blocks_X = DIVUP(chunks_X, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNKS_PER_BLOCK_X); const size_t scale_stride_rowwise = launch_params.z_tensor->scale_inv.shape[1]; const size_t scale_stride_colwise = launch_params.training ? launch_params.z_tensor->columnwise_scale_inv.shape[1] : 1; @@ -459,17 +459,18 @@ void rocm_norm_mxfp8_quantize(LaunchParams &launch_params) e8m0_t *const scales_colwise_ptr = launch_params.training ? reinterpret_cast(launch_params.z_tensor->columnwise_scale_inv.dptr) : nullptr; - const dim3 block(transformer_engine::MXFP8_THREADS_PER_CHUNK); + const dim3 block(dispatch::mxfp8::quantize_kernel::MXFP8_THREADS_PER_CHUNK); const dim3 grid(blocks_X, blocks_Y); + using namespace dispatch::mxfp8::quantize_kernel; TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( scale_dim_Y_colwise, SCALE_DIM_Y, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( launch_params.z_tensor->dtype(), OType, TRANSFORMER_ENGINE_SWITCH_CONDITION( !(cols % (32 * sizeof(compute_t))), IS_ALIGNED, - cast_mxfp8_2D_kernel<<>>( + quantize_mxfp8_kernel<<>>( reinterpret_cast(launch_params.params.z), nullptr, reinterpret_cast(launch_params.z_tensor->data.dptr), diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index cec7da248..6c21eab7b 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -65,12 +65,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_Norm_Backend norm_backend; bool is_aligned = true; -<<<<<<< HEAD #ifndef __HIP_PLATFORM_AMD__ - bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); -======= bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode); ->>>>>>> 389a6b if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { NVTE_CHECK(!cudnn_backend, diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 54815851d..598e0ca08 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -51,12 +51,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens NVTE_Norm_Backend norm_backend; bool is_aligned = true; -<<<<<<< HEAD #ifndef __HIP_PLATFORM_AMD__ - bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); -======= bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode); ->>>>>>> 389a6b if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { NVTE_CHECK(!cudnn_backend, diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 7c6055629..c55f1f612 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -8,15 +8,10 @@ from __future__ import annotations import os from enum import Enum -<<<<<<< HEAD -from typing import Optional, Union, Callable, NamedTuple -from typing_extensions import Literal -======= from typing import Any, Literal, Optional, Union, Callable, NamedTuple from dataclasses import field ->>>>>>> 389a6b from pydantic.dataclasses import dataclass -from transformer_engine.common import is_fp8_fnuz +from transformer_engine.common import is_fp8_fnuz, te_rocm_build class _FormatHelper(NamedTuple): @@ -58,17 +53,12 @@ class Format(Enum): FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format """ -<<<<<<< HEAD - E4M3 = _FormatHelper(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E4M3.value) - E5M2 = _FormatHelper(fwd=_FormatMaxVals.E5M2.value, bwd=_FormatMaxVals.E5M2.value) - HYBRID = _FormatHelper(fwd=E4M3.fwd, bwd=E5M2.bwd) -======= - - E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) - E4M3 = _FormatHelper(max_fwd=448, max_bwd=448) - E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344) + #TODO: bring E2M1 back after rocm support MXFP4 + if not te_rocm_build: + E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) + E4M3 = _FormatHelper(max_fwd=_FormatMaxVals.E4M3.value, max_bwd=_FormatMaxVals.E4M3.value) + E5M2 = _FormatHelper(max_fwd=_FormatMaxVals.E5M2.value, max_bwd=_FormatMaxVals.E5M2.value) HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) ->>>>>>> 389a6b @dataclass(frozen=True) diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index af8eaaf67..69b44494b 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -28,7 +28,6 @@ using bf16__ = __hip_bfloat16; constexpr int amax_kernel_threads = 512; -<<<<<<< HEAD #ifdef __HIP_PLATFORM_AMD__ template @@ -52,7 +51,6 @@ __global__ void amax_final_reduce(const float* __restrict__ block_amax, #endif -======= __launch_bounds__(1) __global__ void zero_amax_kernel(float *amax_ptr, const float *noop_ptr) { if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { return; @@ -60,7 +58,6 @@ __launch_bounds__(1) __global__ void zero_amax_kernel(float *amax_ptr, const flo *amax_ptr = 0; } ->>>>>>> 389a6b template __launch_bounds__(amax_kernel_threads) __global__ void amax_kernel(const InputType *input, float *amax, @@ -280,19 +277,13 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt float *amax_ptr = reinterpret_cast( (output.amax.dptr != nullptr) ? output.amax.dptr : output.columnwise_amax.dptr); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( -<<<<<<< HEAD input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); - launch_amax_kernel(reinterpret_cast(input.data.dptr), - reinterpret_cast(output.amax.dptr), input.data.numel(), + launch_amax_kernel( + reinterpret_cast(input.data.dptr), amax_ptr, input.data.numel(), #ifdef __HIP_PLATFORM_AMD__ - block_amax, block_capacity, + block_amax, block_capacity, #endif - noop_ptr, stream);); // NOLINT(*) -======= - input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel( - reinterpret_cast(input.data.dptr), amax_ptr, input.data.numel(), noop_ptr, - stream);); // NOLINT(*) ->>>>>>> 389a6b + noop_ptr, stream);); // NOLINT(*) } } // anonymous namespace diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index d0a7cb85e..881b134e7 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -20,18 +20,14 @@ namespace transformer_engine { namespace { -<<<<<<< HEAD #ifdef __HIP_PLATFORM_AMD__ #define __ldg(x) (*(x)) #endif #ifndef __HIP_PLATFORM_AMD__ -constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32; -======= constexpr int MXFP8_BLOCK_SIZE = 32; constexpr int NVFP4_BLOCK_SIZE = 16; ->>>>>>> 389a6b constexpr __device__ __host__ int TB_DIM = 32; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; @@ -376,138 +372,11 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s bool nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING; // 1D block scaling, row-wise or colum-wise -<<<<<<< HEAD - if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - const int m = - input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1]; - const int k = - input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0]; - - constexpr int SF_TILE_DIM_M = 128; - constexpr int SF_TILE_DIM_K = 4; - - NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); - NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); - NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); - if (output->has_data()) { - NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), - output->scale_inv.shape.end(), 1, std::multiplies()), - "Input.scale_inv size is not equal to Output.scale_inv size!"); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), - output->columnwise_scale_inv.shape.end(), 1, - std::multiplies()), - "Input.columnwise_scale_inv size is not equal to " - "Output.columnwise_scale_inv size!"); - } - - int num_tiles_m = m / SF_TILE_DIM_M; - int num_tiles_k = k / SF_TILE_DIM_K; - - dim3 block_size(TB_DIM, TB_DIM); - if (input->has_data()) { - int vec_load_size = (num_tiles_k - 1) % 4 + 1; - /* there is no int3 and misaligned if using int4/int2 */ - if (vec_load_size == 3) vec_load_size = 1; - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int original_M = input->flat_first_dim(); - const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; - switch (vec_load_size) { - case 4: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - case 2: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - case 1: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; - } - NVTE_CHECK_CUDA(cudaGetLastError()); - } - if (input->has_columnwise_data()) { - int vec_load_size = (num_tiles_m - 1) % 4 + 1; - if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int original_M = input->flat_last_dim(); - const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; - switch (vec_load_size) { - case 4: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - case 2: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - case 1: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; - } - NVTE_CHECK_CUDA(cudaGetLastError()); - } - // 2D block scaling -======= int m, k; if (input->has_data()) { m = input->scale_inv.shape[0]; k = input->scale_inv.shape[1]; ->>>>>>> 389a6b } else { if (nvfp4) { m = input->columnwise_scale_inv.shape[0]; diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh deleted file mode 100644 index b7c4cf837..000000000 --- a/transformer_engine/common/util/cast_kernels.cuh +++ /dev/null @@ -1,1546 +0,0 @@ -/************************************************************************* - * This file was modified for portability to AMDGPU - * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file cast_kernels.cuh - * \brief CUDA kernels to cast to/from FP8/MXFP8. - */ - -#ifndef TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ - -#include -#ifndef __HIP_PLATFORM_AMD__ -#include -#endif //#ifndef __HIP_PLATFORM_AMD__ -#include -#include - -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" -#include "transformer_engine/transformer_engine.h" -#ifdef __HIP_PLATFORM_AMD__ -#include "rocm_cast_kernels.cuh" -#endif - -namespace transformer_engine { - -#ifndef __HIP_PLATFORM_AMD__ -namespace mxfp8_kernel { - -constexpr size_t SCALE_DIM_Y = 32; -constexpr size_t SCALE_DIM_X = 32; - -constexpr size_t BUFFS_NUM = 2; -constexpr size_t PACK_SIZE = 4; -constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; - -// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory -constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 - -// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory -constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_act_input, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const float *noop, float *const dbias_workspace, float *const amax_ptr, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; - constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; - - using IType2 = typename ptx::FPx2; - using OType2 = typename ptx::FPx2; - - if constexpr (NO_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - } - constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; - - constexpr size_t BUFF_DIM_Y = THREADS_Y; - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; - constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; - static_assert(BUFF_DIM_Y == 32); - - constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; - static_assert(STAGES >= 1); - - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; - - const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; - const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; - - const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; - const size_t tid_X_rowwise = threadIdx.x % THREADS_X; - const size_t tid_Y_colwise = 0; - const size_t tid_X_colwise = threadIdx.x; - - const size_t thread_offset_Y_rowwise = tid_Y_rowwise; - const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const size_t thread_offset_Y_colwise = tid_Y_colwise; - const size_t thread_offset_X_colwise = tid_X_colwise; - - const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; - const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; - - const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - - // helps resolving bank conflicts in shmem - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / THREADS_PER_BANK; - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_sh = reinterpret_cast(dshmem); - IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); - OType *out_rowwise_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); - IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - - constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - float partial_dbias_colwise = 0.0f; - float thread_dbias_rowwise[SCALE_DIM_X]; - if constexpr (IS_DBIAS) { -#pragma unroll - for (int j = 0; j < SCALE_DIM_X; ++j) { - thread_dbias_rowwise[j] = 0.0f; - } - } - - float block_amax = 0.0f; - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], - &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - } - -#pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, - global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], - is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); - } - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], parity); - - float thread_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; - thread_amax = 0.0f; - float in_compute_colwise[BUFF_DIM_Y]; - IType in_colwise_IType[BUFF_DIM_Y]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType thread_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); - } - thread_amax = static_cast(thread_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - partial_dbias_colwise += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); - const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - in_compute_colwise[i] = elt; - } - } - - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; - const size_t global_scales_offset_X = scales_offset_X_colwise; - const size_t scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - -// 3. Scale elements -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; - - const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_sh[shmem_offset_elt] = static_cast(scaled_out); - } - } - - if constexpr (ROWWISE_SCALING) { - const size_t shmem_offset_base_rowwise = - buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; - - // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY - Vec in_IType[WAVES]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if (!out_of_bounds) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } - } - } - } - if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); - } -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); - } - - // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = - (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - in_compute_rowwise[j] = elt; - } - } - } - - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const size_t stage_scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - - // 3. Scale elements -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; - } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); - } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_sh[shmem_offset_rowwise]); - } - } - - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t global_offset_Y = block_offset_Y + stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t buff_offset = buff * BUFF_DIM; - - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_sh[buff_offset])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_sh[buff_offset])); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - } - } - - parity ^= 1; - - if constexpr (IS_DBIAS) { - float thread_partial_dbias = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_partial_dbias = partial_dbias_colwise; - } else { - // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] - // HEIGHT = THREADS_Y - // WIDTH = THREADS_X * (SCALE_DIM_X + 1) - // Added extra 1-element padding per thread_X to reduce bank conflicts - float *partial_dbias_rowwise = reinterpret_cast(dshmem); - - constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - - const size_t shmem_thread_offset = - tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - const size_t shmem_elt_idx = swizzled_group_offset + e; - partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; - } - } - __syncthreads(); -#pragma unroll - for (int i = 0; i < THREADS_Y; ++i) { - // Add extra element offset per MXFP8 scaling block [1x32] - const size_t scaling_block = threadIdx.x / SCALE_DIM_X; - thread_partial_dbias += - partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; - } - } - const size_t dbias_stride = cols; - const size_t dbias_offset_Y = blockIdx.y; - const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; - const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; - const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); - if (!col_out_of_bounds_dbias) { - dbias_workspace[dbias_idx] = thread_partial_dbias; - } - } - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); - } - - if (is_master_thread && amax_ptr != nullptr) { - atomicMaxFloat(amax_ptr, block_amax); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} -} // namespace mxfp8_kernel - -constexpr size_t FP8_CHUNK_DIM_Y = 128; -constexpr size_t FP8_CHUNK_DIM_X = 128; -constexpr size_t FP8_THREADS_PER_CHUNK = 128; -constexpr size_t FP8_BUFFERS_NUM = 2; -constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; -static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); - -constexpr size_t FP8_BUFFER_DIM_Y = 16; -constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 -constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 -constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 - -constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 -constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 -static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); - -template -__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) - cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_act_input, - const __grid_constant__ CUtensorMap tensor_map_output, - float *const dbias_workspace, float *const amax_ptr, - float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, - const size_t cols) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; - - const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; - const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; - - const size_t thread_offset_Y = tid_Y; - const size_t thread_offset_X = tid_X; - - const size_t dbias_offset_Y = blockIdx.y + tid_Y; - const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; - const bool col_out_of_bounds = my_column >= cols; - const size_t dbias_stride = cols; - - float partial_dbias = 0.f; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - - constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - const size_t chunk_offset_Y = block_offset_Y; - const size_t chunk_offset_X = block_offset_X; - -#pragma unroll - for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { - const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; - const size_t chunk_stage_offset_X = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, - chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, - &mbar[prefetch_buff], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], - is_master_thread); - } - } - -#pragma unroll - for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { - const size_t buff = iter % FP8_BUFFERS_NUM; - const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; - const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; - if (next_iter < FP8_ITERATIONS) { - const size_t next_buff = next_iter % FP8_BUFFERS_NUM; - const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, - chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], - is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); - } - } - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); - -#pragma unroll - for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { - const size_t stage_offset_Y = stage; - const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; - const size_t shmem_offset_x = thread_offset_X; - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = row >= rows; - const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; - - float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - if constexpr (IS_DACT) { - if (!out_of_bounds) { - partial_dbias += elt; - } - } else { - // If no activation, elt is 0 so we can safely do this - partial_dbias += elt; - } - } - __builtin_assume(amax >= 0); - if (IS_DACT) { - if (!out_of_bounds) { - amax = fmaxf(amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - amax = fmaxf(amax, fabsf(elt)); - } - out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); - } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - parity ^= 1; - - if constexpr (IS_DBIAS) { - const size_t dbias_offset_X = my_column; - const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; - if (!col_out_of_bounds) { - dbias_workspace[dbias_offset] = partial_dbias; - } - } - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -constexpr size_t CHUNKS_PER_BLOCK = 128; -constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; -constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; -constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; -constexpr size_t CHUNKS_PER_ITERATION = 32; -constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; -constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; -constexpr size_t SHMEM_BUFFERS = 2; -static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); - -template -__global__ void __launch_bounds__(THREADS_PER_BLOCK) - cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, - float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK; - const IType *input = input_ptr + block_offset; - OType *output = output_ptr + block_offset; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; - - constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; - constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); - -#pragma unroll - for (int iter = 0; iter < ITERATIONS; ++iter) { - const size_t buff = iter % SHMEM_BUFFERS; - const size_t it_offset = iter * SHMEM_DIM; - - const size_t next_iter = iter + 1; - const size_t next_buff = next_iter % SHMEM_BUFFERS; - const size_t next_iter_offset = next_iter * SHMEM_DIM; - - if (next_iter < ITERATIONS) { - copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, - &(mbar[next_iter]), is_master_thread); - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); - -#pragma unroll - for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { - const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; - float elt = static_cast(in_sh[buff][shmem_offset]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(elt)); - out_sh[buff][shmem_offset] = static_cast(elt * scale); - } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - ptx::cp_async_bulk_tensor_1d_shared_to_global( - reinterpret_cast(output + it_offset), - reinterpret_cast(&out_sh[buff]), transaction_size_OUT); - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read<1>(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} -#endif // #ifndef __HIP_PLATFORM_AMD__ - -constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; -template -__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) - reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, - const size_t rows, const size_t cols) { - using ComputeVec = Vec; - using OutputVec = Vec; - - const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; - - if (thread_id * nvec >= cols) { - return; - } - - const float *const thread_in_base = dbias_partial + thread_id * nvec; - OType *const thread_out_base = dbias_output + thread_id * nvec; - - ComputeVec ldg_vec; - ComputeVec acc_vec; - acc_vec.clear(); - for (int i = 0; i < rows; ++i) { - ldg_vec.load_from(thread_in_base + i * cols); -#pragma unroll - for (int e = 0; e < nvec; ++e) { - acc_vec.data.elt[e] += ldg_vec.data.elt[e]; - } - } - - OutputVec stg_vec; -#pragma unroll - for (int e = 0; e < nvec; ++e) { - stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); - } - stg_vec.store_to(thread_out_base); -} - -template -void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, - cudaStream_t stream) { - constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 - constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); - - NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); - const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); - - reduce_dbias_kernel - <<>>( - reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -#ifndef __HIP_PLATFORM_AMD__ -template -static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { - const size_t N = product(input.data.shape); - - const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); - NVTE_CHECK(isFullTile, "Only full tiles are supported."); - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - const size_t chunks = DIVUP(N, CHUNK_SIZE); - const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - const float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block(THREADS_PER_BLOCK); - const dim3 grid(blocks); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - const IType *input_ptr = reinterpret_cast(input.data.dptr); - OType *output_ptr = reinterpret_cast(output->data.dptr); - - cast_fp8_1D_kernel<<>>( - input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) - ); // NOLINT(*) - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, - Tensor *workspace, cudaStream_t stream) { - checkCuDriverContext(stream); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); - const size_t blocks_Y = chunks_Y; - const size_t blocks_X = chunks_X; - - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; - - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; - workspace->data.dtype = DType::kFloat32; - return; - } - } - float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block(FP8_THREADS_PER_CHUNK); - const dim3 grid(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->data.dtype, OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output{}; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); - } - - create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype)); - - cast_fp8_2D_kernel - <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, - workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols); - NVTE_CHECK_CUDA(cudaGetLastError()); - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) -} -#endif // #ifndef __HIP_PLATFORM_AMD__ - -template -void mxfp8_quantize(const Tensor &input, const Tensor *act_input, - const Tensor *noop, // TODO (ksivamani) - Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { -#ifndef __HIP_PLATFORM_AMD__ - using namespace mxfp8_kernel; - checkCuDriverContext(stream); -#endif - - bool use_rowwise_scaling = output->has_data(); - bool use_colwise_scaling = output->has_columnwise_data(); - NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - - if (use_rowwise_scaling) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - } - if (use_colwise_scaling) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Columnwise scaling tensor must be allocated"); - } - CheckNoopTensor(*noop, "cast_noop"); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - -#ifdef __HIP_PLATFORM_AMD__ - constexpr size_t CHUNK_DIM_Y = MXFP8_CHUNK_DIM_Y; - constexpr size_t CHUNK_DIM_X = MXFP8_CHUNK_DIM_X; - constexpr size_t THREADS_PER_CHUNK = MXFP8_THREADS_PER_CHUNK; -#else - constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); - - constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; - constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; - constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; - - constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; - constexpr size_t BUFF_DIM_Y = THREADS_Y; - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; -#endif - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - const dim3 grid(blocks_X, blocks_Y); - const size_t block_size = THREADS_PER_CHUNK; - - const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; - const size_t scale_stride_colwise = - use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - - e8m0_t *const scales_rowwise_ptr = - use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; - e8m0_t *const scales_colwise_ptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; - -#ifndef __HIP_PLATFORM_AMD__ - ScalingType scaling_type; - if (use_rowwise_scaling && (!use_colwise_scaling)) { - scaling_type = ScalingType::ROWWISE; - } else if ((!use_rowwise_scaling) && use_colwise_scaling) { - scaling_type = ScalingType::COLWISE; - } else if (use_rowwise_scaling && use_colwise_scaling) { - scaling_type = ScalingType::BIDIMENSIONAL; - } -#endif - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; - workspace->data.dtype = DType::kFloat32; - return; - } - } - - float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - const float *noop_ptr = reinterpret_cast(noop->data.dptr); - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, -#ifdef __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - !(cols % (32 * sizeof(IType))), IS_ALIGNED, - cast_mxfp8_2D_kernel - <<>>( - reinterpret_cast(input.data.dptr), - (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->columnwise_data.dptr), - scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - ))); // NOLINT(*) -#else // #ifdef __HIP_PLATFORM_AMD__ - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, - cols, 0, input_type_bit_size); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, input_type_bit_size); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, output_type_bit_size); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, - BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); - } - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - switch (scaling_type) { - case ScalingType::ROWWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::COLWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::BIDIMENSIONAL: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - } -#endif // #ifdef __HIP_PLATFORM_AMD__ - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) -} - -namespace detail { - -using Empty = transformer_engine::Empty; - -__device__ inline float identity(float value, const Empty &) { return value; } - -struct DequantizeParam { - const float *scale_inv; -}; - -__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { - return value * (*(param.scale_inv)); -} - -} // namespace detail - -/* HIPCC has strict rules for __device__ functions usage on host. - It forbids not only calling but also other ODR-use assigning to variables - https://github.com/llvm/llvm-project/issues/105825 - Use templated struct wrapper to work around - */ -template -struct ActivationType -{ - static constexpr auto op = OP; -}; - -template -void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, - cudaStream_t stream) { -#ifdef __HIP_PLATFORM_AMD__ - constexpr float (*UnaryOP)(float, const ParamOP &) = (ActivationType::op == nullptr) ? ActivationType::op : ActivationType::op; -#else //#ifdef __HIP_PLATFORM_AMD__ - constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; -#endif //#ifdef __HIP_PLATFORM_AMD__ - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(noop->data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, - cudaStream_t stream) { -#ifdef __HIP_PLATFORM_AMD__ - constexpr float (*UnaryOP)(float, const ParamOP &) = (ActivationType::op == nullptr) ? ActivationType::op : ActivationType::op; -#else //#ifdef __HIP_PLATFORM_AMD__ - constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; -#endif //#ifdef __HIP_PLATFORM_AMD__ - const size_t N = product(input->data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input->data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryGradKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -namespace { - -#ifndef __HIP_PLATFORM_AMD__ -static bool is_full_tile_1D_tensor(const Tensor *const t) { - const size_t N = product(t->data.shape); - const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); - return isFullTile; -} - -bool dimensions_supported_by_TMA(const Tensor *const t) { - const size_t cols = t->flat_last_dim(); - constexpr size_t TMA_bytes = 16; - const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); - return cols % alignment_requirement == 0; -} -#endif //#ifndef __HIP_PLATFORM_AMD__ - -} // namespace - -#ifndef __HIP_PLATFORM_AMD__ -// Supported by the Arch >= 10.0 -template -void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!IS_DBIAS && !IS_DACT) { - if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { - // Aligned AND FP8 - cast_fp8_1D(input, output, stream); - } else { - // Unaligned - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } - } else if (!IS_DBIAS && IS_DACT) { - if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { - // Aligned AND FP8 (+dAct) - cast_fp8_2D(input, act_input, output, dbias, workspace, - stream); - } else { - // Unaligned - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } - } else { - cast_fp8_2D(input, act_input, output, dbias, workspace, - stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize(input, act_input, noop, output, dbias, - workspace, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} - -// Supported by the Arch < 10.0 -template -void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { - // zhongboz: should we just ignore IS_ACT here? - NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + - " or IS_DBIAS=true" + " on GPU with compute capability < 10.0."); - } - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!IS_DACT) { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } else { - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} -#endif //#ifndef __HIP_PLATFORM_AMD__ - -template -void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, - Tensor *dbias, Tensor *workspace, cudaStream_t stream) { - CheckNoopTensor(*noop, "cast_noop"); - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias != nullptr); - CheckOutputTensor(*dbias, "dbias"); - } - if constexpr (IS_DACT) { - NVTE_CHECK(act_input != nullptr); - CheckInputTensor(*act_input, "activation_input"); - NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); - NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); - } - - NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - -#ifndef __HIP_PLATFORM_AMD__ - // NVIDIA - // Supported by the Arch >= 10.0 - if (is_supported_by_CC_100()) { - fp8_quantize_arch_ge_100(input, act_input, noop, output, - dbias, workspace, stream); - } else { // Supported by the Arch < 10.0 - fp8_quantize_arch_l_100(input, act_input, noop, output, - dbias, workspace, stream); - } -#else - // AMD - fp8_quantize_rocm(input, act_input, noop, output, - dbias, workspace, stream); -#endif //#ifndef __HIP_PLATFORM_AMD__ -} - -namespace detail { - -template -void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output, - NVTETensor dbias, NVTETensor workspace, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - const Tensor *input_tensor; - const Tensor *activation_input_tensor; - if constexpr (IS_DBIAS || IS_DACT) { - // backward - input is incoming gradient - input_tensor = convertNVTETensorCheck(grad); - activation_input_tensor = convertNVTETensor(input); - } else { - // forward = input is activation input - input_tensor = convertNVTETensorCheck(input); - activation_input_tensor = nullptr; - } - auto output_tensor = convertNVTETensorCheck(output); - auto dbias_tensor = convertNVTETensor(dbias); - auto workspace_tensor = convertNVTETensor(workspace); - - const QuantizationConfig *quant_config_cpp = - reinterpret_cast(quant_config); - - // extract noop tensor from quant_config_cpp if it's not null - const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; - const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor(); - - switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - } - } else if (output_tensor->has_data()) { - fp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - break; - } -#ifndef __HIP_PLATFORM_AMD__ - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; - quantize_transpose_square_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor.data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - bool rowwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; - rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT - : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - bool columnwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; - columnwise_option = columnwise_compact - ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT - : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor.data, stream); - break; - } -#endif - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - } -} - -} // namespace detail -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index c83322f93..09187069e 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -18,7 +18,9 @@ #endif // __HIP_PLATFORM_AMD__ #include +#ifndef __HIP_PLATFORM_AMD__ #include "nccl.h" +#endif //#ifndef __HIP_PLATFORM_AMD__ #ifdef NVTE_WITH_CUBLASMP #include @@ -123,6 +125,7 @@ #endif // NVTE_WITH_CUBLASMP +#ifndef __HIP_PLATFORM_AMD__ #define NVTE_CHECK_NCCL(expr) \ do { \ const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \ @@ -130,5 +133,5 @@ NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \ } \ } while (false) - +#endif //#ifndef __HIP_PLATFORM_AMD__ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 4f0b888c5..ef53c2670 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -21,11 +21,15 @@ #endif // CUDA_VERSION >= 12080 #include "common/utils.cuh" +#ifdef __HIP_PLATFORM_AMD__ +#include "../util/vectorized_pointwise.h" +#endif //#ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { namespace ptx { +#ifndef __HIP_PLATFORM_AMD__ template struct ArchSpecific { constexpr static int id = N * 10; @@ -125,6 +129,8 @@ constexpr bool is_supported_arch() { #define ARCH_HAS_STOCHASTIC_ROUNDING \ NVTE_CUDA_ARCH_MATCHES(ptx::ArchSpecific<100>, ptx::ArchSpecific<103>) +#endif //#ifndef __HIP_PLATFORM_AMD__ + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init __device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -259,26 +265,8 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) { } __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { -<<<<<<< HEAD -#ifdef __HIP_PLATFORM_AMD__ -#define __CUDA_ARCH_HAS_FEATURE__(x) 0 -#endif -#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ - (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) - uint16_t out; - asm volatile( - "{\n" - "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" - "}" - : "=h"(out) - : "f"(val)); - return *reinterpret_cast(&out); -#else - // TODO: nan/inf needs to be set for any value - // of nan/inf in input not just amax. - if (isnan(val)) { - return 0xFF; -======= +#ifndef __HIP_PLATFORM_AMD__ + constexpr bool is_blackwell = false; constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; if constexpr (is_blackwell) { uint16_t out; @@ -290,6 +278,7 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { : "f"(val)); return *reinterpret_cast(&out); } else { +#endif //#ifndef __HIP_PLATFORM_AMD__ // TODO: nan/inf needs to be set for any value // of nan/inf in input not just amax. if (isnan(val)) { @@ -309,8 +298,9 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { ++exponent; } return exponent; ->>>>>>> 389a6b +#ifndef __HIP_PLATFORM_AMD__ } +#endif //#ifndef __HIP_PLATFORM_AMD__ } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor @@ -328,6 +318,38 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_ #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } +#ifdef __HIP_PLATFORM_AMD__ +template +__device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T *g_ptr, size_t g_start_col, + size_t g_start_row, size_t g_stride, size_t chunk_dim_y, + size_t chunk_dim_x, size_t total_rows, + size_t total_cols) { + const size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; + const size_t l_idx = threadIdx.x; + + for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { + size_t l_y = (i_vec / chunk_dim_x_vec_elements); + size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); + + size_t g_row = g_start_row + l_y; + size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; + + const T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; + VectorizedLoader shared_loader(current_sh_row_base_ptr, chunk_dim_x); + + T* current_g_row_base_ptr = g_ptr + g_row * g_stride; + VectorizedStorer global_storer(current_g_row_base_ptr, total_cols); + + shared_loader.load(l_x_vec, chunk_dim_x); + + if (g_row < total_rows) { + global_storer.storage_.scratch_ = shared_loader.storage_.scratch_; + global_storer.store(g_col_primitive_start / N_VEC, total_cols); + } + } +} +#endif //#ifdef __HIP_PLATFORM_AMD__ + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // shared::cta -> global __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( @@ -909,6 +931,47 @@ __forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +#ifdef __HIP_PLATFORM_AMD__ +// These 2d copy functions replace TMA tensormap async copies for AMD GPUs. +template +__device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t g_start_col, + size_t g_start_row, size_t g_stride, size_t chunk_dim_y, + size_t chunk_dim_x, size_t total_rows, + size_t total_cols) { + size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; + const size_t l_idx = threadIdx.x; + + for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { + size_t l_y = (i_vec / chunk_dim_x_vec_elements); + size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); + + size_t g_row = g_start_row + l_y; + size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; + + if (g_row < total_rows) { + const T* current_g_row_base_ptr = g_ptr + g_row * g_stride; + VectorizedLoaderglobal_loader(current_g_row_base_ptr, total_cols); + + T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; + VectorizedStorershared_storer(current_sh_row_base_ptr, chunk_dim_x); + + global_loader.load(g_col_primitive_start / N_VEC, total_cols); + shared_storer.storage_.scratch_ = global_loader.storage_.scratch_; + shared_storer.store(l_x_vec, chunk_dim_x); + + } else { + T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; + VectorizedStorer shared_storer(current_sh_row_base_ptr, chunk_dim_x); + +#pragma unroll + for (int i = 0; i < N_VEC; ++i) { + shared_storer.separate()[i] = static_cast(0); + } + shared_storer.store(l_x_vec, chunk_dim_x); + } + } +} +#else __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, const size_t chunk_Y, const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { @@ -929,6 +992,7 @@ __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, co NVTE_DEVICE_ERROR("copy_2d_to_shared is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +#endif //#ifdef __HIP_PLATFORM_AMD__ __forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, diff --git a/transformer_engine/common/util/rocm_vectorized_2d.cuh b/transformer_engine/common/util/rocm_vectorized_2d.cuh index 5877ddd87..eda0f437f 100644 --- a/transformer_engine/common/util/rocm_vectorized_2d.cuh +++ b/transformer_engine/common/util/rocm_vectorized_2d.cuh @@ -9,73 +9,5 @@ #include "../util/vectorized_pointwise.h" namespace transformer_engine { -// These 2d copy functions replace TMA tensormap async copies for AMD GPUs. -template -__device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t g_start_col, - size_t g_start_row, size_t g_stride, size_t chunk_dim_y, - size_t chunk_dim_x, size_t total_rows, - size_t total_cols) { - size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; - const size_t l_idx = threadIdx.x; - for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { - size_t l_y = (i_vec / chunk_dim_x_vec_elements); - size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); - - size_t g_row = g_start_row + l_y; - size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; - - if (g_row < total_rows) { - const T* current_g_row_base_ptr = g_ptr + g_row * g_stride; - VectorizedLoaderglobal_loader(current_g_row_base_ptr, total_cols); - - T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; - VectorizedStorershared_storer(current_sh_row_base_ptr, chunk_dim_x); - - global_loader.load(g_col_primitive_start / N_VEC, total_cols); - shared_storer.storage_.scratch_ = global_loader.storage_.scratch_; - shared_storer.store(l_x_vec, chunk_dim_x); - - } else { - T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; - VectorizedStorer shared_storer(current_sh_row_base_ptr, chunk_dim_x); - -#pragma unroll - for (int i = 0; i < N_VEC; ++i) { - shared_storer.separate()[i] = static_cast(0); - } - shared_storer.store(l_x_vec, chunk_dim_x); - } - } -} - -template -__device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T *g_ptr, size_t g_start_col, - size_t g_start_row, size_t g_stride, size_t chunk_dim_y, - size_t chunk_dim_x, size_t total_rows, - size_t total_cols) { - const size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; - const size_t l_idx = threadIdx.x; - - for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { - size_t l_y = (i_vec / chunk_dim_x_vec_elements); - size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); - - size_t g_row = g_start_row + l_y; - size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; - - const T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; - VectorizedLoader shared_loader(current_sh_row_base_ptr, chunk_dim_x); - - T* current_g_row_base_ptr = g_ptr + g_row * g_stride; - VectorizedStorer global_storer(current_g_row_base_ptr, total_cols); - - shared_loader.load(l_x_vec, chunk_dim_x); - - if (g_row < total_rows) { - global_storer.storage_.scratch_ = shared_loader.storage_.scratch_; - global_storer.store(g_col_primitive_start / N_VEC, total_cols); - } - } -} } // namespace transformer_engine diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 7e150ed6f..c56242d34 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -53,8 +53,6 @@ constexpr uint32_t THREADS_PER_WARP = 32; //////////////////////////////////////////////////////////////////////////////////////////////////// -<<<<<<< HEAD -======= // Device-side error #define NVTE_DEVICE_ERROR(message) \ do { \ @@ -88,7 +86,6 @@ inline __device__ void operator+=(float2 &a, const float2 &b) { // NOLINT(*) //////////////////////////////////////////////////////////////////////////////////////////////////// ->>>>>>> 389a6b template struct Sum { inline __device__ Sum() {} From 0519b4ba1298f7b599c7a7e8330fb88dbdf4a9bb Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Tue, 10 Feb 2026 11:07:06 -0600 Subject: [PATCH 02/41] [ROCm] resolve the conflicts on jax side --- build_tools/jax.py | 2 +- .../jax/cpp_extensions/activation.py | 37 +--------- transformer_engine/jax/cpp_extensions/base.py | 31 ++------- transformer_engine/jax/cpp_extensions/misc.py | 6 -- .../jax/cpp_extensions/normalization.py | 8 +-- .../jax/cpp_extensions/quantization.py | 18 ----- transformer_engine/jax/csrc/extensions.h | 5 +- .../jax/csrc/extensions/attention.cpp | 67 ------------------- .../jax/csrc/extensions/cgemm_helper.cpp | 4 ++ .../jax/csrc/extensions/gemm.cpp | 18 +++-- transformer_engine/jax/csrc/extensions/misc.h | 2 + .../jax/csrc/extensions/pybind.cpp | 16 ++--- transformer_engine/jax/quantize/helper.py | 31 +-------- transformer_engine/jax/setup.py | 6 +- 14 files changed, 41 insertions(+), 210 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index 7886b8ba2..e67036f49 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -105,7 +105,7 @@ def setup_jax_extension( sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], extra_compile_args=cxx_flags, - libraries=["nccl"], + libraries=["nccl"] if not rocm_build() else [], ) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 5897c2a74..df148265d 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,14 +11,8 @@ import jax import jax.numpy as jnp -<<<<<<< HEAD -from jax import dtypes -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax.experimental.custom_partitioning import SdyShardingRule -======= from jax import dtypes, ffi from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING ->>>>>>> 389a6b from jax.sharding import PartitionSpec import numpy as np @@ -579,15 +573,6 @@ def shardy_sharding_rule( value_types, result_types, ): -<<<<<<< HEAD - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") - del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types - prefix = "ActLuPrimitive_" - x_rank = len(value_types[0].shape) - scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - x_rank - 1, unique_var=prefix + "x", flatten_axis=-2 -======= del ( out_dtype, act_enum, @@ -600,7 +585,6 @@ def shardy_sharding_rule( is_outer, mesh, result_types, ->>>>>>> 389a6b ) prefix = "ActLu" input_shape = value_types[0].shape @@ -1134,25 +1118,6 @@ def shardy_sharding_rule( value_types, result_types, ): -<<<<<<< HEAD - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") - del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types - prefix = "BaseDActLuDBiasQuantizePrimitive_" - scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 - ) - x_axes = scale_rules.input_spec - dz_axes = (*x_axes[:-2], x_axes[-1]) - out = x_axes - colwise_out = (prefix + "out_colwise",) - if is_2x: - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) - else: - colwise_out = out -======= ->>>>>>> 389a6b del ( out_dtype, diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 6c53317af..176e0eadc 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -17,11 +17,7 @@ from jax._src import dispatch from jax import ffi -<<<<<<< HEAD from .misc import is_hip_extension -import jax -======= ->>>>>>> 389a6b import transformer_engine_jax @@ -193,24 +189,14 @@ def register_primitive(cls, outer_only=False): def name_of_wrapper_p(): return cls.name + "_wrapper" -<<<<<<< HEAD - inner_p = core.Primitive(cls.name) - dispatch.prim_requires_devices_during_lowering.add(inner_p) - inner_p.multiple_results = cls.multiple_results - inner_p.def_impl(partial(xla.apply_primitive, inner_p)) - inner_p.def_abstract_eval(cls.abstract) - mlir.register_lowering(inner_p, cls.lowering, platform="rocm" if is_hip_extension() else "cuda") - cls.inner_primitive = inner_p -======= if not outer_only: inner_p = core.Primitive(cls.name) dispatch.prim_requires_devices_during_lowering.add(inner_p) inner_p.multiple_results = cls.multiple_results inner_p.def_impl(partial(xla.apply_primitive, inner_p)) inner_p.def_abstract_eval(cls.abstract) - mlir.register_lowering(inner_p, cls.lowering, platform="cuda") + mlir.register_lowering(inner_p, cls.lowering, platform="rocm" if is_hip_extension() else "cuda") cls.inner_primitive = inner_p ->>>>>>> 389a6b outer_p = core.Primitive(name_of_wrapper_p()) dispatch.prim_requires_devices_during_lowering.add(outer_p) @@ -219,16 +205,11 @@ def name_of_wrapper_p(): outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) - if version.parse(jax.__version__) >= version.parse("0.5.0"): - outer_p_lower.def_partition( - infer_sharding_from_operands=cls.infer_sharding_from_operands, - partition=cls.partition, - sharding_rule=cls.shardy_sharding_rule, - ) - else: - outer_p_lower.def_partition( - infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition - ) + outer_p_lower.def_partition( + infer_sharding_from_operands=cls.infer_sharding_from_operands, + partition=cls.partition, + sharding_rule=cls.shardy_sharding_rule, + ) mlir.register_lowering( outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results) ) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 9443262c8..6c4be68ec 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -74,16 +74,10 @@ def jax_dtype_to_te_dtype(jax_dtype): jnp.bfloat16.dtype: TEDType.kBFloat16, jnp.int32.dtype: TEDType.kInt32, jnp.int64.dtype: TEDType.kInt64, -<<<<<<< HEAD get_jnp_float8_e4m3_type().dtype: TEDType.kFloat8E4M3, get_jnp_float8_e5m2_type().dtype: TEDType.kFloat8E5M2, - jnp.uint8.dtype: TEDType.kByte, -======= - jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3, - jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2, jnp.float8_e8m0fnu.dtype: TEDType.kFloat8E8M0, jnp.float4_e2m1fn.dtype: TEDType.kFloat4E2M1, ->>>>>>> 389a6b } if jax_dtype not in converter: diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 37b5b077b..1bf6ec943 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -12,14 +12,8 @@ import jax import jax.numpy as jnp -<<<<<<< HEAD -from jax import dtypes -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax.experimental.custom_partitioning import SdyShardingRule -======= from jax import dtypes, ffi from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING ->>>>>>> 389a6b from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec from .misc import is_hip_extension diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index fd7a101c8..bd2176170 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -12,14 +12,8 @@ import jax import jax.numpy as jnp -<<<<<<< HEAD -from jax import dtypes -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax.experimental.custom_partitioning import SdyShardingRule -======= from jax import dtypes, ffi from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING ->>>>>>> 389a6b from jax.sharding import PartitionSpec import transformer_engine_jax @@ -639,17 +633,6 @@ def shardy_sharding_rule( value_types, result_types, ): -<<<<<<< HEAD - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") - del out_dtype, scale_dtype, is_outer, mesh, result_types - - prefix = "BaseDBiasQuantizePrimitive_" - scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[0].shape), - unique_var=prefix + "x", - flatten_axis=flatten_axis, -======= del ( out_dtype, scale_dtype, @@ -658,7 +641,6 @@ def shardy_sharding_rule( use_rht, mesh, result_types, ->>>>>>> 389a6b ) prefix = "DBiasQuantize" diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 396d7c089..845176080 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -17,7 +17,9 @@ #include #include #include +#ifndef USE_ROCM #include +#endif #include #include @@ -143,14 +145,11 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); -<<<<<<< HEAD #ifndef USE_ROCM -======= // Amax XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler); ->>>>>>> 389a6b // Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index a3c1a262b..1281eb272 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -522,72 +522,6 @@ static void FusedAttnBackwardImpl( auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { -<<<<<<< HEAD - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); - auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); - if (is_ragged) { - (void)cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype), - stream); - } - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, - deterministic, workspace_tensor.data(), stream); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto kv_shape = - std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(k, kv_shape, dtype); - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype); - if (is_ragged) { - (void)cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); - (void)cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype), - stream); - } - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, workspace_tensor.data(), stream); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dk_tensor = TensorWrapper(dk, k_shape, dtype); - auto dv_tensor = TensorWrapper(dv, v_shape, dtype); - if (is_ragged) { - (void)cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); - (void)cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream); - (void)cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream); - } - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, workspace_tensor.data(), stream); - } else { - NVTE_ERROR("Unsupported qkv_layout."); -======= // QKV packed in q: [batch*seqlen, 3, heads, dim] NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen"); NVTE_CHECK(qk_head_dim == v_head_dim, @@ -614,7 +548,6 @@ static void FusedAttnBackwardImpl( dv_ptr = static_cast(static_cast(dk) + stride); // V has same shape as K since they're packed together v_shape = k_shape; ->>>>>>> 389a6b } auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype); diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp index 7082bfb03..4d44bb4a8 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -1,9 +1,12 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +#ifndef USE_ROCM #include "cgemm_helper.h" #include "common/util/system.h" @@ -257,3 +260,4 @@ CommunicatorHandler::~CommunicatorHandler() { } // namespace jax } // namespace transformer_engine +#endif //#ifndef USE_ROCM diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 8f6383c0f..41b78f117 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -15,13 +15,17 @@ #include #include "../extensions.h" +#ifndef USE_ROCM #include "cgemm_helper.h" +#endif //#ifndef USE_ROCM #include "common.h" #include "common/util/cuda_runtime.h" #include "common/util/string.h" #include "common/util/system.h" #include "cuda_runtime.h" +#ifndef USE_ROCM #include "nccl.h" +#endif //#ifndef USE_ROCM #include "transformer_engine/swizzle.h" #include "xla/ffi/api/c_api.h" @@ -98,6 +102,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( return std::make_tuple(std::move(input), input_shape); } +#ifndef USE_ROCM Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta, @@ -162,6 +167,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, .Attr("grad") .Attr("use_split_accumulator") .Attr("collective_op")); +#endif //#ifndef USE_ROCM Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, @@ -273,6 +279,10 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i rhs_.data() /*A*/, lhs_.data() /*B*/, beta_ptr, out_.data() /*C*/, out_.data() /*D*/, workspace_.data(), config, stream); } else { +#ifdef USE_ROCM + //TODO: better assert + std::cerr<<"ROCm TE jax does not integrate userbuffer for now"< buffer_shape{0, 0}; DType buffer_dtype = out_dtype; auto &comm_handler = CommunicatorHandler::get(); @@ -318,6 +328,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, workspace_, grad, false, use_split_accumulator, aux_out_, stream); } +#endif //#ifdef USE_ROCM } return ffi_with_cuda_error_check(); @@ -346,14 +357,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("fuse_bias") .Attr("fuse_gelu") .Attr("grad") -<<<<<<< HEAD - .Attr("use_split_accumulator"), - GemmFFI_CudaGraph_Traits); -======= .Attr("use_split_accumulator") .Attr("collective_op"), - FFI_CudaGraph_Traits); ->>>>>>> 389a6b + GemmFFI_CudaGraph_Traits); size_t GroupedGemmGetGroupSizes(cudaStream_t stream, size_t num_gemms, int32_t *dev_group_sizes, int32_t *host_group_sizes) { diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 21b50c1af..a0c5db5a8 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -128,6 +128,7 @@ enum class JAXX_Collective_Op : int64_t { REDUCE_SCATTER = 2, }; +#ifndef USE_ROCM static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) { switch (op) { case JAXX_Collective_Op::ALL_GATHER: @@ -141,6 +142,7 @@ static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) { break; } } +#endif } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 7998af062..bc47ef6bd 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -7,7 +7,9 @@ ************************************************************************/ #include "../extensions.h" +#ifndef USE_ROCM #include "cgemm_helper.h" +#endif //#ifndef USE_ROCM #include "common/util/cuda_runtime.h" namespace transformer_engine { @@ -78,12 +80,15 @@ pybind11::dict Registrations() { dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); + // Amax + dict["te_rht_amax_ffi"] = pybind11::dict( + pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); #else // Normalization dict["te_norm_forward_ffi"] = EncapsulateFFI(NormForwardHandler); dict["te_norm_backward_ffi"] = EncapsulateFFI(NormBackwardHandler); -<<<<<<< HEAD // Attention dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler); dict["te_fused_attn_backward_ffi"] = EncapsulateFFI(FusedAttnBackwardHandler); @@ -91,13 +96,6 @@ pybind11::dict Registrations() { dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); dict["te_grouped_gemm_ffi"] = EncapsulateFFI(GroupedGemmHandler); #endif -======= - // Amax - dict["te_rht_amax_ffi"] = pybind11::dict( - pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), - pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); - ->>>>>>> 389a6b return dict; } @@ -121,8 +119,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); +#ifndef USE_ROCM m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); +#endif pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index ced397371..792173ed1 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -26,7 +26,6 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict -<<<<<<< HEAD from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type from transformer_engine_jax import DType @@ -35,10 +34,6 @@ get_cublasLt_version, get_cuda_version, ) -from transformer_engine.common import recipe -from transformer_engine.jax.sharding import global_shard_guard, MeshResource -======= -from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version from transformer_engine.common.recipe import ( Recipe, DelayedScaling, @@ -54,7 +49,6 @@ get_all_mesh_axes, with_sharding_constraint, ) ->>>>>>> 389a6b from .metadata import QuantizeMeta from .scaling_modes import ScalingMode @@ -102,16 +96,11 @@ def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: Returns: A tuple of (bool, str) indicating support and any error message """ -<<<<<<< HEAD if is_hip_extension(): if gpu_arch in [94, 95]: return True, "" else: return False, "Device arch gfx94x or gfx95x required for FP8 execution." - if gpu_arch >= 90: # hopper and above - return True, "" -======= ->>>>>>> 389a6b if gpu_arch < 89: # pre-ada return False, "Device compute capability 8.9 or higher required for FP8 execution." if get_cublasLt_version() < 120103: @@ -130,13 +119,8 @@ def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: Returns: A tuple of (bool, str) indicating support and any error message """ -<<<<<<< HEAD if is_hip_extension(): return False, "FP8 block scaled gemm not yet supported for ROCm" - if gpu_arch >= 100: # blackwell and above - return True, "" -======= ->>>>>>> 389a6b if gpu_arch < 99: # pre-blackwell return False, "Device compute capability 9.9 or higher required for MXFP8 execution." if get_cublasLt_version() < 120800: @@ -259,23 +243,14 @@ def _format2dtypes(format_: Format): Returns: A tuple of (forward_dtype, backward_dtype) for the given format """ -<<<<<<< HEAD - if format_ == recipe.Format.E4M3: - return get_jnp_float8_e4m3_type(), get_jnp_float8_e4m3_type() - if format_ == recipe.Format.E5M2: - return get_jnp_float8_e5m2_type(), get_jnp_float8_e5m2_type() - if format_ == recipe.Format.HYBRID: - return get_jnp_float8_e4m3_type(), get_jnp_float8_e5m2_type() -======= if format_ == Format.E4M3: - return jnp.float8_e4m3fn, jnp.float8_e4m3fn + return get_jnp_float8_e4m3_type(), get_jnp_float8_e4m3_type() if format_ == Format.E5M2: - return jnp.float8_e5m2, jnp.float8_e5m2 + return get_jnp_float8_e5m2_type(), get_jnp_float8_e5m2_type() if format_ == Format.HYBRID: - return jnp.float8_e4m3fn, jnp.float8_e5m2 + return get_jnp_float8_e4m3_type(), get_jnp_float8_e5m2_type() if format_ == Format.E2M1: return jnp.float4_e2m1fn, jnp.float4_e2m1fn ->>>>>>> 389a6b return jnp.bfloat16, jnp.bfloat16 diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index 589a96470..0b958d3ad 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -46,12 +46,8 @@ from build_tools.build_ext import get_build_ext -<<<<<<< HEAD -from build_tools.utils import ( rocm_build, copy_common_headers, copy_hipify_tools, - clear_hipify_tools_copy) -======= from build_tools.utils import copy_common_headers, min_python_version_str ->>>>>>> 389a6b +from build_tools.utils import rocm_build, copy_hipify_tools, clear_hipify_tools_copy from build_tools.te_version import te_version from build_tools.jax import setup_jax_extension, install_requirements, test_requirements From 8f4b04db1d2d9debf0760eb27293b82b3a5cbb39 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Tue, 10 Feb 2026 11:08:54 -0600 Subject: [PATCH 03/41] [ROCm] resolve the conflicts on pytorch side --- .../dot_product_attention/backends.py | 4 - .../dot_product_attention/context_parallel.py | 53 +- .../pytorch/cpp_extensions/fused_attn.py | 21 +- transformer_engine/pytorch/csrc/common.cpp | 5 +- transformer_engine/pytorch/csrc/common.h | 4 +- .../pytorch/csrc/extensions/gemm.cpp | 16 +- transformer_engine/pytorch/csrc/util.cpp | 5 +- transformer_engine/pytorch/csrc/util.h | 5 +- transformer_engine/pytorch/fp8.py | 1093 +---------------- transformer_engine/pytorch/module/_common.py | 9 +- transformer_engine/pytorch/module/base.py | 16 +- .../pytorch/module/layernorm_linear.py | 10 +- .../pytorch/module/layernorm_mlp.py | 21 +- transformer_engine/pytorch/module/linear.py | 13 +- transformer_engine/pytorch/ops/fuser.py | 4 - transformer_engine/pytorch/quantization.py | 30 +- .../pytorch/quantized_tensor.py | 74 -- transformer_engine/pytorch/setup.py | 6 +- .../pytorch/tensor/_quantization_helpers.py | 1 + .../pytorch/tensor/float8_tensor.py | 26 +- .../pytorch/tensor/mxfp8_tensor.py | 42 +- .../pytorch/triton/cross_entropy.py | 211 ---- transformer_engine/pytorch/utils.py | 31 +- 23 files changed, 95 insertions(+), 1605 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index b455a0bd6..5437b73bc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -15,12 +15,8 @@ from packaging.version import Version as PkgVersion import torch -<<<<<<< HEAD from torch.utils.cpp_extension import IS_HIP_EXTENSION - -======= import torch.nn.functional as F ->>>>>>> 389a6b import transformer_engine_torch as tex from transformer_engine.pytorch.utils import ( get_device_compute_capability, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index c5e81516a..13b41345b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1191,43 +1191,7 @@ def forward( dQKV_quantizer, dO_quantizer, dP_quantizer, -<<<<<<< HEAD - ) = dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True) - - if fp8: - if use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] - - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha - if is_input_fp8: - QKV_quantizer = q._quantizer - q, k, v = q._data, k._data, v._data - else: - q_f16, k_f16, v_f16 = q, k, v - if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = QKV_quantizer(q_f16)._data - if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]] - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) - # partial result quantizer - for i in range(cp_size): - S_quantizer_per_step[i] = S_quantizer.copy() - S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) - O_CP_quantizer_per_step[i] = O_CP_quantizer.copy() - O_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) - else: - assert False, "FP8 is only supported with Fused Attention!" - else: - q_f16 = q - if use_fused_attention: - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] -======= ) = dpa_utils.get_attention_quantizers(fp8, quantizers) ->>>>>>> 389a6b q_f16 = None q_fp8, k_fp8, v_fp8 = (None, None, None) @@ -1293,7 +1257,7 @@ def forward( # q, k, v: torch.Tensor, dtype=fwd_nominal_dtype q_f16 = q if use_fused_attention: - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] if return_max_logit: max_logit_per_step = [ torch.empty(q.shape[-2], dtype=q.dtype, device=q.device) for _ in range(cp_size) @@ -2080,14 +2044,8 @@ def backward(ctx, dout, *_args): ] p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: -<<<<<<< HEAD - fp8_meta_kwargs = {} - fused_attn_dqkv_dtype = TE_DType[dout_dtype] - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] -======= bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] ->>>>>>> 389a6b + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: @@ -3527,13 +3485,8 @@ def backward(ctx, dout, *_args): dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: fp8_meta_kwargs = {} -<<<<<<< HEAD - fused_attn_dqkv_dtype = TE_DType[dout_dtype] - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] -======= dqkv_te_dtype = TE_DType[dout.dtype] - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] ->>>>>>> 389a6b + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] if not ctx.use_fused_attention: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index b4a6e7ba8..852dcdb59 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -90,7 +90,12 @@ "padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, } -<<<<<<< HEAD +SoftmaxType = { + "vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + "off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX, + "learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX, +} + if not IS_HIP_EXTENSION: FusedAttnBackend = { "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, @@ -104,20 +109,6 @@ "CK": NVTE_Fused_Attn_Backend.NVTE_CK, "No_Backend": NVTE_Fused_Attn_Backend.NVTE_No_Backend, } -======= -SoftmaxType = { - "vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, - "off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX, - "learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX, -} - -FusedAttnBackend = { - "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, - "F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - "FP8": NVTE_Fused_Attn_Backend.NVTE_FP8, - "No_Backend": NVTE_Fused_Attn_Backend.NVTE_No_Backend, -} ->>>>>>> 389a6b BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 BACKEND_F16arb_ELTS_PER_THREADS = 16 diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 2edb210ef..e1a78d49a 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -312,7 +312,6 @@ size_t roundup(const size_t value, const size_t multiple) { return ((value + multiple - 1) / multiple) * multiple; } -<<<<<<< HEAD #ifdef __HIP_PLATFORM_AMD__ inline bool nvte_use_atomic_amax() { @@ -336,7 +335,6 @@ at::Tensor allocate_amax_workspace(const TensorWrapper& input_tensor) { #endif -======= void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { NVTE_SCOPED_GIL_RELEASE({ nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val, @@ -353,5 +351,4 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_pe return philox_args; } ->>>>>>> 389a6b } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index a98818c88..74852b22d 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -505,11 +505,10 @@ size_t roundup(const size_t value, const size_t multiple); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); -<<<<<<< HEAD #ifdef __HIP_PLATFORM_AMD__ at::Tensor allocate_amax_workspace(const TensorWrapper& input_tensor); #endif -======= + std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); // unpack the PhiloxCudaState into CUDA tensor @@ -518,7 +517,6 @@ void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr); // extract PhiloxCudaState from CUDA random number generator at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread); ->>>>>>> 389a6b } // namespace transformer_engine::pytorch namespace std { diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index cf99c2256..4a438d366 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -219,9 +219,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); -<<<<<<< HEAD -#ifndef USE_ROCM -======= // Construct GEMM config transformer_engine::MatmulConfigWrapper config; if (grad) { @@ -235,7 +232,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans config.set_use_split_accumulator(use_split_accumulator); config.set_sm_count(num_math_sms); ->>>>>>> 389a6b +#ifndef USE_ROCM // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; #endif @@ -246,7 +243,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa))); swizzled_scale_inverses_list.emplace_back( std::move(swizzle_scaling_factors(B_tensor, !transb))); -#endif // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // as it is not natively supported by cublasLt @@ -260,6 +256,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans transa = true; transb = false; } +#endif if (comm_overlap) { #ifndef USE_ROCM @@ -494,17 +491,9 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out)); } -<<<<<<< HEAD #ifndef USE_ROCM - // Optionally swizzle the scaling factors - // Keep the swizzled scaling factor tensors alive during the GEMMs. - auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa); - auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb); -#endif -======= // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; ->>>>>>> 389a6b // Optionally swizzle the scaling factors swizzled_scale_inverses_list.emplace_back( @@ -544,6 +533,7 @@ std::optional> te_general_grouped_gemm( transb = false; } } +#endif std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, te_pre_gelu_out_vector; diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 9d25f67df..3948c6403 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -190,9 +190,6 @@ std::optional multi_tensor_swizzle_scaling_factors( return buffer; } -<<<<<<< HEAD -#endif //!USE_ROCM -======= at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input, bool rowwise) { using namespace transformer_engine::pytorch; @@ -261,4 +258,4 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapp input = std::move(output_cu); return swizzled_scale_inv; } ->>>>>>> 389a6b +#endif //!USE_ROCM diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 1305c9afc..9a46ae86d 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -31,9 +31,6 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap std::optional multi_tensor_swizzle_scaling_factors( std::vector &inputs, bool rowwise); -<<<<<<< HEAD -#endif //!USE_ROCM -======= /*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place. * * If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid @@ -45,6 +42,6 @@ std::optional multi_tensor_swizzle_scaling_factors( */ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input, bool rowwise); ->>>>>>> 389a6b +#endif //!USE_ROCM #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index c364d5e45..b36302db2 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,11 +10,6 @@ # pylint: disable=wrong-import-position,unused-import -<<<<<<< HEAD -import torch -from torch.utils.cpp_extension import IS_HIP_EXTENSION -import transformer_engine_torch as tex -======= import warnings warnings.warn( @@ -29,7 +24,6 @@ # There are some users indirectly importing these classes # from fp8.py. This ensure backwards compatibility. # https://github.com/Lightning-AI/lightning-thunder/pull/2635. ->>>>>>> 389a6b from transformer_engine.common.recipe import ( Recipe, DelayedScaling, @@ -41,1090 +35,6 @@ CustomRecipe, ) -<<<<<<< HEAD -from .constants import dist_group_type -from .utils import get_device_compute_capability, get_torch_float8_e4m3_type, get_torch_float8_e5m2_type -from .jit import jit_fuser - -__all__ = ["fp8_autocast", "fp8_model_init"] - -def check_fp8_support() -> Tuple[bool, str]: - if IS_HIP_EXTENSION: - gpu_arch = get_device_compute_capability() - if gpu_arch in ((9, 4), (9, 5)): - return True, "" - else: - return False, "Device arch gfx94x or gfx95x required for FP8 execution." - else: - """Return if fp8 support is available""" - if get_device_compute_capability() >= (9, 0): # hopper and above - return True, "" - if get_device_compute_capability() < (8, 9): # pre-ada - return False, "Device compute capability 8.9 or higher required for FP8 execution." - if tex.get_cublasLt_version() < 120103: - return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." - if float(torch.version.cuda) < 12.1: - return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." - return True, "" - -def check_mxfp8_support() -> Tuple[bool, str]: - """Return if fp8 support is available""" - if IS_HIP_EXTENSION: - if os.getenv("NVTE_ROCM_ENABLE_MXFP8", "0") == "0": - return False, "MXFP8 support is not enabled." - gpu_arch = get_device_compute_capability() - if gpu_arch == (9, 5): - return True, "" - return False, "Gfx95x is required for MXFP8 execution." - if get_device_compute_capability() >= (12, 0): - return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." - if get_device_compute_capability() >= (10, 0): # blackwell and above - return True, "" - return False, "Device compute capability 10.0 or higher required for MXFP8 execution." - - -def check_fp8_block_scaling_support() -> Tuple[bool, str]: - """Return if fp8 block scaling support is available""" - if IS_HIP_EXTENSION: - return False, "FP8 block scaled gemm not yet supported for ROCm" - if ( - get_device_compute_capability() >= (9, 0) - and get_device_compute_capability() < (10, 0) - and float(torch.version.cuda) >= 12.9 - ): - return True, "" - return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." - - -def check_recipe_support(recipe: Recipe) -> None: - """Check if the given recipe is supported.""" - recipe_supported = True - unsupported_reason = "" - if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)): - recipe_supported, unsupported_reason = check_fp8_support() - elif isinstance(recipe, Float8BlockScaling): - recipe_supported, unsupported_reason = check_fp8_block_scaling_support() - elif isinstance(recipe, MXFP8BlockScaling): - recipe_supported, unsupported_reason = check_mxfp8_support() - assert recipe_supported, unsupported_reason - - -def get_default_fp8_recipe() -> Recipe: - """FP8 recipe with default args.""" - if IS_HIP_EXTENSION: - if os.getenv("NVTE_ROCM_ENABLE_MXFP8", "0") != "2": - return DelayedScaling() - gpu_arch = get_device_compute_capability() - if gpu_arch == (9, 5): - return MXFP8BlockScaling() - return DelayedScaling() - if check_mxfp8_support()[0]: - return MXFP8BlockScaling() - if get_device_compute_capability() >= (12, 0): - # This is a temporary restriction until MXFP8 is supported for all gemm layouts. - return Float8CurrentScaling() - return DelayedScaling() - - -def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype: - """Get fp8 data type according to recipe and tensor""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return get_torch_float8_e4m3_type() - return get_torch_float8_e5m2_type() - - -def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: - """Get fp8 data type according to recipe and tensor""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return tex.DType.kFloat8E4M3 - return tex.DType.kFloat8E5M2 - - -def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: - """Get max representible FP8 value.""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return Format.E4M3.value.max_fwd - return Format.E5M2.value.max_fwd - - -class FP8GlobalStateManager: - """Class to keep track of and manipulate the global - FP8 state at different stages of execution. - """ - - FP8_ENABLED = False - FP8_CALIBRATION = False - FP8_RECIPE = None - FP8_DISTRIBUTED_GROUP = None - FP8_PARAMETERS = False - HIGH_PRECISION_INIT_VAL = False - IS_FIRST_FP8_MODULE = False - FP8_GRAPH_CAPTURING = False - SKIP_FP8_REDUCTION_FOR_FSDP2 = False - FP8_AUTOCAST_DEPTH = 0 - global_amax_buffer = {} - global_amax_history_buffer = {} - global_scale_buffer = {} - fp8_tensors_recompute_buffer = [] - fp8_available = None - reason_for_no_fp8 = "" - autocast_arguments = {} - autocast_to_fp8_params = {} - fp8_param_to_autocast = {} - skip_fp8_weight_update_tensor = None - mxfp8_available = None - reason_for_no_mxfp8 = "" - fp8_block_scaling_available = None - reason_for_no_fp8_block_scaling = None - - @classmethod - def reset(cls) -> None: - """Reset the global state""" - cls.FP8_ENABLED = False - cls.FP8_CALIBRATION = False - cls.FP8_RECIPE = None - cls.FP8_DISTRIBUTED_GROUP = None - cls.FP8_PARAMETERS = False - cls.HIGH_PRECISION_INIT_VAL = False - cls.IS_FIRST_FP8_MODULE = False - cls.FP8_GRAPH_CAPTURING = False - cls.FP8_AUTOCAST_DEPTH = 0 - cls.global_amax_buffer = {} - cls.global_amax_history_buffer = {} - cls.global_scale_buffer = {} - cls.fp8_tensors_recompute_buffer = [] - cls.fp8_available = None - cls.reason_for_no_fp8 = "" - cls.autocast_arguments = {} - cls.autocast_to_fp8_params = {} - cls.fp8_param_to_autocast = {} - cls.skip_fp8_weight_update_tensor = None - cls.mxfp8_available = None - cls.reason_for_no_mxfp8 = "" - cls.fp8_block_scaling_available = None - cls.reason_for_no_fp8_block_scaling = "" - - @classmethod - def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: - """`skip_fp8_weight_update_tensor` inplace setter.""" - if cls.skip_fp8_weight_update_tensor is None: - cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda") - cls.skip_fp8_weight_update_tensor.fill_(skip) - - @classmethod - def get_skip_fp8_weight_update_tensor(cls) -> None: - """`skip_fp8_weight_update_tensor` getter.""" - return cls.skip_fp8_weight_update_tensor - - @classmethod - def is_fp8_available(cls) -> Tuple[bool, str]: - """Return if fp8 support is available""" - if cls.fp8_available is None: - cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support() - return cls.fp8_available, cls.reason_for_no_fp8 - - @classmethod - def is_mxfp8_available(cls) -> Tuple[bool, str]: - """Return if MXFP8/current scaling support is available.""" - if cls.mxfp8_available is None: - cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support() - return cls.mxfp8_available, cls.reason_for_no_mxfp8 - - @classmethod - def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: - """Return if Float8 block scaling support is available.""" - if cls.fp8_block_scaling_available is None: - cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = ( - check_fp8_block_scaling_support() - ) - return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling - - @staticmethod - def get_meta_tensor_key(forward: bool = True) -> str: - """Returns scaling key in `fp8_meta`.""" - if forward: - return "scaling_fwd" - return "scaling_bwd" - - @staticmethod - def get_fwd_bwd_key(forward: bool = True) -> str: - """Convert bool `forward` to string.""" - return "forward" if forward else "backward" - - @classmethod - def get_buffer_info(cls) -> str: - """ - Returns a key for `fp8_meta` that stores the module's index - in the global buffers along with autocast information. - """ - return "buffer_index_and_autocast_key" - - @classmethod - def get_key_in_buffer( - cls, - forward: bool, - fp8_recipe: Recipe, - fp8_group: dist_group_type, - ) -> str: - """Returns a key into the global FP8 buffers.""" - autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - fwd_bwd_key = cls.get_fwd_bwd_key(forward) - return f"{fwd_bwd_key}_{autocast_key}" - - @classmethod - def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]: - """Splits buffer key into relevant parts.""" - forward, autocast_key = key.split("_", 1) - forward = forward == "forward" - return forward, autocast_key - - @classmethod - def add_fp8_tensors_to_global_buffer( - cls, - fp8_meta: Dict[str, Any], - ) -> None: - """ - Delayed scaling only. - - The amax reduction process happens completely outside the FP8 modules. - To participate in the reduction, the only role played by a module is - to call this function in order to append it's FP8 tensor into a global - buffer. There are 5 global buffers maintained, one each for amax, amax - history, scale, scale-inverse, and non-weight-mask. Each buffer has - keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix - to indicate the type of FP8 tensor, since the forward and backward - reductions happen separately. - - Note: For CG capture, this method is called from the graphed - wrapper. For non CG case, it's called from within the module. - """ - - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - # Every module must call this function exactly once since - # the amax tensors are static. Ensures that compatibility - # with non-graphed modules is maintained. - index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors. - if index_in_buffer in fp8_meta: - return - - fp8_meta[index_in_buffer] = [] - for forward in (True, False): - fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) - if fp8_meta_tensor_key not in fp8_meta: - # Handles non-parameter FP8 modules, e.g. DPA. - continue - - key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) - - if key not in cls.global_amax_buffer: - cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] - cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] - else: - cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) - cls.global_amax_history_buffer[key].append( - fp8_meta[fp8_meta_tensor_key].amax_history - ) - cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) - fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) - fp8_meta[index_in_buffer].append(key) - - @classmethod - def is_fp8_enabled(cls) -> bool: - """Is FP8 enabled""" - return cls.FP8_ENABLED - - @classmethod - def is_fp8_calibration(cls) -> bool: - """Is FP8 calibration""" - return cls.FP8_CALIBRATION - - @classmethod - def with_fp8_parameters(cls) -> bool: - """Should the parameters be stored as FP8""" - return cls.FP8_PARAMETERS - - @classmethod - def with_high_precision_init_val(cls) -> bool: - """Should the high precision initial values be stored with FP8 parameters""" - return cls.HIGH_PRECISION_INIT_VAL - - @classmethod - def fp8_graph_capturing(cls) -> bool: - """Is CUDA graph capture under way?""" - return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing() - - @classmethod - def is_first_fp8_module(cls): - """Returns `True` only the first time when called multiple - times from within the same `fp8_autocast` context. - """ - tmp = cls.IS_FIRST_FP8_MODULE - cls.IS_FIRST_FP8_MODULE = False - return tmp - - @classmethod - def get_fp8_recipe(cls) -> Recipe: - """Return the fp8 recipe""" - if cls.FP8_RECIPE is not None: - return cls.FP8_RECIPE - return get_default_fp8_recipe() - - @classmethod - def get_fp8_group(cls) -> Union[dist_group_type, None]: - """Return the fp8 group for scale/amax comm""" - return cls.FP8_DISTRIBUTED_GROUP - - @classmethod - def get_fp8_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: - """FP8 autocast state getter""" - return ( - cls.FP8_ENABLED, - cls.FP8_CALIBRATION, - cls.FP8_RECIPE, - cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING, - ) - - @classmethod - def set_fp8_autocast_state( - cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool] - ) -> None: - """FP8 autocast state setter""" - ( - cls.FP8_ENABLED, - cls.FP8_CALIBRATION, - cls.FP8_RECIPE, - cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING, - ) = fp8_state - - @staticmethod - def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None: - """Reduce tensor across given group.""" - if torch.distributed.is_initialized(): - torch.distributed.all_reduce( - tensor, - op=torch.distributed.ReduceOp.MAX, - group=group, - async_op=False, - ) - - @classmethod - def reduce_and_update_fp8_tensors( - cls, - forward: bool = True, - ) -> None: - """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" - # global_amax_buffer should only be non-empty for fp8 delayed scaling - for buffer_key, amax_buffer in cls.global_amax_buffer.items(): - # Check for forward or backward reduction. - fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) - if fwd_update != forward: - continue - if len(amax_buffer) == 0: - continue - - # Retrieve autocast specific args and concat amaxes. - recipe, group = cls.autocast_arguments[autocast_key] - contiguous_amax = torch.cat(amax_buffer) - - # Reduction. - if ( - recipe.reduce_amax - and torch.distributed.is_initialized() - and torch.distributed.get_world_size(group=group) > 1 - ): - cls.reduce_tensor_across_group_op_max(contiguous_amax, group) - - # Amax and scale update. - unfused_update = ( - bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0"))) - or callable(recipe.amax_compute_algo) - or callable(recipe.scaling_factor_compute_algo) - ) - - if not unfused_update: - tex.fused_amax_and_scale_update_after_reduction( - contiguous_amax, - cls.global_amax_history_buffer[buffer_key], - cls.global_scale_buffer[buffer_key], - recipe.amax_compute_algo, - get_fp8_te_dtype(recipe, forward), - recipe.margin, - ) - else: - split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) - - for amax_history, scale in zip( - cls.global_amax_history_buffer[buffer_key], - cls.global_scale_buffer[buffer_key], - ): - _amax_and_scale_update( - amax_history, scale, get_fp8_max(recipe, forward), recipe - ) - - @classmethod - def get_unique_autocast_key( - cls, - recipe: Optional[Recipe] = None, - group: Optional[dist_group_type] = None, - ): - """ - For FP8, each autocast can be uniquely identified by the recipe and fp8 group. - Safely using `hash` as we never cross checkpoint boundaries. - """ - return f"{str(recipe)}:{hash(group)}" - - @classmethod - def fp8_autocast_enter( - cls, - enabled: bool = False, - calibrating: bool = False, - fp8_recipe: Optional[Recipe] = None, - fp8_group: Optional[dist_group_type] = None, - _graph: bool = False, - ) -> None: - """Set state and tracking variables for entry into FP8 region.""" - - fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe - autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) - - cls.FP8_ENABLED = enabled - cls.FP8_CALIBRATION = calibrating - cls.FP8_RECIPE = fp8_recipe - cls.FP8_DISTRIBUTED_GROUP = fp8_group - cls.FP8_GRAPH_CAPTURING = _graph - - if cls.FP8_AUTOCAST_DEPTH == 0: - cls.IS_FIRST_FP8_MODULE = True - cls.FP8_AUTOCAST_DEPTH += 1 - - if enabled: - fp8_available, reason_for_no_fp8 = cls.is_fp8_available() - assert fp8_available, reason_for_no_fp8 - if isinstance(fp8_recipe, MXFP8BlockScaling): - mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() - assert mxfp8_available, reason_for_no_mxfp8 - if isinstance(fp8_recipe, Float8BlockScaling): - fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() - assert fp8_block_available, reason_for_no_fp8_block - - @classmethod - def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: - """Set state and tracking variables for exit from FP8 region.""" - cls.FP8_AUTOCAST_DEPTH -= 1 - # Reduce only the non-FP8 weight modules here. - # FP8 weight modules are reduced at the end of the optimizer - # step after the weight amax is populated. - if not cls.SKIP_FP8_REDUCTION_FOR_FSDP2 and enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): - # delayed scaling only function, for other recipes (current scaling with any granularity), - # this is noop for other recipes because cls.global_amax_buffer is empty list - cls.reduce_and_update_fp8_tensors(forward=True) - - @classmethod - def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: - """Copy the scaling factors and amaxes for recompute forward phase - to ensure both forward steps are numerically same. - """ - - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - - to_copy = [ - fp8_meta["scaling_fwd"].amax_history.clone(), - fp8_meta["scaling_fwd"].scale.clone(), - ] - - if buffer_position_key in fp8_meta: - cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) - else: - if len(cls.fp8_tensors_recompute_buffer) == 0: - cls.fp8_tensors_recompute_buffer = [deque()] - else: - cls.fp8_tensors_recompute_buffer.append(deque()) - cls.fp8_tensors_recompute_buffer[-1].append(to_copy) - fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1 - - @classmethod - def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: - """Switch to the copied scaling factors and amaxes from phase - 1 forward for indentical numerical outputs. - """ - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - # Store updated amaxes and scales from phase 1 post forward. - fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history.clone() - fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale.clone() - - # Retrieve stashed amaxes and scales from phase 1 pre forward. - buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() - - # Replace amaxes and scales with stashed values for phase 2 forward - fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) - fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) - - @staticmethod - def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: - """Restore latest scaling factors and amaxes after recompute forward run.""" - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) - fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) - - -@contextmanager -def fp8_model_init( - enabled: bool = True, - recipe: Optional[Recipe] = None, - preserve_high_precision_init_val: bool = False, -) -> None: - """ - Context manager for FP8 initialization of parameters. - - Example usage: - - .. code-block:: python - - with fp8_model_init(enabled=True): - model = transformer_engine.pytorch.Linear(768, 768) - - # Preserving high precision initial value to initialize master weight - with fp8_model_init(enabled=True, preserve_high_precision_init_val=True): - model = transformer_engine.pytorch.Linear(768, 768) - master_weight = model.weight.get_high_precision_init_val() - model.weight.clear_high_precision_init_val() - - Parameters - ---------- - enabled: bool, default = `True` - when enabled, Transformer Engine modules created inside this `fp8_model_init` - region will hold only FP8 copies of its parameters, as opposed to the default - behavior where both higher precision and FP8 copies are present. Setting this - option to `True` may result in lower memory consumption and is especially - useful for scenarios like: - - * full model training using optimizer with master weights, where the high - precision copies of weights are already present in the optimizer. - * inference, where only the FP8 copies of the parameters are used. - * LoRA-like fine-tuning, where the main parameters of the model do not change. - recipe: transformer_engine.common.recipe.Recipe, default = `None` - Recipe used to create the parameters. If left to None, it uses the default FP8 recipe. - preserve_high_precision_init_val: bool, default = `False` - when enabled, store the high precision tensor used to initialize FP8 parameters - in CPU memory, and add two function attributes named `get_high_precision_init_val()` - and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high - precision tensor. The purpose is that users can use this high-precision copy - to initialize master weights, avoiding the loss of precision that can occur when - using FP8 parameters directly. Note that after the master weights are initialized, - users should call `clear_high_precision_init_val()` to release this CPU memory. - - This functionality is *EXPERIMENTAL*. - """ - _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS - _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE - _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL - FP8GlobalStateManager.FP8_PARAMETERS = enabled - FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe - FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val - try: - yield - finally: - FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters - FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe - FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val - - -@contextmanager -def fp8_autocast( - enabled: bool = True, - calibrating: bool = False, - fp8_recipe: Optional[Recipe] = None, - fp8_group: Optional[dist_group_type] = None, - _graph: bool = False, -) -> None: - """ - Context manager for FP8 usage. - - .. code-block:: python - - with fp8_autocast(enabled=True): - out = model(inp) - - .. note:: - - Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors - with shapes where both dimensions are divisible by 16. In terms of the input to the full - Transformer network, this typically requires padding sequence length to be multiple of 16. - - .. note:: - - When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once - inside a single `fp8_autocast` region. This is unsupported behavior because the amax - reduction is handled during the exit of the `fp8_autocast` context. Calling the same - module more than once inside an `fp8_autocast` region overrides the amax tensors - before reduction can occur. - - Parameters - ---------- - enabled: bool, default = `True` - whether or not to enable fp8 - calibrating: bool, default = `False` - calibration mode allows collecting statistics such as amax and scale - data of fp8 tensors even when executing without fp8 enabled. This is - useful for saving an inference ready fp8 checkpoint while training - using a higher precision. - fp8_recipe: recipe.Recipe, default = `None` - recipe used for FP8 training. - fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` - distributed group over which amaxes for the fp8 tensors - are reduced at the end of each training step. - """ - if enabled: - check_recipe_support(fp8_recipe) - fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() - FP8GlobalStateManager.fp8_autocast_enter( - enabled=enabled, - calibrating=calibrating, - fp8_recipe=fp8_recipe, - fp8_group=fp8_group, - _graph=_graph, - ) - try: - yield - finally: - FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) - FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph) - - -def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: - """Update amax history and set next amax to zero.""" - if amax_history.shape[0] > 1: - new_amax_history = torch.roll(amax_history, -1, 0) - amax_history.copy_(new_amax_history) - amax_history[0].fill_(0.0) - return amax_history - - -@torch.jit.script -def _default_get_amax_and_update_history( - amax_history: torch.Tensor, - amax_compute_algo: str, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Default function to obtain amax from history.""" - if amax_compute_algo == "max": - amax = torch.max(amax_history, dim=0).values - else: # amax_compute_algo == "most_recent" - amax = amax_history[0].clone() - - amax_history = _update_amax_history(amax_history) - return amax_history, amax - - -@jit_fuser -def _default_sf_compute( - amax: torch.Tensor, - scale: torch.Tensor, - fp8_max: float, - margin: int, - _fp32_max: float = torch.finfo(torch.float32).max, # finfo not available in jitter -) -> torch.Tensor: - """Default function to convert amax to scaling factor. - Computing the scaling factor requires consideration of the following scenarios: - 1. amax == 0: - No action is possible, set scale to the previous scale (or 1). - 2. 0 < amax < tiny_amax - The amax is too tiny that the scale becomes infinite in FP32. - Set scale = FP32_max - 3. tiny_amax <= amax < FP32_max: - Set scale = FP8_max (or scaled_max) / amax - 4. When amax == inf or amax == nan: - No action is possible, set scale to the previous scale (or 1). - """ - sf = (fp8_max / amax) / (2**margin) - sf = torch.where(amax > 0.0, sf, scale) - sf = torch.where(torch.isfinite(amax), sf, scale) - sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf) - scale.copy_(sf) - return scale - - -def _compute_amax_and_update_history( - amax_history: torch.Tensor, - amax_compute_algo: Union[Callable, str], -) -> Tuple[torch.Tensor, torch.Tensor]: - """Obtain the amax from the history.""" - - if callable(amax_compute_algo): - amax = amax_compute_algo(amax_history) - amax_history = _update_amax_history(amax_history) - return amax_history, amax - return _default_get_amax_and_update_history( - amax_history, - amax_compute_algo, - ) - - -def _compute_scaling_factor( - amax: torch.Tensor, - scale: torch.Tensor, - fp8_max: float, - recipe: DelayedScaling, -) -> torch.Tensor: - """Convert amax to scaling factor.""" - - if recipe.scaling_factor_compute_algo is None: - return _default_sf_compute( - amax, - scale, - fp8_max, - recipe.margin, - ) - return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) - - -def _amax_and_scale_update( - amax_history: torch.Tensor, - scale: torch.Tensor, - fp8_max: float, - recipe: DelayedScaling, -) -> None: - """Updates FP8 meta tensors.""" - new_amax_history, amax = _compute_amax_and_update_history( - amax_history, - recipe.amax_compute_algo, - ) - new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe) - scale.copy_(new_scale) - amax_history.copy_(new_amax_history) - - -def split_and_copy( - buffer: torch.Tensor, - outputs: List[torch.Tensor], - chunk_sizes: List[int], -) -> None: - """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" - splits = buffer.split(chunk_sizes) - torch._foreach_copy_(outputs, splits) - - -class RecipeState(abc.ABC): - """Configuration and state for a quantization recipe. - - This is a builder class for quantizers, which are in turn builder - classes for quantized tensors. - - This class may pack together the state for multiple quantizers, - which is helpful for applying fused kernels with less overhead. - - """ - - @staticmethod - def create( - recipe: Recipe, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> RecipeState: - """Factory method to create the state for a quantization recipe - - Parameters - ---------- - recipe: Recipe - Quantization recipe. - mode: {"forward", "backward"} - Training stage where quantization will be performed. - num_quantizers: int, default = 1 - Number of quantizers to create state for. - device: torch.device, default = default CUDA device - Device for quantized tensors. - - Returns - ------- - RecipeState: - Quantization recipe state. - - """ - - cls = None - if recipe.delayed(): - cls = DelayedScalingRecipeState - elif recipe.mxfp8(): - cls = MXFP8BlockScalingRecipeState - elif recipe.float8_current_scaling(): - cls = Float8CurrentScalingRecipeState - elif recipe.float8_block_scaling(): - cls = Float8BlockScalingRecipeState - else: - raise ValueError(f"{recipe.__class__.__name__} is not supported") - return cls( - recipe, - mode=mode, - num_quantizers=num_quantizers, - device=device, - ) - - @abc.abstractmethod - def make_quantizers(self) -> list: - """Convert recipe state to quantizers. - - Quantizers are builder classes for quantized tensors. They are - typically used to convert a high-precision tensor (e.g. in - FP32 or BF16) into a quantized tensor (e.g. in FP8). - - """ - - -class DelayedScalingRecipeState(RecipeState): - """State for FP8 quantization with per-tensor delayed scaling. - - Delayed scaling recipe requires a scaling factor (applied when - casting to FP8) and a history of max-abs values ("amax") from - recent FP8 casts for updating the scaling factor. The scale update - is handled externally by `FP8GlobalStateManager`. - - """ - - recipe: DelayedScaling - mode: str - dtype: tex.DType - scale: torch.Tensor - amax_history: torch.Tensor - - def __init__( - self, - recipe: DelayedScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") - - # Allocate buffers - if device is None: - device = torch.device("cuda") - self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device) - self.amax_history = torch.zeros( - recipe.amax_history_len, - num_quantizers, - dtype=torch.float32, - device=device, - ) - - def make_quantizers(self) -> list: - # TODO(ksivamani); Find better design for this, adding here to avoid circular import. - from .tensor.float8_tensor import Float8Quantizer - - return [ - Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype) - for i in range(self.num_quantizers) - ] - - -class Float8CurrentScalingRecipeState(RecipeState): - """Configuration for Per-tensor current scaling quantization. - - Per-tensor current quantization does not require state. - - """ - - recipe: Float8CurrentScaling - mode: str - dtype: tex.DType - device: torch.device - - def __init__( - self, - recipe: Float8CurrentScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") - - # Allocate buffers - if device is None: - device = torch.device("cuda") - self.device = device - - def make_quantizers(self) -> list: - from .tensor.float8_tensor import Float8CurrentScalingQuantizer - - return [ - Float8CurrentScalingQuantizer(self.dtype, device=self.device) - for i in range(self.num_quantizers) - ] - - -class MXFP8BlockScalingRecipeState(RecipeState): - """Configuration for MXFP8 quantization. - - MXFP8 quantization does not require state. - - """ - - recipe: MXFP8BlockScaling - mode: str - dtype: tex.DType - - def __init__( - self, - recipe: MXFP8BlockScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") - - # Allocate buffers - if device is None: - device = torch.device("cuda") - - def make_quantizers(self) -> list: - # TODO(ksivamani); Find better design for this, adding here to avoid circular import. - from .tensor.mxfp8_tensor import MXFP8Quantizer - - return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] - - -class Float8BlockScalingRecipeState(RecipeState): - """Configuration for Float8BlockScaling quantization. - - Float8BlockScaling quantization does not require state, - but different quantizers use different modes. - """ - - recipe: Float8BlockScaling - mode: str - qx_dtype: tex.DType - qw_dtype: tex.DType - qgrad_dtype: tex.DType - - def __init__( - self, - recipe: Float8BlockScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.qx_dtype = get_fp8_te_dtype(recipe, True) - self.qw_dtype = get_fp8_te_dtype(recipe, True) - self.qgrad_dtype = get_fp8_te_dtype(recipe, False) - - # Allocate buffers - if device is None: - device = torch.device("cuda") - self.device = device - - def make_quantizers(self) -> list: - # TODO(ksivamani); Find better design for this, adding here to avoid circular import. - from .tensor.float8_blockwise_tensor import Float8BlockQuantizer - - if self.mode == "forward": - # The index convention (coming from base.py set_meta_tensor) - # is somewhat awkward, and doesn't play nicely with QuantizeOp, - # which is not associated with a GEMM. - assert self.num_quantizers % 3 == 0 # x, w, output per gemm - return list( - itertools.chain.from_iterable( - [ - [ - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qw_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, - block_scaling_dim=self.recipe.w_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - ] - for _ in range(self.num_quantizers // 3) - ] - ) - ) - - assert self.mode == "backward", f"Unexpected mode {self.mode}" - assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm - return list( - itertools.chain.from_iterable( - [ - [ - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), - ] - for _ in range(self.num_quantizers // 2) - ] - ) - ) -======= # Importing each function instead of 'import *' allows us specify '__all__' in # quantize.py and also makes any newer additions to quantize.py invisible via @@ -1158,4 +68,3 @@ def make_quantizers(self) -> list: NVFP4BlockScalingRecipeState, CustomRecipeState, ) ->>>>>>> 389a6b diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 591fa60c2..4ba5da68d 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -1,19 +1,13 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Internal function used by multiple modules.""" -<<<<<<< HEAD -import os -from typing import Any, List, Optional, Tuple, Union, Callable -from dataclasses import dataclass -======= import dataclasses ->>>>>>> 389a6b import queue from typing import Any, Callable, List, Optional, Tuple, Union @@ -28,6 +22,7 @@ if IS_HIP_EXTENSION: from ..triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton + import os def _get_normalization_func(normalization: str, forward: bool): use_rmsnorm_triton = bool( int(os.environ.get('NVTE_USE_RMSNORM_TRITON', '0')) ) and IS_HIP_EXTENSION diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6a3bda0c6..661cf3f2e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -19,11 +19,8 @@ import torch import torch.nn.functional as F -<<<<<<< HEAD -from torch.utils.cpp_extension import IS_HIP_EXTENSION -======= from torch.distributed.tensor import DTensor ->>>>>>> 389a6b +from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe @@ -49,21 +46,13 @@ from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer -<<<<<<< HEAD if IS_HIP_EXTENSION: from ..tensor.fsdp2_allgather_tensor import FSDPAGTensor -from ..tensor._internal.float8_tensor_base import Float8TensorBase -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -if IS_HIP_EXTENSION: from ..triton_kernels.cast import te_quantize_triton -from ..utils import get_device_compute_capability, is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase -======= from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage ->>>>>>> 389a6b from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor @@ -1328,11 +1317,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: raise RuntimeError("Weight quantizer has not been initialized") quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.internal = False -<<<<<<< HEAD if IS_HIP_EXTENSION and not self.keep_fp8_weight_transpose_cache: quantizer.columnwise_usage=False -======= if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer): device_mesh = dtensor_param.device_mesh amax_reduction_group = ( @@ -1342,7 +1329,6 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: ) quantizer.amax_reduction_group = amax_reduction_group quantizer.with_amax_reduction = True ->>>>>>> 389a6b # Quantize parameter param = quantizer(param) if IS_HIP_EXTENSION and self.use_fsdp2 and not self.primary_weights_in_fp8 and fp8_meta_index is not None: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 0d1f3f5b0..a906ea42d 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -305,16 +305,11 @@ def forward( is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage) # Configure quantizer -<<<<<<< HEAD - if weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) -======= # If weight is already quantized, no need to set quantizer states if is_weight_param_quantized: weight_quantizer = weight._quantizer elif weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) ->>>>>>> 389a6b + weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch @@ -447,14 +442,11 @@ def forward( ): ln_out.update_usage(rowwise_usage=False) -<<<<<<< HEAD # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: if isinstance(weightmat, QuantizedTensorBase): weightmat.update_usage(columnwise_usage=True) -======= ->>>>>>> 389a6b if cpu_offloading: mark_activation_offload(inputmat, mu, rsigma, ln_out) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4a0397502..3fefb650e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -371,22 +371,17 @@ def forward( # which handles weight caching etc. # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch -<<<<<<< HEAD - fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) -======= # No need to set the quantizer states if weights are already quantized if isinstance(fc1_weight, QuantizedTensorStorage): fc1_weight_quantizer = fc1_weight._quantizer elif fc1_weight_quantizer is not None: - fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) if isinstance(fc2_weight, QuantizedTensorStorage): fc2_weight_quantizer = fc2_weight._quantizer elif fc2_weight_quantizer is not None: - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) ->>>>>>> 389a6b fc1_weight_final = module.get_weight_workspace( tensor=fc1_weight, quantizer=fc1_weight_quantizer, @@ -579,8 +574,6 @@ def forward( # Cache state for backward pass if is_grad_enabled: -<<<<<<< HEAD - # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: if isinstance(fc1_weight_final, QuantizedTensorBase): @@ -588,8 +581,6 @@ def forward( if isinstance(fc2_weight_final, QuantizedTensorBase): fc2_weight_final.update_usage(columnwise_usage=True) -======= ->>>>>>> 389a6b if cpu_offloading: mark_activation_offload( inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out @@ -906,11 +897,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ctx.fc2_weight_quantizer is not None and isinstance( -<<<<<<< HEAD - fc2_weight, QuantizedTensorBase -======= ctx.fc2_weight, QuantizedTensorStorage ->>>>>>> 389a6b ): fc2_weight.update_usage(columnwise_usage=True) @@ -1168,11 +1155,7 @@ def fc2_wgrad_gemm( # Make sure required data is available if ctx.fc1_weight_quantizer is not None and isinstance( -<<<<<<< HEAD - fc1_weight, QuantizedTensorBase -======= ctx.fc1_weight_quantizer, QuantizedTensorStorage ->>>>>>> 389a6b ): fc1_weight.update_usage(columnwise_usage=True) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index e7d2cf8d8..ffa47d986 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -253,16 +253,10 @@ def forward( weightmat = weight if fp8 or debug: # Configure quantizer -<<<<<<< HEAD - if weight_quantizer is not None: - columnwise_usage = is_grad_enabled and inp.requires_grad and keep_fp8_weight_transpose_cache - if not columnwise_usage and keep_fp8_weight_transpose_cache: -======= # No need to set the quantizer states if weight is already quantized if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): - columnwise_usage = is_grad_enabled and inp.requires_grad - if not columnwise_usage: ->>>>>>> 389a6b + columnwise_usage = is_grad_enabled and inp.requires_grad and keep_fp8_weight_transpose_cache + if not columnwise_usage and keep_fp8_weight_transpose_cache: columnwise_usage = ( is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() @@ -414,14 +408,11 @@ def forward( if backward_needs_input: saved_inputmat = inputmat -<<<<<<< HEAD # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: if isinstance(weightmat, QuantizedTensorBase): weightmat.update_usage(columnwise_usage=True) -======= ->>>>>>> 389a6b if cpu_offloading and saved_inputmat is not None: mark_activation_offload(saved_inputmat) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index c1985e04c..7eb04fa27 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -29,16 +29,12 @@ fuse_forward_linear_bias_add, fuse_forward_linear_scale_add, ) -<<<<<<< HEAD if not IS_HIP_EXTENSION: from transformer_engine.pytorch.ops.fused import ( fuse_userbuffers_backward_linear, fuse_userbuffers_forward_linear, ) -from transformer_engine.pytorch.tensor.quantized_tensor import ( -======= from transformer_engine.pytorch.quantized_tensor import ( ->>>>>>> 389a6b prepare_for_saving, restore_from_saved, ) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 030370b9d..0b7eddb9f 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -31,6 +31,8 @@ from .utils import get_device_compute_capability from .jit import jit_fuser +from torch.utils.cpp_extension import IS_HIP_EXTENSION +from .utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type __all__ = [ "autocast", @@ -46,6 +48,12 @@ @functools.lru_cache(maxsize=None) def check_fp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" + if IS_HIP_EXTENSION: + gpu_arch = get_device_compute_capability() + if gpu_arch in ((9, 4), (9, 5)): + return True, "" + else: + return False, "Device arch gfx94x or gfx95x required for FP8 execution." if get_device_compute_capability() >= (9, 0): # hopper and above return True, "" if get_device_compute_capability() < (8, 9): # pre-ada @@ -60,6 +68,13 @@ def check_fp8_support() -> Tuple[bool, str]: @functools.lru_cache(maxsize=None) def check_mxfp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" + if IS_HIP_EXTENSION: + if os.getenv("NVTE_ROCM_ENABLE_MXFP8", "0") == "0": + return False, "MXFP8 support is not enabled." + gpu_arch = get_device_compute_capability() + if gpu_arch == (9, 5): + return True, "" + return False, "Gfx95x is required for MXFP8 execution." if get_device_compute_capability() >= (12, 0): return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." if get_device_compute_capability() >= (10, 0): # blackwell and above @@ -69,6 +84,8 @@ def check_mxfp8_support() -> Tuple[bool, str]: @functools.lru_cache(maxsize=None) def check_nvfp4_support() -> Tuple[bool, str]: + if IS_HIP_EXTENSION: + return False, "ROCm TE currently not supporting NVFP4" """Return if nvfp4 support is available""" if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" @@ -78,6 +95,8 @@ def check_nvfp4_support() -> Tuple[bool, str]: @functools.lru_cache(maxsize=None) def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" + if IS_HIP_EXTENSION: + return False, "FP8 block scaled gemm not yet supported for ROCm" if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: return True, "" return ( @@ -101,6 +120,13 @@ def check_recipe_support(recipe: Recipe) -> None: def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" + if IS_HIP_EXTENSION: + if os.getenv("NVTE_ROCM_ENABLE_MXFP8", "0") != "2": + return DelayedScaling() + gpu_arch = get_device_compute_capability() + if gpu_arch == (9, 5): + return MXFP8BlockScaling() + return DelayedScaling() if check_mxfp8_support()[0]: return MXFP8BlockScaling() if get_device_compute_capability() >= (12, 0): @@ -119,8 +145,8 @@ def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch. if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor ): - return torch.float8_e4m3fn - return torch.float8_e5m2 + return get_torch_float8_e4m3_type() + return get_torch_float8_e5m2_type() def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 35bacb32a..dd6a7ebc5 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -331,86 +331,12 @@ def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-a """Returns whether or not given tensor can be quantized""" return True -<<<<<<< HEAD:transformer_engine/pytorch/tensor/quantized_tensor.py -class _QuantizeFunc(torch.autograd.Function): - """Cast to FP8 from other dtype""" - - @staticmethod - def forward( - _ctx: Optional[torch.autograd.function.FunctionCtx], # unused - tensor: torch.Tensor, - quantizer: Quantizer, - ) -> QuantizedTensor: - # pylint: disable=missing-function-docstring - if IS_HIP_EXTENSION: - from ..triton_kernels.cast import te_quantize_triton - use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) - quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize - return quantize_func(tensor, quantizer) - else: - return tex.quantize(tensor, quantizer) - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - # Assume that we want gradients in full precision - return grad, None - - -class _IdentityFunc(torch.autograd.Function): - """Identity function - - If constructor keyword-arguments are provided, then construct a - new Float8Tensor using the provided tensor's attributes. - - """ - - @staticmethod - def forward( - ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None - ) -> QuantizedTensor: - # pylint: disable=missing-function-docstring - - # Return input tensor if constructor kwargs are not provided - if init_kwargs is None: - return tensor.detach() - - # Construct new tensor if constructor kwargs are provided - ctx.input_dtype = tensor.dtype - kwargs = tensor.get_metadata() - for key, val in init_kwargs.items(): - kwargs[key] = val - return type(tensor)(tensor.shape, tensor.dtype, **kwargs) - - @staticmethod - def backward(ctx, grad_output): - # pylint: disable=missing-function-docstring - grad_input = grad_output - if grad_input.dtype == ctx.input_dtype: - grad_input = grad_input.detach() - else: - grad_input = grad_input.to(ctx.input_dtype) - return grad_input, None - - -def _stride_from_shape(shape: list[int]): - if len(shape) == 0: - return [] - rstride = [1] - for d in reversed(shape[1:]): - rstride.append(rstride[-1] * d) - return list(reversed(rstride)) -======= def get_usages(self) -> Dict[str, bool]: """Get the usage of the quantizer""" return { "rowwise": self.rowwise_usage, "columnwise": self.columnwise_usage, } ->>>>>>> 389a6b:transformer_engine/pytorch/quantized_tensor.py - class QuantizedTensor(torch.Tensor): """Abstract base class for tensor with quantized data diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index c59001e61..73f926c61 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -47,12 +47,8 @@ from build_tools.build_ext import get_build_ext -<<<<<<< HEAD -from build_tools.utils import ( - rocm_build, copy_common_headers, copy_hipify_tools, clear_hipify_tools_copy ) -======= +from build_tools.utils import rocm_build, copy_hipify_tools, clear_hipify_tools_copy from build_tools.utils import copy_common_headers, min_python_version_str ->>>>>>> 389a6b from build_tools.te_version import te_version from build_tools.pytorch import ( setup_pytorch_extension, diff --git a/transformer_engine/pytorch/tensor/_quantization_helpers.py b/transformer_engine/pytorch/tensor/_quantization_helpers.py index 2214edbff..55fc4785d 100644 --- a/transformer_engine/pytorch/tensor/_quantization_helpers.py +++ b/transformer_engine/pytorch/tensor/_quantization_helpers.py @@ -26,6 +26,7 @@ def forward( quantize_impl: Callable, ) -> QuantizedTensor: # pylint: disable=missing-function-docstring + # TODO: bring back triton based quantization return quantize_impl(tensor) @staticmethod diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 789d207b0..8f741b7f2 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -6,16 +6,9 @@ """Tensor class with FP8 data""" from __future__ import annotations -<<<<<<< HEAD -import os -from typing import Optional, Tuple, Iterable, Union -import warnings -from torch.utils.cpp_extension import IS_HIP_EXTENSION -======= from typing import Any, Optional, Tuple, Iterable, Union import warnings ->>>>>>> 389a6b import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState import transformer_engine_torch as tex @@ -27,8 +20,11 @@ from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc from ..constants import dist_group_type + +from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: from ..triton_kernels.cast import te_quantize_triton + import os aten = torch.ops.aten @@ -109,7 +105,13 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" - return tex.quantize(tensor, self) + if IS_HIP_EXTENSION: + from ..triton_kernels.cast import te_quantize_triton + use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) + quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize + return quantize_func(tensor, self) + else: + return tex.quantize(tensor, self) def make_empty( self, @@ -304,7 +306,13 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" - return tex.quantize(tensor, self) + if IS_HIP_EXTENSION: + from ..triton_kernels.cast import te_quantize_triton + use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) + quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize + return quantize_func(tensor, self) + else: + return tex.quantize(tensor, self) def make_empty( self, diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 6008f0503..1848a60cf 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -8,22 +8,12 @@ from __future__ import annotations from collections.abc import Iterable import math -<<<<<<< HEAD -import os -from typing import Optional, Tuple, Union -from torch.utils.cpp_extension import IS_HIP_EXTENSION - -import torch -if IS_HIP_EXTENSION: - from ..triton_kernels.cast import te_quantize_triton -======= from typing import Optional, Tuple, Union, Any import warnings import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState ->>>>>>> 389a6b import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType @@ -34,6 +24,11 @@ from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc +from torch.utils.cpp_extension import IS_HIP_EXTENSION +if IS_HIP_EXTENSION: + import os + from ..triton_kernels.cast import te_quantize_triton + aten = torch.ops.aten @@ -89,7 +84,13 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" - return tex.quantize(tensor, self) + if IS_HIP_EXTENSION: + from ..triton_kernels.cast import te_quantize_triton + use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) + quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize + return quantize_func(tensor, self) + else: + return tex.quantize(tensor, self) def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" @@ -124,41 +125,26 @@ def make_empty( ) # Allocate FP8 data -<<<<<<< HEAD - data = torch.empty(shape, dtype=torch.uint8, device=device) - # ROCm TE does not implement fuse padding zeros so use zero tensor here - scale_inv = torch.zeros( - round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), - dtype=torch.uint8, - device=device, - ) -======= data = None scale_inv = None if self.rowwise_usage: data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) - scale_inv = torch.empty( + # ROCm TE does not implement fuse padding zeros so use zero tensor here + scale_inv = torch.zeros( round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), dtype=torch.uint8, device=device, pin_memory=pin_memory, ) ->>>>>>> 389a6b # Allocate FP8 data transpose if needed columnwise_data = None columnwise_scale_inv = None if self.columnwise_usage: -<<<<<<< HEAD columnwise_data = torch.empty_like(data) # ROCm TE does not implement fuse padding zeros so use zero tensor here columnwise_scale_inv = torch.zeros( -======= - columnwise_data = torch.empty_like(data, pin_memory=pin_memory) - columnwise_scale_inv = torch.empty( ->>>>>>> 389a6b round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(shape[-1], 128), dtype=torch.uint8, diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 815c2836c..498dd7cdd 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -14,218 +14,7 @@ import torch.distributed as dist import triton -<<<<<<< HEAD -import triton.language as tl from torch.utils.cpp_extension import IS_HIP_EXTENSION - - -@triton.jit -def online_softmax_kernel( - X_ptr, - X_stride, - Y_ptr, - Y_stride, - m_d_X_y_ptr, - m_d_X_y_stride, - rank, - n_cols, - BLOCK_SIZE: tl.constexpr, -): - """ - This kernel computes the m/d components on this TP rank for the online softmax. - - Parameters: - X_ptr: Pointer to input tensor. - X_stride (int): The stride of the input tensor. - Y_ptr: Pointer to target tensor. - Y_stride (int): The stride of the target tensor. - m_d_X_y_ptr: Pointer to m/d/X_y tensor. - m_d_X_y_stride (int): The stride of the m/d/X_y tensor. - rank (int): The rank of this device in the TP group. - n_cols (int): The number of columns in the input tensor. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - program_id = tl.program_id(0).to(tl.int64) - - # locate the start index - X_ptr += program_id * X_stride - - # Load Y_ptr - Y_ptr += program_id * Y_stride - y = tl.load(Y_ptr) - - vocab_start_idx = rank * n_cols - vocab_end_idx = (rank + 1) * n_cols - if y >= vocab_start_idx: - if y < vocab_end_idx: - X_y = tl.load(X_ptr + y - vocab_start_idx).to(tl.float32) - else: - X_y = float("-inf") - else: - X_y = float("-inf") - - m_d_X_y_ptr += program_id * m_d_X_y_stride * 3 - - # 3. [Online softmax] first pass: find max + sum - m = float("-inf") # m is the max value. use the notation from the paper - d = 0.0 # d is the sum. use the notation from the paper - - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")).to( - tl.float32 - ) - block_max = tl.max(X_block) - m_new = tl.maximum(m, block_max) - d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) - m = m_new - - tl.store(m_d_X_y_ptr, m) - tl.store(m_d_X_y_ptr + m_d_X_y_stride, d) - tl.store(m_d_X_y_ptr + (2 * m_d_X_y_stride), X_y) - - -@triton.jit -def cross_entropy_kernel( - X_ptr, - X_stride, - Y_ptr, - Y_stride, - loss_ptr, - loss_stride, - m_d_X_y_ptr, - m_d_X_y_stride, - rank, - world_size, - ignore_idx, - n_cols, - n_non_ignore, - reduce_loss: tl.constexpr, - label_smoothing: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - """ - This kernel computes both cross entropy loss and the gradient of the input. - - Parameters: - X_ptr: Pointer to input tensor. - X_stride (int): The stride of the input tensor. - Y_ptr: Pointer to target tensor. - Y_stride (int): The stride of the target tensor. - loss_ptr: Pointer to tensor to store the loss. - loss_stride (int): The stride of the loss tensor. - m_d_X_y_ptr: Pointer to m/d/X_y tensor. - m_d_X_y_stride: The stride of m/d/X_y tensor. - rank (int): The rank of this device in the TP group. - world_size (int): The size of world involved in this distributed loss calculation. - ignore_idx (int): Tokens to be ignored for loss and gradient calculation. - n_cols (int): The number of columns in the input tensor. - n_non_ignore (int): The number of non-ignored elements in the batch. - label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - program_id = tl.program_id(0).to(tl.int64) - - # locate the start index - X_ptr += program_id * X_stride - - # Load Y_ptr - Y_ptr += program_id * Y_stride - y = tl.load(Y_ptr) - - if y == ignore_idx: - # set all X_ptr as 0 - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) - return - - loss_ptr += program_id * loss_stride - m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride - - # Need to reduce the m/d/X_y values from other TP ranks - m = tl.load(m_d_X_y_ptr) - d = tl.load(m_d_X_y_ptr + m_d_X_y_stride) - ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride)) - - for i in range(1, world_size): - offset = i * 3 * n_non_ignore * m_d_X_y_stride - access_ptr = m_d_X_y_ptr + offset - m_new = tl.load(access_ptr) - d_new = tl.load(access_ptr + m_d_X_y_stride) - X_y_new = tl.load(access_ptr + (2 * m_d_X_y_stride)) - - d = d * tl.exp(m - tl.maximum(m, m_new)) + d_new * tl.exp(m_new - tl.maximum(m, m_new)) - m = tl.maximum(m, m_new) - ori_X_y = tl.maximum(ori_X_y, X_y_new) - - # Label smoothing is a general case of normal cross entropy - scaled_x_sum = 0.0 - eps = label_smoothing / (n_cols * world_size) - - # 4. [Online softmax] second pass: calculate the gradients - # dx_y = (softmax(x_y) - 1) / N - # dx_i = softmax(x_i) / N, i != y - # N is the number of non ignored elements in the batch - # For label smoothing: - # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y - # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N - # = dx_i - (1 - label_smoothing) / N - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) - grad_dtype = X_block.dtype - X_block = X_block.to(tl.float32) - if label_smoothing > 0: - # scale X beforehand to avoid overflow - scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) - # Scale gradients based on reduction mode - # For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore - # For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here - if reduce_loss: - X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) - else: - X_block = tl.exp(X_block - m) / d - eps - tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) - - # We need tl.debug_barrier() to ensure the new result of X_ptr is written - tl.debug_barrier() - - # 5. Calculate the loss - - # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) - # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) - loss = -(ori_X_y - m - tl.log(d)) - - # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps - # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) - # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) - # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: - # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) - # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 - if label_smoothing > 0: - smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) - loss = loss * (1 - label_smoothing) + smooth_loss - - # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` - vocab_start_idx = rank * n_cols - vocab_end_idx = (rank + 1) * n_cols - if y >= vocab_start_idx: - if y < vocab_end_idx: - X_y = tl.load(X_ptr + y - vocab_start_idx) - # Apply the same conditional scaling logic for the target token - if reduce_loss: - X_y += -(1 - label_smoothing) / (n_non_ignore) - else: - X_y += -(1 - label_smoothing) - tl.store(X_ptr + y - vocab_start_idx, X_y) - - tl.store(loss_ptr, loss) -======= ->>>>>>> 389a6b - from transformer_engine.common.triton.cross_entropy import ( online_softmax_kernel, cross_entropy_kernel, diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index d91c07c45..86acb7932 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -468,7 +468,15 @@ def is_fp8_fnuz(): get_torch_float8_e4m3_type = lambda: torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn get_torch_float8_e5m2_type = lambda: torch.float8_e5m2fnuz if is_fp8_fnuz() else torch.float8_e5m2 -<<<<<<< HEAD +def assert_dim_for_all_gather( + tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer +) -> None: + """Assert that tensor dimensions are supported for all-gather""" + if with_all_gather: + assert quantizer.is_quantizable(tensor), ( + "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ + ) + def is_bf16_compatible() -> None: if IS_HIP_EXTENSION: # only MI200 and newer machines support bf16 @@ -481,24 +489,6 @@ def is_bf16_compatible() -> None: check on device compute capability to enforce sm_80 or higher. """ return torch.cuda.get_device_capability()[0] >= 8 -======= -def assert_dim_for_all_gather( - tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer -) -> None: - """Assert that tensor dimensions are supported for all-gather""" - if with_all_gather: - assert quantizer.is_quantizable(tensor), ( - "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ - ) - - -def is_bf16_compatible() -> bool: - """Replaces torch.cuda.is_bf16_compatible() with an explicit - check on device compute capability to enforce sm_80 or higher. - """ - return torch.cuda.get_device_capability()[0] >= 8 ->>>>>>> 389a6b - def is_bf16_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: """ @@ -535,14 +525,11 @@ def is_non_tn_fp8_gemm_supported() -> bool: @functools.lru_cache(maxsize=None) def get_cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" -<<<<<<< HEAD # ROCm fused attn does not use cudnn, return high numbers to avoid tests filtering out if IS_HIP_EXTENSION: return (99, 0, 0) -======= import transformer_engine.pytorch.cpp_extensions as ext ->>>>>>> 389a6b encoded_version = ext.get_cudnn_version() major_version_magnitude = 1000 if encoded_version < 90000 else 10000 major, encoded_version = divmod(encoded_version, major_version_magnitude) From e60ff21fdd4420faf9573b32eb22f821ec32d585 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Tue, 10 Feb 2026 11:09:24 -0600 Subject: [PATCH 04/41] [ROCm] resolve the conflicts in setup --- pyproject.toml | 6 +----- setup.py | 15 +++++---------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c32fc31a4..3814aabd0 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,15 +1,11 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. [build-system] -<<<<<<< HEAD -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax", "flax>=0.7.1"] -======= requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] ->>>>>>> 389a6b # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" diff --git a/setup.py b/setup.py index c6199c387..bec4943e1 100644 --- a/setup.py +++ b/setup.py @@ -245,17 +245,17 @@ def git_check_submodules() -> None: cmdclass = {} package_data = {} include_package_data = False -<<<<<<< HEAD - install_requires = ([f"transformer_engine_{te_cuda_vers}=={__version__}"],) -======= install_requires = [] ->>>>>>> 389a6b extras_require = { "core": [f"transformer_engine_cu12=={__version__}"], "core_cu12": [f"transformer_engine_cu12=={__version__}"], "core_cu13": [f"transformer_engine_cu13=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], + } if not rocm_build() else { + "core": [f"transformer_engine_{te_cuda_vers}=={__version__}"], + "pytorch": [f"transformer_engine_torch=={__version__}"], + "jax": [f"transformer_engine_jax=={__version__}"], } else: install_requires, test_requires = setup_requirements() @@ -303,13 +303,8 @@ def git_check_submodules() -> None: long_description=long_description, long_description_content_type="text/x-rst", ext_modules=ext_modules, -<<<<<<< HEAD - cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, - python_requires=">=3.8", -======= - cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, + cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} if not rocm_build() else {"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, python_requires=f">={min_python_version_str()}", ->>>>>>> 389a6b classifiers=["Programming Language :: Python :: 3"], install_requires=install_requires, license_files=("LICENSE",), From 8bbb16214277009bed6a8327ed6312a5b44b3f59 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 11 Feb 2026 10:55:18 -0600 Subject: [PATCH 05/41] [ROCm] resolve the cpp gtest --- tests/cpp/operator/CMakeLists.txt | 38 ++---------- tests/cpp/operator/test_cast_mxfp8.cu | 45 ++++---------- .../operator/test_cast_mxfp8_gated_swiglu.cu | 60 +++++++------------ tests/cpp/test_common.cu | 42 ++++--------- tests/cpp/test_common.h | 13 +--- 5 files changed, 52 insertions(+), 146 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index f0123ccf7..cd36993ce 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -4,7 +4,6 @@ # # See LICENSE for license information. -<<<<<<< HEAD list(APPEND test_cuda_sources test_cast.cu test_cast_current_scaling.cu @@ -24,16 +23,18 @@ list(APPEND test_cuda_sources test_act.cu test_normalization.cu test_normalization_mxfp8.cu + test_memset.cu test_multi_cast_transpose.cu test_multi_padding.cu test_multi_unpadding.cu test_causal_softmax.cu - test_swizzle.cu test_swap_first_dims.cu ../test_common.cu) if(USE_CUDA) list(APPEND test_cuda_sources - test_cast_float8blockwise.cu) + test_cast_nvfp4_transpose.cu + test_cast_float8blockwise.cu + test_swizzle.cu) else() list(APPEND test_cuda_sources test_cublaslt_gemm.cu) @@ -70,37 +71,6 @@ else() add_executable(test_operator ${test_hip_sources}) endif() -======= -add_executable(test_operator - test_cast.cu - test_cast_current_scaling.cu - test_cast_dbias.cu - test_cast_dbias_dgelu.cu - test_cast_gated_swiglu.cu - test_cast_mxfp8_gated_swiglu.cu - test_qdq.cu - test_cast_mxfp8.cu - test_cast_nvfp4_transpose.cu - test_cast_float8blockwise.cu - test_dequantize_mxfp8.cu - test_transpose.cu - test_cast_transpose.cu - test_cast_transpose_current_scaling.cu - test_cast_transpose_dbias.cu - test_cast_transpose_dbias_dgelu.cu - test_cast_transpose_dgeglu.cu - test_act.cu - test_normalization.cu - test_normalization_mxfp8.cu - test_memset.cu - test_multi_cast_transpose.cu - test_multi_padding.cu - test_multi_unpadding.cu - test_causal_softmax.cu - test_swizzle.cu - test_swap_first_dims.cu - ../test_common.cu) ->>>>>>> 389a6b # Find required packages find_package(OpenMP REQUIRED) diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index acba36464..a029e4f3f 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -315,29 +315,22 @@ void performTest_x1(const ProcessingMethod processing_method, #ifdef __HIP_PLATFORM_AMD__ const double abs_tolerable_mismatches_limit = 1.0; const double rel_tolerable_mismatches_limit = 1.0e-4; + std::vector mismatches_scales_indices; #else const double abs_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0; #endif - std::vector mismatches_scales_indices; size_t mismatches_scales = 0; -<<<<<<< HEAD - compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales_indices, mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); -======= - compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ mismatches_scales, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); ->>>>>>> 389a6b #ifdef __HIP_PLATFORM_AMD__ if (::testing::Test::HasFatalFailure()) return; @@ -510,50 +503,36 @@ void performTest_x2(const ProcessingMethod processing_method, #ifdef __HIP_PLATFORM_AMD__ const double abs_tolerable_mismatches_limit = 1.0; const double rel_tolerable_mismatches_limit = 1.0e-4; + std::vector mismatches_scales_indices_rowwise; + std::vector mismatches_scales_indices_colwise; #else const double abs_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0; #endif - std::vector mismatches_scales_indices_rowwise; size_t mismatches_scales_rowwise = 0; -<<<<<<< HEAD - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_indices_rowwise, mismatches_scales_rowwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); -======= compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_rowwise, +#endif //#ifdef __HIP_PLATFORM_AMD__ mismatches_scales_rowwise, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); ->>>>>>> 389a6b - std::vector mismatches_scales_indices_colwise; size_t mismatches_scales_colwise = 0; -<<<<<<< HEAD - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_indices_colwise, mismatches_scales_colwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); -======= compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_colwise, +#endif //#ifdef __HIP_PLATFORM_AMD__ mismatches_scales_colwise, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); ->>>>>>> 389a6b #ifdef __HIP_PLATFORM_AMD__ if (::testing::Test::HasFatalFailure()) return; diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index d6bcfef30..ba4144a7c 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -264,7 +264,9 @@ void performTest_x1(const size_t rows, rowwise, colwise); +#ifdef __HIP_PLATFORM_AMD__ std::vector mismatches_scales_indices; +#endif size_t mismatches_scales = 0; const size_t scale_diff_abs_tolerance = 0; const double abs_tolerable_mismatches_limit = 1.0; @@ -274,25 +276,11 @@ void performTest_x1(const size_t rows, ? output.rowwise_cpu_scale_inv_ptr() : output.columnwise_cpu_scale_inv_ptr(); if (rowwise) { -<<<<<<< HEAD - compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales_indices, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); - } else { - compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales_indices, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); -======= compare_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ mismatches_scales, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, @@ -300,12 +288,13 @@ void performTest_x1(const size_t rows, } else { compare_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ mismatches_scales, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); - ->>>>>>> 389a6b } #ifdef __HIP_PLATFORM_AMD__ @@ -411,44 +400,35 @@ void performTest_x2(const size_t rows, const double abs_tolerable_mismatches_limit = 1.0; const double rel_tolerable_mismatches_limit = 1.0e-4; +#ifdef __HIP_PLATFORM_AMD__ std::vector mismatches_scales_indices_rowwise; +#endif size_t mismatches_scales_rowwise = 0; -<<<<<<< HEAD - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_indices_rowwise, mismatches_scales_rowwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); - std::vector mismatches_scales_indices_colwise; - size_t mismatches_scales_colwise = 0; - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_indices_colwise, mismatches_scales_colwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); -======= compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_rowwise, +#endif mismatches_scales_rowwise, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); +#ifdef __HIP_PLATFORM_AMD__ + std::vector mismatches_scales_indices_colwise; +#endif size_t mismatches_scales_colwise = 0; compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_colwise, +#endif mismatches_scales_colwise, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); ->>>>>>> 389a6b - #ifdef __HIP_PLATFORM_AMD__ if (::testing::Test::HasFatalFailure()) return; adjust_ref_for_e8m0_scale_error("scales_rowwise", mismatches_scales_indices_rowwise, @@ -514,7 +494,7 @@ class CastMXFP8_GatedActTestSuite : public ::testing::TestWithParam bool>> {}; TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { - #ifdef __HIP_PLATFORM_AMD__ +#ifdef __HIP_PLATFORM_AMD__ omp_set_num_threads(std::min(128, omp_get_max_threads())); // Using threads = # of vcpus causes occasional errors. #else // #ifdef __HIP_PLATFORM_AMD__ // Skip tests for pre-Blackwell architectures diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 3286b1527..5427bc118 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -412,20 +412,13 @@ void Tensor::to_cpu() const { cudaMemcpyDeviceToHost); } if (columnwise_) { -<<<<<<< HEAD - (void)cudaMemcpy(cpu_data_columnwise_.get(), - tensor_.get_columnwise_data().data_ptr, - size, - cudaMemcpyDeviceToHost); -======= const DType colwise_type = tensor_.dtype(); const size_t colwise_size = bytes(s, colwise_type); - cudaMemcpy(cpu_data_columnwise_.get(), + (void)cudaMemcpy(cpu_data_columnwise_.get(), tensor_.get_columnwise_data().data_ptr, colwise_size, cudaMemcpyDeviceToHost); ->>>>>>> 389a6b } if (isFp8Type(dtype()) || isFp4Type(dtype())) { if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) { @@ -759,14 +752,6 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t } } -<<<<<<< HEAD -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - std::vector &mismatch_indices, - size_t& mismatches_num, const size_t atol, - const double abs_tolerable_mismatches_limit, - const double rel_tolerable_mismatches_limit) -======= template struct CastToType; @@ -783,10 +768,12 @@ struct CastToType { template void compare_scaling_factors(const std::string &name, const T *test, const T *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector &mismatch_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ size_t& mismatches_num, const size_t atol, const double abs_tolerable_mismatches_limit, const double rel_tolerable_mismatches_limit) ->>>>>>> 389a6b { using UpcastType = typename CastToType::type; auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3); @@ -796,6 +783,9 @@ void compare_scaling_factors(const std::string &name, const T *test, const T *re const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, std::floor(N * rel_tolerable_mismatches_limit)); mismatches_num = 0; +#ifndef __HIP_PLATFORM_AMD__ + std::vector mismatch_indices; +#endif //#ifndef __HIP_PLATFORM_AMD__ for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { @@ -842,8 +832,6 @@ void compare_scaling_factors(const std::string &name, const T *test, const T *re } } -<<<<<<< HEAD - #ifdef __HIP_PLATFORM_AMD__ void adjust_ref_for_e8m0_scale_error(const std::string &name, const std::vector &mismatch_idx, @@ -887,11 +875,13 @@ void adjust_ref_for_e8m0_scale_error(const std::string &name, } } #endif // #ifdef __HIP_PLATFORM_AMD__ -======= // Instantiate templates template void compare_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector &mismatch_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ size_t& mismatches_num, const size_t atol, const double abs_tolerable_mismatches_limit, const double rel_tolerable_mismatches_limit); @@ -899,11 +889,13 @@ void compare_scaling_factors(const std::string &name, const uint8_t *te template void compare_scaling_factors(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector &mismatch_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ size_t& mismatches_num, const size_t atol, const double abs_tolerable_mismatches_limit, const double rel_tolerable_mismatches_limit); ->>>>>>> 389a6b std::pair getTolerances(const DType type) { switch(type) { @@ -1069,13 +1061,6 @@ bool isFp8Type(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; } -<<<<<<< HEAD -int32_t getDeviceComputeCapability() -{ - cudaDeviceProp deviceProp; - (void)cudaGetDeviceProperties(&deviceProp, 0); - return 10 * deviceProp.major + deviceProp.minor; -======= bool isFp4Type(DType type) { return type == DType::kFloat4E2M1; } @@ -1084,7 +1069,6 @@ int32_t getDeviceComputeCapability() { cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, 0); return 10 * deviceProp.major + deviceProp.minor; ->>>>>>> 389a6b } size_t first_dimension(const std::vector &shape) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 1d6d9107e..56154c9d9 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -488,24 +488,17 @@ void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t N, float mismatch_rate_tol = 0.); -<<<<<<< HEAD -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - std::vector &mismatch_indices, size_t& mismatches_num, - const size_t scale_diff_abs_tolerance = 0, - const double abs_tolerable_mismatches_limit = 0, - const double rel_tolerable_mismatches_limit = 0); -======= template void compare_scaling_factors(const std::string &name, const T *test, const T *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef USE_ROCM + std::vector& mismatch_indices, +#endif //#ifdef USE_ROCM size_t& mismatches_num, const size_t scale_diff_abs_tolerance = 0, const double abs_tolerable_mismatches_limit = 0, const double rel_tolerable_mismatches_limit = 0); ->>>>>>> 389a6b - #ifdef USE_ROCM void adjust_ref_for_e8m0_scale_error(const std::string &name, const std::vector &mismatch_idx, From f573b40081340199654061e518abb9c195e96a81 Mon Sep 17 00:00:00 2001 From: alextmagro Date: Wed, 11 Feb 2026 21:02:08 +0000 Subject: [PATCH 06/41] [ROCm] resolve pytorch and jax tests Resolve wheels and examples --- build_tools/wheel_utils/build_wheels.sh | 103 +++++++----------- examples/pytorch/mnist/main.py | 20 +--- tests/jax/distributed_test_base.py | 6 +- tests/jax/test_custom_call_compute.py | 75 +------------ tests/jax/test_distributed_layernorm_mlp.py | 7 -- tests/jax/test_fused_attn.py | 4 - .../attention/run_attention_with_cp.py | 35 +----- tests/pytorch/attention/test_attention.py | 58 +--------- .../attention/test_attention_with_cp.py | 24 +--- tests/pytorch/attention/test_kv_cache.py | 10 +- tests/pytorch/distributed/run_fsdp2_model.py | 18 +-- tests/pytorch/distributed/test_fusible_ops.py | 5 +- tests/pytorch/test_cpu_offloading.py | 6 - .../test_float8_current_scaling_exact.py | 7 -- tests/pytorch/test_fusible_ops.py | 17 --- tests/pytorch/test_numerics.py | 48 +------- tests/pytorch/test_recipe.py | 7 +- tests/pytorch/utils.py | 16 +-- 18 files changed, 65 insertions(+), 401 deletions(-) diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 6db223691..0be852c8a 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -43,19 +43,10 @@ else fi if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install setuptools wheel -fi - -# Install deps -<<<<<<< HEAD -if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install pybind11[global] ninja + ${PYBINDIR}pip install pybind11[global] ninja setuptools wheel else - ${PYBINDIR}pip install cmake pybind11[global] ninja + /opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel fi -======= -/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel ->>>>>>> 389a6b if $BUILD_METAPACKAGE ; then cd /TransformerEngine @@ -83,70 +74,52 @@ if $BUILD_COMMON ; then # Create the wheel. ${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt -<<<<<<< HEAD - # Repack the wheel for cuda specific package, i.e. cu12. - ${PYBINDIR}wheel unpack dist/* - # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). - sed -i "s/Name: transformer-engine/Name: transformer-engine-${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - sed -i "s/Name: transformer_engine/Name: transformer_engine_${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_${TE_CUDA_VERS}-${VERSION}.dist-info" - ${PYBINDIR}wheel pack ${WHL_BASE} -======= - # Repack the wheel for specific cuda version. - /opt/python/cp310-cp310/bin/wheel unpack dist/* - # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). - sed -i "s/Name: transformer-engine/Name: transformer-engine-cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - sed -i "s/Name: transformer_engine/Name: transformer_engine_cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu${CUDA_MAJOR}-${VERSION}.dist-info" - /opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE} ->>>>>>> 389a6b + if [ "$ROCM_BUILD" = "1" ]; then + # Repack the wheel for cuda specific package, i.e. cu12. + ${PYBINDIR}wheel unpack dist/* + # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). + sed -i "s/Name: transformer-engine/Name: transformer-engine-${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_${TE_CUDA_VERS}-${VERSION}.dist-info" + ${PYBINDIR}wheel pack ${WHL_BASE} + else + # Repack the wheel for specific cuda version. + /opt/python/cp310-cp310/bin/wheel unpack dist/* + # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). + sed -i "s/Name: transformer-engine/Name: transformer-engine-cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu${CUDA_MAJOR}-${VERSION}.dist-info" + /opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE} + fi # Rename the wheel to make it python version agnostic. whl_name=$(basename dist/*) IFS='-' read -ra whl_parts <<< "$whl_name" -<<<<<<< HEAD whl_name_target="${whl_parts[0]}_${TE_CUDA_VERS}-${whl_parts[1]}-py3-none-${whl_parts[4]}" -======= - whl_name_target="${whl_parts[0]}_cu${CUDA_MAJOR}-${whl_parts[1]}-py3-none-${whl_parts[4]}" ->>>>>>> 389a6b rm -rf $WHL_BASE dist mv *.whl /wheelhouse/"$whl_name_target" fi if $BUILD_PYTORCH ; then -<<<<<<< HEAD - cd /TransformerEngine/transformer_engine/pytorch - if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/rocm6.3 - else - PYBINDIR=/opt/python/cp38-cp38/bin/ - ${PYBINDIR}pip install torch - fi - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt - cp dist/* /wheelhouse/ -fi - -if $BUILD_JAX ; then - cd /TransformerEngine/transformer_engine/jax - if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install jax - else - PYBINDIR=/opt/python/cp310-cp310/bin/ - ${PYBINDIR}pip install "jax[cuda12_local]" jaxlib - fi - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt - cp dist/* /wheelhouse/ -======= - cd /TransformerEngine/transformer_engine/pytorch - /opt/python/cp310-cp310/bin/pip install torch - /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt - cp dist/* /wheelhouse/ + cd /TransformerEngine/transformer_engine/pytorch + if [ "$ROCM_BUILD" = "1" ]; then + ${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/rocm6.3 + else + PYBINDIR=/opt/python/cp310-cp310/bin/ + ${PYBINDIR}pip install torch + fi + ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt + cp dist/* /wheelhouse/ fi if $BUILD_JAX ; then - cd /TransformerEngine/transformer_engine/jax - /opt/python/cp310-cp310/bin/pip install "jax[cuda${CUDA_MAJOR}_local]" jaxlib - /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt - cp dist/* /wheelhouse/ ->>>>>>> 389a6b + cd /TransformerEngine/transformer_engine/jax + if [ "$ROCM_BUILD" = "1" ]; then + ${PYBINDIR}pip install jax + else + PYBINDIR=/opt/python/cp310-cp310/bin/ + ${PYBINDIR}pip install "jax[cuda12_local]" jaxlib + fi + ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt + cp dist/* /wheelhouse/ fi diff --git a/examples/pytorch/mnist/main.py b/examples/pytorch/mnist/main.py index 3516d5275..347d36e7c 100644 --- a/examples/pytorch/mnist/main.py +++ b/examples/pytorch/mnist/main.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -54,17 +54,8 @@ def train(args, model, device, train_loader, optimizer, epoch, use_amp, use_fp8) for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() -<<<<<<< HEAD - if use_amp: - with autocast(device_type='cuda', dtype=torch.float16): - output = model(data) - else: - with te.fp8_autocast(enabled=use_fp8): - output = model(data) -======= with te.autocast(enabled=use_fp8): output = model(data) ->>>>>>> 389a6b loss = F.nll_loss(output, target) loss.backward() optimizer.step() @@ -99,17 +90,8 @@ def test(model, device, test_loader, use_amp, use_fp8): with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) -<<<<<<< HEAD - if use_amp: - with autocast(device_type='cuda', dtype=torch.float16): - output = model(data) - else: - with te.fp8_autocast(enabled=use_fp8): - output = model(data) -======= with te.autocast(enabled=use_fp8): output = model(data) ->>>>>>> 389a6b test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 244e8b5ee..e8d9cefd6 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,10 +10,6 @@ import pytest import jax -<<<<<<< HEAD -from jax._src.pjit import pjit -======= ->>>>>>> 389a6b from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED from transformer_engine.jax.sharding import MeshResource diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index ad5ddf0d3..75b606a62 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -614,16 +614,12 @@ def test_norm_forward_with_tensor_scaling_fp8( ) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) -<<<<<<< HEAD - @pytest.mark.parametrize("out_dtype", FP8_COMPUTE_TYPE) -======= @pytest.mark.parametrize( "out_dtype", [ - jnp.float8_e4m3fn, + jnp_float8_e4m3_type if is_hip_extension() else jnp.float8_e4m3fn, ], ) ->>>>>>> 389a6b def test_norm_forward_with_block_scaling_fp8( self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype ): @@ -640,15 +636,9 @@ def test_norm_forward_with_block_scaling_fp8( ) -<<<<<<< HEAD QUANTIZE_OUTPUT_DTYPES = { "L0": [jnp_float8_e4m3_type], "L2": FP8_COMPUTE_TYPE, -======= -QUANTIZE_OUTPUT_FP8_DTYPES = { - "L0": [jnp.float8_e4m3fn], - "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2], ->>>>>>> 389a6b } QUANTIZE_OUTPUT_DTYPES = { test_level: QUANTIZE_OUTPUT_FP8_DTYPES[test_level] + [jnp.float4_e2m1fn] @@ -692,11 +682,7 @@ def test_norm_forward_with_block_scaling_fp8( @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) -<<<<<<< HEAD @pytest_parametrize_wrapper("q_dtype", FP8_COMPUTE_TYPE) -======= -@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2, jnp.float4_e2m1fn]) ->>>>>>> 389a6b @pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper( @@ -1085,13 +1071,8 @@ def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, w @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) -<<<<<<< HEAD @pytest_parametrize_wrapper("q_dtype", [jnp_float8_e4m3_type]) -@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) -======= -@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) ->>>>>>> 389a6b @pytest_parametrize_wrapper("flatten_axis", [-1]) @pytest_parametrize_wrapper("with_group_sizes", [True, False]) @pytest_parametrize_wrapper( @@ -1487,17 +1468,10 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( -<<<<<<< HEAD - scaling_mode=scaling_mode, - fwd_dtype=jnp_float8_e4m3_type, - bwd_dtype=jnp_float8_e5m2_type if scaling_mode.is_tensor_scaling() else jnp_float8_e4m3_type, - is_2x2x=True, -======= fp8_recipe=recipe, quantize_meta_set=QuantizeMetaSet( x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() ), ->>>>>>> 389a6b ) n_iterations = 3 if recipe.delayed() else 1 @@ -1511,17 +1485,10 @@ def ref_func(x, w, bias, data_layout): x, w, bias, data_layout ) -<<<<<<< HEAD - assert_allclose(primitive_out, ref_out, dtype=jnp_float8_e4m3_type) - assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp_float8_e5m2_type) -======= assert_allclose(primitive_out, ref_out, dtype=quantizer_set.x.q_dtype) assert_allclose(primitive_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype) assert_allclose(primitive_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype) assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=quantizer_set.dgrad.q_dtype) ->>>>>>> 389a6b @pytest.fixture(name="random_inputs") @@ -1568,17 +1535,10 @@ def test_layernorm_dense_grad(self, m, n, k, recipe, norm_type, with_jax_gemm): gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) quantizer_set = QuantizerFactory.create_set( -<<<<<<< HEAD - scaling_mode=scaling_mode, - fwd_dtype=jnp_float8_e4m3_type, - bwd_dtype=jnp_float8_e5m2_type if scaling_mode.is_tensor_scaling() else jnp_float8_e4m3_type, - is_2x2x=True, -======= fp8_recipe=recipe, quantize_meta_set=QuantizeMetaSet( x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() ), ->>>>>>> 389a6b ) if norm_type == "layernorm": @@ -1624,21 +1584,12 @@ def ref_func(x, w, gamma, beta): prim_beta_grad, ) = value_n_grad_prim_func(x, w, gamma, beta) -<<<<<<< HEAD - assert_allclose(prim_out, ref_out, dtype=jnp_float8_e4m3_type) - assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp_float8_e5m2_type) - if beta is not None: - assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp_float8_e5m2_type) -======= assert_allclose(prim_out, ref_out, dtype=quantizer_set.x.q_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype) assert_allclose(prim_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype) assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=quantizer_set.dgrad.q_dtype) if beta is not None: assert_allclose(prim_beta_grad, ref_beta_grad, dtype=quantizer_set.dgrad.q_dtype) ->>>>>>> 389a6b @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 128, 128)]) @@ -1676,17 +1627,10 @@ def test_layernorm_mlp_grad( quantizer_sets = QuantizerFactory.create_set( n_quantizer_sets=2, -<<<<<<< HEAD - scaling_mode=scaling_mode, - fwd_dtype=jnp_float8_e4m3_type, - bwd_dtype=jnp_float8_e5m2_type if scaling_mode.is_tensor_scaling() else jnp_float8_e4m3_type, - is_2x2x=True, -======= fp8_recipe=recipe, quantize_meta_set=QuantizeMetaSet( x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() ), ->>>>>>> 389a6b ) if norm_type == "layernorm": @@ -1754,20 +1698,6 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ref_bias_2_grad, ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) -<<<<<<< HEAD - assert_allclose(prim_out, ref_out, dtype=jnp_float8_e4m3_type) - - assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp_float8_e5m2_type) - if use_bias: - assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp_float8_e5m2_type) - - assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp_float8_e5m2_type) - if use_bias: - assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp_float8_e5m2_type) - - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) -======= fwd_dtype = quantizer_sets[0].x.q_dtype bwd_dtype = quantizer_sets[0].dgrad.q_dtype assert_allclose(prim_out, ref_out, dtype=fwd_dtype) @@ -1778,7 +1708,6 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): if use_bias: assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=bwd_dtype) assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=bwd_dtype) ->>>>>>> 389a6b # E5M2 * E5M2 is not supported diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index d58ebcef5..c67528f04 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -275,23 +275,16 @@ def _test_layernorm_mlp_grad( ) # +1 for multi_gpus multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) -<<<<<<< HEAD # TODO: skip cases with single fwd as nan/inf if is_hip_extension() and (jnp.any(jnp.isnan(single_fwd)) or jnp.any(jnp.isinf(single_fwd))): pytest.skip("skip tests with nan/inf single fwd.") - - fwd_test_type = dtype if fp8_recipe is None else jnp_float8_e4m3_type - bwd_test_type = dtype if fp8_recipe is None else jnp_float8_e5m2_type -======= - fwd_test_type = bwd_test_type = dtype if quantization_recipe is not None: quantize_config = get_quantize_config_with_recipe(quantization_recipe) fwd_test_type = quantize_config.FWD_DTYPE bwd_test_type = quantize_config.BWD_DTYPE ->>>>>>> 389a6b if fwd_test_type == jnp.float16 and use_bias: assert_allclose(multi_fwd, single_fwd, atol=0.04, rtol=1.5) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index a29725909..f33961455 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -24,12 +24,8 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.typing import ArrayLike, DTypeLike -<<<<<<< HEAD -from transformer_engine.jax import fp8_autocast from transformer_engine.jax.cpp_extensions.misc import is_hip_extension -======= from transformer_engine.jax import autocast ->>>>>>> 389a6b from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.attention import ( AttnBiasType, diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 8fc914053..80f21048f 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,23 +10,16 @@ from contextlib import nullcontext import torch import torch.distributed as dist -<<<<<<< HEAD +import warnings + from torch.utils.cpp_extension import IS_HIP_EXTENSION -from transformer_engine.pytorch.attention import DotProductAttention -======= ->>>>>>> 389a6b + from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( get_cu_seqlens_on_cp_rank, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize import transformer_engine_torch as tex from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn -<<<<<<< HEAD -from transformer_engine.pytorch.fp8 import fp8_autocast -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer -from transformer_engine.common.recipe import DelayedScaling -import warnings -======= from transformer_engine.pytorch import ( autocast, DotProductAttention, @@ -35,7 +28,6 @@ ) from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling from utils import ModelConfig, compare_and_assert ->>>>>>> 389a6b dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -336,16 +328,11 @@ def run_dpa_with_cp( core_attention_bias=bias, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, -<<<<<<< HEAD cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], cu_seqlens_kv_padded=( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] ), -======= - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, fp8_output=fp8_mha, ->>>>>>> 389a6b ) if config.return_max_logit: out, max_logit = out @@ -438,16 +425,11 @@ def run_dpa_with_cp( core_attention_bias=bias_, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, -<<<<<<< HEAD cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], cu_seqlens_kv_padded=( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] ), -======= - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, fp8_output=fp8_mha, ->>>>>>> 389a6b ) if config.return_max_logit: out_, max_logit_ = out_ @@ -491,17 +473,10 @@ def run_dpa_with_cp( for x in [dq_, dk_, dv_, out_] ] elif qkv_format == "thd": -<<<<<<< HEAD - dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]] - dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_] - cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size -======= dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] - cu_seqlens_q_padded = cu_seqlens_q_padded // world_size ->>>>>>> 389a6b + cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True ) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 23463cc32..c0cf64803 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -12,12 +12,9 @@ import pytest import torch -<<<<<<< HEAD from torch.utils.cpp_extension import IS_HIP_EXTENSION -======= from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype ->>>>>>> 389a6b from transformer_engine.common import recipe from transformer_engine.pytorch import ( TransformerLayer, @@ -91,7 +88,6 @@ def reset_global_fp8_state(): FP8GlobalStateManager.reset() -<<<<<<< HEAD if IS_HIP_EXTENSION: from utils import EnvVarCleaner @pytest.fixture(autouse=True) @@ -101,13 +97,11 @@ def reset_attn_backend(): "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", "NVTE_FP8_DPA_BWD"]) yield -======= # Define F16 data types to test param_types = [torch.float16] if is_bf16_available(): param_types.append(torch.bfloat16) param_types_lean = [torch.bfloat16] ->>>>>>> 389a6b model_configs_base = { # test: ModelConfig(b, sq, hq, dqk) @@ -126,7 +120,6 @@ def reset_attn_backend(): } -<<<<<<< HEAD param_types = [torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) @@ -136,10 +129,8 @@ def reset_attn_backend(): # backend is capable of supporting it. @pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") def test_gqa_mla_thd(): - """ - Explicitly test dk_or_dv_reduce_thd as part of TE's CK integration - post-processing for BWD FA with native padding support. - """ + """Explicitly test dk_or_dv_reduce_thd as part of TE's CK integration + post-processing for BWD FA with native padding support.""" # b, sq, h, dqk config = ModelConfig(8, 128, 16, 128, num_gqa_groups= 4, head_dim_v=64, attn_mask_type="padding") qkv_layout = "thd_thd_thd" @@ -156,11 +147,10 @@ def test_gqa_mla_thd(): test_dot_product_attention(dtype, {"layout_1": config}, "layout_1", False, False, qkv_layout, False, True, False) + @pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") def test_dot_product_mem_calc(): - """ - Non-regression test for memory workspace calculation integer overflow issue. - """ + """Non-regression test for memory workspace calculation integer overflow issue.""" ckpt_attn = False pad_between_seqs = False if not is_bf16_compatible(): @@ -197,8 +187,6 @@ def test_dot_product_mem_calc(): del os.environ["NVTE_FUSED_ATTN_AOTRITON"] -======= ->>>>>>> 389a6b @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_base]) @@ -306,13 +294,9 @@ def test_dot_product_attention( ) if len(fused_attn_backends) == 2: os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" -<<<<<<< HEAD os.environ["NVTE_FUSED_ATTN_CK"] = "0" os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1" - fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( -======= fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention( ->>>>>>> 389a6b dtype, config, "FusedAttention", @@ -324,15 +308,11 @@ def test_dot_product_attention( share_cu_seqlens_ref, # Not used by AOT ) os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" -<<<<<<< HEAD os.environ["NVTE_FUSED_ATTN_CK"] = "1" os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0" os.environ["NVTE_CK_USES_FWD_V3"] = "1" os.environ["NVTE_CK_USES_BWD_V3"] = "1" - fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( -======= fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention( ->>>>>>> 389a6b dtype, config, "FusedAttention", @@ -1926,37 +1906,7 @@ def get_model(dtype, config): qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] -<<<<<<< HEAD -def _rmse(a, b): - return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum()) - - -def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): - logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) - logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) - try: - if a.dtype != b.dtype: - a = a.to(b.dtype) - torch.testing.assert_close(a, b, atol=atol, rtol=rtol) - except Exception as e: - logging.debug(e) - - rmse = _rmse(a, b) - logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) - rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) - assert rmse < rmse_tol * rmse_range, ( - name_a - + " vs " - + name_b - + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - rmse, rmse_tol * rmse_range, rmse_tol, rmse_range - ) - ) - - @pytest.mark.skipif(IS_HIP_EXTENSION, reason="FP8 Fused attention is not supported on ROCm") -======= ->>>>>>> 389a6b @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index a9b0afe89..9ac96dcef 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -12,12 +12,10 @@ import pytest import torch -<<<<<<< HEAD + from torch.utils.cpp_extension import IS_HIP_EXTENSION -from transformer_engine.pytorch.utils import ( -======= + from transformer_engine.pytorch import ( ->>>>>>> 389a6b get_device_compute_capability, get_cudnn_version, ) @@ -87,13 +85,8 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") -<<<<<<< HEAD @pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) -======= -@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.parametrize("dtype", dtypes) ->>>>>>> 389a6b @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) @@ -123,11 +116,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ) if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") -<<<<<<< HEAD if IS_HIP_EXTENSION: if config.head_dim_qk != config.head_dim_v and not FlashAttentionUtils.v3_is_installed: pytest.skip("MLA FlashAttention requires v3+!") -======= dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} available_backends, *_ = get_available_attention_backends( config, @@ -137,7 +128,6 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): flash_attn_supported, *_ = available_backends if not flash_attn_supported: pytest.skip("No attention backend available.") ->>>>>>> 389a6b subprocess.run( get_bash_arguments( @@ -207,13 +197,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") -<<<<<<< HEAD @pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) -======= -@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.parametrize("dtype", dtypes) ->>>>>>> 389a6b @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) @@ -235,15 +220,12 @@ def test_cp_with_fused_attention( pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") if (not IS_HIP_EXTENSION) and dtype == "fp8" and get_device_compute_capability() < (9, 0): pytest.skip("FP8 attention is only supported on sm90+!") -<<<<<<< HEAD if IS_HIP_EXTENSION and dtype == "fp8": - pytest.skip("FP8 attention has not been supported on ROCm yet!") -======= + pytest.skip("FP8 attention is not supported on ROCm yet!") if dtype == "fp8" and not fp8_dpa and fp8_mha: pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!") if dtype != "fp8" and fp8_bwd: pytest.skip("Only fp8 works with fp8_bwd=True!") ->>>>>>> 389a6b config = model_configs_fused_attn[model] config.context_parallel = True diff --git a/tests/pytorch/attention/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py index bab34ef28..eb86c0776 100644 --- a/tests/pytorch/attention/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -15,19 +15,13 @@ import pytest import torch -from torch.distributions import Exponential -<<<<<<< HEAD from torch.utils.cpp_extension import IS_HIP_EXTENSION -from transformer_engine.pytorch import make_graphed_callables -from transformer_engine.common import recipe -from transformer_engine.pytorch import fp8_autocast, fp8_model_init -from transformer_engine.pytorch.transformer import ( -======= + +from torch.distributions import Exponential from transformer_engine.pytorch import ( make_graphed_callables, autocast, quantized_model_init, ->>>>>>> 389a6b TransformerLayer, DotProductAttention, InferenceParams, diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index 2ec97518f..b9fe33593 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -1,6 +1,6 @@ #!/usr/bin/python3 # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -17,6 +17,7 @@ Float8CurrentScaling, MXFP8BlockScaling, ) +from transformer_engine.pytorch import torch_version import torch import torch.distributed as dist @@ -28,23 +29,8 @@ from torch.distributed.device_mesh import init_device_mesh from transformer_engine.pytorch import QuantizedTensor from contextlib import nullcontext -<<<<<<< HEAD -from transformer_engine.pytorch import torch_version - -class SimpleNet(nn.Module): - def __init__(self, input_size, hidden_size, output_size): - super(SimpleNet, self).__init__() - self.fc1 = te.Linear(input_size, hidden_size) - self.fc2 = te.Linear(hidden_size, output_size) - - def forward(self, x): - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return x -======= LOCAL_RANK = None ->>>>>>> 389a6b def dist_print(msg): diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 3ce4ca7cd..85ae2d85b 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -31,10 +31,7 @@ is_bf16_available, ) import transformer_engine.pytorch.ops as te_ops -<<<<<<< HEAD -from transformer_engine.pytorch.utils import is_bf16_compatible, is_fp8_fnuz -======= ->>>>>>> 389a6b +from transformer_engine.pytorch.utils import is_fp8_fnuz import transformer_engine_torch as tex # Import utility functions diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index f3f649fb8..c5b4b48b6 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -46,15 +46,9 @@ NUM_LAYERS = model_config["small"].num_layers EPSILON = model_config["small"].eps -<<<<<<< HEAD -# Flash attention saves some internal tensor for the backward pass -# that cannot be offloaded to CPU. -assert os.getenv("NVTE_FLASH_ATTN", "1") == "0" -======= # Disable garbage collection to tests if there are reference cycles. # We do not want them, because they can result in CUDA out of memory errors. import gc ->>>>>>> 389a6b gc.disable() diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index c5c26e873..d21c2e366 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -9,15 +9,8 @@ import transformer_engine.pytorch as te -<<<<<<< HEAD -import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.common.recipe import Float8CurrentScaling, Format -from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype -======= from transformer_engine.common.recipe import Float8CurrentScaling from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype ->>>>>>> 389a6b # read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 37b5e9ee9..a67fd4f45 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -35,12 +35,7 @@ NVFP4Quantizer, is_bf16_available, ) -<<<<<<< HEAD -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer -from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import get_device_compute_capability -======= ->>>>>>> 389a6b import transformer_engine_torch as tex from torch.utils.cpp_extension import IS_HIP_EXTENSION @@ -1377,20 +1372,8 @@ def test_rmsnorm( # Expected numerical error tols = dtype_tols(dtype) - # Explicit checks for quantization if quantized_compute: -<<<<<<< HEAD - tols = dtype_tols(y_test._quantizer.dtype) - expected_tensor_cls = { - Float8Quantizer:Float8Tensor, - Float8CurrentScalingQuantizer:Float8Tensor, - MXFP8Quantizer:MXFP8Tensor - }[type(y_test._quantizer)] - assert isinstance(y_test, expected_tensor_cls) - y_test = y_test.dequantize(dtype=torch.float32) -======= tols = quantization_tols(quantization) ->>>>>>> 389a6b # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 0aa932dfc..50578fc1a 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -38,15 +38,6 @@ LayerNorm, Fp8Padding, Fp8Unpadding, -<<<<<<< HEAD -) -from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils as fa_utils -from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint -from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm -from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend -from transformer_engine.pytorch.tensor.float8_tensor import ( -======= ->>>>>>> 389a6b Float8Quantizer, Float8CurrentScalingQuantizer, MXFP8Quantizer, @@ -57,19 +48,15 @@ is_bf16_available, is_nvfp4_available, ) +from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils as fa_utils from transformer_engine.pytorch import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.common import recipe import transformer_engine_torch as tex -<<<<<<< HEAD -from utils import ModelConfig, reset_rng_states, get_available_attention_backends +from utils import ModelConfig, reset_rng_states if IS_HIP_EXTENSION: from utils import EnvVarCleaner -======= -from utils import ModelConfig, reset_rng_states ->>>>>>> 389a6b - # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) @@ -202,28 +189,6 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> use_cutlass_grouped_gemm.append(True) -<<<<<<< HEAD -def is_fused_attn_available( - config: ModelConfig, - dtype: torch.dtype, - qkv_layout="bshd_bshd_bshd", - is_training=True, - deterministic=False, -): - _, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - is_training=is_training, - deterministic=deterministic, - ) - if IS_HIP_EXTENSION: - return fused_attn_backends != [] - return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends - - -======= ->>>>>>> 389a6b def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -806,7 +771,6 @@ def test_gpt_full_activation_recompute( ): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") -<<<<<<< HEAD if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): if (dtype == torch.bfloat16 and not fp8 @@ -814,13 +778,11 @@ def test_gpt_full_activation_recompute( and recipe.float8_per_tensor_scaling() ): pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.") -======= if fp8 and recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): pytest.skip( f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" ) ->>>>>>> 389a6b config = model_configs[model] torch.compiler.reset() # avoid cache size limit overflow @@ -972,12 +934,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("model", ["126m"]) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] -<<<<<<< HEAD - if not is_fused_attn_available(config, dtype, deterministic=True): - pytest.skip("No attention backend available.") - -======= ->>>>>>> 389a6b outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index d8870d3da..6850be9b4 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -30,12 +30,7 @@ _amax_and_scale_update, ) import transformer_engine.pytorch.ops as te_ops -<<<<<<< HEAD from transformer_engine.pytorch.utils import is_fp8_fnuz -from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear -from transformer_engine.pytorch.distributed import fp8_autocast -======= ->>>>>>> 389a6b from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling import transformer_engine_torch as tex diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 28236a18b..05555626b 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -15,12 +15,6 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine -<<<<<<< HEAD -import transformer_engine.common.recipe -import transformer_engine.pytorch as te -from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type -======= ->>>>>>> 389a6b import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import InferenceParams @@ -32,6 +26,7 @@ check_set_window_size, ) from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend +from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type torch_float8_e4m3_type = get_torch_float8_e4m3_type() torch_float8_e5m2_type = get_torch_float8_e5m2_type() @@ -105,15 +100,10 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: return dict(rtol=1.3e-6, atol=1e-5) if dtype == torch.float64: return dict(rtol=1e-7, atol=1e-7) - if dtype == torch.float8_e4m3fn or dtype == torch.float8_e4m3fnuz: + if dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 -<<<<<<< HEAD - if dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 -======= - if dtype == torch.float8_e5m2: + if dtype in (torch.float8_e5m2, torch.float8_e5m2fnuz): return dict(rtol=0.25, atol=0.125) # epsilon = 0.125 ->>>>>>> 389a6b raise ValueError(f"Unsupported dtype ({dtype})") From eaaae946f976c5a3ef634bea4193bbde2f68743a Mon Sep 17 00:00:00 2001 From: alextmagro Date: Thu, 19 Feb 2026 18:49:23 +0000 Subject: [PATCH 07/41] pytest, example, wheels conflict resolution --- ci/pytorch.sh | 3 ++- tests/pytorch/test_cpu_offloading.py | 11 +++++++++++ tests/pytorch/test_cpu_offloading_v1.py | 4 +++- .../test_float8_current_scaling_exact.py | 6 ++++-- .../test_layernorm_saved_tensors_logic.py | 8 ++++---- tests/pytorch/test_numerics.py | 6 ++++-- .../transformer_engine/hadamard_transform.h | 6 ++++++ transformer_engine/common/recipe/__init__.py | 12 ++++++------ transformer_engine/pytorch/csrc/common.h | 2 ++ .../pytorch/csrc/extensions/activation.cpp | 12 ++++++++++++ .../pytorch/csrc/extensions/bias.cpp | 6 ++++++ .../pytorch/csrc/extensions/cast.cpp | 4 ++++ .../pytorch/csrc/extensions/normalization.cpp | 16 ++++++++++++++++ transformer_engine/pytorch/csrc/pybind.h | 4 ++++ transformer_engine/pytorch/csrc/quantizer.cpp | 2 ++ transformer_engine/pytorch/module/base.py | 2 +- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 8 ++++---- transformer_engine/pytorch/module/linear.py | 2 +- transformer_engine/pytorch/quantization.py | 1 + .../pytorch/tensor/mxfp8_tensor.py | 4 +++- .../pytorch/triton_kernels/cast.py | 18 +++++++++--------- .../pytorch/triton_kernels/layernorm.py | 2 +- .../pytorch/triton_kernels/rmsnorm.py | 2 +- 24 files changed, 108 insertions(+), 35 deletions(-) diff --git a/ci/pytorch.sh b/ci/pytorch.sh index be150485f..5558beca3 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -51,7 +51,8 @@ run_test_config(){ run_default_fa 1 test_deferred_init.py run_default_fa 1 test_float8tensor.py run_default_fa 1 test_float8_current_scaling_exact.py - test $_fus_attn = auto -o $_fus_attn = ck -o $_fus_attn = aotriton && NVTE_FLASH_ATTN=0 run 1 test_cpu_offloading.py + run 1 test_cpu_offloading.py + test $_fus_attn = auto -o $_fus_attn = ck -o $_fus_attn = aotriton && NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 run 3 test_cpu_offloading_v1.py run_default_fa 1 test_fused_rope.py run_default_fa 1 test_fused_router.py run_default_fa 1 test_fusible_ops.py diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index c5b4b48b6..4e4c71e14 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -21,6 +23,8 @@ from utils import ModelConfig import transformer_engine_torch as tex +from torch.utils.cpp_extension import IS_HIP_EXTENSION + # Check supported quantization schemes fp8_available, _ = FP8GlobalStateManager.is_fp8_available() fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() @@ -626,6 +630,13 @@ def test_numerics( "Fused attention + cuda graphs is temporarily broken, not because of cpu offloading" ) + if (IS_HIP_EXTENSION + and backend == "FusedAttention" + and not use_cuda_graphs + and layer_type in ("multihead_attention", "transformer_layer") + ): + pytest.skip("No dot product attention backend is available for the provided inputs") + os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_UNFUSED_ATTN"] = "0" diff --git a/tests/pytorch/test_cpu_offloading_v1.py b/tests/pytorch/test_cpu_offloading_v1.py index 8a8e03630..07091ee7a 100644 --- a/tests/pytorch/test_cpu_offloading_v1.py +++ b/tests/pytorch/test_cpu_offloading_v1.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -34,7 +36,7 @@ # Flash attention saves some internal tensor for the backward pass # that cannot be offloaded to CPU. -assert os.getenv("NVTE_FLASH_ATTN") == "0" +assert os.getenv("NVTE_FLASH_ATTN", "1") == "0" # CPU offload v1 code path is enabled assert os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1" diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index d21c2e366..21fe4700b 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -9,7 +11,7 @@ import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Float8CurrentScaling +from transformer_engine.common.recipe import Float8CurrentScaling, Format from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype @@ -847,7 +849,7 @@ def test_fp8_current_scaling_linear_large_numel_e4m3(self, dtype, shape): pytest.skip(f"Skipping {shape}: insufficient device memory for allocation.") try: - with fp8_autocast(enabled=True, fp8_recipe=recipe): + with autocast(enabled=True, recipe=recipe): y = layer(x) except torch.OutOfMemoryError: pytest.skip(f"Skipping {shape}: OOM during forward.") diff --git a/tests/pytorch/test_layernorm_saved_tensors_logic.py b/tests/pytorch/test_layernorm_saved_tensors_logic.py index cb7760b5f..ab7d5d9d4 100644 --- a/tests/pytorch/test_layernorm_saved_tensors_logic.py +++ b/tests/pytorch/test_layernorm_saved_tensors_logic.py @@ -1,11 +1,11 @@ -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information import pytest import torch import torch.nn as nn from unittest.mock import patch -from transformer_engine.pytorch import LayerNormLinear, LayerNormMLP, fp8_autocast +from transformer_engine.pytorch import LayerNormLinear, LayerNormMLP, autocast from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager @@ -84,7 +84,7 @@ def spy_on_ctx(ctx, *args, **kwargs): weight_tensor.requires_grad_(True) with patch(config["backward_target"], side_effect=spy_on_ctx) as mock_backward: - with fp8_autocast(enabled=True): + with autocast(enabled=True): out, ln_out_returned = model(inp) out.backward(grad_output, retain_graph=True) @@ -99,7 +99,7 @@ def spy_on_ctx(ctx, *args, **kwargs): saved_ln_out_container.clear() with patch(config["backward_target"], side_effect=spy_on_ctx) as mock_backward: - with fp8_autocast(enabled=True): + with autocast(enabled=True): out, ln_out_returned = model(inp) out.backward(grad_output) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 50578fc1a..9fe6304d4 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1196,7 +1196,7 @@ def _test_granular_accuracy_with_fp8(block, bs, dtype, config): ) inp_hidden_states.retain_grad() - with fp8_autocast(enabled=True): + with autocast(enabled=True): out = block(inp_hidden_states) loss = out.sum() loss.backward() @@ -1357,10 +1357,11 @@ def test_fp8_linear_without_transpose_cache_accuracy(dtype, bs, model, fp8_model module = LayerNormLinear config = model_configs[model] - with fp8_model_init(enabled=fp8_model_params): + with quantized_model_init(enabled=fp8_model_params): layer = module( config.hidden_size, 4 * config.hidden_size, + config.eps, bias=True, params_dtype=dtype, device="cuda", @@ -1371,6 +1372,7 @@ def test_fp8_linear_without_transpose_cache_accuracy(dtype, bs, model, fp8_model ref_layer = module( config.hidden_size, 4 * config.hidden_size, + config.eps, bias=True, params_dtype=dtype, device="cuda", diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index a0dd325da..73edf23a3 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -11,6 +13,8 @@ #ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ #define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ +#ifndef __HIP_PLATFORM_AMD__ + #include "transformer_engine.h" #ifdef __cplusplus @@ -65,4 +69,6 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE } // extern "C" #endif +#endif + #endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index c55f1f612..674d4e4cb 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -33,6 +33,7 @@ class _FormatMaxVals(Enum): """ Tuples of FP8 (OCP, FNUZ) values for different formats. """ + E2M1 = (6, 6) E4M3 = (448, 240) E5M2 = (57344, 57344) @@ -53,12 +54,11 @@ class Format(Enum): FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format """ - #TODO: bring E2M1 back after rocm support MXFP4 - if not te_rocm_build: - E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) - E4M3 = _FormatHelper(max_fwd=_FormatMaxVals.E4M3.value, max_bwd=_FormatMaxVals.E4M3.value) - E5M2 = _FormatHelper(max_fwd=_FormatMaxVals.E5M2.value, max_bwd=_FormatMaxVals.E5M2.value) - HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) + #TODO: Change max vals after rocm support MXFP4 + E2M1 = _FormatHelper(fwd=_FormatMaxVals.E2M1.value, bwd=_FormatMaxVals.E2M1.value) + E4M3 = _FormatHelper(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E4M3.value) + E5M2 = _FormatHelper(fwd=_FormatMaxVals.E5M2.value, bwd=_FormatMaxVals.E5M2.value) + HYBRID = _FormatHelper(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E5M2.value) @dataclass(frozen=True) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 74852b22d..55d8aafd6 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -293,6 +293,7 @@ class MXFP8Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; +#ifndef __HIP_PLATFORM_AMD__ class NVFP4Quantizer : public Quantizer { public: // fp4 dtype @@ -346,6 +347,7 @@ class NVFP4Quantizer : public Quantizer { void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); }; +#endif // #ifndef __HIP_PLATFORM_AMD__ std::unique_ptr convert_quantizer(py::handle quantizer); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index de1e3ccbd..ebdbb5817 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -41,6 +41,9 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { impl = Impl::FUSED_ACTIVATION_AMAX_FP8; +#ifdef __HIP_PLATFORM_AMD__ + } +#else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); @@ -51,6 +54,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; } } +#endif // Perform compute auto stream = at::cuda::getCurrentCUDAStream(); @@ -101,6 +105,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); } break; +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_ACTIVATION_AMAX_NVFP4: // Compute activation and amax in high precision, then quantize to NVFP4 { @@ -119,6 +124,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); } break; +#endif default: NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); } @@ -153,6 +159,9 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { impl = Impl::FUSED_ACTIVATION_AMAX_FP8; +#ifdef __HIP_PLATFORM_AMD__ + } +#else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); @@ -163,6 +172,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; } } +#endif // Perform compute auto stream = at::cuda::getCurrentCUDAStream(); @@ -213,6 +223,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); } break; +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_ACTIVATION_AMAX_NVFP4: // Compute activation and amax in high precision, then quantize to NVFP4 { @@ -231,6 +242,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); } break; +#endif default: NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); } diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index dcff95887..1d3e27a14 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -151,6 +151,9 @@ std::vector dact_dbias( impl = Impl::FUSED_DACT_DBIAS_QUANTIZE; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { impl = Impl::FUSED_DACT_AMAX_FP8; +#ifdef __HIP_PLATFORM_AMD__ + } +#else } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); @@ -161,6 +164,7 @@ std::vector dact_dbias( impl = Impl::FUSED_DACT_AMAX_NVFP4; } } +#endif // Perform compute auto stream = at::cuda::getCurrentCUDAStream(); @@ -220,6 +224,7 @@ std::vector dact_dbias( fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); break; } +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_DACT_AMAX_NVFP4: // Fused dact-amax kernel, unfused dbias and NVFP4 quantize { @@ -237,6 +242,7 @@ std::vector dact_dbias( nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); break; } +#endif default: NVTE_ERROR("Invalid implementation"); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 48c02215f..97b9e7ca6 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -493,6 +493,7 @@ std::tuple, std::vector> bulk_allocate_mx return retval; } +#ifndef __HIP_PLATFORM_AMD__ // allocate fp4 data, fp8 scalings, and amax values // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate @@ -693,6 +694,7 @@ std::tuple, std::vector> bulk_allocate_nv return retval; } +#endif // #ifndef __HIP_PLATFORM_AMD__ } // namespace @@ -791,6 +793,7 @@ std::vector split_quantize(const at::Tensor &tensor, } std::tie(output_py_list, output_cpp_list) = bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); +#ifndef __HIP_PLATFORM_AMD__ } else if (is_nvfp4) { // NVFP4: construct output tensors with bulk allocations std::vector nvfp4_quantizers; @@ -799,6 +802,7 @@ std::vector split_quantize(const at::Tensor &tensor, } std::tie(output_py_list, output_cpp_list) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); +#endif } else { NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer"); } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 060a342a6..6d635e1c2 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -120,6 +120,9 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); impl = Impl::FUSED_NORM_AMAX_FP8; +#ifdef __HIP_PLATFORM_AMD__ + } +#else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); @@ -131,6 +134,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe impl = Impl::FUSED_NORM_AMAX_NVFP4; } } + #endif // Construct unquantized output tensor if needed TensorWrapper unquantized_out_nvte; @@ -148,12 +152,14 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); kernel_out_nvte = &unquantized_out_nvte; } break; +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); std::tie(unquantized_out_nvte, unquantized_out) = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); kernel_out_nvte = &unquantized_out_nvte; } break; +#endif default: { } } @@ -191,10 +197,12 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); } break; +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); } break; +#endif default: { } } @@ -344,6 +352,9 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); impl = Impl::FUSED_NORM_AMAX_FP8; +#ifdef __HIP_PLATFORM_AMD__ + } +#else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); @@ -355,6 +366,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w impl = Impl::FUSED_NORM_AMAX_NVFP4; } } +#endif // Construct unquantized output tensor if needed TensorWrapper unquantized_out_nvte; @@ -372,12 +384,14 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); kernel_out_nvte = &unquantized_out_nvte; } break; +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); std::tie(unquantized_out_nvte, unquantized_out) = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); kernel_out_nvte = &unquantized_out_nvte; } break; +#endif default: { } } @@ -413,10 +427,12 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); } break; +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); } break; +#endif default: { } } diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 65665d01b..1c1855669 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -108,8 +108,12 @@ constexpr std::array custom_types_converters = { CreateQuantizer), std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer), +#ifdef __HIP_PLATFORM_AMD__ +}; +#else std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor, CreateQuantizer)}; +#endif } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index ffef3e59c..7240c3bf3 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1143,6 +1143,7 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s return scale_shape; } +#ifndef __HIP_PLATFORM_AMD__ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->with_rht = quantizer.attr("with_rht").cast(); @@ -1719,5 +1720,6 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s } return scale_shape; } +#endif } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 661cf3f2e..0ad1e86a4 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -51,7 +51,7 @@ from ..triton_kernels.cast import te_quantize_triton from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage -from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype +from ..utils import get_device_compute_capability, is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a906ea42d..89af05f93 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -444,7 +444,7 @@ def forward( # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: - if isinstance(weightmat, QuantizedTensorBase): + if isinstance(weightmat, QuantizedTensorStorage): weightmat.update_usage(columnwise_usage=True) if cpu_offloading: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 3fefb650e..fb89d6195 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -576,9 +576,9 @@ def forward( if is_grad_enabled: # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: - if isinstance(fc1_weight_final, QuantizedTensorBase): + if isinstance(fc1_weight_final, QuantizedTensorStorage): fc1_weight_final.update_usage(columnwise_usage=True) - if isinstance(fc2_weight_final, QuantizedTensorBase): + if isinstance(fc2_weight_final, QuantizedTensorStorage): fc2_weight_final.update_usage(columnwise_usage=True) if cpu_offloading: @@ -897,7 +897,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ctx.fc2_weight_quantizer is not None and isinstance( - ctx.fc2_weight, QuantizedTensorStorage + fc2_weight, QuantizedTensorStorage ): fc2_weight.update_usage(columnwise_usage=True) @@ -1155,7 +1155,7 @@ def fc2_wgrad_gemm( # Make sure required data is available if ctx.fc1_weight_quantizer is not None and isinstance( - ctx.fc1_weight_quantizer, QuantizedTensorStorage + fc1_weight, QuantizedTensorStorage # this fixes a bug with upstream usage of fc1_weight_quantizer ): fc1_weight.update_usage(columnwise_usage=True) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ffa47d986..0d43776f1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -410,7 +410,7 @@ def forward( # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: - if isinstance(weightmat, QuantizedTensorBase): + if isinstance(weightmat, QuantizedTensorStorage): weightmat.update_usage(columnwise_usage=True) if cpu_offloading and saved_inputmat is not None: diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 0b7eddb9f..915527736 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -260,6 +260,7 @@ class FP8GlobalStateManager: HIGH_PRECISION_INIT_VAL = False IS_FIRST_FP8_MODULE = False FP8_GRAPH_CAPTURING = False + SKIP_FP8_REDUCTION_FOR_FSDP2 = False AUTOCAST_DEPTH = 0 global_amax_buffer = {} global_amax_history_buffer = {} diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 1848a60cf..f9ff4b77b 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -142,7 +142,9 @@ def make_empty( columnwise_data = None columnwise_scale_inv = None if self.columnwise_usage: - columnwise_data = torch.empty_like(data) + columnwise_data = torch.empty( + shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + ) # ROCm TE does not implement fuse padding zeros so use zero tensor here columnwise_scale_inv = torch.zeros( round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), diff --git a/transformer_engine/pytorch/triton_kernels/cast.py b/transformer_engine/pytorch/triton_kernels/cast.py index b6a7270a3..eae6bb79c 100644 --- a/transformer_engine/pytorch/triton_kernels/cast.py +++ b/transformer_engine/pytorch/triton_kernels/cast.py @@ -10,11 +10,11 @@ from ..utils import is_non_tn_fp8_gemm_supported -from ..tensor._internal.float8_tensor_base import Float8TensorBase +from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from .cast_transpose import te_cast_transpose_mxfp8_triton, te_cast_transpose_noop_triton, te_dequantize_mxfp8_triton import transformer_engine_torch as tex -from ..tensor.quantized_tensor import QuantizedTensor, Quantizer -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..quantized_tensor import QuantizedTensor, Quantizer +from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage @functools.lru_cache(maxsize=None) def _empty_tensor() -> torch.Tensor: @@ -72,7 +72,7 @@ def te_quantize_triton( _setup_conditional_transpose_storage(out) else: out = quantizer.make_empty(input_tensor.shape, dtype=fake_tensor_type) - if isinstance(out, Float8TensorBase): + if isinstance(out, Float8TensorStorage): _setup_conditional_transpose_storage(out) else: # Create a QuantizedTensor from the provided output tensor @@ -82,11 +82,11 @@ def te_quantize_triton( if noop_flag is None: noop_flag = _empty_tensor() # if it's mxfp8, we'll check if both rowwise and columnwise data are none - if (isinstance(out, MXFP8TensorBase) and out._rowwise_data is None and out._columnwise_data is None) or (not isinstance(out, MXFP8TensorBase) and out.size().numel() == 0): + if (isinstance(out, MXFP8TensorStorage) and out._rowwise_data is None and out._columnwise_data is None) or (not isinstance(out, MXFP8TensorStorage) and out.size().numel() == 0): # Return empty output if the quantized tensor has no elements return out - if isinstance(out, Float8TensorBase): + if isinstance(out, Float8TensorStorage): if input_tensor.nelement() > 0: if not out._transpose_invalid: quantizer = out._get_quantizer() @@ -117,7 +117,7 @@ def te_quantize_triton( else: out.remove_caches() #Make sure to remove transpose if it is marked as invalid out = tex.quantize(input_tensor, quantizer, out, noop_flag) - elif isinstance(out, MXFP8TensorBase): + elif isinstance(out, MXFP8TensorStorage): te_cast_transpose_mxfp8_triton(input_tensor, out) else: raise NotImplementedError(f"Not implemented for tensor type: '{type(out).__name__}'") @@ -125,9 +125,9 @@ def te_quantize_triton( return out def te_dequantize_triton(input, dtype: tex.DType): - if isinstance(input, MXFP8TensorBase): + if isinstance(input, MXFP8TensorStorage): return te_dequantize_mxfp8_triton(input, dtype) - elif isinstance(input, Float8TensorBase): + elif isinstance(input, Float8TensorStorage): return tex.dequantize(input, dtype) else: raise NotImplementedError(f"Not implemented for tensor type: '{type(input).__name__}'") diff --git a/transformer_engine/pytorch/triton_kernels/layernorm.py b/transformer_engine/pytorch/triton_kernels/layernorm.py index 3baa64697..37f504e4c 100644 --- a/transformer_engine/pytorch/triton_kernels/layernorm.py +++ b/transformer_engine/pytorch/triton_kernels/layernorm.py @@ -10,7 +10,7 @@ from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..constants import TE_DType from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor.quantized_tensor import Quantizer +from ..quantized_tensor import Quantizer from ..triton_kernels.cast import te_quantize_triton import triton import triton.language as tl diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 9f152582e..3acbf0835 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -13,7 +13,7 @@ te_dtype_to_triton_dtype, ) from .common import get_fp8_max -from ..tensor.quantized_tensor import Quantizer +from ..quantized_tensor import Quantizer import transformer_engine_torch as tex def dg_tmp_rows(x, sm_margin=None): From 8f94cf652f6989483e4525f8f3812c6046fa9543 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 24 Feb 2026 09:47:40 -0600 Subject: [PATCH 08/41] jax and pytorch bugfix --- tests/jax/test_custom_call_compute.py | 15 ++++++++------- tests/jax/test_fused_attn.py | 4 ++-- tests/jax/utils.py | 2 +- tests/pytorch/attention/run_attention_with_cp.py | 4 +++- tests/pytorch/attention/test_attention_with_cp.py | 7 ++++--- .../jax/cpp_extensions/attention.py | 1 + transformer_engine/jax/cpp_extensions/gemm.py | 6 +++--- .../jax/cpp_extensions/normalization.py | 1 + transformer_engine/jax/csrc/extensions/amax.cpp | 4 ++++ transformer_engine/jax/csrc/extensions/gemm.cpp | 6 ++++++ .../jax/csrc/extensions/quantization.cpp | 2 ++ transformer_engine/pytorch/distributed.py | 6 ++++++ 12 files changed, 41 insertions(+), 17 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 75b606a62..3b9ee0034 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -636,7 +636,7 @@ def test_norm_forward_with_block_scaling_fp8( ) -QUANTIZE_OUTPUT_DTYPES = { +QUANTIZE_OUTPUT_FP8_DTYPES = { "L0": [jnp_float8_e4m3_type], "L2": FP8_COMPUTE_TYPE, } @@ -1790,11 +1790,12 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( dtype, input_shape, layout ) - num_gemms = input_shape[0] - _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))( - group_sizes, - num_gemms=num_gemms, - ) + if not is_hip_extension(): + num_gemms = input_shape[0] + _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))( + group_sizes, + num_gemms=num_gemms, + ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm @@ -1805,7 +1806,7 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): rhs, group_sizes, contracting_dims, - use_async_d2h_group_sizes=True, + use_async_d2h_group_sizes=not is_hip_extension(), ) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index f33961455..a08a1fe42 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1026,14 +1026,14 @@ def check_dqkv(primitive, reference, pad, idx): ), pytest.param( 2, - 512, + 2048, 1024, 12, 12, 64, 64, jnp.bfloat16, - id="2-512-1024-12-12-64-64-BF16-CROSS", + id="2-2048-1024-12-12-64-64-BF16-CROSS", ), pytest.param( 2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA" diff --git a/tests/jax/utils.py b/tests/jax/utils.py index a0e5e708b..60b01348f 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 80f21048f..b59fe6451 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -451,9 +451,11 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[4] = tensors_to_deq - for tensor in tensors: + i = 0 + for tensor in tensors[4:]: assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) + i += 1 out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors ############ compare results between CP and no-CP ############ diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 9ac96dcef..d0956c226 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -24,6 +24,7 @@ Float8CurrentScaling, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils +from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) @@ -144,8 +145,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): model_configs_fused_attn = { # test: ModelConfig(b, sq, hq, dqk) - "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=True), # MHA - "cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=True), # MHA + "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=not IS_HIP_EXTENSION), # MHA + "cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=not IS_HIP_EXTENSION), # MHA "cp_1_2": ModelConfig( 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA @@ -305,7 +306,7 @@ def test_cp_with_fused_attention( ] available_backends, _, fused_attn_backends = get_available_attention_backends( config, - qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, + qkv_dtype=dtypes[dtype] if dtype != "fp8" else get_torch_float8_e4m3_type(), qkv_layout="_".join([qkv_format] * 3), fp8=fp8, fp8_meta=fp8_meta, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 428f3ba2e..ab2a4562e 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -9,6 +9,7 @@ import warnings from dataclasses import dataclass, replace from functools import partial, reduce +from packaging import version from typing import Optional, Tuple import jax diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 756913c91..a04a98d97 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -24,8 +24,8 @@ get_num_compute_streams, JAXX_Collective_Op, get_device_compute_capability, - initialize_cgemm_communicator, - get_cgemm_num_max_streams, + #initialize_cgemm_communicator, + #get_cgemm_num_max_streams, ) from .base import BasePrimitive, register_primitive @@ -83,7 +83,7 @@ def get_cublas_workspace_size_bytes() -> None: """Return workspace size needed for current architecture""" if is_hip_extension(): """Return 64 MiB for gfx50x, 32 MiB for all other architectures.""" - if tex.get_device_compute_capability(0) == 95: + if get_device_compute_capability(0) == 95: return 67_108_864 return 33_554_432 """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 1bf6ec943..e53d63625 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -8,6 +8,7 @@ import warnings import operator from functools import partial, cache, reduce +from packaging import version from typing import Optional, Union import jax diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index 46f167fca..a4b590250 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -1,8 +1,11 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +#ifndef __HIP_PLATFORM_AMD__ #include #include @@ -98,3 +101,4 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( } // namespace jax } // namespace transformer_engine +#endif // #ifndef __HIP_PLATFORM_AMD__ \ No newline at end of file diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 41b78f117..f038101b2 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -87,6 +87,9 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( } } else { // Swizzle for NVFP4 NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); +#ifdef __HIP_PLATFORM_AMD__ + } +#else input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); // Create tensor to hold swizzled scale factor TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); @@ -97,6 +100,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( // Set swizzled scales into the input tensor input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); } +#endif // #ifdef __HIP_PLATFORM_AMD__ } return std::make_tuple(std::move(input), input_shape); @@ -767,6 +771,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t num_non_empty_gemms = lhs_list.size(); +#ifndef __HIP_PLATFORM_AMD__ if (is_mxfp8_scaling) { for (int i = 0; i < num_non_empty_gemms; i++) { // The i-th GEMM will use the (i % num_streams)-th stream to compute, @@ -778,6 +783,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); } } +#endif // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM size_t num_zero_outs = zero_out_dptr_list.size(); diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 1f7db8438..626c47276 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -177,6 +177,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T } if (is_quantize_colwise(quantize_layout)) { +#ifndef __HIP_PLATFORM_AMD__ if (is_nvfp4 && use_rht) { if (is_quantize_2x2x(quantize_layout)) { // Do regular rowwise quantization without RHT @@ -218,6 +219,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T return ffi_with_cuda_error_check(); } +#endif // #ifndef __HIP_PLATFORM_AMD__ bool const is_colwise_transposed = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4; diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index c7329caca..04ffa324d 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -952,7 +952,13 @@ def _all_gather_fp8( if isinstance(inp, Float8Tensor): dtype = inp.dtype device = inp.device + # Temporarily ensure rowwise usage for output tensor creation + # since we're gathering rowwise data, not the transpose + init_rowwise_usage = quantizer.rowwise_usage + init_columnwise_usage = quantizer.columnwise_usage + quantizer.set_usage(rowwise=True, columnwise=init_columnwise_usage) out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + quantizer.set_usage(rowwise=init_rowwise_usage, columnwise=init_columnwise_usage) elif isinstance(inp, Float8Tensor): out = inp.make_like(inp, shape=out_shape) out._data = torch.empty( From bac79938e79698010f749470bb802260a76a64f0 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 24 Feb 2026 10:41:15 -0600 Subject: [PATCH 09/41] copyrights and fp8_autocast->autocast fix --- build_tools/jax.py | 2 +- build_tools/pytorch.py | 2 +- build_tools/utils.py | 2 +- tests/cpp/operator/CMakeLists.txt | 2 +- tests/pytorch/distributed/test_fusible_ops.py | 2 +- tests/pytorch/test_fused_optimizer.py | 2 +- tests/pytorch/triton_kernels/test_cast.py | 10 +++++----- transformer_engine/common/CMakeLists.txt | 2 +- transformer_engine/common/__init__.py | 2 +- transformer_engine/common/cast/cast.cu | 2 ++ transformer_engine/common/cast/dispatch/dequantize.cuh | 2 ++ transformer_engine/common/cast/fp8/dequantize_fp8.cuh | 2 ++ transformer_engine/common/cast/fp8/gated_fp8.cuh | 2 ++ .../common/cast/mxfp8/dequantize_mxfp8.cuh | 2 +- transformer_engine/common/common.h | 2 +- .../common/fused_attn_rocm/fused_attn.cpp | 2 +- .../common/fused_attn_rocm/fused_attn_aotriton.cpp | 2 +- .../common/fused_attn_rocm/fused_attn_aotriton.h | 2 +- .../common/fused_attn_rocm/fused_attn_ck.cpp | 2 +- .../common/fused_attn_rocm/fused_attn_ck.h | 2 +- transformer_engine/common/gemm/cublaslt_gemm.cu | 2 +- transformer_engine/common/gemm/rocm_gemm.cu | 2 +- transformer_engine/common/recipe/__init__.py | 2 +- transformer_engine/common/util/logging.h | 2 +- transformer_engine/jax/csrc/extensions.h | 2 +- transformer_engine/jax/csrc/extensions/pybind.cpp | 2 +- transformer_engine/jax/quantize/helper.py | 2 +- .../attention/dot_product_attention/backends.py | 2 +- .../dot_product_attention/context_parallel.py | 2 +- transformer_engine/pytorch/csrc/common.h | 2 +- .../pytorch/csrc/extensions/activation.cpp | 2 +- transformer_engine/pytorch/csrc/extensions/cast.cpp | 2 +- .../pytorch/csrc/extensions/normalization.cpp | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/triton_kernels/layernorm.py | 2 +- transformer_engine/pytorch/triton_kernels/rmsnorm.py | 2 +- 36 files changed, 44 insertions(+), 36 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index e67036f49..6f2e57f87 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 35d910bcd..4bbefd730 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/utils.py b/build_tools/utils.py index 0c34bedde..0c18d7ecf 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index cd36993ce..d3b75bbbf 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 85ae2d85b..b9fbfb2a5 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index eeeda171e..5526103d5 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/triton_kernels/test_cast.py b/tests/pytorch/triton_kernels/test_cast.py index 3f725c496..f85773d65 100644 --- a/tests/pytorch/triton_kernels/test_cast.py +++ b/tests/pytorch/triton_kernels/test_cast.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information import pytest @@ -10,7 +10,7 @@ from transformer_engine.pytorch.triton_kernels.common import te_dtype_to_torch_dtype import transformer_engine_torch as tex from test_common import te_compare_results, fill_uniform, get_tolerances -from transformer_engine.pytorch.fp8 import fp8_autocast +from transformer_engine.pytorch.fp8 import autocast from transformer_engine.common import recipe from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type @@ -43,7 +43,7 @@ def test_quantize(scaling, shape, in_dtype, out_dtype): triton_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=out_dtype, device="cuda") tex_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=out_dtype, device="cuda") - with fp8_autocast(enabled=True, fp8_recipe=recipe.Float8CurrentScaling()): + with autocast(enabled=True, recipe=recipe.Float8CurrentScaling()): quantized_out_triton = te_quantize_triton(input_tensor, quantizer=triton_quantizer) quantized_out_tex = tex.quantize(input_tensor, tex_quantizer) @@ -187,13 +187,13 @@ def test_amax_atomic_vs_two_stage(shape, in_dtype, out_dtype): # atomic amax os.environ[env_key] = "1" - with fp8_autocast(enabled=True, fp8_recipe=recipe.Float8CurrentScaling()): + with autocast(enabled=True, recipe=recipe.Float8CurrentScaling()): out_atomic = te_quantize_triton(input_tensor, quantizer=quantizer_atomic) # 2-stage amax os.environ[env_key] = "0" - with fp8_autocast(enabled=True, fp8_recipe=recipe.Float8CurrentScaling()): + with autocast(enabled=True, recipe=recipe.Float8CurrentScaling()): out_2stage = te_quantize_triton(input_tensor, quantizer=quantizer_2stage) te_compare_results( diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 46eb5dba5..831de2b45 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index cdda37508..f0335f44c 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 575106a53..7ecc05d2e 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index f55719852..4ba64ca97 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh index 5d30a6c3f..22a3929e3 100644 --- a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/cast/fp8/gated_fp8.cuh b/transformer_engine/common/cast/fp8/gated_fp8.cuh index c9040a3da..aa46a574c 100644 --- a/transformer_engine/common/cast/fp8/gated_fp8.cuh +++ b/transformer_engine/common/cast/fp8/gated_fp8.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index 96aed3e88..5701a446d 100644 --- a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 03b90febb..5feeb600c 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index d39fccbce..48a309118 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 1c25fa031..9109ddb15 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h index 178bd8d8f..fd4dffd73 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 8d639c47c..02bc9ce94 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h index 926c90866..0772609ff 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 343b3cecb..0127a9edf 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 97bd2e8a7..205a0a058 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 674d4e4cb..b90cd5ce3 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 09187069e..ebcf99afe 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 845176080..2bfa4c89f 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index bc47ef6bd..937dde228 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 792173ed1..b8a8809fc 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 5437b73bc..a0aaab1f3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 13b41345b..096eca809 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 55d8aafd6..fd83f20d4 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index ebdbb5817..6936d6bc8 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 97b9e7ca6..f3c77a332 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 6d635e1c2..805579ff4 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index fb89d6195..fb3327156 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/triton_kernels/layernorm.py b/transformer_engine/pytorch/triton_kernels/layernorm.py index 37f504e4c..86b7b46c7 100644 --- a/transformer_engine/pytorch/triton_kernels/layernorm.py +++ b/transformer_engine/pytorch/triton_kernels/layernorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 3acbf0835..1ca6183c9 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information import torch From 8ae38e8e66df970c2cc165dee6d7f0bd98d35250 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 24 Feb 2026 13:35:14 -0600 Subject: [PATCH 10/41] Enable test_distributed_dense.py --- ci/jax.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/jax.sh b/ci/jax.sh index 81d994585..d350ebac7 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -86,6 +86,7 @@ run_test_config_mgpu() { if [ "$TEST_LEVEL" -le 3 ]; then TEST_ERROR_IGNORE="1" fi + run_default_fa 2 test_distributed_dense.py run $_dfa_level test_distributed_fused_attn.py $_timeout_args TEST_ERROR_IGNORE="" run_default_fa 3 test_distributed_layernorm.py From 05a977a8e089c4e3beb308267ce0d8cc3a6416d9 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 3 Mar 2026 00:18:46 -0600 Subject: [PATCH 11/41] address IFU comments --- build_tools/wheel_utils/build_wheels.sh | 8 +- ci/jax.sh | 12 +-- setup.py | 7 +- tests/jax/distributed_test_base.py | 2 - tests/jax/test_fused_attn.py | 1 + .../attention/run_attention_with_cp.py | 1 - tests/pytorch/utils.py | 4 +- .../common/cast/core/common.cuh | 2 + .../common/cast/mxfp8/dequantize_mxfp8.cuh | 2 + .../common/cast/mxfp8/gated_mxfp8.cuh | 2 + .../common/cast/mxfp8/quantize_mxfp8.cuh | 2 + .../cast/mxfp8/rocm_dequantize_mxfp8.cuh | 7 +- .../common/cast/mxfp8/rocm_gated_mxfp8.cuh | 14 ++-- .../common/cast/mxfp8/rocm_quantize_mxfp8.cuh | 8 +- transformer_engine/common/common.cu | 2 +- .../common/fused_attn_rocm/fused_attn.cpp | 1 + transformer_engine/common/recipe/__init__.py | 32 +++----- .../common/recipe/current_scaling.cu | 2 +- transformer_engine/common/swizzle/swizzle.cu | 3 +- transformer_engine/common/util/ptx.cuh | 73 ------------------- .../common/util/rocm_vectorized_2d.cuh | 13 ---- .../jax/cpp_extensions/attention.py | 4 +- transformer_engine/jax/cpp_extensions/gemm.py | 3 +- .../jax/csrc/extensions/amax.cpp | 4 +- .../jax/csrc/extensions/cgemm_helper.h | 4 + .../jax/csrc/extensions/gemm.cpp | 11 +-- transformer_engine/jax/csrc/extensions/misc.h | 2 + .../jax/csrc/extensions/quantization.cpp | 4 +- transformer_engine/jax/quantize/helper.py | 2 + transformer_engine/jax/setup.py | 5 +- .../dot_product_attention/backends.py | 4 + .../dot_product_attention.py | 2 + .../attention/dot_product_attention/utils.py | 7 +- .../pytorch/cpp_extensions/fused_attn.py | 4 + transformer_engine/pytorch/csrc/common.cpp | 4 +- .../pytorch/csrc/extensions/activation.cpp | 8 +- .../pytorch/csrc/extensions/bias.cpp | 4 +- .../pytorch/csrc/extensions/cast.cpp | 6 +- .../pytorch/csrc/extensions/normalization.cpp | 12 +-- .../pytorch/csrc/extensions/recipe.cpp | 2 +- transformer_engine/pytorch/csrc/pybind.h | 2 + transformer_engine/pytorch/csrc/quantizer.cpp | 4 +- transformer_engine/pytorch/fp8.py | 2 - transformer_engine/pytorch/module/base.py | 2 +- transformer_engine/pytorch/quantization.py | 2 + transformer_engine/pytorch/setup.py | 12 ++- .../pytorch/tensor/float8_tensor.py | 2 +- transformer_engine/pytorch/utils.py | 4 +- 48 files changed, 130 insertions(+), 190 deletions(-) delete mode 100644 transformer_engine/common/util/rocm_vectorized_2d.cuh diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 0be852c8a..0e8ab68a8 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -75,12 +75,12 @@ if $BUILD_COMMON ; then ${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt if [ "$ROCM_BUILD" = "1" ]; then - # Repack the wheel for cuda specific package, i.e. cu12. + # Repack the wheel for specific rocm package. ${PYBINDIR}wheel unpack dist/* # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). - sed -i "s/Name: transformer-engine/Name: transformer-engine-${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - sed -i "s/Name: transformer_engine/Name: transformer_engine_${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_${TE_CUDA_VERS}-${VERSION}.dist-info" + sed -i "s/Name: transformer-engine/Name: transformer-engine-rocm/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_rocm/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_rocm-${VERSION}.dist-info" ${PYBINDIR}wheel pack ${WHL_BASE} else # Repack the wheel for specific cuda version. diff --git a/ci/jax.sh b/ci/jax.sh index d350ebac7..ef9dbe124 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -59,8 +59,7 @@ run_test_config() { run_default_fa 1 test_functions.py run 1 test_fused_attn.py NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass - run_default_fa 1 test_helper.py - run_default_fa 1 test_layer.py #it effectevly always uses unfused attention + run_default_fa 1 test_layer.py # it effectively always uses unfused attention run_default_fa 1 test_sanity_import.py run_default_fa 1 test_softmax.py } @@ -71,7 +70,7 @@ run_test_config_mgpu() { # Mitigate distributed tests hang by adding 5min timeout _timeout_args="--timeout 300 --timeout-method thread" - # Workaround for some distributed tests hang/abotrion + # Workaround for some distributed tests hang/abortion export XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" if [ $_fus_attn = $_DEFAULT_FUSED_ATTN ]; then @@ -81,12 +80,13 @@ run_test_config_mgpu() { _dfa_level=3 export NVTE_JAX_UNITTEST_LEVEL=L2 fi + + run_default_fa 2 test_distributed_dense.py # Do not fail automated CI if test_distributed_fused_attn is hung - # If the sctipt run w/o TEST_LEVEL the test error will be honored + # If the script runs w/o TEST_LEVEL the test error will be honored if [ "$TEST_LEVEL" -le 3 ]; then TEST_ERROR_IGNORE="1" fi - run_default_fa 2 test_distributed_dense.py run $_dfa_level test_distributed_fused_attn.py $_timeout_args TEST_ERROR_IGNORE="" run_default_fa 3 test_distributed_layernorm.py @@ -96,7 +96,7 @@ run_test_config_mgpu() { run_default_fa 3 test_sanity_import.py } -# Single config mode, run it synchroniously and return result +# Single config mode, run it synchronously and return result if [ -n "$SINGLE_CONFIG" ]; then _fus_attn="$SINGLE_CONFIG" configure_fused_attn_env $_fus_attn && run_test_config diff --git a/setup.py b/setup.py index bec4943e1..eb241f5cb 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,6 @@ import time from pathlib import Path from typing import List, Tuple -import subprocess import setuptools from setuptools.command.egg_info import egg_info @@ -240,7 +239,7 @@ def git_check_submodules() -> None: assert bool( int(os.getenv("NVTE_RELEASE_BUILD", "0")) ), "NVTE_RELEASE_BUILD env must be set for metapackage build." - te_cuda_vers = "rocm" if rocm_build() else "cu12" + te_cuda_vers = "cu12" ext_modules = [] cmdclass = {} package_data = {} @@ -253,7 +252,7 @@ def git_check_submodules() -> None: "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], } if not rocm_build() else { - "core": [f"transformer_engine_{te_cuda_vers}=={__version__}"], + "core": [f"transformer_engine_rocm=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], } @@ -303,7 +302,7 @@ def git_check_submodules() -> None: long_description=long_description, long_description_content_type="text/x-rst", ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} if not rocm_build() else {"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, + cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, python_requires=f">={min_python_version_str()}", classifiers=["Programming Language :: Python :: 3"], install_requires=install_requires, diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index e8d9cefd6..137fa480d 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -1,5 +1,3 @@ -# This file was modified for portability to AMDGPU -# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index a08a1fe42..72797f556 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -387,6 +387,7 @@ def _check_configs(self): get_device_compute_capability(0) >= 100 and self.dropout_prob == 0.1 and self.attn_bias_type is not AttnBiasType.NO_BIAS + and not is_hip_extension() ): pytest.skip( "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index b59fe6451..9e59f4f6a 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -455,7 +455,6 @@ def run_dpa_with_cp( for tensor in tensors[4:]: assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) - i += 1 out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors ############ compare results between CP and no-CP ############ diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 05555626b..ed5a12995 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -100,9 +100,9 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: return dict(rtol=1.3e-6, atol=1e-5) if dtype == torch.float64: return dict(rtol=1e-7, atol=1e-7) - if dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + if dtype in torch_float8_e4m3_type: return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 - if dtype in (torch.float8_e5m2, torch.float8_e5m2fnuz): + if dtype in torch_float8_e5m2_type: return dict(rtol=0.25, atol=0.125) # epsilon = 0.125 raise ValueError(f"Unsupported dtype ({dtype})") diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index ec36e941f..540f7e252 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index 5701a446d..38eead606 100644 --- a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -25,6 +25,8 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" +#include "./rocm_vectorized_2d.cuh" + namespace transformer_engine { namespace dispatch { namespace mxfp8 { diff --git a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index 28e46fc7a..69e30680c 100644 --- a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -25,6 +25,8 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" +#include "./rocm_vectorized_2d.cuh" + namespace transformer_engine { namespace dispatch { namespace mxfp8 { diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 19234e9b4..8e25b3f65 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -26,6 +26,8 @@ #include "../../utils.cuh" #include "../core/common.cuh" +#include "./rocm_vectorized_2d.cuh" + namespace transformer_engine { namespace dispatch { namespace mxfp8 { diff --git a/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh index 02224a69f..49c57737c 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh @@ -67,7 +67,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, + transformer_engine::rocm::copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); @@ -108,9 +108,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); - ptx::bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, - chunk_it_offset_y, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, rows, cols); + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, + chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); } diff --git a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh index 7382b8aab..a8c02e4f8 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh @@ -122,16 +122,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Initiate bulk tensor copy if constexpr (IS_DGATED) { - copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, + transformer_engine::rocm::copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } // Act - copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, + transformer_engine::rocm::copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); // Gate - copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, + transformer_engine::rocm::copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); @@ -356,19 +356,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); if constexpr (USE_ROWWISE_SCALING) { - ptx::bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { - ptx::bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } if constexpr (USE_COLWISE_SCALING) { - ptx::bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { - ptx::bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } diff --git a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh index dc36fb42d..d5b51a2f4 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh @@ -143,11 +143,11 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int chunk_it_offset_x = chunk_offset_X; const size_t row_base = chunk_it_offset_y; if constexpr (IS_DACT) { - copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, + transformer_engine::rocm::copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, + transformer_engine::rocm::copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); __syncthreads(); @@ -290,12 +290,12 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __syncthreads(); if constexpr (USE_ROWWISE_SCALING) { - ptx::bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } if constexpr (USE_COLWISE_SCALING) { - ptx::bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index ab574256c..a0193e95f 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 48a309118..97aecf4de 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -282,6 +282,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( int64_t window_size_right, bool return_max_logit, bool cuda_graph) { using namespace transformer_engine; + NVTE_CHECK(!(return_max_logit || cuda_graph), "ROCm does not support return_max_logit and cuda_graph for fused_attn yet."); // by default, fused attn is enabled bool nvte_fused_attn = true; if (const char* env_p = std::getenv("NVTE_FUSED_ATTN") ) { diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b90cd5ce3..ae86e492d 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -18,24 +18,10 @@ class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. """ - fwd: tuple - bwd: tuple - @property - def max_fwd(self) -> float: - return self.fwd[is_fp8_fnuz()] + max_fwd: float + max_bwd: float - @property - def max_bwd(self) -> float: - return self.bwd[is_fp8_fnuz()] - -class _FormatMaxVals(Enum): - """ - Tuples of FP8 (OCP, FNUZ) values for different formats. - """ - E2M1 = (6, 6) - E4M3 = (448, 240) - E5M2 = (57344, 57344) class Format(Enum): """ @@ -54,11 +40,15 @@ class Format(Enum): FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format """ - #TODO: Change max vals after rocm support MXFP4 - E2M1 = _FormatHelper(fwd=_FormatMaxVals.E2M1.value, bwd=_FormatMaxVals.E2M1.value) - E4M3 = _FormatHelper(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E4M3.value) - E5M2 = _FormatHelper(fwd=_FormatMaxVals.E5M2.value, bwd=_FormatMaxVals.E5M2.value) - HYBRID = _FormatHelper(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E5M2.value) + + E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) + if te_rocm_build: + max_e4m3_val = 240 if is_fp8_fnuz() else 448 + E4M3 = _FormatHelper(max_fwd=max_e4m3_val, max_bwd=max_e4m3_val) + else: + E4M3 = _FormatHelper(max_fwd=448, max_bwd=448) + E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344) + HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) @dataclass(frozen=True) diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index 69b44494b..55aa2907e 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 881b134e7..a00c30a9c 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -24,8 +24,8 @@ namespace { #define __ldg(x) (*(x)) #endif -#ifndef __HIP_PLATFORM_AMD__ constexpr int MXFP8_BLOCK_SIZE = 32; +#ifndef __HIP_PLATFORM_AMD__ constexpr int NVFP4_BLOCK_SIZE = 16; constexpr __device__ __host__ int TB_DIM = 32; @@ -38,7 +38,6 @@ constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32; #else // HIPCC does not support __host__ qualifier for variables // and constexpr values do not need __device__ qualifier because they are compile-time constants -constexpr int MXFP8_BLOCK_SIZE = 32; constexpr int TB_DIM = 32; constexpr int NEW_SF_TILE_DIM_K = 16; constexpr int N_SF_PER_TD_PER_TILE = 4; diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index ef53c2670..312890db0 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -318,38 +318,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_ #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } -#ifdef __HIP_PLATFORM_AMD__ -template -__device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T *g_ptr, size_t g_start_col, - size_t g_start_row, size_t g_stride, size_t chunk_dim_y, - size_t chunk_dim_x, size_t total_rows, - size_t total_cols) { - const size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; - const size_t l_idx = threadIdx.x; - - for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { - size_t l_y = (i_vec / chunk_dim_x_vec_elements); - size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); - - size_t g_row = g_start_row + l_y; - size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; - - const T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; - VectorizedLoader shared_loader(current_sh_row_base_ptr, chunk_dim_x); - - T* current_g_row_base_ptr = g_ptr + g_row * g_stride; - VectorizedStorer global_storer(current_g_row_base_ptr, total_cols); - - shared_loader.load(l_x_vec, chunk_dim_x); - - if (g_row < total_rows) { - global_storer.storage_.scratch_ = shared_loader.storage_.scratch_; - global_storer.store(g_col_primitive_start / N_VEC, total_cols); - } - } -} -#endif //#ifdef __HIP_PLATFORM_AMD__ - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // shared::cta -> global __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( @@ -931,47 +899,7 @@ __forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -#ifdef __HIP_PLATFORM_AMD__ -// These 2d copy functions replace TMA tensormap async copies for AMD GPUs. -template -__device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t g_start_col, - size_t g_start_row, size_t g_stride, size_t chunk_dim_y, - size_t chunk_dim_x, size_t total_rows, - size_t total_cols) { - size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; - const size_t l_idx = threadIdx.x; - - for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { - size_t l_y = (i_vec / chunk_dim_x_vec_elements); - size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); - - size_t g_row = g_start_row + l_y; - size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; - - if (g_row < total_rows) { - const T* current_g_row_base_ptr = g_ptr + g_row * g_stride; - VectorizedLoaderglobal_loader(current_g_row_base_ptr, total_cols); - - T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; - VectorizedStorershared_storer(current_sh_row_base_ptr, chunk_dim_x); - - global_loader.load(g_col_primitive_start / N_VEC, total_cols); - shared_storer.storage_.scratch_ = global_loader.storage_.scratch_; - shared_storer.store(l_x_vec, chunk_dim_x); - } else { - T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; - VectorizedStorer shared_storer(current_sh_row_base_ptr, chunk_dim_x); - -#pragma unroll - for (int i = 0; i < N_VEC; ++i) { - shared_storer.separate()[i] = static_cast(0); - } - shared_storer.store(l_x_vec, chunk_dim_x); - } - } -} -#else __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, const size_t chunk_Y, const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { @@ -992,7 +920,6 @@ __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, co NVTE_DEVICE_ERROR("copy_2d_to_shared is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -#endif //#ifdef __HIP_PLATFORM_AMD__ __forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, diff --git a/transformer_engine/common/util/rocm_vectorized_2d.cuh b/transformer_engine/common/util/rocm_vectorized_2d.cuh deleted file mode 100644 index eda0f437f..000000000 --- a/transformer_engine/common/util/rocm_vectorized_2d.cuh +++ /dev/null @@ -1,13 +0,0 @@ -/************************************************************************* - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - * - * License for AMD contributions = MIT. See LICENSE for more information - ************************************************************************/ - -#pragma once - -#include "../util/vectorized_pointwise.h" - -namespace transformer_engine { - -} // namespace transformer_engine diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index ab2a4562e..ba6d01a9f 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -2786,7 +2786,7 @@ def fused_attn_bwd( # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on # sm100+ compute_capabilities = get_all_device_compute_capability() - if any(x >= 100 for x in compute_capabilities): + if any(x >= 100 for x in compute_capabilities) and not is_hip_extension(): assert not ( attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index a04a98d97..2daecedfa 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -285,7 +285,8 @@ def collective_gemm_bootstrap( and before any collective GEMM operations. Each process should call this function with its own unique process_id. """ - + if is_hip_extension(): + assert 0, "collective_gemm_bootstrap is not supported for ROCm yet." assert ( num_devices_per_process == 1 and jax.local_device_count() == 1 ), "Only single device per process is supported at the moment!" diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index a4b590250..050d0fd23 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -5,7 +5,7 @@ * * See LICENSE for license information. ************************************************************************/ -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM #include #include @@ -101,4 +101,4 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( } // namespace jax } // namespace transformer_engine -#endif // #ifndef __HIP_PLATFORM_AMD__ \ No newline at end of file +#endif // #ifndef USE_ROCM \ No newline at end of file diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.h b/transformer_engine/jax/csrc/extensions/cgemm_helper.h index 84b2b8154..03d86c168 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.h +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -7,6 +9,7 @@ #ifndef TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ #define TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ +#ifndef USE_ROCM #include #include @@ -186,4 +189,5 @@ int GetCgemmNumMaxStreams(); } // namespace jax } // namespace transformer_engine +#endif // #ifndef USE_ROCM #endif // TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f038101b2..d35b2d072 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -86,10 +86,11 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } } else { // Swizzle for NVFP4 - NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM + NVTE_ERROR("ROCm TE does not support NVFP4 yet."); } #else + NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); // Create tensor to hold swizzled scale factor TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); @@ -100,7 +101,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( // Set swizzled scales into the input tensor input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); } -#endif // #ifdef __HIP_PLATFORM_AMD__ +#endif // #ifdef USE_ROCM } return std::make_tuple(std::move(input), input_shape); @@ -285,7 +286,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i } else { #ifdef USE_ROCM //TODO: better assert - std::cerr<<"ROCm TE jax does not integrate userbuffer for now"< buffer_shape{0, 0}; DType buffer_dtype = out_dtype; @@ -771,7 +772,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t num_non_empty_gemms = lhs_list.size(); -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM if (is_mxfp8_scaling) { for (int i = 0; i < num_non_empty_gemms; i++) { // The i-th GEMM will use the (i % num_streams)-th stream to compute, diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index a0c5db5a8..c2d3d6f25 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 626c47276..9a0a87d69 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -177,7 +177,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T } if (is_quantize_colwise(quantize_layout)) { -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM if (is_nvfp4 && use_rht) { if (is_quantize_2x2x(quantize_layout)) { // Do regular rowwise quantization without RHT @@ -219,7 +219,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T return ffi_with_cuda_error_check(); } -#endif // #ifndef __HIP_PLATFORM_AMD__ +#endif // #ifndef USE_ROCM bool const is_colwise_transposed = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4; diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index b8a8809fc..95d5aea21 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -134,6 +134,8 @@ def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: def _check_fp4_support(gpu_arch) -> Tuple[bool, str]: """Check if FP4 is supported for the given GPU architecture.""" + if is_hip_extension(): + return False, "FP4 not yet supported for ROCm" if gpu_arch < 100: # pre-blackwell return False, "Device compute capability 10.0 or higher required for NVFP4 execution." if get_cublasLt_version() < 120800: diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index 0b958d3ad..619b6070b 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -122,7 +122,10 @@ def get_cuda_major_version() -> int: # us to detect CUDA version dynamically during compilation and # choose the correct wheel for te core lib. __version__ = te_version() - te_core = f"transformer_engine_cu{get_cuda_major_version()}=={__version__}" + if not rocm_build(): + te_core = f"transformer_engine_cu{get_cuda_major_version()}=={__version__}" + else: + te_core = f"transformer_engine_rocm=={__version__}" install_requires = install_requirements() + [te_core] # Configure package diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index a0aaab1f3..038ebc3c0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -219,6 +219,8 @@ def __init__( softmax_type: str = "vanilla", return_max_logit: Optional[bool] = False, ) -> None: + if IS_HIP_EXTENSION: + assert not return_max_logit, "ROCm does not support return_max_logit yet." super().__init__() self.softmax_scale = softmax_scale @@ -1676,6 +1678,8 @@ def __init__( softmax_type: str = "vanilla", return_max_logit: Optional[bool] = False, ) -> None: + if IS_HIP_EXTENSION: + assert not return_max_logit, "ROCm does not support return_max_logit yet." super().__init__() self.softmax_scale = softmax_scale diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 4157e8d3a..ef601e4c4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -319,6 +319,8 @@ def __init__( softmax_type: str = "vanilla", return_max_logit: Optional[bool] = False, ) -> None: + if IS_HIP_EXTENSION: + assert not return_max_logit, "ROCm does not support return_max_logit yet." super().__init__() self.logger = logging.getLogger("DotProductAttention") diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index c8da3161b..54fe21d81 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -482,7 +482,7 @@ def get_attention_backend( fp8_recipe = fp8_meta["recipe"] if fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] - if use_fused_attention and fp8_recipe.float8_current_scaling(): + if use_fused_attention and fp8_recipe.float8_current_scaling() and not IS_HIP_EXTENSION: if device_compute_capability < (10, 0): logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") use_fused_attention = False @@ -502,7 +502,7 @@ def get_attention_backend( ) use_fused_attention = False - if device_compute_capability == (12, 0): + if device_compute_capability == (12, 0) and not IS_HIP_EXTENSION: if use_flash_attention: logger.debug( "Disabling FlashAttention as FP8 is not supported" @@ -599,6 +599,7 @@ def get_attention_backend( device_compute_capability == (12, 0) and (head_dim_qk > 128 or head_dim_qk % 8 != 0) and is_training + and not IS_HIP_EXTENSION ): if use_fused_attention: logger.debug( @@ -679,7 +680,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False - if device_compute_capability == (12, 0): + if device_compute_capability == (12, 0) and not IS_HIP_EXTENSION: if use_fused_attention: logger.debug( "Disabling FusedAttention as qkv_format = thd is" diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 852dcdb59..e5492ebc6 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -264,6 +264,10 @@ def fused_attn_fwd( max_logit: if return_max_logit = True, shape [h] and same data type as O; otherwise None """ + if IS_HIP_EXTENSION: + assert not return_max_logit, "ROCm does not support return_max_logit yet." + assert not cuda_graph, "ROCm does not support cuda_graph." + if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index e1a78d49a..59f57743b 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -12,7 +12,7 @@ #include "pybind.h" #include "transformer_engine/transformer_engine.h" -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM #include "common/common.h" #endif @@ -312,7 +312,7 @@ size_t roundup(const size_t value, const size_t multiple) { return ((value + multiple - 1) / multiple) * multiple; } -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM inline bool nvte_use_atomic_amax() { const char *env_p = std::getenv("NVTE_USE_ATOMIC_AMAX"); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 6936d6bc8..205605312 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -41,7 +41,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { impl = Impl::FUSED_ACTIVATION_AMAX_FP8; -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM } #else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { @@ -105,7 +105,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); } break; -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_ACTIVATION_AMAX_NVFP4: // Compute activation and amax in high precision, then quantize to NVFP4 { @@ -159,7 +159,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { impl = Impl::FUSED_ACTIVATION_AMAX_FP8; -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM } #else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { @@ -223,7 +223,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); } break; -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_ACTIVATION_AMAX_NVFP4: // Compute activation and amax in high precision, then quantize to NVFP4 { diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 1d3e27a14..e8a735966 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -151,7 +151,7 @@ std::vector dact_dbias( impl = Impl::FUSED_DACT_DBIAS_QUANTIZE; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { impl = Impl::FUSED_DACT_AMAX_FP8; -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM } #else } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { @@ -224,7 +224,7 @@ std::vector dact_dbias( fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); break; } -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_DACT_AMAX_NVFP4: // Fused dact-amax kernel, unfused dbias and NVFP4 quantize { diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f3c77a332..8fc4e1e97 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -493,7 +493,7 @@ std::tuple, std::vector> bulk_allocate_mx return retval; } -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM // allocate fp4 data, fp8 scalings, and amax values // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate @@ -694,7 +694,7 @@ std::tuple, std::vector> bulk_allocate_nv return retval; } -#endif // #ifndef __HIP_PLATFORM_AMD__ +#endif // #ifndef USE_ROCM } // namespace @@ -793,7 +793,7 @@ std::vector split_quantize(const at::Tensor &tensor, } std::tie(output_py_list, output_cpp_list) = bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM } else if (is_nvfp4) { // NVFP4: construct output tensors with bulk allocations std::vector nvfp4_quantizers; diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 805579ff4..839bb694a 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -120,7 +120,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); impl = Impl::FUSED_NORM_AMAX_FP8; -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM } #else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { @@ -152,7 +152,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); kernel_out_nvte = &unquantized_out_nvte; } break; -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); std::tie(unquantized_out_nvte, unquantized_out) = @@ -197,7 +197,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); } break; -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); @@ -352,7 +352,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); impl = Impl::FUSED_NORM_AMAX_FP8; -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM } #else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { @@ -384,7 +384,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); kernel_out_nvte = &unquantized_out_nvte; } break; -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); std::tie(unquantized_out_nvte, unquantized_out) = @@ -427,7 +427,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); } break; -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index fafdc3761..577a938f2 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -28,7 +28,7 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { DType::kFloat8E4M3, // It doesn't matter because we only compute amax. amax_ptr); -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM at::Tensor ws = allocate_amax_workspace(te_input); TensorWrapper tw = makeTransformerEngineTensor(ws); nvte_compute_amax_with_workspace(te_input.data(), fake_te_output.data(), diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 1c1855669..a84641364 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7240c3bf3..90ed2a99f 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -516,7 +516,7 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te // Compute amax if (compute_amax) { -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM at::Tensor ws = allocate_amax_workspace(input); TensorWrapper tw = makeTransformerEngineTensor(ws); NVTE_SCOPED_GIL_RELEASE({ @@ -1143,7 +1143,7 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s return scale_shape; } -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->with_rht = quantizer.attr("with_rht").cast(); diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index b36302db2..f937b3de9 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -1,5 +1,3 @@ -# This file was modified for portability to AMDGPU -# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0ad1e86a4..634c188ce 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 915527736..d8dff33d5 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 73f926c61..12a87d4bd 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -155,9 +155,13 @@ def run(self): # us to detect CUDA version dynamically during compilation and # choose the correct wheel for te core lib. __version__ = te_version() - cuda_major_version = parse(torch.version.cuda).major - assert cuda_major_version in (12, 13), f"Unsupported cuda version {torch.version.cuda}." - te_core = f"transformer_engine_cu{cuda_major_version}=={__version__}" + if not rocm_build(): + cuda_major_version = parse(torch.version.cuda).major + assert cuda_major_version in (12, 13), f"Unsupported cuda version {torch.version.cuda}." + te_core = f"transformer_engine_cu{cuda_major_version}=={__version__}" + install_requires = install_requirements() + [te_core] + else: + te_core = f"transformer_engine_rocm=={__version__}" install_requires = install_requirements() + [te_core] # Configure package diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 8f741b7f2..316733e31 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 86acb7932..12c62437d 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -21,6 +21,8 @@ __all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"] +if IS_HIP_EXTENSION: + __all__.extend(["is_mi200", "is_mi308", "is_fp8_fnuz"]) def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: """Check if any of the given tensors require gradient.""" From 0385852c1f825f14410a4cb071e256050f568134 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 3 Mar 2026 14:53:26 -0600 Subject: [PATCH 12/41] _FormatHelperFP8 and missing file add --- .../common/cast/mxfp8/rocm_vectorized_2d.cuh | 81 +++++++++++++++++++ transformer_engine/common/recipe/__init__.py | 33 +++++--- .../jax/csrc/extensions/amax.cpp | 2 +- 3 files changed, 106 insertions(+), 10 deletions(-) create mode 100644 transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh diff --git a/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh b/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh new file mode 100644 index 000000000..50474f308 --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh @@ -0,0 +1,81 @@ +/************************************************************************* + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once + +#include "../../util/vectorized_pointwise.h" + +namespace transformer_engine::rocm { +// These 2d copy functions replace TMA tensormap async copies for AMD GPUs. +template +__device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t g_start_col, + size_t g_start_row, size_t g_stride, size_t chunk_dim_y, + size_t chunk_dim_x, size_t total_rows, + size_t total_cols) { + size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; + const size_t l_idx = threadIdx.x; + + for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { + size_t l_y = (i_vec / chunk_dim_x_vec_elements); + size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); + + size_t g_row = g_start_row + l_y; + size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; + + if (g_row < total_rows) { + const T* current_g_row_base_ptr = g_ptr + g_row * g_stride; + VectorizedLoaderglobal_loader(current_g_row_base_ptr, total_cols); + + T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; + VectorizedStorershared_storer(current_sh_row_base_ptr, chunk_dim_x); + + global_loader.load(g_col_primitive_start / N_VEC, total_cols); + shared_storer.storage_.scratch_ = global_loader.storage_.scratch_; + shared_storer.store(l_x_vec, chunk_dim_x); + + } else { + T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; + VectorizedStorer shared_storer(current_sh_row_base_ptr, chunk_dim_x); + +#pragma unroll + for (int i = 0; i < N_VEC; ++i) { + shared_storer.separate()[i] = static_cast(0); + } + shared_storer.store(l_x_vec, chunk_dim_x); + } + } +} + +template +__device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T *g_ptr, size_t g_start_col, + size_t g_start_row, size_t g_stride, size_t chunk_dim_y, + size_t chunk_dim_x, size_t total_rows, + size_t total_cols) { + const size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; + const size_t l_idx = threadIdx.x; + + for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { + size_t l_y = (i_vec / chunk_dim_x_vec_elements); + size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); + + size_t g_row = g_start_row + l_y; + size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; + + const T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; + VectorizedLoader shared_loader(current_sh_row_base_ptr, chunk_dim_x); + + T* current_g_row_base_ptr = g_ptr + g_row * g_stride; + VectorizedStorer global_storer(current_g_row_base_ptr, total_cols); + + shared_loader.load(l_x_vec, chunk_dim_x); + + if (g_row < total_rows) { + global_storer.storage_.scratch_ = shared_loader.storage_.scratch_; + global_storer.store(g_col_primitive_start / N_VEC, total_cols); + } + } +} +} // namespace transformer_engine::rocm \ No newline at end of file diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index ae86e492d..223f7a720 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -18,10 +18,30 @@ class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. """ - max_fwd: float max_bwd: float +class _FormatHelperFP8(NamedTuple): + """ + Stores max FP8 values for fprop and bprop a `Format`. + """ + fwd: tuple + bwd: tuple + + @property + def max_fwd(self) -> float: + return self.fwd[is_fp8_fnuz()] + + @property + def max_bwd(self) -> float: + return self.bwd[is_fp8_fnuz()] + +class _FormatMaxVals(Enum): + """ + Tuples of FP8 (OCP, FNUZ) values for different formats. + """ + E4M3 = (448, 240) + E5M2 = (57344, 57344) class Format(Enum): """ @@ -40,15 +60,10 @@ class Format(Enum): FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format """ - E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) - if te_rocm_build: - max_e4m3_val = 240 if is_fp8_fnuz() else 448 - E4M3 = _FormatHelper(max_fwd=max_e4m3_val, max_bwd=max_e4m3_val) - else: - E4M3 = _FormatHelper(max_fwd=448, max_bwd=448) - E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344) - HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) + E4M3 = _FormatHelperFP8(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E4M3.value) + E5M2 = _FormatHelperFP8(fwd=_FormatMaxVals.E5M2.value, bwd=_FormatMaxVals.E5M2.value) + HYBRID = _FormatHelperFP8(fwd=E4M3.fwd, bwd=E5M2.bwd) @dataclass(frozen=True) diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index 050d0fd23..aa40a8e35 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -101,4 +101,4 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( } // namespace jax } // namespace transformer_engine -#endif // #ifndef USE_ROCM \ No newline at end of file +#endif // #ifndef USE_ROCM From 46d382db16b620e02c06c01e822d413f23ddd898 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 3 Mar 2026 14:59:50 -0600 Subject: [PATCH 13/41] add use_async_d2h_group_size as a test parameter --- tests/jax/test_custom_call_compute.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 3b9ee0034..9303d6da8 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1786,11 +1786,14 @@ def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) @pytest_parametrize_wrapper("layout", ["NN"]) - def test_grouped_gemm_fp16(self, dtype, input_shape, layout): + @pytest_parametrize_wrapper("use_async_d2h_group_size", [True, False]) + def test_grouped_gemm_fp16(self, dtype, input_shape, layout, use_async_d2h_group_size): lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( dtype, input_shape, layout ) - if not is_hip_extension(): + if use_async_d2h_group_size: + if is_hip_extension(): + pytest.skip("ROCm does not support use_async_d2h_group_sizes yet.") num_gemms = input_shape[0] _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))( group_sizes, @@ -1806,7 +1809,7 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): rhs, group_sizes, contracting_dims, - use_async_d2h_group_sizes=not is_hip_extension(), + use_async_d2h_group_sizes=use_async_d2h_group_size, ) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) From 15416f1713980bb50b8c6c7fd2f1caa64578c4bc Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 3 Mar 2026 17:26:59 -0600 Subject: [PATCH 14/41] enable FP4 tests --- tests/cpp/operator/CMakeLists.txt | 2 +- .../cpp/operator/test_cast_nvfp4_transpose.cu | 26 ++++++++++++++++--- tests/cpp/test_common.h | 13 +++++++--- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index d3b75bbbf..65c3c1702 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -11,6 +11,7 @@ list(APPEND test_cuda_sources test_cast_dbias_dgelu.cu test_cast_gated_swiglu.cu test_cast_mxfp8_gated_swiglu.cu + test_cast_nvfp4_transpose.cu test_qdq.cu test_cast_mxfp8.cu test_dequantize_mxfp8.cu @@ -32,7 +33,6 @@ list(APPEND test_cuda_sources ../test_common.cu) if(USE_CUDA) list(APPEND test_cuda_sources - test_cast_nvfp4_transpose.cu test_cast_float8blockwise.cu test_swizzle.cu) else() diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index afd7927da..014441b3f 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -6,7 +8,9 @@ #include #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif #include #include @@ -31,9 +35,13 @@ enum ActivationType { }; double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) { +#ifdef __HIP_PLATFORM_AMD__ + const __half2_raw raw_truncated_to_fp4e2m1_pair = + __hip_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__hip_fp4x2_storage_t*>(&fp4_pair), __HIP_E2M1); +#else const __half2_raw raw_truncated_to_fp4e2m1_pair = __nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1); - +#endif const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair); const double truncated_to_fp4e2m1_x = static_cast(truncated_to_fp4e2m1_pair.x); const double truncated_to_fp4e2m1_y = static_cast(truncated_to_fp4e2m1_pair.y); @@ -631,14 +639,24 @@ void performTest(float (*OP)(const float), const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get(); size_t scale_mismatches_num = 0; +#ifdef __HIP_PLATFORM_AMD__ + std::vector mismatches_scales_indices; +#endif + compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), ref_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif scale_mismatches_num); compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), ref_scales_t.get(), unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif scale_mismatches_num); } @@ -675,9 +693,9 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { // Skip tests for pre-Blackwell architectures - if (getDeviceComputeCapability() < blackwellComputeCapability) { - GTEST_SKIP(); - } + // if (getDeviceComputeCapability() < blackwellComputeCapability) { + // GTEST_SKIP(); + // } using namespace transformer_engine; using namespace test; diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 56154c9d9..05189b4af 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -19,12 +19,13 @@ #include #if FP4_TYPE_SUPPORTED #include -#endif +#endif //FP4_TYPE_SUPPORTED #else -#define FP4_TYPE_SUPPORTED (false) +#define FP4_TYPE_SUPPORTED (true) #include #include "amd_detail/hip_float8.h" -#endif +#include +#endif //USE_ROCM #include #include @@ -73,9 +74,15 @@ using fp8e5m2 = te_hip_fp8_e5m2; #endif //USE_ROCM using fp8e8m0 = uint8_t; #if FP4_TYPE_SUPPORTED +#ifndef USE_ROCM using fp4e2m1 = __nv_fp4_e2m1; using fp4e2m1x2 = __nv_fp4x2_e2m1; using fp4e2m1x4 = __nv_fp4x4_e2m1; +#else +using fp4e2m1 = __hip_fp4_e2m1; +using fp4e2m1x2 = __hip_fp4x2_e2m1; +using fp4e2m1x4 = __hip_fp4x4_e2m1; +#endif //USE_ROCM #endif template From bac5096fa2f2a5e5add28b3a2ec2973b81ede541 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 4 Mar 2026 16:13:50 -0600 Subject: [PATCH 15/41] rough initial version --- transformer_engine/common/CMakeLists.txt | 1 + .../common/cast/dispatch/quantize.cuh | 7 +- .../common/cast/nvfp4/core_nvfp4.cuh | 2 + .../common/cast/nvfp4/quantize_nvfp4.cuh | 2 + .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 2 + transformer_engine/common/common.h | 18 +- .../transformer_engine/transformer_engine.h | 10 - ...quantize_transpose_vector_blockwise_fp4.cu | 120 +++++---- transformer_engine/common/util/ptx.cuh | 254 +++++++++--------- 9 files changed, 229 insertions(+), 187 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 831de2b45..23b091363 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -257,6 +257,7 @@ else() gemm/rocm_gemm.cu amd_detail/system.cpp) list(APPEND transformer_engine_cuda_sources + transpose/quantize_transpose_vector_blockwise_fp4.cu fused_attn_rocm/fused_attn_aotriton.cpp fused_attn_rocm/fused_attn_ck.cpp fused_attn_rocm/utils.cpp) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 8e8993668..f26e93551 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -91,7 +91,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, dummy_workspace_tensor, stream); break; } -#ifndef __HIP_PLATFORM_AMD__ +// #ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); @@ -108,6 +108,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, (cols % 32 == 0) && output_tensor->has_data(); // Launch NVFP4 quantize kernel +#ifndef __HIP_PLATFORM_AMD__ if (use_optimized_kernel) { if (quant_config_cpp.nvfp4_2d_quantization) { nvfp4::quantize_transpose( @@ -117,6 +118,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); } } else { +#endif auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax : output_tensor->columnwise_amax; quantize_transpose_vector_blockwise_fp4( @@ -131,9 +133,12 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); +#ifndef __HIP_PLATFORM_AMD__ } +#endif break; } +#ifndef __HIP_PLATFORM_AMD__ case NVTE_BLOCK_SCALING_2D: { // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh index cff846490..e2eda60c5 100644 --- a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -12,7 +12,9 @@ #define TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif #include #include diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh index 83ad8fd40..665124e3d 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh @@ -12,7 +12,9 @@ #define TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif #include #include diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 7322bf265..4c0a565da 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -12,7 +12,9 @@ #define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif #include #include diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 5feeb600c..505706239 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -13,14 +13,22 @@ #include #define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080) #else +#ifdef __HIPCC__ +#define FP4_TYPE_SUPPORTED true +#else #define FP4_TYPE_SUPPORTED false +#endif #endif //#ifndef __HIP_PLATFORM_AMD__ #include #include #include #if FP4_TYPE_SUPPORTED +#ifndef __HIP_PLATFORM_AMD__ #include +#else +#include +#endif #endif #include @@ -361,6 +369,11 @@ using fp4e2m1x4 = __nv_fp4x4_e2m1; using bf16 = hip_bfloat16; using fp8e4m3 = te_hip_fp8_e4m3; using fp8e5m2 = te_hip_fp8_e5m2; +#if FP4_TYPE_SUPPORTED +using fp4e2m1 = __hip_fp4_e2m1; +using fp4e2m1x2 = __hip_fp4x2_e2m1; +using fp4e2m1x4 = __hip_fp4x4_e2m1; +#endif //FP4_TYPE_SUPPORTED #endif //__HIP_PLATFORM_AMD__ using e8m0_t = uint8_t; @@ -384,6 +397,9 @@ TRANSFORMER_ENGINE_TYPE_NAME(half) TRANSFORMER_ENGINE_TYPE_NAME(hip_bfloat16) TRANSFORMER_ENGINE_TYPE_NAME(te_hip_fp8_e4m3) TRANSFORMER_ENGINE_TYPE_NAME(te_hip_fp8_e5m2) +#if FP4_TYPE_SUPPORTED +TRANSFORMER_ENGINE_TYPE_NAME(__hip_fp4_e2m1) +#endif #else TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3) @@ -644,7 +660,7 @@ struct TypeInfo { switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat4E2M1: { \ - using type = __nv_fp4x2_storage_t; \ + using type = __hip_fp4x2_storage_t; \ { __VA_ARGS__ } \ } break; \ default: \ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 044e021e6..70c01c67f 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -425,13 +425,8 @@ enum class DType { kFloat8E4M3 = 7, kFloat8E5M2 = 8, kFloat8E8M0 = 9, -#ifndef __HIP_PLATFORM_AMD__ kFloat4E2M1 = 10, kNumTypes -#else - kNumTypes = 10, - kFloat4E2M1 -#endif // #ifndef __HIP_PLATFORM_AMD__ }; /*! \brief Check if TE datatype is FP8 @@ -443,17 +438,12 @@ inline bool is_fp8_dtype(const DType t) { return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2; } -#ifndef __HIP_PLATFORM_AMD__ /*! \brief Check if TE datatype is FP4 * * Return true if TE datatype is FP4 * \param[in] DType TE Datatype of interest */ inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; } -#else -//TODO: fp4 types not supported on AMD GPUs -inline bool is_fp4_dtype(const DType t) { return false; } -#endif // #ifndef __HIP_PLATFORM_AMD__ /*! \brief Check if TE datatype is high precision (FP32, FP16, BF16) * diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index b49a54fbd..24e4bc42d 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -5,13 +5,17 @@ ************************************************************************/ #include +#ifndef __HIP_PLATFORM_AMD__ #include #include +#endif #include #include #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif #include #include "common/common.h" @@ -23,7 +27,7 @@ namespace transformer_engine { -#if CUDA_VERSION >= 12080 +// #if CUDA_VERSION >= 12080 namespace quantize_transpose_nvfp4 { namespace { @@ -155,12 +159,26 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); // for 2D block scaling, we need to reduce amax in warp +#ifdef __HIP_PLATFORM_AMD__ +static __device__ constexpr uint64_t WARP_REDUCE_AMAX_GROUP_MASKS[8] = { + 0x0101010101010101ULL, 0x0202020202020202ULL, + 0x0404040404040404ULL, 0x0808080808080808ULL, + 0x1010101010101010ULL, 0x2020202020202020ULL, + 0x4040404040404040ULL, 0x8080808080808080ULL}; +#else static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = { - 0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080}; + 0x01010101, 0x02020202, 0x04040404, 0x08080808, + 0x10101010, 0x20202020, 0x40404040, 0x80808080}; +#endif // max for every group_size elements in warp template -__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) { +__device__ __forceinline__ float groupMax(float val, +#ifdef __HIP_PLATFORM_AMD__ + uint64_t groupMask) { +#else + unsigned int groupMask) { +#endif for (int offset = group_size / 2; offset > 0; offset /= 2) { val = max(val, __shfl_down_sync(groupMask, val, offset * shfl_down_stride)); } @@ -189,7 +207,7 @@ __device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scal } __device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { - constexpr float fp8_max = TypeExtrema::max; + const float fp8_max = TypeExtrema::max; constexpr float fp4_max = TypeExtrema::max; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return max value of float32 @@ -257,56 +275,56 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s return ((rb * cbg_cnt + cbg) * kRowsPerBaseBlockCol + d3) * 16 + d4 * kColsPerBaseBlockCol + d5; } -__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding( +__device__ __forceinline__ __hip_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding( const float2 in01, const float2 in23, const uint32_t rbits) { - constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; - if constexpr (has_rs) { - uint16_t out_4x; - asm volatile( - "{\n" - "cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t" - "}" - : "=h"(out_4x) - : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); - return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); - } else { + // constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + // if constexpr (has_rs) { + // uint16_t out_4x; + // asm volatile( + // "{\n" + // "cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t" + // "}" + // : "=h"(out_4x) + // : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); + // return *reinterpret_cast<__hip_fp4x4_e2m1*>(&out_4x); + // } else { NVTE_DEVICE_ERROR( "FP4 cvt.rs PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); uint16_t dummy = 0; - return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); - } + return *reinterpret_cast<__hip_fp4x4_e2m1*>(&dummy); + // } } -__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, +__device__ __forceinline__ __hip_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, const float2 in23, const uint32_t rbits) { - constexpr bool has_fp4 = ARCH_BLACKWELL_FAMILY; - if constexpr (has_fp4) { - // NOTE: rbits unused for rn. - uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing. - asm volatile( - "{\n" - ".reg.b8 f0; \n\t" - ".reg.b8 f1; \n\t" - "cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t" - "mov.b32 %0, {f0, f1, f0, f1};\n\t" - "}" - : "=r"(out_4x) - : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); - return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0]; - } else { + // constexpr bool has_fp4 = ARCH_BLACKWELL_FAMILY; + // if constexpr (has_fp4) { + // // NOTE: rbits unused for rn. + // uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing. + // asm volatile( + // "{\n" + // ".reg.b8 f0; \n\t" + // ".reg.b8 f1; \n\t" + // "cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t" + // "cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t" + // "mov.b32 %0, {f0, f1, f0, f1};\n\t" + // "}" + // : "=r"(out_4x) + // : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); + // return reinterpret_cast<__hip_fp4x4_e2m1*>(&out_4x)[0]; + // } else { NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); uint16_t dummy = 0; - return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); - } + return *reinterpret_cast<__hip_fp4x4_e2m1*>(&dummy); + // } } template -__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, +__device__ __forceinline__ __hip_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, const uint32_t rbits) { if constexpr (kApplyStochasticRounding) { return cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, rbits); @@ -540,11 +558,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo f2_b.x = ComputeOutputFP4(smem_vec[i + 1].data.elt[0], encode_scale); f2_b.y = ComputeOutputFP4(smem_vec[i + 1].data.elt[1], encode_scale); const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; - // Convert to __nv_fp4x4_e2m1 - __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + // Convert to __hip_fp4x4_e2m1 + __hip_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); - output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; - output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + output_vec.data.elt[i] = reinterpret_cast<__hip_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__hip_fp4x2_storage_t*>(&out_4x)[1]; } // Step 2.7: Store output_c if constexpr (kAligned) { @@ -668,11 +686,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo encode_scale); const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; - // Convert to __nv_fp4x4_e2m1 - __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + // Convert to __hip_fp4x4_e2m1 + __hip_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); - output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; - output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + output_vec.data.elt[i] = reinterpret_cast<__hip_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__hip_fp4x2_storage_t*>(&out_4x)[1]; } // Step 3.7: Store output_t if constexpr (kAligned) { @@ -697,7 +715,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } // namespace } // namespace quantize_transpose_nvfp4 -#endif // CUDA_VERSION >= 12080 +// #endif // CUDA_VERSION >= 12080 namespace detail { @@ -709,7 +727,7 @@ void quantize_transpose_vector_blockwise_fp4( const NVTETensor rng_state_tensor, const bool use_2d_quantization, const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); -#if CUDA_VERSION >= 12080 +// #if CUDA_VERSION >= 12080 // pow 2 scale is for MXFP4 since it's using E8M0 scaling // raise error if pow2_scale is true @@ -830,9 +848,9 @@ void quantize_transpose_vector_blockwise_fp4( ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); -#else - NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); -#endif // CUDA_VERSION >= 12080 +// #else +// NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +// #endif // CUDA_VERSION >= 12080 } } // namespace detail diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index ef53c2670..e93549db5 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -495,9 +495,15 @@ static_assert(sizeof(fp8e4m3x2) == 2); static_assert(sizeof(fp8e5m2x2) == 2); #if FP4_TYPE_SUPPORTED +#ifndef __HIP_PLATFORM_AMD__ using fp4e2m1 = __nv_fp4_e2m1; using fp4e2m1x2 = __nv_fp4x2_e2m1; using fp4e2m1x4 = __nv_fp4x4_e2m1; +#else +using fp4e2m1 = __hip_fp4_e2m1; +using fp4e2m1x2 = __hip_fp4x2_e2m1; +using fp4e2m1x4 = __hip_fp4x4_e2m1; +#endif static_assert(sizeof(fp4e2m1x2) == 1); static_assert(sizeof(fp4e2m1x4) == 2); @@ -521,86 +527,86 @@ __device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, cons __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( const uint64_t in_4x, const float2 scale, const uint32_t rbits) { uint16_t out_4x = 0; - constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; - if constexpr (has_rs) { - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b16 v0_bf16; \n\t" - ".reg.b16 v1_bf16; \n\t" - ".reg.b16 v2_bf16; \n\t" - ".reg.b16 v3_bf16; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" - "cvt.f32.bf16 v0, v0_bf16; \n\t" - "cvt.f32.bf16 v1, v1_bf16; \n\t" - "cvt.f32.bf16 v2, v2_bf16; \n\t" - "cvt.f32.bf16 v3, v3_bf16; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order - "}" - : "=h"(out_4x) - : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); - } else { + // constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + // if constexpr (has_rs) { + // asm volatile( + // "{\n" + // ".reg.b64 v01; \n\t" + // ".reg.b64 v23; \n\t" + // ".reg.b16 v0_bf16; \n\t" + // ".reg.b16 v1_bf16; \n\t" + // ".reg.b16 v2_bf16; \n\t" + // ".reg.b16 v3_bf16; \n\t" + // ".reg.b32 v0; \n\t" + // ".reg.b32 v1; \n\t" + // ".reg.b32 v2; \n\t" + // ".reg.b32 v3; \n\t" + // "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + // "cvt.f32.bf16 v0, v0_bf16; \n\t" + // "cvt.f32.bf16 v1, v1_bf16; \n\t" + // "cvt.f32.bf16 v2, v2_bf16; \n\t" + // "cvt.f32.bf16 v3, v3_bf16; \n\t" + // "mov.b64 v01, {v0, v1}; \n\t" + // "mov.b64 v23, {v2, v3}; \n\t" + // "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + // "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + // "mov.b64 {v1, v0}, v01; \n\t" + // "mov.b64 {v3, v2}, v23; \n\t" + // "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order + // "}" + // : "=h"(out_4x) + // : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); + // } else { NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); - } + // } return *reinterpret_cast(&out_4x); } __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, const float2 scale, const uint32_t rbits) { - constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + // constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. - if constexpr (is_blackwell) { - // NOTE: rbits unused for rn. - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b16 v0_bf16; \n\t" - ".reg.b16 v1_bf16; \n\t" - ".reg.b16 v2_bf16; \n\t" - ".reg.b16 v3_bf16; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - ".reg.b8 f0; \n\t" - ".reg.b8 f1; \n\t" - "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" - "cvt.f32.bf16 v0, v0_bf16; \n\t" - "cvt.f32.bf16 v1, v1_bf16; \n\t" - "cvt.f32.bf16 v2, v2_bf16; \n\t" - "cvt.f32.bf16 v3, v3_bf16; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" - "mov.b32 %0, {f0, f1, f0, f1};\n\t" - "}" - : "=r"(out_4x) - : "l"(in_4x), "l"(reinterpret_cast(scale))); - } else { + // if constexpr (is_blackwell) { + // // NOTE: rbits unused for rn. + // asm volatile( + // "{\n" + // ".reg.b64 v01; \n\t" + // ".reg.b64 v23; \n\t" + // ".reg.b16 v0_bf16; \n\t" + // ".reg.b16 v1_bf16; \n\t" + // ".reg.b16 v2_bf16; \n\t" + // ".reg.b16 v3_bf16; \n\t" + // ".reg.b32 v0; \n\t" + // ".reg.b32 v1; \n\t" + // ".reg.b32 v2; \n\t" + // ".reg.b32 v3; \n\t" + // ".reg.b8 f0; \n\t" + // ".reg.b8 f1; \n\t" + // "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + // "cvt.f32.bf16 v0, v0_bf16; \n\t" + // "cvt.f32.bf16 v1, v1_bf16; \n\t" + // "cvt.f32.bf16 v2, v2_bf16; \n\t" + // "cvt.f32.bf16 v3, v3_bf16; \n\t" + // "mov.b64 v01, {v0, v1}; \n\t" + // "mov.b64 v23, {v2, v3}; \n\t" + // "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + // "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + // "mov.b64 {v1, v0}, v01; \n\t" + // "mov.b64 {v3, v2}, v23; \n\t" + // "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + // "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + // "mov.b32 %0, {f0, f1, f0, f1};\n\t" + // "}" + // : "=r"(out_4x) + // : "l"(in_4x), "l"(reinterpret_cast(scale))); + // } else { NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); - } + // } return reinterpret_cast(&out_4x)[0]; } @@ -618,35 +624,35 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { uint16_t out_4x = 0; - constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; - if constexpr (has_rs) { - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - "mov.b64 {v0, v1} , %1; \n\t" - "mov.b64 {v2, v3} , %2; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order - "}" - : "=h"(out_4x) - : "l"(reinterpret_cast(in01)), - "l"(reinterpret_cast(in23)), - "l"(reinterpret_cast(scale)), "r"(rbits)); - } else { + // constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + // if constexpr (has_rs) { + // asm volatile( + // "{\n" + // ".reg.b64 v01; \n\t" + // ".reg.b64 v23; \n\t" + // ".reg.b32 v0; \n\t" + // ".reg.b32 v1; \n\t" + // ".reg.b32 v2; \n\t" + // ".reg.b32 v3; \n\t" + // "mov.b64 {v0, v1} , %1; \n\t" + // "mov.b64 {v2, v3} , %2; \n\t" + // "mov.b64 v01, {v0, v1}; \n\t" + // "mov.b64 v23, {v2, v3}; \n\t" + // "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + // "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + // "mov.b64 {v1, v0}, v01; \n\t" + // "mov.b64 {v3, v2}, v23; \n\t" + // "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order + // "}" + // : "=h"(out_4x) + // : "l"(reinterpret_cast(in01)), + // "l"(reinterpret_cast(in23)), + // "l"(reinterpret_cast(scale)), "r"(rbits)); + // } else { NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); - } + // } return *reinterpret_cast(&out_4x); } @@ -654,41 +660,41 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 const float2 in23, const float2 scale, const uint32_t rbits) { - constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + // constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. - if constexpr (is_blackwell) { - // NOTE: rbits unused for rn. - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - ".reg.b8 f0; \n\t" - ".reg.b8 f1; \n\t" - "mov.b64 {v0, v1} , %1; \n\t" - "mov.b64 {v2, v3} , %2; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" - "mov.b32 %0, {f0, f1, f0, f1};\n\t" - "}" - : "=r"(out_4x) - : "l"(reinterpret_cast(in01)), - "l"(reinterpret_cast(in23)), - "l"(reinterpret_cast(scale))); - } else { + // if constexpr (is_blackwell) { + // // NOTE: rbits unused for rn. + // asm volatile( + // "{\n" + // ".reg.b64 v01; \n\t" + // ".reg.b64 v23; \n\t" + // ".reg.b32 v0; \n\t" + // ".reg.b32 v1; \n\t" + // ".reg.b32 v2; \n\t" + // ".reg.b32 v3; \n\t" + // ".reg.b8 f0; \n\t" + // ".reg.b8 f1; \n\t" + // "mov.b64 {v0, v1} , %1; \n\t" + // "mov.b64 {v2, v3} , %2; \n\t" + // "mov.b64 v01, {v0, v1}; \n\t" + // "mov.b64 v23, {v2, v3}; \n\t" + // "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + // "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + // "mov.b64 {v1, v0}, v01; \n\t" + // "mov.b64 {v3, v2}, v23; \n\t" + // "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + // "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + // "mov.b32 %0, {f0, f1, f0, f1};\n\t" + // "}" + // : "=r"(out_4x) + // : "l"(reinterpret_cast(in01)), + // "l"(reinterpret_cast(in23)), + // "l"(reinterpret_cast(scale))); + // } else { NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); - } + // } return reinterpret_cast(&out_4x)[0]; } From da242239b333c8a54cdf7f74459002efe7016dd1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 4 Mar 2026 18:09:02 -0600 Subject: [PATCH 16/41] initial working version --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 14 +++-- .../common/cast/dispatch/quantize.cuh | 11 +++- ...quantize_transpose_vector_blockwise_fp4.cu | 54 ++++++++++++------- 3 files changed, 55 insertions(+), 24 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 014441b3f..47f2d39ab 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -34,18 +34,26 @@ enum ActivationType { SReLU }; +static constexpr float E2M1_LUT[16] = { + 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, + -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f, +}; + double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) { #ifdef __HIP_PLATFORM_AMD__ - const __half2_raw raw_truncated_to_fp4e2m1_pair = - __hip_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__hip_fp4x2_storage_t*>(&fp4_pair), __HIP_E2M1); + uint8_t raw = *reinterpret_cast(&fp4_pair); + // Decode manually + float lo = E2M1_LUT[raw & 0xF]; + float hi = E2M1_LUT[(raw >> 4) & 0xF]; + return {static_cast(lo), static_cast(hi)}; #else const __half2_raw raw_truncated_to_fp4e2m1_pair = __nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1); -#endif const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair); const double truncated_to_fp4e2m1_x = static_cast(truncated_to_fp4e2m1_pair.x); const double truncated_to_fp4e2m1_y = static_cast(truncated_to_fp4e2m1_pair.y); return {truncated_to_fp4e2m1_x, truncated_to_fp4e2m1_y}; +#endif } template diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index f26e93551..0378be46f 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -121,8 +121,17 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, #endif auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax : output_tensor->columnwise_amax; + + // If amax was not explicitly set, fall back to the scale field which + // holds the same value when set via set_scale(). + NVTE_CHECK(global_amax.dptr != nullptr || output_tensor->scale.dptr != nullptr, + "NVFP4 quantization requires global_amax (output_tensor->amax) " + "or scale to be set. Call output.set_scale(amax_value) before quantizing."); + const SimpleTensor& effective_amax = + (global_amax.dptr != nullptr) ? global_amax : output_tensor->scale; + quantize_transpose_vector_blockwise_fp4( - /*input=*/input_tensor->data, /*global_amax=*/global_amax, + /*input=*/input_tensor->data, /*global_amax=*/effective_amax, /*scale_inv=*/output_tensor->scale_inv, /*scale_inv_t=*/output_tensor->columnwise_scale_inv, /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 24e4bc42d..4387e916a 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -167,13 +167,12 @@ static __device__ constexpr uint64_t WARP_REDUCE_AMAX_GROUP_MASKS[8] = { 0x4040404040404040ULL, 0x8080808080808080ULL}; #else static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = { - 0x01010101, 0x02020202, 0x04040404, 0x08080808, - 0x10101010, 0x20202020, 0x40404040, 0x80808080}; + 0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080}; #endif // max for every group_size elements in warp template -__device__ __forceinline__ float groupMax(float val, +__device__ __forceinline__ float groupMax(float val, #ifdef __HIP_PLATFORM_AMD__ uint64_t groupMask) { #else @@ -299,28 +298,43 @@ __device__ __forceinline__ __hip_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_r __device__ __forceinline__ __hip_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, const float2 in23, const uint32_t rbits) { - // constexpr bool has_fp4 = ARCH_BLACKWELL_FAMILY; - // if constexpr (has_fp4) { - // // NOTE: rbits unused for rn. - // uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing. - // asm volatile( - // "{\n" - // ".reg.b8 f0; \n\t" - // ".reg.b8 f1; \n\t" - // "cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t" - // "cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t" - // "mov.b32 %0, {f0, f1, f0, f1};\n\t" - // "}" - // : "=r"(out_4x) - // : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); - // return reinterpret_cast<__hip_fp4x4_e2m1*>(&out_4x)[0]; - // } else { +#ifdef __HIP_PLATFORM_AMD__ + const __hip_fp4_storage_t q0 = __hip_cvt_float_to_fp4(in01.x, __HIP_E2M1, hipRoundNearest); + const __hip_fp4_storage_t q1 = __hip_cvt_float_to_fp4(in01.y, __HIP_E2M1, hipRoundNearest); + const __hip_fp4_storage_t q2 = __hip_cvt_float_to_fp4(in23.x, __HIP_E2M1, hipRoundNearest); + const __hip_fp4_storage_t q3 = __hip_cvt_float_to_fp4(in23.y, __HIP_E2M1, hipRoundNearest); + + uint16_t packed = static_cast( + (q0 & 0xFu) + | ((q1 & 0xFu) << 4) + | ((q2 & 0xFu) << 8) + | ((q3 & 0xFu) << 12)); + + return *reinterpret_cast<__hip_fp4x4_e2m1*>(&packed); +#else + constexpr bool has_fp4 = ARCH_BLACKWELL_FAMILY; + if constexpr (has_fp4) { + // NOTE: rbits unused for rn. + uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing. + asm volatile( + "{\n" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); + return reinterpret_cast<__hip_fp4x4_e2m1*>(&out_4x)[0]; + } else { NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); uint16_t dummy = 0; return *reinterpret_cast<__hip_fp4x4_e2m1*>(&dummy); - // } + } + #endif } template From c03b7bb9b2aca9a5a347d519b2c1f377fbbcee03 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Thu, 5 Mar 2026 10:56:50 -0600 Subject: [PATCH 17/41] Addressing comments and small fixes --- build_tools/wheel_utils/build_wheels.sh | 2 +- setup.py | 1 - tests/pytorch/utils.py | 4 ++-- .../common/cast/mxfp8/dequantize_mxfp8.cuh | 2 ++ .../common/cast/mxfp8/gated_mxfp8.cuh | 2 ++ .../common/cast/mxfp8/quantize_mxfp8.cuh | 2 ++ .../common/cast/mxfp8/rocm_dequantize_mxfp8.cuh | 4 ++-- .../common/cast/mxfp8/rocm_gated_mxfp8.cuh | 14 +++++++------- .../common/cast/mxfp8/rocm_quantize_mxfp8.cuh | 8 ++++---- .../common/cast/mxfp8/rocm_vectorized_2d.cuh | 4 ++-- .../common/fused_attn_rocm/fused_attn.cpp | 4 +++- transformer_engine/common/util/ptx.cuh | 3 --- transformer_engine/jax/cpp_extensions/gemm.py | 2 +- transformer_engine/jax/setup.py | 2 +- .../dot_product_attention/dot_product_attention.py | 4 ++++ .../attention/dot_product_attention/utils.py | 2 +- .../pytorch/cpp_extensions/fused_attn.py | 2 +- .../pytorch/csrc/extensions/bias.cpp | 2 +- .../pytorch/csrc/extensions/recipe.cpp | 2 +- 19 files changed, 37 insertions(+), 29 deletions(-) diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 0e8ab68a8..9ba647296 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -67,7 +67,7 @@ if $BUILD_COMMON ; then #hipify expects python in PATH, also ninja may be installed to python bindir test -n "$PYBINDIR" && PATH="$PYBINDIR:$PATH" || true else - TE_CUDA_VERS="cu12" + TE_CUDA_VERS="${CUDA_MAJOR}" PYBINDIR=/opt/python/cp38-cp38/bin/ fi diff --git a/setup.py b/setup.py index eb241f5cb..bee036ada 100644 --- a/setup.py +++ b/setup.py @@ -239,7 +239,6 @@ def git_check_submodules() -> None: assert bool( int(os.getenv("NVTE_RELEASE_BUILD", "0")) ), "NVTE_RELEASE_BUILD env must be set for metapackage build." - te_cuda_vers = "cu12" ext_modules = [] cmdclass = {} package_data = {} diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index ed5a12995..4c75893d0 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -100,9 +100,9 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: return dict(rtol=1.3e-6, atol=1e-5) if dtype == torch.float64: return dict(rtol=1e-7, atol=1e-7) - if dtype in torch_float8_e4m3_type: + if dtype == torch_float8_e4m3_type: return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 - if dtype in torch_float8_e5m2_type: + if dtype == torch_float8_e5m2_type: return dict(rtol=0.25, atol=0.125) # epsilon = 0.125 raise ValueError(f"Unsupported dtype ({dtype})") diff --git a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index 38eead606..2309e038a 100644 --- a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -25,7 +25,9 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" +#ifdef __HIP_PLATFORM_AMD__ #include "./rocm_vectorized_2d.cuh" +#endif namespace transformer_engine { namespace dispatch { diff --git a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index 69e30680c..966191a4e 100644 --- a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -25,7 +25,9 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" +#ifdef __HIP_PLATFORM_AMD__ #include "./rocm_vectorized_2d.cuh" +#endif namespace transformer_engine { namespace dispatch { diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 8e25b3f65..beea39651 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -26,7 +26,9 @@ #include "../../utils.cuh" #include "../core/common.cuh" +#ifdef __HIP_PLATFORM_AMD__ #include "./rocm_vectorized_2d.cuh" +#endif namespace transformer_engine { namespace dispatch { diff --git a/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh index 49c57737c..98492bdf6 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh @@ -67,7 +67,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; - transformer_engine::rocm::copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, + copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); @@ -108,7 +108,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); - transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, + bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); diff --git a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh index a8c02e4f8..998779594 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh @@ -122,16 +122,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Initiate bulk tensor copy if constexpr (IS_DGATED) { - transformer_engine::rocm::copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, + copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } // Act - transformer_engine::rocm::copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, + copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); // Gate - transformer_engine::rocm::copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, + copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); @@ -356,19 +356,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); if constexpr (USE_ROWWISE_SCALING) { - transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, + bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { - transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, + bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } if constexpr (USE_COLWISE_SCALING) { - transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, + bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { - transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, + bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } diff --git a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh index d5b51a2f4..b2881bf47 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh @@ -143,11 +143,11 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int chunk_it_offset_x = chunk_offset_X; const size_t row_base = chunk_it_offset_y; if constexpr (IS_DACT) { - transformer_engine::rocm::copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, + copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } - transformer_engine::rocm::copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, + copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); __syncthreads(); @@ -290,12 +290,12 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __syncthreads(); if constexpr (USE_ROWWISE_SCALING) { - transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, + bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } if constexpr (USE_COLWISE_SCALING) { - transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, + bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } diff --git a/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh b/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh index 50474f308..81dc46a85 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh @@ -8,7 +8,7 @@ #include "../../util/vectorized_pointwise.h" -namespace transformer_engine::rocm { +namespace transformer_engine { // These 2d copy functions replace TMA tensormap async copies for AMD GPUs. template __device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t g_start_col, @@ -78,4 +78,4 @@ __device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T * } } } -} // namespace transformer_engine::rocm \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 97aecf4de..abd98d1f7 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -282,7 +282,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( int64_t window_size_right, bool return_max_logit, bool cuda_graph) { using namespace transformer_engine; - NVTE_CHECK(!(return_max_logit || cuda_graph), "ROCm does not support return_max_logit and cuda_graph for fused_attn yet."); + // TODO: Add return_max_logit support + if (return_max_logit || cuda_graph) return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + // by default, fused attn is enabled bool nvte_fused_attn = true; if (const char* env_p = std::getenv("NVTE_FUSED_ATTN") ) { diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 312890db0..98a8cc998 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -21,9 +21,6 @@ #endif // CUDA_VERSION >= 12080 #include "common/utils.cuh" -#ifdef __HIP_PLATFORM_AMD__ -#include "../util/vectorized_pointwise.h" -#endif //#ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 2daecedfa..c1524dbd8 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index 619b6070b..0d632556f 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index ef601e4c4..886cea3c4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -61,6 +63,8 @@ FlashAttention, ) +from torch.utils.cpp_extension import IS_HIP_EXTENSION + # Setup Attention Logging attn_log.setup_logging() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 54fe21d81..7706dd61e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index e5492ebc6..a5f8a8f21 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index e8a735966..06a5fb4a5 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index 577a938f2..61f3ba1d8 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. From c453dba4883484cf490043c777bc7893a9fcc999 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 5 Mar 2026 13:16:01 -0600 Subject: [PATCH 18/41] various cleanups --- tests/cpp/CMakeLists.txt | 2 +- .../cpp/operator/test_cast_nvfp4_transpose.cu | 9 +- .../common/cast/dispatch/quantize.cuh | 1 - .../common/cast/nvfp4/core_nvfp4.cuh | 2 + .../common/cast/nvfp4/quantize_nvfp4.cuh | 2 + .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 2 + transformer_engine/common/common.h | 5 +- .../transformer_engine/transformer_engine.h | 2 +- ...quantize_transpose_vector_blockwise_fp4.cu | 79 +++--- transformer_engine/common/util/ptx.cuh | 268 ++++++++++-------- 10 files changed, 209 insertions(+), 163 deletions(-) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index b71addebf..f8af20665 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 47f2d39ab..3f8cc4353 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -49,6 +49,7 @@ double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) { #else const __half2_raw raw_truncated_to_fp4e2m1_pair = __nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1); + const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair); const double truncated_to_fp4e2m1_x = static_cast(truncated_to_fp4e2m1_pair.x); const double truncated_to_fp4e2m1_y = static_cast(truncated_to_fp4e2m1_pair.y); @@ -700,10 +701,12 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam transformer_engine::DType>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { +#ifndef __HIP_PLATFORM_AMD__ // Skip tests for pre-Blackwell architectures - // if (getDeviceComputeCapability() < blackwellComputeCapability) { - // GTEST_SKIP(); - // } + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } +#endif using namespace transformer_engine; using namespace test; diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 0378be46f..a99d51558 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -91,7 +91,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, dummy_workspace_tensor, stream); break; } -// #ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh index e2eda60c5..80c2586e7 100644 --- a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh index 665124e3d..7265ff5e3 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 4c0a565da..50920fd3b 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 505706239..cf63b1461 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -656,11 +656,14 @@ struct TypeInfo { } // Add a pack_size argument to select the packed type for FP4 +#ifdef __HIP_PLATFORM_AMD__ +#define __nv_fp4x2_storage_t __hip_fp4x2_storage_t +#endif #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat4E2M1: { \ - using type = __hip_fp4x2_storage_t; \ + using type = __nv_fp4x2_storage_t; \ { __VA_ARGS__ } \ } break; \ default: \ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 70c01c67f..fccb882ff 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 4387e916a..da02644ba 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -27,7 +29,12 @@ namespace transformer_engine { -// #if CUDA_VERSION >= 12080 +#ifdef __HIP_PLATFORM_AMD__ +#define __nv_fp4x4_e2m1 __hip_fp4x4_e2m1 +#define __nv_fp4x2_storage_t __hip_fp4x2_storage_t +#endif + +#if defined(__HIP_PLATFORM_AMD__) || CUDA_VERSION >= 12080 namespace quantize_transpose_nvfp4 { namespace { @@ -206,7 +213,11 @@ __device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scal } __device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { +#ifdef __HIP_PLATFORM_AMD__ const float fp8_max = TypeExtrema::max; +#else + constexpr float fp8_max = TypeExtrema::max; +#endif constexpr float fp4_max = TypeExtrema::max; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return max value of float32 @@ -274,28 +285,32 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s return ((rb * cbg_cnt + cbg) * kRowsPerBaseBlockCol + d3) * 16 + d4 * kColsPerBaseBlockCol + d5; } -__device__ __forceinline__ __hip_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding( +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding( const float2 in01, const float2 in23, const uint32_t rbits) { - // constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; - // if constexpr (has_rs) { - // uint16_t out_4x; - // asm volatile( - // "{\n" - // "cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t" - // "}" - // : "=h"(out_4x) - // : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); - // return *reinterpret_cast<__hip_fp4x4_e2m1*>(&out_4x); - // } else { +#ifndef __HIP_PLATFORM_AMD__ + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + uint16_t out_4x; + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t" + "}" + : "=h"(out_4x) + : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); + } else { +#endif NVTE_DEVICE_ERROR( "FP4 cvt.rs PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); uint16_t dummy = 0; - return *reinterpret_cast<__hip_fp4x4_e2m1*>(&dummy); - // } + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); +#ifndef __HIP_PLATFORM_AMD__ + } +#endif } -__device__ __forceinline__ __hip_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, const float2 in23, const uint32_t rbits) { #ifdef __HIP_PLATFORM_AMD__ @@ -326,19 +341,19 @@ __device__ __forceinline__ __hip_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const flo "}" : "=r"(out_4x) : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); - return reinterpret_cast<__hip_fp4x4_e2m1*>(&out_4x)[0]; + return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0]; } else { NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); uint16_t dummy = 0; - return *reinterpret_cast<__hip_fp4x4_e2m1*>(&dummy); + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); } #endif } template -__device__ __forceinline__ __hip_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, const uint32_t rbits) { if constexpr (kApplyStochasticRounding) { return cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, rbits); @@ -572,11 +587,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo f2_b.x = ComputeOutputFP4(smem_vec[i + 1].data.elt[0], encode_scale); f2_b.y = ComputeOutputFP4(smem_vec[i + 1].data.elt[1], encode_scale); const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; - // Convert to __hip_fp4x4_e2m1 - __hip_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); - output_vec.data.elt[i] = reinterpret_cast<__hip_fp4x2_storage_t*>(&out_4x)[0]; - output_vec.data.elt[i + 1] = reinterpret_cast<__hip_fp4x2_storage_t*>(&out_4x)[1]; + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; } // Step 2.7: Store output_c if constexpr (kAligned) { @@ -700,11 +715,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo encode_scale); const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; - // Convert to __hip_fp4x4_e2m1 - __hip_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); - output_vec.data.elt[i] = reinterpret_cast<__hip_fp4x2_storage_t*>(&out_4x)[0]; - output_vec.data.elt[i + 1] = reinterpret_cast<__hip_fp4x2_storage_t*>(&out_4x)[1]; + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; } // Step 3.7: Store output_t if constexpr (kAligned) { @@ -729,7 +744,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } // namespace } // namespace quantize_transpose_nvfp4 -// #endif // CUDA_VERSION >= 12080 +#endif // CUDA_VERSION >= 12080 namespace detail { @@ -741,7 +756,7 @@ void quantize_transpose_vector_blockwise_fp4( const NVTETensor rng_state_tensor, const bool use_2d_quantization, const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); -// #if CUDA_VERSION >= 12080 +#if defined(__HIP_PLATFORM_AMD__) || CUDA_VERSION >= 12080 // pow 2 scale is for MXFP4 since it's using E8M0 scaling // raise error if pow2_scale is true @@ -862,9 +877,9 @@ void quantize_transpose_vector_blockwise_fp4( ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); -// #else -// NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); -// #endif // CUDA_VERSION >= 12080 +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 } } // namespace detail diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index e93549db5..bbc5bcff6 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -527,86 +527,96 @@ __device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, cons __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( const uint64_t in_4x, const float2 scale, const uint32_t rbits) { uint16_t out_4x = 0; - // constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; - // if constexpr (has_rs) { - // asm volatile( - // "{\n" - // ".reg.b64 v01; \n\t" - // ".reg.b64 v23; \n\t" - // ".reg.b16 v0_bf16; \n\t" - // ".reg.b16 v1_bf16; \n\t" - // ".reg.b16 v2_bf16; \n\t" - // ".reg.b16 v3_bf16; \n\t" - // ".reg.b32 v0; \n\t" - // ".reg.b32 v1; \n\t" - // ".reg.b32 v2; \n\t" - // ".reg.b32 v3; \n\t" - // "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" - // "cvt.f32.bf16 v0, v0_bf16; \n\t" - // "cvt.f32.bf16 v1, v1_bf16; \n\t" - // "cvt.f32.bf16 v2, v2_bf16; \n\t" - // "cvt.f32.bf16 v3, v3_bf16; \n\t" - // "mov.b64 v01, {v0, v1}; \n\t" - // "mov.b64 v23, {v2, v3}; \n\t" - // "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order - // "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order - // "mov.b64 {v1, v0}, v01; \n\t" - // "mov.b64 {v3, v2}, v23; \n\t" - // "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order - // "}" - // : "=h"(out_4x) - // : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); - // } else { +#ifndef __HIP_PLATFORM_AMD__ + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); + } else { +#endif NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); - // } +#ifndef __HIP_PLATFORM_AMD__ + } +#endif return *reinterpret_cast(&out_4x); } __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, const float2 scale, const uint32_t rbits) { - // constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; +#ifndef __HIP_PLATFORM_AMD__ + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; +#endif uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. - // if constexpr (is_blackwell) { - // // NOTE: rbits unused for rn. - // asm volatile( - // "{\n" - // ".reg.b64 v01; \n\t" - // ".reg.b64 v23; \n\t" - // ".reg.b16 v0_bf16; \n\t" - // ".reg.b16 v1_bf16; \n\t" - // ".reg.b16 v2_bf16; \n\t" - // ".reg.b16 v3_bf16; \n\t" - // ".reg.b32 v0; \n\t" - // ".reg.b32 v1; \n\t" - // ".reg.b32 v2; \n\t" - // ".reg.b32 v3; \n\t" - // ".reg.b8 f0; \n\t" - // ".reg.b8 f1; \n\t" - // "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" - // "cvt.f32.bf16 v0, v0_bf16; \n\t" - // "cvt.f32.bf16 v1, v1_bf16; \n\t" - // "cvt.f32.bf16 v2, v2_bf16; \n\t" - // "cvt.f32.bf16 v3, v3_bf16; \n\t" - // "mov.b64 v01, {v0, v1}; \n\t" - // "mov.b64 v23, {v2, v3}; \n\t" - // "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order - // "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order - // "mov.b64 {v1, v0}, v01; \n\t" - // "mov.b64 {v3, v2}, v23; \n\t" - // "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" - // "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" - // "mov.b32 %0, {f0, f1, f0, f1};\n\t" - // "}" - // : "=r"(out_4x) - // : "l"(in_4x), "l"(reinterpret_cast(scale))); - // } else { +#ifndef __HIP_PLATFORM_AMD__ + if constexpr (is_blackwell) { + // NOTE: rbits unused for rn. + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale))); + } else { +#endif NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); - // } +#ifndef __HIP_PLATFORM_AMD__ + } +#endif return reinterpret_cast(&out_4x)[0]; } @@ -624,35 +634,39 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { uint16_t out_4x = 0; - // constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; - // if constexpr (has_rs) { - // asm volatile( - // "{\n" - // ".reg.b64 v01; \n\t" - // ".reg.b64 v23; \n\t" - // ".reg.b32 v0; \n\t" - // ".reg.b32 v1; \n\t" - // ".reg.b32 v2; \n\t" - // ".reg.b32 v3; \n\t" - // "mov.b64 {v0, v1} , %1; \n\t" - // "mov.b64 {v2, v3} , %2; \n\t" - // "mov.b64 v01, {v0, v1}; \n\t" - // "mov.b64 v23, {v2, v3}; \n\t" - // "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order - // "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order - // "mov.b64 {v1, v0}, v01; \n\t" - // "mov.b64 {v3, v2}, v23; \n\t" - // "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order - // "}" - // : "=h"(out_4x) - // : "l"(reinterpret_cast(in01)), - // "l"(reinterpret_cast(in23)), - // "l"(reinterpret_cast(scale)), "r"(rbits)); - // } else { +#ifndef __HIP_PLATFORM_AMD__ + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale)), "r"(rbits)); + } else { +#endif NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); - // } +#ifndef __HIP_PLATFORM_AMD__ + } +#endif return *reinterpret_cast(&out_4x); } @@ -660,41 +674,47 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 const float2 in23, const float2 scale, const uint32_t rbits) { - // constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; +#ifndef __HIP_PLATFORM_AMD__ + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; +#endif uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. - // if constexpr (is_blackwell) { - // // NOTE: rbits unused for rn. - // asm volatile( - // "{\n" - // ".reg.b64 v01; \n\t" - // ".reg.b64 v23; \n\t" - // ".reg.b32 v0; \n\t" - // ".reg.b32 v1; \n\t" - // ".reg.b32 v2; \n\t" - // ".reg.b32 v3; \n\t" - // ".reg.b8 f0; \n\t" - // ".reg.b8 f1; \n\t" - // "mov.b64 {v0, v1} , %1; \n\t" - // "mov.b64 {v2, v3} , %2; \n\t" - // "mov.b64 v01, {v0, v1}; \n\t" - // "mov.b64 v23, {v2, v3}; \n\t" - // "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order - // "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order - // "mov.b64 {v1, v0}, v01; \n\t" - // "mov.b64 {v3, v2}, v23; \n\t" - // "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" - // "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" - // "mov.b32 %0, {f0, f1, f0, f1};\n\t" - // "}" - // : "=r"(out_4x) - // : "l"(reinterpret_cast(in01)), - // "l"(reinterpret_cast(in23)), - // "l"(reinterpret_cast(scale))); - // } else { +#ifndef __HIP_PLATFORM_AMD__ + if constexpr (is_blackwell) { + // NOTE: rbits unused for rn. + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale))); + } else { +#endif NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); - // } +#ifndef __HIP_PLATFORM_AMD__ + } +#endif return reinterpret_cast(&out_4x)[0]; } From 4a843ba66a4d4e1973e94eab7a62dc1f80938933 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 5 Mar 2026 13:17:00 -0600 Subject: [PATCH 19/41] manually update runner labels --- .github/workflows/rocm-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 4fc34e391..26a26fab1 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -46,7 +46,7 @@ jobs: strategy: fail-fast: false matrix: - runner: [linux-mi325-8, linux-mi355-8] + runner: [linux-te-mi325-8, linux-te-mi355-8] steps: - name: Checkout repository uses: actions/checkout@v4 From 316dffb60374c4ff021e8e0284c84ddfb6090596 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Thu, 5 Mar 2026 13:23:14 -0600 Subject: [PATCH 20/41] Comment cleanup --- build_tools/wheel_utils/build_wheels.sh | 2 +- transformer_engine/jax/csrc/extensions/gemm.cpp | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 9ba647296..6ada06d0c 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -67,7 +67,7 @@ if $BUILD_COMMON ; then #hipify expects python in PATH, also ninja may be installed to python bindir test -n "$PYBINDIR" && PATH="$PYBINDIR:$PATH" || true else - TE_CUDA_VERS="${CUDA_MAJOR}" + TE_CUDA_VERS="cu${CUDA_MAJOR}" PYBINDIR=/opt/python/cp38-cp38/bin/ fi diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index d35b2d072..6f0b7ff22 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -285,8 +285,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i out_.data() /*D*/, workspace_.data(), config, stream); } else { #ifdef USE_ROCM - //TODO: better assert - NVTE_ERROR("ROCm TE jax does not integrate userbuffer for now"); + NVTE_ERROR("ROCm TE JAX does not support comm-comp overlap yet."); #else std::vector buffer_shape{0, 0}; DType buffer_dtype = out_dtype; From 5c747bde06088e491e90af7f9e9ee25cea1bfb53 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 5 Mar 2026 15:56:48 -0600 Subject: [PATCH 21/41] only enable on gfx950 --- tests/cpp/test_common.h | 2 +- transformer_engine/common/common.h | 2 +- transformer_engine/common/util/ptx.cuh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 05189b4af..715d40f4f 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -21,7 +21,7 @@ #include #endif //FP4_TYPE_SUPPORTED #else -#define FP4_TYPE_SUPPORTED (true) +#define FP4_TYPE_SUPPORTED __gfx950__ #include #include "amd_detail/hip_float8.h" #include diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index cf63b1461..7d28a6783 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -14,7 +14,7 @@ #define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080) #else #ifdef __HIPCC__ -#define FP4_TYPE_SUPPORTED true +#define FP4_TYPE_SUPPORTED __gfx950__ #else #define FP4_TYPE_SUPPORTED false #endif diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 3ea8a8fcd..590242b74 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -536,7 +536,7 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_roun __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, const float2 scale, const uint32_t rbits) { -#ifndef __HIP_PLATFORM_AMD__ +#ifndef __HIP_PLATFORM_AMD__ constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; #endif uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. From db56b8f8a666f07b4ccaf783d2c6463e2c82836c Mon Sep 17 00:00:00 2001 From: alextmagro Date: Thu, 5 Mar 2026 16:09:17 -0600 Subject: [PATCH 22/41] Update jax gemm.py --- transformer_engine/jax/cpp_extensions/gemm.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c1524dbd8..276480f24 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -20,19 +20,22 @@ from jax.sharding import NamedSharding, PartitionSpec from jax.experimental.custom_partitioning import SdyShardingRule +from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type + from transformer_engine_jax import ( get_num_compute_streams, JAXX_Collective_Op, get_device_compute_capability, - #initialize_cgemm_communicator, - #get_cgemm_num_max_streams, ) +if not is_hip_extension(): + from transformer_engine_jax import ( + initialize_cgemm_communicator, + get_cgemm_num_max_streams, + ) from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize -from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type - from ..quantize import ( AbstractBaseTensor, NoScaleTensor, From 62eea94bd2d518d52255dad154d0740032a16792 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 6 Mar 2026 16:45:50 +0000 Subject: [PATCH 23/41] Revert "only enable on gfx950" This reverts commit 5c747bde06088e491e90af7f9e9ee25cea1bfb53. --- tests/cpp/test_common.h | 2 +- transformer_engine/common/common.h | 2 +- transformer_engine/common/util/ptx.cuh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 715d40f4f..05189b4af 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -21,7 +21,7 @@ #include #endif //FP4_TYPE_SUPPORTED #else -#define FP4_TYPE_SUPPORTED __gfx950__ +#define FP4_TYPE_SUPPORTED (true) #include #include "amd_detail/hip_float8.h" #include diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 7d28a6783..cf63b1461 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -14,7 +14,7 @@ #define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080) #else #ifdef __HIPCC__ -#define FP4_TYPE_SUPPORTED __gfx950__ +#define FP4_TYPE_SUPPORTED true #else #define FP4_TYPE_SUPPORTED false #endif diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 590242b74..3ea8a8fcd 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -536,7 +536,7 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_roun __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, const float2 scale, const uint32_t rbits) { -#ifndef __HIP_PLATFORM_AMD__ +#ifndef __HIP_PLATFORM_AMD__ constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; #endif uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. From 6d459ec10edab81c7c0d0d39be7c2d1cacba5c6c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 6 Mar 2026 13:15:05 -0600 Subject: [PATCH 24/41] reenable in NVTEDType --- .../common/include/transformer_engine/transformer_engine.h | 6 ------ 1 file changed, 6 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index fccb882ff..d4e10673d 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -34,14 +34,8 @@ enum NVTEDType { kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */ kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */ kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */ -#ifndef __HIP_PLATFORM_AMD__ kNVTEFloat4E2M1 = 10, /*!< 4-bit float (E2M1) */ kNVTENumTypes /*!< Number of supported types */ -#else - //switch the order since rocm platform does not support e2m1 - kNVTENumTypes = 10, /*!< Number of supported types */ - kNVTEFloat4E2M1 = 11 /*!< 4-bit float (E2M1) */ -#endif // #ifndef __HIP_PLATFORM_AMD__ }; /*! \struct NVTEShape From 6eb2707145b80b0b37b26486928423837ad20da0 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Fri, 6 Mar 2026 14:14:08 -0600 Subject: [PATCH 25/41] Fix dev merge conflicts --- tests/pytorch/triton_kernels/test_cast.py | 2 +- transformer_engine/common/CMakeLists.txt | 12 ------------ transformer_engine/common/gemm/cublaslt_gemm.cu | 1 + transformer_engine/jax/csrc/extensions/misc.h | 16 ++++------------ transformer_engine/pytorch/module/_common.py | 1 + .../pytorch/module/grouped_linear.py | 14 +------------- .../pytorch/triton_kernels/norms_common.py | 2 +- 7 files changed, 9 insertions(+), 39 deletions(-) diff --git a/tests/pytorch/triton_kernels/test_cast.py b/tests/pytorch/triton_kernels/test_cast.py index f85773d65..a595ebb8e 100644 --- a/tests/pytorch/triton_kernels/test_cast.py +++ b/tests/pytorch/triton_kernels/test_cast.py @@ -10,7 +10,7 @@ from transformer_engine.pytorch.triton_kernels.common import te_dtype_to_torch_dtype import transformer_engine_torch as tex from test_common import te_compare_results, fill_uniform, get_tolerances -from transformer_engine.pytorch.fp8 import autocast +from transformer_engine.pytorch.quantization import autocast from transformer_engine.common import recipe from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 2349cf9f6..5ed01f2ee 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -350,27 +350,15 @@ target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") if (USE_CUDA) -<<<<<<< HEAD # CUTLASS kernels require SM90a and cause hang in debug build set_property( SOURCE gemm/cutlass_grouped_gemm.cu APPEND PROPERTY COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a;-g0") -======= -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) - set_source_files_properties( - "gemm/cutlass_grouped_gemm.cu" - PROPERTIES - COMPILE_FLAGS - "-gencode arch=compute_90a,code=sm_90a") -else() - message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") -endif() else() set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) target_include_directories(transformer_engine PRIVATE ${CK_ROOT}/include) ->>>>>>> origin/dev endif() #USE_CUDA # Configure dependencies diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index bad80c109..b86efd5cb 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -27,6 +27,7 @@ #include "../util/handle_manager.h" #include "../util/logging.h" #include "../util/multi_stream.h" +#include "../util/system.h" #include "./config.h" #ifndef __HIP_PLATFORM_AMD__ #include "./cutlass_grouped_gemm.cuh" diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 9e195c079..b67d0f67c 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -107,25 +107,17 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { struct BLOCK_SIZE { size_t x; size_t y; -<<<<<<< HEAD constexpr BLOCK_SIZE(int _x, int _y) : x(_x), y(_y) {} }; -======= -} MXFP8_BLOCK_SIZE{1, 32}; -constexpr struct Alignment { - size_t x; - size_t y; -#ifndef __HIP_PLATFORM_AMD__ -} MXFP8_ALIGNMENT{128, 4}; -#else -} MXFP8_ALIGNMENT{1, 1}; -#endif ->>>>>>> origin/dev constexpr BLOCK_SIZE MXFP8_BLOCK_SIZE{1, 32}; constexpr BLOCK_SIZE NVFP4_BLOCK_SIZE{1, 16}; +#ifdef __HIP_PLATFORM_AMD__ +constexpr BLOCK_SIZE BLOCK_SCALE_ALIGNMENT{1, 1}; +#else constexpr BLOCK_SIZE BLOCK_SCALE_ALIGNMENT{128, 4}; +#endif std::vector get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N, bool is_colwise); diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 5927ac786..04de0d061 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -26,6 +26,7 @@ te_rmsnorm_fwd_triton, te_rmsnorm_bwd_triton ) + import os def _get_normalization_func(normalization: str, forward: bool): use_rmsnorm_triton = bool( int(os.environ.get('NVTE_USE_RMSNORM_TRITON', '0')) ) and IS_HIP_EXTENSION diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 22a578474..fc191b706 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -146,13 +146,10 @@ def forward( inputmats = [cast_if_needed(inp_view, activation_dtype)] else: inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) -<<<<<<< HEAD if cpu_offloading: start_offload(*inputmats) -======= ->>>>>>> origin/dev # Initialize weights weights_fp8: list if fp8: @@ -392,17 +389,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) -<<<<<<< HEAD - # Make sure weights are available in column-wise format - # for dgrad computation. - for weight in weights: - if isinstance(weight, QuantizedTensorStorage): - weight.update_usage(columnwise_usage=True) - general_grouped_gemm( -======= for weight, quantizer in zip(weights, ctx.weight_quantizers): - if quantizer is not None and isinstance(weight, QuantizedTensorBase): + if quantizer is not None and isinstance(weight, QuantizedTensorStorage): weight.update_usage( rowwise_usage=quantizer.rowwise_usage, columnwise_usage=quantizer.columnwise_usage, @@ -414,7 +403,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], general_grouped_gemm_func = general_grouped_gemm kwargs = {} general_grouped_gemm_func( ->>>>>>> origin/dev weights, grad_output, [dgrad], diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index d40526e4c..87cfa722e 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -12,7 +12,7 @@ te_dtype_to_torch_dtype, te_dtype_to_triton_dtype, ) -from ..tensor.quantized_tensor import Quantizer +from ..quantized_tensor import Quantizer from .utils import num_programs, block_size, use_blocked, make_ln_out from .common import get_fp8_max from .rmsnorm import ( From 8cec975624f23cbbac892660a58d6ded64b6ba1f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 6 Mar 2026 17:33:16 -0600 Subject: [PATCH 26/41] enable in bwd_helper --- transformer_engine/common/cast/dispatch/quantize.cuh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index a99d51558..187cc67f1 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -251,7 +251,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens stream); break; } -#ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { NVTE_CHECK((!IS_DBIAS && !IS_DACT), "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); @@ -269,6 +268,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens (cols % 32 == 0) && output_tensor->has_data(); // Launch NVFP4 quantize kernel +#ifndef __HIP_PLATFORM_AMD__ if (use_optimized_kernel) { if (quant_config_cpp.nvfp4_2d_quantization) { nvfp4::quantize_transpose( @@ -278,6 +278,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); } } else { +#endif auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax : output_tensor->columnwise_amax; quantize_transpose_vector_blockwise_fp4( @@ -292,9 +293,12 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); +#ifndef __HIP_PLATFORM_AMD__ } +#endif break; } +#ifndef __HIP_PLATFORM_AMD__ case NVTE_BLOCK_SCALING_2D: { // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT), From ccda439ec20e71b46f38eb9a5115bae5b11ac823 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 9 Mar 2026 12:32:05 -0500 Subject: [PATCH 27/41] alignment fixes --- tests/cpp/test_common.cu | 14 ++++++++++++++ tests/cpp/test_common.h | 7 ------- transformer_engine/common/CMakeLists.txt | 3 ++- .../common/cast/dispatch/quantize.cuh | 2 -- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 66474f556..ed7d77699 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1219,12 +1219,26 @@ std::array get_scale_tensor_dims(const size_t rows, const bool is_rowwise = (block_size_rows == 1) && ((block_size_cols == 32) || (block_size_cols == 16)); + // On AMD, MXFP8 scales (block_size=32) are passed unpadded to hipBLASlt. + // NVFP4 scales (block_size=16) still require [128,4] padding for kernel indexing. +#ifdef __HIP_PLATFORM_AMD__ + const bool needs_padding = (block_size_cols == 16 || block_size_rows == 16); + const size_t alignment_Y = needs_padding + ? (is_rowwise ? scale_tensor_alignment_Y_rowwise + : scale_tensor_alignment_Y_colwise) + : 1; + const size_t alignment_X = needs_padding + ? (is_rowwise ? scale_tensor_alignment_X_rowwise + : scale_tensor_alignment_X_colwise) + : 1; +#else const size_t alignment_Y = is_rowwise ? scale_tensor_alignment_Y_rowwise : scale_tensor_alignment_Y_colwise; const size_t alignment_X = is_rowwise ? scale_tensor_alignment_X_rowwise : scale_tensor_alignment_X_colwise; +#endif const size_t unpadded_blocks_Y = divide_round_up(rows, block_size_rows); const size_t unpadded_blocks_X = divide_round_up(cols, block_size_cols); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index ec686fc17..e02dacd61 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -345,17 +345,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_MANTISSA_BITS = 23; // [128,4] rowwise and [4,128] colwise alignment requirement -#ifdef __HIP_PLATFORM_AMD__ -constexpr size_t scale_tensor_alignment_X_rowwise = 1; -constexpr size_t scale_tensor_alignment_Y_rowwise = 1; -constexpr size_t scale_tensor_alignment_X_colwise = 1; -constexpr size_t scale_tensor_alignment_Y_colwise = 1; -#else constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_colwise = 4; constexpr size_t scale_tensor_alignment_X_colwise = 128; -#endif inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 5ed01f2ee..4e968f145 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -259,7 +259,8 @@ else() list(APPEND transformer_engine_cuda_sources fused_attn_rocm/fused_attn_aotriton.cpp fused_attn_rocm/fused_attn_ck.cpp - fused_attn_rocm/utils.cpp) + fused_attn_rocm/utils.cpp + transpose/quantize_transpose_vector_blockwise_fp4.cu) endif() diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 9add948c5..187cc67f1 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -91,7 +91,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, dummy_workspace_tensor, stream); break; } -#ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); @@ -252,7 +251,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens stream); break; } -#ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { NVTE_CHECK((!IS_DBIAS && !IS_DACT), "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); From 4b0fd34740e245dadc273614ac767e91f2ee22ba Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 9 Mar 2026 12:35:28 -0500 Subject: [PATCH 28/41] fix merge error --- transformer_engine/jax/cpp_extensions/gemm.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index cfaa1beba..6a6a106b5 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -20,8 +20,6 @@ from jax.sharding import NamedSharding, PartitionSpec from jax.experimental.custom_partitioning import SdyShardingRule -from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type - from transformer_engine_jax import ( get_num_compute_streams, JAXX_Collective_Op, @@ -29,15 +27,12 @@ #initialize_cgemm_communicator, #get_cgemm_num_max_streams, ) -if not is_hip_extension(): - from transformer_engine_jax import ( - initialize_cgemm_communicator, - get_cgemm_num_max_streams, - ) from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize +from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type + from ..quantize import ( AbstractBaseTensor, NoScaleTensor, From 84934c22cb94805e59af636b0728b174e051cbe6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 9 Mar 2026 16:38:25 -0500 Subject: [PATCH 29/41] minor fixes --- tests/cpp/operator/CMakeLists.txt | 2 +- tests/pytorch/attention/test_attention.py | 1 - transformer_engine/common/util/ptx.cuh | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 53db01492..fee06f3c7 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -11,7 +11,6 @@ list(APPEND test_cuda_sources test_cast_dbias_dgelu.cu test_cast_gated_swiglu.cu test_cast_mxfp8_gated_swiglu.cu - test_cast_nvfp4_transpose.cu test_qdq.cu test_cast_mxfp8.cu test_dequantize_mxfp8.cu @@ -38,6 +37,7 @@ if(USE_CUDA) test_swizzle.cu) else() list(APPEND test_cuda_sources + test_cast_nvfp4_transpose.cu test_cublaslt_gemm.cu) endif() diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 713bef645..dc3cef901 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -148,7 +148,6 @@ def test_gqa_mla_thd(): test_dot_product_attention(dtype, {"layout_1": config}, "layout_1", False, False, qkv_layout, False, True) - @pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") def test_dot_product_mem_calc(): """Non-regression test for memory workspace calculation integer overflow issue.""" diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 3ea8a8fcd..590242b74 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -536,7 +536,7 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_roun __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, const float2 scale, const uint32_t rbits) { -#ifndef __HIP_PLATFORM_AMD__ +#ifndef __HIP_PLATFORM_AMD__ constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; #endif uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. From 586bd09964e34b79686c1b2eaab1b40bc2069942 Mon Sep 17 00:00:00 2001 From: leo-amd Date: Thu, 12 Mar 2026 13:55:57 +0100 Subject: [PATCH 30/41] Run CI From aa18e9a7a83ccf32e0ed908d7e06f06ea64fabab Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 13 Mar 2026 16:34:36 -0500 Subject: [PATCH 31/41] more scales fixing --- tests/cpp/test_common.cu | 27 +++++++++++++++------------ tests/cpp/test_common.h | 7 +++++++ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index ed7d77699..1a27752ce 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -181,12 +181,21 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; - size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); - size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise); +#ifdef __HIP_PLATFORM_AMD__ + // NVFP4 requires [128,4] padding on AMD regardless of MXFP8 alignment constants + constexpr size_t nvfp4_align_Y = 128; + constexpr size_t nvfp4_align_X = 4; +#else + constexpr size_t nvfp4_align_Y = scale_tensor_alignment_Y_rowwise; + constexpr size_t nvfp4_align_X = scale_tensor_alignment_X_rowwise; +#endif + + size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, nvfp4_align_Y); + size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), nvfp4_align_X); ret_rowwise.shape = {scale_dim_Y, scale_dim_X}; - size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise); - size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise); + size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, nvfp4_align_Y); + size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), nvfp4_align_X); ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t}; ret_rowwise.type = DType::kFloat8E4M3; @@ -1223,14 +1232,8 @@ std::array get_scale_tensor_dims(const size_t rows, // NVFP4 scales (block_size=16) still require [128,4] padding for kernel indexing. #ifdef __HIP_PLATFORM_AMD__ const bool needs_padding = (block_size_cols == 16 || block_size_rows == 16); - const size_t alignment_Y = needs_padding - ? (is_rowwise ? scale_tensor_alignment_Y_rowwise - : scale_tensor_alignment_Y_colwise) - : 1; - const size_t alignment_X = needs_padding - ? (is_rowwise ? scale_tensor_alignment_X_rowwise - : scale_tensor_alignment_X_colwise) - : 1; + const size_t alignment_Y = needs_padding ? (is_rowwise ? 128 : 4) : 1; + const size_t alignment_X = needs_padding ? (is_rowwise ? 4 : 128) : 1; #else const size_t alignment_Y = is_rowwise ? scale_tensor_alignment_Y_rowwise diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index e02dacd61..ec686fc17 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -345,10 +345,17 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_MANTISSA_BITS = 23; // [128,4] rowwise and [4,128] colwise alignment requirement +#ifdef __HIP_PLATFORM_AMD__ +constexpr size_t scale_tensor_alignment_X_rowwise = 1; +constexpr size_t scale_tensor_alignment_Y_rowwise = 1; +constexpr size_t scale_tensor_alignment_X_colwise = 1; +constexpr size_t scale_tensor_alignment_Y_colwise = 1; +#else constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_colwise = 4; constexpr size_t scale_tensor_alignment_X_colwise = 128; +#endif inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; From 95d0c9fd42987f87302d506443c5a3a3c5d6418d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 17 Mar 2026 16:10:53 -0500 Subject: [PATCH 32/41] address review comments --- tests/cpp/operator/CMakeLists.txt | 3 +-- transformer_engine/common/CMakeLists.txt | 7 +++---- transformer_engine/common/cast/dispatch/quantize.cuh | 6 ++++-- .../transpose/quantize_transpose_vector_blockwise_fp4.cu | 6 +++++- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 64857af1a..dfd8fba29 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -14,6 +14,7 @@ list(APPEND test_cuda_sources test_qdq.cu test_cast_mxfp8.cu test_dequantize_mxfp8.cu + test_cast_nvfp4_transpose.cu test_transpose.cu test_cast_transpose.cu test_cast_transpose_current_scaling.cu @@ -32,12 +33,10 @@ list(APPEND test_cuda_sources ../test_common.cu) if(USE_CUDA) list(APPEND test_cuda_sources - test_cast_nvfp4_transpose.cu test_cast_float8blockwise.cu test_swizzle.cu) else() list(APPEND test_cuda_sources - test_cast_nvfp4_transpose.cu test_cublaslt_gemm.cu) endif() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 08a7a592c..8d5537368 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -222,7 +222,8 @@ list(APPEND transformer_engine_cuda_arch_specific_sources cast/cast.cu activation/gelu.cu activation/relu.cu - activation/swiglu.cu) + activation/swiglu.cu + transpose/quantize_transpose_vector_blockwise_fp4.cu) if(USE_CUDA) #NV specific source codes @@ -246,7 +247,6 @@ if(USE_CUDA) list(APPEND transformer_engine_cuda_arch_specific_sources gemm/cutlass_grouped_gemm.cu transpose/quantize_transpose_square_blockwise.cu - transpose/quantize_transpose_vector_blockwise_fp4.cu hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform_cast_fusion.cu) else() @@ -259,8 +259,7 @@ else() list(APPEND transformer_engine_cuda_sources fused_attn_rocm/fused_attn_aotriton.cpp fused_attn_rocm/fused_attn_ck.cpp - fused_attn_rocm/utils.cpp - transpose/quantize_transpose_vector_blockwise_fp4.cu) + fused_attn_rocm/utils.cpp) endif() diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index ef9927ec4..3e986466b 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -120,7 +120,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, #endif auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax : output_tensor->columnwise_amax; - +#ifdef __HIP_PLATFORM_AMD__ // If amax was not explicitly set, fall back to the scale field which // holds the same value when set via set_scale(). NVTE_CHECK(global_amax.dptr != nullptr || output_tensor->scale.dptr != nullptr, @@ -128,9 +128,11 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, "or scale to be set. Call output.set_scale(amax_value) before quantizing."); const SimpleTensor& effective_amax = (global_amax.dptr != nullptr) ? global_amax : output_tensor->scale; - quantize_transpose_vector_blockwise_fp4( /*input=*/input_tensor->data, /*global_amax=*/effective_amax, +#else + /*input=*/input_tensor->data, /*global_amax=*/global_amax, +#endif /*scale_inv=*/output_tensor->scale_inv, /*scale_inv_t=*/output_tensor->columnwise_scale_inv, /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 05b1df680..0d58e3c03 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -213,7 +213,8 @@ __device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scal } __device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_DEVICE_COMPILE__) + // On AMD host, TypeExtrema::max is non-constexpr (runtime FNUZ detection) const float fp8_max = TypeExtrema::max; #else constexpr float fp8_max = TypeExtrema::max; @@ -299,6 +300,9 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); } else { +#else + NVTE_DEVICE_ERROR( + "cvt_fp32_to_fp4_4x_with_stochastic_rounding is not support on AMDGPU."); #endif NVTE_DEVICE_ERROR( "FP4 cvt.rs PTX instructions are architecture-specific. " From 6cd60387bacce1452f9cbb475d5a1b8d4e55c37c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 17 Mar 2026 16:17:09 -0500 Subject: [PATCH 33/41] adjust error message slightly --- .../transpose/quantize_transpose_vector_blockwise_fp4.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 0d58e3c03..bbf8a6d41 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -300,13 +300,13 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); } else { + NVTE_DEVICE_ERROR( + "FP4 cvt.rs PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); #else NVTE_DEVICE_ERROR( - "cvt_fp32_to_fp4_4x_with_stochastic_rounding is not support on AMDGPU."); + "cvt_fp32_to_fp4_4x_with_stochastic_rounding is not supported on AMDGPU."); #endif - NVTE_DEVICE_ERROR( - "FP4 cvt.rs PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); uint16_t dummy = 0; return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); #ifndef __HIP_PLATFORM_AMD__ From 55a8c849ee2c598487389d6a680c11754e18dc59 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 17 Mar 2026 17:07:27 -0500 Subject: [PATCH 34/41] simplify via hipify map --- build_tools/hipify/custom_map.json | 9 ++++++++- tests/cpp/operator/test_cast_nvfp4_transpose.cu | 2 -- tests/cpp/test_common.h | 6 ------ transformer_engine/common/cast/nvfp4/core_nvfp4.cuh | 4 ---- .../common/cast/nvfp4/quantize_nvfp4.cuh | 4 ---- .../common/cast/nvfp4/quantize_transpose_nvfp4.cuh | 4 ---- transformer_engine/common/common.h | 7 ------- .../quantize_transpose_vector_blockwise_fp4.cu | 13 ++----------- transformer_engine/common/util/ptx.cuh | 6 ------ 9 files changed, 10 insertions(+), 45 deletions(-) diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index 63b516906..335cfca2b 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -2,13 +2,20 @@ "custom_map" : { "" : "", "" : "\"common/amd_detail/hip_float8.h\"", + "" : "", "cuda_runtime.h\"" : "hip_runtime.h\"", "ATen/cudnn/Handle.h" : "ATen/miopen/Handle.h", "CUfunc_cache" : "hipFuncCache_t", "" : "", "cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)", "__nv_bfloat162":"__hip_bfloat162", - "cuda::getCurrentCUDAStream" : "hip::getCurrentHIPStreamMasqueradingAsCUDA" + "cuda::getCurrentCUDAStream" : "hip::getCurrentHIPStreamMasqueradingAsCUDA", + "__nv_fp4_e2m1" : "__hip_fp4_e2m1", + "__nv_fp4x2_e2m1" : "__hip_fp4x2_e2m1", + "__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1", + "__nv_fp4x2_storage_t" : "__hip_fp4x2_storage_t", + "" : "", + "" : "" } } diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 9bbad4b9f..79e11b723 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -8,9 +8,7 @@ #include #include -#ifndef __HIP_PLATFORM_AMD__ #include -#endif #include #include diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 6fd2aa717..0885855a1 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -74,15 +74,9 @@ using fp8e5m2 = te_hip_fp8_e5m2; #endif //USE_ROCM using fp8e8m0 = uint8_t; #if FP4_TYPE_SUPPORTED -#ifndef USE_ROCM using fp4e2m1 = __nv_fp4_e2m1; using fp4e2m1x2 = __nv_fp4x2_e2m1; using fp4e2m1x4 = __nv_fp4x4_e2m1; -#else -using fp4e2m1 = __hip_fp4_e2m1; -using fp4e2m1x2 = __hip_fp4x2_e2m1; -using fp4e2m1x4 = __hip_fp4x4_e2m1; -#endif //USE_ROCM #endif template diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh index 074f2ba51..bdbe5cddc 100644 --- a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -1,6 +1,4 @@ /************************************************************************* - * This file was modified for portability to AMDGPU - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -14,9 +12,7 @@ #define TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ #include -#ifndef __HIP_PLATFORM_AMD__ #include -#endif #include #include diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh index a1fb977b1..b4bccf239 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh @@ -1,6 +1,4 @@ /************************************************************************* - * This file was modified for portability to AMDGPU - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -14,9 +12,7 @@ #define TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ #include -#ifndef __HIP_PLATFORM_AMD__ #include -#endif #include #include diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index f4a8f2b48..455074e32 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1,6 +1,4 @@ /************************************************************************* - * This file was modified for portability to AMDGPU - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -14,9 +12,7 @@ #define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ #include -#ifndef __HIP_PLATFORM_AMD__ #include -#endif #include #include diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 0c1c6ef9c..10a65bb1d 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -24,11 +24,7 @@ #include #include #if FP4_TYPE_SUPPORTED -#ifndef __HIP_PLATFORM_AMD__ #include -#else -#include -#endif #endif #include @@ -656,9 +652,6 @@ struct TypeInfo { } // Add a pack_size argument to select the packed type for FP4 -#ifdef __HIP_PLATFORM_AMD__ -#define __nv_fp4x2_storage_t __hip_fp4x2_storage_t -#endif #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index bbf8a6d41..4adf5af50 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -7,17 +7,13 @@ ************************************************************************/ #include -#ifndef __HIP_PLATFORM_AMD__ #include #include -#endif #include #include #include -#ifndef __HIP_PLATFORM_AMD__ #include -#endif #include #include "common/common.h" @@ -29,11 +25,6 @@ namespace transformer_engine { -#ifdef __HIP_PLATFORM_AMD__ -#define __nv_fp4x4_e2m1 __hip_fp4x4_e2m1 -#define __nv_fp4x2_storage_t __hip_fp4x2_storage_t -#endif - #if defined(__HIP_PLATFORM_AMD__) || CUDA_VERSION >= 12080 namespace quantize_transpose_nvfp4 { namespace { @@ -301,8 +292,8 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); } else { NVTE_DEVICE_ERROR( - "FP4 cvt.rs PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); + "FP4 cvt.rs PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); #else NVTE_DEVICE_ERROR( "cvt_fp32_to_fp4_4x_with_stochastic_rounding is not supported on AMDGPU."); diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 6c9d5da23..3d1258ee0 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -460,15 +460,9 @@ static_assert(sizeof(fp8e4m3x2) == 2); static_assert(sizeof(fp8e5m2x2) == 2); #if FP4_TYPE_SUPPORTED -#ifndef __HIP_PLATFORM_AMD__ using fp4e2m1 = __nv_fp4_e2m1; using fp4e2m1x2 = __nv_fp4x2_e2m1; using fp4e2m1x4 = __nv_fp4x4_e2m1; -#else -using fp4e2m1 = __hip_fp4_e2m1; -using fp4e2m1x2 = __hip_fp4x2_e2m1; -using fp4e2m1x4 = __hip_fp4x4_e2m1; -#endif static_assert(sizeof(fp4e2m1x2) == 1); static_assert(sizeof(fp4e2m1x4) == 2); From 10d88bfa10b7128487079e8fac02db429e6da057 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 17 Mar 2026 17:21:39 -0500 Subject: [PATCH 35/41] adjust more error messages --- transformer_engine/common/util/ptx.cuh | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 3d1258ee0..fb4ccdd73 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -517,12 +517,13 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_roun : "=h"(out_4x) : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); } else { -#endif NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); -#ifndef __HIP_PLATFORM_AMD__ } +#else + NVTE_DEVICE_ERROR( + "mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding is not supported on AMDGPU."); #endif return *reinterpret_cast(&out_4x); } @@ -569,12 +570,13 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64 : "=r"(out_4x) : "l"(in_4x), "l"(reinterpret_cast(scale))); } else { -#endif NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); -#ifndef __HIP_PLATFORM_AMD__ } +#else + NVTE_DEVICE_ERROR( + "mul_cvt_bf16_to_fp4_4x_with_rn is not supported on AMDGPU."); #endif return reinterpret_cast(&out_4x)[0]; } @@ -619,12 +621,13 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_roun "l"(reinterpret_cast(in23)), "l"(reinterpret_cast(scale)), "r"(rbits)); } else { -#endif NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); -#ifndef __HIP_PLATFORM_AMD__ } +#else + NVTE_DEVICE_ERROR( + "mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding is not supported on AMDGPU."); #endif return *reinterpret_cast(&out_4x); } @@ -667,12 +670,13 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 "l"(reinterpret_cast(in23)), "l"(reinterpret_cast(scale))); } else { -#endif NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); -#ifndef __HIP_PLATFORM_AMD__ } +#else + NVTE_DEVICE_ERROR( + "mul_cvt_fp32_to_fp4_4x_with_rn is not supported on AMDGPU."); #endif return reinterpret_cast(&out_4x)[0]; } From b4caf6f7e89f0129befe2dd394123bd6a15bb688 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 Mar 2026 12:56:17 -0500 Subject: [PATCH 36/41] change disabling of header includes --- build_tools/hipify/custom_map.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index 335cfca2b..872d38efa 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -14,8 +14,8 @@ "__nv_fp4x2_e2m1" : "__hip_fp4x2_e2m1", "__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1", "__nv_fp4x2_storage_t" : "__hip_fp4x2_storage_t", - "" : "", - "" : "" + "#include " : "", + "#include " : "" } } From 511db6171b356817a3a855728645860bb8480ea7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 Mar 2026 13:23:41 -0500 Subject: [PATCH 37/41] address review comments --- tests/cpp/operator/test_cast_nvfp4_transpose.cu | 10 ++++++++++ tests/cpp/test_common.cu | 4 ++-- tests/cpp/test_common.h | 7 +++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 79e11b723..6674240e5 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -32,10 +32,12 @@ enum ActivationType { SReLU }; +#ifdef __HIP_PLATFORM_AMD__ static constexpr float E2M1_LUT[16] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f, }; +#endif double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) { #ifdef __HIP_PLATFORM_AMD__ @@ -582,7 +584,12 @@ void performTest(float (*OP)(const float), // Set 2nd stage NVFP4 scaling factor output.set_scale(amax); +#ifndef __HIP_PLATFORM_AMD__ bool use_2d_quantization = false; +#else + // Test both 1D and 2D quantization paths on AMDGPU + for (bool use_2d_quantization : {false, true}) { +#endif compute_ref(OP, input.rowwise_cpu_dptr(), @@ -665,6 +672,9 @@ void performTest(float (*OP)(const float), mismatches_scales_indices, #endif scale_mismatches_num); +#ifdef __HIP_PLATFORM_AMD__ + } +#endif } std::vector> tensor_dims = { diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index f32526be3..f2f169826 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1232,8 +1232,8 @@ std::array get_scale_tensor_dims(const size_t rows, // NVFP4 scales (block_size=16) still require [128,4] padding for kernel indexing. #ifdef __HIP_PLATFORM_AMD__ const bool needs_padding = (block_size_cols == 16 || block_size_rows == 16); - const size_t alignment_Y = needs_padding ? (is_rowwise ? 128 : 4) : 1; - const size_t alignment_X = needs_padding ? (is_rowwise ? 4 : 128) : 1; + const size_t alignment_Y = needs_padding ? (is_rowwise ? nvfp4_scale_tensor_alignment_Y_rowwise : nvfp4_scale_tensor_alignment_Y_colwise) : 1; + const size_t alignment_X = needs_padding ? (is_rowwise ? nvfp4_scale_tensor_alignment_X_rowwise : nvfp4_scale_tensor_alignment_X_colwise) : 1; #else const size_t alignment_Y = is_rowwise ? scale_tensor_alignment_Y_rowwise diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 0885855a1..97670e384 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -340,10 +340,17 @@ constexpr uint32_t FP32_MANTISSA_BITS = 23; // [128,4] rowwise and [4,128] colwise alignment requirement #ifdef __HIP_PLATFORM_AMD__ +// For mxfp8: constexpr size_t scale_tensor_alignment_X_rowwise = 1; constexpr size_t scale_tensor_alignment_Y_rowwise = 1; constexpr size_t scale_tensor_alignment_X_colwise = 1; constexpr size_t scale_tensor_alignment_Y_colwise = 1; + +// For nvfp4: +constexpr size_t nvfp4_scale_tensor_alignment_Y_rowwise = 128; +constexpr size_t nvfp4_scale_tensor_alignment_X_rowwise = 4; +constexpr size_t nvfp4_scale_tensor_alignment_Y_colwise = 4; +constexpr size_t nvfp4_scale_tensor_alignment_X_colwise = 128; #else constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_rowwise = 4; From 36cf73a67ddaa650466eb33ea99b965ac48e6049 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 Mar 2026 15:54:23 -0500 Subject: [PATCH 38/41] implement SR --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 13 +++++ ...quantize_transpose_vector_blockwise_fp4.cu | 56 +++++++++++++++++-- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 6674240e5..5a0691a0c 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -586,8 +586,14 @@ void performTest(float (*OP)(const float), #ifndef __HIP_PLATFORM_AMD__ bool use_2d_quantization = false; + for (bool use_stochastic_rounding : {false}) { #else // Test both 1D and 2D quantization paths on AMDGPU + hipDeviceProp_t prop; + hipGetDeviceProperties(&prop, 0); + const bool is_gfx950 = std::string(prop.gcnArchName).find("gfx950") != std::string::npos; + for (bool use_stochastic_rounding : (is_gfx950 ? std::vector{false, true} + : std::vector{false})) { for (bool use_2d_quantization : {false, true}) { #endif @@ -611,7 +617,11 @@ void performTest(float (*OP)(const float), rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed rng_state.rowwise_cpu_dptr()[1] = 321; // rng_sequence rng_state.from_cpu(); +#ifdef __HIP_PLATFORM_AMD__ + quant_config.set_stochastic_rounding(use_stochastic_rounding); +#else quant_config.set_stochastic_rounding(false); +#endif quant_config.set_rng_state(rng_state.data()); // Set 2D quantization based on compile-time flag @@ -674,6 +684,9 @@ void performTest(float (*OP)(const float), scale_mismatches_num); #ifdef __HIP_PLATFORM_AMD__ } + } +#else + } #endif } diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 4adf5af50..5508813d6 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -26,6 +26,13 @@ namespace transformer_engine { #if defined(__HIP_PLATFORM_AMD__) || CUDA_VERSION >= 12080 + +#ifdef __HIP_PLATFORM_AMD__ +using fp4x4_storage_t = __hip_fp4x4_storage_t; +#else +using fp4x4_storage_t = __nv_fp4x4_e2m1; +#endif + namespace quantize_transpose_nvfp4 { namespace { @@ -277,7 +284,11 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s return ((rb * cbg_cnt + cbg) * kRowsPerBaseBlockCol + d3) * 16 + d4 * kColsPerBaseBlockCol + d5; } +#ifdef __HIP_PLATFORM_AMD__ +__device__ __forceinline__ fp4x4_storage_t cvt_fp32_to_fp4_4x_with_stochastic_rounding( +#else __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding( +#endif const float2 in01, const float2 in23, const uint32_t rbits) { #ifndef __HIP_PLATFORM_AMD__ constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; @@ -295,17 +306,36 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro "FP4 cvt.rs PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); #else - NVTE_DEVICE_ERROR( - "cvt_fp32_to_fp4_4x_with_stochastic_rounding is not supported on AMDGPU."); -#endif +#ifdef __gfx950__ + // opsel=1 always writes to byte 1, result read from fp4x2[1] + union { uint32_t ui32; __hip_fp4x2_storage_t fp4x2[4]; } u{0}; + __amd_floatx2_storage_t packed01{in01.x, in01.y}; + __amd_floatx2_storage_t packed23{in23.x, in23.y}; + u.ui32 = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(u.ui32, packed01, rbits, 1.0f, 1); + const __hip_fp4x2_storage_t lo = u.fp4x2[1]; + u.ui32 = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(u.ui32, packed23, rbits, 1.0f, 1); + const __hip_fp4x2_storage_t hi = u.fp4x2[1]; + return static_cast(lo | (static_cast(hi) << 8)); +#else + NVTE_DEVICE_ERROR("FP4 stochastic rounding on AMDGPU requires gfx950 or later."); +#endif // __gfx950__ +#endif // !__HIP_PLATFORM_AMD__ uint16_t dummy = 0; +#ifdef __HIP_PLATFORM_AMD__ + return *reinterpret_cast(&dummy); +#else return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); +#endif #ifndef __HIP_PLATFORM_AMD__ } #endif } +#ifdef __HIP_PLATFORM_AMD__ +__device__ __forceinline__ fp4x4_storage_t cvt_fp32_to_fp4_4x_with_rn(const float2 in01, +#else __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, +#endif const float2 in23, const uint32_t rbits) { #ifdef __HIP_PLATFORM_AMD__ @@ -314,13 +344,11 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const floa const __hip_fp4_storage_t q2 = __hip_cvt_float_to_fp4(in23.x, __HIP_E2M1, hipRoundNearest); const __hip_fp4_storage_t q3 = __hip_cvt_float_to_fp4(in23.y, __HIP_E2M1, hipRoundNearest); - uint16_t packed = static_cast( + return static_cast( (q0 & 0xFu) | ((q1 & 0xFu) << 4) | ((q2 & 0xFu) << 8) | ((q3 & 0xFu) << 12)); - - return *reinterpret_cast<__hip_fp4x4_e2m1*>(&packed); #else constexpr bool has_fp4 = ARCH_BLACKWELL_FAMILY; if constexpr (has_fp4) { @@ -348,7 +376,11 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const floa } template +#ifdef __HIP_PLATFORM_AMD__ +__device__ __forceinline__ fp4x4_storage_t cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, +#else __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, +#endif const uint32_t rbits) { if constexpr (kApplyStochasticRounding) { return cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, rbits); @@ -583,10 +615,16 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo f2_b.y = ComputeOutputFP4(smem_vec[i + 1].data.elt[1], encode_scale); const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; // Convert to __nv_fp4x4_e2m1 +#ifdef __HIP_PLATFORM_AMD__ + fp4x4_storage_t out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + output_vec.data.elt[i] = static_cast<__hip_fp4x2_storage_t>(out_4x & 0xFF); + output_vec.data.elt[i + 1] = static_cast<__hip_fp4x2_storage_t>((out_4x >> 8) & 0xFF); +#else __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; +#endif } // Step 2.7: Store output_c if constexpr (kAligned) { @@ -711,10 +749,16 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; // Convert to __nv_fp4x4_e2m1 +#ifdef __HIP_PLATFORM_AMD__ + fp4x4_storage_t out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + output_vec.data.elt[i] = static_cast<__hip_fp4x2_storage_t>(out_4x & 0xFF); + output_vec.data.elt[i + 1] = static_cast<__hip_fp4x2_storage_t>((out_4x >> 8) & 0xFF); +#else __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; +#endif } // Step 3.7: Store output_t if constexpr (kAligned) { From a85f68f42a77bbc8b01c9a0ea37a70e0cf10c9b1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 Mar 2026 16:35:29 -0500 Subject: [PATCH 39/41] simplify slightly --- tests/cpp/operator/test_cast_nvfp4_transpose.cu | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 5a0691a0c..de4216bc5 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -586,9 +586,9 @@ void performTest(float (*OP)(const float), #ifndef __HIP_PLATFORM_AMD__ bool use_2d_quantization = false; - for (bool use_stochastic_rounding : {false}) { #else - // Test both 1D and 2D quantization paths on AMDGPU + // Test both 1D and 2D quantization paths on AMDGPU, + // as well as stochastic rounding. hipDeviceProp_t prop; hipGetDeviceProperties(&prop, 0); const bool is_gfx950 = std::string(prop.gcnArchName).find("gfx950") != std::string::npos; @@ -685,8 +685,6 @@ void performTest(float (*OP)(const float), #ifdef __HIP_PLATFORM_AMD__ } } -#else - } #endif } From a607feb45c39632d9b79aafefd478839a3b56df5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 19 Mar 2026 14:23:27 -0500 Subject: [PATCH 40/41] address review comments --- tests/cpp/test_common.cu | 5 +++-- tests/cpp/test_common.h | 15 +++++++-------- .../common/cast/dispatch/quantize.cuh | 17 ++++++++++++++--- .../quantize_transpose_vector_blockwise_fp4.cu | 4 ++-- transformer_engine/common/util/ptx.cuh | 9 +++++++++ 5 files changed, 35 insertions(+), 15 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index f2f169826..7e72b2fc7 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1228,9 +1228,10 @@ std::array get_scale_tensor_dims(const size_t rows, const bool is_rowwise = (block_size_rows == 1) && ((block_size_cols == 32) || (block_size_cols == 16)); - // On AMD, MXFP8 scales (block_size=32) are passed unpadded to hipBLASlt. - // NVFP4 scales (block_size=16) still require [128,4] padding for kernel indexing. #ifdef __HIP_PLATFORM_AMD__ + // On AMD, MXFP8 scales (block_size=32) are allocated unpadded to match + // TE's internal allocation (which avoids padding for hipBLASlt compatibility). + // NVFP4 scales (block_size=16) still require [128,4] padding for kernel indexing. const bool needs_padding = (block_size_cols == 16 || block_size_rows == 16); const size_t alignment_Y = needs_padding ? (is_rowwise ? nvfp4_scale_tensor_alignment_Y_rowwise : nvfp4_scale_tensor_alignment_Y_colwise) : 1; const size_t alignment_X = needs_padding ? (is_rowwise ? nvfp4_scale_tensor_alignment_X_rowwise : nvfp4_scale_tensor_alignment_X_colwise) : 1; diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 97670e384..2afbc60c9 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -12,20 +12,19 @@ #include #include #include - +#include #ifndef USE_ROCM #define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080) +#else +#define FP4_TYPE_SUPPORTED (true) +#endif + #include +#include #include #if FP4_TYPE_SUPPORTED #include -#endif //FP4_TYPE_SUPPORTED -#else -#define FP4_TYPE_SUPPORTED (true) -#include -#include "amd_detail/hip_float8.h" -#include -#endif //USE_ROCM +#endif #include #include diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 3e986466b..6e9f6811d 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -121,8 +121,8 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax : output_tensor->columnwise_amax; #ifdef __HIP_PLATFORM_AMD__ - // If amax was not explicitly set, fall back to the scale field which - // holds the same value when set via set_scale(). + // Fix for upstream bug: if amax was not explicitly set, fall back to the + // scale field which holds the same value when set via set_scale(). NVTE_CHECK(global_amax.dptr != nullptr || output_tensor->scale.dptr != nullptr, "NVFP4 quantization requires global_amax (output_tensor->amax) " "or scale to be set. Call output.set_scale(amax_value) before quantizing."); @@ -283,8 +283,19 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens #endif auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax : output_tensor->columnwise_amax; +#ifdef __HIP_PLATFORM_AMD__ + // Fix for upstream bug: if amax was not explicitly set, fall back to the + // scale field which holds the same value when set via set_scale(). + NVTE_CHECK(global_amax.dptr != nullptr || output_tensor->scale.dptr != nullptr, + "NVFP4 quantization requires global_amax (output_tensor->amax) " + "or scale to be set. Call output.set_scale(amax_value) before quantizing."); + const SimpleTensor& effective_amax = + (global_amax.dptr != nullptr) ? global_amax : output_tensor->scale; quantize_transpose_vector_blockwise_fp4( - /*input=*/grad_tensor->data, /*global_amax=*/global_amax, + /*input=*/input_tensor->data, /*global_amax=*/effective_amax, +#else + /*input=*/input_tensor->data, /*global_amax=*/global_amax, +#endif /*scale_inv=*/output_tensor->scale_inv, /*scale_inv_t=*/output_tensor->columnwise_scale_inv, /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 5508813d6..66afdffab 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -306,7 +306,7 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro "FP4 cvt.rs PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); #else -#ifdef __gfx950__ +#ifdef ARCH_HAS_STOCHASTIC_ROUNDING // opsel=1 always writes to byte 1, result read from fp4x2[1] union { uint32_t ui32; __hip_fp4x2_storage_t fp4x2[4]; } u{0}; __amd_floatx2_storage_t packed01{in01.x, in01.y}; @@ -318,7 +318,7 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro return static_cast(lo | (static_cast(hi) << 8)); #else NVTE_DEVICE_ERROR("FP4 stochastic rounding on AMDGPU requires gfx950 or later."); -#endif // __gfx950__ +#endif // ARCH_HAS_STOCHASTIC_ROUNDING #endif // !__HIP_PLATFORM_AMD__ uint16_t dummy = 0; #ifdef __HIP_PLATFORM_AMD__ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index fb4ccdd73..f8e426ab9 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -126,6 +126,15 @@ constexpr bool is_supported_arch() { #define ARCH_HAS_STOCHASTIC_ROUNDING \ NVTE_CUDA_ARCH_MATCHES(ptx::ArchSpecific<100>, ptx::ArchSpecific<103>) +#else + +// Native FP4 stochastic rounding is available on gfx950 and later. +#if defined(__gfx950__) +#define ARCH_HAS_STOCHASTIC_ROUNDING (true) +#else +#define ARCH_HAS_STOCHASTIC_ROUNDING (false) +#endif + #endif //#ifndef __HIP_PLATFORM_AMD__ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init From ca2e444da1d92c827245bbf215f08733b2d3eadc Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 19 Mar 2026 15:55:05 -0500 Subject: [PATCH 41/41] bugfix arch SR support --- .../common/transpose/quantize_transpose_vector_blockwise_fp4.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 66afdffab..a923919fc 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -306,7 +306,7 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro "FP4 cvt.rs PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); #else -#ifdef ARCH_HAS_STOCHASTIC_ROUNDING +#if ARCH_HAS_STOCHASTIC_ROUNDING // opsel=1 always writes to byte 1, result read from fp4x2[1] union { uint32_t ui32; __hip_fp4x2_storage_t fp4x2[4]; } u{0}; __amd_floatx2_storage_t packed01{in01.x, in01.y};