Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 8 additions & 57 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,70 +28,21 @@ Feature Support Status
Installation
============

Install from manylinux wheels
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Starting from ROCm 7.0, we provide manylinux wheels for Transformer Engine releases on `https://repo.radeon.com/rocm/manylinux`. For example, the wheels for ROCm 7.1.1 are at `https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/`. From the page, you can find four files related to Transformer Engine:

* transformer_engine_rocm-*-py3-none-manylinux_2_28_x86_64.whl - This is the wheel file for installing the common library. It should not be installed by itself.
* transformer_engine-*-py3-none-any.whl - This is the wheel file for installing the common TE Python package.
* transformer_engine_jax-*.tar.gz - This is the source tar ball for the JAX extension.
* transformer_engine_torch-*.tar.gz - This is the source tar ball for the Pytorch extension.

Below are the example commands to download and install the wheels. They install both Pytorch and JAX extensions on the system where both frameworks are installed.

.. code-block:: bash

wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine_rocm-2.2.0-py3-none-manylinux_2_28_x86_64.whl
wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine-2.2.0-py3-none-any.whl
wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine_jax-2.2.0.tar.gz
wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine_torch-2.2.0.tar.gz

pip install ./transformer_engine* --no-build-isolation

Install TE from source
^^^^^^^^^^^^^^^^^^

Execute the following commands to install ROCm Transformer Engine from source on AMDGPUs:

.. code-block:: bash
See docs/installation.rst for detailed installation instructions on ROCm and AMDGPU.
For addtional build configuration parameters see `Fused Attention Backends on ROCm` section below.

# Clone TE repo and submodules
git clone --recursive https://github.com/ROCm/TransformerEngine.git

cd TransformerEngine
export NVTE_FRAMEWORK=pytorch,jax #optionally set framework, currently only support pytorch and jax; if not set will try to detect installed frameworks
export NVTE_ROCM_ARCH="gfx942;gfx950" # gfx942 for support of MI300/MI325, and gfx950 for support of MI350

# Build Platform Selection (optional)
# Note: Useful when both ROCm and CUDA platforms are present in the Docker
export NVTE_USE_ROCM=1 #Use 1 for ROCm, or set to 0 to use CUDA; If not set will try to detect installed platform, prioritizing ROCm

pip install . --no-build-isolation

It is also possible to build wheels for later installation with "pip wheel ." although those wheels will not be portable to systems with
different libraries installed. If the build still fails with the "--no-build-isolation" flag try installing setuptools<80.0.0

Note on Switching between Installation from Source and Installation from Wheels
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Sometimes, issues might occur when installing from source on a system where a previous installation with wheels, or vice versa. It is safe to uninstall TE first before
switching between installing from source and installing from wheels. Here is the example command:

.. code-block:: bash
AITER rebuilding
^^^^^^^^^^^^^^^^

# The package name pattern might be transformer_engine or transformer-engine depending on setuptools version
pip list | grep transformer.engine | xargs pip uninstall -y
TE uses AITER submodule as fused attention backend on ROCm. Rebuilding of this library takes a long time so build scripts cache the built library in `build/aiter-prebuilts`. If you want rebuild AITER, delete the cache and rebuild TE.

Known Issue with ROCm 6.4 PyTorch Release
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Using the docker image ``rocm/pytorch:rocm6.4_ubuntu22.04_py3.10_pytorch_release_2.5.1`` triggers a failure in the unit-test ``tests/pytorch/test_permutation.py`` (tracked in Jira ticket SWDEV-534311).

Rebuilding PyTorch at commit ``f929e0d602a71aa393ca2e6097674b210bdf321c`` resolves the issue.

Re-install PyTorch
^^^^^^^^^^^^^^^^^^

.. code-block:: bash

# Remove the pre-installed pytorch
Expand Down Expand Up @@ -299,7 +250,7 @@ Certain settings can be enabled to potentially optimize workloads depending on t
* NVTE_CK_ZERO_OUT_PAD - by default 1, if set to 0 then the output of the FA forward pass will not be initialized to zero, meaning invalid regions (representing padding) may take nonzero values. Only used if input has padding.

AITER FA v3 Kernels
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~
ROCm TE supports flash-attention v3 fwd/bwd kernels on gfx942 and gfx950 using AITER backend.
This functionality can be controlled by the following environment variables:

