Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build_tools/hipify/custom_map.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"<nvtx3/nvToolsExt.h>" : "<roctracer/roctx.h>",
"cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)",
"__nv_bfloat162":"__hip_bfloat162",
"cuda::getCurrentCUDAStream" : "hip::getCurrentHIPStreamMasqueradingAsCUDA"
"cuda::getCurrentCUDAStream" : "hip::getCurrentHIPStreamMasqueradingAsCUDA",
"at::cuda::CUDAGuard" : "c10::hip::HIPGuardMasqueradingAsCUDA"
}
}

25 changes: 12 additions & 13 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,24 +103,23 @@ def generate_input_shapes(
seqlens_q_padded.cumsum(0, dtype=torch.int32),
]
).cuda()
<<<<<<< HEAD
if kernel_backend == "FlashAttention":
cu_seqlens_q = cu_seqlens_q_padded[:-1]
if IS_HIP_EXTENSION:
if kernel_backend == "FlashAttention":
cu_seqlens_q = cu_seqlens_q_padded[:-1]
else:
cu_seqlens_q = torch.cat(
[torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)]
).cuda()
else:
cu_seqlens_q = torch.cat(
[torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)]
).cuda()
=======
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)

# Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does,
# cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only.
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
# Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does,
# cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only.
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()

# NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded`
# will not be the same as `cu_seqlens_q` and `cu_seqlens_q_padded` respectively.
>>>>>>> 99df88
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded

Expand Down
2 changes: 2 additions & 0 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,8 @@ def find_factors(x):

for num_q_per_gqa_group in num_querys_per_gqa_group:
config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
continue
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
Expand Down
3 changes: 0 additions & 3 deletions tests/pytorch/distributed/run_fsdp2_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#!/usr/bin/python3
<<<<<<< HEAD
# This file was modified for portability to AMDGPU
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
=======

>>>>>>> 99df88
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down
3 changes: 1 addition & 2 deletions tests/pytorch/test_gemm_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from torch.utils.cpp_extension import IS_HIP_EXTENSION

from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace


storage_fname = "te_algo"
Expand Down Expand Up @@ -107,7 +106,7 @@ def run_gemm():
N = 32
datatype = torch.float16
inp = torch.randn((N, N), device="cuda", dtype=datatype)
_, _, _, _ = general_gemm(A=inp, B=inp, out_dtype=datatype, workspace=get_workspace())
_, _, _, _ = general_gemm(A=inp, B=inp, out_dtype=datatype)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions tests/pytorch/test_gemm_sm_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Dict

from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
import logging


Expand Down Expand Up @@ -40,7 +39,7 @@ def test_gemm_sm_count():
datatype = torch.float32
A = torch.randn((K, M), device="cuda", dtype=datatype)
B = torch.randn((N, K), device="cuda", dtype=datatype)
gemm_parameters = {'A': A, 'B': B, 'layout': "NN", 'workspace': get_workspace()}
gemm_parameters = {'A': A, 'B': B, 'layout': "NN"}

with torch.cuda.stream(torch.cuda.Stream()):
full_timing = _run_gemm_timing("Full", gemm_parameters)
Expand Down
3 changes: 0 additions & 3 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,15 +1707,12 @@ def test_layernorm_linear_accuracy(
def test_layernorm_linear_accuracy_delay_wgrad_compute(
dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation
):
<<<<<<< HEAD
if IS_HIP_EXTENSION:
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias:
pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.")
=======
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")

>>>>>>> 99df88
config = model_configs[model]

ln_linear_ref = LayerNormLinear(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,9 @@ def get_attention_backend(
----------
use_flash_attention : bool
Whether the `FlashAttention` backend has been selected.
<<<<<<< HEAD
flash_attention_backend: PkgVersion
If `use_flash_attention = True`, the version of the selected `FlashAttention` backend.
use_fused_attention: bool
=======
use_fused_attention : bool
>>>>>>> 99df88
Whether the `FusedAttention` backend has been selected.
fused_attention_backend : tex.NVTE_Fused_Attn_Backend
If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
Expand Down
13 changes: 11 additions & 2 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -8,9 +10,10 @@
import os
import functools
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
from ..constants import TE_DType
from ..utils import get_sm_count, _empty_tensor
from ..utils import get_sm_count, _empty_tensor, get_device_compute_capability

from ..quantized_tensor import Quantizer
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
Expand All @@ -29,8 +32,14 @@


def get_cublas_workspace_size_bytes() -> None:
"""Return workspace size needed for current architecture."""
if IS_HIP_EXTENSION:
"""Return 64 MiB for gfx50x, 32 MiB for all other architectures."""
if get_device_compute_capability(0) == 95:
return 67_108_864
return 33_554_432
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
if get_device_compute_capability(0) >= 90:
# 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales
return 32 * 1024 * 1024 + 1024
return 4_194_304
Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ size_t roundup(size_t value, size_t multiple) {
return ((value + multiple - 1) / multiple) * multiple;
}

<<<<<<< HEAD
#ifdef USE_ROCM

inline bool nvte_use_atomic_amax() {
Expand All @@ -335,9 +334,8 @@ at::Tensor allocate_amax_workspace(const TensorWrapper& input_tensor) {
}

#endif
=======

size_t ceildiv(size_t numer, size_t denom) { return (numer + denom - 1) / denom; }
>>>>>>> 99df88

void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) {
NVTE_SCOPED_GIL_RELEASE({
Expand Down
3 changes: 0 additions & 3 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,14 +403,11 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) {
case transformer_engine::DType::kFloat8E5M2:
#ifndef USE_ROCM
return at::kFloat8_e5m2;
<<<<<<< HEAD
#else
return te_fp8_fnuz()? at::kFloat8_e5m2fnuz : at::kFloat8_e5m2;
#endif // USE_ROCM
=======
case transformer_engine::DType::kFloat8E8M0:
return at::kByte; // e8m0 dtype requires PyTorch 2.7.0+
>>>>>>> 99df88
default:
NVTE_ERROR("Invalid type (", static_cast<int>(t), ").");
}
Expand Down
5 changes: 1 addition & 4 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -489,17 +489,14 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
void fused_multi_row_unpadding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> unpadded_input_row_list);
<<<<<<< HEAD
#ifndef USE_ROCM
=======

/***************************************************************************************************
* Scale swizzling for GEMM
**************************************************************************************************/

void inplace_swizzle_scale_for_gemm(py::handle &tensor);

>>>>>>> 99df88
#ifndef USE_ROCM
/***************************************************************************************************
* NVSHMEM APIs
**************************************************************************************************/
Expand Down
20 changes: 7 additions & 13 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,6 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc

return retval;
}
#endif // #ifndef USE_ROCM

// Owns all allocations/wrappers backing quant_config_list[*].set_rng_state(...).
struct StochasticRngStateResources {
Expand Down Expand Up @@ -1100,6 +1099,7 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,
}
});
}
#endif // #ifndef USE_ROCM

} // namespace

