From ed2d2b32b1e8b02ef6d50068a01423ea0824f42b Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Tue, 3 Feb 2026 13:44:43 -0500 Subject: [PATCH 1/4] Make distinctive ROCm TE wheels names --- build_tools/build_ext.py | 8 +- build_tools/utils.py | 20 +++- .../wheel_utils/Dockerfile.rocm.manylinux.x86 | 4 +- build_tools/wheel_utils/build_wheels.sh | 106 ++++++++++-------- setup.py | 19 +++- transformer_engine/__init__.py | 4 +- transformer_engine/common/CMakeLists.txt | 4 +- transformer_engine/common/__init__.py | 16 +-- transformer_engine/jax/setup.py | 4 +- transformer_engine/pytorch/setup.py | 9 +- 10 files changed, 123 insertions(+), 71 deletions(-) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 8bcfc5a69..5d96a9287 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -25,7 +25,6 @@ cmake_bin, debug_build_enabled, found_ninja, - get_frameworks, nvcc_path, get_max_jobs_for_parallel_build, ) @@ -158,8 +157,11 @@ def run(self) -> None: def build_extensions(self): # For core lib + JAX install, fix build_ext from pybind11.setup_helpers # to handle CUDA files correctly. + # Upstream uses get_frameworks() here which is incorrectly works when install from + # release (sdist) wheel on a system with both frameworks installed. ext_names = [ext.name for ext in self.extensions] - if "transformer_engine_pytorch" not in ext_names: + if ("transformer_engine_torch" not in ext_names and + "transformer_engine_rocm_torch" not in ext_names): # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when # extra_compile_args is a dict. for ext in self.extensions: diff --git a/build_tools/utils.py b/build_tools/utils.py index e3c5b6be8..2f9b2e031 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -208,6 +208,7 @@ def rocm_build() -> bool: # If neither ROCm nor CUDA is detected, raise an error raise FileNotFoundError("Could not detect ROCm or CUDA platform") + @functools.lru_cache(maxsize=None) def rocm_path() -> Tuple[str, str]: """ROCm root path and HIPCC binary path as a tuple""" @@ -227,6 +228,18 @@ def rocm_path() -> Tuple[str, str]: return rocm_home, hipcc_bin +def rocm_version() -> Tuple[int, ...]: + """ROCm version as a (major, minor) tuple. + Try to get ROCm version by parsing .info/version. + """ + rocm_home, _ = rocm_path() + try: + with open(rocm_home / ".info" / "version", "r") as f: + rocm_version= f.read().strip().split('.')[:2] + return tuple(int(v) for v in rocm_version) + except FileNotFoundError: + raise RuntimeError("Could not determine ROCm version.") + def cuda_toolkit_include_path() -> Tuple[str, str]: """Returns root path for cuda toolkit includes. @@ -495,10 +508,13 @@ def uninstall_te_wheel_packages(): "pip", "uninstall", "-y", - "transformer_engine_rocm", # te_cuda_vers for ROCm build + "transformer_engine", "transformer_engine_cu12", "transformer_engine_torch", "transformer_engine_jax", + "transformer_engine_rocm", + "transformer_engine_rocm_jax", + "transformer_engine_rocm_torch", ] ) diff --git a/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 b/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 index dc0cd112b..318a0696f 100644 --- a/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 +++ b/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. @@ -45,4 +45,4 @@ COPY build_wheels.sh / WORKDIR /TransformerEngine/ RUN git clone https://github.com/ROCm/TransformerEngine.git /TransformerEngine -CMD ["/bin/bash", "/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true"] +CMD ["/bin/bash", "/build_wheels.sh", "manylinux_2_28_x86_64", "false", "true", "true", "true"] diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 4a6653479..eacb197d9 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -14,11 +14,13 @@ BUILD_JAX=${5:-true} export NVTE_RELEASE_BUILD=1 export TARGET_BRANCH=${TARGET_BRANCH:-} -mkdir -p /wheelhouse/logs + +WHEEL_ROOT=${WHEEL_ROOT:-/wheelhouse} +mkdir -p $WHEEL_ROOT/logs # Generate wheels for common library. -git config --global --add safe.directory /TransformerEngine -cd /TransformerEngine +TE_ROOT=${TE_ROOT:-/TransformerEngine} +cd $TE_ROOT #If there is default Python installation, use it PYTHON=`which python || true` @@ -29,90 +31,102 @@ else fi ROCM_BUILD=`${PYBINDIR}python -c "import build_tools.utils as u; print(int(u.rocm_build()))"` - -if [ "$LOCAL_TREE_BUILD" != "1" ]; then - if [ "$ROCM_BUILD" = "1" ]; then - git pull - fi - git checkout $TARGET_BRANCH - git submodule update --init --recursive +if [ "$ROCM_BUILD" = "1" ]; then + ROCM_BUILD=true else - git submodule status --recursive | cut -d' ' -f3 | xargs -l -P1 -I_SUB_ git config --global --add safe.directory /TransformerEngine/_SUB_ + ROCM_BUILD=false fi -if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install setuptools wheel +if [ "$LOCAL_TREE_BUILD" != "1" ]; then + git config --global --add safe.directory $TE_ROOT + if [ "$SKIP_REPO_UPDATE" = "1" ]; then + git submodule status --recursive | cut -d' ' -f3 | xargs -l -P1 -I_SUB_ git config --global --add safe.directory $TE_ROOT/_SUB_ + else + if [ $ROCM_BUILD ]; then + git pull + fi + git checkout $TARGET_BRANCH + git submodule update --init --recursive + fi fi # Install deps -if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install pybind11[global] ninja +if [ $ROCM_BUILD ]; then + ${PYBINDIR}pip install setuptools wheel pybind11[global] ninja else ${PYBINDIR}pip install cmake pybind11[global] ninja fi if $BUILD_METAPACKAGE ; then - cd /TransformerEngine - if [ "$ROCM_BUILD" != "1" ]; then + cd $TE_ROOT + if [ ! $ROCM_BUILD ]; then PYBINDIR=/opt/python/cp310-cp310/bin/ fi - NVTE_BUILD_METAPACKAGE=1 ${PYBINDIR}python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt - mv dist/* /wheelhouse/ + NVTE_BUILD_METAPACKAGE=1 ${PYBINDIR}python setup.py bdist_wheel 2>&1 | tee $WHEEL_ROOT/logs/metapackage.txt + mv dist/* $WHEEL_ROOT/ fi -if $BUILD_COMMON ; then +if $BUILD_COMMON -a $ROCM_BUILD; then + VERSION=`cat build_tools/VERSION.txt` + WHL_BASE="transformer_engine_rocm-${VERSION}" + #dataclasses, psutil are needed for AITER + ${PYBINDIR}pip install dataclasses psutil + #hipify expects python in PATH, also ninja may be installed to python bindir + test -n "$PYBINDIR" && PATH="$PYBINDIR:$PATH" || true + + # Create the wheel. + ${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee $WHEEL_ROOT/logs/common.txt + + # Rename the wheel to make it python version agnostic. + whl_name=$(basename dist/*) + IFS='-' read -ra whl_parts <<< "$whl_name" + whl_name_target="${whl_parts[0]}-${whl_parts[1]}-py3-none-${whl_parts[4]}" + mv dist/*.whl $WHEEL_ROOT/"$whl_name_target" + +elif $BUILD_COMMON; then VERSION=`cat build_tools/VERSION.txt` WHL_BASE="transformer_engine-${VERSION}" - if [ "$ROCM_BUILD" = "1" ]; then - TE_CUDA_VERS="rocm" - #dataclasses, psutil are needed for AITER - ${PYBINDIR}pip install dataclasses psutil - #hipify expects python in PATH, also ninja may be installed to python bindir - test -n "$PYBINDIR" && PATH="$PYBINDIR:$PATH" || true - else - TE_CUDA_VERS="cu12" - PYBINDIR=/opt/python/cp38-cp38/bin/ - fi + PYBINDIR=/opt/python/cp38-cp38/bin/ # Create the wheel. - ${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt + ${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee $WHEEL_ROOT/logs/common.txt # Repack the wheel for cuda specific package, i.e. cu12. ${PYBINDIR}wheel unpack dist/* # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). - sed -i "s/Name: transformer-engine/Name: transformer-engine-${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - sed -i "s/Name: transformer_engine/Name: transformer_engine_${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_${TE_CUDA_VERS}-${VERSION}.dist-info" + sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" ${PYBINDIR}wheel pack ${WHL_BASE} # Rename the wheel to make it python version agnostic. whl_name=$(basename dist/*) IFS='-' read -ra whl_parts <<< "$whl_name" - whl_name_target="${whl_parts[0]}_${TE_CUDA_VERS}-${whl_parts[1]}-py3-none-${whl_parts[4]}" + whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}" rm -rf $WHL_BASE dist - mv *.whl /wheelhouse/"$whl_name_target" + mv *.whl $WHEEL_ROOT/"$whl_name_target" fi if $BUILD_PYTORCH ; then - cd /TransformerEngine/transformer_engine/pytorch - if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/rocm6.3 + cd $TE_ROOT/transformer_engine/pytorch + if [ $ROCM_BUILD ]; then + ${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/cpu else PYBINDIR=/opt/python/cp38-cp38/bin/ ${PYBINDIR}pip install torch fi - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt - cp dist/* /wheelhouse/ + ${PYBINDIR}python setup.py sdist 2>&1 | tee $WHEEL_ROOT/logs/torch.txt + cp dist/* $WHEEL_ROOT/ fi if $BUILD_JAX ; then - cd /TransformerEngine/transformer_engine/jax - if [ "$ROCM_BUILD" = "1" ]; then + cd $TE_ROOT/transformer_engine/jax + if [ $ROCM_BUILD ]; then ${PYBINDIR}pip install jax else PYBINDIR=/opt/python/cp310-cp310/bin/ ${PYBINDIR}pip install "jax[cuda12_local]" jaxlib fi - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt - cp dist/* /wheelhouse/ + ${PYBINDIR}python setup.py sdist 2>&1 | tee $WHEEL_ROOT/logs/jax.txt + cp dist/* $WHEEL_ROOT/ fi diff --git a/setup.py b/setup.py index 1ae476311..bdde17b38 100644 --- a/setup.py +++ b/setup.py @@ -182,12 +182,11 @@ def setup_requirements() -> Tuple[List[str], List[str]]: assert bool( int(os.getenv("NVTE_RELEASE_BUILD", "0")) ), "NVTE_RELEASE_BUILD env must be set for metapackage build." - te_cuda_vers = "rocm" if rocm_build() else "cu12" ext_modules = [] cmdclass = {} package_data = {} include_package_data = False - install_requires = ([f"transformer_engine_{te_cuda_vers}=={__version__}"],) + install_requires = ([f"transformer_engine_cu12=={__version__}"],) extras_require = { "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], @@ -222,9 +221,21 @@ def setup_requirements() -> Tuple[List[str], List[str]]: ) ) + PACKAGE_NAME="transformer_engine" + if rocm_build(): + if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))): + install_requires = ([f"transformer_engine_rocm=={__version__}"],) + else: + PACKAGE_NAME="transformer_engine_rocm" + #On ROCm make add extra to core package too so it can be installed w/o metapackage + extras_require = { + "pytorch": [f"transformer_engine_rocm_torch=={__version__}"], + "jax": [f"transformer_engine_rocm_jax=={__version__}"], + } # Configure package setuptools.setup( - name="transformer_engine", + name=PACKAGE_NAME, version=__version__, packages=setuptools.find_packages( include=[ @@ -239,7 +250,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]: long_description_content_type="text/x-rst", ext_modules=ext_modules, cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, - python_requires=">=3.8", + python_requires=">=3.9", classifiers=["Programming Language :: Python :: 3"], install_requires=install_requires, license_files=("LICENSE",), diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index 050abc8f7..9a3758a1e 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -83,4 +83,4 @@ category=RuntimeWarning, ) -__version__ = str(metadata.version("transformer_engine")) +__version__ = str(metadata.version("transformer_engine_rocm" if transformer_engine.common.te_rocm_build else "transformer_engine")) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cefec6d06..312df074d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -473,9 +473,11 @@ if (USE_ROCM) file(READ "${ROCM_PATH}/.info/version" ROCM_VER) string(STRIP "${ROCM_VER}" ROCM_VER) string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER}") + get_git_commit("${TE}" TE_COMMIT_ID) file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt" "ROCM_VERSION: ${ROCM_VER}\n" "GPU_TARGETS: ${CMAKE_HIP_ARCHITECTURES}\n" + "COMMIT_ID: ${TE_COMMIT_ID}\n" ) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt" DESTINATION "transformer_engine/") endif() diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 26672bafd..02497fcb5 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -133,7 +133,7 @@ def load_framework_extension(framework: str) -> None: if framework == "torch": extra_dep_name = "pytorch" - te_cuda_vers = "rocm" if te_rocm_build else "cu12" + te_core_tag = "rocm" if te_rocm_build else "cu12" # If the framework extension pip package is installed, it means that TE is installed via # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework @@ -143,24 +143,24 @@ def load_framework_extension(framework: str) -> None: "transformer_engine" ), "Could not find `transformer-engine`." assert _is_pip_package_installed( - f"transformer_engine_{te_cuda_vers}" - ), f"Could not find `transformer-engine-{te_cuda_vers}`." + f"transformer_engine_{te_core_tag}" + ), f"Could not find `transformer-engine-{te_core_tag}`." assert ( version(module_name) == version("transformer-engine") - == version(f"transformer-engine-{te_cuda_vers}") + == version(f"transformer-engine-{te_core_tag}") ), ( "TransformerEngine package version mismatch. Found" f" {module_name} v{version(module_name)}, transformer-engine" - f" v{version('transformer-engine')}, and transformer-engine-{te_cuda_vers}" - f" v{version(f'transformer-engine-{te_cuda_vers}')}. Install transformer-engine using " + f" v{version('transformer-engine')}, and transformer-engine-{te_core_tag}" + f" v{version(f'transformer-engine-{te_core_tag}')}. Install transformer-engine using " f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'" ) # If the core package is installed via PyPI, log if # the framework extension is not found from PyPI. # Note: Should we error? This is a rare use case. - if _is_pip_package_installed(f"transformer-engine-{te_cuda_vers}"): + if _is_pip_package_installed(f"transformer-engine-{te_core_tag}"): if not _is_pip_package_installed(module_name): _logger.info( "Could not find package %s. Install transformer-engine using " diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index b58d2df7f..28a0ec029 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -99,7 +99,7 @@ # Configure package setuptools.setup( - name="transformer_engine_jax", + name="transformer_engine_rocm_jax" if rocm_build() else "transformer_engine_jax", version=te_version(), description="Transformer acceleration library - Jax Lib", ext_modules=ext_modules, diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index e86873b12..da6906476 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -56,6 +56,9 @@ test_requirements, ) +if rocm_build(): + PACKAGE_NAME = "transformer_engine_rocm_torch" + os.environ["NVTE_PROJECT_BUILDING"] = "1" CMakeBuildExtension = get_build_ext(BuildExtension, True) @@ -112,6 +115,10 @@ class CachedWheelsCommand(_bdist_wheel): """ def run(self): + if rocm_build(): + print("ROCm build detected, building from source...") + return super().run() + if FORCE_BUILD: super().run() From dc8efdbc1721aef6c04ca7475a7dccfa0ac71c38 Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Wed, 4 Feb 2026 00:04:26 -0500 Subject: [PATCH 2/4] Fix dev (monolithic) installation --- build_tools/wheel_utils/build_wheels.sh | 2 +- transformer_engine/__init__.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index eacb197d9..076273c0b 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -39,7 +39,7 @@ fi if [ "$LOCAL_TREE_BUILD" != "1" ]; then git config --global --add safe.directory $TE_ROOT - if [ "$SKIP_REPO_UPDATE" = "1" ]; then + if [ "$NO_REPO_UPDATE" = "1" ]; then git submodule status --recursive | cut -d' ' -f3 | xargs -l -P1 -I_SUB_ git config --global --add safe.directory $TE_ROOT/_SUB_ else if [ $ROCM_BUILD ]; then diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index 9a3758a1e..da8c33749 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -83,4 +83,9 @@ category=RuntimeWarning, ) -__version__ = str(metadata.version("transformer_engine_rocm" if transformer_engine.common.te_rocm_build else "transformer_engine")) +try: + __version__ = str(metadata.version("transformer_engine")) +except metadata.PackageNotFoundError: + if not transformer_engine.common.te_rocm_build: + raise + __version__ = str(metadata.version("transformer_engine_rocm")) From cebfd811a31c7621022d0899d168f0c7d6d10ba2 Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Fri, 13 Mar 2026 23:31:47 -0400 Subject: [PATCH 3/4] Add ROCm major version to wheel package. Documentation update --- README.rst | 67 ++------------ build_tools/utils.py | 2 +- docs/installation.rst | 126 +++++++++++++++++++++++++- setup.py | 10 +- transformer_engine/__init__.py | 4 +- transformer_engine/common/__init__.py | 2 +- transformer_engine/jax/setup.py | 4 +- transformer_engine/pytorch/setup.py | 4 +- 8 files changed, 148 insertions(+), 71 deletions(-) diff --git a/README.rst b/README.rst index da0b86c27..40074e352 100644 --- a/README.rst +++ b/README.rst @@ -1,7 +1,7 @@ .. This file was modified to include portability information to AMDGPU. - Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. @@ -28,70 +28,21 @@ Feature Support Status Installation ============ -Install from manylinux wheels -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Starting from ROCm 7.0, we provide manylinux wheels for Transformer Engine releases on `https://repo.radeon.com/rocm/manylinux`. For example, the wheels for ROCm 7.1.1 are at `https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/`. From the page, you can find four files related to Transformer Engine: - -* transformer_engine_rocm-*-py3-none-manylinux_2_28_x86_64.whl - This is the wheel file for installing the common library. It should not be installed by itself. -* transformer_engine-*-py3-none-any.whl - This is the wheel file for installing the common TE Python package. -* transformer_engine_jax-*.tar.gz - This is the source tar ball for the JAX extension. -* transformer_engine_torch-*.tar.gz - This is the source tar ball for the Pytorch extension. - -Below are the example commands to download and install the wheels. They install both Pytorch and JAX extensions on the system where both frameworks are installed. - -.. code-block:: bash - - wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine_rocm-2.2.0-py3-none-manylinux_2_28_x86_64.whl - wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine-2.2.0-py3-none-any.whl - wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine_jax-2.2.0.tar.gz - wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine_torch-2.2.0.tar.gz - - pip install ./transformer_engine* --no-build-isolation - -Install TE from source -^^^^^^^^^^^^^^^^^^ - -Execute the following commands to install ROCm Transformer Engine from source on AMDGPUs: - -.. code-block:: bash +See docs/installation.rst for detailed installation instructions on ROCm and AMDGPU. +For addtional build configuration parameters see `Fused Attention Backends on ROCm` section below. - # Clone TE repo and submodules - git clone --recursive https://github.com/ROCm/TransformerEngine.git - - cd TransformerEngine - export NVTE_FRAMEWORK=pytorch,jax #optionally set framework, currently only support pytorch and jax; if not set will try to detect installed frameworks - export NVTE_ROCM_ARCH="gfx942;gfx950" # gfx942 for support of MI300/MI325, and gfx950 for support of MI350 - - # Build Platform Selection (optional) - # Note: Useful when both ROCm and CUDA platforms are present in the Docker - export NVTE_USE_ROCM=1 #Use 1 for ROCm, or set to 0 to use CUDA; If not set will try to detect installed platform, prioritizing ROCm - - pip install . --no-build-isolation - -It is also possible to build wheels for later installation with "pip wheel ." although those wheels will not be portable to systems with -different libraries installed. If the build still fails with the "--no-build-isolation" flag try installing setuptools<80.0.0 - -Note on Switching between Installation from Source and Installation from Wheels -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Sometimes, issues might occur when installing from source on a system where a previous installation with wheels, or vice versa. It is safe to uninstall TE first before -switching between installing from source and installing from wheels. Here is the example command: - -.. code-block:: bash +AITER rebuilding +^^^^^^^^^^^^^^^^ - # The package name pattern might be transformer_engine or transformer-engine depending on setuptools version - pip list | grep transformer.engine | xargs pip uninstall -y +TE uses AITER submodule as fused attention backend on ROCm. Rebuilding of this library takes a long time so build scripts cache the built library in `build/aiter-prebuilts`. If you want rebuild AITER, delete the cache and rebuild TE. Known Issue with ROCm 6.4 PyTorch Release -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Using the docker image ``rocm/pytorch:rocm6.4_ubuntu22.04_py3.10_pytorch_release_2.5.1`` triggers a failure in the unit-test ``tests/pytorch/test_permutation.py`` (tracked in Jira ticket SWDEV-534311). Rebuilding PyTorch at commit ``f929e0d602a71aa393ca2e6097674b210bdf321c`` resolves the issue. -Re-install PyTorch -^^^^^^^^^^^^^^^^^^ - .. code-block:: bash # Remove the pre-installed pytorch @@ -299,7 +250,7 @@ Certain settings can be enabled to potentially optimize workloads depending on t * NVTE_CK_ZERO_OUT_PAD - by default 1, if set to 0 then the output of the FA forward pass will not be initialized to zero, meaning invalid regions (representing padding) may take nonzero values. Only used if input has padding. AITER FA v3 Kernels -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~ ROCm TE supports flash-attention v3 fwd/bwd kernels on gfx942 and gfx950 using AITER backend. This functionality can be controlled by the following environment variables: @@ -309,7 +260,7 @@ This functionality can be controlled by the following environment variables: * NVTE_CK_HOW_V3_BF16_CVT - by default 1, float to bf16 convert type when v3 is enabled, 0:RTNE; 1:RTNA; 2:RTZ, only applicable to the gfx942 architecture. Float to BFloat16 Conversion in CK Backend (gfx942 only) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ How fp32 converts to bf16 affects both the performance and accuracy in ck fused attn. ROCm TE provides the compile-time env NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT with the following values available to choose from: diff --git a/build_tools/utils.py b/build_tools/utils.py index edb6c2f6d..4e4492fc9 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -504,7 +504,7 @@ def uninstall_te_wheel_packages(): "transformer_engine_cu12", "transformer_engine_torch", "transformer_engine_jax", - "transformer_engine_rocm", + "transformer_engine_rocm7", "transformer_engine_rocm_jax", "transformer_engine_rocm_torch", ] diff --git a/docs/installation.rst b/docs/installation.rst index a8bb74fd1..df9a00474 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -1,10 +1,132 @@ .. + This file was modified to include portability information to AMDGPU. + + Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. -Installation -============ +Installation on AMD GPUs +======================== + +Prerequisites +------------- +1. `AMD Instinct GPU `__. Other GPUs are not supported while they can still work. +2. Linux x86_64 +3. `ROCm stack `__ + 3.1. For ROCm TheRock (ROCm 7.11 and newer), install amdrocm-core-sdk* package + +Additional Prerequisites +^^^^^^^^^^^^^^^^^^^^^^^^ + +1. [For PyTorch support] `https://rocm.docs.amd.com/projects/install-on-linux/en/develop/install/3rd-party/pytorch-install.html`__ +2. [For JAX support] `https://rocm.docs.amd.com/projects/install-on-linux/en/develop/install/3rd-party/jax-install.html`__ + +if HIP compiler complains it cannot detect the platform set `HIP_PLATFORM=amd` in the environment. +if ROCm is installed in a non-standard location, set `ROCM_PATH` to the root of the ROCm installation in the environment, e.g. `ROCM_PATH=/opt/venv/lib/python3.12/site-packages/_rocm_sdk_devel` and additonally set the following environment variables: + +- `HIP_DEVICE_LIB_PATH=$ROCM_PATH/llvm/amdgcn/bitcode/` +- `CMAKE_PREFIX_PATH=$ROCM_PATH/lib/cmake/`` + +pip - from wheels +----------------- + +Transformer Engine for ROCm 7.0 and newer can be installed from manylinux wheels published at `https://repo.radeon.com/rocm/manylinux`. For example, the wheels for ROCm 7.2 are at `https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/`. Four files related to Transformer Engine can be found there: + +- transformer_engine-\*-py3-none-any.whl - the main TE pure Python metapackage. +- transformer_engine_rocm-\*-py3-none-manylinux_2_28_x86_64.whl - the core library package. +- transformer_engine_jax-\*.tar.gz - source tarball (sdist) for the JAX extension. +- transformer_engine_torch-\*.tar.gz - source tarball (sdist) for the Pytorch extension. + +Below are the example commands to download and install the wheels. They install both Pytorch and JAX extensions on the system where both frameworks are installed. + +.. code-block:: bash + + wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/transformer_engine_rocm-2.4.0-py3-none-manylinux_2_28_x86_64.whl + wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/transformer_engine-2.4.0-py3-none-any.whl + wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/transformer_engine_jax-2.4.0.tar.gz + wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/transformer_engine_torch-2.4.0.tar.gz + + pip install ./transformer_engine* --no-build-isolation + +Starting from version 2.10 core library wheel can be installed by itself or as an extenstion for TE Python metapackage. + +Example of installing ROCm core library wheel without any framework extensions: + +.. code-block:: bash + + pip install transformer_engine_rocm7-2.10.0-py3-none-manylinux_2_28_x86_64.whl + +Additionaly install framework extensions using ROCm package name and pip extras syntax. + +.. code-block:: bash + + pip install --find-links transformer_engine_rocm7[pytorch,jax] --no-build-isolation + +Installing the common library and frameworks extensions as extras for TE Python metapackage + +.. code-block:: bash + + pip install ./transformer_engine-2.4.0-py3-none-any.whl[rocm7,pytorch,jax] --no-build-isolation + +It is not recommended to install TE Python metapackage using just package name transformer_engine because of possible installing of the NVIDIA GPU version. It is recommended to use either transformer_engine_rocm7 or wheel file name to make sure the correct common library is installed. + + +Installation from source +^^^^^^^^^^^^^^^^^^^^^^^^^^ +Execute the following commands to install Transformer Engine from source: + +.. code-block:: bash + + # Clone repository, checkout stable branch, clone submodules + git clone --recursive https://github.com/ROCm/TransformerEngine.git + + cd TransformerEngine + export NVTE_FRAMEWORK=pytorch,jax # Optionally set framework(s) + export NVTE_ROCM_ARCH="gfx942;gfx950" # Optionally set target GPU achs; gfx942 for MI300/MI325, and gfx950 for MI350 + export NVTE_USE_ROCM=1 # Optionally force building for ROCm, useful when both ROCm and CUDA build environments are installed. If set to 0, it will force building for CUDA. + pip3 install --no-build-isolation . # Build and install + +Or instead of immediate istall, create wheel file for later installation: + +.. code-block:: bash + + pip wheel . --no-build-isolation + pip3 install ./transformer_engine-*.whl + +If the Git repository has already been cloned, make sure to also clone the submodules: + +.. code-block:: bash + + git submodule update --init --recursive + +Extra dependencies for testing can be installed by setting the "test" option: + +.. code-block:: bash + + pip3 install --no-build-isolation .[test] + +To build the C++ extensions with debug symbols, e.g. with the `-g` flag: + +.. code-block:: bash + + NVTE_BUILD_DEBUG=1 pip3 install --no-build-isolation . + + +Switching between Installation from Source and Installation from Wheels +----------------------------------------------------------------------- +Sometimes, issues might occur when installing from source on a system where a previous installation with wheels, or vice versa. It is safe to uninstall TE first before +switching between installing from source and installing from wheels. Here is the example command: + +.. code-block:: bash + + # The package name pattern might be transformer_engine or transformer-engine depending on setuptools version + pip list | grep transformer.engine | xargs pip uninstall -y + + +Installation on NVIDIA GPUs +=========================== Prerequisites ------------- diff --git a/setup.py b/setup.py index aa6653db5..8b5826f38 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ from build_tools.te_version import te_version from build_tools.utils import ( rocm_build, + rocm_version, all_files_in_dir, cuda_archs, cuda_version, @@ -250,9 +251,10 @@ def git_check_submodules() -> None: "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], } if not rocm_build() else { - "rocm": [f"transformer_engine_rocm=={__version__}"], - "pytorch": [f"transformer_engine_rocm[pytorch]=={__version__}"], - "jax": [f"transformer_engine_rocm[jax]=={__version__}"], + "rocm": [f"transformer_engine_rocm7=={__version__}"], + "rocm7": [f"transformer_engine_rocm7=={__version__}"], + "pytorch": [f"transformer_engine_rocm7[pytorch]=={__version__}"], + "jax": [f"transformer_engine_rocm7[jax]=={__version__}"], } else: install_requires, test_requires = setup_requirements() @@ -290,7 +292,7 @@ def git_check_submodules() -> None: PACKAGE_NAME="transformer_engine" if (rocm_build() and bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) and not bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))) ): - PACKAGE_NAME="transformer_engine_rocm" + PACKAGE_NAME=f"transformer_engine_rocm{rocm_version()[0]}" #On ROCm add extras to core package so it can be installed w/o metapackage extras_require.update({ "pytorch": [f"transformer_engine_rocm_torch=={__version__}"], diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index da8c33749..f1476ce27 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -88,4 +88,6 @@ except metadata.PackageNotFoundError: if not transformer_engine.common.te_rocm_build: raise - __version__ = str(metadata.version("transformer_engine_rocm")) + _te_core_installed, _, __version__ = transformer_engine.common.get_te_core_package_info() + if not _te_core_installed: + raise diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 63c524404..cc9e04ca5 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -141,7 +141,7 @@ def get_te_core_package_info() -> Tuple[bool, str, str]: te_core_packages = ("transformer-engine-cu12", "transformer-engine-cu13") if te_rocm_build: - te_core_packages = ("transformer-engine-rocm",) + te_core_packages = ("transformer-engine-rocm7",) for package in te_core_packages: if _is_package_installed(package): return True, package, version(package) diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index ad7619db6..e9f04658a 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -47,7 +47,7 @@ from build_tools.build_ext import get_build_ext, SdistWithLocalVersion from build_tools.utils import copy_common_headers, min_python_version_str -from build_tools.utils import rocm_build +from build_tools.utils import rocm_build, rocm_version from build_tools.te_version import te_version from build_tools.jax import setup_jax_extension, install_requirements, test_requirements @@ -129,7 +129,7 @@ def get_cuda_major_version() -> int: if not rocm_build(): te_core = f"transformer_engine_cu{get_cuda_major_version()}=={__version__}" else: - te_core = f"transformer_engine_rocm=={__version__}" + te_core = f"transformer_engine_rocm{rocm_version()[0]}=={__version__}" install_requires = install_requirements() + [te_core] # Configure package diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 5ef1b9d1c..185dd41ce 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -47,7 +47,7 @@ from build_tools.build_ext import get_build_ext, SdistWithLocalVersion -from build_tools.utils import rocm_build +from build_tools.utils import rocm_build, rocm_version from build_tools.utils import copy_common_headers, min_python_version_str from build_tools.te_version import te_version from build_tools.pytorch import ( @@ -169,7 +169,7 @@ def run(self): te_core = f"transformer_engine_cu{cuda_major_version}=={__version__}" install_requires = install_requirements() + [te_core] else: - te_core = f"transformer_engine_rocm=={__version__}" + te_core = f"transformer_engine_rocm{rocm_version()[0]}=={__version__}" install_requires = install_requirements() + [te_core] # Configure package From 6b6a664848292cdae443b9e3d0978cb73ac1d236 Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Tue, 17 Mar 2026 23:27:00 -0400 Subject: [PATCH 4/4] Make extension extras uniq. Build metapackage by default --- build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 | 2 +- docs/installation.rst | 2 +- setup.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 b/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 index 318a0696f..6b908f9bc 100644 --- a/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 +++ b/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 @@ -45,4 +45,4 @@ COPY build_wheels.sh / WORKDIR /TransformerEngine/ RUN git clone https://github.com/ROCm/TransformerEngine.git /TransformerEngine -CMD ["/bin/bash", "/build_wheels.sh", "manylinux_2_28_x86_64", "false", "true", "true", "true"] +CMD ["/bin/bash", "/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true"] diff --git a/docs/installation.rst b/docs/installation.rst index c3bcd9558..b7b0388cc 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -68,7 +68,7 @@ Installing the common library and frameworks extensions as extras for TE Python .. code-block:: bash - pip install ./transformer_engine-2.4.0-py3-none-any.whl[rocm7,pytorch,jax] --no-build-isolation + pip install ./transformer_engine-2.4.0-py3-none-any.whl[rocm7,rocm_pytorch,rocm_jax] --no-build-isolation It is not recommended to install TE Python metapackage using just package name transformer_engine because of possible installing of the NVIDIA GPU version. It is recommended to use either transformer_engine_rocm7 or wheel file name to make sure the correct common library is installed. diff --git a/setup.py b/setup.py index 8980b2106..d201641f3 100644 --- a/setup.py +++ b/setup.py @@ -253,8 +253,8 @@ def git_check_submodules() -> None: } if not rocm_build() else { "rocm": [f"transformer_engine_rocm7=={__version__}"], "rocm7": [f"transformer_engine_rocm7=={__version__}"], - "pytorch": [f"transformer_engine_rocm7[pytorch]=={__version__}"], - "jax": [f"transformer_engine_rocm7[jax]=={__version__}"], + "rocm_pytorch": [f"transformer_engine_rocm7[pytorch]=={__version__}"], + "rocm_jax": [f"transformer_engine_rocm7[jax]=={__version__}"], } else: install_requires, test_requires = setup_requirements()