Expand All @@ -309,7 +260,7 @@ This functionality can be controlled by the following environment variables:
* NVTE_CK_HOW_V3_BF16_CVT - by default 1, float to bf16 convert type when v3 is enabled, 0:RTNE; 1:RTNA; 2:RTZ, only applicable to the gfx942 architecture.

Float to BFloat16 Conversion in CK Backend (gfx942 only)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
How fp32 converts to bf16 affects both the performance and accuracy in ck fused attn.
ROCm TE provides the compile-time env NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT with the following values available to choose from:

Expand Down
6 changes: 4 additions & 2 deletions build_tools/build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
cmake_bin,
debug_build_enabled,
found_ninja,
get_frameworks,
nvcc_path,
get_max_jobs_for_parallel_build,
)
Expand Down Expand Up @@ -161,8 +160,11 @@ def run(self) -> None:
def build_extensions(self):
# For core lib + JAX install, fix build_ext from pybind11.setup_helpers
# to handle CUDA files correctly.
# Upstream uses get_frameworks() here which is incorrectly works when install from
# release (sdist) wheel on a system with both frameworks installed.
ext_names = [ext.name for ext in self.extensions]
if "transformer_engine_pytorch" not in ext_names:
if ("transformer_engine_torch" not in ext_names and
"transformer_engine_rocm_torch" not in ext_names):
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict.
for ext in self.extensions:
Expand Down
18 changes: 17 additions & 1 deletion build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def rocm_build() -> bool:
# If neither ROCm nor CUDA is detected, raise an error
raise FileNotFoundError("Could not detect ROCm or CUDA platform")


@functools.lru_cache(maxsize=None)
def rocm_path() -> Tuple[str, str]:
"""ROCm root path and HIPCC binary path as a tuple"""
Expand All @@ -246,6 +247,18 @@ def rocm_path() -> Tuple[str, str]:
return rocm_home, hipcc_bin


def rocm_version() -> Tuple[int, ...]:
"""ROCm version as a (major, minor) tuple.
Try to get ROCm version by parsing .info/version.
"""
rocm_home, _ = rocm_path()
try:
with open(rocm_home / ".info" / "version", "r") as f:
rocm_version= f.read().strip().split('.')[:2]
return tuple(int(v) for v in rocm_version)
except FileNotFoundError:
raise RuntimeError("Could not determine ROCm version.")


