From 42765969eceb49f203b3cc67f0d336010a6d7699 Mon Sep 17 00:00:00 2001 From: Nikolay Date: Fri, 6 Mar 2026 16:11:53 +0100 Subject: [PATCH] [Torch] INT2 symmetric decompression support --- .gitignore | 6 + .../distillation_qat_with_lora/lm-eval.sh | 8 + .../main_export_to_openvino.py | 50 +++ .../test_synthetic_int2_export.py | 152 +++++++ src/nncf/parameters.py | 3 + .../algorithms/weight_compression/config.py | 27 +- .../weight_compression/constants.py | 1 + src/nncf/tensor/definitions.py | 2 + src/nncf/torch/function_hook/strip.py | 2 +- src/nncf/torch/quantization/layers.py | 58 +++ .../torch/quantization/quantize_functions.py | 50 ++- src/nncf/torch/quantization/strip.py | 16 +- .../quantization/test_weights_compression.py | 49 +++ .../torch/quantization/test_symmetric_2bit.py | 378 ++++++++++++++++++ 14 files changed, 787 insertions(+), 15 deletions(-) create mode 100755 examples/llm_compression/torch/distillation_qat_with_lora/lm-eval.sh create mode 100644 examples/llm_compression/torch/distillation_qat_with_lora/main_export_to_openvino.py create mode 100644 examples/llm_compression/torch/distillation_qat_with_lora/test_synthetic_int2_export.py create mode 100644 tests/torch/quantization/test_symmetric_2bit.py diff --git a/.gitignore b/.gitignore index 03a23b98880..550288fac19 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,9 @@ nncf-tests.xml compressed_graph.dot original_graph.dot tests/post_training/**/*memory_logs + +output_* +*eval* +debug* +*.pth +*.db \ No newline at end of file diff --git a/examples/llm_compression/torch/distillation_qat_with_lora/lm-eval.sh b/examples/llm_compression/torch/distillation_qat_with_lora/lm-eval.sh new file mode 100755 index 00000000000..1300b296695 --- /dev/null +++ b/examples/llm_compression/torch/distillation_qat_with_lora/lm-eval.sh @@ -0,0 +1,8 @@ +IR_DIR=u2_u4_ov_model +lm_eval \ +--model openvino \ +--model_args pretrained=$IR_DIR \ +--device cpu \ +--output_path ov_eval \ +--limit 100 \ +--tasks lambada_openai diff --git a/examples/llm_compression/torch/distillation_qat_with_lora/main_export_to_openvino.py b/examples/llm_compression/torch/distillation_qat_with_lora/main_export_to_openvino.py new file mode 100644 index 00000000000..873184c53f9 --- /dev/null +++ b/examples/llm_compression/torch/distillation_qat_with_lora/main_export_to_openvino.py @@ -0,0 +1,50 @@ +# Copyright (c) 2026 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +import torch +from optimum.exporters.openvino.convert import export_from_model +from torch import nn +from transformers import AutoModelForCausalLM + +import nncf +from nncf.parameters import StripFormat +from nncf.torch.function_hook.wrapper import get_hook_storage +from nncf.torch.model_creation import load_from_config +from nncf.torch.quantization.layers import SymmetricLoraQuantizer # noqa: F401 + + +def load_checkpoint(model: nn.Module, ckpt_file: Path) -> nn.Module: + """ + Loads the state of a tuned model from a checkpoint. This function restores the placement of Fake Quantizers (FQs) + with absorbable LoRA adapters and loads their parameters. + + :param model: The model to load the checkpoint into. + :param ckpt_file: Path to the checkpoint file. + :returns: The model with the loaded NNCF state from checkpoint. + """ + ckpt = torch.load(ckpt_file, weights_only=False, map_location="cpu") + model = load_from_config(model, ckpt["nncf_config"]) + if "model_state" in ckpt: + model.load_state_dict(ckpt["model_state"]) + hook_storage = get_hook_storage(model) + hook_storage.load_state_dict(ckpt["nncf_state_dict"]) + return model + + +pretrained = "Qwen/Qwen3-4B" +ckpt_file = "nncf_checkpoint_epoch10.pth" +ir_dir = "u2_u4_ov_model" +with torch.no_grad(): + model_to_eval = AutoModelForCausalLM.from_pretrained(pretrained, torch_dtype=torch.float32, device_map="cpu") + model_to_eval = load_checkpoint(model_to_eval, ckpt_file) + model_to_eval = nncf.strip(model_to_eval, do_copy=False, strip_format=StripFormat.DQ) + export_from_model(model_to_eval, ir_dir, device="cpu") diff --git a/examples/llm_compression/torch/distillation_qat_with_lora/test_synthetic_int2_export.py b/examples/llm_compression/torch/distillation_qat_with_lora/test_synthetic_int2_export.py new file mode 100644 index 00000000000..3f998d1379f --- /dev/null +++ b/examples/llm_compression/torch/distillation_qat_with_lora/test_synthetic_int2_export.py @@ -0,0 +1,152 @@ +# Copyright (c) 2026 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Synthetic test to verify INT2 symmetric decompression subgraph +can be exported to OpenVINO IR via torch.jit.trace + openvino.convert_model. +""" + +import numpy as np +import openvino as ov +import torch + + +def pack_uint2(tensor: torch.Tensor) -> torch.Tensor: + packed_tensor = tensor.contiguous().reshape(-1, 4) + packed_tensor = ( + torch.bitwise_and(packed_tensor[..., 0], 3) + | (torch.bitwise_and(packed_tensor[..., 1], 3) << 2) + | (torch.bitwise_and(packed_tensor[..., 2], 3) << 4) + | (torch.bitwise_and(packed_tensor[..., 3], 3) << 6) + ) + return packed_tensor + + +def unpack_uint2(packed_tensor: torch.Tensor) -> torch.Tensor: + return torch.stack( + ( + torch.bitwise_and(packed_tensor, 3), + torch.bitwise_and(torch.bitwise_right_shift(packed_tensor, 2), 3), + torch.bitwise_and(torch.bitwise_right_shift(packed_tensor, 4), 3), + torch.bitwise_and(torch.bitwise_right_shift(packed_tensor, 6), 3), + ), + dim=-1, + ) + + +def decompress_symmetric(input, scale): + input = input.type(dtype=scale.dtype) + return input * scale + + +class INT2SymmetricLinear(torch.nn.Module): + """ + A simple linear layer that uses INT2 symmetric weight decompression, + matching the NNCF INT2SymmetricWeightsDecompressor pattern. + """ + + ZERO_POINT_VALUE = 2 + + def __init__(self, in_features, out_features, group_size): + super().__init__() + assert out_features % group_size == 0 + ngroups = out_features // group_size + + compressed_weight_shape = (ngroups, group_size, in_features) + scale_shape = (ngroups, 1, in_features) + + # Random uint2 weights [0, 3] + rng = np.random.default_rng(seed=42) + raw_weights = rng.integers(0, 4, size=compressed_weight_shape, dtype=np.uint8) + scale = (rng.random(scale_shape, dtype=np.float32) * 2.0 - 1.0).astype(np.float32) + + self.compressed_weight_shape = compressed_weight_shape + self.packed_weight = torch.nn.Parameter(pack_uint2(torch.from_numpy(raw_weights)), requires_grad=False) + self.register_buffer("_scale", torch.from_numpy(scale).to(torch.float16)) + self.register_buffer("_zero_point", torch.tensor(self.ZERO_POINT_VALUE, dtype=torch.uint8)) + self.result_shape = (out_features, in_features) + self.result_dtype = torch.float32 + + def forward(self, x): + # NNCF INT2 symmetric decompression pattern + w = unpack_uint2(self.packed_weight) + w = w.reshape(self.compressed_weight_shape) + w = w.type(dtype=self.result_dtype) - self._zero_point.type(dtype=self.result_dtype) + w = decompress_symmetric(w, self._scale) + w = w.reshape(self.result_shape) + w = w.type(dtype=self.result_dtype) + return torch.matmul(x, w.t()) + + +class SmallModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = INT2SymmetricLinear(16, 32, group_size=4) + self.linear2 = INT2SymmetricLinear(32, 16, group_size=4) + + def forward(self, x): + x = self.linear1(x) + x = torch.relu(x) + x = self.linear2(x) + return x + + +def main(): + print("=== Synthetic INT2 export test ===") + model = SmallModel() + model.eval() + + dummy_input = torch.randn(1, 16) + + # Step 1: Convert to OpenVINO IR + print("[1/4] Converting to OpenVINO IR...") + ov_model = ov.convert_model(model, example_input=dummy_input) + print(" Conversion successful.") + + # Step 2: Check u2 constants in the converted OV model + print("[2/4] Checking u2 constants in OV model...") + u2_constants = [] + for op in ov_model.get_ordered_ops(): + if op.get_type_name() == "Constant" and "uint2" in str(op.get_output_element_type(0)): + u2_constants.append(op) + + expected_u2_count = 2 # one per INT2SymmetricLinear layer + print(f" Found {len(u2_constants)} u2 constant(s) (expected {expected_u2_count}).") + for c in u2_constants: + print(f" - {c.get_friendly_name()}: shape={c.get_output_partial_shape(0)}") + assert len(u2_constants) == expected_u2_count, f"Expected {expected_u2_count} u2 constants, got {len(u2_constants)}" + print(" PASSED - u2 constants detected.") + + # Step 3: Save IR + ir_path = "/tmp/test_int2_synthetic_ir" + print(f"[3/4] Saving IR to {ir_path}...") + ov.save_model(ov_model, f"{ir_path}/model.xml") + print(" Save successful.") + + # Step 4: Verify inference + print("[4/4] Running inference comparison...") + with torch.no_grad(): + torch_out = model(dummy_input).numpy() + + compiled = ov.Core().compile_model(ov_model, "CPU") + ov_out = compiled(dummy_input.numpy())[0] + + max_diff = np.max(np.abs(torch_out - ov_out)) + print(f" Max absolute difference: {max_diff:.6e}") + if max_diff < 1e-2: + print(" PASSED - Outputs match within tolerance.") + else: + print(f" WARNING - Large difference detected: {max_diff}") + + print("\n=== All steps completed successfully! ===") + + +if __name__ == "__main__": + main() diff --git a/src/nncf/parameters.py b/src/nncf/parameters.py index 0729465dbf4..d02940d4d6d 100644 --- a/src/nncf/parameters.py +++ b/src/nncf/parameters.py @@ -88,6 +88,8 @@ class CompressWeightsMode(StrEnum): :param INT4_ASYM: The same as INT4_SYM mode, but weights are quantized to a primary precision asymmetrically with a typical non-fixed zero point. https://github.com/openvinotoolkit/nncf/blob/develop/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#asymmetric-quantization + :param INT2_SYM: Stands for 2-bit integer symmetric quantization without zero point. + Similar to INT4_SYM but with a 2-bit primary precision. :param NF4: The the same as INT4_SYM mode, but primary precision is NF4 data type without zero point. :param INT8: Mode is deprecated and will be removed in future releases. Please use `INT8_ASYM` instead. :param MXFP4: MX-compliant FP4 format with E2M1 values sharing group-level E8M0 scale. The size of group is 32. @@ -103,6 +105,7 @@ class CompressWeightsMode(StrEnum): INT8_ASYM = "int8_asym" INT4_SYM = "int4_sym" INT4_ASYM = "int4_asym" + INT2_SYM = "int2_sym" NF4 = "nf4" CB4 = "cb4" INT8 = "int8" # Deprecated mode diff --git a/src/nncf/quantization/algorithms/weight_compression/config.py b/src/nncf/quantization/algorithms/weight_compression/config.py index 99b5b38e391..a0d0c651078 100644 --- a/src/nncf/quantization/algorithms/weight_compression/config.py +++ b/src/nncf/quantization/algorithms/weight_compression/config.py @@ -46,18 +46,26 @@ def num_bits(self): """ :return: number of bits that is used for storing a single quantized value in the given mode. """ - if self.mode in [ - CompressWeightsMode.INT8_SYM, - CompressWeightsMode.INT8_ASYM, - CompressWeightsMode.FP8_E4M3, - CompressWeightsMode.MXFP8_E4M3, - ]: - return 8 - return 4 + return { + CompressWeightsMode.INT8_SYM: 8, + CompressWeightsMode.INT8_ASYM: 8, + CompressWeightsMode.FP8_E4M3: 8, + CompressWeightsMode.MXFP8_E4M3: 8, + CompressWeightsMode.INT4_SYM: 4, + CompressWeightsMode.INT4_ASYM: 4, + CompressWeightsMode.NF4: 4, + CompressWeightsMode.MXFP4: 4, + CompressWeightsMode.FP4: 4, + CompressWeightsMode.CB4: 4, + CompressWeightsMode.INT2_SYM: 2, + }.get(self.mode, 4) @property def is_asym_mode(self): - return self.mode in [CompressWeightsMode.INT4_ASYM, CompressWeightsMode.INT8_ASYM] + return self.mode in [ + CompressWeightsMode.INT4_ASYM, + CompressWeightsMode.INT8_ASYM, + ] @property def is_integer(self): @@ -101,6 +109,7 @@ def compression_dtype(self) -> TensorDataType: dtype_per_mode = { CompressWeightsMode.INT4_SYM: TensorDataType.int4, CompressWeightsMode.INT4_ASYM: TensorDataType.uint4, + CompressWeightsMode.INT2_SYM: TensorDataType.int2, CompressWeightsMode.INT8_ASYM: TensorDataType.uint8, CompressWeightsMode.INT8_SYM: TensorDataType.int8, CompressWeightsMode.NF4: TensorDataType.nf4, diff --git a/src/nncf/quantization/algorithms/weight_compression/constants.py b/src/nncf/quantization/algorithms/weight_compression/constants.py index 449c1b67ff9..370c6f0a03b 100644 --- a/src/nncf/quantization/algorithms/weight_compression/constants.py +++ b/src/nncf/quantization/algorithms/weight_compression/constants.py @@ -120,6 +120,7 @@ CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT4_ASYM, CompressWeightsMode.INT4_SYM, + CompressWeightsMode.INT2_SYM, ) OPTIMIZED_COMPRESSION_COMPATIBLE_FLOAT_MODES = ( diff --git a/src/nncf/tensor/definitions.py b/src/nncf/tensor/definitions.py index b38e6fbd025..ef4448a2260 100644 --- a/src/nncf/tensor/definitions.py +++ b/src/nncf/tensor/definitions.py @@ -53,6 +53,7 @@ class TensorDataType(StrEnum): uint8 = auto() uint4 = auto() int4 = auto() + int2 = auto() def is_float(self) -> bool: """ @@ -78,6 +79,7 @@ def itemsize(self) -> int: TensorDataType.nf4: 4, TensorDataType.uint4: 4, TensorDataType.int4: 4, + TensorDataType.int2: 2, TensorDataType.f8e4m3: 8, TensorDataType.f8e5m2: 8, TensorDataType.int8: 8, diff --git a/src/nncf/torch/function_hook/strip.py b/src/nncf/torch/function_hook/strip.py index 2a7d64b4221..464722b63d8 100644 --- a/src/nncf/torch/function_hook/strip.py +++ b/src/nncf/torch/function_hook/strip.py @@ -126,7 +126,7 @@ def replace_quantizer_to_compressed_weight_with_decompressor(model: TModel) -> T msg = "" if hook_module._qspec.half_range or hook_module._qspec.narrow_range: msg += "Unexpected parameters of quantizers on strip: half_range and narrow_range should be False.\n" - if hook_module.num_bits not in [4, 8]: + if hook_module.num_bits not in [2, 4, 8]: msg += f"Unsupported number of bits {hook_module.num_bits} for the quantizer {hook_module}.\n" if msg: raise nncf.ValidationError(msg) diff --git a/src/nncf/torch/quantization/layers.py b/src/nncf/torch/quantization/layers.py index 50c76ccabdc..95063e5fa74 100644 --- a/src/nncf/torch/quantization/layers.py +++ b/src/nncf/torch/quantization/layers.py @@ -47,10 +47,12 @@ from nncf.torch.quantization.quantize_functions import decompress_symmetric from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high from nncf.torch.quantization.quantize_functions import pack_int4 +from nncf.torch.quantization.quantize_functions import pack_uint2 from nncf.torch.quantization.quantize_functions import pack_uint4 from nncf.torch.quantization.quantize_functions import symmetric_quantize from nncf.torch.quantization.quantize_functions import symmetric_quantize_lora from nncf.torch.quantization.quantize_functions import unpack_int4 +from nncf.torch.quantization.quantize_functions import unpack_uint2 from nncf.torch.quantization.quantize_functions import unpack_uint4 from nncf.torch.return_types import maybe_get_values_from_torch_return_type from nncf.torch.return_types import maybe_wrap_to_torch_return_type @@ -1467,6 +1469,62 @@ def forward(self, x): return result +class INT2SymmetricWeightsDecompressor(BaseWeightsDecompressor): + """ + Applies symmetric decompression of 2-bit compressed weights in the forward pass. + + Weights with values in [-2, -1, 0, 1] are stored as uint2 [0, 1, 2, 3] using + a hardcoded zero point of 2. Four uint2 values are packed into each uint8 byte. + """ + + ZERO_POINT_VALUE = 2 + + def __init__( + self, + scale: torch.Tensor, + compressed_weight_shape: tuple[int, ...], + result_shape: tuple[int, ...] | None = None, + result_dtype: torch.dtype | None = None, + ): + """ + :param scale: A scale in quantization scheme + :param compressed_weight_shape: A compressed weight shape + :param result_shape: (Optional) A shape that result should be reshaped to + :param result_dtype: (Optional) A data type that result should be cast to + """ + super().__init__() + self.register_buffer("_scale", scale.type(dtype=torch.float16)) + self.register_buffer( + "_zero_point", + torch.tensor(self.ZERO_POINT_VALUE, dtype=torch.uint8), + ) + + self.compressed_weight_shape = compressed_weight_shape + self.result_shape = result_shape + self.result_dtype = result_dtype + + @property + def quantization_mode(self) -> QuantizationMode: + return QuantizationMode.SYMMETRIC + + def pack_weight(self, weight: torch.Tensor) -> torch.Tensor: + if torch.any((weight < 0) | (weight > 3)): + msg = "Weight values are not in [0, 3]." + raise ValueError(msg) + return pack_uint2(weight.type(dtype=torch.uint8)) + + def forward(self, x): + x = unpack_uint2(x) + x = x.reshape(self.compressed_weight_shape) + + x = x.type(dtype=self.result_dtype) - self._zero_point.type(dtype=self.result_dtype) + + result = decompress_symmetric(x, self._scale) + result = result.reshape(self.result_shape) if self.result_shape is not None else result + result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result + return result + + @COMPRESSION_MODULES.register() class SQMultiply(torch.nn.Module, StatefulModuleInterface): SCALE_SHAPE_KEY = "scale_shape" diff --git a/src/nncf/torch/quantization/quantize_functions.py b/src/nncf/torch/quantization/quantize_functions.py index 047b376164a..0d95746eb45 100644 --- a/src/nncf/torch/quantization/quantize_functions.py +++ b/src/nncf/torch/quantization/quantize_functions.py @@ -121,10 +121,12 @@ def backward(ctx: Any, *grad_outputs: Any) -> Any: class QuantizeSymmetricTorch(torch.autograd.Function): @staticmethod def forward(ctx, input_, input_shape, scale, level_low, level_high, levels): - # range: [-scale, 7/8 * scale] if scale > 0 else [7/8 * scale, -scale] + # Signed-scale formula: scale sign selects which side gets more quants. + # scale > 0 → range [-scale, level_high/|level_low| * scale] (more quants for negatives) + # scale < 0 → range [level_high/|level_low| * scale, -scale] (more quants for positives) + # Works for any bit-width (2, 4, 8, …). input_low = torch.where(scale > 0, -scale, -scale / level_low * level_high) - # 15/8 * scale or (2-1/8) * scale - input_range = torch.abs((2 + 1 / level_low) * scale) + input_range = (levels - 1) * torch.abs(scale) / (-level_low) dtype = input_.dtype original_shape = input_.shape input_ = input_.reshape(input_shape) @@ -471,3 +473,45 @@ def unpack_int4(packed_tensor: torch.Tensor) -> torch.Tensor: """ t = unpack_uint4(packed_tensor) return t.type(torch.int8) - 8 + + +def pack_uint2(tensor: torch.Tensor) -> torch.Tensor: + """ + Packs a tensor containing uint2 values (in the range [0, 3]) into a tensor with uint8 values, + where each element stores four uint2 values. + + :param tensor: A tensor of dtype `torch.uint8` where each element represents a uint2 value. + The tensor should contain values in the range [0, 3]. + :return: A packed tensor of dtype `torch.uint8` where each element packs four uint2 values. + :raises nncf.errors.ValidationError: If the input tensor is not of type `torch.uint8`. + """ + if tensor.dtype != torch.uint8: + msg = f"Invalid tensor dtype {tensor.type}. torch.uint8 type is supported." + raise ValidationError(msg) + packed_tensor = tensor.contiguous().reshape(-1, 4) + packed_tensor = ( + torch.bitwise_and(packed_tensor[..., 0], 3) + | (torch.bitwise_and(packed_tensor[..., 1], 3) << 2) + | (torch.bitwise_and(packed_tensor[..., 2], 3) << 4) + | (torch.bitwise_and(packed_tensor[..., 3], 3) << 6) + ) + return packed_tensor + + +def unpack_uint2(packed_tensor: torch.Tensor) -> torch.Tensor: + """ + Unpacks a tensor, where each uint8 element stores four uint2 values, back into a tensor with + individual uint2 values. + + :param packed_tensor: A tensor of dtype `torch.uint8` where each element packs four uint2 values. + :return: A tensor of dtype `torch.uint8` where each element represents a uint2 value. + """ + return torch.stack( + ( + torch.bitwise_and(packed_tensor, 3), + torch.bitwise_and(torch.bitwise_right_shift(packed_tensor, 2), 3), + torch.bitwise_and(torch.bitwise_right_shift(packed_tensor, 4), 3), + torch.bitwise_and(torch.bitwise_right_shift(packed_tensor, 6), 3), + ), + dim=-1, + ) diff --git a/src/nncf/torch/quantization/strip.py b/src/nncf/torch/quantization/strip.py index 09eebaa7b48..fa45772676f 100644 --- a/src/nncf/torch/quantization/strip.py +++ b/src/nncf/torch/quantization/strip.py @@ -19,6 +19,7 @@ import nncf from nncf.torch.quantization.layers import AsymmetricQuantizer from nncf.torch.quantization.layers import BaseQuantizer +from nncf.torch.quantization.layers import INT2SymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor @@ -26,7 +27,7 @@ from nncf.torch.quantization.layers import SymmetricQuantizer from nncf.torch.quantization.quantize_functions import TuneRange -SUPPORTED_NUM_BITS_FOR_STRIP_MODEL = [8] +SUPPORTED_NUM_BITS_FOR_STRIP_MODEL = [8, 4, 2] def convert_to_torch_fakequantizer(nncf_quantizer: BaseQuantizer) -> FakeQuantize: @@ -145,7 +146,9 @@ def asym_fq_to_decompressor( def sym_fq_to_decompressor( quantizer: SymmetricQuantizer, weight: torch.Tensor -) -> tuple[INT8SymmetricWeightsDecompressor | INT4SymmetricWeightsDecompressor, torch.Tensor]: +) -> tuple[ + INT8SymmetricWeightsDecompressor | INT4SymmetricWeightsDecompressor | INT2SymmetricWeightsDecompressor, torch.Tensor +]: """ Converts an asymmetric quantizer and original weight tensor to a decompressor and quantized weight tensor. @@ -178,6 +181,15 @@ def sym_fq_to_decompressor( if quantizer.num_bits == 8: decompressor = INT8SymmetricWeightsDecompressor(scale=scale, result_dtype=weight_dtype) + elif quantizer.num_bits == 2: + # Shift signed weights to unsigned: [-2, 1] -> [0, 3] + q_weight = (q_weight + 2).to(torch.uint8) + decompressor = INT2SymmetricWeightsDecompressor( + scale=scale, + compressed_weight_shape=q_weight.shape, + result_shape=weight_shape, + result_dtype=weight_dtype, + ) else: decompressor = INT4SymmetricWeightsDecompressor( scale=scale, diff --git a/tests/torch/function_hook/quantization/test_weights_compression.py b/tests/torch/function_hook/quantization/test_weights_compression.py index 8ab004d2c09..371e27b9085 100644 --- a/tests/torch/function_hook/quantization/test_weights_compression.py +++ b/tests/torch/function_hook/quantization/test_weights_compression.py @@ -32,13 +32,16 @@ from nncf.torch.function_hook import get_hook_storage from nncf.torch.function_hook import wrap_model from nncf.torch.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper +from nncf.torch.quantization.layers import INT2SymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor from nncf.torch.quantization.quantize_functions import pack_int4 +from nncf.torch.quantization.quantize_functions import pack_uint2 from nncf.torch.quantization.quantize_functions import pack_uint4 from nncf.torch.quantization.quantize_functions import unpack_int4 +from nncf.torch.quantization.quantize_functions import unpack_uint2 from nncf.torch.quantization.quantize_functions import unpack_uint4 from tests.cross_fw.test_templates.helpers import RoPEModel from tests.cross_fw.test_templates.helpers import SAMPEModel @@ -567,6 +570,52 @@ def test_pack_int4(): assert torch.all(unpacked_w == w_int8) +def test_pack_uint2(): + w_uint8 = torch.randint(0, 4, (4, 4), dtype=torch.uint8) + packed_w = pack_uint2(w_uint8) + assert packed_w.dtype == torch.uint8 + assert packed_w.numel() * 4 == w_uint8.numel() + unpacked_w = unpack_uint2(packed_w).reshape(w_uint8.shape) + assert torch.all(unpacked_w == w_uint8) + + +def test_pack_uint2_single_value(): + """pack_uint2 requires multiples of 4 elements (4 uint2 per uint8 byte). + A single u2 value cannot be packed alone — verify that 4 identical values + round-trip correctly through pack/unpack.""" + # 4 copies of the value 2 (the minimum packable unit) + w_uint8 = torch.tensor([2, 2, 2, 2], dtype=torch.uint8) + packed_w = pack_uint2(w_uint8) + assert packed_w.shape == (1,) # 4 uint2 -> 1 uint8 + # Expected: 2 | (2<<2) | (2<<4) | (2<<6) = 2 + 8 + 32 + 128 = 170 = 0xAA + assert packed_w.item() == 0xAA + unpacked_w = unpack_uint2(packed_w).reshape(w_uint8.shape) + assert torch.all(unpacked_w == w_uint8) + + +def test_pack_uint2_all_values(): + """Verify all four possible uint2 values [0, 1, 2, 3] pack and unpack correctly.""" + w_uint8 = torch.tensor([0, 1, 2, 3], dtype=torch.uint8) + packed_w = pack_uint2(w_uint8) + assert packed_w.shape == (1,) + # Expected: 0 | (1<<2) | (2<<4) | (3<<6) = 0 + 4 + 32 + 192 = 228 = 0xE4 + assert packed_w.item() == 0xE4 + unpacked_w = unpack_uint2(packed_w).reshape(w_uint8.shape) + assert torch.all(unpacked_w == w_uint8) + + +def test_int2_symmetric_weights_decompressor(): + scale = torch.tensor([[0.5], [1.0]], dtype=torch.float32) + weight_signed = torch.tensor([[-2, -1, 0, 1], [-1, 0, 1, -2]], dtype=torch.int8) + weight_unsigned = (weight_signed + 2).to(torch.uint8) + + decompressor = INT2SymmetricWeightsDecompressor(scale, compressed_weight_shape=(2, 4), result_dtype=torch.float32) + packed_w = decompressor.pack_weight(weight_unsigned) + result = decompressor(packed_w) + expected = weight_signed.float() * scale.float() + assert torch.allclose(result, expected, atol=1e-3) + + class TestPTTemplateWeightCompression(TemplateWeightCompression): @staticmethod def get_matmul_model() -> torch.nn.Module: diff --git a/tests/torch/quantization/test_symmetric_2bit.py b/tests/torch/quantization/test_symmetric_2bit.py new file mode 100644 index 00000000000..7e09ef9fdda --- /dev/null +++ b/tests/torch/quantization/test_symmetric_2bit.py @@ -0,0 +1,378 @@ +# Copyright (c) 2026 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for symmetric quantization at 2-bit (and other bit-widths). + +QuantizeSymmetric and QuantizeSymmetricTorch use different scale conventions: + + QuantizeSymmetric: + scale = level_high × step_size (max representable positive value) + input_low = scale * (level_low / level_high) + input_range = scale - input_low + + QuantizeSymmetricTorch (signed-scale, from _calculate_signed_scale): + scale = ±|level_low| × step_size (sign selects which side gets more quants) + input_low differs by sign of scale (flips asymmetry) + input_range = (levels - 1) * |scale| / |level_low| + +Both are correct; they just interpret `scale` differently. +QuantizeSymmetricTorch's signed-scale design allocates more quants to the +dominant side (positive or negative), which is especially important at low +bit-widths where a single extra level matters. +""" + +import pytest +import torch + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +# 2-bit signed: levels=4, level_low=-2, level_high=1 +PARAMS_2BIT = dict(level_low=-2, level_high=1, levels=4) +# 4-bit signed: levels=16, level_low=-8, level_high=7 +PARAMS_4BIT = dict(level_low=-8, level_high=7, levels=16) +# 8-bit signed: levels=256, level_low=-128, level_high=127 +PARAMS_8BIT = dict(level_low=-128, level_high=127, levels=256) + + +def _step_size(scale_val: float, level_low: int) -> float: + """Compute the step size for QuantizeSymmetricTorch's scale convention.""" + return abs(scale_val) / (-level_low) + + +def _expected_levels(scale_val: float, level_low: int, level_high: int, **_) -> list[float]: + """ + Compute expected dequantized output levels for QuantizeSymmetricTorch. + + For scale > 0: levels are {level_low, …, level_high} * step, more quants on negative side. + For scale < 0: levels are {-level_high, …, -level_low} * step, more quants on positive side. + """ + step = _step_size(scale_val, level_low) + if scale_val > 0: + return [i * step for i in range(level_low, level_high + 1)] + return [i * step for i in range(-level_high, -level_low + 1)] + + +def _run_quantize_symmetric(input_tensor, scale, level_low, level_high, levels): + """Run QuantizeSymmetric (original, CUDA/CPU extension).""" + from nncf.torch.quantization.quantize_functions import QuantizeSymmetric + + return QuantizeSymmetric.apply(input_tensor, scale, level_low, level_high, levels) + + +def _run_quantize_symmetric_torch(input_tensor, scale, level_low, level_high, levels): + """Run QuantizeSymmetricTorch (pure-torch, used for LoRA).""" + from nncf.torch.quantization.quantize_functions import QuantizeSymmetricTorch + + input_shape = input_tensor.shape + return QuantizeSymmetricTorch.apply(input_tensor, input_shape, scale, level_low, level_high, levels) + + +# --------------------------------------------------------------------------- +# Tests: QuantizeSymmetric (baseline) at 2-bit +# --------------------------------------------------------------------------- + + +class TestQuantizeSymmetric2Bit: + """Verify QuantizeSymmetric (original) is correct at 2-bit.""" + + @pytest.mark.parametrize( + "input_val, expected", + [ + (-3.0, -2.0), + (-2.0, -2.0), + (-1.5, -2.0), + (-1.0, -1.0), + (-0.5, 0.0), + (0.0, 0.0), + (0.5, 0.0), + (1.0, 1.0), + (1.5, 1.0), + ], + ) + def test_scale_1(self, input_val, expected): + scale = torch.tensor(1.0) + x = torch.tensor(input_val) + out = _run_quantize_symmetric(x, scale, **PARAMS_2BIT) + assert out.item() == pytest.approx(expected, abs=1e-5) + + +# --------------------------------------------------------------------------- +# Tests: QuantizeSymmetricTorch at 2-bit — correct with its own scale convention +# --------------------------------------------------------------------------- + + +class TestQuantizeSymmetricTorch2Bit: + """ + QuantizeSymmetricTorch at 2-bit with the signed-scale convention. + + scale is produced by _calculate_signed_scale: scale = max_abs / |level_low| + step_size = |scale| / |level_low| + + For 2-bit (level_low=-2, level_high=1) with scale=1.0 (positive -> negatives dominate): + step = 1.0/2 = 0.5 + range: [-scale, level_high/|level_low| * scale] = [-1.0, 0.5] + output levels: {-2, -1, 0, 1} * 0.5 = {-1.0, -0.5, 0.0, 0.5} + -> 2 levels negative, 1 level positive (more quants for negatives) + """ + + @pytest.mark.parametrize( + "input_val, expected", + [ + # scale=1.0 positive -> step=0.5, range [-1.0, 0.5], output {-1.0, -0.5, 0.0, 0.5} + (-2.0, -1.0), # clipped to input_low + (-1.0, -1.0), # exact min output + (-0.8, -1.0), # rounds to -1.0 + (-0.6, -0.5), # rounds to -0.5 + (-0.5, -0.5), # exact + (-0.25, 0.0), # rounds to 0 + (0.0, 0.0), # exact zero + (0.25, 0.0), # rounds to 0 + (0.5, 0.5), # exact max output + (1.0, 0.5), # clipped to input_high + ], + ) + def test_positive_scale(self, input_val, expected): + """Positive scale: more quants allocated to negatives.""" + scale = torch.tensor(1.0) + x = torch.tensor([input_val]) + out = _run_quantize_symmetric_torch(x, scale, **PARAMS_2BIT) + assert out.item() == pytest.approx(expected, abs=1e-5) + + @pytest.mark.parametrize( + "input_val, expected", + [ + # scale=-1.0 negative -> step=0.5, range [-0.5, 1.0], output {-0.5, 0.0, 0.5, 1.0} + (-1.0, -0.5), # clipped to input_low + (-0.5, -0.5), # exact min output + (-0.25, 0.0), # rounds to 0 + (0.0, 0.0), # exact zero + (0.25, 0.0), # rounds to 0 + (0.5, 0.5), # exact + (0.6, 0.5), # rounds to 0.5 + (0.8, 1.0), # rounds to 1.0 + (1.0, 1.0), # exact max output + (2.0, 1.0), # clipped to input_high + ], + ) + def test_negative_scale(self, input_val, expected): + """Negative scale: more quants allocated to positives.""" + scale = torch.tensor(-1.0) + x = torch.tensor([input_val]) + out = _run_quantize_symmetric_torch(x, scale, **PARAMS_2BIT) + assert out.item() == pytest.approx(expected, abs=1e-5) + + @pytest.mark.parametrize("scale_val", [0.5, 1.0, 2.0, -0.5, -1.0, -2.0]) + def test_output_has_exactly_4_levels(self, scale_val): + """All outputs must land on exactly 4 quantization levels.""" + scale = torch.tensor(scale_val) + x = torch.linspace(-5 * abs(scale_val), 5 * abs(scale_val), 1000) + out = _run_quantize_symmetric_torch(x, scale, **PARAMS_2BIT) + actual = sorted(set(round(v, 5) for v in out.tolist())) + expected = sorted(round(v, 5) for v in _expected_levels(scale_val, **PARAMS_2BIT)) + assert actual == pytest.approx(expected, abs=1e-5) + + @pytest.mark.parametrize("scale_val", [0.5, 1.0, 2.0, -0.5, -1.0, -2.0]) + def test_zero_maps_to_zero(self, scale_val): + """Zero input always maps to zero output.""" + scale = torch.tensor(scale_val) + x = torch.tensor([0.0]) + out = _run_quantize_symmetric_torch(x, scale, **PARAMS_2BIT) + assert out.item() == pytest.approx(0.0, abs=1e-5) + + +# --------------------------------------------------------------------------- +# Tests: QuantizeSymmetricTorch at 4-bit (signed-scale asymmetry) +# --------------------------------------------------------------------------- + + +class TestQuantizeSymmetricTorch4Bit: + """ + 4-bit with signed scale. step = |scale| / 8. + scale > 0: range [-scale, 7/8*scale], 8 neg + 7 pos levels + scale < 0: range [-7/8*|scale|, |scale|], 7 neg + 8 pos levels + """ + + @pytest.mark.parametrize("scale_val", [0.5, 1.0, 2.0, -0.5, -1.0, -2.0]) + def test_output_has_exactly_16_levels(self, scale_val): + scale = torch.tensor(scale_val) + x = torch.linspace(-5 * abs(scale_val), 5 * abs(scale_val), 5000) + out = _run_quantize_symmetric_torch(x, scale, **PARAMS_4BIT) + actual = sorted(set(round(v, 6) for v in out.tolist())) + expected = sorted(round(v, 6) for v in _expected_levels(scale_val, **PARAMS_4BIT)) + assert actual == pytest.approx(expected, abs=1e-5) + + @pytest.mark.parametrize("scale_val", [1.0, -1.0]) + def test_zero_maps_to_zero(self, scale_val): + scale = torch.tensor(scale_val) + x = torch.tensor([0.0]) + out = _run_quantize_symmetric_torch(x, scale, **PARAMS_4BIT) + assert out.item() == pytest.approx(0.0, abs=1e-5) + + +# --------------------------------------------------------------------------- +# Tests: QuantizeSymmetricTorch at 8-bit +# --------------------------------------------------------------------------- + + +class TestQuantizeSymmetricTorch8Bit: + """8-bit verification.""" + + @pytest.mark.parametrize("scale_val", [0.5, 1.0, 2.0]) + def test_output_has_exactly_256_levels(self, scale_val): + scale = torch.tensor(scale_val) + x = torch.linspace(-2 * abs(scale_val), 2 * abs(scale_val), 50000) + out = _run_quantize_symmetric_torch(x, scale, **PARAMS_8BIT) + actual = sorted(set(round(v, 8) for v in out.tolist())) + expected = sorted(round(v, 8) for v in _expected_levels(scale_val, **PARAMS_8BIT)) + assert actual == pytest.approx(expected, abs=1e-5) + + @pytest.mark.parametrize("scale_val", [0.5, 1.0, 2.0]) + def test_zero_maps_to_zero(self, scale_val): + scale = torch.tensor(scale_val) + x = torch.tensor([0.0]) + out = _run_quantize_symmetric_torch(x, scale, **PARAMS_8BIT) + assert out.item() == pytest.approx(0.0, abs=1e-5) + + +# --------------------------------------------------------------------------- +# Tests: scale convention difference (documenting, not a bug) +# --------------------------------------------------------------------------- + + +class TestScaleConventionDifference: + """ + QuantizeSymmetric and QuantizeSymmetricTorch use different scale conventions + and thus produce different outputs for the same scale value. + This is by design, not a bug. + """ + + def test_different_scale_semantics_2bit(self): + """ + With scale=1.0 at 2-bit: + QuantizeSymmetric: step = scale / level_high = 1.0, levels = {-2, -1, 0, 1} + QuantizeSymmetricTorch: step = |scale| / |level_low| = 0.5, levels = {-1, -0.5, 0, 0.5} + """ + scale = torch.tensor(1.0) + x = torch.linspace(-5, 5, 1000) + + out_qs = _run_quantize_symmetric(x, scale, **PARAMS_2BIT) + qs_levels = sorted(set(round(v, 5) for v in out_qs.tolist())) + + out_qst = _run_quantize_symmetric_torch(x, scale, **PARAMS_2BIT) + qst_levels = sorted(set(round(v, 5) for v in out_qst.tolist())) + + # Different step sizes, both have 4 levels + assert len(qs_levels) == 4 + assert len(qst_levels) == 4 + assert qs_levels == pytest.approx([-2.0, -1.0, 0.0, 1.0], abs=1e-5) + assert qst_levels == pytest.approx([-1.0, -0.5, 0.0, 0.5], abs=1e-5) + + def test_equivalent_when_scale_converted_2bit(self): + """ + QuantizeSymmetricTorch with scale_torch = QuantizeSymmetric's scale * |level_low| / level_high + should produce the same output levels (just different step). + """ + qs_scale = torch.tensor(1.0) # QuantizeSymmetric convention + # Convert to QuantizeSymmetricTorch convention: + # step = qs_scale / level_high = 1.0, qst_scale = step * |level_low| = 2.0 + qst_scale = qs_scale * (-PARAMS_2BIT["level_low"]) / PARAMS_2BIT["level_high"] + + x = torch.linspace(-3, 3, 500) + out_qs = _run_quantize_symmetric(x, qs_scale, **PARAMS_2BIT) + out_qst = _run_quantize_symmetric_torch(x, qst_scale, **PARAMS_2BIT) + torch.testing.assert_close(out_qst, out_qs) + + +# --------------------------------------------------------------------------- +# Tests: formula internals +# --------------------------------------------------------------------------- + + +class TestFormulaInternals: + """Verify the generalized input_low / input_range formulas.""" + + @pytest.mark.parametrize( + "bits, params", + [ + (2, PARAMS_2BIT), + (4, PARAMS_4BIT), + (8, PARAMS_8BIT), + ], + ) + def test_input_range_formula(self, bits, params): + """ + The generalized formula (levels-1)*|scale|/|level_low| should equal + the old formula |( 2 + 1/level_low ) * scale| for all standard signed ranges. + """ + scale = 1.0 + level_low = params["level_low"] + levels = params["levels"] + + new_range = (levels - 1) * abs(scale) / (-level_low) + old_range = abs((2 + 1 / level_low) * scale) + assert new_range == pytest.approx(old_range, rel=1e-12) + + @pytest.mark.parametrize( + "scale_val, level_low, level_high, levels", + [ + (1.0, -2, 1, 4), # 2-bit positive scale + (-1.0, -2, 1, 4), # 2-bit negative scale + (1.0, -8, 7, 16), # 4-bit positive scale + (-1.0, -8, 7, 16), # 4-bit negative scale + ], + ) + def test_input_low_and_range_consistency(self, scale_val, level_low, level_high, levels): + """ + input_low + input_range should map to the correct upper bound. + """ + scale = torch.tensor(scale_val) + input_low = torch.where(scale > 0, -scale, -scale / level_low * level_high) + input_range = (levels - 1) * torch.abs(scale) / (-level_low) + upper = input_low + input_range + + step = abs(scale_val) / (-level_low) + if scale_val > 0: + assert input_low.item() == pytest.approx(level_low * step, abs=1e-10) + assert upper.item() == pytest.approx(level_high * step, abs=1e-10) + else: + # Flipped: more quants on positive side + assert input_low.item() == pytest.approx(-level_high * step, abs=1e-10) + assert upper.item() == pytest.approx(-level_low * step, abs=1e-10) + + +# --------------------------------------------------------------------------- +# Tests: gradient flow (QuantizeSymmetricTorch backward) +# --------------------------------------------------------------------------- + + +class TestQuantizeSymmetricTorchGradient: + """Ensure backward pass runs and produces non-zero gradients.""" + + @pytest.mark.parametrize("bits_params", [PARAMS_2BIT, PARAMS_4BIT, PARAMS_8BIT]) + def test_gradient_flows(self, bits_params): + scale = torch.tensor(1.0, requires_grad=True) + x = torch.randn(10, requires_grad=True) + out = _run_quantize_symmetric_torch(x, scale, **bits_params) + loss = out.sum() + loss.backward() + assert scale.grad is not None + assert x.grad is not None + + def test_gradient_with_negative_scale_2bit(self): + scale = torch.tensor(-1.0, requires_grad=True) + x = torch.randn(10, requires_grad=True) + out = _run_quantize_symmetric_torch(x, scale, **PARAMS_2BIT) + loss = out.sum() + loss.backward() + assert scale.grad is not None + assert x.grad is not None