diff --git a/README.rst b/README.rst index 9f038887d..2a6d88dd9 100644 --- a/README.rst +++ b/README.rst @@ -1,8 +1,6 @@ .. This file was modified to include portability information to AMDGPU. - Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. - Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. @@ -28,70 +26,21 @@ Feature Support Status Installation ============ -Install from manylinux wheels -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Starting from ROCm 7.0, we provide manylinux wheels for Transformer Engine releases on `https://repo.radeon.com/rocm/manylinux`. For example, the wheels for ROCm 7.1.1 are at `https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/`. From the page, you can find four files related to Transformer Engine: - -* transformer_engine_rocm-*-py3-none-manylinux_2_28_x86_64.whl - This is the wheel file for installing the common library. It should not be installed by itself. -* transformer_engine-*-py3-none-any.whl - This is the wheel file for installing the common TE Python package. -* transformer_engine_jax-*.tar.gz - This is the source tar ball for the JAX extension. -* transformer_engine_torch-*.tar.gz - This is the source tar ball for the Pytorch extension. - -Below are the example commands to download and install the wheels. They install both Pytorch and JAX extensions on the system where both frameworks are installed. - -.. code-block:: bash - - wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine_rocm-2.2.0-py3-none-manylinux_2_28_x86_64.whl - wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine-2.2.0-py3-none-any.whl - wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine_jax-2.2.0.tar.gz - wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine_torch-2.2.0.tar.gz - - pip install ./transformer_engine* --no-build-isolation - -Install TE from source -^^^^^^^^^^^^^^^^^^ - -Execute the following commands to install ROCm Transformer Engine from source on AMDGPUs: - -.. code-block:: bash - - # Clone TE repo and submodules - git clone --recursive https://github.com/ROCm/TransformerEngine.git - - cd TransformerEngine - export NVTE_FRAMEWORK=pytorch,jax #optionally set framework, currently only support pytorch and jax; if not set will try to detect installed frameworks - export NVTE_ROCM_ARCH="gfx942;gfx950" # gfx942 for support of MI300/MI325, and gfx950 for support of MI350 - - # Build Platform Selection (optional) - # Note: Useful when both ROCm and CUDA platforms are present in the Docker - export NVTE_USE_ROCM=1 #Use 1 for ROCm, or set to 0 to use CUDA; If not set will try to detect installed platform, prioritizing ROCm - - pip install . --no-build-isolation - -It is also possible to build wheels for later installation with "pip wheel ." although those wheels will not be portable to systems with -different libraries installed. If the build still fails with the "--no-build-isolation" flag try installing setuptools<80.0.0 - -Note on Switching between Installation from Source and Installation from Wheels -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Sometimes, issues might occur when installing from source on a system where a previous installation with wheels, or vice versa. It is safe to uninstall TE first before -switching between installing from source and installing from wheels. Here is the example command: +See docs/installation.rst for detailed installation instructions on ROCm and AMDGPU. +For addtional build configuration parameters see `Fused Attention Backends on ROCm` section below. -.. code-block:: bash +AITER rebuilding +^^^^^^^^^^^^^^^^ - # The package name pattern might be transformer_engine or transformer-engine depending on setuptools version - pip list | grep transformer.engine | xargs pip uninstall -y +TE uses AITER submodule as fused attention backend on ROCm. Rebuilding of this library takes a long time so build scripts cache the built library in `build/aiter-prebuilts`. If you want rebuild AITER, delete the cache and rebuild TE. Known Issue with ROCm 6.4 PyTorch Release -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Using the docker image ``rocm/pytorch:rocm6.4_ubuntu22.04_py3.10_pytorch_release_2.5.1`` triggers a failure in the unit-test ``tests/pytorch/test_permutation.py`` (tracked in Jira ticket SWDEV-534311). Rebuilding PyTorch at commit ``f929e0d602a71aa393ca2e6097674b210bdf321c`` resolves the issue. -Re-install PyTorch -^^^^^^^^^^^^^^^^^^ - .. code-block:: bash # Remove the pre-installed pytorch @@ -299,7 +248,7 @@ Certain settings can be enabled to potentially optimize workloads depending on t * NVTE_CK_ZERO_OUT_PAD - by default 1, if set to 0 then the output of the FA forward pass will not be initialized to zero, meaning invalid regions (representing padding) may take nonzero values. Only used if input has padding. AITER FA v3 Kernels -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~ ROCm TE supports flash-attention v3 fwd/bwd kernels on gfx942 and gfx950 using AITER backend. This functionality can be controlled by the following environment variables: @@ -309,7 +258,7 @@ This functionality can be controlled by the following environment variables: * NVTE_CK_HOW_V3_BF16_CVT - by default 1, float to bf16 convert type when v3 is enabled, 0:RTNE; 1:RTNA; 2:RTZ, only applicable to the gfx942 architecture. Float to BFloat16 Conversion in CK Backend (gfx942 only) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ How fp32 converts to bf16 affects both the performance and accuracy in ck fused attn. ROCm TE provides the compile-time env NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT with the following values available to choose from: diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index ce4655cbb..6a2f35aaa 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -27,7 +27,6 @@ cmake_bin, debug_build_enabled, found_ninja, - get_frameworks, nvcc_path, get_max_jobs_for_parallel_build, ) @@ -161,8 +160,11 @@ def run(self) -> None: def build_extensions(self): # For core lib + JAX install, fix build_ext from pybind11.setup_helpers # to handle CUDA files correctly. + # Upstream uses get_frameworks() here which is incorrectly works when install from + # release (sdist) wheel on a system with both frameworks installed. ext_names = [ext.name for ext in self.extensions] - if "transformer_engine_pytorch" not in ext_names: + if ("transformer_engine_torch" not in ext_names and + "transformer_engine_rocm_torch" not in ext_names): # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when # extra_compile_args is a dict. for ext in self.extensions: diff --git a/build_tools/utils.py b/build_tools/utils.py index 6f1622d69..2cb7a3768 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -227,6 +227,7 @@ def rocm_build() -> bool: # If neither ROCm nor CUDA is detected, raise an error raise FileNotFoundError("Could not detect ROCm or CUDA platform") + @functools.lru_cache(maxsize=None) def rocm_path() -> Tuple[str, str]: """ROCm root path and HIPCC binary path as a tuple""" @@ -246,6 +247,18 @@ def rocm_path() -> Tuple[str, str]: return rocm_home, hipcc_bin +def rocm_version() -> Tuple[int, ...]: + """ROCm version as a (major, minor) tuple. + Try to get ROCm version by parsing .info/version. + """ + rocm_home, _ = rocm_path() + try: + with open(rocm_home / ".info" / "version", "r") as f: + rocm_version= f.read().strip().split('.')[:2] + return tuple(int(v) for v in rocm_version) + except FileNotFoundError: + raise RuntimeError("Could not determine ROCm version.") + def cuda_toolkit_include_path() -> Tuple[str, str]: """Returns root path for cuda toolkit includes. @@ -487,9 +500,12 @@ def uninstall_te_wheel_packages(): "pip", "uninstall", "-y", - "transformer_engine_rocm", # te_cuda_vers for ROCm build + "transformer_engine", "transformer_engine_cu12", "transformer_engine_torch", "transformer_engine_jax", + "transformer_engine_rocm7", + "transformer_engine_rocm_jax", + "transformer_engine_rocm_torch", ] ) diff --git a/docs/installation.rst b/docs/installation.rst index dd3064b62..15f482cda 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -1,10 +1,121 @@ .. + This file was modified to include portability information to AMDGPU. + Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. -Installation -============ +Installation on AMD GPUs +======================== + +Prerequisites +------------- +1. `AMD Instinct GPU `__. Other GPUs are not supported while they can still work. +2. Linux x86_64 +3. `ROCm stack `__. For ROCm TheRock (ROCm 7.11 and newer), install amdrocm-core-sdk* package + +Additional Prerequisites +^^^^^^^^^^^^^^^^^^^^^^^^ + +1. [For PyTorch support] `Pytorch `__ +2. [For JAX support] `JAX `__ + +If HIP compiler complains it cannot detect the platform set `HIP_PLATFORM=amd` in the environment. +If ROCm is installed in a non-standard location, set `ROCM_PATH` to the root of the ROCm installation in the environment, e.g. `ROCM_PATH=/opt/venv/lib/python3.12/site-packages/_rocm_sdk_devel` and additionally set the following environment variables: + +- HIP_DEVICE_LIB_PATH=$ROCM_PATH/llvm/amdgcn/bitcode/ +- CMAKE_PREFIX_PATH=$ROCM_PATH/lib/cmake/ + +pip - from wheels +----------------- + +Transformer Engine for ROCm 7.0 and newer can be installed from `Manylinux wheels `__. Four files related to Transformer Engine can be found there: + +- transformer_engine-\*-py3-none-any.whl - the main TE pure Python metapackage. +- transformer_engine_rocm-\*-py3-none-manylinux_2_28_x86_64.whl - the core library package. +- transformer_engine_jax-\*.tar.gz - source tarball (sdist) for the JAX extension. +- transformer_engine_torch-\*.tar.gz - source tarball (sdist) for the Pytorch extension. + +Below are the example commands to download and install the wheels published with ROCm 7.2. They install both Pytorch and JAX extensions on the system where both frameworks are installed. + +.. code-block:: bash + + wget -r -l1 -nd -A 'transformer_engine*' https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/ + pip install ./transformer_engine* --no-build-isolation + +Starting from version 2.10, core library wheel can be installed by itself or as an extension for TE Python metapackage. + +Example of installing ROCm core library wheel without any framework extensions: + +.. code-block:: bash + + pip install transformer_engine_rocm7-2.10.0-py3-none-manylinux_2_28_x86_64.whl + +Additionally install framework extensions using ROCm package name and pip extras syntax. + +.. code-block:: bash + + pip install --find-links transformer_engine_rocm7[pytorch,jax] --no-build-isolation + +Installing the common library and frameworks extensions as extras for TE Python metapackage + +.. code-block:: bash + + pip install ./transformer_engine-2.10.0-py3-none-any.whl[rocm7,rocm_pytorch,rocm_jax] --no-build-isolation + +It is not recommended to install TE Python metapackage using just package name transformer_engine because of possible installing of the NVIDIA GPU version. It is recommended to use either transformer_engine_rocm7 or wheel file name to make sure the correct common library is installed. + + +Installation from source +^^^^^^^^^^^^^^^^^^^^^^^^^^ +Execute the following commands to install Transformer Engine from source: + +.. code-block:: bash + + # Clone repository, checkout stable branch, clone submodules + git clone --recursive https://github.com/ROCm/TransformerEngine.git + + cd TransformerEngine + export NVTE_FRAMEWORK=pytorch,jax # Optionally set framework(s) + export NVTE_ROCM_ARCH="gfx942;gfx950" # Optionally set target GPU architectures; gfx942 for MI300/MI325, and gfx950 for MI350 + export NVTE_USE_ROCM=1 # Optionally force building for ROCm, useful when both ROCm and CUDA build environments are installed. If set to 0, it will force building for CUDA. + pip3 install --no-build-isolation . # Build and install + +Or instead of immediate installation, create wheel file to install it later: + +.. code-block:: bash + + pip wheel . --no-build-isolation + pip3 install ./transformer_engine-*.whl + +If the Git repository has already been cloned, make sure the submodules do not have any local changes, otherwise the build will try to reset them unless `NVTE_SKIP_SUBMODULE_CHECKS_DURING_BUILD=1` is set. + +Extra dependencies for testing can be installed by setting the "test" option: + +.. code-block:: bash + + pip3 install --no-build-isolation .[test] + +To build the C++ extensions with debug symbols, e.g. with the `-g` flag: + +.. code-block:: bash + + NVTE_BUILD_DEBUG=1 pip3 install --no-build-isolation . + + +Switching between Installation from Source and Installation from Wheels +----------------------------------------------------------------------- +Sometimes, issues might occur when installing from source on a system where a previous installation with wheels, or vice versa. It is safe to uninstall TE first before +switching between installing from source and installing from wheels. Here is the example command: + +.. code-block:: bash + + # The package name pattern might be transformer_engine or transformer-engine depending on Setuptools version + pip list | grep transformer.engine | cut -f' ' -d1 | xargs pip uninstall -y + + +Installation on NVIDIA GPUs +=========================== Prerequisites ------------- diff --git a/setup.py b/setup.py index 80435f9aa..d201641f3 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ from build_tools.te_version import te_version from build_tools.utils import ( rocm_build, + rocm_version, all_files_in_dir, cuda_archs, cuda_version, @@ -250,9 +251,10 @@ def git_check_submodules() -> None: "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], } if not rocm_build() else { - "core": [f"transformer_engine_rocm=={__version__}"], - "pytorch": [f"transformer_engine_torch=={__version__}"], - "jax": [f"transformer_engine_jax=={__version__}"], + "rocm": [f"transformer_engine_rocm7=={__version__}"], + "rocm7": [f"transformer_engine_rocm7=={__version__}"], + "rocm_pytorch": [f"transformer_engine_rocm7[pytorch]=={__version__}"], + "rocm_jax": [f"transformer_engine_rocm7[jax]=={__version__}"], } else: install_requires, test_requires = setup_requirements() @@ -288,9 +290,14 @@ def git_check_submodules() -> None: ) PACKAGE_NAME="transformer_engine" - if rocm_build() and bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): - if not bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))): - PACKAGE_NAME="transformer_engine_rocm" + if (rocm_build() and bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) + and not bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))) ): + PACKAGE_NAME=f"transformer_engine_rocm{rocm_version()[0]}" + #On ROCm add extras to core package so it can be installed w/o metapackage + extras_require.update({ + "pytorch": [f"transformer_engine_rocm_torch=={__version__}"], + "jax": [f"transformer_engine_rocm_jax=={__version__}"], + }) # Configure package setuptools.setup( diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index a8c0c26df..71219deb1 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -83,4 +83,11 @@ category=RuntimeWarning, ) -__version__ = str(metadata.version("transformer_engine")) +try: + __version__ = str(metadata.version("transformer_engine")) +except metadata.PackageNotFoundError: + if not transformer_engine.common.te_rocm_build: + raise + _te_core_installed, _, __version__ = transformer_engine.common.get_te_core_package_info() + if not _te_core_installed: + raise diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 075392f14..95719e188 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -141,7 +141,7 @@ def get_te_core_package_info() -> Tuple[bool, str, str]: te_core_packages = ("transformer-engine-cu12", "transformer-engine-cu13") if te_rocm_build: - te_core_packages = ("transformer-engine-rocm",) + te_core_packages = ("transformer-engine-rocm7",) for package in te_core_packages: if _is_package_installed(package): return True, package, version(package) @@ -173,6 +173,11 @@ def load_framework_extension(framework: str) -> None: te_installed = _is_package_installed("transformer_engine") te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine") + # Meta package is optional for ROCm build. + if te_rocm_build and te_core_installed and not te_installed: + te_installed = True + te_installed_via_pypi = True + assert te_installed, "Could not find `transformer_engine`." # If the framework extension pip package is installed, it means that TE is installed via @@ -180,7 +185,8 @@ def load_framework_extension(framework: str) -> None: # extension are all installed via PyPI and have matching versions. if te_framework_installed: assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." - assert te_core_installed, "Could not find TE core package `transformer-engine-cu*`." + assert te_core_installed, ( "Could not find TE core package " + f"`transformer-engine-{'rocm' if te_rocm_build else 'cu'}*`." ) assert version(module_name) == version("transformer-engine") == te_core_version, ( "Transformer Engine package version mismatch. Found" diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index f18155cb9..47710a1f6 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -47,7 +47,7 @@ from build_tools.build_ext import get_build_ext, SdistWithLocalVersion from build_tools.utils import copy_common_headers, min_python_version_str -from build_tools.utils import rocm_build +from build_tools.utils import rocm_build, rocm_version from build_tools.te_version import te_version from build_tools.jax import setup_jax_extension, install_requirements, test_requirements @@ -129,12 +129,12 @@ def get_cuda_major_version() -> int: if not rocm_build(): te_core = f"transformer_engine_cu{get_cuda_major_version()}=={__version__}" else: - te_core = f"transformer_engine_rocm=={__version__}" + te_core = f"transformer_engine_rocm{rocm_version()[0]}=={__version__}" install_requires = install_requirements() + [te_core] # Configure package setuptools.setup( - name="transformer_engine_jax", + name="transformer_engine_rocm_jax" if rocm_build() else "transformer_engine_jax", version=__version__, description="Transformer acceleration library - Jax Lib", ext_modules=ext_modules, diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 4d4fb10cb..a05934412 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -47,7 +47,7 @@ from build_tools.build_ext import get_build_ext, SdistWithLocalVersion -from build_tools.utils import rocm_build +from build_tools.utils import rocm_build, rocm_version from build_tools.utils import copy_common_headers, min_python_version_str from build_tools.te_version import te_version from build_tools.pytorch import ( @@ -58,6 +58,7 @@ if rocm_build(): from build_tools.hipify.hipify import copy_hipify_tools, clear_hipify_tools_copy + PACKAGE_NAME = "transformer_engine_rocm_torch" os.environ["NVTE_PROJECT_BUILDING"] = "1" @@ -168,7 +169,7 @@ def run(self): te_core = f"transformer_engine_cu{cuda_major_version}=={__version__}" install_requires = install_requirements() + [te_core] else: - te_core = f"transformer_engine_rocm=={__version__}" + te_core = f"transformer_engine_rocm{rocm_version()[0]}=={__version__}" install_requires = install_requirements() + [te_core] # Configure package