From 05bf128b4e62ecd7c49b42a4403d6be6b3a6c2b5 Mon Sep 17 00:00:00 2001 From: sararora Date: Wed, 18 Mar 2026 20:02:00 -0500 Subject: [PATCH] mxfp4 cast tranpose hadamard shuffle fused --- build_tools/pytorch.py | 4 +- transformer_engine/common/CMakeLists.txt | 3 +- .../common/cast/mxfp4/cast_transpose_mxfp4.h | 37 +++++++++++++++++++ transformer_engine/common/recipe/__init__.py | 1 + transformer_engine/pytorch/csrc/extensions.h | 16 ++++++++ .../pytorch/csrc/extensions/pybind.cpp | 34 +++++++++++++++++ transformer_engine/pytorch/tensor/__init__.py | 16 ++++++++ .../pytorch/tensor/mxfp4_tensor.py | 34 ++++++++++++++--- 8 files changed, 137 insertions(+), 8 deletions(-) create mode 100644 transformer_engine/common/cast/mxfp4/cast_transpose_mxfp4.h diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index d8eb9a81e..699735220 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -38,8 +38,10 @@ def setup_pytorch_extension( ) -> setuptools.Extension: """Setup CUDA extension for PyTorch support""" - # Source files + # Source files (.cu files are hipified to .hip for ROCm builds) sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp") + cu_sources = all_files_in_dir(Path(csrc_source_files), name_extension="cu") + sources.extend(cu_sources) # Header files if rocm_build(): diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index bb84765fc..6fdec2b50 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -259,7 +259,8 @@ else() list(APPEND transformer_engine_cuda_sources fused_attn_rocm/fused_attn_aotriton.cpp fused_attn_rocm/fused_attn_ck.cpp - fused_attn_rocm/utils.cpp) + fused_attn_rocm/utils.cpp + cast/mxfp4/cast_transpose_mxfp4.hip) endif() diff --git a/transformer_engine/common/cast/mxfp4/cast_transpose_mxfp4.h b/transformer_engine/common/cast/mxfp4/cast_transpose_mxfp4.h new file mode 100644 index 000000000..f54ebf3cf --- /dev/null +++ b/transformer_engine/common/cast/mxfp4/cast_transpose_mxfp4.h @@ -0,0 +1,37 @@ +/************************************************************************* + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once + +#include + +namespace te_mxfp4 { + +extern "C" void launch_cast_transpose_mxfp4( + const void* input, + void* rowwise_fp4, + void* rowwise_scale, + void* colwise_fp4, + void* colwise_scale, + int M, int N, + bool use_rowwise, + bool use_colwise, + bool shuffle_scales, + bool use_hadamard, + bool shuffle_rowwise_fp4, + bool shuffle_colwise_fp4, + int rowwise_scale_stride, + int colwise_scale_stride, + int rowwise_scale_N, + int rowwise_scale_M_pad, + int rowwise_scale_N_pad, + int colwise_scale_M, + int colwise_scale_N, + int colwise_scale_M_pad, + int colwise_scale_N_pad, + hipStream_t stream); + +} // namespace te_mxfp4 diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 5ac53bc4d..53e3b253e 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -570,6 +570,7 @@ class MXFP4BlockScaling(Recipe): fp4_format: Format = Format.E2M1 fp8_dpa: bool = False fp8_mha: bool = False + use_hadamard: bool = False @property def fp8_format(self) -> Format: diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index cbdc63dc2..b68739ec4 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -259,6 +259,22 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob py::object dequantize(const py::handle &input, DType otype); +/*************************************************************************************************** + * MXFP4 Quantization + **************************************************************************************************/ + +std::tuple cast_transpose_mxfp4_fused_shuffle( + at::Tensor input, + std::optional rowwise_fp4_out, + std::optional rowwise_scale_out, + std::optional colwise_fp4_out, + std::optional colwise_scale_out, + bool shuffle_rowwise_scale, + bool shuffle_colwise_scale, + bool shuffle_rowwise_fp4, + bool shuffle_colwise_fp4, + bool use_hadamard); + std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index db70dfbf1..2d5dc5443 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -31,6 +31,9 @@ PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr; PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *MXFP8TensorStoragePythonClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr; +PyTypeObject *MXFP4TensorPythonClass = nullptr; +PyTypeObject *MXFP4TensorBasePythonClass = nullptr; +PyTypeObject *MXFP4QuantizerClass = nullptr; PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; PyTypeObject *Float8BlockwiseQTensorStoragePythonClass = nullptr; PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; @@ -70,6 +73,21 @@ void init_mxfp8_extension() { "Internal error: could not initialize pyTorch MXFP8 extension."); } +void init_mxfp4_extension() { + if (MXFP4TensorPythonClass) return; + auto fp4_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp4_tensor"); + MXFP4QuantizerClass = + reinterpret_cast(PyObject_GetAttrString(fp4_module.ptr(), "MXFP4Quantizer")); + MXFP4TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(fp4_module.ptr(), "MXFP4Tensor")); + auto fp4_base_module = + py::module_::import("transformer_engine.pytorch.tensor._internal.mxfp4_tensor_base"); + MXFP4TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp4_base_module.ptr(), "MXFP4TensorBase")); + NVTE_CHECK(MXFP4TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch MXFP4 extension."); +} + void init_float8blockwise_extension() { if (Float8BlockwiseQTensorStoragePythonClass) return; auto fp8_module = @@ -109,6 +127,7 @@ void init_nvfp4_extensions() { void init_extension() { init_float8_extension(); init_mxfp8_extension(); + init_mxfp4_extension(); init_float8blockwise_extension(); init_nvfp4_extensions(); } @@ -246,6 +265,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm"); m.def("rmsnorm_bwd_add", &transformer_engine::pytorch::rmsnorm_bwd_add, "Fused backward of RMSNorm + add"); + // MXFP4 Quantization + m.def("cast_transpose_mxfp4_fused_shuffle", + &transformer_engine::pytorch::cast_transpose_mxfp4_fused_shuffle, + "MXFP4 cast and transpose with fused weight shuffle for GEMM", + py::arg("input"), + py::arg("rowwise_fp4_out") = py::none(), + py::arg("rowwise_scale_out") = py::none(), + py::arg("colwise_fp4_out") = py::none(), + py::arg("colwise_scale_out") = py::none(), + py::arg("shuffle_rowwise_scale") = true, + py::arg("shuffle_colwise_scale") = true, + py::arg("shuffle_rowwise_fp4") = true, + py::arg("shuffle_colwise_fp4") = true, + py::arg("use_hadamard") = false, + py::call_guard()); m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize, "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); m.def("split_quantize", &transformer_engine::pytorch::split_quantize, diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index cb199d24b..acdfd1e12 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -44,6 +44,14 @@ "restore_from_saved", ] +# Import MXFP4 tensor classes if available +try: + from .mxfp4_tensor import MXFP4Tensor, MXFP4Quantizer + from ._internal.mxfp4_tensor_base import MXFP4TensorBase + __all__.extend(["MXFP4Tensor", "MXFP4Quantizer", "MXFP4TensorBase"]) +except ImportError: + pass + def _make_module_cast_func(dtype): """Make module cast function that can handle QuantizedTensor""" @@ -90,4 +98,12 @@ def get_all_tensor_types(): NVFP4Tensor, NVFP4TensorStorage, ] + + try: + from transformer_engine.pytorch.tensor.mxfp4_tensor import MXFP4Tensor + from transformer_engine.pytorch.tensor._internal.mxfp4_tensor_base import MXFP4TensorBase + all_tensor_types.extend([MXFP4Tensor, MXFP4TensorBase]) + except ImportError: + pass + return all_tensor_types diff --git a/transformer_engine/pytorch/tensor/mxfp4_tensor.py b/transformer_engine/pytorch/tensor/mxfp4_tensor.py index adbe9802b..4b52378a6 100644 --- a/transformer_engine/pytorch/tensor/mxfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp4_tensor.py @@ -9,11 +9,12 @@ from typing import Optional, Tuple, Union import torch -from ..triton_kernels.cast import te_quantize_triton import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType +_HAS_HIP_MXFP4 = hasattr(tex, "cast_transpose_mxfp4_fused_shuffle") + from transformer_engine.common.recipe import MXFP4BlockScaling, Recipe from ..constants import MXFP8_BLOCK_SCALING_SIZE # MXFP4 uses same block size from ..utils import devices_match, round_up_to_nearest_multiple @@ -42,7 +43,8 @@ class MXFP4Quantizer(Quantizer): High-precision tensors (e.g. in FP32 or BF16) are quantized to FP4 by dividing them into groups of 32 elements, each scaled and cast - separately using AITER's per_1x32_f4_quant_hip kernel. + separately. On ROCm (gfx950), uses the fused HIP cast-transpose kernel + with optional Hadamard transform and AITER-compatible shuffled layout. The quantization produces: - FP4 data: [M, K/2] uint8 (2 FP4 values packed per byte) @@ -59,10 +61,12 @@ def __init__( rowwise: bool = True, columnwise: bool = True, shuffle_B_matrix_for_aiter: bool = False, + use_hadamard: bool = False, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) self.dtype = fp4_dtype self.shuffle_B_matrix_for_aiter = shuffle_B_matrix_for_aiter + self.use_hadamard = use_hadamard assert self.dtype == tex.DType.kFloat4E2M1, "Only E2M1 format supported for MXFP4" def update_quantized( @@ -75,17 +79,35 @@ def update_quantized( assert isinstance(dst, MXFP4Tensor), f"Cannot store quantized MXFP4 in {type(dst)} type." - # Make sure input is in expected format if not devices_match(src.device, dst.device): src = src.to(device=dst.device) if not src.is_contiguous(): src = src.contiguous() - te_quantize_triton(src, self, dst, noop_flag) + if _HAS_HIP_MXFP4: + if src.dtype != torch.bfloat16: + src = src.to(torch.bfloat16) + if src.dim() > 2: + src = src.view(-1, src.shape[-1]) + + with torch._C._DisableTorchDispatch(): + tex.cast_transpose_mxfp4_fused_shuffle( + src, + dst._rowwise_data.view(torch.uint8) if dst._rowwise_data is not None else None, + dst._rowwise_scale_inv.view(torch.uint8) if dst._rowwise_scale_inv is not None else None, + dst._columnwise_data.view(torch.uint8) if dst._columnwise_data is not None else None, + dst._columnwise_scale_inv.view(torch.uint8) if dst._columnwise_scale_inv is not None else None, + shuffle_rowwise_scale=True, + shuffle_colwise_scale=True, + shuffle_rowwise_fp4=self.shuffle_B_matrix_for_aiter, + shuffle_colwise_fp4=self.shuffle_B_matrix_for_aiter, + use_hadamard=self.use_hadamard, + ) + else: + from ..triton_kernels.cast import te_quantize_triton + te_quantize_triton(src, self, dst, noop_flag) - # Update FP4 dtype dst._fp4_dtype = self.dtype - return dst def is_quantizable(self, inp: torch.Tensor) -> bool: