From cd0f9fbccf3f01c35f723ba550a34ea257bedc01 Mon Sep 17 00:00:00 2001 From: Andrey Churkin Date: Wed, 13 May 2026 15:28:28 +0100 Subject: [PATCH 1/2] Add FP8 quantization support for the ONNX backend --- src/nncf/onnx/graph/model_transformer.py | 4 ++ src/nncf/onnx/quantization/quantize_model.py | 12 +++--- .../onnx/quantization/quantizer_parameters.py | 39 ++++++++++++++++++- .../algorithms/min_max/algorithm.py | 6 ++- .../algorithms/min_max/backend.py | 4 ++ .../algorithms/min_max/onnx_backend.py | 15 +++++-- .../algorithms/min_max/openvino_backend.py | 2 + .../algorithms/min_max/torch_backend.py | 2 + .../algorithms/min_max/torch_fx_backend.py | 2 + .../weight_compression/onnx_backend.py | 18 ++++++++- src/nncf/quantization/quantize_model.py | 1 - 11 files changed, 91 insertions(+), 14 deletions(-) diff --git a/src/nncf/onnx/graph/model_transformer.py b/src/nncf/onnx/graph/model_transformer.py index e9bc08df804..ae494bdccfe 100644 --- a/src/nncf/onnx/graph/model_transformer.py +++ b/src/nncf/onnx/graph/model_transformer.py @@ -284,13 +284,17 @@ def _get_scale_zero_point_tensors( dims = scale.shape if per_channel else [] onnx_scale = [scale.tolist()] if not per_channel else scale onnx_zero_point = [zero_point.tolist()] if not per_channel else zero_point + if tensor_type == np.uint8: onnx_tensor_type = onnx.TensorProto.UINT8 elif tensor_type == np.int8: onnx_tensor_type = onnx.TensorProto.INT8 + elif tensor_type in (onnx.TensorProto.FLOAT8E5M2, onnx.TensorProto.FLOAT8E4M3FN): + onnx_tensor_type = tensor_type else: msg = f"Incorrect tensor type - {tensor_type}." raise nncf.ValidationError(msg) + assert quantizer.input[1] == dequantizer.input[1] and quantizer.input[2] == dequantizer.input[2] scale_tensor_name = quantizer.input[1] zero_point_tensor_name = quantizer.input[2] diff --git a/src/nncf/onnx/quantization/quantize_model.py b/src/nncf/onnx/quantization/quantize_model.py index 4ec6b8c6111..d76179564c9 100644 --- a/src/nncf/onnx/quantization/quantize_model.py +++ b/src/nncf/onnx/quantization/quantize_model.py @@ -139,13 +139,14 @@ def quantize_impl( if target_device == TargetDevice.CPU_SPR: msg = "target_device == CPU_SPR is not supported." raise nncf.ValidationError(msg) - if mode is not None: - msg = f"mode={mode} is not supported" - raise ValueError(msg) - if model.opset_import[0].version < 10: + + opset_version = model.opset_import[0].version + if opset_version < 21 and mode is not None: + msg = f"FP8 quantization requires opset >= 21, got {opset_version}" + if opset_version < 10: msg = "ONNX models with opset version < 10 do not support quantization." raise nncf.ValidationError(msg) - if model.opset_import[0].version < 13: + if opset_version < 13: nncf_logger.warning( "ONNX models with 10 < opset version < 13 do not support per-channel quantization." " Per-tensor quantization will be applied." @@ -163,6 +164,7 @@ def quantize_impl( model = apply_preprocess_passes(model) quantization_algorithm = PostTrainingQuantization( + mode=mode, preset=preset, target_device=target_device, subset_size=subset_size, diff --git a/src/nncf/onnx/quantization/quantizer_parameters.py b/src/nncf/onnx/quantization/quantizer_parameters.py index 45ae65a758f..3677a4b9425 100644 --- a/src/nncf/onnx/quantization/quantizer_parameters.py +++ b/src/nncf/onnx/quantization/quantizer_parameters.py @@ -12,7 +12,10 @@ from dataclasses import dataclass import numpy as np +import onnx +from nncf.quantization.advanced_parameters import FP8Type +from nncf.quantization.fake_quantize import FakeConvertParameters from nncf.quantization.fake_quantize import FakeQuantizeParameters from nncf.quantization.fake_quantize import calculate_scale_zero_point from nncf.tensor import functions as fns @@ -31,10 +34,44 @@ class ONNXQuantizerLayerParameters: scale: np.ndarray zero_point: np.ndarray - tensor_type: np.dtype + tensor_type: onnx.TensorProto.DataType | np.dtype axis: int | None = None +def convert_fc_params_to_onnx_params( + parameters: FakeConvertParameters, axis: int | None +) -> ONNXQuantizerLayerParameters: + """ + Converts common FakeConvertParameters to ONNXQuantizerLayerParameters. + + :param parameters: FakeConvertParameters representation. + :param axis: Axis for per-channel quantization. + :return: Quantizer layer attributes. + """ + if parameters.destination_type == FP8Type.E4M3: + tensor_type = onnx.TensorProto.FLOAT8E4M3FN + elif parameters.destination_type == FP8Type.E5M2: + tensor_type = onnx.TensorProto.FLOAT8E5M2 + else: + msg = f"Unsupported FP8type: {parameters.destination_type}. Expected FP8Type.E4M3 or FP8Type.E5M2" + raise ValueError(msg) + + scale = parameters.scale + zero_point = parameters.shift + + # TODO(andrey-churkin): Check that scale and zero_point are calculated correctly. + + # NOTE: adding machine epsilon to avoid division by zero + eps = fns.finfo(scale).eps + scale = fns.where(fns.abs(scale) < eps, eps, scale) + scale = 1.0 / scale + # ONNX demands parameters to be a scalar or 1-D Tensor. + scale = fns.squeeze(scale) + zero_point = fns.squeeze(zero_point) + + return ONNXQuantizerLayerParameters(scale.data, zero_point.data, tensor_type, axis) + + def convert_fq_params_to_onnx_params( parameters: FakeQuantizeParameters, num_bits: int, tensor_type: np.dtype, axis: tuple[int] ) -> ONNXQuantizerLayerParameters: diff --git a/src/nncf/quantization/algorithms/min_max/algorithm.py b/src/nncf/quantization/algorithms/min_max/algorithm.py index 17d1d4ff840..fdf00745e23 100644 --- a/src/nncf/quantization/algorithms/min_max/algorithm.py +++ b/src/nncf/quantization/algorithms/min_max/algorithm.py @@ -1030,7 +1030,9 @@ def filter_func(point: StatisticPoint) -> bool: ) for quantization_target_point in unified_scale_group: transformation_layout.register( - self._backend_entity.create_convert_insertion_command(quantization_target_point, parameters) + self._backend_entity.create_convert_insertion_command( + graph, quantization_target_point, qconfig, parameters + ) ) unified_ops_list.add(quantization_target_point) continue @@ -1069,7 +1071,7 @@ def filter_func(point: StatisticPoint) -> bool: statistics, is_per_channel=qconfig.per_channel, destination_type=destination_type ) command = self._backend_entity.create_convert_insertion_command( - quantization_target_point, parameters + graph, quantization_target_point, qconfig, parameters ) else: parameters = calculate_quantizer_parameters(statistics, qconfig, quant_group, half_range) diff --git a/src/nncf/quantization/algorithms/min_max/backend.py b/src/nncf/quantization/algorithms/min_max/backend.py index bf1a235cdfa..180921ae34e 100644 --- a/src/nncf/quantization/algorithms/min_max/backend.py +++ b/src/nncf/quantization/algorithms/min_max/backend.py @@ -198,13 +198,17 @@ def create_unified_scales_quantizers_insertion_commands( @staticmethod @abstractmethod def create_convert_insertion_command( + nncf_graph: NNCFGraph, target_point: TargetPoint, + quantizer_config: QuantizerConfig, parameters: FakeConvertParameters, ) -> Command: """ Returns backend-specific convert insertion command. + :param nncf_graph: NNCFGraph to get input/output shapes for the target point. :param target_point: Target location for the correction. + :param quantizer_config: QuantizerConfig instance for the current layer. :param parameters: FakeConvertParameters to calculate activation quantization parameters. :return: Backend-specific Command for the quantizer insertion operation. """ diff --git a/src/nncf/quantization/algorithms/min_max/onnx_backend.py b/src/nncf/quantization/algorithms/min_max/onnx_backend.py index ee710e9b508..38b27bff70b 100644 --- a/src/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/src/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -12,7 +12,6 @@ import numpy as np -import nncf from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype @@ -33,6 +32,7 @@ from nncf.onnx.graph.transformations.commands import ONNXTargetPoint from nncf.onnx.hardware.config import ONNXHWConfig from nncf.onnx.quantization.default_quantization import DEFAULT_ONNX_QUANT_TRAIT_TO_OP_DICT +from nncf.onnx.quantization.quantizer_parameters import convert_fc_params_to_onnx_params from nncf.onnx.quantization.quantizer_parameters import convert_fq_params_to_onnx_params from nncf.parameters import ModelType from nncf.parameters import TargetDevice @@ -158,11 +158,20 @@ def create_unified_scales_quantizers_insertion_commands( @staticmethod def create_convert_insertion_command( + nncf_graph: NNCFGraph, target_point: ONNXTargetPoint, + quantizer_config: QuantizerConfig, parameters: FakeConvertParameters, ) -> TransformationCommand: - msg = "FakeConvert insertion not implemented in ONNX backend!" - raise nncf.InternalError(msg) + axis = None + if quantizer_config.per_channel: + node = nncf_graph.get_node_by_name(target_point.target_node_name) + axis = ( + get_weight_quantization_axis(node, target_point.port_id) if target_point.is_weight_target_point() else 1 + ) + onnx_parameters = convert_fc_params_to_onnx_params(parameters, axis) + nncf_input_node_next_nodes = ONNXMinMaxAlgoBackend._get_input_edges_mapping(nncf_graph) + return ONNXQuantizerInsertionCommand(target_point, nncf_input_node_next_nodes, onnx_parameters) @staticmethod def _get_input_edges_mapping(nncf_graph: NNCFGraph): diff --git a/src/nncf/quantization/algorithms/min_max/openvino_backend.py b/src/nncf/quantization/algorithms/min_max/openvino_backend.py index 04a84d33c9e..3fe65abddaf 100644 --- a/src/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/src/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -136,7 +136,9 @@ def create_unified_scales_quantizers_insertion_commands( @staticmethod def create_convert_insertion_command( + nncf_graph: NNCFGraph, target_point: OVTargetPoint, + quantizer_config: QuantizerConfig, parameters: FakeConvertParameters, ) -> OVQuantizerInsertionCommand: return OVConvertInsertionCommand(target_point, parameters) diff --git a/src/nncf/quantization/algorithms/min_max/torch_backend.py b/src/nncf/quantization/algorithms/min_max/torch_backend.py index 9e396bbdea9..4fddf7eea04 100644 --- a/src/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/src/nncf/quantization/algorithms/min_max/torch_backend.py @@ -149,7 +149,9 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) - @staticmethod def create_convert_insertion_command( + nncf_graph: NNCFGraph, target_point: PTTargetPoint, + quantizer_config: QuantizerConfig, parameters: FakeConvertParameters, ) -> TransformationCommand: msg = "FakeConvert insertion not implemented in PyTorch backend!" diff --git a/src/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/src/nncf/quantization/algorithms/min_max/torch_fx_backend.py index b166950a019..071ff56c2d4 100644 --- a/src/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/src/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -137,7 +137,9 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) - @staticmethod def create_convert_insertion_command( + nncf_graph: NNCFGraph, target_point: PTTargetPoint, + quantizer_config: QuantizerConfig, parameters: FakeConvertParameters, ) -> TransformationCommand: msg = "FakeConvert insertion not implemented in PyTorch backend!" diff --git a/src/nncf/quantization/algorithms/weight_compression/onnx_backend.py b/src/nncf/quantization/algorithms/weight_compression/onnx_backend.py index 7c4a1f92ef4..d58eb65df89 100644 --- a/src/nncf/quantization/algorithms/weight_compression/onnx_backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/onnx_backend.py @@ -71,6 +71,7 @@ class ONNXWeightCompressionAlgoBackend(WeightCompressionAlgoBackend): CompressWeightsMode.INT8_ASYM: onnx.TensorProto.UINT8, CompressWeightsMode.INT4_SYM: onnx.TensorProto.INT4, CompressWeightsMode.INT4_ASYM: onnx.TensorProto.UINT4, + CompressWeightsMode.FP8_E4M3: onnx.TensorProto.FLOAT8E4M3FN, } def __init__(self, model: onnx.ModelProto): @@ -363,8 +364,14 @@ def _add_dequantize_linear_layer( zero_point = pack_4_bits(zero_point) # Create initializers for the quantized weights, scale, and zero point + if weight_dtype == onnx.TensorProto.FLOAT8E4M3FN: + np_dtype = helper.tensor_dtype_to_np_dtype(weight_dtype) + vals = onnx.numpy_helper.saturate_cast(np.asarray(quantized_weights), np_dtype).flatten() + else: + vals = quantized_weights + quantized_weights_initializer = onnx.helper.make_tensor( - quantized_weight_name, weight_dtype, orig_shape, quantized_weights.tobytes(), raw=True + quantized_weight_name, weight_dtype, orig_shape, vals.tobytes(), raw=True ) scale_initializer = numpy_helper.from_array( np.array(scale, dtype=helper.tensor_dtype_to_np_dtype(scale_dtype)), name=scale_name @@ -374,8 +381,15 @@ def _add_dequantize_linear_layer( if zero_point is not None: deq_inputs.append(weight_name + "_zero_point") + + if weight_dtype == onnx.TensorProto.FLOAT8E4M3FN: + np_dtype = helper.tensor_dtype_to_np_dtype(weight_dtype) + vals = onnx.numpy_helper.saturate_cast(np.asarray(zero_point), np_dtype).flatten() + else: + vals = zero_point + zero_point_initializer = onnx.helper.make_tensor( - weight_name + "_zero_point", weight_dtype, orig_zero_point_shape, zero_point.tobytes(), raw=True + weight_name + "_zero_point", weight_dtype, orig_zero_point_shape, vals.tobytes(), raw=True ) new_initializers.append(zero_point_initializer) diff --git a/src/nncf/quantization/quantize_model.py b/src/nncf/quantization/quantize_model.py index 6ce83d496ef..e507673b200 100644 --- a/src/nncf/quantization/quantize_model.py +++ b/src/nncf/quantization/quantize_model.py @@ -627,7 +627,6 @@ def compress_weights( CompressWeightsMode.NF4, CompressWeightsMode.MXFP4, CompressWeightsMode.MXFP8_E4M3, - CompressWeightsMode.FP8_E4M3, CompressWeightsMode.FP4, CompressWeightsMode.NVFP4, CompressWeightsMode.CODEBOOK, From 1dc817f9e72071b2fb1cb06f7fb5488dd1407659 Mon Sep 17 00:00:00 2001 From: Andrey Churkin Date: Fri, 15 May 2026 09:45:13 +0100 Subject: [PATCH 2/2] update test --- tests/onnx/quantization/test_weights_compression.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/onnx/quantization/test_weights_compression.py b/tests/onnx/quantization/test_weights_compression.py index 0598b08d229..2f354ae7414 100644 --- a/tests/onnx/quantization/test_weights_compression.py +++ b/tests/onnx/quantization/test_weights_compression.py @@ -49,7 +49,6 @@ CompressWeightsMode.NVFP4, CompressWeightsMode.MXFP4, CompressWeightsMode.MXFP8_E4M3, - CompressWeightsMode.FP8_E4M3, CompressWeightsMode.FP4, )