Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
110 commits
Select commit Hold shift + click to select a range
66aed3a
Updated VERSION to 2.11.0.dev0
ptrendx Nov 15, 2025
42d2274
[JAX] Quickstart documentation (#2310)
tdophung Nov 15, 2025
e1edaae
[PyTorch] Reduce CPU overheads (#2377)
ksivaman Nov 17, 2025
1df4a69
[PyTorch] Enable reference Current Scaling recipe (#2368)
negvet Nov 17, 2025
7e593c3
Add num_splits support for FA3 backend (#2380)
cyanguwa Nov 17, 2025
15cefbc
[JAX] Add support for sink attention in JAX (#2225)
pggPL Nov 18, 2025
d677a26
Show quickstart_jax.ipynb along with quickstart.ipynb on html documen…
tdophung Nov 18, 2025
e122173
[PyTorch] Cache RHT device tensors properly (#2395)
ksivaman Nov 18, 2025
30c0120
[PyTorch] Fix small errors (#2396)
pggPL Nov 18, 2025
05bfa3f
[PyTorch] Implement Selective Activation Checkpointing for LayerNormM…
jaimec00 Nov 18, 2025
41fb9bc
[PyTorch] fix `test_current_device` test (#2398)
cyanguwa Nov 19, 2025
877b796
Feature fast cast-only mxfp8 (#2062)
Jianbing-D Nov 19, 2025
e6da012
[PyTorch] Disable Flash Attention backend in Userbuffers tests (#2399)
timmoon10 Nov 19, 2025
49f7c1d
Avoid autogenerating docs for Python files with leading underscore (#…
timmoon10 Nov 19, 2025
8ef8285
Minor improvements to CPU overhead (#2400)
ksivaman Nov 19, 2025
4142547
[PyTorch] Fix ONNX export errors (#2406)
pggPL Nov 21, 2025
15dead1
[PyTorch] Fix for CPU offloading (#2403)
pggPL Nov 21, 2025
6f4bc33
Make grad_output contiguous in cross_entropy.py (#2402)
LitLeo Nov 21, 2025
632c4c3
ci: Build and attach bdist wheels to release page (#2138)
ko3n1g Nov 21, 2025
b14f417
[PyTorch] Fix assertion error message formatting in DotProductAttenti…
janbernloehr Nov 21, 2025
beed55b
[JAX] Set BSHD as default in Unfused DPA, DPA and MHA API calls (#2392)
KshitijLakhani Nov 21, 2025
4654b70
[JAX] Remove unnecessary SWA calculation in _segment_ids_pos_to_seqle…
KshitijLakhani Nov 21, 2025
a75da0c
Enable SWA with CP for THD input format (#2220)
sudhakarsingh27 Nov 21, 2025
f8cb598
[PyTorch] Only disable Flash Attention in Userbuffers test on SM 8.0 …
timmoon10 Nov 21, 2025
0056b98
[PyTorch] Change arguments order in triton kernels to make jax-triton…
tdophung Nov 25, 2025
f612b74
docs: Document NVTE_CUDA_ARCHS environment variable in README (#2414)
satias10 Nov 25, 2025
66ae303
[JAX] Allow DP + FSDP and fixed sr_rng_state partitioning (#2418)
phu0ngng Nov 25, 2025
3b8d9a8
[Pytorch] remove redundant error check in Linear module (#2420)
vthumbe1503 Nov 25, 2025
89cc2a7
[PyTorch][NVFP4][MOE] NVFP4 Grouped Hadamard Amax Kernel (#2351)
zhongbozhu Nov 25, 2025
d52ed47
FSDP2 Allgather Perf improvement and support for FusedAdam with FSDP2…
vthumbe1503 Nov 25, 2025
b3c2505
[Pytorch] Fix backward_dw cuda graph order (#2376)
Wohox Nov 25, 2025
9f61f8a
[PyTorch Debug] Debug support for GroupedLinear (#1953)
pggPL Nov 25, 2025
9ca89e9
[PyTorch] Avoid initializing recipe state in fusible op base class co…
timmoon10 Nov 26, 2025
ca468eb
Extend docs with quantizers/quantized_tensors/custom_recipe (#2428)
negvet Nov 26, 2025
df39a7c
Docs fix (#2301)
pggPL Nov 26, 2025
3ff0b8d
Change Flax MHA to DPA to remove the duplicated QKV projection step (…
tdophung Nov 27, 2025
f1512b2
[JAX] Triton binding (#2437)
phu0ngng Dec 2, 2025
14b5331
[Common] NVTEGroupedTensor class and helpers (#2388)
phu0ngng Dec 2, 2025
cc42a57
[JAX] Make test_layer.py tolerances stricter (#2306)
jberchtold-nvidia Dec 2, 2025
d126cdd
Add primary weighs fp8 support for mxfp8 (#2055)
kunlunl Dec 2, 2025
6182206
[Core] Fix inconsistent logic in C++ tensor class (#2330)
timmoon10 Dec 4, 2025
50be029
[JAX] Enable TE/JAX test timings in CI (#2475)
jberchtold-nvidia Dec 5, 2025
f0572aa
Fix bugs from refactoring C++ tensor class (#2481)
timmoon10 Dec 5, 2025
fd0cd12
[JAX] Add CP + THD + AG + Striped>1 + SWA support (#2379)
KshitijLakhani Dec 6, 2025
fd91bae
Changed VERSION to 2.12.0.dev0
ptrendx Dec 8, 2025
c09411d
[Pytorch][Bug]MXFP8 Split tensor Bug fix (#2427)
vthumbe1503 Dec 8, 2025
8ef3a33
Fix runtime lib loading logic (#2297)
ksivaman Dec 9, 2025
e05f87e
[PyTorch] Change order of args in another permutation triton kernel …
tdophung Dec 9, 2025
dbaa02d
Fix the sm120 compilation with CUDA 12 (#2482)
ptrendx Dec 9, 2025
46c6ef3
Jax primitives for permutation on single GPU (#2473)
tdophung Dec 9, 2025
5afbb0e
[JAX] Make softmax_type in FFI optional (#2491)
jberchtold-nvidia Dec 10, 2025
e411547
[PyTorch Debug] Add nvdlfw-inspect to dependencies (#2173)
pggPL Dec 10, 2025
93c5c65
[PyTorch] Add THD support for max_logit/MuonClip (#2480)
cyanguwa Dec 10, 2025
a5694f2
Add separate RNG states for column-wise quantization with Stochastic …
negvet Dec 11, 2025
811e090
[PyTorch] Update RNG global states in tracker set_states (#2501)
buptzyb Dec 11, 2025
5035232
[PyTorch] Convert sample tuple to list in cudagraph input reuse (#2426)
buptzyb Dec 11, 2025
887a4fc
[JAX] Unset NVTE_FUSED_RING_ATTENTION_USE_SCAN by default (#2503)
KshitijLakhani Dec 11, 2025
8c9f7c2
[PyTorch] Add triton requirement (#2490)
ksivaman Dec 12, 2025
36f2dfd
fix ce loss calculation when some tokens are ignored (#2476)
yashaswikarnati Dec 15, 2025
b215116
Check calling convention for amax switch. (#2506)
kwyss-nvidia Dec 15, 2025
2886cbc
[PyTorch debug] Fix test for debug tools (#2507)
pggPL Dec 15, 2025
eac8af6
Remove test skip logic for GEMM-AR tests (#2516)
vcherepanov-nv Dec 16, 2025
dbd0197
Reset cache logic of weight workspace for NVFP4TensorStorage (#2524)
jinhangchoi Dec 17, 2025
5c2f2ff
Add ccache support to TE and use it in GitHub actions (#2444)
ptrendx Dec 17, 2025
442513c
[JAX] Add tutorial for integrating TE/JAX quantization into an existi…
jberchtold-nvidia Dec 17, 2025
14ddb43
Fix meta device check failure when passing torch.device objects (#2519)
LucienXian Dec 18, 2025
3e69397
ci: Use whitelisted sha for `get-release` (#2531)
ko3n1g Dec 18, 2025
6fd6209
[PyTorch] Make sure Float8Tensor.contiguous supports autograd (#2533)
sudhakarsingh27 Dec 19, 2025
d46d5db
[JAX] Handle meshs set with jax.set_mesh (#2532)
jberchtold-nvidia Dec 19, 2025
47902e9
[JAX] Remove unused TE DPA module dtype which fixes cuDNN backend det…
jberchtold-nvidia Dec 20, 2025
eb8e792
[PyTorch][NVFP4][MOE] NVFP4 Grouped Quantize with Hadamard Transform …
zhongbozhu Dec 20, 2025
97a09c2
Fix ptxas compilation on sm103 for triton kernels (#2539)
tdophung Dec 22, 2025
5ba01fa
[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimizati…
xiaoxi-wangfj Dec 27, 2025
26c82db
[JAX] Fix incorrect calculation of segment pos from segment ids in us…
KshitijLakhani Dec 31, 2025
697b52c
Fix overflow of padding/unpadding kernel (#2548)
adamantboy Dec 31, 2025
324be33
[PyTorch] Support cudagraph recomputation (#2518)
buptzyb Dec 31, 2025
830ef60
Update copyright to include year 2026 (#2553)
ksivaman Jan 2, 2026
27dc83b
Document environment variables (#2552)
ksivaman Jan 2, 2026
c988548
[PyTorch] Fix garbage initialized permuted_scale (#2547)
xiaoxi-wangfj Jan 2, 2026
4f364c8
Fix out of bound ID passed to `cutlass::arch::NamedBarrier::sync` (#2…
ksivaman Jan 5, 2026
c90a921
Add tests that reset_parameters doesn't change parameter initial valu…
pstjohn Jan 5, 2026
a976740
[docs] Getting started refactor (#2534)
pggPL Jan 6, 2026
df69100
[Common] Fix long compile time in padding.cu on arch 75 (#2562)
jberchtold-nvidia Jan 6, 2026
404a3ee
[JAX] Fix test_layer to support fused attention and adjust test encod…
jberchtold-nvidia Jan 6, 2026
702fc5e
Fix 50% comparison mismatch in sort_chunks_by_index (#2566)
tdophung Jan 7, 2026
de51c96
[NVFP4][MOE] Bug Fix for NVFP4 Grouped Quant (#2564)
zhongbozhu Jan 7, 2026
08dc786
Fix 50% comparison mismatch in sort_chunks_by_index (Cont.) (#2575)
tdophung Jan 7, 2026
5f828c2
Solve pytorch-triton and triton package contention (#2540)
tdophung Jan 8, 2026
5f0e3b9
[JAX] Refactor and trim TE JAX Attn testing (#2542)
KshitijLakhani Jan 9, 2026
32f403f
Update list of authorized CI users (#2581)
timmoon10 Jan 9, 2026
2f8ae81
Debug doc generation (#2576)
timmoon10 Jan 10, 2026
fe8fad5
[PyTorch] Bunch of fixes for cpu offloading (#2535)
pggPL Jan 13, 2026
69636a0
ONNX: Fix FP8 quantization for the second MLP in LayerNormMLP (#2577)
victoroliv2 Jan 13, 2026
bd00799
Revert adding pytorch-triton as a build requirement (#2592)
tdophung Jan 14, 2026
fcfa0c3
(Bug fix) Fix accuracy issue for blockwise scaling+E8 scale on Blackw…
lhb8125 Jan 15, 2026
4df43db
docs: Update README Latest News section (#2583)
sbhavani Jan 15, 2026
2236292
[JAX] Disable fused attention in encoder tests for determinism (#2601)
jberchtold-nvidia Jan 15, 2026
6cbdb04
[JAX] Install Cmake in TE/JAX build Github Action (#2603)
jberchtold-nvidia Jan 15, 2026
6a34b65
fix: enable opt for cutlass sources to avoid infinite compile time (#…
kainzhong Jan 15, 2026
a652730
[JAX] Custom partitioning for Permutation primitives (#2591)
tdophung Jan 16, 2026
99df881
Add logic for block-scaled tensors with GEMM swizzled scales (#2486)
timmoon10 Jan 17, 2026
ccbe825
[ROCm] commit with conflicts as they are
wangye805 Mar 16, 2026
10348ab
[ROCm] resolve copyright conflicts
wangye805 Mar 16, 2026
838b2bd
[ROCm] resolve conflicts in common dir
wangye805 Mar 17, 2026
73bbe88
Common build improvements
Micky774 Mar 18, 2026
af4a06a
Compilation correction
Micky774 Mar 18, 2026
6b6dd70
Squashed commit of jax resolutions:
Micky774 Mar 19, 2026
3e06da0
Squashed commit of pytorch resolutions:
Micky774 Mar 19, 2026
2bd196b
IFU PR feedback
Micky774 Mar 19, 2026
55b64ed
Guarded NV-only env vars
Micky774 Mar 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions .github/actions/build-pytorch-wheel/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

FROM ubuntu:22.04

ENV DEBIAN_FRONTEND=noninteractive

ENV CUDA_HOME=/usr/local/cuda
ENV PATH=$PATH:$CUDA_HOME/bin
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
ENV TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5;8.0;8.6;9.0"

ARG PYTHON_VERSION=3.12
ARG TORCH_VERSION=2.9.1
ARG CUDA_VERSION=12.9.1
ARG CUDNN_MAJOR_VERSION=9
ENV PATH=/opt/venv/bin:$PATH
ENV PYTHONUNBUFFERED=1
ARG AARCH=x86_64

# Install Python
RUN apt-get update && \
apt-get install -y software-properties-common wget && \
add-apt-repository ppa:deadsnakes/ppa -y && \
apt-get install -y python$PYTHON_VERSION-dev python$PYTHON_VERSION-venv python3-pip && \
python$PYTHON_VERSION -m venv /opt/venv


# Install cuda-toolkit
RUN CUDA_MAJOR_VERSION=$(echo $CUDA_VERSION | awk -F \. {'print $1'}) && \
CUDA_MINOR_VERSION=$(echo $CUDA_VERSION | awk -F \. {'print $2'}) && \
rm /etc/apt/sources.list.d/cuda*.list || true && \
rm /etc/apt/sources.list.d/nvidia-cuda.list || true && \
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${AARCH}/cuda-keyring_1.1-1_all.deb && \
dpkg -i cuda-keyring_1.1-1_all.deb && \
rm cuda-keyring_1.1-1_all.deb && \
apt-get update && \
apt-get install -y cuda-toolkit-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} cudnn-cuda-$CUDA_MAJOR_VERSION libcudnn$CUDNN_MAJOR_VERSION-cuda-$CUDA_MAJOR_VERSION libnccl2 libnccl-dev cmake

# Install PyTorch
RUN export MATRIX_CUDA_VERSION=$(echo $CUDA_VERSION | awk -F \. {'print $1 $2'}) && \
export MATRIX_TORCH_VERSION=$(echo $TORCH_VERSION | awk -F \. {'print $1 "." $2'}) && \
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126, '2.9': 126}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129, '2.9': 130}[env['MATRIX_TORCH_VERSION']]; \
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
) && \
pip install --no-cache-dir torch==${TORCH_VERSION} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
118 changes: 118 additions & 0 deletions .github/actions/build-pytorch-wheel/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

name: Build PyTorch Wheel
description: Builds a PyTorch wheel for TransformerEngine

inputs:
release-version:
description: 'The release version to use for the build'
required: true
python-version:
description: 'The Python version to use for the build'
required: true
cuda-version:
description: 'The CUDA version to use for the build'
required: true
cudnn-version:
description: 'The cuDNN version to use for the build'
required: true
torch-version:
description: 'The PyTorch version to use for the build'
required: true
cxx11_abi:
description: 'Enable torch flag C++11 ABI (TRUE/FALSE)'
required: true
base-image:
description: 'The base image to use for the build'
required: false
aarch:
description: 'The architecture to use for the build'
required: true
outputs:
wheel_name:
description: 'The name of the built wheel'
value: ${{ steps.build_wheel.outputs.wheel_name }}

runs:
using: 'composite'
steps:
- name: Move /var/lib/docker/
shell: bash -euxo pipefail {0}
run: sudo mv /var/lib/docker/ "${GITHUB_WORKSPACE}/docker"

- name: Maximize build space
uses: easimon/maximize-build-space@c28619d8999a147d5e09c1199f84ff6af6ad5794
with:
root-reserve-mb: 5120
temp-reserve-mb: 32
swap-size-mb: 10240
remove-dotnet: 'true'
remove-android: 'true'
remove-haskell: 'true'
remove-codeql: 'true'
build-mount-path: '/var/lib/docker/'

- name: Restore /var/lib/docker/
shell: bash -euxo pipefail {0}
run: sudo sh -c "mv ${GITHUB_WORKSPACE}/docker/* /var/lib/docker"

- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ inputs.release-version }}
submodules: recursive

- name: Checkout build tools
uses: actions/checkout@v4
with:
path: build-tools
submodules: recursive

- name: Build image
shell: bash -euxo pipefail {0}
env:
BASE_IMAGE: ${{ inputs.base-image }}
run: |
if [[ "${BASE_IMAGE}" == "" ]]; then
docker build \
-t transformer-engine-build \
-f build-tools/.github/actions/build-pytorch-wheel/Dockerfile \
--build-arg PYTHON_VERSION=${{ inputs.python-version }} \
--build-arg TORCH_VERSION=${{ inputs.torch-version }} \
--build-arg CUDA_VERSION=${{ inputs.cuda-version }} \
--build-arg CUDNN_MAJOR_VERSION=${{ inputs.cudnn-version }} \
--build-arg AARCH=${{ inputs.aarch }} \
.
else
docker pull ${BASE_IMAGE}
docker tag ${BASE_IMAGE} transformer-engine-build
fi
- name: Build wheel
shell: bash -euxo pipefail {0}
id: build_wheel
env:
CXX11_ABI: ${{ inputs.cxx11_abi }}
run: |
echo ::group::Build wheel
EXIT_CODE=$(docker run \
--rm \
--shm-size=64g \
--workdir /workspace/transformer_engine/pytorch \
--volume $(pwd):/workspace \
--volume $GITHUB_OUTPUT:$GITHUB_OUTPUT \
-e PIP_CONSTRAINT= \
-e CXX11_ABI=$CXX11_ABI \
-e GITHUB_OUTPUT=$GITHUB_OUTPUT \
transformer-engine-build bash /workspace/build-tools/.github/actions/build-pytorch-wheel/build.sh | tail -n 1)
# Do not fail the job if timeout killed the build
exit $EXIT_CODE
echo ::endgroup::
- name: Log Built Wheels
shell: bash -euxo pipefail {0}
run: |
ls transformer_engine/pytorch/dist
26 changes: 26 additions & 0 deletions .github/actions/build-pytorch-wheel/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash

# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

set -eoxu pipefail

export NVTE_PYTORCH_FORCE_BUILD=TRUE
export NVTE_NO_LOCAL_VERSION=1
export NVTE_PYTORCH_FORCE_CXX11_ABI=$CXX11_ABI
export PIP_CONSTRAINT=

pip install wheel packaging nvidia-mathdx ninja pybind11

# 5h timeout since GH allows max 6h and we want some buffer
EXIT_CODE=0
timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$?

if [ $EXIT_CODE -eq 0 ]; then
wheel_name=$(python -c "import setup; print(setup.get_wheel_url()[1])" | tail -n 1)
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" | tee -a "$GITHUB_OUTPUT"
fi

echo $EXIT_CODE
69 changes: 69 additions & 0 deletions .github/scripts/check_for_ngc_images.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/bin/bash

# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

# Configuration
BASE_IMAGE="nvcr.io/nvidia/pytorch"
TAG_SUFFIX="-py3"
MONTHS_TO_CHECK=5 # Check current month and previous 4 months (total 5)

# Initialize an array to store existing tags
EXISTING_TAGS=()

echo "Checking for existence of the last ${MONTHS_TO_CHECK} NGC PyTorch images: ${BASE_IMAGE}:YY.MM${TAG_SUFFIX}"
echo "---------------------------------------------------------------------"

# Loop through the last N months
for i in $(seq 0 $((MONTHS_TO_CHECK - 1))); do
# Calculate Year and Month for the tag
CURRENT_YEAR=$(date +%Y)
CURRENT_MONTH=$(date +%m)

# Calculate target month and year
TARGET_DATE=$(date -d "$CURRENT_YEAR-$CURRENT_MONTH-01 -$i months" +%y.%m)

# Construct the full image tag and the tag-only string
IMAGE_TAG="${TARGET_DATE}${TAG_SUFFIX}"
FULL_IMAGE="${BASE_IMAGE}:${IMAGE_TAG}"

echo "Checking: ${FULL_IMAGE}"

# Use 'docker manifest inspect' to check for image existence without pulling.
if docker manifest inspect "${FULL_IMAGE}" > /dev/null 2>&1; then
echo "✅ EXISTS: Found."
# Add the tag-only string to the array
EXISTING_TAGS+=("nvcr.io/nvidia/pytorch:${IMAGE_TAG}")
else
echo "❌ MISSING: Not found."
fi
done

echo "---------------------------------------------------------------------"

## JSON Output Generation
# This uses the collected array to build a JSON string.

# 1. Convert the shell array to a newline-separated string.
TAGS_NL_SEP=$(printf "%s\n" "${EXISTING_TAGS[@]}")

# 2. Use jq to read the newline-separated list and format it into a JSON array.
# . | split("\n") | .[:-1] reads the input, splits it by newline, and removes the trailing empty element.
if command -v jq &> /dev/null; then
JSON_STRING=$(echo -e "${TAGS_NL_SEP}" | jq -R -s 'split("\n") | .[:-1]')

echo "Generated JSON String of Existing Tags:"
echo "${JSON_STRING}"

# Optional: Save the JSON string to a variable for further use
# echo "JSON_STRING is now available in the shell if you source this script."
else
echo "WARNING: 'jq' is not installed. Cannot format output as JSON."
echo "Found Tags: ${EXISTING_TAGS[*]}"
fi

echo "---"
echo "Check complete."

echo "${JSON_STRING}" > ngc_images.json
Loading
Loading