def cuda_toolkit_include_path() -> Tuple[str, str]:
"""Returns root path for cuda toolkit includes.
Expand Down Expand Up @@ -487,9 +500,12 @@ def uninstall_te_wheel_packages():
"pip",
"uninstall",
"-y",
"transformer_engine_rocm", # te_cuda_vers for ROCm build
"transformer_engine",
"transformer_engine_cu12",
"transformer_engine_torch",
"transformer_engine_jax",
"transformer_engine_rocm7",
"transformer_engine_rocm_jax",
"transformer_engine_rocm_torch",
]
)
126 changes: 124 additions & 2 deletions docs/installation.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,132 @@
..
This file was modified to include portability information to AMDGPU.

Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.

Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

Installation
============
Installation on AMD GPUs
========================

Prerequisites
-------------
1. `AMD Instinct GPU <https://www.amd.com/en/products/accelerators/instinct.html>`__. Other GPUs are not supported while they can still work.
2. Linux x86_64
3. `ROCm stack <https://rocm.docs.amd.com/projects/install-on-linux/en/latest/index.html>`__
3.1. For ROCm TheRock (ROCm 7.11 and newer), install amdrocm-core-sdk* package

Additional Prerequisites
^^^^^^^^^^^^^^^^^^^^^^^^

1. [For PyTorch support] `https://rocm.docs.amd.com/projects/install-on-linux/en/develop/install/3rd-party/pytorch-install.html`__
2. [For JAX support] `https://rocm.docs.amd.com/projects/install-on-linux/en/develop/install/3rd-party/jax-install.html`__

if HIP compiler complains it cannot detect the platform set `HIP_PLATFORM=amd` in the environment.
if ROCm is installed in a non-standard location, set `ROCM_PATH` to the root of the ROCm installation in the environment, e.g. `ROCM_PATH=/opt/venv/lib/python3.12/site-packages/_rocm_sdk_devel` and additonally set the following environment variables:

- `HIP_DEVICE_LIB_PATH=$ROCM_PATH/llvm/amdgcn/bitcode/`
- `CMAKE_PREFIX_PATH=$ROCM_PATH/lib/cmake/``

pip - from wheels
-----------------

Transformer Engine for ROCm 7.0 and newer can be installed from manylinux wheels published at `https://repo.radeon.com/rocm/manylinux`. For example, the wheels for ROCm 7.2 are at `https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/`. Four files related to Transformer Engine can be found there:

- transformer_engine-\*-py3-none-any.whl - the main TE pure Python metapackage.
- transformer_engine_rocm-\*-py3-none-manylinux_2_28_x86_64.whl - the core library package.
- transformer_engine_jax-\*.tar.gz - source tarball (sdist) for the JAX extension.
- transformer_engine_torch-\*.tar.gz - source tarball (sdist) for the Pytorch extension.

Below are the example commands to download and install the wheels. They install both Pytorch and JAX extensions on the system where both frameworks are installed.

.. code-block:: bash

wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/transformer_engine_rocm-2.4.0-py3-none-manylinux_2_28_x86_64.whl
wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/transformer_engine-2.4.0-py3-none-any.whl
wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/transformer_engine_jax-2.4.0.tar.gz
wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/transformer_engine_torch-2.4.0.tar.gz

pip install ./transformer_engine* --no-build-isolation

Starting from version 2.10 core library wheel can be installed by itself or as an extenstion for TE Python metapackage.

Example of installing ROCm core library wheel without any framework extensions:

.. code-block:: bash

pip install transformer_engine_rocm7-2.10.0-py3-none-manylinux_2_28_x86_64.whl

Additionaly install framework extensions using ROCm package name and pip extras syntax.

.. code-block:: bash

pip install --find-links <url/or/local/directory/with/te/wheels/> transformer_engine_rocm7[pytorch,jax] --no-build-isolation

Installing the common library and frameworks extensions as extras for TE Python metapackage

.. code-block:: bash

pip install ./transformer_engine-2.4.0-py3-none-any.whl[rocm7,rocm_pytorch,rocm_jax] --no-build-isolation

It is not recommended to install TE Python metapackage using just package name transformer_engine because of possible installing of the NVIDIA GPU version. It is recommended to use either transformer_engine_rocm7 or wheel file name to make sure the correct common library is installed.


Installation from source
^^^^^^^^^^^^^^^^^^^^^^^^^^
Execute the following commands to install Transformer Engine from source:

.. code-block:: bash

# Clone repository, checkout stable branch, clone submodules
git clone --recursive https://github.com/ROCm/TransformerEngine.git

cd TransformerEngine
export NVTE_FRAMEWORK=pytorch,jax # Optionally set framework(s)
export NVTE_ROCM_ARCH="gfx942;gfx950" # Optionally set target GPU achs; gfx942 for MI300/MI325, and gfx950 for MI350
export NVTE_USE_ROCM=1 # Optionally force building for ROCm, useful when both ROCm and CUDA build environments are installed. If set to 0, it will force building for CUDA.
pip3 install --no-build-isolation . # Build and install

Or instead of immediate istall, create wheel file for later installation:

.. code-block:: bash

pip wheel . --no-build-isolation
pip3 install ./transformer_engine-*.whl

If the Git repository has already been cloned, make sure to also clone the submodules:

.. code-block:: bash

git submodule update --init --recursive

Extra dependencies for testing can be installed by setting the "test" option:

.. code-block:: bash

pip3 install --no-build-isolation .[test]

To build the C++ extensions with debug symbols, e.g. with the `-g` flag:

.. code-block:: bash

NVTE_BUILD_DEBUG=1 pip3 install --no-build-isolation .


Switching between Installation from Source and Installation from Wheels
-----------------------------------------------------------------------
Sometimes, issues might occur when installing from source on a system where a previous installation with wheels, or vice versa. It is safe to uninstall TE first before
switching between installing from source and installing from wheels. Here is the example command:

.. code-block:: bash

# The package name pattern might be transformer_engine or transformer-engine depending on setuptools version
pip list | grep transformer.engine | xargs pip uninstall -y


Installation on NVIDIA GPUs
===========================

Prerequisites
-------------
Expand Down
19 changes: 13 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from build_tools.te_version import te_version
from build_tools.utils import (
rocm_build,
rocm_version,
all_files_in_dir,
cuda_archs,
cuda_version,
Expand Down Expand Up @@ -250,9 +251,10 @@ def git_check_submodules() -> None:
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
} if not rocm_build() else {
"core": [f"transformer_engine_rocm=={__version__}"],
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
"rocm": [f"transformer_engine_rocm7=={__version__}"],
"rocm7": [f"transformer_engine_rocm7=={__version__}"],
"rocm_pytorch": [f"transformer_engine_rocm7[pytorch]=={__version__}"],
"rocm_jax": [f"transformer_engine_rocm7[jax]=={__version__}"],
}
else:
install_requires, test_requires = setup_requirements()
Expand Down Expand Up @@ -288,9 +290,14 @@ def git_check_submodules() -> None:
)

PACKAGE_NAME="transformer_engine"
if rocm_build() and bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if not bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
PACKAGE_NAME="transformer_engine_rocm"
if (rocm_build() and bool(int(os.getenv("NVTE_RELEASE_BUILD", "0")))
and not bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))) ):
PACKAGE_NAME=f"transformer_engine_rocm{rocm_version()[0]}"
#On ROCm add extras to core package so it can be installed w/o metapackage
extras_require.update({
"pytorch": [f"transformer_engine_rocm_torch=={__version__}"],
"jax": [f"transformer_engine_rocm_jax=={__version__}"],
})

# Configure package
setuptools.setup(
Expand Down
9 changes: 8 additions & 1 deletion transformer_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,11 @@
category=RuntimeWarning,
)

__version__ = str(metadata.version("transformer_engine"))
try:
__version__ = str(metadata.version("transformer_engine"))
except metadata.PackageNotFoundError:
if not transformer_engine.common.te_rocm_build:
raise
_te_core_installed, _, __version__ = transformer_engine.common.get_te_core_package_info()
if not _te_core_installed:
raise
10 changes: 8 additions & 2 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_te_core_package_info() -> Tuple[bool, str, str]:

te_core_packages = ("transformer-engine-cu12", "transformer-engine-cu13")
if te_rocm_build:
te_core_packages = ("transformer-engine-rocm",)
te_core_packages = ("transformer-engine-rocm7",)
for package in te_core_packages:
if _is_package_installed(package):
return True, package, version(package)
Expand Down Expand Up @@ -173,14 +173,20 @@ def load_framework_extension(framework: str) -> None:
te_installed = _is_package_installed("transformer_engine")
te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine")

# Meta package is optional for ROCm build.
if te_rocm_build and te_core_installed and not te_installed:
te_installed = True
te_installed_via_pypi = True

assert te_installed, "Could not find `transformer_engine`."

# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching versions.
if te_framework_installed:
assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
assert te_core_installed, "Could not find TE core package `transformer-engine-cu*`."
assert te_core_installed, ( "Could not find TE core package "
f"`transformer-engine-{'rocm' if te_rocm_build else 'cu'}*`." )

assert version(module_name) == version("transformer-engine") == te_core_version, (
"Transformer Engine package version mismatch. Found"
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/jax/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

from build_tools.build_ext import get_build_ext, SdistWithLocalVersion
from build_tools.utils import copy_common_headers, min_python_version_str
from build_tools.utils import rocm_build
from build_tools.utils import rocm_build, rocm_version
from build_tools.te_version import te_version
from build_tools.jax import setup_jax_extension, install_requirements, test_requirements

Expand Down Expand Up @@ -129,12 +129,12 @@ def get_cuda_major_version() -> int:
if not rocm_build():
te_core = f"transformer_engine_cu{get_cuda_major_version()}=={__version__}"
else:
te_core = f"transformer_engine_rocm=={__version__}"
te_core = f"transformer_engine_rocm{rocm_version()[0]}=={__version__}"
install_requires = install_requirements() + [te_core]

# Configure package
setuptools.setup(
name="transformer_engine_jax",
name="transformer_engine_rocm_jax" if rocm_build() else "transformer_engine_jax",
version=__version__,
description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules,
Expand Down
Loading