Expand Down Expand Up @@ -1169,12 +1169,14 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
return detail::IsMXFP8Quantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_MXFP8;
#ifndef USE_ROCM
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsNVFP4Quantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_NVFP4;
quantization_method = QuantizationMethod::FUSED_NVFP4;
#endif
}
}

Expand All @@ -1200,34 +1202,25 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
}
std::tie(output_py_list, output_cpp_list) =
bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers);
<<<<<<< HEAD
#ifndef USE_ROCM
} else if (is_nvfp4) {
// NVFP4: construct output tensors with bulk allocations
=======
break;
}
#ifndef USE_ROCM
case AllocationMethod::BULK_NVFP4: {
// Bulk allocation for NVFP4 tensors
>>>>>>> 99df88
std::vector<NVFP4Quantizer *> nvfp4_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
nvfp4_quantizers.push_back(static_cast<NVFP4Quantizer *>(quantizer.get()));
}
bool contiguous_data_and_scale;
std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) =
bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers);
<<<<<<< HEAD
#endif
} else {
NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer");
=======
if (!contiguous_data_and_scale) {
// Avoid fused quantize kernel if data is not contiguous
quantization_method = QuantizationMethod::UNFUSED;
}
break;
}
#endif
default: {
// Allocate output tensors individually
for (size_t i = 0; i < num_splits; ++i) {
Expand All @@ -1236,12 +1229,12 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
output_cpp_list.emplace_back(std::move(output_cpp));
output_py_list.emplace_back(std::move(output_py));
}
>>>>>>> 99df88
}
}

// Quantize into output tensors
switch (quantization_method) {
#ifndef USE_ROCM
case QuantizationMethod::FUSED_NVFP4: {
// Fused NVFP4 quantize kernel
auto input_nvte = makeTransformerEngineTensor(input_dptr, input_shape, input_dtype);
Expand All @@ -1253,6 +1246,7 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
nvfp4_quantizers);
break;
}
#endif
default:
// General multi-tensor quantization
multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list);
Expand Down
53 changes: 10 additions & 43 deletions transformer_engine/pytorch/csrc/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,37 +827,20 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
<<<<<<< HEAD
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv0 = ceildiv(m_dim, kBlockLen);
#ifdef USE_ROCM
sinv1 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = ceildiv(k_dim, kBlockLen);
#else
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
#endif
=======
sinv0 = ceildiv(m_dim, kBlockLen);
sinv1 = roundup(ceildiv(k_dim, kBlockLen), 4);
>>>>>>> 99df88
#endif
} else if (block_scaling_dim == 1) {
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
<<<<<<< HEAD
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv0 = ceildiv(k_dim, kBlockLen);
#ifdef USE_ROCM
sinv1 = m_dim;
#else
sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4);
#endif
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
=======
sinv0 = ceildiv(k_dim, kBlockLen);
sinv1 = roundup(m_dim, 4);
>>>>>>> 99df88
#endif
} else {
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor rowwise."
Expand All @@ -870,35 +853,19 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
<<<<<<< HEAD
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv0 = ceildiv(k_dim, kBlockLen);
#ifdef USE_ROCM
sinv1 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = ceildiv(m_dim, kBlockLen);
#else
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
sinv1 = roundup(ceildiv(m_dim, kBlockLen), 4);
#endif
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv0 = ceildiv(m_dim, kBlockLen);
#ifdef USE_ROCM
sinv1 = k_dim;
#else
sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
#endif
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
=======
sinv0 = ceildiv(k_dim, kBlockLen);
sinv1 = roundup(ceildiv(m_dim, kBlockLen), 4);
} else if (block_scaling_dim == 1) {
sinv0 = ceildiv(m_dim, kBlockLen);
sinv1 = roundup(k_dim, 4);
>>>>>>> 99df88
#endif
} else {
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor columnwise."
Expand Down
Loading